ui.py 85.8 KB
Newer Older
1 2 3 4
import base64
import html
import io
import json
A
AUTOMATIC 已提交
5
import math
6 7
import mimetypes
import os
A
AUTOMATIC 已提交
8
import random
9
import sys
10
import tempfile
11 12
import time
import traceback
M
Michoko 已提交
13 14
import platform
import subprocess as sp
15
from functools import partial, reduce
16

A
AUTOMATIC 已提交
17 18
import numpy as np
import torch
19
from PIL import Image, PngImagePlugin
20
import piexif
21 22 23

import gradio as gr
import gradio.utils
A
AUTOMATIC 已提交
24
import gradio.routes
25

A
AUTOMATIC 已提交
26
from modules import sd_hijack, sd_models, localization
27
from modules.paths import script_path
28 29

from modules.shared import opts, cmd_opts, restricted_opts, aesthetic_embeddings
M
init  
MalumaDev 已提交
30

31 32
if cmd_opts.deepdanbooru:
    from modules.deepbooru import get_deepbooru_tags
33 34
import modules.shared as shared
from modules.sd_samplers import samplers, samplers_for_img2img
35
from modules.sd_hijack import model_hijack
36
import modules.ldsr_model
A
AUTOMATIC 已提交
37
import modules.scripts
38 39
import modules.gfpgan_model
import modules.codeformer_model
A
AUTOMATIC 已提交
40
import modules.styles
41
import modules.generation_parameters_copypaste
A
AUTOMATIC 已提交
42
from modules import prompt_parser
M
Milly 已提交
43
from modules.images import save_image
44
import modules.textual_inversion.ui
45
import modules.hypernetworks.ui
46

M
MalumaDev 已提交
47
import modules.aesthetic_clip as aesthetic_clip
Y
yfszzx 已提交
48
import modules.images_history as img_his
49 50


A
Aidan Holland 已提交
51
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
52 53 54 55
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')


56
if not cmd_opts.share and not cmd_opts.listen:
57 58 59 60
    # fix gradio phoning home
    gradio.utils.version_check = lambda: None
    gradio.utils.get_local_ip_address = lambda: '127.0.0.1'

J
JamnedZ 已提交
61 62 63
if cmd_opts.ngrok != None:
    import modules.ngrok as ngrok
    print('ngrok authtoken detected, trying to connect...')
D
ddPn08 已提交
64
    ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860, cmd_opts.ngrok_region)
J
JamnedZ 已提交
65

66 67 68 69 70 71 72 73 74 75

def gr_show(visible=True):
    return {"visible": visible, "__type__": "update"}


sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None

css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
76
.wrap .m-12::before { content:"Loading..." }
77 78 79 80
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
"""

81 82 83 84
# Using constants for these since the variation selector isn't visible.
# Important that they exactly match script.js for tooltip to work.
random_symbol = '\U0001f3b2\ufe0f'  # 🎲️
reuse_symbol = '\u267b\ufe0f'  # ♻️
85 86
art_symbol = '\U0001f3a8'  # 🎨
paste_symbol = '\u2199\ufe0f'  # ↙
87
folder_symbol = '\U0001f4c2'  # 📂
88
refresh_symbol = '\U0001f504'  # 🔄
A
AUTOMATIC 已提交
89 90
save_style_symbol = '\U0001f4be'  # 💾
apply_style_symbol = '\U0001f4cb'  # 📋
91
clear_prompt_symbol = '\U0001F5D1' # 🗑️
92

93

94
def plaintext_to_html(text):
95
    text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
96 97 98 99
    return text


def image_from_url_text(filedata):
100 101 102 103 104 105 106 107
    if type(filedata) == dict and filedata["is_file"]:
        filename = filedata["name"]
        tempdir = os.path.normpath(tempfile.gettempdir())
        normfn = os.path.normpath(filename)
        assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'

        return Image.open(filename)

108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
    if type(filedata) == list:
        if len(filedata) == 0:
            return None

        filedata = filedata[0]

    if filedata.startswith("data:image/png;base64,"):
        filedata = filedata[len("data:image/png;base64,"):]

    filedata = base64.decodebytes(filedata.encode('utf-8'))
    image = Image.open(io.BytesIO(filedata))
    return image


def send_gradio_gallery_to_image(x):
    if len(x) == 0:
        return None

    return image_from_url_text(x[0])

J
jtkelm2 已提交
128

A
aoirusann 已提交
129
def save_files(js_data, images, do_make_zip, index):
J
Justin Maier 已提交
130
    import csv
131
    filenames = []
A
aoirusann 已提交
132
    fullfns = []
133

A
Aidan Holland 已提交
134
    #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
135 136 137 138 139 140
    class MyObject:
        def __init__(self, d=None):
            if d is not None:
                for key, value in d.items():
                    setattr(self, key, value)

141
    data = json.loads(js_data)
142

143 144
    p = MyObject(data)
    path = opts.outdir_save
145
    save_to_dirs = opts.use_save_to_dirs_for_ui
M
Milly 已提交
146 147
    extension: str = opts.samples_format
    start_index = 0
148

149
    if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]):  # ensures we are looking at a specific non-grid picture, and we have save_selected_only
150

J
jtkelm2 已提交
151
        images = [images[index]]
M
Milly 已提交
152
        start_index = index
153

154 155
    os.makedirs(opts.outdir_save, exist_ok=True)

156
    with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
157 158 159
        at_start = file.tell() == 0
        writer = csv.writer(file)
        if at_start:
160
            writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
161

M
Milly 已提交
162
        for image_index, filedata in enumerate(images, start_index):
163
            image = image_from_url_text(filedata)
164

M
Milly 已提交
165 166 167
            is_grid = image_index < p.index_of_first_image
            i = 0 if is_grid else (image_index - p.index_of_first_image)

G
Greg Fuller 已提交
168
            fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
M
Milly 已提交
169 170

            filename = os.path.relpath(fullfn, path)
171
            filenames.append(filename)
A
aoirusann 已提交
172
            fullfns.append(fullfn)
A
aoirusann 已提交
173 174 175
            if txt_fullfn:
                filenames.append(os.path.basename(txt_fullfn))
                fullfns.append(txt_fullfn)
176

G
Greg Fuller 已提交
177
        writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
178

A
aoirusann 已提交
179 180 181 182 183 184 185 186 187 188 189
    # Make Zip
    if do_make_zip:
        zip_filepath = os.path.join(path, "images.zip")

        from zipfile import ZipFile
        with ZipFile(zip_filepath, "w") as zip_file:
            for i in range(len(fullfns)):
                with open(fullfns[i], mode="rb") as f:
                    zip_file.writestr(filenames[i], f.read())
        fullfns.insert(0, zip_filepath)

190
    return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
191 192


193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
def save_pil_to_file(pil_image, dir=None):
    use_metadata = False
    metadata = PngImagePlugin.PngInfo()
    for key, value in pil_image.info.items():
        if isinstance(key, str) and isinstance(value, str):
            metadata.add_text(key, value)
            use_metadata = True

    file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
    pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
    return file_obj


# override save to file function so that it also writes PNG info
gr.processing_utils.save_pil_to_file = save_pil_to_file


210 211
def wrap_gradio_call(func, extra_outputs=None):
    def f(*args, extra_outputs_array=extra_outputs, **kwargs):
E
EyeDeck 已提交
212 213
        run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
        if run_memmon:
214
            shared.mem_mon.monitor()
215 216 217 218 219
        t = time.perf_counter()

        try:
            res = list(func(*args, **kwargs))
        except Exception as e:
220 221 222
            # When printing out our debug argument list, do not print out more than a MB of text
            max_debug_str_len = 131072 # (1024*1024)/8

223
            print("Error completing request", file=sys.stderr)
224 225 226 227 228
            argStr = f"Arguments: {str(args)} {str(kwargs)}"
            print(argStr[:max_debug_str_len], file=sys.stderr)
            if len(argStr) > max_debug_str_len:
                print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)

229 230
            print(traceback.format_exc(), file=sys.stderr)

A
AUTOMATIC 已提交
231 232 233
            shared.state.job = ""
            shared.state.job_count = 0

234 235 236 237
            if extra_outputs_array is None:
                extra_outputs_array = [None, '']

            res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
238 239

        elapsed = time.perf_counter() - t
240 241 242 243 244
        elapsed_m = int(elapsed // 60)
        elapsed_s = elapsed % 60
        elapsed_text = f"{elapsed_s:.2f}s"
        if (elapsed_m > 0):
            elapsed_text = f"{elapsed_m}m "+elapsed_text
245

E
EyeDeck 已提交
246
        if run_memmon:
247 248 249 250 251 252 253
            mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
            active_peak = mem_stats['active_peak']
            reserved_peak = mem_stats['reserved_peak']
            sys_peak = mem_stats['system_peak']
            sys_total = mem_stats['total']
            sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)

E
EyeDeck 已提交
254
            vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
255 256
        else:
            vram_html = ''
257

258
        # last item is always HTML
259
        res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
260

261
        shared.state.skipped = False
262
        shared.state.interrupted = False
263
        shared.state.job_count = 0
264 265 266 267 268 269

        return tuple(res)

    return f


270
def calc_time_left(progress, threshold, label, force_display):
A
Anastasius 已提交
271
    if progress == 0:
272
        return ""
A
Anastasius 已提交
273 274 275
    else:
        time_since_start = time.time() - shared.state.time_start
        eta = (time_since_start/progress)
276
        eta_relative = eta-time_since_start
A
Alexandre Simard 已提交
277 278 279 280 281 282 283
        if (eta_relative > threshold and progress > 0.02) or force_display:
            if eta_relative > 3600:
                return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
            elif eta_relative > 60:
                return label + time.strftime('%M:%S',  time.gmtime(eta_relative))
            else:
                return label + time.strftime('%Ss',  time.gmtime(eta_relative))
284 285
        else:
            return ""
A
Anastasius 已提交
286 287


288
def check_progress_call(id_part):
A
AUTOMATIC 已提交
289
    if shared.state.job_count == 0:
290
        return "", gr_show(False), gr_show(False), gr_show(False)
A
AUTOMATIC 已提交
291

A
AUTOMATIC 已提交
292 293 294 295
    progress = 0

    if shared.state.job_count > 0:
        progress += shared.state.job_no / shared.state.job_count
A
AUTOMATIC 已提交
296 297 298
    if shared.state.sampling_steps > 0:
        progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps

A
Alexandre Simard 已提交
299
    time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display )
300 301
    if time_left != "":
        shared.state.time_left_force_display = True
A
Anastasius 已提交
302

A
AUTOMATIC 已提交
303 304
    progress = min(progress, 1)

A
AUTOMATIC 已提交
305 306
    progressbar = ""
    if opts.show_progressbar:
307
        progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{"&nbsp;" * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}</div></div>"""
A
AUTOMATIC 已提交
308 309 310 311 312

    image = gr_show(False)
    preview_visibility = gr_show(False)

    if opts.show_progress_every_n_steps > 0:
313
        if shared.parallel_processing_allowed:
A
AUTOMATIC 已提交
314

315 316 317
            if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
                shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
                shared.state.current_image_sampling_step = shared.state.sampling_step
A
AUTOMATIC 已提交
318

A
AUTOMATIC 已提交
319 320
        image = shared.state.current_image

A
AUTOMATIC 已提交
321
        if image is None:
A
AUTOMATIC 已提交
322 323 324
            image = gr.update(value=None)
        else:
            preview_visibility = gr_show(True)
A
AUTOMATIC 已提交
325

326 327 328 329 330 331
    if shared.state.textinfo is not None:
        textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
    else:
        textinfo_result = gr_show(False)

    return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
A
AUTOMATIC 已提交
332 333


334
def check_progress_call_initial(id_part):
335
    shared.state.job_count = -1
336 337
    shared.state.current_latent = None
    shared.state.current_image = None
338
    shared.state.textinfo = None
A
Anastasius 已提交
339
    shared.state.time_start = time.time()
340
    shared.state.time_left_force_display = False
341

342
    return check_progress_call(id_part)
343 344


A
AUTOMATIC 已提交
345 346 347 348 349 350 351
def roll_artist(prompt):
    allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories])
    artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])

    return prompt + ", " + artist.name if prompt != '' else artist.name


A
AUTOMATIC 已提交
352 353 354 355 356 357 358
def visit(x, func, path=""):
    if hasattr(x, 'children'):
        for c in x.children:
            visit(c, func, path)
    elif x.label is not None:
        func(path + "/" + str(x.label), x)

359

360 361
def add_style(name: str, prompt: str, negative_prompt: str):
    if name is None:
A
AUTOMATIC 已提交
362
        return [gr_show() for x in range(4)]
A
AUTOMATIC 已提交
363

364
    style = modules.styles.PromptStyle(name, prompt, negative_prompt)
A
AUTOMATIC 已提交
365
    shared.prompt_styles.styles[style.name] = style
366 367
    # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
    # reserialize all styles every time we save them
A
AUTOMATIC 已提交
368
    shared.prompt_styles.save_styles(shared.styles_filename)
A
AUTOMATIC 已提交
369

370
    return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
A
AUTOMATIC 已提交
371 372 373 374 375 376 377


def apply_styles(prompt, prompt_neg, style1_name, style2_name):
    prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
    prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])

    return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
A
AUTOMATIC 已提交
378 379


A
AUTOMATIC 已提交
380 381 382 383 384
def interrogate(image):
    prompt = shared.interrogator.interrogate(image)

    return gr_show(True) if prompt is None else prompt

A
AUTOMATIC 已提交
385

G
Greendayle 已提交
386 387 388 389 390
def interrogate_deepbooru(image):
    prompt = get_deepbooru_tags(image)
    return gr_show(True) if prompt is None else prompt


391 392 393 394
def create_seed_inputs():
    with gr.Row():
        with gr.Box():
            with gr.Row(elem_id='seed_row'):
395
                seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1)
396
                seed.style(container=False)
397 398
                random_seed = gr.Button(random_symbol, elem_id='random_seed')
                reuse_seed = gr.Button(reuse_symbol, elem_id='reuse_seed')
399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428

        with gr.Box(elem_id='subseed_show_box'):
            seed_checkbox = gr.Checkbox(label='Extra', elem_id='subseed_show', value=False)

    # Components to show/hide based on the 'Extra' checkbox
    seed_extras = []

    with gr.Row(visible=False) as seed_extra_row_1:
        seed_extras.append(seed_extra_row_1)
        with gr.Box():
            with gr.Row(elem_id='subseed_row'):
                subseed = gr.Number(label='Variation seed', value=-1)
                subseed.style(container=False)
                random_subseed = gr.Button(random_symbol, elem_id='random_subseed')
                reuse_subseed = gr.Button(reuse_symbol, elem_id='reuse_subseed')
        subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01)

    with gr.Row(visible=False) as seed_extra_row_2:
        seed_extras.append(seed_extra_row_2)
        seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0)
        seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0)

    random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
    random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])

    def change_visibility(show):
        return {comp: gr_show(show) for comp in seed_extras}

    seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)

429
    return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
430 431


432

433
def clear_prompt(_prompt, confirmed, _token_counter):
434
        if(confirmed):
435
            return ["", confirmed, update_token_counter("", 1)]
436

437
def connect_clear_prompt(button, prompt, _dummy_confirmed, token_counter):
438
    button.click(
439 440
        _js="clear_prompt",
        fn=clear_prompt,
441 442
        inputs=[prompt, _dummy_confirmed, token_counter],
        outputs=[prompt, _dummy_confirmed, token_counter],
443
    )
444 445


446 447 448
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
    """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
        (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
449
        was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
450 451 452
    def copy_seed(gen_info_string: str, index):
        res = -1

453 454
        try:
            gen_info = json.loads(gen_info_string)
455 456 457 458 459
            index -= gen_info.get('index_of_first_image', 0)

            if is_subseed and gen_info.get('subseed_strength', 0) > 0:
                all_subseeds = gen_info.get('all_subseeds', [-1])
                res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
460
            else:
461 462 463
                all_seeds = gen_info.get('all_seeds', [-1])
                res = all_seeds[index if 0 <= index < len(all_seeds) else 0]

464 465 466 467
        except json.decoder.JSONDecodeError as e:
            if gen_info_string != '':
                print("Error parsing JSON generation info:", file=sys.stderr)
                print(gen_info_string, file=sys.stderr)
468 469

        return [res, gr_show(False)]
470 471 472

    reuse_seed.click(
        fn=copy_seed,
473
        _js="(x, y) => [x, selected_gallery_index()]",
474
        show_progress=False,
475 476
        inputs=[generation_info, dummy_component],
        outputs=[seed, dummy_component]
477 478
    )

479

L
Liam 已提交
480
def update_token_counter(text, steps):
481
    try:
A
AUTOMATIC 已提交
482 483 484
        _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
        prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)

485 486 487 488 489
    except Exception:
        # a parsing error can happen here during typing, and we don't want to bother the user with
        # messages related to it in console
        prompt_schedules = [[[steps, text]]]

L
Liam 已提交
490
    flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
491
    prompts = [prompt_text for step, prompt_text in flat_prompts]
L
Liam 已提交
492
    tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
493 494
    style_class = ' class="red"' if (token_count > max_length) else ""
    return f"<span {style_class}>{token_count}/{max_length}</span>"
A
AUTOMATIC 已提交
495

496

A
AUTOMATIC 已提交
497
def create_toprow(is_img2img):
498 499
    id_part = "img2img" if is_img2img else "txt2img"

A
AUTOMATIC 已提交
500
    with gr.Row(elem_id="toprow"):
A
AUTOMATIC 已提交
501
        with gr.Column(scale=6):
A
AUTOMATIC 已提交
502
            with gr.Row():
503
                with gr.Column(scale=80):
A
AUTOMATIC 已提交
504
                    with gr.Row():
505
                        prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
A
aoirusann 已提交
506 507 508
                            placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
                        )

A
AUTOMATIC 已提交
509
            with gr.Row():
A
AUTOMATIC 已提交
510
                with gr.Column(scale=80):
B
Ben 已提交
511
                    with gr.Row():
512
                        negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
A
aoirusann 已提交
513 514 515
                            placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
                        )

A
AUTOMATIC 已提交
516 517 518 519 520
        with gr.Column(scale=1, elem_id="roll_col"):
            roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
            paste = gr.Button(value=paste_symbol, elem_id="paste")
            save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
            prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
521
            clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id="clear_prompt", visible=opts.clear_prompt_visible)
A
AUTOMATIC 已提交
522

A
AUTOMATIC 已提交
523 524 525 526 527 528 529 530 531 532 533
            token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
            token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")

        button_interrogate = None
        button_deepbooru = None
        if is_img2img:
            with gr.Column(scale=1, elem_id="interrogate_col"):
                button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")

                if cmd_opts.deepdanbooru:
                    button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
A
AUTOMATIC 已提交
534 535 536

        with gr.Column(scale=1):
            with gr.Row():
537
                skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
538
                interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
539
                submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
540

541 542 543 544 545 546
                skip.click(
                    fn=lambda: shared.state.skip(),
                    inputs=[],
                    outputs=[],
                )

547 548 549 550 551
                interrupt.click(
                    fn=lambda: shared.state.interrupt(),
                    inputs=[],
                    outputs=[],
                )
A
AUTOMATIC 已提交
552

A
AUTOMATIC 已提交
553 554 555
            with gr.Row():
                with gr.Column(scale=1, elem_id="style_pos_col"):
                    prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
556
                    prompt_style.save_to_config = True
A
AUTOMATIC 已提交
557 558 559

                with gr.Column(scale=1, elem_id="style_neg_col"):
                    prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
560
                    prompt_style2.save_to_config = True
A
AUTOMATIC 已提交
561

562
    return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, clear_prompt_button
A
AUTOMATIC 已提交
563 564


565 566 567
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
    if textinfo is None:
        textinfo = gr.HTML(visible=False)
568

569
    check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
570
    check_progress.click(
571
        fn=lambda: check_progress_call(id_part),
572 573
        show_progress=False,
        inputs=[],
574
        outputs=[progressbar, preview, preview, textinfo],
575 576
    )

577
    check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
578
    check_progress_initial.click(
579
        fn=lambda: check_progress_call_initial(id_part),
580 581
        show_progress=False,
        inputs=[],
582
        outputs=[progressbar, preview, preview, textinfo],
583
    )
A
AUTOMATIC 已提交
584 585


586 587 588 589
def apply_setting(key, value):
    if value is None:
        return gr.update()

590 591 592 593
    # dont allow model to be swapped when model hash exists in prompt
    if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
        return gr.update()

594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615
    if key == "sd_model_checkpoint":
        ckpt_info = sd_models.get_closet_checkpoint_match(value)

        if ckpt_info is not None:
            value = ckpt_info.title
        else:
            return gr.update()

    comp_args = opts.data_labels[key].component_args
    if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
        return

    valtype = type(opts.data_labels[key].default)
    oldval = opts.data[key]
    opts.data[key] = valtype(value) if valtype != type(None) else value
    if oldval != value and opts.data_labels[key].onchange is not None:
        opts.data_labels[key].onchange()

    opts.save(shared.config_filename)
    return value


616 617 618 619
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
    def refresh():
        refresh_method()
        args = refreshed_args() if callable(refreshed_args) else refreshed_args
620

621 622
        for k, v in args.items():
            setattr(refresh_component, k, v)
623

624
        return gr.update(**(args or {}))
625

626 627 628 629 630 631 632
    refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
    refresh_button.click(
        fn=refresh,
        inputs=[],
        outputs=[refresh_component]
    )
    return refresh_button
633 634


635 636 637
def create_ui(wrap_gradio_gpu_call):
    import modules.img2img
    import modules.txt2img
638 639


640
    with gr.Blocks(analytics_enabled=False) as txt2img_interface:
P
papuSpartan 已提交
641 642
        txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\
        txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\
643
        token_button, clear_prompt_button = create_toprow(is_img2img=False)
P
papuSpartan 已提交
644

645
        dummy_component = gr.Label(visible=False)
A
AUTOMATIC 已提交
646
        txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
647 648 649



650

651 652 653
        with gr.Row(elem_id='txt2img_progress_row'):
            with gr.Column(scale=1):
                pass
654

655 656
            with gr.Column(scale=1):
                progressbar = gr.HTML(elem_id="txt2img_progressbar")
657
                txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
658
                setup_progressbar(progressbar, txt2img_preview, 'txt2img')
659

660 661 662 663 664
        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):
                steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
                sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")

665 666 667 668
                with gr.Group():
                    width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
                    height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)

669
                with gr.Row():
A
AUTOMATIC 已提交
670
                    restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
671
                    tiling = gr.Checkbox(label='Tiling', value=False)
A
AUTOMATIC 已提交
672 673 674
                    enable_hr = gr.Checkbox(label='Highres. fix', value=False)

                with gr.Row(visible=False) as hr_options:
675 676
                    firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0)
                    firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0)
677
                    denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
678 679

                with gr.Row(equal_height=True):
R
RW21 已提交
680
                    batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
681 682
                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)

683
                cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
684

685
                seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
686

A
AUTOMATIC 已提交
687 688
                aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui()

A
AUTOMATIC 已提交
689
                with gr.Group():
A
AUTOMATIC 已提交
690
                    custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
691 692

            with gr.Column(variant='panel'):
A
AUTOMATIC 已提交
693

694
                with gr.Group():
A
AUTOMATIC 已提交
695
                    txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
A
AUTOMATIC 已提交
696
                    txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4)
697

R
ruocaled 已提交
698
                with gr.Column():
699 700 701 702 703
                    with gr.Row():
                        save = gr.Button('Save')
                        send_to_img2img = gr.Button('Send to img2img')
                        send_to_inpaint = gr.Button('Send to inpaint')
                        send_to_extras = gr.Button('Send to extras')
M
Michoko 已提交
704
                        button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
M
Michoko 已提交
705
                        open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
706

A
aoirusann 已提交
707 708
                    with gr.Row():
                        do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
J
Justin Maier 已提交
709

A
aoirusann 已提交
710
                    with gr.Row():
711
                        download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
A
aoirusann 已提交
712

B
Ben 已提交
713 714 715
                    with gr.Group():
                        html_info = gr.HTML()
                        generation_info = gr.Textbox(visible=False)
716

717 718
            connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
            connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
719
            connect_clear_prompt(clear_prompt_button, txt2img_prompt, dummy_component, token_counter)
720

721
            txt2img_args = dict(
722
                fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
A
AUTOMATIC 已提交
723
                _js="submit",
724
                inputs=[
A
AUTOMATIC 已提交
725
                    txt2img_prompt,
726
                    txt2img_negative_prompt,
A
AUTOMATIC 已提交
727
                    txt2img_prompt_style,
A
AUTOMATIC 已提交
728
                    txt2img_prompt_style2,
729 730
                    steps,
                    sampler_index,
A
AUTOMATIC 已提交
731
                    restore_faces,
732
                    tiling,
733 734 735 736
                    batch_count,
                    batch_size,
                    cfg_scale,
                    seed,
737
                    subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
738 739
                    height,
                    width,
A
AUTOMATIC 已提交
740 741
                    enable_hr,
                    denoising_strength,
742 743
                    firstphase_width,
                    firstphase_height,
744 745 746 747 748 749 750 751
                    aesthetic_lr,
                    aesthetic_weight,
                    aesthetic_steps,
                    aesthetic_imgs,
                    aesthetic_slerp,
                    aesthetic_imgs_text,
                    aesthetic_slerp_angle,
                    aesthetic_text_negative
A
AUTOMATIC 已提交
752
                ] + custom_inputs,
753

754 755 756 757
                outputs=[
                    txt2img_gallery,
                    generation_info,
                    html_info
758 759
                ],
                show_progress=False,
760 761
            )

A
AUTOMATIC 已提交
762
            txt2img_prompt.submit(**txt2img_args)
763 764
            submit.click(**txt2img_args)

D
d8ahazard 已提交
765 766 767 768 769 770 771 772 773 774 775
            txt_prompt_img.change(
                fn=modules.images.image_data,
                inputs=[
                    txt_prompt_img
                ],
                outputs=[
                    txt2img_prompt,
                    txt_prompt_img
                ]
            )

A
AUTOMATIC 已提交
776 777 778 779 780 781
            enable_hr.change(
                fn=lambda x: gr_show(x),
                inputs=[enable_hr],
                outputs=[hr_options],
            )

782 783
            save.click(
                fn=wrap_gradio_call(save_files),
A
aoirusann 已提交
784
                _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
785 786 787
                inputs=[
                    generation_info,
                    txt2img_gallery,
A
aoirusann 已提交
788
                    do_make_zip,
789
                    html_info,
790 791
                ],
                outputs=[
A
aoirusann 已提交
792
                    download_files,
793 794 795 796 797 798
                    html_info,
                    html_info,
                    html_info,
                ]
            )

A
AUTOMATIC 已提交
799 800
            roll.click(
                fn=roll_artist,
L
Liam 已提交
801
                _js="update_txt2img_tokens",
A
AUTOMATIC 已提交
802
                inputs=[
A
AUTOMATIC 已提交
803
                    txt2img_prompt,
A
AUTOMATIC 已提交
804 805
                ],
                outputs=[
A
AUTOMATIC 已提交
806
                    txt2img_prompt,
A
AUTOMATIC 已提交
807 808 809
                ]
            )

810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827
            txt2img_paste_fields = [
                (txt2img_prompt, "Prompt"),
                (txt2img_negative_prompt, "Negative prompt"),
                (steps, "Steps"),
                (sampler_index, "Sampler"),
                (restore_faces, "Face restoration"),
                (cfg_scale, "CFG scale"),
                (seed, "Seed"),
                (width, "Size-1"),
                (height, "Size-2"),
                (batch_size, "Batch size"),
                (subseed, "Variation seed"),
                (subseed_strength, "Variation seed strength"),
                (seed_resize_from_w, "Seed resize from-1"),
                (seed_resize_from_h, "Seed resize from-2"),
                (denoising_strength, "Denoising strength"),
                (enable_hr, lambda d: "Denoising strength" in d),
                (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
828 829
                (firstphase_width, "First pass size-1"),
                (firstphase_height, "First pass size-2"),
830 831 832 833 834 835 836 837
                (aesthetic_lr, "Aesthetic LR"),
                (aesthetic_weight, "Aesthetic weight"),
                (aesthetic_steps, "Aesthetic steps"),
                (aesthetic_imgs, "Aesthetic embedding"),
                (aesthetic_slerp, "Aesthetic slerp"),
                (aesthetic_imgs_text, "Aesthetic text"),
                (aesthetic_text_negative, "Aesthetic text negative"),
                (aesthetic_slerp_angle, "Aesthetic slerp angle"),
838
            ]
839 840 841 842 843 844 845 846 847 848 849 850

            txt2img_preview_params = [
                txt2img_prompt,
                txt2img_negative_prompt,
                steps,
                sampler_index,
                cfg_scale,
                seed,
                width,
                height,
            ]

L
Liam 已提交
851
            token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
852

853
    with gr.Blocks(analytics_enabled=False) as img2img_interface:
854 855
        img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit,\
        img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\
856
        token_counter, token_button, clear_prompt_button = create_toprow(is_img2img=True)
857

858

859
        with gr.Row(elem_id='img2img_progress_row'):
A
AUTOMATIC 已提交
860
            img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
D
d8ahazard 已提交
861

862 863
            with gr.Column(scale=1):
                pass
864

865 866
            with gr.Column(scale=1):
                progressbar = gr.HTML(elem_id="img2img_progressbar")
867
                img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
868
                setup_progressbar(progressbar, img2img_preview, 'img2img')
869

870 871
        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):
872

873
                with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
874
                    with gr.TabItem('img2img', id='img2img'):
875
                        init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
876

877
                    with gr.TabItem('Inpaint', id='inpaint'):
878
                        init_img_with_mask = gr.Image(label="Image for inpainting with mask",  show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
879

880 881
                        init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
                        init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
882 883 884 885

                        mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)

                        with gr.Row():
886
                            mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
887 888
                            inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")

889
                        inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index")
890 891 892 893 894

                        with gr.Row():
                            inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
                            inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32)

895
                    with gr.TabItem('Batch img2img', id='batch'):
896
                        hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
897
                        gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
898 899
                        img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
                        img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
900 901

                with gr.Row():
902 903 904 905
                    resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")

                steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
                sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
A
AUTOMATIC 已提交
906

907
                with gr.Group():
908 909
                    width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
                    height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
910

A
AUTOMATIC 已提交
911
                with gr.Row():
A
AUTOMATIC 已提交
912
                    restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
913
                    tiling = gr.Checkbox(label='Tiling', value=False)
914 915

                with gr.Row():
R
RW21 已提交
916
                    batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
917 918 919
                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)

                with gr.Group():
920
                    cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
921
                    denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75)
922

923
                seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
924

A
AUTOMATIC 已提交
925 926
                aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui()

A
AUTOMATIC 已提交
927
                with gr.Group():
A
AUTOMATIC 已提交
928
                    custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
A
AUTOMATIC 已提交
929

930
            with gr.Column(variant='panel'):
A
AUTOMATIC 已提交
931

932
                with gr.Group():
A
AUTOMATIC 已提交
933
                    img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
A
AUTOMATIC 已提交
934
                    img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4)
935

R
ruocaled 已提交
936
                with gr.Column():
937 938
                    with gr.Row():
                        save = gr.Button('Save')
A
AUTOMATIC 已提交
939 940
                        img2img_send_to_img2img = gr.Button('Send to img2img')
                        img2img_send_to_inpaint = gr.Button('Send to inpaint')
941
                        img2img_send_to_extras = gr.Button('Send to extras')
M
Michoko 已提交
942
                        button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
M
Michoko 已提交
943
                        open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
944

A
aoirusann 已提交
945 946
                    with gr.Row():
                        do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
J
Justin Maier 已提交
947

A
aoirusann 已提交
948
                    with gr.Row():
949
                        download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
A
aoirusann 已提交
950

B
Ben 已提交
951 952 953
                    with gr.Group():
                        html_info = gr.HTML()
                        generation_info = gr.Textbox(visible=False)
954

955 956
            connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
            connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
957
            connect_clear_prompt(clear_prompt_button, img2img_prompt, dummy_component, token_counter)
958

D
d8ahazard 已提交
959 960 961
            img2img_prompt_img.change(
                fn=modules.images.image_data,
                inputs=[
A
AUTOMATIC 已提交
962
                    img2img_prompt_img
D
d8ahazard 已提交
963 964 965 966 967 968 969
                ],
                outputs=[
                    img2img_prompt,
                    img2img_prompt_img
                ]
            )

970
            mask_mode.change(
971
                lambda mode, img: {
972
                    init_img_with_mask: gr_show(mode == 0),
973 974
                    init_img_inpaint: gr_show(mode == 1),
                    init_mask_inpaint: gr_show(mode == 1),
975
                },
976
                inputs=[mask_mode, init_img_with_mask],
977 978
                outputs=[
                    init_img_with_mask,
979 980
                    init_img_inpaint,
                    init_mask_inpaint,
981 982 983
                ],
            )

984
            img2img_args = dict(
985
                fn=wrap_gradio_gpu_call(modules.img2img.img2img),
986
                _js="submit_img2img",
987
                inputs=[
988
                    dummy_component,
A
AUTOMATIC 已提交
989
                    img2img_prompt,
990
                    img2img_negative_prompt,
A
AUTOMATIC 已提交
991
                    img2img_prompt_style,
A
AUTOMATIC 已提交
992
                    img2img_prompt_style2,
993 994
                    init_img,
                    init_img_with_mask,
995 996
                    init_img_inpaint,
                    init_mask_inpaint,
997
                    mask_mode,
998 999 1000 1001
                    steps,
                    sampler_index,
                    mask_blur,
                    inpainting_fill,
A
AUTOMATIC 已提交
1002
                    restore_faces,
1003
                    tiling,
1004 1005 1006 1007 1008
                    batch_count,
                    batch_size,
                    cfg_scale,
                    denoising_strength,
                    seed,
1009
                    subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
1010 1011 1012 1013
                    height,
                    width,
                    resize_mode,
                    inpaint_full_res,
1014
                    inpaint_full_res_padding,
A
AUTOMATIC 已提交
1015
                    inpainting_mask_invert,
1016 1017
                    img2img_batch_input_dir,
                    img2img_batch_output_dir,
M
MalumaDev 已提交
1018 1019 1020 1021 1022 1023 1024 1025
                    aesthetic_lr_im,
                    aesthetic_weight_im,
                    aesthetic_steps_im,
                    aesthetic_imgs_im,
                    aesthetic_slerp_im,
                    aesthetic_imgs_text_im,
                    aesthetic_slerp_angle_im,
                    aesthetic_text_negative_im,
A
AUTOMATIC 已提交
1026
                ] + custom_inputs,
1027 1028 1029 1030
                outputs=[
                    img2img_gallery,
                    generation_info,
                    html_info
1031 1032
                ],
                show_progress=False,
1033 1034
            )

A
AUTOMATIC 已提交
1035
            img2img_prompt.submit(**img2img_args)
1036 1037
            submit.click(**img2img_args)

A
AUTOMATIC 已提交
1038 1039 1040 1041 1042 1043
            img2img_interrogate.click(
                fn=interrogate,
                inputs=[init_img],
                outputs=[img2img_prompt],
            )

1044 1045 1046 1047 1048 1049
            if cmd_opts.deepdanbooru:
                img2img_deepbooru.click(
                    fn=interrogate_deepbooru,
                    inputs=[init_img],
                    outputs=[img2img_prompt],
                )
G
Greendayle 已提交
1050

1051 1052
            save.click(
                fn=wrap_gradio_call(save_files),
A
aoirusann 已提交
1053
                _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
1054 1055 1056
                inputs=[
                    generation_info,
                    img2img_gallery,
A
aoirusann 已提交
1057 1058
                    do_make_zip,
                    html_info,
1059 1060
                ],
                outputs=[
A
aoirusann 已提交
1061
                    download_files,
1062 1063 1064 1065 1066 1067
                    html_info,
                    html_info,
                    html_info,
                ]
            )

A
AUTOMATIC 已提交
1068 1069
            roll.click(
                fn=roll_artist,
L
Liam 已提交
1070
                _js="update_img2img_tokens",
A
AUTOMATIC 已提交
1071 1072 1073 1074 1075 1076 1077 1078 1079 1080
                inputs=[
                    img2img_prompt,
                ],
                outputs=[
                    img2img_prompt,
                ]
            )

            prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
            style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
L
Liam 已提交
1081
            style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
A
AUTOMATIC 已提交
1082 1083

            for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
A
AUTOMATIC 已提交
1084 1085 1086
                button.click(
                    fn=add_style,
                    _js="ask_for_style_name",
1087 1088 1089
                    # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
                    # the same number of parameters, but we only know the style-name after the JavaScript prompt
                    inputs=[dummy_component, prompt, negative_prompt],
A
AUTOMATIC 已提交
1090 1091 1092
                    outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
                )

1093
            for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
A
AUTOMATIC 已提交
1094 1095
                button.click(
                    fn=apply_styles,
1096
                    _js=js_func,
A
AUTOMATIC 已提交
1097 1098
                    inputs=[prompt, negative_prompt, style1, style2],
                    outputs=[prompt, negative_prompt, style1, style2],
A
AUTOMATIC 已提交
1099 1100
                )

1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116
            img2img_paste_fields = [
                (img2img_prompt, "Prompt"),
                (img2img_negative_prompt, "Negative prompt"),
                (steps, "Steps"),
                (sampler_index, "Sampler"),
                (restore_faces, "Face restoration"),
                (cfg_scale, "CFG scale"),
                (seed, "Seed"),
                (width, "Size-1"),
                (height, "Size-2"),
                (batch_size, "Batch size"),
                (subseed, "Variation seed"),
                (subseed_strength, "Variation seed strength"),
                (seed_resize_from_w, "Seed resize from-1"),
                (seed_resize_from_h, "Seed resize from-2"),
                (denoising_strength, "Denoising strength"),
1117 1118 1119 1120 1121 1122 1123 1124
                (aesthetic_lr_im, "Aesthetic LR"),
                (aesthetic_weight_im, "Aesthetic weight"),
                (aesthetic_steps_im, "Aesthetic steps"),
                (aesthetic_imgs_im, "Aesthetic embedding"),
                (aesthetic_slerp_im, "Aesthetic slerp"),
                (aesthetic_imgs_text_im, "Aesthetic text"),
                (aesthetic_text_negative_im, "Aesthetic text negative"),
                (aesthetic_slerp_angle_im, "Aesthetic slerp angle"),
1125
            ]
L
Liam 已提交
1126
            token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
1127

1128 1129 1130
    with gr.Blocks(analytics_enabled=False) as extras_interface:
        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):
1131
                with gr.Tabs(elem_id="mode_extras"):
A
ArrowM 已提交
1132
                    with gr.TabItem('Single Image'):
1133
                        extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil")
A
ArrowM 已提交
1134 1135

                    with gr.TabItem('Batch Process'):
1136
                        image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
A
AUTOMATIC 已提交
1137

1138 1139 1140 1141 1142 1143 1144 1145 1146
                    with gr.TabItem('Batch from Directory'):
                        extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs,
                            placeholder="A directory on the same machine where the server is running."
                        )
                        extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs,
                            placeholder="Leave blank to save images to the default path."
                        )
                        show_extras_results = gr.Checkbox(label='Show result images', value=True)

J
Justin Maier 已提交
1147 1148 1149 1150 1151 1152
                with gr.Tabs(elem_id="extras_resize_mode"):
                    with gr.TabItem('Scale by'):
                        upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
                    with gr.TabItem('Scale to'):
                        with gr.Group():
                            with gr.Row():
J
Justin Maier 已提交
1153 1154
                                upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
                                upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
J
Justin Maier 已提交
1155
                            upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
A
AUTOMATIC 已提交
1156 1157

                with gr.Group():
A
AUTOMATIC 已提交
1158
                    extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
A
AUTOMATIC 已提交
1159 1160

                with gr.Group():
M
Mykeehu 已提交
1161
                    extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
A
AUTOMATIC 已提交
1162 1163 1164
                    extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)

                with gr.Group():
1165 1166 1167 1168
                    gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan)

                with gr.Group():
                    codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
1169
                    codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)
1170 1171 1172 1173

                submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')

            with gr.Column(variant='panel'):
A
AUTOMATIC 已提交
1174
                result_images = gr.Gallery(label="Result", show_label=False)
1175 1176
                html_info_x = gr.HTML()
                html_info = gr.HTML()
S
Seki 已提交
1177
                extras_send_to_img2img = gr.Button('Send to img2img')
S
Seki 已提交
1178
                extras_send_to_inpaint = gr.Button('Send to inpaint')
M
Michoko 已提交
1179
                button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
M
Michoko 已提交
1180
                open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
1181

D
d8ahazard 已提交
1182

1183
        submit.click(
1184
            fn=wrap_gradio_gpu_call(modules.extras.run_extras),
1185
            _js="get_extras_tab_index",
1186
            inputs=[
J
Justin Maier 已提交
1187
                dummy_component,
1188
                dummy_component,
1189
                extras_image,
A
ArrowM 已提交
1190
                image_batch,
1191 1192 1193
                extras_batch_input_dir,
                extras_batch_output_dir,
                show_extras_results,
1194 1195 1196
                gfpgan_visibility,
                codeformer_visibility,
                codeformer_weight,
A
AUTOMATIC 已提交
1197
                upscaling_resize,
J
Justin Maier 已提交
1198 1199 1200
                upscaling_resize_w,
                upscaling_resize_h,
                upscaling_crop,
A
AUTOMATIC 已提交
1201 1202 1203
                extras_upscaler_1,
                extras_upscaler_2,
                extras_upscaler_2_visibility,
1204 1205
            ],
            outputs=[
A
ArrowM 已提交
1206
                result_images,
1207 1208 1209 1210
                html_info_x,
                html_info,
            ]
        )
J
Justin Maier 已提交
1211

S
Seki 已提交
1212 1213 1214 1215 1216 1217
        extras_send_to_img2img.click(
            fn=lambda x: image_from_url_text(x),
            _js="extract_image_from_gallery_img2img",
            inputs=[result_images],
            outputs=[init_img],
        )
J
Justin Maier 已提交
1218

S
Seki 已提交
1219 1220
        extras_send_to_inpaint.click(
            fn=lambda x: image_from_url_text(x),
1221
            _js="extract_image_from_gallery_inpaint",
S
Seki 已提交
1222 1223 1224
            inputs=[result_images],
            outputs=[init_img_with_mask],
        )
1225

1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240
    with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):
                image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")

            with gr.Column(variant='panel'):
                html = gr.HTML()
                generation_info = gr.Textbox(visible=False)
                html2 = gr.HTML()

                with gr.Row():
                    pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
                    pnginfo_send_to_img2img = gr.Button('Send to img2img')

        image.change(
1241
            fn=wrap_gradio_call(modules.extras.run_pnginfo),
1242 1243 1244
            inputs=[image],
            outputs=[html, generation_info, html2],
        )
Y
yfszzx 已提交
1245 1246 1247 1248 1249 1250
    #images history
    images_history_switch_dict = {
        "fn":modules.generation_parameters_copypaste.connect_paste,
        "t2i":txt2img_paste_fields,
        "i2i":img2img_paste_fields
    }
1251

1252
    images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
1253

1254 1255 1256
    with gr.Blocks() as modelmerger_interface:
        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):
1257
                gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
1258

1259
                with gr.Row():
1260 1261 1262
                    primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
                    secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
                    tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
S
safentisAuth 已提交
1263
                custom_name = gr.Textbox(label="Custom Name (Optional)")
1264 1265
                interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
                interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
M
Milly 已提交
1266
                save_as_half = gr.Checkbox(value=False, label="Save as float16")
1267
                modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
1268

1269
            with gr.Column(variant='panel'):
1270
                submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
1271

1272 1273
    sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()

A
AUTOMATIC 已提交
1274
    with gr.Blocks() as train_interface:
1275
        with gr.Row().style(equal_height=False):
A
AUTOMATIC 已提交
1276
            gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
1277

A
AUTOMATIC 已提交
1278 1279
        with gr.Row().style(equal_height=False):
            with gr.Tabs(elem_id="train_tabs"):
1280

A
AUTOMATIC 已提交
1281
                with gr.Tab(label="Create embedding"):
1282
                    new_embedding_name = gr.Textbox(label="Name")
1283
                    initialization_text = gr.Textbox(label="Initialization text", value="*")
1284
                    nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
D
DepFA 已提交
1285
                    overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
1286 1287 1288 1289 1290 1291

                    with gr.Row():
                        with gr.Column(scale=3):
                            gr.HTML(value="")

                        with gr.Column():
A
AUTOMATIC 已提交
1292
                            create_embedding = gr.Button(value="Create embedding", variant='primary')
1293

M
MalumaDev 已提交
1294 1295
                with gr.Tab(label="Create aesthetic images embedding"):

1296 1297 1298 1299 1300 1301 1302 1303 1304 1305
                    new_embedding_name_ae = gr.Textbox(label="Name")
                    process_src_ae = gr.Textbox(label='Source directory')
                    batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256)
                    with gr.Row():
                        with gr.Column(scale=3):
                            gr.HTML(value="")

                        with gr.Column():
                            create_embedding_ae = gr.Button(value="Create images embedding", variant='primary')

A
AUTOMATIC 已提交
1306
                with gr.Tab(label="Create hypernetwork"):
A
AUTOMATIC 已提交
1307
                    new_hypernetwork_name = gr.Textbox(label="Name")
1308
                    new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
1309
                    new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
1310
                    new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
D
DepFA 已提交
1311
                    overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
1312
                    new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
A
AUTOMATIC 已提交
1313 1314 1315 1316 1317 1318

                    with gr.Row():
                        with gr.Column(scale=3):
                            gr.HTML(value="")

                        with gr.Column():
A
AUTOMATIC 已提交
1319
                            create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
1320

A
AUTOMATIC 已提交
1321
                with gr.Tab(label="Preprocess images"):
1322 1323
                    process_src = gr.Textbox(label='Source directory')
                    process_dst = gr.Textbox(label='Destination directory')
A
alg-wiki 已提交
1324 1325
                    process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
                    process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
D
DepFA 已提交
1326
                    preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
1327 1328

                    with gr.Row():
D
DepFA 已提交
1329
                        process_flip = gr.Checkbox(label='Create flipped copies')
1330
                        process_split = gr.Checkbox(label='Split oversized images')
1331 1332
                        process_caption = gr.Checkbox(label='Use BLIP for caption')
                        process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
1333

1334 1335 1336 1337
                    with gr.Row(visible=False) as process_split_extra_row:
                        process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
                        process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)

1338 1339 1340 1341 1342 1343 1344
                    with gr.Row():
                        with gr.Column(scale=3):
                            gr.HTML(value="")

                        with gr.Column():
                            run_preprocess = gr.Button(value="Preprocess", variant='primary')

1345 1346 1347 1348 1349 1350
                    process_split.change(
                        fn=lambda show: gr_show(show),
                        inputs=[process_split],
                        outputs=[process_split_extra_row],
                    )

A
AUTOMATIC 已提交
1351
                with gr.Tab(label="Train"):
D
DepFA 已提交
1352
                    gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
1353
                    with gr.Row():
A
AUTOMATIC 已提交
1354
                        train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
1355 1356
                        create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
                    with gr.Row():
A
AUTOMATIC 已提交
1357
                        train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
1358
                        create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
D
DepFA 已提交
1359 1360 1361 1362
                    with gr.Row():
                        embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
                        hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
                    
1363
                    batch_size = gr.Number(label='Batch size', value=1, precision=0)
1364 1365 1366
                    dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
                    log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
                    template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
A
alg-wiki 已提交
1367 1368
                    training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
                    training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
1369
                    steps = gr.Number(label='Max steps', value=100000, precision=0)
1370 1371
                    create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
                    save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
D
DepFA 已提交
1372
                    save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
1373
                    preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
1374 1375

                    with gr.Row():
A
AUTOMATIC 已提交
1376 1377 1378
                        interrupt_training = gr.Button(value="Interrupt")
                        train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
                        train_embedding = gr.Button(value="Train Embedding", variant='primary')
1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393

            with gr.Column():
                progressbar = gr.HTML(elem_id="ti_progressbar")
                ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)

                ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
                ti_preview = gr.Image(elem_id='ti_preview', visible=False)
                ti_progress = gr.HTML(elem_id="ti_progress", value="")
                ti_outcome = gr.HTML(elem_id="ti_error", value="")
                setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)

        create_embedding.click(
            fn=modules.textual_inversion.ui.create_embedding,
            inputs=[
                new_embedding_name,
1394
                initialization_text,
1395
                nvpt,
D
DepFA 已提交
1396
                overwrite_old_embedding,
1397 1398 1399 1400 1401 1402 1403 1404
            ],
            outputs=[
                train_embedding_name,
                ti_output,
                ti_outcome,
            ]
        )

1405
        create_embedding_ae.click(
M
MalumaDev 已提交
1406
            fn=aesthetic_clip.generate_imgs_embd,
1407 1408 1409 1410 1411 1412 1413
            inputs=[
                new_embedding_name_ae,
                process_src_ae,
                batch_ae
            ],
            outputs=[
                aesthetic_imgs,
M
MalumaDev 已提交
1414
                aesthetic_imgs_im,
1415 1416 1417 1418 1419
                ti_output,
                ti_outcome,
            ]
        )

A
AUTOMATIC 已提交
1420
        create_hypernetwork.click(
A
AUTOMATIC 已提交
1421
            fn=modules.hypernetworks.ui.create_hypernetwork,
A
AUTOMATIC 已提交
1422 1423
            inputs=[
                new_hypernetwork_name,
1424
                new_hypernetwork_sizes,
D
DepFA 已提交
1425
                overwrite_old_hypernetwork,
1426 1427
                new_hypernetwork_layer_structure,
                new_hypernetwork_add_layer_norm,
D
update  
discus0434 已提交
1428
                new_hypernetwork_activation_func,
A
AUTOMATIC 已提交
1429 1430 1431 1432 1433 1434 1435 1436
            ],
            outputs=[
                train_hypernetwork_name,
                ti_output,
                ti_outcome,
            ]
        )

1437 1438 1439 1440 1441 1442
        run_preprocess.click(
            fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
            _js="start_training_textual_inversion",
            inputs=[
                process_src,
                process_dst,
A
alg-wiki 已提交
1443 1444
                process_width,
                process_height,
D
DepFA 已提交
1445
                preprocess_txt_action,
1446 1447 1448
                process_flip,
                process_split,
                process_caption,
1449 1450 1451
                process_caption_deepbooru,
                process_split_threshold,
                process_overlap_ratio,
1452 1453 1454 1455 1456 1457 1458
            ],
            outputs=[
                ti_output,
                ti_outcome,
            ],
        )

1459 1460 1461 1462 1463
        train_embedding.click(
            fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
            _js="start_training_textual_inversion",
            inputs=[
                train_embedding_name,
D
DepFA 已提交
1464
                embedding_learn_rate,
1465
                batch_size,
1466 1467
                dataset_directory,
                log_directory,
A
alg-wiki 已提交
1468 1469
                training_width,
                training_height,
1470 1471 1472 1473
                steps,
                create_image_every,
                save_embedding_every,
                template_file,
D
DepFA 已提交
1474
                save_image_with_stored_embedding,
1475 1476
                preview_from_txt2img,
                *txt2img_preview_params,
1477 1478 1479 1480 1481 1482 1483
            ],
            outputs=[
                ti_output,
                ti_outcome,
            ]
        )

A
AUTOMATIC 已提交
1484
        train_hypernetwork.click(
A
AUTOMATIC 已提交
1485
            fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
A
AUTOMATIC 已提交
1486 1487 1488
            _js="start_training_textual_inversion",
            inputs=[
                train_hypernetwork_name,
D
DepFA 已提交
1489
                hypernetwork_learn_rate,
1490
                batch_size,
A
AUTOMATIC 已提交
1491 1492
                dataset_directory,
                log_directory,
1493 1494
                training_width,
                training_height,
1495 1496 1497 1498
                steps,
                create_image_every,
                save_embedding_every,
                template_file,
1499 1500
                preview_from_txt2img,
                *txt2img_preview_params,
1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513
            ],
            outputs=[
                ti_output,
                ti_outcome,
            ]
        )

        interrupt_training.click(
            fn=lambda: shared.state.interrupt(),
            inputs=[],
            outputs=[],
        )

1514
    def create_setting_component(key, is_quicksettings=False):
1515 1516 1517 1518 1519 1520
        def fun():
            return opts.data[key] if key in opts.data else opts.data_labels[key].default

        info = opts.data_labels[key]
        t = type(info.default)

1521 1522
        args = info.component_args() if callable(info.component_args) else info.component_args

1523
        if info.component is not None:
1524
            comp = info.component
1525
        elif t == str:
1526
            comp = gr.Textbox
1527
        elif t == int:
1528
            comp = gr.Number
1529
        elif t == bool:
1530
            comp = gr.Checkbox
1531 1532 1533
        else:
            raise Exception(f'bad options item type: {str(t)} for key {key}')

A
AUTOMATIC 已提交
1534 1535
        elem_id = "setting_"+key

1536 1537
        if info.refresh is not None:
            if is_quicksettings:
A
AUTOMATIC 已提交
1538 1539
                res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
                create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
1540 1541
            else:
                with gr.Row(variant="compact"):
A
AUTOMATIC 已提交
1542 1543
                    res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
                    create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
1544
        else:
A
AUTOMATIC 已提交
1545
            res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
1546 1547 1548


        return res
1549

A
AUTOMATIC 已提交
1550
    components = []
1551
    component_dict = {}
A
AUTOMATIC 已提交
1552

M
Michoko 已提交
1553
    def open_folder(f):
1554
        if not os.path.exists(f):
C
CookieHCl 已提交
1555
            print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
1556 1557
            return
        elif not os.path.isdir(f):
1558 1559 1560 1561 1562 1563 1564 1565
            print(f"""
WARNING
An open_folder request was made with an argument that is not a folder.
This could be an error or a malicious attempt to run code on your computer.
Requested path was: {f}
""", file=sys.stderr)
            return

M
Michoko 已提交
1566 1567 1568 1569 1570 1571 1572 1573
        if not shared.cmd_opts.hide_ui_dir_config:
            path = os.path.normpath(f)
            if platform.system() == "Windows":
                os.startfile(path)
            elif platform.system() == "Darwin":
                sp.Popen(["open", path])
            else:
                sp.Popen(["xdg-open", path])
M
Michoko 已提交
1574

1575
    def run_settings(*args):
1576 1577 1578
        changed = 0

        for key, value, comp in zip(opts.data_labels.keys(), args, components):
1579 1580
            if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
                return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
1581

A
AUTOMATIC 已提交
1582
        for key, value, comp in zip(opts.data_labels.keys(), args, components):
1583 1584 1585
            if comp == dummy_component:
                continue

1586 1587 1588 1589
            comp_args = opts.data_labels[key].component_args
            if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
                continue

1590 1591 1592
            if cmd_opts.hide_ui_dir_config and key in restricted_opts:
                continue

1593
            oldval = opts.data.get(key, None)
1594
            opts.data[key] = value
1595

1596 1597 1598
            if oldval != value:
                if opts.data_labels[key].onchange is not None:
                    opts.data_labels[key].onchange()
1599

1600
                changed += 1
1601 1602 1603

        opts.save(shared.config_filename)

1604
        return f'{changed} settings changed.', opts.dumpjson()
1605

1606 1607 1608 1609
    def run_settings_single(value, key):
        if not opts.same_type(value, opts.data_labels[key].default):
            return gr.update(visible=True), opts.dumpjson()

1610 1611 1612
        if cmd_opts.hide_ui_dir_config and key in restricted_opts:
            return gr.update(value=oldval), opts.dumpjson()

1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623
        oldval = opts.data.get(key, None)
        opts.data[key] = value

        if oldval != value:
            if opts.data_labels[key].onchange is not None:
                opts.data_labels[key].onchange()

        opts.save(shared.config_filename)

        return gr.update(value=value), opts.dumpjson()

A
AUTOMATIC 已提交
1624
    with gr.Blocks(analytics_enabled=False) as settings_interface:
1625
        settings_submit = gr.Button(value="Apply settings", variant='primary')
A
AUTOMATIC 已提交
1626 1627
        result = gr.HTML()

1628 1629
        settings_cols = 3
        items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
A
AUTOMATIC 已提交
1630

1631 1632 1633
        quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
        quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings')

1634 1635
        quicksettings_list = []

1636 1637 1638 1639 1640 1641
        cols_displayed = 0
        items_displayed = 0
        previous_section = None
        column = None
        with gr.Row(elem_id="settings").style(equal_height=False):
            for i, (k, item) in enumerate(opts.data_labels.items()):
D
DepFA 已提交
1642

1643 1644 1645 1646
                if previous_section != item.section:
                    if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
                        if column is not None:
                            column.__exit__()
D
DepFA 已提交
1647

1648 1649
                        column = gr.Column(variant='panel')
                        column.__enter__()
A
AUTOMATIC 已提交
1650

1651 1652 1653 1654 1655 1656 1657
                        items_displayed = 0
                        cols_displayed += 1

                    previous_section = item.section

                    gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))

1658
                if k in quicksettings_names:
1659 1660 1661 1662 1663 1664 1665
                    quicksettings_list.append((i, k, item))
                    components.append(dummy_component)
                else:
                    component = create_setting_component(k)
                    component_dict[k] = component
                    components.append(component)
                    items_displayed += 1
1666

1667 1668
        with gr.Row():
            request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
A
AUTOMATIC 已提交
1669 1670 1671
            download_localization = gr.Button(value='Download localization template', elem_id="download_localization")

        with gr.Row():
1672 1673 1674
            reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
            restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')

1675 1676 1677 1678
        request_notifications.click(
            fn=lambda: None,
            inputs=[],
            outputs=[],
1679
            _js='function(){}'
1680 1681
        )

A
AUTOMATIC 已提交
1682 1683 1684 1685 1686 1687 1688
        download_localization.click(
            fn=lambda: None,
            inputs=[],
            outputs=[],
            _js='download_localization'
        )

D
DepFA 已提交
1689
        def reload_scripts():
D
DepFA 已提交
1690
            modules.scripts.reload_script_body_only()
1691
            reload_javascript() # need to refresh the html page
D
DepFA 已提交
1692 1693 1694 1695 1696 1697 1698

        reload_script_bodies.click(
            fn=reload_scripts,
            inputs=[],
            outputs=[],
            _js='function(){}'
        )
1699 1700

        def request_restart():
1701
            shared.state.interrupt()
D
DepFA 已提交
1702
            settings_interface.gradio_ref.do_restart = True
1703 1704 1705 1706 1707 1708 1709

        restart_gradio.click(
            fn=request_restart,
            inputs=[],
            outputs=[],
            _js='function(){restart_reload()}'
        )
J
Justin Maier 已提交
1710

1711 1712 1713
        if column is not None:
            column.__exit__()

1714
    interfaces = [
A
AUTOMATIC 已提交
1715 1716 1717 1718
        (txt2img_interface, "txt2img", "txt2img"),
        (img2img_interface, "img2img", "img2img"),
        (extras_interface, "Extras", "extras"),
        (pnginfo_interface, "PNG Info", "pnginfo"),
1719
        (images_history, "History", "images_history"),
1720
        (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
A
AUTOMATIC 已提交
1721
        (train_interface, "Train", "ti"),
A
AUTOMATIC 已提交
1722
        (settings_interface, "Settings", "settings"),
1723 1724 1725 1726 1727
    ]

    with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
        css = file.read()

A
typo  
AUTOMATIC 已提交
1728
    if os.path.exists(os.path.join(script_path, "user.css")):
A
AUTOMATIC 已提交
1729 1730 1731 1732
        with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
            usercss = file.read()
            css += usercss

1733 1734 1735
    if not cmd_opts.no_progressbar_hiding:
        css += css_hide_progressbar

A
AUTOMATIC 已提交
1736
    with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
1737 1738
        with gr.Row(elem_id="quicksettings"):
            for i, k, item in quicksettings_list:
1739
                component = create_setting_component(k, is_quicksettings=True)
1740 1741
                component_dict[k] = component

D
DepFA 已提交
1742
        settings_interface.gradio_ref = demo
J
Justin Maier 已提交
1743

1744
        with gr.Tabs(elem_id="tabs") as tabs:
A
AUTOMATIC 已提交
1745
            for interface, label, ifid in interfaces:
1746
                with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
A
AUTOMATIC 已提交
1747
                    interface.render()
J
Justin Maier 已提交
1748

1749 1750
        if os.path.exists(os.path.join(script_path, "notification.mp3")):
            audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
A
AUTOMATIC 已提交
1751

1752
        text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
1753
        settings_submit.click(
1754 1755 1756
            fn=run_settings,
            inputs=components,
            outputs=[result, text_settings],
1757
        )
1758 1759 1760 1761 1762 1763 1764 1765 1766 1767

        for i, k, item in quicksettings_list:
            component = component_dict[k]

            component.change(
                fn=lambda value, k=k: run_settings_single(value, key=k),
                inputs=[component],
                outputs=[component, text_settings],
            )

S
safentisAuth 已提交
1768 1769
        def modelmerger(*args):
            try:
1770
                results = modules.extras.run_modelmerger(*args)
S
safentisAuth 已提交
1771 1772 1773
            except Exception as e:
                print("Error loading/saving model file:", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)
1774
                modules.sd_models.list_models()  # to remove the potentially missing models from the list
S
safentisAuth 已提交
1775 1776
                return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
            return results
1777

1778
        modelmerger_merge.click(
S
safentisAuth 已提交
1779
            fn=modelmerger,
1780 1781 1782
            inputs=[
                primary_model_name,
                secondary_model_name,
1783
                tertiary_model_name,
1784 1785 1786
                interp_method,
                interp_amount,
                save_as_half,
S
safentisAuth 已提交
1787
                custom_name,
1788 1789 1790 1791 1792
            ],
            outputs=[
                submit_result,
                primary_model_name,
                secondary_model_name,
1793
                tertiary_model_name,
1794 1795 1796
                component_dict['sd_model_checkpoint'],
            ]
        )
1797 1798 1799
        paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2']
        txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names]
        img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]
A
AUTOMATIC 已提交
1800
        send_to_img2img.click(
1801 1802 1803 1804
            fn=lambda img, *args: (image_from_url_text(img),*args),
            _js="(gallery, ...args) => [extract_image_from_gallery_img2img(gallery), ...args]",
            inputs=[txt2img_gallery] + txt2img_fields,
            outputs=[init_img] + img2img_fields,
A
AUTOMATIC 已提交
1805 1806 1807
        )

        send_to_inpaint.click(
1808 1809 1810 1811
            fn=lambda x, *args: (image_from_url_text(x), *args),
            _js="(gallery, ...args) => [extract_image_from_gallery_inpaint(gallery), ...args]",
            inputs=[txt2img_gallery] + txt2img_fields,
            outputs=[init_img_with_mask] + img2img_fields,
A
AUTOMATIC 已提交
1812 1813 1814 1815
        )

        img2img_send_to_img2img.click(
            fn=lambda x: image_from_url_text(x),
1816
            _js="extract_image_from_gallery_img2img",
A
AUTOMATIC 已提交
1817 1818 1819 1820 1821 1822
            inputs=[img2img_gallery],
            outputs=[init_img],
        )

        img2img_send_to_inpaint.click(
            fn=lambda x: image_from_url_text(x),
1823
            _js="extract_image_from_gallery_inpaint",
A
AUTOMATIC 已提交
1824 1825 1826 1827 1828 1829
            inputs=[img2img_gallery],
            outputs=[init_img_with_mask],
        )

        send_to_extras.click(
            fn=lambda x: image_from_url_text(x),
S
Seki 已提交
1830
            _js="extract_image_from_gallery_extras",
A
AUTOMATIC 已提交
1831
            inputs=[txt2img_gallery],
1832
            outputs=[extras_image],
A
AUTOMATIC 已提交
1833 1834
        )

M
Michoko 已提交
1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852
        open_txt2img_folder.click(
            fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples),
            inputs=[],
            outputs=[],
        )

        open_img2img_folder.click(
            fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples),
            inputs=[],
            outputs=[],
        )

        open_extras_folder.click(
            fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples),
            inputs=[],
            outputs=[],
        )

A
AUTOMATIC 已提交
1853 1854
        img2img_send_to_extras.click(
            fn=lambda x: image_from_url_text(x),
S
Seki 已提交
1855
            _js="extract_image_from_gallery_extras",
A
AUTOMATIC 已提交
1856
            inputs=[img2img_gallery],
1857
            outputs=[extras_image],
A
AUTOMATIC 已提交
1858
        )
1859

1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875
        settings_map = {
            'sd_hypernetwork': 'Hypernet',
            'CLIP_stop_at_last_layers': 'Clip skip',
            'sd_model_checkpoint': 'Model hash',
        }

        settings_paste_fields = [
            (component_dict[k], lambda d, k=k, v=v: apply_setting(k, d.get(v, None)))
            for k, v in settings_map.items()
        ]

        modules.generation_parameters_copypaste.connect_paste(txt2img_paste, txt2img_paste_fields + settings_paste_fields, txt2img_prompt)
        modules.generation_parameters_copypaste.connect_paste(img2img_paste, img2img_paste_fields + settings_paste_fields, img2img_prompt)

        modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_txt2img')
        modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_img2img_img2img')
1876

1877
    ui_config_file = cmd_opts.ui_config_file
A
AUTOMATIC 已提交
1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891
    ui_settings = {}
    settings_count = len(ui_settings)
    error_loading = False

    try:
        if os.path.exists(ui_config_file):
            with open(ui_config_file, "r", encoding="utf8") as file:
                ui_settings = json.load(file)
    except Exception:
        error_loading = True
        print("Error loading settings:", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)

    def loadsave(path, x):
ふぁ 已提交
1892
        def apply_field(obj, field, condition=None, init_field=None):
A
AUTOMATIC 已提交
1893
            key = path + "/" + field
1894 1895 1896

            if getattr(obj,'custom_script_source',None) is not None:
              key = 'customscript/' + obj.custom_script_source + '/' + key
J
Justin Maier 已提交
1897

A
AUTOMATIC 已提交
1898 1899
            if getattr(obj, 'do_not_save_to_config', False):
                return
J
Justin Maier 已提交
1900

A
AUTOMATIC 已提交
1901 1902 1903
            saved_value = ui_settings.get(key, None)
            if saved_value is None:
                ui_settings[key] = getattr(obj, field)
C
CookieHCl 已提交
1904 1905 1906
            elif condition and not condition(saved_value):
                print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
            else:
A
AUTOMATIC 已提交
1907
                setattr(obj, field, saved_value)
ふぁ 已提交
1908 1909
                if init_field is not None:
                    init_field(saved_value)
A
AUTOMATIC 已提交
1910

A
AUTOMATIC 已提交
1911 1912 1913
        if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
            apply_field(x, 'visible')

A
AUTOMATIC 已提交
1914 1915 1916 1917 1918 1919 1920
        if type(x) == gr.Slider:
            apply_field(x, 'value')
            apply_field(x, 'minimum')
            apply_field(x, 'maximum')
            apply_field(x, 'step')

        if type(x) == gr.Radio:
1921
            apply_field(x, 'value', lambda val: val in x.choices)
A
AUTOMATIC 已提交
1922

D
DepFA 已提交
1923
        if type(x) == gr.Checkbox:
D
DepFA 已提交
1924
            apply_field(x, 'value')
D
DepFA 已提交
1925 1926

        if type(x) == gr.Textbox:
D
DepFA 已提交
1927
            apply_field(x, 'value')
J
Justin Maier 已提交
1928

D
DepFA 已提交
1929
        if type(x) == gr.Number:
D
DepFA 已提交
1930
            apply_field(x, 'value')
J
Justin Maier 已提交
1931

1932 1933 1934
        # Since there are many dropdowns that shouldn't be saved,
        # we only mark dropdowns that should be saved.
        if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False):
ふぁ 已提交
1935 1936
            apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
            apply_field(x, 'visible')
1937

A
AUTOMATIC 已提交
1938 1939
    visit(txt2img_interface, loadsave, "txt2img")
    visit(img2img_interface, loadsave, "img2img")
1940
    visit(extras_interface, loadsave, "extras")
1941
    visit(modelmerger_interface, loadsave, "modelmerger")
A
AUTOMATIC 已提交
1942 1943 1944 1945 1946

    if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
        with open(ui_config_file, "w", encoding="utf8") as file:
            json.dump(ui_settings, file, indent=4)

1947 1948 1949
    return demo


1950 1951 1952
def load_javascript(raw_response):
    with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
        javascript = f'<script>{jsfile.read()}</script>'
A
AUTOMATIC 已提交
1953

1954 1955 1956 1957
    jsdir = os.path.join(script_path, "javascript")
    for filename in sorted(os.listdir(jsdir)):
        with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
            javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
1958

1959 1960
    if cmd_opts.theme is not None:
        javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"
1961

1962
    javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
1963

D
DepFA 已提交
1964
    def template_response(*args, **kwargs):
1965 1966 1967
        res = raw_response(*args, **kwargs)
        res.body = res.body.replace(
            b'</head>', f'{javascript}</head>'.encode("utf8"))
D
DepFA 已提交
1968 1969 1970 1971
        res.init_headers()
        return res

    gradio.routes.templates.TemplateResponse = template_response
Y
yfszzx 已提交
1972

1973 1974 1975 1976

reload_javascript = partial(load_javascript,
                            gradio.routes.templates.TemplateResponse)
reload_javascript()