api.py 31.5 KB
Newer Older
1 2
import base64
import io
3
import time
V
Vladimir Mandic 已提交
4
import datetime
5
import uvicorn
6
import gradio as gr
B
Bruno Seoane 已提交
7
from threading import Lock
S
Sena 已提交
8
from io import BytesIO
V
Vladimir Mandic 已提交
9
from fastapi import APIRouter, Depends, FastAPI, Request, Response
10
from fastapi.security import HTTPBasic, HTTPBasicCredentials
V
Vladimir Mandic 已提交
11 12 13
from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
14 15
from secrets import compare_digest

16
import modules.shared as shared
17
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
18
from modules.api.models import *
19
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
V
Vladimir Mandic 已提交
20 21 22
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
S
Sena 已提交
23
from PIL import PngImagePlugin,Image
Φ
Φφ 已提交
24
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
25
from modules.sd_models_config import find_checkpoint_config_near_filename
B
Bruno Seoane 已提交
26
from modules.realesrgan_model import get_realesrgan_models
V
Vladimir Mandic 已提交
27
from modules import devices
B
Bruno Seoane 已提交
28
from typing import List
V
Vladimir Mandic 已提交
29 30
import piexif
import piexif.helper
A
arcticfaded 已提交
31

B
Bruno Seoane 已提交
32 33 34 35
def upscaler_to_index(name: str):
    try:
        return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
    except:
N
noodleanon 已提交
36
        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
37

N
noodleanon 已提交
38 39 40 41 42
def script_name_to_index(name, scripts):
    try:
        return [script.title().lower() for script in scripts].index(name.lower())
    except:
        raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
43

44 45 46 47
def validate_sampler_name(name):
    config = sd_samplers.all_samplers_map.get(name, None)
    if config is None:
        raise HTTPException(status_code=404, detail="Sampler not found")
48

49
    return name
50

B
Bruno Seoane 已提交
51 52
def setUpscalers(req: dict):
    reqDict = vars(req)
53 54
    reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
    reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
B
Bruno Seoane 已提交
55
    return reqDict
R
Roy Shilkrot 已提交
56

S
Sena 已提交
57 58 59
def decode_base64_to_image(encoding):
    if encoding.startswith("data:image/"):
        encoding = encoding.split(";")[1].split(",")[1]
60 61 62 63 64
    try:
        image = Image.open(BytesIO(base64.b64decode(encoding)))
        return image
    except Exception as err:
        raise HTTPException(status_code=500, detail="Invalid encoded image")
65

66
def encode_pil_to_base64(image):
E
evshiron 已提交
67 68
    with io.BytesIO() as output_bytes:

V
Vladimir Mandic 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
        if opts.samples_format.lower() == 'png':
            use_metadata = False
            metadata = PngImagePlugin.PngInfo()
            for key, value in image.info.items():
                if isinstance(key, str) and isinstance(value, str):
                    metadata.add_text(key, value)
                    use_metadata = True
            image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)

        elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
            parameters = image.info.get('parameters', None)
            exif_bytes = piexif.dump({
                "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
            })
            if opts.samples_format.lower() in ("jpg", "jpeg"):
                image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
            else:
                image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)

        else:
            raise HTTPException(status_code=500, detail="Invalid image format")
E
evshiron 已提交
90 91

        bytes_data = output_bytes.getvalue()
V
Vladimir Mandic 已提交
92

E
evshiron 已提交
93
    return base64.b64encode(bytes_data)
94

V
Vladimir Mandic 已提交
95
def api_middleware(app: FastAPI):
V
Vladimir Mandic 已提交
96 97 98 99 100 101 102 103 104 105
    rich_available = True
    try:
        import anyio # importing just so it can be placed on silent list
        import starlette # importing just so it can be placed on silent list
        from rich.console import Console
        console = Console()
    except:
        import traceback
        rich_available = False

V
Vladimir Mandic 已提交
106 107 108 109 110 111
    @app.middleware("http")
    async def log_and_time(req: Request, call_next):
        ts = time.time()
        res: Response = await call_next(req)
        duration = str(round(time.time() - ts, 4))
        res.headers["X-Process-Time"] = duration
V
Vladimir Mandic 已提交
112 113 114
        endpoint = req.scope.get('path', 'err')
        if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
            print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
V
Vladimir Mandic 已提交
115 116 117 118 119 120
                t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
                code = res.status_code,
                ver = req.scope.get('http_version', '0.0'),
                cli = req.scope.get('client', ('0:0.0.0', 0))[0],
                prot = req.scope.get('scheme', 'err'),
                method = req.scope.get('method', 'err'),
V
Vladimir Mandic 已提交
121
                endpoint = endpoint,
V
Vladimir Mandic 已提交
122 123 124 125
                duration = duration,
            ))
        return res

V
Vladimir Mandic 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
    def handle_exception(request: Request, e: Exception):
        err = {
            "error": type(e).__name__,
            "detail": vars(e).get('detail', ''),
            "body": vars(e).get('body', ''),
            "errors": str(e),
        }
        print(f"API error: {request.method}: {request.url} {err}")
        if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
            if rich_available:
                console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
            else:
                traceback.print_exc()
        return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))

    @app.middleware("http")
    async def exception_handling(request: Request, call_next):
        try:
            return await call_next(request)
        except Exception as e:
            return handle_exception(request, e)

    @app.exception_handler(Exception)
    async def fastapi_exception_handler(request: Request, e: Exception):
        return handle_exception(request, e)

    @app.exception_handler(HTTPException)
    async def http_exception_handler(request: Request, e: HTTPException):
        return handle_exception(request, e)

156

157
class Api:
B
Bruno Seoane 已提交
158
    def __init__(self, app: FastAPI, queue_lock: Lock):
159
        if shared.cmd_opts.api_auth:
J
Jim Hays 已提交
160
            self.credentials = dict()
161 162
            for auth in shared.cmd_opts.api_auth.split(","):
                user, password = auth.split(":")
J
Jim Hays 已提交
163
                self.credentials[user] = password
164

165
        self.router = APIRouter()
A
arcticfaded 已提交
166 167
        self.app = app
        self.queue_lock = queue_lock
V
Vladimir Mandic 已提交
168
        api_middleware(self.app)
169 170 171 172 173 174 175 176
        self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
        self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
        self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
        self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
        self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
        self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
        self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
        self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
177
        self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
178 179 180 181 182 183 184 185 186
        self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
        self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
        self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
        self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
        self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
        self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
        self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
        self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
        self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
J
Jim Hays 已提交
187
        self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
P
Philpax 已提交
188
        self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
D
Dean Hopkins 已提交
189
        self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
V
Vladimir Mandic 已提交
190 191 192 193 194
        self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
        self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
        self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
        self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
        self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
V
Vladimir Mandic 已提交
195
        self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
Φ
Φφ 已提交
196 197
        self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
        self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
Y
Yea chen 已提交
198
        self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
199

200 201 202
        self.default_script_arg_txt2img = []
        self.default_script_arg_img2img = []

203 204 205 206 207
    def add_api_route(self, path: str, endpoint, **kwargs):
        if shared.cmd_opts.api_auth:
            return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
        return self.app.add_api_route(path, endpoint, **kwargs)

J
Jim Hays 已提交
208 209 210
    def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
        if credentials.username in self.credentials:
            if compare_digest(credentials.password, self.credentials[credentials.username]):
211 212 213
                return True

        raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
214

215 216
    def get_selectable_script(self, script_name, script_runner):
        if script_name is None or script_name == "":
A
AUTOMATIC 已提交
217 218 219 220 221
            return None, None

        script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
        script = script_runner.selectable_scripts[script_idx]
        return script, script_idx
Y
Yea chen 已提交
222 223
    
    def get_scripts_list(self):
Y
Yea Chen 已提交
224 225
        t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
        i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
Y
Yea chen 已提交
226 227

        return ScriptsList(txt2img = t2ilist, img2img = i2ilist)  
228

229
    def get_script(self, script_name, script_runner):
V
Vespinian 已提交
230 231 232 233 234
        if script_name is None or script_name == "":
            return None, None
        
        script_idx = script_name_to_index(script_name, script_runner.scripts)
        return script_runner.scripts[script_idx]
235

236
    def init_default_script_args(self, script_runner):
237 238 239 240 241
        #find max idx from the scripts in runner and generate a none array to init script_args
        last_arg_index = 1
        for script in script_runner.scripts:
            if last_arg_index < script.args_to:
                last_arg_index = script.args_to
V
Vespinian 已提交
242
        # None everywhere except position 0 to initialize script args
243
        script_args = [None]*last_arg_index
244 245 246 247 248 249 250 251 252 253 254 255 256 257
        script_args[0] = 0

        # get default values
        with gr.Blocks(): # will throw errors calling ui function without this
            for script in script_runner.scripts:
                if script.ui(script.is_img2img):
                    ui_default_values = []
                    for elem in script.ui(script.is_img2img):
                        ui_default_values.append(elem.value)
                    script_args[script.args_from:script.args_to] = ui_default_values
        return script_args

    def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner):
        script_args = default_script_args.copy()
V
Vespinian 已提交
258 259 260 261
        # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
        if selectable_scripts:
            script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
            script_args[0] = selectable_idx + 1
262 263

        # Now check for always on scripts
264 265
        if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
            for alwayson_script_name in request.alwayson_scripts.keys():
266 267 268 269 270 271
                alwayson_script = self.get_script(alwayson_script_name, script_runner)
                if alwayson_script == None:
                    raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
                # Selectable script in always on script param check
                if alwayson_script.alwayson == False:
                    raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
272 273
                # always on script with no arg should always run so you don't really need to add them to the requests
                if "args" in request.alwayson_scripts[alwayson_script_name]:
274 275 276
                    # min between arg length in scriptrunner and arg length in the request
                    for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
                        script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
V
Vespinian 已提交
277 278 279 280 281 282 283
        return script_args

    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
        script_runner = scripts.scripts_txt2img
        if not script_runner.scripts:
            script_runner.initialize_scripts(False)
            ui.create_ui()
284 285
        if not self.default_script_arg_txt2img:
            self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
V
Vespinian 已提交
286 287
        selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)

288
        populate = txt2imgreq.copy(update={  # Override __init__ params
V
Vespinian 已提交
289
            "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
290 291 292
            "do_not_save_samples": not txt2imgreq.save_images,
            "do_not_save_grid": not txt2imgreq.save_images,
        })
V
Vespinian 已提交
293 294 295 296 297 298
        if populate.sampler_name:
            populate.sampler_index = None  # prevent a warning later on

        args = vars(populate)
        args.pop('script_name', None)
        args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
299
        args.pop('alwayson_scripts', None)
V
Vespinian 已提交
300

301
        script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
302

303 304
        send_images = args.pop('send_images', True)
        args.pop('save_images', None)
305

A
arcticfaded 已提交
306
        with self.queue_lock:
307
            p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
308
            p.scripts = script_runner
309 310
            p.outpath_grids = opts.outdir_txt2img_grids
            p.outpath_samples = opts.outdir_txt2img_samples
311

P
Philpax 已提交
312
            shared.state.begin()
V
Vespinian 已提交
313
            if selectable_scripts != None:
314
                p.script_args = script_args
V
Vespinian 已提交
315
                processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
316
            else:
V
Vespinian 已提交
317
                p.script_args = tuple(script_args) # Need to pass args as tuple here
318
                processed = process_images(p)
P
Philpax 已提交
319
            shared.state.end()
320

321
        b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
E
evshiron 已提交
322

323
        return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
324

325 326 327
    def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
        init_images = img2imgreq.init_images
        if init_images is None:
E
evshiron 已提交
328
            raise HTTPException(status_code=404, detail="Init image not found")
329

S
Stephen 已提交
330 331
        mask = img2imgreq.mask
        if mask:
S
Sena 已提交
332
            mask = decode_base64_to_image(mask)
S
Stephen 已提交
333

334 335 336 337
        script_runner = scripts.scripts_img2img
        if not script_runner.scripts:
            script_runner.initialize_scripts(True)
            ui.create_ui()
338 339
        if not self.default_script_arg_img2img:
            self.default_script_arg_img2img = self.init_default_script_args(script_runner)
V
Vespinian 已提交
340
        selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
341

V
Vespinian 已提交
342
        populate = img2imgreq.copy(update={  # Override __init__ params
343
            "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
344 345 346 347
            "do_not_save_samples": not img2imgreq.save_images,
            "do_not_save_grid": not img2imgreq.save_images,
            "mask": mask,
        })
348 349
        if populate.sampler_name:
            populate.sampler_index = None  # prevent a warning later on
350 351 352

        args = vars(populate)
        args.pop('include_init_images', None)  # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
N
noodleanon 已提交
353
        args.pop('script_name', None)
V
Vespinian 已提交
354
        args.pop('script_args', None)  # will refeed them to the pipeline directly after initializing them
355
        args.pop('alwayson_scripts', None)
356

357
        script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
358

359 360
        send_images = args.pop('send_images', True)
        args.pop('save_images', None)
361

362
        with self.queue_lock:
363 364
            p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
            p.init_images = [decode_base64_to_image(x) for x in init_images]
365
            p.scripts = script_runner
366 367
            p.outpath_grids = opts.outdir_img2img_grids
            p.outpath_samples = opts.outdir_img2img_samples
368

P
Philpax 已提交
369
            shared.state.begin()
V
Vespinian 已提交
370
            if selectable_scripts != None:
371
                p.script_args = script_args
V
Vespinian 已提交
372
                processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
N
noodleanon 已提交
373
            else:
V
Vespinian 已提交
374
                p.script_args = tuple(script_args) # Need to pass args as tuple here
N
noodleanon 已提交
375
                processed = process_images(p)
P
Philpax 已提交
376
            shared.state.end()
E
evshiron 已提交
377

378
        b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
379

380
        if not img2imgreq.include_init_images:
381 382 383
            img2imgreq.init_images = None
            img2imgreq.mask = None

384
        return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
385

B
Bruno Seoane 已提交
386
    def extras_single_image_api(self, req: ExtrasSingleImageRequest):
B
Bruno Seoane 已提交
387
        reqDict = setUpscalers(req)
B
Bruno Seoane 已提交
388

B
Bruno Seoane 已提交
389
        reqDict['image'] = decode_base64_to_image(reqDict['image'])
B
Bruno Seoane 已提交
390 391

        with self.queue_lock:
392
            result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
B
Bruno Seoane 已提交
393

B
Bruno Seoane 已提交
394
        return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
395 396

    def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
B
Bruno Seoane 已提交
397
        reqDict = setUpscalers(req)
398

A
AUTOMATIC 已提交
399 400
        image_list = reqDict.pop('imageList', [])
        image_folder = [decode_base64_to_image(x.data) for x in image_list]
401 402

        with self.queue_lock:
A
AUTOMATIC 已提交
403
            result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
404

B
Bruno Seoane 已提交
405
        return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
406

B
Bruno Seoane 已提交
407
    def pnginfoapi(self, req: PNGInfoRequest):
B
Bruno Seoane 已提交
408 409 410
        if(not req.image.strip()):
            return PNGInfoResponse(info="")

411 412 413 414 415 416 417 418 419
        image = decode_base64_to_image(req.image.strip())
        if image is None:
            return PNGInfoResponse(info="")

        geninfo, items = images.read_info_from_image(image)
        if geninfo is None:
            geninfo = ""

        items = {**{'parameters': geninfo}, **items}
B
Bruno Seoane 已提交
420

421
        return PNGInfoResponse(info=geninfo, items=items)
422

423
    def progressapi(self, req: ProgressRequest = Depends()):
E
evshiron 已提交
424 425 426
        # copy from check_progress_call of ui.py

        if shared.state.job_count == 0:
427
            return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
E
evshiron 已提交
428 429 430 431 432 433 434 435 436 437 438 439 440 441 442

        # avoid dividing zero
        progress = 0.01

        if shared.state.job_count > 0:
            progress += shared.state.job_no / shared.state.job_count
        if shared.state.sampling_steps > 0:
            progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps

        time_since_start = time.time() - shared.state.time_start
        eta = (time_since_start/progress)
        eta_relative = eta-time_since_start

        progress = min(progress, 1)

A
AUTOMATIC 已提交
443
        shared.state.set_current_image()
444

445
        current_image = None
446
        if shared.state.current_image and not req.skip_current_image:
447 448
            current_image = encode_pil_to_base64(shared.state.current_image)

449
        return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
E
evshiron 已提交
450

451
    def interrogateapi(self, interrogatereq: InterrogateRequest):
R
Roy Shilkrot 已提交
452 453
        image_b64 = interrogatereq.image
        if image_b64 is None:
J
Jim Hays 已提交
454
            raise HTTPException(status_code=404, detail="Image not found")
R
Roy Shilkrot 已提交
455

456 457
        img = decode_base64_to_image(image_b64)
        img = img.convert('RGB')
R
Roy Shilkrot 已提交
458 459 460

        # Override object param
        with self.queue_lock:
461 462 463
            if interrogatereq.model == "clip":
                processed = shared.interrogator.interrogate(img)
            elif interrogatereq.model == "deepdanbooru":
464
                processed = deepbooru.model.tag(img)
465 466
            else:
                raise HTTPException(status_code=404, detail="Model not found")
J
Jim Hays 已提交
467

468
        return InterrogateResponse(caption=processed)
469

E
evshiron 已提交
470 471 472 473 474
    def interruptapi(self):
        shared.state.interrupt()

        return {}

Φ
Φφ 已提交
475 476 477 478 479 480 481 482 483 484
    def unloadapi(self):
        unload_model_weights()

        return {}

    def reloadapi(self):
        reload_model_weights()

        return {}

B
Bruno Seoane 已提交
485 486 487
    def skip(self):
        shared.state.skip()

B
Bruno Seoane 已提交
488 489 490 491 492 493 494 495
    def get_config(self):
        options = {}
        for key in shared.opts.data.keys():
            metadata = shared.opts.data_labels.get(key)
            if(metadata is not None):
                options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
            else:
                options.update({key: shared.opts.data.get(key, None)})
496

B
Bruno Seoane 已提交
497
        return options
498

B
Bruno Seoane 已提交
499
    def set_config(self, req: Dict[str, Any]):
500 501
        for k, v in req.items():
            shared.opts.set(k, v)
B
Bruno Seoane 已提交
502 503 504 505 506 507 508 509

        shared.opts.save(shared.config_filename)
        return

    def get_cmd_flags(self):
        return vars(shared.cmd_opts)

    def get_samplers(self):
510
        return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
B
Bruno Seoane 已提交
511 512

    def get_upscalers(self):
513 514 515 516 517
        return [
            {
                "name": upscaler.name,
                "model_name": upscaler.scaler.model_name,
                "model_path": upscaler.data_path,
518
                "model_url": None,
519 520 521 522
                "scale": upscaler.scale,
            }
            for upscaler in shared.sd_upscalers
        ]
523

B
Bruno Seoane 已提交
524
    def get_sd_models(self):
525
        return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
B
Bruno Seoane 已提交
526 527 528 529 530 531 532 533 534

    def get_hypernetworks(self):
        return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]

    def get_face_restorers(self):
        return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]

    def get_realesrgan_models(self):
        return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
535

J
Jim Hays 已提交
536
    def get_prompt_styles(self):
B
Bruno Seoane 已提交
537 538
        styleList = []
        for k in shared.prompt_styles.styles:
539
            style = shared.prompt_styles.styles[k]
540
            styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
B
Bruno Seoane 已提交
541 542 543

        return styleList

P
Philpax 已提交
544 545
    def get_embeddings(self):
        db = sd_hijack.model_hijack.embedding_db
546 547 548 549 550 551 552 553 554 555 556 557 558

        def convert_embedding(embedding):
            return {
                "step": embedding.step,
                "sd_checkpoint": embedding.sd_checkpoint,
                "sd_checkpoint_name": embedding.sd_checkpoint_name,
                "shape": embedding.shape,
                "vectors": embedding.vectors,
            }

        def convert_embeddings(embeddings):
            return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}

P
Philpax 已提交
559
        return {
560 561
            "loaded": convert_embeddings(db.word_embeddings),
            "skipped": convert_embeddings(db.skipped_embeddings),
P
Philpax 已提交
562 563
        }

D
Dean Hopkins 已提交
564 565
    def refresh_checkpoints(self):
        shared.refresh_checkpoints()
E
evshiron 已提交
566

V
Vladimir Mandic 已提交
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627
    def create_embedding(self, args: dict):
        try:
            shared.state.begin()
            filename = create_embedding(**args) # create empty embedding
            sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
            shared.state.end()
            return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
        except AssertionError as e:
            shared.state.end()
            return TrainResponse(info = "create embedding error: {error}".format(error = e))

    def create_hypernetwork(self, args: dict):
        try:
            shared.state.begin()
            filename = create_hypernetwork(**args) # create empty embedding
            shared.state.end()
            return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
        except AssertionError as e:
            shared.state.end()
            return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))

    def preprocess(self, args: dict):
        try:
            shared.state.begin()
            preprocess(**args) # quick operation unless blip/booru interrogation is enabled
            shared.state.end()
            return PreprocessResponse(info = 'preprocess complete')
        except KeyError as e:
            shared.state.end()
            return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
        except AssertionError as e:
            shared.state.end()
            return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
        except FileNotFoundError as e:
            shared.state.end()
            return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))

    def train_embedding(self, args: dict):
        try:
            shared.state.begin()
            apply_optimizations = shared.opts.training_xattention_optimizations
            error = None
            filename = ''
            if not apply_optimizations:
                sd_hijack.undo_optimizations()
            try:
                embedding, filename = train_embedding(**args) # can take a long time to complete
            except Exception as e:
                error = e
            finally:
                if not apply_optimizations:
                    sd_hijack.apply_optimizations()
                shared.state.end()
            return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
        except AssertionError as msg:
            shared.state.end()
            return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))

    def train_hypernetwork(self, args: dict):
        try:
            shared.state.begin()
A
AUTOMATIC 已提交
628
            shared.loaded_hypernetworks = []
V
Vladimir Mandic 已提交
629 630 631 632 633 634
            apply_optimizations = shared.opts.training_xattention_optimizations
            error = None
            filename = ''
            if not apply_optimizations:
                sd_hijack.undo_optimizations()
            try:
M
minux302 已提交
635
                hypernetwork, filename = train_hypernetwork(**args)
V
Vladimir Mandic 已提交
636 637 638 639 640 641 642 643
            except Exception as e:
                error = e
            finally:
                shared.sd_model.cond_stage_model.to(devices.device)
                shared.sd_model.first_stage_model.to(devices.device)
                if not apply_optimizations:
                    sd_hijack.apply_optimizations()
                shared.state.end()
A
AUTOMATIC 已提交
644
            return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
V
Vladimir Mandic 已提交
645 646
        except AssertionError as msg:
            shared.state.end()
A
AUTOMATIC 已提交
647
            return TrainResponse(info="train embedding error: {error}".format(error=error))
V
Vladimir Mandic 已提交
648

V
Vladimir Mandic 已提交
649 650 651 652
    def get_memory(self):
        try:
            import os, psutil
            process = psutil.Process(os.getpid())
V
Vladimir Mandic 已提交
653 654 655
            res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
            ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
            ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
V
Vladimir Mandic 已提交
656 657 658 659 660 661
        except Exception as err:
            ram = { 'error': f'{err}' }
        try:
            import torch
            if torch.cuda.is_available():
                s = torch.cuda.mem_get_info()
V
Vladimir Mandic 已提交
662
                system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
V
Vladimir Mandic 已提交
663
                s = dict(torch.cuda.memory_stats(shared.device))
V
Vladimir Mandic 已提交
664 665 666 667
                allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
                reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
                active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
                inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
V
Vladimir Mandic 已提交
668 669 670 671 672 673 674 675 676 677 678 679 680 681 682
                warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
                cuda = {
                    'system': system,
                    'active': active,
                    'allocated': allocated,
                    'reserved': reserved,
                    'inactive': inactive,
                    'events': warnings,
                }
            else:
                cuda = { 'error': 'unavailable' }
        except Exception as err:
            cuda = { 'error': f'{err}' }
        return MemoryResponse(ram = ram, cuda = cuda)

683
    def launch(self, server_name, port):
A
arcticfaded 已提交
684 685
        self.app.include_router(self.router)
        uvicorn.run(self.app, host=server_name, port=port)