scripts.py 12.0 KB
Newer Older
1 2 3
import os
import sys
import traceback
4
from collections import namedtuple
5

A
AUTOMATIC 已提交
6
import modules.ui as ui
7 8
import gradio as gr

A
AUTOMATIC 已提交
9
from modules.processing import StableDiffusionProcessing
10
from modules import shared, paths, script_callbacks, extensions
11 12 13

AlwaysVisible = object()

A
AUTOMATIC 已提交
14

15 16
class Script:
    filename = None
A
AUTOMATIC 已提交
17 18
    args_from = None
    args_to = None
19 20
    alwayson = False

21 22 23
    """A gr.Group component that has all script's UI inside it"""
    group = None

24 25 26 27
    infotext_fields = None
    """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
    parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
    """
28 29

    def title(self):
30 31
        """this function should return the title of the script. This is what will be displayed in the dropdown menu."""

32 33
        raise NotImplementedError()

A
AUTOMATIC 已提交
34
    def ui(self, is_img2img):
35 36 37 38 39
        """this function should create gradio UI elements. See https://gradio.app/docs/#components
        The return value should be an array of all components that are used in processing.
        Values of those returned componenbts will be passed to run() and process() functions.
        """

A
AUTOMATIC 已提交
40 41
        pass

A
AUTOMATIC 已提交
42
    def show(self, is_img2img):
43 44 45 46 47 48 49 50 51
        """
        is_img2img is True if this function is called for the img2img interface, and Fasle otherwise

        This function should return:
         - False if the script should not be shown in UI at all
         - True if the script should be shown in UI if it's scelected in the scripts drowpdown
         - script.AlwaysVisible if the script should be shown in UI at all times
         """

A
AUTOMATIC 已提交
52 53
        return True

54 55 56 57 58 59 60 61 62 63 64
    def run(self, p, *args):
        """
        This function is called if the script has been selected in the script dropdown.
        It must do all processing and return the Processed object with results, same as
        one returned by processing.process_images.

        Usually the processing is done by calling the processing.process_images function.

        args contains all values returned by components from ui()
        """

A
AUTOMATIC 已提交
65 66
        raise NotImplementedError()

67 68 69
    def process(self, p, *args):
        """
        This function is called before processing begins for AlwaysVisible scripts.
A
AUTOMATIC 已提交
70 71 72 73 74 75
        You can modify the processing object (p) here, inject hooks, etc.
        args contains all values returned by components from ui()
        """

        pass

A
Artem Zagidulin 已提交
76
    def process_one(self, p, n, *args):
A
Artem Zagidulin 已提交
77 78 79 80 81 82
        """
        Same as process(), but called for every iteration
        """

        pass

A
AUTOMATIC 已提交
83 84 85 86
    def postprocess(self, p, processed, *args):
        """
        This function is called after processing ends for AlwaysVisible scripts.
        args contains all values returned by components from ui()
87 88 89 90
        """

        pass

A
AUTOMATIC 已提交
91
    def describe(self):
92
        """unused"""
A
AUTOMATIC 已提交
93 94
        return ""

95

96 97 98 99 100 101 102 103 104 105 106
current_basedir = paths.script_path


def basedir():
    """returns the base directory for the current script. For scripts in the main scripts directory,
    this is the main directory (where webui.py resides), and for scripts in extensions directory
    (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
    """
    return current_basedir


A
AUTOMATIC 已提交
107
scripts_data = []
108 109 110 111 112 113 114 115 116 117 118 119
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])


def list_scripts(scriptdirname, extension):
    scripts_list = []

    basedir = os.path.join(paths.script_path, scriptdirname)
    if os.path.exists(basedir):
        for filename in sorted(os.listdir(basedir)):
            scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))

120 121
    for ext in extensions.active():
        scripts_list += ext.list_files(scriptdirname, extension)
122

123
    scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
124

125
    return scripts_list
126

127

A
AUTOMATIC 已提交
128 129 130
def list_files_with_name(filename):
    res = []

131
    dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
A
AUTOMATIC 已提交
132 133 134 135 136 137 138 139 140 141 142 143

    for dirpath in dirs:
        if not os.path.isdir(dirpath):
            continue

        path = os.path.join(dirpath, filename)
        if os.path.isfile(filename):
            res.append(path)

    return res


144 145 146 147 148 149 150 151
def load_scripts():
    global current_basedir
    scripts_data.clear()
    script_callbacks.clear_callbacks()

    scripts_list = list_scripts("scripts", ".py")

    syspath = sys.path
152

153
    for scriptfile in sorted(scripts_list):
154
        try:
155 156 157 158 159
            if scriptfile.basedir != paths.script_path:
                sys.path = [scriptfile.basedir] + sys.path
            current_basedir = scriptfile.basedir

            with open(scriptfile.path, "r", encoding="utf8") as file:
160 161
                text = file.read()

162
            from types import ModuleType
163 164
            compiled = compile(text, scriptfile.path, 'exec')
            module = ModuleType(scriptfile.filename)
165 166 167 168
            exec(compiled, module.__dict__)

            for key, script_class in module.__dict__.items():
                if type(script_class) == type and issubclass(script_class, Script):
169
                    scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
170 171

        except Exception:
172
            print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
173
            print(traceback.format_exc(), file=sys.stderr)
174

175 176 177 178
        finally:
            sys.path = syspath
            current_basedir = paths.script_path

179 180 181

def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
    try:
A
AUTOMATIC 已提交
182
        res = func(*args, **kwargs)
183 184
        return res
    except Exception:
A
AUTOMATIC 已提交
185
        print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
186 187 188 189 190
        print(traceback.format_exc(), file=sys.stderr)

    return default


A
AUTOMATIC 已提交
191 192 193
class ScriptRunner:
    def __init__(self):
        self.scripts = []
194 195
        self.selectable_scripts = []
        self.alwayson_scripts = []
ふぁ 已提交
196
        self.titles = []
197
        self.infotext_fields = []
A
AUTOMATIC 已提交
198 199

    def setup_ui(self, is_img2img):
200
        for script_class, path, basedir in scripts_data:
A
AUTOMATIC 已提交
201 202 203
            script = script_class()
            script.filename = path

204
            visibility = script.show(is_img2img)
A
AUTOMATIC 已提交
205

206 207 208 209
            if visibility == AlwaysVisible:
                self.scripts.append(script)
                self.alwayson_scripts.append(script)
                script.alwayson = True
A
AUTOMATIC 已提交
210

211 212 213
            elif visibility:
                self.scripts.append(script)
                self.selectable_scripts.append(script)
A
AUTOMATIC 已提交
214

215 216 217 218
        self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]

        inputs = [None]
        inputs_alwayson = [True]
A
AUTOMATIC 已提交
219

220
        def create_script_ui(script, inputs, inputs_alwayson):
A
AUTOMATIC 已提交
221
            script.args_from = len(inputs)
O
OWKenobi 已提交
222
            script.args_to = len(inputs)
A
AUTOMATIC 已提交
223 224 225 226

            controls = wrap_call(script.ui, script.filename, "ui", is_img2img)

            if controls is None:
227
                return
A
AUTOMATIC 已提交
228

A
AUTOMATIC 已提交
229
            for control in controls:
D
DepFA 已提交
230
                control.custom_script_source = os.path.basename(script.filename)
231 232 233

            if script.infotext_fields is not None:
                self.infotext_fields += script.infotext_fields
A
AUTOMATIC 已提交
234

A
AUTOMATIC 已提交
235
            inputs += controls
236
            inputs_alwayson += [script.alwayson for _ in controls]
A
AUTOMATIC 已提交
237
            script.args_to = len(inputs)
A
AUTOMATIC 已提交
238

239
        for script in self.alwayson_scripts:
240
            with gr.Group() as group:
241 242
                create_script_ui(script, inputs, inputs_alwayson)

243 244
            script.group = group

X
xmodar 已提交
245
        dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
246 247 248 249
        dropdown.save_to_config = True
        inputs[0] = dropdown

        for script in self.selectable_scripts:
250 251 252 253
            with gr.Group(visible=False) as group:
                create_script_ui(script, inputs, inputs_alwayson)

            script.group = group
254

A
AUTOMATIC 已提交
255
        def select_script(script_index):
256
            selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
A
AUTOMATIC 已提交
257

258
            return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
A
AUTOMATIC 已提交
259

ふぁ 已提交
260
        def init_field(title):
261 262
            """called when an initial value is set from ui-config.json to show script's UI components"""

ふぁ 已提交
263
            if title == 'None':
ふぁ 已提交
264
                return
265

ふぁ 已提交
266
            script_index = self.titles.index(title)
267
            self.selectable_scripts[script_index].group.visible = True
ふぁ 已提交
268 269

        dropdown.init_field = init_field
270

A
AUTOMATIC 已提交
271 272 273
        dropdown.change(
            fn=select_script,
            inputs=[dropdown],
274
            outputs=[script.group for script in self.selectable_scripts]
A
AUTOMATIC 已提交
275
        )
A
AUTOMATIC 已提交
276

A
AUTOMATIC 已提交
277
        return inputs
A
AUTOMATIC 已提交
278

A
AUTOMATIC 已提交
279 280
    def run(self, p: StableDiffusionProcessing, *args):
        script_index = args[0]
A
AUTOMATIC 已提交
281

A
AUTOMATIC 已提交
282 283
        if script_index == 0:
            return None
A
AUTOMATIC 已提交
284

285
        script = self.selectable_scripts[script_index-1]
A
AUTOMATIC 已提交
286

A
AUTOMATIC 已提交
287 288
        if script is None:
            return None
A
AUTOMATIC 已提交
289

A
AUTOMATIC 已提交
290 291
        script_args = args[script.args_from:script.args_to]
        processed = script.run(p, *script_args)
A
AUTOMATIC 已提交
292

293 294
        shared.total_tqdm.clear()

A
AUTOMATIC 已提交
295
        return processed
A
AUTOMATIC 已提交
296

A
AUTOMATIC 已提交
297
    def process(self, p):
298 299 300 301 302
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.process(p, *script_args)
            except Exception:
A
AUTOMATIC 已提交
303 304 305
                print(f"Error running process: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

A
Artem Zagidulin 已提交
306
    def process_one(self, p, n):
A
Artem Zagidulin 已提交
307 308 309
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
A
Artem Zagidulin 已提交
310
                script.process_one(p, n, *script_args)
A
Artem Zagidulin 已提交
311 312 313 314
            except Exception:
                print(f"Error running process_one: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

A
AUTOMATIC 已提交
315 316 317 318 319 320 321
    def postprocess(self, p, processed):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess(p, processed, *script_args)
            except Exception:
                print(f"Error running postprocess: {script.filename}", file=sys.stderr)
322 323
                print(traceback.format_exc(), file=sys.stderr)

A
AUTOMATIC 已提交
324
    def reload_sources(self, cache):
D
DepFA 已提交
325 326 327 328 329 330
        for si, script in list(enumerate(self.scripts)):
            with open(script.filename, "r", encoding="utf8") as file:
                args_from = script.args_from
                args_to = script.args_to
                filename = script.filename
                text = file.read()
D
DepFA 已提交
331

D
DepFA 已提交
332
                from types import ModuleType
D
DepFA 已提交
333

A
AUTOMATIC 已提交
334 335 336 337 338 339
                module = cache.get(filename, None)
                if module is None:
                    compiled = compile(text, filename, 'exec')
                    module = ModuleType(script.filename)
                    exec(compiled, module.__dict__)
                    cache[filename] = module
D
DepFA 已提交
340 341 342 343 344 345 346

                for key, script_class in module.__dict__.items():
                    if type(script_class) == type and issubclass(script_class, Script):
                        self.scripts[si] = script_class()
                        self.scripts[si].filename = filename
                        self.scripts[si].args_from = args_from
                        self.scripts[si].args_to = args_to
A
AUTOMATIC 已提交
347

348

A
AUTOMATIC 已提交
349 350
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
D
DepFA 已提交
351

352

D
DepFA 已提交
353
def reload_script_body_only():
A
AUTOMATIC 已提交
354 355 356
    cache = {}
    scripts_txt2img.reload_sources(cache)
    scripts_img2img.reload_sources(cache)
D
DepFA 已提交
357

D
DepFA 已提交
358

359
def reload_scripts():
D
DepFA 已提交
360
    global scripts_txt2img, scripts_img2img
D
DepFA 已提交
361

362
    load_scripts()
D
DepFA 已提交
363

D
DepFA 已提交
364 365
    scripts_txt2img = ScriptRunner()
    scripts_img2img = ScriptRunner()
366