from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.sd_samplers import all_samplers from modules.extras import run_pnginfo import modules.shared as shared import uvicorn from fastapi import Body, APIRouter, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, Json import json import io import base64 from PIL import Image sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: Json info: Json class ImageToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: Json info: Json class Api: def __init__(self, app, queue_lock): self.router = APIRouter() self.app = app self.queue_lock = queue_lock self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) if sampler_index is None: raise HTTPException(status_code=404, detail="Sampler not found") populate = txt2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, "sampler_index": sampler_index[0], "do_not_save_samples": True, "do_not_save_grid": True } ) p = StableDiffusionProcessingTxt2Img(**vars(populate)) # Override object param with self.queue_lock: processed = process_images(p) b64images = [] for i in processed.images: buffer = io.BytesIO() i.save(buffer, format="png") b64images.append(base64.b64encode(buffer.getvalue())) return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): sampler_index = sampler_to_index(img2imgreq.sampler_index) if sampler_index is None: raise HTTPException(status_code=404, detail="Sampler not found") init_images = img2imgreq.init_images if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") mask = img2imgreq.mask if mask: raise HTTPException(status_code=400, detail="Mask not supported yet") populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, "sampler_index": sampler_index[0], "do_not_save_samples": True, "do_not_save_grid": True } ) p = StableDiffusionProcessingImg2Img(**vars(populate)) imgs = [] for img in init_images: # if has a comma, deal with prefix if "," in img: img = img.split(",")[1] # convert base64 to PIL image img = base64.b64decode(img) img = Image.open(io.BytesIO(img)) imgs = [img] * p.batch_size p.init_images = imgs # Override object param with self.queue_lock: processed = process_images(p) b64images = [] for i in processed.images: buffer = io.BytesIO() i.save(buffer, format="png") b64images.append(base64.b64encode(buffer.getvalue())) return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info)) def extrasapi(self): raise NotImplementedError def pnginfoapi(self): raise NotImplementedError def launch(self, server_name, port): self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port)