Fix localai v2.0+ image generation

Support specific output image format(jpeg, png) and size
This commit is contained in:
hibobmaster 2023-12-23 22:30:11 +08:00
parent 4dd62d3940
commit e0aff19905
Signed by: bobmaster
SSH key fingerprint: SHA256:5ZYgd8fg+PcNZNy4SzcSKu5JtqZyBF8kUhY7/k2viDk
6 changed files with 107 additions and 31 deletions

View file

@ -15,5 +15,7 @@ REPLY_COUNT=1
SYSTEM_PROMPT="You are ChatGPT, a large language model trained by OpenAI. Respond conversationally" SYSTEM_PROMPT="You are ChatGPT, a large language model trained by OpenAI. Respond conversationally"
TEMPERATURE=0.8 TEMPERATURE=0.8
IMAGE_GENERATION_ENDPOINT="http://127.0.0.1:7860/sdapi/v1/txt2img" IMAGE_GENERATION_ENDPOINT="http://127.0.0.1:7860/sdapi/v1/txt2img"
IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui or localai
IMAGE_GENERATION_SIZE="512x512"
IMAGE_FORMAT="jpeg"
TIMEOUT=120.0 TIMEOUT=120.0

View file

@ -16,6 +16,8 @@
"temperature": 0.8, "temperature": 0.8,
"system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally", "system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
"image_generation_endpoint": "http://localai:8080/v1/images/generations", "image_generation_endpoint": "http://localai:8080/v1/images/generations",
"image_generation_backend": "openai", "image_generation_backend": "localai",
"image_generation_size": "512x512",
"image_format": "webp",
"timeout": 120.0 "timeout": 120.0
} }

View file

@ -2,4 +2,5 @@ httpx
Pillow Pillow
tiktoken tiktoken
tenacity tenacity
aiofiles
mattermostdriver @ git+https://github.com/hibobmaster/python-mattermost-driver mattermostdriver @ git+https://github.com/hibobmaster/python-mattermost-driver

View file

@ -1,3 +1,5 @@
import sys
import aiofiles.os
from mattermostdriver import AsyncDriver from mattermostdriver import AsyncDriver
from typing import Optional from typing import Optional
import json import json
@ -34,6 +36,8 @@ class Bot:
temperature: Optional[float] = None, temperature: Optional[float] = None,
image_generation_endpoint: Optional[str] = None, image_generation_endpoint: Optional[str] = None,
image_generation_backend: Optional[str] = None, image_generation_backend: Optional[str] = None,
image_generation_size: Optional[str] = None,
image_format: Optional[str] = None,
timeout: Optional[float] = 120.0, timeout: Optional[float] = 120.0,
) -> None: ) -> None:
if server_url is None: if server_url is None:
@ -54,6 +58,19 @@ class Bot:
raise ValueError("scheme must be either http or https") raise ValueError("scheme must be either http or https")
self.scheme = scheme self.scheme = scheme
if image_generation_endpoint and image_generation_backend not in [
"openai",
"sdwui",
"localai",
None,
]:
logger.error("image_generation_backend must be openai or sdwui or localai")
sys.exit(1)
if image_format not in ["jpeg", "png", None]:
logger.error("image_format should be jpeg or png, leave blank for jpeg")
sys.exit(1)
# @chatgpt # @chatgpt
if username is None: if username is None:
raise ValueError("username must be provided") raise ValueError("username must be provided")
@ -78,6 +95,21 @@ class Bot:
) )
self.image_generation_endpoint: str = image_generation_endpoint self.image_generation_endpoint: str = image_generation_endpoint
self.image_generation_backend: str = image_generation_backend self.image_generation_backend: str = image_generation_backend
if image_format:
self.image_format: str = image_format
else:
self.image_format = "jpeg"
if image_generation_size is None:
self.image_generation_size = "512x512"
self.image_generation_width = 512
self.image_generation_height = 512
else:
self.image_generation_size = image_generation_size
self.image_generation_width = self.image_generation_size.split("x")[0]
self.image_generation_height = self.image_generation_size.split("x")[1]
self.timeout = timeout or 120.0 self.timeout = timeout or 120.0
self.bot_id = None self.bot_id = None
@ -251,20 +283,19 @@ class Bot:
"channel_id": channel_id, "channel_id": channel_id,
}, },
) )
b64_datas = await imagegen.get_images( image_path_list = await imagegen.get_images(
self.httpx_client, self.httpx_client,
self.image_generation_endpoint, self.image_generation_endpoint,
prompt, prompt,
self.image_generation_backend, self.image_generation_backend,
timeount=self.timeout, timeount=self.timeout,
api_key=self.openai_api_key, api_key=self.openai_api_key,
output_path=self.base_path / "images",
n=1, n=1,
size="256x256", size=self.image_generation_size,
) width=self.image_generation_width,
image_path_list = await asyncio.to_thread( height=self.image_generation_height,
imagegen.save_images, image_format=self.image_format,
b64_datas,
self.base_path / "images",
) )
# send image # send image
for image_path in image_path_list: for image_path in image_path_list:
@ -274,6 +305,7 @@ class Bot:
image_path, image_path,
root_id, root_id,
) )
await aiofiles.os.remove(image_path)
except Exception as e: except Exception as e:
logger.error(e, exc_info=True) logger.error(e, exc_info=True)
raise Exception(e) raise Exception(e)
@ -321,8 +353,7 @@ class Bot:
"root_id": root_id, "root_id": root_id,
} }
) )
# remove image after posting
os.remove(filepath)
except Exception as e: except Exception as e:
logger.error(e, exc_info=True) logger.error(e, exc_info=True)
raise Exception(e) raise Exception(e)

View file

@ -7,9 +7,14 @@ from PIL import Image
async def get_images( async def get_images(
aclient: httpx.AsyncClient, url: str, prompt: str, backend_type: str, **kwargs aclient: httpx.AsyncClient,
url: str,
prompt: str,
backend_type: str,
output_path: str,
**kwargs,
) -> list[str]: ) -> list[str]:
timeout = kwargs.get("timeout", 120.0) timeout = kwargs.get("timeout", 180.0)
if backend_type == "openai": if backend_type == "openai":
resp = await aclient.post( resp = await aclient.post(
url, url,
@ -20,7 +25,7 @@ async def get_images(
json={ json={
"prompt": prompt, "prompt": prompt,
"n": kwargs.get("n", 1), "n": kwargs.get("n", 1),
"size": kwargs.get("size", "256x256"), "size": kwargs.get("size", "512x512"),
"response_format": "b64_json", "response_format": "b64_json",
}, },
timeout=timeout, timeout=timeout,
@ -29,7 +34,7 @@ async def get_images(
b64_datas = [] b64_datas = []
for data in resp.json()["data"]: for data in resp.json()["data"]:
b64_datas.append(data["b64_json"]) b64_datas.append(data["b64_json"])
return b64_datas return save_images_b64(b64_datas, output_path, **kwargs)
else: else:
raise Exception( raise Exception(
f"{resp.status_code} {resp.reason_phrase} {resp.text}", f"{resp.status_code} {resp.reason_phrase} {resp.text}",
@ -45,25 +50,56 @@ async def get_images(
"sampler_name": kwargs.get("sampler_name", "Euler a"), "sampler_name": kwargs.get("sampler_name", "Euler a"),
"batch_size": kwargs.get("n", 1), "batch_size": kwargs.get("n", 1),
"steps": kwargs.get("steps", 20), "steps": kwargs.get("steps", 20),
"width": 256 if "256" in kwargs.get("size") else 512, "width": kwargs.get("width", 512),
"height": 256 if "256" in kwargs.get("size") else 512, "height": kwargs.get("height", 512),
}, },
timeout=timeout, timeout=timeout,
) )
if resp.status_code == 200: if resp.status_code == 200:
b64_datas = resp.json()["images"] b64_datas = resp.json()["images"]
return b64_datas return save_images_b64(b64_datas, output_path, **kwargs)
else: else:
raise Exception( raise Exception(
f"{resp.status_code} {resp.reason_phrase} {resp.text}", f"{resp.status_code} {resp.reason_phrase} {resp.text}",
) )
elif backend_type == "localai":
resp = await aclient.post(
url,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {kwargs.get('api_key')}",
},
json={
"prompt": prompt,
"size": kwargs.get("size", "512x512"),
},
timeout=timeout,
)
if resp.status_code == 200:
image_url = resp.json()["data"][0]["url"]
return await save_image_url(image_url, aclient, output_path, **kwargs)
def save_images(b64_datas: list[str], path: Path, **kwargs) -> list[str]: def save_images_b64(b64_datas: list[str], path: Path, **kwargs) -> list[str]:
images = [] images_path_list = []
for b64_data in b64_datas: for b64_data in b64_datas:
image_path = path / (str(uuid.uuid4()) + ".jpeg") image_path = path / (
str(uuid.uuid4()) + "." + kwargs.get("image_format", "jpeg")
)
img = Image.open(io.BytesIO(base64.decodebytes(bytes(b64_data, "utf-8")))) img = Image.open(io.BytesIO(base64.decodebytes(bytes(b64_data, "utf-8"))))
img.save(image_path) img.save(image_path)
images.append(image_path) images_path_list.append(image_path)
return images return images_path_list
async def save_image_url(
url: str, aclient: httpx.AsyncClient, path: Path, **kwargs
) -> list[str]:
images_path_list = []
r = await aclient.get(url)
image_path = path / (str(uuid.uuid4()) + "." + kwargs.get("image_format", "jpeg"))
if r.status_code == 200:
img = Image.open(io.BytesIO(r.content))
img.save(image_path)
images_path_list.append(image_path)
return images_path_list

View file

@ -39,6 +39,8 @@ async def main():
temperature=config.get("temperature"), temperature=config.get("temperature"),
image_generation_endpoint=config.get("image_generation_endpoint"), image_generation_endpoint=config.get("image_generation_endpoint"),
image_generation_backend=config.get("image_generation_backend"), image_generation_backend=config.get("image_generation_backend"),
image_generation_size=config.get("image_generation_size"),
image_format=config.get("image_format"),
timeout=config.get("timeout"), timeout=config.get("timeout"),
) )
@ -48,21 +50,23 @@ async def main():
email=os.environ.get("EMAIL"), email=os.environ.get("EMAIL"),
password=os.environ.get("PASSWORD"), password=os.environ.get("PASSWORD"),
username=os.environ.get("USERNAME"), username=os.environ.get("USERNAME"),
port=os.environ.get("PORT"), port=int(os.environ.get("PORT", 443)),
scheme=os.environ.get("SCHEME"), scheme=os.environ.get("SCHEME"),
openai_api_key=os.environ.get("OPENAI_API_KEY"), openai_api_key=os.environ.get("OPENAI_API_KEY"),
gpt_api_endpoint=os.environ.get("GPT_API_ENDPOINT"), gpt_api_endpoint=os.environ.get("GPT_API_ENDPOINT"),
gpt_model=os.environ.get("GPT_MODEL"), gpt_model=os.environ.get("GPT_MODEL"),
max_tokens=os.environ.get("MAX_TOKENS"), max_tokens=int(os.environ.get("MAX_TOKENS", 4000)),
top_p=os.environ.get("TOP_P"), top_p=float(os.environ.get("TOP_P", 1.0)),
presence_penalty=os.environ.get("PRESENCE_PENALTY"), presence_penalty=float(os.environ.get("PRESENCE_PENALTY", 0.0)),
frequency_penalty=os.environ.get("FREQUENCY_PENALTY"), frequency_penalty=float(os.environ.get("FREQUENCY_PENALTY", 0.0)),
reply_count=os.environ.get("REPLY_COUNT"), reply_count=int(os.environ.get("REPLY_COUNT", 1)),
system_prompt=os.environ.get("SYSTEM_PROMPT"), system_prompt=os.environ.get("SYSTEM_PROMPT"),
temperature=os.environ.get("TEMPERATURE"), temperature=float(os.environ.get("TEMPERATURE", 0.8)),
image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"), image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"),
image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"), image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"),
timeout=os.environ.get("TIMEOUT"), image_generation_size=os.environ.get("IMAGE_GENERATION_SIZE"),
image_format=os.environ.get("IMAGE_FORMAT"),
timeout=float(os.environ.get("TIMEOUT", 120.0)),
) )
await mattermost_bot.login() await mattermost_bot.login()