From bf95dc0f42a35f6eb58bff2d9730b6c4105cc4ce Mon Sep 17 00:00:00 2001 From: hibobmaster <32976627+hibobmaster@users.noreply.github.com> Date: Sun, 17 Sep 2023 12:27:16 +0800 Subject: [PATCH] refactor: image generation --- .full-env.example | 2 + README.md | 9 +-- full-config.json.sample | 2 + src/bot.py | 120 ++++++++++++++++++++++++++++++---------- src/imagegen.py | 69 +++++++++++++++++++++++ src/main.py | 4 ++ 6 files changed, 172 insertions(+), 34 deletions(-) create mode 100644 src/imagegen.py diff --git a/.full-env.example b/.full-env.example index d1c9f2c..a4565e6 100644 --- a/.full-env.example +++ b/.full-env.example @@ -17,4 +17,6 @@ SYSTEM_PROMPT="You are ChatGPT, a large language model trained by OpenAI. Respo TEMPERATURE=0.8 FLOWISE_API_URL="http://flowise:3000/api/v1/prediction/6deb3c89-45bf-4ac4-a0b0-b2d5ef249d21" FLOWISE_API_KEY="U3pe0bbVDWOyoJtsDzFJjRvHKTP3FRjODwuM78exC3A=" +IMAGE_GENERATION_ENDPOINT="http://localai:8080/v1/images/generations" +IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui TIMEOUT=120.0 diff --git a/README.md b/README.md index d25e6fb..96bd2b8 100644 --- a/README.md +++ b/README.md @@ -5,10 +5,11 @@ This is a simple Matrix bot that support using OpenAI API, Langchain to generate ## Feature -1. Support official openai api and self host models([LocalAI](https://github.com/go-skynet/LocalAI)) +1. Support official openai api and self host models([LocalAI](https://localai.io/model-compatibility/)) 2. Support E2E Encrypted Room 3. Colorful code blocks 4. Langchain([Flowise](https://github.com/FlowiseAI/Flowise)) +5. Image Generation with [DALLĀ·E](https://platform.openai.com/docs/api-reference/images/create) or [LocalAI](https://localai.io/features/image-generation/) or [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API) ## Installation and Setup @@ -67,7 +68,7 @@ python src/main.py ## Usage -To interact with the bot, simply send a message to the bot in the Matrix room with one of the two prompts:
+To interact with the bot, simply send a message to the bot in the Matrix room with one of the following prompts:
- `!help` help message - `!gpt` To generate a one time response: @@ -95,8 +96,8 @@ To interact with the bot, simply send a message to the bot in the Matrix room wi ## Image Generation - - +![demo1](https://i.imgur.com/voeomsF.jpg) +![demo2](https://i.imgur.com/BKZktWd.jpg) https://github.com/hibobmaster/matrix_chatgpt_bot/wiki/
diff --git a/full-config.json.sample b/full-config.json.sample index 6d62d4e..3d91f94 100644 --- a/full-config.json.sample +++ b/full-config.json.sample @@ -18,5 +18,7 @@ "system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally", "flowise_api_url": "http://flowise:3000/api/v1/prediction/6deb3c89-45bf-4ac4-a0b0-b2d5ef249d21", "flowise_api_key": "U3pe0bbVDWOyoJtsDzFJjRvHKTP3FRjODwuM78exC3A=", + "image_generation_endpoint": "http://localai:8080/v1/images/generations", + "image_generation_backend": "openai", "timeout": 120.0 } diff --git a/src/bot.py b/src/bot.py index e1e7e95..08acca9 100644 --- a/src/bot.py +++ b/src/bot.py @@ -5,6 +5,7 @@ import re import sys import traceback from typing import Union, Optional +import aiofiles.os import httpx @@ -33,6 +34,7 @@ from send_image import send_room_image from send_message import send_room_message from flowise import flowise_query from gptbot import Chatbot +import imagegen logger = getlogger() DEVICE_NAME = "MatrixChatGPTBot" @@ -61,6 +63,8 @@ class Bot: temperature: Union[float, None] = None, flowise_api_url: Optional[str] = None, flowise_api_key: Optional[str] = None, + image_generation_endpoint: Optional[str] = None, + image_generation_backend: Optional[str] = None, timeout: Union[float, None] = None, ): if homeserver is None or user_id is None or device_id is None: @@ -71,6 +75,14 @@ class Bot: logger.warning("password is required") sys.exit(1) + if image_generation_endpoint and image_generation_backend not in [ + "openai", + "sdwui", + None, + ]: + logger.warning("image_generation_backend must be openai or sdwui") + sys.exit(1) + self.homeserver: str = homeserver self.user_id: str = user_id self.password: str = password @@ -98,11 +110,16 @@ class Bot: self.import_keys_password: str = import_keys_password self.flowise_api_url: str = flowise_api_url self.flowise_api_key: str = flowise_api_key + self.image_generation_endpoint: str = image_generation_endpoint + self.image_generation_backend: str = image_generation_backend self.timeout: float = timeout or 120.0 self.base_path = Path(os.path.dirname(__file__)).parent + if not os.path.exists(self.base_path / "images"): + os.mkdir(self.base_path / "images") + self.httpx_client = httpx.AsyncClient( follow_redirects=True, timeout=self.timeout, @@ -270,6 +287,23 @@ class Bot: except Exception as e: logger.error(e, exc_info=True) + # !pic command + p = self.pic_prog.match(content_body) + if p: + prompt = p.group(1) + try: + asyncio.create_task( + self.pic( + room_id, + prompt, + reply_to_event_id, + sender_id, + raw_user_message, + ) + ) + except Exception as e: + logger.error(e, exc_info=True) + # help command h = self.help_prog.match(content_body) if h: @@ -523,9 +557,7 @@ class Bot: logger.info(estr) # !chat command - async def chat( - self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message - ): + async def chat(self, room_id, reply_to_event_id, prompt, sender_id, user_message): try: await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000) content = await self.chatbot.ask_async( @@ -538,16 +570,17 @@ class Bot: reply_message=content, reply_to_event_id=reply_to_event_id, sender_id=sender_id, - user_message=raw_user_message, + user_message=user_message, ) - except Exception: + except Exception as e: + logger.error(e, exc_info=True) await self.send_general_error_message( - room_id, reply_to_event_id, sender_id, raw_user_message + room_id, reply_to_event_id, sender_id, user_message ) # !gpt command async def gpt( - self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message + self, room_id, reply_to_event_id, prompt, sender_id, user_message ) -> None: try: # sending typing state, seconds to milliseconds @@ -562,17 +595,17 @@ class Bot: reply_message=responseMessage.strip(), reply_to_event_id=reply_to_event_id, sender_id=sender_id, - user_message=raw_user_message, + user_message=user_message, ) except Exception as e: + logger.error(e, exc_info=True) await self.send_general_error_message( - room_id, reply_to_event_id, sender_id, raw_user_message + room_id, reply_to_event_id, sender_id, user_message ) - logger.error(e) # !lc command async def lc( - self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message + self, room_id, reply_to_event_id, prompt, sender_id, user_message ) -> None: try: # sending typing state @@ -592,11 +625,12 @@ class Bot: reply_message=responseMessage.strip(), reply_to_event_id=reply_to_event_id, sender_id=sender_id, - user_message=raw_user_message, + user_message=user_message, ) - except Exception: + except Exception as e: + logger.error(e, exc_info=True) await self.send_general_error_message( - room_id, reply_to_event_id, sender_id, raw_user_message + room_id, reply_to_event_id, sender_id, user_message ) # !new command @@ -605,7 +639,7 @@ class Bot: room_id, reply_to_event_id, sender_id, - raw_user_message, + user_message, new_command, ) -> None: try: @@ -623,32 +657,58 @@ class Bot: reply_message=content, reply_to_event_id=reply_to_event_id, sender_id=sender_id, - user_message=raw_user_message, + user_message=user_message, ) - except Exception: + except Exception as e: + logger.error(e, exc_info=True) await self.send_general_error_message( - room_id, reply_to_event_id, sender_id, raw_user_message + room_id, reply_to_event_id, sender_id, user_message ) # !pic command - async def pic(self, room_id, prompt, replay_to_event_id): + async def pic(self, room_id, prompt, replay_to_event_id, sender_id, user_message): try: - await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000) - # generate image - links = await self.imageGen.get_images(prompt) - image_path_list = await self.imageGen.save_images( - links, self.base_path / "images", self.output_four_images - ) - # send image - for image_path in image_path_list: - await send_room_image(self.client, room_id, image_path) - await self.client.room_typing(room_id, typing_state=False) + if self.image_generation_endpoint is not None: + await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000) + # generate image + b64_datas = await imagegen.get_images( + self.httpx_client, + self.image_generation_endpoint, + prompt, + self.image_generation_backend, + timeount=self.timeout, + api_key=self.openai_api_key, + n=1, + size="256x256", + ) + image_path_list = await asyncio.to_thread( + imagegen.save_images, + b64_datas, + self.base_path / "images", + ) + # send image + for image_path in image_path_list: + await send_room_image(self.client, room_id, image_path) + await aiofiles.os.remove(image_path) + await self.client.room_typing(room_id, typing_state=False) + else: + await send_room_message( + self.client, + room_id, + reply_message="Image generation endpoint not provided", + reply_to_event_id=replay_to_event_id, + sender_id=sender_id, + user_message=user_message, + ) except Exception as e: + logger.error(e, exc_info=True) await send_room_message( self.client, room_id, - reply_message=str(e), + reply_message="Image generation failed", reply_to_event_id=replay_to_event_id, + user_message=user_message, + sender_id=sender_id, ) # !help command diff --git a/src/imagegen.py b/src/imagegen.py new file mode 100644 index 0000000..fb54f14 --- /dev/null +++ b/src/imagegen.py @@ -0,0 +1,69 @@ +import httpx +from pathlib import Path +import uuid +import base64 +import io +from PIL import Image + + +async def get_images( + aclient: httpx.AsyncClient, url: str, prompt: str, backend_type: str, **kwargs +) -> list[str]: + timeout = kwargs.get("timeout", 120.0) + if backend_type == "openai": + resp = await aclient.post( + url, + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer " + kwargs.get("api_key"), + }, + json={ + "prompt": prompt, + "n": kwargs.get("n", 1), + "size": kwargs.get("size", "256x256"), + "response_format": "b64_json", + }, + timeout=timeout, + ) + if resp.status_code == 200: + b64_datas = [] + for data in resp.json()["data"]: + b64_datas.append(data["b64_json"]) + return b64_datas + else: + raise Exception( + f"{resp.status_code} {resp.reason_phrase} {resp.text}", + ) + elif backend_type == "sdwui": + resp = await aclient.post( + url, + headers={ + "Content-Type": "application/json", + }, + json={ + "prompt": prompt, + "sampler_name": kwargs.get("sampler_name", "Euler a"), + "batch_size": kwargs.get("n", 1), + "steps": kwargs.get("steps", 20), + "width": 256 if "256" in kwargs.get("size") else 512, + "height": 256 if "256" in kwargs.get("size") else 512, + }, + timeout=timeout, + ) + if resp.status_code == 200: + b64_datas = resp.json()["images"] + return b64_datas + else: + raise Exception( + f"{resp.status_code} {resp.reason_phrase} {resp.text}", + ) + + +def save_images(b64_datas: list[str], path: Path, **kwargs) -> list[str]: + images = [] + for b64_data in b64_datas: + image_path = path / (str(uuid.uuid4()) + ".jpeg") + img = Image.open(io.BytesIO(base64.decodebytes(bytes(b64_data, "utf-8")))) + img.save(image_path) + images.append(image_path) + return images diff --git a/src/main.py b/src/main.py index fef7d57..5f835c1 100644 --- a/src/main.py +++ b/src/main.py @@ -42,6 +42,8 @@ async def main(): temperature=float(config.get("temperature")), flowise_api_url=config.get("flowise_api_url"), flowise_api_key=config.get("flowise_api_key"), + image_generation_endpoint=config.get("image_generation_endpoint"), + image_generation_backend=config.get("image_generation_backend"), timeout=float(config.get("timeout")), ) if ( @@ -71,6 +73,8 @@ async def main(): temperature=float(os.environ.get("TEMPERATURE")), flowise_api_url=os.environ.get("FLOWISE_API_URL"), flowise_api_key=os.environ.get("FLOWISE_API_KEY"), + image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"), + image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"), timeout=float(os.environ.get("TIMEOUT")), ) if (