api.py 28.4 KB
Newer Older
1 2
import base64
import io
3
import time
V
Vladimir Mandic 已提交
4
import datetime
5
import uvicorn
B
Bruno Seoane 已提交
6
from threading import Lock
S
Sena 已提交
7
from io import BytesIO
S
Sena 已提交
8
from gradio.processing_utils import decode_base64_to_file
V
Vladimir Mandic 已提交
9
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
10 11 12
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest

13
import modules.shared as shared
14
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
15
from modules.api.models import *
16
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
V
Vladimir Mandic 已提交
17 18 19
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
S
Sena 已提交
20
from PIL import PngImagePlugin,Image
21 22
from modules.sd_models import checkpoints_list
from modules.sd_models_config import find_checkpoint_config_near_filename
B
Bruno Seoane 已提交
23
from modules.realesrgan_model import get_realesrgan_models
V
Vladimir Mandic 已提交
24
from modules import devices
B
Bruno Seoane 已提交
25
from typing import List
V
Vladimir Mandic 已提交
26 27
import piexif
import piexif.helper
A
arcticfaded 已提交
28

B
Bruno Seoane 已提交
29 30 31 32
def upscaler_to_index(name: str):
    try:
        return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
    except:
N
noodleanon 已提交
33
        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
34

N
noodleanon 已提交
35 36 37 38 39
def script_name_to_index(name, scripts):
    try:
        return [script.title().lower() for script in scripts].index(name.lower())
    except:
        raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
40

41 42 43 44
def validate_sampler_name(name):
    config = sd_samplers.all_samplers_map.get(name, None)
    if config is None:
        raise HTTPException(status_code=404, detail="Sampler not found")
45

46
    return name
47

B
Bruno Seoane 已提交
48 49
def setUpscalers(req: dict):
    reqDict = vars(req)
50 51
    reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
    reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
B
Bruno Seoane 已提交
52
    return reqDict
R
Roy Shilkrot 已提交
53

S
Sena 已提交
54 55 56
def decode_base64_to_image(encoding):
    if encoding.startswith("data:image/"):
        encoding = encoding.split(";")[1].split(",")[1]
57 58 59 60 61
    try:
        image = Image.open(BytesIO(base64.b64decode(encoding)))
        return image
    except Exception as err:
        raise HTTPException(status_code=500, detail="Invalid encoded image")
62

63
def encode_pil_to_base64(image):
E
evshiron 已提交
64 65
    with io.BytesIO() as output_bytes:

V
Vladimir Mandic 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
        if opts.samples_format.lower() == 'png':
            use_metadata = False
            metadata = PngImagePlugin.PngInfo()
            for key, value in image.info.items():
                if isinstance(key, str) and isinstance(value, str):
                    metadata.add_text(key, value)
                    use_metadata = True
            image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)

        elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
            parameters = image.info.get('parameters', None)
            exif_bytes = piexif.dump({
                "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
            })
            if opts.samples_format.lower() in ("jpg", "jpeg"):
                image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
            else:
                image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)

        else:
            raise HTTPException(status_code=500, detail="Invalid image format")
E
evshiron 已提交
87 88

        bytes_data = output_bytes.getvalue()
V
Vladimir Mandic 已提交
89

E
evshiron 已提交
90
    return base64.b64encode(bytes_data)
91

V
Vladimir Mandic 已提交
92
def api_middleware(app: FastAPI):
V
Vladimir Mandic 已提交
93 94 95 96 97 98
    @app.middleware("http")
    async def log_and_time(req: Request, call_next):
        ts = time.time()
        res: Response = await call_next(req)
        duration = str(round(time.time() - ts, 4))
        res.headers["X-Process-Time"] = duration
V
Vladimir Mandic 已提交
99 100 101
        endpoint = req.scope.get('path', 'err')
        if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
            print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
V
Vladimir Mandic 已提交
102 103 104 105 106 107
                t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
                code = res.status_code,
                ver = req.scope.get('http_version', '0.0'),
                cli = req.scope.get('client', ('0:0.0.0', 0))[0],
                prot = req.scope.get('scheme', 'err'),
                method = req.scope.get('method', 'err'),
V
Vladimir Mandic 已提交
108
                endpoint = endpoint,
V
Vladimir Mandic 已提交
109 110 111 112
                duration = duration,
            ))
        return res

113

114
class Api:
B
Bruno Seoane 已提交
115
    def __init__(self, app: FastAPI, queue_lock: Lock):
116
        if shared.cmd_opts.api_auth:
J
Jim Hays 已提交
117
            self.credentials = dict()
118 119
            for auth in shared.cmd_opts.api_auth.split(","):
                user, password = auth.split(":")
J
Jim Hays 已提交
120
                self.credentials[user] = password
121

122
        self.router = APIRouter()
A
arcticfaded 已提交
123 124
        self.app = app
        self.queue_lock = queue_lock
V
Vladimir Mandic 已提交
125
        api_middleware(self.app)
126 127 128 129 130 131 132 133
        self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
        self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
        self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
        self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
        self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
        self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
        self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
        self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
134
        self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
135 136 137 138 139 140 141 142 143
        self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
        self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
        self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
        self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
        self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
        self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
        self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
        self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
        self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
J
Jim Hays 已提交
144
        self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
P
Philpax 已提交
145
        self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
D
Dean Hopkins 已提交
146
        self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
V
Vladimir Mandic 已提交
147 148 149 150 151
        self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
        self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
        self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
        self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
        self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
V
Vladimir Mandic 已提交
152
        self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
Y
Yea chen 已提交
153
        self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
154 155 156 157 158 159

    def add_api_route(self, path: str, endpoint, **kwargs):
        if shared.cmd_opts.api_auth:
            return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
        return self.app.add_api_route(path, endpoint, **kwargs)

J
Jim Hays 已提交
160 161 162
    def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
        if credentials.username in self.credentials:
            if compare_digest(credentials.password, self.credentials[credentials.username]):
163 164 165
                return True

        raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
166

167 168
    def get_selectable_script(self, script_name, script_runner):
        if script_name is None or script_name == "":
A
AUTOMATIC 已提交
169 170 171 172 173
            return None, None

        script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
        script = script_runner.selectable_scripts[script_idx]
        return script, script_idx
Y
Yea chen 已提交
174 175
    
    def get_scripts_list(self):
Y
Yea Chen 已提交
176 177
        t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
        i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
Y
Yea chen 已提交
178 179

        return ScriptsList(txt2img = t2ilist, img2img = i2ilist)  
180

181
    def get_script(self, script_name, script_runner):
V
Vespinian 已提交
182 183 184 185 186
        if script_name is None or script_name == "":
            return None, None
        
        script_idx = script_name_to_index(script_name, script_runner.scripts)
        return script_runner.scripts[script_idx]
187

V
Vespinian 已提交
188
    def init_script_args(self, request, selectable_scripts, selectable_idx, script_runner):
189 190 191 192 193
        #find max idx from the scripts in runner and generate a none array to init script_args
        last_arg_index = 1
        for script in script_runner.scripts:
            if last_arg_index < script.args_to:
                last_arg_index = script.args_to
V
Vespinian 已提交
194
        # None everywhere except position 0 to initialize script args
195
        script_args = [None]*last_arg_index
V
Vespinian 已提交
196 197 198 199
        # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
        if selectable_scripts:
            script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
            script_args[0] = selectable_idx + 1
200
        else:
V
Vespinian 已提交
201
            # when [0] = 0 no selectable script to run
202 203 204
            script_args[0] = 0

        # Now check for always on scripts
205 206
        if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
            for alwayson_script_name in request.alwayson_scripts.keys():
207 208 209 210 211 212
                alwayson_script = self.get_script(alwayson_script_name, script_runner)
                if alwayson_script == None:
                    raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
                # Selectable script in always on script param check
                if alwayson_script.alwayson == False:
                    raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
213
                # always on script with no arg should always run so you don't really need to add them to the requests
214 215
                if "args" in request.alwayson_scripts[alwayson_script_name]:
                    script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
V
Vespinian 已提交
216 217 218
        return script_args

    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
219
        script_runner = scripts.scripts_txt2img
V
Vespinian 已提交
220 221 222 223 224
        if not script_runner.scripts:
            script_runner.initialize_scripts(False)
            ui.create_ui()
        selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)

225
        populate = txt2imgreq.copy(update={  # Override __init__ params
V
Vespinian 已提交
226
            "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
227 228 229
            "do_not_save_samples": not txt2imgreq.save_images,
            "do_not_save_grid": not txt2imgreq.save_images,
        })
V
Vespinian 已提交
230 231 232 233 234 235
        if populate.sampler_name:
            populate.sampler_index = None  # prevent a warning later on

        args = vars(populate)
        args.pop('script_name', None)
        args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
236
        args.pop('alwayson_scripts', None)
V
Vespinian 已提交
237 238

        script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner)
239

240 241
        send_images = args.pop('send_images', True)
        args.pop('save_images', None)
242

A
arcticfaded 已提交
243
        with self.queue_lock:
244
            p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
245
            p.scripts = script_runner
246 247
            p.outpath_grids = opts.outdir_txt2img_grids
            p.outpath_samples = opts.outdir_txt2img_samples
248

P
Philpax 已提交
249
            shared.state.begin()
V
Vespinian 已提交
250
            if selectable_scripts != None:
251
                p.script_args = script_args
V
Vespinian 已提交
252
                processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
253
            else:
V
Vespinian 已提交
254
                p.script_args = tuple(script_args) # Need to pass args as tuple here
255
                processed = process_images(p)
P
Philpax 已提交
256
            shared.state.end()
257

258
        b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
E
evshiron 已提交
259

260
        return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
261

262 263 264
    def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
        init_images = img2imgreq.init_images
        if init_images is None:
E
evshiron 已提交
265
            raise HTTPException(status_code=404, detail="Init image not found")
266

S
Stephen 已提交
267 268
        mask = img2imgreq.mask
        if mask:
S
Sena 已提交
269
            mask = decode_base64_to_image(mask)
S
Stephen 已提交
270

271
        script_runner = scripts.scripts_img2img
272 273 274
        if not script_runner.scripts:
            script_runner.initialize_scripts(True)
            ui.create_ui()
V
Vespinian 已提交
275
        selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
276

V
Vespinian 已提交
277
        populate = img2imgreq.copy(update={  # Override __init__ params
278
            "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
279 280 281 282
            "do_not_save_samples": not img2imgreq.save_images,
            "do_not_save_grid": not img2imgreq.save_images,
            "mask": mask,
        })
283 284
        if populate.sampler_name:
            populate.sampler_index = None  # prevent a warning later on
285 286 287

        args = vars(populate)
        args.pop('include_init_images', None)  # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
N
noodleanon 已提交
288
        args.pop('script_name', None)
V
Vespinian 已提交
289
        args.pop('script_args', None)  # will refeed them to the pipeline directly after initializing them
290
        args.pop('alwayson_scripts', None)
291

V
Vespinian 已提交
292
        script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner)
293

294 295
        send_images = args.pop('send_images', True)
        args.pop('save_images', None)
296

297
        with self.queue_lock:
298 299
            p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
            p.init_images = [decode_base64_to_image(x) for x in init_images]
300
            p.scripts = script_runner
301 302
            p.outpath_grids = opts.outdir_img2img_grids
            p.outpath_samples = opts.outdir_img2img_samples
303

P
Philpax 已提交
304
            shared.state.begin()
V
Vespinian 已提交
305
            if selectable_scripts != None:
306
                p.script_args = script_args
V
Vespinian 已提交
307
                processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
N
noodleanon 已提交
308
            else:
V
Vespinian 已提交
309
                p.script_args = tuple(script_args) # Need to pass args as tuple here
N
noodleanon 已提交
310
                processed = process_images(p)
P
Philpax 已提交
311
            shared.state.end()
E
evshiron 已提交
312

313
        b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
314

315
        if not img2imgreq.include_init_images:
316 317 318
            img2imgreq.init_images = None
            img2imgreq.mask = None

319
        return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
320

B
Bruno Seoane 已提交
321
    def extras_single_image_api(self, req: ExtrasSingleImageRequest):
B
Bruno Seoane 已提交
322
        reqDict = setUpscalers(req)
B
Bruno Seoane 已提交
323

B
Bruno Seoane 已提交
324
        reqDict['image'] = decode_base64_to_image(reqDict['image'])
B
Bruno Seoane 已提交
325 326

        with self.queue_lock:
327
            result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
B
Bruno Seoane 已提交
328

B
Bruno Seoane 已提交
329
        return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
330 331

    def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
B
Bruno Seoane 已提交
332
        reqDict = setUpscalers(req)
333

B
Bruno Seoane 已提交
334 335 336 337 338 339
        def prepareFiles(file):
            file = decode_base64_to_file(file.data, file_path=file.name)
            file.orig_name = file.name
            return file

        reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
340 341 342
        reqDict.pop('imageList')

        with self.queue_lock:
343
            result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
344

B
Bruno Seoane 已提交
345
        return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
346

B
Bruno Seoane 已提交
347
    def pnginfoapi(self, req: PNGInfoRequest):
B
Bruno Seoane 已提交
348 349 350
        if(not req.image.strip()):
            return PNGInfoResponse(info="")

351 352 353 354 355 356 357 358 359
        image = decode_base64_to_image(req.image.strip())
        if image is None:
            return PNGInfoResponse(info="")

        geninfo, items = images.read_info_from_image(image)
        if geninfo is None:
            geninfo = ""

        items = {**{'parameters': geninfo}, **items}
B
Bruno Seoane 已提交
360

361
        return PNGInfoResponse(info=geninfo, items=items)
362

363
    def progressapi(self, req: ProgressRequest = Depends()):
E
evshiron 已提交
364 365 366
        # copy from check_progress_call of ui.py

        if shared.state.job_count == 0:
367
            return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
E
evshiron 已提交
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382

        # avoid dividing zero
        progress = 0.01

        if shared.state.job_count > 0:
            progress += shared.state.job_no / shared.state.job_count
        if shared.state.sampling_steps > 0:
            progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps

        time_since_start = time.time() - shared.state.time_start
        eta = (time_since_start/progress)
        eta_relative = eta-time_since_start

        progress = min(progress, 1)

A
AUTOMATIC 已提交
383
        shared.state.set_current_image()
384

385
        current_image = None
386
        if shared.state.current_image and not req.skip_current_image:
387 388
            current_image = encode_pil_to_base64(shared.state.current_image)

389
        return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
E
evshiron 已提交
390

391
    def interrogateapi(self, interrogatereq: InterrogateRequest):
R
Roy Shilkrot 已提交
392 393
        image_b64 = interrogatereq.image
        if image_b64 is None:
J
Jim Hays 已提交
394
            raise HTTPException(status_code=404, detail="Image not found")
R
Roy Shilkrot 已提交
395

396 397
        img = decode_base64_to_image(image_b64)
        img = img.convert('RGB')
R
Roy Shilkrot 已提交
398 399 400

        # Override object param
        with self.queue_lock:
401 402 403
            if interrogatereq.model == "clip":
                processed = shared.interrogator.interrogate(img)
            elif interrogatereq.model == "deepdanbooru":
404
                processed = deepbooru.model.tag(img)
405 406
            else:
                raise HTTPException(status_code=404, detail="Model not found")
J
Jim Hays 已提交
407

408
        return InterrogateResponse(caption=processed)
409

E
evshiron 已提交
410 411 412 413 414
    def interruptapi(self):
        shared.state.interrupt()

        return {}

B
Bruno Seoane 已提交
415 416 417
    def skip(self):
        shared.state.skip()

B
Bruno Seoane 已提交
418 419 420 421 422 423 424 425
    def get_config(self):
        options = {}
        for key in shared.opts.data.keys():
            metadata = shared.opts.data_labels.get(key)
            if(metadata is not None):
                options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
            else:
                options.update({key: shared.opts.data.get(key, None)})
426

B
Bruno Seoane 已提交
427
        return options
428

B
Bruno Seoane 已提交
429
    def set_config(self, req: Dict[str, Any]):
430 431
        for k, v in req.items():
            shared.opts.set(k, v)
B
Bruno Seoane 已提交
432 433 434 435 436 437 438 439

        shared.opts.save(shared.config_filename)
        return

    def get_cmd_flags(self):
        return vars(shared.cmd_opts)

    def get_samplers(self):
440
        return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
B
Bruno Seoane 已提交
441 442

    def get_upscalers(self):
443 444 445 446 447
        return [
            {
                "name": upscaler.name,
                "model_name": upscaler.scaler.model_name,
                "model_path": upscaler.data_path,
448
                "model_url": None,
449 450 451 452
                "scale": upscaler.scale,
            }
            for upscaler in shared.sd_upscalers
        ]
453

B
Bruno Seoane 已提交
454
    def get_sd_models(self):
455
        return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
B
Bruno Seoane 已提交
456 457 458 459 460 461 462 463 464

    def get_hypernetworks(self):
        return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]

    def get_face_restorers(self):
        return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]

    def get_realesrgan_models(self):
        return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
465

J
Jim Hays 已提交
466
    def get_prompt_styles(self):
B
Bruno Seoane 已提交
467 468
        styleList = []
        for k in shared.prompt_styles.styles:
469
            style = shared.prompt_styles.styles[k]
470
            styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
B
Bruno Seoane 已提交
471 472 473

        return styleList

P
Philpax 已提交
474 475
    def get_embeddings(self):
        db = sd_hijack.model_hijack.embedding_db
476 477 478 479 480 481 482 483 484 485 486 487 488

        def convert_embedding(embedding):
            return {
                "step": embedding.step,
                "sd_checkpoint": embedding.sd_checkpoint,
                "sd_checkpoint_name": embedding.sd_checkpoint_name,
                "shape": embedding.shape,
                "vectors": embedding.vectors,
            }

        def convert_embeddings(embeddings):
            return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}

P
Philpax 已提交
489
        return {
490 491
            "loaded": convert_embeddings(db.word_embeddings),
            "skipped": convert_embeddings(db.skipped_embeddings),
P
Philpax 已提交
492 493
        }

D
Dean Hopkins 已提交
494 495
    def refresh_checkpoints(self):
        shared.refresh_checkpoints()
E
evshiron 已提交
496

V
Vladimir Mandic 已提交
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
    def create_embedding(self, args: dict):
        try:
            shared.state.begin()
            filename = create_embedding(**args) # create empty embedding
            sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
            shared.state.end()
            return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
        except AssertionError as e:
            shared.state.end()
            return TrainResponse(info = "create embedding error: {error}".format(error = e))

    def create_hypernetwork(self, args: dict):
        try:
            shared.state.begin()
            filename = create_hypernetwork(**args) # create empty embedding
            shared.state.end()
            return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
        except AssertionError as e:
            shared.state.end()
            return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))

    def preprocess(self, args: dict):
        try:
            shared.state.begin()
            preprocess(**args) # quick operation unless blip/booru interrogation is enabled
            shared.state.end()
            return PreprocessResponse(info = 'preprocess complete')
        except KeyError as e:
            shared.state.end()
            return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
        except AssertionError as e:
            shared.state.end()
            return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
        except FileNotFoundError as e:
            shared.state.end()
            return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))

    def train_embedding(self, args: dict):
        try:
            shared.state.begin()
            apply_optimizations = shared.opts.training_xattention_optimizations
            error = None
            filename = ''
            if not apply_optimizations:
                sd_hijack.undo_optimizations()
            try:
                embedding, filename = train_embedding(**args) # can take a long time to complete
            except Exception as e:
                error = e
            finally:
                if not apply_optimizations:
                    sd_hijack.apply_optimizations()
                shared.state.end()
            return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
        except AssertionError as msg:
            shared.state.end()
            return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))

    def train_hypernetwork(self, args: dict):
        try:
            shared.state.begin()
A
AUTOMATIC 已提交
558
            shared.loaded_hypernetworks = []
V
Vladimir Mandic 已提交
559 560 561 562 563 564
            apply_optimizations = shared.opts.training_xattention_optimizations
            error = None
            filename = ''
            if not apply_optimizations:
                sd_hijack.undo_optimizations()
            try:
M
minux302 已提交
565
                hypernetwork, filename = train_hypernetwork(**args)
V
Vladimir Mandic 已提交
566 567 568 569 570 571 572 573
            except Exception as e:
                error = e
            finally:
                shared.sd_model.cond_stage_model.to(devices.device)
                shared.sd_model.first_stage_model.to(devices.device)
                if not apply_optimizations:
                    sd_hijack.apply_optimizations()
                shared.state.end()
A
AUTOMATIC 已提交
574
            return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
V
Vladimir Mandic 已提交
575 576
        except AssertionError as msg:
            shared.state.end()
A
AUTOMATIC 已提交
577
            return TrainResponse(info="train embedding error: {error}".format(error=error))
V
Vladimir Mandic 已提交
578

V
Vladimir Mandic 已提交
579 580 581 582
    def get_memory(self):
        try:
            import os, psutil
            process = psutil.Process(os.getpid())
V
Vladimir Mandic 已提交
583 584 585
            res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
            ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
            ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
V
Vladimir Mandic 已提交
586 587 588 589 590 591
        except Exception as err:
            ram = { 'error': f'{err}' }
        try:
            import torch
            if torch.cuda.is_available():
                s = torch.cuda.mem_get_info()
V
Vladimir Mandic 已提交
592
                system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
V
Vladimir Mandic 已提交
593
                s = dict(torch.cuda.memory_stats(shared.device))
V
Vladimir Mandic 已提交
594 595 596 597
                allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
                reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
                active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
                inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
V
Vladimir Mandic 已提交
598 599 600 601 602 603 604 605 606 607 608 609 610 611 612
                warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
                cuda = {
                    'system': system,
                    'active': active,
                    'allocated': allocated,
                    'reserved': reserved,
                    'inactive': inactive,
                    'events': warnings,
                }
            else:
                cuda = { 'error': 'unavailable' }
        except Exception as err:
            cuda = { 'error': f'{err}' }
        return MemoryResponse(ram = ram, cuda = cuda)

613
    def launch(self, server_name, port):
A
arcticfaded 已提交
614 615
        self.app.include_router(self.router)
        uvicorn.run(self.app, host=server_name, port=port)