webui.py 80.3 KB
Newer Older
1 2 3
import argparse
import os
import sys
4 5

script_path = os.path.dirname(os.path.realpath(__file__))
6 7 8

# use current directory as SD dir if it has related files, otherwise parent dir of script as stated in guide
sd_path = os.path.abspath('.') if os.path.exists('./ldm/models/diffusion/ddpm.py') else os.path.dirname(script_path)
9 10

# add parent directory to path; this is where Stable diffusion repo should be
A
AUTOMATIC 已提交
11 12
path_dirs = [
    (sd_path, 'ldm', 'Stable Diffusion'),
13
    (os.path.join(sd_path,'../taming-transformers'), 'taming', 'Taming Transformers')
A
AUTOMATIC 已提交
14
]
15 16 17 18 19 20
for d, must_exist, what in path_dirs:
    must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
    if not os.path.exists(must_exist_path):
        print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
    else:
        sys.path.append(os.path.join(script_path, d))
21

A
first  
AUTOMATIC 已提交
22 23 24 25
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
A
AUTOMATIC 已提交
26
import gradio.utils
A
first  
AUTOMATIC 已提交
27
from omegaconf import OmegaConf
A
AUTOMATIC 已提交
28
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin, ImageFilter, ImageOps
A
first  
AUTOMATIC 已提交
29 30 31
from torch import autocast
import mimetypes
import random
32
import math
A
AUTOMATIC 已提交
33 34
import html
import time
A
AUTOMATIC 已提交
35 36
import json
import traceback
37 38 39
from collections import namedtuple
from contextlib import nullcontext
import signal
40
import tqdm
A
AUTOMATIC 已提交
41
import re
42
import threading
A
AUTOMATIC 已提交
43 44 45
import time
import base64
import io
A
first  
AUTOMATIC 已提交
46

A
AUTOMATIC 已提交
47
import k_diffusion.sampling
A
first  
AUTOMATIC 已提交
48 49 50 51 52 53 54 55 56 57 58 59
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')

# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8

60
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
61
invalid_filename_chars = '<>:"/\\|?*\n'
A
AUTOMATIC 已提交
62
config_filename = "config.json"
63 64 65
sd_model_file = os.path.join(script_path, 'model.ckpt')
if not os.path.exists(sd_model_file):
    sd_model_file = "models/ldm/stable-diffusion-v1/model.ckpt"
A
AUTOMATIC 已提交
66

A
first  
AUTOMATIC 已提交
67
parser = argparse.ArgumentParser()
68
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
69
parser.add_argument("--ckpt", type=str, default=os.path.join(sd_path, sd_model_file), help="path to checkpoint of model",)
70
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
71
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default='GFPGANv1.3.pth')
72
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
73
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
74
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
A
AUTOMATIC 已提交
75
parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="embeddings dirtectory for textual inversion (default: embeddings)")
76
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
A
AUTOMATIC 已提交
77 78
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a lot of speed for very low VRM usage")
A
AUTOMATIC 已提交
79
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="a workaround test; may help with speed in you use --lowvram")
80
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
A
AUTOMATIC 已提交
81
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
A
AUTOMATIC 已提交
82
cmd_opts = parser.parse_args()
A
first  
AUTOMATIC 已提交
83

A
AUTOMATIC 已提交
84 85 86
cpu = torch.device("cpu")
gpu = torch.device("cuda")
device = gpu if torch.cuda.is_available() else cpu
A
AUTOMATIC 已提交
87
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
88
queue_lock = threading.Lock()
A
AUTOMATIC 已提交
89

A
AUTOMATIC 已提交
90

A
AUTOMATIC 已提交
91 92 93 94
def gr_show(visible=True):
    return {"visible": visible, "__type__": "update"}


A
AUTOMATIC 已提交
95 96 97 98 99 100 101 102 103 104
class State:
    interrupted = False
    job = ""

    def interrupt(self):
        self.interrupted = True


state = State()

A
AUTOMATIC 已提交
105 106 107 108
if not cmd_opts.share:
    # fix gradio phoning home
    gradio.utils.version_check = lambda: None
    gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
A
AUTOMATIC 已提交
109

110 111 112 113 114 115
css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
"""
A
first  
AUTOMATIC 已提交
116

A
AUTOMATIC 已提交
117 118
SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
samplers = [
119
    *[SamplerData(x[0], lambda funcname=x[1]: KDiffusionSampler(funcname)) for x in [
A
AUTOMATIC 已提交
120
        ('Euler a', 'sample_euler_ancestral'),
A
AUTOMATIC 已提交
121
        ('Euler', 'sample_euler'),
A
AUTOMATIC 已提交
122 123
        ('LMS', 'sample_lms'),
        ('Heun', 'sample_heun'),
A
AUTOMATIC 已提交
124 125
        ('DPM2', 'sample_dpm_2'),
        ('DPM2 a', 'sample_dpm_2_ancestral'),
A
AUTOMATIC 已提交
126
    ] if hasattr(k_diffusion.sampling, x[1])],
127 128
    SamplerData('DDIM', lambda: VanillaStableDiffusionSampler(DDIMSampler)),
    SamplerData('PLMS', lambda: VanillaStableDiffusionSampler(PLMSSampler)),
A
AUTOMATIC 已提交
129
]
A
AUTOMATIC 已提交
130
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
A
AUTOMATIC 已提交
131

A
AUTOMATIC 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])

try:
    from basicsr.archs.rrdbnet_arch import RRDBNet
    from realesrgan import RealESRGANer
    from realesrgan.archs.srvgg_arch import SRVGGNetCompact

    realesrgan_models = [
        RealesrganModelInfo(
            name="Real-ESRGAN 4x plus",
            location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
            netscale=4, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        ),
        RealesrganModelInfo(
            name="Real-ESRGAN 4x plus anime 6B",
            location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
            netscale=4, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
        ),
A
AUTOMATIC 已提交
150 151 152 153 154
        RealesrganModelInfo(
            name="Real-ESRGAN 2x plus",
            location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
            netscale=2, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
        ),
A
AUTOMATIC 已提交
155 156
    ]
    have_realesrgan = True
157
except Exception:
A
AUTOMATIC 已提交
158
    print("Error importing Real-ESRGAN:", file=sys.stderr)
A
AUTOMATIC 已提交
159 160 161 162 163
    print(traceback.format_exc(), file=sys.stderr)

    realesrgan_models = [RealesrganModelInfo('None', '', 0, None)]
    have_realesrgan = False

A
AUTOMATIC 已提交
164 165 166 167 168
sd_upscalers = {
    "RealESRGAN": lambda img: upscale_with_realesrgan(img, 2, 0),
    "Lanczos": lambda img: img.resize((img.width*2, img.height*2), resample=LANCZOS),
    "None": lambda img: img
}
A
AUTOMATIC 已提交
169

170

171 172 173 174 175 176 177
def gfpgan_model_path():
    places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
    files = [cmd_opts.gfpgan_model] + [os.path.join(dirname, cmd_opts.gfpgan_model) for dirname in places]
    found = [x for x in files if os.path.exists(x)]

    if len(found) == 0:
        raise Exception("GFPGAN model not found in paths: " + ", ".join(files))
A
AUTOMATIC 已提交
178

179
    return found[0]
A
AUTOMATIC 已提交
180 181 182


def gfpgan():
183 184
    return GFPGANer(model_path=gfpgan_model_path(), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)

185 186 187 188 189 190
def gfpgan_fix_faces(gfpgan_model, np_image):
    np_image_bgr = np_image[:, :, ::-1]
    cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan_model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
    np_image = gfpgan_output_bgr[:, :, ::-1]

    return np_image
191 192 193 194 195 196 197 198 199 200 201 202 203

have_gfpgan = False
try:
    model_path = gfpgan_model_path()

    if os.path.exists(cmd_opts.gfpgan_dir):
        sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir))
    from gfpgan import GFPGANer

    have_gfpgan = True
except Exception:
    print("Error setting up GFPGAN:", file=sys.stderr)
    print(traceback.format_exc(), file=sys.stderr)
A
AUTOMATIC 已提交
204 205 206



A
AUTOMATIC 已提交
207
class Options:
208 209 210 211 212 213 214
    class OptionInfo:
        def __init__(self, default=None, label="", component=None, component_args=None):
            self.default = default
            self.label = label
            self.component = component
            self.component_args = component_args

A
AUTOMATIC 已提交
215 216
    data = None
    data_labels = {
A
AUTOMATIC 已提交
217 218 219
        "outdir_samples": OptionInfo("", "Output dictectory for images; if empty, defaults to two directories below"),
        "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output dictectory for txt2img images'),
        "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output dictectory for img2img images'),
A
AUTOMATIC 已提交
220
        "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output dictectory for images from extras tab'),
A
AUTOMATIC 已提交
221 222 223 224 225
        "outdir_grids": OptionInfo("", "Output dictectory for grids; if empty, defaults to two directories below"),
        "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output dictectory for txt2img grids'),
        "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output dictectory for img2img grids'),
        "save_to_dirs": OptionInfo(False, "When writing images/grids, create a directory with name derived from the prompt"),
        "save_to_dirs_prompt_len": OptionInfo(10, "When using above, how many words from prompt to put into directory name", gr.Slider, {"minimum": 1, "maximum": 32, "step": 1}),
A
AUTOMATIC 已提交
226
        "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button"),
227 228 229
        "samples_save": OptionInfo(True, "Save indiviual samples"),
        "samples_format": OptionInfo('png', 'File format for indiviual samples'),
        "grid_save": OptionInfo(True, "Save image grids"),
A
AUTOMATIC 已提交
230
        "return_grid": OptionInfo(True, "Show grid in results for web"),
231 232
        "grid_format": OptionInfo('png', 'File format for grids'),
        "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
233
        "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
234 235
        "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
        "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
236
        "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
237
        "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
A
AUTOMATIC 已提交
238
        "font": OptionInfo("arial.ttf", "Font for image grids  that have text"),
239
        "prompt_matrix_add_to_start": OptionInfo(True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"),
240 241 242
        "enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text text and [text] to make it pay less attention"),
        "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),

A
AUTOMATIC 已提交
243 244 245
    }

    def __init__(self):
246
        self.data = {k: v.default for k, v in self.data_labels.items()}
A
AUTOMATIC 已提交
247 248 249 250 251 252 253 254 255 256 257 258 259

    def __setattr__(self, key, value):
        if self.data is not None:
            if key in self.data:
                self.data[key] = value

        return super(Options, self).__setattr__(key, value)

    def __getattr__(self, item):
        if self.data is not None:
            if item in self.data:
                return self.data[item]

260
        if item in self.data_labels:
261
            return self.data_labels[item].default
262

A
AUTOMATIC 已提交
263 264 265 266 267 268 269 270 271 272 273
        return super(Options, self).__getattribute__(item)

    def save(self, filename):
        with open(filename, "w", encoding="utf8") as file:
            json.dump(self.data, file)

    def load(self, filename):
        with open(filename, "r", encoding="utf8") as file:
            self.data = json.load(file)


A
first  
AUTOMATIC 已提交
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.eval()
    return model


293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
module_in_gpu = None


def setup_for_low_vram(sd_model):
    parents = {}

    def send_me_to_gpu(module, _):
        """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
        we add this as forward_pre_hook to a lot of modules and this way all but one of them will
        be in CPU
        """
        global module_in_gpu

        module = parents.get(module, module)

        if module_in_gpu == module:
            return

        if module_in_gpu is not None:
            module_in_gpu.to(cpu)

        module.to(gpu)
        module_in_gpu = module

    # see below for register_forward_pre_hook;
    # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
    # useless here, and we just replace those methods
    def first_stage_model_encode_wrap(self, encoder, x):
        send_me_to_gpu(self, None)
        return encoder(x)

    def first_stage_model_decode_wrap(self, decoder, z):
        send_me_to_gpu(self, None)
        return decoder(z)

    # remove three big modules, cond, first_stage, and unet from the model and then
    # send the model to GPU. Then put modules back. the modules will be in CPU.
    stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
    sd_model.to(device)
    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored

    # register hooks for those the first two models
    sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
    sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
    sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
    sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
    parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model

A
AUTOMATIC 已提交
342 343 344 345 346 347 348 349 350 351 352
    if cmd_opts.medvram:
        sd_model.model.register_forward_pre_hook(send_me_to_gpu)
    else:
        diff_model = sd_model.model.diffusion_model

        # the third remaining model is still too big for 4GB, so we also do the same for its submodules
        # so that only one of them is in GPU at a time
        stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
        diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
        sd_model.model.to(device)
        diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
353

A
AUTOMATIC 已提交
354 355 356 357 358 359 360
        # install hooks for bits of third model
        diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
        for block in diff_model.input_blocks:
            block.register_forward_pre_hook(send_me_to_gpu)
        diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
        for block in diff_model.output_blocks:
            block.register_forward_pre_hook(send_me_to_gpu)
361 362


A
AUTOMATIC 已提交
363
def create_random_tensors(shape, seeds):
A
AUTOMATIC 已提交
364
    xs = []
A
AUTOMATIC 已提交
365 366 367 368 369 370 371
    for seed in seeds:
        torch.manual_seed(seed)

        # randn results depend on device; gpu and cpu get different results for same seed;
        # the way I see it, it's better to do this on CPU, so that everyone gets same result;
        # but the original script had it like this so i do not dare change it for now because
        # it will break everyone's seeds.
A
AUTOMATIC 已提交
372 373 374 375 376
        xs.append(torch.randn(shape, device=device))
    x = torch.stack(xs)
    return x


H
hlky 已提交
377
def torch_gc():
378 379 380
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
A
AUTOMATIC 已提交
381

382

A
AUTOMATIC 已提交
383
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False):
384
    if short_filename or prompt is None or seed is None:
A
AUTOMATIC 已提交
385 386 387
        file_decoration = ""
    elif opts.save_to_dirs:
        file_decoration = f"-{seed}"
388
    else:
A
AUTOMATIC 已提交
389
        file_decoration = f"-{seed}-{sanitize_filename_part(prompt)[:128]}"
390

A
AUTOMATIC 已提交
391
    if extension == 'png' and opts.enable_pnginfo and info is not None:
392 393 394 395 396
        pnginfo = PngImagePlugin.PngInfo()
        pnginfo.add_text("parameters", info)
    else:
        pnginfo = None

A
AUTOMATIC 已提交
397
    if opts.save_to_dirs and not no_prompt:
A
AUTOMATIC 已提交
398 399 400 401 402 403 404
        words = re.findall(r'\w+', prompt or "")
        if len(words) == 0:
            words = ["empty"]

        dirname = " ".join(words[0:opts.save_to_dirs_prompt_len])
        path = os.path.join(path, dirname)

405
    os.makedirs(path, exist_ok=True)
A
AUTOMATIC 已提交
406 407 408 409 410 411 412 413 414 415 416

    filecount = len(os.listdir(path))
    fullfn = "a.png"
    fullfn_without_extension = "a"
    for i in range(100):
        fn = f"{filecount:05}" if basename == '' else f"{basename}-{filecount:04}"
        fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
        fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
        if not os.path.exists(fullfn):
            break

417 418 419 420 421 422 423 424 425 426 427 428
    image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)

    target_side_length = 4000
    oversize = image.width > target_side_length or image.height > target_side_length
    if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
        ratio = image.width / image.height

        if oversize and ratio > 1:
            image = image.resize((target_side_length, image.height * target_side_length // image.width), LANCZOS)
        elif oversize:
            image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)

A
AUTOMATIC 已提交
429
        image.save(f"{fullfn_without_extension}.jpg", quality=opts.jpeg_quality, pnginfo=pnginfo)
430

431 432 433 434 435
    if opts.save_txt and info is not None:
        with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file:
            file.write(info + "\n")


436

A
AUTOMATIC 已提交
437 438 439 440
def sanitize_filename_part(text):
    return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]


A
AUTOMATIC 已提交
441 442 443 444
def plaintext_to_html(text):
    text = "".join([f"<p>{html.escape(x)}</p>\n" for x in text.split('\n')])
    return text

A
AUTOMATIC 已提交
445

446 447 448 449 450 451 452 453 454
def image_grid(imgs, batch_size=1, rows=None):
    if rows is None:
        if opts.n_rows > 0:
            rows = opts.n_rows
        elif opts.n_rows == 0:
            rows = batch_size
        else:
            rows = math.sqrt(len(imgs))
            rows = round(rows)
455 456

    cols = math.ceil(len(imgs) / rows)
A
first  
AUTOMATIC 已提交
457 458

    w, h = imgs[0].size
459
    grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
A
first  
AUTOMATIC 已提交
460 461 462 463 464 465

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))

    return grid

466

A
AUTOMATIC 已提交
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])


def split_grid(image, tile_w=512, tile_h=512, overlap=64):
    w = image.width
    h = image.height

    now = tile_w - overlap  # non-overlap width
    noh = tile_h - overlap

    cols = math.ceil((w - overlap) / now)
    rows = math.ceil((h - overlap) / noh)

    grid = Grid([], tile_w, tile_h, w, h, overlap)
    for row in range(rows):
        row_images = []

        y = row * noh

        if y + tile_h >= h:
            y = h - tile_h

        for col in range(cols):
            x = col * now

            if x+tile_w >= w:
                x = w - tile_w

            tile = image.crop((x, y, x + tile_w, y + tile_h))

            row_images.append([x, tile_w, tile])

        grid.tiles.append([y, tile_h, row_images])

    return grid


def combine_grid(grid):
    def make_mask_image(r):
        r = r * 255 / grid.overlap
        r = r.astype(np.uint8)
        return Image.fromarray(r, 'L')

510 511
    mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
    mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
A
AUTOMATIC 已提交
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533

    combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
    for y, h, row in grid.tiles:
        combined_row = Image.new("RGB", (grid.image_w, h))
        for x, w, tile in row:
            if x == 0:
                combined_row.paste(tile, (0, 0))
                continue

            combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
            combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))

        if y == 0:
            combined_image.paste(combined_row, (0, 0))
            continue

        combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
        combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))

    return combined_image


534 535 536 537 538 539 540 541 542
class GridAnnotation:
    def __init__(self, text='', is_active=True):
        self.text = text
        self.is_active = is_active
        self.size = None


def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
    def wrap(drawing, text, font, line_length):
A
AUTOMATIC 已提交
543 544 545
        lines = ['']
        for word in text.split():
            line = f'{lines[-1]} {word}'.strip()
546
            if drawing.textlength(line, font=font) <= line_length:
A
AUTOMATIC 已提交
547 548 549
                lines[-1] = line
            else:
                lines.append(word)
550
        return lines
A
AUTOMATIC 已提交
551

552 553 554
    def draw_texts(drawing, draw_x, draw_y, lines):
        for i, line in enumerate(lines):
            drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
A
AUTOMATIC 已提交
555

556 557
            if not line.is_active:
                drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4)
A
AUTOMATIC 已提交
558

559
            draw_y += line.size[1] + line_spacing
A
AUTOMATIC 已提交
560 561 562

    fontsize = (width + height) // 25
    line_spacing = fontsize // 2
A
AUTOMATIC 已提交
563
    fnt = ImageFont.truetype(opts.font, fontsize)
A
AUTOMATIC 已提交
564 565 566
    color_active = (0, 0, 0)
    color_inactive = (153, 153, 153)

567
    pad_left = width * 3 // 4 if len(ver_texts) > 0 else 0
A
AUTOMATIC 已提交
568 569 570 571

    cols = im.width // width
    rows = im.height // height

572 573 574 575 576 577
    assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
    assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'

    calc_img = Image.new("RGB", (1, 1), "white")
    calc_d = ImageDraw.Draw(calc_img)

A
AUTOMATIC 已提交
578
    for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
579 580 581 582
        items = [] + texts
        texts.clear()

        for line in items:
A
AUTOMATIC 已提交
583
            wrapped = wrap(calc_d, line.text, fnt, allowed_width)
584 585 586 587 588 589 590 591 592 593 594
            texts += [GridAnnotation(x, line.is_active) for x in wrapped]

        for line in texts:
            bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
            line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])

    hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
    ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]

    pad_top = max(hor_text_heights) + line_spacing * 2

A
AUTOMATIC 已提交
595 596 597 598 599 600 601
    result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
    result.paste(im, (pad_left, pad_top))

    d = ImageDraw.Draw(result)

    for col in range(cols):
        x = pad_left + width * col + width / 2
602
        y = pad_top / 2 - hor_text_heights[col] / 2
A
AUTOMATIC 已提交
603

604
        draw_texts(d, x, y, hor_texts[col])
A
AUTOMATIC 已提交
605 606 607

    for row in range(rows):
        x = pad_left / 2
608
        y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
A
AUTOMATIC 已提交
609

610
        draw_texts(d, x, y, ver_texts[row])
A
AUTOMATIC 已提交
611 612 613 614

    return result


615 616 617 618
def draw_prompt_matrix(im, width, height, all_prompts):
    prompts = all_prompts[1:]
    boundary = math.ceil(len(prompts) / 2)

619 620
    prompts_horiz = prompts[:boundary]
    prompts_vert = prompts[boundary:]
621

622 623
    hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
    ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
624

625
    return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
626 627


628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643
def draw_xy_grid(xs, ys, x_label, y_label, cell):
    res = []

    ver_texts = [[GridAnnotation(y_label(y))] for y in ys]
    hor_texts = [[GridAnnotation(x_label(x))] for x in xs]

    for y in ys:
        for x in xs:
            res.append(cell(x, y))


    grid = image_grid(res, rows=len(ys))
    grid = draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)

    return grid

644

A
AUTOMATIC 已提交
645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672
def resize_image(resize_mode, im, width, height):
    if resize_mode == 0:
        res = im.resize((width, height), resample=LANCZOS)
    elif resize_mode == 1:
        ratio = width / height
        src_ratio = im.width / im.height

        src_w = width if ratio > src_ratio else im.width * height // im.height
        src_h = height if ratio <= src_ratio else im.height * width // im.width

        resized = im.resize((src_w, src_h), resample=LANCZOS)
        res = Image.new("RGB", (width, height))
        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
    else:
        ratio = width / height
        src_ratio = im.width / im.height

        src_w = width if ratio < src_ratio else im.width * height // im.height
        src_h = height if ratio >= src_ratio else im.height * width // im.width

        resized = im.resize((src_w, src_h), resample=LANCZOS)
        res = Image.new("RGB", (width, height))
        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))

        if ratio < src_ratio:
            fill_height = height // 2 - src_h // 2
            res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
            res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
673
        elif ratio > src_ratio:
A
AUTOMATIC 已提交
674 675 676 677 678 679 680
            fill_width = width // 2 - src_w // 2
            res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
            res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))

    return res


681 682 683 684 685 686 687
def wrap_gradio_gpu_call(func):
    def f(*args, **kwargs):
        with queue_lock:
            res = func(*args, **kwargs)

        return res

A
AUTOMATIC 已提交
688
    return wrap_gradio_call(f)
689 690


A
AUTOMATIC 已提交
691
def wrap_gradio_call(func):
692
    def f(*args, **kwargs):
A
AUTOMATIC 已提交
693
        t = time.perf_counter()
A
AUTOMATIC 已提交
694 695 696 697 698 699 700 701

        try:
            res = list(func(*args, **kwargs))
        except Exception as e:
            print("Error completing request", file=sys.stderr)
            print("Arguments:", args, kwargs, file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)

A
AUTOMATIC 已提交
702
            res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
A
AUTOMATIC 已提交
703

A
AUTOMATIC 已提交
704 705 706 707 708
        elapsed = time.perf_counter() - t

        # last item is always HTML
        res[-1] = res[-1] + f"<p class='performance'>Time taken: {elapsed:.2f}s</p>"

A
AUTOMATIC 已提交
709 710
        state.interrupted = False

A
AUTOMATIC 已提交
711 712 713 714 715
        return tuple(res)

    return f


O
orionaskatu 已提交
716
class StableDiffusionModelHijack:
A
AUTOMATIC 已提交
717 718 719
    ids_lookup = {}
    word_embeddings = {}
    word_embeddings_checksums = {}
720
    fixes = None
A
AUTOMATIC 已提交
721
    comments = []
A
AUTOMATIC 已提交
722 723
    dir_mtime = None

724 725
    def load_textual_inversion_embeddings(self, dirname, model):
        mt = os.path.getmtime(dirname)
A
AUTOMATIC 已提交
726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745
        if self.dir_mtime is not None and mt <= self.dir_mtime:
            return

        self.dir_mtime = mt
        self.ids_lookup.clear()
        self.word_embeddings.clear()

        tokenizer = model.cond_stage_model.tokenizer

        def const_hash(a):
            r = 0
            for v in a:
                r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
            return r

        def process_file(path, filename):
            name = os.path.splitext(filename)[0]

            data = torch.load(path)
            param_dict = data['string_to_param']
746 747
            if hasattr(param_dict, '_parameters'):
                param_dict = getattr(param_dict, '_parameters')  # fix for torch 1.12.1 loading saved file from torch 1.11
A
AUTOMATIC 已提交
748 749 750 751 752 753
            assert len(param_dict) == 1, 'embedding file has multiple terms in it'
            emb = next(iter(param_dict.items()))[1].reshape(768)
            self.word_embeddings[name] = emb
            self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'

            ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
754

A
AUTOMATIC 已提交
755 756 757 758 759
            first_id = ids[0]
            if first_id not in self.ids_lookup:
                self.ids_lookup[first_id] = []
            self.ids_lookup[first_id].append((ids, name))

760
        for fn in os.listdir(dirname):
A
AUTOMATIC 已提交
761
            try:
762 763
                process_file(os.path.join(dirname, fn), fn)
            except Exception:
A
AUTOMATIC 已提交
764 765 766 767 768 769 770 771 772 773 774 775
                print(f"Error loading emedding {fn}:", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)
                continue

        print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.")

    def hijack(self, m):
        model_embeddings = m.cond_stage_model.transformer.text_model.embeddings

        model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
        m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)

A
AUTOMATIC 已提交
776

A
AUTOMATIC 已提交
777
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
778
    def __init__(self, wrapped, hijack):
A
AUTOMATIC 已提交
779 780
        super().__init__()
        self.wrapped = wrapped
781
        self.hijack = hijack
A
AUTOMATIC 已提交
782 783
        self.tokenizer = wrapped.tokenizer
        self.max_length = wrapped.max_length
784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800
        self.token_mults = {}

        tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
        for text, ident in tokens_with_parens:
            mult = 1.0
            for c in text:
                if c == '[':
                    mult /= 1.1
                if c == ']':
                    mult *= 1.1
                if c == '(':
                    mult *= 1.1
                if c == ')':
                    mult /= 1.1

            if mult != 1.0:
                self.token_mults[ident] = mult
A
AUTOMATIC 已提交
801 802

    def forward(self, text):
803 804
        self.hijack.fixes = []
        self.hijack.comments = []
A
AUTOMATIC 已提交
805 806 807 808
        remade_batch_tokens = []
        id_start = self.wrapped.tokenizer.bos_token_id
        id_end = self.wrapped.tokenizer.eos_token_id
        maxlen = self.wrapped.max_length - 2
809
        used_custom_terms = []
A
AUTOMATIC 已提交
810 811 812

        cache = {}
        batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
813
        batch_multipliers = []
A
AUTOMATIC 已提交
814 815 816 817
        for tokens in batch_tokens:
            tuple_tokens = tuple(tokens)

            if tuple_tokens in cache:
818
                remade_tokens, fixes, multipliers = cache[tuple_tokens]
A
AUTOMATIC 已提交
819 820 821
            else:
                fixes = []
                remade_tokens = []
822 823
                multipliers = []
                mult = 1.0
A
AUTOMATIC 已提交
824 825 826 827 828

                i = 0
                while i < len(tokens):
                    token = tokens[i]

829
                    possible_matches = self.hijack.ids_lookup.get(token, None)
A
AUTOMATIC 已提交
830

A
AUTOMATIC 已提交
831
                    mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
832 833 834
                    if mult_change is not None:
                        mult *= mult_change
                    elif possible_matches is None:
A
AUTOMATIC 已提交
835
                        remade_tokens.append(token)
836
                        multipliers.append(mult)
A
AUTOMATIC 已提交
837 838 839 840 841 842
                    else:
                        found = False
                        for ids, word in possible_matches:
                            if tokens[i:i+len(ids)] == ids:
                                fixes.append((len(remade_tokens), word))
                                remade_tokens.append(777)
843
                                multipliers.append(mult)
A
AUTOMATIC 已提交
844 845
                                i += len(ids) - 1
                                found = True
846
                                used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
A
AUTOMATIC 已提交
847 848 849 850
                                break

                        if not found:
                            remade_tokens.append(token)
851
                            multipliers.append(mult)
A
AUTOMATIC 已提交
852 853 854

                    i += 1

855 856 857 858 859 860 861 862
                if len(remade_tokens) > maxlen - 2:
                    vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
                    ovf = remade_tokens[maxlen - 2:]
                    overflowing_words = [vocab.get(int(x), "") for x in ovf]
                    overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))

                    self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")

A
AUTOMATIC 已提交
863 864
                remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
                remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
865 866 867 868
                cache[tuple_tokens] = (remade_tokens, fixes, multipliers)

            multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
            multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
A
AUTOMATIC 已提交
869 870

            remade_batch_tokens.append(remade_tokens)
871
            self.hijack.fixes.append(fixes)
872
            batch_multipliers.append(multipliers)
A
AUTOMATIC 已提交
873

874 875 876
        if len(used_custom_terms) > 0:
            self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))

877
        tokens = torch.asarray(remade_batch_tokens).to(device)
A
AUTOMATIC 已提交
878 879
        outputs = self.wrapped.transformer(input_ids=tokens)
        z = outputs.last_hidden_state
880 881 882 883 884 885 886 887

        # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
        batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device)
        original_mean = z.mean()
        z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
        new_mean = z.mean()
        z *= original_mean / new_mean

A
AUTOMATIC 已提交
888 889 890 891 892 893 894 895 896 897 898
        return z


class EmbeddingsWithFixes(nn.Module):
    def __init__(self, wrapped, embeddings):
        super().__init__()
        self.wrapped = wrapped
        self.embeddings = embeddings

    def forward(self, input_ids):
        batch_fixes = self.embeddings.fixes
899
        self.embeddings.fixes = None
A
AUTOMATIC 已提交
900 901 902

        inputs_embeds = self.wrapped(input_ids)

903 904 905 906
        if batch_fixes is not None:
            for fixes, tensor in zip(batch_fixes, inputs_embeds):
                for offset, word in fixes:
                    tensor[offset] = self.embeddings.word_embeddings[word]
A
AUTOMATIC 已提交
907

908
        return inputs_embeds
A
AUTOMATIC 已提交
909 910


911
class StableDiffusionProcessing:
A
AUTOMATIC 已提交
912 913 914
    def __init__(self, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
        self.outpath_samples: str = outpath_samples
        self.outpath_grids: str = outpath_grids
915
        self.prompt: str = prompt
916
        self.negative_prompt: str = (negative_prompt or "")
917 918 919 920 921 922 923 924 925 926
        self.seed: int = seed
        self.sampler_index: int = sampler_index
        self.batch_size: int = batch_size
        self.n_iter: int = n_iter
        self.steps: int = steps
        self.cfg_scale: float = cfg_scale
        self.width: int = width
        self.height: int = height
        self.prompt_matrix: bool = prompt_matrix
        self.use_GFPGAN: bool = use_GFPGAN
927
        self.do_not_save_samples: bool = do_not_save_samples
928 929
        self.do_not_save_grid: bool = do_not_save_grid
        self.extra_generation_params: dict = extra_generation_params
930
        self.overlay_images = overlay_images
931
        self.paste_to = None
932 933 934 935 936 937 938 939

    def init(self):
        pass

    def sample(self, x, conditioning, unconditional_conditioning):
        raise NotImplementedError()


A
AUTOMATIC 已提交
940 941 942 943 944 945 946 947
def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
    if sampler_wrapper.mask is not None:
        img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
        x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec

    return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)


948 949 950
class VanillaStableDiffusionSampler:
    def __init__(self, constructor):
        self.sampler = constructor(sd_model)
A
AUTOMATIC 已提交
951
        self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else None
A
AUTOMATIC 已提交
952 953 954 955 956 957 958
        self.mask = None
        self.nmask = None
        self.init_latent = None

    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
        t_enc = int(min(p.denoising_strength, 0.999) * p.steps)

A
AUTOMATIC 已提交
959
        # existing code fails with cetin step counts, like 9
960 961 962 963 964
        try:
            self.sampler.make_schedule(ddim_num_steps=p.steps, verbose=False)
        except Exception:
            self.sampler.make_schedule(ddim_num_steps=p.steps+1, verbose=False)

A
AUTOMATIC 已提交
965 966
        x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(device), noise=noise)

967
        self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
A
AUTOMATIC 已提交
968 969 970 971 972 973 974 975
        self.mask = p.mask
        self.nmask = p.nmask
        self.init_latent = p.init_latent

        samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)

        return samples

976 977 978 979 980 981 982 983 984 985

    def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
        samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
        return samples_ddim


class CFGDenoiser(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.inner_model = model
A
AUTOMATIC 已提交
986 987 988
        self.mask = None
        self.nmask = None
        self.init_latent = None
989 990

    def forward(self, x, sigma, uncond, cond, cond_scale):
A
AUTOMATIC 已提交
991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005
        if batch_cond_uncond:
            x_in = torch.cat([x] * 2)
            sigma_in = torch.cat([sigma] * 2)
            cond_in = torch.cat([uncond, cond])
            uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
            denoised = uncond + (cond - uncond) * cond_scale
        else:
            uncond = self.inner_model(x, sigma, cond=uncond)
            cond = self.inner_model(x, sigma, cond=cond)
            denoised = uncond + (cond - uncond) * cond_scale

        if self.mask is not None:
            denoised = self.init_latent * self.mask + self.nmask * denoised

        return denoised
1006

A
AUTOMATIC 已提交
1007 1008 1009 1010 1011 1012 1013 1014 1015

def extended_trange(*args, **kwargs):
    for x in tqdm.trange(*args, desc=state.job, **kwargs):
        if state.interrupted:
            break

        yield x


1016 1017 1018 1019 1020 1021 1022
class KDiffusionSampler:
    def __init__(self, funcname):
        self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model)
        self.funcname = funcname
        self.func = getattr(k_diffusion.sampling, self.funcname)
        self.model_wrap_cfg = CFGDenoiser(self.model_wrap)

A
AUTOMATIC 已提交
1023 1024 1025 1026 1027 1028 1029 1030 1031
    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
        t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
        sigmas = self.model_wrap.get_sigmas(p.steps)
        noise = noise * sigmas[p.steps - t_enc - 1]

        xi = x + noise

        sigma_sched = sigmas[p.steps - t_enc - 1:]

A
AUTOMATIC 已提交
1032 1033 1034
        self.model_wrap_cfg.mask = p.mask
        self.model_wrap_cfg.nmask = p.nmask
        self.model_wrap_cfg.init_latent = p.init_latent
A
AUTOMATIC 已提交
1035

1036
        if hasattr(k_diffusion.sampling, 'trange'):
A
AUTOMATIC 已提交
1037
            k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
1038

A
AUTOMATIC 已提交
1039
        return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
A
AUTOMATIC 已提交
1040

1041 1042 1043 1044
    def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
        sigmas = self.model_wrap.get_sigmas(p.steps)
        x = x * sigmas[0]

1045
        if hasattr(k_diffusion.sampling, 'trange'):
A
AUTOMATIC 已提交
1046
            k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
1047

A
AUTOMATIC 已提交
1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059
        def cb(d):
            n = d['i']
            img = d['denoised']

            x_samples_ddim = sd_model.decode_first_stage(img)
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
            for i, x_sample in enumerate(x_samples_ddim):
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)
                image = Image.fromarray(x_sample)
                image.save(f'a/{n}.png')

1060 1061 1062 1063
        samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
        return samples_ddim


A
AUTOMATIC 已提交
1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087
class Processed:
    def __init__(self, p: StableDiffusionProcessing, images, seed, info):
        self.images = images
        self.prompt = p.prompt
        self.seed = seed
        self.info = info
        self.width = p.width
        self.height = p.height
        self.sampler = samplers[p.sampler_index].name
        self.cfg_scale = p.cfg_scale
        self.steps = p.steps

    def js(self):
        obj = {
            "prompt": self.prompt,
            "seed": int(self.seed),
            "width": self.width,
            "height": self.height,
            "sampler": self.sampler,
            "cfg_scale": self.cfg_scale,
            "steps": self.steps,
        }

        return json.dumps(obj)
1088 1089 1090


def process_images(p: StableDiffusionProcessing) -> Processed:
A
AUTOMATIC 已提交
1091
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
A
first  
AUTOMATIC 已提交
1092

1093 1094 1095 1096
    prompt = p.prompt
    model = sd_model

    assert p.prompt is not None
H
hlky 已提交
1097
    torch_gc()
A
first  
AUTOMATIC 已提交
1098

1099
    seed = int(random.randrange(4294967294) if p.seed == -1 else p.seed)
A
first  
AUTOMATIC 已提交
1100

A
AUTOMATIC 已提交
1101 1102
    os.makedirs(p.outpath_samples, exist_ok=True)
    os.makedirs(p.outpath_grids, exist_ok=True)
A
first  
AUTOMATIC 已提交
1103

1104 1105
    comments = []

1106
    prompt_matrix_parts = []
1107
    if p.prompt_matrix:
A
AUTOMATIC 已提交
1108
        all_prompts = []
1109
        prompt_matrix_parts = prompt.split("|")
A
AUTOMATIC 已提交
1110
        combination_count = 2 ** (len(prompt_matrix_parts) - 1)
A
AUTOMATIC 已提交
1111
        for combination_num in range(combination_count):
1112
            selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
A
AUTOMATIC 已提交
1113

1114 1115 1116 1117
            if opts.prompt_matrix_add_to_start:
                selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
            else:
                selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
A
AUTOMATIC 已提交
1118

1119
            all_prompts.append(", ".join(selected_prompts))
A
AUTOMATIC 已提交
1120

1121
        p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
A
AUTOMATIC 已提交
1122 1123
        all_seeds = len(all_prompts) * [seed]

1124
        print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
A
AUTOMATIC 已提交
1125
    else:
1126
        all_prompts = p.batch_size * p.n_iter * [prompt]
A
AUTOMATIC 已提交
1127
        all_seeds = [seed + x for x in range(len(all_prompts))]
A
AUTOMATIC 已提交
1128

1129
    generation_params = {
1130 1131 1132
        "Steps": p.steps,
        "Sampler": samplers[p.sampler_index].name,
        "CFG scale": p.cfg_scale,
1133
        "Seed": seed,
A
AUTOMATIC 已提交
1134
        "GFPGAN": ("GFPGAN" if p.use_GFPGAN else None)
1135 1136
    }

1137 1138
    if p.extra_generation_params is not None:
        generation_params.update(p.extra_generation_params)
1139 1140 1141

    generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])

A
AUTOMATIC 已提交
1142
    def infotext():
1143
        return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
A
AUTOMATIC 已提交
1144 1145

    if os.path.exists(cmd_opts.embeddings_dir):
1146
        model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
1147

A
first  
AUTOMATIC 已提交
1148
    output_images = []
1149
    precision_scope = autocast if cmd_opts.precision == "autocast" else nullcontext
1150
    ema_scope = (nullcontext if cmd_opts.lowvram else model.ema_scope)
1151
    with torch.no_grad(), precision_scope("cuda"), ema_scope():
1152
        p.init()
A
AUTOMATIC 已提交
1153

1154
        for n in range(p.n_iter):
A
AUTOMATIC 已提交
1155 1156 1157
            if state.interrupted:
                break

1158 1159
            prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
A
AUTOMATIC 已提交
1160

1161
            uc = model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
A
AUTOMATIC 已提交
1162 1163
            c = model.get_learned_conditioning(prompts)

1164 1165
            if len(model_hijack.comments) > 0:
                comments += model_hijack.comments
A
AUTOMATIC 已提交
1166

A
AUTOMATIC 已提交
1167
            # we manually generate all input noises because each one should have a specific seed
1168
            x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
A
AUTOMATIC 已提交
1169

A
AUTOMATIC 已提交
1170
            if p.n_iter > 1:
A
AUTOMATIC 已提交
1171 1172
                state.job = f"Batch {n+1} out of {p.n_iter}"

1173
            samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
A
first  
AUTOMATIC 已提交
1174

A
AUTOMATIC 已提交
1175 1176
            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
A
first  
AUTOMATIC 已提交
1177

1178 1179 1180
            for i, x_sample in enumerate(x_samples_ddim):
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)
A
AUTOMATIC 已提交
1181

1182 1183
                if p.use_GFPGAN:
                    torch_gc()
A
AUTOMATIC 已提交
1184

1185 1186
                    gfpgan_model = gfpgan()
                    x_sample = gfpgan_fix_faces(gfpgan_model, x_sample)
A
AUTOMATIC 已提交
1187

1188
                image = Image.fromarray(x_sample)
1189

1190 1191
                if p.overlay_images is not None and i < len(p.overlay_images):
                    overlay = p.overlay_images[i]
1192

1193 1194 1195 1196 1197 1198
                    if p.paste_to is not None:
                        x, y, w, h = p.paste_to
                        base_image = Image.new('RGBA', (overlay.width, overlay.height))
                        image = resize_image(1, image, w, h)
                        base_image.paste(image, (x, y))
                        image = base_image
1199

1200 1201 1202
                    image = image.convert('RGBA')
                    image.alpha_composite(overlay)
                    image = image.convert('RGB')
1203

1204 1205
                if opts.samples_save and not p.do_not_save_samples:
                    save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext())
A
AUTOMATIC 已提交
1206

1207
                output_images.append(image)
A
first  
AUTOMATIC 已提交
1208

1209
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
1210
        if not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
A
AUTOMATIC 已提交
1211 1212
            return_grid = opts.return_grid

1213
            if p.prompt_matrix:
1214
                grid = image_grid(output_images, p.batch_size, rows=1 << ((len(prompt_matrix_parts)-1)//2))
A
AUTOMATIC 已提交
1215 1216

                try:
1217 1218
                    grid = draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
                except Exception:
A
AUTOMATIC 已提交
1219 1220 1221 1222
                    import traceback
                    print("Error creating prompt_matrix text:", file=sys.stderr)
                    print(traceback.format_exc(), file=sys.stderr)

A
AUTOMATIC 已提交
1223
                return_grid = True
1224
            else:
1225
                grid = image_grid(output_images, p.batch_size)
1226

A
AUTOMATIC 已提交
1227 1228 1229
            if return_grid:
                output_images.insert(0, grid)

1230 1231
            if opts.grid_save:
                save_image(grid, p.outpath_grids, "grid", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
A
first  
AUTOMATIC 已提交
1232

A
AUTOMATIC 已提交
1233
    torch_gc()
A
AUTOMATIC 已提交
1234
    return Processed(p, output_images, seed, infotext())
D
dogewanwan 已提交
1235

A
AUTOMATIC 已提交
1236

1237 1238
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    sampler = None
D
dogewanwan 已提交
1239

1240 1241
    def init(self):
        self.sampler = samplers[self.sampler_index].constructor()
A
AUTOMATIC 已提交
1242

1243 1244
    def sample(self, x, conditioning, unconditional_conditioning):
        samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
A
AUTOMATIC 已提交
1245 1246
        return samples_ddim

1247

A
AUTOMATIC 已提交
1248
def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, code: str):
1249
    p = StableDiffusionProcessingTxt2Img(
A
AUTOMATIC 已提交
1250 1251
        outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
        outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
A
AUTOMATIC 已提交
1252
        prompt=prompt,
1253
        negative_prompt=negative_prompt,
A
AUTOMATIC 已提交
1254
        seed=seed,
A
AUTOMATIC 已提交
1255
        sampler_index=sampler_index,
A
AUTOMATIC 已提交
1256 1257
        batch_size=batch_size,
        n_iter=n_iter,
1258
        steps=steps,
A
AUTOMATIC 已提交
1259 1260 1261 1262 1263 1264 1265
        cfg_scale=cfg_scale,
        width=width,
        height=height,
        prompt_matrix=prompt_matrix,
        use_GFPGAN=use_GFPGAN
    )

1266 1267 1268 1269 1270
    if code != '' and cmd_opts.allow_code:
        p.do_not_save_grid = True
        p.do_not_save_samples = True

        display_result_data = [[], -1, ""]
A
AUTOMATIC 已提交
1271

1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284
        def display(imgs, s=display_result_data[1], i=display_result_data[2]):
            display_result_data[0] = imgs
            display_result_data[1] = s
            display_result_data[2] = i

        from types import ModuleType
        compiled = compile(code, '', 'exec')
        module = ModuleType("testmodule")
        module.__dict__.update(globals())
        module.p = p
        module.display = display
        exec(compiled, module.__dict__)

A
AUTOMATIC 已提交
1285
        processed = Processed(p, *display_result_data)
1286 1287
    else:
        processed = process_images(p)
A
AUTOMATIC 已提交
1288

A
AUTOMATIC 已提交
1289
    return processed.images, processed.js(), plaintext_to_html(processed.info)
A
AUTOMATIC 已提交
1290

A
AUTOMATIC 已提交
1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306
def image_from_url_text(filedata):
    if filedata.startswith("data:image/png;base64,"):
        filedata = filedata[len("data:image/png;base64,"):]

    filedata = base64.decodebytes(filedata.encode('utf-8'))
    image = Image.open(io.BytesIO(filedata))
    return image


def send_gradio_gallery_to_image(x):
    if len(x) == 0:
        return None

    return image_from_url_text(x[0])


A
AUTOMATIC 已提交
1307 1308
def save_files(js_data, images):
    import csv
A
AUTOMATIC 已提交
1309

A
AUTOMATIC 已提交
1310
    os.makedirs(opts.outdir_save, exist_ok=True)
A
AUTOMATIC 已提交
1311

A
AUTOMATIC 已提交
1312
    filenames = []
1313

A
AUTOMATIC 已提交
1314
    data = json.loads(js_data)
A
AUTOMATIC 已提交
1315

A
AUTOMATIC 已提交
1316 1317 1318 1319 1320
    with open("log/log.csv", "a", encoding="utf8", newline='') as file:
        at_start = file.tell() == 0
        writer = csv.writer(file)
        if at_start:
            writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename"])
A
AUTOMATIC 已提交
1321

A
AUTOMATIC 已提交
1322 1323 1324 1325
        filename_base = str(int(time.time() * 1000))
        for i, filedata in enumerate(images):
            filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + ".png"
            filepath = os.path.join(opts.outdir_save, filename)
A
AUTOMATIC 已提交
1326

A
AUTOMATIC 已提交
1327 1328
            if filedata.startswith("data:image/png;base64,"):
                filedata = filedata[len("data:image/png;base64,"):]
A
AUTOMATIC 已提交
1329

A
AUTOMATIC 已提交
1330 1331
            with open(filepath, "wb") as imgfile:
                imgfile.write(base64.decodebytes(filedata.encode('utf-8')))
A
AUTOMATIC 已提交
1332

A
AUTOMATIC 已提交
1333
            filenames.append(filename)
A
AUTOMATIC 已提交
1334

A
AUTOMATIC 已提交
1335
        writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0]])
A
AUTOMATIC 已提交
1336

A
AUTOMATIC 已提交
1337
    return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
A
AUTOMATIC 已提交
1338 1339


A
AUTOMATIC 已提交
1340 1341 1342
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
    with gr.Row():
        prompt = gr.Textbox(label="Prompt", elem_id="txt2img_prompt", show_label=False, placeholder="Prompt", lines=1)
1343
        negative_prompt = gr.Textbox(label="Negative prompt", elem_id="txt2img_negative_prompt", show_label=False, placeholder="Negative prompt", lines=1, visible=False)
1344
        submit = gr.Button('Generate', elem_id="txt2img_generate", variant='primary')
A
AUTOMATIC 已提交
1345 1346 1347 1348

    with gr.Row().style(equal_height=False):
        with gr.Column(variant='panel'):
            steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
O
orionaskatu 已提交
1349
            sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
A
AUTOMATIC 已提交
1350 1351 1352 1353 1354 1355 1356 1357 1358

            with gr.Row():
                use_GFPGAN = gr.Checkbox(label='GFPGAN', value=False, visible=have_gfpgan)
                prompt_matrix = gr.Checkbox(label='Prompt matrix', value=False)

            with gr.Row():
                batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
                batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)

1359
            cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.0)
A
AUTOMATIC 已提交
1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370

            with gr.Group():
                height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
                width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)

            seed = gr.Number(label='Seed', value=-1)

            code = gr.Textbox(label="Python script", visible=cmd_opts.allow_code, lines=1)

        with gr.Column(variant='panel'):
            with gr.Group():
A
AUTOMATIC 已提交
1371
                txt2img_gallery = gr.Gallery(label='Output')
A
AUTOMATIC 已提交
1372 1373 1374 1375

            with gr.Group():
                with gr.Row():
                    save = gr.Button('Save')
A
AUTOMATIC 已提交
1376 1377 1378 1379
                    send_to_img2img = gr.Button('Send to img2img')
                    send_to_inpaint = gr.Button('Send to inpaint')
                    send_to_extras = gr.Button('Send to extras')
                    interrupt = gr.Button('Interrupt')
A
AUTOMATIC 已提交
1380 1381

            with gr.Group():
A
AUTOMATIC 已提交
1382
                html_info = gr.HTML()
A
AUTOMATIC 已提交
1383
                generation_info = gr.Textbox(visible=False)
A
AUTOMATIC 已提交
1384 1385

        txt2img_args = dict(
1386
            fn=wrap_gradio_gpu_call(txt2img),
A
AUTOMATIC 已提交
1387 1388
            inputs=[
                prompt,
1389
                negative_prompt,
A
AUTOMATIC 已提交
1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402
                steps,
                sampler_index,
                use_GFPGAN,
                prompt_matrix,
                batch_count,
                batch_size,
                cfg_scale,
                seed,
                height,
                width,
                code
            ],
            outputs=[
A
AUTOMATIC 已提交
1403
                txt2img_gallery,
A
AUTOMATIC 已提交
1404
                generation_info,
A
AUTOMATIC 已提交
1405 1406 1407 1408 1409 1410 1411
                html_info
            ]
        )

        prompt.submit(**txt2img_args)
        submit.click(**txt2img_args)

A
AUTOMATIC 已提交
1412 1413 1414 1415 1416 1417 1418 1419 1420 1421
        interrupt.click(
            fn=lambda: state.interrupt(),
            inputs=[],
            outputs=[],
        )

        save.click(
            fn=wrap_gradio_call(save_files),
            inputs=[
                generation_info,
A
AUTOMATIC 已提交
1422
                txt2img_gallery,
A
AUTOMATIC 已提交
1423 1424 1425 1426 1427 1428 1429 1430
            ],
            outputs=[
                html_info,
                html_info,
                html_info,
            ]
        )

A
first  
AUTOMATIC 已提交
1431

1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466
def get_crop_region(mask, pad=0):
    h, w = mask.shape

    crop_left = 0
    for i in range(w):
        if not (mask[:,i] == 0).all():
            break
        crop_left += 1

    crop_right = 0
    for i in reversed(range(w)):
        if not (mask[:,i] == 0).all():
            break
        crop_right += 1


    crop_top = 0
    for i in range(h):
        if not (mask[i] == 0).all():
            break
        crop_top += 1

    crop_bottom = 0
    for i in reversed(range(h)):
        if not (mask[i] == 0).all():
            break
        crop_bottom += 1

    return (
        int(max(crop_left-pad, 0)),
        int(max(crop_top-pad, 0)),
        int(min(w - crop_right + pad, w)),
        int(min(h - crop_bottom + pad, h))
    )

A
first  
AUTOMATIC 已提交
1467

A
AUTOMATIC 已提交
1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482
def fill(image, mask):
    image_mod = Image.new('RGBA', (image.width, image.height))

    image_masked = Image.new('RGBa', (image.width, image.height))
    image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))

    image_masked = image_masked.convert('RGBa')

    for radius, repeats in [(64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
        blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
        for _ in range(repeats):
            image_mod.alpha_composite(blurred)

    return image_mod.convert("RGB")

A
first  
AUTOMATIC 已提交
1483

1484 1485
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    sampler = None
D
dogewanwan 已提交
1486

1487
    def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, **kwargs):
1488
        super().__init__(**kwargs)
A
first  
AUTOMATIC 已提交
1489

1490 1491 1492 1493
        self.init_images = init_images
        self.resize_mode: int = resize_mode
        self.denoising_strength: float = denoising_strength
        self.init_latent = None
1494 1495
        self.image_mask = mask
        self.mask_for_overlay = None
A
AUTOMATIC 已提交
1496
        self.mask_blur = mask_blur
1497
        self.inpainting_fill = inpainting_fill
1498
        self.inpaint_full_res = inpaint_full_res
A
AUTOMATIC 已提交
1499 1500
        self.mask = None
        self.nmask = None
1501 1502 1503

    def init(self):
        self.sampler = samplers_for_img2img[self.sampler_index].constructor()
1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521
        crop_region = None

        if self.image_mask is not None:
            if self.mask_blur > 0:
                self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)).convert('L')

            if self.inpaint_full_res:
                self.mask_for_overlay = self.image_mask
                mask = self.image_mask.convert('L')
                crop_region = get_crop_region(np.array(mask), 64)
                x1, y1, x2, y2 = crop_region

                mask = mask.crop(crop_region)
                self.image_mask = resize_image(2, mask, self.width, self.height)
                self.paste_to = (x1, y1, x2-x1, y2-y1)
            else:
                self.image_mask = resize_image(self.resize_mode, self.image_mask, self.width, self.height)
                self.mask_for_overlay = self.image_mask
A
first  
AUTOMATIC 已提交
1522

1523 1524
            self.overlay_images = []

1525

1526 1527 1528
        imgs = []
        for img in self.init_images:
            image = img.convert("RGB")
A
AUTOMATIC 已提交
1529

1530 1531 1532 1533
            if crop_region is None:
                image = resize_image(self.resize_mode, image, self.width, self.height)

            if self.image_mask is not None:
1534
                if self.inpainting_fill != 1:
1535
                    image = fill(image, self.mask_for_overlay)
1536 1537

                image_masked = Image.new('RGBa', (image.width, image.height))
1538
                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
1539 1540

                self.overlay_images.append(image_masked.convert('RGBA'))
A
AUTOMATIC 已提交
1541

1542 1543 1544 1545
            if crop_region is not None:
                image = image.crop(crop_region)
                image = resize_image(2, image, self.width, self.height)

1546 1547
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)
A
AUTOMATIC 已提交
1548

1549
            imgs.append(image)
A
first  
AUTOMATIC 已提交
1550

1551 1552
        if len(imgs) == 1:
            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
1553 1554
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size
1555 1556 1557 1558 1559
        elif len(imgs) <= self.batch_size:
            self.batch_size = len(imgs)
            batch_images = np.array(imgs)
        else:
            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
A
AUTOMATIC 已提交
1560

1561 1562 1563
        image = torch.from_numpy(batch_images)
        image = 2. * image - 1.
        image = image.to(device)
A
AUTOMATIC 已提交
1564

1565
        self.init_latent = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image))
A
AUTOMATIC 已提交
1566

1567 1568
        if self.image_mask is not None:
            latmask = self.image_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
A
AUTOMATIC 已提交
1569
            latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255
1570 1571 1572 1573 1574 1575
            latmask = latmask[0]
            latmask = np.tile(latmask[None], (4, 1, 1))

            self.mask = torch.asarray(1.0 - latmask).to(device).type(sd_model.dtype)
            self.nmask = torch.asarray(latmask).to(device).type(sd_model.dtype)

1576 1577 1578 1579 1580
            if self.inpainting_fill == 2:
                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], [self.seed + x + 1 for x in range(self.init_latent.shape[0])]) * self.nmask
            elif self.inpainting_fill == 3:
                self.init_latent = self.init_latent * self.mask

1581
    def sample(self, x, conditioning, unconditional_conditioning):
A
AUTOMATIC 已提交
1582
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
A
AUTOMATIC 已提交
1583

1584
        if self.mask is not None:
A
AUTOMATIC 已提交
1585
            samples = samples * self.nmask + self.init_latent * self.mask
1586

A
AUTOMATIC 已提交
1587
        return samples
A
AUTOMATIC 已提交
1588

1589

1590
def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, prompt_matrix, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_name: str, upscale_overlap: int, inpaint_full_res: bool):
A
AUTOMATIC 已提交
1591 1592 1593 1594 1595 1596
    is_classic = mode == 0
    is_inpaint = mode == 1
    is_loopback = mode == 2
    is_upscale = mode == 3

    if is_inpaint:
A
AUTOMATIC 已提交
1597 1598 1599 1600 1601 1602
        image = init_img_with_mask['image']
        mask = init_img_with_mask['mask']
    else:
        image = init_img
        mask = None

1603 1604 1605
    assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'

    p = StableDiffusionProcessingImg2Img(
A
AUTOMATIC 已提交
1606 1607
        outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
        outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
1608 1609 1610 1611 1612
        prompt=prompt,
        seed=seed,
        sampler_index=sampler_index,
        batch_size=batch_size,
        n_iter=n_iter,
A
AUTOMATIC 已提交
1613
        steps=steps,
1614 1615 1616 1617 1618
        cfg_scale=cfg_scale,
        width=width,
        height=height,
        prompt_matrix=prompt_matrix,
        use_GFPGAN=use_GFPGAN,
A
AUTOMATIC 已提交
1619 1620
        init_images=[image],
        mask=mask,
1621 1622
        mask_blur=mask_blur,
        inpainting_fill=inpainting_fill,
1623 1624
        resize_mode=resize_mode,
        denoising_strength=denoising_strength,
1625
        inpaint_full_res=inpaint_full_res,
1626 1627 1628
        extra_generation_params={"Denoising Strength": denoising_strength}
    )

A
AUTOMATIC 已提交
1629
    if is_loopback:
A
AUTOMATIC 已提交
1630 1631 1632
        output_images, info = None, None
        history = []
        initial_seed = None
1633
        initial_info = None
A
AUTOMATIC 已提交
1634 1635

        for i in range(n_iter):
1636 1637 1638 1639
            p.n_iter = 1
            p.batch_size = 1
            p.do_not_save_grid = True

A
AUTOMATIC 已提交
1640
            state.job = f"Batch {i + 1} out of {n_iter}"
1641
            processed = process_images(p)
A
AUTOMATIC 已提交
1642 1643

            if initial_seed is None:
1644 1645
                initial_seed = processed.seed
                initial_info = processed.info
A
AUTOMATIC 已提交
1646

A
AUTOMATIC 已提交
1647
            p.init_images = [processed.images[0]]
1648
            p.seed = processed.seed + 1
1649
            p.denoising_strength = max(p.denoising_strength * 0.95, 0.1)
1650
            history.append(processed.images[0])
A
AUTOMATIC 已提交
1651

1652
        grid = image_grid(history, batch_size, rows=1)
1653

A
AUTOMATIC 已提交
1654
        save_image(grid, p.outpath_grids, "grid", initial_seed, prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename)
A
AUTOMATIC 已提交
1655

A
AUTOMATIC 已提交
1656
        processed = Processed(p, history, initial_seed, initial_info)
A
AUTOMATIC 已提交
1657

A
AUTOMATIC 已提交
1658
    elif is_upscale:
A
AUTOMATIC 已提交
1659 1660 1661
        initial_seed = None
        initial_info = None

A
AUTOMATIC 已提交
1662
        upscaler = sd_upscalers.get(upscaler_name, next(iter(sd_upscalers.values())))
A
AUTOMATIC 已提交
1663
        img = upscaler(init_img)
A
AUTOMATIC 已提交
1664 1665 1666

        torch_gc()

1667
        grid = split_grid(img, tile_w=width, tile_h=height, overlap=upscale_overlap)
A
AUTOMATIC 已提交
1668

1669 1670
        p.n_iter = 1
        p.do_not_save_grid = True
1671
        p.do_not_save_samples = True
1672 1673 1674 1675 1676 1677 1678 1679 1680 1681

        work = []
        work_results = []

        for y, h, row in grid.tiles:
            for tiledata in row:
                work.append(tiledata[2])

        batch_count = math.ceil(len(work) / p.batch_size)
        print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} in a total of {batch_count} batches.")
A
AUTOMATIC 已提交
1682

1683 1684
        for i in range(batch_count):
            p.init_images = work[i*p.batch_size:(i+1)*p.batch_size]
A
AUTOMATIC 已提交
1685

A
AUTOMATIC 已提交
1686
            state.job = f"Batch {i + 1} out of {batch_count}"
1687
            processed = process_images(p)
1688 1689

            if initial_seed is None:
1690 1691
                initial_seed = processed.seed
                initial_info = processed.info
1692

1693 1694
            p.seed = processed.seed + 1
            work_results += processed.images
1695 1696

        image_index = 0
A
AUTOMATIC 已提交
1697 1698
        for y, h, row in grid.tiles:
            for tiledata in row:
A
AUTOMATIC 已提交
1699
                tiledata[2] = work_results[image_index] if image_index<len(work_results) else Image.new("RGB", (p.width, p.height))
1700
                image_index += 1
A
AUTOMATIC 已提交
1701 1702 1703

        combined_image = combine_grid(grid)

A
AUTOMATIC 已提交
1704
        save_image(combined_image, p.outpath_grids, "grid", initial_seed, prompt, opts.grid_format, info=initial_info, short_filename=not opts.grid_extended_filename)
A
AUTOMATIC 已提交
1705

A
AUTOMATIC 已提交
1706
        processed = Processed(p, [combined_image], initial_seed, initial_info)
A
AUTOMATIC 已提交
1707

A
AUTOMATIC 已提交
1708
    else:
1709
        processed = process_images(p)
A
first  
AUTOMATIC 已提交
1710

A
AUTOMATIC 已提交
1711
    return processed.images, processed.js(), plaintext_to_html(processed.info)
A
first  
AUTOMATIC 已提交
1712 1713


1714 1715 1716
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None

A
AUTOMATIC 已提交
1717 1718 1719 1720

with gr.Blocks(analytics_enabled=False) as img2img_interface:
    with gr.Row():
        prompt = gr.Textbox(label="Prompt", elem_id="img2img_prompt", show_label=False, placeholder="Prompt", lines=1)
1721
        submit = gr.Button('Generate', elem_id="img2img_generate", variant='primary')
A
AUTOMATIC 已提交
1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733

    with gr.Row().style(equal_height=False):

        with gr.Column(variant='panel'):
            with gr.Group():
                switch_mode = gr.Radio(label='Mode', elem_id="img2img_mode", choices=['Redraw whole image', 'Inpaint a part of image', 'Loopback', 'SD upscale'], value='Redraw whole image', type="index", show_label=False)
                init_img = gr.Image(label="Image for img2img", source="upload", interactive=True, type="pil")
                init_img_with_mask = gr.Image(label="Image for inpainting with mask", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", visible=False)
                resize_mode = gr.Radio(label="Resize mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")

            steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
            sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
1734 1735
            mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, visible=False)
            inpainting_fill = gr.Radio(label='Msked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False)
A
AUTOMATIC 已提交
1736 1737 1738 1739

            with gr.Row():
                use_GFPGAN = gr.Checkbox(label='GFPGAN', value=False, visible=have_gfpgan)
                prompt_matrix = gr.Checkbox(label='Prompt matrix', value=False)
1740
                inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=True, visible=False)
A
AUTOMATIC 已提交
1741

1742
            with gr.Row():
A
AUTOMATIC 已提交
1743 1744
                sd_upscale_upscaler_name = gr.Radio(label='Upscaler', choices=list(sd_upscalers.keys()), value=list(sd_upscalers.keys())[0], visible=False)
                sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
1745

A
AUTOMATIC 已提交
1746 1747 1748 1749 1750
            with gr.Row():
                batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
                batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)

            with gr.Group():
1751
                cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.0)
A
AUTOMATIC 已提交
1752 1753 1754 1755 1756 1757 1758 1759 1760 1761
                denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75)

            with gr.Group():
                height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
                width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)

            seed = gr.Number(label='Seed', value=-1)

        with gr.Column(variant='panel'):
            with gr.Group():
A
AUTOMATIC 已提交
1762
                img2img_gallery = gr.Gallery(label='Output')
A
AUTOMATIC 已提交
1763 1764 1765 1766 1767

            with gr.Group():
                with gr.Row():
                    interrupt = gr.Button('Interrupt')
                    save = gr.Button('Save')
A
AUTOMATIC 已提交
1768
                    img2img_send_to_extras = gr.Button('Send to extras')
A
AUTOMATIC 已提交
1769 1770

            with gr.Group():
A
AUTOMATIC 已提交
1771
                html_info = gr.HTML()
A
AUTOMATIC 已提交
1772
                generation_info = gr.Textbox(visible=False)
A
AUTOMATIC 已提交
1773 1774 1775 1776 1777 1778 1779 1780

        def apply_mode(mode):
            is_classic = mode == 0
            is_inpaint = mode == 1
            is_loopback = mode == 2
            is_upscale = mode == 3

            return {
A
AUTOMATIC 已提交
1781 1782 1783 1784 1785 1786 1787 1788 1789 1790
                init_img: gr_show(not is_inpaint),
                init_img_with_mask: gr_show(is_inpaint),
                mask_blur: gr_show(is_inpaint),
                inpainting_fill: gr_show(is_inpaint),
                prompt_matrix: gr_show(is_classic),
                batch_count: gr_show(not is_upscale),
                batch_size: gr_show(not is_loopback),
                sd_upscale_upscaler_name: gr_show(is_upscale),
                sd_upscale_overlap:gr_show(is_upscale),
                inpaint_full_res: gr_show(is_inpaint),
A
AUTOMATIC 已提交
1791 1792 1793 1794 1795
            }

        switch_mode.change(
            apply_mode,
            inputs=[switch_mode],
1796 1797 1798 1799 1800 1801 1802 1803 1804 1805
            outputs=[
                init_img,
                init_img_with_mask,
                mask_blur,
                inpainting_fill,
                prompt_matrix,
                batch_count,
                batch_size,
                sd_upscale_upscaler_name,
                sd_upscale_overlap,
1806
                inpaint_full_res,
1807
            ]
A
AUTOMATIC 已提交
1808 1809 1810
        )

        img2img_args = dict(
1811
            fn=wrap_gradio_gpu_call(img2img),
A
AUTOMATIC 已提交
1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829
            inputs=[
                prompt,
                init_img,
                init_img_with_mask,
                steps,
                sampler_index,
                mask_blur,
                inpainting_fill,
                use_GFPGAN,
                prompt_matrix,
                switch_mode,
                batch_count,
                batch_size,
                cfg_scale,
                denoising_strength,
                seed,
                height,
                width,
1830 1831 1832
                resize_mode,
                sd_upscale_upscaler_name,
                sd_upscale_overlap,
1833
                inpaint_full_res,
A
AUTOMATIC 已提交
1834 1835
            ],
            outputs=[
A
AUTOMATIC 已提交
1836
                img2img_gallery,
A
AUTOMATIC 已提交
1837
                generation_info,
A
AUTOMATIC 已提交
1838 1839 1840 1841 1842 1843
                html_info
            ]
        )

        prompt.submit(**img2img_args)
        submit.click(**img2img_args)
A
first  
AUTOMATIC 已提交
1844

A
AUTOMATIC 已提交
1845 1846 1847 1848 1849 1850 1851 1852 1853 1854
        interrupt.click(
            fn=lambda: state.interrupt(),
            inputs=[],
            outputs=[],
        )

        save.click(
            fn=wrap_gradio_call(save_files),
            inputs=[
                generation_info,
A
AUTOMATIC 已提交
1855
                img2img_gallery,
A
AUTOMATIC 已提交
1856 1857 1858 1859 1860 1861 1862 1863
            ],
            outputs=[
                html_info,
                html_info,
                html_info,
            ]
        )

A
AUTOMATIC 已提交
1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876
        send_to_img2img.click(
            fn=send_gradio_gallery_to_image,
            inputs=[txt2img_gallery],
            outputs=[init_img],
        )

        send_to_inpaint.click(
            fn=send_gradio_gallery_to_image,
            inputs=[txt2img_gallery],
            outputs=[init_img_with_mask],
        )


A
AUTOMATIC 已提交
1877

A
AUTOMATIC 已提交
1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
    info = realesrgan_models[RealESRGAN_model_index]

    model = info.model()
    upsampler = RealESRGANer(
        scale=info.netscale,
        model_path=info.location,
        model=model,
        half=True
    )

    upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]

    image = Image.fromarray(upsampled)
    return image


A
AUTOMATIC 已提交
1895
def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index):
A
AUTOMATIC 已提交
1896 1897
    torch_gc()

1898 1899
    image = image.convert("RGB")

A
AUTOMATIC 已提交
1900
    outpath = opts.outdir_samples or opts.outdir_extras_samples
A
AUTOMATIC 已提交
1901

A
AUTOMATIC 已提交
1902 1903
    if have_gfpgan is not None and GFPGAN_strength > 0:
        gfpgan_model = gfpgan()
1904 1905

        restored_img = gfpgan_fix_faces(gfpgan_model, np.array(image, dtype=np.uint8))
A
AUTOMATIC 已提交
1906 1907 1908 1909 1910 1911 1912 1913
        res = Image.fromarray(restored_img)

        if GFPGAN_strength < 1.0:
            res = Image.blend(image, res, GFPGAN_strength)

        image = res

    if have_realesrgan and RealESRGAN_upscaling != 1.0:
A
AUTOMATIC 已提交
1914
        image = upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
1915

A
AUTOMATIC 已提交
1916
    save_image(image, outpath, "", None, '', opts.samples_format, short_filename=True, no_prompt=True)
1917

A
AUTOMATIC 已提交
1918
    return image, '', ''
1919 1920


A
AUTOMATIC 已提交
1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966
with gr.Blocks(analytics_enabled=False) as extras_interface:
    with gr.Row().style(equal_height=False):
        with gr.Column(variant='panel'):
            with gr.Group():
                image = gr.Image(label="Source", source="upload", interactive=True, type="pil")
                gfpgan_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=have_gfpgan)
                realesrgan_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Real-ESRGAN upscaling", value=2, interactive=have_realesrgan)
                realesrgan_model = gr.Radio(label='Real-ESRGAN model', choices=[x.name for x in realesrgan_models], value=realesrgan_models[0].name, type="index", interactive=have_realesrgan)

            submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')

        with gr.Column(variant='panel'):
            result_image = gr.Image(label="Result")
            html_info_x = gr.HTML()
            html_info = gr.HTML()

    extras_args = dict(
        fn=wrap_gradio_gpu_call(run_extras),
        inputs=[
            image,
            gfpgan_strength,
            realesrgan_resize,
            realesrgan_model,
        ],
        outputs=[
            result_image,
            html_info_x,
            html_info,
        ]
    )

    submit.click(**extras_args)

    send_to_extras.click(
        fn=send_gradio_gallery_to_image,
        inputs=[txt2img_gallery],
        outputs=[image],
    )

    img2img_send_to_extras.click(
        fn=send_gradio_gallery_to_image,
        inputs=[img2img_gallery],
        outputs=[image],
    )


A
AUTOMATIC 已提交
1967

A
AUTOMATIC 已提交
1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982

def run_pnginfo(image):
    info = ''
    for key, text in image.info.items():
        info += f"""
<div>
<p><b>{plaintext_to_html(str(key))}</b></p>
<p>{plaintext_to_html(str(text))}</p>
</div>
""".strip()+"\n"

    if len(info) == 0:
        message = "Nothing found in the image."
        info = f"<div><p>{message}<p></div>"

A
AUTOMATIC 已提交
1983
    return '', '', info
A
AUTOMATIC 已提交
1984 1985 1986 1987 1988 1989 1990 1991 1992


pnginfo_interface = gr.Interface(
    wrap_gradio_call(run_pnginfo),
    inputs=[
        gr.Image(label="Source", source="upload", interactive=True, type="pil"),
    ],
    outputs=[
        gr.HTML(),
A
AUTOMATIC 已提交
1993 1994
        gr.HTML(),
        gr.HTML(),
A
AUTOMATIC 已提交
1995 1996
    ],
    allow_flagging="never",
A
AUTOMATIC 已提交
1997
    analytics_enabled=False,
A
AUTOMATIC 已提交
1998 1999 2000
)


A
AUTOMATIC 已提交
2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014
opts = Options()
if os.path.exists(config_filename):
    opts.load(config_filename)


def run_settings(*args):
    up = []

    for key, value, comp in zip(opts.data_labels.keys(), args, settings_interface.input_components):
        opts.data[key] = value
        up.append(comp.update(value=value))

    opts.save(config_filename)

A
AUTOMATIC 已提交
2015
    return 'Settings saved.', '', ''
A
AUTOMATIC 已提交
2016 2017 2018 2019


def create_setting_component(key):
    def fun():
2020 2021 2022 2023
        return opts.data[key] if key in opts.data else opts.data_labels[key].default

    info = opts.data_labels[key]
    t = type(info.default)
A
AUTOMATIC 已提交
2024

2025 2026 2027 2028
    if info.component is not None:
        item = info.component(label=info.label, value=fun, **(info.component_args or {}))
    elif t == str:
        item = gr.Textbox(label=info.label, value=fun, lines=1)
A
AUTOMATIC 已提交
2029
    elif t == int:
2030
        item = gr.Number(label=info.label, value=fun)
A
AUTOMATIC 已提交
2031
    elif t == bool:
2032
        item = gr.Checkbox(label=info.label, value=fun)
A
AUTOMATIC 已提交
2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044
    else:
        raise Exception(f'bad options item type: {str(t)} for key {key}')

    return item


settings_interface = gr.Interface(
    run_settings,
    inputs=[create_setting_component(key) for key in opts.data_labels.keys()],
    outputs=[
        gr.Textbox(label='Result'),
        gr.HTML(),
A
AUTOMATIC 已提交
2045
        gr.HTML(),
A
AUTOMATIC 已提交
2046 2047 2048 2049
    ],
    title=None,
    description=None,
    allow_flagging="never",
A
AUTOMATIC 已提交
2050
    analytics_enabled=False,
A
AUTOMATIC 已提交
2051 2052 2053 2054 2055
)

interfaces = [
    (txt2img_interface, "txt2img"),
    (img2img_interface, "img2img"),
A
AUTOMATIC 已提交
2056
    (extras_interface, "Extras"),
A
AUTOMATIC 已提交
2057
    (pnginfo_interface, "PNG Info"),
A
AUTOMATIC 已提交
2058 2059 2060
    (settings_interface, "Settings"),
]

A
AUTOMATIC 已提交
2061 2062 2063 2064 2065 2066 2067 2068 2069
try:
    # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.

    from transformers import logging

    logging.set_verbosity_error()
except Exception:
    pass

A
AUTOMATIC 已提交
2070 2071 2072
sd_config = OmegaConf.load(cmd_opts.config)
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
sd_model = (sd_model if cmd_opts.no_half else sd_model.half())
2073

A
AUTOMATIC 已提交
2074
if cmd_opts.lowvram or cmd_opts.medvram:
A
AUTOMATIC 已提交
2075
    setup_for_low_vram(sd_model)
A
AUTOMATIC 已提交
2076
else:
A
AUTOMATIC 已提交
2077
    sd_model = sd_model.to(device)
A
AUTOMATIC 已提交
2078

A
AUTOMATIC 已提交
2079
model_hijack = StableDiffusionModelHijack()
A
AUTOMATIC 已提交
2080
model_hijack.hijack(sd_model)
2081

A
AUTOMATIC 已提交
2082 2083 2084
with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
    css = file.read()

2085 2086 2087 2088 2089 2090
if not cmd_opts.no_progressbar_hiding:
    css += css_hide_progressbar

with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as file:
    javascript = file.read()

A
first  
AUTOMATIC 已提交
2091

2092 2093 2094 2095 2096
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(signal, frame):
    print('Interrupted')
    os._exit(0)

2097

2098 2099
signal.signal(signal.SIGINT, sigint_handler)

2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122
demo = gr.TabbedInterface(
    interface_list=[x[0] for x in interfaces],
    tab_names=[x[1] for x in interfaces],
    analytics_enabled=False,
    css=css,
)


def inject_gradio_html(javascript):
    import gradio.routes

    def template_response(*args, **kwargs):
        res = gradio_routes_templates_response(*args, **kwargs)
        res.body = res.body.replace(b'</head>', f'<script>{javascript}</script></head>'.encode("utf8"))
        res.init_headers()
        return res

    gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
    gradio.routes.templates.TemplateResponse = template_response


inject_gradio_html(javascript)

A
AUTOMATIC 已提交
2123
demo.launch(share=cmd_opts.share)
2124