api.py 28.2 KB
Newer Older
1 2
import base64
import io
3
import time
V
Vladimir Mandic 已提交
4
import datetime
5
import uvicorn
B
Bruno Seoane 已提交
6
from threading import Lock
S
Sena 已提交
7
from io import BytesIO
S
Sena 已提交
8
from gradio.processing_utils import decode_base64_to_file
V
Vladimir Mandic 已提交
9
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
10 11 12
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest

13
import modules.shared as shared
14
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
15
from modules.api.models import *
16
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
V
Vladimir Mandic 已提交
17 18 19
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
S
Sena 已提交
20
from PIL import PngImagePlugin,Image
21 22
from modules.sd_models import checkpoints_list
from modules.sd_models_config import find_checkpoint_config_near_filename
B
Bruno Seoane 已提交
23
from modules.realesrgan_model import get_realesrgan_models
V
Vladimir Mandic 已提交
24
from modules import devices
B
Bruno Seoane 已提交
25
from typing import List
V
Vladimir Mandic 已提交
26 27
import piexif
import piexif.helper
A
arcticfaded 已提交
28

B
Bruno Seoane 已提交
29 30 31 32
def upscaler_to_index(name: str):
    try:
        return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
    except:
N
noodleanon 已提交
33
        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
34

N
noodleanon 已提交
35 36 37 38 39
def script_name_to_index(name, scripts):
    try:
        return [script.title().lower() for script in scripts].index(name.lower())
    except:
        raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
40

41 42 43 44
def validate_sampler_name(name):
    config = sd_samplers.all_samplers_map.get(name, None)
    if config is None:
        raise HTTPException(status_code=404, detail="Sampler not found")
45

46
    return name
47

B
Bruno Seoane 已提交
48 49
def setUpscalers(req: dict):
    reqDict = vars(req)
50 51
    reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
    reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
B
Bruno Seoane 已提交
52
    return reqDict
R
Roy Shilkrot 已提交
53

S
Sena 已提交
54 55 56
def decode_base64_to_image(encoding):
    if encoding.startswith("data:image/"):
        encoding = encoding.split(";")[1].split(",")[1]
57 58 59 60 61
    try:
        image = Image.open(BytesIO(base64.b64decode(encoding)))
        return image
    except Exception as err:
        raise HTTPException(status_code=500, detail="Invalid encoded image")
62

63
def encode_pil_to_base64(image):
E
evshiron 已提交
64 65
    with io.BytesIO() as output_bytes:

V
Vladimir Mandic 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
        if opts.samples_format.lower() == 'png':
            use_metadata = False
            metadata = PngImagePlugin.PngInfo()
            for key, value in image.info.items():
                if isinstance(key, str) and isinstance(value, str):
                    metadata.add_text(key, value)
                    use_metadata = True
            image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)

        elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
            parameters = image.info.get('parameters', None)
            exif_bytes = piexif.dump({
                "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
            })
            if opts.samples_format.lower() in ("jpg", "jpeg"):
                image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
            else:
                image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)

        else:
            raise HTTPException(status_code=500, detail="Invalid image format")
E
evshiron 已提交
87 88

        bytes_data = output_bytes.getvalue()
V
Vladimir Mandic 已提交
89

E
evshiron 已提交
90
    return base64.b64encode(bytes_data)
91

V
Vladimir Mandic 已提交
92
def api_middleware(app: FastAPI):
V
Vladimir Mandic 已提交
93 94 95 96 97 98
    @app.middleware("http")
    async def log_and_time(req: Request, call_next):
        ts = time.time()
        res: Response = await call_next(req)
        duration = str(round(time.time() - ts, 4))
        res.headers["X-Process-Time"] = duration
V
Vladimir Mandic 已提交
99 100 101
        endpoint = req.scope.get('path', 'err')
        if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
            print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
V
Vladimir Mandic 已提交
102 103 104 105 106 107
                t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
                code = res.status_code,
                ver = req.scope.get('http_version', '0.0'),
                cli = req.scope.get('client', ('0:0.0.0', 0))[0],
                prot = req.scope.get('scheme', 'err'),
                method = req.scope.get('method', 'err'),
V
Vladimir Mandic 已提交
108
                endpoint = endpoint,
V
Vladimir Mandic 已提交
109 110 111 112
                duration = duration,
            ))
        return res

113

114
class Api:
B
Bruno Seoane 已提交
115
    def __init__(self, app: FastAPI, queue_lock: Lock):
116
        if shared.cmd_opts.api_auth:
J
Jim Hays 已提交
117
            self.credentials = dict()
118 119
            for auth in shared.cmd_opts.api_auth.split(","):
                user, password = auth.split(":")
J
Jim Hays 已提交
120
                self.credentials[user] = password
121

122
        self.router = APIRouter()
A
arcticfaded 已提交
123 124
        self.app = app
        self.queue_lock = queue_lock
V
Vladimir Mandic 已提交
125
        api_middleware(self.app)
126 127 128 129 130 131 132 133
        self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
        self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
        self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
        self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
        self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
        self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
        self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
        self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
134
        self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
135 136 137 138 139 140 141 142 143
        self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
        self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
        self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
        self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
        self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
        self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
        self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
        self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
        self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
J
Jim Hays 已提交
144
        self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
P
Philpax 已提交
145
        self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
D
Dean Hopkins 已提交
146
        self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
V
Vladimir Mandic 已提交
147 148 149 150 151
        self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
        self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
        self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
        self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
        self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
V
Vladimir Mandic 已提交
152
        self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
153 154 155 156 157 158

    def add_api_route(self, path: str, endpoint, **kwargs):
        if shared.cmd_opts.api_auth:
            return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
        return self.app.add_api_route(path, endpoint, **kwargs)

J
Jim Hays 已提交
159 160 161
    def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
        if credentials.username in self.credentials:
            if compare_digest(credentials.password, self.credentials[credentials.username]):
162 163 164
                return True

        raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
165

166 167
    def get_selectable_script(self, script_name, script_runner):
        if script_name is None or script_name == "":
A
AUTOMATIC 已提交
168 169 170 171 172
            return None, None

        script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
        script = script_runner.selectable_scripts[script_idx]
        return script, script_idx
173

174
    def get_script(self, script_name, script_runner):
V
Vespinian 已提交
175 176 177 178 179
        if script_name is None or script_name == "":
            return None, None
        
        script_idx = script_name_to_index(script_name, script_runner.scripts)
        return script_runner.scripts[script_idx]
180

V
Vespinian 已提交
181
    def init_script_args(self, request, selectable_scripts, selectable_idx, script_runner):
182 183 184 185 186
        #find max idx from the scripts in runner and generate a none array to init script_args
        last_arg_index = 1
        for script in script_runner.scripts:
            if last_arg_index < script.args_to:
                last_arg_index = script.args_to
V
Vespinian 已提交
187
        # None everywhere except position 0 to initialize script args
188
        script_args = [None]*last_arg_index
V
Vespinian 已提交
189 190 191 192
        # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
        if selectable_scripts:
            script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
            script_args[0] = selectable_idx + 1
193 194 195 196 197
        else:
            # if 0 then none
            script_args[0] = 0

        # Now check for always on scripts
V
Vespinian 已提交
198
        if request.alwayson_script_name and (len(request.alwayson_script_name) > 0):
199
            # always on script with no arg should always run, but if you include their name in the api request, send an empty list for there args
V
Vespinian 已提交
200 201 202
            if not request.alwayson_script_args:
                raise HTTPException(status_code=422, detail=f"Script {request.alwayson_script_name} has no arg list")
            if len(request.alwayson_script_name) != len(request.alwayson_script_args):
203 204
                raise HTTPException(status_code=422, detail=f"Number of script names and number of script arg lists doesn't match")

V
Vespinian 已提交
205
            for alwayson_script_name, alwayson_script_args in zip(request.alwayson_script_name, request.alwayson_script_args):
206 207 208 209 210 211 212 213
                alwayson_script = self.get_script(alwayson_script_name, script_runner)
                if alwayson_script == None:
                    raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
                # Selectable script in always on script param check
                if alwayson_script.alwayson == False:
                    raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
                if alwayson_script_args != []:
                    script_args[alwayson_script.args_from:alwayson_script.args_to] = alwayson_script_args
V
Vespinian 已提交
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
        return script_args

    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
        script_runner = scripts.scripts_txt2img
        if not script_runner.scripts:
            script_runner.initialize_scripts(False)
            ui.create_ui()
        selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)

        populate = txt2imgreq.copy(update={ # Override __init__ params
            "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
            "do_not_save_samples": True,
            "do_not_save_grid": True
            }
        )

        if populate.sampler_name:
            populate.sampler_index = None  # prevent a warning later on

        args = vars(populate)
        args.pop('script_name', None)
        args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
        args.pop('alwayson_script_name', None)
        args.pop('alwayson_script_args', None)

        script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner)
240

A
arcticfaded 已提交
241
        with self.queue_lock:
242
            p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
243
            p.scripts = script_runner
244

P
Philpax 已提交
245
            shared.state.begin()
V
Vespinian 已提交
246
            if selectable_scripts != None:
247
                p.script_args = script_args
248 249
                p.outpath_grids = opts.outdir_txt2img_grids
                p.outpath_samples = opts.outdir_txt2img_samples
V
Vespinian 已提交
250
                processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
251
            else:
V
Vespinian 已提交
252
                p.script_args = tuple(script_args) # Need to pass args as tuple here
253
                processed = process_images(p)
P
Philpax 已提交
254
            shared.state.end()
255

B
Bruno Seoane 已提交
256
        b64images = list(map(encode_pil_to_base64, processed.images))
E
evshiron 已提交
257

258
        return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
259

260 261 262
    def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
        init_images = img2imgreq.init_images
        if init_images is None:
E
evshiron 已提交
263
            raise HTTPException(status_code=404, detail="Init image not found")
264

S
Stephen 已提交
265 266
        mask = img2imgreq.mask
        if mask:
S
Sena 已提交
267
            mask = decode_base64_to_image(mask)
S
Stephen 已提交
268

269 270 271 272
        script_runner = scripts.scripts_img2img
        if not script_runner.scripts:
            script_runner.initialize_scripts(True)
            ui.create_ui()
V
Vespinian 已提交
273
        selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
274

275
        populate = img2imgreq.copy(update={ # Override __init__ params
276
            "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
277
            "do_not_save_samples": True,
E
evshiron 已提交
278
            "do_not_save_grid": True,
S
Stephen 已提交
279
            "mask": mask
280 281
            }
        )
282

283 284
        if populate.sampler_name:
            populate.sampler_index = None  # prevent a warning later on
285 286 287

        args = vars(populate)
        args.pop('include_init_images', None)  # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
N
noodleanon 已提交
288
        args.pop('script_name', None)
V
Vespinian 已提交
289
        args.pop('script_args', None)  # will refeed them to the pipeline directly after initializing them
290 291 292
        args.pop('alwayson_script_name', None)
        args.pop('alwayson_script_args', None)

V
Vespinian 已提交
293
        script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner)
294

295
        with self.queue_lock:
296 297
            p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
            p.init_images = [decode_base64_to_image(x) for x in init_images]
298
            p.scripts = script_runner
299

P
Philpax 已提交
300
            shared.state.begin()
V
Vespinian 已提交
301
            if selectable_scripts != None:
302
                p.script_args = script_args
N
noodleanon 已提交
303 304
                p.outpath_grids = opts.outdir_img2img_grids
                p.outpath_samples = opts.outdir_img2img_samples
V
Vespinian 已提交
305
                processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
N
noodleanon 已提交
306
            else:
V
Vespinian 已提交
307
                p.script_args = tuple(script_args) # Need to pass args as tuple here
N
noodleanon 已提交
308
                processed = process_images(p)
P
Philpax 已提交
309
            shared.state.end()
E
evshiron 已提交
310

B
Bruno Seoane 已提交
311
        b64images = list(map(encode_pil_to_base64, processed.images))
312

313
        if not img2imgreq.include_init_images:
314 315 316
            img2imgreq.init_images = None
            img2imgreq.mask = None

317
        return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
318

B
Bruno Seoane 已提交
319
    def extras_single_image_api(self, req: ExtrasSingleImageRequest):
B
Bruno Seoane 已提交
320
        reqDict = setUpscalers(req)
B
Bruno Seoane 已提交
321

B
Bruno Seoane 已提交
322
        reqDict['image'] = decode_base64_to_image(reqDict['image'])
B
Bruno Seoane 已提交
323 324

        with self.queue_lock:
325
            result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
B
Bruno Seoane 已提交
326

B
Bruno Seoane 已提交
327
        return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
328 329

    def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
B
Bruno Seoane 已提交
330
        reqDict = setUpscalers(req)
331

B
Bruno Seoane 已提交
332 333 334 335 336 337
        def prepareFiles(file):
            file = decode_base64_to_file(file.data, file_path=file.name)
            file.orig_name = file.name
            return file

        reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
338 339 340
        reqDict.pop('imageList')

        with self.queue_lock:
341
            result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
342

B
Bruno Seoane 已提交
343
        return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
344

B
Bruno Seoane 已提交
345
    def pnginfoapi(self, req: PNGInfoRequest):
B
Bruno Seoane 已提交
346 347 348
        if(not req.image.strip()):
            return PNGInfoResponse(info="")

349 350 351 352 353 354 355 356 357
        image = decode_base64_to_image(req.image.strip())
        if image is None:
            return PNGInfoResponse(info="")

        geninfo, items = images.read_info_from_image(image)
        if geninfo is None:
            geninfo = ""

        items = {**{'parameters': geninfo}, **items}
B
Bruno Seoane 已提交
358

359
        return PNGInfoResponse(info=geninfo, items=items)
360

361
    def progressapi(self, req: ProgressRequest = Depends()):
E
evshiron 已提交
362 363 364
        # copy from check_progress_call of ui.py

        if shared.state.job_count == 0:
365
            return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
E
evshiron 已提交
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380

        # avoid dividing zero
        progress = 0.01

        if shared.state.job_count > 0:
            progress += shared.state.job_no / shared.state.job_count
        if shared.state.sampling_steps > 0:
            progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps

        time_since_start = time.time() - shared.state.time_start
        eta = (time_since_start/progress)
        eta_relative = eta-time_since_start

        progress = min(progress, 1)

A
AUTOMATIC 已提交
381
        shared.state.set_current_image()
382

383
        current_image = None
384
        if shared.state.current_image and not req.skip_current_image:
385 386
            current_image = encode_pil_to_base64(shared.state.current_image)

387
        return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
E
evshiron 已提交
388

389
    def interrogateapi(self, interrogatereq: InterrogateRequest):
R
Roy Shilkrot 已提交
390 391
        image_b64 = interrogatereq.image
        if image_b64 is None:
J
Jim Hays 已提交
392
            raise HTTPException(status_code=404, detail="Image not found")
R
Roy Shilkrot 已提交
393

394 395
        img = decode_base64_to_image(image_b64)
        img = img.convert('RGB')
R
Roy Shilkrot 已提交
396 397 398

        # Override object param
        with self.queue_lock:
399 400 401
            if interrogatereq.model == "clip":
                processed = shared.interrogator.interrogate(img)
            elif interrogatereq.model == "deepdanbooru":
402
                processed = deepbooru.model.tag(img)
403 404
            else:
                raise HTTPException(status_code=404, detail="Model not found")
J
Jim Hays 已提交
405

406
        return InterrogateResponse(caption=processed)
407

E
evshiron 已提交
408 409 410 411 412
    def interruptapi(self):
        shared.state.interrupt()

        return {}

B
Bruno Seoane 已提交
413 414 415
    def skip(self):
        shared.state.skip()

B
Bruno Seoane 已提交
416 417 418 419 420 421 422 423
    def get_config(self):
        options = {}
        for key in shared.opts.data.keys():
            metadata = shared.opts.data_labels.get(key)
            if(metadata is not None):
                options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
            else:
                options.update({key: shared.opts.data.get(key, None)})
424

B
Bruno Seoane 已提交
425
        return options
426

B
Bruno Seoane 已提交
427
    def set_config(self, req: Dict[str, Any]):
428 429
        for k, v in req.items():
            shared.opts.set(k, v)
B
Bruno Seoane 已提交
430 431 432 433 434 435 436 437

        shared.opts.save(shared.config_filename)
        return

    def get_cmd_flags(self):
        return vars(shared.cmd_opts)

    def get_samplers(self):
438
        return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
B
Bruno Seoane 已提交
439 440

    def get_upscalers(self):
441 442 443 444 445
        return [
            {
                "name": upscaler.name,
                "model_name": upscaler.scaler.model_name,
                "model_path": upscaler.data_path,
446
                "model_url": None,
447 448 449 450
                "scale": upscaler.scale,
            }
            for upscaler in shared.sd_upscalers
        ]
451

B
Bruno Seoane 已提交
452
    def get_sd_models(self):
453
        return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
B
Bruno Seoane 已提交
454 455 456 457 458 459 460 461 462

    def get_hypernetworks(self):
        return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]

    def get_face_restorers(self):
        return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]

    def get_realesrgan_models(self):
        return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
463

J
Jim Hays 已提交
464
    def get_prompt_styles(self):
B
Bruno Seoane 已提交
465 466
        styleList = []
        for k in shared.prompt_styles.styles:
467
            style = shared.prompt_styles.styles[k]
468
            styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
B
Bruno Seoane 已提交
469 470 471

        return styleList

P
Philpax 已提交
472 473
    def get_embeddings(self):
        db = sd_hijack.model_hijack.embedding_db
474 475 476 477 478 479 480 481 482 483 484 485 486

        def convert_embedding(embedding):
            return {
                "step": embedding.step,
                "sd_checkpoint": embedding.sd_checkpoint,
                "sd_checkpoint_name": embedding.sd_checkpoint_name,
                "shape": embedding.shape,
                "vectors": embedding.vectors,
            }

        def convert_embeddings(embeddings):
            return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}

P
Philpax 已提交
487
        return {
488 489
            "loaded": convert_embeddings(db.word_embeddings),
            "skipped": convert_embeddings(db.skipped_embeddings),
P
Philpax 已提交
490 491
        }

D
Dean Hopkins 已提交
492 493
    def refresh_checkpoints(self):
        shared.refresh_checkpoints()
E
evshiron 已提交
494

V
Vladimir Mandic 已提交
495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555
    def create_embedding(self, args: dict):
        try:
            shared.state.begin()
            filename = create_embedding(**args) # create empty embedding
            sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
            shared.state.end()
            return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
        except AssertionError as e:
            shared.state.end()
            return TrainResponse(info = "create embedding error: {error}".format(error = e))

    def create_hypernetwork(self, args: dict):
        try:
            shared.state.begin()
            filename = create_hypernetwork(**args) # create empty embedding
            shared.state.end()
            return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
        except AssertionError as e:
            shared.state.end()
            return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))

    def preprocess(self, args: dict):
        try:
            shared.state.begin()
            preprocess(**args) # quick operation unless blip/booru interrogation is enabled
            shared.state.end()
            return PreprocessResponse(info = 'preprocess complete')
        except KeyError as e:
            shared.state.end()
            return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
        except AssertionError as e:
            shared.state.end()
            return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
        except FileNotFoundError as e:
            shared.state.end()
            return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))

    def train_embedding(self, args: dict):
        try:
            shared.state.begin()
            apply_optimizations = shared.opts.training_xattention_optimizations
            error = None
            filename = ''
            if not apply_optimizations:
                sd_hijack.undo_optimizations()
            try:
                embedding, filename = train_embedding(**args) # can take a long time to complete
            except Exception as e:
                error = e
            finally:
                if not apply_optimizations:
                    sd_hijack.apply_optimizations()
                shared.state.end()
            return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
        except AssertionError as msg:
            shared.state.end()
            return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))

    def train_hypernetwork(self, args: dict):
        try:
            shared.state.begin()
A
AUTOMATIC 已提交
556
            shared.loaded_hypernetworks = []
V
Vladimir Mandic 已提交
557 558 559 560 561 562
            apply_optimizations = shared.opts.training_xattention_optimizations
            error = None
            filename = ''
            if not apply_optimizations:
                sd_hijack.undo_optimizations()
            try:
M
minux302 已提交
563
                hypernetwork, filename = train_hypernetwork(**args)
V
Vladimir Mandic 已提交
564 565 566 567 568 569 570 571
            except Exception as e:
                error = e
            finally:
                shared.sd_model.cond_stage_model.to(devices.device)
                shared.sd_model.first_stage_model.to(devices.device)
                if not apply_optimizations:
                    sd_hijack.apply_optimizations()
                shared.state.end()
A
AUTOMATIC 已提交
572
            return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
V
Vladimir Mandic 已提交
573 574
        except AssertionError as msg:
            shared.state.end()
A
AUTOMATIC 已提交
575
            return TrainResponse(info="train embedding error: {error}".format(error=error))
V
Vladimir Mandic 已提交
576

V
Vladimir Mandic 已提交
577 578 579 580
    def get_memory(self):
        try:
            import os, psutil
            process = psutil.Process(os.getpid())
V
Vladimir Mandic 已提交
581 582 583
            res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
            ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
            ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
V
Vladimir Mandic 已提交
584 585 586 587 588 589
        except Exception as err:
            ram = { 'error': f'{err}' }
        try:
            import torch
            if torch.cuda.is_available():
                s = torch.cuda.mem_get_info()
V
Vladimir Mandic 已提交
590
                system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
V
Vladimir Mandic 已提交
591
                s = dict(torch.cuda.memory_stats(shared.device))
V
Vladimir Mandic 已提交
592 593 594 595
                allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
                reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
                active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
                inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
V
Vladimir Mandic 已提交
596 597 598 599 600 601 602 603 604 605 606 607 608 609 610
                warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
                cuda = {
                    'system': system,
                    'active': active,
                    'allocated': allocated,
                    'reserved': reserved,
                    'inactive': inactive,
                    'events': warnings,
                }
            else:
                cuda = { 'error': 'unavailable' }
        except Exception as err:
            cuda = { 'error': f'{err}' }
        return MemoryResponse(ram = ram, cuda = cuda)

611
    def launch(self, server_name, port):
A
arcticfaded 已提交
612 613
        self.app.include_router(self.router)
        uvicorn.run(self.app, host=server_name, port=port)