scripts.py 24.3 KB
Newer Older
1
import os
2
import re
3
import sys
H
huchenlei 已提交
4
import inspect
5
from collections import namedtuple
6 7 8

import gradio as gr

9
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
10 11 12

AlwaysVisible = object()

A
AUTOMATIC 已提交
13

14 15 16 17 18
class PostprocessImageArgs:
    def __init__(self, image):
        self.image = image


19 20 21 22 23
class PostprocessBatchListArgs:
    def __init__(self, images):
        self.images = images


24
class Script:
A
AUTOMATIC 已提交
25 26 27
    name = None
    """script's internal name derived from title"""

28 29 30
    section = None
    """name of UI section that the script's controls will be placed into"""

31
    filename = None
A
AUTOMATIC 已提交
32 33
    args_from = None
    args_to = None
34 35
    alwayson = False

36 37 38
    is_txt2img = False
    is_img2img = False

39
    group = None
A
AUTOMATIC 已提交
40
    """A gr.Group component that has all script's UI inside it"""
41

42 43 44 45
    infotext_fields = None
    """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
    parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
    """
46

47 48 49 50 51
    paste_field_names = None
    """if set in ui(), this is a list of names of infotext fields; the fields will be sent through the
    various "Send to <X>" buttons when clicked
    """

A
AUTOMATIC 已提交
52 53 54
    api_info = None
    """Generated value of type modules.api.models.ScriptInfo with information about the script for API"""

55
    def title(self):
56 57
        """this function should return the title of the script. This is what will be displayed in the dropdown menu."""

58 59
        raise NotImplementedError()

A
AUTOMATIC 已提交
60
    def ui(self, is_img2img):
61 62
        """this function should create gradio UI elements. See https://gradio.app/docs/#components
        The return value should be an array of all components that are used in processing.
J
Jim Hays 已提交
63
        Values of those returned components will be passed to run() and process() functions.
64 65
        """

A
AUTOMATIC 已提交
66 67
        pass

A
AUTOMATIC 已提交
68
    def show(self, is_img2img):
69 70 71 72 73
        """
        is_img2img is True if this function is called for the img2img interface, and Fasle otherwise

        This function should return:
         - False if the script should not be shown in UI at all
J
Jim Hays 已提交
74
         - True if the script should be shown in UI if it's selected in the scripts dropdown
75 76 77
         - script.AlwaysVisible if the script should be shown in UI at all times
         """

A
AUTOMATIC 已提交
78 79
        return True

80 81 82 83 84 85 86 87 88 89 90
    def run(self, p, *args):
        """
        This function is called if the script has been selected in the script dropdown.
        It must do all processing and return the Processed object with results, same as
        one returned by processing.process_images.

        Usually the processing is done by calling the processing.process_images function.

        args contains all values returned by components from ui()
        """

91
        pass
A
AUTOMATIC 已提交
92

93 94 95 96 97 98 99 100 101
    def before_process(self, p, *args):
        """
        This function is called very early before processing begins for AlwaysVisible scripts.
        You can modify the processing object (p) here, inject hooks, etc.
        args contains all values returned by components from ui()
        """

        pass

102 103 104
    def process(self, p, *args):
        """
        This function is called before processing begins for AlwaysVisible scripts.
A
AUTOMATIC 已提交
105 106 107 108 109 110
        You can modify the processing object (p) here, inject hooks, etc.
        args contains all values returned by components from ui()
        """

        pass

111 112 113 114 115 116 117 118 119 120 121 122 123 124
    def before_process_batch(self, p, *args, **kwargs):
        """
        Called before extra networks are parsed from the prompt, so you can add
        new extra network keywords to the prompt with this callback.

        **kwargs will have those items:
          - batch_number - index of current batch, from 0 to number of batches-1
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
          - seeds - list of seeds for current batch
          - subseeds - list of subseeds for current batch
        """

        pass

125 126
    def after_extra_networks_activate(self, p, *args, **kwargs):
        """
127
        Called after extra networks activation, before conds calculation
128 129 130 131 132 133 134 135 136 137 138 139
        allow modification of the network after extra networks activation been applied
        won't be call if p.disable_extra_networks

        **kwargs will have those items:
          - batch_number - index of current batch, from 0 to number of batches-1
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
          - seeds - list of seeds for current batch
          - subseeds - list of subseeds for current batch
          - extra_network_data - list of ExtraNetworkParams for current stage
        """
        pass

140
    def process_batch(self, p, *args, **kwargs):
A
Artem Zagidulin 已提交
141
        """
142 143 144 145 146 147 148
        Same as process(), but called for every batch.

        **kwargs will have those items:
          - batch_number - index of current batch, from 0 to number of batches-1
          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
          - seeds - list of seeds for current batch
          - subseeds - list of subseeds for current batch
A
Artem Zagidulin 已提交
149 150 151 152
        """

        pass

153 154 155 156 157 158 159 160 161 162 163
    def postprocess_batch(self, p, *args, **kwargs):
        """
        Same as process_batch(), but called for every batch after it has been generated.

        **kwargs will have same items as process_batch, and also:
          - batch_number - index of current batch, from 0 to number of batches-1
          - images - torch tensor with all generated images, with values ranging from 0 to 1;
        """

        pass

164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
    def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
        """
        Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
        This is useful when you want to update the entire batch instead of individual images.

        You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
        If the number of images is different from the batch size when returning,
        then the script has the responsibility to also update the following attributes in the processing object (p):
          - p.prompts
          - p.negative_prompts
          - p.seeds
          - p.subseeds

        **kwargs will have same items as process_batch, and also:
          - batch_number - index of current batch, from 0 to number of batches-1
        """

        pass

183 184 185 186 187 188 189
    def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
        """
        Called for every image after it has been generated.
        """

        pass

A
AUTOMATIC 已提交
190 191 192 193
    def postprocess(self, p, processed, *args):
        """
        This function is called after processing ends for AlwaysVisible scripts.
        args contains all values returned by components from ui()
194 195 196 197
        """

        pass

198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
    def before_component(self, component, **kwargs):
        """
        Called before a component is created.
        Use elem_id/label fields of kwargs to figure out which component it is.
        This can be useful to inject your own components somewhere in the middle of vanilla UI.
        You can return created components in the ui() function to add them to the list of arguments for your processing functions
        """

        pass

    def after_component(self, component, **kwargs):
        """
        Called after a component is created. Same as above.
        """

        pass

A
AUTOMATIC 已提交
215
    def describe(self):
216
        """unused"""
A
AUTOMATIC 已提交
217 218
        return ""

219 220 221 222
    def elem_id(self, item_id):
        """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""

        need_tabname = self.show(True) == self.show(False)
223 224
        tabkind = 'img2img' if self.is_img2img else 'txt2txt'
        tabname = f"{tabkind}_" if need_tabname else ""
225 226 227 228
        title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))

        return f'script_{tabname}{title}_{item_id}'

A
AUTOMATIC1111 已提交
229
    def before_hr(self, p, *args):
230 231 232 233
        """
        This function is called before hires fix start.
        """
        pass
234

235 236 237 238 239 240 241 242 243 244 245 246
current_basedir = paths.script_path


def basedir():
    """returns the base directory for the current script. For scripts in the main scripts directory,
    this is the main directory (where webui.py resides), and for scripts in extensions directory
    (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
    """
    return current_basedir


ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
247 248 249

scripts_data = []
postprocessing_scripts_data = []
250
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
251 252 253 254 255 256 257 258 259 260


def list_scripts(scriptdirname, extension):
    scripts_list = []

    basedir = os.path.join(paths.script_path, scriptdirname)
    if os.path.exists(basedir):
        for filename in sorted(os.listdir(basedir)):
            scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))

261 262
    for ext in extensions.active():
        scripts_list += ext.list_files(scriptdirname, extension)
263

264
    scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
265

266
    return scripts_list
267

268

A
AUTOMATIC 已提交
269 270 271
def list_files_with_name(filename):
    res = []

272
    dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
A
AUTOMATIC 已提交
273 274 275 276 277 278

    for dirpath in dirs:
        if not os.path.isdir(dirpath):
            continue

        path = os.path.join(dirpath, filename)
T
Tong Zeng 已提交
279
        if os.path.isfile(path):
A
AUTOMATIC 已提交
280 281 282 283 284
            res.append(path)

    return res


285 286 287
def load_scripts():
    global current_basedir
    scripts_data.clear()
288
    postprocessing_scripts_data.clear()
289 290 291 292 293
    script_callbacks.clear_callbacks()

    scripts_list = list_scripts("scripts", ".py")

    syspath = sys.path
294

295
    def register_scripts_from_module(module):
A
AUTOMATIC 已提交
296
        for script_class in module.__dict__.values():
H
huchenlei 已提交
297
            if not inspect.isclass(script_class):
298 299 300 301 302 303 304
                continue

            if issubclass(script_class, Script):
                scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
            elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
                postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))

S
sumof2primes 已提交
305 306 307 308 309 310 311 312
    def orderby(basedir):
        # 1st webui, 2nd extensions-builtin, 3rd extensions
        priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
        for key in priority:
            if basedir.startswith(key):
                return priority[key]
        return 9999

S
sumof2primes 已提交
313
    for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
314
        try:
315 316 317 318
            if scriptfile.basedir != paths.script_path:
                sys.path = [scriptfile.basedir] + sys.path
            current_basedir = scriptfile.basedir

319 320
            script_module = script_loading.load_module(scriptfile.path)
            register_scripts_from_module(script_module)
321 322

        except Exception:
323
            errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True)
324

325 326 327
        finally:
            sys.path = syspath
            current_basedir = paths.script_path
A
AUTOMATIC 已提交
328
            timer.startup_timer.record(scriptfile.filename)
329

330 331 332 333 334 335
    global scripts_txt2img, scripts_img2img, scripts_postproc

    scripts_txt2img = ScriptRunner()
    scripts_img2img = ScriptRunner()
    scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()

336 337 338

def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
    try:
339
        return func(*args, **kwargs)
340
    except Exception:
341
        errors.report(f"Error calling: {filename}/{funcname}", exc_info=True)
342 343 344 345

    return default


A
AUTOMATIC 已提交
346 347 348
class ScriptRunner:
    def __init__(self):
        self.scripts = []
349 350
        self.selectable_scripts = []
        self.alwayson_scripts = []
ふぁ 已提交
351
        self.titles = []
352
        self.infotext_fields = []
353
        self.paste_field_names = []
354
        self.inputs = [None]
A
AUTOMATIC 已提交
355

356
    def initialize_scripts(self, is_img2img):
357 358
        from modules import scripts_auto_postprocessing

359 360 361 362
        self.scripts.clear()
        self.alwayson_scripts.clear()
        self.selectable_scripts.clear()

363 364
        auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()

A
AUTOMATIC 已提交
365 366 367
        for script_data in auto_processing_scripts + scripts_data:
            script = script_data.script_class()
            script.filename = script_data.path
368 369
            script.is_txt2img = not is_img2img
            script.is_img2img = is_img2img
A
AUTOMATIC 已提交
370

371
            visibility = script.show(script.is_img2img)
A
AUTOMATIC 已提交
372

373 374 375 376
            if visibility == AlwaysVisible:
                self.scripts.append(script)
                self.alwayson_scripts.append(script)
                script.alwayson = True
A
AUTOMATIC 已提交
377

378 379 380
            elif visibility:
                self.scripts.append(script)
                self.selectable_scripts.append(script)
A
AUTOMATIC 已提交
381

382
    def create_script_ui(self, script):
A
AUTOMATIC 已提交
383 384
        import modules.api.models as api_models

385 386
        script.args_from = len(self.inputs)
        script.args_to = len(self.inputs)
387

388
        controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
A
AUTOMATIC 已提交
389

390 391
        if controls is None:
            return
A
AUTOMATIC 已提交
392

393 394
        script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
        api_args = []
A
AUTOMATIC 已提交
395

396 397
        for control in controls:
            control.custom_script_source = os.path.basename(script.filename)
A
AUTOMATIC 已提交
398

399
            arg_info = api_models.ScriptArg(label=control.label or "")
A
AUTOMATIC 已提交
400

401 402 403 404
            for field in ("value", "minimum", "maximum", "step", "choices"):
                v = getattr(control, field, None)
                if v is not None:
                    setattr(arg_info, field, v)
405

406
            api_args.append(arg_info)
A
AUTOMATIC 已提交
407

408 409 410 411 412 413
        script.api_info = api_models.ScriptInfo(
            name=script.name,
            is_img2img=script.is_img2img,
            is_alwayson=script.alwayson,
            args=api_args,
        )
A
AUTOMATIC 已提交
414

415 416
        if script.infotext_fields is not None:
            self.infotext_fields += script.infotext_fields
A
AUTOMATIC 已提交
417

418 419
        if script.paste_field_names is not None:
            self.paste_field_names += script.paste_field_names
A
AUTOMATIC 已提交
420

421 422
        self.inputs += controls
        script.args_to = len(self.inputs)
A
AUTOMATIC 已提交
423

424 425 426
    def setup_ui_for_section(self, section, scriptlist=None):
        if scriptlist is None:
            scriptlist = self.alwayson_scripts
427

428 429 430
        for script in scriptlist:
            if script.alwayson and script.section != section:
                continue
A
AUTOMATIC 已提交
431

432 433
            with gr.Group(visible=script.alwayson) as group:
                self.create_script_ui(script)
434

435 436
            script.group = group

437 438
    def prepare_ui(self):
        self.inputs = [None]
439

440 441
    def setup_ui(self):
        self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
442

443 444 445 446 447 448
        self.setup_ui_for_section(None)

        dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
        self.inputs[0] = dropdown

        self.setup_ui_for_section(None, self.selectable_scripts)
449

A
AUTOMATIC 已提交
450
        def select_script(script_index):
451
            selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
A
AUTOMATIC 已提交
452

453
            return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
A
AUTOMATIC 已提交
454

ふぁ 已提交
455
        def init_field(title):
456 457
            """called when an initial value is set from ui-config.json to show script's UI components"""

ふぁ 已提交
458
            if title == 'None':
ふぁ 已提交
459
                return
460

ふぁ 已提交
461
            script_index = self.titles.index(title)
462
            self.selectable_scripts[script_index].group.visible = True
ふぁ 已提交
463 464

        dropdown.init_field = init_field
465

A
AUTOMATIC 已提交
466 467 468
        dropdown.change(
            fn=select_script,
            inputs=[dropdown],
469
            outputs=[script.group for script in self.selectable_scripts]
A
AUTOMATIC 已提交
470
        )
A
AUTOMATIC 已提交
471

E
EllangoK 已提交
472
        self.script_load_ctr = 0
473

E
EllangoK 已提交
474 475 476 477 478 479 480 481 482 483
        def onload_script_visibility(params):
            title = params.get('Script', None)
            if title:
                title_index = self.titles.index(title)
                visibility = title_index == self.script_load_ctr
                self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
                return gr.update(visible=visibility)
            else:
                return gr.update(visible=False)

484 485
        self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
        self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
E
EllangoK 已提交
486

487
        return self.inputs
A
AUTOMATIC 已提交
488

489
    def run(self, p, *args):
A
AUTOMATIC 已提交
490
        script_index = args[0]
A
AUTOMATIC 已提交
491

A
AUTOMATIC 已提交
492 493
        if script_index == 0:
            return None
A
AUTOMATIC 已提交
494

495
        script = self.selectable_scripts[script_index-1]
A
AUTOMATIC 已提交
496

A
AUTOMATIC 已提交
497 498
        if script is None:
            return None
A
AUTOMATIC 已提交
499

A
AUTOMATIC 已提交
500 501
        script_args = args[script.args_from:script.args_to]
        processed = script.run(p, *script_args)
A
AUTOMATIC 已提交
502

503 504
        shared.total_tqdm.clear()

A
AUTOMATIC 已提交
505
        return processed
A
AUTOMATIC 已提交
506

507 508 509 510 511 512 513 514
    def before_process(self, p):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.before_process(p, *script_args)
            except Exception:
                errors.report(f"Error running before_process: {script.filename}", exc_info=True)

A
AUTOMATIC 已提交
515
    def process(self, p):
516 517 518 519 520
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.process(p, *script_args)
            except Exception:
521
                errors.report(f"Error running process: {script.filename}", exc_info=True)
A
AUTOMATIC 已提交
522

523 524 525 526 527 528
    def before_process_batch(self, p, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.before_process_batch(p, *script_args, **kwargs)
            except Exception:
529
                errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
530

531 532 533 534 535 536 537 538
    def after_extra_networks_activate(self, p, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.after_extra_networks_activate(p, *script_args, **kwargs)
            except Exception:
                errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)

539
    def process_batch(self, p, **kwargs):
A
Artem Zagidulin 已提交
540 541 542
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
543
                script.process_batch(p, *script_args, **kwargs)
A
Artem Zagidulin 已提交
544
            except Exception:
545
                errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
A
Artem Zagidulin 已提交
546

A
AUTOMATIC 已提交
547 548 549 550 551 552
    def postprocess(self, p, processed):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess(p, processed, *script_args)
            except Exception:
553
                errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
554

555 556 557 558 559 560
    def postprocess_batch(self, p, images, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_batch(p, *script_args, images=images, **kwargs)
            except Exception:
561
                errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
562

563 564 565 566 567 568 569 570
    def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_batch_list(p, pp, *script_args, **kwargs)
            except Exception:
                errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)

571 572 573 574 575 576
    def postprocess_image(self, p, pp: PostprocessImageArgs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_image(p, pp, *script_args)
            except Exception:
577
                errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
578

579 580 581 582 583
    def before_component(self, component, **kwargs):
        for script in self.scripts:
            try:
                script.before_component(component, **kwargs)
            except Exception:
584
                errors.report(f"Error running before_component: {script.filename}", exc_info=True)
585 586 587 588 589 590

    def after_component(self, component, **kwargs):
        for script in self.scripts:
            try:
                script.after_component(component, **kwargs)
            except Exception:
591
                errors.report(f"Error running after_component: {script.filename}", exc_info=True)
592

A
AUTOMATIC 已提交
593
    def reload_sources(self, cache):
D
DepFA 已提交
594
        for si, script in list(enumerate(self.scripts)):
595 596 597 598 599 600 601 602 603
            args_from = script.args_from
            args_to = script.args_to
            filename = script.filename

            module = cache.get(filename, None)
            if module is None:
                module = script_loading.load_module(script.filename)
                cache[filename] = module

A
AUTOMATIC 已提交
604
            for script_class in module.__dict__.values():
605 606 607 608 609
                if type(script_class) == type and issubclass(script_class, Script):
                    self.scripts[si] = script_class()
                    self.scripts[si].filename = filename
                    self.scripts[si].args_from = args_from
                    self.scripts[si].args_to = args_to
A
AUTOMATIC 已提交
610

611

612 613 614 615 616 617 618 619 620
    def before_hr(self, p):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.before_hr(p, *script_args)
            except Exception:
                errors.report(f"Error running before_hr: {script.filename}", exc_info=True)


621 622 623
scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None
scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
624
scripts_current: ScriptRunner = None
D
DepFA 已提交
625

626

D
DepFA 已提交
627
def reload_script_body_only():
A
AUTOMATIC 已提交
628 629 630
    cache = {}
    scripts_txt2img.reload_sources(cache)
    scripts_img2img.reload_sources(cache)
D
DepFA 已提交
631

D
DepFA 已提交
632

633
reload_scripts = load_scripts  # compatibility alias
634

635

636 637 638 639 640
def add_classes_to_gradio_component(comp):
    """
    this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
    """

641
    comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
642 643 644 645 646 647

    if getattr(comp, 'multiselect', False):
        comp.elem_classes.append('multiselect')



648 649 650 651 652 653 654 655
def IOComponent_init(self, *args, **kwargs):
    if scripts_current is not None:
        scripts_current.before_component(self, **kwargs)

    script_callbacks.before_component_callback(self, **kwargs)

    res = original_IOComponent_init(self, *args, **kwargs)

656
    add_classes_to_gradio_component(self)
A
AUTOMATIC 已提交
657

658 659 660 661 662 663 664 665 666 667
    script_callbacks.after_component_callback(self, **kwargs)

    if scripts_current is not None:
        scripts_current.after_component(self, **kwargs)

    return res


original_IOComponent_init = gr.components.IOComponent.__init__
gr.components.IOComponent.__init__ = IOComponent_init
668 669 670 671 672 673 674 675 676 677 678 679


def BlockContext_init(self, *args, **kwargs):
    res = original_BlockContext_init(self, *args, **kwargs)

    add_classes_to_gradio_component(self)

    return res


original_BlockContext_init = gr.blocks.BlockContext.__init__
gr.blocks.BlockContext.__init__ = BlockContext_init