scripts.py 15.4 KB
Newer Older
1
import os
2
import re
3 4
import sys
import traceback
5
from collections import namedtuple
6 7 8

import gradio as gr

A
AUTOMATIC 已提交
9
from modules.processing import StableDiffusionProcessing
10
from modules import shared, paths, script_callbacks, extensions, script_loading
11 12 13

AlwaysVisible = object()

A
AUTOMATIC 已提交
14

15 16
class Script:
    filename = None
A
AUTOMATIC 已提交
17 18
    args_from = None
    args_to = None
19 20
    alwayson = False

21 22 23
    is_txt2img = False
    is_img2img = False

24 25 26
    """A gr.Group component that has all script's UI inside it"""
    group = None

27 28 29 30
    infotext_fields = None
    """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
    parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
    """
31 32

    def title(self):
33 34
        """this function should return the title of the script. This is what will be displayed in the dropdown menu."""

35 36
        raise NotImplementedError()

A
AUTOMATIC 已提交
37
    def ui(self, is_img2img):
38 39
        """this function should create gradio UI elements. See https://gradio.app/docs/#components
        The return value should be an array of all components that are used in processing.
J
Jim Hays 已提交
40
        Values of those returned components will be passed to run() and process() functions.
41 42
        """

A
AUTOMATIC 已提交
43 44
        pass

A
AUTOMATIC 已提交
45
    def show(self, is_img2img):
46 47 48 49 50
        """
        is_img2img is True if this function is called for the img2img interface, and Fasle otherwise

        This function should return:
         - False if the script should not be shown in UI at all
J
Jim Hays 已提交
51
         - True if the script should be shown in UI if it's selected in the scripts dropdown
52 53 54
         - script.AlwaysVisible if the script should be shown in UI at all times
         """

A
AUTOMATIC 已提交
55 56
        return True

57 58 59 60 61 62 63 64 65 66 67
    def run(self, p, *args):
        """
        This function is called if the script has been selected in the script dropdown.
        It must do all processing and return the Processed object with results, same as
        one returned by processing.process_images.

        Usually the processing is done by calling the processing.process_images function.

        args contains all values returned by components from ui()
        """

A
AUTOMATIC 已提交
68 69
        raise NotImplementedError()

70 71 72
    def process(self, p, *args):
        """
        This function is called before processing begins for AlwaysVisible scripts.
A
AUTOMATIC 已提交
73 74 75 76 77 78
        You can modify the processing object (p) here, inject hooks, etc.
        args contains all values returned by components from ui()
        """

        pass

79
    def process_batch(self, p, *args, **kwargs):
A
Artem Zagidulin 已提交
80
        """
81 82 83 84 85 86 87
        Same as process(), but called for every batch.

        **kwargs will have those items:
          - batch_number - index of current batch, from 0 to number of batches-1
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
          - seeds - list of seeds for current batch
          - subseeds - list of subseeds for current batch
A
Artem Zagidulin 已提交
88 89 90 91
        """

        pass

92 93 94 95 96 97 98 99 100 101 102
    def postprocess_batch(self, p, *args, **kwargs):
        """
        Same as process_batch(), but called for every batch after it has been generated.

        **kwargs will have same items as process_batch, and also:
          - batch_number - index of current batch, from 0 to number of batches-1
          - images - torch tensor with all generated images, with values ranging from 0 to 1;
        """

        pass

A
AUTOMATIC 已提交
103 104 105 106
    def postprocess(self, p, processed, *args):
        """
        This function is called after processing ends for AlwaysVisible scripts.
        args contains all values returned by components from ui()
107 108 109 110
        """

        pass

111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
    def before_component(self, component, **kwargs):
        """
        Called before a component is created.
        Use elem_id/label fields of kwargs to figure out which component it is.
        This can be useful to inject your own components somewhere in the middle of vanilla UI.
        You can return created components in the ui() function to add them to the list of arguments for your processing functions
        """

        pass

    def after_component(self, component, **kwargs):
        """
        Called after a component is created. Same as above.
        """

        pass

A
AUTOMATIC 已提交
128
    def describe(self):
129
        """unused"""
A
AUTOMATIC 已提交
130 131
        return ""

132 133 134 135 136 137 138 139 140
    def elem_id(self, item_id):
        """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""

        need_tabname = self.show(True) == self.show(False)
        tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
        title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))

        return f'script_{tabname}{title}_{item_id}'

141

142 143 144 145 146 147 148 149 150 151 152
current_basedir = paths.script_path


def basedir():
    """returns the base directory for the current script. For scripts in the main scripts directory,
    this is the main directory (where webui.py resides), and for scripts in extensions directory
    (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
    """
    return current_basedir


A
AUTOMATIC 已提交
153
scripts_data = []
154
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
155
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
156 157 158 159 160 161 162 163 164 165


def list_scripts(scriptdirname, extension):
    scripts_list = []

    basedir = os.path.join(paths.script_path, scriptdirname)
    if os.path.exists(basedir):
        for filename in sorted(os.listdir(basedir)):
            scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))

166 167
    for ext in extensions.active():
        scripts_list += ext.list_files(scriptdirname, extension)
168

169
    scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
170

171
    return scripts_list
172

173

A
AUTOMATIC 已提交
174 175 176
def list_files_with_name(filename):
    res = []

177
    dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
A
AUTOMATIC 已提交
178 179 180 181 182 183

    for dirpath in dirs:
        if not os.path.isdir(dirpath):
            continue

        path = os.path.join(dirpath, filename)
T
Tong Zeng 已提交
184
        if os.path.isfile(path):
A
AUTOMATIC 已提交
185 186 187 188 189
            res.append(path)

    return res


190 191 192 193 194 195 196 197
def load_scripts():
    global current_basedir
    scripts_data.clear()
    script_callbacks.clear_callbacks()

    scripts_list = list_scripts("scripts", ".py")

    syspath = sys.path
198

199
    for scriptfile in sorted(scripts_list):
200
        try:
201 202 203 204
            if scriptfile.basedir != paths.script_path:
                sys.path = [scriptfile.basedir] + sys.path
            current_basedir = scriptfile.basedir

205
            module = script_loading.load_module(scriptfile.path)
206 207 208

            for key, script_class in module.__dict__.items():
                if type(script_class) == type and issubclass(script_class, Script):
209
                    scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
210 211

        except Exception:
212
            print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
213
            print(traceback.format_exc(), file=sys.stderr)
214

215 216 217 218
        finally:
            sys.path = syspath
            current_basedir = paths.script_path

219 220 221

def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
    try:
A
AUTOMATIC 已提交
222
        res = func(*args, **kwargs)
223 224
        return res
    except Exception:
A
AUTOMATIC 已提交
225
        print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
226 227 228 229 230
        print(traceback.format_exc(), file=sys.stderr)

    return default


A
AUTOMATIC 已提交
231 232 233
class ScriptRunner:
    def __init__(self):
        self.scripts = []
234 235
        self.selectable_scripts = []
        self.alwayson_scripts = []
ふぁ 已提交
236
        self.titles = []
237
        self.infotext_fields = []
A
AUTOMATIC 已提交
238

239 240 241 242 243
    def initialize_scripts(self, is_img2img):
        self.scripts.clear()
        self.alwayson_scripts.clear()
        self.selectable_scripts.clear()

244
        for script_class, path, basedir, script_module in scripts_data:
A
AUTOMATIC 已提交
245 246
            script = script_class()
            script.filename = path
247 248
            script.is_txt2img = not is_img2img
            script.is_img2img = is_img2img
A
AUTOMATIC 已提交
249

250
            visibility = script.show(script.is_img2img)
A
AUTOMATIC 已提交
251

252 253 254 255
            if visibility == AlwaysVisible:
                self.scripts.append(script)
                self.alwayson_scripts.append(script)
                script.alwayson = True
A
AUTOMATIC 已提交
256

257 258 259
            elif visibility:
                self.scripts.append(script)
                self.selectable_scripts.append(script)
A
AUTOMATIC 已提交
260

261
    def setup_ui(self):
262 263 264 265
        self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]

        inputs = [None]
        inputs_alwayson = [True]
A
AUTOMATIC 已提交
266

267
        def create_script_ui(script, inputs, inputs_alwayson):
A
AUTOMATIC 已提交
268
            script.args_from = len(inputs)
O
OWKenobi 已提交
269
            script.args_to = len(inputs)
A
AUTOMATIC 已提交
270

271
            controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
A
AUTOMATIC 已提交
272 273

            if controls is None:
274
                return
A
AUTOMATIC 已提交
275

A
AUTOMATIC 已提交
276
            for control in controls:
D
DepFA 已提交
277
                control.custom_script_source = os.path.basename(script.filename)
278 279 280

            if script.infotext_fields is not None:
                self.infotext_fields += script.infotext_fields
A
AUTOMATIC 已提交
281

A
AUTOMATIC 已提交
282
            inputs += controls
283
            inputs_alwayson += [script.alwayson for _ in controls]
A
AUTOMATIC 已提交
284
            script.args_to = len(inputs)
A
AUTOMATIC 已提交
285

286
        for script in self.alwayson_scripts:
287
            with gr.Group() as group:
288 289
                create_script_ui(script, inputs, inputs_alwayson)

290 291
            script.group = group

X
xmodar 已提交
292
        dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
293 294 295
        inputs[0] = dropdown

        for script in self.selectable_scripts:
296 297 298 299
            with gr.Group(visible=False) as group:
                create_script_ui(script, inputs, inputs_alwayson)

            script.group = group
300

A
AUTOMATIC 已提交
301
        def select_script(script_index):
302
            selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
A
AUTOMATIC 已提交
303

304
            return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
A
AUTOMATIC 已提交
305

ふぁ 已提交
306
        def init_field(title):
307 308
            """called when an initial value is set from ui-config.json to show script's UI components"""

ふぁ 已提交
309
            if title == 'None':
ふぁ 已提交
310
                return
311

ふぁ 已提交
312
            script_index = self.titles.index(title)
313
            self.selectable_scripts[script_index].group.visible = True
ふぁ 已提交
314 315

        dropdown.init_field = init_field
316

A
AUTOMATIC 已提交
317 318 319
        dropdown.change(
            fn=select_script,
            inputs=[dropdown],
320
            outputs=[script.group for script in self.selectable_scripts]
A
AUTOMATIC 已提交
321
        )
A
AUTOMATIC 已提交
322

A
AUTOMATIC 已提交
323
        return inputs
A
AUTOMATIC 已提交
324

A
AUTOMATIC 已提交
325 326
    def run(self, p: StableDiffusionProcessing, *args):
        script_index = args[0]
A
AUTOMATIC 已提交
327

A
AUTOMATIC 已提交
328 329
        if script_index == 0:
            return None
A
AUTOMATIC 已提交
330

331
        script = self.selectable_scripts[script_index-1]
A
AUTOMATIC 已提交
332

A
AUTOMATIC 已提交
333 334
        if script is None:
            return None
A
AUTOMATIC 已提交
335

A
AUTOMATIC 已提交
336 337
        script_args = args[script.args_from:script.args_to]
        processed = script.run(p, *script_args)
A
AUTOMATIC 已提交
338

339 340
        shared.total_tqdm.clear()

A
AUTOMATIC 已提交
341
        return processed
A
AUTOMATIC 已提交
342

A
AUTOMATIC 已提交
343
    def process(self, p):
344 345 346 347 348
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.process(p, *script_args)
            except Exception:
A
AUTOMATIC 已提交
349 350 351
                print(f"Error running process: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

352
    def process_batch(self, p, **kwargs):
A
Artem Zagidulin 已提交
353 354 355
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
356
                script.process_batch(p, *script_args, **kwargs)
A
Artem Zagidulin 已提交
357
            except Exception:
358
                print(f"Error running process_batch: {script.filename}", file=sys.stderr)
A
Artem Zagidulin 已提交
359 360
                print(traceback.format_exc(), file=sys.stderr)

A
AUTOMATIC 已提交
361 362 363 364 365 366 367
    def postprocess(self, p, processed):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess(p, processed, *script_args)
            except Exception:
                print(f"Error running postprocess: {script.filename}", file=sys.stderr)
368 369
                print(traceback.format_exc(), file=sys.stderr)

370 371 372 373 374 375 376 377 378
    def postprocess_batch(self, p, images, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_batch(p, *script_args, images=images, **kwargs)
            except Exception:
                print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
    def before_component(self, component, **kwargs):
        for script in self.scripts:
            try:
                script.before_component(component, **kwargs)
            except Exception:
                print(f"Error running before_component: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

    def after_component(self, component, **kwargs):
        for script in self.scripts:
            try:
                script.after_component(component, **kwargs)
            except Exception:
                print(f"Error running after_component: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

A
AUTOMATIC 已提交
395
    def reload_sources(self, cache):
D
DepFA 已提交
396
        for si, script in list(enumerate(self.scripts)):
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
            args_from = script.args_from
            args_to = script.args_to
            filename = script.filename

            module = cache.get(filename, None)
            if module is None:
                module = script_loading.load_module(script.filename)
                cache[filename] = module

            for key, script_class in module.__dict__.items():
                if type(script_class) == type and issubclass(script_class, Script):
                    self.scripts[si] = script_class()
                    self.scripts[si].filename = filename
                    self.scripts[si].args_from = args_from
                    self.scripts[si].args_to = args_to
A
AUTOMATIC 已提交
412

413

A
AUTOMATIC 已提交
414 415
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
416
scripts_current: ScriptRunner = None
D
DepFA 已提交
417

418

D
DepFA 已提交
419
def reload_script_body_only():
A
AUTOMATIC 已提交
420 421 422
    cache = {}
    scripts_txt2img.reload_sources(cache)
    scripts_img2img.reload_sources(cache)
D
DepFA 已提交
423

D
DepFA 已提交
424

425
def reload_scripts():
D
DepFA 已提交
426
    global scripts_txt2img, scripts_img2img
D
DepFA 已提交
427

428
    load_scripts()
D
DepFA 已提交
429

D
DepFA 已提交
430 431
    scripts_txt2img = ScriptRunner()
    scripts_img2img = ScriptRunner()
432

433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451

def IOComponent_init(self, *args, **kwargs):
    if scripts_current is not None:
        scripts_current.before_component(self, **kwargs)

    script_callbacks.before_component_callback(self, **kwargs)

    res = original_IOComponent_init(self, *args, **kwargs)

    script_callbacks.after_component_callback(self, **kwargs)

    if scripts_current is not None:
        scripts_current.after_component(self, **kwargs)

    return res


original_IOComponent_init = gr.components.IOComponent.__init__
gr.components.IOComponent.__init__ = IOComponent_init