refactor: image generation
This commit is contained in:
parent
0197e8b3d2
commit
bf95dc0f42
6 changed files with 172 additions and 34 deletions
|
@ -17,4 +17,6 @@ SYSTEM_PROMPT="You are ChatGPT, a large language model trained by OpenAI. Respo
|
|||
TEMPERATURE=0.8
|
||||
FLOWISE_API_URL="http://flowise:3000/api/v1/prediction/6deb3c89-45bf-4ac4-a0b0-b2d5ef249d21"
|
||||
FLOWISE_API_KEY="U3pe0bbVDWOyoJtsDzFJjRvHKTP3FRjODwuM78exC3A="
|
||||
IMAGE_GENERATION_ENDPOINT="http://localai:8080/v1/images/generations"
|
||||
IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui
|
||||
TIMEOUT=120.0
|
||||
|
|
|
@ -5,10 +5,11 @@ This is a simple Matrix bot that support using OpenAI API, Langchain to generate
|
|||
|
||||
## Feature
|
||||
|
||||
1. Support official openai api and self host models([LocalAI](https://github.com/go-skynet/LocalAI))
|
||||
1. Support official openai api and self host models([LocalAI](https://localai.io/model-compatibility/))
|
||||
2. Support E2E Encrypted Room
|
||||
3. Colorful code blocks
|
||||
4. Langchain([Flowise](https://github.com/FlowiseAI/Flowise))
|
||||
5. Image Generation with [DALL·E](https://platform.openai.com/docs/api-reference/images/create) or [LocalAI](https://localai.io/features/image-generation/) or [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)
|
||||
|
||||
|
||||
## Installation and Setup
|
||||
|
@ -67,7 +68,7 @@ python src/main.py
|
|||
|
||||
## Usage
|
||||
|
||||
To interact with the bot, simply send a message to the bot in the Matrix room with one of the two prompts:<br>
|
||||
To interact with the bot, simply send a message to the bot in the Matrix room with one of the following prompts:<br>
|
||||
- `!help` help message
|
||||
|
||||
- `!gpt` To generate a one time response:
|
||||
|
@ -95,8 +96,8 @@ To interact with the bot, simply send a message to the bot in the Matrix room wi
|
|||
|
||||
|
||||
## Image Generation
|
||||
|
||||
|
||||
![demo1](https://i.imgur.com/voeomsF.jpg)
|
||||
![demo2](https://i.imgur.com/BKZktWd.jpg)
|
||||
https://github.com/hibobmaster/matrix_chatgpt_bot/wiki/ <br>
|
||||
|
||||
|
||||
|
|
|
@ -18,5 +18,7 @@
|
|||
"system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
||||
"flowise_api_url": "http://flowise:3000/api/v1/prediction/6deb3c89-45bf-4ac4-a0b0-b2d5ef249d21",
|
||||
"flowise_api_key": "U3pe0bbVDWOyoJtsDzFJjRvHKTP3FRjODwuM78exC3A=",
|
||||
"image_generation_endpoint": "http://localai:8080/v1/images/generations",
|
||||
"image_generation_backend": "openai",
|
||||
"timeout": 120.0
|
||||
}
|
||||
|
|
120
src/bot.py
120
src/bot.py
|
@ -5,6 +5,7 @@ import re
|
|||
import sys
|
||||
import traceback
|
||||
from typing import Union, Optional
|
||||
import aiofiles.os
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -33,6 +34,7 @@ from send_image import send_room_image
|
|||
from send_message import send_room_message
|
||||
from flowise import flowise_query
|
||||
from gptbot import Chatbot
|
||||
import imagegen
|
||||
|
||||
logger = getlogger()
|
||||
DEVICE_NAME = "MatrixChatGPTBot"
|
||||
|
@ -61,6 +63,8 @@ class Bot:
|
|||
temperature: Union[float, None] = None,
|
||||
flowise_api_url: Optional[str] = None,
|
||||
flowise_api_key: Optional[str] = None,
|
||||
image_generation_endpoint: Optional[str] = None,
|
||||
image_generation_backend: Optional[str] = None,
|
||||
timeout: Union[float, None] = None,
|
||||
):
|
||||
if homeserver is None or user_id is None or device_id is None:
|
||||
|
@ -71,6 +75,14 @@ class Bot:
|
|||
logger.warning("password is required")
|
||||
sys.exit(1)
|
||||
|
||||
if image_generation_endpoint and image_generation_backend not in [
|
||||
"openai",
|
||||
"sdwui",
|
||||
None,
|
||||
]:
|
||||
logger.warning("image_generation_backend must be openai or sdwui")
|
||||
sys.exit(1)
|
||||
|
||||
self.homeserver: str = homeserver
|
||||
self.user_id: str = user_id
|
||||
self.password: str = password
|
||||
|
@ -98,11 +110,16 @@ class Bot:
|
|||
self.import_keys_password: str = import_keys_password
|
||||
self.flowise_api_url: str = flowise_api_url
|
||||
self.flowise_api_key: str = flowise_api_key
|
||||
self.image_generation_endpoint: str = image_generation_endpoint
|
||||
self.image_generation_backend: str = image_generation_backend
|
||||
|
||||
self.timeout: float = timeout or 120.0
|
||||
|
||||
self.base_path = Path(os.path.dirname(__file__)).parent
|
||||
|
||||
if not os.path.exists(self.base_path / "images"):
|
||||
os.mkdir(self.base_path / "images")
|
||||
|
||||
self.httpx_client = httpx.AsyncClient(
|
||||
follow_redirects=True,
|
||||
timeout=self.timeout,
|
||||
|
@ -270,6 +287,23 @@ class Bot:
|
|||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
|
||||
# !pic command
|
||||
p = self.pic_prog.match(content_body)
|
||||
if p:
|
||||
prompt = p.group(1)
|
||||
try:
|
||||
asyncio.create_task(
|
||||
self.pic(
|
||||
room_id,
|
||||
prompt,
|
||||
reply_to_event_id,
|
||||
sender_id,
|
||||
raw_user_message,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
|
||||
# help command
|
||||
h = self.help_prog.match(content_body)
|
||||
if h:
|
||||
|
@ -523,9 +557,7 @@ class Bot:
|
|||
logger.info(estr)
|
||||
|
||||
# !chat command
|
||||
async def chat(
|
||||
self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message
|
||||
):
|
||||
async def chat(self, room_id, reply_to_event_id, prompt, sender_id, user_message):
|
||||
try:
|
||||
await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
|
||||
content = await self.chatbot.ask_async(
|
||||
|
@ -538,16 +570,17 @@ class Bot:
|
|||
reply_message=content,
|
||||
reply_to_event_id=reply_to_event_id,
|
||||
sender_id=sender_id,
|
||||
user_message=raw_user_message,
|
||||
user_message=user_message,
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
await self.send_general_error_message(
|
||||
room_id, reply_to_event_id, sender_id, raw_user_message
|
||||
room_id, reply_to_event_id, sender_id, user_message
|
||||
)
|
||||
|
||||
# !gpt command
|
||||
async def gpt(
|
||||
self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message
|
||||
self, room_id, reply_to_event_id, prompt, sender_id, user_message
|
||||
) -> None:
|
||||
try:
|
||||
# sending typing state, seconds to milliseconds
|
||||
|
@ -562,17 +595,17 @@ class Bot:
|
|||
reply_message=responseMessage.strip(),
|
||||
reply_to_event_id=reply_to_event_id,
|
||||
sender_id=sender_id,
|
||||
user_message=raw_user_message,
|
||||
user_message=user_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
await self.send_general_error_message(
|
||||
room_id, reply_to_event_id, sender_id, raw_user_message
|
||||
room_id, reply_to_event_id, sender_id, user_message
|
||||
)
|
||||
logger.error(e)
|
||||
|
||||
# !lc command
|
||||
async def lc(
|
||||
self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message
|
||||
self, room_id, reply_to_event_id, prompt, sender_id, user_message
|
||||
) -> None:
|
||||
try:
|
||||
# sending typing state
|
||||
|
@ -592,11 +625,12 @@ class Bot:
|
|||
reply_message=responseMessage.strip(),
|
||||
reply_to_event_id=reply_to_event_id,
|
||||
sender_id=sender_id,
|
||||
user_message=raw_user_message,
|
||||
user_message=user_message,
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
await self.send_general_error_message(
|
||||
room_id, reply_to_event_id, sender_id, raw_user_message
|
||||
room_id, reply_to_event_id, sender_id, user_message
|
||||
)
|
||||
|
||||
# !new command
|
||||
|
@ -605,7 +639,7 @@ class Bot:
|
|||
room_id,
|
||||
reply_to_event_id,
|
||||
sender_id,
|
||||
raw_user_message,
|
||||
user_message,
|
||||
new_command,
|
||||
) -> None:
|
||||
try:
|
||||
|
@ -623,32 +657,58 @@ class Bot:
|
|||
reply_message=content,
|
||||
reply_to_event_id=reply_to_event_id,
|
||||
sender_id=sender_id,
|
||||
user_message=raw_user_message,
|
||||
user_message=user_message,
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
await self.send_general_error_message(
|
||||
room_id, reply_to_event_id, sender_id, raw_user_message
|
||||
room_id, reply_to_event_id, sender_id, user_message
|
||||
)
|
||||
|
||||
# !pic command
|
||||
async def pic(self, room_id, prompt, replay_to_event_id):
|
||||
async def pic(self, room_id, prompt, replay_to_event_id, sender_id, user_message):
|
||||
try:
|
||||
await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
|
||||
# generate image
|
||||
links = await self.imageGen.get_images(prompt)
|
||||
image_path_list = await self.imageGen.save_images(
|
||||
links, self.base_path / "images", self.output_four_images
|
||||
)
|
||||
# send image
|
||||
for image_path in image_path_list:
|
||||
await send_room_image(self.client, room_id, image_path)
|
||||
await self.client.room_typing(room_id, typing_state=False)
|
||||
if self.image_generation_endpoint is not None:
|
||||
await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
|
||||
# generate image
|
||||
b64_datas = await imagegen.get_images(
|
||||
self.httpx_client,
|
||||
self.image_generation_endpoint,
|
||||
prompt,
|
||||
self.image_generation_backend,
|
||||
timeount=self.timeout,
|
||||
api_key=self.openai_api_key,
|
||||
n=1,
|
||||
size="256x256",
|
||||
)
|
||||
image_path_list = await asyncio.to_thread(
|
||||
imagegen.save_images,
|
||||
b64_datas,
|
||||
self.base_path / "images",
|
||||
)
|
||||
# send image
|
||||
for image_path in image_path_list:
|
||||
await send_room_image(self.client, room_id, image_path)
|
||||
await aiofiles.os.remove(image_path)
|
||||
await self.client.room_typing(room_id, typing_state=False)
|
||||
else:
|
||||
await send_room_message(
|
||||
self.client,
|
||||
room_id,
|
||||
reply_message="Image generation endpoint not provided",
|
||||
reply_to_event_id=replay_to_event_id,
|
||||
sender_id=sender_id,
|
||||
user_message=user_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
await send_room_message(
|
||||
self.client,
|
||||
room_id,
|
||||
reply_message=str(e),
|
||||
reply_message="Image generation failed",
|
||||
reply_to_event_id=replay_to_event_id,
|
||||
user_message=user_message,
|
||||
sender_id=sender_id,
|
||||
)
|
||||
|
||||
# !help command
|
||||
|
|
69
src/imagegen.py
Normal file
69
src/imagegen.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
import httpx
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
import base64
|
||||
import io
|
||||
from PIL import Image
|
||||
|
||||
|
||||
async def get_images(
|
||||
aclient: httpx.AsyncClient, url: str, prompt: str, backend_type: str, **kwargs
|
||||
) -> list[str]:
|
||||
timeout = kwargs.get("timeout", 120.0)
|
||||
if backend_type == "openai":
|
||||
resp = await aclient.post(
|
||||
url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + kwargs.get("api_key"),
|
||||
},
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"n": kwargs.get("n", 1),
|
||||
"size": kwargs.get("size", "256x256"),
|
||||
"response_format": "b64_json",
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
b64_datas = []
|
||||
for data in resp.json()["data"]:
|
||||
b64_datas.append(data["b64_json"])
|
||||
return b64_datas
|
||||
else:
|
||||
raise Exception(
|
||||
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
||||
)
|
||||
elif backend_type == "sdwui":
|
||||
resp = await aclient.post(
|
||||
url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"sampler_name": kwargs.get("sampler_name", "Euler a"),
|
||||
"batch_size": kwargs.get("n", 1),
|
||||
"steps": kwargs.get("steps", 20),
|
||||
"width": 256 if "256" in kwargs.get("size") else 512,
|
||||
"height": 256 if "256" in kwargs.get("size") else 512,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
b64_datas = resp.json()["images"]
|
||||
return b64_datas
|
||||
else:
|
||||
raise Exception(
|
||||
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
||||
)
|
||||
|
||||
|
||||
def save_images(b64_datas: list[str], path: Path, **kwargs) -> list[str]:
|
||||
images = []
|
||||
for b64_data in b64_datas:
|
||||
image_path = path / (str(uuid.uuid4()) + ".jpeg")
|
||||
img = Image.open(io.BytesIO(base64.decodebytes(bytes(b64_data, "utf-8"))))
|
||||
img.save(image_path)
|
||||
images.append(image_path)
|
||||
return images
|
|
@ -42,6 +42,8 @@ async def main():
|
|||
temperature=float(config.get("temperature")),
|
||||
flowise_api_url=config.get("flowise_api_url"),
|
||||
flowise_api_key=config.get("flowise_api_key"),
|
||||
image_generation_endpoint=config.get("image_generation_endpoint"),
|
||||
image_generation_backend=config.get("image_generation_backend"),
|
||||
timeout=float(config.get("timeout")),
|
||||
)
|
||||
if (
|
||||
|
@ -71,6 +73,8 @@ async def main():
|
|||
temperature=float(os.environ.get("TEMPERATURE")),
|
||||
flowise_api_url=os.environ.get("FLOWISE_API_URL"),
|
||||
flowise_api_key=os.environ.get("FLOWISE_API_KEY"),
|
||||
image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"),
|
||||
image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"),
|
||||
timeout=float(os.environ.get("TIMEOUT")),
|
||||
)
|
||||
if (
|
||||
|
|
Loading…
Reference in a new issue