scripts.py 23.4 KB
Newer Older
1
import os
2
import re
3
import sys
H
huchenlei 已提交
4
import inspect
5
from collections import namedtuple
6 7 8

import gradio as gr

9
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
10 11 12

AlwaysVisible = object()

A
AUTOMATIC 已提交
13

14 15 16 17 18
class PostprocessImageArgs:
    def __init__(self, image):
        self.image = image


L
ljleb 已提交
19 20 21 22 23
class PostprocessBatchListArgs:
    def __init__(self, images):
        self.images = images


24
class Script:
A
AUTOMATIC 已提交
25 26 27
    name = None
    """script's internal name derived from title"""

28 29 30
    section = None
    """name of UI section that the script's controls will be placed into"""

31
    filename = None
A
AUTOMATIC 已提交
32 33
    args_from = None
    args_to = None
34 35
    alwayson = False

36 37 38
    is_txt2img = False
    is_img2img = False

39
    group = None
40 41 42 43
    """A gr.Group component that has all script's UI inside it."""

    create_group = True
    """If False, for alwayson scripts, a group component will not be created."""
44

45 46 47 48
    infotext_fields = None
    """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
    parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
    """
49

50 51 52 53 54
    paste_field_names = None
    """if set in ui(), this is a list of names of infotext fields; the fields will be sent through the
    various "Send to <X>" buttons when clicked
    """

A
AUTOMATIC 已提交
55 56 57
    api_info = None
    """Generated value of type modules.api.models.ScriptInfo with information about the script for API"""

58
    def title(self):
59 60
        """this function should return the title of the script. This is what will be displayed in the dropdown menu."""

61 62
        raise NotImplementedError()

A
AUTOMATIC 已提交
63
    def ui(self, is_img2img):
64 65
        """this function should create gradio UI elements. See https://gradio.app/docs/#components
        The return value should be an array of all components that are used in processing.
J
Jim Hays 已提交
66
        Values of those returned components will be passed to run() and process() functions.
67 68
        """

A
AUTOMATIC 已提交
69 70
        pass

A
AUTOMATIC 已提交
71
    def show(self, is_img2img):
72 73 74 75 76
        """
        is_img2img is True if this function is called for the img2img interface, and Fasle otherwise

        This function should return:
         - False if the script should not be shown in UI at all
J
Jim Hays 已提交
77
         - True if the script should be shown in UI if it's selected in the scripts dropdown
78 79 80
         - script.AlwaysVisible if the script should be shown in UI at all times
         """

A
AUTOMATIC 已提交
81 82
        return True

83 84 85 86 87 88 89 90 91 92 93
    def run(self, p, *args):
        """
        This function is called if the script has been selected in the script dropdown.
        It must do all processing and return the Processed object with results, same as
        one returned by processing.process_images.

        Usually the processing is done by calling the processing.process_images function.

        args contains all values returned by components from ui()
        """

94
        pass
A
AUTOMATIC 已提交
95

96 97 98 99 100 101 102 103 104
    def before_process(self, p, *args):
        """
        This function is called very early before processing begins for AlwaysVisible scripts.
        You can modify the processing object (p) here, inject hooks, etc.
        args contains all values returned by components from ui()
        """

        pass

105 106 107
    def process(self, p, *args):
        """
        This function is called before processing begins for AlwaysVisible scripts.
A
AUTOMATIC 已提交
108 109 110 111 112 113
        You can modify the processing object (p) here, inject hooks, etc.
        args contains all values returned by components from ui()
        """

        pass

114 115 116 117 118 119 120 121 122 123 124 125 126 127
    def before_process_batch(self, p, *args, **kwargs):
        """
        Called before extra networks are parsed from the prompt, so you can add
        new extra network keywords to the prompt with this callback.

        **kwargs will have those items:
          - batch_number - index of current batch, from 0 to number of batches-1
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
          - seeds - list of seeds for current batch
          - subseeds - list of subseeds for current batch
        """

        pass

128 129
    def after_extra_networks_activate(self, p, *args, **kwargs):
        """
W
w-e-w 已提交
130
        Called after extra networks activation, before conds calculation
131 132 133 134 135 136 137 138 139 140 141 142
        allow modification of the network after extra networks activation been applied
        won't be call if p.disable_extra_networks

        **kwargs will have those items:
          - batch_number - index of current batch, from 0 to number of batches-1
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
          - seeds - list of seeds for current batch
          - subseeds - list of subseeds for current batch
          - extra_network_data - list of ExtraNetworkParams for current stage
        """
        pass

143
    def process_batch(self, p, *args, **kwargs):
A
Artem Zagidulin 已提交
144
        """
145 146 147 148 149 150 151
        Same as process(), but called for every batch.

        **kwargs will have those items:
          - batch_number - index of current batch, from 0 to number of batches-1
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
          - seeds - list of seeds for current batch
          - subseeds - list of subseeds for current batch
A
Artem Zagidulin 已提交
152 153 154 155
        """

        pass

156 157 158 159 160 161 162 163 164 165 166
    def postprocess_batch(self, p, *args, **kwargs):
        """
        Same as process_batch(), but called for every batch after it has been generated.

        **kwargs will have same items as process_batch, and also:
          - batch_number - index of current batch, from 0 to number of batches-1
          - images - torch tensor with all generated images, with values ranging from 0 to 1;
        """

        pass

L
ljleb 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
    def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
        """
        Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
        This is useful when you want to update the entire batch instead of individual images.

        You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
        If the number of images is different from the batch size when returning,
        then the script has the responsibility to also update the following attributes in the processing object (p):
          - p.prompts
          - p.negative_prompts
          - p.seeds
          - p.subseeds

        **kwargs will have same items as process_batch, and also:
          - batch_number - index of current batch, from 0 to number of batches-1
        """

        pass

186 187 188 189 190 191 192
    def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
        """
        Called for every image after it has been generated.
        """

        pass

A
AUTOMATIC 已提交
193 194 195 196
    def postprocess(self, p, processed, *args):
        """
        This function is called after processing ends for AlwaysVisible scripts.
        args contains all values returned by components from ui()
197 198 199 200
        """

        pass

201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
    def before_component(self, component, **kwargs):
        """
        Called before a component is created.
        Use elem_id/label fields of kwargs to figure out which component it is.
        This can be useful to inject your own components somewhere in the middle of vanilla UI.
        You can return created components in the ui() function to add them to the list of arguments for your processing functions
        """

        pass

    def after_component(self, component, **kwargs):
        """
        Called after a component is created. Same as above.
        """

        pass

A
AUTOMATIC 已提交
218
    def describe(self):
219
        """unused"""
A
AUTOMATIC 已提交
220 221
        return ""

222 223 224 225
    def elem_id(self, item_id):
        """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""

        need_tabname = self.show(True) == self.show(False)
226 227
        tabkind = 'img2img' if self.is_img2img else 'txt2txt'
        tabname = f"{tabkind}_" if need_tabname else ""
228 229 230 231
        title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))

        return f'script_{tabname}{title}_{item_id}'

A
AUTOMATIC1111 已提交
232
    def before_hr(self, p, *args):
233 234 235 236
        """
        This function is called before hires fix start.
        """
        pass
237

238

239 240 241 242 243 244 245 246 247 248 249 250
current_basedir = paths.script_path


def basedir():
    """returns the base directory for the current script. For scripts in the main scripts directory,
    this is the main directory (where webui.py resides), and for scripts in extensions directory
    (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
    """
    return current_basedir


ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
251 252 253

scripts_data = []
postprocessing_scripts_data = []
254
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
255 256


257
def list_scripts(scriptdirname, extension, *, include_extensions=True):
258 259 260 261 262 263 264
    scripts_list = []

    basedir = os.path.join(paths.script_path, scriptdirname)
    if os.path.exists(basedir):
        for filename in sorted(os.listdir(basedir)):
            scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))

265 266 267
    if include_extensions:
        for ext in extensions.active():
            scripts_list += ext.list_files(scriptdirname, extension)
268

269
    scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
270

271
    return scripts_list
272

273

A
AUTOMATIC 已提交
274 275 276
def list_files_with_name(filename):
    res = []

277
    dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
A
AUTOMATIC 已提交
278 279 280 281 282 283

    for dirpath in dirs:
        if not os.path.isdir(dirpath):
            continue

        path = os.path.join(dirpath, filename)
T
Tong Zeng 已提交
284
        if os.path.isfile(path):
A
AUTOMATIC 已提交
285 286 287 288 289
            res.append(path)

    return res


290 291 292
def load_scripts():
    global current_basedir
    scripts_data.clear()
293
    postprocessing_scripts_data.clear()
294 295
    script_callbacks.clear_callbacks()

296
    scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False)
297 298

    syspath = sys.path
299

300
    def register_scripts_from_module(module):
A
AUTOMATIC 已提交
301
        for script_class in module.__dict__.values():
H
huchenlei 已提交
302
            if not inspect.isclass(script_class):
303 304 305 306 307 308 309
                continue

            if issubclass(script_class, Script):
                scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
            elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
                postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))

S
sumof2primes 已提交
310 311 312 313 314 315 316 317
    def orderby(basedir):
        # 1st webui, 2nd extensions-builtin, 3rd extensions
        priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
        for key in priority:
            if basedir.startswith(key):
                return priority[key]
        return 9999

S
sumof2primes 已提交
318
    for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
319
        try:
320 321 322 323
            if scriptfile.basedir != paths.script_path:
                sys.path = [scriptfile.basedir] + sys.path
            current_basedir = scriptfile.basedir

324 325
            script_module = script_loading.load_module(scriptfile.path)
            register_scripts_from_module(script_module)
326 327

        except Exception:
328
            errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True)
329

330 331 332
        finally:
            sys.path = syspath
            current_basedir = paths.script_path
A
AUTOMATIC 已提交
333
            timer.startup_timer.record(scriptfile.filename)
334

335 336 337 338 339 340
    global scripts_txt2img, scripts_img2img, scripts_postproc

    scripts_txt2img = ScriptRunner()
    scripts_img2img = ScriptRunner()
    scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()

341 342 343

def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
    try:
344
        return func(*args, **kwargs)
345
    except Exception:
346
        errors.report(f"Error calling: {filename}/{funcname}", exc_info=True)
347 348 349 350

    return default


A
AUTOMATIC 已提交
351 352 353
class ScriptRunner:
    def __init__(self):
        self.scripts = []
354 355
        self.selectable_scripts = []
        self.alwayson_scripts = []
ふぁ 已提交
356
        self.titles = []
357
        self.infotext_fields = []
358
        self.paste_field_names = []
359
        self.inputs = [None]
A
AUTOMATIC 已提交
360

361
    def initialize_scripts(self, is_img2img):
362 363
        from modules import scripts_auto_postprocessing

364 365 366 367
        self.scripts.clear()
        self.alwayson_scripts.clear()
        self.selectable_scripts.clear()

368 369
        auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()

A
AUTOMATIC 已提交
370 371 372
        for script_data in auto_processing_scripts + scripts_data:
            script = script_data.script_class()
            script.filename = script_data.path
373 374
            script.is_txt2img = not is_img2img
            script.is_img2img = is_img2img
A
AUTOMATIC 已提交
375

376
            visibility = script.show(script.is_img2img)
A
AUTOMATIC 已提交
377

378 379 380 381
            if visibility == AlwaysVisible:
                self.scripts.append(script)
                self.alwayson_scripts.append(script)
                script.alwayson = True
A
AUTOMATIC 已提交
382

383 384 385
            elif visibility:
                self.scripts.append(script)
                self.selectable_scripts.append(script)
A
AUTOMATIC 已提交
386

387
    def create_script_ui(self, script):
A
AUTOMATIC 已提交
388 389
        import modules.api.models as api_models

390 391
        script.args_from = len(self.inputs)
        script.args_to = len(self.inputs)
392

393
        controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
A
AUTOMATIC 已提交
394

395 396
        if controls is None:
            return
A
AUTOMATIC 已提交
397

398 399
        script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
        api_args = []
A
AUTOMATIC 已提交
400

401 402
        for control in controls:
            control.custom_script_source = os.path.basename(script.filename)
A
AUTOMATIC 已提交
403

404
            arg_info = api_models.ScriptArg(label=control.label or "")
A
AUTOMATIC 已提交
405

406 407 408 409
            for field in ("value", "minimum", "maximum", "step", "choices"):
                v = getattr(control, field, None)
                if v is not None:
                    setattr(arg_info, field, v)
410

411
            api_args.append(arg_info)
A
AUTOMATIC 已提交
412

413 414 415 416 417 418
        script.api_info = api_models.ScriptInfo(
            name=script.name,
            is_img2img=script.is_img2img,
            is_alwayson=script.alwayson,
            args=api_args,
        )
A
AUTOMATIC 已提交
419

420 421
        if script.infotext_fields is not None:
            self.infotext_fields += script.infotext_fields
A
AUTOMATIC 已提交
422

423 424
        if script.paste_field_names is not None:
            self.paste_field_names += script.paste_field_names
A
AUTOMATIC 已提交
425

426 427
        self.inputs += controls
        script.args_to = len(self.inputs)
A
AUTOMATIC 已提交
428

429 430 431
    def setup_ui_for_section(self, section, scriptlist=None):
        if scriptlist is None:
            scriptlist = self.alwayson_scripts
432

433 434 435
        for script in scriptlist:
            if script.alwayson and script.section != section:
                continue
A
AUTOMATIC 已提交
436

437 438 439
            if script.create_group:
                with gr.Group(visible=script.alwayson) as group:
                    self.create_script_ui(script)
440

441 442 443
                script.group = group
            else:
                self.create_script_ui(script)
444

445 446
    def prepare_ui(self):
        self.inputs = [None]
447

448 449
    def setup_ui(self):
        self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
450

451 452 453 454 455 456
        self.setup_ui_for_section(None)

        dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
        self.inputs[0] = dropdown

        self.setup_ui_for_section(None, self.selectable_scripts)
457

A
AUTOMATIC 已提交
458
        def select_script(script_index):
459
            selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
A
AUTOMATIC 已提交
460

461
            return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
A
AUTOMATIC 已提交
462

ふぁ 已提交
463
        def init_field(title):
464 465
            """called when an initial value is set from ui-config.json to show script's UI components"""

ふぁ 已提交
466
            if title == 'None':
ふぁ 已提交
467
                return
468

ふぁ 已提交
469
            script_index = self.titles.index(title)
470
            self.selectable_scripts[script_index].group.visible = True
ふぁ 已提交
471 472

        dropdown.init_field = init_field
473

A
AUTOMATIC 已提交
474 475 476
        dropdown.change(
            fn=select_script,
            inputs=[dropdown],
477
            outputs=[script.group for script in self.selectable_scripts]
A
AUTOMATIC 已提交
478
        )
A
AUTOMATIC 已提交
479

E
EllangoK 已提交
480
        self.script_load_ctr = 0
481

E
EllangoK 已提交
482 483 484 485 486 487 488 489 490 491
        def onload_script_visibility(params):
            title = params.get('Script', None)
            if title:
                title_index = self.titles.index(title)
                visibility = title_index == self.script_load_ctr
                self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
                return gr.update(visible=visibility)
            else:
                return gr.update(visible=False)

492 493
        self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
        self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
E
EllangoK 已提交
494

495
        return self.inputs
A
AUTOMATIC 已提交
496

497
    def run(self, p, *args):
A
AUTOMATIC 已提交
498
        script_index = args[0]
A
AUTOMATIC 已提交
499

A
AUTOMATIC 已提交
500 501
        if script_index == 0:
            return None
A
AUTOMATIC 已提交
502

503
        script = self.selectable_scripts[script_index-1]
A
AUTOMATIC 已提交
504

A
AUTOMATIC 已提交
505 506
        if script is None:
            return None
A
AUTOMATIC 已提交
507

A
AUTOMATIC 已提交
508 509
        script_args = args[script.args_from:script.args_to]
        processed = script.run(p, *script_args)
A
AUTOMATIC 已提交
510

511 512
        shared.total_tqdm.clear()

A
AUTOMATIC 已提交
513
        return processed
A
AUTOMATIC 已提交
514

515 516 517 518 519 520 521 522
    def before_process(self, p):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.before_process(p, *script_args)
            except Exception:
                errors.report(f"Error running before_process: {script.filename}", exc_info=True)

A
AUTOMATIC 已提交
523
    def process(self, p):
524 525 526 527 528
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.process(p, *script_args)
            except Exception:
529
                errors.report(f"Error running process: {script.filename}", exc_info=True)
A
AUTOMATIC 已提交
530

531 532 533 534 535 536
    def before_process_batch(self, p, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.before_process_batch(p, *script_args, **kwargs)
            except Exception:
537
                errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
538

539 540 541 542 543 544 545 546
    def after_extra_networks_activate(self, p, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.after_extra_networks_activate(p, *script_args, **kwargs)
            except Exception:
                errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)

547
    def process_batch(self, p, **kwargs):
A
Artem Zagidulin 已提交
548 549 550
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
551
                script.process_batch(p, *script_args, **kwargs)
A
Artem Zagidulin 已提交
552
            except Exception:
553
                errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
A
Artem Zagidulin 已提交
554

A
AUTOMATIC 已提交
555 556 557 558 559 560
    def postprocess(self, p, processed):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess(p, processed, *script_args)
            except Exception:
561
                errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
562

563 564 565 566 567 568
    def postprocess_batch(self, p, images, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_batch(p, *script_args, images=images, **kwargs)
            except Exception:
569
                errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
570

L
ljleb 已提交
571 572 573 574 575 576 577 578
    def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_batch_list(p, pp, *script_args, **kwargs)
            except Exception:
                errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)

579 580 581 582 583 584
    def postprocess_image(self, p, pp: PostprocessImageArgs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_image(p, pp, *script_args)
            except Exception:
585
                errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
586

587 588 589 590 591
    def before_component(self, component, **kwargs):
        for script in self.scripts:
            try:
                script.before_component(component, **kwargs)
            except Exception:
592
                errors.report(f"Error running before_component: {script.filename}", exc_info=True)
593 594 595 596 597 598

    def after_component(self, component, **kwargs):
        for script in self.scripts:
            try:
                script.after_component(component, **kwargs)
            except Exception:
599
                errors.report(f"Error running after_component: {script.filename}", exc_info=True)
600

A
AUTOMATIC 已提交
601
    def reload_sources(self, cache):
D
DepFA 已提交
602
        for si, script in list(enumerate(self.scripts)):
603 604 605 606 607 608 609 610 611
            args_from = script.args_from
            args_to = script.args_to
            filename = script.filename

            module = cache.get(filename, None)
            if module is None:
                module = script_loading.load_module(script.filename)
                cache[filename] = module

A
AUTOMATIC 已提交
612
            for script_class in module.__dict__.values():
613 614 615 616 617
                if type(script_class) == type and issubclass(script_class, Script):
                    self.scripts[si] = script_class()
                    self.scripts[si].filename = filename
                    self.scripts[si].args_from = args_from
                    self.scripts[si].args_to = args_to
A
AUTOMATIC 已提交
618

619

620 621 622 623 624 625 626 627 628
    def before_hr(self, p):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.before_hr(p, *script_args)
            except Exception:
                errors.report(f"Error running before_hr: {script.filename}", exc_info=True)


629 630 631
scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None
scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
632
scripts_current: ScriptRunner = None
D
DepFA 已提交
633

634

D
DepFA 已提交
635
def reload_script_body_only():
A
AUTOMATIC 已提交
636 637 638
    cache = {}
    scripts_txt2img.reload_sources(cache)
    scripts_img2img.reload_sources(cache)
D
DepFA 已提交
639

D
DepFA 已提交
640

641
reload_scripts = load_scripts  # compatibility alias