gradio_extensons.py 5.3 KB
Newer Older
W
w-e-w 已提交
1 2
from inspect import signature
from functools import wraps
A
AUTOMATIC1111 已提交
3 4
import gradio as gr

5 6
from modules import scripts, ui_tempdir, patches

A
AUTOMATIC1111 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45

def add_classes_to_gradio_component(comp):
    """
    this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
    """

    comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]

    if getattr(comp, 'multiselect', False):
        comp.elem_classes.append('multiselect')


def IOComponent_init(self, *args, **kwargs):
    self.webui_tooltip = kwargs.pop('tooltip', None)

    if scripts.scripts_current is not None:
        scripts.scripts_current.before_component(self, **kwargs)

    scripts.script_callbacks.before_component_callback(self, **kwargs)

    res = original_IOComponent_init(self, *args, **kwargs)

    add_classes_to_gradio_component(self)

    scripts.script_callbacks.after_component_callback(self, **kwargs)

    if scripts.scripts_current is not None:
        scripts.scripts_current.after_component(self, **kwargs)

    return res


def Block_get_config(self):
    config = original_Block_get_config(self)

    webui_tooltip = getattr(self, 'webui_tooltip', None)
    if webui_tooltip:
        config["webui_tooltip"] = webui_tooltip

46 47
    config.pop('example_inputs', None)

A
AUTOMATIC1111 已提交
48 49 50 51 52 53 54 55 56 57 58
    return config


def BlockContext_init(self, *args, **kwargs):
    res = original_BlockContext_init(self, *args, **kwargs)

    add_classes_to_gradio_component(self)

    return res


59 60 61 62 63 64 65 66 67 68
def Blocks_get_config_file(self, *args, **kwargs):
    config = original_Blocks_get_config_file(self, *args, **kwargs)

    for comp_config in config["components"]:
        if "example_inputs" in comp_config:
            comp_config["example_inputs"] = {"serialized": []}

    return config


W
w-e-w 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
def gradio_component_compatibility_layer(component_function):
    @wraps(component_function)
    def patched_function(*args, **kwargs):
        original_signature = signature(component_function).parameters
        valid_kwargs = {k: v for k, v in kwargs.items() if k in original_signature}
        result = component_function(*args, **valid_kwargs)
        return result

    return patched_function


sub_events = ['then', 'success']


def gradio_component_events_compatibility_layer(component_function):
    @wraps(component_function)
    def patched_function(*args, **kwargs):
        kwargs['js'] = kwargs.get('js', kwargs.pop('_js', None))
        original_signature = signature(component_function).parameters
        valid_kwargs = {k: v for k, v in kwargs.items() if k in original_signature}

        result = component_function(*args, **valid_kwargs)

        for sub_event in sub_events:
            component_event_then_function = getattr(result, sub_event, None)
            if component_event_then_function:
                patched_component_event_then_function = gradio_component_sub_events_compatibility_layer(component_event_then_function)
                setattr(result, sub_event, patched_component_event_then_function)
        # original_component_event_then_function = patches.patch(f'{__name__}.', obj=result, field='then', replacement=patched_component_event_then_function)

        return result

    return patched_function


def gradio_component_sub_events_compatibility_layer(component_function):
    @wraps(component_function)
    def patched_function(*args, **kwargs):
        kwargs['js'] = kwargs.get('js', kwargs.pop('_js', None))
        original_signature = signature(component_function).parameters
        valid_kwargs = {k: v for k, v in kwargs.items() if k in original_signature}
        result = component_function(*args, **valid_kwargs)
        return result

    return patched_function


for component_name in set(gr.components.__all__ + gr.layouts.__all__):
    try:
        component = getattr(gr, component_name)
        component_init = getattr(component, '__init__')
        patched_component_init = gradio_component_compatibility_layer(component_init)
        original_IOComponent_init = patches.patch(f'{__name__}.{component_name}', obj=component, field="__init__", replacement=patched_component_init)

        component_events = set(getattr(component, 'EVENTS'))
        for component_event in component_events:
            component_event_function = getattr(component, component_event)
            patched_component_event_function = gradio_component_events_compatibility_layer(component_event_function)
            original_component_event_function = patches.patch(f'{__name__}.{component_name}.{component_event}', obj=component, field=component_event, replacement=patched_component_event_function)
    except Exception as e:
        print(e)
        pass

gr.Box = gr.Group


original_IOComponent_init = patches.patch(__name__, obj=gr.components.base.Component, field="__init__", replacement=IOComponent_init)
136 137 138
original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
A
AUTOMATIC1111 已提交
139

140 141

ui_tempdir.install_ui_tempdir_override()
W
w-e-w 已提交
142