ui.py 91.1 KB
Newer Older
1 2
import html
import json
A
AUTOMATIC 已提交
3
import math
4 5
import mimetypes
import os
D
discus0434 已提交
6
import platform
A
AUTOMATIC 已提交
7
import random
8
import sys
9
import tempfile
10 11
import time
import traceback
12
from functools import partial, reduce
A
AUTOMATIC 已提交
13
import warnings
14

D
discus0434 已提交
15 16 17
import gradio as gr
import gradio.routes
import gradio.utils
A
AUTOMATIC 已提交
18
import numpy as np
19
from PIL import Image, PngImagePlugin
20
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
21

22
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
A
AUTOMATIC 已提交
23
from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
24
from modules.paths import script_path, data_path
25

26
from modules.shared import opts, cmd_opts, restricted_opts
M
init  
MalumaDev 已提交
27

D
discus0434 已提交
28
import modules.codeformer_model
Y
yfszzx 已提交
29
import modules.generation_parameters_copypaste as parameters_copypaste
D
discus0434 已提交
30 31
import modules.gfpgan_model
import modules.hypernetworks.ui
A
AUTOMATIC 已提交
32
import modules.scripts
D
discus0434 已提交
33
import modules.shared as shared
A
AUTOMATIC 已提交
34
import modules.styles
D
discus0434 已提交
35
import modules.textual_inversion.ui
A
AUTOMATIC 已提交
36
from modules import prompt_parser
M
Milly 已提交
37
from modules.images import save_image
D
discus0434 已提交
38 39
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
40
from modules.textual_inversion import textual_inversion
41
import modules.hypernetworks.ui
Y
yfszzx 已提交
42
from modules.generation_parameters_copypaste import image_from_url_text
43
import modules.extras
44

A
AUTOMATIC 已提交
45 46
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)

A
Aidan Holland 已提交
47
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
48 49 50
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')

51
if not cmd_opts.share and not cmd_opts.listen:
52 53 54 55
    # fix gradio phoning home
    gradio.utils.version_check = lambda: None
    gradio.utils.get_local_ip_address = lambda: '127.0.0.1'

Y
Yuval Aboulafia 已提交
56
if cmd_opts.ngrok is not None:
J
JamnedZ 已提交
57 58
    import modules.ngrok as ngrok
    print('ngrok authtoken detected, trying to connect...')
Y
Yuval Aboulafia 已提交
59 60 61 62 63
    ngrok.connect(
        cmd_opts.ngrok,
        cmd_opts.port if cmd_opts.port is not None else 7860,
        cmd_opts.ngrok_region
        )
J
JamnedZ 已提交
64

65 66 67 68 69 70 71 72 73 74

def gr_show(visible=True):
    return {"visible": visible, "__type__": "update"}


sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None

css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
75
.wrap .m-12::before { content:"Loading..." }
D
dtlnor 已提交
76 77
.wrap .z-20 svg { display:none!important; }
.wrap .z-20::before { content:"Loading..." }
A
AUTOMATIC 已提交
78
.wrap.cover-bg .z-20::before { content:"" }
79 80
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
D
dtlnor 已提交
81
.meta-text-center { display:none!important; }
82 83
"""

84 85 86 87
# Using constants for these since the variation selector isn't visible.
# Important that they exactly match script.js for tooltip to work.
random_symbol = '\U0001f3b2\ufe0f'  # 🎲️
reuse_symbol = '\u267b\ufe0f'  # ♻️
88
paste_symbol = '\u2199\ufe0f'  # ↙
89
refresh_symbol = '\U0001f504'  # 🔄
A
AUTOMATIC 已提交
90 91
save_style_symbol = '\U0001f4be'  # 💾
apply_style_symbol = '\U0001f4cb'  # 📋
A
AUTOMATIC 已提交
92
clear_prompt_symbol = '\U0001f5d1\ufe0f'  # 🗑️
A
AUTOMATIC 已提交
93
extra_networks_symbol = '\U0001F3B4'  # 🎴
94 95
switch_values_symbol = '\U000021C5' # ⇅

96

97
def plaintext_to_html(text):
98
    return ui_common.plaintext_to_html(text)
99

100 101 102 103 104 105

def send_gradio_gallery_to_image(x):
    if len(x) == 0:
        return None
    return image_from_url_text(x[0])

A
AUTOMATIC 已提交
106 107 108 109 110 111 112
def visit(x, func, path=""):
    if hasattr(x, 'children'):
        for c in x.children:
            visit(c, func, path)
    elif x.label is not None:
        func(path + "/" + str(x.label), x)

113

114 115
def add_style(name: str, prompt: str, negative_prompt: str):
    if name is None:
A
AUTOMATIC 已提交
116
        return [gr_show() for x in range(4)]
A
AUTOMATIC 已提交
117

118
    style = modules.styles.PromptStyle(name, prompt, negative_prompt)
A
AUTOMATIC 已提交
119
    shared.prompt_styles.styles[style.name] = style
120 121
    # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
    # reserialize all styles every time we save them
A
AUTOMATIC 已提交
122
    shared.prompt_styles.save_styles(shared.styles_filename)
A
AUTOMATIC 已提交
123

124
    return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]
A
AUTOMATIC 已提交
125

126 127 128 129 130 131 132 133 134 135 136 137

def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
    from modules import processing, devices

    if not enable:
        return ""

    p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)

    with devices.autocast():
        p.init([""], [0], [0])

138
    return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
139

A
AUTOMATIC 已提交
140

141 142 143
def apply_styles(prompt, prompt_neg, styles):
    prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
    prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
A
AUTOMATIC 已提交
144

145
    return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]
A
AUTOMATIC 已提交
146 147


148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
    if mode in {0, 1, 3, 4}:
        return [interrogation_function(ii_singles[mode]), None]
    elif mode == 2:
        return [interrogation_function(ii_singles[mode]["image"]), None]
    elif mode == 5:
        assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
        images = shared.listfiles(ii_input_dir)
        print(f"Will process {len(images)} images.")
        if ii_output_dir != "":
            os.makedirs(ii_output_dir, exist_ok=True)
        else:
            ii_output_dir = ii_input_dir

        for image in images:
            img = Image.open(image)
            filename = os.path.basename(image)
            left, _ = os.path.splitext(filename)
            print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))

168
        return [gr.update(), None]
169 170


A
AUTOMATIC 已提交
171
def interrogate(image):
A
Allen Benz 已提交
172
    prompt = shared.interrogator.interrogate(image.convert("RGB"))
173
    return gr.update() if prompt is None else prompt
A
AUTOMATIC 已提交
174

A
AUTOMATIC 已提交
175

G
Greendayle 已提交
176
def interrogate_deepbooru(image):
177
    prompt = deepbooru.model.tag(image)
178
    return gr.update() if prompt is None else prompt
G
Greendayle 已提交
179 180


181
def create_seed_inputs(target_interface):
A
AUTOMATIC 已提交
182
    with FormRow(elem_id=target_interface + '_seed_row', variant="compact"):
183 184
        seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
        seed.style(container=False)
A
AUTOMATIC 已提交
185 186
        random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed')
        reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed')
187

A
AUTOMATIC 已提交
188
        seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
189 190 191 192

    # Components to show/hide based on the 'Extra' checkbox
    seed_extras = []

193
    with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
194
        seed_extras.append(seed_extra_row_1)
195 196
        subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
        subseed.style(container=False)
A
AUTOMATIC 已提交
197 198
        random_subseed = ToolButton(random_symbol, elem_id=target_interface + '_random_subseed')
        reuse_subseed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
199
        subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
200

201
    with FormRow(visible=False) as seed_extra_row_2:
202
        seed_extras.append(seed_extra_row_2)
203 204
        seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
        seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
205 206 207 208 209 210 211 212 213

    random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
    random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])

    def change_visibility(show):
        return {comp: gr_show(show) for comp in seed_extras}

    seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)

214
    return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
215 216


217

P
papuSpartan 已提交
218
def connect_clear_prompt(button):
P
papuSpartan 已提交
219
    """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
220
    button.click(
221
        _js="clear_prompt",
P
papuSpartan 已提交
222
        fn=None,
P
papuSpartan 已提交
223 224
        inputs=[],
        outputs=[],
225
    )
226 227


228 229 230
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
    """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
        (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
231
        was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
232 233 234
    def copy_seed(gen_info_string: str, index):
        res = -1

235 236
        try:
            gen_info = json.loads(gen_info_string)
237 238 239 240 241
            index -= gen_info.get('index_of_first_image', 0)

            if is_subseed and gen_info.get('subseed_strength', 0) > 0:
                all_subseeds = gen_info.get('all_subseeds', [-1])
                res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
242
            else:
243 244 245
                all_seeds = gen_info.get('all_seeds', [-1])
                res = all_seeds[index if 0 <= index < len(all_seeds) else 0]

246 247 248 249
        except json.decoder.JSONDecodeError as e:
            if gen_info_string != '':
                print("Error parsing JSON generation info:", file=sys.stderr)
                print(gen_info_string, file=sys.stderr)
250 251

        return [res, gr_show(False)]
252 253 254

    reuse_seed.click(
        fn=copy_seed,
255
        _js="(x, y) => [x, selected_gallery_index()]",
256
        show_progress=False,
257 258
        inputs=[generation_info, dummy_component],
        outputs=[seed, dummy_component]
259 260
    )

261

L
Liam 已提交
262
def update_token_counter(text, steps):
263
    try:
A
AUTOMATIC 已提交
264 265
        text, _ = extra_networks.parse_prompt(text)

A
AUTOMATIC 已提交
266 267 268
        _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
        prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)

269 270 271 272 273
    except Exception:
        # a parsing error can happen here during typing, and we don't want to bother the user with
        # messages related to it in console
        prompt_schedules = [[[steps, text]]]

L
Liam 已提交
274
    flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
275
    prompts = [prompt_text for step, prompt_text in flat_prompts]
A
AUTOMATIC 已提交
276
    token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
277
    return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
A
AUTOMATIC 已提交
278

279

A
AUTOMATIC 已提交
280
def create_toprow(is_img2img):
281 282
    id_part = "img2img" if is_img2img else "txt2img"

283 284
    with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
        with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
A
AUTOMATIC 已提交
285
            with gr.Row():
286
                with gr.Column(scale=80):
A
AUTOMATIC 已提交
287
                    with gr.Row():
288
                        prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)")
A
aoirusann 已提交
289

A
AUTOMATIC 已提交
290
            with gr.Row():
A
AUTOMATIC 已提交
291
                with gr.Column(scale=80):
B
Ben 已提交
292
                    with gr.Row():
A
AUTOMATIC 已提交
293
                        negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
A
aoirusann 已提交
294

A
AUTOMATIC 已提交
295 296 297
        button_interrogate = None
        button_deepbooru = None
        if is_img2img:
A
AUTOMATIC 已提交
298
            with gr.Column(scale=1, elem_classes="interrogate-col"):
A
AUTOMATIC 已提交
299
                button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
A
AUTOMATIC 已提交
300
                button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
A
AUTOMATIC 已提交
301

302
        with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
A
AUTOMATIC 已提交
303 304 305
            with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
                interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
                skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
306
                submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
307

308 309 310 311 312 313
                skip.click(
                    fn=lambda: shared.state.skip(),
                    inputs=[],
                    outputs=[],
                )

314 315 316 317 318
                interrupt.click(
                    fn=lambda: shared.state.interrupt(),
                    inputs=[],
                    outputs=[],
                )
A
AUTOMATIC 已提交
319

320 321 322 323 324 325 326
            with gr.Row(elem_id=f"{id_part}_tools"):
                paste = ToolButton(value=paste_symbol, elem_id="paste")
                clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
                extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
                prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
                save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")

A
AUTOMATIC 已提交
327
                token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
328
                token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
A
AUTOMATIC 已提交
329
                negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
330 331 332 333 334 335 336 337 338
                negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")

                clear_prompt_button.click(
                    fn=lambda *x: x,
                    _js="confirm_clear_prompt",
                    inputs=[prompt, negative_prompt],
                    outputs=[prompt, negative_prompt],
                )

A
AUTOMATIC 已提交
339
            with gr.Row(elem_id=f"{id_part}_styles_row"):
340 341
                prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
                create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
A
AUTOMATIC 已提交
342

A
AUTOMATIC 已提交
343
    return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button
A
AUTOMATIC 已提交
344 345


346
def setup_progressbar(*args, **kwargs):
347
    pass
A
AUTOMATIC 已提交
348 349


350 351 352 353
def apply_setting(key, value):
    if value is None:
        return gr.update()

354 355 356
    if shared.cmd_opts.freeze_settings:
        return gr.update()

357 358 359 360
    # dont allow model to be swapped when model hash exists in prompt
    if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
        return gr.update()

361 362 363 364 365 366 367 368 369 370 371 372 373
    if key == "sd_model_checkpoint":
        ckpt_info = sd_models.get_closet_checkpoint_match(value)

        if ckpt_info is not None:
            value = ckpt_info.title
        else:
            return gr.update()

    comp_args = opts.data_labels[key].component_args
    if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
        return

    valtype = type(opts.data_labels[key].default)
374
    oldval = opts.data.get(key, None)
375 376 377 378 379
    opts.data[key] = valtype(value) if valtype != type(None) else value
    if oldval != value and opts.data_labels[key].onchange is not None:
        opts.data_labels[key].onchange()

    opts.save(shared.config_filename)
380
    return getattr(opts, key)
381

382

383 384 385 386
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
    def refresh():
        refresh_method()
        args = refreshed_args() if callable(refreshed_args) else refreshed_args
387

388 389
        for k, v in args.items():
            setattr(refresh_component, k, v)
390

391
        return gr.update(**(args or {}))
392

393
    refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
394 395 396 397 398 399
    refresh_button.click(
        fn=refresh,
        inputs=[],
        outputs=[refresh_component]
    )
    return refresh_button
400 401


402
def create_output_panel(tabname, outdir):
403
    return ui_common.create_output_panel(tabname, outdir)
J
Justin Maier 已提交
404

A
aoirusann 已提交
405

406 407
def create_sampler_and_steps_selection(choices, tabname):
    if opts.samplers_in_dropdown:
408
        with FormRow(elem_id=f"sampler_selection_{tabname}"):
409
            sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
410
            steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
411
    else:
412
        with FormGroup(elem_id=f"sampler_selection_{tabname}"):
413
            steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
414 415 416
            sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")

    return steps, sampler_index
417

418

419
def ordered_ui_categories():
A
AUTOMATIC 已提交
420
    user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
421

A
AUTOMATIC 已提交
422
    for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
423 424 425
        yield category


426 427 428 429 430 431 432 433 434 435
def get_value_for_setting(key):
    value = getattr(opts, key)

    info = opts.data_labels[key]
    args = info.component_args() if callable(info.component_args) else info.component_args or {}
    args = {k: v for k, v in args.items() if k not in {'precision'}}

    return gr.update(value=value, **args)


436 437 438 439 440 441 442 443 444 445 446 447
def create_override_settings_dropdown(tabname, row):
    dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)

    dropdown.change(
        fn=lambda x: gr.Dropdown.update(visible=len(x) > 0),
        inputs=[dropdown],
        outputs=[dropdown],
    )

    return dropdown


448
def create_ui():
449 450
    import modules.img2img
    import modules.txt2img
451

452 453
    reload_javascript()

454
    parameters_copypaste.reset()
455

456 457 458
    modules.scripts.scripts_current = modules.scripts.scripts_txt2img
    modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)

459
    with gr.Blocks(analytics_enabled=False) as txt2img_interface:
A
AUTOMATIC 已提交
460
        txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
P
papuSpartan 已提交
461

462
        dummy_component = gr.Label(visible=False)
A
AUTOMATIC 已提交
463
        txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
464

A
AUTOMATIC 已提交
465 466 467 468
        with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
            from modules import ui_extra_networks
            extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')

469
        with gr.Row().style(equal_height=False):
A
AUTOMATIC 已提交
470
            with gr.Column(variant='compact', elem_id="txt2img_settings"):
471 472 473
                for category in ordered_ui_categories():
                    if category == "sampler":
                        steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
A
AUTOMATIC 已提交
474

475 476 477 478 479 480
                    elif category == "dimensions":
                        with FormRow():
                            with gr.Column(elem_id="txt2img_column_size", scale=4):
                                width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
                                height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")

A
AUTOMATIC 已提交
481 482 483
                            with gr.Column(elem_id="txt2img_dimensions_row", scale=1):
                                res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")

484 485 486 487 488 489 490 491 492 493 494 495
                            if opts.dimensions_and_batch_together:
                                with gr.Column(elem_id="txt2img_column_batch"):
                                    batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
                                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")

                    elif category == "cfg":
                        cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")

                    elif category == "seed":
                        seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')

                    elif category == "checkboxes":
A
AUTOMATIC 已提交
496
                        with FormRow(elem_classes="checkboxes-row", variant="compact"):
497 498 499
                            restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
                            tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
                            enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
500
                            hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
501 502

                    elif category == "hires_fix":
503
                        with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
504
                            with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
505 506 507 508
                                hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
                                hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
                                denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")

509
                            with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
510 511 512
                                hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
                                hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
                                hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
513 514 515 516 517 518 519

                    elif category == "batch":
                        if not opts.dimensions_and_batch_together:
                            with FormRow(elem_id="txt2img_column_batch"):
                                batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
                                batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")

520 521 522 523
                    elif category == "override_settings":
                        with FormRow(elem_id="txt2img_override_settings_row") as row:
                            override_settings = create_override_settings_dropdown('txt2img', row)

524 525 526
                    elif category == "scripts":
                        with FormGroup(elem_id="txt2img_script_container"):
                            custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
527

528 529
            hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
            for input in hr_resolution_preview_inputs:
530 531 532 533 534 535 536 537 538 539 540 541 542
                input.change(
                    fn=calc_resolution_hires,
                    inputs=hr_resolution_preview_inputs,
                    outputs=[hr_final_resolution],
                    show_progress=False,
                )
                input.change(
                    None,
                    _js="onCalcResolutionHires",
                    inputs=hr_resolution_preview_inputs,
                    outputs=[],
                    show_progress=False,
                )
543

544
            txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
545

546 547
            connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
            connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
548

549
            txt2img_args = dict(
550
                fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
A
AUTOMATIC 已提交
551
                _js="submit",
552
                inputs=[
553
                    dummy_component,
A
AUTOMATIC 已提交
554
                    txt2img_prompt,
555
                    txt2img_negative_prompt,
556
                    txt2img_prompt_styles,
557 558
                    steps,
                    sampler_index,
A
AUTOMATIC 已提交
559
                    restore_faces,
560
                    tiling,
561 562 563 564
                    batch_count,
                    batch_size,
                    cfg_scale,
                    seed,
565
                    subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
566 567
                    height,
                    width,
A
AUTOMATIC 已提交
568 569
                    enable_hr,
                    denoising_strength,
A
AUTOMATIC 已提交
570 571
                    hr_scale,
                    hr_upscaler,
572 573 574
                    hr_second_pass_steps,
                    hr_resize_x,
                    hr_resize_y,
575
                    override_settings,
A
AUTOMATIC 已提交
576
                ] + custom_inputs,
577

578 579 580
                outputs=[
                    txt2img_gallery,
                    generation_info,
581 582
                    html_info,
                    html_log,
583 584
                ],
                show_progress=False,
585 586
            )

A
AUTOMATIC 已提交
587
            txt2img_prompt.submit(**txt2img_args)
588
            submit.click(**txt2img_args)
A
AUTOMATIC 已提交
589

A
AUTOMATIC 已提交
590
            res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
591

D
d8ahazard 已提交
592 593 594 595 596 597 598 599 600 601 602
            txt_prompt_img.change(
                fn=modules.images.image_data,
                inputs=[
                    txt_prompt_img
                ],
                outputs=[
                    txt2img_prompt,
                    txt_prompt_img
                ]
            )

A
AUTOMATIC 已提交
603 604 605 606
            enable_hr.change(
                fn=lambda x: gr_show(x),
                inputs=[enable_hr],
                outputs=[hr_options],
607
                show_progress = False,
A
AUTOMATIC 已提交
608 609
            )

610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627
            txt2img_paste_fields = [
                (txt2img_prompt, "Prompt"),
                (txt2img_negative_prompt, "Negative prompt"),
                (steps, "Steps"),
                (sampler_index, "Sampler"),
                (restore_faces, "Face restoration"),
                (cfg_scale, "CFG scale"),
                (seed, "Seed"),
                (width, "Size-1"),
                (height, "Size-2"),
                (batch_size, "Batch size"),
                (subseed, "Variation seed"),
                (subseed_strength, "Variation seed strength"),
                (seed_resize_from_w, "Seed resize from-1"),
                (seed_resize_from_h, "Seed resize from-2"),
                (denoising_strength, "Denoising strength"),
                (enable_hr, lambda d: "Denoising strength" in d),
                (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
A
AUTOMATIC 已提交
628 629
                (hr_scale, "Hires upscale"),
                (hr_upscaler, "Hires upscaler"),
630 631 632
                (hr_second_pass_steps, "Hires steps"),
                (hr_resize_x, "Hires resize-1"),
                (hr_resize_y, "Hires resize-2"),
633
                *modules.scripts.scripts_txt2img.infotext_fields
634
            ]
635
            parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
636
            parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
637
                paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
638
            ))
639 640 641 642 643 644 645 646 647 648 649 650

            txt2img_preview_params = [
                txt2img_prompt,
                txt2img_negative_prompt,
                steps,
                sampler_index,
                cfg_scale,
                seed,
                width,
                height,
            ]

651
            token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
652
            negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
653

A
AUTOMATIC 已提交
654 655
            ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)

656 657
    modules.scripts.scripts_current = modules.scripts.scripts_img2img
    modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
658

659
    with gr.Blocks(analytics_enabled=False) as img2img_interface:
A
AUTOMATIC 已提交
660
        img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
661

A
AUTOMATIC 已提交
662
        img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
663

A
AUTOMATIC 已提交
664 665 666 667
        with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
            from modules import ui_extra_networks
            extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')

668
        with FormRow().style(equal_height=False):
A
AUTOMATIC 已提交
669
            with gr.Column(variant='compact', elem_id="img2img_settings"):
670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
                copy_image_buttons = []
                copy_image_destinations = {}

                def add_copy_image_controls(tab_name, elem):
                    with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"):
                        gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}")

                        for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']):
                            if name == tab_name:
                                gr.Button(title, interactive=False)
                                copy_image_destinations[name] = elem
                                continue

                            button = gr.Button(title)
                            copy_image_buttons.append((button, name, elem))

686 687 688
                with gr.Tabs(elem_id="mode_img2img"):
                    with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
                        init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
689
                        add_copy_image_controls('img2img', init_img)
690

691 692
                    with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
                        sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
693
                        add_copy_image_controls('sketch', sketch)
694

695 696
                    with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
                        init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
697
                        add_copy_image_controls('inpaint', init_img_with_mask)
698

699 700 701
                    with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
                        inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
                        inpaint_color_sketch_orig = gr.State(None)
702
                        add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
703

704 705 706 707 708 709
                        def update_orig(image, state):
                            if image is not None:
                                same_size = state is not None and state.size == image.size
                                has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
                                edited = same_size and has_exact_match
                                return image if not edited or state is None else state
710

711
                        inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
712

713 714 715
                    with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
                        init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
                        init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask")
716

717
                    with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
718
                        hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
A
Andrii Skaliuk 已提交
719 720 721 722 723 724
                        gr.HTML(
                            f"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
                            f"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
                            f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
                            f"{hidden}</p>"
                        )
725 726
                        img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
                        img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
A
Andrii Skaliuk 已提交
727
                        img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
728

729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747
                def copy_image(img):
                    if isinstance(img, dict) and 'image' in img:
                        return img['image']

                    return img

                for button, name, elem in copy_image_buttons:
                    button.click(
                        fn=copy_image,
                        inputs=[elem],
                        outputs=[copy_image_destinations[name]],
                    )
                    button.click(
                        fn=lambda: None,
                        _js="switch_to_"+name.replace(" ", "_"),
                        inputs=[],
                        outputs=[],
                    )

748 749
                with FormRow():
                    resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
750

751 752 753
                for category in ordered_ui_categories():
                    if category == "sampler":
                        steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
A
AUTOMATIC 已提交
754

755 756 757 758 759
                    elif category == "dimensions":
                        with FormRow():
                            with gr.Column(elem_id="img2img_column_size", scale=4):
                                width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
                                height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
760

A
AUTOMATIC 已提交
761 762 763
                            with gr.Column(elem_id="img2img_dimensions_row", scale=1):
                                res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")

764 765 766 767
                            if opts.dimensions_and_batch_together:
                                with gr.Column(elem_id="img2img_column_batch"):
                                    batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
                                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
768

769 770
                    elif category == "cfg":
                        with FormGroup():
771 772 773
                            with FormRow():
                                cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
                                image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
774
                            denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
775

776 777
                    elif category == "seed":
                        seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
778

779
                    elif category == "checkboxes":
A
AUTOMATIC 已提交
780
                        with FormRow(elem_classes="checkboxes-row", variant="compact"):
781 782
                            restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
                            tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
783

784 785 786 787 788
                    elif category == "batch":
                        if not opts.dimensions_and_batch_together:
                            with FormRow(elem_id="img2img_column_batch"):
                                batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
                                batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
A
AUTOMATIC 已提交
789

790 791 792 793
                    elif category == "override_settings":
                        with FormRow(elem_id="img2img_override_settings_row") as row:
                            override_settings = create_override_settings_dropdown('img2img', row)

794 795 796
                    elif category == "scripts":
                        with FormGroup(elem_id="img2img_script_container"):
                            custom_inputs = modules.scripts.scripts_img2img.setup_ui()
A
AUTOMATIC 已提交
797

A
AUTOMATIC 已提交
798
                    elif category == "inpaint":
799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826
                        with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
                            with FormRow():
                                mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
                                mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")

                            with FormRow():
                                inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")

                            with FormRow():
                                inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")

                            with FormRow():
                                with gr.Column():
                                    inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")

                                with gr.Column(scale=4):
                                    inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")

                            def select_img2img_tab(tab):
                                return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),

                            for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
                                elem.select(
                                    fn=lambda tab=i: select_img2img_tab(tab),
                                    inputs=[],
                                    outputs=[inpaint_controls, mask_alpha],
                                )

827
            img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
828

829 830
            connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
            connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
831

D
d8ahazard 已提交
832 833 834
            img2img_prompt_img.change(
                fn=modules.images.image_data,
                inputs=[
A
AUTOMATIC 已提交
835
                    img2img_prompt_img
D
d8ahazard 已提交
836 837 838 839 840 841 842
                ],
                outputs=[
                    img2img_prompt,
                    img2img_prompt_img
                ]
            )

843
            img2img_args = dict(
844
                fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
845
                _js="submit_img2img",
846
                inputs=[
847
                    dummy_component,
848
                    dummy_component,
A
AUTOMATIC 已提交
849
                    img2img_prompt,
850
                    img2img_negative_prompt,
851
                    img2img_prompt_styles,
852
                    init_img,
853
                    sketch,
854
                    init_img_with_mask,
855 856
                    inpaint_color_sketch,
                    inpaint_color_sketch_orig,
857 858
                    init_img_inpaint,
                    init_mask_inpaint,
859 860 861
                    steps,
                    sampler_index,
                    mask_blur,
862
                    mask_alpha,
863
                    inpainting_fill,
A
AUTOMATIC 已提交
864
                    restore_faces,
865
                    tiling,
866 867 868
                    batch_count,
                    batch_size,
                    cfg_scale,
K
Kyle 已提交
869
                    image_cfg_scale,
870 871
                    denoising_strength,
                    seed,
872
                    subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
873 874 875 876
                    height,
                    width,
                    resize_mode,
                    inpaint_full_res,
877
                    inpaint_full_res_padding,
A
AUTOMATIC 已提交
878
                    inpainting_mask_invert,
879 880
                    img2img_batch_input_dir,
                    img2img_batch_output_dir,
881 882
                    img2img_batch_inpaint_mask_dir,
                    override_settings,
A
AUTOMATIC 已提交
883
                ] + custom_inputs,
884 885 886
                outputs=[
                    img2img_gallery,
                    generation_info,
887 888
                    html_info,
                    html_log,
889 890
                ],
                show_progress=False,
891 892
            )

893 894 895 896 897 898 899 900 901 902 903 904 905 906 907
            interrogate_args = dict(
                _js="get_img2img_tab_index",
                inputs=[
                    dummy_component,
                    img2img_batch_input_dir,
                    img2img_batch_output_dir,
                    init_img,
                    sketch,
                    init_img_with_mask,
                    inpaint_color_sketch,
                    init_img_inpaint,
                ],
                outputs=[img2img_prompt, dummy_component],
            )

A
AUTOMATIC 已提交
908
            img2img_prompt.submit(**img2img_args)
909
            submit.click(**img2img_args)
A
AUTOMATIC 已提交
910
            res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
911

A
AUTOMATIC 已提交
912
            img2img_interrogate.click(
913
                fn=lambda *args: process_interrogate(interrogate, *args),
914
                **interrogate_args,
A
AUTOMATIC 已提交
915 916
            )

A
AUTOMATIC 已提交
917
            img2img_deepbooru.click(
918
                fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
919
                **interrogate_args,
A
AUTOMATIC 已提交
920 921 922
            )

            prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
923
            style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
L
Liam 已提交
924
            style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
A
AUTOMATIC 已提交
925 926

            for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
A
AUTOMATIC 已提交
927 928 929
                button.click(
                    fn=add_style,
                    _js="ask_for_style_name",
930 931 932
                    # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
                    # the same number of parameters, but we only know the style-name after the JavaScript prompt
                    inputs=[dummy_component, prompt, negative_prompt],
933
                    outputs=[txt2img_prompt_styles, img2img_prompt_styles],
A
AUTOMATIC 已提交
934 935
                )

936
            for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
A
AUTOMATIC 已提交
937 938
                button.click(
                    fn=apply_styles,
939
                    _js=js_func,
940 941
                    inputs=[prompt, negative_prompt, styles],
                    outputs=[prompt, negative_prompt, styles],
A
AUTOMATIC 已提交
942 943
                )

Y
yfszzx 已提交
944
            token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
X
xSinStarx 已提交
945
            negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])
Y
yfszzx 已提交
946

A
AUTOMATIC 已提交
947 948
            ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)

949 950 951 952 953 954 955
            img2img_paste_fields = [
                (img2img_prompt, "Prompt"),
                (img2img_negative_prompt, "Negative prompt"),
                (steps, "Steps"),
                (sampler_index, "Sampler"),
                (restore_faces, "Face restoration"),
                (cfg_scale, "CFG scale"),
K
Kyle 已提交
956
                (image_cfg_scale, "Image CFG scale"),
957 958 959 960 961 962 963 964 965
                (seed, "Seed"),
                (width, "Size-1"),
                (height, "Size-2"),
                (batch_size, "Batch size"),
                (subseed, "Variation seed"),
                (subseed_strength, "Variation seed strength"),
                (seed_resize_from_w, "Seed resize from-1"),
                (seed_resize_from_h, "Seed resize from-2"),
                (denoising_strength, "Denoising strength"),
A
AUTOMATIC 已提交
966
                (mask_blur, "Mask blur"),
967
                *modules.scripts.scripts_img2img.infotext_fields
968
            ]
969 970
            parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
            parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
971
            parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
972
                paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
973
            ))
974

975
    modules.scripts.scripts_current = None
976

977
    with gr.Blocks(analytics_enabled=False) as extras_interface:
978
        ui_postprocessing.create_ui()
979

980 981 982 983 984 985 986
    with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):
                image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")

            with gr.Column(variant='panel'):
                html = gr.HTML()
987
                generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
988 989
                html2 = gr.HTML()
                with gr.Row():
Y
yfszzx 已提交
990
                    buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
991 992 993 994 995

                for tabname, button in buttons.items():
                    parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
                        paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image,
                    ))
996 997

        image.change(
998
            fn=wrap_gradio_call(modules.extras.run_pnginfo),
999 1000 1001
            inputs=[image],
            outputs=[html, generation_info, html2],
        )
1002

1003 1004 1005 1006 1007 1008 1009 1010 1011
    def update_interp_description(value):
        interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
        interp_descriptions = {
            "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
            "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
            "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
        }
        return interp_descriptions[value]

1012
    with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
1013
        with gr.Row().style(equal_height=False):
A
AUTOMATIC 已提交
1014
            with gr.Column(variant='compact'):
1015
                interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
1016

1017
                with FormRow(elem_id="modelmerger_models"):
1018
                    primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
1019 1020
                    create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")

1021
                    secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
1022 1023
                    create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")

1024
                    tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
1025 1026
                    create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")

1027 1028
                custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
                interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
1029
                interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
1030
                interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
1031

1032
                with FormRow():
1033 1034
                    checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
                    save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
1035

1036 1037 1038 1039 1040 1041 1042 1043
                with FormRow():
                    with gr.Column():
                        config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")

                    with gr.Column():
                        with FormRow():
                            bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
                            create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
1044

1045 1046 1047
                with FormRow():
                    discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")

A
AUTOMATIC 已提交
1048 1049
                with gr.Row():
                    modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
1050

A
AUTOMATIC 已提交
1051 1052 1053
            with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
                with gr.Group(elem_id="modelmerger_results_panel"):
                    modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
1054

1055
    with gr.Blocks(analytics_enabled=False) as train_interface:
1056
        with gr.Row().style(equal_height=False):
A
AUTOMATIC 已提交
1057
            gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
1058

A
AUTOMATIC 已提交
1059
        with gr.Row(variant="compact").style(equal_height=False):
A
AUTOMATIC 已提交
1060
            with gr.Tabs(elem_id="train_tabs"):
1061

A
AUTOMATIC 已提交
1062
                with gr.Tab(label="Create embedding"):
1063 1064 1065 1066
                    new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
                    initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
                    nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
                    overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
1067 1068 1069 1070 1071 1072

                    with gr.Row():
                        with gr.Column(scale=3):
                            gr.HTML(value="")

                        with gr.Column():
1073
                            create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
1074

A
AUTOMATIC 已提交
1075
                with gr.Tab(label="Create hypernetwork"):
1076 1077 1078 1079 1080 1081 1082
                    new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
                    new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
                    new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
                    new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
                    new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
                    new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
                    new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
A
aria1th 已提交
1083
                    new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
1084
                    overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
A
AUTOMATIC 已提交
1085 1086 1087 1088 1089 1090

                    with gr.Row():
                        with gr.Column(scale=3):
                            gr.HTML(value="")

                        with gr.Column():
1091
                            create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
1092

A
AUTOMATIC 已提交
1093
                with gr.Tab(label="Preprocess images"):
1094 1095 1096 1097 1098
                    process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
                    process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
                    process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
                    process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
                    preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
1099 1100

                    with gr.Row():
1101 1102 1103
                        process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
                        process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
                        process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
D
dan 已提交
1104
                        process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop")
1105 1106
                        process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
                        process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
1107

1108
                    with gr.Row(visible=False) as process_split_extra_row:
1109 1110
                        process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
                        process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
1111

C
captin411 已提交
1112
                    with gr.Row(visible=False) as process_focal_crop_row:
1113 1114 1115 1116
                        process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
                        process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
                        process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
                        process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
D
dan 已提交
1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129
                    
                    with gr.Column(visible=False) as process_multicrop_col:
                        gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
                        with gr.Row():
                            process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim")
                            process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim")
                        with gr.Row():
                            process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea")
                            process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea")
                        with gr.Row():
                            process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
                            process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
   
1130 1131 1132 1133 1134
                    with gr.Row():
                        with gr.Column(scale=3):
                            gr.HTML(value="")

                        with gr.Column():
S
space-nuko 已提交
1135
                            with gr.Row():
1136 1137
                                interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
                            run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
1138

1139 1140 1141 1142 1143 1144
                    process_split.change(
                        fn=lambda show: gr_show(show),
                        inputs=[process_split],
                        outputs=[process_split_extra_row],
                    )

C
captin411 已提交
1145 1146 1147 1148 1149 1150
                    process_focal_crop.change(
                        fn=lambda show: gr_show(show),
                        inputs=[process_focal_crop],
                        outputs=[process_focal_crop_row],
                    )

D
dan 已提交
1151 1152 1153 1154 1155 1156
                    process_multicrop.change(
                        fn=lambda show: gr_show(show),
                        inputs=[process_multicrop],
                        outputs=[process_multicrop_col],
                    )

1157 1158 1159
                def get_textual_inversion_template_names():
                    return sorted([x for x in textual_inversion.textual_inversion_templates])

A
AUTOMATIC 已提交
1160
                with gr.Tab(label="Train"):
D
DepFA 已提交
1161
                    gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
A
AUTOMATIC 已提交
1162
                    with FormRow():
A
AUTOMATIC 已提交
1163
                        train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
1164
                        create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
A
AUTOMATIC 已提交
1165

A
AUTOMATIC 已提交
1166
                        train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
1167
                        create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
A
AUTOMATIC 已提交
1168 1169

                    with FormRow():
1170 1171
                        embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
                        hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
1172
                    
A
AUTOMATIC 已提交
1173
                    with FormRow():
1174
                        clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
1175
                        clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
1176

A
AUTOMATIC 已提交
1177 1178 1179 1180
                    with FormRow():
                        batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
                        gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")

1181 1182
                    dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
                    log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
1183 1184 1185 1186 1187

                    with FormRow():
                        template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
                        create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")

1188 1189
                    training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
                    training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
1190
                    varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
1191
                    steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
A
AUTOMATIC 已提交
1192 1193 1194 1195 1196

                    with FormRow():
                        create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
                        save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")

1197 1198
                    use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False, elem_id="use_weight")

1199 1200
                    save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
                    preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
A
AUTOMATIC 已提交
1201 1202 1203 1204 1205

                    shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
                    tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")

                    latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
1206 1207

                    with gr.Row():
A
AUTOMATIC 已提交
1208
                        train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
1209 1210
                        interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
                        train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
1211

1212 1213 1214 1215
                params = script_callbacks.UiTrainTabParams(txt2img_preview_params)

                script_callbacks.ui_train_tabs_callback(params)

1216
            with gr.Column(elem_id='ti_gallery_container'):
1217 1218 1219 1220 1221 1222 1223 1224 1225
                ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
                ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
                ti_progress = gr.HTML(elem_id="ti_progress", value="")
                ti_outcome = gr.HTML(elem_id="ti_error", value="")

        create_embedding.click(
            fn=modules.textual_inversion.ui.create_embedding,
            inputs=[
                new_embedding_name,
1226
                initialization_text,
1227
                nvpt,
D
DepFA 已提交
1228
                overwrite_old_embedding,
1229 1230 1231 1232 1233 1234 1235 1236
            ],
            outputs=[
                train_embedding_name,
                ti_output,
                ti_outcome,
            ]
        )

A
AUTOMATIC 已提交
1237
        create_hypernetwork.click(
A
AUTOMATIC 已提交
1238
            fn=modules.hypernetworks.ui.create_hypernetwork,
A
AUTOMATIC 已提交
1239 1240
            inputs=[
                new_hypernetwork_name,
1241
                new_hypernetwork_sizes,
D
DepFA 已提交
1242
                overwrite_old_hypernetwork,
1243
                new_hypernetwork_layer_structure,
D
update  
discus0434 已提交
1244
                new_hypernetwork_activation_func,
1245
                new_hypernetwork_initialization_option,
1246
                new_hypernetwork_add_layer_norm,
A
aria1th 已提交
1247 1248
                new_hypernetwork_use_dropout,
                new_hypernetwork_dropout_structure
A
AUTOMATIC 已提交
1249 1250 1251 1252 1253 1254 1255 1256
            ],
            outputs=[
                train_hypernetwork_name,
                ti_output,
                ti_outcome,
            ]
        )

1257 1258 1259 1260
        run_preprocess.click(
            fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
            _js="start_training_textual_inversion",
            inputs=[
1261
                dummy_component,
1262 1263
                process_src,
                process_dst,
A
alg-wiki 已提交
1264 1265
                process_width,
                process_height,
D
DepFA 已提交
1266
                preprocess_txt_action,
1267 1268 1269
                process_flip,
                process_split,
                process_caption,
1270 1271 1272
                process_caption_deepbooru,
                process_split_threshold,
                process_overlap_ratio,
C
captin411 已提交
1273 1274 1275 1276 1277
                process_focal_crop,
                process_focal_crop_face_weight,
                process_focal_crop_entropy_weight,
                process_focal_crop_edges_weight,
                process_focal_crop_debug,
D
dan 已提交
1278 1279 1280 1281 1282 1283 1284
                process_multicrop,
                process_multicrop_mindim,
                process_multicrop_maxdim,
                process_multicrop_minarea,
                process_multicrop_maxarea,
                process_multicrop_objective,
                process_multicrop_threshold,
1285 1286 1287 1288 1289 1290 1291
            ],
            outputs=[
                ti_output,
                ti_outcome,
            ],
        )

1292 1293 1294 1295
        train_embedding.click(
            fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
            _js="start_training_textual_inversion",
            inputs=[
1296
                dummy_component,
1297
                train_embedding_name,
D
DepFA 已提交
1298
                embedding_learn_rate,
1299
                batch_size,
1300
                gradient_step,
1301 1302
                dataset_directory,
                log_directory,
A
alg-wiki 已提交
1303 1304
                training_width,
                training_height,
D
dan 已提交
1305
                varsize,
1306
                steps,
1307 1308
                clip_grad_mode,
                clip_grad_value,
1309 1310 1311
                shuffle_tags,
                tag_drop_out,
                latent_sampling_method,
1312
                use_weight,
1313 1314 1315
                create_image_every,
                save_embedding_every,
                template_file,
D
DepFA 已提交
1316
                save_image_with_stored_embedding,
1317 1318
                preview_from_txt2img,
                *txt2img_preview_params,
1319 1320 1321 1322 1323 1324 1325
            ],
            outputs=[
                ti_output,
                ti_outcome,
            ]
        )

A
AUTOMATIC 已提交
1326
        train_hypernetwork.click(
A
AUTOMATIC 已提交
1327
            fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
A
AUTOMATIC 已提交
1328 1329
            _js="start_training_textual_inversion",
            inputs=[
1330
                dummy_component,
A
AUTOMATIC 已提交
1331
                train_hypernetwork_name,
D
DepFA 已提交
1332
                hypernetwork_learn_rate,
1333
                batch_size,
1334
                gradient_step,
A
AUTOMATIC 已提交
1335 1336
                dataset_directory,
                log_directory,
1337 1338
                training_width,
                training_height,
D
dan 已提交
1339
                varsize,
1340
                steps,
1341 1342
                clip_grad_mode,
                clip_grad_value,
1343 1344 1345
                shuffle_tags,
                tag_drop_out,
                latent_sampling_method,
1346
                use_weight,
1347 1348 1349
                create_image_every,
                save_embedding_every,
                template_file,
1350 1351
                preview_from_txt2img,
                *txt2img_preview_params,
1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364
            ],
            outputs=[
                ti_output,
                ti_outcome,
            ]
        )

        interrupt_training.click(
            fn=lambda: shared.state.interrupt(),
            inputs=[],
            outputs=[],
        )

S
space-nuko 已提交
1365 1366 1367 1368 1369 1370
        interrupt_preprocessing.click(
            fn=lambda: shared.state.interrupt(),
            inputs=[],
            outputs=[],
        )

1371
    def create_setting_component(key, is_quicksettings=False):
1372 1373 1374 1375 1376 1377
        def fun():
            return opts.data[key] if key in opts.data else opts.data_labels[key].default

        info = opts.data_labels[key]
        t = type(info.default)

1378 1379
        args = info.component_args() if callable(info.component_args) else info.component_args

1380
        if info.component is not None:
1381
            comp = info.component
1382
        elif t == str:
1383
            comp = gr.Textbox
1384
        elif t == int:
1385
            comp = gr.Number
1386
        elif t == bool:
1387
            comp = gr.Checkbox
1388 1389 1390
        else:
            raise Exception(f'bad options item type: {str(t)} for key {key}')

A
AUTOMATIC 已提交
1391 1392
        elem_id = "setting_"+key

1393 1394
        if info.refresh is not None:
            if is_quicksettings:
1395
                res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
A
AUTOMATIC 已提交
1396
                create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
1397
            else:
1398
                with FormRow():
1399
                    res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
A
AUTOMATIC 已提交
1400
                    create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
1401
        else:
1402
            res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
1403 1404

        return res
1405

A
AUTOMATIC 已提交
1406
    components = []
1407
    component_dict = {}
1408
    shared.settings_components = component_dict
A
AUTOMATIC 已提交
1409

1410 1411 1412
    script_callbacks.ui_settings_callback()
    opts.reorder()

1413
    def run_settings(*args):
1414
        changed = []
1415 1416

        for key, value, comp in zip(opts.data_labels.keys(), args, components):
1417
            assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
1418

A
AUTOMATIC 已提交
1419
        for key, value, comp in zip(opts.data_labels.keys(), args, components):
1420 1421 1422
            if comp == dummy_component:
                continue

1423
            if opts.set(key, value):
1424
                changed.append(key)
1425

1426 1427 1428
        try:
            opts.save(shared.config_filename)
        except RuntimeError:
1429
            return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
1430
        return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
1431

1432 1433 1434 1435
    def run_settings_single(value, key):
        if not opts.same_type(value, opts.data_labels[key].default):
            return gr.update(visible=True), opts.dumpjson()

1436 1437
        if not opts.set(key, value):
            return gr.update(value=getattr(opts, key)), opts.dumpjson()
1438 1439 1440

        opts.save(shared.config_filename)

1441
        return get_value_for_setting(key), opts.dumpjson()
1442

A
AUTOMATIC 已提交
1443
    with gr.Blocks(analytics_enabled=False) as settings_interface:
1444
        with gr.Row():
A
AUTOMATIC 已提交
1445 1446 1447 1448
            with gr.Column(scale=6):
                settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
            with gr.Column():
                restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
A
AUTOMATIC 已提交
1449

1450
        result = gr.HTML(elem_id="settings_result")
A
AUTOMATIC 已提交
1451

1452
        quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
1453
        quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
1454

1455 1456
        quicksettings_list = []

1457
        previous_section = None
1458
        current_tab = None
1459
        current_row = None
1460
        with gr.Tabs(elem_id="settings"):
1461
            for i, (k, item) in enumerate(opts.data_labels.items()):
1462
                section_must_be_skipped = item.section[0] is None
D
DepFA 已提交
1463

1464
                if previous_section != item.section and not section_must_be_skipped:
1465
                    elem_id, text = item.section
D
DepFA 已提交
1466

1467
                    if current_tab is not None:
1468
                        current_row.__exit__()
1469
                        current_tab.__exit__()
A
AUTOMATIC 已提交
1470

1471
                    gr.Group()
1472 1473
                    current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
                    current_tab.__enter__()
1474 1475
                    current_row = gr.Column(variant='compact')
                    current_row.__enter__()
1476 1477 1478

                    previous_section = item.section

1479
                if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
1480 1481
                    quicksettings_list.append((i, k, item))
                    components.append(dummy_component)
1482 1483
                elif section_must_be_skipped:
                    components.append(dummy_component)
1484 1485 1486 1487
                else:
                    component = create_setting_component(k)
                    component_dict[k] = component
                    components.append(component)
1488

1489
            if current_tab is not None:
1490
                current_row.__exit__()
1491
                current_tab.__exit__()
A
AUTOMATIC 已提交
1492

1493 1494 1495 1496
            with gr.TabItem("Actions"):
                request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
                download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
                reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
1497

A
AUTOMATIC 已提交
1498 1499
            with gr.TabItem("Licenses"):
                gr.HTML(shared.html("licenses.html"), elem_id="licenses")
A
AUTOMATIC 已提交
1500

1501
            gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
1502

1503 1504 1505 1506
        request_notifications.click(
            fn=lambda: None,
            inputs=[],
            outputs=[],
1507
            _js='function(){}'
1508 1509
        )

A
AUTOMATIC 已提交
1510 1511 1512 1513 1514 1515 1516
        download_localization.click(
            fn=lambda: None,
            inputs=[],
            outputs=[],
            _js='download_localization'
        )

D
DepFA 已提交
1517
        def reload_scripts():
D
DepFA 已提交
1518
            modules.scripts.reload_script_body_only()
1519
            reload_javascript()  # need to refresh the html page
D
DepFA 已提交
1520 1521 1522 1523

        reload_script_bodies.click(
            fn=reload_scripts,
            inputs=[],
A
AUTOMATIC 已提交
1524
            outputs=[]
D
DepFA 已提交
1525
        )
1526 1527

        def request_restart():
1528
            shared.state.interrupt()
1529
            shared.state.need_restart = True
1530 1531 1532

        restart_gradio.click(
            fn=request_restart,
1533
            _js='restart_reload',
1534 1535 1536
            inputs=[],
            outputs=[],
        )
J
Justin Maier 已提交
1537

1538
    interfaces = [
A
AUTOMATIC 已提交
1539 1540 1541 1542
        (txt2img_interface, "txt2img", "txt2img"),
        (img2img_interface, "img2img", "img2img"),
        (extras_interface, "Extras", "extras"),
        (pnginfo_interface, "PNG Info", "pnginfo"),
1543
        (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
A
AUTOMATIC 已提交
1544
        (train_interface, "Train", "ti"),
1545 1546
    ]

A
AUTOMATIC 已提交
1547 1548 1549
    css = ""

    for cssfile in modules.scripts.list_files_with_name("style.css"):
A
AUTOMATIC 已提交
1550 1551 1552
        if not os.path.isfile(cssfile):
            continue

A
AUTOMATIC 已提交
1553 1554
        with open(cssfile, "r", encoding="utf8") as file:
            css += file.read() + "\n"
1555

1556 1557
    if os.path.exists(os.path.join(data_path, "user.css")):
        with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file:
A
AUTOMATIC 已提交
1558
            css += file.read() + "\n"
A
AUTOMATIC 已提交
1559

1560 1561 1562
    if not cmd_opts.no_progressbar_hiding:
        css += css_hide_progressbar

1563 1564 1565
    interfaces += script_callbacks.ui_tabs_callback()
    interfaces += [(settings_interface, "Settings", "settings")]

1566 1567 1568
    extensions_interface = ui_extensions.create_ui()
    interfaces += [(extensions_interface, "Extensions", "extensions")]

1569 1570 1571 1572
    shared.tab_names = []
    for _interface, label, _ifid in interfaces:
        shared.tab_names.append(label)

A
AUTOMATIC 已提交
1573
    with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
1574
        with gr.Row(elem_id="quicksettings", variant="compact"):
1575
            for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
1576
                component = create_setting_component(k, is_quicksettings=True)
1577 1578
                component_dict[k] = component

1579
        parameters_copypaste.connect_paste_params_buttons()
1580

1581
        with gr.Tabs(elem_id="tabs") as tabs:
A
AUTOMATIC 已提交
1582
            for interface, label, ifid in interfaces:
1583
                if label in shared.opts.hidden_tabs:
V
Vladimir Mandic 已提交
1584
                    continue
1585
                with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
A
AUTOMATIC 已提交
1586
                    interface.render()
J
Justin Maier 已提交
1587

1588 1589
        if os.path.exists(os.path.join(script_path, "notification.mp3")):
            audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
A
AUTOMATIC 已提交
1590

A
AUTOMATIC 已提交
1591 1592 1593
        footer = shared.html("footer.html")
        footer = footer.format(versions=versions_html())
        gr.HTML(footer, elem_id="footer")
A
AUTOMATIC 已提交
1594

1595
        text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
1596
        settings_submit.click(
1597
            fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
1598
            inputs=components,
1599
            outputs=[text_settings, result],
1600
        )
1601 1602 1603

        for i, k, item in quicksettings_list:
            component = component_dict[k]
1604
            info = opts.data_labels[k]
1605 1606 1607 1608 1609

            component.change(
                fn=lambda value, k=k: run_settings_single(value, key=k),
                inputs=[component],
                outputs=[component, text_settings],
1610
                show_progress=info.refresh is not None,
1611 1612
            )

1613 1614 1615 1616 1617 1618
        text_settings.change(
            fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
            inputs=[],
            outputs=[image_cfg_scale],
        )

1619 1620 1621 1622 1623 1624 1625 1626
        button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
        button_set_checkpoint.click(
            fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
            _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
            inputs=[component_dict['sd_model_checkpoint'], dummy_component],
            outputs=[component_dict['sd_model_checkpoint'], text_settings],
        )

1627 1628 1629
        component_keys = [k for k in opts.data_labels.keys() if k in component_dict]

        def get_settings_values():
1630
            return [get_value_for_setting(key) for key in component_keys]
1631 1632 1633 1634 1635 1636 1637

        demo.load(
            fn=get_settings_values,
            inputs=[],
            outputs=[component_dict[k] for k in component_keys],
        )

S
safentisAuth 已提交
1638 1639
        def modelmerger(*args):
            try:
1640
                results = modules.extras.run_modelmerger(*args)
S
safentisAuth 已提交
1641 1642 1643
            except Exception as e:
                print("Error loading/saving model file:", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)
1644
                modules.sd_models.list_models()  # to remove the potentially missing models from the list
A
AUTOMATIC 已提交
1645
                return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
S
safentisAuth 已提交
1646
            return results
1647

1648
        modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
1649
        modelmerger_merge.click(
A
AUTOMATIC 已提交
1650 1651
            fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
            _js='modelmerger',
1652
            inputs=[
A
AUTOMATIC 已提交
1653
                dummy_component,
1654 1655
                primary_model_name,
                secondary_model_name,
1656
                tertiary_model_name,
1657 1658 1659
                interp_method,
                interp_amount,
                save_as_half,
S
safentisAuth 已提交
1660
                custom_name,
1661
                checkpoint_format,
1662
                config_source,
1663
                bake_in_vae,
1664
                discard_weights,
1665 1666 1667 1668
            ],
            outputs=[
                primary_model_name,
                secondary_model_name,
1669
                tertiary_model_name,
1670
                component_dict['sd_model_checkpoint'],
A
AUTOMATIC 已提交
1671
                modelmerger_result,
1672 1673
            ]
        )
1674

1675
    ui_config_file = cmd_opts.ui_config_file
A
AUTOMATIC 已提交
1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689
    ui_settings = {}
    settings_count = len(ui_settings)
    error_loading = False

    try:
        if os.path.exists(ui_config_file):
            with open(ui_config_file, "r", encoding="utf8") as file:
                ui_settings = json.load(file)
    except Exception:
        error_loading = True
        print("Error loading settings:", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)

    def loadsave(path, x):
ふぁ 已提交
1690
        def apply_field(obj, field, condition=None, init_field=None):
A
AUTOMATIC 已提交
1691
            key = path + "/" + field
1692

1693
            if getattr(obj, 'custom_script_source', None) is not None:
1694
              key = 'customscript/' + obj.custom_script_source + '/' + key
J
Justin Maier 已提交
1695

A
AUTOMATIC 已提交
1696 1697
            if getattr(obj, 'do_not_save_to_config', False):
                return
J
Justin Maier 已提交
1698

A
AUTOMATIC 已提交
1699 1700 1701
            saved_value = ui_settings.get(key, None)
            if saved_value is None:
                ui_settings[key] = getattr(obj, field)
C
CookieHCl 已提交
1702
            elif condition and not condition(saved_value):
A
AUTOMATIC 已提交
1703 1704 1705 1706
                pass

                # this warning is generally not useful;
                # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
C
CookieHCl 已提交
1707
            else:
A
AUTOMATIC 已提交
1708
                setattr(obj, field, saved_value)
ふぁ 已提交
1709 1710
                if init_field is not None:
                    init_field(saved_value)
A
AUTOMATIC 已提交
1711

1712
        if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
A
AUTOMATIC 已提交
1713 1714
            apply_field(x, 'visible')

A
AUTOMATIC 已提交
1715 1716 1717 1718 1719 1720 1721
        if type(x) == gr.Slider:
            apply_field(x, 'value')
            apply_field(x, 'minimum')
            apply_field(x, 'maximum')
            apply_field(x, 'step')

        if type(x) == gr.Radio:
1722
            apply_field(x, 'value', lambda val: val in x.choices)
A
AUTOMATIC 已提交
1723

D
DepFA 已提交
1724
        if type(x) == gr.Checkbox:
D
DepFA 已提交
1725
            apply_field(x, 'value')
D
DepFA 已提交
1726 1727

        if type(x) == gr.Textbox:
D
DepFA 已提交
1728
            apply_field(x, 'value')
J
Justin Maier 已提交
1729

D
DepFA 已提交
1730
        if type(x) == gr.Number:
D
DepFA 已提交
1731
            apply_field(x, 'value')
J
Justin Maier 已提交
1732

1733
        if type(x) == gr.Dropdown:
1734
            def check_dropdown(val):
A
AUTOMATIC 已提交
1735
                if getattr(x, 'multiselect', False):
1736 1737 1738 1739 1740
                    return all([value in x.choices for value in val])
                else:
                    return val in x.choices

            apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
1741

A
AUTOMATIC 已提交
1742 1743
    visit(txt2img_interface, loadsave, "txt2img")
    visit(img2img_interface, loadsave, "img2img")
1744
    visit(extras_interface, loadsave, "extras")
1745
    visit(modelmerger_interface, loadsave, "modelmerger")
A
AUTOMATIC 已提交
1746
    visit(train_interface, loadsave, "train")
A
AUTOMATIC 已提交
1747 1748 1749 1750 1751

    if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
        with open(ui_config_file, "w", encoding="utf8") as file:
            json.dump(ui_settings, file, indent=4)

1752 1753 1754
    # Required as a workaround for change() event not triggering when loading values from ui-config.json
    interp_description.value = update_interp_description(interp_method.value)

1755 1756 1757
    return demo


1758
def reload_javascript():
1759 1760
    script_js = os.path.join(script_path, "script.js")
    head = f'<script type="text/javascript" src="file={os.path.abspath(script_js)}?{os.path.getmtime(script_js)}"></script>\n'
1761

1762
    inline = f"{localization.localization_js(shared.opts.localization)};"
1763
    if cmd_opts.theme is not None:
1764
        inline += f"set_theme('{cmd_opts.theme}');"
1765

1766
    for script in modules.scripts.list_scripts("javascript", ".js"):
1767
        head += f'<script type="text/javascript" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
1768

1769 1770 1771
    for script in modules.scripts.list_scripts("javascript", ".mjs"):
        head += f'<script type="module" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'

S
Shondoit 已提交
1772 1773
    head += f'<script type="text/javascript">{inline}</script>\n'

D
DepFA 已提交
1774
    def template_response(*args, **kwargs):
1775
        res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
1776
        res.body = res.body.replace(b'</head>', f'{head}</head>'.encode("utf8"))
D
DepFA 已提交
1777 1778 1779 1780
        res.init_headers()
        return res

    gradio.routes.templates.TemplateResponse = template_response
Y
yfszzx 已提交
1781

1782

1783 1784
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
    shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
A
AUTOMATIC 已提交
1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803


def versions_html():
    import torch
    import launch

    python_version = ".".join([str(x) for x in sys.version_info[0:3]])
    commit = launch.commit_hash()
    short_commit = commit[0:8]

    if shared.xformers_available:
        import xformers
        xformers_version = xformers.__version__
    else:
        xformers_version = "N/A"

    return f"""
python: <span title="{sys.version}">{python_version}</span>
 • 
1804
torch: {getattr(torch, '__long_version__',torch.__version__)}
A
AUTOMATIC 已提交
1805 1806 1807 1808 1809 1810
 • 
xformers: {xformers_version}
 • 
gradio: {gr.__version__}
 • 
commit: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{short_commit}</a>
1811 1812
 • 
checkpoint: <a id="sd_checkpoint_hash">N/A</a>
A
AUTOMATIC 已提交
1813
"""