From 5fb07155373d281e2c1dd20c2b49d0e67b4ad832 Mon Sep 17 00:00:00 2001 From: hibobmaster <32976627+hibobmaster@users.noreply.github.com> Date: Thu, 4 Jan 2024 21:35:28 +0800 Subject: [PATCH] feat: Expose more stable diffusion webui api parameters --- .full-env.example | 3 +++ CHANGELOG.md | 3 +++ full-config.json.example | 3 +++ src/bot.py | 10 ++++++++++ src/imagegen.py | 1 + src/main.py | 6 ++++++ 6 files changed, 26 insertions(+) diff --git a/.full-env.example b/.full-env.example index 50ae9fe..7ec4aee 100644 --- a/.full-env.example +++ b/.full-env.example @@ -21,4 +21,7 @@ IMAGE_GENERATION_ENDPOINT="http://127.0.0.1:7860/sdapi/v1/txt2img" IMAGE_GENERATION_BACKEND="sdwui" # openai or sdwui or localai IMAGE_GENERATION_SIZE="512x512" IMAGE_FORMAT="webp" +SDWUI_STEPS=20 +SDWUI_SAMPLER_NAME="Euler a" +SDWUI_CFG_SCALE=7 TIMEOUT=120.0 diff --git a/CHANGELOG.md b/CHANGELOG.md index 982abe9..d9d4726 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +## 1.5.2 +- Expose more stable diffusion webui api parameters + ## 1.5.1 - fix: set timeout not work in image generation diff --git a/full-config.json.example b/full-config.json.example index 77e6213..efbe4c0 100644 --- a/full-config.json.example +++ b/full-config.json.example @@ -21,6 +21,9 @@ "image_generation_endpoint": "http://localai:8080/v1/images/generations", "image_generation_backend": "localai", "image_generation_size": "512x512", + "sdwui_steps": 20, + "sdwui_sampler_name": "Euler a", + "sdwui_cfg_scale": 7, "image_format": "webp", "timeout": 120.0 } diff --git a/src/bot.py b/src/bot.py index c89d8a7..118d475 100644 --- a/src/bot.py +++ b/src/bot.py @@ -70,6 +70,9 @@ class Bot: image_generation_backend: Optional[str] = None, image_generation_size: Optional[str] = None, image_format: Optional[str] = None, + sdwui_steps: Optional[int] = None, + sdwui_sampler_name: Optional[str] = None, + sdwui_cfg_scale: Optional[float] = None, timeout: Union[float, None] = None, ): if homeserver is None or user_id is None or device_id is None: @@ -138,6 +141,10 @@ class Bot: self.image_generation_width = self.image_generation_size.split("x")[0] self.image_generation_height = self.image_generation_size.split("x")[1] + self.sdwui_steps = sdwui_steps + self.sdwui_sampler_name = sdwui_sampler_name + self.sdwui_cfg_scale = sdwui_cfg_scale + self.timeout: float = timeout or 120.0 self.base_path = Path(os.path.dirname(__file__)).parent @@ -1368,6 +1375,9 @@ class Bot: size=self.image_generation_size, width=self.image_generation_width, height=self.image_generation_height, + steps=self.sdwui_steps, + sampler_name=self.sdwui_sampler_name, + cfg_scale=self.sdwui_cfg_scale, image_format=self.image_format, ) # send image diff --git a/src/imagegen.py b/src/imagegen.py index 8f059d9..d32356d 100644 --- a/src/imagegen.py +++ b/src/imagegen.py @@ -48,6 +48,7 @@ async def get_images( json={ "prompt": prompt, "sampler_name": kwargs.get("sampler_name", "Euler a"), + "cfg_scale": kwargs.get("cfg_scale", 7), "batch_size": kwargs.get("n", 1), "steps": kwargs.get("steps", 20), "width": kwargs.get("width", 512), diff --git a/src/main.py b/src/main.py index 48bbceb..6d37154 100644 --- a/src/main.py +++ b/src/main.py @@ -45,6 +45,9 @@ async def main(): image_generation_endpoint=config.get("image_generation_endpoint"), image_generation_backend=config.get("image_generation_backend"), image_generation_size=config.get("image_generation_size"), + sdwui_steps=config.get("sdwui_steps"), + sdwui_sampler_name=config.get("sdwui_sampler_name"), + sdwui_cfg_scale=config.get("sdwui_cfg_scale"), image_format=config.get("image_format"), timeout=config.get("timeout"), ) @@ -78,6 +81,9 @@ async def main(): image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"), image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"), image_generation_size=os.environ.get("IMAGE_GENERATION_SIZE"), + sdwui_steps=int(os.environ.get("SDWUI_STEPS", 20)), + sdwui_sampler_name=os.environ.get("SDWUI_SAMPLER_NAME"), + sdwui_cfg_scale=float(os.environ.get("SDWUI_CFG_SCALE", 7)), image_format=os.environ.get("IMAGE_FORMAT"), timeout=float(os.environ.get("TIMEOUT", 120.0)), )