diff --git a/.dockerignore b/.dockerignore
index ce10b6f..66cc0ee 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -5,12 +5,17 @@ Dockerfile
.dockerignore
config.json
config.json.sample
-bot
+db
bot.log
venv
-compose.yaml
+.venv
+*.yaml
+*.yml
.git
.idea
__pycache__
-venv
+.env
+.env.example
+.github
+settings.js
diff --git a/.env.example b/.env.example
new file mode 100644
index 0000000..a0dfd51
--- /dev/null
+++ b/.env.example
@@ -0,0 +1,10 @@
+HOMESERVER="https://matrix.xxxxxx.xxxx" # required
+USER_ID="@lullap:xxxxxxxxxxxxx.xxx" # required
+PASSWORD="xxxxxxxxxxxxxxx" # required
+DEVICE_ID="xxxxxxxxxxxxxx" # required
+ROOM_ID="!FYCmBSkCRUXXXXXXXXX:matrix.XXX.XXX" # Optional, if the property is blank, bot will work on the room it is in (Unencrypted room only as for now)
+OPENAI_API_KEY="xxxxxxxxxxxxxxxxx" # Optional, for !chat and !gpt command
+BING_API_ENDPOINT="xxxxxxxxxxxxxxx" # Optional, for !bing command
+ACCESS_TOKEN="xxxxxxxxxxxxxxxxxxxxx" # Optional, use user_id and password is recommended
+JAILBREAKENABLED="true" # Optional
+BING_AUTH_COOKIE="xxxxxxxxxxxxxxxxxxx" # _U cookie, Optional, for Bing Image Creator
diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml
index 18e1252..2859316 100644
--- a/.github/workflows/docker-release.yml
+++ b/.github/workflows/docker-release.yml
@@ -1,51 +1,51 @@
-name: Publish Docker image
-
-on:
- workflow_dispatch:
- release:
- types: [published]
-
-jobs:
- push_to_registry:
- name: Push Docker image to Docker Hub
- runs-on: ubuntu-latest
- steps:
- -
- name: Check out the repo
- uses: actions/checkout@v3
- -
- name: Log in to Docker Hub
- uses: docker/login-action@v2
- with:
- username: ${{ secrets.DOCKER_USERNAME }}
- password: ${{ secrets.DOCKER_PASSWORD }}
-
- -
- name: Docker metadata
- id: meta
- uses: docker/metadata-action@v4
- with:
- images: hibobmaster/matrixchatgptbot
- tags: |
- type=raw,value=latest
- type=ref,event=tag
-
- -
- name: Set up QEMU
- uses: docker/setup-qemu-action@v2
-
- -
- name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v2
-
- -
- name: Build and push Docker image
- uses: docker/build-push-action@v4
- with:
- context: .
- platforms: linux/386,linux/amd64,linux/arm/v6,linux/arm/v7,linux/arm64/v8,linux/ppc64le,linux/s390x
- push: true
- tags: ${{ steps.meta.outputs.tags }}
- labels: ${{ steps.meta.outputs.labels }}
- cache-from: type=gha
+name: Publish Docker image
+
+on:
+ workflow_dispatch:
+ release:
+ types: [published]
+
+jobs:
+ push_to_registry:
+ name: Push Docker image to Docker Hub
+ runs-on: ubuntu-latest
+ steps:
+ -
+ name: Check out the repo
+ uses: actions/checkout@v3
+ -
+ name: Log in to Docker Hub
+ uses: docker/login-action@v2
+ with:
+ username: ${{ secrets.DOCKER_USERNAME }}
+ password: ${{ secrets.DOCKER_PASSWORD }}
+
+ -
+ name: Docker metadata
+ id: meta
+ uses: docker/metadata-action@v4
+ with:
+ images: hibobmaster/matrixchatgptbot
+ tags: |
+ type=raw,value=latest
+ type=ref,event=tag
+
+ -
+ name: Set up QEMU
+ uses: docker/setup-qemu-action@v2
+
+ -
+ name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v2
+
+ -
+ name: Build and push Docker image
+ uses: docker/build-push-action@v4
+ with:
+ context: .
+ platforms: linux/386,linux/amd64,linux/arm/v6,linux/arm/v7,linux/arm64/v8,linux/ppc64le,linux/s390x
+ push: true
+ tags: ${{ steps.meta.outputs.tags }}
+ labels: ${{ steps.meta.outputs.labels }}
+ cache-from: type=gha
cache-to: type=gha,mode=max
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index e05db87..f4817b8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -26,7 +26,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
-bot
+db
bot.log
# image generation folder
@@ -135,6 +135,7 @@ env.bak/
venv.bak/
config.json
compose_local_build.yaml
+settings.js
# Spyder project settings
.spyderproject
diff --git a/.vscode/settings.json b/.vscode/settings.json
new file mode 100644
index 0000000..a6735e5
--- /dev/null
+++ b/.vscode/settings.json
@@ -0,0 +1,3 @@
+{
+ "python.analysis.typeCheckingMode": "off"
+}
\ No newline at end of file
diff --git a/BingImageGen.py b/BingImageGen.py
index 65c83a0..76c4936 100644
--- a/BingImageGen.py
+++ b/BingImageGen.py
@@ -105,16 +105,5 @@ class ImageGen:
with open(f"{output_dir}/{image_name}.jpeg", "wb") as output_file:
for chunk in response.iter_content(chunk_size=8192):
output_file.write(chunk)
- # image_num = 0
- # for link in links:
- # with self.session.get(link, stream=True) as response:
- # # save response to file
- # response.raise_for_status()
- # with open(f"{output_dir}/{image_num}.jpeg", "wb") as output_file:
- # for chunk in response.iter_content(chunk_size=8192):
- # output_file.write(chunk)
- #
- # image_num += 1
- # return image path
return f"{output_dir}/{image_name}.jpeg"
diff --git a/Dockerfile b/Dockerfile
index d4755fb..afe6df6 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -2,7 +2,7 @@ FROM python:3.11-alpine as base
FROM base as pybuilder
RUN sed -i 's|v3\.\d*|edge|' /etc/apk/repositories
-RUN apk update && apk add olm-dev gcc musl-dev tzdata libmagic
+RUN apk update && apk add olm-dev gcc musl-dev libmagic
COPY requirements.txt /requirements.txt
RUN pip3 install --user -r /requirements.txt && rm /requirements.txt
@@ -11,12 +11,10 @@ FROM base as runner
LABEL "org.opencontainers.image.source"="https://github.com/hibobmaster/matrix_chatgpt_bot"
RUN apk update && apk add olm-dev libmagic
COPY --from=pybuilder /root/.local /usr/local
-COPY --from=pybuilder /usr/share/zoneinfo /usr/share/zoneinfo
COPY . /app
FROM runner
-ENV TZ=Asia/Shanghai
WORKDIR /app
CMD ["python", "main.py"]
diff --git a/README.md b/README.md
index 051910b..414a97e 100644
--- a/README.md
+++ b/README.md
@@ -3,11 +3,11 @@ This is a simple Matrix bot that uses OpenAI's GPT API and Bing AI to generate r
![demo](https://i.imgur.com/kK4rnPf.jpeg "demo")
## Installation and Setup
-Docker method:
-Edit `config.json` with proper values
+Docker method(Recommended):
+Edit `config.json` or `.env` with proper values
Create an empty file, for persist database only
```bash
-touch bot
+touch db
sudo docker compose up -d
```
diff --git a/ask_gpt.py b/askgpt.py
similarity index 64%
rename from ask_gpt.py
rename to askgpt.py
index 299741c..dd25ef7 100644
--- a/ask_gpt.py
+++ b/askgpt.py
@@ -4,22 +4,24 @@ import json
from log import getlogger
logger = getlogger()
+class askGPT:
+ def __init__(self):
+ self.session = aiohttp.ClientSession()
-async def ask(prompt: str, api_endpoint: str, headers: dict) -> str:
- jsons = {
- "model": "gpt-3.5-turbo",
- "messages": [
- {
- "role": "user",
- "content": prompt,
- },
- ],
- }
- async with aiohttp.ClientSession() as session:
+ async def oneTimeAsk(self, prompt: str, api_endpoint: str, headers: dict) -> str:
+ jsons = {
+ "model": "gpt-3.5-turbo",
+ "messages": [
+ {
+ "role": "user",
+ "content": prompt,
+ },
+ ],
+ }
max_try = 5
while max_try > 0:
try:
- async with session.post(url=api_endpoint,
+ async with self.session.post(url=api_endpoint,
json=jsons, headers=headers, timeout=30) as response:
status_code = response.status
if not status_code == 200:
@@ -31,9 +33,6 @@ async def ask(prompt: str, api_endpoint: str, headers: dict) -> str:
continue
resp = await response.read()
- await session.close()
return json.loads(resp)['choices'][0]['message']['content']
except Exception as e:
logger.error("Error Exception", exc_info=True)
- print(e)
- pass
diff --git a/bing.py b/bing.py
index 7fa7f1f..643b648 100644
--- a/bing.py
+++ b/bing.py
@@ -4,52 +4,49 @@ import asyncio
from log import getlogger
# api_endpoint = "http://localhost:3000/conversation"
logger = getlogger()
-python_boolean_to_json = {
- "true": True,
-}
class BingBot:
def __init__(self, bing_api_endpoint: str, jailbreakEnabled: bool = False):
self.data = {
- # 'jailbreakConversationId': json.dumps(python_boolean_to_json['true']),
'clientOptions.clientToUse': 'bing',
}
self.bing_api_endpoint = bing_api_endpoint
+ self.session = aiohttp.ClientSession()
+
self.jailbreakEnabled = jailbreakEnabled
if self.jailbreakEnabled:
- self.data['jailbreakConversationId'] = json.dumps(python_boolean_to_json['true'])
+ self.data['jailbreakConversationId'] = json.dumps(True)
async def ask_bing(self, prompt) -> str:
self.data['message'] = prompt
- async with aiohttp.ClientSession() as session:
- max_try = 5
- while max_try > 0:
- try:
- resp = await session.post(url=self.bing_api_endpoint, json=self.data)
- status_code = resp.status
- body = await resp.read()
- if not status_code == 200:
- # print failed reason
- logger.warning(str(resp.reason))
- max_try = max_try - 1
- # print(await resp.text())
- await asyncio.sleep(2)
- continue
- json_body = json.loads(body)
- if self.jailbreakEnabled:
- self.data['jailbreakConversationId'] = json_body['jailbreakConversationId']
- self.data['parentMessageId'] = json_body['messageId']
- else:
- self.data['conversationSignature'] = json_body['conversationSignature']
- self.data['conversationId'] = json_body['conversationId']
- self.data['clientId'] = json_body['clientId']
- self.data['invocationId'] = json_body['invocationId']
- return json_body['response']
- except Exception as e:
- logger.error("Error Exception", exc_info=True)
- print(f"Error: {e}")
- pass
- return "Error, please retry"
+ max_try = 5
+ while max_try > 0:
+ try:
+ resp = await self.session.post(url=self.bing_api_endpoint, json=self.data)
+ status_code = resp.status
+ body = await resp.read()
+ if not status_code == 200:
+ # print failed reason
+ logger.warning(str(resp.reason))
+ max_try = max_try - 1
+ # print(await resp.text())
+ await asyncio.sleep(2)
+ continue
+ json_body = json.loads(body)
+ if self.jailbreakEnabled:
+ self.data['jailbreakConversationId'] = json_body['jailbreakConversationId']
+ self.data['parentMessageId'] = json_body['messageId']
+ else:
+ self.data['conversationSignature'] = json_body['conversationSignature']
+ self.data['conversationId'] = json_body['conversationId']
+ self.data['clientId'] = json_body['clientId']
+ self.data['invocationId'] = json_body['invocationId']
+ return json_body['response']
+ except Exception as e:
+ logger.error("Error Exception", exc_info=True)
+ print(f"Error: {e}")
+ pass
+ return "Error, please retry"
diff --git a/bot.py b/bot.py
index c28d9a4..1a95f27 100644
--- a/bot.py
+++ b/bot.py
@@ -2,23 +2,33 @@ import sys
import asyncio
import re
import os
-from typing import Optional
-from nio import AsyncClient, MatrixRoom, RoomMessageText, LoginResponse, AsyncClientConfig
+import traceback
+from typing import Optional, Union
+from nio import (
+ AsyncClient,
+ MatrixRoom,
+ RoomMessageText,
+ InviteMemberEvent,
+ LoginResponse,
+ JoinError,
+ ToDeviceError,
+ LocalProtocolError,
+ KeyVerificationEvent,
+ KeyVerificationStart,
+ KeyVerificationCancel,
+ KeyVerificationKey,
+ KeyVerificationMac,
+ AsyncClientConfig
+ )
from nio.store.database import SqliteStore
-from ask_gpt import ask
+from askgpt import askGPT
from send_message import send_room_message
from v3 import Chatbot
from log import getlogger
from bing import BingBot
from BingImageGen import ImageGen
from send_image import send_room_image
-"""
-free api_endpoint from https://github.com/ayaka14732/ChatGPTAPIFree
-"""
-chatgpt_api_endpoint_list = {
- "free": "https://chatgpt-api.shn.hk/v1/",
- "paid": "https://api.openai.com/v1/chat/completions"
-}
+
logger = getlogger()
@@ -27,35 +37,45 @@ class Bot:
self,
homeserver: str,
user_id: str,
- password: str,
device_id: str,
- api_key: Optional[str] = "",
+ chatgpt_api_endpoint: str = os.environ.get("CHATGPT_API_ENDPOINT") or "https://api.openai.com/v1/chat/completions",
+ api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") or "",
room_id: Optional[str] = '',
bing_api_endpoint: Optional[str] = '',
- access_token: Optional[str] = '',
- jailbreakEnabled: Optional[bool] = False,
+ password: Union[str, None] = None,
+ access_token: Union[str, None] = None,
+ jailbreakEnabled: Optional[bool] = True,
bing_auth_cookie: Optional[str] = '',
):
self.homeserver = homeserver
self.user_id = user_id
self.password = password
+ self.access_token = access_token
self.device_id = device_id
self.room_id = room_id
self.api_key = api_key
+ self.chatgpt_api_endpoint = chatgpt_api_endpoint
self.bing_api_endpoint = bing_api_endpoint
self.jailbreakEnabled = jailbreakEnabled
self.bing_auth_cookie = bing_auth_cookie
# initialize AsyncClient object
self.store_path = os.getcwd()
self.config = AsyncClientConfig(store=SqliteStore,
- store_name="bot",
+ store_name="db",
store_sync_tokens=True,
encryption_enabled=True,
)
- self.client = AsyncClient(self.homeserver, user=self.user_id, device_id=self.device_id,
+ self.client = AsyncClient(homeserver=self.homeserver, user=self.user_id, device_id=self.device_id,
config=self.config, store_path=self.store_path,)
- if access_token != '':
- self.client.access_token = access_token
+
+ if self.access_token is not None:
+ self.client.access_token = self.access_token
+
+ # setup event callbacks
+ self.client.add_event_callback(self.message_callback, (RoomMessageText, ))
+ self.client.add_event_callback(self.invite_callback, (InviteMemberEvent, ))
+ self.client.add_to_device_callback(self.to_device_callback, (KeyVerificationEvent, ))
+
# regular expression to match keyword [!gpt {prompt}] [!chat {prompt}]
self.gpt_prog = re.compile(r"^\s*!gpt\s*(.+)$")
self.chat_prog = re.compile(r"^\s*!chat\s*(.+)$")
@@ -67,18 +87,21 @@ class Bot:
if self.api_key != '':
self.chatbot = Chatbot(api_key=self.api_key)
- self.chatgpt_api_endpoint = chatgpt_api_endpoint_list['paid']
+ self.chatgpt_api_endpoint = self.chatgpt_api_endpoint
# request header for !gpt command
self.headers = {
"Content-Type": "application/json",
- "Authorization": "Bearer " + self.api_key,
+ "Authorization": f"Bearer {self.api_key}",
}
else:
- self.chatgpt_api_endpoint = chatgpt_api_endpoint_list['free']
+ self.chatgpt_api_endpoint = self.chatgpt_api_endpoint
self.headers = {
"Content-Type": "application/json",
}
+ # initialize askGPT class
+ self.askgpt = askGPT()
+
# initialize bingbot
if self.bing_api_endpoint != '':
self.bingbot = BingBot(bing_api_endpoint, jailbreakEnabled=self.jailbreakEnabled)
@@ -100,86 +123,330 @@ class Bot:
# reply event_id
reply_to_event_id = event.event_id
+ # sender_id
+ sender_id = event.sender
+
+ # user_message
+ raw_user_message = event.body
+
# print info to console
print(
f"Message received in room {room.display_name}\n"
- f"{room.user_name(event.sender)} | {event.body}"
+ f"{room.user_name(event.sender)} | {raw_user_message}"
)
+ # prevent command trigger loop
if self.user_id != event.sender:
# remove newline character from event.body
- event.body = re.sub("\r\n|\r|\n", " ", event.body)
+ content_body = re.sub("\r\n|\r|\n", " ", raw_user_message)
# chatgpt
- n = self.chat_prog.match(event.body)
+ n = self.chat_prog.match(content_body)
if n:
prompt = n.group(1)
if self.api_key != '':
- await self.gpt(room_id, reply_to_event_id, prompt)
+ await self.chat(room_id, reply_to_event_id, prompt, sender_id, raw_user_message)
else:
logger.warning("No API_KEY provided")
await send_room_message(self.client, room_id, send_text="API_KEY not provided")
- m = self.gpt_prog.match(event.body)
+ m = self.gpt_prog.match(content_body)
if m:
prompt = m.group(1)
- await self.chat(room_id, reply_to_event_id, prompt)
+ await self.gpt(room_id, reply_to_event_id, prompt, sender_id, raw_user_message)
# bing ai
if self.bing_api_endpoint != '':
- b = self.bing_prog.match(event.body)
+ b = self.bing_prog.match(content_body)
if b:
prompt = b.group(1)
- await self.bing(room_id, reply_to_event_id, prompt)
+ # raw_content_body used for construct formatted_body
+ await self.bing(room_id, reply_to_event_id, prompt, sender_id, raw_user_message)
# Image Generation by Microsoft Bing
if self.bing_auth_cookie != '':
- i = self.pic_prog.match(event.body)
+ i = self.pic_prog.match(content_body)
if i:
prompt = i.group(1)
await self.pic(room_id, prompt)
# help command
- h = self.help_prog.match(event.body)
+ h = self.help_prog.match(content_body)
if h:
await self.help(room_id)
- # !gpt command
- async def gpt(self, room_id, reply_to_event_id, prompt):
- await self.client.room_typing(room_id)
+ # invite_callback event
+ async def invite_callback(self, room: MatrixRoom, event: InviteMemberEvent) -> None:
+ """Handle an incoming invite event.
+ https://github.com/8go/matrix-eno-bot/blob/ad037e02bd2960941109e9526c1033dd157bb212/callbacks.py#L104
+ If an invite is received, then join the room specified in the invite.
+ code copied from:
+ """
+ logger.debug(f"Got invite to {room.room_id} from {event.sender}.")
+ # Attempt to join 3 times before giving up
+ for attempt in range(3):
+ result = await self.client.join(room.room_id)
+ if type(result) == JoinError:
+ logger.error(
+ f"Error joining room {room.room_id} (attempt %d): %s",
+ attempt, result.message,
+ )
+ else:
+ break
+ else:
+ logger.error("Unable to join room: %s", room.room_id)
+
+ # Successfully joined room
+ logger.info(f"Joined {room.room_id}")
+
+ # to_device_callback event
+ async def to_device_callback(self, event: KeyVerificationEvent) -> None:
+ """Handle events sent to device.
+
+ Specifically this will perform Emoji verification.
+ It will accept an incoming Emoji verification requests
+ and follow the verification protocol.
+ code copied from: https://github.com/8go/matrix-eno-bot/blob/ad037e02bd2960941109e9526c1033dd157bb212/callbacks.py#L127
+ """
try:
- # run synchronous function in different thread
- text = await asyncio.to_thread(self.chatbot.ask, prompt)
- text = text.strip()
- await send_room_message(self.client, room_id, send_text=text,
- reply_to_event_id=reply_to_event_id)
+ client = self.client
+ logger.debug(
+ f"Device Event of type {type(event)} received in "
+ "to_device_cb().")
+
+ if isinstance(event, KeyVerificationStart): # first step
+ """ first step: receive KeyVerificationStart
+ KeyVerificationStart(
+ source={'content':
+ {'method': 'm.sas.v1',
+ 'from_device': 'DEVICEIDXY',
+ 'key_agreement_protocols':
+ ['curve25519-hkdf-sha256', 'curve25519'],
+ 'hashes': ['sha256'],
+ 'message_authentication_codes':
+ ['hkdf-hmac-sha256', 'hmac-sha256'],
+ 'short_authentication_string':
+ ['decimal', 'emoji'],
+ 'transaction_id': 'SomeTxId'
+ },
+ 'type': 'm.key.verification.start',
+ 'sender': '@user2:example.org'
+ },
+ sender='@user2:example.org',
+ transaction_id='SomeTxId',
+ from_device='DEVICEIDXY',
+ method='m.sas.v1',
+ key_agreement_protocols=[
+ 'curve25519-hkdf-sha256', 'curve25519'],
+ hashes=['sha256'],
+ message_authentication_codes=[
+ 'hkdf-hmac-sha256', 'hmac-sha256'],
+ short_authentication_string=['decimal', 'emoji'])
+ """
+
+ if "emoji" not in event.short_authentication_string:
+ estr = ("Other device does not support emoji verification "
+ f"{event.short_authentication_string}. Aborting.")
+ print(estr)
+ logger.info(estr)
+ return
+ resp = await client.accept_key_verification(
+ event.transaction_id)
+ if isinstance(resp, ToDeviceError):
+ estr = f"accept_key_verification() failed with {resp}"
+ print(estr)
+ logger.info(estr)
+
+ sas = client.key_verifications[event.transaction_id]
+
+ todevice_msg = sas.share_key()
+ resp = await client.to_device(todevice_msg)
+ if isinstance(resp, ToDeviceError):
+ estr = f"to_device() failed with {resp}"
+ print(estr)
+ logger.info(estr)
+
+ elif isinstance(event, KeyVerificationCancel): # anytime
+ """ at any time: receive KeyVerificationCancel
+ KeyVerificationCancel(source={
+ 'content': {'code': 'm.mismatched_sas',
+ 'reason': 'Mismatched authentication string',
+ 'transaction_id': 'SomeTxId'},
+ 'type': 'm.key.verification.cancel',
+ 'sender': '@user2:example.org'},
+ sender='@user2:example.org',
+ transaction_id='SomeTxId',
+ code='m.mismatched_sas',
+ reason='Mismatched short authentication string')
+ """
+
+ # There is no need to issue a
+ # client.cancel_key_verification(tx_id, reject=False)
+ # here. The SAS flow is already cancelled.
+ # We only need to inform the user.
+ estr = (f"Verification has been cancelled by {event.sender} "
+ f"for reason \"{event.reason}\".")
+ print(estr)
+ logger.info(estr)
+
+ elif isinstance(event, KeyVerificationKey): # second step
+ """ Second step is to receive KeyVerificationKey
+ KeyVerificationKey(
+ source={'content': {
+ 'key': 'SomeCryptoKey',
+ 'transaction_id': 'SomeTxId'},
+ 'type': 'm.key.verification.key',
+ 'sender': '@user2:example.org'
+ },
+ sender='@user2:example.org',
+ transaction_id='SomeTxId',
+ key='SomeCryptoKey')
+ """
+ sas = client.key_verifications[event.transaction_id]
+
+ print(f"{sas.get_emoji()}")
+ # don't log the emojis
+
+ # The bot process must run in forground with a screen and
+ # keyboard so that user can accept/reject via keyboard.
+ # For emoji verification bot must not run as service or
+ # in background.
+ yn = input("Do the emojis match? (Y/N) (C for Cancel) ")
+ if yn.lower() == "y":
+ estr = ("Match! The verification for this "
+ "device will be accepted.")
+ print(estr)
+ logger.info(estr)
+ resp = await client.confirm_short_auth_string(
+ event.transaction_id)
+ if isinstance(resp, ToDeviceError):
+ estr = ("confirm_short_auth_string() "
+ f"failed with {resp}")
+ print(estr)
+ logger.info(estr)
+ elif yn.lower() == "n": # no, don't match, reject
+ estr = ("No match! Device will NOT be verified "
+ "by rejecting verification.")
+ print(estr)
+ logger.info(estr)
+ resp = await client.cancel_key_verification(
+ event.transaction_id, reject=True)
+ if isinstance(resp, ToDeviceError):
+ estr = (f"cancel_key_verification failed with {resp}")
+ print(estr)
+ logger.info(estr)
+ else: # C or anything for cancel
+ estr = ("Cancelled by user! Verification will be "
+ "cancelled.")
+ print(estr)
+ logger.info(estr)
+ resp = await client.cancel_key_verification(
+ event.transaction_id, reject=False)
+ if isinstance(resp, ToDeviceError):
+ estr = (f"cancel_key_verification failed with {resp}")
+ print(estr)
+ logger.info(estr)
+
+ elif isinstance(event, KeyVerificationMac): # third step
+ """ Third step is to receive KeyVerificationMac
+ KeyVerificationMac(
+ source={'content': {
+ 'mac': {'ed25519:DEVICEIDXY': 'SomeKey1',
+ 'ed25519:SomeKey2': 'SomeKey3'},
+ 'keys': 'SomeCryptoKey4',
+ 'transaction_id': 'SomeTxId'},
+ 'type': 'm.key.verification.mac',
+ 'sender': '@user2:example.org'},
+ sender='@user2:example.org',
+ transaction_id='SomeTxId',
+ mac={'ed25519:DEVICEIDXY': 'SomeKey1',
+ 'ed25519:SomeKey2': 'SomeKey3'},
+ keys='SomeCryptoKey4')
+ """
+ sas = client.key_verifications[event.transaction_id]
+ try:
+ todevice_msg = sas.get_mac()
+ except LocalProtocolError as e:
+ # e.g. it might have been cancelled by ourselves
+ estr = (f"Cancelled or protocol error: Reason: {e}.\n"
+ f"Verification with {event.sender} not concluded. "
+ "Try again?")
+ print(estr)
+ logger.info(estr)
+ else:
+ resp = await client.to_device(todevice_msg)
+ if isinstance(resp, ToDeviceError):
+ estr = f"to_device failed with {resp}"
+ print(estr)
+ logger.info(estr)
+ estr = (f"sas.we_started_it = {sas.we_started_it}\n"
+ f"sas.sas_accepted = {sas.sas_accepted}\n"
+ f"sas.canceled = {sas.canceled}\n"
+ f"sas.timed_out = {sas.timed_out}\n"
+ f"sas.verified = {sas.verified}\n"
+ f"sas.verified_devices = {sas.verified_devices}\n")
+ print(estr)
+ logger.info(estr)
+ estr = ("Emoji verification was successful!\n"
+ "Initiate another Emoji verification from "
+ "another device or room if desired. "
+ "Or if done verifying, hit Control-C to stop the "
+ "bot in order to restart it as a service or to "
+ "run it in the background.")
+ print(estr)
+ logger.info(estr)
+ else:
+ estr = (f"Received unexpected event type {type(event)}. "
+ f"Event is {event}. Event will be ignored.")
+ print(estr)
+ logger.info(estr)
+ except BaseException:
+ estr = traceback.format_exc()
+ print(estr)
+ logger.info(estr)
+
+ # !chat command
+ async def chat(self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message):
+ await self.client.room_typing(room_id, timeout=120000)
+ try:
+ text = await asyncio.wait_for(self.chatbot.ask_async(prompt), timeout=120)
+ except TimeoutError as e:
+ logger.error("timeoutException", exc_info=True)
+ text = "Timeout error"
except Exception as e:
logger.error("Error", exc_info=True)
print(f"Error: {e}")
- # !chat command
- async def chat(self, room_id, reply_to_event_id, prompt):
+ text = text.strip()
+ try:
+ await send_room_message(self.client, room_id, reply_message=text,
+ reply_to_event_id=reply_to_event_id, sender_id=sender_id, user_message=raw_user_message)
+ except Exception as e:
+ logger.error(f"Error: {e}", exc_info=True)
+
+ # !gpt command
+ async def gpt(self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message):
try:
# sending typing state
- await self.client.room_typing(room_id)
+ await self.client.room_typing(room_id, timeout=120000)
# timeout 120s
- text = await asyncio.wait_for(ask(prompt, self.chatgpt_api_endpoint, self.headers), timeout=120)
+ text = await asyncio.wait_for(self.askgpt.oneTimeAsk(prompt, self.chatgpt_api_endpoint, self.headers), timeout=120)
except TimeoutError:
logger.error("timeoutException", exc_info=True)
text = "Timeout error"
text = text.strip()
try:
- await send_room_message(self.client, room_id, send_text=text,
- reply_to_event_id=reply_to_event_id)
+ await send_room_message(self.client, room_id, reply_message=text,
+ reply_to_event_id=reply_to_event_id, sender_id=sender_id, user_message=raw_user_message)
except Exception as e:
logger.error(f"Error: {e}", exc_info=True)
# !bing command
- async def bing(self, room_id, reply_to_event_id, prompt):
+ async def bing(self, room_id, reply_to_event_id, prompt, sender_id, raw_content_body):
try:
# sending typing state
- await self.client.room_typing(room_id)
+ await self.client.room_typing(room_id, timeout=120000)
# timeout 120s
text = await asyncio.wait_for(self.bingbot.ask_bing(prompt), timeout=120)
except TimeoutError:
@@ -187,8 +454,8 @@ class Bot:
text = "Timeout error"
text = text.strip()
try:
- await send_room_message(self.client, room_id, send_text=text,
- reply_to_event_id=reply_to_event_id)
+ await send_room_message(self.client, room_id, reply_message=text,
+ reply_to_event_id=reply_to_event_id, sender=sender_id, raw_content_body=raw_content_body)
except Exception as e:
logger.error(f"Error: {e}", exc_info=True)
@@ -216,7 +483,7 @@ class Bot:
"!bing [content], chat with context conversation powered by Bing AI\n" + \
"!pic [prompt], Image generation by Microsoft Bing"
- await send_room_message(self.client, room_id, send_text=help_info)
+ await send_room_message(self.client, room_id, reply_message=help_info)
except Exception as e:
logger.error(f"Error: {e}", exc_info=True)
@@ -232,6 +499,24 @@ class Bot:
logger.error(f"Error: {e}", exc_info=True)
# sync messages in the room
- async def sync_forever(self, timeout=30000):
- self.client.add_event_callback(self.message_callback, RoomMessageText)
- await self.client.sync_forever(timeout=timeout, full_state=True)
+ async def sync_forever(self, timeout=30000, full_state=True):
+ await self.client.sync_forever(timeout=timeout, full_state=full_state)
+
+ # Sync encryption keys with the server
+ async def sync_encryption_key(self):
+ if self.client.should_upload_keys:
+ await self.client.keys_upload()
+
+ # Trust own devices
+ async def trust_own_devices(self):
+ await self.client.sync(timeout=30000, full_state=True)
+ for device_id, olm_device in self.client.device_store[
+ self.user_id].items():
+ logger.debug("My other devices are: "
+ f"device_id={device_id}, "
+ f"olm_device={olm_device}.")
+ logger.info("Setting up trust for my own "
+ f"device {device_id} and session key "
+ f"{olm_device.keys['ed25519']}.")
+ self.client.verify_device(olm_device)
+
\ No newline at end of file
diff --git a/compose.yaml b/compose.yaml
index cf7ae30..1f7c56f 100644
--- a/compose.yaml
+++ b/compose.yaml
@@ -6,10 +6,13 @@ services:
# build:
# context: .
# dockerfile: ./Dockerfile
+ env_file:
+ - .env
volumes:
- - ./config.json:/app/config.json
- # use touch to create an empty file bot, for persist database only
- - ./bot:/app/bot
+ # use env file or config.json
+ # - ./config.json:/app/config.json
+ # use touch to create an empty file db, for persist database only
+ - ./db:/app/db
networks:
- matrix_network
# api:
diff --git a/main.py b/main.py
index 82b2375..477e301 100644
--- a/main.py
+++ b/main.py
@@ -1,30 +1,53 @@
#!/usr/bin/env python3
import json
+import os
import asyncio
from bot import Bot
+from nio import Api, SyncResponse
+from log import getlogger
+logger = getlogger()
async def main():
- fp = open('config.json', 'r')
- config = json.load(fp)
- matrix_bot = Bot(homeserver=config['homeserver'],
- user_id=config['user_id'],
- password=config.get('password', ''), # provide a default value when the key does not exist
- device_id=config['device_id'],
- room_id=config.get('room_id', ''),
- api_key=config.get('api_key', ''),
- bing_api_endpoint=config.get('bing_api_endpoint', ''),
- access_token=config.get('access_token', ''),
- jailbreakEnabled=config.get('jailbreakEnabled', False),
- bing_auth_cookie=config.get('bing_auth_cookie', ''),
+ if os.path.exists('config.json'):
+ fp = open('config.json', 'r', encoding="utf8")
+ config = json.load(fp)
+
+ matrix_bot = Bot(homeserver=os.environ.get("HOMESERVER") or config.get('homeserver'),
+ user_id=os.environ.get("USER_ID") or config.get('user_id') ,
+ password=os.environ.get("PASSWORD") or config.get('password'),
+ device_id=os.environ.get("DEVICE_ID") or config.get('device_id'),
+ room_id=os.environ.get("ROOM_ID") or config.get('room_id'),
+ api_key=os.environ.get("OPENAI_API_KEY") or config.get('api_key'),
+ bing_api_endpoint=os.environ.get("BING_API_ENDPOINT") or config.get('bing_api_endpoint'),
+ access_token=os.environ.get("ACCESS_TOKEN") or config.get('access_token'),
+ jailbreakEnabled=os.environ.get("JAILBREAKENABLED", "False").lower() in ('true', '1') or config.get('jailbreakEnabled'),
+ bing_auth_cookie=os.environ.get("BING_AUTH_COOKIE") or config.get('bing_auth_cookie'),
)
- if config.get('access_token', '') == '':
- await matrix_bot.login()
- await matrix_bot.sync_forever()
+ # if not set access_token, then login via password
+ # if os.path.exists('config.json'):
+ # fp = open('config.json', 'r', encoding="utf8")
+ # config = json.load(fp)
+ # if os.environ.get("ACCESS_TOKEN") is None and config.get("access_token") is None:
+ # await matrix_bot.login()
+
+ # elif os.environ.get("ACCESS_TOKEN") is None:
+ # await matrix_bot.login()
+
+ await matrix_bot.login()
+
+ # await matrix_bot.sync_encryption_key()
+
+ # await matrix_bot.trust_own_devices()
+
+ try:
+ await matrix_bot.sync_forever(timeout=3000, full_state=True)
+ finally:
+ await matrix_bot.client.close()
if __name__ == "__main__":
- print("matrix chatgpt bot start.....")
+ logger.debug("matrix chatgpt bot start.....")
try:
loop = asyncio.get_running_loop()
except RuntimeError:
diff --git a/requirements.txt b/requirements.txt
index 8bef71b..47c538f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,41 +1,46 @@
-aiofiles==0.6.0
+aiofiles==23.1.0
aiohttp==3.8.4
aiohttp-socks==0.7.1
aiosignal==1.3.1
+anyio==3.6.2
async-timeout==4.0.2
atomicwrites==1.4.1
attrs==22.2.0
blobfile==2.0.1
-cachetools==4.2.4
+cachetools==5.3.0
certifi==2022.12.7
cffi==1.15.1
-charset-normalizer==3.0.1
-filelock==3.9.0
+charset-normalizer==3.1.0
+filelock==3.10.7
frozenlist==1.3.3
future==0.18.3
-h11==0.12.0
+h11==0.14.0
h2==4.1.0
hpack==4.0.0
+httpcore==0.16.3
+httpx==0.23.3
hyperframe==6.0.1
idna==3.4
jsonschema==4.17.3
Logbook==1.5.3
lxml==4.9.2
-matrix-nio==0.20.1
+matrix-nio[e2e]
multidict==6.0.4
peewee==3.16.0
-Pillow==9.4.0
+Pillow==9.5.0
pycparser==2.21
pycryptodome==3.17
pycryptodomex==3.17
pyrsistent==0.19.3
python-magic==0.4.27
-python-olm==3.1.3
-python-socks==2.1.1
-regex==2022.10.31
+python-socks==2.2.0
+regex==2023.3.23
requests==2.28.2
-tiktoken==0.3.0
+rfc3986==1.5.0
+sniffio==1.3.0
unpaddedbase64==2.1.0
-urllib3==1.26.14
+urllib3==1.26.15
wcwidth==0.2.6
yarl==1.8.2
+python-olm >= '3.1.0'
+tiktoken==0.3.3
diff --git a/send_message.py b/send_message.py
index 6151c2e..58f4aa5 100644
--- a/send_message.py
+++ b/send_message.py
@@ -1,14 +1,21 @@
from nio import AsyncClient
-
async def send_room_message(client: AsyncClient,
room_id: str,
- send_text: str,
+ reply_message: str,
+ sender_id: str = '',
+ user_message: str = '',
reply_to_event_id: str = '') -> None:
if reply_to_event_id == '':
- content = {"msgtype": "m.text", "body": f"{send_text}", }
+ content = {"msgtype": "m.text", "body": reply_message, }
else:
- content={"msgtype": "m.text", "body": f"{send_text}",
+ body = r'> <' + sender_id + r'> ' + user_message + r'\n\n' + reply_message
+ format = r'org.matrix.custom.html'
+ formatted_body = r'In reply to ' + sender_id
+ + r'
' + user_message + r'
' + reply_message
+
+ content={"msgtype": "m.text", "body": body, "format": format, "formatted_body": formatted_body,
"m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}}, }
await client.room_send(
room_id,
diff --git a/test.py b/test.py
deleted file mode 100644
index bc41f13..0000000
--- a/test.py
+++ /dev/null
@@ -1,58 +0,0 @@
-from v3 import Chatbot
-import asyncio
-from ask_gpt import ask
-import json
-from bing import BingBot
-
-fp = open("config.json", "r")
-config = json.load(fp)
-api_key = config.get('api_key', '')
-bing_api_endpoint = config.get('bing_api_endpoint', '')
-api_endpoint_list = {
- "free": "https://chatgpt-api.shn.hk/v1/",
- "paid": "https://api.openai.com/v1/chat/completions"
-}
-
-
-def test_v3(prompt: str):
- bot = Chatbot(api_key=api_key)
- resp = bot.ask(prompt=prompt)
- print(resp)
-
-
-async def test_ask_gpt_paid(prompt: str):
- headers = {
- "Content-Type": "application/json",
- "Authorization": "Bearer " + api_key,
- }
- api_endpoint = api_endpoint_list['paid']
- # test ask_gpt.py ask()
- print(await ask(prompt, api_endpoint, headers))
-
-
-async def test_ask_gpt_free(prompt: str):
- headers = {
- "Content-Type": "application/json",
- }
- api_endpoint = api_endpoint_list['free']
- print(await ask(prompt, api_endpoint, headers))
-
-
-async def test_bingbot():
- if bing_api_endpoint != '':
- bingbot = BingBot(bing_api_endpoint)
- prompt1 = "Hello World"
- prompt2 = "Do you know Victor Marie Hugo"
- prompt3 = "Can you tell me something about him?"
- resp1 = await bingbot.ask_bing(prompt1)
- resp2 = await bingbot.ask_bing(prompt2)
- resp3 = await bingbot.ask_bing(prompt3)
- print(resp1)
- print(resp2)
- print(resp3)
-
-if __name__ == "__main__":
- test_v3("Hello World")
- asyncio.run(test_ask_gpt_paid("Hello World"))
- asyncio.run(test_ask_gpt_free("Hello World"))
- asyncio.run(test_bingbot())
diff --git a/v3.py b/v3.py
index 570ed8a..2c1e6ee 100644
--- a/v3.py
+++ b/v3.py
@@ -1,20 +1,12 @@
-"""
-A simple wrapper for the official ChatGPT API
-https://github.com/acheong08/ChatGPT/blob/main/src/revChatGPT/V3.py
-"""
-
import json
import os
+from typing import AsyncGenerator
-
+import httpx
import requests
import tiktoken
-ENGINE = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo"
-ENCODER = tiktoken.get_encoding("gpt2")
-
-
class Chatbot:
"""
Official ChatGPT API
@@ -22,29 +14,63 @@ class Chatbot:
def __init__(
self,
- api_key: str = None,
- engine: str = None,
+ api_key: str,
+ engine: str = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo",
proxy: str = None,
- max_tokens: int = 4096,
+ timeout: float = None,
+ max_tokens: int = None,
temperature: float = 0.5,
top_p: float = 1.0,
+ presence_penalty: float = 0.0,
+ frequency_penalty: float = 0.0,
reply_count: int = 1,
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
) -> None:
"""
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
"""
- self.engine = engine or ENGINE
+ self.engine: str = engine
+ self.api_key: str = api_key
+ self.system_prompt: str = system_prompt
+ self.max_tokens: int = max_tokens or (
+ 31000 if engine == "gpt-4-32k" else 7000 if engine == "gpt-4" else 4000
+ )
+ self.truncate_limit: int = (
+ 30500 if engine == "gpt-4-32k" else 6500 if engine == "gpt-4" else 3500
+ )
+ self.temperature: float = temperature
+ self.top_p: float = top_p
+ self.presence_penalty: float = presence_penalty
+ self.frequency_penalty: float = frequency_penalty
+ self.reply_count: int = reply_count
+ self.timeout: float = timeout
+ self.proxy = proxy
self.session = requests.Session()
- self.api_key = api_key
- # self.proxy = proxy
- # if self.proxy:
- # proxies = {
- # "http": self.proxy,
- # "https": self.proxy,
- # }
- # self.session.proxies = proxies
- self.conversation: dict = {
+ self.session.proxies.update(
+ {
+ "http": proxy,
+ "https": proxy,
+ },
+ )
+ proxy = (
+ proxy or os.environ.get("all_proxy") or os.environ.get("ALL_PROXY") or None
+ )
+
+ if proxy:
+ if "socks5h" not in proxy:
+ self.aclient = httpx.AsyncClient(
+ follow_redirects=True,
+ proxies=proxy,
+ timeout=timeout,
+ )
+ else:
+ self.aclient = httpx.AsyncClient(
+ follow_redirects=True,
+ proxies=proxy,
+ timeout=timeout,
+ )
+
+ self.conversation: dict[str, list[dict]] = {
"default": [
{
"role": "system",
@@ -52,20 +78,13 @@ class Chatbot:
},
],
}
- self.system_prompt = system_prompt
- self.max_tokens = max_tokens
- self.temperature = temperature
- self.top_p = top_p
- self.reply_count = reply_count
- initial_conversation = "\n".join(
- [x["content"] for x in self.conversation["default"]],
- )
- if len(ENCODER.encode(initial_conversation)) > self.max_tokens:
- raise Exception("System prompt is too long")
def add_to_conversation(
- self, message: str, role: str, convo_id: str = "default"
+ self,
+ message: str,
+ role: str,
+ convo_id: str = "default",
) -> None:
"""
Add a message to the conversation
@@ -77,12 +96,8 @@ class Chatbot:
Truncate the conversation
"""
while True:
- full_conversation = "".join(
- message["role"] + ": " + message["content"] + "\n"
- for message in self.conversation[convo_id]
- )
if (
- len(ENCODER.encode(full_conversation)) > self.max_tokens
+ self.get_token_count(convo_id) > self.truncate_limit
and len(self.conversation[convo_id]) > 1
):
# Don't remove the first message
@@ -90,15 +105,41 @@ class Chatbot:
else:
break
+
+ def get_token_count(self, convo_id: str = "default") -> int:
+ """
+ Get token count
+ """
+ if self.engine not in [
+ "gpt-3.5-turbo",
+ "gpt-3.5-turbo-0301",
+ "gpt-4",
+ "gpt-4-0314",
+ "gpt-4-32k",
+ "gpt-4-32k-0314",
+ ]:
+ raise NotImplementedError("Unsupported engine {self.engine}")
+
+ tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base"
+
+ encoding = tiktoken.encoding_for_model(self.engine)
+
+ num_tokens = 0
+ for message in self.conversation[convo_id]:
+ # every message follows {role/name}\n{content}\n
+ num_tokens += 5
+ for key, value in message.items():
+ num_tokens += len(encoding.encode(value))
+ if key == "name": # if there's a name, the role is omitted
+ num_tokens += 5 # role is always required and always 1 token
+ num_tokens += 5 # every reply is primed with assistant
+ return num_tokens
+
def get_max_tokens(self, convo_id: str) -> int:
"""
Get max tokens
"""
- full_conversation = "".join(
- message["role"] + ": " + message["content"] + "\n"
- for message in self.conversation[convo_id]
- )
- return 4000 - len(ENCODER.encode(full_conversation))
+ return self.max_tokens - self.get_token_count(convo_id)
def ask_stream(
self,
@@ -106,7 +147,7 @@ class Chatbot:
role: str = "user",
convo_id: str = "default",
**kwargs,
- ) -> str:
+ ):
"""
Ask a question
"""
@@ -117,7 +158,7 @@ class Chatbot:
self.__truncate_conversation(convo_id=convo_id)
# Get response
response = self.session.post(
- "https://api.openai.com/v1/chat/completions",
+ os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
json={
"model": self.engine,
@@ -126,16 +167,22 @@ class Chatbot:
# kwargs
"temperature": kwargs.get("temperature", self.temperature),
"top_p": kwargs.get("top_p", self.top_p),
+ "presence_penalty": kwargs.get(
+ "presence_penalty",
+ self.presence_penalty,
+ ),
+ "frequency_penalty": kwargs.get(
+ "frequency_penalty",
+ self.frequency_penalty,
+ ),
"n": kwargs.get("n", self.reply_count),
"user": role,
- # "max_tokens": self.get_max_tokens(convo_id=convo_id),
+ "max_tokens": self.get_max_tokens(convo_id=convo_id),
},
+ timeout=kwargs.get("timeout", self.timeout),
stream=True,
)
- if response.status_code != 200:
- raise Exception(
- f"Error: {response.status_code} {response.reason} {response.text}",
- )
+
response_role: str = None
full_response: str = ""
for line in response.iter_lines():
@@ -160,8 +207,100 @@ class Chatbot:
yield content
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
+ async def ask_stream_async(
+ self,
+ prompt: str,
+ role: str = "user",
+ convo_id: str = "default",
+ **kwargs,
+ ) -> AsyncGenerator[str, None]:
+ """
+ Ask a question
+ """
+ # Make conversation if it doesn't exist
+ if convo_id not in self.conversation:
+ self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
+ self.add_to_conversation(prompt, "user", convo_id=convo_id)
+ self.__truncate_conversation(convo_id=convo_id)
+ # Get response
+ async with self.aclient.stream(
+ "post",
+ os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
+ headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
+ json={
+ "model": self.engine,
+ "messages": self.conversation[convo_id],
+ "stream": True,
+ # kwargs
+ "temperature": kwargs.get("temperature", self.temperature),
+ "top_p": kwargs.get("top_p", self.top_p),
+ "presence_penalty": kwargs.get(
+ "presence_penalty",
+ self.presence_penalty,
+ ),
+ "frequency_penalty": kwargs.get(
+ "frequency_penalty",
+ self.frequency_penalty,
+ ),
+ "n": kwargs.get("n", self.reply_count),
+ "user": role,
+ "max_tokens": self.get_max_tokens(convo_id=convo_id),
+ },
+ timeout=kwargs.get("timeout", self.timeout),
+ ) as response:
+ if response.status_code != 200:
+ await response.aread()
+
+ response_role: str = ""
+ full_response: str = ""
+ async for line in response.aiter_lines():
+ line = line.strip()
+ if not line:
+ continue
+ # Remove "data: "
+ line = line[6:]
+ if line == "[DONE]":
+ break
+ resp: dict = json.loads(line)
+ choices = resp.get("choices")
+ if not choices:
+ continue
+ delta: dict[str, str] = choices[0].get("delta")
+ if not delta:
+ continue
+ if "role" in delta:
+ response_role = delta["role"]
+ if "content" in delta:
+ content: str = delta["content"]
+ full_response += content
+ yield content
+ self.add_to_conversation(full_response, response_role, convo_id=convo_id)
+
+ async def ask_async(
+ self,
+ prompt: str,
+ role: str = "user",
+ convo_id: str = "default",
+ **kwargs,
+ ) -> str:
+ """
+ Non-streaming ask
+ """
+ response = self.ask_stream_async(
+ prompt=prompt,
+ role=role,
+ convo_id=convo_id,
+ **kwargs,
+ )
+ full_response: str = "".join([r async for r in response])
+ return full_response
+
def ask(
- self, prompt: str, role: str = "user", convo_id: str = "default", **kwargs
+ self,
+ prompt: str,
+ role: str = "user",
+ convo_id: str = "default",
+ **kwargs,
) -> str:
"""
Non-streaming ask
@@ -175,12 +314,6 @@ class Chatbot:
full_response: str = "".join(response)
return full_response
- def rollback(self, n: int = 1, convo_id: str = "default") -> None:
- """
- Rollback the conversation
- """
- for _ in range(n):
- self.conversation[convo_id].pop()
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
"""
@@ -188,5 +321,4 @@ class Chatbot:
"""
self.conversation[convo_id] = [
{"role": "system", "content": system_prompt or self.system_prompt},
- ]
-
+ ]
\ No newline at end of file