ui.py 88.7 KB
Newer Older
1 2
import html
import json
A
AUTOMATIC 已提交
3
import math
4 5
import mimetypes
import os
D
discus0434 已提交
6
import platform
A
AUTOMATIC 已提交
7
import random
D
discus0434 已提交
8
import subprocess as sp
9
import sys
10
import tempfile
11 12
import time
import traceback
13
from functools import partial, reduce
14

D
discus0434 已提交
15 16 17
import gradio as gr
import gradio.routes
import gradio.utils
A
AUTOMATIC 已提交
18
import numpy as np
19
from PIL import Image, PngImagePlugin
20
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
21

22 23
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
from modules.ui_components import FormRow, FormGroup, ToolButton
24
from modules.paths import script_path
25

26
from modules.shared import opts, cmd_opts, restricted_opts
M
init  
MalumaDev 已提交
27

D
discus0434 已提交
28
import modules.codeformer_model
Y
yfszzx 已提交
29
import modules.generation_parameters_copypaste as parameters_copypaste
D
discus0434 已提交
30 31
import modules.gfpgan_model
import modules.hypernetworks.ui
A
AUTOMATIC 已提交
32
import modules.scripts
D
discus0434 已提交
33
import modules.shared as shared
A
AUTOMATIC 已提交
34
import modules.styles
D
discus0434 已提交
35
import modules.textual_inversion.ui
A
AUTOMATIC 已提交
36
from modules import prompt_parser
M
Milly 已提交
37
from modules.images import save_image
D
discus0434 已提交
38 39
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
40
import modules.textual_inversion.ui
41
import modules.hypernetworks.ui
Y
yfszzx 已提交
42
from modules.generation_parameters_copypaste import image_from_url_text
43

A
Aidan Holland 已提交
44
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
45 46 47
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')

48
if not cmd_opts.share and not cmd_opts.listen:
49 50 51 52
    # fix gradio phoning home
    gradio.utils.version_check = lambda: None
    gradio.utils.get_local_ip_address = lambda: '127.0.0.1'

Y
Yuval Aboulafia 已提交
53
if cmd_opts.ngrok is not None:
J
JamnedZ 已提交
54 55
    import modules.ngrok as ngrok
    print('ngrok authtoken detected, trying to connect...')
Y
Yuval Aboulafia 已提交
56 57 58 59 60
    ngrok.connect(
        cmd_opts.ngrok,
        cmd_opts.port if cmd_opts.port is not None else 7860,
        cmd_opts.ngrok_region
        )
J
JamnedZ 已提交
61

62 63 64 65 66 67 68 69 70 71

def gr_show(visible=True):
    return {"visible": visible, "__type__": "update"}


sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None

css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
72
.wrap .m-12::before { content:"Loading..." }
D
dtlnor 已提交
73 74
.wrap .z-20 svg { display:none!important; }
.wrap .z-20::before { content:"Loading..." }
75 76
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
D
dtlnor 已提交
77
.meta-text-center { display:none!important; }
78 79
"""

80 81 82 83
# Using constants for these since the variation selector isn't visible.
# Important that they exactly match script.js for tooltip to work.
random_symbol = '\U0001f3b2\ufe0f'  # 🎲️
reuse_symbol = '\u267b\ufe0f'  # ♻️
84
paste_symbol = '\u2199\ufe0f'  # ↙
85
folder_symbol = '\U0001f4c2'  # 📂
86
refresh_symbol = '\U0001f504'  # 🔄
A
AUTOMATIC 已提交
87 88
save_style_symbol = '\U0001f4be'  # 💾
apply_style_symbol = '\U0001f4cb'  # 📋
P
papuSpartan 已提交
89
clear_prompt_symbol = '\U0001F5D1'  # 🗑️
90

91

92
def plaintext_to_html(text):
93
    text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
94 95 96 97 98 99 100
    return text

def send_gradio_gallery_to_image(x):
    if len(x) == 0:
        return None
    return image_from_url_text(x[0])

A
aoirusann 已提交
101
def save_files(js_data, images, do_make_zip, index):
J
Justin Maier 已提交
102
    import csv
103
    filenames = []
A
aoirusann 已提交
104
    fullfns = []
105

A
Aidan Holland 已提交
106
    #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
107 108 109 110 111 112
    class MyObject:
        def __init__(self, d=None):
            if d is not None:
                for key, value in d.items():
                    setattr(self, key, value)

113
    data = json.loads(js_data)
114

115 116
    p = MyObject(data)
    path = opts.outdir_save
117
    save_to_dirs = opts.use_save_to_dirs_for_ui
M
Milly 已提交
118 119
    extension: str = opts.samples_format
    start_index = 0
120

121
    if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]):  # ensures we are looking at a specific non-grid picture, and we have save_selected_only
122

J
jtkelm2 已提交
123
        images = [images[index]]
M
Milly 已提交
124
        start_index = index
125

126 127
    os.makedirs(opts.outdir_save, exist_ok=True)

128
    with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
129 130 131
        at_start = file.tell() == 0
        writer = csv.writer(file)
        if at_start:
132
            writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
133

M
Milly 已提交
134
        for image_index, filedata in enumerate(images, start_index):
135
            image = image_from_url_text(filedata)
136

M
Milly 已提交
137 138 139
            is_grid = image_index < p.index_of_first_image
            i = 0 if is_grid else (image_index - p.index_of_first_image)

G
Greg Fuller 已提交
140
            fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
M
Milly 已提交
141 142

            filename = os.path.relpath(fullfn, path)
143
            filenames.append(filename)
A
aoirusann 已提交
144
            fullfns.append(fullfn)
A
aoirusann 已提交
145 146 147
            if txt_fullfn:
                filenames.append(os.path.basename(txt_fullfn))
                fullfns.append(txt_fullfn)
148

149
        writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
150

A
aoirusann 已提交
151 152 153 154 155 156 157 158 159 160 161
    # Make Zip
    if do_make_zip:
        zip_filepath = os.path.join(path, "images.zip")

        from zipfile import ZipFile
        with ZipFile(zip_filepath, "w") as zip_file:
            for i in range(len(fullfns)):
                with open(fullfns[i], mode="rb") as f:
                    zip_file.writestr(filenames[i], f.read())
        fullfns.insert(0, zip_filepath)

162
    return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
163 164


165
def calc_time_left(progress, threshold, label, force_display, show_eta):
A
Anastasius 已提交
166
    if progress == 0:
167
        return ""
A
Anastasius 已提交
168 169 170
    else:
        time_since_start = time.time() - shared.state.time_start
        eta = (time_since_start/progress)
171
        eta_relative = eta-time_since_start
172
        if (eta_relative > threshold and show_eta) or force_display:
A
Alexandre Simard 已提交
173 174 175 176 177 178
            if eta_relative > 3600:
                return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
            elif eta_relative > 60:
                return label + time.strftime('%M:%S',  time.gmtime(eta_relative))
            else:
                return label + time.strftime('%Ss',  time.gmtime(eta_relative))
179 180
        else:
            return ""
A
Anastasius 已提交
181 182


183
def check_progress_call(id_part):
A
AUTOMATIC 已提交
184
    if shared.state.job_count == 0:
185
        return "", gr_show(False), gr_show(False), gr_show(False)
A
AUTOMATIC 已提交
186

A
AUTOMATIC 已提交
187 188 189 190
    progress = 0

    if shared.state.job_count > 0:
        progress += shared.state.job_no / shared.state.job_count
A
AUTOMATIC 已提交
191 192 193
    if shared.state.sampling_steps > 0:
        progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps

194
    # Show progress percentage and time left at the same moment, and base it also on steps done
195
    show_eta = progress >= 0.01 or shared.state.sampling_step >= 10
196

197
    time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta)
198 199
    if time_left != "":
        shared.state.time_left_force_display = True
A
Anastasius 已提交
200

A
AUTOMATIC 已提交
201 202
    progress = min(progress, 1)

A
AUTOMATIC 已提交
203 204
    progressbar = ""
    if opts.show_progressbar:
205
        progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{"&nbsp;" * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}</div></div>"""
A
AUTOMATIC 已提交
206 207 208 209

    image = gr_show(False)
    preview_visibility = gr_show(False)

210
    if opts.show_progress_every_n_steps != 0:
A
AUTOMATIC 已提交
211
        shared.state.set_current_image()
A
AUTOMATIC 已提交
212 213
        image = shared.state.current_image

A
AUTOMATIC 已提交
214
        if image is None:
A
AUTOMATIC 已提交
215 216 217
            image = gr.update(value=None)
        else:
            preview_visibility = gr_show(True)
A
AUTOMATIC 已提交
218

219 220 221 222 223 224
    if shared.state.textinfo is not None:
        textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
    else:
        textinfo_result = gr_show(False)

    return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
A
AUTOMATIC 已提交
225 226


227
def check_progress_call_initial(id_part):
228
    shared.state.job_count = -1
229 230
    shared.state.current_latent = None
    shared.state.current_image = None
231
    shared.state.textinfo = None
A
Anastasius 已提交
232
    shared.state.time_start = time.time()
233
    shared.state.time_left_force_display = False
234

235
    return check_progress_call(id_part)
236 237


A
AUTOMATIC 已提交
238 239 240 241 242 243 244
def visit(x, func, path=""):
    if hasattr(x, 'children'):
        for c in x.children:
            visit(c, func, path)
    elif x.label is not None:
        func(path + "/" + str(x.label), x)

245

246 247
def add_style(name: str, prompt: str, negative_prompt: str):
    if name is None:
A
AUTOMATIC 已提交
248
        return [gr_show() for x in range(4)]
A
AUTOMATIC 已提交
249

250
    style = modules.styles.PromptStyle(name, prompt, negative_prompt)
A
AUTOMATIC 已提交
251
    shared.prompt_styles.styles[style.name] = style
252 253
    # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
    # reserialize all styles every time we save them
A
AUTOMATIC 已提交
254
    shared.prompt_styles.save_styles(shared.styles_filename)
A
AUTOMATIC 已提交
255

256
    return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
A
AUTOMATIC 已提交
257 258 259 260 261 262 263


def apply_styles(prompt, prompt_neg, style1_name, style2_name):
    prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
    prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])

    return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
A
AUTOMATIC 已提交
264 265


A
AUTOMATIC 已提交
266
def interrogate(image):
A
Allen Benz 已提交
267
    prompt = shared.interrogator.interrogate(image.convert("RGB"))
A
AUTOMATIC 已提交
268 269 270

    return gr_show(True) if prompt is None else prompt

A
AUTOMATIC 已提交
271

G
Greendayle 已提交
272
def interrogate_deepbooru(image):
273
    prompt = deepbooru.model.tag(image)
G
Greendayle 已提交
274 275 276
    return gr_show(True) if prompt is None else prompt


277
def create_seed_inputs(target_interface):
278 279 280 281 282
    with FormRow(elem_id=target_interface + '_seed_row'):
        seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
        seed.style(container=False)
        random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
        reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
283

284
        with gr.Group(elem_id=target_interface + '_subseed_show_box'):
285
            seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
286 287 288 289

    # Components to show/hide based on the 'Extra' checkbox
    seed_extras = []

290
    with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
291
        seed_extras.append(seed_extra_row_1)
292 293 294 295
        subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
        subseed.style(container=False)
        random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
        reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
296
        subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
297

298
    with FormRow(visible=False) as seed_extra_row_2:
299
        seed_extras.append(seed_extra_row_2)
300 301
        seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
        seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
302 303 304 305 306 307 308 309 310

    random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
    random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])

    def change_visibility(show):
        return {comp: gr_show(show) for comp in seed_extras}

    seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)

311
    return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
312 313


314

P
papuSpartan 已提交
315
def connect_clear_prompt(button):
P
papuSpartan 已提交
316
    """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
317
    button.click(
318
        _js="clear_prompt",
P
papuSpartan 已提交
319
        fn=None,
P
papuSpartan 已提交
320 321
        inputs=[],
        outputs=[],
322
    )
323 324


325 326 327
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
    """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
        (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
328
        was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
329 330 331
    def copy_seed(gen_info_string: str, index):
        res = -1

332 333
        try:
            gen_info = json.loads(gen_info_string)
334 335 336 337 338
            index -= gen_info.get('index_of_first_image', 0)

            if is_subseed and gen_info.get('subseed_strength', 0) > 0:
                all_subseeds = gen_info.get('all_subseeds', [-1])
                res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
339
            else:
340 341 342
                all_seeds = gen_info.get('all_seeds', [-1])
                res = all_seeds[index if 0 <= index < len(all_seeds) else 0]

343 344 345 346
        except json.decoder.JSONDecodeError as e:
            if gen_info_string != '':
                print("Error parsing JSON generation info:", file=sys.stderr)
                print(gen_info_string, file=sys.stderr)
347 348

        return [res, gr_show(False)]
349 350 351

    reuse_seed.click(
        fn=copy_seed,
352
        _js="(x, y) => [x, selected_gallery_index()]",
353
        show_progress=False,
354 355
        inputs=[generation_info, dummy_component],
        outputs=[seed, dummy_component]
356 357
    )

358

L
Liam 已提交
359
def update_token_counter(text, steps):
360
    try:
A
AUTOMATIC 已提交
361 362 363
        _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
        prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)

364 365 366 367 368
    except Exception:
        # a parsing error can happen here during typing, and we don't want to bother the user with
        # messages related to it in console
        prompt_schedules = [[[steps, text]]]

L
Liam 已提交
369
    flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
370
    prompts = [prompt_text for step, prompt_text in flat_prompts]
L
Liam 已提交
371
    tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
372 373
    style_class = ' class="red"' if (token_count > max_length) else ""
    return f"<span {style_class}>{token_count}/{max_length}</span>"
A
AUTOMATIC 已提交
374

375

A
AUTOMATIC 已提交
376
def create_toprow(is_img2img):
377 378
    id_part = "img2img" if is_img2img else "txt2img"

A
AUTOMATIC 已提交
379
    with gr.Row(elem_id="toprow"):
A
AUTOMATIC 已提交
380
        with gr.Column(scale=6):
A
AUTOMATIC 已提交
381
            with gr.Row():
382
                with gr.Column(scale=80):
A
AUTOMATIC 已提交
383
                    with gr.Row():
384
                        prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
A
aoirusann 已提交
385 386 387
                            placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
                        )

A
AUTOMATIC 已提交
388
            with gr.Row():
A
AUTOMATIC 已提交
389
                with gr.Column(scale=80):
B
Ben 已提交
390
                    with gr.Row():
391
                        negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
A
aoirusann 已提交
392 393 394
                            placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
                        )

A
AUTOMATIC 已提交
395 396 397 398
        with gr.Column(scale=1, elem_id="roll_col"):
            paste = gr.Button(value=paste_symbol, elem_id="paste")
            save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
            prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
399
            clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
A
AUTOMATIC 已提交
400 401 402
            token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
            token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")

403 404 405 406 407 408 409
            clear_prompt_button.click(
                fn=lambda *x: x,
                _js="confirm_clear_prompt",
                inputs=[prompt, negative_prompt],
                outputs=[prompt, negative_prompt],
            )

A
AUTOMATIC 已提交
410 411 412 413 414
        button_interrogate = None
        button_deepbooru = None
        if is_img2img:
            with gr.Column(scale=1, elem_id="interrogate_col"):
                button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
A
AUTOMATIC 已提交
415
                button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
A
AUTOMATIC 已提交
416 417 418

        with gr.Column(scale=1):
            with gr.Row():
419
                skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
420
                interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
421
                submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
422

423 424 425 426 427 428
                skip.click(
                    fn=lambda: shared.state.skip(),
                    inputs=[],
                    outputs=[],
                )

429 430 431 432 433
                interrupt.click(
                    fn=lambda: shared.state.interrupt(),
                    inputs=[],
                    outputs=[],
                )
A
AUTOMATIC 已提交
434

A
AUTOMATIC 已提交
435 436 437
            with gr.Row():
                with gr.Column(scale=1, elem_id="style_pos_col"):
                    prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
438
                    prompt_style.save_to_config = True
A
AUTOMATIC 已提交
439 440 441

                with gr.Column(scale=1, elem_id="style_neg_col"):
                    prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
442
                    prompt_style2.save_to_config = True
A
AUTOMATIC 已提交
443

444
    return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
A
AUTOMATIC 已提交
445 446


447 448 449
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
    if textinfo is None:
        textinfo = gr.HTML(visible=False)
450

451
    check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
452
    check_progress.click(
453
        fn=lambda: check_progress_call(id_part),
454 455
        show_progress=False,
        inputs=[],
456
        outputs=[progressbar, preview, preview, textinfo],
457 458
    )

459
    check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
460
    check_progress_initial.click(
461
        fn=lambda: check_progress_call_initial(id_part),
462 463
        show_progress=False,
        inputs=[],
464
        outputs=[progressbar, preview, preview, textinfo],
465
    )
A
AUTOMATIC 已提交
466 467


468 469 470 471
def apply_setting(key, value):
    if value is None:
        return gr.update()

472 473 474
    if shared.cmd_opts.freeze_settings:
        return gr.update()

475 476 477 478
    # dont allow model to be swapped when model hash exists in prompt
    if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
        return gr.update()

479 480 481 482 483 484 485 486 487 488 489 490 491
    if key == "sd_model_checkpoint":
        ckpt_info = sd_models.get_closet_checkpoint_match(value)

        if ckpt_info is not None:
            value = ckpt_info.title
        else:
            return gr.update()

    comp_args = opts.data_labels[key].component_args
    if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
        return

    valtype = type(opts.data_labels[key].default)
492
    oldval = opts.data.get(key, None)
493 494 495 496 497 498 499 500
    opts.data[key] = valtype(value) if valtype != type(None) else value
    if oldval != value and opts.data_labels[key].onchange is not None:
        opts.data_labels[key].onchange()

    opts.save(shared.config_filename)
    return value


501 502 503 504
def update_generation_info(args):
    generation_info, html_info, img_index = args
    try:
        generation_info = json.loads(generation_info)
505 506
        if img_index < 0 or img_index >= len(generation_info["infotexts"]):
            return html_info
507 508 509 510 511
        return plaintext_to_html(generation_info["infotexts"][img_index])
    except Exception:
        pass
    # if the json parse or anything else fails, just return the old html_info
    return html_info
512

513

514 515 516 517
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
    def refresh():
        refresh_method()
        args = refreshed_args() if callable(refreshed_args) else refreshed_args
518

519 520
        for k, v in args.items():
            setattr(refresh_component, k, v)
521

522
        return gr.update(**(args or {}))
523

524
    refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
525 526 527 528 529 530
    refresh_button.click(
        fn=refresh,
        inputs=[],
        outputs=[refresh_component]
    )
    return refresh_button
531 532


533
def create_output_panel(tabname, outdir):
Y
yfszzx 已提交
534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554
    def open_folder(f):
        if not os.path.exists(f):
            print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
            return
        elif not os.path.isdir(f):
            print(f"""
WARNING
An open_folder request was made with an argument that is not a folder.
This could be an error or a malicious attempt to run code on your computer.
Requested path was: {f}
""", file=sys.stderr)
            return

        if not shared.cmd_opts.hide_ui_dir_config:
            path = os.path.normpath(f)
            if platform.system() == "Windows":
                os.startfile(path)
            elif platform.system() == "Darwin":
                sp.Popen(["open", path])
            else:
                sp.Popen(["xdg-open", path])
555 556 557

    with gr.Column(variant='panel'):
            with gr.Group():
Y
yfszzx 已提交
558
                result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
559

Y
yfszzx 已提交
560 561
            generation_info = None
            with gr.Column():
A
AUTOMATIC 已提交
562 563
                with gr.Row(elem_id=f"image_buttons_{tabname}"):
                    open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder')
A
AUTOMATIC 已提交
564

Y
yfszzx 已提交
565
                    if tabname != "extras":
566
                        save = gr.Button('Save', elem_id=f'save_{tabname}')
A
AUTOMATIC 已提交
567
                        save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
Y
yfszzx 已提交
568 569 570

                    buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])

A
AUTOMATIC 已提交
571
                open_folder_button.click(
Y
yfszzx 已提交
572 573 574 575 576
                    fn=lambda: open_folder(opts.outdir_samples or outdir),
                    inputs=[],
                    outputs=[],
                )

577
                if tabname != "extras":
Y
yfszzx 已提交
578 579 580 581 582
                    with gr.Row():
                        download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)

                    with gr.Group():
                        html_info = gr.HTML()
583
                        html_log = gr.HTML()
584

Y
yfszzx 已提交
585
                        generation_info = gr.Textbox(visible=False)
586 587 588 589 590 591 592 593 594
                        if tabname == 'txt2img' or tabname == 'img2img':
                            generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
                            generation_info_button.click(
                                fn=update_generation_info,
                                _js="(x, y) => [x, y, selected_gallery_index()]",
                                inputs=[generation_info, html_info],
                                outputs=[html_info],
                                preprocess=False
                            )
Y
yfszzx 已提交
595 596 597

                        save.click(
                            fn=wrap_gradio_call(save_files),
A
AUTOMATIC 已提交
598
                            _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
Y
yfszzx 已提交
599 600 601
                            inputs=[
                                generation_info,
                                result_gallery,
A
AUTOMATIC 已提交
602
                                html_info,
Y
yfszzx 已提交
603 604 605 606
                                html_info,
                            ],
                            outputs=[
                                download_files,
A
AUTOMATIC 已提交
607 608 609
                                html_log,
                            ]
                        )
A
AUTOMATIC 已提交
610

A
AUTOMATIC 已提交
611 612 613
                        save_zip.click(
                            fn=wrap_gradio_call(save_files),
                            _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
Y
yfszzx 已提交
614 615 616 617 618 619 620 621
                            inputs=[
                                generation_info,
                                result_gallery,
                                html_info,
                                html_info,
                            ],
                            outputs=[
                                download_files,
622
                                html_log,
Y
yfszzx 已提交
623 624
                            ]
                        )
625

Y
yfszzx 已提交
626 627 628
                else:
                    html_info_x = gr.HTML()
                    html_info = gr.HTML()
629
                    html_log = gr.HTML()
630

Y
yfszzx 已提交
631
                parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
632
                return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
J
Justin Maier 已提交
633

A
aoirusann 已提交
634

635 636
def create_sampler_and_steps_selection(choices, tabname):
    if opts.samplers_in_dropdown:
637
        with FormRow(elem_id=f"sampler_selection_{tabname}"):
638
            sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
639
            sampler_index.save_to_config = True
640 641
            steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
    else:
642
        with FormGroup(elem_id=f"sampler_selection_{tabname}"):
643 644 645 646
            steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
            sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")

    return steps, sampler_index
647

648

649 650 651 652 653 654 655
def ordered_ui_categories():
    user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}

    for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
        yield category


656
def create_ui():
657 658
    import modules.img2img
    import modules.txt2img
659

660 661
    reload_javascript()

662
    parameters_copypaste.reset()
663

664 665 666
    modules.scripts.scripts_current = modules.scripts.scripts_txt2img
    modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)

667
    with gr.Blocks(analytics_enabled=False) as txt2img_interface:
668
        txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
P
papuSpartan 已提交
669

670
        dummy_component = gr.Label(visible=False)
A
AUTOMATIC 已提交
671
        txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
672

673 674 675
        with gr.Row(elem_id='txt2img_progress_row'):
            with gr.Column(scale=1):
                pass
676

677 678
            with gr.Column(scale=1):
                progressbar = gr.HTML(elem_id="txt2img_progressbar")
679
                txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
680
                setup_progressbar(progressbar, txt2img_preview, 'txt2img')
681

682
        with gr.Row().style(equal_height=False):
683
            with gr.Column(variant='panel', elem_id="txt2img_settings"):
684 685 686
                for category in ordered_ui_categories():
                    if category == "sampler":
                        steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
A
AUTOMATIC 已提交
687

688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725
                    elif category == "dimensions":
                        with FormRow():
                            with gr.Column(elem_id="txt2img_column_size", scale=4):
                                width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
                                height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")

                            if opts.dimensions_and_batch_together:
                                with gr.Column(elem_id="txt2img_column_batch"):
                                    batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
                                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")

                    elif category == "cfg":
                        cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")

                    elif category == "seed":
                        seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')

                    elif category == "checkboxes":
                        with FormRow(elem_id="txt2img_checkboxes"):
                            restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
                            tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
                            enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")

                    elif category == "hires_fix":
                        with FormRow(visible=False, elem_id="txt2img_hires_fix") as hr_options:
                            hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
                            hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
                            denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")

                    elif category == "batch":
                        if not opts.dimensions_and_batch_together:
                            with FormRow(elem_id="txt2img_column_batch"):
                                batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
                                batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")

                    elif category == "scripts":
                        with FormGroup(elem_id="txt2img_script_container"):
                            custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
726

727
            txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
728
            parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
729

730 731
            connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
            connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
732

733
            txt2img_args = dict(
734
                fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
A
AUTOMATIC 已提交
735
                _js="submit",
736
                inputs=[
A
AUTOMATIC 已提交
737
                    txt2img_prompt,
738
                    txt2img_negative_prompt,
A
AUTOMATIC 已提交
739
                    txt2img_prompt_style,
A
AUTOMATIC 已提交
740
                    txt2img_prompt_style2,
741 742
                    steps,
                    sampler_index,
A
AUTOMATIC 已提交
743
                    restore_faces,
744
                    tiling,
745 746 747 748
                    batch_count,
                    batch_size,
                    cfg_scale,
                    seed,
749
                    subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
750 751
                    height,
                    width,
A
AUTOMATIC 已提交
752 753
                    enable_hr,
                    denoising_strength,
A
AUTOMATIC 已提交
754 755
                    hr_scale,
                    hr_upscaler,
A
AUTOMATIC 已提交
756
                ] + custom_inputs,
757

758 759 760
                outputs=[
                    txt2img_gallery,
                    generation_info,
761 762
                    html_info,
                    html_log,
763 764
                ],
                show_progress=False,
765 766
            )

A
AUTOMATIC 已提交
767
            txt2img_prompt.submit(**txt2img_args)
768 769
            submit.click(**txt2img_args)

D
d8ahazard 已提交
770 771 772 773 774 775 776 777 778 779 780
            txt_prompt_img.change(
                fn=modules.images.image_data,
                inputs=[
                    txt_prompt_img
                ],
                outputs=[
                    txt2img_prompt,
                    txt_prompt_img
                ]
            )

A
AUTOMATIC 已提交
781 782 783 784 785 786
            enable_hr.change(
                fn=lambda x: gr_show(x),
                inputs=[enable_hr],
                outputs=[hr_options],
            )

787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804
            txt2img_paste_fields = [
                (txt2img_prompt, "Prompt"),
                (txt2img_negative_prompt, "Negative prompt"),
                (steps, "Steps"),
                (sampler_index, "Sampler"),
                (restore_faces, "Face restoration"),
                (cfg_scale, "CFG scale"),
                (seed, "Seed"),
                (width, "Size-1"),
                (height, "Size-2"),
                (batch_size, "Batch size"),
                (subseed, "Variation seed"),
                (subseed_strength, "Variation seed strength"),
                (seed_resize_from_w, "Seed resize from-1"),
                (seed_resize_from_h, "Seed resize from-2"),
                (denoising_strength, "Denoising strength"),
                (enable_hr, lambda d: "Denoising strength" in d),
                (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
A
AUTOMATIC 已提交
805 806
                (hr_scale, "Hires upscale"),
                (hr_upscaler, "Hires upscaler"),
807
                *modules.scripts.scripts_txt2img.infotext_fields
808
            ]
809
            parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
810 811 812 813 814 815 816 817 818 819 820 821

            txt2img_preview_params = [
                txt2img_prompt,
                txt2img_negative_prompt,
                steps,
                sampler_index,
                cfg_scale,
                seed,
                width,
                height,
            ]

822
            token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
823

824 825
    modules.scripts.scripts_current = modules.scripts.scripts_img2img
    modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
826

827
    with gr.Blocks(analytics_enabled=False) as img2img_interface:
828
        img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
829

830
        with gr.Row(elem_id='img2img_progress_row'):
A
AUTOMATIC 已提交
831
            img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
D
d8ahazard 已提交
832

833 834
            with gr.Column(scale=1):
                pass
835

836 837
            with gr.Column(scale=1):
                progressbar = gr.HTML(elem_id="img2img_progressbar")
838
                img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
839
                setup_progressbar(progressbar, img2img_preview, 'img2img')
840

841
        with FormRow().style(equal_height=False):
842
            with gr.Column(variant='panel', elem_id="img2img_settings"):
843

844
                with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
845
                    with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
846
                        init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
847

848
                    with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
K
kavorite 已提交
849
                        init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
850
                        init_img_with_mask_orig = gr.State(None)
851

852 853 854 855 856 857 858 859
                        use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
                        if use_color_sketch:
                            def update_orig(image, state):
                                if image is not None:
                                    same_size = state is not None and state.size == image.size
                                    has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
                                    edited = same_size and has_exact_match
                                    return image if not edited or state is None else state
860

861
                            init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
862

863 864
                        init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
                        init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
865

866
                        with FormRow():
867 868
                            mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
                            mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
869

870 871 872
                        with FormRow():
                            mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
                            inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
873

874 875
                        with FormRow():
                            inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
876

877 878 879
                        with FormRow():
                            with gr.Column():
                                inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
880

881 882
                            with gr.Column(scale=4):
                                inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
883

884
                    with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
885
                        hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
886
                        gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
887 888
                        img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
                        img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
889

890 891
                with FormRow():
                    resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
892

893 894 895
                for category in ordered_ui_categories():
                    if category == "sampler":
                        steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
A
AUTOMATIC 已提交
896

897 898 899 900 901
                    elif category == "dimensions":
                        with FormRow():
                            with gr.Column(elem_id="img2img_column_size", scale=4):
                                width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
                                height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
902

903 904 905 906
                            if opts.dimensions_and_batch_together:
                                with gr.Column(elem_id="img2img_column_batch"):
                                    batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
                                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
907

908 909 910 911
                    elif category == "cfg":
                        with FormGroup():
                            cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
                            denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
912

913 914
                    elif category == "seed":
                        seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
915

916 917 918 919
                    elif category == "checkboxes":
                        with FormRow(elem_id="img2img_checkboxes"):
                            restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
                            tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
920

921 922 923 924 925
                    elif category == "batch":
                        if not opts.dimensions_and_batch_together:
                            with FormRow(elem_id="img2img_column_batch"):
                                batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
                                batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
A
AUTOMATIC 已提交
926

927 928 929
                    elif category == "scripts":
                        with FormGroup(elem_id="img2img_script_container"):
                            custom_inputs = modules.scripts.scripts_img2img.setup_ui()
A
AUTOMATIC 已提交
930

931
            img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
932
            parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
933

934 935
            connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
            connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
936

D
d8ahazard 已提交
937 938 939
            img2img_prompt_img.change(
                fn=modules.images.image_data,
                inputs=[
A
AUTOMATIC 已提交
940
                    img2img_prompt_img
D
d8ahazard 已提交
941 942 943 944 945 946 947
                ],
                outputs=[
                    img2img_prompt,
                    img2img_prompt_img
                ]
            )

948
            mask_mode.change(
949
                lambda mode, img: {
950
                    init_img_with_mask: gr_show(mode == 0),
951 952
                    init_img_inpaint: gr_show(mode == 1),
                    init_mask_inpaint: gr_show(mode == 1),
953
                },
954
                inputs=[mask_mode, init_img_with_mask],
955 956
                outputs=[
                    init_img_with_mask,
957 958
                    init_img_inpaint,
                    init_mask_inpaint,
959 960 961
                ],
            )

962
            img2img_args = dict(
963
                fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
964
                _js="submit_img2img",
965
                inputs=[
966
                    dummy_component,
A
AUTOMATIC 已提交
967
                    img2img_prompt,
968
                    img2img_negative_prompt,
A
AUTOMATIC 已提交
969
                    img2img_prompt_style,
A
AUTOMATIC 已提交
970
                    img2img_prompt_style2,
971 972
                    init_img,
                    init_img_with_mask,
973
                    init_img_with_mask_orig,
974 975
                    init_img_inpaint,
                    init_mask_inpaint,
976
                    mask_mode,
977 978 979
                    steps,
                    sampler_index,
                    mask_blur,
980
                    mask_alpha,
981
                    inpainting_fill,
A
AUTOMATIC 已提交
982
                    restore_faces,
983
                    tiling,
984 985 986 987 988
                    batch_count,
                    batch_size,
                    cfg_scale,
                    denoising_strength,
                    seed,
989
                    subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
990 991 992 993
                    height,
                    width,
                    resize_mode,
                    inpaint_full_res,
994
                    inpaint_full_res_padding,
A
AUTOMATIC 已提交
995
                    inpainting_mask_invert,
996 997
                    img2img_batch_input_dir,
                    img2img_batch_output_dir,
A
AUTOMATIC 已提交
998
                ] + custom_inputs,
999 1000 1001
                outputs=[
                    img2img_gallery,
                    generation_info,
1002 1003
                    html_info,
                    html_log,
1004 1005
                ],
                show_progress=False,
1006 1007
            )

A
AUTOMATIC 已提交
1008
            img2img_prompt.submit(**img2img_args)
1009 1010
            submit.click(**img2img_args)

A
AUTOMATIC 已提交
1011 1012 1013 1014 1015 1016
            img2img_interrogate.click(
                fn=interrogate,
                inputs=[init_img],
                outputs=[img2img_prompt],
            )

A
AUTOMATIC 已提交
1017 1018 1019 1020
            img2img_deepbooru.click(
                fn=interrogate_deepbooru,
                inputs=[init_img],
                outputs=[img2img_prompt],
A
AUTOMATIC 已提交
1021 1022 1023 1024
            )

            prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
            style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
L
Liam 已提交
1025
            style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
A
AUTOMATIC 已提交
1026 1027

            for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
A
AUTOMATIC 已提交
1028 1029 1030
                button.click(
                    fn=add_style,
                    _js="ask_for_style_name",
1031 1032 1033
                    # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
                    # the same number of parameters, but we only know the style-name after the JavaScript prompt
                    inputs=[dummy_component, prompt, negative_prompt],
A
AUTOMATIC 已提交
1034 1035 1036
                    outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
                )

1037
            for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
A
AUTOMATIC 已提交
1038 1039
                button.click(
                    fn=apply_styles,
1040
                    _js=js_func,
A
AUTOMATIC 已提交
1041 1042
                    inputs=[prompt, negative_prompt, style1, style2],
                    outputs=[prompt, negative_prompt, style1, style2],
A
AUTOMATIC 已提交
1043 1044
                )

Y
yfszzx 已提交
1045 1046
            token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])

1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062
            img2img_paste_fields = [
                (img2img_prompt, "Prompt"),
                (img2img_negative_prompt, "Negative prompt"),
                (steps, "Steps"),
                (sampler_index, "Sampler"),
                (restore_faces, "Face restoration"),
                (cfg_scale, "CFG scale"),
                (seed, "Seed"),
                (width, "Size-1"),
                (height, "Size-2"),
                (batch_size, "Batch size"),
                (subseed, "Variation seed"),
                (subseed_strength, "Variation seed strength"),
                (seed_resize_from_w, "Seed resize from-1"),
                (seed_resize_from_h, "Seed resize from-2"),
                (denoising_strength, "Denoising strength"),
A
AUTOMATIC 已提交
1063
                (mask_blur, "Mask blur"),
1064
                *modules.scripts.scripts_img2img.infotext_fields
1065
            ]
1066
            parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
Y
yfszzx 已提交
1067
            parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
1068

1069
    modules.scripts.scripts_current = None
1070

1071 1072 1073
    with gr.Blocks(analytics_enabled=False) as extras_interface:
        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):
1074
                with gr.Tabs(elem_id="mode_extras"):
1075 1076
                    with gr.TabItem('Single Image', elem_id="extras_single_tab"):
                        extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
A
ArrowM 已提交
1077

1078 1079
                    with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"):
                        image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
A
AUTOMATIC 已提交
1080

1081 1082 1083 1084
                    with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"):
                        extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
                        extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
                        show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
1085

1086 1087
                submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')

J
Justin Maier 已提交
1088
                with gr.Tabs(elem_id="extras_resize_mode"):
1089 1090 1091
                    with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"):
                        upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
                    with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"):
J
Justin Maier 已提交
1092 1093
                        with gr.Group():
                            with gr.Row():
1094 1095 1096
                                upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
                                upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
                            upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
1097

A
AUTOMATIC 已提交
1098
                with gr.Group():
1099
                    extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
A
AUTOMATIC 已提交
1100 1101

                with gr.Group():
M
Mykeehu 已提交
1102
                    extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
1103
                    extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility")
A
AUTOMATIC 已提交
1104 1105

                with gr.Group():
1106
                    gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility")
1107 1108

                with gr.Group():
1109 1110
                    codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility")
                    codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight")
1111

1112
                with gr.Group():
1113
                    upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix")
1114

1115
            result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples)
D
d8ahazard 已提交
1116

1117
        submit.click(
1118
            fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']),
1119
            _js="get_extras_tab_index",
1120
            inputs=[
J
Justin Maier 已提交
1121
                dummy_component,
1122
                dummy_component,
1123
                extras_image,
A
ArrowM 已提交
1124
                image_batch,
1125 1126 1127
                extras_batch_input_dir,
                extras_batch_output_dir,
                show_extras_results,
1128 1129 1130
                gfpgan_visibility,
                codeformer_visibility,
                codeformer_weight,
A
AUTOMATIC 已提交
1131
                upscaling_resize,
J
Justin Maier 已提交
1132 1133 1134
                upscaling_resize_w,
                upscaling_resize_h,
                upscaling_crop,
A
AUTOMATIC 已提交
1135 1136 1137
                extras_upscaler_1,
                extras_upscaler_2,
                extras_upscaler_2_visibility,
1138
                upscale_before_face_fix,
1139 1140
            ],
            outputs=[
A
ArrowM 已提交
1141
                result_images,
1142 1143 1144 1145
                html_info_x,
                html_info,
            ]
        )
1146
        parameters_copypaste.add_paste_fields("extras", extras_image, None)
J
Justin Maier 已提交
1147

C
Chris OBryan 已提交
1148 1149 1150
        extras_image.change(
            fn=modules.extras.clear_cache,
            inputs=[], outputs=[]
S
Seki 已提交
1151
        )
1152

1153 1154 1155 1156 1157 1158 1159
    with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):
                image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")

            with gr.Column(variant='panel'):
                html = gr.HTML()
1160
                generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
1161 1162
                html2 = gr.HTML()
                with gr.Row():
Y
yfszzx 已提交
1163
                    buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
1164
                parameters_copypaste.bind_buttons(buttons, image, generation_info)
1165 1166

        image.change(
1167
            fn=wrap_gradio_call(modules.extras.run_pnginfo),
1168 1169 1170
            inputs=[image],
            outputs=[html, generation_info, html2],
        )
1171

1172
    with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
1173 1174
        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):
1175
                gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
1176

1177
                with gr.Row():
1178
                    primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
1179 1180
                    create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")

1181
                    secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
1182 1183
                    create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")

1184
                    tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
1185 1186
                    create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")

1187 1188 1189
                custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
                interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
                interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
1190 1191

                with gr.Row():
1192 1193
                    checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
                    save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
1194

1195
                modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
1196

1197
            with gr.Column(variant='panel'):
1198
                submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
1199

1200
    with gr.Blocks(analytics_enabled=False) as train_interface:
1201
        with gr.Row().style(equal_height=False):
A
AUTOMATIC 已提交
1202
            gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
1203

A
AUTOMATIC 已提交
1204 1205
        with gr.Row().style(equal_height=False):
            with gr.Tabs(elem_id="train_tabs"):
1206

A
AUTOMATIC 已提交
1207
                with gr.Tab(label="Create embedding"):
1208 1209 1210 1211
                    new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
                    initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
                    nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
                    overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
1212 1213 1214 1215 1216 1217

                    with gr.Row():
                        with gr.Column(scale=3):
                            gr.HTML(value="")

                        with gr.Column():
1218
                            create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
1219

A
AUTOMATIC 已提交
1220
                with gr.Tab(label="Create hypernetwork"):
1221 1222 1223 1224 1225 1226 1227 1228
                    new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
                    new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
                    new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
                    new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
                    new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
                    new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
                    new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
                    overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
A
AUTOMATIC 已提交
1229 1230 1231 1232 1233 1234

                    with gr.Row():
                        with gr.Column(scale=3):
                            gr.HTML(value="")

                        with gr.Column():
1235
                            create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
1236

A
AUTOMATIC 已提交
1237
                with gr.Tab(label="Preprocess images"):
1238 1239 1240 1241 1242
                    process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
                    process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
                    process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
                    process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
                    preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
1243 1244

                    with gr.Row():
1245 1246 1247 1248 1249
                        process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
                        process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
                        process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
                        process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
                        process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
1250

1251
                    with gr.Row(visible=False) as process_split_extra_row:
1252 1253
                        process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
                        process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
1254

C
captin411 已提交
1255
                    with gr.Row(visible=False) as process_focal_crop_row:
1256 1257 1258 1259
                        process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
                        process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
                        process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
                        process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
C
captin411 已提交
1260

1261 1262 1263 1264 1265
                    with gr.Row():
                        with gr.Column(scale=3):
                            gr.HTML(value="")

                        with gr.Column():
S
space-nuko 已提交
1266
                            with gr.Row():
1267 1268
                                interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
                            run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
1269

1270 1271 1272 1273 1274 1275
                    process_split.change(
                        fn=lambda show: gr_show(show),
                        inputs=[process_split],
                        outputs=[process_split_extra_row],
                    )

C
captin411 已提交
1276 1277 1278 1279 1280 1281
                    process_focal_crop.change(
                        fn=lambda show: gr_show(show),
                        inputs=[process_focal_crop],
                        outputs=[process_focal_crop_row],
                    )

A
AUTOMATIC 已提交
1282
                with gr.Tab(label="Train"):
D
DepFA 已提交
1283
                    gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
1284
                    with gr.Row():
A
AUTOMATIC 已提交
1285
                        train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
1286 1287
                        create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
                    with gr.Row():
A
AUTOMATIC 已提交
1288
                        train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
1289
                        create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
D
DepFA 已提交
1290
                    with gr.Row():
1291 1292
                        embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
                        hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
1293
                    
1294 1295
                    with gr.Row():
                        clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
1296
                        clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309

                    batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
                    gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
                    dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
                    log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
                    template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
                    training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
                    training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
                    steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
                    create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
                    save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
                    save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
                    preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
1310
                    with gr.Row():
1311 1312
                        shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
                        tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
1313
                    with gr.Row():
1314
                        latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
1315 1316

                    with gr.Row():
1317 1318 1319
                        interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
                        train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
                        train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
1320

1321 1322 1323 1324
                params = script_callbacks.UiTrainTabParams(txt2img_preview_params)

                script_callbacks.ui_train_tabs_callback(params)

1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338
            with gr.Column():
                progressbar = gr.HTML(elem_id="ti_progressbar")
                ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)

                ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
                ti_preview = gr.Image(elem_id='ti_preview', visible=False)
                ti_progress = gr.HTML(elem_id="ti_progress", value="")
                ti_outcome = gr.HTML(elem_id="ti_error", value="")
                setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)

        create_embedding.click(
            fn=modules.textual_inversion.ui.create_embedding,
            inputs=[
                new_embedding_name,
1339
                initialization_text,
1340
                nvpt,
D
DepFA 已提交
1341
                overwrite_old_embedding,
1342 1343 1344 1345 1346 1347 1348 1349
            ],
            outputs=[
                train_embedding_name,
                ti_output,
                ti_outcome,
            ]
        )

A
AUTOMATIC 已提交
1350
        create_hypernetwork.click(
A
AUTOMATIC 已提交
1351
            fn=modules.hypernetworks.ui.create_hypernetwork,
A
AUTOMATIC 已提交
1352 1353
            inputs=[
                new_hypernetwork_name,
1354
                new_hypernetwork_sizes,
D
DepFA 已提交
1355
                overwrite_old_hypernetwork,
1356
                new_hypernetwork_layer_structure,
D
update  
discus0434 已提交
1357
                new_hypernetwork_activation_func,
1358
                new_hypernetwork_initialization_option,
1359
                new_hypernetwork_add_layer_norm,
D
discus0434 已提交
1360
                new_hypernetwork_use_dropout
A
AUTOMATIC 已提交
1361 1362 1363 1364 1365 1366 1367 1368
            ],
            outputs=[
                train_hypernetwork_name,
                ti_output,
                ti_outcome,
            ]
        )

1369 1370 1371 1372 1373 1374
        run_preprocess.click(
            fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
            _js="start_training_textual_inversion",
            inputs=[
                process_src,
                process_dst,
A
alg-wiki 已提交
1375 1376
                process_width,
                process_height,
D
DepFA 已提交
1377
                preprocess_txt_action,
1378 1379 1380
                process_flip,
                process_split,
                process_caption,
1381 1382 1383
                process_caption_deepbooru,
                process_split_threshold,
                process_overlap_ratio,
C
captin411 已提交
1384 1385 1386 1387 1388
                process_focal_crop,
                process_focal_crop_face_weight,
                process_focal_crop_entropy_weight,
                process_focal_crop_edges_weight,
                process_focal_crop_debug,
1389 1390 1391 1392 1393 1394 1395
            ],
            outputs=[
                ti_output,
                ti_outcome,
            ],
        )

1396 1397 1398 1399 1400
        train_embedding.click(
            fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
            _js="start_training_textual_inversion",
            inputs=[
                train_embedding_name,
D
DepFA 已提交
1401
                embedding_learn_rate,
1402
                batch_size,
1403
                gradient_step,
1404 1405
                dataset_directory,
                log_directory,
A
alg-wiki 已提交
1406 1407
                training_width,
                training_height,
1408
                steps,
1409 1410
                clip_grad_mode,
                clip_grad_value,
1411 1412 1413
                shuffle_tags,
                tag_drop_out,
                latent_sampling_method,
1414 1415 1416
                create_image_every,
                save_embedding_every,
                template_file,
D
DepFA 已提交
1417
                save_image_with_stored_embedding,
1418 1419
                preview_from_txt2img,
                *txt2img_preview_params,
1420 1421 1422 1423 1424 1425 1426
            ],
            outputs=[
                ti_output,
                ti_outcome,
            ]
        )

A
AUTOMATIC 已提交
1427
        train_hypernetwork.click(
A
AUTOMATIC 已提交
1428
            fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
A
AUTOMATIC 已提交
1429 1430 1431
            _js="start_training_textual_inversion",
            inputs=[
                train_hypernetwork_name,
D
DepFA 已提交
1432
                hypernetwork_learn_rate,
1433
                batch_size,
1434
                gradient_step,
A
AUTOMATIC 已提交
1435 1436
                dataset_directory,
                log_directory,
1437 1438
                training_width,
                training_height,
1439
                steps,
1440 1441
                clip_grad_mode,
                clip_grad_value,
1442 1443 1444
                shuffle_tags,
                tag_drop_out,
                latent_sampling_method,
1445 1446 1447
                create_image_every,
                save_embedding_every,
                template_file,
1448 1449
                preview_from_txt2img,
                *txt2img_preview_params,
1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462
            ],
            outputs=[
                ti_output,
                ti_outcome,
            ]
        )

        interrupt_training.click(
            fn=lambda: shared.state.interrupt(),
            inputs=[],
            outputs=[],
        )

S
space-nuko 已提交
1463 1464 1465 1466 1467 1468
        interrupt_preprocessing.click(
            fn=lambda: shared.state.interrupt(),
            inputs=[],
            outputs=[],
        )

1469
    def create_setting_component(key, is_quicksettings=False):
1470 1471 1472 1473 1474 1475
        def fun():
            return opts.data[key] if key in opts.data else opts.data_labels[key].default

        info = opts.data_labels[key]
        t = type(info.default)

1476 1477
        args = info.component_args() if callable(info.component_args) else info.component_args

1478
        if info.component is not None:
1479
            comp = info.component
1480
        elif t == str:
1481
            comp = gr.Textbox
1482
        elif t == int:
1483
            comp = gr.Number
1484
        elif t == bool:
1485
            comp = gr.Checkbox
1486 1487 1488
        else:
            raise Exception(f'bad options item type: {str(t)} for key {key}')

A
AUTOMATIC 已提交
1489 1490
        elem_id = "setting_"+key

1491 1492
        if info.refresh is not None:
            if is_quicksettings:
1493
                res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
A
AUTOMATIC 已提交
1494
                create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
1495
            else:
1496
                with FormRow():
1497
                    res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
A
AUTOMATIC 已提交
1498
                    create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
1499
        else:
1500
            res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
1501 1502

        return res
1503

A
AUTOMATIC 已提交
1504
    components = []
1505
    component_dict = {}
A
AUTOMATIC 已提交
1506

1507 1508 1509
    script_callbacks.ui_settings_callback()
    opts.reorder()

1510
    def run_settings(*args):
1511
        changed = []
1512 1513

        for key, value, comp in zip(opts.data_labels.keys(), args, components):
1514
            assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
1515

A
AUTOMATIC 已提交
1516
        for key, value, comp in zip(opts.data_labels.keys(), args, components):
1517 1518 1519
            if comp == dummy_component:
                continue

1520
            if opts.set(key, value):
1521
                changed.append(key)
1522

1523 1524 1525
        try:
            opts.save(shared.config_filename)
        except RuntimeError:
1526
            return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
1527
        return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
1528

1529 1530 1531 1532
    def run_settings_single(value, key):
        if not opts.same_type(value, opts.data_labels[key].default):
            return gr.update(visible=True), opts.dumpjson()

1533 1534
        if not opts.set(key, value):
            return gr.update(value=getattr(opts, key)), opts.dumpjson()
1535 1536 1537 1538 1539

        opts.save(shared.config_filename)

        return gr.update(value=value), opts.dumpjson()

A
AUTOMATIC 已提交
1540
    with gr.Blocks(analytics_enabled=False) as settings_interface:
1541
        with gr.Row():
A
AUTOMATIC 已提交
1542 1543 1544 1545
            with gr.Column(scale=6):
                settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
            with gr.Column():
                restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
A
AUTOMATIC 已提交
1546

1547
        result = gr.HTML(elem_id="settings_result")
A
AUTOMATIC 已提交
1548

1549
        quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
1550
        quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
1551

1552 1553
        quicksettings_list = []

1554
        previous_section = None
1555 1556
        current_tab = None
        with gr.Tabs(elem_id="settings"):
1557
            for i, (k, item) in enumerate(opts.data_labels.items()):
1558
                section_must_be_skipped = item.section[0] is None
D
DepFA 已提交
1559

1560
                if previous_section != item.section and not section_must_be_skipped:
1561
                    elem_id, text = item.section
D
DepFA 已提交
1562

1563 1564
                    if current_tab is not None:
                        current_tab.__exit__()
A
AUTOMATIC 已提交
1565

1566 1567
                    current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
                    current_tab.__enter__()
1568 1569 1570

                    previous_section = item.section

1571
                if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
1572 1573
                    quicksettings_list.append((i, k, item))
                    components.append(dummy_component)
1574 1575
                elif section_must_be_skipped:
                    components.append(dummy_component)
1576 1577 1578 1579
                else:
                    component = create_setting_component(k)
                    component_dict[k] = component
                    components.append(component)
1580

1581 1582
            if current_tab is not None:
                current_tab.__exit__()
A
AUTOMATIC 已提交
1583

1584 1585 1586 1587
            with gr.TabItem("Actions"):
                request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
                download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
                reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
1588

A
AUTOMATIC 已提交
1589 1590 1591 1592 1593
            if os.path.exists("html/licenses.html"):
                with open("html/licenses.html", encoding="utf8") as file:
                    with gr.TabItem("Licenses"):
                        gr.HTML(file.read(), elem_id="licenses")

1594
            gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
1595

1596 1597 1598 1599
        request_notifications.click(
            fn=lambda: None,
            inputs=[],
            outputs=[],
1600
            _js='function(){}'
1601 1602
        )

A
AUTOMATIC 已提交
1603 1604 1605 1606 1607 1608 1609
        download_localization.click(
            fn=lambda: None,
            inputs=[],
            outputs=[],
            _js='download_localization'
        )

D
DepFA 已提交
1610
        def reload_scripts():
D
DepFA 已提交
1611
            modules.scripts.reload_script_body_only()
1612
            reload_javascript()  # need to refresh the html page
D
DepFA 已提交
1613 1614 1615 1616

        reload_script_bodies.click(
            fn=reload_scripts,
            inputs=[],
A
AUTOMATIC 已提交
1617
            outputs=[]
D
DepFA 已提交
1618
        )
1619 1620

        def request_restart():
1621
            shared.state.interrupt()
1622
            shared.state.need_restart = True
1623 1624 1625

        restart_gradio.click(
            fn=request_restart,
1626
            _js='restart_reload',
1627 1628 1629
            inputs=[],
            outputs=[],
        )
J
Justin Maier 已提交
1630

1631
    interfaces = [
A
AUTOMATIC 已提交
1632 1633 1634 1635
        (txt2img_interface, "txt2img", "txt2img"),
        (img2img_interface, "img2img", "img2img"),
        (extras_interface, "Extras", "extras"),
        (pnginfo_interface, "PNG Info", "pnginfo"),
1636
        (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
A
AUTOMATIC 已提交
1637
        (train_interface, "Train", "ti"),
1638 1639
    ]

A
AUTOMATIC 已提交
1640 1641 1642
    css = ""

    for cssfile in modules.scripts.list_files_with_name("style.css"):
A
AUTOMATIC 已提交
1643 1644 1645
        if not os.path.isfile(cssfile):
            continue

A
AUTOMATIC 已提交
1646 1647
        with open(cssfile, "r", encoding="utf8") as file:
            css += file.read() + "\n"
1648

A
typo  
AUTOMATIC 已提交
1649
    if os.path.exists(os.path.join(script_path, "user.css")):
A
AUTOMATIC 已提交
1650
        with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
A
AUTOMATIC 已提交
1651
            css += file.read() + "\n"
A
AUTOMATIC 已提交
1652

1653 1654 1655
    if not cmd_opts.no_progressbar_hiding:
        css += css_hide_progressbar

1656 1657 1658
    interfaces += script_callbacks.ui_tabs_callback()
    interfaces += [(settings_interface, "Settings", "settings")]

1659 1660 1661
    extensions_interface = ui_extensions.create_ui()
    interfaces += [(extensions_interface, "Extensions", "extensions")]

A
AUTOMATIC 已提交
1662
    with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
1663
        with gr.Row(elem_id="quicksettings"):
1664
            for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
1665
                component = create_setting_component(k, is_quicksettings=True)
1666 1667
                component_dict[k] = component

1668 1669 1670
        parameters_copypaste.integrate_settings_paste_fields(component_dict)
        parameters_copypaste.run_bind()

1671
        with gr.Tabs(elem_id="tabs") as tabs:
A
AUTOMATIC 已提交
1672
            for interface, label, ifid in interfaces:
1673
                with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
A
AUTOMATIC 已提交
1674
                    interface.render()
J
Justin Maier 已提交
1675

1676 1677
        if os.path.exists(os.path.join(script_path, "notification.mp3")):
            audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
A
AUTOMATIC 已提交
1678

A
AUTOMATIC 已提交
1679 1680 1681 1682
        if os.path.exists("html/footer.html"):
            with open("html/footer.html", encoding="utf8") as file:
                gr.HTML(file.read(), elem_id="footer")

1683
        text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
1684
        settings_submit.click(
1685
            fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
1686
            inputs=components,
1687
            outputs=[text_settings, result],
1688
        )
1689 1690 1691 1692 1693 1694 1695 1696 1697 1698

        for i, k, item in quicksettings_list:
            component = component_dict[k]

            component.change(
                fn=lambda value, k=k: run_settings_single(value, key=k),
                inputs=[component],
                outputs=[component, text_settings],
            )

1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709
        component_keys = [k for k in opts.data_labels.keys() if k in component_dict]

        def get_settings_values():
            return [getattr(opts, key) for key in component_keys]

        demo.load(
            fn=get_settings_values,
            inputs=[],
            outputs=[component_dict[k] for k in component_keys],
        )

S
safentisAuth 已提交
1710 1711
        def modelmerger(*args):
            try:
1712
                results = modules.extras.run_modelmerger(*args)
S
safentisAuth 已提交
1713 1714 1715
            except Exception as e:
                print("Error loading/saving model file:", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)
1716
                modules.sd_models.list_models()  # to remove the potentially missing models from the list
1717
                return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
S
safentisAuth 已提交
1718
            return results
1719

1720
        modelmerger_merge.click(
S
safentisAuth 已提交
1721
            fn=modelmerger,
1722 1723 1724
            inputs=[
                primary_model_name,
                secondary_model_name,
1725
                tertiary_model_name,
1726 1727 1728
                interp_method,
                interp_amount,
                save_as_half,
S
safentisAuth 已提交
1729
                custom_name,
1730
                checkpoint_format,
1731 1732 1733 1734 1735
            ],
            outputs=[
                submit_result,
                primary_model_name,
                secondary_model_name,
1736
                tertiary_model_name,
1737 1738 1739
                component_dict['sd_model_checkpoint'],
            ]
        )
1740

1741
    ui_config_file = cmd_opts.ui_config_file
A
AUTOMATIC 已提交
1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755
    ui_settings = {}
    settings_count = len(ui_settings)
    error_loading = False

    try:
        if os.path.exists(ui_config_file):
            with open(ui_config_file, "r", encoding="utf8") as file:
                ui_settings = json.load(file)
    except Exception:
        error_loading = True
        print("Error loading settings:", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)

    def loadsave(path, x):
ふぁ 已提交
1756
        def apply_field(obj, field, condition=None, init_field=None):
A
AUTOMATIC 已提交
1757
            key = path + "/" + field
1758

1759
            if getattr(obj, 'custom_script_source', None) is not None:
1760
              key = 'customscript/' + obj.custom_script_source + '/' + key
J
Justin Maier 已提交
1761

A
AUTOMATIC 已提交
1762 1763
            if getattr(obj, 'do_not_save_to_config', False):
                return
J
Justin Maier 已提交
1764

A
AUTOMATIC 已提交
1765 1766 1767
            saved_value = ui_settings.get(key, None)
            if saved_value is None:
                ui_settings[key] = getattr(obj, field)
C
CookieHCl 已提交
1768 1769 1770
            elif condition and not condition(saved_value):
                print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
            else:
A
AUTOMATIC 已提交
1771
                setattr(obj, field, saved_value)
ふぁ 已提交
1772 1773
                if init_field is not None:
                    init_field(saved_value)
A
AUTOMATIC 已提交
1774

A
AUTOMATIC 已提交
1775 1776 1777
        if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
            apply_field(x, 'visible')

A
AUTOMATIC 已提交
1778 1779 1780 1781 1782 1783 1784
        if type(x) == gr.Slider:
            apply_field(x, 'value')
            apply_field(x, 'minimum')
            apply_field(x, 'maximum')
            apply_field(x, 'step')

        if type(x) == gr.Radio:
1785
            apply_field(x, 'value', lambda val: val in x.choices)
A
AUTOMATIC 已提交
1786

D
DepFA 已提交
1787
        if type(x) == gr.Checkbox:
D
DepFA 已提交
1788
            apply_field(x, 'value')
D
DepFA 已提交
1789 1790

        if type(x) == gr.Textbox:
D
DepFA 已提交
1791
            apply_field(x, 'value')
J
Justin Maier 已提交
1792

D
DepFA 已提交
1793
        if type(x) == gr.Number:
D
DepFA 已提交
1794
            apply_field(x, 'value')
J
Justin Maier 已提交
1795

1796 1797 1798
        # Since there are many dropdowns that shouldn't be saved,
        # we only mark dropdowns that should be saved.
        if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False):
ふぁ 已提交
1799 1800
            apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
            apply_field(x, 'visible')
1801

A
AUTOMATIC 已提交
1802 1803
    visit(txt2img_interface, loadsave, "txt2img")
    visit(img2img_interface, loadsave, "img2img")
1804
    visit(extras_interface, loadsave, "extras")
1805
    visit(modelmerger_interface, loadsave, "modelmerger")
A
AUTOMATIC 已提交
1806 1807 1808 1809 1810

    if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
        with open(ui_config_file, "w", encoding="utf8") as file:
            json.dump(ui_settings, file, indent=4)

1811 1812 1813
    return demo


1814
def reload_javascript():
1815 1816
    with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
        javascript = f'<script>{jsfile.read()}</script>'
A
AUTOMATIC 已提交
1817

A
AUTOMATIC 已提交
1818
    scripts_list = modules.scripts.list_scripts("javascript", ".js")
1819

A
AUTOMATIC 已提交
1820 1821
    for basedir, filename, path in scripts_list:
        with open(path, "r", encoding="utf8") as jsfile:
1822
            javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
1823

1824 1825
    if cmd_opts.theme is not None:
        javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"
1826

1827
    javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
1828

D
DepFA 已提交
1829
    def template_response(*args, **kwargs):
1830
        res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
1831 1832
        res.body = res.body.replace(
            b'</head>', f'{javascript}</head>'.encode("utf8"))
D
DepFA 已提交
1833 1834 1835 1836
        res.init_headers()
        return res

    gradio.routes.templates.TemplateResponse = template_response
Y
yfszzx 已提交
1837

1838

1839 1840
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
    shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse