refactor code structure and remove unused
This commit is contained in:
parent
d0b93b454d
commit
2ead99a06b
20 changed files with 1004 additions and 1515 deletions
165
src/BingImageGen.py
Normal file
165
src/BingImageGen.py
Normal file
|
@ -0,0 +1,165 @@
|
|||
"""
|
||||
Code derived from:
|
||||
https://github.com/acheong08/EdgeGPT/blob/f940cecd24a4818015a8b42a2443dd97c3c2a8f4/src/ImageGen.py
|
||||
"""
|
||||
from log import getlogger
|
||||
from uuid import uuid4
|
||||
import os
|
||||
import contextlib
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import random
|
||||
import requests
|
||||
import regex
|
||||
|
||||
logger = getlogger()
|
||||
|
||||
BING_URL = "https://www.bing.com"
|
||||
# Generate random IP between range 13.104.0.0/14
|
||||
FORWARDED_IP = (
|
||||
f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
|
||||
)
|
||||
HEADERS = {
|
||||
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
|
||||
"accept-language": "en-US,en;q=0.9",
|
||||
"cache-control": "max-age=0",
|
||||
"content-type": "application/x-www-form-urlencoded",
|
||||
"referrer": "https://www.bing.com/images/create/",
|
||||
"origin": "https://www.bing.com",
|
||||
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36 Edg/110.0.1587.63",
|
||||
"x-forwarded-for": FORWARDED_IP,
|
||||
}
|
||||
|
||||
|
||||
class ImageGenAsync:
|
||||
"""
|
||||
Image generation by Microsoft Bing
|
||||
Parameters:
|
||||
auth_cookie: str
|
||||
"""
|
||||
|
||||
def __init__(self, auth_cookie: str, quiet: bool = True) -> None:
|
||||
self.session = aiohttp.ClientSession(
|
||||
headers=HEADERS,
|
||||
cookies={"_U": auth_cookie},
|
||||
)
|
||||
self.quiet = quiet
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *excinfo) -> None:
|
||||
await self.session.close()
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self._close())
|
||||
|
||||
async def _close(self):
|
||||
await self.session.close()
|
||||
|
||||
async def get_images(self, prompt: str) -> list:
|
||||
"""
|
||||
Fetches image links from Bing
|
||||
Parameters:
|
||||
prompt: str
|
||||
"""
|
||||
if not self.quiet:
|
||||
print("Sending request...")
|
||||
url_encoded_prompt = requests.utils.quote(prompt)
|
||||
# https://www.bing.com/images/create?q=<PROMPT>&rt=3&FORM=GENCRE
|
||||
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE"
|
||||
async with self.session.post(url, allow_redirects=False) as response:
|
||||
content = await response.text()
|
||||
if "this prompt has been blocked" in content.lower():
|
||||
raise Exception(
|
||||
"Your prompt has been blocked by Bing. Try to change any bad words and try again.",
|
||||
)
|
||||
if response.status != 302:
|
||||
# if rt4 fails, try rt3
|
||||
url = (
|
||||
f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=3&FORM=GENCRE"
|
||||
)
|
||||
async with self.session.post(
|
||||
url,
|
||||
allow_redirects=False,
|
||||
timeout=200,
|
||||
) as response3:
|
||||
if response3.status != 302:
|
||||
print(f"ERROR: {response3.text}")
|
||||
raise Exception("Redirect failed")
|
||||
response = response3
|
||||
# Get redirect URL
|
||||
redirect_url = response.headers["Location"].replace("&nfy=1", "")
|
||||
request_id = redirect_url.split("id=")[-1]
|
||||
await self.session.get(f"{BING_URL}{redirect_url}")
|
||||
# https://www.bing.com/images/create/async/results/{ID}?q={PROMPT}
|
||||
polling_url = f"{BING_URL}/images/create/async/results/{request_id}?q={url_encoded_prompt}"
|
||||
# Poll for results
|
||||
if not self.quiet:
|
||||
print("Waiting for results...")
|
||||
while True:
|
||||
if not self.quiet:
|
||||
print(".", end="", flush=True)
|
||||
# By default, timeout is 300s, change as needed
|
||||
response = await self.session.get(polling_url)
|
||||
if response.status != 200:
|
||||
raise Exception("Could not get results")
|
||||
content = await response.text()
|
||||
if content and content.find("errorMessage") == -1:
|
||||
break
|
||||
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
# Use regex to search for src=""
|
||||
image_links = regex.findall(r'src="([^"]+)"', content)
|
||||
# Remove size limit
|
||||
normal_image_links = [link.split("?w=")[0] for link in image_links]
|
||||
# Remove duplicates
|
||||
normal_image_links = list(set(normal_image_links))
|
||||
|
||||
# Bad images
|
||||
bad_images = [
|
||||
"https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
|
||||
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
|
||||
]
|
||||
for im in normal_image_links:
|
||||
if im in bad_images:
|
||||
raise Exception("Bad images")
|
||||
# No images
|
||||
if not normal_image_links:
|
||||
raise Exception("No images")
|
||||
return normal_image_links
|
||||
|
||||
async def save_images(self, links: list, output_dir: str) -> str:
|
||||
"""
|
||||
Saves images to output directory
|
||||
"""
|
||||
if not self.quiet:
|
||||
print("\nDownloading images...")
|
||||
with contextlib.suppress(FileExistsError):
|
||||
os.mkdir(output_dir)
|
||||
|
||||
# image name
|
||||
image_name = str(uuid4())
|
||||
# we just need one image for better display in chat room
|
||||
if links:
|
||||
link = links.pop()
|
||||
|
||||
image_path = os.path.join(output_dir, f"{image_name}.jpeg")
|
||||
try:
|
||||
async with self.session.get(link, raise_for_status=True) as response:
|
||||
# save response to file
|
||||
with open(image_path, "wb") as output_file:
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
output_file.write(chunk)
|
||||
return f"{output_dir}/{image_name}.jpeg"
|
||||
|
||||
except aiohttp.client_exceptions.InvalidURL as url_exception:
|
||||
raise Exception(
|
||||
"Inappropriate contents found in the generated images. Please try again or try another prompt.",
|
||||
) from url_exception
|
46
src/askgpt.py
Normal file
46
src/askgpt.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
import aiohttp
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from log import getlogger
|
||||
|
||||
logger = getlogger()
|
||||
|
||||
|
||||
class askGPT:
|
||||
def __init__(
|
||||
self, session: aiohttp.ClientSession, headers: str
|
||||
) -> None:
|
||||
self.session = session
|
||||
self.api_endpoint = "https://api.openai.com/v1/chat/completions"
|
||||
self.headers = headers
|
||||
|
||||
async def oneTimeAsk(self, prompt: str) -> str:
|
||||
jsons = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
],
|
||||
}
|
||||
max_try = 2
|
||||
while max_try > 0:
|
||||
try:
|
||||
async with self.session.post(
|
||||
url=self.api_endpoint, json=jsons, headers=self.headers, timeout=120
|
||||
) as response:
|
||||
status_code = response.status
|
||||
if not status_code == 200:
|
||||
# print failed reason
|
||||
logger.warning(str(response.reason))
|
||||
max_try = max_try - 1
|
||||
# wait 2s
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
|
||||
resp = await response.read()
|
||||
return json.loads(resp)["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
raise Exception(e)
|
410
src/bot.py
Normal file
410
src/bot.py
Normal file
|
@ -0,0 +1,410 @@
|
|||
from mattermostdriver import AsyncDriver
|
||||
from typing import Optional
|
||||
import json
|
||||
import asyncio
|
||||
import re
|
||||
import os
|
||||
import aiohttp
|
||||
from askgpt import askGPT
|
||||
from revChatGPT.V3 import Chatbot as GPTChatBot
|
||||
from BingImageGen import ImageGenAsync
|
||||
from log import getlogger
|
||||
from pandora import Pandora
|
||||
import uuid
|
||||
|
||||
logger = getlogger()
|
||||
|
||||
ENGINES = [
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
]
|
||||
|
||||
|
||||
class Bot:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
username: str,
|
||||
access_token: Optional[str] = None,
|
||||
login_id: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
openai_api_key: Optional[str] = None,
|
||||
pandora_api_endpoint: Optional[str] = None,
|
||||
pandora_api_model: Optional[str] = None,
|
||||
bing_auth_cookie: Optional[str] = None,
|
||||
port: int = 443,
|
||||
scheme: str = "https",
|
||||
timeout: int = 30,
|
||||
gpt_engine: str = "gpt-3.5-turbo",
|
||||
) -> None:
|
||||
if server_url is None:
|
||||
raise ValueError("server url must be provided")
|
||||
|
||||
if port is None:
|
||||
self.port = 443
|
||||
else:
|
||||
if port < 0 or port > 65535:
|
||||
raise ValueError("port must be between 0 and 65535")
|
||||
self.port = port
|
||||
|
||||
if scheme is None:
|
||||
self.scheme = "https"
|
||||
else:
|
||||
if scheme.strip().lower() not in ["http", "https"]:
|
||||
raise ValueError("scheme must be either http or https")
|
||||
self.scheme = scheme
|
||||
|
||||
if timeout is None:
|
||||
self.timeout = 30
|
||||
else:
|
||||
self.timeout = timeout
|
||||
|
||||
if gpt_engine is None:
|
||||
self.gpt_engine = "gpt-3.5-turbo"
|
||||
else:
|
||||
if gpt_engine not in ENGINES:
|
||||
raise ValueError("gpt_engine must be one of {}".format(ENGINES))
|
||||
self.gpt_engine = gpt_engine
|
||||
|
||||
# login relative info
|
||||
if access_token is None and password is None:
|
||||
raise ValueError("Either token or password must be provided")
|
||||
|
||||
if access_token is not None:
|
||||
self.driver = AsyncDriver(
|
||||
{
|
||||
"token": access_token,
|
||||
"url": server_url,
|
||||
"port": self.port,
|
||||
"request_timeout": self.timeout,
|
||||
"scheme": self.scheme,
|
||||
}
|
||||
)
|
||||
else:
|
||||
self.driver = AsyncDriver(
|
||||
{
|
||||
"login_id": login_id,
|
||||
"password": password,
|
||||
"url": server_url,
|
||||
"port": self.port,
|
||||
"request_timeout": self.timeout,
|
||||
"scheme": self.scheme,
|
||||
}
|
||||
)
|
||||
|
||||
# @chatgpt
|
||||
if username is None:
|
||||
raise ValueError("username must be provided")
|
||||
else:
|
||||
self.username = username
|
||||
|
||||
# aiohttp session
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
# initialize chatGPT class
|
||||
self.openai_api_key = openai_api_key
|
||||
if openai_api_key is not None:
|
||||
# request header for !gpt command
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.openai_api_key}",
|
||||
}
|
||||
|
||||
self.askgpt = askGPT(
|
||||
self.session,
|
||||
self.headers,
|
||||
)
|
||||
|
||||
self.gptchatbot = GPTChatBot(
|
||||
api_key=self.openai_api_key, engine=self.gpt_engine
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"openai_api_key is not provided, !gpt and !chat command will not work"
|
||||
)
|
||||
|
||||
# initialize pandora
|
||||
self.pandora_api_endpoint = pandora_api_endpoint
|
||||
if pandora_api_endpoint is not None:
|
||||
self.pandora = Pandora(
|
||||
api_endpoint=pandora_api_endpoint, clientSession=self.session
|
||||
)
|
||||
if pandora_api_model is None:
|
||||
self.pandora_api_model = "text-davinci-002-render-sha-mobile"
|
||||
else:
|
||||
self.pandora_api_model = pandora_api_model
|
||||
self.pandora_data = {}
|
||||
|
||||
# initialize image generator
|
||||
self.bing_auth_cookie = bing_auth_cookie
|
||||
if bing_auth_cookie is not None:
|
||||
self.imagegen = ImageGenAsync(auth_cookie=self.bing_auth_cookie)
|
||||
else:
|
||||
logger.warning(
|
||||
"bing_auth_cookie is not provided, !pic command will not work"
|
||||
)
|
||||
|
||||
# regular expression to match keyword
|
||||
self.gpt_prog = re.compile(r"^\s*!gpt\s*(.+)$")
|
||||
self.chat_prog = re.compile(r"^\s*!chat\s*(.+)$")
|
||||
self.pic_prog = re.compile(r"^\s*!pic\s*(.+)$")
|
||||
self.help_prog = re.compile(r"^\s*!help\s*.*$")
|
||||
self.talk_prog = re.compile(r"^\s*!talk\s*(.+)$")
|
||||
self.goon_prog = re.compile(r"^\s*!goon\s*.*$")
|
||||
self.new_prog = re.compile(r"^\s*!new\s*.*$")
|
||||
|
||||
# close session
|
||||
async def close(self, task: asyncio.Task) -> None:
|
||||
await self.session.close()
|
||||
self.driver.disconnect()
|
||||
task.cancel()
|
||||
|
||||
async def login(self) -> None:
|
||||
await self.driver.login()
|
||||
|
||||
def pandora_init(self, user_id: str) -> None:
|
||||
self.pandora_data[user_id] = {
|
||||
"conversation_id": None,
|
||||
"parent_message_id": str(uuid.uuid4()),
|
||||
"first_time": True,
|
||||
}
|
||||
|
||||
async def run(self) -> None:
|
||||
await self.driver.init_websocket(self.websocket_handler)
|
||||
|
||||
# websocket handler
|
||||
async def websocket_handler(self, message) -> None:
|
||||
logger.info(message)
|
||||
response = json.loads(message)
|
||||
if "event" in response:
|
||||
event_type = response["event"]
|
||||
if event_type == "posted":
|
||||
raw_data = response["data"]["post"]
|
||||
raw_data_dict = json.loads(raw_data)
|
||||
user_id = raw_data_dict["user_id"]
|
||||
channel_id = raw_data_dict["channel_id"]
|
||||
sender_name = response["data"]["sender_name"]
|
||||
raw_message = raw_data_dict["message"]
|
||||
|
||||
if user_id not in self.pandora_data:
|
||||
self.pandora_init(user_id)
|
||||
|
||||
try:
|
||||
asyncio.create_task(
|
||||
self.message_callback(
|
||||
raw_message, channel_id, user_id, sender_name
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
await self.send_message(channel_id, f"{e}")
|
||||
|
||||
# message callback
|
||||
async def message_callback(
|
||||
self, raw_message: str, channel_id: str, user_id: str, sender_name: str
|
||||
) -> None:
|
||||
# prevent command trigger loop
|
||||
if sender_name != self.username:
|
||||
message = raw_message
|
||||
|
||||
if self.openai_api_key is not None:
|
||||
# !gpt command trigger handler
|
||||
if self.gpt_prog.match(message):
|
||||
prompt = self.gpt_prog.match(message).group(1)
|
||||
try:
|
||||
response = await self.gpt(prompt)
|
||||
await self.send_message(channel_id, f"{response}")
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise Exception(e)
|
||||
|
||||
# !chat command trigger handler
|
||||
elif self.chat_prog.match(message):
|
||||
prompt = self.chat_prog.match(message).group(1)
|
||||
try:
|
||||
response = await self.chat(prompt)
|
||||
await self.send_message(channel_id, f"{response}")
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise Exception(e)
|
||||
|
||||
if self.pandora_api_endpoint is not None:
|
||||
# !talk command trigger handler
|
||||
if self.talk_prog.match(message):
|
||||
prompt = self.talk_prog.match(message).group(1)
|
||||
try:
|
||||
if self.pandora_data[user_id]["conversation_id"] is not None:
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"model": self.pandora_api_model,
|
||||
"parent_message_id": self.pandora_data[user_id][
|
||||
"parent_message_id"
|
||||
],
|
||||
"conversation_id": self.pandora_data[user_id][
|
||||
"conversation_id"
|
||||
],
|
||||
"stream": False,
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"model": self.pandora_api_model,
|
||||
"parent_message_id": self.pandora_data[user_id][
|
||||
"parent_message_id"
|
||||
],
|
||||
"stream": False,
|
||||
}
|
||||
response = await self.pandora.talk(data)
|
||||
self.pandora_data[user_id]["conversation_id"] = response[
|
||||
"conversation_id"
|
||||
]
|
||||
self.pandora_data[user_id]["parent_message_id"] = response[
|
||||
"message"
|
||||
]["id"]
|
||||
content = response["message"]["content"]["parts"][0]
|
||||
if self.pandora_data[user_id]["first_time"]:
|
||||
self.pandora_data[user_id]["first_time"] = False
|
||||
data = {
|
||||
"model": self.pandora_api_model,
|
||||
"message_id": self.pandora_data[user_id][
|
||||
"parent_message_id"
|
||||
],
|
||||
}
|
||||
await self.pandora.gen_title(
|
||||
data, self.pandora_data[user_id]["conversation_id"]
|
||||
)
|
||||
|
||||
await self.send_message(channel_id, f"{content}")
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise Exception(e)
|
||||
|
||||
# !goon command trigger handler
|
||||
if (
|
||||
self.goon_prog.match(message)
|
||||
and self.pandora_data[user_id]["conversation_id"] is not None
|
||||
):
|
||||
try:
|
||||
data = {
|
||||
"model": self.pandora_api_model,
|
||||
"parent_message_id": self.pandora_data[user_id][
|
||||
"parent_message_id"
|
||||
],
|
||||
"conversation_id": self.pandora_data[user_id][
|
||||
"conversation_id"
|
||||
],
|
||||
"stream": False,
|
||||
}
|
||||
response = await self.pandora.goon(data)
|
||||
self.pandora_data[user_id]["conversation_id"] = response[
|
||||
"conversation_id"
|
||||
]
|
||||
self.pandora_data[user_id]["parent_message_id"] = response[
|
||||
"message"
|
||||
]["id"]
|
||||
content = response["message"]["content"]["parts"][0]
|
||||
await self.send_message(channel_id, f"{content}")
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise Exception(e)
|
||||
|
||||
# !new command trigger handler
|
||||
if self.new_prog.match(message):
|
||||
self.pandora_init(user_id)
|
||||
try:
|
||||
await self.send_message(
|
||||
channel_id,
|
||||
"New conversation created, " +
|
||||
"please use !talk to start chatting!",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self.bing_auth_cookie is not None:
|
||||
# !pic command trigger handler
|
||||
if self.pic_prog.match(message):
|
||||
prompt = self.pic_prog.match(message).group(1)
|
||||
# generate image
|
||||
try:
|
||||
links = await self.imagegen.get_images(prompt)
|
||||
image_path = await self.imagegen.save_images(links, "images")
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise Exception(e)
|
||||
|
||||
# send image
|
||||
try:
|
||||
await self.send_file(channel_id, prompt, image_path)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise Exception(e)
|
||||
|
||||
# !help command trigger handler
|
||||
if self.help_prog.match(message):
|
||||
try:
|
||||
await self.send_message(channel_id, self.help())
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
|
||||
# send message to room
|
||||
async def send_message(self, channel_id: str, message: str) -> None:
|
||||
await self.driver.posts.create_post(
|
||||
options={"channel_id": channel_id, "message": message}
|
||||
)
|
||||
|
||||
# send file to room
|
||||
async def send_file(self, channel_id: str, message: str, filepath: str) -> None:
|
||||
filename = os.path.split(filepath)[-1]
|
||||
try:
|
||||
file_id = await self.driver.files.upload_file(
|
||||
channel_id=channel_id,
|
||||
files={
|
||||
"files": (filename, open(filepath, "rb")),
|
||||
},
|
||||
)["file_infos"][0]["id"]
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise Exception(e)
|
||||
|
||||
try:
|
||||
await self.driver.posts.create_post(
|
||||
options={
|
||||
"channel_id": channel_id,
|
||||
"message": message,
|
||||
"file_ids": [file_id],
|
||||
}
|
||||
)
|
||||
# remove image after posting
|
||||
os.remove(filepath)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise Exception(e)
|
||||
|
||||
# !gpt command function
|
||||
async def gpt(self, prompt: str) -> str:
|
||||
return await self.askgpt.oneTimeAsk(prompt)
|
||||
|
||||
# !chat command function
|
||||
async def chat(self, prompt: str) -> str:
|
||||
return await self.gptchatbot.ask_async(prompt)
|
||||
|
||||
# !help command function
|
||||
def help(self) -> str:
|
||||
help_info = (
|
||||
"!gpt [content], generate response without context conversation\n"
|
||||
+ "!chat [content], chat with context conversation\n"
|
||||
+ "!pic [prompt], Image generation by Microsoft Bing\n"
|
||||
+ "!talk [content], talk using chatgpt web\n"
|
||||
+ "!goon, continue the incomplete conversation\n"
|
||||
+ "!new, start a new conversation\n"
|
||||
+ "!help, help message"
|
||||
)
|
||||
return help_info
|
30
src/log.py
Normal file
30
src/log.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
import logging
|
||||
|
||||
|
||||
def getlogger():
|
||||
# create a custom logger if not already created
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logger.hasHandlers():
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# create handlers
|
||||
info_handler = logging.StreamHandler()
|
||||
error_handler = logging.FileHandler("bot.log", mode="a")
|
||||
error_handler.setLevel(logging.ERROR)
|
||||
info_handler.setLevel(logging.INFO)
|
||||
|
||||
# create formatters
|
||||
error_format = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(funcName)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
info_format = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
# set formatter
|
||||
error_handler.setFormatter(error_format)
|
||||
info_handler.setFormatter(info_format)
|
||||
|
||||
# add handlers to logger
|
||||
logger.addHandler(error_handler)
|
||||
logger.addHandler(info_handler)
|
||||
|
||||
return logger
|
70
src/main.py
Normal file
70
src/main.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
import signal
|
||||
from bot import Bot
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from log import getlogger
|
||||
|
||||
logger = getlogger()
|
||||
|
||||
|
||||
async def main():
|
||||
config_path = Path(os.path.dirname(__file__)).parent / "config.json"
|
||||
if os.path.isfile(config_path):
|
||||
fp = open("config.json", "r", encoding="utf-8")
|
||||
config = json.load(fp)
|
||||
|
||||
mattermost_bot = Bot(
|
||||
server_url=config.get("server_url"),
|
||||
access_token=config.get("access_token"),
|
||||
login_id=config.get("login_id"),
|
||||
password=config.get("password"),
|
||||
username=config.get("username"),
|
||||
openai_api_key=config.get("openai_api_key"),
|
||||
bing_auth_cookie=config.get("bing_auth_cookie"),
|
||||
pandora_api_endpoint=config.get("pandora_api_endpoint"),
|
||||
pandora_api_model=config.get("pandora_api_model"),
|
||||
port=config.get("port"),
|
||||
scheme=config.get("scheme"),
|
||||
timeout=config.get("timeout"),
|
||||
gpt_engine=config.get("gpt_engine"),
|
||||
)
|
||||
|
||||
else:
|
||||
mattermost_bot = Bot(
|
||||
server_url=os.environ.get("SERVER_URL"),
|
||||
access_token=os.environ.get("ACCESS_TOKEN"),
|
||||
login_id=os.environ.get("LOGIN_ID"),
|
||||
password=os.environ.get("PASSWORD"),
|
||||
username=os.environ.get("USERNAME"),
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
bing_auth_cookie=os.environ.get("BING_AUTH_COOKIE"),
|
||||
pandora_api_endpoint=os.environ.get("PANDORA_API_ENDPOINT"),
|
||||
pandora_api_model=os.environ.get("PANDORA_API_MODEL"),
|
||||
port=os.environ.get("PORT"),
|
||||
scheme=os.environ.get("SCHEME"),
|
||||
timeout=os.environ.get("TIMEOUT"),
|
||||
gpt_engine=os.environ.get("GPT_ENGINE"),
|
||||
)
|
||||
|
||||
await mattermost_bot.login()
|
||||
|
||||
task = asyncio.create_task(mattermost_bot.run())
|
||||
|
||||
# handle signal interrupt
|
||||
loop = asyncio.get_running_loop()
|
||||
for signame in ("SIGINT", "SIGTERM"):
|
||||
loop.add_signal_handler(
|
||||
getattr(signal, signame),
|
||||
lambda: asyncio.create_task(mattermost_bot.close(task)),
|
||||
)
|
||||
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Bot stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
106
src/pandora.py
Normal file
106
src/pandora.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
# https://github.com/pengzhile/pandora/blob/master/doc/HTTP-API.md
|
||||
import uuid
|
||||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
|
||||
class Pandora:
|
||||
def __init__(self, api_endpoint: str, clientSession: aiohttp.ClientSession) -> None:
|
||||
self.api_endpoint = api_endpoint.rstrip("/")
|
||||
self.session = clientSession
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.session.close()
|
||||
|
||||
async def gen_title(self, data: dict, conversation_id: str) -> None:
|
||||
"""
|
||||
data = {
|
||||
"model": "",
|
||||
"message_id": "",
|
||||
}
|
||||
:param data: dict
|
||||
:param conversation_id: str
|
||||
:return: None
|
||||
"""
|
||||
api_endpoint = (
|
||||
self.api_endpoint + f"/api/conversation/gen_title/{conversation_id}"
|
||||
)
|
||||
async with self.session.post(api_endpoint, json=data) as resp:
|
||||
return await resp.json()
|
||||
|
||||
async def talk(self, data: dict) -> None:
|
||||
api_endpoint = self.api_endpoint + "/api/conversation/talk"
|
||||
"""
|
||||
data = {
|
||||
"prompt": "",
|
||||
"model": "",
|
||||
"parent_message_id": "",
|
||||
"conversation_id": "", # ignore at the first time
|
||||
"stream": True,
|
||||
}
|
||||
:param data: dict
|
||||
:return: None
|
||||
"""
|
||||
data["message_id"] = str(uuid.uuid4())
|
||||
async with self.session.post(api_endpoint, json=data) as resp:
|
||||
return await resp.json()
|
||||
|
||||
async def goon(self, data: dict) -> None:
|
||||
"""
|
||||
data = {
|
||||
"model": "",
|
||||
"parent_message_id": "",
|
||||
"conversation_id": "",
|
||||
"stream": True,
|
||||
}
|
||||
"""
|
||||
api_endpoint = self.api_endpoint + "/api/conversation/goon"
|
||||
async with self.session.post(api_endpoint, json=data) as resp:
|
||||
return await resp.json()
|
||||
|
||||
|
||||
async def test():
|
||||
model = "text-davinci-002-render-sha-mobile"
|
||||
api_endpoint = "http://127.0.0.1:8008"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
client = Pandora(api_endpoint, session)
|
||||
conversation_id = None
|
||||
parent_message_id = str(uuid.uuid4())
|
||||
first_time = True
|
||||
async with client:
|
||||
while True:
|
||||
prompt = input("BobMaster: ")
|
||||
if conversation_id:
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"parent_message_id": parent_message_id,
|
||||
"conversation_id": conversation_id,
|
||||
"stream": False,
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"parent_message_id": parent_message_id,
|
||||
"stream": False,
|
||||
}
|
||||
response = await client.talk(data)
|
||||
conversation_id = response["conversation_id"]
|
||||
parent_message_id = response["message"]["id"]
|
||||
content = response["message"]["content"]["parts"][0]
|
||||
print("ChatGPT: " + content + "\n")
|
||||
if first_time:
|
||||
first_time = False
|
||||
data = {
|
||||
"model": model,
|
||||
"message_id": parent_message_id,
|
||||
}
|
||||
response = await client.gen_title(data, conversation_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test())
|
Loading…
Add table
Add a link
Reference in a new issue