api.py 24.6 KB
Newer Older
1 2
import base64
import io
3
import time
V
Vladimir Mandic 已提交
4
import datetime
5
import uvicorn
B
Bruno Seoane 已提交
6
from threading import Lock
S
Sena 已提交
7
from io import BytesIO
S
Sena 已提交
8
from gradio.processing_utils import decode_base64_to_file
V
Vladimir Mandic 已提交
9
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
10 11 12
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest

13
import modules.shared as shared
14
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
15
from modules.api.models import *
16
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
V
Vladimir Mandic 已提交
17 18 19
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
S
Sena 已提交
20
from PIL import PngImagePlugin,Image
21 22
from modules.sd_models import checkpoints_list
from modules.sd_models_config import find_checkpoint_config_near_filename
B
Bruno Seoane 已提交
23
from modules.realesrgan_model import get_realesrgan_models
V
Vladimir Mandic 已提交
24
from modules import devices
B
Bruno Seoane 已提交
25
from typing import List
V
Vladimir Mandic 已提交
26 27
import piexif
import piexif.helper
A
arcticfaded 已提交
28

B
Bruno Seoane 已提交
29 30 31 32
def upscaler_to_index(name: str):
    try:
        return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
    except:
N
noodleanon 已提交
33
        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
34

N
noodleanon 已提交
35 36 37 38 39
def script_name_to_index(name, scripts):
    try:
        return [script.title().lower() for script in scripts].index(name.lower())
    except:
        raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
40

41 42 43 44
def validate_sampler_name(name):
    config = sd_samplers.all_samplers_map.get(name, None)
    if config is None:
        raise HTTPException(status_code=404, detail="Sampler not found")
45

46
    return name
47

B
Bruno Seoane 已提交
48 49
def setUpscalers(req: dict):
    reqDict = vars(req)
50 51
    reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
    reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
B
Bruno Seoane 已提交
52
    return reqDict
R
Roy Shilkrot 已提交
53

S
Sena 已提交
54 55 56
def decode_base64_to_image(encoding):
    if encoding.startswith("data:image/"):
        encoding = encoding.split(";")[1].split(",")[1]
57 58 59 60 61
    try:
        image = Image.open(BytesIO(base64.b64decode(encoding)))
        return image
    except Exception as err:
        raise HTTPException(status_code=500, detail="Invalid encoded image")
62

63
def encode_pil_to_base64(image):
E
evshiron 已提交
64 65
    with io.BytesIO() as output_bytes:

V
Vladimir Mandic 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
        if opts.samples_format.lower() == 'png':
            use_metadata = False
            metadata = PngImagePlugin.PngInfo()
            for key, value in image.info.items():
                if isinstance(key, str) and isinstance(value, str):
                    metadata.add_text(key, value)
                    use_metadata = True
            image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)

        elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
            parameters = image.info.get('parameters', None)
            exif_bytes = piexif.dump({
                "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
            })
            if opts.samples_format.lower() in ("jpg", "jpeg"):
                image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
            else:
                image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)

        else:
            raise HTTPException(status_code=500, detail="Invalid image format")
E
evshiron 已提交
87 88

        bytes_data = output_bytes.getvalue()
V
Vladimir Mandic 已提交
89

E
evshiron 已提交
90
    return base64.b64encode(bytes_data)
91

V
Vladimir Mandic 已提交
92
def api_middleware(app: FastAPI):
V
Vladimir Mandic 已提交
93 94 95 96 97 98
    @app.middleware("http")
    async def log_and_time(req: Request, call_next):
        ts = time.time()
        res: Response = await call_next(req)
        duration = str(round(time.time() - ts, 4))
        res.headers["X-Process-Time"] = duration
V
Vladimir Mandic 已提交
99 100 101
        endpoint = req.scope.get('path', 'err')
        if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
            print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
V
Vladimir Mandic 已提交
102 103 104 105 106 107
                t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
                code = res.status_code,
                ver = req.scope.get('http_version', '0.0'),
                cli = req.scope.get('client', ('0:0.0.0', 0))[0],
                prot = req.scope.get('scheme', 'err'),
                method = req.scope.get('method', 'err'),
V
Vladimir Mandic 已提交
108
                endpoint = endpoint,
V
Vladimir Mandic 已提交
109 110 111 112
                duration = duration,
            ))
        return res

113

114
class Api:
B
Bruno Seoane 已提交
115
    def __init__(self, app: FastAPI, queue_lock: Lock):
116
        if shared.cmd_opts.api_auth:
J
Jim Hays 已提交
117
            self.credentials = dict()
118 119
            for auth in shared.cmd_opts.api_auth.split(","):
                user, password = auth.split(":")
J
Jim Hays 已提交
120
                self.credentials[user] = password
121

122
        self.router = APIRouter()
A
arcticfaded 已提交
123 124
        self.app = app
        self.queue_lock = queue_lock
V
Vladimir Mandic 已提交
125
        api_middleware(self.app)
126 127 128 129 130 131 132 133
        self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
        self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
        self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
        self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
        self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
        self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
        self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
        self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
134
        self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
135 136 137 138 139 140 141 142 143
        self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
        self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
        self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
        self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
        self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
        self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
        self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
        self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
        self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
J
Jim Hays 已提交
144
        self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
P
Philpax 已提交
145
        self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
D
Dean Hopkins 已提交
146
        self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
V
Vladimir Mandic 已提交
147 148 149 150 151
        self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
        self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
        self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
        self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
        self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
V
Vladimir Mandic 已提交
152
        self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
153 154 155 156 157 158

    def add_api_route(self, path: str, endpoint, **kwargs):
        if shared.cmd_opts.api_auth:
            return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
        return self.app.add_api_route(path, endpoint, **kwargs)

J
Jim Hays 已提交
159 160 161
    def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
        if credentials.username in self.credentials:
            if compare_digest(credentials.password, self.credentials[credentials.username]):
162 163 164
                return True

        raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
165

A
AUTOMATIC 已提交
166 167 168 169 170 171 172 173 174 175 176
    def get_script(self, script_name, script_runner):
        if script_name is None:
            return None, None

        if not script_runner.scripts:
            script_runner.initialize_scripts(False)
            ui.create_ui()

        script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
        script = script_runner.selectable_scripts[script_idx]
        return script, script_idx
177

A
AUTOMATIC 已提交
178 179
    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
        script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
180

A
arcticfaded 已提交
181
        populate = txt2imgreq.copy(update={ # Override __init__ params
182
            "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
A
arcticfaded 已提交
183 184
            "do_not_save_samples": True,
            "do_not_save_grid": True
A
arcticfaded 已提交
185 186
            }
        )
187 188
        if populate.sampler_name:
            populate.sampler_index = None  # prevent a warning later on
189

190 191 192
        args = vars(populate)
        args.pop('script_name', None)

A
arcticfaded 已提交
193
        with self.queue_lock:
194
            p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
195

P
Philpax 已提交
196
            shared.state.begin()
A
AUTOMATIC 已提交
197
            if script is not None:
198 199 200 201 202 203
                p.outpath_grids = opts.outdir_txt2img_grids
                p.outpath_samples = opts.outdir_txt2img_samples
                p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
                processed = scripts.scripts_txt2img.run(p, *p.script_args)
            else:
                processed = process_images(p)
P
Philpax 已提交
204
            shared.state.end()
205

B
Bruno Seoane 已提交
206
        b64images = list(map(encode_pil_to_base64, processed.images))
E
evshiron 已提交
207

208
        return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
209

210 211 212
    def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
        init_images = img2imgreq.init_images
        if init_images is None:
E
evshiron 已提交
213
            raise HTTPException(status_code=404, detail="Init image not found")
214

A
AUTOMATIC 已提交
215
        script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img)
N
noodleanon 已提交
216

S
Stephen 已提交
217 218
        mask = img2imgreq.mask
        if mask:
S
Sena 已提交
219
            mask = decode_base64_to_image(mask)
S
Stephen 已提交
220

221
        populate = img2imgreq.copy(update={ # Override __init__ params
222
            "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
223
            "do_not_save_samples": True,
E
evshiron 已提交
224
            "do_not_save_grid": True,
S
Stephen 已提交
225
            "mask": mask
226 227
            }
        )
228 229
        if populate.sampler_name:
            populate.sampler_index = None  # prevent a warning later on
230 231 232

        args = vars(populate)
        args.pop('include_init_images', None)  # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
N
noodleanon 已提交
233
        args.pop('script_name', None)
234

235
        with self.queue_lock:
236 237 238
            p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
            p.init_images = [decode_base64_to_image(x) for x in init_images]

P
Philpax 已提交
239
            shared.state.begin()
A
AUTOMATIC 已提交
240
            if script is not None:
N
noodleanon 已提交
241 242 243 244 245 246
                p.outpath_grids = opts.outdir_img2img_grids
                p.outpath_samples = opts.outdir_img2img_samples
                p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
                processed = scripts.scripts_img2img.run(p, *p.script_args)
            else:
                processed = process_images(p)
P
Philpax 已提交
247
            shared.state.end()
E
evshiron 已提交
248

B
Bruno Seoane 已提交
249
        b64images = list(map(encode_pil_to_base64, processed.images))
250

251
        if not img2imgreq.include_init_images:
252 253 254
            img2imgreq.init_images = None
            img2imgreq.mask = None

255
        return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
256

B
Bruno Seoane 已提交
257
    def extras_single_image_api(self, req: ExtrasSingleImageRequest):
B
Bruno Seoane 已提交
258
        reqDict = setUpscalers(req)
B
Bruno Seoane 已提交
259

B
Bruno Seoane 已提交
260
        reqDict['image'] = decode_base64_to_image(reqDict['image'])
B
Bruno Seoane 已提交
261 262

        with self.queue_lock:
263
            result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
B
Bruno Seoane 已提交
264

B
Bruno Seoane 已提交
265
        return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
266 267

    def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
B
Bruno Seoane 已提交
268
        reqDict = setUpscalers(req)
269

B
Bruno Seoane 已提交
270 271 272 273 274 275
        def prepareFiles(file):
            file = decode_base64_to_file(file.data, file_path=file.name)
            file.orig_name = file.name
            return file

        reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
276 277 278
        reqDict.pop('imageList')

        with self.queue_lock:
279
            result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
280

B
Bruno Seoane 已提交
281
        return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
282

B
Bruno Seoane 已提交
283
    def pnginfoapi(self, req: PNGInfoRequest):
B
Bruno Seoane 已提交
284 285 286
        if(not req.image.strip()):
            return PNGInfoResponse(info="")

287 288 289 290 291 292 293 294 295
        image = decode_base64_to_image(req.image.strip())
        if image is None:
            return PNGInfoResponse(info="")

        geninfo, items = images.read_info_from_image(image)
        if geninfo is None:
            geninfo = ""

        items = {**{'parameters': geninfo}, **items}
B
Bruno Seoane 已提交
296

297
        return PNGInfoResponse(info=geninfo, items=items)
298

299
    def progressapi(self, req: ProgressRequest = Depends()):
E
evshiron 已提交
300 301 302
        # copy from check_progress_call of ui.py

        if shared.state.job_count == 0:
303
            return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
E
evshiron 已提交
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318

        # avoid dividing zero
        progress = 0.01

        if shared.state.job_count > 0:
            progress += shared.state.job_no / shared.state.job_count
        if shared.state.sampling_steps > 0:
            progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps

        time_since_start = time.time() - shared.state.time_start
        eta = (time_since_start/progress)
        eta_relative = eta-time_since_start

        progress = min(progress, 1)

A
AUTOMATIC 已提交
319
        shared.state.set_current_image()
320

321
        current_image = None
322
        if shared.state.current_image and not req.skip_current_image:
323 324
            current_image = encode_pil_to_base64(shared.state.current_image)

325
        return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
E
evshiron 已提交
326

327
    def interrogateapi(self, interrogatereq: InterrogateRequest):
R
Roy Shilkrot 已提交
328 329
        image_b64 = interrogatereq.image
        if image_b64 is None:
J
Jim Hays 已提交
330
            raise HTTPException(status_code=404, detail="Image not found")
R
Roy Shilkrot 已提交
331

332 333
        img = decode_base64_to_image(image_b64)
        img = img.convert('RGB')
R
Roy Shilkrot 已提交
334 335 336

        # Override object param
        with self.queue_lock:
337 338 339
            if interrogatereq.model == "clip":
                processed = shared.interrogator.interrogate(img)
            elif interrogatereq.model == "deepdanbooru":
340
                processed = deepbooru.model.tag(img)
341 342
            else:
                raise HTTPException(status_code=404, detail="Model not found")
J
Jim Hays 已提交
343

344
        return InterrogateResponse(caption=processed)
345

E
evshiron 已提交
346 347 348 349 350
    def interruptapi(self):
        shared.state.interrupt()

        return {}

B
Bruno Seoane 已提交
351 352 353
    def skip(self):
        shared.state.skip()

B
Bruno Seoane 已提交
354 355 356 357 358 359 360 361
    def get_config(self):
        options = {}
        for key in shared.opts.data.keys():
            metadata = shared.opts.data_labels.get(key)
            if(metadata is not None):
                options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
            else:
                options.update({key: shared.opts.data.get(key, None)})
362

B
Bruno Seoane 已提交
363
        return options
364

B
Bruno Seoane 已提交
365
    def set_config(self, req: Dict[str, Any]):
366 367
        for k, v in req.items():
            shared.opts.set(k, v)
B
Bruno Seoane 已提交
368 369 370 371 372 373 374 375

        shared.opts.save(shared.config_filename)
        return

    def get_cmd_flags(self):
        return vars(shared.cmd_opts)

    def get_samplers(self):
376
        return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
B
Bruno Seoane 已提交
377 378

    def get_upscalers(self):
379 380 381 382 383
        return [
            {
                "name": upscaler.name,
                "model_name": upscaler.scaler.model_name,
                "model_path": upscaler.data_path,
384
                "model_url": None,
385 386 387 388
                "scale": upscaler.scale,
            }
            for upscaler in shared.sd_upscalers
        ]
389

B
Bruno Seoane 已提交
390
    def get_sd_models(self):
391
        return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
B
Bruno Seoane 已提交
392 393 394 395 396 397 398 399 400

    def get_hypernetworks(self):
        return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]

    def get_face_restorers(self):
        return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]

    def get_realesrgan_models(self):
        return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
401

J
Jim Hays 已提交
402
    def get_prompt_styles(self):
B
Bruno Seoane 已提交
403 404
        styleList = []
        for k in shared.prompt_styles.styles:
405
            style = shared.prompt_styles.styles[k]
406
            styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
B
Bruno Seoane 已提交
407 408 409

        return styleList

P
Philpax 已提交
410 411
    def get_embeddings(self):
        db = sd_hijack.model_hijack.embedding_db
412 413 414 415 416 417 418 419 420 421 422 423 424

        def convert_embedding(embedding):
            return {
                "step": embedding.step,
                "sd_checkpoint": embedding.sd_checkpoint,
                "sd_checkpoint_name": embedding.sd_checkpoint_name,
                "shape": embedding.shape,
                "vectors": embedding.vectors,
            }

        def convert_embeddings(embeddings):
            return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}

P
Philpax 已提交
425
        return {
426 427
            "loaded": convert_embeddings(db.word_embeddings),
            "skipped": convert_embeddings(db.skipped_embeddings),
P
Philpax 已提交
428 429
        }

D
Dean Hopkins 已提交
430 431
    def refresh_checkpoints(self):
        shared.refresh_checkpoints()
E
evshiron 已提交
432

V
Vladimir Mandic 已提交
433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493
    def create_embedding(self, args: dict):
        try:
            shared.state.begin()
            filename = create_embedding(**args) # create empty embedding
            sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
            shared.state.end()
            return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
        except AssertionError as e:
            shared.state.end()
            return TrainResponse(info = "create embedding error: {error}".format(error = e))

    def create_hypernetwork(self, args: dict):
        try:
            shared.state.begin()
            filename = create_hypernetwork(**args) # create empty embedding
            shared.state.end()
            return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
        except AssertionError as e:
            shared.state.end()
            return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))

    def preprocess(self, args: dict):
        try:
            shared.state.begin()
            preprocess(**args) # quick operation unless blip/booru interrogation is enabled
            shared.state.end()
            return PreprocessResponse(info = 'preprocess complete')
        except KeyError as e:
            shared.state.end()
            return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
        except AssertionError as e:
            shared.state.end()
            return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
        except FileNotFoundError as e:
            shared.state.end()
            return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))

    def train_embedding(self, args: dict):
        try:
            shared.state.begin()
            apply_optimizations = shared.opts.training_xattention_optimizations
            error = None
            filename = ''
            if not apply_optimizations:
                sd_hijack.undo_optimizations()
            try:
                embedding, filename = train_embedding(**args) # can take a long time to complete
            except Exception as e:
                error = e
            finally:
                if not apply_optimizations:
                    sd_hijack.apply_optimizations()
                shared.state.end()
            return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
        except AssertionError as msg:
            shared.state.end()
            return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))

    def train_hypernetwork(self, args: dict):
        try:
            shared.state.begin()
A
AUTOMATIC 已提交
494
            shared.loaded_hypernetworks = []
V
Vladimir Mandic 已提交
495 496 497 498 499 500
            apply_optimizations = shared.opts.training_xattention_optimizations
            error = None
            filename = ''
            if not apply_optimizations:
                sd_hijack.undo_optimizations()
            try:
M
minux302 已提交
501
                hypernetwork, filename = train_hypernetwork(**args)
V
Vladimir Mandic 已提交
502 503 504 505 506 507 508 509
            except Exception as e:
                error = e
            finally:
                shared.sd_model.cond_stage_model.to(devices.device)
                shared.sd_model.first_stage_model.to(devices.device)
                if not apply_optimizations:
                    sd_hijack.apply_optimizations()
                shared.state.end()
A
AUTOMATIC 已提交
510
            return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
V
Vladimir Mandic 已提交
511 512
        except AssertionError as msg:
            shared.state.end()
A
AUTOMATIC 已提交
513
            return TrainResponse(info="train embedding error: {error}".format(error=error))
V
Vladimir Mandic 已提交
514

V
Vladimir Mandic 已提交
515 516 517 518
    def get_memory(self):
        try:
            import os, psutil
            process = psutil.Process(os.getpid())
V
Vladimir Mandic 已提交
519 520 521
            res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
            ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
            ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
V
Vladimir Mandic 已提交
522 523 524 525 526 527
        except Exception as err:
            ram = { 'error': f'{err}' }
        try:
            import torch
            if torch.cuda.is_available():
                s = torch.cuda.mem_get_info()
V
Vladimir Mandic 已提交
528
                system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
V
Vladimir Mandic 已提交
529
                s = dict(torch.cuda.memory_stats(shared.device))
V
Vladimir Mandic 已提交
530 531 532 533
                allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
                reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
                active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
                inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
V
Vladimir Mandic 已提交
534 535 536 537 538 539 540 541 542 543 544 545 546 547 548
                warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
                cuda = {
                    'system': system,
                    'active': active,
                    'allocated': allocated,
                    'reserved': reserved,
                    'inactive': inactive,
                    'events': warnings,
                }
            else:
                cuda = { 'error': 'unavailable' }
        except Exception as err:
            cuda = { 'error': f'{err}' }
        return MemoryResponse(ram = ram, cuda = cuda)

549
    def launch(self, server_name, port):
A
arcticfaded 已提交
550 551
        self.app.include_router(self.router)
        uvicorn.run(self.app, host=server_name, port=port)