scripts.py 19.7 KB
Newer Older
1
import os
2
import re
3 4
import sys
import traceback
5
from collections import namedtuple
6 7 8

import gradio as gr

9
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
10 11 12

AlwaysVisible = object()

A
AUTOMATIC 已提交
13

14 15 16 17 18
class PostprocessImageArgs:
    def __init__(self, image):
        self.image = image


19 20
class Script:
    filename = None
A
AUTOMATIC 已提交
21 22
    args_from = None
    args_to = None
23 24
    alwayson = False

25 26 27
    is_txt2img = False
    is_img2img = False

28 29 30
    """A gr.Group component that has all script's UI inside it"""
    group = None

31 32 33 34
    infotext_fields = None
    """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
    parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
    """
35

36 37 38 39 40
    paste_field_names = None
    """if set in ui(), this is a list of names of infotext fields; the fields will be sent through the
    various "Send to <X>" buttons when clicked
    """

41
    def title(self):
42 43
        """this function should return the title of the script. This is what will be displayed in the dropdown menu."""

44 45
        raise NotImplementedError()

A
AUTOMATIC 已提交
46
    def ui(self, is_img2img):
47 48
        """this function should create gradio UI elements. See https://gradio.app/docs/#components
        The return value should be an array of all components that are used in processing.
J
Jim Hays 已提交
49
        Values of those returned components will be passed to run() and process() functions.
50 51
        """

A
AUTOMATIC 已提交
52 53
        pass

A
AUTOMATIC 已提交
54
    def show(self, is_img2img):
55 56 57 58 59
        """
        is_img2img is True if this function is called for the img2img interface, and Fasle otherwise

        This function should return:
         - False if the script should not be shown in UI at all
J
Jim Hays 已提交
60
         - True if the script should be shown in UI if it's selected in the scripts dropdown
61 62 63
         - script.AlwaysVisible if the script should be shown in UI at all times
         """

A
AUTOMATIC 已提交
64 65
        return True

66 67 68 69 70 71 72 73 74 75 76
    def run(self, p, *args):
        """
        This function is called if the script has been selected in the script dropdown.
        It must do all processing and return the Processed object with results, same as
        one returned by processing.process_images.

        Usually the processing is done by calling the processing.process_images function.

        args contains all values returned by components from ui()
        """

77
        pass
A
AUTOMATIC 已提交
78

79 80 81
    def process(self, p, *args):
        """
        This function is called before processing begins for AlwaysVisible scripts.
A
AUTOMATIC 已提交
82 83 84 85 86 87
        You can modify the processing object (p) here, inject hooks, etc.
        args contains all values returned by components from ui()
        """

        pass

88 89 90 91 92 93 94 95 96 97 98 99 100 101
    def before_process_batch(self, p, *args, **kwargs):
        """
        Called before extra networks are parsed from the prompt, so you can add
        new extra network keywords to the prompt with this callback.

        **kwargs will have those items:
          - batch_number - index of current batch, from 0 to number of batches-1
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
          - seeds - list of seeds for current batch
          - subseeds - list of subseeds for current batch
        """

        pass

102
    def process_batch(self, p, *args, **kwargs):
A
Artem Zagidulin 已提交
103
        """
104 105 106 107 108 109 110
        Same as process(), but called for every batch.

        **kwargs will have those items:
          - batch_number - index of current batch, from 0 to number of batches-1
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
          - seeds - list of seeds for current batch
          - subseeds - list of subseeds for current batch
A
Artem Zagidulin 已提交
111 112 113 114
        """

        pass

115 116 117 118 119 120 121 122 123 124 125
    def postprocess_batch(self, p, *args, **kwargs):
        """
        Same as process_batch(), but called for every batch after it has been generated.

        **kwargs will have same items as process_batch, and also:
          - batch_number - index of current batch, from 0 to number of batches-1
          - images - torch tensor with all generated images, with values ranging from 0 to 1;
        """

        pass

126 127 128 129 130 131 132
    def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
        """
        Called for every image after it has been generated.
        """

        pass

A
AUTOMATIC 已提交
133 134 135 136
    def postprocess(self, p, processed, *args):
        """
        This function is called after processing ends for AlwaysVisible scripts.
        args contains all values returned by components from ui()
137 138 139 140
        """

        pass

141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
    def before_component(self, component, **kwargs):
        """
        Called before a component is created.
        Use elem_id/label fields of kwargs to figure out which component it is.
        This can be useful to inject your own components somewhere in the middle of vanilla UI.
        You can return created components in the ui() function to add them to the list of arguments for your processing functions
        """

        pass

    def after_component(self, component, **kwargs):
        """
        Called after a component is created. Same as above.
        """

        pass

A
AUTOMATIC 已提交
158
    def describe(self):
159
        """unused"""
A
AUTOMATIC 已提交
160 161
        return ""

162 163 164 165 166 167 168 169 170
    def elem_id(self, item_id):
        """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""

        need_tabname = self.show(True) == self.show(False)
        tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
        title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))

        return f'script_{tabname}{title}_{item_id}'

171

172 173 174 175 176 177 178 179 180 181 182 183
current_basedir = paths.script_path


def basedir():
    """returns the base directory for the current script. For scripts in the main scripts directory,
    this is the main directory (where webui.py resides), and for scripts in extensions directory
    (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
    """
    return current_basedir


ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
184 185 186

scripts_data = []
postprocessing_scripts_data = []
187
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
188 189 190 191 192 193 194 195 196 197


def list_scripts(scriptdirname, extension):
    scripts_list = []

    basedir = os.path.join(paths.script_path, scriptdirname)
    if os.path.exists(basedir):
        for filename in sorted(os.listdir(basedir)):
            scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))

198 199
    for ext in extensions.active():
        scripts_list += ext.list_files(scriptdirname, extension)
200

201
    scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
202

203
    return scripts_list
204

205

A
AUTOMATIC 已提交
206 207 208
def list_files_with_name(filename):
    res = []

209
    dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
A
AUTOMATIC 已提交
210 211 212 213 214 215

    for dirpath in dirs:
        if not os.path.isdir(dirpath):
            continue

        path = os.path.join(dirpath, filename)
T
Tong Zeng 已提交
216
        if os.path.isfile(path):
A
AUTOMATIC 已提交
217 218 219 220 221
            res.append(path)

    return res


222 223 224
def load_scripts():
    global current_basedir
    scripts_data.clear()
225
    postprocessing_scripts_data.clear()
226 227 228 229 230
    script_callbacks.clear_callbacks()

    scripts_list = list_scripts("scripts", ".py")

    syspath = sys.path
231

232 233 234 235 236 237 238 239 240 241
    def register_scripts_from_module(module):
        for key, script_class in module.__dict__.items():
            if type(script_class) != type:
                continue

            if issubclass(script_class, Script):
                scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
            elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
                postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))

S
sumof2primes 已提交
242 243 244 245 246 247 248 249
    def orderby(basedir):
        # 1st webui, 2nd extensions-builtin, 3rd extensions
        priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
        for key in priority:
            if basedir.startswith(key):
                return priority[key]
        return 9999

S
sumof2primes 已提交
250
    for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
251
        try:
252 253 254 255
            if scriptfile.basedir != paths.script_path:
                sys.path = [scriptfile.basedir] + sys.path
            current_basedir = scriptfile.basedir

256 257
            script_module = script_loading.load_module(scriptfile.path)
            register_scripts_from_module(script_module)
258 259

        except Exception:
260
            print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
261
            print(traceback.format_exc(), file=sys.stderr)
262

263 264 265 266
        finally:
            sys.path = syspath
            current_basedir = paths.script_path

267 268 269

def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
    try:
A
AUTOMATIC 已提交
270
        res = func(*args, **kwargs)
271 272
        return res
    except Exception:
A
AUTOMATIC 已提交
273
        print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
274 275 276 277 278
        print(traceback.format_exc(), file=sys.stderr)

    return default


A
AUTOMATIC 已提交
279 280 281
class ScriptRunner:
    def __init__(self):
        self.scripts = []
282 283
        self.selectable_scripts = []
        self.alwayson_scripts = []
ふぁ 已提交
284
        self.titles = []
285
        self.infotext_fields = []
286
        self.paste_field_names = []
A
AUTOMATIC 已提交
287

288
    def initialize_scripts(self, is_img2img):
289 290
        from modules import scripts_auto_postprocessing

291 292 293 294
        self.scripts.clear()
        self.alwayson_scripts.clear()
        self.selectable_scripts.clear()

295 296 297
        auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()

        for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
A
AUTOMATIC 已提交
298 299
            script = script_class()
            script.filename = path
300 301
            script.is_txt2img = not is_img2img
            script.is_img2img = is_img2img
A
AUTOMATIC 已提交
302

303
            visibility = script.show(script.is_img2img)
A
AUTOMATIC 已提交
304

305 306 307 308
            if visibility == AlwaysVisible:
                self.scripts.append(script)
                self.alwayson_scripts.append(script)
                script.alwayson = True
A
AUTOMATIC 已提交
309

310 311 312
            elif visibility:
                self.scripts.append(script)
                self.selectable_scripts.append(script)
A
AUTOMATIC 已提交
313

314
    def setup_ui(self):
315 316 317 318
        self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]

        inputs = [None]
        inputs_alwayson = [True]
A
AUTOMATIC 已提交
319

320
        def create_script_ui(script, inputs, inputs_alwayson):
A
AUTOMATIC 已提交
321
            script.args_from = len(inputs)
O
OWKenobi 已提交
322
            script.args_to = len(inputs)
A
AUTOMATIC 已提交
323

324
            controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
A
AUTOMATIC 已提交
325 326

            if controls is None:
327
                return
A
AUTOMATIC 已提交
328

A
AUTOMATIC 已提交
329
            for control in controls:
D
DepFA 已提交
330
                control.custom_script_source = os.path.basename(script.filename)
331 332 333

            if script.infotext_fields is not None:
                self.infotext_fields += script.infotext_fields
A
AUTOMATIC 已提交
334

335 336 337
            if script.paste_field_names is not None:
                self.paste_field_names += script.paste_field_names

A
AUTOMATIC 已提交
338
            inputs += controls
339
            inputs_alwayson += [script.alwayson for _ in controls]
A
AUTOMATIC 已提交
340
            script.args_to = len(inputs)
A
AUTOMATIC 已提交
341

342
        for script in self.alwayson_scripts:
343
            with gr.Group() as group:
344 345
                create_script_ui(script, inputs, inputs_alwayson)

346 347
            script.group = group

X
xmodar 已提交
348
        dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
349 350 351
        inputs[0] = dropdown

        for script in self.selectable_scripts:
352 353 354 355
            with gr.Group(visible=False) as group:
                create_script_ui(script, inputs, inputs_alwayson)

            script.group = group
356

A
AUTOMATIC 已提交
357
        def select_script(script_index):
358
            selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
A
AUTOMATIC 已提交
359

360
            return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
A
AUTOMATIC 已提交
361

ふぁ 已提交
362
        def init_field(title):
363 364
            """called when an initial value is set from ui-config.json to show script's UI components"""

ふぁ 已提交
365
            if title == 'None':
ふぁ 已提交
366
                return
367

ふぁ 已提交
368
            script_index = self.titles.index(title)
369
            self.selectable_scripts[script_index].group.visible = True
ふぁ 已提交
370 371

        dropdown.init_field = init_field
372

A
AUTOMATIC 已提交
373 374 375
        dropdown.change(
            fn=select_script,
            inputs=[dropdown],
376
            outputs=[script.group for script in self.selectable_scripts]
A
AUTOMATIC 已提交
377
        )
A
AUTOMATIC 已提交
378

E
EllangoK 已提交
379 380 381 382 383 384 385 386 387 388 389 390 391 392
        self.script_load_ctr = 0
        def onload_script_visibility(params):
            title = params.get('Script', None)
            if title:
                title_index = self.titles.index(title)
                visibility = title_index == self.script_load_ctr
                self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
                return gr.update(visible=visibility)
            else:
                return gr.update(visible=False)

        self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
        self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )

A
AUTOMATIC 已提交
393
        return inputs
A
AUTOMATIC 已提交
394

395
    def run(self, p, *args):
A
AUTOMATIC 已提交
396
        script_index = args[0]
A
AUTOMATIC 已提交
397

A
AUTOMATIC 已提交
398 399
        if script_index == 0:
            return None
A
AUTOMATIC 已提交
400

401
        script = self.selectable_scripts[script_index-1]
A
AUTOMATIC 已提交
402

A
AUTOMATIC 已提交
403 404
        if script is None:
            return None
A
AUTOMATIC 已提交
405

A
AUTOMATIC 已提交
406 407
        script_args = args[script.args_from:script.args_to]
        processed = script.run(p, *script_args)
A
AUTOMATIC 已提交
408

409 410
        shared.total_tqdm.clear()

A
AUTOMATIC 已提交
411
        return processed
A
AUTOMATIC 已提交
412

A
AUTOMATIC 已提交
413
    def process(self, p):
414 415 416 417 418
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.process(p, *script_args)
            except Exception:
A
AUTOMATIC 已提交
419 420 421
                print(f"Error running process: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

422 423 424 425 426 427 428 429 430
    def before_process_batch(self, p, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.before_process_batch(p, *script_args, **kwargs)
            except Exception:
                print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

431
    def process_batch(self, p, **kwargs):
A
Artem Zagidulin 已提交
432 433 434
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
435
                script.process_batch(p, *script_args, **kwargs)
A
Artem Zagidulin 已提交
436
            except Exception:
437
                print(f"Error running process_batch: {script.filename}", file=sys.stderr)
A
Artem Zagidulin 已提交
438 439
                print(traceback.format_exc(), file=sys.stderr)

A
AUTOMATIC 已提交
440 441 442 443 444 445 446
    def postprocess(self, p, processed):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess(p, processed, *script_args)
            except Exception:
                print(f"Error running postprocess: {script.filename}", file=sys.stderr)
447 448
                print(traceback.format_exc(), file=sys.stderr)

449 450 451 452 453 454 455 456 457
    def postprocess_batch(self, p, images, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_batch(p, *script_args, images=images, **kwargs)
            except Exception:
                print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

458 459 460 461 462 463 464 465 466
    def postprocess_image(self, p, pp: PostprocessImageArgs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_image(p, pp, *script_args)
            except Exception:
                print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482
    def before_component(self, component, **kwargs):
        for script in self.scripts:
            try:
                script.before_component(component, **kwargs)
            except Exception:
                print(f"Error running before_component: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

    def after_component(self, component, **kwargs):
        for script in self.scripts:
            try:
                script.after_component(component, **kwargs)
            except Exception:
                print(f"Error running after_component: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

A
AUTOMATIC 已提交
483
    def reload_sources(self, cache):
D
DepFA 已提交
484
        for si, script in list(enumerate(self.scripts)):
485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
            args_from = script.args_from
            args_to = script.args_to
            filename = script.filename

            module = cache.get(filename, None)
            if module is None:
                module = script_loading.load_module(script.filename)
                cache[filename] = module

            for key, script_class in module.__dict__.items():
                if type(script_class) == type and issubclass(script_class, Script):
                    self.scripts[si] = script_class()
                    self.scripts[si].filename = filename
                    self.scripts[si].args_from = args_from
                    self.scripts[si].args_to = args_to
A
AUTOMATIC 已提交
500

501

A
AUTOMATIC 已提交
502 503
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
504
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
505
scripts_current: ScriptRunner = None
D
DepFA 已提交
506

507

D
DepFA 已提交
508
def reload_script_body_only():
A
AUTOMATIC 已提交
509 510 511
    cache = {}
    scripts_txt2img.reload_sources(cache)
    scripts_img2img.reload_sources(cache)
D
DepFA 已提交
512

D
DepFA 已提交
513

514
def reload_scripts():
515
    global scripts_txt2img, scripts_img2img, scripts_postproc
D
DepFA 已提交
516

517
    load_scripts()
D
DepFA 已提交
518

D
DepFA 已提交
519 520
    scripts_txt2img = ScriptRunner()
    scripts_img2img = ScriptRunner()
521
    scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
522

523

524 525 526 527 528 529 530 531 532 533 534 535
def add_classes_to_gradio_component(comp):
    """
    this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
    """

    comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]

    if getattr(comp, 'multiselect', False):
        comp.elem_classes.append('multiselect')



536 537 538 539 540 541 542 543
def IOComponent_init(self, *args, **kwargs):
    if scripts_current is not None:
        scripts_current.before_component(self, **kwargs)

    script_callbacks.before_component_callback(self, **kwargs)

    res = original_IOComponent_init(self, *args, **kwargs)

544
    add_classes_to_gradio_component(self)
A
AUTOMATIC 已提交
545

546 547 548 549 550 551 552 553 554 555
    script_callbacks.after_component_callback(self, **kwargs)

    if scripts_current is not None:
        scripts_current.after_component(self, **kwargs)

    return res


original_IOComponent_init = gr.components.IOComponent.__init__
gr.components.IOComponent.__init__ = IOComponent_init