api.py 5.4 KB
Newer Older
R
Roy Shilkrot 已提交
1
from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI, InterrogateAPI
2
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
A
arcticfaded 已提交
3 4
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
5 6
import modules.shared as shared
import uvicorn
A
arcticfaded 已提交
7
from fastapi import Body, APIRouter, HTTPException
8 9
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json
10
from typing import List
11 12 13
import json
import io
import base64
14
from PIL import Image
15

A
arcticfaded 已提交
16
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
A
arcticfaded 已提交
17

18
class TextToImageResponse(BaseModel):
19
    images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
20 21 22
    parameters: Json
    info: Json

23
class ImageToImageResponse(BaseModel):
24
    images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
25 26 27
    parameters: Json
    info: Json

R
Roy Shilkrot 已提交
28 29 30 31 32
class InterrogateResponse(BaseModel):
    caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
    parameters: Json
    info: Json

33 34

class Api:
A
arcticfaded 已提交
35
    def __init__(self, app, queue_lock):
36
        self.router = APIRouter()
A
arcticfaded 已提交
37 38 39
        self.app = app
        self.queue_lock = queue_lock
        self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
40
        self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
R
Roy Shilkrot 已提交
41
        self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
42

S
Stephen 已提交
43 44 45 46 47 48 49 50
    def __base64_to_image(self, base64_string):
        # if has a comma, deal with prefix
        if "," in base64_string:
            base64_string = base64_string.split(",")[1]
        imgdata = base64.b64decode(base64_string)
        # convert base64 to PIL image
        return Image.open(io.BytesIO(imgdata))

51
    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
A
arcticfaded 已提交
52 53 54 55 56
        sampler_index = sampler_to_index(txt2imgreq.sampler_index)
        
        if sampler_index is None:
            raise HTTPException(status_code=404, detail="Sampler not found") 
        
A
arcticfaded 已提交
57 58
        populate = txt2imgreq.copy(update={ # Override __init__ params
            "sd_model": shared.sd_model, 
A
arcticfaded 已提交
59
            "sampler_index": sampler_index[0],
A
arcticfaded 已提交
60 61
            "do_not_save_samples": True,
            "do_not_save_grid": True
A
arcticfaded 已提交
62 63 64 65
            }
        )
        p = StableDiffusionProcessingTxt2Img(**vars(populate))
        # Override object param
A
arcticfaded 已提交
66 67
        with self.queue_lock:
            processed = process_images(p)
68 69 70 71 72 73 74
        
        b64images = []
        for i in processed.images:
            buffer = io.BytesIO()
            i.save(buffer, format="png")
            b64images.append(base64.b64encode(buffer.getvalue()))

75
        return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
76 77 78
        
        

79 80 81 82 83 84 85 86 87 88 89
    def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
        sampler_index = sampler_to_index(img2imgreq.sampler_index)
        
        if sampler_index is None:
            raise HTTPException(status_code=404, detail="Sampler not found") 


        init_images = img2imgreq.init_images
        if init_images is None:
            raise HTTPException(status_code=404, detail="Init image not found") 

S
Stephen 已提交
90 91
        mask = img2imgreq.mask
        if mask:
S
Stephen 已提交
92
            mask = self.__base64_to_image(mask)
S
Stephen 已提交
93

94 95 96 97 98
        
        populate = img2imgreq.copy(update={ # Override __init__ params
            "sd_model": shared.sd_model, 
            "sampler_index": sampler_index[0],
            "do_not_save_samples": True,
S
Stephen 已提交
99 100
            "do_not_save_grid": True, 
            "mask": mask
101 102 103 104 105 106
            }
        )
        p = StableDiffusionProcessingImg2Img(**vars(populate))

        imgs = []
        for img in init_images:
S
Stephen 已提交
107
            img = self.__base64_to_image(img)
108 109 110 111 112 113 114 115 116 117 118 119 120
            imgs = [img] * p.batch_size

        p.init_images = imgs
        # Override object param
        with self.queue_lock:
            processed = process_images(p)
        
        b64images = []
        for i in processed.images:
            buffer = io.BytesIO()
            i.save(buffer, format="png")
            b64images.append(base64.b64encode(buffer.getvalue()))

121 122 123 124 125
        if (not img2imgreq.include_init_images):
            img2imgreq.init_images = None
            img2imgreq.mask = None

        return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
126

R
Roy Shilkrot 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
    def interrogateapi(self, interrogatereq: InterrogateAPI):
        image_b64 = interrogatereq.image
        if image_b64 is None:
            raise HTTPException(status_code=404, detail="Image not found") 

        populate = interrogatereq.copy(update={ # Override __init__ params
            }
        )

        img = self.__base64_to_image(image_b64)

        # Override object param
        with self.queue_lock:
            processed = shared.interrogator.interrogate(img)
        
        return InterrogateResponse(caption=processed, parameters=json.dumps(vars(interrogatereq)), info=None)

A
arcticfaded 已提交
144
    def extrasapi(self):
145 146
        raise NotImplementedError

A
arcticfaded 已提交
147
    def pnginfoapi(self):
148 149 150
        raise NotImplementedError

    def launch(self, server_name, port):
A
arcticfaded 已提交
151 152
        self.app.include_router(self.router)
        uvicorn.run(self.app, host=server_name, port=port)