diff --git a/.full-env.example b/.full-env.example index 53c3020..0176496 100644 --- a/.full-env.example +++ b/.full-env.example @@ -15,5 +15,7 @@ REPLY_COUNT=1 SYSTEM_PROMPT="You are ChatGPT, a large language model trained by OpenAI. Respond conversationally" TEMPERATURE=0.8 IMAGE_GENERATION_ENDPOINT="http://127.0.0.1:7860/sdapi/v1/txt2img" -IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui +IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui or localai +IMAGE_GENERATION_SIZE="512x512" +IMAGE_FORMAT="jpeg" TIMEOUT=120.0 diff --git a/full-config.json.example b/full-config.json.example index 5215c90..d964c3a 100644 --- a/full-config.json.example +++ b/full-config.json.example @@ -16,6 +16,8 @@ "temperature": 0.8, "system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally", "image_generation_endpoint": "http://localai:8080/v1/images/generations", - "image_generation_backend": "openai", + "image_generation_backend": "localai", + "image_generation_size": "512x512", + "image_format": "webp", "timeout": 120.0 } diff --git a/requirements.txt b/requirements.txt index 6314c78..1a4b9d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ httpx Pillow tiktoken tenacity +aiofiles mattermostdriver @ git+https://github.com/hibobmaster/python-mattermost-driver diff --git a/src/bot.py b/src/bot.py index e801715..5c1ca59 100644 --- a/src/bot.py +++ b/src/bot.py @@ -1,3 +1,5 @@ +import sys +import aiofiles.os from mattermostdriver import AsyncDriver from typing import Optional import json @@ -34,6 +36,8 @@ class Bot: temperature: Optional[float] = None, image_generation_endpoint: Optional[str] = None, image_generation_backend: Optional[str] = None, + image_generation_size: Optional[str] = None, + image_format: Optional[str] = None, timeout: Optional[float] = 120.0, ) -> None: if server_url is None: @@ -54,6 +58,19 @@ class Bot: raise ValueError("scheme must be either http or https") self.scheme = scheme + if image_generation_endpoint and image_generation_backend not in [ + "openai", + "sdwui", + "localai", + None, + ]: + logger.error("image_generation_backend must be openai or sdwui or localai") + sys.exit(1) + + if image_format not in ["jpeg", "png", None]: + logger.error("image_format should be jpeg or png, leave blank for jpeg") + sys.exit(1) + # @chatgpt if username is None: raise ValueError("username must be provided") @@ -78,6 +95,21 @@ class Bot: ) self.image_generation_endpoint: str = image_generation_endpoint self.image_generation_backend: str = image_generation_backend + + if image_format: + self.image_format: str = image_format + else: + self.image_format = "jpeg" + + if image_generation_size is None: + self.image_generation_size = "512x512" + self.image_generation_width = 512 + self.image_generation_height = 512 + else: + self.image_generation_size = image_generation_size + self.image_generation_width = self.image_generation_size.split("x")[0] + self.image_generation_height = self.image_generation_size.split("x")[1] + self.timeout = timeout or 120.0 self.bot_id = None @@ -251,20 +283,19 @@ class Bot: "channel_id": channel_id, }, ) - b64_datas = await imagegen.get_images( + image_path_list = await imagegen.get_images( self.httpx_client, self.image_generation_endpoint, prompt, self.image_generation_backend, timeount=self.timeout, api_key=self.openai_api_key, + output_path=self.base_path / "images", n=1, - size="256x256", - ) - image_path_list = await asyncio.to_thread( - imagegen.save_images, - b64_datas, - self.base_path / "images", + size=self.image_generation_size, + width=self.image_generation_width, + height=self.image_generation_height, + image_format=self.image_format, ) # send image for image_path in image_path_list: @@ -274,6 +305,7 @@ class Bot: image_path, root_id, ) + await aiofiles.os.remove(image_path) except Exception as e: logger.error(e, exc_info=True) raise Exception(e) @@ -321,8 +353,7 @@ class Bot: "root_id": root_id, } ) - # remove image after posting - os.remove(filepath) + except Exception as e: logger.error(e, exc_info=True) raise Exception(e) diff --git a/src/imagegen.py b/src/imagegen.py index 2214eac..8f059d9 100644 --- a/src/imagegen.py +++ b/src/imagegen.py @@ -7,9 +7,14 @@ from PIL import Image async def get_images( - aclient: httpx.AsyncClient, url: str, prompt: str, backend_type: str, **kwargs + aclient: httpx.AsyncClient, + url: str, + prompt: str, + backend_type: str, + output_path: str, + **kwargs, ) -> list[str]: - timeout = kwargs.get("timeout", 120.0) + timeout = kwargs.get("timeout", 180.0) if backend_type == "openai": resp = await aclient.post( url, @@ -20,7 +25,7 @@ async def get_images( json={ "prompt": prompt, "n": kwargs.get("n", 1), - "size": kwargs.get("size", "256x256"), + "size": kwargs.get("size", "512x512"), "response_format": "b64_json", }, timeout=timeout, @@ -29,7 +34,7 @@ async def get_images( b64_datas = [] for data in resp.json()["data"]: b64_datas.append(data["b64_json"]) - return b64_datas + return save_images_b64(b64_datas, output_path, **kwargs) else: raise Exception( f"{resp.status_code} {resp.reason_phrase} {resp.text}", @@ -45,25 +50,56 @@ async def get_images( "sampler_name": kwargs.get("sampler_name", "Euler a"), "batch_size": kwargs.get("n", 1), "steps": kwargs.get("steps", 20), - "width": 256 if "256" in kwargs.get("size") else 512, - "height": 256 if "256" in kwargs.get("size") else 512, + "width": kwargs.get("width", 512), + "height": kwargs.get("height", 512), }, timeout=timeout, ) if resp.status_code == 200: b64_datas = resp.json()["images"] - return b64_datas + return save_images_b64(b64_datas, output_path, **kwargs) else: raise Exception( f"{resp.status_code} {resp.reason_phrase} {resp.text}", ) + elif backend_type == "localai": + resp = await aclient.post( + url, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {kwargs.get('api_key')}", + }, + json={ + "prompt": prompt, + "size": kwargs.get("size", "512x512"), + }, + timeout=timeout, + ) + if resp.status_code == 200: + image_url = resp.json()["data"][0]["url"] + return await save_image_url(image_url, aclient, output_path, **kwargs) -def save_images(b64_datas: list[str], path: Path, **kwargs) -> list[str]: - images = [] +def save_images_b64(b64_datas: list[str], path: Path, **kwargs) -> list[str]: + images_path_list = [] for b64_data in b64_datas: - image_path = path / (str(uuid.uuid4()) + ".jpeg") + image_path = path / ( + str(uuid.uuid4()) + "." + kwargs.get("image_format", "jpeg") + ) img = Image.open(io.BytesIO(base64.decodebytes(bytes(b64_data, "utf-8")))) img.save(image_path) - images.append(image_path) - return images + images_path_list.append(image_path) + return images_path_list + + +async def save_image_url( + url: str, aclient: httpx.AsyncClient, path: Path, **kwargs +) -> list[str]: + images_path_list = [] + r = await aclient.get(url) + image_path = path / (str(uuid.uuid4()) + "." + kwargs.get("image_format", "jpeg")) + if r.status_code == 200: + img = Image.open(io.BytesIO(r.content)) + img.save(image_path) + images_path_list.append(image_path) + return images_path_list diff --git a/src/main.py b/src/main.py index 0934842..37d3cb8 100644 --- a/src/main.py +++ b/src/main.py @@ -39,6 +39,8 @@ async def main(): temperature=config.get("temperature"), image_generation_endpoint=config.get("image_generation_endpoint"), image_generation_backend=config.get("image_generation_backend"), + image_generation_size=config.get("image_generation_size"), + image_format=config.get("image_format"), timeout=config.get("timeout"), ) @@ -48,21 +50,23 @@ async def main(): email=os.environ.get("EMAIL"), password=os.environ.get("PASSWORD"), username=os.environ.get("USERNAME"), - port=os.environ.get("PORT"), + port=int(os.environ.get("PORT", 443)), scheme=os.environ.get("SCHEME"), openai_api_key=os.environ.get("OPENAI_API_KEY"), gpt_api_endpoint=os.environ.get("GPT_API_ENDPOINT"), gpt_model=os.environ.get("GPT_MODEL"), - max_tokens=os.environ.get("MAX_TOKENS"), - top_p=os.environ.get("TOP_P"), - presence_penalty=os.environ.get("PRESENCE_PENALTY"), - frequency_penalty=os.environ.get("FREQUENCY_PENALTY"), - reply_count=os.environ.get("REPLY_COUNT"), + max_tokens=int(os.environ.get("MAX_TOKENS", 4000)), + top_p=float(os.environ.get("TOP_P", 1.0)), + presence_penalty=float(os.environ.get("PRESENCE_PENALTY", 0.0)), + frequency_penalty=float(os.environ.get("FREQUENCY_PENALTY", 0.0)), + reply_count=int(os.environ.get("REPLY_COUNT", 1)), system_prompt=os.environ.get("SYSTEM_PROMPT"), - temperature=os.environ.get("TEMPERATURE"), + temperature=float(os.environ.get("TEMPERATURE", 0.8)), image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"), image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"), - timeout=os.environ.get("TIMEOUT"), + image_generation_size=os.environ.get("IMAGE_GENERATION_SIZE"), + image_format=os.environ.get("IMAGE_FORMAT"), + timeout=float(os.environ.get("TIMEOUT", 120.0)), ) await mattermost_bot.login()