Fix localai v2.0+ image generation
This commit is contained in:
parent
f4d7b9212a
commit
fac14a4244
5 changed files with 90 additions and 24 deletions
|
@ -18,5 +18,7 @@ SYSTEM_PROMPT="You are ChatGPT, a large language model trained by OpenAI. Respo
|
||||||
TEMPERATURE=0.8
|
TEMPERATURE=0.8
|
||||||
LC_ADMIN="@admin:xxxxxx.xxx,@admin2:xxxxxx.xxx"
|
LC_ADMIN="@admin:xxxxxx.xxx,@admin2:xxxxxx.xxx"
|
||||||
IMAGE_GENERATION_ENDPOINT="http://127.0.0.1:7860/sdapi/v1/txt2img"
|
IMAGE_GENERATION_ENDPOINT="http://127.0.0.1:7860/sdapi/v1/txt2img"
|
||||||
IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui
|
IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui or localai
|
||||||
|
IMAGE_GENERATION_SIZE="512x512"
|
||||||
|
IMAGE_FORMAT="webp"
|
||||||
TIMEOUT=120.0
|
TIMEOUT=120.0
|
||||||
|
|
|
@ -19,6 +19,8 @@
|
||||||
"system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
"system_prompt": "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
||||||
"lc_admin": ["@admin:xxxxx.org"],
|
"lc_admin": ["@admin:xxxxx.org"],
|
||||||
"image_generation_endpoint": "http://localai:8080/v1/images/generations",
|
"image_generation_endpoint": "http://localai:8080/v1/images/generations",
|
||||||
"image_generation_backend": "openai",
|
"image_generation_backend": "localai",
|
||||||
|
"image_generation_size": "512x512",
|
||||||
|
"image_format": "webp",
|
||||||
"timeout": 120.0
|
"timeout": 120.0
|
||||||
}
|
}
|
||||||
|
|
42
src/bot.py
42
src/bot.py
|
@ -68,22 +68,31 @@ class Bot:
|
||||||
lc_admin: Optional[list[str]] = None,
|
lc_admin: Optional[list[str]] = None,
|
||||||
image_generation_endpoint: Optional[str] = None,
|
image_generation_endpoint: Optional[str] = None,
|
||||||
image_generation_backend: Optional[str] = None,
|
image_generation_backend: Optional[str] = None,
|
||||||
|
image_generation_size: Optional[str] = None,
|
||||||
|
image_format: Optional[str] = None,
|
||||||
timeout: Union[float, None] = None,
|
timeout: Union[float, None] = None,
|
||||||
):
|
):
|
||||||
if homeserver is None or user_id is None or device_id is None:
|
if homeserver is None or user_id is None or device_id is None:
|
||||||
logger.warning("homeserver && user_id && device_id is required")
|
logger.error("homeserver && user_id && device_id is required")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if password is None and access_token is None:
|
if password is None and access_token is None:
|
||||||
logger.warning("password is required")
|
logger.error("password is required")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if image_generation_endpoint and image_generation_backend not in [
|
if image_generation_endpoint and image_generation_backend not in [
|
||||||
"openai",
|
"openai",
|
||||||
"sdwui",
|
"sdwui",
|
||||||
|
"localai",
|
||||||
None,
|
None,
|
||||||
]:
|
]:
|
||||||
logger.warning("image_generation_backend must be openai or sdwui")
|
logger.error("image_generation_backend must be openai or sdwui or localai")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if image_format not in ["jpeg", "webp", "png", None]:
|
||||||
|
logger.error(
|
||||||
|
"image_format should be jpeg or webp or png, leave blank for jpeg"
|
||||||
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
self.homeserver: str = homeserver
|
self.homeserver: str = homeserver
|
||||||
|
@ -115,6 +124,20 @@ class Bot:
|
||||||
self.image_generation_endpoint: str = image_generation_endpoint
|
self.image_generation_endpoint: str = image_generation_endpoint
|
||||||
self.image_generation_backend: str = image_generation_backend
|
self.image_generation_backend: str = image_generation_backend
|
||||||
|
|
||||||
|
if image_format:
|
||||||
|
self.image_format: str = image_format
|
||||||
|
else:
|
||||||
|
self.image_format = "jpeg"
|
||||||
|
|
||||||
|
if image_generation_size is None:
|
||||||
|
self.image_generation_size = "512x512"
|
||||||
|
self.image_generation_width = 512
|
||||||
|
self.image_generation_height = 512
|
||||||
|
else:
|
||||||
|
self.image_generation_size = image_generation_size
|
||||||
|
self.image_generation_width = self.image_generation_size.split("x")[0]
|
||||||
|
self.image_generation_height = self.image_generation_size.split("x")[1]
|
||||||
|
|
||||||
self.timeout: float = timeout or 120.0
|
self.timeout: float = timeout or 120.0
|
||||||
|
|
||||||
self.base_path = Path(os.path.dirname(__file__)).parent
|
self.base_path = Path(os.path.dirname(__file__)).parent
|
||||||
|
@ -1333,20 +1356,19 @@ class Bot:
|
||||||
if self.image_generation_endpoint is not None:
|
if self.image_generation_endpoint is not None:
|
||||||
await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
|
await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
|
||||||
# generate image
|
# generate image
|
||||||
b64_datas = await imagegen.get_images(
|
image_path_list = await imagegen.get_images(
|
||||||
self.httpx_client,
|
self.httpx_client,
|
||||||
self.image_generation_endpoint,
|
self.image_generation_endpoint,
|
||||||
prompt,
|
prompt,
|
||||||
self.image_generation_backend,
|
self.image_generation_backend,
|
||||||
timeount=self.timeout,
|
timeount=self.timeout,
|
||||||
api_key=self.openai_api_key,
|
api_key=self.openai_api_key,
|
||||||
|
output_path=self.base_path / "images",
|
||||||
n=1,
|
n=1,
|
||||||
size="256x256",
|
size=self.image_generation_size,
|
||||||
)
|
width=self.image_generation_width,
|
||||||
image_path_list = await asyncio.to_thread(
|
height=self.image_generation_height,
|
||||||
imagegen.save_images,
|
image_format=self.image_format,
|
||||||
b64_datas,
|
|
||||||
self.base_path / "images",
|
|
||||||
)
|
)
|
||||||
# send image
|
# send image
|
||||||
for image_path in image_path_list:
|
for image_path in image_path_list:
|
||||||
|
|
|
@ -7,9 +7,14 @@ from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
async def get_images(
|
async def get_images(
|
||||||
aclient: httpx.AsyncClient, url: str, prompt: str, backend_type: str, **kwargs
|
aclient: httpx.AsyncClient,
|
||||||
|
url: str,
|
||||||
|
prompt: str,
|
||||||
|
backend_type: str,
|
||||||
|
output_path: str,
|
||||||
|
**kwargs,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
timeout = kwargs.get("timeout", 120.0)
|
timeout = kwargs.get("timeout", 180.0)
|
||||||
if backend_type == "openai":
|
if backend_type == "openai":
|
||||||
resp = await aclient.post(
|
resp = await aclient.post(
|
||||||
url,
|
url,
|
||||||
|
@ -20,7 +25,7 @@ async def get_images(
|
||||||
json={
|
json={
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"n": kwargs.get("n", 1),
|
"n": kwargs.get("n", 1),
|
||||||
"size": kwargs.get("size", "256x256"),
|
"size": kwargs.get("size", "512x512"),
|
||||||
"response_format": "b64_json",
|
"response_format": "b64_json",
|
||||||
},
|
},
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
@ -29,7 +34,7 @@ async def get_images(
|
||||||
b64_datas = []
|
b64_datas = []
|
||||||
for data in resp.json()["data"]:
|
for data in resp.json()["data"]:
|
||||||
b64_datas.append(data["b64_json"])
|
b64_datas.append(data["b64_json"])
|
||||||
return b64_datas
|
return save_images_b64(b64_datas, output_path, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
||||||
|
@ -45,25 +50,56 @@ async def get_images(
|
||||||
"sampler_name": kwargs.get("sampler_name", "Euler a"),
|
"sampler_name": kwargs.get("sampler_name", "Euler a"),
|
||||||
"batch_size": kwargs.get("n", 1),
|
"batch_size": kwargs.get("n", 1),
|
||||||
"steps": kwargs.get("steps", 20),
|
"steps": kwargs.get("steps", 20),
|
||||||
"width": 256 if "256" in kwargs.get("size") else 512,
|
"width": kwargs.get("width", 512),
|
||||||
"height": 256 if "256" in kwargs.get("size") else 512,
|
"height": kwargs.get("height", 512),
|
||||||
},
|
},
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
if resp.status_code == 200:
|
if resp.status_code == 200:
|
||||||
b64_datas = resp.json()["images"]
|
b64_datas = resp.json()["images"]
|
||||||
return b64_datas
|
return save_images_b64(b64_datas, output_path, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
||||||
)
|
)
|
||||||
|
elif backend_type == "localai":
|
||||||
|
resp = await aclient.post(
|
||||||
|
url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {kwargs.get('api_key')}",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"prompt": prompt,
|
||||||
|
"size": kwargs.get("size", "512x512"),
|
||||||
|
},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
image_url = resp.json()["data"][0]["url"]
|
||||||
|
return await save_image_url(image_url, aclient, output_path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def save_images(b64_datas: list[str], path: Path, **kwargs) -> list[str]:
|
def save_images_b64(b64_datas: list[str], path: Path, **kwargs) -> list[str]:
|
||||||
images = []
|
images_path_list = []
|
||||||
for b64_data in b64_datas:
|
for b64_data in b64_datas:
|
||||||
image_path = path / (str(uuid.uuid4()) + ".jpeg")
|
image_path = path / (
|
||||||
|
str(uuid.uuid4()) + "." + kwargs.get("image_format", "jpeg")
|
||||||
|
)
|
||||||
img = Image.open(io.BytesIO(base64.decodebytes(bytes(b64_data, "utf-8"))))
|
img = Image.open(io.BytesIO(base64.decodebytes(bytes(b64_data, "utf-8"))))
|
||||||
img.save(image_path)
|
img.save(image_path)
|
||||||
images.append(image_path)
|
images_path_list.append(image_path)
|
||||||
return images
|
return images_path_list
|
||||||
|
|
||||||
|
|
||||||
|
async def save_image_url(
|
||||||
|
url: str, aclient: httpx.AsyncClient, path: Path, **kwargs
|
||||||
|
) -> list[str]:
|
||||||
|
images_path_list = []
|
||||||
|
r = await aclient.get(url)
|
||||||
|
image_path = path / (str(uuid.uuid4()) + "." + kwargs.get("image_format", "jpeg"))
|
||||||
|
if r.status_code == 200:
|
||||||
|
img = Image.open(io.BytesIO(r.content))
|
||||||
|
img.save(image_path)
|
||||||
|
images_path_list.append(image_path)
|
||||||
|
return images_path_list
|
||||||
|
|
|
@ -44,6 +44,8 @@ async def main():
|
||||||
lc_admin=config.get("lc_admin"),
|
lc_admin=config.get("lc_admin"),
|
||||||
image_generation_endpoint=config.get("image_generation_endpoint"),
|
image_generation_endpoint=config.get("image_generation_endpoint"),
|
||||||
image_generation_backend=config.get("image_generation_backend"),
|
image_generation_backend=config.get("image_generation_backend"),
|
||||||
|
image_generation_size=config.get("image_generation_size"),
|
||||||
|
image_format=config.get("image_format"),
|
||||||
timeout=config.get("timeout"),
|
timeout=config.get("timeout"),
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
|
@ -75,6 +77,8 @@ async def main():
|
||||||
lc_admin=os.environ.get("LC_ADMIN"),
|
lc_admin=os.environ.get("LC_ADMIN"),
|
||||||
image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"),
|
image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"),
|
||||||
image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"),
|
image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"),
|
||||||
|
image_generation_size=os.environ.get("IMAGE_GENERATION_SIZE"),
|
||||||
|
image_format=os.environ.get("IMAGE_FORMAT"),
|
||||||
timeout=float(os.environ.get("TIMEOUT", 120.0)),
|
timeout=float(os.environ.get("TIMEOUT", 120.0)),
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
|
|
Loading…
Reference in a new issue