webui.py 17.2 KB
Newer Older
1 2
from __future__ import annotations

3
import os
4
import sys
D
DepFA 已提交
5
import time
D
DepFA 已提交
6
import importlib
7
import signal
8
import re
V
Vladimir Mandic 已提交
9
import warnings
10
import json
11
from threading import Thread
A
Aarni Koskela 已提交
12
from typing import Iterable
13

W
w-e-w 已提交
14
from fastapi import FastAPI
E
evshiron 已提交
15
from fastapi.middleware.cors import CORSMiddleware
D
DepFA 已提交
16
from fastapi.middleware.gzip import GZipMiddleware
17
from packaging import version
D
DepFA 已提交
18

19
import logging
20

A
Aarni Koskela 已提交
21 22 23 24 25 26 27 28 29 30 31
# We can't use cmd_opts for this because it will not have been initialized at this point.
log_level = os.environ.get("SD_WEBUI_LOG_LEVEL")
if log_level:
    log_level = getattr(logging, log_level.upper(), None) or logging.INFO
    logging.basicConfig(
        level=log_level,
        format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
    )

logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR)  # sshh...
32 33
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

34
from modules import timer
A
AUTOMATIC 已提交
35
startup_timer = timer.startup_timer
36
startup_timer.record("launcher")
37

38
import torch
A
AUTOMATIC 已提交
39
import pytorch_lightning   # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
V
Vladimir Mandic 已提交
40
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
V
Vladimir Mandic 已提交
41
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
A
AUTOMATIC 已提交
42 43
startup_timer.record("import torch")

A
AUTOMATIC1111 已提交
44
import gradio  # noqa: F401
A
AUTOMATIC 已提交
45 46
startup_timer.record("import gradio")

47 48 49
from modules import paths, timer, import_hook, errors, devices  # noqa: F401
startup_timer.record("setup paths")

A
AUTOMATIC 已提交
50
import ldm.modules.encoders.modules  # noqa: F401
A
AUTOMATIC 已提交
51 52
startup_timer.record("import ldm")

53
from modules import extra_networks
54
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock  # noqa: F401
A
AUTOMATIC 已提交
55

56 57
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
58
    torch.__long_version__ = torch.__version__
B
brkirch 已提交
59
    torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
60

A
AUTOMATIC 已提交
61
from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
D
d8ahazard 已提交
62 63 64
import modules.codeformer_model as codeformer
import modules.face_restoration
import modules.gfpgan_model as gfpgan
65
import modules.img2img
D
d8ahazard 已提交
66

D
d8ahazard 已提交
67 68 69
import modules.lowvram
import modules.scripts
import modules.sd_hijack
70
import modules.sd_hijack_optimizations
71
import modules.sd_models
M
Muhammad Rizqi Nur 已提交
72
import modules.sd_vae
A
AUTOMATIC 已提交
73
import modules.sd_unet
D
d8ahazard 已提交
74
import modules.txt2img
M
Maiko Tan 已提交
75
import modules.script_callbacks
76
import modules.textual_inversion.textual_inversion
77
import modules.progress
D
d8ahazard 已提交
78

D
d8ahazard 已提交
79
import modules.ui
D
d8ahazard 已提交
80
from modules import modelloader
D
d8ahazard 已提交
81
from modules.shared import cmd_opts
82
import modules.hypernetworks.hypernetwork
83

A
AUTOMATIC 已提交
84 85
startup_timer.record("other imports")

86

87 88 89 90
if cmd_opts.server_name:
    server_name = cmd_opts.server_name
else:
    server_name = "0.0.0.0" if cmd_opts.listen else None
91

92

93 94 95 96 97 98 99 100 101
def fix_asyncio_event_loop_policy():
    """
        The default `asyncio` event loop policy only automatically creates
        event loops in the main threads. Other threads must create event
        loops explicitly or `asyncio.get_event_loop` (and therefore
        `.IOLoop.current`) will fail. Installing this policy allows event
        loops to be created automatically on any thread, matching the
        behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
    """
102

103
    import asyncio
104

105 106 107 108 109 110
    if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
        # "Any thread" and "selector" should be orthogonal, but there's not a clean
        # interface for composing policies so pick the right base.
        _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy  # type: ignore
    else:
        _BasePolicy = asyncio.DefaultEventLoopPolicy
111

112 113 114
    class AnyThreadEventLoopPolicy(_BasePolicy):  # type: ignore
        """Event loop policy that allows loop creation on any thread.
        Usage::
115

116 117
            asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
        """
118

119 120 121 122 123 124 125 126 127 128
        def get_event_loop(self) -> asyncio.AbstractEventLoop:
            try:
                return super().get_event_loop()
            except (RuntimeError, AssertionError):
                # This was an AssertionError in python 3.4.2 (which ships with debian jessie)
                # and changed to a RuntimeError in 3.4.3.
                # "There is no current event loop in thread %r"
                loop = self.new_event_loop()
                self.set_event_loop(loop)
                return loop
129

130
    asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
131

132

133
def check_versions():
134 135 136
    if shared.cmd_opts.skip_version_check:
        return

137
    expected_torch_version = "2.0.0"
138 139 140 141 142 143

    if version.parse(torch.__version__) < version.parse(expected_torch_version):
        errors.print_error_explanation(f"""
You are running torch {torch.__version__}.
The program is tested to work with torch {expected_torch_version}.
To reinstall the desired version, run with commandline flag --reinstall-torch.
144 145 146 147
Beware that this will cause a lot of large files to be downloaded, as well as
there are reports of issues with training tab on the latest version.

Use --skip-version-check commandline argument to disable this check.
148 149
        """.strip())

S
Sakura-Luna 已提交
150
    expected_xformers_version = "0.0.20"
151 152 153 154 155 156 157 158
    if shared.xformers_available:
        import xformers

        if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
            errors.print_error_explanation(f"""
You are running xformers {xformers.__version__}.
The program is tested to work with xformers {expected_xformers_version}.
To reinstall the desired version, run with commandline flag --reinstall-xformers.
159 160

Use --skip-version-check commandline argument to disable this check.
161 162 163
            """.strip())


164
def restore_config_state_file():
165
    config_state_file = shared.opts.restore_config_state_file
166 167 168
    if config_state_file == "":
        return

169 170 171 172 173 174 175
    shared.opts.restore_config_state_file = ""
    shared.opts.save(shared.config_filename)

    if os.path.isfile(config_state_file):
        print(f"*** About to restore extension state from file: {config_state_file}")
        with open(config_state_file, "r", encoding="utf-8") as f:
            config_state = json.load(f)
176
            config_states.restore_extension_config(config_state)
177
        startup_timer.record("restore extension config")
178
    elif config_state_file:
179
        print(f"!!! Config state backup not found: {config_state_file}")
180

181

182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
def validate_tls_options():
    if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
        return

    try:
        if not os.path.exists(cmd_opts.tls_keyfile):
            print("Invalid path to TLS keyfile given")
        if not os.path.exists(cmd_opts.tls_certfile):
            print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
    except TypeError:
        cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
        print("TLS setup invalid, running webui without TLS")
    else:
        print("Running with TLS")
    startup_timer.record("TLS")


A
Aarni Koskela 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
    """
    Convert the gradio_auth and gradio_auth_path commandline arguments into
    an iterable of (username, password) tuples.
    """
    def process_credential_line(s) -> tuple[str, ...] | None:
        s = s.strip()
        if not s:
            return None
        return tuple(s.split(':', 1))

    if cmd_opts.gradio_auth:
        for cred in cmd_opts.gradio_auth.split(','):
            cred = process_credential_line(cred)
            if cred:
                yield cred

    if cmd_opts.gradio_auth_path:
        with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
            for line in file.readlines():
                for cred in line.strip().split(','):
                    cred = process_credential_line(cred)
                    if cred:
                        yield cred


225 226 227 228 229 230 231 232 233 234 235 236
def configure_sigint_handler():
    # make the program just exit at ctrl+c without waiting for anything
    def sigint_handler(sig, frame):
        print(f'Interrupted with signal {sig} in {frame}')
        os._exit(0)

    if not os.environ.get("COVERAGE_RUN"):
        # Don't install the immediate-quit handler when running under coverage,
        # as then the coverage report won't be generated.
        signal.signal(signal.SIGINT, sigint_handler)


237 238 239 240 241 242
def configure_opts_onchange():
    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
    shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
    shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
    shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
    shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
243
    shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
244 245 246
    startup_timer.record("opts onchange")


247 248
def initialize():
    fix_asyncio_event_loop_policy()
249
    validate_tls_options()
250
    configure_sigint_handler()
251
    check_versions()
252 253 254 255 256 257 258 259 260 261 262 263 264 265
    modelloader.cleanup_models()
    configure_opts_onchange()

    modules.sd_models.setup_model()
    startup_timer.record("setup SD model")

    codeformer.setup_model(cmd_opts.codeformer_models_path)
    startup_timer.record("setup codeformer")

    gfpgan.setup_model(cmd_opts.gfpgan_models_path)
    startup_timer.record("setup gfpgan")

    initialize_rest(reload_script_modules=False)

266

267 268 269 270 271
def initialize_rest(*, reload_script_modules=False):
    """
    Called both from initialize() and when reloading the webui.
    """
    sd_samplers.set_samplers()
272 273 274 275 276
    extensions.list_extensions()
    startup_timer.record("list extensions")

    restore_config_state_file()

Y
yfszzx 已提交
277
    if cmd_opts.ui_debug_mode:
278 279
        shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
        modules.scripts.load_scripts()
Y
yfszzx 已提交
280
        return
281

282
    modules.sd_models.list_models()
A
AUTOMATIC 已提交
283 284
    startup_timer.record("list SD models")

285
    localization.list_localizations(cmd_opts.localizations_dir)
A
AUTOMATIC 已提交
286

A
AUTOMATIC 已提交
287 288
    with startup_timer.subcategory("load scripts"):
        modules.scripts.load_scripts()
A
AUTOMATIC 已提交
289

290 291 292 293 294
    if reload_script_modules:
        for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
            importlib.reload(module)
        startup_timer.record("reload script modules")

A
Acncagua Slt 已提交
295
    modelloader.load_upscalers()
A
AUTOMATIC 已提交
296
    startup_timer.record("load upscalers")
A
Acncagua Slt 已提交
297

M
Muhammad Rizqi Nur 已提交
298
    modules.sd_vae.refresh_vae_list()
A
AUTOMATIC 已提交
299
    startup_timer.record("refresh VAE")
300
    modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
A
AUTOMATIC 已提交
301
    startup_timer.record("refresh textual inversion templates")
302

303 304 305 306
    modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
    modules.sd_hijack.list_optimizers()
    startup_timer.record("scripts list_optimizers")

A
AUTOMATIC 已提交
307 308 309
    modules.sd_unet.list_unets()
    startup_timer.record("scripts list_unets")

310 311 312 313 314 315 316 317 318 319 320 321 322 323
    def load_model():
        """
        Accesses shared.sd_model property to load model.
        After it's available, if it has been loaded before this access by some extension,
        its optimization may be None because the list of optimizaers has neet been filled
        by that time, so we apply optimization again.
        """

        shared.sd_model  # noqa: B018

        if modules.sd_hijack.current_optimizer is None:
            modules.sd_hijack.apply_optimizations()

    Thread(target=load_model).start()
A
AUTOMATIC 已提交
324

325
    Thread(target=devices.first_time_calculation).start()
A
AUTOMATIC 已提交
326

A
AUTOMATIC 已提交
327
    shared.reload_hypernetworks()
328
    startup_timer.record("reload hypernetworks")
A
AUTOMATIC 已提交
329

A
Aarni Koskela 已提交
330
    ui_extra_networks.initialize()
331
    ui_extra_networks.register_default_pages()
A
AUTOMATIC 已提交
332 333

    extra_networks.initialize()
334
    extra_networks.register_default_extra_networks()
335
    startup_timer.record("initialize extra networks")
336

D
DepFA 已提交
337

V
Vladimir Mandic 已提交
338
def setup_middleware(app):
339
    app.middleware_stack = None  # reset current middleware to allow modifying user provided list
V
Vladimir Mandic 已提交
340
    app.add_middleware(GZipMiddleware, minimum_size=1000)
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
    configure_cors_middleware(app)
    app.build_middleware_stack()  # rebuild middleware stack on-the-fly


def configure_cors_middleware(app):
    cors_options = {
        "allow_methods": ["*"],
        "allow_headers": ["*"],
        "allow_credentials": True,
    }
    if cmd_opts.cors_allow_origins:
        cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
    if cmd_opts.cors_allow_origins_regex:
        cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
    app.add_middleware(CORSMiddleware, **cors_options)
E
evshiron 已提交
356 357


358
def create_api(app):
359
    from modules.api.api import Api
A
arcticfaded 已提交
360
    api = Api(app, queue_lock)
361 362
    return api

363

364 365
def api_only():
    initialize()
D
DepFA 已提交
366

367
    app = FastAPI()
V
Vladimir Mandic 已提交
368
    setup_middleware(app)
369 370
    api = create_api(app)

371 372
    modules.script_callbacks.app_started_callback(None, app)

A
AUTOMATIC 已提交
373
    print(f"Startup time: {startup_timer.summary()}.")
374 375 376
    api.launch(
        server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
        port=cmd_opts.port if cmd_opts.port else 7861,
377
        root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else ""
378
    )
379

380

381 382
def webui():
    launch_api = cmd_opts.api
383
    initialize()
384

385
    while 1:
386 387
        if shared.opts.clean_temp_dir_at_start:
            ui_tempdir.cleanup_tmpdr()
A
AUTOMATIC 已提交
388
            startup_timer.record("cleanup temp dir")
389

A
AUTOMATIC 已提交
390
        modules.script_callbacks.before_ui_callback()
A
AUTOMATIC 已提交
391
        startup_timer.record("scripts before_ui_callback")
A
AUTOMATIC 已提交
392

393
        shared.demo = modules.ui.create_ui()
A
AUTOMATIC 已提交
394
        startup_timer.record("create ui")
395

A
AUTOMATIC 已提交
396
        if not cmd_opts.no_gradio_queue:
397 398
            shared.demo.queue(64)

A
Aarni Koskela 已提交
399
        gradio_auth_creds = list(get_gradio_auth_creds()) or None
A
AUTOMATIC 已提交
400

A
AUTOMATIC 已提交
401
        app, local_url, share_url = shared.demo.launch(
D
DepFA 已提交
402
            share=cmd_opts.share,
403
            server_name=server_name,
D
DepFA 已提交
404
            server_port=cmd_opts.port,
405 406
            ssl_keyfile=cmd_opts.tls_keyfile,
            ssl_certfile=cmd_opts.tls_certfile,
G
Garrett Sutula 已提交
407
            ssl_verify=cmd_opts.disable_tls_verify,
D
DepFA 已提交
408
            debug=cmd_opts.gradio_debug,
A
Aarni Koskela 已提交
409
            auth=gradio_auth_creds,
410
            inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING') != '1',
411 412
            prevent_thread_lock=True,
            allowed_paths=cmd_opts.gradio_allowed_path,
413 414 415 416
            app_kwargs={
                "docs_url": "/docs",
                "redoc_url": "/redoc",
            },
A
AUTOMATIC1111 已提交
417
            root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
D
DepFA 已提交
418
        )
419

W
w-e-w 已提交
420 421 422
        # after initial launch, disable --autolaunch for subsequent restarts
        cmd_opts.autolaunch = False

A
AUTOMATIC 已提交
423 424
        startup_timer.record("gradio launch")

425 426
        # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
        # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
J
Jim Hays 已提交
427 428
        # running web ui and do whatever the attacker wants, including installing an extension and
        # running its code. We disable this here. Suggested by RyotaK.
429 430
        app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']

V
Vladimir Mandic 已提交
431
        setup_middleware(app)
D
DepFA 已提交
432

433
        modules.progress.setup_progress_api(app)
434
        modules.ui.setup_ui_api(app)
435

436
        if launch_api:
437
            create_api(app)
D
DepFA 已提交
438

439 440
        ui_extra_networks.add_pages_to_demo(app)

A
AUTOMATIC 已提交
441 442 443 444
        startup_timer.record("add APIs")

        with startup_timer.subcategory("app_started_callback"):
            modules.script_callbacks.app_started_callback(shared.demo, app)
A
AUTOMATIC 已提交
445

A
AUTOMATIC 已提交
446
        timer.startup_record = startup_timer.dump()
A
AUTOMATIC 已提交
447
        print(f"Startup time: {startup_timer.summary()}.")
M
Maiko Tan 已提交
448

449 450 451 452 453 454 455 456 457
        try:
            while True:
                server_command = shared.state.wait_for_server_command(timeout=5)
                if server_command:
                    if server_command in ("stop", "restart"):
                        break
                    else:
                        print(f"Unknown server command: {server_command}")
        except KeyboardInterrupt:
458
            print('Caught KeyboardInterrupt, stopping...')
459 460 461
            server_command = "stop"

        if server_command == "stop":
462
            print("Stopping server...")
463 464 465
            # If we catch a keyboard interrupt, we want to stop the server and exit.
            shared.demo.close()
            break
A
AUTOMATIC 已提交
466

467
        print('Restarting UI...')
468 469
        shared.demo.close()
        time.sleep(0.5)
470
        startup_timer.reset()
471 472
        modules.script_callbacks.app_reload_callback()
        startup_timer.record("app reload callback")
473
        modules.script_callbacks.script_unloaded_callback()
474 475
        startup_timer.record("scripts unloaded callback")
        initialize_rest(reload_script_modules=True)
A
AUTOMATIC 已提交
476

A
AUTOMATIC 已提交
477

478
if __name__ == "__main__":
A
arcticfaded 已提交
479
    if cmd_opts.nowebui:
480
        api_only()
481
    else:
482
        webui()