api.py 20.2 KB
Newer Older
1 2
import base64
import io
3
import time
V
Vladimir Mandic 已提交
4
import datetime
5
import uvicorn
B
Bruno Seoane 已提交
6
from threading import Lock
S
Sena 已提交
7
from io import BytesIO
S
Sena 已提交
8
from gradio.processing_utils import decode_base64_to_file
V
Vladimir Mandic 已提交
9
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
10 11 12
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest

13
import modules.shared as shared
V
Vladimir Mandic 已提交
14
from modules import sd_samplers, deepbooru, sd_hijack
15
from modules.api.models import *
16
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
B
Bruno Seoane 已提交
17
from modules.extras import run_extras, run_pnginfo
V
Vladimir Mandic 已提交
18 19 20
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
S
Sena 已提交
21
from PIL import PngImagePlugin,Image
B
Bruno Seoane 已提交
22 23
from modules.sd_models import checkpoints_list
from modules.realesrgan_model import get_realesrgan_models
V
Vladimir Mandic 已提交
24
from modules import devices
B
Bruno Seoane 已提交
25
from typing import List
A
arcticfaded 已提交
26

B
Bruno Seoane 已提交
27 28 29 30
def upscaler_to_index(name: str):
    try:
        return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
    except:
B
Bruno Seoane 已提交
31
        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
32

33

34 35 36 37
def validate_sampler_name(name):
    config = sd_samplers.all_samplers_map.get(name, None)
    if config is None:
        raise HTTPException(status_code=404, detail="Sampler not found")
38

39
    return name
40

B
Bruno Seoane 已提交
41 42 43 44 45 46 47
def setUpscalers(req: dict):
    reqDict = vars(req)
    reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
    reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
    reqDict.pop('upscaler_1')
    reqDict.pop('upscaler_2')
    return reqDict
R
Roy Shilkrot 已提交
48

S
Sena 已提交
49 50 51 52
def decode_base64_to_image(encoding):
    if encoding.startswith("data:image/"):
        encoding = encoding.split(";")[1].split(",")[1]
    return Image.open(BytesIO(base64.b64decode(encoding)))
53

54
def encode_pil_to_base64(image):
E
evshiron 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    with io.BytesIO() as output_bytes:

        # Copy any text-only metadata
        use_metadata = False
        metadata = PngImagePlugin.PngInfo()
        for key, value in image.info.items():
            if isinstance(key, str) and isinstance(value, str):
                metadata.add_text(key, value)
                use_metadata = True

        image.save(
            output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
        )
        bytes_data = output_bytes.getvalue()
    return base64.b64encode(bytes_data)
70

V
Vladimir Mandic 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
def init_api_middleware(app: FastAPI):
    @app.middleware("http")
    async def log_and_time(req: Request, call_next):
        ts = time.time()
        res: Response = await call_next(req)
        duration = str(round(time.time() - ts, 4))
        res.headers["X-Process-Time"] = duration
        if shared.cmd_opts.api_log:
            print('API {t} {code} {prot}/{ver} {method} {p} {cli} {duration}'.format(
                t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
                code = res.status_code,
                ver = req.scope.get('http_version', '0.0'),
                cli = req.scope.get('client', ('0:0.0.0', 0))[0],
                prot = req.scope.get('scheme', 'err'),
                method = req.scope.get('method', 'err'),
                p = req.scope.get('path', 'err'),
                duration = duration,
            ))
        return res

91

92
class Api:
B
Bruno Seoane 已提交
93
    def __init__(self, app: FastAPI, queue_lock: Lock):
94
        if shared.cmd_opts.api_auth:
J
Jim Hays 已提交
95
            self.credentials = dict()
96 97
            for auth in shared.cmd_opts.api_auth.split(","):
                user, password = auth.split(":")
J
Jim Hays 已提交
98
                self.credentials[user] = password
99

100
        self.router = APIRouter()
A
arcticfaded 已提交
101
        self.app = app
V
Vladimir Mandic 已提交
102
        init_api_middleware(self.app)
A
arcticfaded 已提交
103
        self.queue_lock = queue_lock
104 105 106 107 108 109 110 111
        self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
        self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
        self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
        self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
        self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
        self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
        self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
        self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
112
        self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
113 114 115 116 117 118 119 120 121
        self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
        self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
        self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
        self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
        self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
        self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
        self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
        self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
        self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
J
Jim Hays 已提交
122
        self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
123 124
        self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
        self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
P
Philpax 已提交
125
        self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
D
Dean Hopkins 已提交
126
        self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
V
Vladimir Mandic 已提交
127 128 129 130 131
        self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
        self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
        self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
        self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
        self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
132 133 134 135 136 137

    def add_api_route(self, path: str, endpoint, **kwargs):
        if shared.cmd_opts.api_auth:
            return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
        return self.app.add_api_route(path, endpoint, **kwargs)

J
Jim Hays 已提交
138 139 140
    def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
        if credentials.username in self.credentials:
            if compare_digest(credentials.password, self.credentials[credentials.username]):
141 142 143
                return True

        raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
144

145
    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
A
arcticfaded 已提交
146
        populate = txt2imgreq.copy(update={ # Override __init__ params
147
            "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
A
arcticfaded 已提交
148 149
            "do_not_save_samples": True,
            "do_not_save_grid": True
A
arcticfaded 已提交
150 151
            }
        )
152 153
        if populate.sampler_name:
            populate.sampler_index = None  # prevent a warning later on
154

A
arcticfaded 已提交
155
        with self.queue_lock:
156 157
            p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))

P
Philpax 已提交
158
            shared.state.begin()
A
arcticfaded 已提交
159
            processed = process_images(p)
P
Philpax 已提交
160
            shared.state.end()
161

E
evshiron 已提交
162

B
Bruno Seoane 已提交
163
        b64images = list(map(encode_pil_to_base64, processed.images))
E
evshiron 已提交
164

165
        return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
166

167 168 169
    def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
        init_images = img2imgreq.init_images
        if init_images is None:
E
evshiron 已提交
170
            raise HTTPException(status_code=404, detail="Init image not found")
171

S
Stephen 已提交
172 173
        mask = img2imgreq.mask
        if mask:
S
Sena 已提交
174
            mask = decode_base64_to_image(mask)
S
Stephen 已提交
175

176
        populate = img2imgreq.copy(update={ # Override __init__ params
177
            "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
178
            "do_not_save_samples": True,
E
evshiron 已提交
179
            "do_not_save_grid": True,
S
Stephen 已提交
180
            "mask": mask
181 182
            }
        )
183 184
        if populate.sampler_name:
            populate.sampler_index = None  # prevent a warning later on
185 186 187

        args = vars(populate)
        args.pop('include_init_images', None)  # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
188

189
        with self.queue_lock:
190 191 192
            p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
            p.init_images = [decode_base64_to_image(x) for x in init_images]

P
Philpax 已提交
193
            shared.state.begin()
194
            processed = process_images(p)
P
Philpax 已提交
195
            shared.state.end()
E
evshiron 已提交
196

B
Bruno Seoane 已提交
197
        b64images = list(map(encode_pil_to_base64, processed.images))
198

199
        if not img2imgreq.include_init_images:
200 201 202
            img2imgreq.init_images = None
            img2imgreq.mask = None

203
        return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
204

B
Bruno Seoane 已提交
205
    def extras_single_image_api(self, req: ExtrasSingleImageRequest):
B
Bruno Seoane 已提交
206
        reqDict = setUpscalers(req)
B
Bruno Seoane 已提交
207

B
Bruno Seoane 已提交
208
        reqDict['image'] = decode_base64_to_image(reqDict['image'])
B
Bruno Seoane 已提交
209 210

        with self.queue_lock:
211
            result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
B
Bruno Seoane 已提交
212

B
Bruno Seoane 已提交
213
        return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
214 215

    def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
B
Bruno Seoane 已提交
216
        reqDict = setUpscalers(req)
217

B
Bruno Seoane 已提交
218 219 220 221 222 223
        def prepareFiles(file):
            file = decode_base64_to_file(file.data, file_path=file.name)
            file.orig_name = file.name
            return file

        reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
224 225 226
        reqDict.pop('imageList')

        with self.queue_lock:
227
            result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
228

B
Bruno Seoane 已提交
229
        return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
230

B
Bruno Seoane 已提交
231
    def pnginfoapi(self, req: PNGInfoRequest):
B
Bruno Seoane 已提交
232 233 234 235 236 237
        if(not req.image.strip()):
            return PNGInfoResponse(info="")

        result = run_pnginfo(decode_base64_to_image(req.image.strip()))

        return PNGInfoResponse(info=result[1])
238

239
    def progressapi(self, req: ProgressRequest = Depends()):
E
evshiron 已提交
240 241 242
        # copy from check_progress_call of ui.py

        if shared.state.job_count == 0:
E
evshiron 已提交
243
            return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
E
evshiron 已提交
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258

        # avoid dividing zero
        progress = 0.01

        if shared.state.job_count > 0:
            progress += shared.state.job_no / shared.state.job_count
        if shared.state.sampling_steps > 0:
            progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps

        time_since_start = time.time() - shared.state.time_start
        eta = (time_since_start/progress)
        eta_relative = eta-time_since_start

        progress = min(progress, 1)

A
AUTOMATIC 已提交
259
        shared.state.set_current_image()
260

261
        current_image = None
262
        if shared.state.current_image and not req.skip_current_image:
263 264 265
            current_image = encode_pil_to_base64(shared.state.current_image)

        return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
E
evshiron 已提交
266

267
    def interrogateapi(self, interrogatereq: InterrogateRequest):
R
Roy Shilkrot 已提交
268 269
        image_b64 = interrogatereq.image
        if image_b64 is None:
J
Jim Hays 已提交
270
            raise HTTPException(status_code=404, detail="Image not found")
R
Roy Shilkrot 已提交
271

272 273
        img = decode_base64_to_image(image_b64)
        img = img.convert('RGB')
R
Roy Shilkrot 已提交
274 275 276

        # Override object param
        with self.queue_lock:
277 278 279
            if interrogatereq.model == "clip":
                processed = shared.interrogator.interrogate(img)
            elif interrogatereq.model == "deepdanbooru":
280
                processed = deepbooru.model.tag(img)
281 282
            else:
                raise HTTPException(status_code=404, detail="Model not found")
J
Jim Hays 已提交
283

284
        return InterrogateResponse(caption=processed)
285

E
evshiron 已提交
286 287 288 289 290
    def interruptapi(self):
        shared.state.interrupt()

        return {}

B
Bruno Seoane 已提交
291 292 293
    def skip(self):
        shared.state.skip()

B
Bruno Seoane 已提交
294 295 296 297 298 299 300 301
    def get_config(self):
        options = {}
        for key in shared.opts.data.keys():
            metadata = shared.opts.data_labels.get(key)
            if(metadata is not None):
                options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
            else:
                options.update({key: shared.opts.data.get(key, None)})
302

B
Bruno Seoane 已提交
303
        return options
304

B
Bruno Seoane 已提交
305
    def set_config(self, req: Dict[str, Any]):
306 307
        for k, v in req.items():
            shared.opts.set(k, v)
B
Bruno Seoane 已提交
308 309 310 311 312 313 314 315

        shared.opts.save(shared.config_filename)
        return

    def get_cmd_flags(self):
        return vars(shared.cmd_opts)

    def get_samplers(self):
316
        return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
B
Bruno Seoane 已提交
317 318 319

    def get_upscalers(self):
        upscalers = []
320

B
Bruno Seoane 已提交
321 322 323
        for upscaler in shared.sd_upscalers:
            u = upscaler.scaler
            upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
324

B
Bruno Seoane 已提交
325
        return upscalers
326

B
Bruno Seoane 已提交
327 328 329 330 331 332 333 334 335 336 337
    def get_sd_models(self):
        return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]

    def get_hypernetworks(self):
        return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]

    def get_face_restorers(self):
        return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]

    def get_realesrgan_models(self):
        return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
338

J
Jim Hays 已提交
339
    def get_prompt_styles(self):
B
Bruno Seoane 已提交
340 341
        styleList = []
        for k in shared.prompt_styles.styles:
342
            style = shared.prompt_styles.styles[k]
343
            styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
B
Bruno Seoane 已提交
344 345 346 347 348 349 350 351

        return styleList

    def get_artists_categories(self):
        return shared.artist_db.cats

    def get_artists(self):
        return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
352

P
Philpax 已提交
353 354
    def get_embeddings(self):
        db = sd_hijack.model_hijack.embedding_db
355 356 357 358 359 360 361 362 363 364 365 366 367

        def convert_embedding(embedding):
            return {
                "step": embedding.step,
                "sd_checkpoint": embedding.sd_checkpoint,
                "sd_checkpoint_name": embedding.sd_checkpoint_name,
                "shape": embedding.shape,
                "vectors": embedding.vectors,
            }

        def convert_embeddings(embeddings):
            return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}

P
Philpax 已提交
368
        return {
369 370
            "loaded": convert_embeddings(db.word_embeddings),
            "skipped": convert_embeddings(db.skipped_embeddings),
P
Philpax 已提交
371 372
        }

D
Dean Hopkins 已提交
373 374
    def refresh_checkpoints(self):
        shared.refresh_checkpoints()
E
evshiron 已提交
375

V
Vladimir Mandic 已提交
376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
    def create_embedding(self, args: dict):
        try:
            shared.state.begin()
            filename = create_embedding(**args) # create empty embedding
            sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
            shared.state.end()
            return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
        except AssertionError as e:
            shared.state.end()
            return TrainResponse(info = "create embedding error: {error}".format(error = e))

    def create_hypernetwork(self, args: dict):
        try:
            shared.state.begin()
            filename = create_hypernetwork(**args) # create empty embedding
            shared.state.end()
            return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
        except AssertionError as e:
            shared.state.end()
            return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))

    def preprocess(self, args: dict):
        try:
            shared.state.begin()
            preprocess(**args) # quick operation unless blip/booru interrogation is enabled
            shared.state.end()
            return PreprocessResponse(info = 'preprocess complete')
        except KeyError as e:
            shared.state.end()
            return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
        except AssertionError as e:
            shared.state.end()
            return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
        except FileNotFoundError as e:
            shared.state.end()
            return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))

    def train_embedding(self, args: dict):
        try:
            shared.state.begin()
            apply_optimizations = shared.opts.training_xattention_optimizations
            error = None
            filename = ''
            if not apply_optimizations:
                sd_hijack.undo_optimizations()
            try:
                embedding, filename = train_embedding(**args) # can take a long time to complete
            except Exception as e:
                error = e
            finally:
                if not apply_optimizations:
                    sd_hijack.apply_optimizations()
                shared.state.end()
            return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
        except AssertionError as msg:
            shared.state.end()
            return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))

    def train_hypernetwork(self, args: dict):
        try:
            shared.state.begin()
            initial_hypernetwork = shared.loaded_hypernetwork
            apply_optimizations = shared.opts.training_xattention_optimizations
            error = None
            filename = ''
            if not apply_optimizations:
                sd_hijack.undo_optimizations()
            try:
                hypernetwork, filename = train_hypernetwork(*args)
            except Exception as e:
                error = e
            finally:
                shared.loaded_hypernetwork = initial_hypernetwork
                shared.sd_model.cond_stage_model.to(devices.device)
                shared.sd_model.first_stage_model.to(devices.device)
                if not apply_optimizations:
                    sd_hijack.apply_optimizations()
                shared.state.end()
            return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
        except AssertionError as msg:
            shared.state.end()
            return TrainResponse(info = "train embedding error: {error}".format(error = error))

459
    def launch(self, server_name, port):
A
arcticfaded 已提交
460 461
        self.app.include_router(self.router)
        uvicorn.run(self.app, host=server_name, port=port)