refactor: image generation
This commit is contained in:
parent
0197e8b3d2
commit
bf95dc0f42
6 changed files with 172 additions and 34 deletions
|
@ -17,4 +17,6 @@ SYSTEM_PROMPT="You are ChatGPT, a large language model trained by OpenAI. Respo
|
||||||
TEMPERATURE=0.8
|
TEMPERATURE=0.8
|
||||||
FLOWISE_API_URL="http://flowise:3000/api/v1/prediction/6deb3c89-45bf-4ac4-a0b0-b2d5ef249d21"
|
FLOWISE_API_URL="http://flowise:3000/api/v1/prediction/6deb3c89-45bf-4ac4-a0b0-b2d5ef249d21"
|
||||||
FLOWISE_API_KEY="U3pe0bbVDWOyoJtsDzFJjRvHKTP3FRjODwuM78exC3A="
|
FLOWISE_API_KEY="U3pe0bbVDWOyoJtsDzFJjRvHKTP3FRjODwuM78exC3A="
|
||||||
|
IMAGE_GENERATION_ENDPOINT="http://localai:8080/v1/images/generations"
|
||||||
|
IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui
|
||||||
TIMEOUT=120.0
|
TIMEOUT=120.0
|
||||||
|
|
|
@ -5,10 +5,11 @@ This is a simple Matrix bot that support using OpenAI API, Langchain to generate
|
||||||
|
|
||||||
## Feature
|
## Feature
|
||||||
|
|
||||||
1. Support official openai api and self host models([LocalAI](https://github.com/go-skynet/LocalAI))
|
1. Support official openai api and self host models([LocalAI](https://localai.io/model-compatibility/))
|
||||||
2. Support E2E Encrypted Room
|
2. Support E2E Encrypted Room
|
||||||
3. Colorful code blocks
|
3. Colorful code blocks
|
||||||
4. Langchain([Flowise](https://github.com/FlowiseAI/Flowise))
|
4. Langchain([Flowise](https://github.com/FlowiseAI/Flowise))
|
||||||
|
5. Image Generation with [DALL·E](https://platform.openai.com/docs/api-reference/images/create) or [LocalAI](https://localai.io/features/image-generation/) or [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)
|
||||||
|
|
||||||
|
|
||||||
## Installation and Setup
|
## Installation and Setup
|
||||||
|
@ -67,7 +68,7 @@ python src/main.py
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
To interact with the bot, simply send a message to the bot in the Matrix room with one of the two prompts:<br>
|
To interact with the bot, simply send a message to the bot in the Matrix room with one of the following prompts:<br>
|
||||||
- `!help` help message
|
- `!help` help message
|
||||||
|
|
||||||
- `!gpt` To generate a one time response:
|
- `!gpt` To generate a one time response:
|
||||||
|
@ -95,8 +96,8 @@ To interact with the bot, simply send a message to the bot in the Matrix room wi
|
||||||
|
|
||||||
|
|
||||||
## Image Generation
|
## Image Generation
|
||||||
|
![demo1](https://i.imgur.com/voeomsF.jpg)
|
||||||
|
![demo2](https://i.imgur.com/BKZktWd.jpg)
|
||||||
https://github.com/hibobmaster/matrix_chatgpt_bot/wiki/ <br>
|
https://github.com/hibobmaster/matrix_chatgpt_bot/wiki/ <br>
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,5 +18,7 @@
|
||||||
"system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
"system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
||||||
"flowise_api_url": "http://flowise:3000/api/v1/prediction/6deb3c89-45bf-4ac4-a0b0-b2d5ef249d21",
|
"flowise_api_url": "http://flowise:3000/api/v1/prediction/6deb3c89-45bf-4ac4-a0b0-b2d5ef249d21",
|
||||||
"flowise_api_key": "U3pe0bbVDWOyoJtsDzFJjRvHKTP3FRjODwuM78exC3A=",
|
"flowise_api_key": "U3pe0bbVDWOyoJtsDzFJjRvHKTP3FRjODwuM78exC3A=",
|
||||||
|
"image_generation_endpoint": "http://localai:8080/v1/images/generations",
|
||||||
|
"image_generation_backend": "openai",
|
||||||
"timeout": 120.0
|
"timeout": 120.0
|
||||||
}
|
}
|
||||||
|
|
120
src/bot.py
120
src/bot.py
|
@ -5,6 +5,7 @@ import re
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Union, Optional
|
from typing import Union, Optional
|
||||||
|
import aiofiles.os
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -33,6 +34,7 @@ from send_image import send_room_image
|
||||||
from send_message import send_room_message
|
from send_message import send_room_message
|
||||||
from flowise import flowise_query
|
from flowise import flowise_query
|
||||||
from gptbot import Chatbot
|
from gptbot import Chatbot
|
||||||
|
import imagegen
|
||||||
|
|
||||||
logger = getlogger()
|
logger = getlogger()
|
||||||
DEVICE_NAME = "MatrixChatGPTBot"
|
DEVICE_NAME = "MatrixChatGPTBot"
|
||||||
|
@ -61,6 +63,8 @@ class Bot:
|
||||||
temperature: Union[float, None] = None,
|
temperature: Union[float, None] = None,
|
||||||
flowise_api_url: Optional[str] = None,
|
flowise_api_url: Optional[str] = None,
|
||||||
flowise_api_key: Optional[str] = None,
|
flowise_api_key: Optional[str] = None,
|
||||||
|
image_generation_endpoint: Optional[str] = None,
|
||||||
|
image_generation_backend: Optional[str] = None,
|
||||||
timeout: Union[float, None] = None,
|
timeout: Union[float, None] = None,
|
||||||
):
|
):
|
||||||
if homeserver is None or user_id is None or device_id is None:
|
if homeserver is None or user_id is None or device_id is None:
|
||||||
|
@ -71,6 +75,14 @@ class Bot:
|
||||||
logger.warning("password is required")
|
logger.warning("password is required")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
if image_generation_endpoint and image_generation_backend not in [
|
||||||
|
"openai",
|
||||||
|
"sdwui",
|
||||||
|
None,
|
||||||
|
]:
|
||||||
|
logger.warning("image_generation_backend must be openai or sdwui")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
self.homeserver: str = homeserver
|
self.homeserver: str = homeserver
|
||||||
self.user_id: str = user_id
|
self.user_id: str = user_id
|
||||||
self.password: str = password
|
self.password: str = password
|
||||||
|
@ -98,11 +110,16 @@ class Bot:
|
||||||
self.import_keys_password: str = import_keys_password
|
self.import_keys_password: str = import_keys_password
|
||||||
self.flowise_api_url: str = flowise_api_url
|
self.flowise_api_url: str = flowise_api_url
|
||||||
self.flowise_api_key: str = flowise_api_key
|
self.flowise_api_key: str = flowise_api_key
|
||||||
|
self.image_generation_endpoint: str = image_generation_endpoint
|
||||||
|
self.image_generation_backend: str = image_generation_backend
|
||||||
|
|
||||||
self.timeout: float = timeout or 120.0
|
self.timeout: float = timeout or 120.0
|
||||||
|
|
||||||
self.base_path = Path(os.path.dirname(__file__)).parent
|
self.base_path = Path(os.path.dirname(__file__)).parent
|
||||||
|
|
||||||
|
if not os.path.exists(self.base_path / "images"):
|
||||||
|
os.mkdir(self.base_path / "images")
|
||||||
|
|
||||||
self.httpx_client = httpx.AsyncClient(
|
self.httpx_client = httpx.AsyncClient(
|
||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
|
@ -270,6 +287,23 @@ class Bot:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e, exc_info=True)
|
logger.error(e, exc_info=True)
|
||||||
|
|
||||||
|
# !pic command
|
||||||
|
p = self.pic_prog.match(content_body)
|
||||||
|
if p:
|
||||||
|
prompt = p.group(1)
|
||||||
|
try:
|
||||||
|
asyncio.create_task(
|
||||||
|
self.pic(
|
||||||
|
room_id,
|
||||||
|
prompt,
|
||||||
|
reply_to_event_id,
|
||||||
|
sender_id,
|
||||||
|
raw_user_message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
|
||||||
# help command
|
# help command
|
||||||
h = self.help_prog.match(content_body)
|
h = self.help_prog.match(content_body)
|
||||||
if h:
|
if h:
|
||||||
|
@ -523,9 +557,7 @@ class Bot:
|
||||||
logger.info(estr)
|
logger.info(estr)
|
||||||
|
|
||||||
# !chat command
|
# !chat command
|
||||||
async def chat(
|
async def chat(self, room_id, reply_to_event_id, prompt, sender_id, user_message):
|
||||||
self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
|
await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
|
||||||
content = await self.chatbot.ask_async(
|
content = await self.chatbot.ask_async(
|
||||||
|
@ -538,16 +570,17 @@ class Bot:
|
||||||
reply_message=content,
|
reply_message=content,
|
||||||
reply_to_event_id=reply_to_event_id,
|
reply_to_event_id=reply_to_event_id,
|
||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
user_message=raw_user_message,
|
user_message=user_message,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
await self.send_general_error_message(
|
await self.send_general_error_message(
|
||||||
room_id, reply_to_event_id, sender_id, raw_user_message
|
room_id, reply_to_event_id, sender_id, user_message
|
||||||
)
|
)
|
||||||
|
|
||||||
# !gpt command
|
# !gpt command
|
||||||
async def gpt(
|
async def gpt(
|
||||||
self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message
|
self, room_id, reply_to_event_id, prompt, sender_id, user_message
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
# sending typing state, seconds to milliseconds
|
# sending typing state, seconds to milliseconds
|
||||||
|
@ -562,17 +595,17 @@ class Bot:
|
||||||
reply_message=responseMessage.strip(),
|
reply_message=responseMessage.strip(),
|
||||||
reply_to_event_id=reply_to_event_id,
|
reply_to_event_id=reply_to_event_id,
|
||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
user_message=raw_user_message,
|
user_message=user_message,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
await self.send_general_error_message(
|
await self.send_general_error_message(
|
||||||
room_id, reply_to_event_id, sender_id, raw_user_message
|
room_id, reply_to_event_id, sender_id, user_message
|
||||||
)
|
)
|
||||||
logger.error(e)
|
|
||||||
|
|
||||||
# !lc command
|
# !lc command
|
||||||
async def lc(
|
async def lc(
|
||||||
self, room_id, reply_to_event_id, prompt, sender_id, raw_user_message
|
self, room_id, reply_to_event_id, prompt, sender_id, user_message
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
# sending typing state
|
# sending typing state
|
||||||
|
@ -592,11 +625,12 @@ class Bot:
|
||||||
reply_message=responseMessage.strip(),
|
reply_message=responseMessage.strip(),
|
||||||
reply_to_event_id=reply_to_event_id,
|
reply_to_event_id=reply_to_event_id,
|
||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
user_message=raw_user_message,
|
user_message=user_message,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
await self.send_general_error_message(
|
await self.send_general_error_message(
|
||||||
room_id, reply_to_event_id, sender_id, raw_user_message
|
room_id, reply_to_event_id, sender_id, user_message
|
||||||
)
|
)
|
||||||
|
|
||||||
# !new command
|
# !new command
|
||||||
|
@ -605,7 +639,7 @@ class Bot:
|
||||||
room_id,
|
room_id,
|
||||||
reply_to_event_id,
|
reply_to_event_id,
|
||||||
sender_id,
|
sender_id,
|
||||||
raw_user_message,
|
user_message,
|
||||||
new_command,
|
new_command,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
|
@ -623,32 +657,58 @@ class Bot:
|
||||||
reply_message=content,
|
reply_message=content,
|
||||||
reply_to_event_id=reply_to_event_id,
|
reply_to_event_id=reply_to_event_id,
|
||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
user_message=raw_user_message,
|
user_message=user_message,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
await self.send_general_error_message(
|
await self.send_general_error_message(
|
||||||
room_id, reply_to_event_id, sender_id, raw_user_message
|
room_id, reply_to_event_id, sender_id, user_message
|
||||||
)
|
)
|
||||||
|
|
||||||
# !pic command
|
# !pic command
|
||||||
async def pic(self, room_id, prompt, replay_to_event_id):
|
async def pic(self, room_id, prompt, replay_to_event_id, sender_id, user_message):
|
||||||
try:
|
try:
|
||||||
await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
|
if self.image_generation_endpoint is not None:
|
||||||
# generate image
|
await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
|
||||||
links = await self.imageGen.get_images(prompt)
|
# generate image
|
||||||
image_path_list = await self.imageGen.save_images(
|
b64_datas = await imagegen.get_images(
|
||||||
links, self.base_path / "images", self.output_four_images
|
self.httpx_client,
|
||||||
)
|
self.image_generation_endpoint,
|
||||||
# send image
|
prompt,
|
||||||
for image_path in image_path_list:
|
self.image_generation_backend,
|
||||||
await send_room_image(self.client, room_id, image_path)
|
timeount=self.timeout,
|
||||||
await self.client.room_typing(room_id, typing_state=False)
|
api_key=self.openai_api_key,
|
||||||
|
n=1,
|
||||||
|
size="256x256",
|
||||||
|
)
|
||||||
|
image_path_list = await asyncio.to_thread(
|
||||||
|
imagegen.save_images,
|
||||||
|
b64_datas,
|
||||||
|
self.base_path / "images",
|
||||||
|
)
|
||||||
|
# send image
|
||||||
|
for image_path in image_path_list:
|
||||||
|
await send_room_image(self.client, room_id, image_path)
|
||||||
|
await aiofiles.os.remove(image_path)
|
||||||
|
await self.client.room_typing(room_id, typing_state=False)
|
||||||
|
else:
|
||||||
|
await send_room_message(
|
||||||
|
self.client,
|
||||||
|
room_id,
|
||||||
|
reply_message="Image generation endpoint not provided",
|
||||||
|
reply_to_event_id=replay_to_event_id,
|
||||||
|
sender_id=sender_id,
|
||||||
|
user_message=user_message,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
await send_room_message(
|
await send_room_message(
|
||||||
self.client,
|
self.client,
|
||||||
room_id,
|
room_id,
|
||||||
reply_message=str(e),
|
reply_message="Image generation failed",
|
||||||
reply_to_event_id=replay_to_event_id,
|
reply_to_event_id=replay_to_event_id,
|
||||||
|
user_message=user_message,
|
||||||
|
sender_id=sender_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# !help command
|
# !help command
|
||||||
|
|
69
src/imagegen.py
Normal file
69
src/imagegen.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
import httpx
|
||||||
|
from pathlib import Path
|
||||||
|
import uuid
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
async def get_images(
|
||||||
|
aclient: httpx.AsyncClient, url: str, prompt: str, backend_type: str, **kwargs
|
||||||
|
) -> list[str]:
|
||||||
|
timeout = kwargs.get("timeout", 120.0)
|
||||||
|
if backend_type == "openai":
|
||||||
|
resp = await aclient.post(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": "Bearer " + kwargs.get("api_key"),
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"prompt": prompt,
|
||||||
|
"n": kwargs.get("n", 1),
|
||||||
|
"size": kwargs.get("size", "256x256"),
|
||||||
|
"response_format": "b64_json",
|
||||||
|
},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
b64_datas = []
|
||||||
|
for data in resp.json()["data"]:
|
||||||
|
b64_datas.append(data["b64_json"])
|
||||||
|
return b64_datas
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
||||||
|
)
|
||||||
|
elif backend_type == "sdwui":
|
||||||
|
resp = await aclient.post(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"prompt": prompt,
|
||||||
|
"sampler_name": kwargs.get("sampler_name", "Euler a"),
|
||||||
|
"batch_size": kwargs.get("n", 1),
|
||||||
|
"steps": kwargs.get("steps", 20),
|
||||||
|
"width": 256 if "256" in kwargs.get("size") else 512,
|
||||||
|
"height": 256 if "256" in kwargs.get("size") else 512,
|
||||||
|
},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
b64_datas = resp.json()["images"]
|
||||||
|
return b64_datas
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_images(b64_datas: list[str], path: Path, **kwargs) -> list[str]:
|
||||||
|
images = []
|
||||||
|
for b64_data in b64_datas:
|
||||||
|
image_path = path / (str(uuid.uuid4()) + ".jpeg")
|
||||||
|
img = Image.open(io.BytesIO(base64.decodebytes(bytes(b64_data, "utf-8"))))
|
||||||
|
img.save(image_path)
|
||||||
|
images.append(image_path)
|
||||||
|
return images
|
|
@ -42,6 +42,8 @@ async def main():
|
||||||
temperature=float(config.get("temperature")),
|
temperature=float(config.get("temperature")),
|
||||||
flowise_api_url=config.get("flowise_api_url"),
|
flowise_api_url=config.get("flowise_api_url"),
|
||||||
flowise_api_key=config.get("flowise_api_key"),
|
flowise_api_key=config.get("flowise_api_key"),
|
||||||
|
image_generation_endpoint=config.get("image_generation_endpoint"),
|
||||||
|
image_generation_backend=config.get("image_generation_backend"),
|
||||||
timeout=float(config.get("timeout")),
|
timeout=float(config.get("timeout")),
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
|
@ -71,6 +73,8 @@ async def main():
|
||||||
temperature=float(os.environ.get("TEMPERATURE")),
|
temperature=float(os.environ.get("TEMPERATURE")),
|
||||||
flowise_api_url=os.environ.get("FLOWISE_API_URL"),
|
flowise_api_url=os.environ.get("FLOWISE_API_URL"),
|
||||||
flowise_api_key=os.environ.get("FLOWISE_API_KEY"),
|
flowise_api_key=os.environ.get("FLOWISE_API_KEY"),
|
||||||
|
image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"),
|
||||||
|
image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"),
|
||||||
timeout=float(os.environ.get("TIMEOUT")),
|
timeout=float(os.environ.get("TIMEOUT")),
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
|
|
Loading…
Reference in a new issue