2023-09-17 04:27:16 +00:00
|
|
|
import httpx
|
|
|
|
from pathlib import Path
|
|
|
|
import uuid
|
|
|
|
import base64
|
|
|
|
import io
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
async def get_images(
|
|
|
|
aclient: httpx.AsyncClient, url: str, prompt: str, backend_type: str, **kwargs
|
|
|
|
) -> list[str]:
|
|
|
|
timeout = kwargs.get("timeout", 120.0)
|
|
|
|
if backend_type == "openai":
|
|
|
|
resp = await aclient.post(
|
|
|
|
url,
|
|
|
|
headers={
|
|
|
|
"Content-Type": "application/json",
|
2023-09-18 06:19:48 +00:00
|
|
|
"Authorization": f"Bearer {kwargs.get('api_key')}",
|
2023-09-17 04:27:16 +00:00
|
|
|
},
|
|
|
|
json={
|
|
|
|
"prompt": prompt,
|
|
|
|
"n": kwargs.get("n", 1),
|
|
|
|
"size": kwargs.get("size", "256x256"),
|
|
|
|
"response_format": "b64_json",
|
|
|
|
},
|
|
|
|
timeout=timeout,
|
|
|
|
)
|
|
|
|
if resp.status_code == 200:
|
|
|
|
b64_datas = []
|
|
|
|
for data in resp.json()["data"]:
|
|
|
|
b64_datas.append(data["b64_json"])
|
|
|
|
return b64_datas
|
|
|
|
else:
|
|
|
|
raise Exception(
|
|
|
|
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
|
|
|
)
|
|
|
|
elif backend_type == "sdwui":
|
|
|
|
resp = await aclient.post(
|
|
|
|
url,
|
|
|
|
headers={
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
},
|
|
|
|
json={
|
|
|
|
"prompt": prompt,
|
|
|
|
"sampler_name": kwargs.get("sampler_name", "Euler a"),
|
|
|
|
"batch_size": kwargs.get("n", 1),
|
|
|
|
"steps": kwargs.get("steps", 20),
|
|
|
|
"width": 256 if "256" in kwargs.get("size") else 512,
|
|
|
|
"height": 256 if "256" in kwargs.get("size") else 512,
|
|
|
|
},
|
|
|
|
timeout=timeout,
|
|
|
|
)
|
|
|
|
if resp.status_code == 200:
|
|
|
|
b64_datas = resp.json()["images"]
|
|
|
|
return b64_datas
|
|
|
|
else:
|
|
|
|
raise Exception(
|
|
|
|
f"{resp.status_code} {resp.reason_phrase} {resp.text}",
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def save_images(b64_datas: list[str], path: Path, **kwargs) -> list[str]:
|
|
|
|
images = []
|
|
|
|
for b64_data in b64_datas:
|
|
|
|
image_path = path / (str(uuid.uuid4()) + ".jpeg")
|
|
|
|
img = Image.open(io.BytesIO(base64.decodebytes(bytes(b64_data, "utf-8"))))
|
|
|
|
img.save(image_path)
|
|
|
|
images.append(image_path)
|
|
|
|
return images
|