🎉 initial commit
This commit is contained in:
commit
fe1a53780d
10 changed files with 636 additions and 0 deletions
163
.gitignore
vendored
Normal file
163
.gitignore
vendored
Normal file
|
@ -0,0 +1,163 @@
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
bin/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
bot
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
config.json
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
.idea/
|
39
README.md
Normal file
39
README.md
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
## Introduction
|
||||||
|
This is a simple Matrix bot that uses OpenAI's GPT API and a Chatbot to generate responses to user inputs. The bot responds to two types of prompts: `!gpt` and `!chat`, depending on the first word of the prompt.
|
||||||
|
![demo](https://i.imgur.com/kK4rnPf.jpeg "demo")
|
||||||
|
|
||||||
|
## Installation and Setup
|
||||||
|
To run this application, follow the steps below:<br>
|
||||||
|
1. Clone the repository:
|
||||||
|
```
|
||||||
|
git clone https://github.com/hibobmaster/matrix_chatgpt_bot.git
|
||||||
|
```
|
||||||
|
2. Install the required dependencies:<br>
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
3. Create a new config.json file and fill it with the necessary information:<br>
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"homeserver": "YOUR_HOMESERVER",
|
||||||
|
"user_id": "YOUR_USER_ID",
|
||||||
|
"password": "YOUR_PASSWORD",
|
||||||
|
"device_id": "YOUR_DEVICE_ID",
|
||||||
|
"room_id": "YOUR_ROOM_ID",
|
||||||
|
"api_key": "YOUR_API_KEY"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
4. Start the bot:
|
||||||
|
```
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
## Usage
|
||||||
|
To interact with the bot, simply send a message to the bot in the Matrix room with one of the two prompts:<br>
|
||||||
|
- `!gpt` To generate a response using free_endpoint API:
|
||||||
|
```
|
||||||
|
!gpt What is the meaning of life?
|
||||||
|
```
|
||||||
|
- `!chat` To chat using official api with context associated support
|
||||||
|
```
|
||||||
|
!chat Can you tell me a joke?
|
||||||
|
```
|
51
ask_gpt.py
Normal file
51
ask_gpt.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
"""
|
||||||
|
api_endpoint from https://github.com/ayaka14732/ChatGPTAPIFree
|
||||||
|
"""
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
api_endpoint_free = "https://chatgpt-api.shn.hk/v1/"
|
||||||
|
headers = {'Content-Type': "application/json"}
|
||||||
|
|
||||||
|
|
||||||
|
async def ask(prompt: str) -> str:
|
||||||
|
jsons = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
async with session.post(url=api_endpoint_free,
|
||||||
|
json=jsons, headers=headers, timeout=10) as response:
|
||||||
|
status_code = response.status
|
||||||
|
if not status_code == 200:
|
||||||
|
# wait 2s
|
||||||
|
time.sleep(2)
|
||||||
|
continue
|
||||||
|
|
||||||
|
resp = await response.read()
|
||||||
|
await session.close()
|
||||||
|
return json.loads(resp)['choices'][0]['message']['content']
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def test() -> None:
|
||||||
|
resp = await ask("Hello World")
|
||||||
|
print(resp)
|
||||||
|
# type: str
|
||||||
|
print(type(resp))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test())
|
82
bot.py
Normal file
82
bot.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
from nio import AsyncClient, MatrixRoom, RoomMessageText, LoginResponse, AsyncClientConfig
|
||||||
|
from nio.store.database import SqliteStore
|
||||||
|
from ask_gpt import ask
|
||||||
|
from send_message import send_room_message
|
||||||
|
from v3 import Chatbot
|
||||||
|
|
||||||
|
|
||||||
|
class Bot:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
homeserver: str,
|
||||||
|
user_id: str,
|
||||||
|
password: str,
|
||||||
|
device_id: str,
|
||||||
|
api_key: str = "",
|
||||||
|
room_id: Optional[str] = '',
|
||||||
|
):
|
||||||
|
self.homeserver = homeserver
|
||||||
|
self.user_id = user_id
|
||||||
|
self.password = password
|
||||||
|
self.device_id = device_id
|
||||||
|
self.room_id = room_id
|
||||||
|
self.api_key = api_key
|
||||||
|
# initialize AsyncClient object
|
||||||
|
self.store_path = os.getcwd()
|
||||||
|
self.config = AsyncClientConfig(store=SqliteStore,
|
||||||
|
store_name="bot",
|
||||||
|
store_sync_tokens=True,
|
||||||
|
)
|
||||||
|
self.client = AsyncClient(self.homeserver, user=self.user_id, device_id=self.device_id,
|
||||||
|
config=self.config, store_path=self.store_path)
|
||||||
|
# regular expression to match keyword [!gpt {prompt}] [!chat {prompt}]
|
||||||
|
self.gpt_prog = re.compile(r"^\s*!gpt\s*(.+)$")
|
||||||
|
self.chat_prog = re.compile(r"^\s*!chat\s*(.+)$")
|
||||||
|
# initialize chatbot
|
||||||
|
self.chatbot = Chatbot(api_key=self.api_key)
|
||||||
|
|
||||||
|
# message_callback event
|
||||||
|
async def message_callback(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||||||
|
# chatgpt
|
||||||
|
m = self.gpt_prog.match(event.body)
|
||||||
|
if m:
|
||||||
|
# sending typing state
|
||||||
|
await self.client.room_typing(self.room_id)
|
||||||
|
prompt = m.group(1)
|
||||||
|
text = await ask(prompt)
|
||||||
|
text = text.strip()
|
||||||
|
await send_room_message(self.client, self.room_id, send_text=text)
|
||||||
|
|
||||||
|
n = self.chat_prog.match(event.body)
|
||||||
|
if n:
|
||||||
|
# sending typing state
|
||||||
|
await self.client.room_typing(self.room_id)
|
||||||
|
prompt = n.group(1)
|
||||||
|
try:
|
||||||
|
text = self.chatbot.ask(prompt).strip()
|
||||||
|
await send_room_message(self.client, self.room_id, send_text=text)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
# print info to console
|
||||||
|
# print(
|
||||||
|
# f"Message received in room {room.display_name}\n"
|
||||||
|
# f"{room.user_name(event.sender)} | {event.body}"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# bot login
|
||||||
|
async def login(self) -> None:
|
||||||
|
resp = await self.client.login(password=self.password)
|
||||||
|
if not isinstance(resp, LoginResponse):
|
||||||
|
print(f"Login Failed: {resp}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# sync messages in the room
|
||||||
|
async def sync_forever(self, timeout=30000):
|
||||||
|
self.client.add_event_callback(self.message_callback, RoomMessageText)
|
||||||
|
await self.client.sync_forever(timeout=timeout, full_state=True)
|
8
config.json.sample
Normal file
8
config.json.sample
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
{
|
||||||
|
"homeserver": "https://matrix.qqs.tw",
|
||||||
|
"user_id": "@lullap:xxxxx.org",
|
||||||
|
"password": "xxxxxxxxxxxxxxxxxx",
|
||||||
|
"device_id": "ECYEOKVPLG",
|
||||||
|
"room_id": "!FYCmBSkCRUNvZDBaDQ:matrix.qqs.tw",
|
||||||
|
"api_key": "xxxxxxxxxxxxxxxxxxxxxxxx"
|
||||||
|
}
|
30
main.py
Normal file
30
main.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from bot import Bot
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
fp = open('config.json', 'r')
|
||||||
|
config = json.load(fp)
|
||||||
|
matrix_bot = Bot(homeserver=config['homeserver'],
|
||||||
|
user_id=config['user_id'],
|
||||||
|
password=config['password'],
|
||||||
|
device_id=config['device_id'],
|
||||||
|
room_id=config['room_id'],
|
||||||
|
api_key=config['api_key'])
|
||||||
|
await matrix_bot.login()
|
||||||
|
await matrix_bot.sync_forever()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
loop.close()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# asyncio.get_event_loop().run_until_complete(main())
|
39
requirements.txt
Normal file
39
requirements.txt
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
aiofiles==0.6.0
|
||||||
|
aiohttp==3.8.4
|
||||||
|
aiohttp-socks==0.7.1
|
||||||
|
aiosignal==1.3.1
|
||||||
|
async-timeout==4.0.2
|
||||||
|
atomicwrites==1.4.1
|
||||||
|
attrs==22.2.0
|
||||||
|
blobfile==2.0.1
|
||||||
|
cachetools==4.2.4
|
||||||
|
certifi==2022.12.7
|
||||||
|
cffi==1.15.1
|
||||||
|
charset-normalizer==3.0.1
|
||||||
|
filelock==3.9.0
|
||||||
|
frozenlist==1.3.3
|
||||||
|
future==0.18.3
|
||||||
|
h11==0.12.0
|
||||||
|
h2==4.1.0
|
||||||
|
hpack==4.0.0
|
||||||
|
hyperframe==6.0.1
|
||||||
|
idna==3.4
|
||||||
|
jsonschema==4.17.3
|
||||||
|
Logbook==1.5.3
|
||||||
|
lxml==4.9.2
|
||||||
|
matrix-nio==0.20.1
|
||||||
|
multidict==6.0.4
|
||||||
|
peewee==3.16.0
|
||||||
|
pycparser==2.21
|
||||||
|
pycryptodome==3.17
|
||||||
|
pycryptodomex==3.17
|
||||||
|
pyrsistent==0.19.3
|
||||||
|
python-olm==3.1.3
|
||||||
|
python-socks==2.1.1
|
||||||
|
regex==2022.10.31
|
||||||
|
requests==2.28.2
|
||||||
|
tiktoken==0.3.0
|
||||||
|
unpaddedbase64==2.1.0
|
||||||
|
urllib3==1.26.14
|
||||||
|
wcwidth==0.2.6
|
||||||
|
yarl==1.8.2
|
12
send_message.py
Normal file
12
send_message.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
from nio import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
async def send_room_message(client: AsyncClient,
|
||||||
|
room_id: str,
|
||||||
|
send_text: str) -> None:
|
||||||
|
await client.room_send(
|
||||||
|
room_id,
|
||||||
|
message_type="m.room.message",
|
||||||
|
content={"msgtype": "m.text", "body": f"{send_text}"},
|
||||||
|
)
|
||||||
|
await client.room_typing(room_id, typing_state=False)
|
20
test.py
Normal file
20
test.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
from v3 import Chatbot
|
||||||
|
import asyncio
|
||||||
|
from ask_gpt import ask
|
||||||
|
import json
|
||||||
|
fp = open("config.json", "r")
|
||||||
|
config = json.load(fp)
|
||||||
|
|
||||||
|
|
||||||
|
def test_v3(prompt: str):
|
||||||
|
bot = Chatbot(api_key=config['api_key'])
|
||||||
|
resp = bot.ask(prompt=prompt)
|
||||||
|
print(resp)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ask(prompt: str):
|
||||||
|
print(await ask(prompt=prompt))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_v3("Hello World")
|
||||||
|
asyncio.run(test_ask("Hello World"))
|
192
v3.py
Normal file
192
v3.py
Normal file
|
@ -0,0 +1,192 @@
|
||||||
|
"""
|
||||||
|
A simple wrapper for the official ChatGPT API
|
||||||
|
https://github.com/acheong08/ChatGPT/blob/main/src/revChatGPT/V3.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
|
ENGINE = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo"
|
||||||
|
ENCODER = tiktoken.get_encoding("gpt2")
|
||||||
|
|
||||||
|
|
||||||
|
class Chatbot:
|
||||||
|
"""
|
||||||
|
Official ChatGPT API
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str = None,
|
||||||
|
engine: str = None,
|
||||||
|
proxy: str = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.5,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
reply_count: int = 1,
|
||||||
|
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
|
||||||
|
"""
|
||||||
|
self.engine = engine or ENGINE
|
||||||
|
self.session = requests.Session()
|
||||||
|
self.api_key = api_key
|
||||||
|
# self.proxy = proxy
|
||||||
|
# if self.proxy:
|
||||||
|
# proxies = {
|
||||||
|
# "http": self.proxy,
|
||||||
|
# "https": self.proxy,
|
||||||
|
# }
|
||||||
|
# self.session.proxies = proxies
|
||||||
|
self.conversation: dict = {
|
||||||
|
"default": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": system_prompt,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.temperature = temperature
|
||||||
|
self.top_p = top_p
|
||||||
|
self.reply_count = reply_count
|
||||||
|
|
||||||
|
initial_conversation = "\n".join(
|
||||||
|
[x["content"] for x in self.conversation["default"]],
|
||||||
|
)
|
||||||
|
if len(ENCODER.encode(initial_conversation)) > self.max_tokens:
|
||||||
|
raise Exception("System prompt is too long")
|
||||||
|
|
||||||
|
def add_to_conversation(
|
||||||
|
self, message: str, role: str, convo_id: str = "default"
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add a message to the conversation
|
||||||
|
"""
|
||||||
|
self.conversation[convo_id].append({"role": role, "content": message})
|
||||||
|
|
||||||
|
def __truncate_conversation(self, convo_id: str = "default") -> None:
|
||||||
|
"""
|
||||||
|
Truncate the conversation
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
full_conversation = "".join(
|
||||||
|
message["role"] + ": " + message["content"] + "\n"
|
||||||
|
for message in self.conversation[convo_id]
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
len(ENCODER.encode(full_conversation)) > self.max_tokens
|
||||||
|
and len(self.conversation[convo_id]) > 1
|
||||||
|
):
|
||||||
|
# Don't remove the first message
|
||||||
|
self.conversation[convo_id].pop(1)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
def get_max_tokens(self, convo_id: str) -> int:
|
||||||
|
"""
|
||||||
|
Get max tokens
|
||||||
|
"""
|
||||||
|
full_conversation = "".join(
|
||||||
|
message["role"] + ": " + message["content"] + "\n"
|
||||||
|
for message in self.conversation[convo_id]
|
||||||
|
)
|
||||||
|
return 4000 - len(ENCODER.encode(full_conversation))
|
||||||
|
|
||||||
|
def ask_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
role: str = "user",
|
||||||
|
convo_id: str = "default",
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Ask a question
|
||||||
|
"""
|
||||||
|
# Make conversation if it doesn't exist
|
||||||
|
if convo_id not in self.conversation:
|
||||||
|
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
|
||||||
|
self.add_to_conversation(prompt, "user", convo_id=convo_id)
|
||||||
|
self.__truncate_conversation(convo_id=convo_id)
|
||||||
|
# Get response
|
||||||
|
response = self.session.post(
|
||||||
|
"https://api.openai.com/v1/chat/completions",
|
||||||
|
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
||||||
|
json={
|
||||||
|
"model": self.engine,
|
||||||
|
"messages": self.conversation[convo_id],
|
||||||
|
"stream": True,
|
||||||
|
# kwargs
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
"top_p": kwargs.get("top_p", self.top_p),
|
||||||
|
"n": kwargs.get("n", self.reply_count),
|
||||||
|
"user": role,
|
||||||
|
# "max_tokens": self.get_max_tokens(convo_id=convo_id),
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Error: {response.status_code} {response.reason} {response.text}",
|
||||||
|
)
|
||||||
|
response_role: str = None
|
||||||
|
full_response: str = ""
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
# Remove "data: "
|
||||||
|
line = line.decode("utf-8")[6:]
|
||||||
|
if line == "[DONE]":
|
||||||
|
break
|
||||||
|
resp: dict = json.loads(line)
|
||||||
|
choices = resp.get("choices")
|
||||||
|
if not choices:
|
||||||
|
continue
|
||||||
|
delta = choices[0].get("delta")
|
||||||
|
if not delta:
|
||||||
|
continue
|
||||||
|
if "role" in delta:
|
||||||
|
response_role = delta["role"]
|
||||||
|
if "content" in delta:
|
||||||
|
content = delta["content"]
|
||||||
|
full_response += content
|
||||||
|
yield content
|
||||||
|
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
|
||||||
|
|
||||||
|
def ask(
|
||||||
|
self, prompt: str, role: str = "user", convo_id: str = "default", **kwargs
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Non-streaming ask
|
||||||
|
"""
|
||||||
|
response = self.ask_stream(
|
||||||
|
prompt=prompt,
|
||||||
|
role=role,
|
||||||
|
convo_id=convo_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
full_response: str = "".join(response)
|
||||||
|
return full_response
|
||||||
|
|
||||||
|
def rollback(self, n: int = 1, convo_id: str = "default") -> None:
|
||||||
|
"""
|
||||||
|
Rollback the conversation
|
||||||
|
"""
|
||||||
|
for _ in range(n):
|
||||||
|
self.conversation[convo_id].pop()
|
||||||
|
|
||||||
|
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
|
||||||
|
"""
|
||||||
|
Reset the conversation
|
||||||
|
"""
|
||||||
|
self.conversation[convo_id] = [
|
||||||
|
{"role": "system", "content": system_prompt or self.system_prompt},
|
||||||
|
]
|
||||||
|
|
Loading…
Reference in a new issue