scripts.py 17.5 KB
Newer Older
1
import os
2
import re
3 4
import sys
import traceback
5
from collections import namedtuple
6 7 8

import gradio as gr

9
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
10 11 12

AlwaysVisible = object()

A
AUTOMATIC 已提交
13

14 15 16 17 18
class PostprocessImageArgs:
    def __init__(self, image):
        self.image = image


19 20
class Script:
    filename = None
A
AUTOMATIC 已提交
21 22
    args_from = None
    args_to = None
23 24
    alwayson = False

25 26 27
    is_txt2img = False
    is_img2img = False

28 29 30
    """A gr.Group component that has all script's UI inside it"""
    group = None

31 32 33 34
    infotext_fields = None
    """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
    parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
    """
35 36

    def title(self):
37 38
        """this function should return the title of the script. This is what will be displayed in the dropdown menu."""

39 40
        raise NotImplementedError()

A
AUTOMATIC 已提交
41
    def ui(self, is_img2img):
42 43
        """this function should create gradio UI elements. See https://gradio.app/docs/#components
        The return value should be an array of all components that are used in processing.
J
Jim Hays 已提交
44
        Values of those returned components will be passed to run() and process() functions.
45 46
        """

A
AUTOMATIC 已提交
47 48
        pass

A
AUTOMATIC 已提交
49
    def show(self, is_img2img):
50 51 52 53 54
        """
        is_img2img is True if this function is called for the img2img interface, and Fasle otherwise

        This function should return:
         - False if the script should not be shown in UI at all
J
Jim Hays 已提交
55
         - True if the script should be shown in UI if it's selected in the scripts dropdown
56 57 58
         - script.AlwaysVisible if the script should be shown in UI at all times
         """

A
AUTOMATIC 已提交
59 60
        return True

61 62 63 64 65 66 67 68 69 70 71
    def run(self, p, *args):
        """
        This function is called if the script has been selected in the script dropdown.
        It must do all processing and return the Processed object with results, same as
        one returned by processing.process_images.

        Usually the processing is done by calling the processing.process_images function.

        args contains all values returned by components from ui()
        """

72
        pass
A
AUTOMATIC 已提交
73

74 75 76
    def process(self, p, *args):
        """
        This function is called before processing begins for AlwaysVisible scripts.
A
AUTOMATIC 已提交
77 78 79 80 81 82
        You can modify the processing object (p) here, inject hooks, etc.
        args contains all values returned by components from ui()
        """

        pass

83
    def process_batch(self, p, *args, **kwargs):
A
Artem Zagidulin 已提交
84
        """
85 86 87 88 89 90 91
        Same as process(), but called for every batch.

        **kwargs will have those items:
          - batch_number - index of current batch, from 0 to number of batches-1
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
          - seeds - list of seeds for current batch
          - subseeds - list of subseeds for current batch
A
Artem Zagidulin 已提交
92 93 94 95
        """

        pass

96 97 98 99 100 101 102 103 104 105 106
    def postprocess_batch(self, p, *args, **kwargs):
        """
        Same as process_batch(), but called for every batch after it has been generated.

        **kwargs will have same items as process_batch, and also:
          - batch_number - index of current batch, from 0 to number of batches-1
          - images - torch tensor with all generated images, with values ranging from 0 to 1;
        """

        pass

107 108 109 110 111 112 113
    def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
        """
        Called for every image after it has been generated.
        """

        pass

A
AUTOMATIC 已提交
114 115 116 117
    def postprocess(self, p, processed, *args):
        """
        This function is called after processing ends for AlwaysVisible scripts.
        args contains all values returned by components from ui()
118 119 120 121
        """

        pass

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
    def before_component(self, component, **kwargs):
        """
        Called before a component is created.
        Use elem_id/label fields of kwargs to figure out which component it is.
        This can be useful to inject your own components somewhere in the middle of vanilla UI.
        You can return created components in the ui() function to add them to the list of arguments for your processing functions
        """

        pass

    def after_component(self, component, **kwargs):
        """
        Called after a component is created. Same as above.
        """

        pass

A
AUTOMATIC 已提交
139
    def describe(self):
140
        """unused"""
A
AUTOMATIC 已提交
141 142
        return ""

143 144 145 146 147 148 149 150 151
    def elem_id(self, item_id):
        """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""

        need_tabname = self.show(True) == self.show(False)
        tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
        title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))

        return f'script_{tabname}{title}_{item_id}'

152

153 154 155 156 157 158 159 160 161 162 163 164
current_basedir = paths.script_path


def basedir():
    """returns the base directory for the current script. For scripts in the main scripts directory,
    this is the main directory (where webui.py resides), and for scripts in extensions directory
    (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
    """
    return current_basedir


ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
165 166 167

scripts_data = []
postprocessing_scripts_data = []
168
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
169 170 171 172 173 174 175 176 177 178


def list_scripts(scriptdirname, extension):
    scripts_list = []

    basedir = os.path.join(paths.script_path, scriptdirname)
    if os.path.exists(basedir):
        for filename in sorted(os.listdir(basedir)):
            scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))

179 180
    for ext in extensions.active():
        scripts_list += ext.list_files(scriptdirname, extension)
181

182
    scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
183

184
    return scripts_list
185

186

A
AUTOMATIC 已提交
187 188 189
def list_files_with_name(filename):
    res = []

190
    dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
A
AUTOMATIC 已提交
191 192 193 194 195 196

    for dirpath in dirs:
        if not os.path.isdir(dirpath):
            continue

        path = os.path.join(dirpath, filename)
T
Tong Zeng 已提交
197
        if os.path.isfile(path):
A
AUTOMATIC 已提交
198 199 200 201 202
            res.append(path)

    return res


203 204 205
def load_scripts():
    global current_basedir
    scripts_data.clear()
206
    postprocessing_scripts_data.clear()
207 208 209 210 211
    script_callbacks.clear_callbacks()

    scripts_list = list_scripts("scripts", ".py")

    syspath = sys.path
212

213 214 215 216 217 218 219 220 221 222
    def register_scripts_from_module(module):
        for key, script_class in module.__dict__.items():
            if type(script_class) != type:
                continue

            if issubclass(script_class, Script):
                scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
            elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
                postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))

223
    for scriptfile in sorted(scripts_list):
224
        try:
225 226 227 228
            if scriptfile.basedir != paths.script_path:
                sys.path = [scriptfile.basedir] + sys.path
            current_basedir = scriptfile.basedir

229 230
            script_module = script_loading.load_module(scriptfile.path)
            register_scripts_from_module(script_module)
231 232

        except Exception:
233
            print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
234
            print(traceback.format_exc(), file=sys.stderr)
235

236 237 238 239
        finally:
            sys.path = syspath
            current_basedir = paths.script_path

240 241 242

def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
    try:
A
AUTOMATIC 已提交
243
        res = func(*args, **kwargs)
244 245
        return res
    except Exception:
A
AUTOMATIC 已提交
246
        print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
247 248 249 250 251
        print(traceback.format_exc(), file=sys.stderr)

    return default


A
AUTOMATIC 已提交
252 253 254
class ScriptRunner:
    def __init__(self):
        self.scripts = []
255 256
        self.selectable_scripts = []
        self.alwayson_scripts = []
ふぁ 已提交
257
        self.titles = []
258
        self.infotext_fields = []
A
AUTOMATIC 已提交
259

260
    def initialize_scripts(self, is_img2img):
261 262
        from modules import scripts_auto_postprocessing

263 264 265 266
        self.scripts.clear()
        self.alwayson_scripts.clear()
        self.selectable_scripts.clear()

267 268 269
        auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()

        for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
A
AUTOMATIC 已提交
270 271
            script = script_class()
            script.filename = path
272 273
            script.is_txt2img = not is_img2img
            script.is_img2img = is_img2img
A
AUTOMATIC 已提交
274

275
            visibility = script.show(script.is_img2img)
A
AUTOMATIC 已提交
276

277 278 279 280
            if visibility == AlwaysVisible:
                self.scripts.append(script)
                self.alwayson_scripts.append(script)
                script.alwayson = True
A
AUTOMATIC 已提交
281

282 283 284
            elif visibility:
                self.scripts.append(script)
                self.selectable_scripts.append(script)
A
AUTOMATIC 已提交
285

286
    def setup_ui(self):
287 288 289 290
        self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]

        inputs = [None]
        inputs_alwayson = [True]
A
AUTOMATIC 已提交
291

292
        def create_script_ui(script, inputs, inputs_alwayson):
A
AUTOMATIC 已提交
293
            script.args_from = len(inputs)
O
OWKenobi 已提交
294
            script.args_to = len(inputs)
A
AUTOMATIC 已提交
295

296
            controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
A
AUTOMATIC 已提交
297 298

            if controls is None:
299
                return
A
AUTOMATIC 已提交
300

A
AUTOMATIC 已提交
301
            for control in controls:
D
DepFA 已提交
302
                control.custom_script_source = os.path.basename(script.filename)
303 304 305

            if script.infotext_fields is not None:
                self.infotext_fields += script.infotext_fields
A
AUTOMATIC 已提交
306

A
AUTOMATIC 已提交
307
            inputs += controls
308
            inputs_alwayson += [script.alwayson for _ in controls]
A
AUTOMATIC 已提交
309
            script.args_to = len(inputs)
A
AUTOMATIC 已提交
310

311
        for script in self.alwayson_scripts:
312
            with gr.Group() as group:
313 314
                create_script_ui(script, inputs, inputs_alwayson)

315 316
            script.group = group

X
xmodar 已提交
317
        dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
318 319 320
        inputs[0] = dropdown

        for script in self.selectable_scripts:
321 322 323 324
            with gr.Group(visible=False) as group:
                create_script_ui(script, inputs, inputs_alwayson)

            script.group = group
325

A
AUTOMATIC 已提交
326
        def select_script(script_index):
327
            selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
A
AUTOMATIC 已提交
328

329
            return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
A
AUTOMATIC 已提交
330

ふぁ 已提交
331
        def init_field(title):
332 333
            """called when an initial value is set from ui-config.json to show script's UI components"""

ふぁ 已提交
334
            if title == 'None':
ふぁ 已提交
335
                return
336

ふぁ 已提交
337
            script_index = self.titles.index(title)
338
            self.selectable_scripts[script_index].group.visible = True
ふぁ 已提交
339 340

        dropdown.init_field = init_field
341

A
AUTOMATIC 已提交
342 343 344
        dropdown.change(
            fn=select_script,
            inputs=[dropdown],
345
            outputs=[script.group for script in self.selectable_scripts]
A
AUTOMATIC 已提交
346
        )
A
AUTOMATIC 已提交
347

E
EllangoK 已提交
348 349 350 351 352 353 354 355 356 357 358 359 360 361
        self.script_load_ctr = 0
        def onload_script_visibility(params):
            title = params.get('Script', None)
            if title:
                title_index = self.titles.index(title)
                visibility = title_index == self.script_load_ctr
                self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
                return gr.update(visible=visibility)
            else:
                return gr.update(visible=False)

        self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
        self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )

A
AUTOMATIC 已提交
362
        return inputs
A
AUTOMATIC 已提交
363

364
    def run(self, p, *args):
A
AUTOMATIC 已提交
365
        script_index = args[0]
A
AUTOMATIC 已提交
366

A
AUTOMATIC 已提交
367 368
        if script_index == 0:
            return None
A
AUTOMATIC 已提交
369

370
        script = self.selectable_scripts[script_index-1]
A
AUTOMATIC 已提交
371

A
AUTOMATIC 已提交
372 373
        if script is None:
            return None
A
AUTOMATIC 已提交
374

A
AUTOMATIC 已提交
375 376
        script_args = args[script.args_from:script.args_to]
        processed = script.run(p, *script_args)
A
AUTOMATIC 已提交
377

378 379
        shared.total_tqdm.clear()

A
AUTOMATIC 已提交
380
        return processed
A
AUTOMATIC 已提交
381

A
AUTOMATIC 已提交
382
    def process(self, p):
383 384 385 386 387
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.process(p, *script_args)
            except Exception:
A
AUTOMATIC 已提交
388 389 390
                print(f"Error running process: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

391
    def process_batch(self, p, **kwargs):
A
Artem Zagidulin 已提交
392 393 394
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
395
                script.process_batch(p, *script_args, **kwargs)
A
Artem Zagidulin 已提交
396
            except Exception:
397
                print(f"Error running process_batch: {script.filename}", file=sys.stderr)
A
Artem Zagidulin 已提交
398 399
                print(traceback.format_exc(), file=sys.stderr)

A
AUTOMATIC 已提交
400 401 402 403 404 405 406
    def postprocess(self, p, processed):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess(p, processed, *script_args)
            except Exception:
                print(f"Error running postprocess: {script.filename}", file=sys.stderr)
407 408
                print(traceback.format_exc(), file=sys.stderr)

409 410 411 412 413 414 415 416 417
    def postprocess_batch(self, p, images, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_batch(p, *script_args, images=images, **kwargs)
            except Exception:
                print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

418 419 420 421 422 423 424 425 426
    def postprocess_image(self, p, pp: PostprocessImageArgs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_image(p, pp, *script_args)
            except Exception:
                print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
    def before_component(self, component, **kwargs):
        for script in self.scripts:
            try:
                script.before_component(component, **kwargs)
            except Exception:
                print(f"Error running before_component: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

    def after_component(self, component, **kwargs):
        for script in self.scripts:
            try:
                script.after_component(component, **kwargs)
            except Exception:
                print(f"Error running after_component: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

A
AUTOMATIC 已提交
443
    def reload_sources(self, cache):
D
DepFA 已提交
444
        for si, script in list(enumerate(self.scripts)):
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
            args_from = script.args_from
            args_to = script.args_to
            filename = script.filename

            module = cache.get(filename, None)
            if module is None:
                module = script_loading.load_module(script.filename)
                cache[filename] = module

            for key, script_class in module.__dict__.items():
                if type(script_class) == type and issubclass(script_class, Script):
                    self.scripts[si] = script_class()
                    self.scripts[si].filename = filename
                    self.scripts[si].args_from = args_from
                    self.scripts[si].args_to = args_to
A
AUTOMATIC 已提交
460

461

A
AUTOMATIC 已提交
462 463
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
464
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
465
scripts_current: ScriptRunner = None
D
DepFA 已提交
466

467

D
DepFA 已提交
468
def reload_script_body_only():
A
AUTOMATIC 已提交
469 470 471
    cache = {}
    scripts_txt2img.reload_sources(cache)
    scripts_img2img.reload_sources(cache)
D
DepFA 已提交
472

D
DepFA 已提交
473

474
def reload_scripts():
475
    global scripts_txt2img, scripts_img2img, scripts_postproc
D
DepFA 已提交
476

477
    load_scripts()
D
DepFA 已提交
478

D
DepFA 已提交
479 480
    scripts_txt2img = ScriptRunner()
    scripts_img2img = ScriptRunner()
481
    scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
482

483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501

def IOComponent_init(self, *args, **kwargs):
    if scripts_current is not None:
        scripts_current.before_component(self, **kwargs)

    script_callbacks.before_component_callback(self, **kwargs)

    res = original_IOComponent_init(self, *args, **kwargs)

    script_callbacks.after_component_callback(self, **kwargs)

    if scripts_current is not None:
        scripts_current.after_component(self, **kwargs)

    return res


original_IOComponent_init = gr.components.IOComponent.__init__
gr.components.IOComponent.__init__ = IOComponent_init