ui.py 95.5 KB
Newer Older
1 2 3 4 5
import json
import mimetypes
import os
import sys
import traceback
A
AUTOMATIC 已提交
6
from functools import reduce
A
AUTOMATIC 已提交
7
import warnings
8

D
discus0434 已提交
9 10 11
import gradio as gr
import gradio.routes
import gradio.utils
A
AUTOMATIC 已提交
12
import numpy as np
A
AUTOMATIC 已提交
13
from PIL import Image, PngImagePlugin  # noqa: F401
14
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
15

A
AUTOMATIC 已提交
16
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave
A
AUTOMATIC 已提交
17
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
18
from modules.paths import script_path, data_path
19

A
AUTOMATIC 已提交
20
from modules.shared import opts, cmd_opts
M
init  
MalumaDev 已提交
21

D
discus0434 已提交
22
import modules.codeformer_model
Y
yfszzx 已提交
23
import modules.generation_parameters_copypaste as parameters_copypaste
D
discus0434 已提交
24 25
import modules.gfpgan_model
import modules.hypernetworks.ui
A
AUTOMATIC 已提交
26
import modules.scripts
D
discus0434 已提交
27
import modules.shared as shared
A
AUTOMATIC 已提交
28
import modules.styles
D
discus0434 已提交
29
import modules.textual_inversion.ui
A
AUTOMATIC 已提交
30
from modules import prompt_parser
D
discus0434 已提交
31 32
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
33
from modules.textual_inversion import textual_inversion
34
import modules.hypernetworks.ui
Y
yfszzx 已提交
35
from modules.generation_parameters_copypaste import image_from_url_text
36
import modules.extras
37

A
AUTOMATIC 已提交
38 39
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)

A
Aidan Holland 已提交
40
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
41 42 43
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')

44
if not cmd_opts.share and not cmd_opts.listen:
45 46 47 48
    # fix gradio phoning home
    gradio.utils.version_check = lambda: None
    gradio.utils.get_local_ip_address = lambda: '127.0.0.1'

Y
Yuval Aboulafia 已提交
49
if cmd_opts.ngrok is not None:
J
JamnedZ 已提交
50 51
    import modules.ngrok as ngrok
    print('ngrok authtoken detected, trying to connect...')
Y
Yuval Aboulafia 已提交
52 53 54
    ngrok.connect(
        cmd_opts.ngrok,
        cmd_opts.port if cmd_opts.port is not None else 7860,
B
bobzilladev 已提交
55
        cmd_opts.ngrok_options
Y
Yuval Aboulafia 已提交
56
        )
J
JamnedZ 已提交
57

58 59 60 61 62 63 64 65

def gr_show(visible=True):
    return {"visible": visible, "__type__": "update"}


sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None

66 67 68 69
# Using constants for these since the variation selector isn't visible.
# Important that they exactly match script.js for tooltip to work.
random_symbol = '\U0001f3b2\ufe0f'  # 🎲️
reuse_symbol = '\u267b\ufe0f'  # ♻️
70
paste_symbol = '\u2199\ufe0f'  # ↙
71
refresh_symbol = '\U0001f504'  # 🔄
A
AUTOMATIC 已提交
72 73
save_style_symbol = '\U0001f4be'  # 💾
apply_style_symbol = '\U0001f4cb'  # 📋
A
AUTOMATIC 已提交
74
clear_prompt_symbol = '\U0001f5d1\ufe0f'  # 🗑️
A
AUTOMATIC 已提交
75
extra_networks_symbol = '\U0001F3B4'  # 🎴
76
switch_values_symbol = '\U000021C5' # ⇅
77
restore_progress_symbol = '\U0001F300' # 🌀
78
detect_image_size_symbol = '\U0001F4D0'  # 📐
79

80

81
def plaintext_to_html(text):
82
    return ui_common.plaintext_to_html(text)
83

84 85 86 87 88 89 90

def send_gradio_gallery_to_image(x):
    if len(x) == 0:
        return None
    return image_from_url_text(x[0])


91 92
def add_style(name: str, prompt: str, negative_prompt: str):
    if name is None:
A
AUTOMATIC 已提交
93
        return [gr_show() for x in range(4)]
A
AUTOMATIC 已提交
94

95
    style = modules.styles.PromptStyle(name, prompt, negative_prompt)
A
AUTOMATIC 已提交
96
    shared.prompt_styles.styles[style.name] = style
97 98
    # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
    # reserialize all styles every time we save them
A
AUTOMATIC 已提交
99
    shared.prompt_styles.save_styles(shared.styles_filename)
A
AUTOMATIC 已提交
100

101
    return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]
A
AUTOMATIC 已提交
102

103 104 105 106 107 108 109 110 111 112 113 114

def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
    from modules import processing, devices

    if not enable:
        return ""

    p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)

    with devices.autocast():
        p.init([""], [0], [0])

115
    return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
116

A
AUTOMATIC 已提交
117

118 119 120 121 122 123 124 125 126 127
def resize_from_to_html(width, height, scale_by):
    target_width = int(width * scale_by)
    target_height = int(height * scale_by)

    if not target_width or not target_height:
        return "no image selected"

    return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"


128 129 130
def apply_styles(prompt, prompt_neg, styles):
    prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
    prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
A
AUTOMATIC 已提交
131

132
    return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]
A
AUTOMATIC 已提交
133 134


135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
    if mode in {0, 1, 3, 4}:
        return [interrogation_function(ii_singles[mode]), None]
    elif mode == 2:
        return [interrogation_function(ii_singles[mode]["image"]), None]
    elif mode == 5:
        assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
        images = shared.listfiles(ii_input_dir)
        print(f"Will process {len(images)} images.")
        if ii_output_dir != "":
            os.makedirs(ii_output_dir, exist_ok=True)
        else:
            ii_output_dir = ii_input_dir

        for image in images:
            img = Image.open(image)
            filename = os.path.basename(image)
            left, _ = os.path.splitext(filename)
153
            print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a'))
154

155
        return [gr.update(), None]
156 157


A
AUTOMATIC 已提交
158
def interrogate(image):
A
Allen Benz 已提交
159
    prompt = shared.interrogator.interrogate(image.convert("RGB"))
160
    return gr.update() if prompt is None else prompt
A
AUTOMATIC 已提交
161

A
AUTOMATIC 已提交
162

G
Greendayle 已提交
163
def interrogate_deepbooru(image):
164
    prompt = deepbooru.model.tag(image)
165
    return gr.update() if prompt is None else prompt
G
Greendayle 已提交
166 167


168
def create_seed_inputs(target_interface):
169 170
    with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
        seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")
171
        seed.style(container=False)
172 173
        random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
        reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
174

175
        seed_checkbox = gr.Checkbox(label='Extra', elem_id=f"{target_interface}_subseed_show", value=False)
176 177 178 179

    # Components to show/hide based on the 'Extra' checkbox
    seed_extras = []

180
    with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
181
        seed_extras.append(seed_extra_row_1)
182
        subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed")
183
        subseed.style(container=False)
184 185 186
        random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
        reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
        subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
187

188
    with FormRow(visible=False) as seed_extra_row_2:
189
        seed_extras.append(seed_extra_row_2)
190 191
        seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=f"{target_interface}_seed_resize_from_w")
        seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=f"{target_interface}_seed_resize_from_h")
192

193 194
    random_seed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_seed')}", show_progress=False, inputs=[], outputs=[])
    random_subseed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_subseed')}", show_progress=False, inputs=[], outputs=[])
195 196 197 198 199 200

    def change_visibility(show):
        return {comp: gr_show(show) for comp in seed_extras}

    seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)

201
    return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
202 203


204

P
papuSpartan 已提交
205
def connect_clear_prompt(button):
P
papuSpartan 已提交
206
    """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
207
    button.click(
208
        _js="clear_prompt",
P
papuSpartan 已提交
209
        fn=None,
P
papuSpartan 已提交
210 211
        inputs=[],
        outputs=[],
212
    )
213 214


215 216 217
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
    """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
        (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
218
        was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
219 220 221
    def copy_seed(gen_info_string: str, index):
        res = -1

222 223
        try:
            gen_info = json.loads(gen_info_string)
224 225 226 227 228
            index -= gen_info.get('index_of_first_image', 0)

            if is_subseed and gen_info.get('subseed_strength', 0) > 0:
                all_subseeds = gen_info.get('all_subseeds', [-1])
                res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
229
            else:
230 231 232
                all_seeds = gen_info.get('all_seeds', [-1])
                res = all_seeds[index if 0 <= index < len(all_seeds) else 0]

A
AUTOMATIC 已提交
233
        except json.decoder.JSONDecodeError:
234 235 236
            if gen_info_string != '':
                print("Error parsing JSON generation info:", file=sys.stderr)
                print(gen_info_string, file=sys.stderr)
237 238

        return [res, gr_show(False)]
239 240 241

    reuse_seed.click(
        fn=copy_seed,
242
        _js="(x, y) => [x, selected_gallery_index()]",
243
        show_progress=False,
244 245
        inputs=[generation_info, dummy_component],
        outputs=[seed, dummy_component]
246 247
    )

248

L
Liam 已提交
249
def update_token_counter(text, steps):
250
    try:
A
AUTOMATIC 已提交
251 252
        text, _ = extra_networks.parse_prompt(text)

A
AUTOMATIC 已提交
253 254 255
        _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
        prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)

256 257 258 259 260
    except Exception:
        # a parsing error can happen here during typing, and we don't want to bother the user with
        # messages related to it in console
        prompt_schedules = [[[steps, text]]]

L
Liam 已提交
261
    flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
262
    prompts = [prompt_text for step, prompt_text in flat_prompts]
A
AUTOMATIC 已提交
263
    token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
264
    return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
A
AUTOMATIC 已提交
265

266

A
AUTOMATIC 已提交
267
def create_toprow(is_img2img):
268 269
    id_part = "img2img" if is_img2img else "txt2img"

270 271
    with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
        with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
A
AUTOMATIC 已提交
272
            with gr.Row():
273
                with gr.Column(scale=80):
A
AUTOMATIC 已提交
274
                    with gr.Row():
275
                        prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)")
A
aoirusann 已提交
276

A
AUTOMATIC 已提交
277
            with gr.Row():
A
AUTOMATIC 已提交
278
                with gr.Column(scale=80):
B
Ben 已提交
279
                    with gr.Row():
A
AUTOMATIC 已提交
280
                        negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
A
aoirusann 已提交
281

A
AUTOMATIC 已提交
282 283 284
        button_interrogate = None
        button_deepbooru = None
        if is_img2img:
A
AUTOMATIC 已提交
285
            with gr.Column(scale=1, elem_classes="interrogate-col"):
A
AUTOMATIC 已提交
286
                button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
A
AUTOMATIC 已提交
287
                button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
A
AUTOMATIC 已提交
288

289
        with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
A
AUTOMATIC 已提交
290 291 292
            with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
                interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
                skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
293
                submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
294

295 296 297 298 299 300
                skip.click(
                    fn=lambda: shared.state.skip(),
                    inputs=[],
                    outputs=[],
                )

301 302 303 304 305
                interrupt.click(
                    fn=lambda: shared.state.interrupt(),
                    inputs=[],
                    outputs=[],
                )
A
AUTOMATIC 已提交
306

307 308 309 310 311 312
            with gr.Row(elem_id=f"{id_part}_tools"):
                paste = ToolButton(value=paste_symbol, elem_id="paste")
                clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
                extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
                prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
                save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
313
                restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
314

A
AUTOMATIC 已提交
315
                token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
316
                token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
A
AUTOMATIC 已提交
317
                negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
318 319 320 321 322 323 324 325 326
                negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")

                clear_prompt_button.click(
                    fn=lambda *x: x,
                    _js="confirm_clear_prompt",
                    inputs=[prompt, negative_prompt],
                    outputs=[prompt, negative_prompt],
                )

A
AUTOMATIC 已提交
327
            with gr.Row(elem_id=f"{id_part}_styles_row"):
328 329
                prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
                create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
A
AUTOMATIC 已提交
330

331
    return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button
A
AUTOMATIC 已提交
332 333


334
def setup_progressbar(*args, **kwargs):
335
    pass
A
AUTOMATIC 已提交
336 337


338 339 340 341
def apply_setting(key, value):
    if value is None:
        return gr.update()

342 343 344
    if shared.cmd_opts.freeze_settings:
        return gr.update()

345 346 347 348
    # dont allow model to be swapped when model hash exists in prompt
    if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
        return gr.update()

349 350 351 352 353 354 355 356 357 358 359 360 361
    if key == "sd_model_checkpoint":
        ckpt_info = sd_models.get_closet_checkpoint_match(value)

        if ckpt_info is not None:
            value = ckpt_info.title
        else:
            return gr.update()

    comp_args = opts.data_labels[key].component_args
    if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
        return

    valtype = type(opts.data_labels[key].default)
362
    oldval = opts.data.get(key, None)
363 364 365 366 367
    opts.data[key] = valtype(value) if valtype != type(None) else value
    if oldval != value and opts.data_labels[key].onchange is not None:
        opts.data_labels[key].onchange()

    opts.save(shared.config_filename)
368
    return getattr(opts, key)
369

370

371 372 373 374
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
    def refresh():
        refresh_method()
        args = refreshed_args() if callable(refreshed_args) else refreshed_args
375

376 377
        for k, v in args.items():
            setattr(refresh_component, k, v)
378

379
        return gr.update(**(args or {}))
380

381
    refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
382 383 384 385 386 387
    refresh_button.click(
        fn=refresh,
        inputs=[],
        outputs=[refresh_component]
    )
    return refresh_button
388 389


390
def create_output_panel(tabname, outdir):
391
    return ui_common.create_output_panel(tabname, outdir)
J
Justin Maier 已提交
392

A
aoirusann 已提交
393

394 395
def create_sampler_and_steps_selection(choices, tabname):
    if opts.samplers_in_dropdown:
396
        with FormRow(elem_id=f"sampler_selection_{tabname}"):
397
            sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
398
            steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
399
    else:
400
        with FormGroup(elem_id=f"sampler_selection_{tabname}"):
401
            steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
402 403 404
            sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")

    return steps, sampler_index
405

406

407
def ordered_ui_categories():
A
AUTOMATIC 已提交
408
    user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
409

A
AUTOMATIC 已提交
410
    for _, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
411 412 413
        yield category


414 415 416 417 418 419 420 421 422 423
def get_value_for_setting(key):
    value = getattr(opts, key)

    info = opts.data_labels[key]
    args = info.component_args() if callable(info.component_args) else info.component_args or {}
    args = {k: v for k, v in args.items() if k not in {'precision'}}

    return gr.update(value=value, **args)


424 425 426 427 428 429 430 431 432 433 434 435
def create_override_settings_dropdown(tabname, row):
    dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)

    dropdown.change(
        fn=lambda x: gr.Dropdown.update(visible=len(x) > 0),
        inputs=[dropdown],
        outputs=[dropdown],
    )

    return dropdown


436
def create_ui():
437 438
    import modules.img2img
    import modules.txt2img
439

440 441
    reload_javascript()

442
    parameters_copypaste.reset()
443

444 445 446
    modules.scripts.scripts_current = modules.scripts.scripts_txt2img
    modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)

447
    with gr.Blocks(analytics_enabled=False) as txt2img_interface:
448
        txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)
P
papuSpartan 已提交
449

450
        dummy_component = gr.Label(visible=False)
A
AUTOMATIC 已提交
451
        txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
452

A
AUTOMATIC 已提交
453 454 455 456
        with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
            from modules import ui_extra_networks
            extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')

457
        with gr.Row().style(equal_height=False):
A
AUTOMATIC 已提交
458
            with gr.Column(variant='compact', elem_id="txt2img_settings"):
459 460 461
                for category in ordered_ui_categories():
                    if category == "sampler":
                        steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
A
AUTOMATIC 已提交
462

463 464 465 466 467 468
                    elif category == "dimensions":
                        with FormRow():
                            with gr.Column(elem_id="txt2img_column_size", scale=4):
                                width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
                                height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")

A
AUTOMATIC 已提交
469
                            with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
470
                                res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
A
AUTOMATIC 已提交
471

472 473 474 475 476 477 478 479 480 481 482 483
                            if opts.dimensions_and_batch_together:
                                with gr.Column(elem_id="txt2img_column_batch"):
                                    batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
                                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")

                    elif category == "cfg":
                        cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")

                    elif category == "seed":
                        seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')

                    elif category == "checkboxes":
A
AUTOMATIC 已提交
484
                        with FormRow(elem_classes="checkboxes-row", variant="compact"):
485 486 487
                            restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
                            tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
                            enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
488
                            hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
489 490

                    elif category == "hires_fix":
491
                        with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
492
                            with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
493 494 495 496
                                hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
                                hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
                                denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")

497
                            with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
498 499 500
                                hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
                                hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
                                hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
501 502 503 504 505 506 507

                    elif category == "batch":
                        if not opts.dimensions_and_batch_together:
                            with FormRow(elem_id="txt2img_column_batch"):
                                batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
                                batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")

508 509 510 511
                    elif category == "override_settings":
                        with FormRow(elem_id="txt2img_override_settings_row") as row:
                            override_settings = create_override_settings_dropdown('txt2img', row)

512 513 514
                    elif category == "scripts":
                        with FormGroup(elem_id="txt2img_script_container"):
                            custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
515

516 517
            hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
            for input in hr_resolution_preview_inputs:
518 519 520 521 522 523 524 525 526 527 528 529 530
                input.change(
                    fn=calc_resolution_hires,
                    inputs=hr_resolution_preview_inputs,
                    outputs=[hr_final_resolution],
                    show_progress=False,
                )
                input.change(
                    None,
                    _js="onCalcResolutionHires",
                    inputs=hr_resolution_preview_inputs,
                    outputs=[],
                    show_progress=False,
                )
531

532
            txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
533

534 535
            connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
            connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
536

537
            txt2img_args = dict(
538
                fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
A
AUTOMATIC 已提交
539
                _js="submit",
540
                inputs=[
541
                    dummy_component,
A
AUTOMATIC 已提交
542
                    txt2img_prompt,
543
                    txt2img_negative_prompt,
544
                    txt2img_prompt_styles,
545 546
                    steps,
                    sampler_index,
A
AUTOMATIC 已提交
547
                    restore_faces,
548
                    tiling,
549 550 551 552
                    batch_count,
                    batch_size,
                    cfg_scale,
                    seed,
553
                    subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
554 555
                    height,
                    width,
A
AUTOMATIC 已提交
556 557
                    enable_hr,
                    denoising_strength,
A
AUTOMATIC 已提交
558 559
                    hr_scale,
                    hr_upscaler,
560 561 562
                    hr_second_pass_steps,
                    hr_resize_x,
                    hr_resize_y,
563
                    override_settings,
A
AUTOMATIC 已提交
564
                ] + custom_inputs,
565

566 567 568
                outputs=[
                    txt2img_gallery,
                    generation_info,
569 570
                    html_info,
                    html_log,
571 572
                ],
                show_progress=False,
573 574
            )

A
AUTOMATIC 已提交
575
            txt2img_prompt.submit(**txt2img_args)
576
            submit.click(**txt2img_args)
A
AUTOMATIC 已提交
577

578
            res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
579

580 581 582 583 584 585 586 587 588 589 590 591 592
            restore_progress_button.click(
                fn=progress.restore_progress,
                _js="restoreProgressTxt2img",
                inputs=[dummy_component],
                outputs=[
                    txt2img_gallery,
                    generation_info,
                    html_info,
                    html_log,
                ],
                show_progress=False,
            )

D
d8ahazard 已提交
593 594 595 596 597 598 599 600 601 602 603
            txt_prompt_img.change(
                fn=modules.images.image_data,
                inputs=[
                    txt_prompt_img
                ],
                outputs=[
                    txt2img_prompt,
                    txt_prompt_img
                ]
            )

A
AUTOMATIC 已提交
604 605 606 607
            enable_hr.change(
                fn=lambda x: gr_show(x),
                inputs=[enable_hr],
                outputs=[hr_options],
608
                show_progress = False,
A
AUTOMATIC 已提交
609 610
            )

611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628
            txt2img_paste_fields = [
                (txt2img_prompt, "Prompt"),
                (txt2img_negative_prompt, "Negative prompt"),
                (steps, "Steps"),
                (sampler_index, "Sampler"),
                (restore_faces, "Face restoration"),
                (cfg_scale, "CFG scale"),
                (seed, "Seed"),
                (width, "Size-1"),
                (height, "Size-2"),
                (batch_size, "Batch size"),
                (subseed, "Variation seed"),
                (subseed_strength, "Variation seed strength"),
                (seed_resize_from_w, "Seed resize from-1"),
                (seed_resize_from_h, "Seed resize from-2"),
                (denoising_strength, "Denoising strength"),
                (enable_hr, lambda d: "Denoising strength" in d),
                (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
A
AUTOMATIC 已提交
629 630
                (hr_scale, "Hires upscale"),
                (hr_upscaler, "Hires upscaler"),
631 632 633
                (hr_second_pass_steps, "Hires steps"),
                (hr_resize_x, "Hires resize-1"),
                (hr_resize_y, "Hires resize-2"),
634
                *modules.scripts.scripts_txt2img.infotext_fields
635
            ]
636
            parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
637
            parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
638
                paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
639
            ))
640 641 642 643 644 645 646 647 648 649 650 651

            txt2img_preview_params = [
                txt2img_prompt,
                txt2img_negative_prompt,
                steps,
                sampler_index,
                cfg_scale,
                seed,
                width,
                height,
            ]

652
            token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
653
            negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
654

A
AUTOMATIC 已提交
655 656
            ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)

657 658
    modules.scripts.scripts_current = modules.scripts.scripts_img2img
    modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
659

660
    with gr.Blocks(analytics_enabled=False) as img2img_interface:
661
        img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True)
662

A
AUTOMATIC 已提交
663
        img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
664

A
AUTOMATIC 已提交
665 666 667 668
        with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
            from modules import ui_extra_networks
            extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')

669
        with FormRow().style(equal_height=False):
A
AUTOMATIC 已提交
670
            with gr.Column(variant='compact', elem_id="img2img_settings"):
671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686
                copy_image_buttons = []
                copy_image_destinations = {}

                def add_copy_image_controls(tab_name, elem):
                    with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"):
                        gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}")

                        for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']):
                            if name == tab_name:
                                gr.Button(title, interactive=False)
                                copy_image_destinations[name] = elem
                                continue

                            button = gr.Button(title)
                            copy_image_buttons.append((button, name, elem))

687
                with gr.Tabs(elem_id="mode_img2img"):
688 689
                    img2img_selected_tab = gr.State(0)

690
                    with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
691
                        init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=opts.img2img_editor_height)
692
                        add_copy_image_controls('img2img', init_img)
693

694
                    with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
695
                        sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
696
                        add_copy_image_controls('sketch', sketch)
697

698
                    with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
699
                        init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
700
                        add_copy_image_controls('inpaint', init_img_with_mask)
701

702
                    with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
703
                        inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
704
                        inpaint_color_sketch_orig = gr.State(None)
705
                        add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
706

707 708 709 710 711 712
                        def update_orig(image, state):
                            if image is not None:
                                same_size = state is not None and state.size == image.size
                                has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
                                edited = same_size and has_exact_match
                                return image if not edited or state is None else state
713

714
                        inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
715

716 717 718
                    with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
                        init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
                        init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask")
719

720
                    with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
721
                        hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
A
Andrii Skaliuk 已提交
722
                        gr.HTML(
A
AUTOMATIC 已提交
723 724
                            "<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
                            "<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
A
Andrii Skaliuk 已提交
725 726 727
                            f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
                            f"{hidden}</p>"
                        )
728 729
                        img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
                        img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
A
Andrii Skaliuk 已提交
730
                        img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
731

732 733 734 735 736
                    img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]

                    for i, tab in enumerate(img2img_tabs):
                        tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])

737 738 739 740 741 742 743 744 745 746 747 748 749 750
                def copy_image(img):
                    if isinstance(img, dict) and 'image' in img:
                        return img['image']

                    return img

                for button, name, elem in copy_image_buttons:
                    button.click(
                        fn=copy_image,
                        inputs=[elem],
                        outputs=[copy_image_destinations[name]],
                    )
                    button.click(
                        fn=lambda: None,
751
                        _js=f"switch_to_{name.replace(' ', '_')}",
752 753 754 755
                        inputs=[],
                        outputs=[],
                    )

756
                with FormRow():
757
                    resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
758

759 760 761
                for category in ordered_ui_categories():
                    if category == "sampler":
                        steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
A
AUTOMATIC 已提交
762

763 764 765
                    elif category == "dimensions":
                        with FormRow():
                            with gr.Column(elem_id="img2img_column_size", scale=4):
766 767 768 769
                                selected_scale_tab = gr.State(value=0)

                                with gr.Tabs():
                                    with gr.Tab(label="Resize to") as tab_scale_to:
A
AUTOMATIC1111 已提交
770 771 772 773 774 775
                                        with FormRow():
                                            with gr.Column(elem_id="img2img_column_size", scale=4):
                                                width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
                                                height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
                                            with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
                                                res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
776
                                                detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
777 778 779 780 781 782 783

                                    with gr.Tab(label="Resize by") as tab_scale_by:
                                        scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")

                                        with FormRow():
                                            scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview")
                                            gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider")
784
                                            button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to")
785

786
                                    on_change_args = dict(
787 788 789 790 791 792 793
                                        fn=resize_from_to_html,
                                        _js="currentImg2imgSourceResolution",
                                        inputs=[dummy_component, dummy_component, scale_by],
                                        outputs=scale_by_html,
                                        show_progress=False,
                                    )

794 795 796
                                    scale_by.release(**on_change_args)
                                    button_update_resize_to.click(**on_change_args)

797 798 799 800
                                    # the code below is meant to update the resolution label after the image in the image selection UI has changed.
                                    # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
                                    # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
                                    for component in [init_img, sketch]:
801 802
                                        component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)

803 804
                            tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])
                            tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])
A
AUTOMATIC 已提交
805

806 807 808 809
                            if opts.dimensions_and_batch_together:
                                with gr.Column(elem_id="img2img_column_batch"):
                                    batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
                                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
810

811 812
                    elif category == "cfg":
                        with FormGroup():
813 814
                            with FormRow():
                                cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
815
                                image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
816
                            denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
817

818 819
                    elif category == "seed":
                        seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
820

821
                    elif category == "checkboxes":
A
AUTOMATIC 已提交
822
                        with FormRow(elem_classes="checkboxes-row", variant="compact"):
823 824
                            restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
                            tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
825

826 827 828 829 830
                    elif category == "batch":
                        if not opts.dimensions_and_batch_together:
                            with FormRow(elem_id="img2img_column_batch"):
                                batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
                                batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
A
AUTOMATIC 已提交
831

832 833 834 835
                    elif category == "override_settings":
                        with FormRow(elem_id="img2img_override_settings_row") as row:
                            override_settings = create_override_settings_dropdown('img2img', row)

836 837 838
                    elif category == "scripts":
                        with FormGroup(elem_id="img2img_script_container"):
                            custom_inputs = modules.scripts.scripts_img2img.setup_ui()
A
AUTOMATIC 已提交
839

A
AUTOMATIC 已提交
840
                    elif category == "inpaint":
841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861
                        with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
                            with FormRow():
                                mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
                                mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")

                            with FormRow():
                                inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")

                            with FormRow():
                                inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")

                            with FormRow():
                                with gr.Column():
                                    inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")

                                with gr.Column(scale=4):
                                    inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")

                            def select_img2img_tab(tab):
                                return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),

862
                            for i, elem in enumerate(img2img_tabs):
863 864 865 866 867 868
                                elem.select(
                                    fn=lambda tab=i: select_img2img_tab(tab),
                                    inputs=[],
                                    outputs=[inpaint_controls, mask_alpha],
                                )

869
            img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
870

871 872
            connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
            connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
873

D
d8ahazard 已提交
874 875 876
            img2img_prompt_img.change(
                fn=modules.images.image_data,
                inputs=[
A
AUTOMATIC 已提交
877
                    img2img_prompt_img
D
d8ahazard 已提交
878 879 880 881 882 883 884
                ],
                outputs=[
                    img2img_prompt,
                    img2img_prompt_img
                ]
            )

885
            img2img_args = dict(
886
                fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
887
                _js="submit_img2img",
888
                inputs=[
889
                    dummy_component,
890
                    dummy_component,
A
AUTOMATIC 已提交
891
                    img2img_prompt,
892
                    img2img_negative_prompt,
893
                    img2img_prompt_styles,
894
                    init_img,
895
                    sketch,
896
                    init_img_with_mask,
897 898
                    inpaint_color_sketch,
                    inpaint_color_sketch_orig,
899 900
                    init_img_inpaint,
                    init_mask_inpaint,
901 902 903
                    steps,
                    sampler_index,
                    mask_blur,
904
                    mask_alpha,
905
                    inpainting_fill,
A
AUTOMATIC 已提交
906
                    restore_faces,
907
                    tiling,
908 909 910
                    batch_count,
                    batch_size,
                    cfg_scale,
K
Kyle 已提交
911
                    image_cfg_scale,
912 913
                    denoising_strength,
                    seed,
914
                    subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
915
                    selected_scale_tab,
916 917
                    height,
                    width,
918
                    scale_by,
919 920
                    resize_mode,
                    inpaint_full_res,
921
                    inpaint_full_res_padding,
A
AUTOMATIC 已提交
922
                    inpainting_mask_invert,
923 924
                    img2img_batch_input_dir,
                    img2img_batch_output_dir,
925 926
                    img2img_batch_inpaint_mask_dir,
                    override_settings,
A
AUTOMATIC 已提交
927
                ] + custom_inputs,
928 929 930
                outputs=[
                    img2img_gallery,
                    generation_info,
931 932
                    html_info,
                    html_log,
933 934
                ],
                show_progress=False,
935 936
            )

937 938 939 940 941 942 943 944 945 946 947 948 949 950 951
            interrogate_args = dict(
                _js="get_img2img_tab_index",
                inputs=[
                    dummy_component,
                    img2img_batch_input_dir,
                    img2img_batch_output_dir,
                    init_img,
                    sketch,
                    init_img_with_mask,
                    inpaint_color_sketch,
                    init_img_inpaint,
                ],
                outputs=[img2img_prompt, dummy_component],
            )

A
AUTOMATIC 已提交
952
            img2img_prompt.submit(**img2img_args)
953
            submit.click(**img2img_args)
954

955
            res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
956

957 958 959 960 961 962 963
            detect_image_size_btn.click(
                fn=lambda w, h, _: (w or gr.update(), h or gr.update()),
                _js="currentImg2imgSourceResolution",
                inputs=[dummy_component, dummy_component, dummy_component],
                outputs=[width, height],
                show_progress=False,
            )
964

965 966 967 968 969 970 971 972 973 974 975 976
            restore_progress_button.click(
                fn=progress.restore_progress,
                _js="restoreProgressImg2img",
                inputs=[dummy_component],
                outputs=[
                    img2img_gallery,
                    generation_info,
                    html_info,
                    html_log,
                ],
                show_progress=False,
            )
A
AUTOMATIC 已提交
977

A
AUTOMATIC 已提交
978
            img2img_interrogate.click(
979
                fn=lambda *args: process_interrogate(interrogate, *args),
980
                **interrogate_args,
A
AUTOMATIC 已提交
981 982
            )

A
AUTOMATIC 已提交
983
            img2img_deepbooru.click(
984
                fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
985
                **interrogate_args,
A
AUTOMATIC 已提交
986 987 988
            )

            prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
989
            style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
L
Liam 已提交
990
            style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
A
AUTOMATIC 已提交
991 992

            for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
A
AUTOMATIC 已提交
993 994 995
                button.click(
                    fn=add_style,
                    _js="ask_for_style_name",
996 997 998
                    # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
                    # the same number of parameters, but we only know the style-name after the JavaScript prompt
                    inputs=[dummy_component, prompt, negative_prompt],
999
                    outputs=[txt2img_prompt_styles, img2img_prompt_styles],
A
AUTOMATIC 已提交
1000 1001
                )

1002
            for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
A
AUTOMATIC 已提交
1003 1004
                button.click(
                    fn=apply_styles,
1005
                    _js=js_func,
1006 1007
                    inputs=[prompt, negative_prompt, styles],
                    outputs=[prompt, negative_prompt, styles],
A
AUTOMATIC 已提交
1008 1009
                )

Y
yfszzx 已提交
1010
            token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
X
xSinStarx 已提交
1011
            negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])
Y
yfszzx 已提交
1012

A
AUTOMATIC 已提交
1013 1014
            ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)

1015 1016 1017 1018 1019 1020 1021
            img2img_paste_fields = [
                (img2img_prompt, "Prompt"),
                (img2img_negative_prompt, "Negative prompt"),
                (steps, "Steps"),
                (sampler_index, "Sampler"),
                (restore_faces, "Face restoration"),
                (cfg_scale, "CFG scale"),
K
Kyle 已提交
1022
                (image_cfg_scale, "Image CFG scale"),
1023 1024 1025 1026 1027 1028 1029 1030 1031
                (seed, "Seed"),
                (width, "Size-1"),
                (height, "Size-2"),
                (batch_size, "Batch size"),
                (subseed, "Variation seed"),
                (subseed_strength, "Variation seed strength"),
                (seed_resize_from_w, "Seed resize from-1"),
                (seed_resize_from_h, "Seed resize from-2"),
                (denoising_strength, "Denoising strength"),
A
AUTOMATIC 已提交
1032
                (mask_blur, "Mask blur"),
1033
                *modules.scripts.scripts_img2img.infotext_fields
1034
            ]
1035 1036
            parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
            parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
1037
            parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
1038
                paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
1039
            ))
1040

1041
    modules.scripts.scripts_current = None
1042

1043
    with gr.Blocks(analytics_enabled=False) as extras_interface:
1044
        ui_postprocessing.create_ui()
1045

1046 1047 1048 1049 1050 1051 1052
    with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):
                image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")

            with gr.Column(variant='panel'):
                html = gr.HTML()
1053
                generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
1054 1055
                html2 = gr.HTML()
                with gr.Row():
Y
yfszzx 已提交
1056
                    buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
1057 1058 1059 1060 1061

                for tabname, button in buttons.items():
                    parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
                        paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image,
                    ))
1062 1063

        image.change(
1064
            fn=wrap_gradio_call(modules.extras.run_pnginfo),
1065 1066 1067
            inputs=[image],
            outputs=[html, generation_info, html2],
        )
1068

1069 1070 1071 1072 1073 1074 1075 1076 1077
    def update_interp_description(value):
        interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
        interp_descriptions = {
            "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
            "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
            "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
        }
        return interp_descriptions[value]

1078
    with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
1079
        with gr.Row().style(equal_height=False):
A
AUTOMATIC 已提交
1080
            with gr.Column(variant='compact'):
1081
                interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
1082

1083
                with FormRow(elem_id="modelmerger_models"):
1084
                    primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
1085 1086
                    create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")

1087
                    secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
1088 1089
                    create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")

1090
                    tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
1091 1092
                    create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")

1093 1094
                custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
                interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
1095
                interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
1096
                interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
1097

1098
                with FormRow():
1099
                    checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
1100
                    save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
1101
                    save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata")
1102

1103 1104 1105 1106 1107 1108 1109 1110
                with FormRow():
                    with gr.Column():
                        config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")

                    with gr.Column():
                        with FormRow():
                            bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
                            create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
1111

1112 1113 1114
                with FormRow():
                    discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")

A
AUTOMATIC 已提交
1115 1116
                with gr.Row():
                    modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
1117

A
AUTOMATIC 已提交
1118 1119 1120
            with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
                with gr.Group(elem_id="modelmerger_results_panel"):
                    modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
1121

1122
    with gr.Blocks(analytics_enabled=False) as train_interface:
1123
        with gr.Row().style(equal_height=False):
A
AUTOMATIC 已提交
1124
            gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
1125

A
AUTOMATIC 已提交
1126
        with gr.Row(variant="compact").style(equal_height=False):
A
AUTOMATIC 已提交
1127
            with gr.Tabs(elem_id="train_tabs"):
1128

1129
                with gr.Tab(label="Create embedding", id="create_embedding"):
1130 1131 1132 1133
                    new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
                    initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
                    nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
                    overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
1134 1135 1136 1137 1138 1139

                    with gr.Row():
                        with gr.Column(scale=3):
                            gr.HTML(value="")

                        with gr.Column():
1140
                            create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
1141

1142
                with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"):
1143 1144 1145 1146 1147 1148 1149
                    new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
                    new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
                    new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
                    new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
                    new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
                    new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
                    new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
A
aria1th 已提交
1150
                    new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
1151
                    overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
A
AUTOMATIC 已提交
1152 1153 1154 1155 1156 1157

                    with gr.Row():
                        with gr.Column(scale=3):
                            gr.HTML(value="")

                        with gr.Column():
1158
                            create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
1159

1160
                with gr.Tab(label="Preprocess images", id="preprocess_images"):
1161 1162 1163 1164 1165
                    process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
                    process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
                    process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
                    process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
                    preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
1166 1167

                    with gr.Row():
1168
                        process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size")
1169 1170 1171
                        process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
                        process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
                        process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
D
dan 已提交
1172
                        process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop")
1173 1174
                        process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
                        process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
1175

1176
                    with gr.Row(visible=False) as process_split_extra_row:
1177 1178
                        process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
                        process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
1179

C
captin411 已提交
1180
                    with gr.Row(visible=False) as process_focal_crop_row:
1181 1182 1183 1184
                        process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
                        process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
                        process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
                        process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
1185

D
dan 已提交
1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196
                    with gr.Column(visible=False) as process_multicrop_col:
                        gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
                        with gr.Row():
                            process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim")
                            process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim")
                        with gr.Row():
                            process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea")
                            process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea")
                        with gr.Row():
                            process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
                            process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
1197

1198 1199 1200 1201 1202
                    with gr.Row():
                        with gr.Column(scale=3):
                            gr.HTML(value="")

                        with gr.Column():
S
space-nuko 已提交
1203
                            with gr.Row():
1204 1205
                                interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
                            run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
1206

1207 1208 1209 1210 1211 1212
                    process_split.change(
                        fn=lambda show: gr_show(show),
                        inputs=[process_split],
                        outputs=[process_split_extra_row],
                    )

C
captin411 已提交
1213 1214 1215 1216 1217 1218
                    process_focal_crop.change(
                        fn=lambda show: gr_show(show),
                        inputs=[process_focal_crop],
                        outputs=[process_focal_crop_row],
                    )

D
dan 已提交
1219 1220 1221 1222 1223 1224
                    process_multicrop.change(
                        fn=lambda show: gr_show(show),
                        inputs=[process_multicrop],
                        outputs=[process_multicrop_col],
                    )

1225
                def get_textual_inversion_template_names():
A
AUTOMATIC 已提交
1226
                    return sorted(textual_inversion.textual_inversion_templates)
1227

1228
                with gr.Tab(label="Train", id="train"):
D
DepFA 已提交
1229
                    gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
A
AUTOMATIC 已提交
1230
                    with FormRow():
A
AUTOMATIC 已提交
1231
                        train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
1232
                        create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
A
AUTOMATIC 已提交
1233

A
AUTOMATIC 已提交
1234 1235
                        train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks))
                        create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
A
AUTOMATIC 已提交
1236 1237

                    with FormRow():
1238 1239
                        embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
                        hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
1240

A
AUTOMATIC 已提交
1241
                    with FormRow():
1242
                        clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
1243
                        clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
1244

A
AUTOMATIC 已提交
1245 1246 1247 1248
                    with FormRow():
                        batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
                        gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")

1249 1250
                    dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
                    log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
1251 1252 1253 1254 1255

                    with FormRow():
                        template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
                        create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")

1256 1257
                    training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
                    training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
1258
                    varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
1259
                    steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
A
AUTOMATIC 已提交
1260 1261 1262 1263 1264

                    with FormRow():
                        create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
                        save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")

1265 1266
                    use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False, elem_id="use_weight")

1267 1268
                    save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
                    preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
A
AUTOMATIC 已提交
1269 1270 1271 1272 1273

                    shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
                    tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")

                    latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
1274 1275

                    with gr.Row():
A
AUTOMATIC 已提交
1276
                        train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
1277 1278
                        interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
                        train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
1279

1280 1281 1282 1283
                params = script_callbacks.UiTrainTabParams(txt2img_preview_params)

                script_callbacks.ui_train_tabs_callback(params)

1284
            with gr.Column(elem_id='ti_gallery_container'):
1285
                ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
A
AUTOMATIC 已提交
1286 1287
                gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
                gr.HTML(elem_id="ti_progress", value="")
1288 1289 1290 1291 1292 1293
                ti_outcome = gr.HTML(elem_id="ti_error", value="")

        create_embedding.click(
            fn=modules.textual_inversion.ui.create_embedding,
            inputs=[
                new_embedding_name,
1294
                initialization_text,
1295
                nvpt,
D
DepFA 已提交
1296
                overwrite_old_embedding,
1297 1298 1299 1300 1301 1302 1303 1304
            ],
            outputs=[
                train_embedding_name,
                ti_output,
                ti_outcome,
            ]
        )

A
AUTOMATIC 已提交
1305
        create_hypernetwork.click(
A
AUTOMATIC 已提交
1306
            fn=modules.hypernetworks.ui.create_hypernetwork,
A
AUTOMATIC 已提交
1307 1308
            inputs=[
                new_hypernetwork_name,
1309
                new_hypernetwork_sizes,
D
DepFA 已提交
1310
                overwrite_old_hypernetwork,
1311
                new_hypernetwork_layer_structure,
D
update  
discus0434 已提交
1312
                new_hypernetwork_activation_func,
1313
                new_hypernetwork_initialization_option,
1314
                new_hypernetwork_add_layer_norm,
A
aria1th 已提交
1315 1316
                new_hypernetwork_use_dropout,
                new_hypernetwork_dropout_structure
A
AUTOMATIC 已提交
1317 1318 1319 1320 1321 1322 1323 1324
            ],
            outputs=[
                train_hypernetwork_name,
                ti_output,
                ti_outcome,
            ]
        )

1325 1326 1327 1328
        run_preprocess.click(
            fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
            _js="start_training_textual_inversion",
            inputs=[
1329
                dummy_component,
1330 1331
                process_src,
                process_dst,
A
alg-wiki 已提交
1332 1333
                process_width,
                process_height,
D
DepFA 已提交
1334
                preprocess_txt_action,
1335
                process_keep_original_size,
1336 1337 1338
                process_flip,
                process_split,
                process_caption,
1339 1340 1341
                process_caption_deepbooru,
                process_split_threshold,
                process_overlap_ratio,
C
captin411 已提交
1342 1343 1344 1345 1346
                process_focal_crop,
                process_focal_crop_face_weight,
                process_focal_crop_entropy_weight,
                process_focal_crop_edges_weight,
                process_focal_crop_debug,
D
dan 已提交
1347 1348 1349 1350 1351 1352 1353
                process_multicrop,
                process_multicrop_mindim,
                process_multicrop_maxdim,
                process_multicrop_minarea,
                process_multicrop_maxarea,
                process_multicrop_objective,
                process_multicrop_threshold,
1354 1355 1356 1357 1358 1359 1360
            ],
            outputs=[
                ti_output,
                ti_outcome,
            ],
        )

1361 1362 1363 1364
        train_embedding.click(
            fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
            _js="start_training_textual_inversion",
            inputs=[
1365
                dummy_component,
1366
                train_embedding_name,
D
DepFA 已提交
1367
                embedding_learn_rate,
1368
                batch_size,
1369
                gradient_step,
1370 1371
                dataset_directory,
                log_directory,
A
alg-wiki 已提交
1372 1373
                training_width,
                training_height,
D
dan 已提交
1374
                varsize,
1375
                steps,
1376 1377
                clip_grad_mode,
                clip_grad_value,
1378 1379 1380
                shuffle_tags,
                tag_drop_out,
                latent_sampling_method,
1381
                use_weight,
1382 1383 1384
                create_image_every,
                save_embedding_every,
                template_file,
D
DepFA 已提交
1385
                save_image_with_stored_embedding,
1386 1387
                preview_from_txt2img,
                *txt2img_preview_params,
1388 1389 1390 1391 1392 1393 1394
            ],
            outputs=[
                ti_output,
                ti_outcome,
            ]
        )

A
AUTOMATIC 已提交
1395
        train_hypernetwork.click(
A
AUTOMATIC 已提交
1396
            fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
A
AUTOMATIC 已提交
1397 1398
            _js="start_training_textual_inversion",
            inputs=[
1399
                dummy_component,
A
AUTOMATIC 已提交
1400
                train_hypernetwork_name,
D
DepFA 已提交
1401
                hypernetwork_learn_rate,
1402
                batch_size,
1403
                gradient_step,
A
AUTOMATIC 已提交
1404 1405
                dataset_directory,
                log_directory,
1406 1407
                training_width,
                training_height,
D
dan 已提交
1408
                varsize,
1409
                steps,
1410 1411
                clip_grad_mode,
                clip_grad_value,
1412 1413 1414
                shuffle_tags,
                tag_drop_out,
                latent_sampling_method,
1415
                use_weight,
1416 1417 1418
                create_image_every,
                save_embedding_every,
                template_file,
1419 1420
                preview_from_txt2img,
                *txt2img_preview_params,
1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433
            ],
            outputs=[
                ti_output,
                ti_outcome,
            ]
        )

        interrupt_training.click(
            fn=lambda: shared.state.interrupt(),
            inputs=[],
            outputs=[],
        )

S
space-nuko 已提交
1434 1435 1436 1437 1438 1439
        interrupt_preprocessing.click(
            fn=lambda: shared.state.interrupt(),
            inputs=[],
            outputs=[],
        )

1440
    def create_setting_component(key, is_quicksettings=False):
1441 1442 1443 1444 1445 1446
        def fun():
            return opts.data[key] if key in opts.data else opts.data_labels[key].default

        info = opts.data_labels[key]
        t = type(info.default)

1447 1448
        args = info.component_args() if callable(info.component_args) else info.component_args

1449
        if info.component is not None:
1450
            comp = info.component
1451
        elif t == str:
1452
            comp = gr.Textbox
1453
        elif t == int:
1454
            comp = gr.Number
1455
        elif t == bool:
1456
            comp = gr.Checkbox
1457
        else:
1458
            raise Exception(f'bad options item type: {t} for key {key}')
1459

1460
        elem_id = f"setting_{key}"
A
AUTOMATIC 已提交
1461

1462 1463
        if info.refresh is not None:
            if is_quicksettings:
1464
                res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
1465
                create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
1466
            else:
1467
                with FormRow():
1468
                    res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
1469
                    create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
1470
        else:
1471
            res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
1472 1473

        return res
1474

A
AUTOMATIC 已提交
1475 1476
    loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)

A
AUTOMATIC 已提交
1477
    components = []
1478
    component_dict = {}
1479
    shared.settings_components = component_dict
A
AUTOMATIC 已提交
1480

1481 1482 1483
    script_callbacks.ui_settings_callback()
    opts.reorder()

1484
    def run_settings(*args):
1485
        changed = []
1486 1487

        for key, value, comp in zip(opts.data_labels.keys(), args, components):
1488
            assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
1489

A
AUTOMATIC 已提交
1490
        for key, value, comp in zip(opts.data_labels.keys(), args, components):
1491 1492 1493
            if comp == dummy_component:
                continue

1494
            if opts.set(key, value):
1495
                changed.append(key)
1496

1497 1498 1499
        try:
            opts.save(shared.config_filename)
        except RuntimeError:
1500
            return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
1501
        return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
1502

1503 1504 1505 1506
    def run_settings_single(value, key):
        if not opts.same_type(value, opts.data_labels[key].default):
            return gr.update(visible=True), opts.dumpjson()

1507 1508
        if not opts.set(key, value):
            return gr.update(value=getattr(opts, key)), opts.dumpjson()
1509 1510 1511

        opts.save(shared.config_filename)

1512
        return get_value_for_setting(key), opts.dumpjson()
1513

A
AUTOMATIC 已提交
1514
    with gr.Blocks(analytics_enabled=False) as settings_interface:
1515
        with gr.Row():
A
AUTOMATIC 已提交
1516 1517 1518 1519
            with gr.Column(scale=6):
                settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
            with gr.Column():
                restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
A
AUTOMATIC 已提交
1520

1521
        result = gr.HTML(elem_id="settings_result")
A
AUTOMATIC 已提交
1522

1523
        quicksettings_names = opts.quicksettings_list
1524
        quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
1525

1526 1527
        quicksettings_list = []

1528
        previous_section = None
1529
        current_tab = None
1530
        current_row = None
1531
        with gr.Tabs(elem_id="settings"):
1532
            for i, (k, item) in enumerate(opts.data_labels.items()):
1533
                section_must_be_skipped = item.section[0] is None
D
DepFA 已提交
1534

1535
                if previous_section != item.section and not section_must_be_skipped:
1536
                    elem_id, text = item.section
D
DepFA 已提交
1537

1538
                    if current_tab is not None:
1539
                        current_row.__exit__()
1540
                        current_tab.__exit__()
A
AUTOMATIC 已提交
1541

1542
                    gr.Group()
1543
                    current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
1544
                    current_tab.__enter__()
1545 1546
                    current_row = gr.Column(variant='compact')
                    current_row.__enter__()
1547 1548 1549

                    previous_section = item.section

1550
                if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
1551 1552
                    quicksettings_list.append((i, k, item))
                    components.append(dummy_component)
1553 1554
                elif section_must_be_skipped:
                    components.append(dummy_component)
1555 1556 1557 1558
                else:
                    component = create_setting_component(k)
                    component_dict[k] = component
                    components.append(component)
1559

1560
            if current_tab is not None:
1561
                current_row.__exit__()
1562
                current_tab.__exit__()
A
AUTOMATIC 已提交
1563

A
AUTOMATIC 已提交
1564 1565 1566
            with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
                loadsave.create_ui()

1567
            with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
1568 1569 1570
                request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
                download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
                reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
Φ
Φφ 已提交
1571 1572 1573
                with gr.Row():
                    unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
                    reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
1574

1575
            with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
A
AUTOMATIC 已提交
1576
                gr.HTML(shared.html("licenses.html"), elem_id="licenses")
A
AUTOMATIC 已提交
1577

1578
            gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
1579

Φ
Φφ 已提交
1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597

        def unload_sd_weights():
            modules.sd_models.unload_model_weights()

        def reload_sd_weights():
            modules.sd_models.reload_model_weights()

        unload_sd_model.click(
            fn=unload_sd_weights,
            inputs=[],
            outputs=[]
        )

        reload_sd_model.click(
            fn=reload_sd_weights,
            inputs=[],
            outputs=[]
        )
1598

1599 1600 1601 1602
        request_notifications.click(
            fn=lambda: None,
            inputs=[],
            outputs=[],
1603
            _js='function(){}'
1604 1605
        )

A
AUTOMATIC 已提交
1606 1607 1608 1609 1610 1611 1612
        download_localization.click(
            fn=lambda: None,
            inputs=[],
            outputs=[],
            _js='download_localization'
        )

D
DepFA 已提交
1613
        def reload_scripts():
D
DepFA 已提交
1614
            modules.scripts.reload_script_body_only()
1615
            reload_javascript()  # need to refresh the html page
D
DepFA 已提交
1616 1617 1618 1619

        reload_script_bodies.click(
            fn=reload_scripts,
            inputs=[],
A
AUTOMATIC 已提交
1620
            outputs=[]
D
DepFA 已提交
1621
        )
1622 1623

        restart_gradio.click(
1624
            fn=shared.state.request_restart,
1625
            _js='restart_reload',
1626 1627 1628
            inputs=[],
            outputs=[],
        )
J
Justin Maier 已提交
1629

1630
    interfaces = [
A
AUTOMATIC 已提交
1631 1632 1633 1634
        (txt2img_interface, "txt2img", "txt2img"),
        (img2img_interface, "img2img", "img2img"),
        (extras_interface, "Extras", "extras"),
        (pnginfo_interface, "PNG Info", "pnginfo"),
1635
        (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
A
AUTOMATIC 已提交
1636
        (train_interface, "Train", "train"),
1637 1638
    ]

1639 1640 1641
    interfaces += script_callbacks.ui_tabs_callback()
    interfaces += [(settings_interface, "Settings", "settings")]

1642 1643 1644
    extensions_interface = ui_extensions.create_ui()
    interfaces += [(extensions_interface, "Extensions", "extensions")]

1645 1646 1647 1648
    shared.tab_names = []
    for _interface, label, _ifid in interfaces:
        shared.tab_names.append(label)

1649
    with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
1650
        with gr.Row(elem_id="quicksettings", variant="compact"):
A
AUTOMATIC 已提交
1651
            for _i, k, _item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
1652
                component = create_setting_component(k, is_quicksettings=True)
1653 1654
                component_dict[k] = component

1655
        parameters_copypaste.connect_paste_params_buttons()
1656

1657
        with gr.Tabs(elem_id="tabs") as tabs:
A
AUTOMATIC 已提交
1658 1659 1660 1661
            tab_order = {k: i for i, k in enumerate(opts.ui_tab_order)}
            sorted_interfaces = sorted(interfaces, key=lambda x: tab_order.get(x[1], 9999))

            for interface, label, ifid in sorted_interfaces:
1662
                if label in shared.opts.hidden_tabs:
V
Vladimir Mandic 已提交
1663
                    continue
1664
                with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
A
AUTOMATIC 已提交
1665
                    interface.render()
J
Justin Maier 已提交
1666

A
AUTOMATIC 已提交
1667 1668 1669 1670 1671 1672 1673 1674 1675 1676
            for interface, _label, ifid in interfaces:
                if ifid in ["extensions", "settings"]:
                    continue

                loadsave.add_block(interface, ifid)

            loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)

            loadsave.setup_ui()

1677
        if os.path.exists(os.path.join(script_path, "notification.mp3")):
A
AUTOMATIC 已提交
1678
            gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
A
AUTOMATIC 已提交
1679

A
AUTOMATIC 已提交
1680 1681 1682
        footer = shared.html("footer.html")
        footer = footer.format(versions=versions_html())
        gr.HTML(footer, elem_id="footer")
A
AUTOMATIC 已提交
1683

1684
        text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
1685
        settings_submit.click(
1686
            fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
1687
            inputs=components,
1688
            outputs=[text_settings, result],
1689
        )
1690

A
AUTOMATIC 已提交
1691
        for _i, k, _item in quicksettings_list:
1692
            component = component_dict[k]
1693
            info = opts.data_labels[k]
1694

1695 1696
            change_handler = component.release if hasattr(component, 'release') else component.change
            change_handler(
1697 1698 1699
                fn=lambda value, k=k: run_settings_single(value, key=k),
                inputs=[component],
                outputs=[component, text_settings],
1700
                show_progress=info.refresh is not None,
1701 1702
            )

1703 1704 1705
        update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
        text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
        demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
1706

1707 1708 1709 1710 1711 1712 1713 1714
        button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
        button_set_checkpoint.click(
            fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
            _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
            inputs=[component_dict['sd_model_checkpoint'], dummy_component],
            outputs=[component_dict['sd_model_checkpoint'], text_settings],
        )

1715 1716 1717
        component_keys = [k for k in opts.data_labels.keys() if k in component_dict]

        def get_settings_values():
1718
            return [get_value_for_setting(key) for key in component_keys]
1719 1720 1721 1722 1723

        demo.load(
            fn=get_settings_values,
            inputs=[],
            outputs=[component_dict[k] for k in component_keys],
1724
            queue=False,
1725 1726
        )

S
safentisAuth 已提交
1727 1728
        def modelmerger(*args):
            try:
1729
                results = modules.extras.run_modelmerger(*args)
S
safentisAuth 已提交
1730 1731 1732
            except Exception as e:
                print("Error loading/saving model file:", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)
1733
                modules.sd_models.list_models()  # to remove the potentially missing models from the list
A
AUTOMATIC 已提交
1734
                return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
S
safentisAuth 已提交
1735
            return results
1736

1737
        modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
1738
        modelmerger_merge.click(
A
AUTOMATIC 已提交
1739 1740
            fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
            _js='modelmerger',
1741
            inputs=[
A
AUTOMATIC 已提交
1742
                dummy_component,
1743 1744
                primary_model_name,
                secondary_model_name,
1745
                tertiary_model_name,
1746 1747 1748
                interp_method,
                interp_amount,
                save_as_half,
S
safentisAuth 已提交
1749
                custom_name,
1750
                checkpoint_format,
1751
                config_source,
1752
                bake_in_vae,
1753
                discard_weights,
1754
                save_metadata,
1755 1756 1757 1758
            ],
            outputs=[
                primary_model_name,
                secondary_model_name,
1759
                tertiary_model_name,
1760
                component_dict['sd_model_checkpoint'],
A
AUTOMATIC 已提交
1761
                modelmerger_result,
1762 1763
            ]
        )
1764

A
AUTOMATIC 已提交
1765 1766
    loadsave.dump_defaults()
    demo.ui_loadsave = loadsave
A
AUTOMATIC 已提交
1767

1768 1769 1770
    # Required as a workaround for change() event not triggering when loading values from ui-config.json
    interp_description.value = update_interp_description(interp_method.value)

1771 1772 1773
    return demo


A
AUTOMATIC 已提交
1774 1775 1776 1777 1778 1779 1780 1781 1782 1783
def webpath(fn):
    if fn.startswith(script_path):
        web_path = os.path.relpath(fn, script_path).replace('\\', '/')
    else:
        web_path = os.path.abspath(fn)

    return f'file={web_path}?{os.path.getmtime(fn)}'


def javascript_html():
1784 1785
    # Ensure localization is in `window` before scripts
    head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
1786

1787 1788
    script_js = os.path.join(script_path, "script.js")
    head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
1789

1790
    for script in modules.scripts.list_scripts("javascript", ".js"):
A
AUTOMATIC 已提交
1791
        head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
1792

1793
    for script in modules.scripts.list_scripts("javascript", ".mjs"):
A
AUTOMATIC 已提交
1794
        head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
1795

1796 1797
    if cmd_opts.theme:
        head += f'<script type="text/javascript">set_theme(\"{cmd_opts.theme}\");</script>\n'
S
Shondoit 已提交
1798

A
AUTOMATIC 已提交
1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823
    return head


def css_html():
    head = ""

    def stylesheet(fn):
        return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'

    for cssfile in modules.scripts.list_files_with_name("style.css"):
        if not os.path.isfile(cssfile):
            continue

        head += stylesheet(cssfile)

    if os.path.exists(os.path.join(data_path, "user.css")):
        head += stylesheet(os.path.join(data_path, "user.css"))

    return head


def reload_javascript():
    js = javascript_html()
    css = css_html()

D
DepFA 已提交
1824
    def template_response(*args, **kwargs):
1825
        res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
A
AUTOMATIC 已提交
1826 1827
        res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
        res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
D
DepFA 已提交
1828 1829 1830 1831
        res.init_headers()
        return res

    gradio.routes.templates.TemplateResponse = template_response
Y
yfszzx 已提交
1832

1833

1834 1835
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
    shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
A
AUTOMATIC 已提交
1836 1837 1838 1839 1840 1841 1842 1843


def versions_html():
    import torch
    import launch

    python_version = ".".join([str(x) for x in sys.version_info[0:3]])
    commit = launch.commit_hash()
1844
    tag = launch.git_tag()
A
AUTOMATIC 已提交
1845 1846 1847 1848 1849 1850 1851 1852

    if shared.xformers_available:
        import xformers
        xformers_version = xformers.__version__
    else:
        xformers_version = "N/A"

    return f"""
1853
version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
1854
&#x2000;•&#x2000;
A
AUTOMATIC 已提交
1855
python: <span title="{sys.version}">{python_version}</span>
1856
&#x2000;•&#x2000;
1857
torch: {getattr(torch, '__long_version__',torch.__version__)}
1858
&#x2000;•&#x2000;
A
AUTOMATIC 已提交
1859
xformers: {xformers_version}
1860
&#x2000;•&#x2000;
A
AUTOMATIC 已提交
1861
gradio: {gr.__version__}
1862
&#x2000;•&#x2000;
1863
checkpoint: <a id="sd_checkpoint_hash">N/A</a>
A
AUTOMATIC 已提交
1864
"""
1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878


def setup_ui_api(app):
    from pydantic import BaseModel, Field
    from typing import List

    class QuicksettingsHint(BaseModel):
        name: str = Field(title="Name of the quicksettings field")
        label: str = Field(title="Label of the quicksettings field")

    def quicksettings_hint():
        return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]

    app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
1879 1880

    app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])