scripts.py 33.0 KB
Newer Older
1
import os
2
import re
3
import sys
H
huchenlei 已提交
4
import inspect
5
from graphlib import TopologicalSorter, CycleError
6
from collections import namedtuple
7
from dataclasses import dataclass
8 9 10

import gradio as gr

11
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
12 13 14

AlwaysVisible = object()

A
AUTOMATIC 已提交
15

16 17 18 19 20
class PostprocessImageArgs:
    def __init__(self, image):
        self.image = image


L
ljleb 已提交
21 22 23 24 25
class PostprocessBatchListArgs:
    def __init__(self, images):
        self.images = images


26 27 28 29 30
@dataclass
class OnComponent:
    component: gr.blocks.Block


31
class Script:
A
AUTOMATIC 已提交
32 33 34
    name = None
    """script's internal name derived from title"""

35 36 37
    section = None
    """name of UI section that the script's controls will be placed into"""

38
    filename = None
A
AUTOMATIC 已提交
39 40
    args_from = None
    args_to = None
41 42
    alwayson = False

43 44
    is_txt2img = False
    is_img2img = False
45
    tabname = None
46

47
    group = None
48 49 50 51
    """A gr.Group component that has all script's UI inside it."""

    create_group = True
    """If False, for alwayson scripts, a group component will not be created."""
52

53 54 55 56
    infotext_fields = None
    """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
    parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
    """
57

58 59 60 61 62
    paste_field_names = None
    """if set in ui(), this is a list of names of infotext fields; the fields will be sent through the
    various "Send to <X>" buttons when clicked
    """

A
AUTOMATIC 已提交
63 64 65
    api_info = None
    """Generated value of type modules.api.models.ScriptInfo with information about the script for API"""

A
AUTOMATIC1111 已提交
66
    on_before_component_elem_id = None
67 68
    """list of callbacks to be called before a component with an elem_id is created"""

A
AUTOMATIC1111 已提交
69
    on_after_component_elem_id = None
70 71
    """list of callbacks to be called after a component with an elem_id is created"""

A
AUTOMATIC1111 已提交
72 73 74
    setup_for_ui_only = False
    """If true, the script setup will only be run in Gradio UI, not in API"""

75
    def title(self):
76 77
        """this function should return the title of the script. This is what will be displayed in the dropdown menu."""

78 79
        raise NotImplementedError()

A
AUTOMATIC 已提交
80
    def ui(self, is_img2img):
81 82
        """this function should create gradio UI elements. See https://gradio.app/docs/#components
        The return value should be an array of all components that are used in processing.
J
Jim Hays 已提交
83
        Values of those returned components will be passed to run() and process() functions.
84 85
        """

A
AUTOMATIC 已提交
86 87
        pass

A
AUTOMATIC 已提交
88
    def show(self, is_img2img):
89 90 91 92 93
        """
        is_img2img is True if this function is called for the img2img interface, and Fasle otherwise

        This function should return:
         - False if the script should not be shown in UI at all
J
Jim Hays 已提交
94
         - True if the script should be shown in UI if it's selected in the scripts dropdown
95 96 97
         - script.AlwaysVisible if the script should be shown in UI at all times
         """

A
AUTOMATIC 已提交
98 99
        return True

100 101 102 103 104 105 106 107 108 109 110
    def run(self, p, *args):
        """
        This function is called if the script has been selected in the script dropdown.
        It must do all processing and return the Processed object with results, same as
        one returned by processing.process_images.

        Usually the processing is done by calling the processing.process_images function.

        args contains all values returned by components from ui()
        """

111
        pass
A
AUTOMATIC 已提交
112

A
AUTOMATIC1111 已提交
113 114 115 116 117 118 119
    def setup(self, p, *args):
        """For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
        args contains all values returned by components from ui().
        """
        pass


120 121
    def before_process(self, p, *args):
        """
A
AUTOMATIC1111 已提交
122
        This function is called very early during processing begins for AlwaysVisible scripts.
123 124 125 126 127 128
        You can modify the processing object (p) here, inject hooks, etc.
        args contains all values returned by components from ui()
        """

        pass

129 130 131
    def process(self, p, *args):
        """
        This function is called before processing begins for AlwaysVisible scripts.
A
AUTOMATIC 已提交
132 133 134 135 136 137
        You can modify the processing object (p) here, inject hooks, etc.
        args contains all values returned by components from ui()
        """

        pass

138 139 140 141 142 143 144 145 146 147 148 149 150 151
    def before_process_batch(self, p, *args, **kwargs):
        """
        Called before extra networks are parsed from the prompt, so you can add
        new extra network keywords to the prompt with this callback.

        **kwargs will have those items:
          - batch_number - index of current batch, from 0 to number of batches-1
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
          - seeds - list of seeds for current batch
          - subseeds - list of subseeds for current batch
        """

        pass

152 153
    def after_extra_networks_activate(self, p, *args, **kwargs):
        """
W
w-e-w 已提交
154
        Called after extra networks activation, before conds calculation
155 156 157 158 159 160 161 162 163 164 165 166
        allow modification of the network after extra networks activation been applied
        won't be call if p.disable_extra_networks

        **kwargs will have those items:
          - batch_number - index of current batch, from 0 to number of batches-1
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
          - seeds - list of seeds for current batch
          - subseeds - list of subseeds for current batch
          - extra_network_data - list of ExtraNetworkParams for current stage
        """
        pass

167
    def process_batch(self, p, *args, **kwargs):
A
Artem Zagidulin 已提交
168
        """
169 170 171 172 173 174 175
        Same as process(), but called for every batch.

        **kwargs will have those items:
          - batch_number - index of current batch, from 0 to number of batches-1
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
          - seeds - list of seeds for current batch
          - subseeds - list of subseeds for current batch
A
Artem Zagidulin 已提交
176 177 178 179
        """

        pass

180 181 182 183 184 185 186 187 188 189 190
    def postprocess_batch(self, p, *args, **kwargs):
        """
        Same as process_batch(), but called for every batch after it has been generated.

        **kwargs will have same items as process_batch, and also:
          - batch_number - index of current batch, from 0 to number of batches-1
          - images - torch tensor with all generated images, with values ranging from 0 to 1;
        """

        pass

L
ljleb 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
    def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
        """
        Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
        This is useful when you want to update the entire batch instead of individual images.

        You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
        If the number of images is different from the batch size when returning,
        then the script has the responsibility to also update the following attributes in the processing object (p):
          - p.prompts
          - p.negative_prompts
          - p.seeds
          - p.subseeds

        **kwargs will have same items as process_batch, and also:
          - batch_number - index of current batch, from 0 to number of batches-1
        """

        pass

210 211 212 213 214 215 216
    def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
        """
        Called for every image after it has been generated.
        """

        pass

A
AUTOMATIC 已提交
217 218 219 220
    def postprocess(self, p, processed, *args):
        """
        This function is called after processing ends for AlwaysVisible scripts.
        args contains all values returned by components from ui()
221 222 223 224
        """

        pass

225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
    def before_component(self, component, **kwargs):
        """
        Called before a component is created.
        Use elem_id/label fields of kwargs to figure out which component it is.
        This can be useful to inject your own components somewhere in the middle of vanilla UI.
        You can return created components in the ui() function to add them to the list of arguments for your processing functions
        """

        pass

    def after_component(self, component, **kwargs):
        """
        Called after a component is created. Same as above.
        """

        pass

242 243 244 245
    def on_before_component(self, callback, *, elem_id):
        """
        Calls callback before a component is created. The callback function is called with a single argument of type OnComponent.

246 247
        May be called in show() or ui() - but it may be too late in latter as some components may already be created.

248 249 250
        This function is an alternative to before_component in that it also cllows to run before a component is created, but
        it doesn't require to be called for every created component - just for the one you need.
        """
A
AUTOMATIC1111 已提交
251 252
        if self.on_before_component_elem_id is None:
            self.on_before_component_elem_id = []
253 254 255 256 257 258 259

        self.on_before_component_elem_id.append((elem_id, callback))

    def on_after_component(self, callback, *, elem_id):
        """
        Calls callback after a component is created. The callback function is called with a single argument of type OnComponent.
        """
A
AUTOMATIC1111 已提交
260 261
        if self.on_after_component_elem_id is None:
            self.on_after_component_elem_id = []
262 263 264

        self.on_after_component_elem_id.append((elem_id, callback))

A
AUTOMATIC 已提交
265
    def describe(self):
266
        """unused"""
A
AUTOMATIC 已提交
267 268
        return ""

269 270 271 272
    def elem_id(self, item_id):
        """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""

        need_tabname = self.show(True) == self.show(False)
S
S-Del 已提交
273
        tabkind = 'img2img' if self.is_img2img else 'txt2img'
274
        tabname = f"{tabkind}_" if need_tabname else ""
275 276 277 278
        title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))

        return f'script_{tabname}{title}_{item_id}'

A
AUTOMATIC1111 已提交
279
    def before_hr(self, p, *args):
280 281 282 283
        """
        This function is called before hires fix start.
        """
        pass
284

285

A
AUTOMATIC1111 已提交
286 287
class ScriptBuiltinUI(Script):
    setup_for_ui_only = True
288 289 290 291 292

    def elem_id(self, item_id):
        """helper function to generate id for a HTML element, constructs final id out of tab and user-supplied item_id"""

        need_tabname = self.show(True) == self.show(False)
A
AUTOMATIC1111 已提交
293
        tabname = ('img2img' if self.is_img2img else 'txt2img') + "_" if need_tabname else ""
294 295 296 297

        return f'{tabname}{item_id}'


298 299 300 301 302 303 304 305 306 307 308 309
current_basedir = paths.script_path


def basedir():
    """returns the base directory for the current script. For scripts in the main scripts directory,
    this is the main directory (where webui.py resides), and for scripts in extensions directory
    (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
    """
    return current_basedir


ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
310 311 312

scripts_data = []
postprocessing_scripts_data = []
313
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
314 315


316
def list_scripts(scriptdirname, extension, *, include_extensions=True):
317
    scripts_list = []
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
    script_dependency_map = {}

    # build script dependency map

    root_script_basedir = os.path.join(paths.script_path, scriptdirname)
    if os.path.exists(root_script_basedir):
        for filename in sorted(os.listdir(root_script_basedir)):
            script_dependency_map[filename] = {
                "extension": None,
                "extension_dirname": None,
                "script_file": ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)),
                "requires": [],
                "load_before": [],
                "load_after": [],
            }
333

334 335
    if include_extensions:
        for ext in extensions.active():
336 337
            extension_scripts_list = ext.list_files(scriptdirname, extension)
            for extension_script in extension_scripts_list:
W
wfjsw 已提交
338 339 340
                script_canonical_name = ext.canonical_name + "/" + extension_script.filename
                if ext.is_builtin:
                    script_canonical_name = "builtin/" + script_canonical_name
341 342 343 344 345 346 347 348 349 350 351
                relative_path = scriptdirname + "/" + extension_script.filename

                requires = None
                load_before = None
                load_after = None

                if ext.metadata is not None:
                    requires = ext.metadata.get(relative_path, "Requires", fallback=None)
                    load_before = ext.metadata.get(relative_path, "Before", fallback=None)
                    load_after = ext.metadata.get(relative_path, "After", fallback=None)

W
wfjsw 已提交
352 353 354
                requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else []
                load_after = list(filter(None, re.split(r"[,\s]+", load_after.lower()))) if load_after else []
                load_before = list(filter(None, re.split(r"[,\s]+", load_before.lower()))) if load_before else []
355

W
wfjsw 已提交
356
                script_dependency_map[script_canonical_name] = {
357 358 359 360 361 362 363 364 365 366 367
                    "extension": ext.canonical_name,
                    "extension_dirname": ext.name,
                    "script_file": extension_script,
                    "requires": requires,
                    "load_before": load_before,
                    "load_after": load_after,
                }

    # resolve dependencies

    loaded_extensions = set()
368 369
    for ext in extensions.active():
        loaded_extensions.add(ext.canonical_name)
370

W
wfjsw 已提交
371
    for script_canonical_name, script_data in script_dependency_map.items():
372 373 374
        # load before requires inverse dependency
        # in this case, append the script name into the load_after list of the specified script
        for load_before_script in script_data['load_before']:
W
wfjsw 已提交
375 376 377 378
            # if this requires an individual script to be loaded before
            if load_before_script in script_dependency_map:
                script_dependency_map[load_before_script]['load_after'].append(script_canonical_name)
            elif load_before_script in loaded_extensions:
379
                for _, script_data2 in script_dependency_map.items():
W
wfjsw 已提交
380 381
                    if script_data2['extension'] == load_before_script:
                        script_data2['load_after'].append(script_canonical_name)
382 383 384
                        break

        # resolve extension name in load_after lists
W
wfjsw 已提交
385 386 387 388 389 390 391
        for load_after_script in list(script_data['load_after']):
            if load_after_script not in script_dependency_map and load_after_script in loaded_extensions:
                script_data['load_after'].remove(load_after_script)
                for script_canonical_name2, script_data2 in script_dependency_map.items():
                    if script_data2['extension'] == load_after_script:
                        script_data['load_after'].remove(script_canonical_name2)
                        break
392 393 394

    # build the DAG
    sorter = TopologicalSorter()
W
wfjsw 已提交
395
    for script_canonical_name, script_data in script_dependency_map.items():
396 397
        requirement_met = True
        for required_script in script_data['requires']:
W
wfjsw 已提交
398 399 400 401 402 403 404 405
            # if this requires an individual script to be loaded
            if required_script not in script_dependency_map and required_script not in loaded_extensions:
                errors.report(f"Script \"{script_canonical_name}\" "
                              f"requires \"{required_script}\" to "
                              f"be loaded, but it is not. Skipping.",
                              exc_info=False)
                requirement_met = False
                break
406 407 408
        if not requirement_met:
            continue

W
wfjsw 已提交
409
        sorter.add(script_canonical_name, *script_data['load_after'])
410 411 412 413 414 415 416 417

    # sort the scripts
    try:
        ordered_script = sorter.static_order()
    except CycleError:
        errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True)
        ordered_script = script_dependency_map.keys()

W
wfjsw 已提交
418 419
    for script_canonical_name in ordered_script:
        script_data = script_dependency_map[script_canonical_name]
420
        scripts_list.append(script_data['script_file'])
421

422
    scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
423

424
    return scripts_list
425

426

A
AUTOMATIC 已提交
427 428 429
def list_files_with_name(filename):
    res = []

430
    dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
A
AUTOMATIC 已提交
431 432 433 434 435 436

    for dirpath in dirs:
        if not os.path.isdir(dirpath):
            continue

        path = os.path.join(dirpath, filename)
T
Tong Zeng 已提交
437
        if os.path.isfile(path):
A
AUTOMATIC 已提交
438 439 440 441 442
            res.append(path)

    return res


443 444 445
def load_scripts():
    global current_basedir
    scripts_data.clear()
446
    postprocessing_scripts_data.clear()
447 448
    script_callbacks.clear_callbacks()

449
    scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False)
450 451

    syspath = sys.path
452

453
    def register_scripts_from_module(module):
A
AUTOMATIC 已提交
454
        for script_class in module.__dict__.values():
H
huchenlei 已提交
455
            if not inspect.isclass(script_class):
456 457 458 459 460 461 462
                continue

            if issubclass(script_class, Script):
                scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
            elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
                postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))

463 464 465
    # here the scripts_list is already ordered
    # processing_script is not considered though
    for scriptfile in scripts_list:
466
        try:
467 468 469 470
            if scriptfile.basedir != paths.script_path:
                sys.path = [scriptfile.basedir] + sys.path
            current_basedir = scriptfile.basedir

471 472
            script_module = script_loading.load_module(scriptfile.path)
            register_scripts_from_module(script_module)
473 474

        except Exception:
475
            errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True)
476

477 478 479
        finally:
            sys.path = syspath
            current_basedir = paths.script_path
A
AUTOMATIC 已提交
480
            timer.startup_timer.record(scriptfile.filename)
481

482 483 484 485 486 487
    global scripts_txt2img, scripts_img2img, scripts_postproc

    scripts_txt2img = ScriptRunner()
    scripts_img2img = ScriptRunner()
    scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()

488 489 490

def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
    try:
491
        return func(*args, **kwargs)
492
    except Exception:
493
        errors.report(f"Error calling: {filename}/{funcname}", exc_info=True)
494 495 496 497

    return default


A
AUTOMATIC 已提交
498 499 500
class ScriptRunner:
    def __init__(self):
        self.scripts = []
501 502
        self.selectable_scripts = []
        self.alwayson_scripts = []
ふぁ 已提交
503
        self.titles = []
504
        self.title_map = {}
505
        self.infotext_fields = []
506
        self.paste_field_names = []
507
        self.inputs = [None]
A
AUTOMATIC 已提交
508

509 510 511 512 513 514
        self.on_before_component_elem_id = {}
        """dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""

        self.on_after_component_elem_id = {}
        """dict of callbacks to be called after an element is created; key=elem_id, value=list of callbacks"""

515
    def initialize_scripts(self, is_img2img):
516 517
        from modules import scripts_auto_postprocessing

518 519 520 521
        self.scripts.clear()
        self.alwayson_scripts.clear()
        self.selectable_scripts.clear()

522 523
        auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()

A
AUTOMATIC 已提交
524 525 526
        for script_data in auto_processing_scripts + scripts_data:
            script = script_data.script_class()
            script.filename = script_data.path
527 528
            script.is_txt2img = not is_img2img
            script.is_img2img = is_img2img
529
            script.tabname = "img2img" if is_img2img else "txt2img"
A
AUTOMATIC 已提交
530

531
            visibility = script.show(script.is_img2img)
A
AUTOMATIC 已提交
532

533 534 535 536
            if visibility == AlwaysVisible:
                self.scripts.append(script)
                self.alwayson_scripts.append(script)
                script.alwayson = True
A
AUTOMATIC 已提交
537

538 539 540
            elif visibility:
                self.scripts.append(script)
                self.selectable_scripts.append(script)
A
AUTOMATIC 已提交
541

542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563
        self.apply_on_before_component_callbacks()

    def apply_on_before_component_callbacks(self):
        for script in self.scripts:
            on_before = script.on_before_component_elem_id or []
            on_after = script.on_after_component_elem_id or []

            for elem_id, callback in on_before:
                if elem_id not in self.on_before_component_elem_id:
                    self.on_before_component_elem_id[elem_id] = []

                self.on_before_component_elem_id[elem_id].append((callback, script))

            for elem_id, callback in on_after:
                if elem_id not in self.on_after_component_elem_id:
                    self.on_after_component_elem_id[elem_id] = []

                self.on_after_component_elem_id[elem_id].append((callback, script))

            on_before.clear()
            on_after.clear()

564
    def create_script_ui(self, script):
A
AUTOMATIC 已提交
565 566
        import modules.api.models as api_models

567 568
        script.args_from = len(self.inputs)
        script.args_to = len(self.inputs)
569

570
        controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
A
AUTOMATIC 已提交
571

572 573
        if controls is None:
            return
A
AUTOMATIC 已提交
574

575 576
        script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
        api_args = []
A
AUTOMATIC 已提交
577

578 579
        for control in controls:
            control.custom_script_source = os.path.basename(script.filename)
A
AUTOMATIC 已提交
580

581
            arg_info = api_models.ScriptArg(label=control.label or "")
A
AUTOMATIC 已提交
582

A
AUTOMATIC1111 已提交
583
            for field in ("value", "minimum", "maximum", "step"):
584 585 586
                v = getattr(control, field, None)
                if v is not None:
                    setattr(arg_info, field, v)
587

A
AUTOMATIC1111 已提交
588 589
            choices = getattr(control, 'choices', None)  # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string
            if choices is not None:
A
linter  
AUTOMATIC1111 已提交
590
                arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices]
A
AUTOMATIC1111 已提交
591

592
            api_args.append(arg_info)
A
AUTOMATIC 已提交
593

594 595 596 597 598 599
        script.api_info = api_models.ScriptInfo(
            name=script.name,
            is_img2img=script.is_img2img,
            is_alwayson=script.alwayson,
            args=api_args,
        )
A
AUTOMATIC 已提交
600

601 602
        if script.infotext_fields is not None:
            self.infotext_fields += script.infotext_fields
A
AUTOMATIC 已提交
603

604 605
        if script.paste_field_names is not None:
            self.paste_field_names += script.paste_field_names
A
AUTOMATIC 已提交
606

607 608
        self.inputs += controls
        script.args_to = len(self.inputs)
A
AUTOMATIC 已提交
609

610 611 612
    def setup_ui_for_section(self, section, scriptlist=None):
        if scriptlist is None:
            scriptlist = self.alwayson_scripts
613

614 615 616
        for script in scriptlist:
            if script.alwayson and script.section != section:
                continue
A
AUTOMATIC 已提交
617

618 619 620
            if script.create_group:
                with gr.Group(visible=script.alwayson) as group:
                    self.create_script_ui(script)
621

622 623 624
                script.group = group
            else:
                self.create_script_ui(script)
625

626 627
    def prepare_ui(self):
        self.inputs = [None]
628

629
    def setup_ui(self):
630 631
        all_titles = [wrap_call(script.title, script.filename, "title") or script.filename for script in self.scripts]
        self.title_map = {title.lower(): script for title, script in zip(all_titles, self.scripts)}
632
        self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
633

634 635 636 637 638 639
        self.setup_ui_for_section(None)

        dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
        self.inputs[0] = dropdown

        self.setup_ui_for_section(None, self.selectable_scripts)
640

A
AUTOMATIC 已提交
641
        def select_script(script_index):
642
            selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
A
AUTOMATIC 已提交
643

644
            return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
A
AUTOMATIC 已提交
645

ふぁ 已提交
646
        def init_field(title):
647 648
            """called when an initial value is set from ui-config.json to show script's UI components"""

ふぁ 已提交
649
            if title == 'None':
ふぁ 已提交
650
                return
651

ふぁ 已提交
652
            script_index = self.titles.index(title)
653
            self.selectable_scripts[script_index].group.visible = True
ふぁ 已提交
654 655

        dropdown.init_field = init_field
656

A
AUTOMATIC 已提交
657 658 659
        dropdown.change(
            fn=select_script,
            inputs=[dropdown],
660
            outputs=[script.group for script in self.selectable_scripts]
A
AUTOMATIC 已提交
661
        )
A
AUTOMATIC 已提交
662

E
EllangoK 已提交
663
        self.script_load_ctr = 0
664

E
EllangoK 已提交
665 666 667 668 669 670 671 672 673 674
        def onload_script_visibility(params):
            title = params.get('Script', None)
            if title:
                title_index = self.titles.index(title)
                visibility = title_index == self.script_load_ctr
                self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
                return gr.update(visible=visibility)
            else:
                return gr.update(visible=False)

675 676
        self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
        self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
E
EllangoK 已提交
677

678
        self.apply_on_before_component_callbacks()
679

680
        return self.inputs
A
AUTOMATIC 已提交
681

682
    def run(self, p, *args):
A
AUTOMATIC 已提交
683
        script_index = args[0]
A
AUTOMATIC 已提交
684

A
AUTOMATIC 已提交
685 686
        if script_index == 0:
            return None
A
AUTOMATIC 已提交
687

688
        script = self.selectable_scripts[script_index-1]
A
AUTOMATIC 已提交
689

A
AUTOMATIC 已提交
690 691
        if script is None:
            return None
A
AUTOMATIC 已提交
692

A
AUTOMATIC 已提交
693 694
        script_args = args[script.args_from:script.args_to]
        processed = script.run(p, *script_args)
A
AUTOMATIC 已提交
695

696 697
        shared.total_tqdm.clear()

A
AUTOMATIC 已提交
698
        return processed
A
AUTOMATIC 已提交
699

700 701 702 703 704 705 706 707
    def before_process(self, p):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.before_process(p, *script_args)
            except Exception:
                errors.report(f"Error running before_process: {script.filename}", exc_info=True)

A
AUTOMATIC 已提交
708
    def process(self, p):
709 710 711 712 713
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.process(p, *script_args)
            except Exception:
714
                errors.report(f"Error running process: {script.filename}", exc_info=True)
A
AUTOMATIC 已提交
715

716 717 718 719 720 721
    def before_process_batch(self, p, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.before_process_batch(p, *script_args, **kwargs)
            except Exception:
722
                errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
723

724 725 726 727 728 729 730 731
    def after_extra_networks_activate(self, p, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.after_extra_networks_activate(p, *script_args, **kwargs)
            except Exception:
                errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)

732
    def process_batch(self, p, **kwargs):
A
Artem Zagidulin 已提交
733 734 735
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
736
                script.process_batch(p, *script_args, **kwargs)
A
Artem Zagidulin 已提交
737
            except Exception:
738
                errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
A
Artem Zagidulin 已提交
739

A
AUTOMATIC 已提交
740 741 742 743 744 745
    def postprocess(self, p, processed):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess(p, processed, *script_args)
            except Exception:
746
                errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
747

748 749 750 751 752 753
    def postprocess_batch(self, p, images, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_batch(p, *script_args, images=images, **kwargs)
            except Exception:
754
                errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
755

L
ljleb 已提交
756 757 758 759 760 761 762 763
    def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_batch_list(p, pp, *script_args, **kwargs)
            except Exception:
                errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)

764 765 766 767 768 769
    def postprocess_image(self, p, pp: PostprocessImageArgs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_image(p, pp, *script_args)
            except Exception:
770
                errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
771

772
    def before_component(self, component, **kwargs):
A
AUTOMATIC1111 已提交
773 774 775 776 777
        for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
            try:
                callback(OnComponent(component=component))
            except Exception:
                errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
778

779 780 781 782
        for script in self.scripts:
            try:
                script.before_component(component, **kwargs)
            except Exception:
783
                errors.report(f"Error running before_component: {script.filename}", exc_info=True)
784 785

    def after_component(self, component, **kwargs):
A
AUTOMATIC1111 已提交
786 787 788 789 790
        for callback, script in self.on_after_component_elem_id.get(component.elem_id, []):
            try:
                callback(OnComponent(component=component))
            except Exception:
                errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
791

792 793 794 795
        for script in self.scripts:
            try:
                script.after_component(component, **kwargs)
            except Exception:
796
                errors.report(f"Error running after_component: {script.filename}", exc_info=True)
797

798 799 800
    def script(self, title):
        return self.title_map.get(title.lower())

A
AUTOMATIC 已提交
801
    def reload_sources(self, cache):
D
DepFA 已提交
802
        for si, script in list(enumerate(self.scripts)):
803 804 805 806 807 808 809 810 811
            args_from = script.args_from
            args_to = script.args_to
            filename = script.filename

            module = cache.get(filename, None)
            if module is None:
                module = script_loading.load_module(script.filename)
                cache[filename] = module

A
AUTOMATIC 已提交
812
            for script_class in module.__dict__.values():
813 814 815 816 817
                if type(script_class) == type and issubclass(script_class, Script):
                    self.scripts[si] = script_class()
                    self.scripts[si].filename = filename
                    self.scripts[si].args_from = args_from
                    self.scripts[si].args_to = args_to
A
AUTOMATIC 已提交
818

819 820 821 822 823 824 825 826
    def before_hr(self, p):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.before_hr(p, *script_args)
            except Exception:
                errors.report(f"Error running before_hr: {script.filename}", exc_info=True)

A
AUTOMATIC1111 已提交
827
    def setup_scrips(self, p, *, is_ui=True):
A
AUTOMATIC1111 已提交
828
        for script in self.alwayson_scripts:
A
AUTOMATIC1111 已提交
829 830 831
            if not is_ui and script.setup_for_ui_only:
                continue

A
AUTOMATIC1111 已提交
832 833 834 835 836 837
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.setup(p, *script_args)
            except Exception:
                errors.report(f"Error running setup: {script.filename}", exc_info=True)

838

839 840 841
scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None
scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
842
scripts_current: ScriptRunner = None
D
DepFA 已提交
843

844

D
DepFA 已提交
845
def reload_script_body_only():
A
AUTOMATIC 已提交
846 847 848
    cache = {}
    scripts_txt2img.reload_sources(cache)
    scripts_img2img.reload_sources(cache)
D
DepFA 已提交
849

D
DepFA 已提交
850

851
reload_scripts = load_scripts  # compatibility alias