refactor code
This commit is contained in:
parent
5c776a6a67
commit
4832d6f00b
17 changed files with 725 additions and 326 deletions
|
@ -5,12 +5,17 @@ Dockerfile
|
|||
.dockerignore
|
||||
config.json
|
||||
config.json.sample
|
||||
bot
|
||||
db
|
||||
bot.log
|
||||
venv
|
||||
compose.yaml
|
||||
.venv
|
||||
*.yaml
|
||||
*.yml
|
||||
.git
|
||||
.idea
|
||||
__pycache__
|
||||
venv
|
||||
.env
|
||||
.env.example
|
||||
.github
|
||||
settings.js
|
||||
|
||||
|
|
10
.env.example
Normal file
10
.env.example
Normal file
|
@ -0,0 +1,10 @@
|
|||
HOMESERVER="https://matrix.xxxxxx.xxxx" # required
|
||||
USER_ID="@lullap:xxxxxxxxxxxxx.xxx" # required
|
||||
PASSWORD="xxxxxxxxxxxxxxx" # required
|
||||
DEVICE_ID="xxxxxxxxxxxxxx" # required
|
||||
ROOM_ID="!FYCmBSkCRUXXXXXXXXX:matrix.XXX.XXX" # Optional, if the property is blank, bot will work on the room it is in (Unencrypted room only as for now)
|
||||
OPENAI_API_KEY="xxxxxxxxxxxxxxxxx" # Optional, for !chat and !gpt command
|
||||
BING_API_ENDPOINT="xxxxxxxxxxxxxxx" # Optional, for !bing command
|
||||
ACCESS_TOKEN="xxxxxxxxxxxxxxxxxxxxx" # Optional, use user_id and password is recommended
|
||||
JAILBREAKENABLED="true" # Optional
|
||||
BING_AUTH_COOKIE="xxxxxxxxxxxxxxxxxxx" # _U cookie, Optional, for Bing Image Creator
|
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -26,7 +26,7 @@ share/python-wheels/
|
|||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
bot
|
||||
db
|
||||
bot.log
|
||||
|
||||
# image generation folder
|
||||
|
@ -135,6 +135,7 @@ env.bak/
|
|||
venv.bak/
|
||||
config.json
|
||||
compose_local_build.yaml
|
||||
settings.js
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
|
|
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"python.analysis.typeCheckingMode": "off"
|
||||
}
|
|
@ -105,16 +105,5 @@ class ImageGen:
|
|||
with open(f"{output_dir}/{image_name}.jpeg", "wb") as output_file:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
output_file.write(chunk)
|
||||
# image_num = 0
|
||||
# for link in links:
|
||||
# with self.session.get(link, stream=True) as response:
|
||||
# # save response to file
|
||||
# response.raise_for_status()
|
||||
# with open(f"{output_dir}/{image_num}.jpeg", "wb") as output_file:
|
||||
# for chunk in response.iter_content(chunk_size=8192):
|
||||
# output_file.write(chunk)
|
||||
#
|
||||
# image_num += 1
|
||||
|
||||
# return image path
|
||||
return f"{output_dir}/{image_name}.jpeg"
|
||||
|
|
|
@ -2,7 +2,7 @@ FROM python:3.11-alpine as base
|
|||
|
||||
FROM base as pybuilder
|
||||
RUN sed -i 's|v3\.\d*|edge|' /etc/apk/repositories
|
||||
RUN apk update && apk add olm-dev gcc musl-dev tzdata libmagic
|
||||
RUN apk update && apk add olm-dev gcc musl-dev libmagic
|
||||
COPY requirements.txt /requirements.txt
|
||||
RUN pip3 install --user -r /requirements.txt && rm /requirements.txt
|
||||
|
||||
|
@ -11,12 +11,10 @@ FROM base as runner
|
|||
LABEL "org.opencontainers.image.source"="https://github.com/hibobmaster/matrix_chatgpt_bot"
|
||||
RUN apk update && apk add olm-dev libmagic
|
||||
COPY --from=pybuilder /root/.local /usr/local
|
||||
COPY --from=pybuilder /usr/share/zoneinfo /usr/share/zoneinfo
|
||||
COPY . /app
|
||||
|
||||
|
||||
FROM runner
|
||||
ENV TZ=Asia/Shanghai
|
||||
WORKDIR /app
|
||||
CMD ["python", "main.py"]
|
||||
|
||||
|
|
|
@ -3,11 +3,11 @@ This is a simple Matrix bot that uses OpenAI's GPT API and Bing AI to generate r
|
|||
![demo](https://i.imgur.com/kK4rnPf.jpeg "demo")
|
||||
|
||||
## Installation and Setup
|
||||
Docker method:<br>
|
||||
Edit `config.json` with proper values <br>
|
||||
Docker method(Recommended):<br>
|
||||
Edit `config.json` or `.env` with proper values <br>
|
||||
Create an empty file, for persist database only<br>
|
||||
```bash
|
||||
touch bot
|
||||
touch db
|
||||
sudo docker compose up -d
|
||||
```
|
||||
<hr>
|
||||
|
|
|
@ -4,8 +4,11 @@ import json
|
|||
from log import getlogger
|
||||
logger = getlogger()
|
||||
|
||||
class askGPT:
|
||||
def __init__(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
async def ask(prompt: str, api_endpoint: str, headers: dict) -> str:
|
||||
async def oneTimeAsk(self, prompt: str, api_endpoint: str, headers: dict) -> str:
|
||||
jsons = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
|
@ -15,11 +18,10 @@ async def ask(prompt: str, api_endpoint: str, headers: dict) -> str:
|
|||
},
|
||||
],
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
max_try = 5
|
||||
while max_try > 0:
|
||||
try:
|
||||
async with session.post(url=api_endpoint,
|
||||
async with self.session.post(url=api_endpoint,
|
||||
json=jsons, headers=headers, timeout=30) as response:
|
||||
status_code = response.status
|
||||
if not status_code == 200:
|
||||
|
@ -31,9 +33,6 @@ async def ask(prompt: str, api_endpoint: str, headers: dict) -> str:
|
|||
continue
|
||||
|
||||
resp = await response.read()
|
||||
await session.close()
|
||||
return json.loads(resp)['choices'][0]['message']['content']
|
||||
except Exception as e:
|
||||
logger.error("Error Exception", exc_info=True)
|
||||
print(e)
|
||||
pass
|
11
bing.py
11
bing.py
|
@ -4,31 +4,28 @@ import asyncio
|
|||
from log import getlogger
|
||||
# api_endpoint = "http://localhost:3000/conversation"
|
||||
logger = getlogger()
|
||||
python_boolean_to_json = {
|
||||
"true": True,
|
||||
}
|
||||
|
||||
|
||||
class BingBot:
|
||||
def __init__(self, bing_api_endpoint: str, jailbreakEnabled: bool = False):
|
||||
self.data = {
|
||||
# 'jailbreakConversationId': json.dumps(python_boolean_to_json['true']),
|
||||
'clientOptions.clientToUse': 'bing',
|
||||
}
|
||||
self.bing_api_endpoint = bing_api_endpoint
|
||||
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
self.jailbreakEnabled = jailbreakEnabled
|
||||
|
||||
if self.jailbreakEnabled:
|
||||
self.data['jailbreakConversationId'] = json.dumps(python_boolean_to_json['true'])
|
||||
self.data['jailbreakConversationId'] = json.dumps(True)
|
||||
|
||||
async def ask_bing(self, prompt) -> str:
|
||||
self.data['message'] = prompt
|
||||
async with aiohttp.ClientSession() as session:
|
||||
max_try = 5
|
||||
while max_try > 0:
|
||||
try:
|
||||
resp = await session.post(url=self.bing_api_endpoint, json=self.data)
|
||||
resp = await self.session.post(url=self.bing_api_endpoint, json=self.data)
|
||||
status_code = resp.status
|
||||
body = await resp.read()
|
||||
if not status_code == 200:
|
||||
|
|
391
bot.py
391
bot.py
|
@ -2,23 +2,33 @@ import sys
|
|||
import asyncio
|
||||
import re
|
||||
import os
|
||||
from typing import Optional
|
||||
from nio import AsyncClient, MatrixRoom, RoomMessageText, LoginResponse, AsyncClientConfig
|
||||
import traceback
|
||||
from typing import Optional, Union
|
||||
from nio import (
|
||||
AsyncClient,
|
||||
MatrixRoom,
|
||||
RoomMessageText,
|
||||
InviteMemberEvent,
|
||||
LoginResponse,
|
||||
JoinError,
|
||||
ToDeviceError,
|
||||
LocalProtocolError,
|
||||
KeyVerificationEvent,
|
||||
KeyVerificationStart,
|
||||
KeyVerificationCancel,
|
||||
KeyVerificationKey,
|
||||
KeyVerificationMac,
|
||||
AsyncClientConfig
|
||||
)
|
||||
from nio.store.database import SqliteStore
|
||||
from ask_gpt import ask
|
||||
from askgpt import askGPT
|
||||
from send_message import send_room_message
|
||||
from v3 import Chatbot
|
||||
from log import getlogger
|
||||
from bing import BingBot
|
||||
from BingImageGen import ImageGen
|
||||
from send_image import send_room_image
|
||||
"""
|
||||
free api_endpoint from https://github.com/ayaka14732/ChatGPTAPIFree
|
||||
"""
|
||||
chatgpt_api_endpoint_list = {
|
||||
"free": "https://chatgpt-api.shn.hk/v1/",
|
||||
"paid": "https://api.openai.com/v1/chat/completions"
|
||||
}
|
||||
|
||||
logger = getlogger()
|
||||
|
||||
|
||||
|
@ -27,35 +37,45 @@ class Bot:
|
|||
self,
|
||||
homeserver: str,
|
||||
user_id: str,
|
||||
password: str,
|
||||
device_id: str,
|
||||
api_key: Optional[str] = "",
|
||||
chatgpt_api_endpoint: str = os.environ.get("CHATGPT_API_ENDPOINT") or "https://api.openai.com/v1/chat/completions",
|
||||
api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") or "",
|
||||
room_id: Optional[str] = '',
|
||||
bing_api_endpoint: Optional[str] = '',
|
||||
access_token: Optional[str] = '',
|
||||
jailbreakEnabled: Optional[bool] = False,
|
||||
password: Union[str, None] = None,
|
||||
access_token: Union[str, None] = None,
|
||||
jailbreakEnabled: Optional[bool] = True,
|
||||
bing_auth_cookie: Optional[str] = '',
|
||||
):
|
||||
self.homeserver = homeserver
|
||||
self.user_id = user_id
|
||||
self.password = password
|
||||
self.access_token = access_token
|
||||
self.device_id = device_id
|
||||
self.room_id = room_id
|
||||
self.api_key = api_key
|
||||
self.chatgpt_api_endpoint = chatgpt_api_endpoint
|
||||
self.bing_api_endpoint = bing_api_endpoint
|
||||
self.jailbreakEnabled = jailbreakEnabled
|
||||
self.bing_auth_cookie = bing_auth_cookie
|
||||
# initialize AsyncClient object
|
||||
self.store_path = os.getcwd()
|
||||
self.config = AsyncClientConfig(store=SqliteStore,
|
||||
store_name="bot",
|
||||
store_name="db",
|
||||
store_sync_tokens=True,
|
||||
encryption_enabled=True,
|
||||
)
|
||||
self.client = AsyncClient(self.homeserver, user=self.user_id, device_id=self.device_id,
|
||||
self.client = AsyncClient(homeserver=self.homeserver, user=self.user_id, device_id=self.device_id,
|
||||
config=self.config, store_path=self.store_path,)
|
||||
if access_token != '':
|
||||
self.client.access_token = access_token
|
||||
|
||||
if self.access_token is not None:
|
||||
self.client.access_token = self.access_token
|
||||
|
||||
# setup event callbacks
|
||||
self.client.add_event_callback(self.message_callback, (RoomMessageText, ))
|
||||
self.client.add_event_callback(self.invite_callback, (InviteMemberEvent, ))
|
||||
self.client.add_to_device_callback(self.to_device_callback, (KeyVerificationEvent, ))
|
||||
|
||||
# regular expression to match keyword [!gpt {prompt}] [!chat {prompt}]
|
||||
self.gpt_prog = re.compile(r"^\s*!gpt\s*(.+)$")
|
||||
self.chat_prog = re.compile(r"^\s*!chat\s*(.+)$")
|
||||
|
@ -67,18 +87,21 @@ class Bot:
|
|||
if self.api_key != '':
|
||||
self.chatbot = Chatbot(api_key=self.api_key)
|
||||
|
||||
self.chatgpt_api_endpoint = chatgpt_api_endpoint_list['paid']
|
||||
self.chatgpt_api_endpoint = self.chatgpt_api_endpoint
|
||||
# request header for !gpt command
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key,
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
else:
|
||||
self.chatgpt_api_endpoint = chatgpt_api_endpoint_list['free']
|
||||
self.chatgpt_api_endpoint = self.chatgpt_api_endpoint
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# initialize askGPT class
|
||||
self.askgpt = askGPT()
|
||||
|
||||
# initialize bingbot
|
||||
if self.bing_api_endpoint != '':
|
||||
self.bingbot = BingBot(bing_api_endpoint, jailbreakEnabled=self.jailbreakEnabled)
|
||||
|
@ -100,86 +123,330 @@ class Bot:
|
|||
# reply event_id
|
||||
reply_to_event_id = event.event_id
|
||||
|
||||
# sender_id
|
||||
sender_id = event.sender
|
||||
|
||||
# user_message
|
||||
raw_user_message = event.body
|
||||
|
||||
# print info to console
|
||||
print(
|
||||
f"Message received in room {room.display_name}\n"
|
||||
f"{room.user_name(event.sender)} | {event.body}"
|
||||
f"{room.user_name(event.sender)} | {raw_user_message}"
|
||||
)
|
||||
|
||||
# prevent command trigger loop
|
||||
if self.user_id != event.sender:
|
||||
# remove newline character from event.body
|
||||
event.body = re.sub("\r\n|\r|\n", " ", event.body)
|
||||
content_body = re.sub("\r\n|\r|\n", " ", raw_user_message)
|
||||
|
||||
# chatgpt
|
||||
n = self.chat_prog.match(event.body)
|
||||
n = self.chat_prog.match(content_body)
|
||||
if n:
|
||||
prompt = n.group(1)
|
||||
if self.api_key != '':
|
||||
await self.gpt(room_id, reply_to_event_id, prompt)
|
||||
await self.chat(room_id, reply_to_event_id, prompt, sender_id, raw_user_message)
|
||||
else:
|
||||
logger.warning("No API_KEY provided")
|
||||
await send_room_message(self.client, room_id, send_text="API_KEY not provided")
|
||||
|
||||
m = self.gpt_prog.match(event.body)
|
||||
m = self.gpt_prog.match(content_body)
|
||||
if m:
|
||||
prompt = m.group(1)
|
||||
await self.chat(room_id, reply_to_event_id, prompt)
|
||||
await self.gpt(room_id, reply_to_event_id, prompt, sender_id, raw_user_message)
|
||||
|
||||
# bing ai
|
||||
if self.bing_api_endpoint != '':
|
||||
b = self.bing_prog.match(event.body)
|
||||
b = self.bing_prog.match(content_body)
|
||||
if b:
|
||||
prompt = b.group(1)
|
||||
await self.bing(room_id, reply_to_event_id, prompt)
|
||||
# raw_content_body used for construct formatted_body
|
||||
await self.bing(room_id, reply_to_event_id, prompt, sender_id, raw_user_message)
|
||||
|
||||
# Image Generation by Microsoft Bing
|
||||
if self.bing_auth_cookie != '':
|
||||
i = self.pic_prog.match(event.body)
|
||||
i = self.pic_prog.match(content_body)
|
||||
if i:
|
||||
prompt = i.group(1)
|
||||
await self.pic(room_id, prompt)
|
||||
|
||||
# help command
|
||||
h = self.help_prog.match(event.body)
|
||||
h = self.help_prog.match(content_body)
|
||||
if h:
|
||||
await self.help(room_id)
|
||||
|
||||
# !gpt command
|
||||
async def gpt(self, room_id, reply_to_event_id, prompt):
|
||||
await self.client.room_typing(room_id)
|
||||
# invite_callback event
|
||||
async def invite_callback(self, room: MatrixRoom, event: InviteMemberEvent) -> None:
|
||||
"""Handle an incoming invite event.
|
||||
https://github.com/8go/matrix-eno-bot/blob/ad037e02bd2960941109e9526c1033dd157bb212/callbacks.py#L104
|
||||
If an invite is received, then join the room specified in the invite.
|
||||
code copied from:
|
||||
"""
|
||||
logger.debug(f"Got invite to {room.room_id} from {event.sender}.")
|
||||
# Attempt to join 3 times before giving up
|
||||
for attempt in range(3):
|
||||
result = await self.client.join(room.room_id)
|
||||
if type(result) == JoinError:
|
||||
logger.error(
|
||||
f"Error joining room {room.room_id} (attempt %d): %s",
|
||||
attempt, result.message,
|
||||
)
|
||||
else:
|
||||
break
|
||||
else:
|
||||
logger.error("Unable to join room: %s", room.room_id)
|
||||
|
||||
# Successfully joined room
|
||||
logger.info(f"Joined {room.room_id}")
|
||||
|
||||
# to_device_callback event
|
||||
async def to_device_callback(self, event: KeyVerificationEvent) -> None:
|
||||
"""Handle events sent to device.
|
||||
|
||||
Specifically this will perform Emoji verification.
|
||||
It will accept an incoming Emoji verification requests
|
||||
and follow the verification protocol.
|
||||
code copied from: https://github.com/8go/matrix-eno-bot/blob/ad037e02bd2960941109e9526c1033dd157bb212/callbacks.py#L127
|
||||
"""
|
||||
try:
|
||||
# run synchronous function in different thread
|
||||
text = await asyncio.to_thread(self.chatbot.ask, prompt)
|
||||
text = text.strip()
|
||||
await send_room_message(self.client, room_id, send_text=text,
|
||||
reply_to_event_id=reply_to_event_id)
|
||||
client = self.client
|
||||
logger.debug(
|
||||
f"Device Event of type {type(event)} received in "
|
||||
"to_device_cb().")
|
||||
|
||||
if isinstance(event, KeyVerificationStart): # first step
|
||||
""" first step: receive KeyVerificationStart
|
||||
KeyVerificationStart(
|
||||
source={'content':
|
||||
{'method': 'm.sas.v1',
|
||||
'from_device': 'DEVICEIDXY',
|
||||
'key_agreement_protocols':
|
||||
['curve25519-hkdf-sha256', 'curve25519'],
|
||||
'hashes': ['sha256'],
|
||||
'message_authentication_codes':
|
||||
['hkdf-hmac-sha256', 'hmac-sha256'],
|
||||
'short_authentication_string':
|
||||
['decimal', 'emoji'],
|
||||
'transaction_id': 'SomeTxId'
|
||||
},
|
||||
'type': 'm.key.verification.start',
|
||||
'sender': '@user2:example.org'
|
||||
},
|
||||
sender='@user2:example.org',
|
||||
transaction_id='SomeTxId',
|
||||
from_device='DEVICEIDXY',
|
||||
method='m.sas.v1',
|
||||
key_agreement_protocols=[
|
||||
'curve25519-hkdf-sha256', 'curve25519'],
|
||||
hashes=['sha256'],
|
||||
message_authentication_codes=[
|
||||
'hkdf-hmac-sha256', 'hmac-sha256'],
|
||||
short_authentication_string=['decimal', 'emoji'])
|
||||
"""
|
||||
|
||||
if "emoji" not in event.short_authentication_string:
|
||||
estr = ("Other device does not support emoji verification "
|
||||
f"{event.short_authentication_string}. Aborting.")
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
return
|
||||
resp = await client.accept_key_verification(
|
||||
event.transaction_id)
|
||||
if isinstance(resp, ToDeviceError):
|
||||
estr = f"accept_key_verification() failed with {resp}"
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
|
||||
sas = client.key_verifications[event.transaction_id]
|
||||
|
||||
todevice_msg = sas.share_key()
|
||||
resp = await client.to_device(todevice_msg)
|
||||
if isinstance(resp, ToDeviceError):
|
||||
estr = f"to_device() failed with {resp}"
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
|
||||
elif isinstance(event, KeyVerificationCancel): # anytime
|
||||
""" at any time: receive KeyVerificationCancel
|
||||
KeyVerificationCancel(source={
|
||||
'content': {'code': 'm.mismatched_sas',
|
||||
'reason': 'Mismatched authentication string',
|
||||
'transaction_id': 'SomeTxId'},
|
||||
'type': 'm.key.verification.cancel',
|
||||
'sender': '@user2:example.org'},
|
||||
sender='@user2:example.org',
|
||||
transaction_id='SomeTxId',
|
||||
code='m.mismatched_sas',
|
||||
reason='Mismatched short authentication string')
|
||||
"""
|
||||
|
||||
# There is no need to issue a
|
||||
# client.cancel_key_verification(tx_id, reject=False)
|
||||
# here. The SAS flow is already cancelled.
|
||||
# We only need to inform the user.
|
||||
estr = (f"Verification has been cancelled by {event.sender} "
|
||||
f"for reason \"{event.reason}\".")
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
|
||||
elif isinstance(event, KeyVerificationKey): # second step
|
||||
""" Second step is to receive KeyVerificationKey
|
||||
KeyVerificationKey(
|
||||
source={'content': {
|
||||
'key': 'SomeCryptoKey',
|
||||
'transaction_id': 'SomeTxId'},
|
||||
'type': 'm.key.verification.key',
|
||||
'sender': '@user2:example.org'
|
||||
},
|
||||
sender='@user2:example.org',
|
||||
transaction_id='SomeTxId',
|
||||
key='SomeCryptoKey')
|
||||
"""
|
||||
sas = client.key_verifications[event.transaction_id]
|
||||
|
||||
print(f"{sas.get_emoji()}")
|
||||
# don't log the emojis
|
||||
|
||||
# The bot process must run in forground with a screen and
|
||||
# keyboard so that user can accept/reject via keyboard.
|
||||
# For emoji verification bot must not run as service or
|
||||
# in background.
|
||||
yn = input("Do the emojis match? (Y/N) (C for Cancel) ")
|
||||
if yn.lower() == "y":
|
||||
estr = ("Match! The verification for this "
|
||||
"device will be accepted.")
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
resp = await client.confirm_short_auth_string(
|
||||
event.transaction_id)
|
||||
if isinstance(resp, ToDeviceError):
|
||||
estr = ("confirm_short_auth_string() "
|
||||
f"failed with {resp}")
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
elif yn.lower() == "n": # no, don't match, reject
|
||||
estr = ("No match! Device will NOT be verified "
|
||||
"by rejecting verification.")
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
resp = await client.cancel_key_verification(
|
||||
event.transaction_id, reject=True)
|
||||
if isinstance(resp, ToDeviceError):
|
||||
estr = (f"cancel_key_verification failed with {resp}")
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
else: # C or anything for cancel
|
||||
estr = ("Cancelled by user! Verification will be "
|
||||
"cancelled.")
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
resp = await client.cancel_key_verification(
|
||||
event.transaction_id, reject=False)
|
||||
if isinstance(resp, ToDeviceError):
|
||||
estr = (f"cancel_key_verification failed with {resp}")
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
|
||||
elif isinstance(event, KeyVerificationMac): # third step
|
||||
""" Third step is to receive KeyVerificationMac
|
||||
KeyVerificationMac(
|
||||
source={'content': {
|
||||
'mac': {'ed25519:DEVICEIDXY': 'SomeKey1',
|
||||
'ed25519:SomeKey2': 'SomeKey3'},
|
||||
'keys': 'SomeCryptoKey4',
|
||||
'transaction_id': 'SomeTxId'},
|
||||
'type': 'm.key.verification.mac',
|
||||
'sender': '@user2:example.org'},
|
||||
sender='@user2:example.org',
|
||||
transaction_id='SomeTxId',
|
||||
mac={'ed25519:DEVICEIDXY': 'SomeKey1',
|
||||
'ed25519:SomeKey2': 'SomeKey3'},
|
||||
keys='SomeCryptoKey4')
|
||||
"""
|
||||
sas = client.key_verifications[event.transaction_id]
|
||||
try:
|
||||
todevice_msg = sas.get_mac()
|
||||
except LocalProtocolError as e:
|
||||
# e.g. it might have been cancelled by ourselves
|
||||
estr = (f"Cancelled or protocol error: Reason: {e}.\n"
|
||||
f"Verification with {event.sender} not concluded. "
|
||||
"Try again?")
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
else:
|
||||
resp = await client.to_device(todevice_msg)
|
||||
if isinstance(resp, ToDeviceError):
|
||||
estr = f"to_device failed with {resp}"
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
estr = (f"sas.we_started_it = {sas.we_started_it}\n"
|
||||
f"sas.sas_accepted = {sas.sas_accepted}\n"
|
||||
f"sas.canceled = {sas.canceled}\n"
|
||||
f"sas.timed_out = {sas.timed_out}\n"
|
||||
f"sas.verified = {sas.verified}\n"
|
||||
f"sas.verified_devices = {sas.verified_devices}\n")
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
estr = ("Emoji verification was successful!\n"
|
||||
"Initiate another Emoji verification from "
|
||||
"another device or room if desired. "
|
||||
"Or if done verifying, hit Control-C to stop the "
|
||||
"bot in order to restart it as a service or to "
|
||||
"run it in the background.")
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
else:
|
||||
estr = (f"Received unexpected event type {type(event)}. "
|
||||
f"Event is {event}. Event will be ignored.")
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
except BaseException:
|
||||
estr = traceback.format_exc()
|
||||
print(estr)
|
||||
logger.info(estr)
|
||||
|
||||
# !chat command
|
||||
async def chat(self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message):
|
||||
await self.client.room_typing(room_id, timeout=120000)
|
||||
try:
|
||||
text = await asyncio.wait_for(self.chatbot.ask_async(prompt), timeout=120)
|
||||
except TimeoutError as e:
|
||||
logger.error("timeoutException", exc_info=True)
|
||||
text = "Timeout error"
|
||||
except Exception as e:
|
||||
logger.error("Error", exc_info=True)
|
||||
print(f"Error: {e}")
|
||||
|
||||
# !chat command
|
||||
async def chat(self, room_id, reply_to_event_id, prompt):
|
||||
text = text.strip()
|
||||
try:
|
||||
await send_room_message(self.client, room_id, reply_message=text,
|
||||
reply_to_event_id=reply_to_event_id, sender_id=sender_id, user_message=raw_user_message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}", exc_info=True)
|
||||
|
||||
# !gpt command
|
||||
async def gpt(self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message):
|
||||
try:
|
||||
# sending typing state
|
||||
await self.client.room_typing(room_id)
|
||||
await self.client.room_typing(room_id, timeout=120000)
|
||||
# timeout 120s
|
||||
text = await asyncio.wait_for(ask(prompt, self.chatgpt_api_endpoint, self.headers), timeout=120)
|
||||
text = await asyncio.wait_for(self.askgpt.oneTimeAsk(prompt, self.chatgpt_api_endpoint, self.headers), timeout=120)
|
||||
except TimeoutError:
|
||||
logger.error("timeoutException", exc_info=True)
|
||||
text = "Timeout error"
|
||||
|
||||
text = text.strip()
|
||||
try:
|
||||
await send_room_message(self.client, room_id, send_text=text,
|
||||
reply_to_event_id=reply_to_event_id)
|
||||
await send_room_message(self.client, room_id, reply_message=text,
|
||||
reply_to_event_id=reply_to_event_id, sender_id=sender_id, user_message=raw_user_message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}", exc_info=True)
|
||||
|
||||
# !bing command
|
||||
async def bing(self, room_id, reply_to_event_id, prompt):
|
||||
async def bing(self, room_id, reply_to_event_id, prompt, sender_id, raw_content_body):
|
||||
try:
|
||||
# sending typing state
|
||||
await self.client.room_typing(room_id)
|
||||
await self.client.room_typing(room_id, timeout=120000)
|
||||
# timeout 120s
|
||||
text = await asyncio.wait_for(self.bingbot.ask_bing(prompt), timeout=120)
|
||||
except TimeoutError:
|
||||
|
@ -187,8 +454,8 @@ class Bot:
|
|||
text = "Timeout error"
|
||||
text = text.strip()
|
||||
try:
|
||||
await send_room_message(self.client, room_id, send_text=text,
|
||||
reply_to_event_id=reply_to_event_id)
|
||||
await send_room_message(self.client, room_id, reply_message=text,
|
||||
reply_to_event_id=reply_to_event_id, sender=sender_id, raw_content_body=raw_content_body)
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}", exc_info=True)
|
||||
|
||||
|
@ -216,7 +483,7 @@ class Bot:
|
|||
"!bing [content], chat with context conversation powered by Bing AI\n" + \
|
||||
"!pic [prompt], Image generation by Microsoft Bing"
|
||||
|
||||
await send_room_message(self.client, room_id, send_text=help_info)
|
||||
await send_room_message(self.client, room_id, reply_message=help_info)
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}", exc_info=True)
|
||||
|
||||
|
@ -232,6 +499,24 @@ class Bot:
|
|||
logger.error(f"Error: {e}", exc_info=True)
|
||||
|
||||
# sync messages in the room
|
||||
async def sync_forever(self, timeout=30000):
|
||||
self.client.add_event_callback(self.message_callback, RoomMessageText)
|
||||
await self.client.sync_forever(timeout=timeout, full_state=True)
|
||||
async def sync_forever(self, timeout=30000, full_state=True):
|
||||
await self.client.sync_forever(timeout=timeout, full_state=full_state)
|
||||
|
||||
# Sync encryption keys with the server
|
||||
async def sync_encryption_key(self):
|
||||
if self.client.should_upload_keys:
|
||||
await self.client.keys_upload()
|
||||
|
||||
# Trust own devices
|
||||
async def trust_own_devices(self):
|
||||
await self.client.sync(timeout=30000, full_state=True)
|
||||
for device_id, olm_device in self.client.device_store[
|
||||
self.user_id].items():
|
||||
logger.debug("My other devices are: "
|
||||
f"device_id={device_id}, "
|
||||
f"olm_device={olm_device}.")
|
||||
logger.info("Setting up trust for my own "
|
||||
f"device {device_id} and session key "
|
||||
f"{olm_device.keys['ed25519']}.")
|
||||
self.client.verify_device(olm_device)
|
||||
|
|
@ -6,10 +6,13 @@ services:
|
|||
# build:
|
||||
# context: .
|
||||
# dockerfile: ./Dockerfile
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ./config.json:/app/config.json
|
||||
# use touch to create an empty file bot, for persist database only
|
||||
- ./bot:/app/bot
|
||||
# use env file or config.json
|
||||
# - ./config.json:/app/config.json
|
||||
# use touch to create an empty file db, for persist database only
|
||||
- ./db:/app/db
|
||||
networks:
|
||||
- matrix_network
|
||||
# api:
|
||||
|
|
51
main.py
51
main.py
|
@ -1,30 +1,53 @@
|
|||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
from bot import Bot
|
||||
from nio import Api, SyncResponse
|
||||
from log import getlogger
|
||||
|
||||
logger = getlogger()
|
||||
|
||||
async def main():
|
||||
fp = open('config.json', 'r')
|
||||
if os.path.exists('config.json'):
|
||||
fp = open('config.json', 'r', encoding="utf8")
|
||||
config = json.load(fp)
|
||||
matrix_bot = Bot(homeserver=config['homeserver'],
|
||||
user_id=config['user_id'],
|
||||
password=config.get('password', ''), # provide a default value when the key does not exist
|
||||
device_id=config['device_id'],
|
||||
room_id=config.get('room_id', ''),
|
||||
api_key=config.get('api_key', ''),
|
||||
bing_api_endpoint=config.get('bing_api_endpoint', ''),
|
||||
access_token=config.get('access_token', ''),
|
||||
jailbreakEnabled=config.get('jailbreakEnabled', False),
|
||||
bing_auth_cookie=config.get('bing_auth_cookie', ''),
|
||||
|
||||
matrix_bot = Bot(homeserver=os.environ.get("HOMESERVER") or config.get('homeserver'),
|
||||
user_id=os.environ.get("USER_ID") or config.get('user_id') ,
|
||||
password=os.environ.get("PASSWORD") or config.get('password'),
|
||||
device_id=os.environ.get("DEVICE_ID") or config.get('device_id'),
|
||||
room_id=os.environ.get("ROOM_ID") or config.get('room_id'),
|
||||
api_key=os.environ.get("OPENAI_API_KEY") or config.get('api_key'),
|
||||
bing_api_endpoint=os.environ.get("BING_API_ENDPOINT") or config.get('bing_api_endpoint'),
|
||||
access_token=os.environ.get("ACCESS_TOKEN") or config.get('access_token'),
|
||||
jailbreakEnabled=os.environ.get("JAILBREAKENABLED", "False").lower() in ('true', '1') or config.get('jailbreakEnabled'),
|
||||
bing_auth_cookie=os.environ.get("BING_AUTH_COOKIE") or config.get('bing_auth_cookie'),
|
||||
)
|
||||
if config.get('access_token', '') == '':
|
||||
# if not set access_token, then login via password
|
||||
# if os.path.exists('config.json'):
|
||||
# fp = open('config.json', 'r', encoding="utf8")
|
||||
# config = json.load(fp)
|
||||
# if os.environ.get("ACCESS_TOKEN") is None and config.get("access_token") is None:
|
||||
# await matrix_bot.login()
|
||||
|
||||
# elif os.environ.get("ACCESS_TOKEN") is None:
|
||||
# await matrix_bot.login()
|
||||
|
||||
await matrix_bot.login()
|
||||
await matrix_bot.sync_forever()
|
||||
|
||||
# await matrix_bot.sync_encryption_key()
|
||||
|
||||
# await matrix_bot.trust_own_devices()
|
||||
|
||||
try:
|
||||
await matrix_bot.sync_forever(timeout=3000, full_state=True)
|
||||
finally:
|
||||
await matrix_bot.client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("matrix chatgpt bot start.....")
|
||||
logger.debug("matrix chatgpt bot start.....")
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
|
|
|
@ -1,41 +1,46 @@
|
|||
aiofiles==0.6.0
|
||||
aiofiles==23.1.0
|
||||
aiohttp==3.8.4
|
||||
aiohttp-socks==0.7.1
|
||||
aiosignal==1.3.1
|
||||
anyio==3.6.2
|
||||
async-timeout==4.0.2
|
||||
atomicwrites==1.4.1
|
||||
attrs==22.2.0
|
||||
blobfile==2.0.1
|
||||
cachetools==4.2.4
|
||||
cachetools==5.3.0
|
||||
certifi==2022.12.7
|
||||
cffi==1.15.1
|
||||
charset-normalizer==3.0.1
|
||||
filelock==3.9.0
|
||||
charset-normalizer==3.1.0
|
||||
filelock==3.10.7
|
||||
frozenlist==1.3.3
|
||||
future==0.18.3
|
||||
h11==0.12.0
|
||||
h11==0.14.0
|
||||
h2==4.1.0
|
||||
hpack==4.0.0
|
||||
httpcore==0.16.3
|
||||
httpx==0.23.3
|
||||
hyperframe==6.0.1
|
||||
idna==3.4
|
||||
jsonschema==4.17.3
|
||||
Logbook==1.5.3
|
||||
lxml==4.9.2
|
||||
matrix-nio==0.20.1
|
||||
matrix-nio[e2e]
|
||||
multidict==6.0.4
|
||||
peewee==3.16.0
|
||||
Pillow==9.4.0
|
||||
Pillow==9.5.0
|
||||
pycparser==2.21
|
||||
pycryptodome==3.17
|
||||
pycryptodomex==3.17
|
||||
pyrsistent==0.19.3
|
||||
python-magic==0.4.27
|
||||
python-olm==3.1.3
|
||||
python-socks==2.1.1
|
||||
regex==2022.10.31
|
||||
python-socks==2.2.0
|
||||
regex==2023.3.23
|
||||
requests==2.28.2
|
||||
tiktoken==0.3.0
|
||||
rfc3986==1.5.0
|
||||
sniffio==1.3.0
|
||||
unpaddedbase64==2.1.0
|
||||
urllib3==1.26.14
|
||||
urllib3==1.26.15
|
||||
wcwidth==0.2.6
|
||||
yarl==1.8.2
|
||||
python-olm >= '3.1.0'
|
||||
tiktoken==0.3.3
|
||||
|
|
|
@ -1,14 +1,21 @@
|
|||
from nio import AsyncClient
|
||||
|
||||
|
||||
async def send_room_message(client: AsyncClient,
|
||||
room_id: str,
|
||||
send_text: str,
|
||||
reply_message: str,
|
||||
sender_id: str = '',
|
||||
user_message: str = '',
|
||||
reply_to_event_id: str = '') -> None:
|
||||
if reply_to_event_id == '':
|
||||
content = {"msgtype": "m.text", "body": f"{send_text}", }
|
||||
content = {"msgtype": "m.text", "body": reply_message, }
|
||||
else:
|
||||
content={"msgtype": "m.text", "body": f"{send_text}",
|
||||
body = r'> <' + sender_id + r'> ' + user_message + r'\n\n' + reply_message
|
||||
format = r'org.matrix.custom.html'
|
||||
formatted_body = r'<mx-reply><blockquote><a href=\"https://matrix.to/#/' + room_id + r'/' + reply_to_event_id
|
||||
+ r'\">In reply to</a> <a href=\"https://matrix.to/#/' + sender_id + r'\">' + sender_id
|
||||
+ r'</a><br>' + user_message + r'</blockquote></mx-reply>' + reply_message
|
||||
|
||||
content={"msgtype": "m.text", "body": body, "format": format, "formatted_body": formatted_body,
|
||||
"m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}}, }
|
||||
await client.room_send(
|
||||
room_id,
|
||||
|
|
58
test.py
58
test.py
|
@ -1,58 +0,0 @@
|
|||
from v3 import Chatbot
|
||||
import asyncio
|
||||
from ask_gpt import ask
|
||||
import json
|
||||
from bing import BingBot
|
||||
|
||||
fp = open("config.json", "r")
|
||||
config = json.load(fp)
|
||||
api_key = config.get('api_key', '')
|
||||
bing_api_endpoint = config.get('bing_api_endpoint', '')
|
||||
api_endpoint_list = {
|
||||
"free": "https://chatgpt-api.shn.hk/v1/",
|
||||
"paid": "https://api.openai.com/v1/chat/completions"
|
||||
}
|
||||
|
||||
|
||||
def test_v3(prompt: str):
|
||||
bot = Chatbot(api_key=api_key)
|
||||
resp = bot.ask(prompt=prompt)
|
||||
print(resp)
|
||||
|
||||
|
||||
async def test_ask_gpt_paid(prompt: str):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + api_key,
|
||||
}
|
||||
api_endpoint = api_endpoint_list['paid']
|
||||
# test ask_gpt.py ask()
|
||||
print(await ask(prompt, api_endpoint, headers))
|
||||
|
||||
|
||||
async def test_ask_gpt_free(prompt: str):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
api_endpoint = api_endpoint_list['free']
|
||||
print(await ask(prompt, api_endpoint, headers))
|
||||
|
||||
|
||||
async def test_bingbot():
|
||||
if bing_api_endpoint != '':
|
||||
bingbot = BingBot(bing_api_endpoint)
|
||||
prompt1 = "Hello World"
|
||||
prompt2 = "Do you know Victor Marie Hugo"
|
||||
prompt3 = "Can you tell me something about him?"
|
||||
resp1 = await bingbot.ask_bing(prompt1)
|
||||
resp2 = await bingbot.ask_bing(prompt2)
|
||||
resp3 = await bingbot.ask_bing(prompt3)
|
||||
print(resp1)
|
||||
print(resp2)
|
||||
print(resp3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_v3("Hello World")
|
||||
asyncio.run(test_ask_gpt_paid("Hello World"))
|
||||
asyncio.run(test_ask_gpt_free("Hello World"))
|
||||
asyncio.run(test_bingbot())
|
250
v3.py
250
v3.py
|
@ -1,20 +1,12 @@
|
|||
"""
|
||||
A simple wrapper for the official ChatGPT API
|
||||
https://github.com/acheong08/ChatGPT/blob/main/src/revChatGPT/V3.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
import tiktoken
|
||||
|
||||
|
||||
ENGINE = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo"
|
||||
ENCODER = tiktoken.get_encoding("gpt2")
|
||||
|
||||
|
||||
class Chatbot:
|
||||
"""
|
||||
Official ChatGPT API
|
||||
|
@ -22,29 +14,63 @@ class Chatbot:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = None,
|
||||
engine: str = None,
|
||||
api_key: str,
|
||||
engine: str = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo",
|
||||
proxy: str = None,
|
||||
max_tokens: int = 4096,
|
||||
timeout: float = None,
|
||||
max_tokens: int = None,
|
||||
temperature: float = 0.5,
|
||||
top_p: float = 1.0,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
reply_count: int = 1,
|
||||
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
|
||||
"""
|
||||
self.engine = engine or ENGINE
|
||||
self.engine: str = engine
|
||||
self.api_key: str = api_key
|
||||
self.system_prompt: str = system_prompt
|
||||
self.max_tokens: int = max_tokens or (
|
||||
31000 if engine == "gpt-4-32k" else 7000 if engine == "gpt-4" else 4000
|
||||
)
|
||||
self.truncate_limit: int = (
|
||||
30500 if engine == "gpt-4-32k" else 6500 if engine == "gpt-4" else 3500
|
||||
)
|
||||
self.temperature: float = temperature
|
||||
self.top_p: float = top_p
|
||||
self.presence_penalty: float = presence_penalty
|
||||
self.frequency_penalty: float = frequency_penalty
|
||||
self.reply_count: int = reply_count
|
||||
self.timeout: float = timeout
|
||||
self.proxy = proxy
|
||||
self.session = requests.Session()
|
||||
self.api_key = api_key
|
||||
# self.proxy = proxy
|
||||
# if self.proxy:
|
||||
# proxies = {
|
||||
# "http": self.proxy,
|
||||
# "https": self.proxy,
|
||||
# }
|
||||
# self.session.proxies = proxies
|
||||
self.conversation: dict = {
|
||||
self.session.proxies.update(
|
||||
{
|
||||
"http": proxy,
|
||||
"https": proxy,
|
||||
},
|
||||
)
|
||||
proxy = (
|
||||
proxy or os.environ.get("all_proxy") or os.environ.get("ALL_PROXY") or None
|
||||
)
|
||||
|
||||
if proxy:
|
||||
if "socks5h" not in proxy:
|
||||
self.aclient = httpx.AsyncClient(
|
||||
follow_redirects=True,
|
||||
proxies=proxy,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
self.aclient = httpx.AsyncClient(
|
||||
follow_redirects=True,
|
||||
proxies=proxy,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.conversation: dict[str, list[dict]] = {
|
||||
"default": [
|
||||
{
|
||||
"role": "system",
|
||||
|
@ -52,20 +78,13 @@ class Chatbot:
|
|||
},
|
||||
],
|
||||
}
|
||||
self.system_prompt = system_prompt
|
||||
self.max_tokens = max_tokens
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.reply_count = reply_count
|
||||
|
||||
initial_conversation = "\n".join(
|
||||
[x["content"] for x in self.conversation["default"]],
|
||||
)
|
||||
if len(ENCODER.encode(initial_conversation)) > self.max_tokens:
|
||||
raise Exception("System prompt is too long")
|
||||
|
||||
def add_to_conversation(
|
||||
self, message: str, role: str, convo_id: str = "default"
|
||||
self,
|
||||
message: str,
|
||||
role: str,
|
||||
convo_id: str = "default",
|
||||
) -> None:
|
||||
"""
|
||||
Add a message to the conversation
|
||||
|
@ -77,12 +96,8 @@ class Chatbot:
|
|||
Truncate the conversation
|
||||
"""
|
||||
while True:
|
||||
full_conversation = "".join(
|
||||
message["role"] + ": " + message["content"] + "\n"
|
||||
for message in self.conversation[convo_id]
|
||||
)
|
||||
if (
|
||||
len(ENCODER.encode(full_conversation)) > self.max_tokens
|
||||
self.get_token_count(convo_id) > self.truncate_limit
|
||||
and len(self.conversation[convo_id]) > 1
|
||||
):
|
||||
# Don't remove the first message
|
||||
|
@ -90,15 +105,41 @@ class Chatbot:
|
|||
else:
|
||||
break
|
||||
|
||||
|
||||
def get_token_count(self, convo_id: str = "default") -> int:
|
||||
"""
|
||||
Get token count
|
||||
"""
|
||||
if self.engine not in [
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0301",
|
||||
"gpt-4",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0314",
|
||||
]:
|
||||
raise NotImplementedError("Unsupported engine {self.engine}")
|
||||
|
||||
tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base"
|
||||
|
||||
encoding = tiktoken.encoding_for_model(self.engine)
|
||||
|
||||
num_tokens = 0
|
||||
for message in self.conversation[convo_id]:
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
num_tokens += 5
|
||||
for key, value in message.items():
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name": # if there's a name, the role is omitted
|
||||
num_tokens += 5 # role is always required and always 1 token
|
||||
num_tokens += 5 # every reply is primed with <im_start>assistant
|
||||
return num_tokens
|
||||
|
||||
def get_max_tokens(self, convo_id: str) -> int:
|
||||
"""
|
||||
Get max tokens
|
||||
"""
|
||||
full_conversation = "".join(
|
||||
message["role"] + ": " + message["content"] + "\n"
|
||||
for message in self.conversation[convo_id]
|
||||
)
|
||||
return 4000 - len(ENCODER.encode(full_conversation))
|
||||
return self.max_tokens - self.get_token_count(convo_id)
|
||||
|
||||
def ask_stream(
|
||||
self,
|
||||
|
@ -106,7 +147,7 @@ class Chatbot:
|
|||
role: str = "user",
|
||||
convo_id: str = "default",
|
||||
**kwargs,
|
||||
) -> str:
|
||||
):
|
||||
"""
|
||||
Ask a question
|
||||
"""
|
||||
|
@ -117,7 +158,7 @@ class Chatbot:
|
|||
self.__truncate_conversation(convo_id=convo_id)
|
||||
# Get response
|
||||
response = self.session.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
||||
json={
|
||||
"model": self.engine,
|
||||
|
@ -126,16 +167,22 @@ class Chatbot:
|
|||
# kwargs
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"top_p": kwargs.get("top_p", self.top_p),
|
||||
"presence_penalty": kwargs.get(
|
||||
"presence_penalty",
|
||||
self.presence_penalty,
|
||||
),
|
||||
"frequency_penalty": kwargs.get(
|
||||
"frequency_penalty",
|
||||
self.frequency_penalty,
|
||||
),
|
||||
"n": kwargs.get("n", self.reply_count),
|
||||
"user": role,
|
||||
# "max_tokens": self.get_max_tokens(convo_id=convo_id),
|
||||
"max_tokens": self.get_max_tokens(convo_id=convo_id),
|
||||
},
|
||||
timeout=kwargs.get("timeout", self.timeout),
|
||||
stream=True,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Error: {response.status_code} {response.reason} {response.text}",
|
||||
)
|
||||
|
||||
response_role: str = None
|
||||
full_response: str = ""
|
||||
for line in response.iter_lines():
|
||||
|
@ -160,8 +207,100 @@ class Chatbot:
|
|||
yield content
|
||||
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
|
||||
|
||||
async def ask_stream_async(
|
||||
self,
|
||||
prompt: str,
|
||||
role: str = "user",
|
||||
convo_id: str = "default",
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Ask a question
|
||||
"""
|
||||
# Make conversation if it doesn't exist
|
||||
if convo_id not in self.conversation:
|
||||
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
|
||||
self.add_to_conversation(prompt, "user", convo_id=convo_id)
|
||||
self.__truncate_conversation(convo_id=convo_id)
|
||||
# Get response
|
||||
async with self.aclient.stream(
|
||||
"post",
|
||||
os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
||||
json={
|
||||
"model": self.engine,
|
||||
"messages": self.conversation[convo_id],
|
||||
"stream": True,
|
||||
# kwargs
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"top_p": kwargs.get("top_p", self.top_p),
|
||||
"presence_penalty": kwargs.get(
|
||||
"presence_penalty",
|
||||
self.presence_penalty,
|
||||
),
|
||||
"frequency_penalty": kwargs.get(
|
||||
"frequency_penalty",
|
||||
self.frequency_penalty,
|
||||
),
|
||||
"n": kwargs.get("n", self.reply_count),
|
||||
"user": role,
|
||||
"max_tokens": self.get_max_tokens(convo_id=convo_id),
|
||||
},
|
||||
timeout=kwargs.get("timeout", self.timeout),
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
await response.aread()
|
||||
|
||||
response_role: str = ""
|
||||
full_response: str = ""
|
||||
async for line in response.aiter_lines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Remove "data: "
|
||||
line = line[6:]
|
||||
if line == "[DONE]":
|
||||
break
|
||||
resp: dict = json.loads(line)
|
||||
choices = resp.get("choices")
|
||||
if not choices:
|
||||
continue
|
||||
delta: dict[str, str] = choices[0].get("delta")
|
||||
if not delta:
|
||||
continue
|
||||
if "role" in delta:
|
||||
response_role = delta["role"]
|
||||
if "content" in delta:
|
||||
content: str = delta["content"]
|
||||
full_response += content
|
||||
yield content
|
||||
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
|
||||
|
||||
async def ask_async(
|
||||
self,
|
||||
prompt: str,
|
||||
role: str = "user",
|
||||
convo_id: str = "default",
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Non-streaming ask
|
||||
"""
|
||||
response = self.ask_stream_async(
|
||||
prompt=prompt,
|
||||
role=role,
|
||||
convo_id=convo_id,
|
||||
**kwargs,
|
||||
)
|
||||
full_response: str = "".join([r async for r in response])
|
||||
return full_response
|
||||
|
||||
def ask(
|
||||
self, prompt: str, role: str = "user", convo_id: str = "default", **kwargs
|
||||
self,
|
||||
prompt: str,
|
||||
role: str = "user",
|
||||
convo_id: str = "default",
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Non-streaming ask
|
||||
|
@ -175,12 +314,6 @@ class Chatbot:
|
|||
full_response: str = "".join(response)
|
||||
return full_response
|
||||
|
||||
def rollback(self, n: int = 1, convo_id: str = "default") -> None:
|
||||
"""
|
||||
Rollback the conversation
|
||||
"""
|
||||
for _ in range(n):
|
||||
self.conversation[convo_id].pop()
|
||||
|
||||
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
|
||||
"""
|
||||
|
@ -189,4 +322,3 @@ class Chatbot:
|
|||
self.conversation[convo_id] = [
|
||||
{"role": "system", "content": system_prompt or self.system_prompt},
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in a new issue