From ffbe6d520ef9b4d23c0d731de528b75f61fe472f Mon Sep 17 00:00:00 2001 From: hibobmaster <32976627+hibobmaster@users.noreply.github.com> Date: Thu, 4 Jan 2024 20:49:56 +0800 Subject: [PATCH] feat: Expose more stable diffusion webui api parameters --- .full-env.example | 3 +++ CHANGELOG.md | 4 ++++ full-config.json.example | 3 +++ src/bot.py | 10 ++++++++++ src/imagegen.py | 1 + src/main.py | 6 ++++++ 6 files changed, 27 insertions(+) diff --git a/.full-env.example b/.full-env.example index 0176496..9134ca2 100644 --- a/.full-env.example +++ b/.full-env.example @@ -18,4 +18,7 @@ IMAGE_GENERATION_ENDPOINT="http://127.0.0.1:7860/sdapi/v1/txt2img" IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui or localai IMAGE_GENERATION_SIZE="512x512" IMAGE_FORMAT="jpeg" +SDWUI_STEPS=20 +SDWUI_SAMPLER_NAME="Euler a" +SDWUI_CFG_SCALE=7 TIMEOUT=120.0 diff --git a/CHANGELOG.md b/CHANGELOG.md index ffb3bb3..2d7a3f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,8 @@ # Changelog + +## v1.3.1 +- Expose more stable diffusion webui api parameters + ## v1.3.0 - Fix localai v2.0+ image generation - Support specific output image format(jpeg, png) and size diff --git a/full-config.json.example b/full-config.json.example index cdbff48..def7963 100644 --- a/full-config.json.example +++ b/full-config.json.example @@ -18,6 +18,9 @@ "image_generation_endpoint": "http://localai:8080/v1/images/generations", "image_generation_backend": "localai", "image_generation_size": "512x512", + "sdwui_steps": 20, + "sdwui_sampler_name": "Euler a", + "sdwui_cfg_scale": 7, "image_format": "jpeg", "timeout": 120.0 } diff --git a/src/bot.py b/src/bot.py index 5c1ca59..5d1a2ac 100644 --- a/src/bot.py +++ b/src/bot.py @@ -37,6 +37,9 @@ class Bot: image_generation_endpoint: Optional[str] = None, image_generation_backend: Optional[str] = None, image_generation_size: Optional[str] = None, + sdwui_steps: Optional[int] = None, + sdwui_sampler_name: Optional[str] = None, + sdwui_cfg_scale: Optional[float] = None, image_format: Optional[str] = None, timeout: Optional[float] = 120.0, ) -> None: @@ -110,6 +113,10 @@ class Bot: self.image_generation_width = self.image_generation_size.split("x")[0] self.image_generation_height = self.image_generation_size.split("x")[1] + self.sdwui_steps = sdwui_steps + self.sdwui_sampler_name = sdwui_sampler_name + self.sdwui_cfg_scale = sdwui_cfg_scale + self.timeout = timeout or 120.0 self.bot_id = None @@ -295,6 +302,9 @@ class Bot: size=self.image_generation_size, width=self.image_generation_width, height=self.image_generation_height, + steps=self.sdwui_steps, + sampler_name=self.sdwui_sampler_name, + cfg_scale=self.sdwui_cfg_scale, image_format=self.image_format, ) # send image diff --git a/src/imagegen.py b/src/imagegen.py index 8f059d9..d32356d 100644 --- a/src/imagegen.py +++ b/src/imagegen.py @@ -48,6 +48,7 @@ async def get_images( json={ "prompt": prompt, "sampler_name": kwargs.get("sampler_name", "Euler a"), + "cfg_scale": kwargs.get("cfg_scale", 7), "batch_size": kwargs.get("n", 1), "steps": kwargs.get("steps", 20), "width": kwargs.get("width", 512), diff --git a/src/main.py b/src/main.py index 37d3cb8..ec6c2f5 100644 --- a/src/main.py +++ b/src/main.py @@ -40,6 +40,9 @@ async def main(): image_generation_endpoint=config.get("image_generation_endpoint"), image_generation_backend=config.get("image_generation_backend"), image_generation_size=config.get("image_generation_size"), + sdwui_steps=config.get("sdwui_steps"), + sdwui_sampler_name=config.get("sdwui_sampler_name"), + sdwui_cfg_scale=config.get("sdwui_cfg_scale"), image_format=config.get("image_format"), timeout=config.get("timeout"), ) @@ -65,6 +68,9 @@ async def main(): image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"), image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"), image_generation_size=os.environ.get("IMAGE_GENERATION_SIZE"), + sdwui_steps=int(os.environ.get("SDWUI_STEPS", 20)), + sdwui_sampler_name=os.environ.get("SDWUI_SAMPLER_NAME"), + sdwui_cfg_scale=float(os.environ.get("SDWUI_CFG_SCALE", 7)), image_format=os.environ.get("IMAGE_FORMAT"), timeout=float(os.environ.get("TIMEOUT", 120.0)), )