feat: refactor image genderation backend

This commit is contained in:
hibobmaster 2023-09-18 15:19:27 +08:00
commit 0b20e0ac1a
Signed by: bobmaster
SSH key fingerprint: SHA256:5ZYgd8fg+PcNZNy4SzcSKu5JtqZyBF8kUhY7/k2viDk
4 changed files with 51 additions and 29 deletions

View file

@ -4,9 +4,11 @@ import json
import asyncio
import re
import os
from pathlib import Path
from gptbot import Chatbot
from log import getlogger
import httpx
import imagegen
logger = getlogger()
@ -78,6 +80,11 @@ class Bot:
self.image_generation_backend: str = image_generation_backend
self.timeout = timeout or 120.0
self.base_path = Path(os.path.dirname(__file__)).parent
if not os.path.exists(self.base_path / "images"):
os.mkdir(self.base_path / "images")
# httpx session
self.httpx_client = httpx.AsyncClient()
@ -213,22 +220,38 @@ class Bot:
raise Exception(e)
# !pic command trigger handler
if self.pic_prog.match(message):
prompt = self.pic_prog.match(message).group(1)
# generate image
try:
links = await self.imagegen.get_images(prompt)
image_path = await self.imagegen.save_images(links, "images")
except Exception as e:
logger.error(e, exc_info=True)
raise Exception(e)
# send image
try:
await self.send_file(channel_id, prompt, image_path)
except Exception as e:
logger.error(e, exc_info=True)
raise Exception(e)
if self.image_generation_endpoint and self.image_generation_backend:
if self.pic_prog.match(message):
prompt = self.pic_prog.match(message).group(1)
# generate image
try:
# generate image
b64_datas = await imagegen.get_images(
self.httpx_client,
self.image_generation_endpoint,
prompt,
self.image_generation_backend,
timeount=self.timeout,
api_key=self.openai_api_key,
n=1,
size="256x256",
)
image_path_list = await asyncio.to_thread(
imagegen.save_images,
b64_datas,
self.base_path / "images",
)
# send image
for image_path in image_path_list:
await self.send_file(
channel_id,
f"{prompt}",
image_path,
root_id,
)
except Exception as e:
logger.error(e, exc_info=True)
raise Exception(e)
# !help command trigger handler
if self.help_prog.match(message):
@ -248,7 +271,9 @@ class Bot:
)
# send file to room
async def send_file(self, channel_id: str, message: str, filepath: str) -> None:
async def send_file(
self, channel_id: str, message: str, filepath: str, root_id: str
) -> None:
filename = os.path.split(filepath)[-1]
try:
file_id = await self.driver.files.upload_file(
@ -256,7 +281,8 @@ class Bot:
files={
"files": (filename, open(filepath, "rb")),
},
)["file_infos"][0]["id"]
)
file_id = file_id["file_infos"][0]["id"]
except Exception as e:
logger.error(e, exc_info=True)
raise Exception(e)
@ -267,6 +293,7 @@ class Bot:
"channel_id": channel_id,
"message": message,
"file_ids": [file_id],
"root_id": root_id,
}
)
# remove image after posting

View file

@ -15,7 +15,7 @@ async def get_images(
url,
headers={
"Content-Type": "application/json",
"Authorization": "Bearer " + kwargs.get("api_key"),
"Authorization": f"Bearer {kwargs.get('api_key')}",
},
json={
"prompt": prompt,