ui.py 35.7 KB
Newer Older
1 2 3 4
import base64
import html
import io
import json
A
AUTOMATIC 已提交
5
import math
6 7
import mimetypes
import os
A
AUTOMATIC 已提交
8
import random
9 10 11 12
import sys
import time
import traceback

A
AUTOMATIC 已提交
13 14
import numpy as np
import torch
15 16 17 18
from PIL import Image

import gradio as gr
import gradio.utils
A
AUTOMATIC 已提交
19
import gradio.routes
20 21 22 23 24 25

from modules.paths import script_path
from modules.shared import opts, cmd_opts
import modules.shared as shared
from modules.sd_samplers import samplers, samplers_for_img2img
import modules.realesrgan_model as realesrgan
A
AUTOMATIC 已提交
26
import modules.scripts
27 28
import modules.gfpgan_model
import modules.codeformer_model
A
AUTOMATIC 已提交
29
import modules.styles
30 31 32 33 34 35

# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')


36
if not cmd_opts.share and not cmd_opts.listen:
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
    # fix gradio phoning home
    gradio.utils.version_check = lambda: None
    gradio.utils.get_local_ip_address = lambda: '127.0.0.1'


def gr_show(visible=True):
    return {"visible": visible, "__type__": "update"}


sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None

css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
"""

def plaintext_to_html(text):
57
    text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
    return text


def image_from_url_text(filedata):
    if type(filedata) == list:
        if len(filedata) == 0:
            return None

        filedata = filedata[0]

    if filedata.startswith("data:image/png;base64,"):
        filedata = filedata[len("data:image/png;base64,"):]

    filedata = base64.decodebytes(filedata.encode('utf-8'))
    image = Image.open(io.BytesIO(filedata))
    return image


def send_gradio_gallery_to_image(x):
    if len(x) == 0:
        return None

    return image_from_url_text(x[0])

J
jtkelm2 已提交
82

J
jtkelm2 已提交
83
def save_files(js_data, images, index):
84 85 86 87 88 89 90
    import csv

    os.makedirs(opts.outdir_save, exist_ok=True)

    filenames = []

    data = json.loads(js_data)
J
jtkelm2 已提交
91
    
J
jtkelm2 已提交
92
    if index > -1 and opts.save_selected_only and (index > 0 or not opts.return_grid): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
J
jtkelm2 已提交
93 94
        images = [images[index]]
        data["seed"] += (index - 1 if opts.return_grid else index)
95

96
    with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
97 98 99
        at_start = file.tell() == 0
        writer = csv.writer(file)
        if at_start:
100
            writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
101 102 103 104 105 106 107 108 109 110 111 112 113 114

        filename_base = str(int(time.time() * 1000))
        for i, filedata in enumerate(images):
            filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + ".png"
            filepath = os.path.join(opts.outdir_save, filename)

            if filedata.startswith("data:image/png;base64,"):
                filedata = filedata[len("data:image/png;base64,"):]

            with open(filepath, "wb") as imgfile:
                imgfile.write(base64.decodebytes(filedata.encode('utf-8')))

            filenames.append(filename)

115
        writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130

    return '', '', plaintext_to_html(f"Saved: {filenames[0]}")


def wrap_gradio_call(func):
    def f(*args, **kwargs):
        t = time.perf_counter()

        try:
            res = list(func(*args, **kwargs))
        except Exception as e:
            print("Error completing request", file=sys.stderr)
            print("Arguments:", args, kwargs, file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)

A
AUTOMATIC 已提交
131 132 133
            shared.state.job = ""
            shared.state.job_count = 0

134 135 136 137 138 139 140 141 142 143 144 145 146 147
            res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]

        elapsed = time.perf_counter() - t

        # last item is always HTML
        res[-1] = res[-1] + f"<p class='performance'>Time taken: {elapsed:.2f}s</p>"

        shared.state.interrupted = False

        return tuple(res)

    return f


A
AUTOMATIC 已提交
148 149 150
def check_progress_call():

    if shared.state.job_count == 0:
A
AUTOMATIC 已提交
151
        return "", gr_show(False), gr_show(False)
A
AUTOMATIC 已提交
152

A
AUTOMATIC 已提交
153 154 155 156
    progress = 0

    if shared.state.job_count > 0:
        progress += shared.state.job_no / shared.state.job_count
A
AUTOMATIC 已提交
157 158 159 160 161
    if shared.state.sampling_steps > 0:
        progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps

    progress = min(progress, 1)

A
AUTOMATIC 已提交
162 163 164 165 166 167 168 169
    progressbar = ""
    if opts.show_progressbar:
        progressbar = f"""<div class='progressDiv'><div class='progress' style="width:{progress * 100}%">{str(int(progress*100))+"%" if progress > 0.01 else ""}</div></div>"""

    image = gr_show(False)
    preview_visibility = gr_show(False)

    if opts.show_progress_every_n_steps > 0:
170
        if shared.parallel_processing_allowed:
A
AUTOMATIC 已提交
171

172 173 174
            if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
                shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
                shared.state.current_image_sampling_step = shared.state.sampling_step
A
AUTOMATIC 已提交
175

A
AUTOMATIC 已提交
176 177 178 179 180 181
        image = shared.state.current_image

        if image is None or progress >= 1:
            image = gr.update(value=None)
        else:
            preview_visibility = gr_show(True)
A
AUTOMATIC 已提交
182

A
AUTOMATIC 已提交
183
    return f"<span style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
A
AUTOMATIC 已提交
184 185


A
AUTOMATIC 已提交
186 187 188 189 190 191 192
def roll_artist(prompt):
    allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories])
    artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])

    return prompt + ", " + artist.name if prompt != '' else artist.name


A
AUTOMATIC 已提交
193 194 195 196 197 198 199
def visit(x, func, path=""):
    if hasattr(x, 'children'):
        for c in x.children:
            visit(c, func, path)
    elif x.label is not None:
        func(path + "/" + str(x.label), x)

200

201 202 203 204 205 206 207 208 209
def create_seed_inputs():
    with gr.Row():
        seed = gr.Number(label='Seed', value=-1)
        subseed = gr.Number(label='Variation seed', value=-1, visible=False)
        seed_checkbox = gr.Checkbox(label="Extra", elem_id="subseed_show", value=False)

    with gr.Row():
        subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, visible=False)
        seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0, visible=False)
210
        seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0, visible=False)
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234

    def change_visiblity(show):

        return {
            subseed: gr_show(show),
            subseed_strength: gr_show(show),
            seed_resize_from_h: gr_show(show),
            seed_resize_from_w: gr_show(show),
        }

    seed_checkbox.change(
        change_visiblity,
        inputs=[seed_checkbox],
        outputs=[
            subseed,
            subseed_strength,
            seed_resize_from_h,
            seed_resize_from_w
        ]
    )

    return seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w


235 236
def add_style(name: str, prompt: str, negative_prompt: str):
    if name is None:
A
AUTOMATIC 已提交
237 238
        return [gr_show(), gr_show()]

239
    style = modules.styles.PromptStyle(name, prompt, negative_prompt)
A
AUTOMATIC 已提交
240
    shared.prompt_styles.styles[style.name] = style
241 242
    # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
    # reserialize all styles every time we save them
A
AUTOMATIC 已提交
243
    shared.prompt_styles.save_styles(shared.styles_filename)
A
AUTOMATIC 已提交
244

A
AUTOMATIC 已提交
245 246 247 248 249 250 251 252 253
    update = {"visible": True, "choices": list(shared.prompt_styles.styles), "__type__": "update"}
    return [update, update, update, update]


def apply_styles(prompt, prompt_neg, style1_name, style2_name):
    prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
    prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])

    return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
A
AUTOMATIC 已提交
254 255


A
AUTOMATIC 已提交
256 257 258 259 260
def interrogate(image):
    prompt = shared.interrogator.interrogate(image)

    return gr_show(True) if prompt is None else prompt

A
AUTOMATIC 已提交
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297

def create_toprow(is_img2img):
    with gr.Row(elem_id="toprow"):
        with gr.Column(scale=4):
            with gr.Row():
                with gr.Column(scale=8):
                    with gr.Row():
                        prompt = gr.Textbox(label="Prompt", elem_id="prompt", show_label=False, placeholder="Prompt", lines=2)
                        roll = gr.Button('Roll', elem_id="roll", visible=len(shared.artist_db.artists) > 0)

                with gr.Column(scale=1, elem_id="style_pos_col"):
                    prompt_style = gr.Dropdown(label="Style 1", elem_id="style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)

            with gr.Row():
                with gr.Column(scale=8):
                    negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)

                with gr.Column(scale=1, elem_id="style_neg_col"):
                    prompt_style2 = gr.Dropdown(label="Style 2", elem_id="style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)

        with gr.Column(scale=1):
            with gr.Row():
                submit = gr.Button('Generate', elem_id="generate", variant='primary')

            with gr.Row():
                if is_img2img:
                    interrogate = gr.Button('Interrogate', elem_id="interrogate")
                else:
                    interrogate = None
                prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
                save_style = gr.Button('Create style', elem_id="style_create")

            check_progress = gr.Button('Check progress', elem_id="check_progress", visible=False)

    return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, check_progress


A
AUTOMATIC 已提交
298
def create_ui(txt2img, img2img, run_extras, run_pnginfo):
299
    with gr.Blocks(analytics_enabled=False) as txt2img_interface:
A
AUTOMATIC 已提交
300
        txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, check_progress = create_toprow(is_img2img=False)
301 302 303 304 305 306 307

        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):
                steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
                sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")

                with gr.Row():
A
AUTOMATIC 已提交
308
                    restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
309
                    tiling = gr.Checkbox(label='Tiling', value=False)
310 311 312 313 314

                with gr.Row():
                    batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)

315
                cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
316 317 318

                with gr.Group():
                    width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
319
                    height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
320

321
                seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w = create_seed_inputs()
322

A
AUTOMATIC 已提交
323
                with gr.Group():
A
AUTOMATIC 已提交
324
                    custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
325 326 327

            with gr.Column(variant='panel'):
                with gr.Group():
A
AUTOMATIC 已提交
328
                    txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
329
                    txt2img_gallery = gr.Gallery(label='Output', elem_id='txt2img_gallery').style(grid=4)
330 331 332 333 334 335 336 337 338

                with gr.Group():
                    with gr.Row():
                        save = gr.Button('Save')
                        send_to_img2img = gr.Button('Send to img2img')
                        send_to_inpaint = gr.Button('Send to inpaint')
                        send_to_extras = gr.Button('Send to extras')
                        interrupt = gr.Button('Interrupt')

A
AUTOMATIC 已提交
339 340
                progressbar = gr.HTML(elem_id="progressbar")

341 342 343 344 345 346
                with gr.Group():
                    html_info = gr.HTML()
                    generation_info = gr.Textbox(visible=False)

            txt2img_args = dict(
                fn=txt2img,
A
AUTOMATIC 已提交
347
                _js="submit",
348
                inputs=[
A
AUTOMATIC 已提交
349
                    txt2img_prompt,
350
                    txt2img_negative_prompt,
A
AUTOMATIC 已提交
351
                    txt2img_prompt_style,
A
AUTOMATIC 已提交
352
                    txt2img_prompt_style2,
353 354
                    steps,
                    sampler_index,
A
AUTOMATIC 已提交
355
                    restore_faces,
356
                    tiling,
357 358 359 360
                    batch_count,
                    batch_size,
                    cfg_scale,
                    seed,
361
                    subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w,
362 363
                    height,
                    width,
A
AUTOMATIC 已提交
364
                ] + custom_inputs,
365 366 367 368 369 370 371
                outputs=[
                    txt2img_gallery,
                    generation_info,
                    html_info
                ]
            )

A
AUTOMATIC 已提交
372
            txt2img_prompt.submit(**txt2img_args)
373 374
            submit.click(**txt2img_args)

A
AUTOMATIC 已提交
375 376
            check_progress.click(
                fn=check_progress_call,
A
AUTOMATIC 已提交
377
                show_progress=False,
A
AUTOMATIC 已提交
378
                inputs=[],
A
AUTOMATIC 已提交
379
                outputs=[progressbar, txt2img_preview, txt2img_preview],
A
AUTOMATIC 已提交
380 381
            )

382 383 384 385 386 387 388 389
            interrupt.click(
                fn=lambda: shared.state.interrupt(),
                inputs=[],
                outputs=[],
            )

            save.click(
                fn=wrap_gradio_call(save_files),
J
jtkelm2 已提交
390
                _js = "(x, y, z) => [x, y, selected_gallery_index()]",
391 392 393
                inputs=[
                    generation_info,
                    txt2img_gallery,
J
jtkelm2 已提交
394
                    html_info
395 396 397 398 399 400 401 402
                ],
                outputs=[
                    html_info,
                    html_info,
                    html_info,
                ]
            )

A
AUTOMATIC 已提交
403 404 405
            roll.click(
                fn=roll_artist,
                inputs=[
A
AUTOMATIC 已提交
406
                    txt2img_prompt,
A
AUTOMATIC 已提交
407 408
                ],
                outputs=[
A
AUTOMATIC 已提交
409
                    txt2img_prompt,
A
AUTOMATIC 已提交
410 411 412
                ]
            )

413
    with gr.Blocks(analytics_enabled=False) as img2img_interface:
A
AUTOMATIC 已提交
414
        img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, check_progress = create_toprow(is_img2img=True)
415 416 417 418

        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):
                with gr.Group():
419
                    switch_mode = gr.Radio(label='Mode', elem_id="img2img_mode", choices=['Redraw whole image', 'Inpaint a part of image', 'SD upscale'], value='Redraw whole image', type="index", show_label=False)
420
                    init_img = gr.Image(label="Image for img2img", source="upload", interactive=True, type="pil")
421
                    init_img_with_mask = gr.Image(label="Image for inpainting with mask", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", visible=False, image_mode="RGBA")
422
                    init_mask = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False)
423
                    init_img_with_mask_comment = gr.HTML(elem_id="mask_bug_info", value="<small>if the editor shows ERROR, switch to another tab and back, then to another img2img mode above and back</small>", visible=False)
424 425 426 427

                    with gr.Row():
                        resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
                        mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask")
428 429 430 431

                steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
                sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
                mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, visible=False)
A
AUTOMATIC 已提交
432
                inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False)
433 434

                with gr.Row():
A
AUTOMATIC 已提交
435
                    inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False, visible=False)
A
AUTOMATIC 已提交
436 437 438
                    inpainting_mask_invert = gr.Radio(label='Masking mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", visible=False)

                with gr.Row():
A
AUTOMATIC 已提交
439
                    restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
440
                    tiling = gr.Checkbox(label='Tiling', value=False)
A
AUTOMATIC 已提交
441
                    sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
442 443

                with gr.Row():
A
AUTOMATIC 已提交
444
                    sd_upscale_upscaler_name = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", visible=False)
445 446 447 448 449 450

                with gr.Row():
                    batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)

                with gr.Group():
451
                    cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
452
                    denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75)
453 454 455

                with gr.Group():
                    width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
456
                    height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
457

458
                seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w = create_seed_inputs()
459

A
AUTOMATIC 已提交
460
                with gr.Group():
A
AUTOMATIC 已提交
461
                    custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
A
AUTOMATIC 已提交
462

463 464
            with gr.Column(variant='panel'):
                with gr.Group():
A
AUTOMATIC 已提交
465
                    img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
466
                    img2img_gallery = gr.Gallery(label='Output', elem_id='img2img_gallery').style(grid=4)
467 468 469 470

                with gr.Group():
                    with gr.Row():
                        save = gr.Button('Save')
A
AUTOMATIC 已提交
471 472
                        img2img_send_to_img2img = gr.Button('Send to img2img')
                        img2img_send_to_inpaint = gr.Button('Send to inpaint')
473
                        img2img_send_to_extras = gr.Button('Send to extras')
A
AUTOMATIC 已提交
474
                        interrupt = gr.Button('Interrupt')
A
AUTOMATIC 已提交
475
                        img2img_save_style = gr.Button('Save prompt as style')
476

A
AUTOMATIC 已提交
477 478
                progressbar = gr.HTML(elem_id="progressbar")

479 480 481 482
                with gr.Group():
                    html_info = gr.HTML()
                    generation_info = gr.Textbox(visible=False)

483
            def apply_mode(mode, uploadmask):
484 485
                is_classic = mode == 0
                is_inpaint = mode == 1
486
                is_upscale = mode == 2
487 488

                return {
489 490
                    init_img: gr_show(not is_inpaint or (is_inpaint and uploadmask == 1)),
                    init_img_with_mask: gr_show(is_inpaint and uploadmask == 0),
A
AUTOMATIC 已提交
491
                    init_img_with_mask_comment: gr_show(is_inpaint and uploadmask == 0),
492 493
                    init_mask: gr_show(is_inpaint and uploadmask == 1),
                    mask_mode: gr_show(is_inpaint),
494 495 496
                    mask_blur: gr_show(is_inpaint),
                    inpainting_fill: gr_show(is_inpaint),
                    sd_upscale_upscaler_name: gr_show(is_upscale),
A
AUTOMATIC 已提交
497
                    sd_upscale_overlap: gr_show(is_upscale),
498
                    inpaint_full_res: gr_show(is_inpaint),
A
AUTOMATIC 已提交
499
                    inpainting_mask_invert: gr_show(is_inpaint),
A
AUTOMATIC 已提交
500
                    img2img_interrogate: gr_show(not is_inpaint),
501 502 503 504
                }

            switch_mode.change(
                apply_mode,
505
                inputs=[switch_mode, mask_mode],
506 507 508
                outputs=[
                    init_img,
                    init_img_with_mask,
A
AUTOMATIC 已提交
509
                    init_img_with_mask_comment,
510 511
                    init_mask,
                    mask_mode,
512 513 514 515 516
                    mask_blur,
                    inpainting_fill,
                    sd_upscale_upscaler_name,
                    sd_upscale_overlap,
                    inpaint_full_res,
A
AUTOMATIC 已提交
517
                    inpainting_mask_invert,
A
AUTOMATIC 已提交
518
                    img2img_interrogate,
519 520 521
                ]
            )

522 523 524 525 526 527 528 529 530 531 532 533 534 535
            mask_mode.change(
                lambda mode: {
                    init_img: gr_show(mode == 1),
                    init_img_with_mask: gr_show(mode == 0),
                    init_mask: gr_show(mode == 1),
                },
                inputs=[mask_mode],
                outputs=[
                    init_img,
                    init_img_with_mask,
                    init_mask,
                ],
            )

536 537
            img2img_args = dict(
                fn=img2img,
A
AUTOMATIC 已提交
538
                _js="submit",
539
                inputs=[
A
AUTOMATIC 已提交
540
                    img2img_prompt,
541
                    img2img_negative_prompt,
A
AUTOMATIC 已提交
542
                    img2img_prompt_style,
A
AUTOMATIC 已提交
543
                    img2img_prompt_style2,
544 545
                    init_img,
                    init_img_with_mask,
546 547
                    init_mask,
                    mask_mode,
548 549 550 551
                    steps,
                    sampler_index,
                    mask_blur,
                    inpainting_fill,
A
AUTOMATIC 已提交
552
                    restore_faces,
553
                    tiling,
554 555 556 557 558 559
                    switch_mode,
                    batch_count,
                    batch_size,
                    cfg_scale,
                    denoising_strength,
                    seed,
560
                    subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w,
561 562 563 564 565 566
                    height,
                    width,
                    resize_mode,
                    sd_upscale_upscaler_name,
                    sd_upscale_overlap,
                    inpaint_full_res,
A
AUTOMATIC 已提交
567
                    inpainting_mask_invert,
A
AUTOMATIC 已提交
568
                ] + custom_inputs,
569 570 571 572 573 574 575
                outputs=[
                    img2img_gallery,
                    generation_info,
                    html_info
                ]
            )

A
AUTOMATIC 已提交
576
            img2img_prompt.submit(**img2img_args)
577 578
            submit.click(**img2img_args)

A
AUTOMATIC 已提交
579 580 581 582 583 584
            img2img_interrogate.click(
                fn=interrogate,
                inputs=[init_img],
                outputs=[img2img_prompt],
            )

A
AUTOMATIC 已提交
585 586
            check_progress.click(
                fn=check_progress_call,
A
AUTOMATIC 已提交
587
                show_progress=False,
A
AUTOMATIC 已提交
588
                inputs=[],
A
AUTOMATIC 已提交
589
                outputs=[progressbar, img2img_preview, img2img_preview],
A
AUTOMATIC 已提交
590 591
            )

592 593 594 595 596 597 598 599
            interrupt.click(
                fn=lambda: shared.state.interrupt(),
                inputs=[],
                outputs=[],
            )

            save.click(
                fn=wrap_gradio_call(save_files),
J
jtkelm2 已提交
600
                _js = "(x, y, z) => [x, y, selected_gallery_index()]",
601 602 603
                inputs=[
                    generation_info,
                    img2img_gallery,
J
jtkelm2 已提交
604
                    html_info
605 606 607 608 609 610 611 612
                ],
                outputs=[
                    html_info,
                    html_info,
                    html_info,
                ]
            )

A
AUTOMATIC 已提交
613 614 615 616 617 618 619 620 621 622 623 624 625
            roll.click(
                fn=roll_artist,
                inputs=[
                    img2img_prompt,
                ],
                outputs=[
                    img2img_prompt,
                ]
            )

            prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
            style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]

626
            dummy_component = gr.Label(visible=False)
A
AUTOMATIC 已提交
627
            for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
A
AUTOMATIC 已提交
628 629 630
                button.click(
                    fn=add_style,
                    _js="ask_for_style_name",
631 632 633
                    # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
                    # the same number of parameters, but we only know the style-name after the JavaScript prompt
                    inputs=[dummy_component, prompt, negative_prompt],
A
AUTOMATIC 已提交
634 635 636 637 638 639 640 641
                    outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
                )

            for button, (prompt, negative_prompt), (style1, style2) in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns):
                button.click(
                    fn=apply_styles,
                    inputs=[prompt, negative_prompt, style1, style2],
                    outputs=[prompt, negative_prompt, style1, style2],
A
AUTOMATIC 已提交
642 643
                )

644 645 646 647 648
    with gr.Blocks(analytics_enabled=False) as extras_interface:
        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):
                with gr.Group():
                    image = gr.Image(label="Source", source="upload", interactive=True, type="pil")
A
AUTOMATIC 已提交
649 650 651 652 653 654 655 656 657 658 659

                upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)

                with gr.Group():
                    extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")

                with gr.Group():
                    extras_upscaler_2 = gr.Radio(label='Upscaler 2', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
                    extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)

                with gr.Group():
660 661 662 663
                    gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan)

                with gr.Group():
                    codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
664
                    codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)
665 666 667 668 669 670 671 672 673 674 675 676

                submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')

            with gr.Column(variant='panel'):
                result_image = gr.Image(label="Result")
                html_info_x = gr.HTML()
                html_info = gr.HTML()

        extras_args = dict(
            fn=run_extras,
            inputs=[
                image,
677 678 679
                gfpgan_visibility,
                codeformer_visibility,
                codeformer_weight,
A
AUTOMATIC 已提交
680 681 682 683
                upscaling_resize,
                extras_upscaler_1,
                extras_upscaler_2,
                extras_upscaler_2_visibility,
684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714
            ],
            outputs=[
                result_image,
                html_info_x,
                html_info,
            ]
        )

        submit.click(**extras_args)

    pnginfo_interface = gr.Interface(
        wrap_gradio_call(run_pnginfo),
        inputs=[
            gr.Image(label="Source", source="upload", interactive=True, type="pil"),
        ],
        outputs=[
            gr.HTML(),
            gr.HTML(),
            gr.HTML(),
        ],
        allow_flagging="never",
        analytics_enabled=False,
    )

    def create_setting_component(key):
        def fun():
            return opts.data[key] if key in opts.data else opts.data_labels[key].default

        info = opts.data_labels[key]
        t = type(info.default)

715 716
        args = info.component_args() if callable(info.component_args) else info.component_args

717
        if info.component is not None:
718
            comp = info.component
719
        elif t == str:
720
            comp = gr.Textbox
721
        elif t == int:
722
            comp = gr.Number
723
        elif t == bool:
724
            comp = gr.Checkbox
725 726 727
        else:
            raise Exception(f'bad options item type: {str(t)} for key {key}')

728
        return comp(label=info.label, value=fun, **(args or {}))
729

A
AUTOMATIC 已提交
730 731 732 733 734
    components = []
    keys = list(opts.data_labels.keys())
    settings_cols = 3
    items_per_col = math.ceil(len(keys) / settings_cols)

735 736 737
    def run_settings(*args):
        up = []

A
AUTOMATIC 已提交
738
        for key, value, comp in zip(opts.data_labels.keys(), args, components):
739 740 741 742
            comp_args = opts.data_labels[key].component_args
            if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
                continue

743 744 745 746 747
            opts.data[key] = value
            up.append(comp.update(value=value))

        opts.save(shared.config_filename)

A
AUTOMATIC 已提交
748
        return 'Settings applied.'
749

A
AUTOMATIC 已提交
750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767
    with gr.Blocks(analytics_enabled=False) as settings_interface:
        submit = gr.Button(value="Apply settings", variant='primary')
        result = gr.HTML()

        with gr.Row(elem_id="settings").style(equal_height=False):
            for colno in range(settings_cols):
                with gr.Column(variant='panel'):
                    for rowno in range(items_per_col):
                        index = rowno + colno * items_per_col

                        if index < len(keys):
                            components.append(create_setting_component(keys[index]))

        submit.click(
            fn=run_settings,
            inputs=components,
            outputs=[result]
        )
768 769

    interfaces = [
A
AUTOMATIC 已提交
770 771 772 773 774
        (txt2img_interface, "txt2img", "txt2img"),
        (img2img_interface, "img2img", "img2img"),
        (extras_interface, "Extras", "extras"),
        (pnginfo_interface, "PNG Info", "pnginfo"),
        (settings_interface, "Settings", "settings"),
775 776 777 778 779 780 781 782
    ]

    with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
        css = file.read()

    if not cmd_opts.no_progressbar_hiding:
        css += css_hide_progressbar

A
AUTOMATIC 已提交
783 784 785 786 787 788 789 790 791 792 793 794 795 796
    with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
        with gr.Tabs() as tabs:
            for interface, label, ifid in interfaces:
                with gr.TabItem(label, id=ifid):
                    interface.render()

        tabs.change(
            fn=lambda x: x,
            inputs=[init_img_with_mask],
            outputs=[init_img_with_mask],
        )

        send_to_img2img.click(
            fn=lambda x: image_from_url_text(x),
S
Seki 已提交
797
            _js="extract_image_from_gallery_img2img",
A
AUTOMATIC 已提交
798 799 800 801 802 803
            inputs=[txt2img_gallery],
            outputs=[init_img],
        )

        send_to_inpaint.click(
            fn=lambda x: image_from_url_text(x),
S
Seki 已提交
804
            _js="extract_image_from_gallery_img2img",
A
AUTOMATIC 已提交
805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824
            inputs=[txt2img_gallery],
            outputs=[init_img_with_mask],
        )

        img2img_send_to_img2img.click(
            fn=lambda x: image_from_url_text(x),
            _js="extract_image_from_gallery",
            inputs=[img2img_gallery],
            outputs=[init_img],
        )

        img2img_send_to_inpaint.click(
            fn=lambda x: image_from_url_text(x),
            _js="extract_image_from_gallery",
            inputs=[img2img_gallery],
            outputs=[init_img_with_mask],
        )

        send_to_extras.click(
            fn=lambda x: image_from_url_text(x),
S
Seki 已提交
825
            _js="extract_image_from_gallery_extras",
A
AUTOMATIC 已提交
826 827 828 829 830 831
            inputs=[txt2img_gallery],
            outputs=[image],
        )

        img2img_send_to_extras.click(
            fn=lambda x: image_from_url_text(x),
S
Seki 已提交
832
            _js="extract_image_from_gallery_extras",
A
AUTOMATIC 已提交
833 834 835
            inputs=[img2img_gallery],
            outputs=[image],
        )
836

837
    ui_config_file = cmd_opts.ui_config_file
A
AUTOMATIC 已提交
838 839 840 841 842 843 844 845 846 847 848 849 850 851
    ui_settings = {}
    settings_count = len(ui_settings)
    error_loading = False

    try:
        if os.path.exists(ui_config_file):
            with open(ui_config_file, "r", encoding="utf8") as file:
                ui_settings = json.load(file)
    except Exception:
        error_loading = True
        print("Error loading settings:", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)

    def loadsave(path, x):
852
        def apply_field(obj, field, condition=None):
A
AUTOMATIC 已提交
853 854 855 856 857
            key = path + "/" + field

            saved_value = ui_settings.get(key, None)
            if saved_value is None:
                ui_settings[key] = getattr(obj, field)
858
            elif condition is None or condition(saved_value):
A
AUTOMATIC 已提交
859 860 861 862 863 864 865 866 867
                setattr(obj, field, saved_value)

        if type(x) == gr.Slider:
            apply_field(x, 'value')
            apply_field(x, 'minimum')
            apply_field(x, 'maximum')
            apply_field(x, 'step')

        if type(x) == gr.Radio:
868
            apply_field(x, 'value', lambda val: val in x.choices)
A
AUTOMATIC 已提交
869 870 871

    visit(txt2img_interface, loadsave, "txt2img")
    visit(img2img_interface, loadsave, "img2img")
872
    visit(extras_interface, loadsave, "extras")
A
AUTOMATIC 已提交
873 874 875 876 877

    if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
        with open(ui_config_file, "w", encoding="utf8") as file:
            json.dump(ui_settings, file, indent=4)

878 879 880
    return demo


A
AUTOMATIC 已提交
881 882
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
    javascript = jsfile.read()
883 884


A
AUTOMATIC 已提交
885 886 887 888 889
def template_response(*args, **kwargs):
    res = gradio_routes_templates_response(*args, **kwargs)
    res.body = res.body.replace(b'</head>', f'<script>{javascript}</script></head>'.encode("utf8"))
    res.init_headers()
    return res
890 891


A
AUTOMATIC 已提交
892 893
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
gradio.routes.templates.TemplateResponse = template_response