refactor: image generation

This commit is contained in:
hibobmaster 2023-09-17 12:27:16 +08:00
parent 0197e8b3d2
commit bf95dc0f42
Signed by: bobmaster
SSH key fingerprint: SHA256:5ZYgd8fg+PcNZNy4SzcSKu5JtqZyBF8kUhY7/k2viDk
6 changed files with 172 additions and 34 deletions

View file

@ -17,4 +17,6 @@ SYSTEM_PROMPT="You are ChatGPT, a large language model trained by OpenAI. Respo
TEMPERATURE=0.8 TEMPERATURE=0.8
FLOWISE_API_URL="http://flowise:3000/api/v1/prediction/6deb3c89-45bf-4ac4-a0b0-b2d5ef249d21" FLOWISE_API_URL="http://flowise:3000/api/v1/prediction/6deb3c89-45bf-4ac4-a0b0-b2d5ef249d21"
FLOWISE_API_KEY="U3pe0bbVDWOyoJtsDzFJjRvHKTP3FRjODwuM78exC3A=" FLOWISE_API_KEY="U3pe0bbVDWOyoJtsDzFJjRvHKTP3FRjODwuM78exC3A="
IMAGE_GENERATION_ENDPOINT="http://localai:8080/v1/images/generations"
IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui
TIMEOUT=120.0 TIMEOUT=120.0

View file

@ -5,10 +5,11 @@ This is a simple Matrix bot that support using OpenAI API, Langchain to generate
## Feature ## Feature
1. Support official openai api and self host models([LocalAI](https://github.com/go-skynet/LocalAI)) 1. Support official openai api and self host models([LocalAI](https://localai.io/model-compatibility/))
2. Support E2E Encrypted Room 2. Support E2E Encrypted Room
3. Colorful code blocks 3. Colorful code blocks
4. Langchain([Flowise](https://github.com/FlowiseAI/Flowise)) 4. Langchain([Flowise](https://github.com/FlowiseAI/Flowise))
5. Image Generation with [DALL·E](https://platform.openai.com/docs/api-reference/images/create) or [LocalAI](https://localai.io/features/image-generation/) or [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)
## Installation and Setup ## Installation and Setup
@ -67,7 +68,7 @@ python src/main.py
## Usage ## Usage
To interact with the bot, simply send a message to the bot in the Matrix room with one of the two prompts:<br> To interact with the bot, simply send a message to the bot in the Matrix room with one of the following prompts:<br>
- `!help` help message - `!help` help message
- `!gpt` To generate a one time response: - `!gpt` To generate a one time response:
@ -95,8 +96,8 @@ To interact with the bot, simply send a message to the bot in the Matrix room wi
## Image Generation ## Image Generation
![demo1](https://i.imgur.com/voeomsF.jpg)
![demo2](https://i.imgur.com/BKZktWd.jpg)
https://github.com/hibobmaster/matrix_chatgpt_bot/wiki/ <br> https://github.com/hibobmaster/matrix_chatgpt_bot/wiki/ <br>

View file

@ -18,5 +18,7 @@
"system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally", "system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
"flowise_api_url": "http://flowise:3000/api/v1/prediction/6deb3c89-45bf-4ac4-a0b0-b2d5ef249d21", "flowise_api_url": "http://flowise:3000/api/v1/prediction/6deb3c89-45bf-4ac4-a0b0-b2d5ef249d21",
"flowise_api_key": "U3pe0bbVDWOyoJtsDzFJjRvHKTP3FRjODwuM78exC3A=", "flowise_api_key": "U3pe0bbVDWOyoJtsDzFJjRvHKTP3FRjODwuM78exC3A=",
"image_generation_endpoint": "http://localai:8080/v1/images/generations",
"image_generation_backend": "openai",
"timeout": 120.0 "timeout": 120.0
} }

View file

@ -5,6 +5,7 @@ import re
import sys import sys
import traceback import traceback
from typing import Union, Optional from typing import Union, Optional
import aiofiles.os
import httpx import httpx
@ -33,6 +34,7 @@ from send_image import send_room_image
from send_message import send_room_message from send_message import send_room_message
from flowise import flowise_query from flowise import flowise_query
from gptbot import Chatbot from gptbot import Chatbot
import imagegen
logger = getlogger() logger = getlogger()
DEVICE_NAME = "MatrixChatGPTBot" DEVICE_NAME = "MatrixChatGPTBot"
@ -61,6 +63,8 @@ class Bot:
temperature: Union[float, None] = None, temperature: Union[float, None] = None,
flowise_api_url: Optional[str] = None, flowise_api_url: Optional[str] = None,
flowise_api_key: Optional[str] = None, flowise_api_key: Optional[str] = None,
image_generation_endpoint: Optional[str] = None,
image_generation_backend: Optional[str] = None,
timeout: Union[float, None] = None, timeout: Union[float, None] = None,
): ):
if homeserver is None or user_id is None or device_id is None: if homeserver is None or user_id is None or device_id is None:
@ -71,6 +75,14 @@ class Bot:
logger.warning("password is required") logger.warning("password is required")
sys.exit(1) sys.exit(1)
if image_generation_endpoint and image_generation_backend not in [
"openai",
"sdwui",
None,
]:
logger.warning("image_generation_backend must be openai or sdwui")
sys.exit(1)
self.homeserver: str = homeserver self.homeserver: str = homeserver
self.user_id: str = user_id self.user_id: str = user_id
self.password: str = password self.password: str = password
@ -98,11 +110,16 @@ class Bot:
self.import_keys_password: str = import_keys_password self.import_keys_password: str = import_keys_password
self.flowise_api_url: str = flowise_api_url self.flowise_api_url: str = flowise_api_url
self.flowise_api_key: str = flowise_api_key self.flowise_api_key: str = flowise_api_key
self.image_generation_endpoint: str = image_generation_endpoint
self.image_generation_backend: str = image_generation_backend
self.timeout: float = timeout or 120.0 self.timeout: float = timeout or 120.0
self.base_path = Path(os.path.dirname(__file__)).parent self.base_path = Path(os.path.dirname(__file__)).parent
if not os.path.exists(self.base_path / "images"):
os.mkdir(self.base_path / "images")
self.httpx_client = httpx.AsyncClient( self.httpx_client = httpx.AsyncClient(
follow_redirects=True, follow_redirects=True,
timeout=self.timeout, timeout=self.timeout,
@ -270,6 +287,23 @@ class Bot:
except Exception as e: except Exception as e:
logger.error(e, exc_info=True) logger.error(e, exc_info=True)
# !pic command
p = self.pic_prog.match(content_body)
if p:
prompt = p.group(1)
try:
asyncio.create_task(
self.pic(
room_id,
prompt,
reply_to_event_id,
sender_id,
raw_user_message,
)
)
except Exception as e:
logger.error(e, exc_info=True)
# help command # help command
h = self.help_prog.match(content_body) h = self.help_prog.match(content_body)
if h: if h:
@ -523,9 +557,7 @@ class Bot:
logger.info(estr) logger.info(estr)
# !chat command # !chat command
async def chat( async def chat(self, room_id, reply_to_event_id, prompt, sender_id, user_message):
self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message
):
try: try:
await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000) await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
content = await self.chatbot.ask_async( content = await self.chatbot.ask_async(
@ -538,16 +570,17 @@ class Bot:
reply_message=content, reply_message=content,
reply_to_event_id=reply_to_event_id, reply_to_event_id=reply_to_event_id,
sender_id=sender_id, sender_id=sender_id,
user_message=raw_user_message, user_message=user_message,
) )
except Exception: except Exception as e:
logger.error(e, exc_info=True)
await self.send_general_error_message( await self.send_general_error_message(
room_id, reply_to_event_id, sender_id, raw_user_message room_id, reply_to_event_id, sender_id, user_message
) )
# !gpt command # !gpt command
async def gpt( async def gpt(
self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message self, room_id, reply_to_event_id, prompt, sender_id, user_message
) -> None: ) -> None:
try: try:
# sending typing state, seconds to milliseconds # sending typing state, seconds to milliseconds
@ -562,17 +595,17 @@ class Bot:
reply_message=responseMessage.strip(), reply_message=responseMessage.strip(),
reply_to_event_id=reply_to_event_id, reply_to_event_id=reply_to_event_id,
sender_id=sender_id, sender_id=sender_id,
user_message=raw_user_message, user_message=user_message,
) )
except Exception as e: except Exception as e:
logger.error(e, exc_info=True)
await self.send_general_error_message( await self.send_general_error_message(
room_id, reply_to_event_id, sender_id, raw_user_message room_id, reply_to_event_id, sender_id, user_message
) )
logger.error(e)
# !lc command # !lc command
async def lc( async def lc(
self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message self, room_id, reply_to_event_id, prompt, sender_id, user_message
) -> None: ) -> None:
try: try:
# sending typing state # sending typing state
@ -592,11 +625,12 @@ class Bot:
reply_message=responseMessage.strip(), reply_message=responseMessage.strip(),
reply_to_event_id=reply_to_event_id, reply_to_event_id=reply_to_event_id,
sender_id=sender_id, sender_id=sender_id,
user_message=raw_user_message, user_message=user_message,
) )
except Exception: except Exception as e:
logger.error(e, exc_info=True)
await self.send_general_error_message( await self.send_general_error_message(
room_id, reply_to_event_id, sender_id, raw_user_message room_id, reply_to_event_id, sender_id, user_message
) )
# !new command # !new command
@ -605,7 +639,7 @@ class Bot:
room_id, room_id,
reply_to_event_id, reply_to_event_id,
sender_id, sender_id,
raw_user_message, user_message,
new_command, new_command,
) -> None: ) -> None:
try: try:
@ -623,32 +657,58 @@ class Bot:
reply_message=content, reply_message=content,
reply_to_event_id=reply_to_event_id, reply_to_event_id=reply_to_event_id,
sender_id=sender_id, sender_id=sender_id,
user_message=raw_user_message, user_message=user_message,
) )
except Exception: except Exception as e:
logger.error(e, exc_info=True)
await self.send_general_error_message( await self.send_general_error_message(
room_id, reply_to_event_id, sender_id, raw_user_message room_id, reply_to_event_id, sender_id, user_message
) )
# !pic command # !pic command
async def pic(self, room_id, prompt, replay_to_event_id): async def pic(self, room_id, prompt, replay_to_event_id, sender_id, user_message):
try: try:
if self.image_generation_endpoint is not None:
await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000) await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
# generate image # generate image
links = await self.imageGen.get_images(prompt) b64_datas = await imagegen.get_images(
image_path_list = await self.imageGen.save_images( self.httpx_client,
links, self.base_path / "images", self.output_four_images self.image_generation_endpoint,
prompt,
self.image_generation_backend,
timeount=self.timeout,
api_key=self.openai_api_key,
n=1,
size="256x256",
)
image_path_list = await asyncio.to_thread(
imagegen.save_images,
b64_datas,
self.base_path / "images",
) )
# send image # send image
for image_path in image_path_list: for image_path in image_path_list:
await send_room_image(self.client, room_id, image_path) await send_room_image(self.client, room_id, image_path)
await aiofiles.os.remove(image_path)
await self.client.room_typing(room_id, typing_state=False) await self.client.room_typing(room_id, typing_state=False)
except Exception as e: else:
await send_room_message( await send_room_message(
self.client, self.client,
room_id, room_id,
reply_message=str(e), reply_message="Image generation endpoint not provided",
reply_to_event_id=replay_to_event_id, reply_to_event_id=replay_to_event_id,
sender_id=sender_id,
user_message=user_message,
)
except Exception as e:
logger.error(e, exc_info=True)
await send_room_message(
self.client,
room_id,
reply_message="Image generation failed",
reply_to_event_id=replay_to_event_id,
user_message=user_message,
sender_id=sender_id,
) )
# !help command # !help command

69
src/imagegen.py Normal file
View file

@ -0,0 +1,69 @@
import httpx
from pathlib import Path
import uuid
import base64
import io
from PIL import Image
async def get_images(
aclient: httpx.AsyncClient, url: str, prompt: str, backend_type: str, **kwargs
) -> list[str]:
timeout = kwargs.get("timeout", 120.0)
if backend_type == "openai":
resp = await aclient.post(
url,
headers={
"Content-Type": "application/json",
"Authorization": "Bearer " + kwargs.get("api_key"),
},
json={
"prompt": prompt,
"n": kwargs.get("n", 1),
"size": kwargs.get("size", "256x256"),
"response_format": "b64_json",
},
timeout=timeout,
)
if resp.status_code == 200:
b64_datas = []
for data in resp.json()["data"]:
b64_datas.append(data["b64_json"])
return b64_datas
else:
raise Exception(
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
)
elif backend_type == "sdwui":
resp = await aclient.post(
url,
headers={
"Content-Type": "application/json",
},
json={
"prompt": prompt,
"sampler_name": kwargs.get("sampler_name", "Euler a"),
"batch_size": kwargs.get("n", 1),
"steps": kwargs.get("steps", 20),
"width": 256 if "256" in kwargs.get("size") else 512,
"height": 256 if "256" in kwargs.get("size") else 512,
},
timeout=timeout,
)
if resp.status_code == 200:
b64_datas = resp.json()["images"]
return b64_datas
else:
raise Exception(
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
)
def save_images(b64_datas: list[str], path: Path, **kwargs) -> list[str]:
images = []
for b64_data in b64_datas:
image_path = path / (str(uuid.uuid4()) + ".jpeg")
img = Image.open(io.BytesIO(base64.decodebytes(bytes(b64_data, "utf-8"))))
img.save(image_path)
images.append(image_path)
return images

View file

@ -42,6 +42,8 @@ async def main():
temperature=float(config.get("temperature")), temperature=float(config.get("temperature")),
flowise_api_url=config.get("flowise_api_url"), flowise_api_url=config.get("flowise_api_url"),
flowise_api_key=config.get("flowise_api_key"), flowise_api_key=config.get("flowise_api_key"),
image_generation_endpoint=config.get("image_generation_endpoint"),
image_generation_backend=config.get("image_generation_backend"),
timeout=float(config.get("timeout")), timeout=float(config.get("timeout")),
) )
if ( if (
@ -71,6 +73,8 @@ async def main():
temperature=float(os.environ.get("TEMPERATURE")), temperature=float(os.environ.get("TEMPERATURE")),
flowise_api_url=os.environ.get("FLOWISE_API_URL"), flowise_api_url=os.environ.get("FLOWISE_API_URL"),
flowise_api_key=os.environ.get("FLOWISE_API_KEY"), flowise_api_key=os.environ.get("FLOWISE_API_KEY"),
image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"),
image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"),
timeout=float(os.environ.get("TIMEOUT")), timeout=float(os.environ.get("TIMEOUT")),
) )
if ( if (