refactor code
This commit is contained in:
parent
5c776a6a67
commit
4832d6f00b
17 changed files with 725 additions and 326 deletions
|
@ -5,12 +5,17 @@ Dockerfile
|
||||||
.dockerignore
|
.dockerignore
|
||||||
config.json
|
config.json
|
||||||
config.json.sample
|
config.json.sample
|
||||||
bot
|
db
|
||||||
bot.log
|
bot.log
|
||||||
venv
|
venv
|
||||||
compose.yaml
|
.venv
|
||||||
|
*.yaml
|
||||||
|
*.yml
|
||||||
.git
|
.git
|
||||||
.idea
|
.idea
|
||||||
__pycache__
|
__pycache__
|
||||||
venv
|
.env
|
||||||
|
.env.example
|
||||||
|
.github
|
||||||
|
settings.js
|
||||||
|
|
||||||
|
|
10
.env.example
Normal file
10
.env.example
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
HOMESERVER="https://matrix.xxxxxx.xxxx" # required
|
||||||
|
USER_ID="@lullap:xxxxxxxxxxxxx.xxx" # required
|
||||||
|
PASSWORD="xxxxxxxxxxxxxxx" # required
|
||||||
|
DEVICE_ID="xxxxxxxxxxxxxx" # required
|
||||||
|
ROOM_ID="!FYCmBSkCRUXXXXXXXXX:matrix.XXX.XXX" # Optional, if the property is blank, bot will work on the room it is in (Unencrypted room only as for now)
|
||||||
|
OPENAI_API_KEY="xxxxxxxxxxxxxxxxx" # Optional, for !chat and !gpt command
|
||||||
|
BING_API_ENDPOINT="xxxxxxxxxxxxxxx" # Optional, for !bing command
|
||||||
|
ACCESS_TOKEN="xxxxxxxxxxxxxxxxxxxxx" # Optional, use user_id and password is recommended
|
||||||
|
JAILBREAKENABLED="true" # Optional
|
||||||
|
BING_AUTH_COOKIE="xxxxxxxxxxxxxxxxxxx" # _U cookie, Optional, for Bing Image Creator
|
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -26,7 +26,7 @@ share/python-wheels/
|
||||||
.installed.cfg
|
.installed.cfg
|
||||||
*.egg
|
*.egg
|
||||||
MANIFEST
|
MANIFEST
|
||||||
bot
|
db
|
||||||
bot.log
|
bot.log
|
||||||
|
|
||||||
# image generation folder
|
# image generation folder
|
||||||
|
@ -135,6 +135,7 @@ env.bak/
|
||||||
venv.bak/
|
venv.bak/
|
||||||
config.json
|
config.json
|
||||||
compose_local_build.yaml
|
compose_local_build.yaml
|
||||||
|
settings.js
|
||||||
|
|
||||||
# Spyder project settings
|
# Spyder project settings
|
||||||
.spyderproject
|
.spyderproject
|
||||||
|
|
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"python.analysis.typeCheckingMode": "off"
|
||||||
|
}
|
|
@ -105,16 +105,5 @@ class ImageGen:
|
||||||
with open(f"{output_dir}/{image_name}.jpeg", "wb") as output_file:
|
with open(f"{output_dir}/{image_name}.jpeg", "wb") as output_file:
|
||||||
for chunk in response.iter_content(chunk_size=8192):
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
output_file.write(chunk)
|
output_file.write(chunk)
|
||||||
# image_num = 0
|
|
||||||
# for link in links:
|
|
||||||
# with self.session.get(link, stream=True) as response:
|
|
||||||
# # save response to file
|
|
||||||
# response.raise_for_status()
|
|
||||||
# with open(f"{output_dir}/{image_num}.jpeg", "wb") as output_file:
|
|
||||||
# for chunk in response.iter_content(chunk_size=8192):
|
|
||||||
# output_file.write(chunk)
|
|
||||||
#
|
|
||||||
# image_num += 1
|
|
||||||
|
|
||||||
# return image path
|
|
||||||
return f"{output_dir}/{image_name}.jpeg"
|
return f"{output_dir}/{image_name}.jpeg"
|
||||||
|
|
|
@ -2,7 +2,7 @@ FROM python:3.11-alpine as base
|
||||||
|
|
||||||
FROM base as pybuilder
|
FROM base as pybuilder
|
||||||
RUN sed -i 's|v3\.\d*|edge|' /etc/apk/repositories
|
RUN sed -i 's|v3\.\d*|edge|' /etc/apk/repositories
|
||||||
RUN apk update && apk add olm-dev gcc musl-dev tzdata libmagic
|
RUN apk update && apk add olm-dev gcc musl-dev libmagic
|
||||||
COPY requirements.txt /requirements.txt
|
COPY requirements.txt /requirements.txt
|
||||||
RUN pip3 install --user -r /requirements.txt && rm /requirements.txt
|
RUN pip3 install --user -r /requirements.txt && rm /requirements.txt
|
||||||
|
|
||||||
|
@ -11,12 +11,10 @@ FROM base as runner
|
||||||
LABEL "org.opencontainers.image.source"="https://github.com/hibobmaster/matrix_chatgpt_bot"
|
LABEL "org.opencontainers.image.source"="https://github.com/hibobmaster/matrix_chatgpt_bot"
|
||||||
RUN apk update && apk add olm-dev libmagic
|
RUN apk update && apk add olm-dev libmagic
|
||||||
COPY --from=pybuilder /root/.local /usr/local
|
COPY --from=pybuilder /root/.local /usr/local
|
||||||
COPY --from=pybuilder /usr/share/zoneinfo /usr/share/zoneinfo
|
|
||||||
COPY . /app
|
COPY . /app
|
||||||
|
|
||||||
|
|
||||||
FROM runner
|
FROM runner
|
||||||
ENV TZ=Asia/Shanghai
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
CMD ["python", "main.py"]
|
CMD ["python", "main.py"]
|
||||||
|
|
||||||
|
|
|
@ -3,11 +3,11 @@ This is a simple Matrix bot that uses OpenAI's GPT API and Bing AI to generate r
|
||||||
![demo](https://i.imgur.com/kK4rnPf.jpeg "demo")
|
![demo](https://i.imgur.com/kK4rnPf.jpeg "demo")
|
||||||
|
|
||||||
## Installation and Setup
|
## Installation and Setup
|
||||||
Docker method:<br>
|
Docker method(Recommended):<br>
|
||||||
Edit `config.json` with proper values <br>
|
Edit `config.json` or `.env` with proper values <br>
|
||||||
Create an empty file, for persist database only<br>
|
Create an empty file, for persist database only<br>
|
||||||
```bash
|
```bash
|
||||||
touch bot
|
touch db
|
||||||
sudo docker compose up -d
|
sudo docker compose up -d
|
||||||
```
|
```
|
||||||
<hr>
|
<hr>
|
||||||
|
|
|
@ -4,8 +4,11 @@ import json
|
||||||
from log import getlogger
|
from log import getlogger
|
||||||
logger = getlogger()
|
logger = getlogger()
|
||||||
|
|
||||||
|
class askGPT:
|
||||||
|
def __init__(self):
|
||||||
|
self.session = aiohttp.ClientSession()
|
||||||
|
|
||||||
async def ask(prompt: str, api_endpoint: str, headers: dict) -> str:
|
async def oneTimeAsk(self, prompt: str, api_endpoint: str, headers: dict) -> str:
|
||||||
jsons = {
|
jsons = {
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
"messages": [
|
"messages": [
|
||||||
|
@ -15,11 +18,10 @@ async def ask(prompt: str, api_endpoint: str, headers: dict) -> str:
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
max_try = 5
|
max_try = 5
|
||||||
while max_try > 0:
|
while max_try > 0:
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_endpoint,
|
async with self.session.post(url=api_endpoint,
|
||||||
json=jsons, headers=headers, timeout=30) as response:
|
json=jsons, headers=headers, timeout=30) as response:
|
||||||
status_code = response.status
|
status_code = response.status
|
||||||
if not status_code == 200:
|
if not status_code == 200:
|
||||||
|
@ -31,9 +33,6 @@ async def ask(prompt: str, api_endpoint: str, headers: dict) -> str:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
resp = await response.read()
|
resp = await response.read()
|
||||||
await session.close()
|
|
||||||
return json.loads(resp)['choices'][0]['message']['content']
|
return json.loads(resp)['choices'][0]['message']['content']
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error Exception", exc_info=True)
|
logger.error("Error Exception", exc_info=True)
|
||||||
print(e)
|
|
||||||
pass
|
|
11
bing.py
11
bing.py
|
@ -4,31 +4,28 @@ import asyncio
|
||||||
from log import getlogger
|
from log import getlogger
|
||||||
# api_endpoint = "http://localhost:3000/conversation"
|
# api_endpoint = "http://localhost:3000/conversation"
|
||||||
logger = getlogger()
|
logger = getlogger()
|
||||||
python_boolean_to_json = {
|
|
||||||
"true": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class BingBot:
|
class BingBot:
|
||||||
def __init__(self, bing_api_endpoint: str, jailbreakEnabled: bool = False):
|
def __init__(self, bing_api_endpoint: str, jailbreakEnabled: bool = False):
|
||||||
self.data = {
|
self.data = {
|
||||||
# 'jailbreakConversationId': json.dumps(python_boolean_to_json['true']),
|
|
||||||
'clientOptions.clientToUse': 'bing',
|
'clientOptions.clientToUse': 'bing',
|
||||||
}
|
}
|
||||||
self.bing_api_endpoint = bing_api_endpoint
|
self.bing_api_endpoint = bing_api_endpoint
|
||||||
|
|
||||||
|
self.session = aiohttp.ClientSession()
|
||||||
|
|
||||||
self.jailbreakEnabled = jailbreakEnabled
|
self.jailbreakEnabled = jailbreakEnabled
|
||||||
|
|
||||||
if self.jailbreakEnabled:
|
if self.jailbreakEnabled:
|
||||||
self.data['jailbreakConversationId'] = json.dumps(python_boolean_to_json['true'])
|
self.data['jailbreakConversationId'] = json.dumps(True)
|
||||||
|
|
||||||
async def ask_bing(self, prompt) -> str:
|
async def ask_bing(self, prompt) -> str:
|
||||||
self.data['message'] = prompt
|
self.data['message'] = prompt
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
max_try = 5
|
max_try = 5
|
||||||
while max_try > 0:
|
while max_try > 0:
|
||||||
try:
|
try:
|
||||||
resp = await session.post(url=self.bing_api_endpoint, json=self.data)
|
resp = await self.session.post(url=self.bing_api_endpoint, json=self.data)
|
||||||
status_code = resp.status
|
status_code = resp.status
|
||||||
body = await resp.read()
|
body = await resp.read()
|
||||||
if not status_code == 200:
|
if not status_code == 200:
|
||||||
|
|
391
bot.py
391
bot.py
|
@ -2,23 +2,33 @@ import sys
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
import traceback
|
||||||
from nio import AsyncClient, MatrixRoom, RoomMessageText, LoginResponse, AsyncClientConfig
|
from typing import Optional, Union
|
||||||
|
from nio import (
|
||||||
|
AsyncClient,
|
||||||
|
MatrixRoom,
|
||||||
|
RoomMessageText,
|
||||||
|
InviteMemberEvent,
|
||||||
|
LoginResponse,
|
||||||
|
JoinError,
|
||||||
|
ToDeviceError,
|
||||||
|
LocalProtocolError,
|
||||||
|
KeyVerificationEvent,
|
||||||
|
KeyVerificationStart,
|
||||||
|
KeyVerificationCancel,
|
||||||
|
KeyVerificationKey,
|
||||||
|
KeyVerificationMac,
|
||||||
|
AsyncClientConfig
|
||||||
|
)
|
||||||
from nio.store.database import SqliteStore
|
from nio.store.database import SqliteStore
|
||||||
from ask_gpt import ask
|
from askgpt import askGPT
|
||||||
from send_message import send_room_message
|
from send_message import send_room_message
|
||||||
from v3 import Chatbot
|
from v3 import Chatbot
|
||||||
from log import getlogger
|
from log import getlogger
|
||||||
from bing import BingBot
|
from bing import BingBot
|
||||||
from BingImageGen import ImageGen
|
from BingImageGen import ImageGen
|
||||||
from send_image import send_room_image
|
from send_image import send_room_image
|
||||||
"""
|
|
||||||
free api_endpoint from https://github.com/ayaka14732/ChatGPTAPIFree
|
|
||||||
"""
|
|
||||||
chatgpt_api_endpoint_list = {
|
|
||||||
"free": "https://chatgpt-api.shn.hk/v1/",
|
|
||||||
"paid": "https://api.openai.com/v1/chat/completions"
|
|
||||||
}
|
|
||||||
logger = getlogger()
|
logger = getlogger()
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,35 +37,45 @@ class Bot:
|
||||||
self,
|
self,
|
||||||
homeserver: str,
|
homeserver: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
password: str,
|
|
||||||
device_id: str,
|
device_id: str,
|
||||||
api_key: Optional[str] = "",
|
chatgpt_api_endpoint: str = os.environ.get("CHATGPT_API_ENDPOINT") or "https://api.openai.com/v1/chat/completions",
|
||||||
|
api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") or "",
|
||||||
room_id: Optional[str] = '',
|
room_id: Optional[str] = '',
|
||||||
bing_api_endpoint: Optional[str] = '',
|
bing_api_endpoint: Optional[str] = '',
|
||||||
access_token: Optional[str] = '',
|
password: Union[str, None] = None,
|
||||||
jailbreakEnabled: Optional[bool] = False,
|
access_token: Union[str, None] = None,
|
||||||
|
jailbreakEnabled: Optional[bool] = True,
|
||||||
bing_auth_cookie: Optional[str] = '',
|
bing_auth_cookie: Optional[str] = '',
|
||||||
):
|
):
|
||||||
self.homeserver = homeserver
|
self.homeserver = homeserver
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.password = password
|
self.password = password
|
||||||
|
self.access_token = access_token
|
||||||
self.device_id = device_id
|
self.device_id = device_id
|
||||||
self.room_id = room_id
|
self.room_id = room_id
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self.chatgpt_api_endpoint = chatgpt_api_endpoint
|
||||||
self.bing_api_endpoint = bing_api_endpoint
|
self.bing_api_endpoint = bing_api_endpoint
|
||||||
self.jailbreakEnabled = jailbreakEnabled
|
self.jailbreakEnabled = jailbreakEnabled
|
||||||
self.bing_auth_cookie = bing_auth_cookie
|
self.bing_auth_cookie = bing_auth_cookie
|
||||||
# initialize AsyncClient object
|
# initialize AsyncClient object
|
||||||
self.store_path = os.getcwd()
|
self.store_path = os.getcwd()
|
||||||
self.config = AsyncClientConfig(store=SqliteStore,
|
self.config = AsyncClientConfig(store=SqliteStore,
|
||||||
store_name="bot",
|
store_name="db",
|
||||||
store_sync_tokens=True,
|
store_sync_tokens=True,
|
||||||
encryption_enabled=True,
|
encryption_enabled=True,
|
||||||
)
|
)
|
||||||
self.client = AsyncClient(self.homeserver, user=self.user_id, device_id=self.device_id,
|
self.client = AsyncClient(homeserver=self.homeserver, user=self.user_id, device_id=self.device_id,
|
||||||
config=self.config, store_path=self.store_path,)
|
config=self.config, store_path=self.store_path,)
|
||||||
if access_token != '':
|
|
||||||
self.client.access_token = access_token
|
if self.access_token is not None:
|
||||||
|
self.client.access_token = self.access_token
|
||||||
|
|
||||||
|
# setup event callbacks
|
||||||
|
self.client.add_event_callback(self.message_callback, (RoomMessageText, ))
|
||||||
|
self.client.add_event_callback(self.invite_callback, (InviteMemberEvent, ))
|
||||||
|
self.client.add_to_device_callback(self.to_device_callback, (KeyVerificationEvent, ))
|
||||||
|
|
||||||
# regular expression to match keyword [!gpt {prompt}] [!chat {prompt}]
|
# regular expression to match keyword [!gpt {prompt}] [!chat {prompt}]
|
||||||
self.gpt_prog = re.compile(r"^\s*!gpt\s*(.+)$")
|
self.gpt_prog = re.compile(r"^\s*!gpt\s*(.+)$")
|
||||||
self.chat_prog = re.compile(r"^\s*!chat\s*(.+)$")
|
self.chat_prog = re.compile(r"^\s*!chat\s*(.+)$")
|
||||||
|
@ -67,18 +87,21 @@ class Bot:
|
||||||
if self.api_key != '':
|
if self.api_key != '':
|
||||||
self.chatbot = Chatbot(api_key=self.api_key)
|
self.chatbot = Chatbot(api_key=self.api_key)
|
||||||
|
|
||||||
self.chatgpt_api_endpoint = chatgpt_api_endpoint_list['paid']
|
self.chatgpt_api_endpoint = self.chatgpt_api_endpoint
|
||||||
# request header for !gpt command
|
# request header for !gpt command
|
||||||
self.headers = {
|
self.headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": "Bearer " + self.api_key,
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
self.chatgpt_api_endpoint = chatgpt_api_endpoint_list['free']
|
self.chatgpt_api_endpoint = self.chatgpt_api_endpoint
|
||||||
self.headers = {
|
self.headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# initialize askGPT class
|
||||||
|
self.askgpt = askGPT()
|
||||||
|
|
||||||
# initialize bingbot
|
# initialize bingbot
|
||||||
if self.bing_api_endpoint != '':
|
if self.bing_api_endpoint != '':
|
||||||
self.bingbot = BingBot(bing_api_endpoint, jailbreakEnabled=self.jailbreakEnabled)
|
self.bingbot = BingBot(bing_api_endpoint, jailbreakEnabled=self.jailbreakEnabled)
|
||||||
|
@ -100,86 +123,330 @@ class Bot:
|
||||||
# reply event_id
|
# reply event_id
|
||||||
reply_to_event_id = event.event_id
|
reply_to_event_id = event.event_id
|
||||||
|
|
||||||
|
# sender_id
|
||||||
|
sender_id = event.sender
|
||||||
|
|
||||||
|
# user_message
|
||||||
|
raw_user_message = event.body
|
||||||
|
|
||||||
# print info to console
|
# print info to console
|
||||||
print(
|
print(
|
||||||
f"Message received in room {room.display_name}\n"
|
f"Message received in room {room.display_name}\n"
|
||||||
f"{room.user_name(event.sender)} | {event.body}"
|
f"{room.user_name(event.sender)} | {raw_user_message}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# prevent command trigger loop
|
||||||
if self.user_id != event.sender:
|
if self.user_id != event.sender:
|
||||||
# remove newline character from event.body
|
# remove newline character from event.body
|
||||||
event.body = re.sub("\r\n|\r|\n", " ", event.body)
|
content_body = re.sub("\r\n|\r|\n", " ", raw_user_message)
|
||||||
|
|
||||||
# chatgpt
|
# chatgpt
|
||||||
n = self.chat_prog.match(event.body)
|
n = self.chat_prog.match(content_body)
|
||||||
if n:
|
if n:
|
||||||
prompt = n.group(1)
|
prompt = n.group(1)
|
||||||
if self.api_key != '':
|
if self.api_key != '':
|
||||||
await self.gpt(room_id, reply_to_event_id, prompt)
|
await self.chat(room_id, reply_to_event_id, prompt, sender_id, raw_user_message)
|
||||||
else:
|
else:
|
||||||
logger.warning("No API_KEY provided")
|
logger.warning("No API_KEY provided")
|
||||||
await send_room_message(self.client, room_id, send_text="API_KEY not provided")
|
await send_room_message(self.client, room_id, send_text="API_KEY not provided")
|
||||||
|
|
||||||
m = self.gpt_prog.match(event.body)
|
m = self.gpt_prog.match(content_body)
|
||||||
if m:
|
if m:
|
||||||
prompt = m.group(1)
|
prompt = m.group(1)
|
||||||
await self.chat(room_id, reply_to_event_id, prompt)
|
await self.gpt(room_id, reply_to_event_id, prompt, sender_id, raw_user_message)
|
||||||
|
|
||||||
# bing ai
|
# bing ai
|
||||||
if self.bing_api_endpoint != '':
|
if self.bing_api_endpoint != '':
|
||||||
b = self.bing_prog.match(event.body)
|
b = self.bing_prog.match(content_body)
|
||||||
if b:
|
if b:
|
||||||
prompt = b.group(1)
|
prompt = b.group(1)
|
||||||
await self.bing(room_id, reply_to_event_id, prompt)
|
# raw_content_body used for construct formatted_body
|
||||||
|
await self.bing(room_id, reply_to_event_id, prompt, sender_id, raw_user_message)
|
||||||
|
|
||||||
# Image Generation by Microsoft Bing
|
# Image Generation by Microsoft Bing
|
||||||
if self.bing_auth_cookie != '':
|
if self.bing_auth_cookie != '':
|
||||||
i = self.pic_prog.match(event.body)
|
i = self.pic_prog.match(content_body)
|
||||||
if i:
|
if i:
|
||||||
prompt = i.group(1)
|
prompt = i.group(1)
|
||||||
await self.pic(room_id, prompt)
|
await self.pic(room_id, prompt)
|
||||||
|
|
||||||
# help command
|
# help command
|
||||||
h = self.help_prog.match(event.body)
|
h = self.help_prog.match(content_body)
|
||||||
if h:
|
if h:
|
||||||
await self.help(room_id)
|
await self.help(room_id)
|
||||||
|
|
||||||
# !gpt command
|
# invite_callback event
|
||||||
async def gpt(self, room_id, reply_to_event_id, prompt):
|
async def invite_callback(self, room: MatrixRoom, event: InviteMemberEvent) -> None:
|
||||||
await self.client.room_typing(room_id)
|
"""Handle an incoming invite event.
|
||||||
|
https://github.com/8go/matrix-eno-bot/blob/ad037e02bd2960941109e9526c1033dd157bb212/callbacks.py#L104
|
||||||
|
If an invite is received, then join the room specified in the invite.
|
||||||
|
code copied from:
|
||||||
|
"""
|
||||||
|
logger.debug(f"Got invite to {room.room_id} from {event.sender}.")
|
||||||
|
# Attempt to join 3 times before giving up
|
||||||
|
for attempt in range(3):
|
||||||
|
result = await self.client.join(room.room_id)
|
||||||
|
if type(result) == JoinError:
|
||||||
|
logger.error(
|
||||||
|
f"Error joining room {room.room_id} (attempt %d): %s",
|
||||||
|
attempt, result.message,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.error("Unable to join room: %s", room.room_id)
|
||||||
|
|
||||||
|
# Successfully joined room
|
||||||
|
logger.info(f"Joined {room.room_id}")
|
||||||
|
|
||||||
|
# to_device_callback event
|
||||||
|
async def to_device_callback(self, event: KeyVerificationEvent) -> None:
|
||||||
|
"""Handle events sent to device.
|
||||||
|
|
||||||
|
Specifically this will perform Emoji verification.
|
||||||
|
It will accept an incoming Emoji verification requests
|
||||||
|
and follow the verification protocol.
|
||||||
|
code copied from: https://github.com/8go/matrix-eno-bot/blob/ad037e02bd2960941109e9526c1033dd157bb212/callbacks.py#L127
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# run synchronous function in different thread
|
client = self.client
|
||||||
text = await asyncio.to_thread(self.chatbot.ask, prompt)
|
logger.debug(
|
||||||
text = text.strip()
|
f"Device Event of type {type(event)} received in "
|
||||||
await send_room_message(self.client, room_id, send_text=text,
|
"to_device_cb().")
|
||||||
reply_to_event_id=reply_to_event_id)
|
|
||||||
|
if isinstance(event, KeyVerificationStart): # first step
|
||||||
|
""" first step: receive KeyVerificationStart
|
||||||
|
KeyVerificationStart(
|
||||||
|
source={'content':
|
||||||
|
{'method': 'm.sas.v1',
|
||||||
|
'from_device': 'DEVICEIDXY',
|
||||||
|
'key_agreement_protocols':
|
||||||
|
['curve25519-hkdf-sha256', 'curve25519'],
|
||||||
|
'hashes': ['sha256'],
|
||||||
|
'message_authentication_codes':
|
||||||
|
['hkdf-hmac-sha256', 'hmac-sha256'],
|
||||||
|
'short_authentication_string':
|
||||||
|
['decimal', 'emoji'],
|
||||||
|
'transaction_id': 'SomeTxId'
|
||||||
|
},
|
||||||
|
'type': 'm.key.verification.start',
|
||||||
|
'sender': '@user2:example.org'
|
||||||
|
},
|
||||||
|
sender='@user2:example.org',
|
||||||
|
transaction_id='SomeTxId',
|
||||||
|
from_device='DEVICEIDXY',
|
||||||
|
method='m.sas.v1',
|
||||||
|
key_agreement_protocols=[
|
||||||
|
'curve25519-hkdf-sha256', 'curve25519'],
|
||||||
|
hashes=['sha256'],
|
||||||
|
message_authentication_codes=[
|
||||||
|
'hkdf-hmac-sha256', 'hmac-sha256'],
|
||||||
|
short_authentication_string=['decimal', 'emoji'])
|
||||||
|
"""
|
||||||
|
|
||||||
|
if "emoji" not in event.short_authentication_string:
|
||||||
|
estr = ("Other device does not support emoji verification "
|
||||||
|
f"{event.short_authentication_string}. Aborting.")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
return
|
||||||
|
resp = await client.accept_key_verification(
|
||||||
|
event.transaction_id)
|
||||||
|
if isinstance(resp, ToDeviceError):
|
||||||
|
estr = f"accept_key_verification() failed with {resp}"
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
|
||||||
|
sas = client.key_verifications[event.transaction_id]
|
||||||
|
|
||||||
|
todevice_msg = sas.share_key()
|
||||||
|
resp = await client.to_device(todevice_msg)
|
||||||
|
if isinstance(resp, ToDeviceError):
|
||||||
|
estr = f"to_device() failed with {resp}"
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
|
||||||
|
elif isinstance(event, KeyVerificationCancel): # anytime
|
||||||
|
""" at any time: receive KeyVerificationCancel
|
||||||
|
KeyVerificationCancel(source={
|
||||||
|
'content': {'code': 'm.mismatched_sas',
|
||||||
|
'reason': 'Mismatched authentication string',
|
||||||
|
'transaction_id': 'SomeTxId'},
|
||||||
|
'type': 'm.key.verification.cancel',
|
||||||
|
'sender': '@user2:example.org'},
|
||||||
|
sender='@user2:example.org',
|
||||||
|
transaction_id='SomeTxId',
|
||||||
|
code='m.mismatched_sas',
|
||||||
|
reason='Mismatched short authentication string')
|
||||||
|
"""
|
||||||
|
|
||||||
|
# There is no need to issue a
|
||||||
|
# client.cancel_key_verification(tx_id, reject=False)
|
||||||
|
# here. The SAS flow is already cancelled.
|
||||||
|
# We only need to inform the user.
|
||||||
|
estr = (f"Verification has been cancelled by {event.sender} "
|
||||||
|
f"for reason \"{event.reason}\".")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
|
||||||
|
elif isinstance(event, KeyVerificationKey): # second step
|
||||||
|
""" Second step is to receive KeyVerificationKey
|
||||||
|
KeyVerificationKey(
|
||||||
|
source={'content': {
|
||||||
|
'key': 'SomeCryptoKey',
|
||||||
|
'transaction_id': 'SomeTxId'},
|
||||||
|
'type': 'm.key.verification.key',
|
||||||
|
'sender': '@user2:example.org'
|
||||||
|
},
|
||||||
|
sender='@user2:example.org',
|
||||||
|
transaction_id='SomeTxId',
|
||||||
|
key='SomeCryptoKey')
|
||||||
|
"""
|
||||||
|
sas = client.key_verifications[event.transaction_id]
|
||||||
|
|
||||||
|
print(f"{sas.get_emoji()}")
|
||||||
|
# don't log the emojis
|
||||||
|
|
||||||
|
# The bot process must run in forground with a screen and
|
||||||
|
# keyboard so that user can accept/reject via keyboard.
|
||||||
|
# For emoji verification bot must not run as service or
|
||||||
|
# in background.
|
||||||
|
yn = input("Do the emojis match? (Y/N) (C for Cancel) ")
|
||||||
|
if yn.lower() == "y":
|
||||||
|
estr = ("Match! The verification for this "
|
||||||
|
"device will be accepted.")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
resp = await client.confirm_short_auth_string(
|
||||||
|
event.transaction_id)
|
||||||
|
if isinstance(resp, ToDeviceError):
|
||||||
|
estr = ("confirm_short_auth_string() "
|
||||||
|
f"failed with {resp}")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
elif yn.lower() == "n": # no, don't match, reject
|
||||||
|
estr = ("No match! Device will NOT be verified "
|
||||||
|
"by rejecting verification.")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
resp = await client.cancel_key_verification(
|
||||||
|
event.transaction_id, reject=True)
|
||||||
|
if isinstance(resp, ToDeviceError):
|
||||||
|
estr = (f"cancel_key_verification failed with {resp}")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
else: # C or anything for cancel
|
||||||
|
estr = ("Cancelled by user! Verification will be "
|
||||||
|
"cancelled.")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
resp = await client.cancel_key_verification(
|
||||||
|
event.transaction_id, reject=False)
|
||||||
|
if isinstance(resp, ToDeviceError):
|
||||||
|
estr = (f"cancel_key_verification failed with {resp}")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
|
||||||
|
elif isinstance(event, KeyVerificationMac): # third step
|
||||||
|
""" Third step is to receive KeyVerificationMac
|
||||||
|
KeyVerificationMac(
|
||||||
|
source={'content': {
|
||||||
|
'mac': {'ed25519:DEVICEIDXY': 'SomeKey1',
|
||||||
|
'ed25519:SomeKey2': 'SomeKey3'},
|
||||||
|
'keys': 'SomeCryptoKey4',
|
||||||
|
'transaction_id': 'SomeTxId'},
|
||||||
|
'type': 'm.key.verification.mac',
|
||||||
|
'sender': '@user2:example.org'},
|
||||||
|
sender='@user2:example.org',
|
||||||
|
transaction_id='SomeTxId',
|
||||||
|
mac={'ed25519:DEVICEIDXY': 'SomeKey1',
|
||||||
|
'ed25519:SomeKey2': 'SomeKey3'},
|
||||||
|
keys='SomeCryptoKey4')
|
||||||
|
"""
|
||||||
|
sas = client.key_verifications[event.transaction_id]
|
||||||
|
try:
|
||||||
|
todevice_msg = sas.get_mac()
|
||||||
|
except LocalProtocolError as e:
|
||||||
|
# e.g. it might have been cancelled by ourselves
|
||||||
|
estr = (f"Cancelled or protocol error: Reason: {e}.\n"
|
||||||
|
f"Verification with {event.sender} not concluded. "
|
||||||
|
"Try again?")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
else:
|
||||||
|
resp = await client.to_device(todevice_msg)
|
||||||
|
if isinstance(resp, ToDeviceError):
|
||||||
|
estr = f"to_device failed with {resp}"
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
estr = (f"sas.we_started_it = {sas.we_started_it}\n"
|
||||||
|
f"sas.sas_accepted = {sas.sas_accepted}\n"
|
||||||
|
f"sas.canceled = {sas.canceled}\n"
|
||||||
|
f"sas.timed_out = {sas.timed_out}\n"
|
||||||
|
f"sas.verified = {sas.verified}\n"
|
||||||
|
f"sas.verified_devices = {sas.verified_devices}\n")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
estr = ("Emoji verification was successful!\n"
|
||||||
|
"Initiate another Emoji verification from "
|
||||||
|
"another device or room if desired. "
|
||||||
|
"Or if done verifying, hit Control-C to stop the "
|
||||||
|
"bot in order to restart it as a service or to "
|
||||||
|
"run it in the background.")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
else:
|
||||||
|
estr = (f"Received unexpected event type {type(event)}. "
|
||||||
|
f"Event is {event}. Event will be ignored.")
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
except BaseException:
|
||||||
|
estr = traceback.format_exc()
|
||||||
|
print(estr)
|
||||||
|
logger.info(estr)
|
||||||
|
|
||||||
|
# !chat command
|
||||||
|
async def chat(self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message):
|
||||||
|
await self.client.room_typing(room_id, timeout=120000)
|
||||||
|
try:
|
||||||
|
text = await asyncio.wait_for(self.chatbot.ask_async(prompt), timeout=120)
|
||||||
|
except TimeoutError as e:
|
||||||
|
logger.error("timeoutException", exc_info=True)
|
||||||
|
text = "Timeout error"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error", exc_info=True)
|
logger.error("Error", exc_info=True)
|
||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
|
|
||||||
# !chat command
|
text = text.strip()
|
||||||
async def chat(self, room_id, reply_to_event_id, prompt):
|
try:
|
||||||
|
await send_room_message(self.client, room_id, reply_message=text,
|
||||||
|
reply_to_event_id=reply_to_event_id, sender_id=sender_id, user_message=raw_user_message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# !gpt command
|
||||||
|
async def gpt(self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message):
|
||||||
try:
|
try:
|
||||||
# sending typing state
|
# sending typing state
|
||||||
await self.client.room_typing(room_id)
|
await self.client.room_typing(room_id, timeout=120000)
|
||||||
# timeout 120s
|
# timeout 120s
|
||||||
text = await asyncio.wait_for(ask(prompt, self.chatgpt_api_endpoint, self.headers), timeout=120)
|
text = await asyncio.wait_for(self.askgpt.oneTimeAsk(prompt, self.chatgpt_api_endpoint, self.headers), timeout=120)
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
logger.error("timeoutException", exc_info=True)
|
logger.error("timeoutException", exc_info=True)
|
||||||
text = "Timeout error"
|
text = "Timeout error"
|
||||||
|
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
try:
|
try:
|
||||||
await send_room_message(self.client, room_id, send_text=text,
|
await send_room_message(self.client, room_id, reply_message=text,
|
||||||
reply_to_event_id=reply_to_event_id)
|
reply_to_event_id=reply_to_event_id, sender_id=sender_id, user_message=raw_user_message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error: {e}", exc_info=True)
|
logger.error(f"Error: {e}", exc_info=True)
|
||||||
|
|
||||||
# !bing command
|
# !bing command
|
||||||
async def bing(self, room_id, reply_to_event_id, prompt):
|
async def bing(self, room_id, reply_to_event_id, prompt, sender_id, raw_content_body):
|
||||||
try:
|
try:
|
||||||
# sending typing state
|
# sending typing state
|
||||||
await self.client.room_typing(room_id)
|
await self.client.room_typing(room_id, timeout=120000)
|
||||||
# timeout 120s
|
# timeout 120s
|
||||||
text = await asyncio.wait_for(self.bingbot.ask_bing(prompt), timeout=120)
|
text = await asyncio.wait_for(self.bingbot.ask_bing(prompt), timeout=120)
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
|
@ -187,8 +454,8 @@ class Bot:
|
||||||
text = "Timeout error"
|
text = "Timeout error"
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
try:
|
try:
|
||||||
await send_room_message(self.client, room_id, send_text=text,
|
await send_room_message(self.client, room_id, reply_message=text,
|
||||||
reply_to_event_id=reply_to_event_id)
|
reply_to_event_id=reply_to_event_id, sender=sender_id, raw_content_body=raw_content_body)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error: {e}", exc_info=True)
|
logger.error(f"Error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
@ -216,7 +483,7 @@ class Bot:
|
||||||
"!bing [content], chat with context conversation powered by Bing AI\n" + \
|
"!bing [content], chat with context conversation powered by Bing AI\n" + \
|
||||||
"!pic [prompt], Image generation by Microsoft Bing"
|
"!pic [prompt], Image generation by Microsoft Bing"
|
||||||
|
|
||||||
await send_room_message(self.client, room_id, send_text=help_info)
|
await send_room_message(self.client, room_id, reply_message=help_info)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error: {e}", exc_info=True)
|
logger.error(f"Error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
@ -232,6 +499,24 @@ class Bot:
|
||||||
logger.error(f"Error: {e}", exc_info=True)
|
logger.error(f"Error: {e}", exc_info=True)
|
||||||
|
|
||||||
# sync messages in the room
|
# sync messages in the room
|
||||||
async def sync_forever(self, timeout=30000):
|
async def sync_forever(self, timeout=30000, full_state=True):
|
||||||
self.client.add_event_callback(self.message_callback, RoomMessageText)
|
await self.client.sync_forever(timeout=timeout, full_state=full_state)
|
||||||
await self.client.sync_forever(timeout=timeout, full_state=True)
|
|
||||||
|
# Sync encryption keys with the server
|
||||||
|
async def sync_encryption_key(self):
|
||||||
|
if self.client.should_upload_keys:
|
||||||
|
await self.client.keys_upload()
|
||||||
|
|
||||||
|
# Trust own devices
|
||||||
|
async def trust_own_devices(self):
|
||||||
|
await self.client.sync(timeout=30000, full_state=True)
|
||||||
|
for device_id, olm_device in self.client.device_store[
|
||||||
|
self.user_id].items():
|
||||||
|
logger.debug("My other devices are: "
|
||||||
|
f"device_id={device_id}, "
|
||||||
|
f"olm_device={olm_device}.")
|
||||||
|
logger.info("Setting up trust for my own "
|
||||||
|
f"device {device_id} and session key "
|
||||||
|
f"{olm_device.keys['ed25519']}.")
|
||||||
|
self.client.verify_device(olm_device)
|
||||||
|
|
|
@ -6,10 +6,13 @@ services:
|
||||||
# build:
|
# build:
|
||||||
# context: .
|
# context: .
|
||||||
# dockerfile: ./Dockerfile
|
# dockerfile: ./Dockerfile
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
volumes:
|
volumes:
|
||||||
- ./config.json:/app/config.json
|
# use env file or config.json
|
||||||
# use touch to create an empty file bot, for persist database only
|
# - ./config.json:/app/config.json
|
||||||
- ./bot:/app/bot
|
# use touch to create an empty file db, for persist database only
|
||||||
|
- ./db:/app/db
|
||||||
networks:
|
networks:
|
||||||
- matrix_network
|
- matrix_network
|
||||||
# api:
|
# api:
|
||||||
|
|
51
main.py
51
main.py
|
@ -1,30 +1,53 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
from bot import Bot
|
from bot import Bot
|
||||||
|
from nio import Api, SyncResponse
|
||||||
|
from log import getlogger
|
||||||
|
|
||||||
|
logger = getlogger()
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
fp = open('config.json', 'r')
|
if os.path.exists('config.json'):
|
||||||
|
fp = open('config.json', 'r', encoding="utf8")
|
||||||
config = json.load(fp)
|
config = json.load(fp)
|
||||||
matrix_bot = Bot(homeserver=config['homeserver'],
|
|
||||||
user_id=config['user_id'],
|
matrix_bot = Bot(homeserver=os.environ.get("HOMESERVER") or config.get('homeserver'),
|
||||||
password=config.get('password', ''), # provide a default value when the key does not exist
|
user_id=os.environ.get("USER_ID") or config.get('user_id') ,
|
||||||
device_id=config['device_id'],
|
password=os.environ.get("PASSWORD") or config.get('password'),
|
||||||
room_id=config.get('room_id', ''),
|
device_id=os.environ.get("DEVICE_ID") or config.get('device_id'),
|
||||||
api_key=config.get('api_key', ''),
|
room_id=os.environ.get("ROOM_ID") or config.get('room_id'),
|
||||||
bing_api_endpoint=config.get('bing_api_endpoint', ''),
|
api_key=os.environ.get("OPENAI_API_KEY") or config.get('api_key'),
|
||||||
access_token=config.get('access_token', ''),
|
bing_api_endpoint=os.environ.get("BING_API_ENDPOINT") or config.get('bing_api_endpoint'),
|
||||||
jailbreakEnabled=config.get('jailbreakEnabled', False),
|
access_token=os.environ.get("ACCESS_TOKEN") or config.get('access_token'),
|
||||||
bing_auth_cookie=config.get('bing_auth_cookie', ''),
|
jailbreakEnabled=os.environ.get("JAILBREAKENABLED", "False").lower() in ('true', '1') or config.get('jailbreakEnabled'),
|
||||||
|
bing_auth_cookie=os.environ.get("BING_AUTH_COOKIE") or config.get('bing_auth_cookie'),
|
||||||
)
|
)
|
||||||
if config.get('access_token', '') == '':
|
# if not set access_token, then login via password
|
||||||
|
# if os.path.exists('config.json'):
|
||||||
|
# fp = open('config.json', 'r', encoding="utf8")
|
||||||
|
# config = json.load(fp)
|
||||||
|
# if os.environ.get("ACCESS_TOKEN") is None and config.get("access_token") is None:
|
||||||
|
# await matrix_bot.login()
|
||||||
|
|
||||||
|
# elif os.environ.get("ACCESS_TOKEN") is None:
|
||||||
|
# await matrix_bot.login()
|
||||||
|
|
||||||
await matrix_bot.login()
|
await matrix_bot.login()
|
||||||
await matrix_bot.sync_forever()
|
|
||||||
|
# await matrix_bot.sync_encryption_key()
|
||||||
|
|
||||||
|
# await matrix_bot.trust_own_devices()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await matrix_bot.sync_forever(timeout=3000, full_state=True)
|
||||||
|
finally:
|
||||||
|
await matrix_bot.client.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("matrix chatgpt bot start.....")
|
logger.debug("matrix chatgpt bot start.....")
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
|
|
@ -1,41 +1,46 @@
|
||||||
aiofiles==0.6.0
|
aiofiles==23.1.0
|
||||||
aiohttp==3.8.4
|
aiohttp==3.8.4
|
||||||
aiohttp-socks==0.7.1
|
aiohttp-socks==0.7.1
|
||||||
aiosignal==1.3.1
|
aiosignal==1.3.1
|
||||||
|
anyio==3.6.2
|
||||||
async-timeout==4.0.2
|
async-timeout==4.0.2
|
||||||
atomicwrites==1.4.1
|
atomicwrites==1.4.1
|
||||||
attrs==22.2.0
|
attrs==22.2.0
|
||||||
blobfile==2.0.1
|
blobfile==2.0.1
|
||||||
cachetools==4.2.4
|
cachetools==5.3.0
|
||||||
certifi==2022.12.7
|
certifi==2022.12.7
|
||||||
cffi==1.15.1
|
cffi==1.15.1
|
||||||
charset-normalizer==3.0.1
|
charset-normalizer==3.1.0
|
||||||
filelock==3.9.0
|
filelock==3.10.7
|
||||||
frozenlist==1.3.3
|
frozenlist==1.3.3
|
||||||
future==0.18.3
|
future==0.18.3
|
||||||
h11==0.12.0
|
h11==0.14.0
|
||||||
h2==4.1.0
|
h2==4.1.0
|
||||||
hpack==4.0.0
|
hpack==4.0.0
|
||||||
|
httpcore==0.16.3
|
||||||
|
httpx==0.23.3
|
||||||
hyperframe==6.0.1
|
hyperframe==6.0.1
|
||||||
idna==3.4
|
idna==3.4
|
||||||
jsonschema==4.17.3
|
jsonschema==4.17.3
|
||||||
Logbook==1.5.3
|
Logbook==1.5.3
|
||||||
lxml==4.9.2
|
lxml==4.9.2
|
||||||
matrix-nio==0.20.1
|
matrix-nio[e2e]
|
||||||
multidict==6.0.4
|
multidict==6.0.4
|
||||||
peewee==3.16.0
|
peewee==3.16.0
|
||||||
Pillow==9.4.0
|
Pillow==9.5.0
|
||||||
pycparser==2.21
|
pycparser==2.21
|
||||||
pycryptodome==3.17
|
pycryptodome==3.17
|
||||||
pycryptodomex==3.17
|
pycryptodomex==3.17
|
||||||
pyrsistent==0.19.3
|
pyrsistent==0.19.3
|
||||||
python-magic==0.4.27
|
python-magic==0.4.27
|
||||||
python-olm==3.1.3
|
python-socks==2.2.0
|
||||||
python-socks==2.1.1
|
regex==2023.3.23
|
||||||
regex==2022.10.31
|
|
||||||
requests==2.28.2
|
requests==2.28.2
|
||||||
tiktoken==0.3.0
|
rfc3986==1.5.0
|
||||||
|
sniffio==1.3.0
|
||||||
unpaddedbase64==2.1.0
|
unpaddedbase64==2.1.0
|
||||||
urllib3==1.26.14
|
urllib3==1.26.15
|
||||||
wcwidth==0.2.6
|
wcwidth==0.2.6
|
||||||
yarl==1.8.2
|
yarl==1.8.2
|
||||||
|
python-olm >= '3.1.0'
|
||||||
|
tiktoken==0.3.3
|
||||||
|
|
|
@ -1,14 +1,21 @@
|
||||||
from nio import AsyncClient
|
from nio import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
async def send_room_message(client: AsyncClient,
|
async def send_room_message(client: AsyncClient,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
send_text: str,
|
reply_message: str,
|
||||||
|
sender_id: str = '',
|
||||||
|
user_message: str = '',
|
||||||
reply_to_event_id: str = '') -> None:
|
reply_to_event_id: str = '') -> None:
|
||||||
if reply_to_event_id == '':
|
if reply_to_event_id == '':
|
||||||
content = {"msgtype": "m.text", "body": f"{send_text}", }
|
content = {"msgtype": "m.text", "body": reply_message, }
|
||||||
else:
|
else:
|
||||||
content={"msgtype": "m.text", "body": f"{send_text}",
|
body = r'> <' + sender_id + r'> ' + user_message + r'\n\n' + reply_message
|
||||||
|
format = r'org.matrix.custom.html'
|
||||||
|
formatted_body = r'<mx-reply><blockquote><a href=\"https://matrix.to/#/' + room_id + r'/' + reply_to_event_id
|
||||||
|
+ r'\">In reply to</a> <a href=\"https://matrix.to/#/' + sender_id + r'\">' + sender_id
|
||||||
|
+ r'</a><br>' + user_message + r'</blockquote></mx-reply>' + reply_message
|
||||||
|
|
||||||
|
content={"msgtype": "m.text", "body": body, "format": format, "formatted_body": formatted_body,
|
||||||
"m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}}, }
|
"m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}}, }
|
||||||
await client.room_send(
|
await client.room_send(
|
||||||
room_id,
|
room_id,
|
||||||
|
|
58
test.py
58
test.py
|
@ -1,58 +0,0 @@
|
||||||
from v3 import Chatbot
|
|
||||||
import asyncio
|
|
||||||
from ask_gpt import ask
|
|
||||||
import json
|
|
||||||
from bing import BingBot
|
|
||||||
|
|
||||||
fp = open("config.json", "r")
|
|
||||||
config = json.load(fp)
|
|
||||||
api_key = config.get('api_key', '')
|
|
||||||
bing_api_endpoint = config.get('bing_api_endpoint', '')
|
|
||||||
api_endpoint_list = {
|
|
||||||
"free": "https://chatgpt-api.shn.hk/v1/",
|
|
||||||
"paid": "https://api.openai.com/v1/chat/completions"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_v3(prompt: str):
|
|
||||||
bot = Chatbot(api_key=api_key)
|
|
||||||
resp = bot.ask(prompt=prompt)
|
|
||||||
print(resp)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_ask_gpt_paid(prompt: str):
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": "Bearer " + api_key,
|
|
||||||
}
|
|
||||||
api_endpoint = api_endpoint_list['paid']
|
|
||||||
# test ask_gpt.py ask()
|
|
||||||
print(await ask(prompt, api_endpoint, headers))
|
|
||||||
|
|
||||||
|
|
||||||
async def test_ask_gpt_free(prompt: str):
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
api_endpoint = api_endpoint_list['free']
|
|
||||||
print(await ask(prompt, api_endpoint, headers))
|
|
||||||
|
|
||||||
|
|
||||||
async def test_bingbot():
|
|
||||||
if bing_api_endpoint != '':
|
|
||||||
bingbot = BingBot(bing_api_endpoint)
|
|
||||||
prompt1 = "Hello World"
|
|
||||||
prompt2 = "Do you know Victor Marie Hugo"
|
|
||||||
prompt3 = "Can you tell me something about him?"
|
|
||||||
resp1 = await bingbot.ask_bing(prompt1)
|
|
||||||
resp2 = await bingbot.ask_bing(prompt2)
|
|
||||||
resp3 = await bingbot.ask_bing(prompt3)
|
|
||||||
print(resp1)
|
|
||||||
print(resp2)
|
|
||||||
print(resp3)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_v3("Hello World")
|
|
||||||
asyncio.run(test_ask_gpt_paid("Hello World"))
|
|
||||||
asyncio.run(test_ask_gpt_free("Hello World"))
|
|
||||||
asyncio.run(test_bingbot())
|
|
250
v3.py
250
v3.py
|
@ -1,20 +1,12 @@
|
||||||
"""
|
|
||||||
A simple wrapper for the official ChatGPT API
|
|
||||||
https://github.com/acheong08/ChatGPT/blob/main/src/revChatGPT/V3.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
import httpx
|
||||||
import requests
|
import requests
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
ENGINE = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo"
|
|
||||||
ENCODER = tiktoken.get_encoding("gpt2")
|
|
||||||
|
|
||||||
|
|
||||||
class Chatbot:
|
class Chatbot:
|
||||||
"""
|
"""
|
||||||
Official ChatGPT API
|
Official ChatGPT API
|
||||||
|
@ -22,29 +14,63 @@ class Chatbot:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str = None,
|
api_key: str,
|
||||||
engine: str = None,
|
engine: str = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo",
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
max_tokens: int = 4096,
|
timeout: float = None,
|
||||||
|
max_tokens: int = None,
|
||||||
temperature: float = 0.5,
|
temperature: float = 0.5,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
|
presence_penalty: float = 0.0,
|
||||||
|
frequency_penalty: float = 0.0,
|
||||||
reply_count: int = 1,
|
reply_count: int = 1,
|
||||||
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
|
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
|
||||||
"""
|
"""
|
||||||
self.engine = engine or ENGINE
|
self.engine: str = engine
|
||||||
|
self.api_key: str = api_key
|
||||||
|
self.system_prompt: str = system_prompt
|
||||||
|
self.max_tokens: int = max_tokens or (
|
||||||
|
31000 if engine == "gpt-4-32k" else 7000 if engine == "gpt-4" else 4000
|
||||||
|
)
|
||||||
|
self.truncate_limit: int = (
|
||||||
|
30500 if engine == "gpt-4-32k" else 6500 if engine == "gpt-4" else 3500
|
||||||
|
)
|
||||||
|
self.temperature: float = temperature
|
||||||
|
self.top_p: float = top_p
|
||||||
|
self.presence_penalty: float = presence_penalty
|
||||||
|
self.frequency_penalty: float = frequency_penalty
|
||||||
|
self.reply_count: int = reply_count
|
||||||
|
self.timeout: float = timeout
|
||||||
|
self.proxy = proxy
|
||||||
self.session = requests.Session()
|
self.session = requests.Session()
|
||||||
self.api_key = api_key
|
self.session.proxies.update(
|
||||||
# self.proxy = proxy
|
{
|
||||||
# if self.proxy:
|
"http": proxy,
|
||||||
# proxies = {
|
"https": proxy,
|
||||||
# "http": self.proxy,
|
},
|
||||||
# "https": self.proxy,
|
)
|
||||||
# }
|
proxy = (
|
||||||
# self.session.proxies = proxies
|
proxy or os.environ.get("all_proxy") or os.environ.get("ALL_PROXY") or None
|
||||||
self.conversation: dict = {
|
)
|
||||||
|
|
||||||
|
if proxy:
|
||||||
|
if "socks5h" not in proxy:
|
||||||
|
self.aclient = httpx.AsyncClient(
|
||||||
|
follow_redirects=True,
|
||||||
|
proxies=proxy,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.aclient = httpx.AsyncClient(
|
||||||
|
follow_redirects=True,
|
||||||
|
proxies=proxy,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conversation: dict[str, list[dict]] = {
|
||||||
"default": [
|
"default": [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
@ -52,20 +78,13 @@ class Chatbot:
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
self.system_prompt = system_prompt
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.temperature = temperature
|
|
||||||
self.top_p = top_p
|
|
||||||
self.reply_count = reply_count
|
|
||||||
|
|
||||||
initial_conversation = "\n".join(
|
|
||||||
[x["content"] for x in self.conversation["default"]],
|
|
||||||
)
|
|
||||||
if len(ENCODER.encode(initial_conversation)) > self.max_tokens:
|
|
||||||
raise Exception("System prompt is too long")
|
|
||||||
|
|
||||||
def add_to_conversation(
|
def add_to_conversation(
|
||||||
self, message: str, role: str, convo_id: str = "default"
|
self,
|
||||||
|
message: str,
|
||||||
|
role: str,
|
||||||
|
convo_id: str = "default",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add a message to the conversation
|
Add a message to the conversation
|
||||||
|
@ -77,12 +96,8 @@ class Chatbot:
|
||||||
Truncate the conversation
|
Truncate the conversation
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
full_conversation = "".join(
|
|
||||||
message["role"] + ": " + message["content"] + "\n"
|
|
||||||
for message in self.conversation[convo_id]
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
len(ENCODER.encode(full_conversation)) > self.max_tokens
|
self.get_token_count(convo_id) > self.truncate_limit
|
||||||
and len(self.conversation[convo_id]) > 1
|
and len(self.conversation[convo_id]) > 1
|
||||||
):
|
):
|
||||||
# Don't remove the first message
|
# Don't remove the first message
|
||||||
|
@ -90,15 +105,41 @@ class Chatbot:
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def get_token_count(self, convo_id: str = "default") -> int:
|
||||||
|
"""
|
||||||
|
Get token count
|
||||||
|
"""
|
||||||
|
if self.engine not in [
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"gpt-3.5-turbo-0301",
|
||||||
|
"gpt-4",
|
||||||
|
"gpt-4-0314",
|
||||||
|
"gpt-4-32k",
|
||||||
|
"gpt-4-32k-0314",
|
||||||
|
]:
|
||||||
|
raise NotImplementedError("Unsupported engine {self.engine}")
|
||||||
|
|
||||||
|
tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base"
|
||||||
|
|
||||||
|
encoding = tiktoken.encoding_for_model(self.engine)
|
||||||
|
|
||||||
|
num_tokens = 0
|
||||||
|
for message in self.conversation[convo_id]:
|
||||||
|
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||||
|
num_tokens += 5
|
||||||
|
for key, value in message.items():
|
||||||
|
num_tokens += len(encoding.encode(value))
|
||||||
|
if key == "name": # if there's a name, the role is omitted
|
||||||
|
num_tokens += 5 # role is always required and always 1 token
|
||||||
|
num_tokens += 5 # every reply is primed with <im_start>assistant
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
def get_max_tokens(self, convo_id: str) -> int:
|
def get_max_tokens(self, convo_id: str) -> int:
|
||||||
"""
|
"""
|
||||||
Get max tokens
|
Get max tokens
|
||||||
"""
|
"""
|
||||||
full_conversation = "".join(
|
return self.max_tokens - self.get_token_count(convo_id)
|
||||||
message["role"] + ": " + message["content"] + "\n"
|
|
||||||
for message in self.conversation[convo_id]
|
|
||||||
)
|
|
||||||
return 4000 - len(ENCODER.encode(full_conversation))
|
|
||||||
|
|
||||||
def ask_stream(
|
def ask_stream(
|
||||||
self,
|
self,
|
||||||
|
@ -106,7 +147,7 @@ class Chatbot:
|
||||||
role: str = "user",
|
role: str = "user",
|
||||||
convo_id: str = "default",
|
convo_id: str = "default",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
):
|
||||||
"""
|
"""
|
||||||
Ask a question
|
Ask a question
|
||||||
"""
|
"""
|
||||||
|
@ -117,7 +158,7 @@ class Chatbot:
|
||||||
self.__truncate_conversation(convo_id=convo_id)
|
self.__truncate_conversation(convo_id=convo_id)
|
||||||
# Get response
|
# Get response
|
||||||
response = self.session.post(
|
response = self.session.post(
|
||||||
"https://api.openai.com/v1/chat/completions",
|
os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
|
||||||
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
||||||
json={
|
json={
|
||||||
"model": self.engine,
|
"model": self.engine,
|
||||||
|
@ -126,16 +167,22 @@ class Chatbot:
|
||||||
# kwargs
|
# kwargs
|
||||||
"temperature": kwargs.get("temperature", self.temperature),
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
"top_p": kwargs.get("top_p", self.top_p),
|
"top_p": kwargs.get("top_p", self.top_p),
|
||||||
|
"presence_penalty": kwargs.get(
|
||||||
|
"presence_penalty",
|
||||||
|
self.presence_penalty,
|
||||||
|
),
|
||||||
|
"frequency_penalty": kwargs.get(
|
||||||
|
"frequency_penalty",
|
||||||
|
self.frequency_penalty,
|
||||||
|
),
|
||||||
"n": kwargs.get("n", self.reply_count),
|
"n": kwargs.get("n", self.reply_count),
|
||||||
"user": role,
|
"user": role,
|
||||||
# "max_tokens": self.get_max_tokens(convo_id=convo_id),
|
"max_tokens": self.get_max_tokens(convo_id=convo_id),
|
||||||
},
|
},
|
||||||
|
timeout=kwargs.get("timeout", self.timeout),
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
|
||||||
raise Exception(
|
|
||||||
f"Error: {response.status_code} {response.reason} {response.text}",
|
|
||||||
)
|
|
||||||
response_role: str = None
|
response_role: str = None
|
||||||
full_response: str = ""
|
full_response: str = ""
|
||||||
for line in response.iter_lines():
|
for line in response.iter_lines():
|
||||||
|
@ -160,8 +207,100 @@ class Chatbot:
|
||||||
yield content
|
yield content
|
||||||
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
|
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
|
||||||
|
|
||||||
|
async def ask_stream_async(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
role: str = "user",
|
||||||
|
convo_id: str = "default",
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""
|
||||||
|
Ask a question
|
||||||
|
"""
|
||||||
|
# Make conversation if it doesn't exist
|
||||||
|
if convo_id not in self.conversation:
|
||||||
|
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
|
||||||
|
self.add_to_conversation(prompt, "user", convo_id=convo_id)
|
||||||
|
self.__truncate_conversation(convo_id=convo_id)
|
||||||
|
# Get response
|
||||||
|
async with self.aclient.stream(
|
||||||
|
"post",
|
||||||
|
os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
|
||||||
|
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
||||||
|
json={
|
||||||
|
"model": self.engine,
|
||||||
|
"messages": self.conversation[convo_id],
|
||||||
|
"stream": True,
|
||||||
|
# kwargs
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
"top_p": kwargs.get("top_p", self.top_p),
|
||||||
|
"presence_penalty": kwargs.get(
|
||||||
|
"presence_penalty",
|
||||||
|
self.presence_penalty,
|
||||||
|
),
|
||||||
|
"frequency_penalty": kwargs.get(
|
||||||
|
"frequency_penalty",
|
||||||
|
self.frequency_penalty,
|
||||||
|
),
|
||||||
|
"n": kwargs.get("n", self.reply_count),
|
||||||
|
"user": role,
|
||||||
|
"max_tokens": self.get_max_tokens(convo_id=convo_id),
|
||||||
|
},
|
||||||
|
timeout=kwargs.get("timeout", self.timeout),
|
||||||
|
) as response:
|
||||||
|
if response.status_code != 200:
|
||||||
|
await response.aread()
|
||||||
|
|
||||||
|
response_role: str = ""
|
||||||
|
full_response: str = ""
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
# Remove "data: "
|
||||||
|
line = line[6:]
|
||||||
|
if line == "[DONE]":
|
||||||
|
break
|
||||||
|
resp: dict = json.loads(line)
|
||||||
|
choices = resp.get("choices")
|
||||||
|
if not choices:
|
||||||
|
continue
|
||||||
|
delta: dict[str, str] = choices[0].get("delta")
|
||||||
|
if not delta:
|
||||||
|
continue
|
||||||
|
if "role" in delta:
|
||||||
|
response_role = delta["role"]
|
||||||
|
if "content" in delta:
|
||||||
|
content: str = delta["content"]
|
||||||
|
full_response += content
|
||||||
|
yield content
|
||||||
|
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
|
||||||
|
|
||||||
|
async def ask_async(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
role: str = "user",
|
||||||
|
convo_id: str = "default",
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Non-streaming ask
|
||||||
|
"""
|
||||||
|
response = self.ask_stream_async(
|
||||||
|
prompt=prompt,
|
||||||
|
role=role,
|
||||||
|
convo_id=convo_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
full_response: str = "".join([r async for r in response])
|
||||||
|
return full_response
|
||||||
|
|
||||||
def ask(
|
def ask(
|
||||||
self, prompt: str, role: str = "user", convo_id: str = "default", **kwargs
|
self,
|
||||||
|
prompt: str,
|
||||||
|
role: str = "user",
|
||||||
|
convo_id: str = "default",
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Non-streaming ask
|
Non-streaming ask
|
||||||
|
@ -175,12 +314,6 @@ class Chatbot:
|
||||||
full_response: str = "".join(response)
|
full_response: str = "".join(response)
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
def rollback(self, n: int = 1, convo_id: str = "default") -> None:
|
|
||||||
"""
|
|
||||||
Rollback the conversation
|
|
||||||
"""
|
|
||||||
for _ in range(n):
|
|
||||||
self.conversation[convo_id].pop()
|
|
||||||
|
|
||||||
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
|
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -189,4 +322,3 @@ class Chatbot:
|
||||||
self.conversation[convo_id] = [
|
self.conversation[convo_id] = [
|
||||||
{"role": "system", "content": system_prompt or self.system_prompt},
|
{"role": "system", "content": system_prompt or self.system_prompt},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue