Fix localai v2.0+ image generation
Support specific output image format(jpeg, png) and size
This commit is contained in:
parent
4dd62d3940
commit
e0aff19905
6 changed files with 107 additions and 31 deletions
|
@ -15,5 +15,7 @@ REPLY_COUNT=1
|
|||
SYSTEM_PROMPT="You are ChatGPT, a large language model trained by OpenAI. Respond conversationally"
|
||||
TEMPERATURE=0.8
|
||||
IMAGE_GENERATION_ENDPOINT="http://127.0.0.1:7860/sdapi/v1/txt2img"
|
||||
IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui
|
||||
IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui or localai
|
||||
IMAGE_GENERATION_SIZE="512x512"
|
||||
IMAGE_FORMAT="jpeg"
|
||||
TIMEOUT=120.0
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
"temperature": 0.8,
|
||||
"system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
||||
"image_generation_endpoint": "http://localai:8080/v1/images/generations",
|
||||
"image_generation_backend": "openai",
|
||||
"image_generation_backend": "localai",
|
||||
"image_generation_size": "512x512",
|
||||
"image_format": "webp",
|
||||
"timeout": 120.0
|
||||
}
|
||||
|
|
|
@ -2,4 +2,5 @@ httpx
|
|||
Pillow
|
||||
tiktoken
|
||||
tenacity
|
||||
aiofiles
|
||||
mattermostdriver @ git+https://github.com/hibobmaster/python-mattermost-driver
|
||||
|
|
49
src/bot.py
49
src/bot.py
|
@ -1,3 +1,5 @@
|
|||
import sys
|
||||
import aiofiles.os
|
||||
from mattermostdriver import AsyncDriver
|
||||
from typing import Optional
|
||||
import json
|
||||
|
@ -34,6 +36,8 @@ class Bot:
|
|||
temperature: Optional[float] = None,
|
||||
image_generation_endpoint: Optional[str] = None,
|
||||
image_generation_backend: Optional[str] = None,
|
||||
image_generation_size: Optional[str] = None,
|
||||
image_format: Optional[str] = None,
|
||||
timeout: Optional[float] = 120.0,
|
||||
) -> None:
|
||||
if server_url is None:
|
||||
|
@ -54,6 +58,19 @@ class Bot:
|
|||
raise ValueError("scheme must be either http or https")
|
||||
self.scheme = scheme
|
||||
|
||||
if image_generation_endpoint and image_generation_backend not in [
|
||||
"openai",
|
||||
"sdwui",
|
||||
"localai",
|
||||
None,
|
||||
]:
|
||||
logger.error("image_generation_backend must be openai or sdwui or localai")
|
||||
sys.exit(1)
|
||||
|
||||
if image_format not in ["jpeg", "png", None]:
|
||||
logger.error("image_format should be jpeg or png, leave blank for jpeg")
|
||||
sys.exit(1)
|
||||
|
||||
# @chatgpt
|
||||
if username is None:
|
||||
raise ValueError("username must be provided")
|
||||
|
@ -78,6 +95,21 @@ class Bot:
|
|||
)
|
||||
self.image_generation_endpoint: str = image_generation_endpoint
|
||||
self.image_generation_backend: str = image_generation_backend
|
||||
|
||||
if image_format:
|
||||
self.image_format: str = image_format
|
||||
else:
|
||||
self.image_format = "jpeg"
|
||||
|
||||
if image_generation_size is None:
|
||||
self.image_generation_size = "512x512"
|
||||
self.image_generation_width = 512
|
||||
self.image_generation_height = 512
|
||||
else:
|
||||
self.image_generation_size = image_generation_size
|
||||
self.image_generation_width = self.image_generation_size.split("x")[0]
|
||||
self.image_generation_height = self.image_generation_size.split("x")[1]
|
||||
|
||||
self.timeout = timeout or 120.0
|
||||
|
||||
self.bot_id = None
|
||||
|
@ -251,20 +283,19 @@ class Bot:
|
|||
"channel_id": channel_id,
|
||||
},
|
||||
)
|
||||
b64_datas = await imagegen.get_images(
|
||||
image_path_list = await imagegen.get_images(
|
||||
self.httpx_client,
|
||||
self.image_generation_endpoint,
|
||||
prompt,
|
||||
self.image_generation_backend,
|
||||
timeount=self.timeout,
|
||||
api_key=self.openai_api_key,
|
||||
output_path=self.base_path / "images",
|
||||
n=1,
|
||||
size="256x256",
|
||||
)
|
||||
image_path_list = await asyncio.to_thread(
|
||||
imagegen.save_images,
|
||||
b64_datas,
|
||||
self.base_path / "images",
|
||||
size=self.image_generation_size,
|
||||
width=self.image_generation_width,
|
||||
height=self.image_generation_height,
|
||||
image_format=self.image_format,
|
||||
)
|
||||
# send image
|
||||
for image_path in image_path_list:
|
||||
|
@ -274,6 +305,7 @@ class Bot:
|
|||
image_path,
|
||||
root_id,
|
||||
)
|
||||
await aiofiles.os.remove(image_path)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise Exception(e)
|
||||
|
@ -321,8 +353,7 @@ class Bot:
|
|||
"root_id": root_id,
|
||||
}
|
||||
)
|
||||
# remove image after posting
|
||||
os.remove(filepath)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise Exception(e)
|
||||
|
|
|
@ -7,9 +7,14 @@ from PIL import Image
|
|||
|
||||
|
||||
async def get_images(
|
||||
aclient: httpx.AsyncClient, url: str, prompt: str, backend_type: str, **kwargs
|
||||
aclient: httpx.AsyncClient,
|
||||
url: str,
|
||||
prompt: str,
|
||||
backend_type: str,
|
||||
output_path: str,
|
||||
**kwargs,
|
||||
) -> list[str]:
|
||||
timeout = kwargs.get("timeout", 120.0)
|
||||
timeout = kwargs.get("timeout", 180.0)
|
||||
if backend_type == "openai":
|
||||
resp = await aclient.post(
|
||||
url,
|
||||
|
@ -20,7 +25,7 @@ async def get_images(
|
|||
json={
|
||||
"prompt": prompt,
|
||||
"n": kwargs.get("n", 1),
|
||||
"size": kwargs.get("size", "256x256"),
|
||||
"size": kwargs.get("size", "512x512"),
|
||||
"response_format": "b64_json",
|
||||
},
|
||||
timeout=timeout,
|
||||
|
@ -29,7 +34,7 @@ async def get_images(
|
|||
b64_datas = []
|
||||
for data in resp.json()["data"]:
|
||||
b64_datas.append(data["b64_json"])
|
||||
return b64_datas
|
||||
return save_images_b64(b64_datas, output_path, **kwargs)
|
||||
else:
|
||||
raise Exception(
|
||||
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
||||
|
@ -45,25 +50,56 @@ async def get_images(
|
|||
"sampler_name": kwargs.get("sampler_name", "Euler a"),
|
||||
"batch_size": kwargs.get("n", 1),
|
||||
"steps": kwargs.get("steps", 20),
|
||||
"width": 256 if "256" in kwargs.get("size") else 512,
|
||||
"height": 256 if "256" in kwargs.get("size") else 512,
|
||||
"width": kwargs.get("width", 512),
|
||||
"height": kwargs.get("height", 512),
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
b64_datas = resp.json()["images"]
|
||||
return b64_datas
|
||||
return save_images_b64(b64_datas, output_path, **kwargs)
|
||||
else:
|
||||
raise Exception(
|
||||
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
||||
)
|
||||
elif backend_type == "localai":
|
||||
resp = await aclient.post(
|
||||
url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {kwargs.get('api_key')}",
|
||||
},
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"size": kwargs.get("size", "512x512"),
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
image_url = resp.json()["data"][0]["url"]
|
||||
return await save_image_url(image_url, aclient, output_path, **kwargs)
|
||||
|
||||
|
||||
def save_images(b64_datas: list[str], path: Path, **kwargs) -> list[str]:
|
||||
images = []
|
||||
def save_images_b64(b64_datas: list[str], path: Path, **kwargs) -> list[str]:
|
||||
images_path_list = []
|
||||
for b64_data in b64_datas:
|
||||
image_path = path / (str(uuid.uuid4()) + ".jpeg")
|
||||
image_path = path / (
|
||||
str(uuid.uuid4()) + "." + kwargs.get("image_format", "jpeg")
|
||||
)
|
||||
img = Image.open(io.BytesIO(base64.decodebytes(bytes(b64_data, "utf-8"))))
|
||||
img.save(image_path)
|
||||
images.append(image_path)
|
||||
return images
|
||||
images_path_list.append(image_path)
|
||||
return images_path_list
|
||||
|
||||
|
||||
async def save_image_url(
|
||||
url: str, aclient: httpx.AsyncClient, path: Path, **kwargs
|
||||
) -> list[str]:
|
||||
images_path_list = []
|
||||
r = await aclient.get(url)
|
||||
image_path = path / (str(uuid.uuid4()) + "." + kwargs.get("image_format", "jpeg"))
|
||||
if r.status_code == 200:
|
||||
img = Image.open(io.BytesIO(r.content))
|
||||
img.save(image_path)
|
||||
images_path_list.append(image_path)
|
||||
return images_path_list
|
||||
|
|
20
src/main.py
20
src/main.py
|
@ -39,6 +39,8 @@ async def main():
|
|||
temperature=config.get("temperature"),
|
||||
image_generation_endpoint=config.get("image_generation_endpoint"),
|
||||
image_generation_backend=config.get("image_generation_backend"),
|
||||
image_generation_size=config.get("image_generation_size"),
|
||||
image_format=config.get("image_format"),
|
||||
timeout=config.get("timeout"),
|
||||
)
|
||||
|
||||
|
@ -48,21 +50,23 @@ async def main():
|
|||
email=os.environ.get("EMAIL"),
|
||||
password=os.environ.get("PASSWORD"),
|
||||
username=os.environ.get("USERNAME"),
|
||||
port=os.environ.get("PORT"),
|
||||
port=int(os.environ.get("PORT", 443)),
|
||||
scheme=os.environ.get("SCHEME"),
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
gpt_api_endpoint=os.environ.get("GPT_API_ENDPOINT"),
|
||||
gpt_model=os.environ.get("GPT_MODEL"),
|
||||
max_tokens=os.environ.get("MAX_TOKENS"),
|
||||
top_p=os.environ.get("TOP_P"),
|
||||
presence_penalty=os.environ.get("PRESENCE_PENALTY"),
|
||||
frequency_penalty=os.environ.get("FREQUENCY_PENALTY"),
|
||||
reply_count=os.environ.get("REPLY_COUNT"),
|
||||
max_tokens=int(os.environ.get("MAX_TOKENS", 4000)),
|
||||
top_p=float(os.environ.get("TOP_P", 1.0)),
|
||||
presence_penalty=float(os.environ.get("PRESENCE_PENALTY", 0.0)),
|
||||
frequency_penalty=float(os.environ.get("FREQUENCY_PENALTY", 0.0)),
|
||||
reply_count=int(os.environ.get("REPLY_COUNT", 1)),
|
||||
system_prompt=os.environ.get("SYSTEM_PROMPT"),
|
||||
temperature=os.environ.get("TEMPERATURE"),
|
||||
temperature=float(os.environ.get("TEMPERATURE", 0.8)),
|
||||
image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"),
|
||||
image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"),
|
||||
timeout=os.environ.get("TIMEOUT"),
|
||||
image_generation_size=os.environ.get("IMAGE_GENERATION_SIZE"),
|
||||
image_format=os.environ.get("IMAGE_FORMAT"),
|
||||
timeout=float(os.environ.get("TIMEOUT", 120.0)),
|
||||
)
|
||||
|
||||
await mattermost_bot.login()
|
||||
|
|
Loading…
Reference in a new issue