Compare commits
No commits in common. "main" and "v1.0.0" have entirely different histories.
26 changed files with 1394 additions and 1335 deletions
|
@ -3,8 +3,6 @@ images
|
||||||
*.md
|
*.md
|
||||||
Dockerfile
|
Dockerfile
|
||||||
Dockerfile-dev
|
Dockerfile-dev
|
||||||
compose.yaml
|
|
||||||
compose-dev.yaml
|
|
||||||
.dockerignore
|
.dockerignore
|
||||||
config.json
|
config.json
|
||||||
config.json.sample
|
config.json.sample
|
||||||
|
@ -17,13 +15,7 @@ venv
|
||||||
.git
|
.git
|
||||||
.idea
|
.idea
|
||||||
__pycache__
|
__pycache__
|
||||||
src/__pycache__
|
|
||||||
.env
|
.env
|
||||||
.env.example
|
.env.example
|
||||||
.github
|
.github
|
||||||
settings.js
|
settings.js
|
||||||
mattermost-server
|
|
||||||
tests
|
|
||||||
full-config.json.example
|
|
||||||
config.json.example
|
|
||||||
.full-env.example
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
SERVER_URL="xxxxx.xxxxxx.xxxxxxxxx"
|
SERVER_URL="xxxxx.xxxxxx.xxxxxxxxx"
|
||||||
|
ACCESS_TOKEN="xxxxxxxxxxxxxxxxx"
|
||||||
USERNAME="@chatgpt"
|
USERNAME="@chatgpt"
|
||||||
EMAIL="xxxxxx"
|
OPENAI_API_KEY="sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||||
PASSWORD="xxxxxxxxxxxxxx"
|
BING_API_ENDPOINT="http://api:3000/conversation"
|
||||||
OPENAI_API_KEY="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
BARD_TOKEN="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx."
|
||||||
GPT_MODEL="gpt-3.5-turbo"
|
BING_AUTH_COOKIE="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
|
@ -1,24 +0,0 @@
|
||||||
SERVER_URL="xxxxx.xxxxxx.xxxxxxxxx"
|
|
||||||
EMAIL="xxxxxx"
|
|
||||||
USERNAME="@chatgpt"
|
|
||||||
PASSWORD="xxxxxxxxxxxxxx"
|
|
||||||
PORT=443
|
|
||||||
SCHEME="https"
|
|
||||||
OPENAI_API_KEY="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
|
||||||
GPT_API_ENDPOINT="https://api.openai.com/v1/chat/completions"
|
|
||||||
GPT_MODEL="gpt-3.5-turbo"
|
|
||||||
MAX_TOKENS=4000
|
|
||||||
TOP_P=1.0
|
|
||||||
PRESENCE_PENALTY=0.0
|
|
||||||
FREQUENCY_PENALTY=0.0
|
|
||||||
REPLY_COUNT=1
|
|
||||||
SYSTEM_PROMPT="You are ChatGPT, a large language model trained by OpenAI. Respond conversationally"
|
|
||||||
TEMPERATURE=0.8
|
|
||||||
IMAGE_GENERATION_ENDPOINT="http://127.0.0.1:7860/sdapi/v1/txt2img"
|
|
||||||
IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui or localai
|
|
||||||
IMAGE_GENERATION_SIZE="512x512"
|
|
||||||
IMAGE_FORMAT="jpeg"
|
|
||||||
SDWUI_STEPS=20
|
|
||||||
SDWUI_SAMPLER_NAME="Euler a"
|
|
||||||
SDWUI_CFG_SCALE=7
|
|
||||||
TIMEOUT=120.0
|
|
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -134,7 +134,3 @@ dmypy.json
|
||||||
|
|
||||||
# Pyre type checker
|
# Pyre type checker
|
||||||
.pyre/
|
.pyre/
|
||||||
|
|
||||||
# custom
|
|
||||||
compose-dev.yaml
|
|
||||||
mattermost-server
|
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
repos:
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
||||||
rev: v4.5.0
|
|
||||||
hooks:
|
|
||||||
- id: trailing-whitespace
|
|
||||||
- id: end-of-file-fixer
|
|
||||||
- id: check-yaml
|
|
||||||
- repo: https://github.com/psf/black
|
|
||||||
rev: 23.12.0
|
|
||||||
hooks:
|
|
||||||
- id: black
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
||||||
rev: v0.1.7
|
|
||||||
hooks:
|
|
||||||
- id: ruff
|
|
||||||
args: [--fix, --exit-non-zero-on-fix]
|
|
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"python.formatting.provider": "black"
|
||||||
|
}
|
165
BingImageGen.py
Normal file
165
BingImageGen.py
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
"""
|
||||||
|
Code derived from:
|
||||||
|
https://github.com/acheong08/EdgeGPT/blob/f940cecd24a4818015a8b42a2443dd97c3c2a8f4/src/ImageGen.py
|
||||||
|
"""
|
||||||
|
from log import getlogger
|
||||||
|
from uuid import uuid4
|
||||||
|
import os
|
||||||
|
import contextlib
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
import requests
|
||||||
|
import regex
|
||||||
|
|
||||||
|
logger = getlogger()
|
||||||
|
|
||||||
|
BING_URL = "https://www.bing.com"
|
||||||
|
# Generate random IP between range 13.104.0.0/14
|
||||||
|
FORWARDED_IP = (
|
||||||
|
f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
|
||||||
|
)
|
||||||
|
HEADERS = {
|
||||||
|
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
|
||||||
|
"accept-language": "en-US,en;q=0.9",
|
||||||
|
"cache-control": "max-age=0",
|
||||||
|
"content-type": "application/x-www-form-urlencoded",
|
||||||
|
"referrer": "https://www.bing.com/images/create/",
|
||||||
|
"origin": "https://www.bing.com",
|
||||||
|
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36 Edg/110.0.1587.63",
|
||||||
|
"x-forwarded-for": FORWARDED_IP,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGenAsync:
|
||||||
|
"""
|
||||||
|
Image generation by Microsoft Bing
|
||||||
|
Parameters:
|
||||||
|
auth_cookie: str
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, auth_cookie: str, quiet: bool = True) -> None:
|
||||||
|
self.session = aiohttp.ClientSession(
|
||||||
|
headers=HEADERS,
|
||||||
|
cookies={"_U": auth_cookie},
|
||||||
|
)
|
||||||
|
self.quiet = quiet
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *excinfo) -> None:
|
||||||
|
await self.session.close()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_until_complete(self._close())
|
||||||
|
|
||||||
|
async def _close(self):
|
||||||
|
await self.session.close()
|
||||||
|
|
||||||
|
async def get_images(self, prompt: str) -> list:
|
||||||
|
"""
|
||||||
|
Fetches image links from Bing
|
||||||
|
Parameters:
|
||||||
|
prompt: str
|
||||||
|
"""
|
||||||
|
if not self.quiet:
|
||||||
|
print("Sending request...")
|
||||||
|
url_encoded_prompt = requests.utils.quote(prompt)
|
||||||
|
# https://www.bing.com/images/create?q=<PROMPT>&rt=3&FORM=GENCRE
|
||||||
|
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE"
|
||||||
|
async with self.session.post(url, allow_redirects=False) as response:
|
||||||
|
content = await response.text()
|
||||||
|
if "this prompt has been blocked" in content.lower():
|
||||||
|
raise Exception(
|
||||||
|
"Your prompt has been blocked by Bing. Try to change any bad words and try again.",
|
||||||
|
)
|
||||||
|
if response.status != 302:
|
||||||
|
# if rt4 fails, try rt3
|
||||||
|
url = (
|
||||||
|
f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=3&FORM=GENCRE"
|
||||||
|
)
|
||||||
|
async with self.session.post(
|
||||||
|
url,
|
||||||
|
allow_redirects=False,
|
||||||
|
timeout=200,
|
||||||
|
) as response3:
|
||||||
|
if response3.status != 302:
|
||||||
|
print(f"ERROR: {response3.text}")
|
||||||
|
raise Exception("Redirect failed")
|
||||||
|
response = response3
|
||||||
|
# Get redirect URL
|
||||||
|
redirect_url = response.headers["Location"].replace("&nfy=1", "")
|
||||||
|
request_id = redirect_url.split("id=")[-1]
|
||||||
|
await self.session.get(f"{BING_URL}{redirect_url}")
|
||||||
|
# https://www.bing.com/images/create/async/results/{ID}?q={PROMPT}
|
||||||
|
polling_url = f"{BING_URL}/images/create/async/results/{request_id}?q={url_encoded_prompt}"
|
||||||
|
# Poll for results
|
||||||
|
if not self.quiet:
|
||||||
|
print("Waiting for results...")
|
||||||
|
while True:
|
||||||
|
if not self.quiet:
|
||||||
|
print(".", end="", flush=True)
|
||||||
|
# By default, timeout is 300s, change as needed
|
||||||
|
response = await self.session.get(polling_url)
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception("Could not get results")
|
||||||
|
content = await response.text()
|
||||||
|
if content and content.find("errorMessage") == -1:
|
||||||
|
break
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
continue
|
||||||
|
# Use regex to search for src=""
|
||||||
|
image_links = regex.findall(r'src="([^"]+)"', content)
|
||||||
|
# Remove size limit
|
||||||
|
normal_image_links = [link.split("?w=")[0] for link in image_links]
|
||||||
|
# Remove duplicates
|
||||||
|
normal_image_links = list(set(normal_image_links))
|
||||||
|
|
||||||
|
# Bad images
|
||||||
|
bad_images = [
|
||||||
|
"https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
|
||||||
|
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
|
||||||
|
]
|
||||||
|
for im in normal_image_links:
|
||||||
|
if im in bad_images:
|
||||||
|
raise Exception("Bad images")
|
||||||
|
# No images
|
||||||
|
if not normal_image_links:
|
||||||
|
raise Exception("No images")
|
||||||
|
return normal_image_links
|
||||||
|
|
||||||
|
async def save_images(self, links: list, output_dir: str) -> str:
|
||||||
|
"""
|
||||||
|
Saves images to output directory
|
||||||
|
"""
|
||||||
|
if not self.quiet:
|
||||||
|
print("\nDownloading images...")
|
||||||
|
with contextlib.suppress(FileExistsError):
|
||||||
|
os.mkdir(output_dir)
|
||||||
|
|
||||||
|
# image name
|
||||||
|
image_name = str(uuid4())
|
||||||
|
# we just need one image for better display in chat room
|
||||||
|
if links:
|
||||||
|
link = links.pop()
|
||||||
|
|
||||||
|
image_path = os.path.join(output_dir, f"{image_name}.jpeg")
|
||||||
|
try:
|
||||||
|
async with self.session.get(link, raise_for_status=True) as response:
|
||||||
|
# save response to file
|
||||||
|
with open(image_path, "wb") as output_file:
|
||||||
|
async for chunk in response.content.iter_chunked(8192):
|
||||||
|
output_file.write(chunk)
|
||||||
|
return f"{output_dir}/{image_name}.jpeg"
|
||||||
|
|
||||||
|
except aiohttp.client_exceptions.InvalidURL as url_exception:
|
||||||
|
raise Exception(
|
||||||
|
"Inappropriate contents found in the generated images. Please try again or try another prompt.",
|
||||||
|
) from url_exception
|
26
CHANGELOG.md
26
CHANGELOG.md
|
@ -1,26 +0,0 @@
|
||||||
# Changelog
|
|
||||||
|
|
||||||
## v1.3.2
|
|
||||||
- Make gptbot more compatible
|
|
||||||
|
|
||||||
## v1.3.1
|
|
||||||
- Expose more stable diffusion webui api parameters
|
|
||||||
|
|
||||||
## v1.3.0
|
|
||||||
- Fix localai v2.0+ image generation
|
|
||||||
- Support specific output image format(jpeg, png) and size
|
|
||||||
|
|
||||||
## v1.2.0
|
|
||||||
- support sending typing state
|
|
||||||
|
|
||||||
## v1.1.0
|
|
||||||
- remove pandora
|
|
||||||
- refactor chat and image genderation backend
|
|
||||||
- reply in thread by default
|
|
||||||
- introduce pre-commit hooks
|
|
||||||
|
|
||||||
## v1.0.4
|
|
||||||
|
|
||||||
- refactor code structure and remove unused
|
|
||||||
- remove Bing AI and Google Bard due to technical problems
|
|
||||||
- bug fix and improvement
|
|
|
@ -13,4 +13,4 @@ COPY . /app
|
||||||
|
|
||||||
FROM runner
|
FROM runner
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
CMD ["python", "src/main.py"]
|
CMD ["python", "main.py"]
|
||||||
|
|
26
README.md
26
README.md
|
@ -1,11 +1,12 @@
|
||||||
## Introduction
|
## Introduction
|
||||||
|
|
||||||
This is a simple Mattermost Bot that uses OpenAI's GPT API(or self-host models) to generate responses to user inputs. The bot responds to these commands: `!gpt`, `!chat` and `!new` and `!help` depending on the first word of the prompt.
|
This is a simple Mattermost Bot that uses OpenAI's GPT API and Bing AI and Google Bard to generate responses to user inputs. The bot responds to six types of prompts: `!gpt`, `!chat` and `!bing` and `!pic` and `!bard` and `!help` depending on the first word of the prompt.
|
||||||
|
|
||||||
## Feature
|
## Feature
|
||||||
|
|
||||||
1. Support official openai api and self host models([LocalAI](https://localai.io/model-compatibility/))
|
1. Support Openai ChatGPT and Bing AI and Google Bard(US only at the moment)
|
||||||
2. Image Generation with [DALL·E](https://platform.openai.com/docs/api-reference/images/create) or [LocalAI](https://localai.io/features/image-generation/) or [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)
|
2. Support Bing Image Creator
|
||||||
|
|
||||||
## Installation and Setup
|
## Installation and Setup
|
||||||
|
|
||||||
See https://github.com/hibobmaster/mattermost_bot/wiki
|
See https://github.com/hibobmaster/mattermost_bot/wiki
|
||||||
|
@ -16,21 +17,8 @@ Edit `config.json` or `.env` with proper values
|
||||||
docker compose up -d
|
docker compose up -d
|
||||||
```
|
```
|
||||||
|
|
||||||
## Commands
|
|
||||||
|
|
||||||
- `!help` help message
|
|
||||||
- `!gpt + [prompt]` generate a one time response from chatGPT
|
|
||||||
- `!chat + [prompt]` chat using official chatGPT api with context conversation
|
|
||||||
- `!pic + [prompt]` Image generation with DALL·E or LocalAI or stable-diffusion-webui
|
|
||||||
|
|
||||||
- `!new` start a new converstaion
|
|
||||||
|
|
||||||
## Demo
|
## Demo
|
||||||
Remove support for Bing AI, Google Bard due to technical problems.
|
|
||||||
![gpt command](https://imgur.com/vdT83Ln.jpg)
|
|
||||||
![image generation](https://i.imgur.com/DQ3i3wW.jpg)
|
|
||||||
|
|
||||||
## Thanks
|
![demo1](https://i.imgur.com/XRAQB4B.jpg)
|
||||||
<a href="https://jb.gg/OpenSourceSupport" target="_blank">
|
![demo2](https://i.imgur.com/if72kyH.jpg)
|
||||||
<img src="https://resources.jetbrains.com/storage/products/company/brand/logos/jb_beam.png" alt="JetBrains Logo (Main) logo." width="200" height="200">
|
![demo3](https://i.imgur.com/GHczfkv.jpg)
|
||||||
</a>
|
|
46
askgpt.py
Normal file
46
askgpt.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
from log import getlogger
|
||||||
|
|
||||||
|
logger = getlogger()
|
||||||
|
|
||||||
|
|
||||||
|
class askGPT:
|
||||||
|
def __init__(
|
||||||
|
self, session: aiohttp.ClientSession, api_endpoint: str, headers: str
|
||||||
|
) -> None:
|
||||||
|
self.session = session
|
||||||
|
self.api_endpoint = api_endpoint
|
||||||
|
self.headers = headers
|
||||||
|
|
||||||
|
async def oneTimeAsk(self, prompt: str) -> str:
|
||||||
|
jsons = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
max_try = 2
|
||||||
|
while max_try > 0:
|
||||||
|
try:
|
||||||
|
async with self.session.post(
|
||||||
|
url=self.api_endpoint, json=jsons, headers=self.headers, timeout=120
|
||||||
|
) as response:
|
||||||
|
status_code = response.status
|
||||||
|
if not status_code == 200:
|
||||||
|
# print failed reason
|
||||||
|
logger.warning(str(response.reason))
|
||||||
|
max_try = max_try - 1
|
||||||
|
# wait 2s
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
continue
|
||||||
|
|
||||||
|
resp = await response.read()
|
||||||
|
return json.loads(resp)["choices"][0]["message"]["content"]
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(e)
|
104
bard.py
Normal file
104
bard.py
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
"""
|
||||||
|
Code derived from: https://github.com/acheong08/Bard/blob/main/src/Bard.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class Bardbot:
|
||||||
|
"""
|
||||||
|
A class to interact with Google Bard.
|
||||||
|
Parameters
|
||||||
|
session_id: str
|
||||||
|
The __Secure-1PSID cookie.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = [
|
||||||
|
"headers",
|
||||||
|
"_reqid",
|
||||||
|
"SNlM0e",
|
||||||
|
"conversation_id",
|
||||||
|
"response_id",
|
||||||
|
"choice_id",
|
||||||
|
"session",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, session_id):
|
||||||
|
headers = {
|
||||||
|
"Host": "bard.google.com",
|
||||||
|
"X-Same-Domain": "1",
|
||||||
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.114 Safari/537.36",
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded;charset=UTF-8",
|
||||||
|
"Origin": "https://bard.google.com",
|
||||||
|
"Referer": "https://bard.google.com/",
|
||||||
|
}
|
||||||
|
self._reqid = int("".join(random.choices(string.digits, k=4)))
|
||||||
|
self.conversation_id = ""
|
||||||
|
self.response_id = ""
|
||||||
|
self.choice_id = ""
|
||||||
|
self.session = requests.Session()
|
||||||
|
self.session.headers = headers
|
||||||
|
self.session.cookies.set("__Secure-1PSID", session_id)
|
||||||
|
self.SNlM0e = self.__get_snlm0e()
|
||||||
|
|
||||||
|
def __get_snlm0e(self):
|
||||||
|
resp = self.session.get(url="https://bard.google.com/", timeout=10)
|
||||||
|
# Find "SNlM0e":"<ID>"
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise Exception("Could not get Google Bard")
|
||||||
|
SNlM0e = re.search(r"SNlM0e\":\"(.*?)\"", resp.text).group(1)
|
||||||
|
return SNlM0e
|
||||||
|
|
||||||
|
def ask(self, message: str) -> dict:
|
||||||
|
"""
|
||||||
|
Send a message to Google Bard and return the response.
|
||||||
|
:param message: The message to send to Google Bard.
|
||||||
|
:return: A dict containing the response from Google Bard.
|
||||||
|
"""
|
||||||
|
# url params
|
||||||
|
params = {
|
||||||
|
"bl": "boq_assistant-bard-web-server_20230326.21_p0",
|
||||||
|
"_reqid": str(self._reqid),
|
||||||
|
"rt": "c",
|
||||||
|
}
|
||||||
|
|
||||||
|
# message arr -> data["f.req"]. Message is double json stringified
|
||||||
|
message_struct = [
|
||||||
|
[message],
|
||||||
|
None,
|
||||||
|
[self.conversation_id, self.response_id, self.choice_id],
|
||||||
|
]
|
||||||
|
data = {
|
||||||
|
"f.req": json.dumps([None, json.dumps(message_struct)]),
|
||||||
|
"at": self.SNlM0e,
|
||||||
|
}
|
||||||
|
|
||||||
|
# do the request!
|
||||||
|
resp = self.session.post(
|
||||||
|
"https://bard.google.com/_/BardChatUi/data/assistant.lamda.BardFrontendService/StreamGenerate",
|
||||||
|
params=params,
|
||||||
|
data=data,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_data = json.loads(resp.content.splitlines()[3])[0][2]
|
||||||
|
if not chat_data:
|
||||||
|
return {"content": f"Google Bard encountered an error: {resp.content}."}
|
||||||
|
json_chat_data = json.loads(chat_data)
|
||||||
|
results = {
|
||||||
|
"content": json_chat_data[0][0],
|
||||||
|
"conversation_id": json_chat_data[1][0],
|
||||||
|
"response_id": json_chat_data[1][1],
|
||||||
|
"factualityQueries": json_chat_data[3],
|
||||||
|
"textQuery": json_chat_data[2][0] if json_chat_data[2] is not None else "",
|
||||||
|
"choices": [{"id": i[0], "content": i[1]} for i in json_chat_data[4]],
|
||||||
|
}
|
||||||
|
self.conversation_id = results["conversation_id"]
|
||||||
|
self.response_id = results["response_id"]
|
||||||
|
self.choice_id = results["choices"][0]["id"]
|
||||||
|
self._reqid += 100000
|
||||||
|
return results
|
64
bing.py
Normal file
64
bing.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from log import getlogger
|
||||||
|
|
||||||
|
# api_endpoint = "http://localhost:3000/conversation"
|
||||||
|
from log import getlogger
|
||||||
|
|
||||||
|
logger = getlogger()
|
||||||
|
|
||||||
|
|
||||||
|
class BingBot:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
bing_api_endpoint: str,
|
||||||
|
jailbreakEnabled: bool = True,
|
||||||
|
):
|
||||||
|
self.data = {
|
||||||
|
"clientOptions.clientToUse": "bing",
|
||||||
|
}
|
||||||
|
self.bing_api_endpoint = bing_api_endpoint
|
||||||
|
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
self.jailbreakEnabled = jailbreakEnabled
|
||||||
|
|
||||||
|
if self.jailbreakEnabled:
|
||||||
|
self.data["jailbreakConversationId"] = True
|
||||||
|
|
||||||
|
async def ask_bing(self, prompt) -> str:
|
||||||
|
self.data["message"] = prompt
|
||||||
|
max_try = 2
|
||||||
|
while max_try > 0:
|
||||||
|
try:
|
||||||
|
resp = await self.session.post(
|
||||||
|
url=self.bing_api_endpoint, json=self.data, timeout=120
|
||||||
|
)
|
||||||
|
status_code = resp.status
|
||||||
|
body = await resp.read()
|
||||||
|
if not status_code == 200:
|
||||||
|
# print failed reason
|
||||||
|
logger.warning(str(resp.reason))
|
||||||
|
max_try = max_try - 1
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
continue
|
||||||
|
json_body = json.loads(body)
|
||||||
|
if self.jailbreakEnabled:
|
||||||
|
self.data["jailbreakConversationId"] = json_body[
|
||||||
|
"jailbreakConversationId"
|
||||||
|
]
|
||||||
|
self.data["parentMessageId"] = json_body["messageId"]
|
||||||
|
else:
|
||||||
|
self.data["conversationSignature"] = json_body[
|
||||||
|
"conversationSignature"
|
||||||
|
]
|
||||||
|
self.data["conversationId"] = json_body["conversationId"]
|
||||||
|
self.data["clientId"] = json_body["clientId"]
|
||||||
|
self.data["invocationId"] = json_body["invocationId"]
|
||||||
|
return json_body["details"]["adaptiveCards"][0]["body"][0]["text"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error Exception", exc_info=True)
|
||||||
|
|
||||||
|
return "Error, please retry"
|
326
bot.py
Normal file
326
bot.py
Normal file
|
@ -0,0 +1,326 @@
|
||||||
|
from mattermostdriver import Driver
|
||||||
|
from typing import Optional
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
import aiohttp
|
||||||
|
from askgpt import askGPT
|
||||||
|
from v3 import Chatbot
|
||||||
|
from bing import BingBot
|
||||||
|
from bard import Bardbot
|
||||||
|
from BingImageGen import ImageGenAsync
|
||||||
|
from log import getlogger
|
||||||
|
|
||||||
|
logger = getlogger()
|
||||||
|
|
||||||
|
|
||||||
|
class Bot:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
username: str,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
login_id: Optional[str] = None,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
openai_api_key: Optional[str] = None,
|
||||||
|
openai_api_endpoint: Optional[str] = None,
|
||||||
|
bing_api_endpoint: Optional[str] = None,
|
||||||
|
bard_token: Optional[str] = None,
|
||||||
|
bing_auth_cookie: Optional[str] = None,
|
||||||
|
port: int = 443,
|
||||||
|
timeout: int = 30,
|
||||||
|
) -> None:
|
||||||
|
if server_url is None:
|
||||||
|
raise ValueError("server url must be provided")
|
||||||
|
|
||||||
|
if port is None:
|
||||||
|
self.port = 443
|
||||||
|
|
||||||
|
if timeout is None:
|
||||||
|
self.timeout = 30
|
||||||
|
|
||||||
|
# login relative info
|
||||||
|
if access_token is None and password is None:
|
||||||
|
raise ValueError("Either token or password must be provided")
|
||||||
|
|
||||||
|
if access_token is not None:
|
||||||
|
self.driver = Driver(
|
||||||
|
{
|
||||||
|
"token": access_token,
|
||||||
|
"url": server_url,
|
||||||
|
"port": self.port,
|
||||||
|
"request_timeout": self.timeout,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.driver = Driver(
|
||||||
|
{
|
||||||
|
"login_id": login_id,
|
||||||
|
"password": password,
|
||||||
|
"url": server_url,
|
||||||
|
"port": self.port,
|
||||||
|
"request_timeout": self.timeout,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# @chatgpt
|
||||||
|
if username is None:
|
||||||
|
raise ValueError("username must be provided")
|
||||||
|
else:
|
||||||
|
self.username = username
|
||||||
|
|
||||||
|
# openai_api_endpoint
|
||||||
|
if openai_api_endpoint is None:
|
||||||
|
self.openai_api_endpoint = "https://api.openai.com/v1/chat/completions"
|
||||||
|
else:
|
||||||
|
self.openai_api_endpoint = openai_api_endpoint
|
||||||
|
|
||||||
|
# aiohttp session
|
||||||
|
self.session = aiohttp.ClientSession()
|
||||||
|
|
||||||
|
self.openai_api_key = openai_api_key
|
||||||
|
# initialize chatGPT class
|
||||||
|
if self.openai_api_key is not None:
|
||||||
|
# request header for !gpt command
|
||||||
|
self.headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.openai_api_key}",
|
||||||
|
}
|
||||||
|
|
||||||
|
self.askgpt = askGPT(
|
||||||
|
self.session,
|
||||||
|
self.openai_api_endpoint,
|
||||||
|
self.headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.chatbot = Chatbot(api_key=self.openai_api_key)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"openai_api_key is not provided, !gpt and !chat command will not work"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bing_api_endpoint = bing_api_endpoint
|
||||||
|
# initialize bingbot
|
||||||
|
if self.bing_api_endpoint is not None:
|
||||||
|
self.bingbot = BingBot(
|
||||||
|
session=self.session,
|
||||||
|
bing_api_endpoint=self.bing_api_endpoint,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"bing_api_endpoint is not provided, !bing command will not work"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bard_token = bard_token
|
||||||
|
# initialize bard
|
||||||
|
if self.bard_token is not None:
|
||||||
|
self.bardbot = Bardbot(session_id=self.bard_token)
|
||||||
|
else:
|
||||||
|
logger.warning("bard_token is not provided, !bard command will not work")
|
||||||
|
|
||||||
|
self.bing_auth_cookie = bing_auth_cookie
|
||||||
|
# initialize image generator
|
||||||
|
if self.bing_auth_cookie is not None:
|
||||||
|
self.imagegen = ImageGenAsync(auth_cookie=self.bing_auth_cookie)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"bing_auth_cookie is not provided, !pic command will not work"
|
||||||
|
)
|
||||||
|
|
||||||
|
# regular expression to match keyword [!gpt {prompt}] [!chat {prompt}] [!bing {prompt}] [!pic {prompt}] [!bard {prompt}]
|
||||||
|
self.gpt_prog = re.compile(r"^\s*!gpt\s*(.+)$")
|
||||||
|
self.chat_prog = re.compile(r"^\s*!chat\s*(.+)$")
|
||||||
|
self.bing_prog = re.compile(r"^\s*!bing\s*(.+)$")
|
||||||
|
self.bard_prog = re.compile(r"^\s*!bard\s*(.+)$")
|
||||||
|
self.pic_prog = re.compile(r"^\s*!pic\s*(.+)$")
|
||||||
|
self.help_prog = re.compile(r"^\s*!help\s*.*$")
|
||||||
|
|
||||||
|
# close session
|
||||||
|
def __del__(self) -> None:
|
||||||
|
self.driver.disconnect()
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
await self.session.close()
|
||||||
|
|
||||||
|
def login(self) -> None:
|
||||||
|
self.driver.login()
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
await self.driver.init_websocket(self.websocket_handler)
|
||||||
|
|
||||||
|
# websocket handler
|
||||||
|
async def websocket_handler(self, message) -> None:
|
||||||
|
print(message)
|
||||||
|
response = json.loads(message)
|
||||||
|
if "event" in response:
|
||||||
|
event_type = response["event"]
|
||||||
|
if event_type == "posted":
|
||||||
|
raw_data = response["data"]["post"]
|
||||||
|
raw_data_dict = json.loads(raw_data)
|
||||||
|
user_id = raw_data_dict["user_id"]
|
||||||
|
channel_id = raw_data_dict["channel_id"]
|
||||||
|
sender_name = response["data"]["sender_name"]
|
||||||
|
raw_message = raw_data_dict["message"]
|
||||||
|
try:
|
||||||
|
asyncio.create_task(
|
||||||
|
self.message_callback(
|
||||||
|
raw_message, channel_id, user_id, sender_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
await asyncio.to_thread(self.send_message, channel_id, f"{e}")
|
||||||
|
|
||||||
|
# message callback
|
||||||
|
async def message_callback(
|
||||||
|
self, raw_message: str, channel_id: str, user_id: str, sender_name: str
|
||||||
|
) -> None:
|
||||||
|
# prevent command trigger loop
|
||||||
|
if sender_name != self.username:
|
||||||
|
message = raw_message
|
||||||
|
|
||||||
|
if self.openai_api_key is not None:
|
||||||
|
# !gpt command trigger handler
|
||||||
|
if self.gpt_prog.match(message):
|
||||||
|
prompt = self.gpt_prog.match(message).group(1)
|
||||||
|
try:
|
||||||
|
response = await self.gpt(prompt)
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.send_message, channel_id, f"{response}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
raise Exception(e)
|
||||||
|
|
||||||
|
# !chat command trigger handler
|
||||||
|
elif self.chat_prog.match(message):
|
||||||
|
prompt = self.chat_prog.match(message).group(1)
|
||||||
|
try:
|
||||||
|
response = await self.chat(prompt)
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.send_message, channel_id, f"{response}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
raise Exception(e)
|
||||||
|
|
||||||
|
if self.bing_api_endpoint is not None:
|
||||||
|
# !bing command trigger handler
|
||||||
|
if self.bing_prog.match(message):
|
||||||
|
prompt = self.bing_prog.match(message).group(1)
|
||||||
|
try:
|
||||||
|
response = await self.bingbot.ask_bing(prompt)
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.send_message, channel_id, f"{response}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
raise Exception(e)
|
||||||
|
|
||||||
|
if self.bard_token is not None:
|
||||||
|
# !bard command trigger handler
|
||||||
|
if self.bard_prog.match(message):
|
||||||
|
prompt = self.bard_prog.match(message).group(1)
|
||||||
|
try:
|
||||||
|
# response is dict object
|
||||||
|
response = await self.bard(prompt)
|
||||||
|
content = str(response["content"]).strip()
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.send_message, channel_id, f"{content}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
raise Exception(e)
|
||||||
|
|
||||||
|
if self.bing_auth_cookie is not None:
|
||||||
|
# !pic command trigger handler
|
||||||
|
if self.pic_prog.match(message):
|
||||||
|
prompt = self.pic_prog.match(message).group(1)
|
||||||
|
# generate image
|
||||||
|
try:
|
||||||
|
links = await self.imagegen.get_images(prompt)
|
||||||
|
image_path = await self.imagegen.save_images(links, "images")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
raise Exception(e)
|
||||||
|
|
||||||
|
# send image
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.send_file, channel_id, prompt, image_path
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
raise Exception(e)
|
||||||
|
|
||||||
|
# !help command trigger handler
|
||||||
|
if self.help_prog.match(message):
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(self.send_message, channel_id, self.help())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
|
||||||
|
# send message to room
|
||||||
|
def send_message(self, channel_id: str, message: str) -> None:
|
||||||
|
self.driver.posts.create_post(
|
||||||
|
options={
|
||||||
|
"channel_id": channel_id,
|
||||||
|
"message": message,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# send file to room
|
||||||
|
def send_file(self, channel_id: str, message: str, filepath: str) -> None:
|
||||||
|
filename = os.path.split(filepath)[-1]
|
||||||
|
try:
|
||||||
|
file_id = self.driver.files.upload_file(
|
||||||
|
channel_id=channel_id,
|
||||||
|
files={
|
||||||
|
"files": (filename, open(filepath, "rb")),
|
||||||
|
},
|
||||||
|
)["file_infos"][0]["id"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
raise Exception(e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.driver.posts.create_post(
|
||||||
|
options={
|
||||||
|
"channel_id": channel_id,
|
||||||
|
"message": message,
|
||||||
|
"file_ids": [file_id],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# remove image after posting
|
||||||
|
os.remove(filepath)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
raise Exception(e)
|
||||||
|
|
||||||
|
# !gpt command function
|
||||||
|
async def gpt(self, prompt: str) -> str:
|
||||||
|
return await self.askgpt.oneTimeAsk(prompt)
|
||||||
|
|
||||||
|
# !chat command function
|
||||||
|
async def chat(self, prompt: str) -> str:
|
||||||
|
return await self.chatbot.ask_async(prompt)
|
||||||
|
|
||||||
|
# !bing command function
|
||||||
|
async def bing(self, prompt: str) -> str:
|
||||||
|
return await self.bingbot.ask_bing(prompt)
|
||||||
|
|
||||||
|
# !bard command function
|
||||||
|
async def bard(self, prompt: str) -> str:
|
||||||
|
return await asyncio.to_thread(self.bardbot.ask, prompt)
|
||||||
|
|
||||||
|
# !help command function
|
||||||
|
def help(self) -> str:
|
||||||
|
help_info = (
|
||||||
|
"!gpt [content], generate response without context conversation\n"
|
||||||
|
+ "!chat [content], chat with context conversation\n"
|
||||||
|
+ "!bing [content], chat with context conversation powered by Bing AI\n"
|
||||||
|
+ "!bard [content], chat with Google's Bard\n"
|
||||||
|
+ "!pic [prompt], Image generation by Microsoft Bing\n"
|
||||||
|
+ "!help, help message"
|
||||||
|
)
|
||||||
|
return help_info
|
14
compose.yaml
14
compose.yaml
|
@ -2,7 +2,7 @@ services:
|
||||||
app:
|
app:
|
||||||
image: ghcr.io/hibobmaster/mattermost_bot:latest
|
image: ghcr.io/hibobmaster/mattermost_bot:latest
|
||||||
container_name: mattermost_bot
|
container_name: mattermost_bot
|
||||||
restart: unless-stopped
|
restart: always
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
# volumes:
|
# volumes:
|
||||||
|
@ -11,13 +11,11 @@ services:
|
||||||
networks:
|
networks:
|
||||||
- mattermost_network
|
- mattermost_network
|
||||||
|
|
||||||
# pandora:
|
# api:
|
||||||
# image: pengzhile/pandora
|
# image: hibobmaster/node-chatgpt-api:latest
|
||||||
# container_name: pandora
|
# container_name: node-chatgpt-api
|
||||||
# restart: unless-stopped
|
# volumes:
|
||||||
# environment:
|
# - ./settings.js:/var/chatgpt-api/settings.js
|
||||||
# - PANDORA_ACCESS_TOKEN=xxxxxxxxxxxxxx
|
|
||||||
# - PANDORA_SERVER=0.0.0.0:8008
|
|
||||||
# networks:
|
# networks:
|
||||||
# - mattermost_network
|
# - mattermost_network
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
{
|
{
|
||||||
"server_url": "xxxx.xxxx.xxxxx",
|
"server_url": "xxxx.xxxx.xxxxx",
|
||||||
"email": "xxxxx",
|
"access_token": "xxxxxxxxxxxxxxxxxxxxxx",
|
||||||
"username": "@chatgpt",
|
"username": "@chatgpt",
|
||||||
"password": "xxxxxxxxxxxxxxxxx",
|
"openai_api_key": "sk-xxxxxxxxxxxxxxxxxxx",
|
||||||
"openai_api_key": "xxxxxxxxxxxxxxxxxxxxxxxxx",
|
"bing_api_endpoint": "http://api:3000/conversation",
|
||||||
"gpt_model": "gpt-3.5-turbo"
|
"bard_token": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxx.",
|
||||||
|
"bing_auth_cookie": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||||
}
|
}
|
|
@ -1,26 +0,0 @@
|
||||||
{
|
|
||||||
"server_url": "localhost",
|
|
||||||
"email": "bot@hibobmaster.com",
|
|
||||||
"username": "@bot",
|
|
||||||
"password": "SfBKY%K7*e&a%ZX$3g@Am&jQ",
|
|
||||||
"port": 8065,
|
|
||||||
"scheme": "http",
|
|
||||||
"openai_api_key": "xxxxxxxxxxxxxxxxxxxxxxxx",
|
|
||||||
"gpt_api_endpoint": "https://api.openai.com/v1/chat/completions",
|
|
||||||
"gpt_model": "gpt-3.5-turbo",
|
|
||||||
"max_tokens": 4000,
|
|
||||||
"top_p": 1.0,
|
|
||||||
"presence_penalty": 0.0,
|
|
||||||
"frequency_penalty": 0.0,
|
|
||||||
"reply_count": 1,
|
|
||||||
"temperature": 0.8,
|
|
||||||
"system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
|
||||||
"image_generation_endpoint": "http://localai:8080/v1/images/generations",
|
|
||||||
"image_generation_backend": "localai",
|
|
||||||
"image_generation_size": "512x512",
|
|
||||||
"sdwui_steps": 20,
|
|
||||||
"sdwui_sampler_name": "Euler a",
|
|
||||||
"sdwui_cfg_scale": 7,
|
|
||||||
"image_format": "jpeg",
|
|
||||||
"timeout": 120.0
|
|
||||||
}
|
|
49
main.py
Normal file
49
main.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
from bot import Bot
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
if os.path.exists("config.json"):
|
||||||
|
fp = open("config.json", "r", encoding="utf-8")
|
||||||
|
config = json.load(fp)
|
||||||
|
|
||||||
|
mattermost_bot = Bot(
|
||||||
|
server_url=config.get("server_url"),
|
||||||
|
access_token=config.get("access_token"),
|
||||||
|
login_id=config.get("login_id"),
|
||||||
|
password=config.get("password"),
|
||||||
|
username=config.get("username"),
|
||||||
|
openai_api_key=config.get("openai_api_key"),
|
||||||
|
openai_api_endpoint=config.get("openai_api_endpoint"),
|
||||||
|
bing_api_endpoint=config.get("bing_api_endpoint"),
|
||||||
|
bard_token=config.get("bard_token"),
|
||||||
|
bing_auth_cookie=config.get("bing_auth_cookie"),
|
||||||
|
port=config.get("port"),
|
||||||
|
timeout=config.get("timeout"),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
mattermost_bot = Bot(
|
||||||
|
server_url=os.environ.get("SERVER_URL"),
|
||||||
|
access_token=os.environ.get("ACCESS_TOKEN"),
|
||||||
|
login_id=os.environ.get("LOGIN_ID"),
|
||||||
|
password=os.environ.get("PASSWORD"),
|
||||||
|
username=os.environ.get("USERNAME"),
|
||||||
|
openai_api_key=os.environ.get("OPENAI_API_KEY"),
|
||||||
|
openai_api_endpoint=os.environ.get("OPENAI_API_ENDPOINT"),
|
||||||
|
bing_api_endpoint=os.environ.get("BING_API_ENDPOINT"),
|
||||||
|
bard_token=os.environ.get("BARD_TOKEN"),
|
||||||
|
bing_auth_cookie=os.environ.get("BING_AUTH_COOKIE"),
|
||||||
|
port=os.environ.get("PORT"),
|
||||||
|
timeout=os.environ.get("TIMEOUT"),
|
||||||
|
)
|
||||||
|
|
||||||
|
mattermost_bot.login()
|
||||||
|
|
||||||
|
await mattermost_bot.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
|
@ -1,6 +1,27 @@
|
||||||
httpx
|
aiohttp==3.8.4
|
||||||
Pillow
|
aiosignal==1.3.1
|
||||||
tiktoken
|
anyio==3.6.2
|
||||||
tenacity
|
async-timeout==4.0.2
|
||||||
aiofiles
|
attrs==23.1.0
|
||||||
|
certifi==2022.12.7
|
||||||
|
charset-normalizer==3.1.0
|
||||||
|
click==8.1.3
|
||||||
|
colorama==0.4.6
|
||||||
|
frozenlist==1.3.3
|
||||||
|
h11==0.14.0
|
||||||
|
httpcore==0.17.0
|
||||||
|
httpx==0.24.0
|
||||||
|
idna==3.4
|
||||||
mattermostdriver @ git+https://github.com/hibobmaster/python-mattermost-driver
|
mattermostdriver @ git+https://github.com/hibobmaster/python-mattermost-driver
|
||||||
|
multidict==6.0.4
|
||||||
|
mypy-extensions==1.0.0
|
||||||
|
packaging==23.1
|
||||||
|
pathspec==0.11.1
|
||||||
|
platformdirs==3.2.0
|
||||||
|
regex==2023.3.23
|
||||||
|
requests==2.28.2
|
||||||
|
sniffio==1.3.0
|
||||||
|
tiktoken==0.3.3
|
||||||
|
urllib3==1.26.15
|
||||||
|
websockets==11.0.1
|
||||||
|
yarl==1.8.2
|
||||||
|
|
380
src/bot.py
380
src/bot.py
|
@ -1,380 +0,0 @@
|
||||||
import sys
|
|
||||||
import aiofiles.os
|
|
||||||
from mattermostdriver import AsyncDriver
|
|
||||||
from typing import Optional
|
|
||||||
import json
|
|
||||||
import asyncio
|
|
||||||
import re
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from gptbot import Chatbot
|
|
||||||
from log import getlogger
|
|
||||||
import httpx
|
|
||||||
import imagegen
|
|
||||||
|
|
||||||
logger = getlogger()
|
|
||||||
|
|
||||||
|
|
||||||
class Bot:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
server_url: str,
|
|
||||||
username: str,
|
|
||||||
email: str,
|
|
||||||
password: str,
|
|
||||||
port: Optional[int] = 443,
|
|
||||||
scheme: Optional[str] = "https",
|
|
||||||
openai_api_key: Optional[str] = None,
|
|
||||||
gpt_api_endpoint: Optional[str] = None,
|
|
||||||
gpt_model: Optional[str] = None,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
presence_penalty: Optional[float] = None,
|
|
||||||
frequency_penalty: Optional[float] = None,
|
|
||||||
reply_count: Optional[int] = None,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
image_generation_endpoint: Optional[str] = None,
|
|
||||||
image_generation_backend: Optional[str] = None,
|
|
||||||
image_generation_size: Optional[str] = None,
|
|
||||||
sdwui_steps: Optional[int] = None,
|
|
||||||
sdwui_sampler_name: Optional[str] = None,
|
|
||||||
sdwui_cfg_scale: Optional[float] = None,
|
|
||||||
image_format: Optional[str] = None,
|
|
||||||
timeout: Optional[float] = 120.0,
|
|
||||||
) -> None:
|
|
||||||
if server_url is None:
|
|
||||||
raise ValueError("server url must be provided")
|
|
||||||
|
|
||||||
if port is None:
|
|
||||||
self.port = 443
|
|
||||||
else:
|
|
||||||
port = int(port)
|
|
||||||
if port <= 0 or port > 65535:
|
|
||||||
raise ValueError("port must be between 0 and 65535")
|
|
||||||
self.port = port
|
|
||||||
|
|
||||||
if scheme is None:
|
|
||||||
self.scheme = "https"
|
|
||||||
else:
|
|
||||||
if scheme.strip().lower() not in ["http", "https"]:
|
|
||||||
raise ValueError("scheme must be either http or https")
|
|
||||||
self.scheme = scheme
|
|
||||||
|
|
||||||
if image_generation_endpoint and image_generation_backend not in [
|
|
||||||
"openai",
|
|
||||||
"sdwui",
|
|
||||||
"localai",
|
|
||||||
None,
|
|
||||||
]:
|
|
||||||
logger.error("image_generation_backend must be openai or sdwui or localai")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if image_format not in ["jpeg", "png", None]:
|
|
||||||
logger.error("image_format should be jpeg or png, leave blank for jpeg")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# @chatgpt
|
|
||||||
if username is None:
|
|
||||||
raise ValueError("username must be provided")
|
|
||||||
else:
|
|
||||||
self.username = username
|
|
||||||
|
|
||||||
self.openai_api_key: str = openai_api_key
|
|
||||||
self.gpt_api_endpoint = (
|
|
||||||
gpt_api_endpoint or "https://api.openai.com/v1/chat/completions"
|
|
||||||
)
|
|
||||||
self.gpt_model: str = gpt_model or "gpt-3.5-turbo"
|
|
||||||
self.max_tokens: int = max_tokens or 4000
|
|
||||||
self.top_p: float = top_p or 1.0
|
|
||||||
self.temperature: float = temperature or 0.8
|
|
||||||
self.presence_penalty: float = presence_penalty or 0.0
|
|
||||||
self.frequency_penalty: float = frequency_penalty or 0.0
|
|
||||||
self.reply_count: int = reply_count or 1
|
|
||||||
self.system_prompt: str = (
|
|
||||||
system_prompt
|
|
||||||
or "You are ChatGPT, \
|
|
||||||
a large language model trained by OpenAI. Respond conversationally"
|
|
||||||
)
|
|
||||||
self.image_generation_endpoint: str = image_generation_endpoint
|
|
||||||
self.image_generation_backend: str = image_generation_backend
|
|
||||||
|
|
||||||
if image_format:
|
|
||||||
self.image_format: str = image_format
|
|
||||||
else:
|
|
||||||
self.image_format = "jpeg"
|
|
||||||
|
|
||||||
if image_generation_size is None:
|
|
||||||
self.image_generation_size = "512x512"
|
|
||||||
self.image_generation_width = 512
|
|
||||||
self.image_generation_height = 512
|
|
||||||
else:
|
|
||||||
self.image_generation_size = image_generation_size
|
|
||||||
self.image_generation_width = self.image_generation_size.split("x")[0]
|
|
||||||
self.image_generation_height = self.image_generation_size.split("x")[1]
|
|
||||||
|
|
||||||
self.sdwui_steps = sdwui_steps
|
|
||||||
self.sdwui_sampler_name = sdwui_sampler_name
|
|
||||||
self.sdwui_cfg_scale = sdwui_cfg_scale
|
|
||||||
|
|
||||||
self.timeout = timeout or 120.0
|
|
||||||
|
|
||||||
self.bot_id = None
|
|
||||||
|
|
||||||
self.base_path = Path(os.path.dirname(__file__)).parent
|
|
||||||
|
|
||||||
if not os.path.exists(self.base_path / "images"):
|
|
||||||
os.mkdir(self.base_path / "images")
|
|
||||||
|
|
||||||
# httpx session
|
|
||||||
self.httpx_client = httpx.AsyncClient()
|
|
||||||
|
|
||||||
# initialize Chatbot object
|
|
||||||
self.chatbot = Chatbot(
|
|
||||||
aclient=self.httpx_client,
|
|
||||||
api_key=self.openai_api_key,
|
|
||||||
api_url=self.gpt_api_endpoint,
|
|
||||||
engine=self.gpt_model,
|
|
||||||
timeout=self.timeout,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
top_p=self.top_p,
|
|
||||||
presence_penalty=self.presence_penalty,
|
|
||||||
frequency_penalty=self.frequency_penalty,
|
|
||||||
reply_count=self.reply_count,
|
|
||||||
system_prompt=self.system_prompt,
|
|
||||||
temperature=self.temperature,
|
|
||||||
)
|
|
||||||
|
|
||||||
# login relative info
|
|
||||||
if email is None and password is None:
|
|
||||||
raise ValueError("user email and password must be provided")
|
|
||||||
|
|
||||||
self.driver = AsyncDriver(
|
|
||||||
{
|
|
||||||
"login_id": email,
|
|
||||||
"password": password,
|
|
||||||
"url": server_url,
|
|
||||||
"port": self.port,
|
|
||||||
"request_timeout": self.timeout,
|
|
||||||
"scheme": self.scheme,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# regular expression to match keyword
|
|
||||||
self.gpt_prog = re.compile(r"^\s*!gpt\s*(.+)$")
|
|
||||||
self.chat_prog = re.compile(r"^\s*!chat\s*(.+)$")
|
|
||||||
self.pic_prog = re.compile(r"^\s*!pic\s*(.+)$")
|
|
||||||
self.help_prog = re.compile(r"^\s*!help\s*.*$")
|
|
||||||
self.new_prog = re.compile(r"^\s*!new\s*.*$")
|
|
||||||
|
|
||||||
# close session
|
|
||||||
async def close(self, task: asyncio.Task) -> None:
|
|
||||||
await self.httpx_client.aclose()
|
|
||||||
self.driver.disconnect()
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
async def login(self) -> None:
|
|
||||||
await self.driver.login()
|
|
||||||
# get user id
|
|
||||||
resp = await self.driver.users.get_user(user_id="me")
|
|
||||||
self.bot_id = resp["id"]
|
|
||||||
|
|
||||||
async def run(self) -> None:
|
|
||||||
await self.driver.init_websocket(self.websocket_handler)
|
|
||||||
|
|
||||||
# websocket handler
|
|
||||||
async def websocket_handler(self, message) -> None:
|
|
||||||
logger.info(message)
|
|
||||||
response = json.loads(message)
|
|
||||||
if "event" in response:
|
|
||||||
event_type = response["event"]
|
|
||||||
if event_type == "posted":
|
|
||||||
raw_data = response["data"]["post"]
|
|
||||||
raw_data_dict = json.loads(raw_data)
|
|
||||||
user_id = raw_data_dict["user_id"]
|
|
||||||
root_id = (
|
|
||||||
raw_data_dict["root_id"]
|
|
||||||
if raw_data_dict["root_id"]
|
|
||||||
else raw_data_dict["id"]
|
|
||||||
)
|
|
||||||
channel_id = raw_data_dict["channel_id"]
|
|
||||||
sender_name = response["data"]["sender_name"]
|
|
||||||
raw_message = raw_data_dict["message"]
|
|
||||||
|
|
||||||
try:
|
|
||||||
asyncio.create_task(
|
|
||||||
self.message_callback(
|
|
||||||
raw_message, channel_id, user_id, sender_name, root_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
await self.send_message(channel_id, f"{e}", root_id)
|
|
||||||
|
|
||||||
# message callback
|
|
||||||
async def message_callback(
|
|
||||||
self,
|
|
||||||
raw_message: str,
|
|
||||||
channel_id: str,
|
|
||||||
user_id: str,
|
|
||||||
sender_name: str,
|
|
||||||
root_id: str,
|
|
||||||
) -> None:
|
|
||||||
# prevent command trigger loop
|
|
||||||
if sender_name != self.username:
|
|
||||||
message = raw_message
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.openai_api_key is not None
|
|
||||||
or self.gpt_api_endpoint != "https://api.openai.com/v1/chat/completions"
|
|
||||||
):
|
|
||||||
# !gpt command trigger handler
|
|
||||||
if self.gpt_prog.match(message):
|
|
||||||
prompt = self.gpt_prog.match(message).group(1)
|
|
||||||
try:
|
|
||||||
# sending typing state
|
|
||||||
await self.driver.users.publish_user_typing(
|
|
||||||
self.bot_id,
|
|
||||||
options={
|
|
||||||
"channel_id": channel_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response = await self.chatbot.oneTimeAsk(prompt)
|
|
||||||
await self.send_message(channel_id, f"{response}", root_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e, exc_info=True)
|
|
||||||
raise Exception(e)
|
|
||||||
|
|
||||||
# !chat command trigger handler
|
|
||||||
elif self.chat_prog.match(message):
|
|
||||||
prompt = self.chat_prog.match(message).group(1)
|
|
||||||
try:
|
|
||||||
# sending typing state
|
|
||||||
await self.driver.users.publish_user_typing(
|
|
||||||
self.bot_id,
|
|
||||||
options={
|
|
||||||
"channel_id": channel_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response = await self.chatbot.ask_async_v2(
|
|
||||||
prompt=prompt, convo_id=user_id
|
|
||||||
)
|
|
||||||
await self.send_message(channel_id, f"{response}", root_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e, exc_info=True)
|
|
||||||
raise Exception(e)
|
|
||||||
|
|
||||||
# !new command trigger handler
|
|
||||||
if self.new_prog.match(message):
|
|
||||||
self.chatbot.reset(convo_id=user_id)
|
|
||||||
try:
|
|
||||||
await self.send_message(
|
|
||||||
channel_id,
|
|
||||||
"New conversation created, "
|
|
||||||
+ "please use !chat to start chatting!",
|
|
||||||
root_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e, exc_info=True)
|
|
||||||
raise Exception(e)
|
|
||||||
|
|
||||||
# !pic command trigger handler
|
|
||||||
if self.image_generation_endpoint and self.image_generation_backend:
|
|
||||||
if self.pic_prog.match(message):
|
|
||||||
prompt = self.pic_prog.match(message).group(1)
|
|
||||||
# generate image
|
|
||||||
try:
|
|
||||||
# sending typing state
|
|
||||||
await self.driver.users.publish_user_typing(
|
|
||||||
self.bot_id,
|
|
||||||
options={
|
|
||||||
"channel_id": channel_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
image_path_list = await imagegen.get_images(
|
|
||||||
self.httpx_client,
|
|
||||||
self.image_generation_endpoint,
|
|
||||||
prompt,
|
|
||||||
self.image_generation_backend,
|
|
||||||
timeount=self.timeout,
|
|
||||||
api_key=self.openai_api_key,
|
|
||||||
output_path=self.base_path / "images",
|
|
||||||
n=1,
|
|
||||||
size=self.image_generation_size,
|
|
||||||
width=self.image_generation_width,
|
|
||||||
height=self.image_generation_height,
|
|
||||||
steps=self.sdwui_steps,
|
|
||||||
sampler_name=self.sdwui_sampler_name,
|
|
||||||
cfg_scale=self.sdwui_cfg_scale,
|
|
||||||
image_format=self.image_format,
|
|
||||||
)
|
|
||||||
# send image
|
|
||||||
for image_path in image_path_list:
|
|
||||||
await self.send_file(
|
|
||||||
channel_id,
|
|
||||||
f"{prompt}",
|
|
||||||
image_path,
|
|
||||||
root_id,
|
|
||||||
)
|
|
||||||
await aiofiles.os.remove(image_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e, exc_info=True)
|
|
||||||
raise Exception(e)
|
|
||||||
|
|
||||||
# !help command trigger handler
|
|
||||||
if self.help_prog.match(message):
|
|
||||||
try:
|
|
||||||
await self.send_message(channel_id, self.help(), root_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e, exc_info=True)
|
|
||||||
|
|
||||||
# send message to room
|
|
||||||
async def send_message(self, channel_id: str, message: str, root_id: str) -> None:
|
|
||||||
await self.driver.posts.create_post(
|
|
||||||
options={
|
|
||||||
"channel_id": channel_id,
|
|
||||||
"message": message,
|
|
||||||
"root_id": root_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# send file to room
|
|
||||||
async def send_file(
|
|
||||||
self, channel_id: str, message: str, filepath: str, root_id: str
|
|
||||||
) -> None:
|
|
||||||
filename = os.path.split(filepath)[-1]
|
|
||||||
try:
|
|
||||||
file_id = await self.driver.files.upload_file(
|
|
||||||
channel_id=channel_id,
|
|
||||||
files={
|
|
||||||
"files": (filename, open(filepath, "rb")),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
file_id = file_id["file_infos"][0]["id"]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e, exc_info=True)
|
|
||||||
raise Exception(e)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.driver.posts.create_post(
|
|
||||||
options={
|
|
||||||
"channel_id": channel_id,
|
|
||||||
"message": message,
|
|
||||||
"file_ids": [file_id],
|
|
||||||
"root_id": root_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e, exc_info=True)
|
|
||||||
raise Exception(e)
|
|
||||||
|
|
||||||
# !help command function
|
|
||||||
def help(self) -> str:
|
|
||||||
help_info = (
|
|
||||||
"!gpt [content], generate response without context conversation\n"
|
|
||||||
+ "!chat [content], chat with context conversation\n"
|
|
||||||
+ "!pic [prompt], Image generation with DALL·E or LocalAI or stable-diffusion-webui\n" # noqa: E501
|
|
||||||
+ "!new, start a new conversation\n"
|
|
||||||
+ "!help, help message"
|
|
||||||
)
|
|
||||||
return help_info
|
|
106
src/imagegen.py
106
src/imagegen.py
|
@ -1,106 +0,0 @@
|
||||||
import httpx
|
|
||||||
from pathlib import Path
|
|
||||||
import uuid
|
|
||||||
import base64
|
|
||||||
import io
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
async def get_images(
|
|
||||||
aclient: httpx.AsyncClient,
|
|
||||||
url: str,
|
|
||||||
prompt: str,
|
|
||||||
backend_type: str,
|
|
||||||
output_path: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> list[str]:
|
|
||||||
timeout = kwargs.get("timeout", 180.0)
|
|
||||||
if backend_type == "openai":
|
|
||||||
resp = await aclient.post(
|
|
||||||
url,
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {kwargs.get('api_key')}",
|
|
||||||
},
|
|
||||||
json={
|
|
||||||
"prompt": prompt,
|
|
||||||
"n": kwargs.get("n", 1),
|
|
||||||
"size": kwargs.get("size", "512x512"),
|
|
||||||
"response_format": "b64_json",
|
|
||||||
},
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
if resp.status_code == 200:
|
|
||||||
b64_datas = []
|
|
||||||
for data in resp.json()["data"]:
|
|
||||||
b64_datas.append(data["b64_json"])
|
|
||||||
return save_images_b64(b64_datas, output_path, **kwargs)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
|
||||||
)
|
|
||||||
elif backend_type == "sdwui":
|
|
||||||
resp = await aclient.post(
|
|
||||||
url,
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
json={
|
|
||||||
"prompt": prompt,
|
|
||||||
"sampler_name": kwargs.get("sampler_name", "Euler a"),
|
|
||||||
"cfg_scale": kwargs.get("cfg_scale", 7),
|
|
||||||
"batch_size": kwargs.get("n", 1),
|
|
||||||
"steps": kwargs.get("steps", 20),
|
|
||||||
"width": kwargs.get("width", 512),
|
|
||||||
"height": kwargs.get("height", 512),
|
|
||||||
},
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
if resp.status_code == 200:
|
|
||||||
b64_datas = resp.json()["images"]
|
|
||||||
return save_images_b64(b64_datas, output_path, **kwargs)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
|
||||||
)
|
|
||||||
elif backend_type == "localai":
|
|
||||||
resp = await aclient.post(
|
|
||||||
url,
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {kwargs.get('api_key')}",
|
|
||||||
},
|
|
||||||
json={
|
|
||||||
"prompt": prompt,
|
|
||||||
"size": kwargs.get("size", "512x512"),
|
|
||||||
},
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
if resp.status_code == 200:
|
|
||||||
image_url = resp.json()["data"][0]["url"]
|
|
||||||
return await save_image_url(image_url, aclient, output_path, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def save_images_b64(b64_datas: list[str], path: Path, **kwargs) -> list[str]:
|
|
||||||
images_path_list = []
|
|
||||||
for b64_data in b64_datas:
|
|
||||||
image_path = path / (
|
|
||||||
str(uuid.uuid4()) + "." + kwargs.get("image_format", "jpeg")
|
|
||||||
)
|
|
||||||
img = Image.open(io.BytesIO(base64.decodebytes(bytes(b64_data, "utf-8"))))
|
|
||||||
img.save(image_path)
|
|
||||||
images_path_list.append(image_path)
|
|
||||||
return images_path_list
|
|
||||||
|
|
||||||
|
|
||||||
async def save_image_url(
|
|
||||||
url: str, aclient: httpx.AsyncClient, path: Path, **kwargs
|
|
||||||
) -> list[str]:
|
|
||||||
images_path_list = []
|
|
||||||
r = await aclient.get(url)
|
|
||||||
image_path = path / (str(uuid.uuid4()) + "." + kwargs.get("image_format", "jpeg"))
|
|
||||||
if r.status_code == 200:
|
|
||||||
img = Image.open(io.BytesIO(r.content))
|
|
||||||
img.save(image_path)
|
|
||||||
images_path_list.append(image_path)
|
|
||||||
return images_path_list
|
|
97
src/main.py
97
src/main.py
|
@ -1,97 +0,0 @@
|
||||||
import signal
|
|
||||||
from bot import Bot
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import asyncio
|
|
||||||
from pathlib import Path
|
|
||||||
from log import getlogger
|
|
||||||
|
|
||||||
logger = getlogger()
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
config_path = Path(os.path.dirname(__file__)).parent / "config.json"
|
|
||||||
if os.path.isfile(config_path):
|
|
||||||
fp = open("config.json", "r", encoding="utf-8")
|
|
||||||
try:
|
|
||||||
config = json.load(fp)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e, exc_info=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
mattermost_bot = Bot(
|
|
||||||
server_url=config.get("server_url"),
|
|
||||||
email=config.get("email"),
|
|
||||||
password=config.get("password"),
|
|
||||||
username=config.get("username"),
|
|
||||||
port=config.get("port"),
|
|
||||||
scheme=config.get("scheme"),
|
|
||||||
openai_api_key=config.get("openai_api_key"),
|
|
||||||
gpt_api_endpoint=config.get("gpt_api_endpoint"),
|
|
||||||
gpt_model=config.get("gpt_model"),
|
|
||||||
max_tokens=config.get("max_tokens"),
|
|
||||||
top_p=config.get("top_p"),
|
|
||||||
presence_penalty=config.get("presence_penalty"),
|
|
||||||
frequency_penalty=config.get("frequency_penalty"),
|
|
||||||
reply_count=config.get("reply_count"),
|
|
||||||
system_prompt=config.get("system_prompt"),
|
|
||||||
temperature=config.get("temperature"),
|
|
||||||
image_generation_endpoint=config.get("image_generation_endpoint"),
|
|
||||||
image_generation_backend=config.get("image_generation_backend"),
|
|
||||||
image_generation_size=config.get("image_generation_size"),
|
|
||||||
sdwui_steps=config.get("sdwui_steps"),
|
|
||||||
sdwui_sampler_name=config.get("sdwui_sampler_name"),
|
|
||||||
sdwui_cfg_scale=config.get("sdwui_cfg_scale"),
|
|
||||||
image_format=config.get("image_format"),
|
|
||||||
timeout=config.get("timeout"),
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
mattermost_bot = Bot(
|
|
||||||
server_url=os.environ.get("SERVER_URL"),
|
|
||||||
email=os.environ.get("EMAIL"),
|
|
||||||
password=os.environ.get("PASSWORD"),
|
|
||||||
username=os.environ.get("USERNAME"),
|
|
||||||
port=int(os.environ.get("PORT", 443)),
|
|
||||||
scheme=os.environ.get("SCHEME"),
|
|
||||||
openai_api_key=os.environ.get("OPENAI_API_KEY"),
|
|
||||||
gpt_api_endpoint=os.environ.get("GPT_API_ENDPOINT"),
|
|
||||||
gpt_model=os.environ.get("GPT_MODEL"),
|
|
||||||
max_tokens=int(os.environ.get("MAX_TOKENS", 4000)),
|
|
||||||
top_p=float(os.environ.get("TOP_P", 1.0)),
|
|
||||||
presence_penalty=float(os.environ.get("PRESENCE_PENALTY", 0.0)),
|
|
||||||
frequency_penalty=float(os.environ.get("FREQUENCY_PENALTY", 0.0)),
|
|
||||||
reply_count=int(os.environ.get("REPLY_COUNT", 1)),
|
|
||||||
system_prompt=os.environ.get("SYSTEM_PROMPT"),
|
|
||||||
temperature=float(os.environ.get("TEMPERATURE", 0.8)),
|
|
||||||
image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"),
|
|
||||||
image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"),
|
|
||||||
image_generation_size=os.environ.get("IMAGE_GENERATION_SIZE"),
|
|
||||||
sdwui_steps=int(os.environ.get("SDWUI_STEPS", 20)),
|
|
||||||
sdwui_sampler_name=os.environ.get("SDWUI_SAMPLER_NAME"),
|
|
||||||
sdwui_cfg_scale=float(os.environ.get("SDWUI_CFG_SCALE", 7)),
|
|
||||||
image_format=os.environ.get("IMAGE_FORMAT"),
|
|
||||||
timeout=float(os.environ.get("TIMEOUT", 120.0)),
|
|
||||||
)
|
|
||||||
|
|
||||||
await mattermost_bot.login()
|
|
||||||
|
|
||||||
task = asyncio.create_task(mattermost_bot.run())
|
|
||||||
|
|
||||||
# handle signal interrupt
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
for signame in ("SIGINT", "SIGTERM"):
|
|
||||||
loop.add_signal_handler(
|
|
||||||
getattr(signal, signame),
|
|
||||||
lambda: asyncio.create_task(mattermost_bot.close(task)),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info("Bot stopped")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
|
@ -1,26 +1,15 @@
|
||||||
"""
|
"""
|
||||||
Code derived from https://github.com/acheong08/ChatGPT/blob/main/src/revChatGPT/V3.py
|
Code derived from: https://github.com/acheong08/ChatGPT/blob/main/src/revChatGPT/V3.py
|
||||||
A simple wrapper for the official ChatGPT API
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
from tenacity import retry, wait_random_exponential, stop_after_attempt
|
|
||||||
import httpx
|
import httpx
|
||||||
|
import requests
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
ENGINES = [
|
|
||||||
"gpt-3.5-turbo",
|
|
||||||
"gpt-3.5-turbo-16k",
|
|
||||||
"gpt-3.5-turbo-0613",
|
|
||||||
"gpt-3.5-turbo-16k-0613",
|
|
||||||
"gpt-4",
|
|
||||||
"gpt-4-32k",
|
|
||||||
"gpt-4-0613",
|
|
||||||
"gpt-4-32k-0613",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Chatbot:
|
class Chatbot:
|
||||||
"""
|
"""
|
||||||
Official ChatGPT API
|
Official ChatGPT API
|
||||||
|
@ -28,48 +17,29 @@ class Chatbot:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
aclient: httpx.AsyncClient,
|
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_url: str = None,
|
engine: str = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo",
|
||||||
engine: str = None,
|
proxy: str = None,
|
||||||
timeout: float = None,
|
timeout: float = None,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.5,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
reply_count: int = 1,
|
reply_count: int = 1,
|
||||||
truncate_limit: int = None,
|
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
||||||
system_prompt: str = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
|
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
|
||||||
"""
|
"""
|
||||||
self.engine: str = engine or "gpt-3.5-turbo"
|
self.engine: str = engine
|
||||||
self.api_key: str = api_key
|
self.api_key: str = api_key
|
||||||
self.api_url: str = api_url or "https://api.openai.com/v1/chat/completions"
|
self.system_prompt: str = system_prompt
|
||||||
self.system_prompt: str = (
|
|
||||||
system_prompt
|
|
||||||
or "You are ChatGPT, \
|
|
||||||
a large language model trained by OpenAI. Respond conversationally"
|
|
||||||
)
|
|
||||||
self.max_tokens: int = max_tokens or (
|
self.max_tokens: int = max_tokens or (
|
||||||
31000
|
31000 if engine == "gpt-4-32k" else 7000 if engine == "gpt-4" else 4000
|
||||||
if "gpt-4-32k" in engine
|
|
||||||
else 7000
|
|
||||||
if "gpt-4" in engine
|
|
||||||
else 15000
|
|
||||||
if "gpt-3.5-turbo-16k" in engine
|
|
||||||
else 4000
|
|
||||||
)
|
)
|
||||||
self.truncate_limit: int = truncate_limit or (
|
self.truncate_limit: int = (
|
||||||
30500
|
30500 if engine == "gpt-4-32k" else 6500 if engine == "gpt-4" else 3500
|
||||||
if "gpt-4-32k" in engine
|
|
||||||
else 6500
|
|
||||||
if "gpt-4" in engine
|
|
||||||
else 14500
|
|
||||||
if "gpt-3.5-turbo-16k" in engine
|
|
||||||
else 3500
|
|
||||||
)
|
)
|
||||||
self.temperature: float = temperature
|
self.temperature: float = temperature
|
||||||
self.top_p: float = top_p
|
self.top_p: float = top_p
|
||||||
|
@ -77,8 +47,31 @@ class Chatbot:
|
||||||
self.frequency_penalty: float = frequency_penalty
|
self.frequency_penalty: float = frequency_penalty
|
||||||
self.reply_count: int = reply_count
|
self.reply_count: int = reply_count
|
||||||
self.timeout: float = timeout
|
self.timeout: float = timeout
|
||||||
|
self.proxy = proxy
|
||||||
|
self.session = requests.Session()
|
||||||
|
self.session.proxies.update(
|
||||||
|
{
|
||||||
|
"http": proxy,
|
||||||
|
"https": proxy,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
proxy = (
|
||||||
|
proxy or os.environ.get("all_proxy") or os.environ.get("ALL_PROXY") or None
|
||||||
|
)
|
||||||
|
|
||||||
self.aclient = aclient
|
if proxy:
|
||||||
|
if "socks5h" not in proxy:
|
||||||
|
self.aclient = httpx.AsyncClient(
|
||||||
|
follow_redirects=True,
|
||||||
|
proxies=proxy,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.aclient = httpx.AsyncClient(
|
||||||
|
follow_redirects=True,
|
||||||
|
proxies=proxy,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
self.conversation: dict[str, list[dict]] = {
|
self.conversation: dict[str, list[dict]] = {
|
||||||
"default": [
|
"default": [
|
||||||
|
@ -89,9 +82,6 @@ class Chatbot:
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.get_token_count("default") > self.max_tokens:
|
|
||||||
raise Exception("System prompt is too long")
|
|
||||||
|
|
||||||
def add_to_conversation(
|
def add_to_conversation(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
|
@ -117,25 +107,29 @@ class Chatbot:
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
|
||||||
def get_token_count(self, convo_id: str = "default") -> int:
|
def get_token_count(self, convo_id: str = "default") -> int:
|
||||||
"""
|
"""
|
||||||
Get token count
|
Get token count
|
||||||
"""
|
"""
|
||||||
_engine = self.engine
|
if self.engine not in [
|
||||||
if self.engine not in ENGINES:
|
"gpt-3.5-turbo",
|
||||||
# use gpt-3.5-turbo to caculate token
|
"gpt-3.5-turbo-0301",
|
||||||
_engine = "gpt-3.5-turbo"
|
"gpt-4",
|
||||||
|
"gpt-4-0314",
|
||||||
|
"gpt-4-32k",
|
||||||
|
"gpt-4-32k-0314",
|
||||||
|
]:
|
||||||
|
raise NotImplementedError("Unsupported engine {self.engine}")
|
||||||
|
|
||||||
tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base"
|
tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base"
|
||||||
|
|
||||||
encoding = tiktoken.encoding_for_model(_engine)
|
encoding = tiktoken.encoding_for_model(self.engine)
|
||||||
|
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
for message in self.conversation[convo_id]:
|
for message in self.conversation[convo_id]:
|
||||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||||
num_tokens += 5
|
num_tokens += 5
|
||||||
for key, value in message.items():
|
for key, value in message.items():
|
||||||
if value:
|
|
||||||
num_tokens += len(encoding.encode(value))
|
num_tokens += len(encoding.encode(value))
|
||||||
if key == "name": # if there's a name, the role is omitted
|
if key == "name": # if there's a name, the role is omitted
|
||||||
num_tokens += 5 # role is always required and always 1 token
|
num_tokens += 5 # role is always required and always 1 token
|
||||||
|
@ -148,15 +142,13 @@ class Chatbot:
|
||||||
"""
|
"""
|
||||||
return self.max_tokens - self.get_token_count(convo_id)
|
return self.max_tokens - self.get_token_count(convo_id)
|
||||||
|
|
||||||
async def ask_stream_async(
|
def ask_stream(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
role: str = "user",
|
role: str = "user",
|
||||||
convo_id: str = "default",
|
convo_id: str = "default",
|
||||||
model: str = None,
|
|
||||||
pass_history: bool = True,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
):
|
||||||
"""
|
"""
|
||||||
Ask a question
|
Ask a question
|
||||||
"""
|
"""
|
||||||
|
@ -166,13 +158,12 @@ class Chatbot:
|
||||||
self.add_to_conversation(prompt, "user", convo_id=convo_id)
|
self.add_to_conversation(prompt, "user", convo_id=convo_id)
|
||||||
self.__truncate_conversation(convo_id=convo_id)
|
self.__truncate_conversation(convo_id=convo_id)
|
||||||
# Get response
|
# Get response
|
||||||
async with self.aclient.stream(
|
response = self.session.post(
|
||||||
"post",
|
os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
|
||||||
self.api_url,
|
|
||||||
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
||||||
json={
|
json={
|
||||||
"model": model or self.engine,
|
"model": self.engine,
|
||||||
"messages": self.conversation[convo_id] if pass_history else [prompt],
|
"messages": self.conversation[convo_id],
|
||||||
"stream": True,
|
"stream": True,
|
||||||
# kwargs
|
# kwargs
|
||||||
"temperature": kwargs.get("temperature", self.temperature),
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
@ -187,18 +178,79 @@ class Chatbot:
|
||||||
),
|
),
|
||||||
"n": kwargs.get("n", self.reply_count),
|
"n": kwargs.get("n", self.reply_count),
|
||||||
"user": role,
|
"user": role,
|
||||||
"max_tokens": min(
|
"max_tokens": self.get_max_tokens(convo_id=convo_id),
|
||||||
self.get_max_tokens(convo_id=convo_id),
|
},
|
||||||
kwargs.get("max_tokens", self.max_tokens),
|
timeout=kwargs.get("timeout", self.timeout),
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
response_role: str = None
|
||||||
|
full_response: str = ""
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
# Remove "data: "
|
||||||
|
line = line.decode("utf-8")[6:]
|
||||||
|
if line == "[DONE]":
|
||||||
|
break
|
||||||
|
resp: dict = json.loads(line)
|
||||||
|
choices = resp.get("choices")
|
||||||
|
if not choices:
|
||||||
|
continue
|
||||||
|
delta = choices[0].get("delta")
|
||||||
|
if not delta:
|
||||||
|
continue
|
||||||
|
if "role" in delta:
|
||||||
|
response_role = delta["role"]
|
||||||
|
if "content" in delta:
|
||||||
|
content = delta["content"]
|
||||||
|
full_response += content
|
||||||
|
yield content
|
||||||
|
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
|
||||||
|
|
||||||
|
async def ask_stream_async(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
role: str = "user",
|
||||||
|
convo_id: str = "default",
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""
|
||||||
|
Ask a question
|
||||||
|
"""
|
||||||
|
# Make conversation if it doesn't exist
|
||||||
|
if convo_id not in self.conversation:
|
||||||
|
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
|
||||||
|
self.add_to_conversation(prompt, "user", convo_id=convo_id)
|
||||||
|
self.__truncate_conversation(convo_id=convo_id)
|
||||||
|
# Get response
|
||||||
|
async with self.aclient.stream(
|
||||||
|
"post",
|
||||||
|
os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
|
||||||
|
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
||||||
|
json={
|
||||||
|
"model": self.engine,
|
||||||
|
"messages": self.conversation[convo_id],
|
||||||
|
"stream": True,
|
||||||
|
# kwargs
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
"top_p": kwargs.get("top_p", self.top_p),
|
||||||
|
"presence_penalty": kwargs.get(
|
||||||
|
"presence_penalty",
|
||||||
|
self.presence_penalty,
|
||||||
),
|
),
|
||||||
|
"frequency_penalty": kwargs.get(
|
||||||
|
"frequency_penalty",
|
||||||
|
self.frequency_penalty,
|
||||||
|
),
|
||||||
|
"n": kwargs.get("n", self.reply_count),
|
||||||
|
"user": role,
|
||||||
|
"max_tokens": self.get_max_tokens(convo_id=convo_id),
|
||||||
},
|
},
|
||||||
timeout=kwargs.get("timeout", self.timeout),
|
timeout=kwargs.get("timeout", self.timeout),
|
||||||
) as response:
|
) as response:
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
await response.aread()
|
await response.aread()
|
||||||
raise Exception(
|
|
||||||
f"{response.status_code} {response.reason_phrase} {response.text}",
|
|
||||||
)
|
|
||||||
|
|
||||||
response_role: str = ""
|
response_role: str = ""
|
||||||
full_response: str = ""
|
full_response: str = ""
|
||||||
|
@ -211,8 +263,6 @@ class Chatbot:
|
||||||
if line == "[DONE]":
|
if line == "[DONE]":
|
||||||
break
|
break
|
||||||
resp: dict = json.loads(line)
|
resp: dict = json.loads(line)
|
||||||
if "error" in resp:
|
|
||||||
raise Exception(f"{resp['error']}")
|
|
||||||
choices = resp.get("choices")
|
choices = resp.get("choices")
|
||||||
if not choices:
|
if not choices:
|
||||||
continue
|
continue
|
||||||
|
@ -232,8 +282,6 @@ class Chatbot:
|
||||||
prompt: str,
|
prompt: str,
|
||||||
role: str = "user",
|
role: str = "user",
|
||||||
convo_id: str = "default",
|
convo_id: str = "default",
|
||||||
model: str = None,
|
|
||||||
pass_history: bool = True,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -243,59 +291,28 @@ class Chatbot:
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
role=role,
|
role=role,
|
||||||
convo_id=convo_id,
|
convo_id=convo_id,
|
||||||
model=model,
|
|
||||||
pass_history=pass_history,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
full_response: str = "".join([r async for r in response])
|
full_response: str = "".join([r async for r in response])
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
async def ask_async_v2(
|
def ask(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
role: str = "user",
|
role: str = "user",
|
||||||
convo_id: str = "default",
|
convo_id: str = "default",
|
||||||
model: str = None,
|
|
||||||
pass_history: bool = True,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Make conversation if it doesn't exist
|
"""
|
||||||
if convo_id not in self.conversation:
|
Non-streaming ask
|
||||||
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
|
"""
|
||||||
self.add_to_conversation(prompt, "user", convo_id=convo_id)
|
response = self.ask_stream(
|
||||||
self.__truncate_conversation(convo_id=convo_id)
|
prompt=prompt,
|
||||||
# Get response
|
role=role,
|
||||||
response = await self.aclient.post(
|
convo_id=convo_id,
|
||||||
url=self.api_url,
|
**kwargs,
|
||||||
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
|
||||||
json={
|
|
||||||
"model": model or self.engine,
|
|
||||||
"messages": self.conversation[convo_id] if pass_history else [prompt],
|
|
||||||
# kwargs
|
|
||||||
"temperature": kwargs.get("temperature", self.temperature),
|
|
||||||
"top_p": kwargs.get("top_p", self.top_p),
|
|
||||||
"presence_penalty": kwargs.get(
|
|
||||||
"presence_penalty",
|
|
||||||
self.presence_penalty,
|
|
||||||
),
|
|
||||||
"frequency_penalty": kwargs.get(
|
|
||||||
"frequency_penalty",
|
|
||||||
self.frequency_penalty,
|
|
||||||
),
|
|
||||||
"n": kwargs.get("n", self.reply_count),
|
|
||||||
"user": role,
|
|
||||||
"max_tokens": min(
|
|
||||||
self.get_max_tokens(convo_id=convo_id),
|
|
||||||
kwargs.get("max_tokens", self.max_tokens),
|
|
||||||
),
|
|
||||||
},
|
|
||||||
timeout=kwargs.get("timeout", self.timeout),
|
|
||||||
)
|
|
||||||
resp = response.json()
|
|
||||||
full_response = resp["choices"][0]["message"]["content"]
|
|
||||||
self.add_to_conversation(
|
|
||||||
full_response, resp["choices"][0]["message"]["role"], convo_id=convo_id
|
|
||||||
)
|
)
|
||||||
|
full_response: str = "".join(response)
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
|
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
|
||||||
|
@ -305,40 +322,3 @@ class Chatbot:
|
||||||
self.conversation[convo_id] = [
|
self.conversation[convo_id] = [
|
||||||
{"role": "system", "content": system_prompt or self.system_prompt},
|
{"role": "system", "content": system_prompt or self.system_prompt},
|
||||||
]
|
]
|
||||||
|
|
||||||
@retry(wait=wait_random_exponential(min=2, max=5), stop=stop_after_attempt(3))
|
|
||||||
async def oneTimeAsk(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
role: str = "user",
|
|
||||||
model: str = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> str:
|
|
||||||
response = await self.aclient.post(
|
|
||||||
url=self.api_url,
|
|
||||||
json={
|
|
||||||
"model": model or self.engine,
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": role,
|
|
||||||
"content": prompt,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
# kwargs
|
|
||||||
"temperature": kwargs.get("temperature", self.temperature),
|
|
||||||
"top_p": kwargs.get("top_p", self.top_p),
|
|
||||||
"presence_penalty": kwargs.get(
|
|
||||||
"presence_penalty",
|
|
||||||
self.presence_penalty,
|
|
||||||
),
|
|
||||||
"frequency_penalty": kwargs.get(
|
|
||||||
"frequency_penalty",
|
|
||||||
self.frequency_penalty,
|
|
||||||
),
|
|
||||||
"user": role,
|
|
||||||
},
|
|
||||||
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
|
||||||
timeout=kwargs.get("timeout", self.timeout),
|
|
||||||
)
|
|
||||||
resp = response.json()
|
|
||||||
return resp["choices"][0]["message"]["content"]
|
|
Loading…
Reference in a new issue