api.py 31.7 KB
Newer Older
1 2
import base64
import io
3
import time
V
Vladimir Mandic 已提交
4
import datetime
5
import uvicorn
6
import gradio as gr
B
Bruno Seoane 已提交
7
from threading import Lock
S
Sena 已提交
8
from io import BytesIO
S
Sena 已提交
9
from gradio.processing_utils import decode_base64_to_file
V
Vladimir Mandic 已提交
10
from fastapi import APIRouter, Depends, FastAPI, Request, Response
11
from fastapi.security import HTTPBasic, HTTPBasicCredentials
V
Vladimir Mandic 已提交
12 13 14
from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
15 16
from secrets import compare_digest

17
import modules.shared as shared
18
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
19
from modules.api.models import *
20
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
V
Vladimir Mandic 已提交
21 22 23
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
S
Sena 已提交
24
from PIL import PngImagePlugin,Image
Φ
Φφ 已提交
25
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
26
from modules.sd_models_config import find_checkpoint_config_near_filename
B
Bruno Seoane 已提交
27
from modules.realesrgan_model import get_realesrgan_models
V
Vladimir Mandic 已提交
28
from modules import devices
B
Bruno Seoane 已提交
29
from typing import List
V
Vladimir Mandic 已提交
30 31
import piexif
import piexif.helper
A
arcticfaded 已提交
32

B
Bruno Seoane 已提交
33 34 35 36
def upscaler_to_index(name: str):
    try:
        return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
    except:
N
noodleanon 已提交
37
        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
38

N
noodleanon 已提交
39 40 41 42 43
def script_name_to_index(name, scripts):
    try:
        return [script.title().lower() for script in scripts].index(name.lower())
    except:
        raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
44

45 46 47 48
def validate_sampler_name(name):
    config = sd_samplers.all_samplers_map.get(name, None)
    if config is None:
        raise HTTPException(status_code=404, detail="Sampler not found")
49

50
    return name
51

B
Bruno Seoane 已提交
52 53
def setUpscalers(req: dict):
    reqDict = vars(req)
54 55
    reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
    reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
B
Bruno Seoane 已提交
56
    return reqDict
R
Roy Shilkrot 已提交
57

S
Sena 已提交
58 59 60
def decode_base64_to_image(encoding):
    if encoding.startswith("data:image/"):
        encoding = encoding.split(";")[1].split(",")[1]
61 62 63 64 65
    try:
        image = Image.open(BytesIO(base64.b64decode(encoding)))
        return image
    except Exception as err:
        raise HTTPException(status_code=500, detail="Invalid encoded image")
66

67
def encode_pil_to_base64(image):
E
evshiron 已提交
68 69
    with io.BytesIO() as output_bytes:

V
Vladimir Mandic 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
        if opts.samples_format.lower() == 'png':
            use_metadata = False
            metadata = PngImagePlugin.PngInfo()
            for key, value in image.info.items():
                if isinstance(key, str) and isinstance(value, str):
                    metadata.add_text(key, value)
                    use_metadata = True
            image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)

        elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
            parameters = image.info.get('parameters', None)
            exif_bytes = piexif.dump({
                "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
            })
            if opts.samples_format.lower() in ("jpg", "jpeg"):
                image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
            else:
                image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)

        else:
            raise HTTPException(status_code=500, detail="Invalid image format")
E
evshiron 已提交
91 92

        bytes_data = output_bytes.getvalue()
V
Vladimir Mandic 已提交
93

E
evshiron 已提交
94
    return base64.b64encode(bytes_data)
95

V
Vladimir Mandic 已提交
96
def api_middleware(app: FastAPI):
V
Vladimir Mandic 已提交
97 98 99 100 101 102 103 104 105 106
    rich_available = True
    try:
        import anyio # importing just so it can be placed on silent list
        import starlette # importing just so it can be placed on silent list
        from rich.console import Console
        console = Console()
    except:
        import traceback
        rich_available = False

V
Vladimir Mandic 已提交
107 108 109 110 111 112
    @app.middleware("http")
    async def log_and_time(req: Request, call_next):
        ts = time.time()
        res: Response = await call_next(req)
        duration = str(round(time.time() - ts, 4))
        res.headers["X-Process-Time"] = duration
V
Vladimir Mandic 已提交
113 114 115
        endpoint = req.scope.get('path', 'err')
        if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
            print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
V
Vladimir Mandic 已提交
116 117 118 119 120 121
                t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
                code = res.status_code,
                ver = req.scope.get('http_version', '0.0'),
                cli = req.scope.get('client', ('0:0.0.0', 0))[0],
                prot = req.scope.get('scheme', 'err'),
                method = req.scope.get('method', 'err'),
V
Vladimir Mandic 已提交
122
                endpoint = endpoint,
V
Vladimir Mandic 已提交
123 124 125 126
                duration = duration,
            ))
        return res

V
Vladimir Mandic 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
    def handle_exception(request: Request, e: Exception):
        err = {
            "error": type(e).__name__,
            "detail": vars(e).get('detail', ''),
            "body": vars(e).get('body', ''),
            "errors": str(e),
        }
        print(f"API error: {request.method}: {request.url} {err}")
        if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
            if rich_available:
                console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
            else:
                traceback.print_exc()
        return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))

    @app.middleware("http")
    async def exception_handling(request: Request, call_next):
        try:
            return await call_next(request)
        except Exception as e:
            return handle_exception(request, e)

    @app.exception_handler(Exception)
    async def fastapi_exception_handler(request: Request, e: Exception):
        return handle_exception(request, e)

    @app.exception_handler(HTTPException)
    async def http_exception_handler(request: Request, e: HTTPException):
        return handle_exception(request, e)

157

158
class Api:
B
Bruno Seoane 已提交
159
    def __init__(self, app: FastAPI, queue_lock: Lock):
160
        if shared.cmd_opts.api_auth:
J
Jim Hays 已提交
161
            self.credentials = dict()
162 163
            for auth in shared.cmd_opts.api_auth.split(","):
                user, password = auth.split(":")
J
Jim Hays 已提交
164
                self.credentials[user] = password
165

166
        self.router = APIRouter()
A
arcticfaded 已提交
167 168
        self.app = app
        self.queue_lock = queue_lock
V
Vladimir Mandic 已提交
169
        api_middleware(self.app)
170 171 172 173 174 175 176 177
        self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
        self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
        self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
        self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
        self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
        self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
        self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
        self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
178
        self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
179 180 181 182 183 184 185 186 187
        self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
        self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
        self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
        self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
        self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
        self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
        self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
        self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
        self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
J
Jim Hays 已提交
188
        self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
P
Philpax 已提交
189
        self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
D
Dean Hopkins 已提交
190
        self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
V
Vladimir Mandic 已提交
191 192 193 194 195
        self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
        self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
        self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
        self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
        self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
V
Vladimir Mandic 已提交
196
        self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
Φ
Φφ 已提交
197 198
        self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
        self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
Y
Yea chen 已提交
199
        self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
200

201 202 203
        self.default_script_arg_txt2img = []
        self.default_script_arg_img2img = []

204 205 206 207 208
    def add_api_route(self, path: str, endpoint, **kwargs):
        if shared.cmd_opts.api_auth:
            return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
        return self.app.add_api_route(path, endpoint, **kwargs)

J
Jim Hays 已提交
209 210 211
    def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
        if credentials.username in self.credentials:
            if compare_digest(credentials.password, self.credentials[credentials.username]):
212 213 214
                return True

        raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
215

216 217
    def get_selectable_script(self, script_name, script_runner):
        if script_name is None or script_name == "":
A
AUTOMATIC 已提交
218 219 220 221 222
            return None, None

        script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
        script = script_runner.selectable_scripts[script_idx]
        return script, script_idx
Y
Yea chen 已提交
223 224
    
    def get_scripts_list(self):
Y
Yea Chen 已提交
225 226
        t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
        i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
Y
Yea chen 已提交
227 228

        return ScriptsList(txt2img = t2ilist, img2img = i2ilist)  
229

230
    def get_script(self, script_name, script_runner):
V
Vespinian 已提交
231 232 233 234 235
        if script_name is None or script_name == "":
            return None, None
        
        script_idx = script_name_to_index(script_name, script_runner.scripts)
        return script_runner.scripts[script_idx]
236

237
    def init_default_script_args(self, script_runner):
238 239 240 241 242
        #find max idx from the scripts in runner and generate a none array to init script_args
        last_arg_index = 1
        for script in script_runner.scripts:
            if last_arg_index < script.args_to:
                last_arg_index = script.args_to
V
Vespinian 已提交
243
        # None everywhere except position 0 to initialize script args
244
        script_args = [None]*last_arg_index
245 246 247 248 249 250 251 252 253 254 255 256 257 258
        script_args[0] = 0

        # get default values
        with gr.Blocks(): # will throw errors calling ui function without this
            for script in script_runner.scripts:
                if script.ui(script.is_img2img):
                    ui_default_values = []
                    for elem in script.ui(script.is_img2img):
                        ui_default_values.append(elem.value)
                    script_args[script.args_from:script.args_to] = ui_default_values
        return script_args

    def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner):
        script_args = default_script_args.copy()
V
Vespinian 已提交
259 260 261 262
        # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
        if selectable_scripts:
            script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
            script_args[0] = selectable_idx + 1
263 264

        # Now check for always on scripts
265 266
        if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
            for alwayson_script_name in request.alwayson_scripts.keys():
267 268 269 270 271 272
                alwayson_script = self.get_script(alwayson_script_name, script_runner)
                if alwayson_script == None:
                    raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
                # Selectable script in always on script param check
                if alwayson_script.alwayson == False:
                    raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
273 274
                # always on script with no arg should always run so you don't really need to add them to the requests
                if "args" in request.alwayson_scripts[alwayson_script_name]:
275 276 277
                    # min between arg length in scriptrunner and arg length in the request
                    for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
                        script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
V
Vespinian 已提交
278 279 280 281 282 283 284
        return script_args

    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
        script_runner = scripts.scripts_txt2img
        if not script_runner.scripts:
            script_runner.initialize_scripts(False)
            ui.create_ui()
285 286
        if not self.default_script_arg_txt2img:
            self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
V
Vespinian 已提交
287 288
        selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)

289
        populate = txt2imgreq.copy(update={  # Override __init__ params
V
Vespinian 已提交
290
            "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
291 292 293
            "do_not_save_samples": not txt2imgreq.save_images,
            "do_not_save_grid": not txt2imgreq.save_images,
        })
V
Vespinian 已提交
294 295 296 297 298 299
        if populate.sampler_name:
            populate.sampler_index = None  # prevent a warning later on

        args = vars(populate)
        args.pop('script_name', None)
        args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
300
        args.pop('alwayson_scripts', None)
V
Vespinian 已提交
301

302
        script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
303

304 305
        send_images = args.pop('send_images', True)
        args.pop('save_images', None)
306

A
arcticfaded 已提交
307
        with self.queue_lock:
308
            p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
309
            p.scripts = script_runner
310 311
            p.outpath_grids = opts.outdir_txt2img_grids
            p.outpath_samples = opts.outdir_txt2img_samples
312

P
Philpax 已提交
313
            shared.state.begin()
V
Vespinian 已提交
314
            if selectable_scripts != None:
315
                p.script_args = script_args
V
Vespinian 已提交
316
                processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
317
            else:
V
Vespinian 已提交
318
                p.script_args = tuple(script_args) # Need to pass args as tuple here
319
                processed = process_images(p)
P
Philpax 已提交
320
            shared.state.end()
321

322
        b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
E
evshiron 已提交
323

324
        return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
325

326 327 328
    def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
        init_images = img2imgreq.init_images
        if init_images is None:
E
evshiron 已提交
329
            raise HTTPException(status_code=404, detail="Init image not found")
330

S
Stephen 已提交
331 332
        mask = img2imgreq.mask
        if mask:
S
Sena 已提交
333
            mask = decode_base64_to_image(mask)
S
Stephen 已提交
334

335 336 337 338
        script_runner = scripts.scripts_img2img
        if not script_runner.scripts:
            script_runner.initialize_scripts(True)
            ui.create_ui()
339 340
        if not self.default_script_arg_img2img:
            self.default_script_arg_img2img = self.init_default_script_args(script_runner)
V
Vespinian 已提交
341
        selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
342

V
Vespinian 已提交
343
        populate = img2imgreq.copy(update={  # Override __init__ params
344
            "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
345 346 347 348
            "do_not_save_samples": not img2imgreq.save_images,
            "do_not_save_grid": not img2imgreq.save_images,
            "mask": mask,
        })
349 350
        if populate.sampler_name:
            populate.sampler_index = None  # prevent a warning later on
351 352 353

        args = vars(populate)
        args.pop('include_init_images', None)  # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
N
noodleanon 已提交
354
        args.pop('script_name', None)
V
Vespinian 已提交
355
        args.pop('script_args', None)  # will refeed them to the pipeline directly after initializing them
356
        args.pop('alwayson_scripts', None)
357

358
        script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
359

360 361
        send_images = args.pop('send_images', True)
        args.pop('save_images', None)
362

363
        with self.queue_lock:
364 365
            p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
            p.init_images = [decode_base64_to_image(x) for x in init_images]
366
            p.scripts = script_runner
367 368
            p.outpath_grids = opts.outdir_img2img_grids
            p.outpath_samples = opts.outdir_img2img_samples
369

P
Philpax 已提交
370
            shared.state.begin()
V
Vespinian 已提交
371
            if selectable_scripts != None:
372
                p.script_args = script_args
V
Vespinian 已提交
373
                processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
N
noodleanon 已提交
374
            else:
V
Vespinian 已提交
375
                p.script_args = tuple(script_args) # Need to pass args as tuple here
N
noodleanon 已提交
376
                processed = process_images(p)
P
Philpax 已提交
377
            shared.state.end()
E
evshiron 已提交
378

379
        b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
380

381
        if not img2imgreq.include_init_images:
382 383 384
            img2imgreq.init_images = None
            img2imgreq.mask = None

385
        return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
386

B
Bruno Seoane 已提交
387
    def extras_single_image_api(self, req: ExtrasSingleImageRequest):
B
Bruno Seoane 已提交
388
        reqDict = setUpscalers(req)
B
Bruno Seoane 已提交
389

B
Bruno Seoane 已提交
390
        reqDict['image'] = decode_base64_to_image(reqDict['image'])
B
Bruno Seoane 已提交
391 392

        with self.queue_lock:
393
            result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
B
Bruno Seoane 已提交
394

B
Bruno Seoane 已提交
395
        return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
396 397

    def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
B
Bruno Seoane 已提交
398
        reqDict = setUpscalers(req)
399

B
Bruno Seoane 已提交
400 401 402 403 404 405
        def prepareFiles(file):
            file = decode_base64_to_file(file.data, file_path=file.name)
            file.orig_name = file.name
            return file

        reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
406 407 408
        reqDict.pop('imageList')

        with self.queue_lock:
409
            result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
410

B
Bruno Seoane 已提交
411
        return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
412

B
Bruno Seoane 已提交
413
    def pnginfoapi(self, req: PNGInfoRequest):
B
Bruno Seoane 已提交
414 415 416
        if(not req.image.strip()):
            return PNGInfoResponse(info="")

417 418 419 420 421 422 423 424 425
        image = decode_base64_to_image(req.image.strip())
        if image is None:
            return PNGInfoResponse(info="")

        geninfo, items = images.read_info_from_image(image)
        if geninfo is None:
            geninfo = ""

        items = {**{'parameters': geninfo}, **items}
B
Bruno Seoane 已提交
426

427
        return PNGInfoResponse(info=geninfo, items=items)
428

429
    def progressapi(self, req: ProgressRequest = Depends()):
E
evshiron 已提交
430 431 432
        # copy from check_progress_call of ui.py

        if shared.state.job_count == 0:
433
            return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
E
evshiron 已提交
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448

        # avoid dividing zero
        progress = 0.01

        if shared.state.job_count > 0:
            progress += shared.state.job_no / shared.state.job_count
        if shared.state.sampling_steps > 0:
            progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps

        time_since_start = time.time() - shared.state.time_start
        eta = (time_since_start/progress)
        eta_relative = eta-time_since_start

        progress = min(progress, 1)

A
AUTOMATIC 已提交
449
        shared.state.set_current_image()
450

451
        current_image = None
452
        if shared.state.current_image and not req.skip_current_image:
453 454
            current_image = encode_pil_to_base64(shared.state.current_image)

455
        return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
E
evshiron 已提交
456

457
    def interrogateapi(self, interrogatereq: InterrogateRequest):
R
Roy Shilkrot 已提交
458 459
        image_b64 = interrogatereq.image
        if image_b64 is None:
J
Jim Hays 已提交
460
            raise HTTPException(status_code=404, detail="Image not found")
R
Roy Shilkrot 已提交
461

462 463
        img = decode_base64_to_image(image_b64)
        img = img.convert('RGB')
R
Roy Shilkrot 已提交
464 465 466

        # Override object param
        with self.queue_lock:
467 468 469
            if interrogatereq.model == "clip":
                processed = shared.interrogator.interrogate(img)
            elif interrogatereq.model == "deepdanbooru":
470
                processed = deepbooru.model.tag(img)
471 472
            else:
                raise HTTPException(status_code=404, detail="Model not found")
J
Jim Hays 已提交
473

474
        return InterrogateResponse(caption=processed)
475

E
evshiron 已提交
476 477 478 479 480
    def interruptapi(self):
        shared.state.interrupt()

        return {}

Φ
Φφ 已提交
481 482 483 484 485 486 487 488 489 490
    def unloadapi(self):
        unload_model_weights()

        return {}

    def reloadapi(self):
        reload_model_weights()

        return {}

B
Bruno Seoane 已提交
491 492 493
    def skip(self):
        shared.state.skip()

B
Bruno Seoane 已提交
494 495 496 497 498 499 500 501
    def get_config(self):
        options = {}
        for key in shared.opts.data.keys():
            metadata = shared.opts.data_labels.get(key)
            if(metadata is not None):
                options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
            else:
                options.update({key: shared.opts.data.get(key, None)})
502

B
Bruno Seoane 已提交
503
        return options
504

B
Bruno Seoane 已提交
505
    def set_config(self, req: Dict[str, Any]):
506 507
        for k, v in req.items():
            shared.opts.set(k, v)
B
Bruno Seoane 已提交
508 509 510 511 512 513 514 515

        shared.opts.save(shared.config_filename)
        return

    def get_cmd_flags(self):
        return vars(shared.cmd_opts)

    def get_samplers(self):
516
        return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
B
Bruno Seoane 已提交
517 518

    def get_upscalers(self):
519 520 521 522 523
        return [
            {
                "name": upscaler.name,
                "model_name": upscaler.scaler.model_name,
                "model_path": upscaler.data_path,
524
                "model_url": None,
525 526 527 528
                "scale": upscaler.scale,
            }
            for upscaler in shared.sd_upscalers
        ]
529

B
Bruno Seoane 已提交
530
    def get_sd_models(self):
531
        return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
B
Bruno Seoane 已提交
532 533 534 535 536 537 538 539 540

    def get_hypernetworks(self):
        return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]

    def get_face_restorers(self):
        return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]

    def get_realesrgan_models(self):
        return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
541

J
Jim Hays 已提交
542
    def get_prompt_styles(self):
B
Bruno Seoane 已提交
543 544
        styleList = []
        for k in shared.prompt_styles.styles:
545
            style = shared.prompt_styles.styles[k]
546
            styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
B
Bruno Seoane 已提交
547 548 549

        return styleList

P
Philpax 已提交
550 551
    def get_embeddings(self):
        db = sd_hijack.model_hijack.embedding_db
552 553 554 555 556 557 558 559 560 561 562 563 564

        def convert_embedding(embedding):
            return {
                "step": embedding.step,
                "sd_checkpoint": embedding.sd_checkpoint,
                "sd_checkpoint_name": embedding.sd_checkpoint_name,
                "shape": embedding.shape,
                "vectors": embedding.vectors,
            }

        def convert_embeddings(embeddings):
            return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}

P
Philpax 已提交
565
        return {
566 567
            "loaded": convert_embeddings(db.word_embeddings),
            "skipped": convert_embeddings(db.skipped_embeddings),
P
Philpax 已提交
568 569
        }

D
Dean Hopkins 已提交
570 571
    def refresh_checkpoints(self):
        shared.refresh_checkpoints()
E
evshiron 已提交
572

V
Vladimir Mandic 已提交
573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633
    def create_embedding(self, args: dict):
        try:
            shared.state.begin()
            filename = create_embedding(**args) # create empty embedding
            sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
            shared.state.end()
            return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
        except AssertionError as e:
            shared.state.end()
            return TrainResponse(info = "create embedding error: {error}".format(error = e))

    def create_hypernetwork(self, args: dict):
        try:
            shared.state.begin()
            filename = create_hypernetwork(**args) # create empty embedding
            shared.state.end()
            return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
        except AssertionError as e:
            shared.state.end()
            return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))

    def preprocess(self, args: dict):
        try:
            shared.state.begin()
            preprocess(**args) # quick operation unless blip/booru interrogation is enabled
            shared.state.end()
            return PreprocessResponse(info = 'preprocess complete')
        except KeyError as e:
            shared.state.end()
            return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
        except AssertionError as e:
            shared.state.end()
            return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
        except FileNotFoundError as e:
            shared.state.end()
            return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))

    def train_embedding(self, args: dict):
        try:
            shared.state.begin()
            apply_optimizations = shared.opts.training_xattention_optimizations
            error = None
            filename = ''
            if not apply_optimizations:
                sd_hijack.undo_optimizations()
            try:
                embedding, filename = train_embedding(**args) # can take a long time to complete
            except Exception as e:
                error = e
            finally:
                if not apply_optimizations:
                    sd_hijack.apply_optimizations()
                shared.state.end()
            return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
        except AssertionError as msg:
            shared.state.end()
            return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))

    def train_hypernetwork(self, args: dict):
        try:
            shared.state.begin()
A
AUTOMATIC 已提交
634
            shared.loaded_hypernetworks = []
V
Vladimir Mandic 已提交
635 636 637 638 639 640
            apply_optimizations = shared.opts.training_xattention_optimizations
            error = None
            filename = ''
            if not apply_optimizations:
                sd_hijack.undo_optimizations()
            try:
M
minux302 已提交
641
                hypernetwork, filename = train_hypernetwork(**args)
V
Vladimir Mandic 已提交
642 643 644 645 646 647 648 649
            except Exception as e:
                error = e
            finally:
                shared.sd_model.cond_stage_model.to(devices.device)
                shared.sd_model.first_stage_model.to(devices.device)
                if not apply_optimizations:
                    sd_hijack.apply_optimizations()
                shared.state.end()
A
AUTOMATIC 已提交
650
            return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
V
Vladimir Mandic 已提交
651 652
        except AssertionError as msg:
            shared.state.end()
A
AUTOMATIC 已提交
653
            return TrainResponse(info="train embedding error: {error}".format(error=error))
V
Vladimir Mandic 已提交
654

V
Vladimir Mandic 已提交
655 656 657 658
    def get_memory(self):
        try:
            import os, psutil
            process = psutil.Process(os.getpid())
V
Vladimir Mandic 已提交
659 660 661
            res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
            ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
            ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
V
Vladimir Mandic 已提交
662 663 664 665 666 667
        except Exception as err:
            ram = { 'error': f'{err}' }
        try:
            import torch
            if torch.cuda.is_available():
                s = torch.cuda.mem_get_info()
V
Vladimir Mandic 已提交
668
                system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
V
Vladimir Mandic 已提交
669
                s = dict(torch.cuda.memory_stats(shared.device))
V
Vladimir Mandic 已提交
670 671 672 673
                allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
                reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
                active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
                inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
V
Vladimir Mandic 已提交
674 675 676 677 678 679 680 681 682 683 684 685 686 687 688
                warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
                cuda = {
                    'system': system,
                    'active': active,
                    'allocated': allocated,
                    'reserved': reserved,
                    'inactive': inactive,
                    'events': warnings,
                }
            else:
                cuda = { 'error': 'unavailable' }
        except Exception as err:
            cuda = { 'error': f'{err}' }
        return MemoryResponse(ram = ram, cuda = cuda)

689
    def launch(self, server_name, port):
A
arcticfaded 已提交
690 691
        self.app.include_router(self.router)
        uvicorn.run(self.app, host=server_name, port=port)