From bf95dc0f42a35f6eb58bff2d9730b6c4105cc4ce Mon Sep 17 00:00:00 2001
From: hibobmaster <32976627+hibobmaster@users.noreply.github.com>
Date: Sun, 17 Sep 2023 12:27:16 +0800
Subject: [PATCH] refactor: image generation
---
.full-env.example | 2 +
README.md | 9 +--
full-config.json.sample | 2 +
src/bot.py | 120 ++++++++++++++++++++++++++++++----------
src/imagegen.py | 69 +++++++++++++++++++++++
src/main.py | 4 ++
6 files changed, 172 insertions(+), 34 deletions(-)
create mode 100644 src/imagegen.py
diff --git a/.full-env.example b/.full-env.example
index d1c9f2c..a4565e6 100644
--- a/.full-env.example
+++ b/.full-env.example
@@ -17,4 +17,6 @@ SYSTEM_PROMPT="You are ChatGPT, a large language model trained by OpenAI. Respo
TEMPERATURE=0.8
FLOWISE_API_URL="http://flowise:3000/api/v1/prediction/6deb3c89-45bf-4ac4-a0b0-b2d5ef249d21"
FLOWISE_API_KEY="U3pe0bbVDWOyoJtsDzFJjRvHKTP3FRjODwuM78exC3A="
+IMAGE_GENERATION_ENDPOINT="http://localai:8080/v1/images/generations"
+IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui
TIMEOUT=120.0
diff --git a/README.md b/README.md
index d25e6fb..96bd2b8 100644
--- a/README.md
+++ b/README.md
@@ -5,10 +5,11 @@ This is a simple Matrix bot that support using OpenAI API, Langchain to generate
## Feature
-1. Support official openai api and self host models([LocalAI](https://github.com/go-skynet/LocalAI))
+1. Support official openai api and self host models([LocalAI](https://localai.io/model-compatibility/))
2. Support E2E Encrypted Room
3. Colorful code blocks
4. Langchain([Flowise](https://github.com/FlowiseAI/Flowise))
+5. Image Generation with [DALLĀ·E](https://platform.openai.com/docs/api-reference/images/create) or [LocalAI](https://localai.io/features/image-generation/) or [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)
## Installation and Setup
@@ -67,7 +68,7 @@ python src/main.py
## Usage
-To interact with the bot, simply send a message to the bot in the Matrix room with one of the two prompts:
+To interact with the bot, simply send a message to the bot in the Matrix room with one of the following prompts:
- `!help` help message
- `!gpt` To generate a one time response:
@@ -95,8 +96,8 @@ To interact with the bot, simply send a message to the bot in the Matrix room wi
## Image Generation
-
-
+![demo1](https://i.imgur.com/voeomsF.jpg)
+![demo2](https://i.imgur.com/BKZktWd.jpg)
https://github.com/hibobmaster/matrix_chatgpt_bot/wiki/
diff --git a/full-config.json.sample b/full-config.json.sample
index 6d62d4e..3d91f94 100644
--- a/full-config.json.sample
+++ b/full-config.json.sample
@@ -18,5 +18,7 @@
"system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
"flowise_api_url": "http://flowise:3000/api/v1/prediction/6deb3c89-45bf-4ac4-a0b0-b2d5ef249d21",
"flowise_api_key": "U3pe0bbVDWOyoJtsDzFJjRvHKTP3FRjODwuM78exC3A=",
+ "image_generation_endpoint": "http://localai:8080/v1/images/generations",
+ "image_generation_backend": "openai",
"timeout": 120.0
}
diff --git a/src/bot.py b/src/bot.py
index e1e7e95..08acca9 100644
--- a/src/bot.py
+++ b/src/bot.py
@@ -5,6 +5,7 @@ import re
import sys
import traceback
from typing import Union, Optional
+import aiofiles.os
import httpx
@@ -33,6 +34,7 @@ from send_image import send_room_image
from send_message import send_room_message
from flowise import flowise_query
from gptbot import Chatbot
+import imagegen
logger = getlogger()
DEVICE_NAME = "MatrixChatGPTBot"
@@ -61,6 +63,8 @@ class Bot:
temperature: Union[float, None] = None,
flowise_api_url: Optional[str] = None,
flowise_api_key: Optional[str] = None,
+ image_generation_endpoint: Optional[str] = None,
+ image_generation_backend: Optional[str] = None,
timeout: Union[float, None] = None,
):
if homeserver is None or user_id is None or device_id is None:
@@ -71,6 +75,14 @@ class Bot:
logger.warning("password is required")
sys.exit(1)
+ if image_generation_endpoint and image_generation_backend not in [
+ "openai",
+ "sdwui",
+ None,
+ ]:
+ logger.warning("image_generation_backend must be openai or sdwui")
+ sys.exit(1)
+
self.homeserver: str = homeserver
self.user_id: str = user_id
self.password: str = password
@@ -98,11 +110,16 @@ class Bot:
self.import_keys_password: str = import_keys_password
self.flowise_api_url: str = flowise_api_url
self.flowise_api_key: str = flowise_api_key
+ self.image_generation_endpoint: str = image_generation_endpoint
+ self.image_generation_backend: str = image_generation_backend
self.timeout: float = timeout or 120.0
self.base_path = Path(os.path.dirname(__file__)).parent
+ if not os.path.exists(self.base_path / "images"):
+ os.mkdir(self.base_path / "images")
+
self.httpx_client = httpx.AsyncClient(
follow_redirects=True,
timeout=self.timeout,
@@ -270,6 +287,23 @@ class Bot:
except Exception as e:
logger.error(e, exc_info=True)
+ # !pic command
+ p = self.pic_prog.match(content_body)
+ if p:
+ prompt = p.group(1)
+ try:
+ asyncio.create_task(
+ self.pic(
+ room_id,
+ prompt,
+ reply_to_event_id,
+ sender_id,
+ raw_user_message,
+ )
+ )
+ except Exception as e:
+ logger.error(e, exc_info=True)
+
# help command
h = self.help_prog.match(content_body)
if h:
@@ -523,9 +557,7 @@ class Bot:
logger.info(estr)
# !chat command
- async def chat(
- self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message
- ):
+ async def chat(self, room_id, reply_to_event_id, prompt, sender_id, user_message):
try:
await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
content = await self.chatbot.ask_async(
@@ -538,16 +570,17 @@ class Bot:
reply_message=content,
reply_to_event_id=reply_to_event_id,
sender_id=sender_id,
- user_message=raw_user_message,
+ user_message=user_message,
)
- except Exception:
+ except Exception as e:
+ logger.error(e, exc_info=True)
await self.send_general_error_message(
- room_id, reply_to_event_id, sender_id, raw_user_message
+ room_id, reply_to_event_id, sender_id, user_message
)
# !gpt command
async def gpt(
- self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message
+ self, room_id, reply_to_event_id, prompt, sender_id, user_message
) -> None:
try:
# sending typing state, seconds to milliseconds
@@ -562,17 +595,17 @@ class Bot:
reply_message=responseMessage.strip(),
reply_to_event_id=reply_to_event_id,
sender_id=sender_id,
- user_message=raw_user_message,
+ user_message=user_message,
)
except Exception as e:
+ logger.error(e, exc_info=True)
await self.send_general_error_message(
- room_id, reply_to_event_id, sender_id, raw_user_message
+ room_id, reply_to_event_id, sender_id, user_message
)
- logger.error(e)
# !lc command
async def lc(
- self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message
+ self, room_id, reply_to_event_id, prompt, sender_id, user_message
) -> None:
try:
# sending typing state
@@ -592,11 +625,12 @@ class Bot:
reply_message=responseMessage.strip(),
reply_to_event_id=reply_to_event_id,
sender_id=sender_id,
- user_message=raw_user_message,
+ user_message=user_message,
)
- except Exception:
+ except Exception as e:
+ logger.error(e, exc_info=True)
await self.send_general_error_message(
- room_id, reply_to_event_id, sender_id, raw_user_message
+ room_id, reply_to_event_id, sender_id, user_message
)
# !new command
@@ -605,7 +639,7 @@ class Bot:
room_id,
reply_to_event_id,
sender_id,
- raw_user_message,
+ user_message,
new_command,
) -> None:
try:
@@ -623,32 +657,58 @@ class Bot:
reply_message=content,
reply_to_event_id=reply_to_event_id,
sender_id=sender_id,
- user_message=raw_user_message,
+ user_message=user_message,
)
- except Exception:
+ except Exception as e:
+ logger.error(e, exc_info=True)
await self.send_general_error_message(
- room_id, reply_to_event_id, sender_id, raw_user_message
+ room_id, reply_to_event_id, sender_id, user_message
)
# !pic command
- async def pic(self, room_id, prompt, replay_to_event_id):
+ async def pic(self, room_id, prompt, replay_to_event_id, sender_id, user_message):
try:
- await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
- # generate image
- links = await self.imageGen.get_images(prompt)
- image_path_list = await self.imageGen.save_images(
- links, self.base_path / "images", self.output_four_images
- )
- # send image
- for image_path in image_path_list:
- await send_room_image(self.client, room_id, image_path)
- await self.client.room_typing(room_id, typing_state=False)
+ if self.image_generation_endpoint is not None:
+ await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
+ # generate image
+ b64_datas = await imagegen.get_images(
+ self.httpx_client,
+ self.image_generation_endpoint,
+ prompt,
+ self.image_generation_backend,
+ timeount=self.timeout,
+ api_key=self.openai_api_key,
+ n=1,
+ size="256x256",
+ )
+ image_path_list = await asyncio.to_thread(
+ imagegen.save_images,
+ b64_datas,
+ self.base_path / "images",
+ )
+ # send image
+ for image_path in image_path_list:
+ await send_room_image(self.client, room_id, image_path)
+ await aiofiles.os.remove(image_path)
+ await self.client.room_typing(room_id, typing_state=False)
+ else:
+ await send_room_message(
+ self.client,
+ room_id,
+ reply_message="Image generation endpoint not provided",
+ reply_to_event_id=replay_to_event_id,
+ sender_id=sender_id,
+ user_message=user_message,
+ )
except Exception as e:
+ logger.error(e, exc_info=True)
await send_room_message(
self.client,
room_id,
- reply_message=str(e),
+ reply_message="Image generation failed",
reply_to_event_id=replay_to_event_id,
+ user_message=user_message,
+ sender_id=sender_id,
)
# !help command
diff --git a/src/imagegen.py b/src/imagegen.py
new file mode 100644
index 0000000..fb54f14
--- /dev/null
+++ b/src/imagegen.py
@@ -0,0 +1,69 @@
+import httpx
+from pathlib import Path
+import uuid
+import base64
+import io
+from PIL import Image
+
+
+async def get_images(
+ aclient: httpx.AsyncClient, url: str, prompt: str, backend_type: str, **kwargs
+) -> list[str]:
+ timeout = kwargs.get("timeout", 120.0)
+ if backend_type == "openai":
+ resp = await aclient.post(
+ url,
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": "Bearer " + kwargs.get("api_key"),
+ },
+ json={
+ "prompt": prompt,
+ "n": kwargs.get("n", 1),
+ "size": kwargs.get("size", "256x256"),
+ "response_format": "b64_json",
+ },
+ timeout=timeout,
+ )
+ if resp.status_code == 200:
+ b64_datas = []
+ for data in resp.json()["data"]:
+ b64_datas.append(data["b64_json"])
+ return b64_datas
+ else:
+ raise Exception(
+ f"{resp.status_code} {resp.reason_phrase} {resp.text}",
+ )
+ elif backend_type == "sdwui":
+ resp = await aclient.post(
+ url,
+ headers={
+ "Content-Type": "application/json",
+ },
+ json={
+ "prompt": prompt,
+ "sampler_name": kwargs.get("sampler_name", "Euler a"),
+ "batch_size": kwargs.get("n", 1),
+ "steps": kwargs.get("steps", 20),
+ "width": 256 if "256" in kwargs.get("size") else 512,
+ "height": 256 if "256" in kwargs.get("size") else 512,
+ },
+ timeout=timeout,
+ )
+ if resp.status_code == 200:
+ b64_datas = resp.json()["images"]
+ return b64_datas
+ else:
+ raise Exception(
+ f"{resp.status_code} {resp.reason_phrase} {resp.text}",
+ )
+
+
+def save_images(b64_datas: list[str], path: Path, **kwargs) -> list[str]:
+ images = []
+ for b64_data in b64_datas:
+ image_path = path / (str(uuid.uuid4()) + ".jpeg")
+ img = Image.open(io.BytesIO(base64.decodebytes(bytes(b64_data, "utf-8"))))
+ img.save(image_path)
+ images.append(image_path)
+ return images
diff --git a/src/main.py b/src/main.py
index fef7d57..5f835c1 100644
--- a/src/main.py
+++ b/src/main.py
@@ -42,6 +42,8 @@ async def main():
temperature=float(config.get("temperature")),
flowise_api_url=config.get("flowise_api_url"),
flowise_api_key=config.get("flowise_api_key"),
+ image_generation_endpoint=config.get("image_generation_endpoint"),
+ image_generation_backend=config.get("image_generation_backend"),
timeout=float(config.get("timeout")),
)
if (
@@ -71,6 +73,8 @@ async def main():
temperature=float(os.environ.get("TEMPERATURE")),
flowise_api_url=os.environ.get("FLOWISE_API_URL"),
flowise_api_key=os.environ.get("FLOWISE_API_KEY"),
+ image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"),
+ image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"),
timeout=float(os.environ.get("TIMEOUT")),
)
if (