feat: refactor image genderation backend
This commit is contained in:
parent
06ccd8c61c
commit
0b20e0ac1a
4 changed files with 51 additions and 29 deletions
63
src/bot.py
63
src/bot.py
|
@ -4,9 +4,11 @@ import json
|
|||
import asyncio
|
||||
import re
|
||||
import os
|
||||
from pathlib import Path
|
||||
from gptbot import Chatbot
|
||||
from log import getlogger
|
||||
import httpx
|
||||
import imagegen
|
||||
|
||||
logger = getlogger()
|
||||
|
||||
|
@ -78,6 +80,11 @@ class Bot:
|
|||
self.image_generation_backend: str = image_generation_backend
|
||||
self.timeout = timeout or 120.0
|
||||
|
||||
self.base_path = Path(os.path.dirname(__file__)).parent
|
||||
|
||||
if not os.path.exists(self.base_path / "images"):
|
||||
os.mkdir(self.base_path / "images")
|
||||
|
||||
# httpx session
|
||||
self.httpx_client = httpx.AsyncClient()
|
||||
|
||||
|
@ -213,22 +220,38 @@ class Bot:
|
|||
raise Exception(e)
|
||||
|
||||
# !pic command trigger handler
|
||||
if self.pic_prog.match(message):
|
||||
prompt = self.pic_prog.match(message).group(1)
|
||||
# generate image
|
||||
try:
|
||||
links = await self.imagegen.get_images(prompt)
|
||||
image_path = await self.imagegen.save_images(links, "images")
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise Exception(e)
|
||||
|
||||
# send image
|
||||
try:
|
||||
await self.send_file(channel_id, prompt, image_path)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise Exception(e)
|
||||
if self.image_generation_endpoint and self.image_generation_backend:
|
||||
if self.pic_prog.match(message):
|
||||
prompt = self.pic_prog.match(message).group(1)
|
||||
# generate image
|
||||
try:
|
||||
# generate image
|
||||
b64_datas = await imagegen.get_images(
|
||||
self.httpx_client,
|
||||
self.image_generation_endpoint,
|
||||
prompt,
|
||||
self.image_generation_backend,
|
||||
timeount=self.timeout,
|
||||
api_key=self.openai_api_key,
|
||||
n=1,
|
||||
size="256x256",
|
||||
)
|
||||
image_path_list = await asyncio.to_thread(
|
||||
imagegen.save_images,
|
||||
b64_datas,
|
||||
self.base_path / "images",
|
||||
)
|
||||
# send image
|
||||
for image_path in image_path_list:
|
||||
await self.send_file(
|
||||
channel_id,
|
||||
f"{prompt}",
|
||||
image_path,
|
||||
root_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise Exception(e)
|
||||
|
||||
# !help command trigger handler
|
||||
if self.help_prog.match(message):
|
||||
|
@ -248,7 +271,9 @@ class Bot:
|
|||
)
|
||||
|
||||
# send file to room
|
||||
async def send_file(self, channel_id: str, message: str, filepath: str) -> None:
|
||||
async def send_file(
|
||||
self, channel_id: str, message: str, filepath: str, root_id: str
|
||||
) -> None:
|
||||
filename = os.path.split(filepath)[-1]
|
||||
try:
|
||||
file_id = await self.driver.files.upload_file(
|
||||
|
@ -256,7 +281,8 @@ class Bot:
|
|||
files={
|
||||
"files": (filename, open(filepath, "rb")),
|
||||
},
|
||||
)["file_infos"][0]["id"]
|
||||
)
|
||||
file_id = file_id["file_infos"][0]["id"]
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise Exception(e)
|
||||
|
@ -267,6 +293,7 @@ class Bot:
|
|||
"channel_id": channel_id,
|
||||
"message": message,
|
||||
"file_ids": [file_id],
|
||||
"root_id": root_id,
|
||||
}
|
||||
)
|
||||
# remove image after posting
|
||||
|
|
|
@ -15,7 +15,7 @@ async def get_images(
|
|||
url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + kwargs.get("api_key"),
|
||||
"Authorization": f"Bearer {kwargs.get('api_key')}",
|
||||
},
|
||||
json={
|
||||
"prompt": prompt,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue