webui.py 67.3 KB
Newer Older
1 2 3
import argparse
import os
import sys
4 5 6 7 8 9 10 11 12 13 14 15

script_path = os.path.dirname(os.path.realpath(__file__))
sd_path = os.path.dirname(script_path)

# add parent directory to path; this is where Stable diffusion repo should be
path_dirs = [(sd_path, 'ldm', 'Stable Diffusion'), ('../../taming-transformers', 'taming', 'Taming Transformers')]
for d, must_exist, what in path_dirs:
    must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
    if not os.path.exists(must_exist_path):
        print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
    else:
        sys.path.append(os.path.join(script_path, d))
16

A
first  
AUTOMATIC 已提交
17 18 19 20
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
A
AUTOMATIC 已提交
21
import gradio.utils
A
first  
AUTOMATIC 已提交
22
from omegaconf import OmegaConf
A
AUTOMATIC 已提交
23
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin, ImageFilter, ImageOps
A
first  
AUTOMATIC 已提交
24 25 26
from torch import autocast
import mimetypes
import random
27
import math
A
AUTOMATIC 已提交
28 29
import html
import time
A
AUTOMATIC 已提交
30 31
import json
import traceback
32 33 34
from collections import namedtuple
from contextlib import nullcontext
import signal
A
first  
AUTOMATIC 已提交
35

A
AUTOMATIC 已提交
36
import k_diffusion.sampling
A
first  
AUTOMATIC 已提交
37 38 39 40
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

A
AUTOMATIC 已提交
41 42 43
# fix gradio phoning home
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
44

A
first  
AUTOMATIC 已提交
45 46 47 48
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')

A
AUTOMATIC 已提交
49

A
first  
AUTOMATIC 已提交
50 51 52 53
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8

54
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
55
invalid_filename_chars = '<>:"/\\|?*\n'
A
AUTOMATIC 已提交
56
config_filename = "config.json"
A
AUTOMATIC 已提交
57

A
first  
AUTOMATIC 已提交
58
parser = argparse.ArgumentParser()
59 60
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=os.path.join(sd_path, "models/ldm/stable-diffusion-v1/model.ckpt"), help="path to checkpoint of model",)
61
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
62
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default='GFPGANv1.3.pth')
63
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
64
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
65
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
A
AUTOMATIC 已提交
66
parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="embeddings dirtectory for textual inversion (default: embeddings)")
67
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
A
AUTOMATIC 已提交
68
parser.add_argument("--lowvram", action='store_true', help="enamble stable diffusion model optimizations for low vram")
69
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
70

A
AUTOMATIC 已提交
71
cmd_opts = parser.parse_args()
A
first  
AUTOMATIC 已提交
72

A
AUTOMATIC 已提交
73 74 75 76
cpu = torch.device("cpu")
gpu = torch.device("cuda")
device = gpu if torch.cuda.is_available() else cpu

77 78 79 80 81 82
css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
"""
A
first  
AUTOMATIC 已提交
83

A
AUTOMATIC 已提交
84 85
SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
samplers = [
86
    *[SamplerData(x[0], lambda funcname=x[1]: KDiffusionSampler(funcname)) for x in [
A
AUTOMATIC 已提交
87
        ('Euler a', 'sample_euler_ancestral'),
A
AUTOMATIC 已提交
88
        ('Euler', 'sample_euler'),
A
AUTOMATIC 已提交
89 90
        ('LMS', 'sample_lms'),
        ('Heun', 'sample_heun'),
A
AUTOMATIC 已提交
91 92
        ('DPM2', 'sample_dpm_2'),
        ('DPM2 a', 'sample_dpm_2_ancestral'),
A
AUTOMATIC 已提交
93
    ] if hasattr(k_diffusion.sampling, x[1])],
94 95
    SamplerData('DDIM', lambda: VanillaStableDiffusionSampler(DDIMSampler)),
    SamplerData('PLMS', lambda: VanillaStableDiffusionSampler(PLMSSampler)),
A
AUTOMATIC 已提交
96
]
A
AUTOMATIC 已提交
97
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
A
AUTOMATIC 已提交
98

A
AUTOMATIC 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])

try:
    from basicsr.archs.rrdbnet_arch import RRDBNet
    from realesrgan import RealESRGANer
    from realesrgan.archs.srvgg_arch import SRVGGNetCompact

    realesrgan_models = [
        RealesrganModelInfo(
            name="Real-ESRGAN 4x plus",
            location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
            netscale=4, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        ),
        RealesrganModelInfo(
            name="Real-ESRGAN 4x plus anime 6B",
            location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
            netscale=4, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
        ),
A
AUTOMATIC 已提交
117 118 119 120 121
        RealesrganModelInfo(
            name="Real-ESRGAN 2x plus",
            location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
            netscale=2, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
        ),
A
AUTOMATIC 已提交
122 123
    ]
    have_realesrgan = True
124
except Exception:
A
AUTOMATIC 已提交
125
    print("Error importing Real-ESRGAN:", file=sys.stderr)
A
AUTOMATIC 已提交
126 127 128 129 130
    print(traceback.format_exc(), file=sys.stderr)

    realesrgan_models = [RealesrganModelInfo('None', '', 0, None)]
    have_realesrgan = False

A
AUTOMATIC 已提交
131 132 133 134 135
sd_upscalers = {
    "RealESRGAN": lambda img: upscale_with_realesrgan(img, 2, 0),
    "Lanczos": lambda img: img.resize((img.width*2, img.height*2), resample=LANCZOS),
    "None": lambda img: img
}
A
AUTOMATIC 已提交
136

137

138 139 140 141 142 143 144
def gfpgan_model_path():
    places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
    files = [cmd_opts.gfpgan_model] + [os.path.join(dirname, cmd_opts.gfpgan_model) for dirname in places]
    found = [x for x in files if os.path.exists(x)]

    if len(found) == 0:
        raise Exception("GFPGAN model not found in paths: " + ", ".join(files))
A
AUTOMATIC 已提交
145

146
    return found[0]
A
AUTOMATIC 已提交
147 148 149


def gfpgan():
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
    return GFPGANer(model_path=gfpgan_model_path(), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)


have_gfpgan = False
try:
    model_path = gfpgan_model_path()

    if os.path.exists(cmd_opts.gfpgan_dir):
        sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir))
    from gfpgan import GFPGANer

    have_gfpgan = True
except Exception:
    print("Error setting up GFPGAN:", file=sys.stderr)
    print(traceback.format_exc(), file=sys.stderr)
A
AUTOMATIC 已提交
165 166 167



A
AUTOMATIC 已提交
168
class Options:
169 170 171 172 173 174 175
    class OptionInfo:
        def __init__(self, default=None, label="", component=None, component_args=None):
            self.default = default
            self.label = label
            self.component = component
            self.component_args = component_args

A
AUTOMATIC 已提交
176 177
    data = None
    data_labels = {
178 179 180 181
        "outdir": OptionInfo("", "Output dictectory; if empty, defaults to 'outputs/*'"),
        "samples_save": OptionInfo(True, "Save indiviual samples"),
        "samples_format": OptionInfo('png', 'File format for indiviual samples'),
        "grid_save": OptionInfo(True, "Save image grids"),
A
AUTOMATIC 已提交
182
        "return_grid": OptionInfo(True, "Show grid in results for web"),
183 184
        "grid_format": OptionInfo('png', 'File format for grids'),
        "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
185
        "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
186 187
        "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
        "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
188
        "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
189
        "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
A
AUTOMATIC 已提交
190
        "font": OptionInfo("arial.ttf", "Font for image grids  that have text"),
191
        "prompt_matrix_add_to_start": OptionInfo(True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"),
A
AUTOMATIC 已提交
192 193 194
    }

    def __init__(self):
195
        self.data = {k: v.default for k, v in self.data_labels.items()}
A
AUTOMATIC 已提交
196 197 198 199 200 201 202 203 204 205 206 207 208

    def __setattr__(self, key, value):
        if self.data is not None:
            if key in self.data:
                self.data[key] = value

        return super(Options, self).__setattr__(key, value)

    def __getattr__(self, item):
        if self.data is not None:
            if item in self.data:
                return self.data[item]

209
        if item in self.data_labels:
210
            return self.data_labels[item].default
211

A
AUTOMATIC 已提交
212 213 214 215 216 217 218 219 220 221 222
        return super(Options, self).__getattribute__(item)

    def save(self, filename):
        with open(filename, "w", encoding="utf8") as file:
            json.dump(self.data, file)

    def load(self, filename):
        with open(filename, "r", encoding="utf8") as file:
            self.data = json.load(file)


A
first  
AUTOMATIC 已提交
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.eval()
    return model


242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
module_in_gpu = None


def setup_for_low_vram(sd_model):
    parents = {}

    def send_me_to_gpu(module, _):
        """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
        we add this as forward_pre_hook to a lot of modules and this way all but one of them will
        be in CPU
        """
        global module_in_gpu

        module = parents.get(module, module)

        if module_in_gpu == module:
            return

        if module_in_gpu is not None:
            module_in_gpu.to(cpu)

        module.to(gpu)
        module_in_gpu = module

    # see below for register_forward_pre_hook;
    # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
    # useless here, and we just replace those methods
    def first_stage_model_encode_wrap(self, encoder, x):
        send_me_to_gpu(self, None)
        return encoder(x)

    def first_stage_model_decode_wrap(self, decoder, z):
        send_me_to_gpu(self, None)
        return decoder(z)

    # remove three big modules, cond, first_stage, and unet from the model and then
    # send the model to GPU. Then put modules back. the modules will be in CPU.
    stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
    sd_model.to(device)
    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored

    # register hooks for those the first two models
    sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
    sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
    sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
    sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
    parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model

    # the third remaining model is still too big for 4GB, so we also do the same for its submodules
    # so that only one of them is in GPU at a time
    diff_model = sd_model.model.diffusion_model
    stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
    diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
    sd_model.model.to(device)
    diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored

    # install hooks for bits of third model
    diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
    for block in diff_model.input_blocks:
        block.register_forward_pre_hook(send_me_to_gpu)
    diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
    for block in diff_model.output_blocks:
        block.register_forward_pre_hook(send_me_to_gpu)


A
AUTOMATIC 已提交
308
def create_random_tensors(shape, seeds):
A
AUTOMATIC 已提交
309
    xs = []
A
AUTOMATIC 已提交
310 311 312 313 314 315 316
    for seed in seeds:
        torch.manual_seed(seed)

        # randn results depend on device; gpu and cpu get different results for same seed;
        # the way I see it, it's better to do this on CPU, so that everyone gets same result;
        # but the original script had it like this so i do not dare change it for now because
        # it will break everyone's seeds.
A
AUTOMATIC 已提交
317 318 319 320 321
        xs.append(torch.randn(shape, device=device))
    x = torch.stack(xs)
    return x


H
hlky 已提交
322
def torch_gc():
323 324 325
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
A
AUTOMATIC 已提交
326

327

328
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False):
329

330 331
    if short_filename or prompt is None or seed is None:
        filename = f"{basename}"
332
    else:
333
        filename = f"{basename}-{seed}-{sanitize_filename_part(prompt)[:128]}"
334

A
AUTOMATIC 已提交
335
    if extension == 'png' and opts.enable_pnginfo and info is not None:
336 337 338 339 340
        pnginfo = PngImagePlugin.PngInfo()
        pnginfo.add_text("parameters", info)
    else:
        pnginfo = None

341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
    os.makedirs(path, exist_ok=True)
    fullfn = os.path.join(path, f"{filename}.{extension}")
    image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)

    target_side_length = 4000
    oversize = image.width > target_side_length or image.height > target_side_length
    if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
        ratio = image.width / image.height

        if oversize and ratio > 1:
            image = image.resize((target_side_length, image.height * target_side_length // image.width), LANCZOS)
        elif oversize:
            image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)

        image.save(os.path.join(path, f"{filename}.jpg"), quality=opts.jpeg_quality, pnginfo=pnginfo)


358 359


A
AUTOMATIC 已提交
360 361 362 363
def sanitize_filename_part(text):
    return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]


A
AUTOMATIC 已提交
364 365 366 367
def plaintext_to_html(text):
    text = "".join([f"<p>{html.escape(x)}</p>\n" for x in text.split('\n')])
    return text

368 369 370 371 372 373 374 375 376
def image_grid(imgs, batch_size=1, rows=None):
    if rows is None:
        if opts.n_rows > 0:
            rows = opts.n_rows
        elif opts.n_rows == 0:
            rows = batch_size
        else:
            rows = math.sqrt(len(imgs))
            rows = round(rows)
377 378

    cols = math.ceil(len(imgs) / rows)
A
first  
AUTOMATIC 已提交
379 380

    w, h = imgs[0].size
381
    grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
A
first  
AUTOMATIC 已提交
382 383 384 385 386 387

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))

    return grid

388

A
AUTOMATIC 已提交
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])


def split_grid(image, tile_w=512, tile_h=512, overlap=64):
    w = image.width
    h = image.height

    now = tile_w - overlap  # non-overlap width
    noh = tile_h - overlap

    cols = math.ceil((w - overlap) / now)
    rows = math.ceil((h - overlap) / noh)

    grid = Grid([], tile_w, tile_h, w, h, overlap)
    for row in range(rows):
        row_images = []

        y = row * noh

        if y + tile_h >= h:
            y = h - tile_h

        for col in range(cols):
            x = col * now

            if x+tile_w >= w:
                x = w - tile_w

            tile = image.crop((x, y, x + tile_w, y + tile_h))

            row_images.append([x, tile_w, tile])

        grid.tiles.append([y, tile_h, row_images])

    return grid


def combine_grid(grid):
    def make_mask_image(r):
        r = r * 255 / grid.overlap
        r = r.astype(np.uint8)
        return Image.fromarray(r, 'L')

432 433
    mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
    mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
A
AUTOMATIC 已提交
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455

    combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
    for y, h, row in grid.tiles:
        combined_row = Image.new("RGB", (grid.image_w, h))
        for x, w, tile in row:
            if x == 0:
                combined_row.paste(tile, (0, 0))
                continue

            combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
            combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))

        if y == 0:
            combined_image.paste(combined_row, (0, 0))
            continue

        combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
        combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))

    return combined_image


456 457 458 459 460 461 462 463 464
class GridAnnotation:
    def __init__(self, text='', is_active=True):
        self.text = text
        self.is_active = is_active
        self.size = None


def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
    def wrap(drawing, text, font, line_length):
A
AUTOMATIC 已提交
465 466 467
        lines = ['']
        for word in text.split():
            line = f'{lines[-1]} {word}'.strip()
468
            if drawing.textlength(line, font=font) <= line_length:
A
AUTOMATIC 已提交
469 470 471
                lines[-1] = line
            else:
                lines.append(word)
472
        return lines
A
AUTOMATIC 已提交
473

474 475 476
    def draw_texts(drawing, draw_x, draw_y, lines):
        for i, line in enumerate(lines):
            drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
A
AUTOMATIC 已提交
477

478 479
            if not line.is_active:
                drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4)
A
AUTOMATIC 已提交
480

481
            draw_y += line.size[1] + line_spacing
A
AUTOMATIC 已提交
482 483 484

    fontsize = (width + height) // 25
    line_spacing = fontsize // 2
A
AUTOMATIC 已提交
485
    fnt = ImageFont.truetype(opts.font, fontsize)
A
AUTOMATIC 已提交
486 487 488
    color_active = (0, 0, 0)
    color_inactive = (153, 153, 153)

A
AUTOMATIC 已提交
489
    pad_left = width * 3 // 4 if len(ver_texts) > 1 else 0
A
AUTOMATIC 已提交
490 491 492 493

    cols = im.width // width
    rows = im.height // height

494 495 496 497 498 499
    assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
    assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'

    calc_img = Image.new("RGB", (1, 1), "white")
    calc_d = ImageDraw.Draw(calc_img)

A
AUTOMATIC 已提交
500
    for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
501 502 503 504
        items = [] + texts
        texts.clear()

        for line in items:
A
AUTOMATIC 已提交
505
            wrapped = wrap(calc_d, line.text, fnt, allowed_width)
506 507 508 509 510 511 512 513 514 515 516
            texts += [GridAnnotation(x, line.is_active) for x in wrapped]

        for line in texts:
            bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
            line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])

    hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
    ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]

    pad_top = max(hor_text_heights) + line_spacing * 2

A
AUTOMATIC 已提交
517 518 519 520 521 522 523
    result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
    result.paste(im, (pad_left, pad_top))

    d = ImageDraw.Draw(result)

    for col in range(cols):
        x = pad_left + width * col + width / 2
524
        y = pad_top / 2 - hor_text_heights[col] / 2
A
AUTOMATIC 已提交
525

526
        draw_texts(d, x, y, hor_texts[col])
A
AUTOMATIC 已提交
527 528 529

    for row in range(rows):
        x = pad_left / 2
530
        y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
A
AUTOMATIC 已提交
531

532
        draw_texts(d, x, y, ver_texts[row])
A
AUTOMATIC 已提交
533 534 535 536

    return result


537 538 539 540
def draw_prompt_matrix(im, width, height, all_prompts):
    prompts = all_prompts[1:]
    boundary = math.ceil(len(prompts) / 2)

541 542
    prompts_horiz = prompts[:boundary]
    prompts_vert = prompts[boundary:]
543

544 545
    hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
    ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
546

547
    return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
548 549


550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565
def draw_xy_grid(xs, ys, x_label, y_label, cell):
    res = []

    ver_texts = [[GridAnnotation(y_label(y))] for y in ys]
    hor_texts = [[GridAnnotation(x_label(x))] for x in xs]

    for y in ys:
        for x in xs:
            res.append(cell(x, y))


    grid = image_grid(res, rows=len(ys))
    grid = draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)

    return grid

566

A
AUTOMATIC 已提交
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594
def resize_image(resize_mode, im, width, height):
    if resize_mode == 0:
        res = im.resize((width, height), resample=LANCZOS)
    elif resize_mode == 1:
        ratio = width / height
        src_ratio = im.width / im.height

        src_w = width if ratio > src_ratio else im.width * height // im.height
        src_h = height if ratio <= src_ratio else im.height * width // im.width

        resized = im.resize((src_w, src_h), resample=LANCZOS)
        res = Image.new("RGB", (width, height))
        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
    else:
        ratio = width / height
        src_ratio = im.width / im.height

        src_w = width if ratio < src_ratio else im.width * height // im.height
        src_h = height if ratio >= src_ratio else im.height * width // im.width

        resized = im.resize((src_w, src_h), resample=LANCZOS)
        res = Image.new("RGB", (width, height))
        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))

        if ratio < src_ratio:
            fill_height = height // 2 - src_h // 2
            res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
            res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
595
        elif ratio > src_ratio:
A
AUTOMATIC 已提交
596 597 598 599 600 601 602
            fill_width = width // 2 - src_w // 2
            res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
            res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))

    return res


A
AUTOMATIC 已提交
603 604 605 606 607 608 609 610 611 612 613 614 615 616
def wrap_gradio_call(func):
    def f(*p1, **p2):
        t = time.perf_counter()
        res = list(func(*p1, **p2))
        elapsed = time.perf_counter() - t

        # last item is always HTML
        res[-1] = res[-1] + f"<p class='performance'>Time taken: {elapsed:.2f}s</p>"

        return tuple(res)

    return f


O
orionaskatu 已提交
617
class StableDiffusionModelHijack:
A
AUTOMATIC 已提交
618 619 620
    ids_lookup = {}
    word_embeddings = {}
    word_embeddings_checksums = {}
621
    fixes = None
622
    comments = None
A
AUTOMATIC 已提交
623 624
    dir_mtime = None

625 626
    def load_textual_inversion_embeddings(self, dirname, model):
        mt = os.path.getmtime(dirname)
A
AUTOMATIC 已提交
627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652
        if self.dir_mtime is not None and mt <= self.dir_mtime:
            return

        self.dir_mtime = mt
        self.ids_lookup.clear()
        self.word_embeddings.clear()

        tokenizer = model.cond_stage_model.tokenizer

        def const_hash(a):
            r = 0
            for v in a:
                r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
            return r

        def process_file(path, filename):
            name = os.path.splitext(filename)[0]

            data = torch.load(path)
            param_dict = data['string_to_param']
            assert len(param_dict) == 1, 'embedding file has multiple terms in it'
            emb = next(iter(param_dict.items()))[1].reshape(768)
            self.word_embeddings[name] = emb
            self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'

            ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
653

A
AUTOMATIC 已提交
654 655 656 657 658
            first_id = ids[0]
            if first_id not in self.ids_lookup:
                self.ids_lookup[first_id] = []
            self.ids_lookup[first_id].append((ids, name))

659
        for fn in os.listdir(dirname):
A
AUTOMATIC 已提交
660
            try:
661 662
                process_file(os.path.join(dirname, fn), fn)
            except Exception:
A
AUTOMATIC 已提交
663 664 665 666 667 668 669 670 671 672 673 674
                print(f"Error loading emedding {fn}:", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)
                continue

        print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.")

    def hijack(self, m):
        model_embeddings = m.cond_stage_model.transformer.text_model.embeddings

        model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
        m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)

A
AUTOMATIC 已提交
675

A
AUTOMATIC 已提交
676
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
677
    def __init__(self, wrapped, hijack):
A
AUTOMATIC 已提交
678 679
        super().__init__()
        self.wrapped = wrapped
680
        self.hijack = hijack
A
AUTOMATIC 已提交
681 682
        self.tokenizer = wrapped.tokenizer
        self.max_length = wrapped.max_length
683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699
        self.token_mults = {}

        tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
        for text, ident in tokens_with_parens:
            mult = 1.0
            for c in text:
                if c == '[':
                    mult /= 1.1
                if c == ']':
                    mult *= 1.1
                if c == '(':
                    mult *= 1.1
                if c == ')':
                    mult /= 1.1

            if mult != 1.0:
                self.token_mults[ident] = mult
A
AUTOMATIC 已提交
700 701

    def forward(self, text):
702 703
        self.hijack.fixes = []
        self.hijack.comments = []
A
AUTOMATIC 已提交
704 705 706 707
        remade_batch_tokens = []
        id_start = self.wrapped.tokenizer.bos_token_id
        id_end = self.wrapped.tokenizer.eos_token_id
        maxlen = self.wrapped.max_length - 2
708
        used_custom_terms = []
A
AUTOMATIC 已提交
709 710 711

        cache = {}
        batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
712
        batch_multipliers = []
A
AUTOMATIC 已提交
713 714 715 716
        for tokens in batch_tokens:
            tuple_tokens = tuple(tokens)

            if tuple_tokens in cache:
717
                remade_tokens, fixes, multipliers = cache[tuple_tokens]
A
AUTOMATIC 已提交
718 719 720
            else:
                fixes = []
                remade_tokens = []
721 722
                multipliers = []
                mult = 1.0
A
AUTOMATIC 已提交
723 724 725 726 727

                i = 0
                while i < len(tokens):
                    token = tokens[i]

728
                    possible_matches = self.hijack.ids_lookup.get(token, None)
A
AUTOMATIC 已提交
729

730 731 732 733
                    mult_change = self.token_mults.get(token)
                    if mult_change is not None:
                        mult *= mult_change
                    elif possible_matches is None:
A
AUTOMATIC 已提交
734
                        remade_tokens.append(token)
735
                        multipliers.append(mult)
A
AUTOMATIC 已提交
736 737 738 739 740 741
                    else:
                        found = False
                        for ids, word in possible_matches:
                            if tokens[i:i+len(ids)] == ids:
                                fixes.append((len(remade_tokens), word))
                                remade_tokens.append(777)
742
                                multipliers.append(mult)
A
AUTOMATIC 已提交
743 744
                                i += len(ids) - 1
                                found = True
745
                                used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
A
AUTOMATIC 已提交
746 747 748 749
                                break

                        if not found:
                            remade_tokens.append(token)
750
                            multipliers.append(mult)
A
AUTOMATIC 已提交
751 752 753

                    i += 1

754 755 756 757 758 759 760 761
                if len(remade_tokens) > maxlen - 2:
                    vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
                    ovf = remade_tokens[maxlen - 2:]
                    overflowing_words = [vocab.get(int(x), "") for x in ovf]
                    overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))

                    self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")

A
AUTOMATIC 已提交
762 763
                remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
                remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
764 765 766 767
                cache[tuple_tokens] = (remade_tokens, fixes, multipliers)

            multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
            multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
A
AUTOMATIC 已提交
768 769

            remade_batch_tokens.append(remade_tokens)
770
            self.hijack.fixes.append(fixes)
771
            batch_multipliers.append(multipliers)
A
AUTOMATIC 已提交
772

773 774 775
        if len(used_custom_terms) > 0:
            self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))

776
        tokens = torch.asarray(remade_batch_tokens).to(device)
A
AUTOMATIC 已提交
777 778
        outputs = self.wrapped.transformer(input_ids=tokens)
        z = outputs.last_hidden_state
779 780 781 782 783 784 785 786

        # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
        batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device)
        original_mean = z.mean()
        z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
        new_mean = z.mean()
        z *= original_mean / new_mean

A
AUTOMATIC 已提交
787 788 789 790 791 792 793 794 795 796 797
        return z


class EmbeddingsWithFixes(nn.Module):
    def __init__(self, wrapped, embeddings):
        super().__init__()
        self.wrapped = wrapped
        self.embeddings = embeddings

    def forward(self, input_ids):
        batch_fixes = self.embeddings.fixes
798
        self.embeddings.fixes = None
A
AUTOMATIC 已提交
799 800 801

        inputs_embeds = self.wrapped(input_ids)

802 803 804 805
        if batch_fixes is not None:
            for fixes, tensor in zip(batch_fixes, inputs_embeds):
                for offset, word in fixes:
                    tensor[offset] = self.embeddings.word_embeddings[word]
A
AUTOMATIC 已提交
806

807
        return inputs_embeds
A
AUTOMATIC 已提交
808 809


810
class StableDiffusionProcessing:
811
    def __init__(self, outpath=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None):
812 813 814 815 816 817 818 819 820 821 822 823
        self.outpath: str = outpath
        self.prompt: str = prompt
        self.seed: int = seed
        self.sampler_index: int = sampler_index
        self.batch_size: int = batch_size
        self.n_iter: int = n_iter
        self.steps: int = steps
        self.cfg_scale: float = cfg_scale
        self.width: int = width
        self.height: int = height
        self.prompt_matrix: bool = prompt_matrix
        self.use_GFPGAN: bool = use_GFPGAN
824
        self.do_not_save_samples: bool = do_not_save_samples
825 826
        self.do_not_save_grid: bool = do_not_save_grid
        self.extra_generation_params: dict = extra_generation_params
827
        self.overlay_images = overlay_images
828 829 830 831 832 833 834 835

    def init(self):
        pass

    def sample(self, x, conditioning, unconditional_conditioning):
        raise NotImplementedError()


A
AUTOMATIC 已提交
836 837 838 839 840 841 842 843
def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
    if sampler_wrapper.mask is not None:
        img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
        x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec

    return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)


844 845 846
class VanillaStableDiffusionSampler:
    def __init__(self, constructor):
        self.sampler = constructor(sd_model)
A
AUTOMATIC 已提交
847 848 849 850 851 852 853 854
        self.orig_p_sample_ddim = self.sampler.p_sample_ddim
        self.mask = None
        self.nmask = None
        self.init_latent = None

    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
        t_enc = int(min(p.denoising_strength, 0.999) * p.steps)

855 856 857 858 859 860
        # existing code fail with cetin step counts, like 9
        try:
            self.sampler.make_schedule(ddim_num_steps=p.steps, verbose=False)
        except Exception:
            self.sampler.make_schedule(ddim_num_steps=p.steps+1, verbose=False)

A
AUTOMATIC 已提交
861 862
        x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(device), noise=noise)

863
        self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
A
AUTOMATIC 已提交
864 865 866 867 868 869 870 871
        self.mask = p.mask
        self.nmask = p.nmask
        self.init_latent = p.init_latent

        samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)

        return samples

872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897

    def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
        samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
        return samples_ddim


class CFGDenoiser(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.inner_model = model

    def forward(self, x, sigma, uncond, cond, cond_scale):
        x_in = torch.cat([x] * 2)
        sigma_in = torch.cat([sigma] * 2)
        cond_in = torch.cat([uncond, cond])
        uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
        return uncond + (cond - uncond) * cond_scale


class KDiffusionSampler:
    def __init__(self, funcname):
        self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model)
        self.funcname = funcname
        self.func = getattr(k_diffusion.sampling, self.funcname)
        self.model_wrap_cfg = CFGDenoiser(self.model_wrap)

A
AUTOMATIC 已提交
898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918
    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
        t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
        sigmas = self.model_wrap.get_sigmas(p.steps)
        noise = noise * sigmas[p.steps - t_enc - 1]

        xi = x + noise

        if p.mask is not None:
            if p.inpainting_fill == 2:
                xi = xi * p.mask + noise * p.nmask
            elif p.inpainting_fill == 3:
                xi = xi * p.mask

        sigma_sched = sigmas[p.steps - t_enc - 1:]

        def mask_cb(v):
            v["denoised"][:] = v["denoised"][:] * p.nmask + p.init_latent * p.mask

        return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=mask_cb if p.mask is not None else None)


919 920 921 922 923 924 925 926
    def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
        sigmas = self.model_wrap.get_sigmas(p.steps)
        x = x * sigmas[0]

        samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
        return samples_ddim


927 928 929 930
Processed = namedtuple('Processed', ['images','seed', 'info'])


def process_images(p: StableDiffusionProcessing) -> Processed:
A
AUTOMATIC 已提交
931
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
A
first  
AUTOMATIC 已提交
932

933 934 935 936
    prompt = p.prompt
    model = sd_model

    assert p.prompt is not None
H
hlky 已提交
937
    torch_gc()
A
first  
AUTOMATIC 已提交
938

939
    seed = int(random.randrange(4294967294) if p.seed == -1 else p.seed)
A
first  
AUTOMATIC 已提交
940

941
    sample_path = os.path.join(p.outpath, "samples")
942
    os.makedirs(sample_path, exist_ok=True)
A
first  
AUTOMATIC 已提交
943
    base_count = len(os.listdir(sample_path))
944
    grid_count = len(os.listdir(p.outpath)) - 1
A
first  
AUTOMATIC 已提交
945

946 947
    comments = []

948
    prompt_matrix_parts = []
949
    if p.prompt_matrix:
A
AUTOMATIC 已提交
950
        all_prompts = []
951
        prompt_matrix_parts = prompt.split("|")
A
AUTOMATIC 已提交
952
        combination_count = 2 ** (len(prompt_matrix_parts) - 1)
A
AUTOMATIC 已提交
953
        for combination_num in range(combination_count):
954
            selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
A
AUTOMATIC 已提交
955

956 957 958 959
            if opts.prompt_matrix_add_to_start:
                selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
            else:
                selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
A
AUTOMATIC 已提交
960

961
            all_prompts.append(", ".join(selected_prompts))
A
AUTOMATIC 已提交
962

963
        p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
A
AUTOMATIC 已提交
964 965
        all_seeds = len(all_prompts) * [seed]

966
        print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
A
AUTOMATIC 已提交
967
    else:
968
        all_prompts = p.batch_size * p.n_iter * [prompt]
A
AUTOMATIC 已提交
969
        all_seeds = [seed + x for x in range(len(all_prompts))]
A
AUTOMATIC 已提交
970

971
    generation_params = {
972 973 974
        "Steps": p.steps,
        "Sampler": samplers[p.sampler_index].name,
        "CFG scale": p.cfg_scale,
975
        "Seed": seed,
A
AUTOMATIC 已提交
976
        "GFPGAN": ("GFPGAN" if p.use_GFPGAN else None)
977 978
    }

979 980
    if p.extra_generation_params is not None:
        generation_params.update(p.extra_generation_params)
981 982 983

    generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])

A
AUTOMATIC 已提交
984
    def infotext():
985
        return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
A
AUTOMATIC 已提交
986 987

    if os.path.exists(cmd_opts.embeddings_dir):
988
        model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
989

A
first  
AUTOMATIC 已提交
990
    output_images = []
991
    precision_scope = autocast if cmd_opts.precision == "autocast" else nullcontext
992
    ema_scope = (nullcontext if cmd_opts.lowvram else model.ema_scope)
993
    with torch.no_grad(), precision_scope("cuda"), ema_scope():
994
        p.init()
A
AUTOMATIC 已提交
995

996 997 998
        for n in range(p.n_iter):
            prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
A
AUTOMATIC 已提交
999

A
AUTOMATIC 已提交
1000
            uc = model.get_learned_conditioning(len(prompts) * [""])
A
AUTOMATIC 已提交
1001 1002
            c = model.get_learned_conditioning(prompts)

1003 1004
            if len(model_hijack.comments) > 0:
                comments += model_hijack.comments
A
AUTOMATIC 已提交
1005

A
AUTOMATIC 已提交
1006
            # we manually generate all input noises because each one should have a specific seed
1007
            x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
A
AUTOMATIC 已提交
1008

1009
            samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
A
first  
AUTOMATIC 已提交
1010

A
AUTOMATIC 已提交
1011 1012
            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
A
first  
AUTOMATIC 已提交
1013

1014
            if p.prompt_matrix or opts.samples_save or opts.grid_save:
A
AUTOMATIC 已提交
1015
                for i, x_sample in enumerate(x_samples_ddim):
1016
                    x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
A
AUTOMATIC 已提交
1017 1018
                    x_sample = x_sample.astype(np.uint8)

A
AUTOMATIC 已提交
1019
                    if p.use_GFPGAN:
C
Craftyawesome 已提交
1020
                        torch_gc()
A
AUTOMATIC 已提交
1021 1022 1023

                        gfpgan_model = gfpgan()
                        cropped_faces, restored_faces, restored_img = gfpgan_model.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
A
AUTOMATIC 已提交
1024 1025 1026
                        x_sample = restored_img

                    image = Image.fromarray(x_sample)
1027

1028 1029 1030 1031 1032
                    if p.overlay_images is not None and i < len(p.overlay_images):
                        image = image.convert('RGBA')
                        image.alpha_composite(p.overlay_images[i])
                        image = image.convert('RGB')

1033 1034
                    if not p.do_not_save_samples:
                        save_image(image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opts.samples_format, info=infotext())
A
AUTOMATIC 已提交
1035 1036 1037

                    output_images.append(image)
                    base_count += 1
A
first  
AUTOMATIC 已提交
1038

1039 1040
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
        if (p.prompt_matrix or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
A
AUTOMATIC 已提交
1041 1042
            return_grid = opts.return_grid

1043
            if p.prompt_matrix:
1044
                grid = image_grid(output_images, p.batch_size, rows=1 << ((len(prompt_matrix_parts)-1)//2))
A
AUTOMATIC 已提交
1045 1046

                try:
1047 1048
                    grid = draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
                except Exception:
A
AUTOMATIC 已提交
1049 1050 1051 1052
                    import traceback
                    print("Error creating prompt_matrix text:", file=sys.stderr)
                    print(traceback.format_exc(), file=sys.stderr)

A
AUTOMATIC 已提交
1053
                return_grid = True
1054
            else:
1055
                grid = image_grid(output_images, p.batch_size)
1056

A
AUTOMATIC 已提交
1057 1058 1059
            if return_grid:
                output_images.insert(0, grid)

1060
            save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
A
first  
AUTOMATIC 已提交
1061 1062
            grid_count += 1

A
AUTOMATIC 已提交
1063
    torch_gc()
1064
    return Processed(output_images, seed, infotext())
D
dogewanwan 已提交
1065

A
AUTOMATIC 已提交
1066

1067 1068
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    sampler = None
D
dogewanwan 已提交
1069

1070 1071
    def init(self):
        self.sampler = samplers[self.sampler_index].constructor()
A
AUTOMATIC 已提交
1072

1073 1074
    def sample(self, x, conditioning, unconditional_conditioning):
        samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
A
AUTOMATIC 已提交
1075 1076
        return samples_ddim

1077
def txt2img(prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, code: str):
1078 1079 1080
    outpath = opts.outdir or "outputs/txt2img-samples"

    p = StableDiffusionProcessingTxt2Img(
A
AUTOMATIC 已提交
1081 1082 1083
        outpath=outpath,
        prompt=prompt,
        seed=seed,
A
AUTOMATIC 已提交
1084
        sampler_index=sampler_index,
A
AUTOMATIC 已提交
1085 1086
        batch_size=batch_size,
        n_iter=n_iter,
1087
        steps=steps,
A
AUTOMATIC 已提交
1088 1089 1090 1091 1092 1093 1094
        cfg_scale=cfg_scale,
        width=width,
        height=height,
        prompt_matrix=prompt_matrix,
        use_GFPGAN=use_GFPGAN
    )

1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115
    if code != '' and cmd_opts.allow_code:
        p.do_not_save_grid = True
        p.do_not_save_samples = True

        display_result_data = [[], -1, ""]
        def display(imgs, s=display_result_data[1], i=display_result_data[2]):
            display_result_data[0] = imgs
            display_result_data[1] = s
            display_result_data[2] = i

        from types import ModuleType
        compiled = compile(code, '', 'exec')
        module = ModuleType("testmodule")
        module.__dict__.update(globals())
        module.p = p
        module.display = display
        exec(compiled, module.__dict__)

        processed = Processed(*display_result_data)
    else:
        processed = process_images(p)
A
AUTOMATIC 已提交
1116

1117
    return processed.images, processed.seed, plaintext_to_html(processed.info)
A
AUTOMATIC 已提交
1118 1119


A
AUTOMATIC 已提交
1120 1121 1122 1123 1124
class Flagging(gr.FlaggingCallback):

    def setup(self, components, flagging_dir: str):
        pass

1125 1126 1127
    def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
        import csv

A
AUTOMATIC 已提交
1128 1129
        os.makedirs("log/images", exist_ok=True)

A
AUTOMATIC 已提交
1130
        # those must match the "txt2img" function
A
AUTOMATIC 已提交
1131
        prompt, steps, sampler_index, use_gfpgan, prompt_matrix, n_iter, batch_size, cfg_scale, seed, height, width, code, images, seed, comment = flag_data
A
AUTOMATIC 已提交
1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155

        filenames = []

        with open("log/log.csv", "a", encoding="utf8", newline='') as file:
            import time
            import base64

            at_start = file.tell() == 0
            writer = csv.writer(file)
            if at_start:
                writer.writerow(["prompt", "seed", "width", "height", "cfgs", "steps", "filename"])

            filename_base = str(int(time.time() * 1000))
            for i, filedata in enumerate(images):
                filename = "log/images/"+filename_base + ("" if len(images) == 1 else "-"+str(i+1)) + ".png"

                if filedata.startswith("data:image/png;base64,"):
                    filedata = filedata[len("data:image/png;base64,"):]

                with open(filename, "wb") as imgfile:
                    imgfile.write(base64.decodebytes(filedata.encode('utf-8')))

                filenames.append(filename)

A
AUTOMATIC 已提交
1156
            writer.writerow([prompt, seed, width, height, cfg_scale, steps, filenames[0]])
A
AUTOMATIC 已提交
1157 1158 1159

        print("Logged:", filenames[0])

A
AUTOMATIC 已提交
1160 1161 1162 1163 1164 1165 1166 1167
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
    with gr.Row():
        prompt = gr.Textbox(label="Prompt", elem_id="txt2img_prompt", show_label=False, placeholder="Prompt", lines=1)
        submit = gr.Button('Generate', variant='primary')

    with gr.Row().style(equal_height=False):
        with gr.Column(variant='panel'):
            steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
O
orionaskatu 已提交
1168
            sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
A
AUTOMATIC 已提交
1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219

            with gr.Row():
                use_GFPGAN = gr.Checkbox(label='GFPGAN', value=False, visible=have_gfpgan)
                prompt_matrix = gr.Checkbox(label='Prompt matrix', value=False)

            with gr.Row():
                batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
                batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)

            cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0)

            with gr.Group():
                height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
                width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)

            seed = gr.Number(label='Seed', value=-1)

            code = gr.Textbox(label="Python script", visible=cmd_opts.allow_code, lines=1)

        with gr.Column(variant='panel'):
            with gr.Group():
                gallery = gr.Gallery(label='Output')
                output_seed = gr.Number(label='Seed', visible=False)
                html_info = gr.HTML()

        txt2img_args = dict(
            fn=wrap_gradio_call(txt2img),
            inputs=[
                prompt,
                steps,
                sampler_index,
                use_GFPGAN,
                prompt_matrix,
                batch_count,
                batch_size,
                cfg_scale,
                seed,
                height,
                width,
                code
            ],
            outputs=[
                gallery,
                output_seed,
                html_info
            ]
        )

        prompt.submit(**txt2img_args)
        submit.click(**txt2img_args)

A
first  
AUTOMATIC 已提交
1220 1221


A
AUTOMATIC 已提交
1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236
def fill(image, mask):
    image_mod = Image.new('RGBA', (image.width, image.height))

    image_masked = Image.new('RGBa', (image.width, image.height))
    image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))

    image_masked = image_masked.convert('RGBa')

    for radius, repeats in [(64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
        blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
        for _ in range(repeats):
            image_mod.alpha_composite(blurred)

    return image_mod.convert("RGB")

A
first  
AUTOMATIC 已提交
1237

1238 1239
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    sampler = None
D
dogewanwan 已提交
1240

1241
    def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, **kwargs):
1242
        super().__init__(**kwargs)
A
first  
AUTOMATIC 已提交
1243

1244 1245 1246 1247
        self.init_images = init_images
        self.resize_mode: int = resize_mode
        self.denoising_strength: float = denoising_strength
        self.init_latent = None
A
AUTOMATIC 已提交
1248 1249
        self.original_mask = mask
        self.mask_blur = mask_blur
1250
        self.inpainting_fill = inpainting_fill
A
AUTOMATIC 已提交
1251 1252
        self.mask = None
        self.nmask = None
1253 1254 1255

    def init(self):
        self.sampler = samplers_for_img2img[self.sampler_index].constructor()
A
first  
AUTOMATIC 已提交
1256

A
AUTOMATIC 已提交
1257
        if self.original_mask is not None:
1258
            self.original_mask = resize_image(self.resize_mode, self.original_mask, self.width, self.height)
1259 1260
            self.overlay_images = []

1261 1262 1263 1264
        imgs = []
        for img in self.init_images:
            image = img.convert("RGB")
            image = resize_image(self.resize_mode, image, self.width, self.height)
A
AUTOMATIC 已提交
1265

1266
            if self.original_mask is not None:
1267
                if self.inpainting_fill != 1:
1268 1269 1270 1271 1272 1273
                    image = fill(image, self.original_mask)

                image_masked = Image.new('RGBa', (image.width, image.height))
                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.original_mask.convert('L')))

                self.overlay_images.append(image_masked.convert('RGBA'))
A
AUTOMATIC 已提交
1274

1275 1276
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)
A
AUTOMATIC 已提交
1277

1278
            imgs.append(image)
A
first  
AUTOMATIC 已提交
1279

1280 1281
        if len(imgs) == 1:
            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
1282 1283
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size
1284 1285 1286 1287 1288
        elif len(imgs) <= self.batch_size:
            self.batch_size = len(imgs)
            batch_images = np.array(imgs)
        else:
            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
A
AUTOMATIC 已提交
1289

1290 1291 1292
        image = torch.from_numpy(batch_images)
        image = 2. * image - 1.
        image = image.to(device)
A
AUTOMATIC 已提交
1293

1294
        self.init_latent = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image))
A
AUTOMATIC 已提交
1295

1296 1297 1298 1299 1300
        if self.original_mask is not None:
            if self.mask_blur > 0:
                self.original_mask = self.original_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)).convert('L')

            latmask = self.original_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
A
AUTOMATIC 已提交
1301
            latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255
1302 1303 1304 1305 1306 1307
            latmask = latmask[0]
            latmask = np.tile(latmask[None], (4, 1, 1))

            self.mask = torch.asarray(1.0 - latmask).to(device).type(sd_model.dtype)
            self.nmask = torch.asarray(latmask).to(device).type(sd_model.dtype)

1308
    def sample(self, x, conditioning, unconditional_conditioning):
A
AUTOMATIC 已提交
1309
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
A
AUTOMATIC 已提交
1310

1311
        if self.mask is not None:
A
AUTOMATIC 已提交
1312
            samples = samples * self.nmask + self.init_latent * self.mask
1313

A
AUTOMATIC 已提交
1314
        return samples
A
AUTOMATIC 已提交
1315

1316

1317
def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, prompt_matrix, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_name: str, upscale_overlap: int):
1318 1319
    outpath = opts.outdir or "outputs/img2img-samples"

A
AUTOMATIC 已提交
1320 1321 1322 1323 1324 1325
    is_classic = mode == 0
    is_inpaint = mode == 1
    is_loopback = mode == 2
    is_upscale = mode == 3

    if is_inpaint:
A
AUTOMATIC 已提交
1326 1327 1328 1329 1330 1331
        image = init_img_with_mask['image']
        mask = init_img_with_mask['mask']
    else:
        image = init_img
        mask = None

1332 1333 1334 1335 1336 1337 1338 1339 1340
    assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'

    p = StableDiffusionProcessingImg2Img(
        outpath=outpath,
        prompt=prompt,
        seed=seed,
        sampler_index=sampler_index,
        batch_size=batch_size,
        n_iter=n_iter,
A
AUTOMATIC 已提交
1341
        steps=steps,
1342 1343 1344 1345 1346
        cfg_scale=cfg_scale,
        width=width,
        height=height,
        prompt_matrix=prompt_matrix,
        use_GFPGAN=use_GFPGAN,
A
AUTOMATIC 已提交
1347 1348
        init_images=[image],
        mask=mask,
1349 1350
        mask_blur=mask_blur,
        inpainting_fill=inpainting_fill,
1351 1352 1353 1354 1355
        resize_mode=resize_mode,
        denoising_strength=denoising_strength,
        extra_generation_params={"Denoising Strength": denoising_strength}
    )

A
AUTOMATIC 已提交
1356
    if is_loopback:
A
AUTOMATIC 已提交
1357 1358 1359
        output_images, info = None, None
        history = []
        initial_seed = None
1360
        initial_info = None
A
AUTOMATIC 已提交
1361 1362

        for i in range(n_iter):
1363 1364 1365 1366
            p.n_iter = 1
            p.batch_size = 1
            p.do_not_save_grid = True

1367
            processed = process_images(p)
A
AUTOMATIC 已提交
1368 1369

            if initial_seed is None:
1370 1371
                initial_seed = processed.seed
                initial_info = processed.info
A
AUTOMATIC 已提交
1372

1373 1374
            p.init_img = processed.images[0]
            p.seed = processed.seed + 1
1375
            p.denoising_strength = max(p.denoising_strength * 0.95, 0.1)
1376
            history.append(processed.images[0])
A
AUTOMATIC 已提交
1377 1378

        grid_count = len(os.listdir(outpath)) - 1
1379
        grid = image_grid(history, batch_size, rows=1)
1380

A
AUTOMATIC 已提交
1381
        save_image(grid, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename)
A
AUTOMATIC 已提交
1382

1383
        processed = Processed(history, initial_seed, initial_info)
A
AUTOMATIC 已提交
1384

A
AUTOMATIC 已提交
1385
    elif is_upscale:
A
AUTOMATIC 已提交
1386 1387 1388
        initial_seed = None
        initial_info = None

1389
        upscaler = sd_upscalers[upscaler_name]
A
AUTOMATIC 已提交
1390
        img = upscaler(init_img)
A
AUTOMATIC 已提交
1391 1392 1393

        torch_gc()

1394
        grid = split_grid(img, tile_w=width, tile_h=height, overlap=upscale_overlap)
A
AUTOMATIC 已提交
1395

1396 1397
        p.n_iter = 1
        p.do_not_save_grid = True
1398
        p.do_not_save_samples = True
1399 1400 1401 1402 1403 1404 1405 1406 1407 1408

        work = []
        work_results = []

        for y, h, row in grid.tiles:
            for tiledata in row:
                work.append(tiledata[2])

        batch_count = math.ceil(len(work) / p.batch_size)
        print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} in a total of {batch_count} batches.")
A
AUTOMATIC 已提交
1409

1410 1411
        for i in range(batch_count):
            p.init_images = work[i*p.batch_size:(i+1)*p.batch_size]
A
AUTOMATIC 已提交
1412

1413
            processed = process_images(p)
1414 1415

            if initial_seed is None:
1416 1417
                initial_seed = processed.seed
                initial_info = processed.info
1418

1419 1420
            p.seed = processed.seed + 1
            work_results += processed.images
1421 1422

        image_index = 0
A
AUTOMATIC 已提交
1423 1424
        for y, h, row in grid.tiles:
            for tiledata in row:
1425 1426
                tiledata[2] = work_results[image_index]
                image_index += 1
A
AUTOMATIC 已提交
1427 1428 1429 1430 1431 1432

        combined_image = combine_grid(grid)

        grid_count = len(os.listdir(outpath)) - 1
        save_image(combined_image, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=initial_info, short_filename=not opts.grid_extended_filename)

1433
        processed = Processed([combined_image], initial_seed, initial_info)
A
AUTOMATIC 已提交
1434

A
AUTOMATIC 已提交
1435
    else:
1436
        processed = process_images(p)
A
first  
AUTOMATIC 已提交
1437

1438
    return processed.images, processed.seed, plaintext_to_html(processed.info)
A
first  
AUTOMATIC 已提交
1439 1440


1441 1442 1443
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None

A
AUTOMATIC 已提交
1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467

with gr.Blocks(analytics_enabled=False) as img2img_interface:
    with gr.Row():
        prompt = gr.Textbox(label="Prompt", elem_id="img2img_prompt", show_label=False, placeholder="Prompt", lines=1)
        submit = gr.Button('Generate', variant='primary')

    with gr.Row().style(equal_height=False):

        with gr.Column(variant='panel'):
            with gr.Group():
                switch_mode = gr.Radio(label='Mode', elem_id="img2img_mode", choices=['Redraw whole image', 'Inpaint a part of image', 'Loopback', 'SD upscale'], value='Redraw whole image', type="index", show_label=False)
                init_img = gr.Image(label="Image for img2img", source="upload", interactive=True, type="pil")
                init_img_with_mask = gr.Image(label="Image for inpainting with mask", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", visible=False)
                resize_mode = gr.Radio(label="Resize mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")

            steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
            sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
            mask_blur = gr.Slider(label='Inpainting: mask blur', minimum=0, maximum=64, step=1, value=4, visible=False)
            inpainting_fill = gr.Radio(label='Inpainting: masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False)

            with gr.Row():
                use_GFPGAN = gr.Checkbox(label='GFPGAN', value=False, visible=have_gfpgan)
                prompt_matrix = gr.Checkbox(label='Prompt matrix', value=False)

1468 1469 1470 1471
            with gr.Row():
                sd_upscale_upscaler_name = gr.Radio(label='Upscaler', choices=list(sd_upscalers.keys()), value="RealESRGAN")
                sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64)

A
AUTOMATIC 已提交
1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505
            with gr.Row():
                batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
                batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)

            with gr.Group():
                cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0)
                denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75)

            with gr.Group():
                height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
                width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)

            seed = gr.Number(label='Seed', value=-1)

        with gr.Column(variant='panel'):
            with gr.Group():
                gallery = gr.Gallery(label='Output')
                output_seed = gr.Number(label='Seed', visible=False)
                html_info = gr.HTML()

        def apply_mode(mode):
            is_classic = mode == 0
            is_inpaint = mode == 1
            is_loopback = mode == 2
            is_upscale = mode == 3

            return {
                init_img: gr.update(visible=not is_inpaint),
                init_img_with_mask: gr.update(visible=is_inpaint),
                mask_blur: gr.update(visible=is_inpaint),
                inpainting_fill: gr.update(visible=is_inpaint),
                prompt_matrix: gr.update(visible=is_classic),
                batch_count: gr.update(visible=not is_upscale),
                batch_size: gr.update(visible=not is_loopback),
1506 1507
                sd_upscale_upscaler_name: gr.update(visible=is_upscale),
                sd_upscale_overlap: gr.update(visible=is_upscale),
A
AUTOMATIC 已提交
1508 1509 1510 1511 1512
            }

        switch_mode.change(
            apply_mode,
            inputs=[switch_mode],
1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523
            outputs=[
                init_img,
                init_img_with_mask,
                mask_blur,
                inpainting_fill,
                prompt_matrix,
                batch_count,
                batch_size,
                sd_upscale_upscaler_name,
                sd_upscale_overlap,
            ]
A
AUTOMATIC 已提交
1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545
        )

        img2img_args = dict(
            fn=wrap_gradio_call(img2img),
            inputs=[
                prompt,
                init_img,
                init_img_with_mask,
                steps,
                sampler_index,
                mask_blur,
                inpainting_fill,
                use_GFPGAN,
                prompt_matrix,
                switch_mode,
                batch_count,
                batch_size,
                cfg_scale,
                denoising_strength,
                seed,
                height,
                width,
1546 1547 1548
                resize_mode,
                sd_upscale_upscaler_name,
                sd_upscale_overlap,
A
AUTOMATIC 已提交
1549 1550 1551 1552 1553 1554 1555 1556 1557 1558
            ],
            outputs=[
                gallery,
                output_seed,
                html_info
            ]
        )

        prompt.submit(**img2img_args)
        submit.click(**img2img_args)
A
first  
AUTOMATIC 已提交
1559

A
AUTOMATIC 已提交
1560

A
AUTOMATIC 已提交
1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
    info = realesrgan_models[RealESRGAN_model_index]

    model = info.model()
    upsampler = RealESRGANer(
        scale=info.netscale,
        model_path=info.location,
        model=model,
        half=True
    )

    upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]

    image = Image.fromarray(upsampled)
    return image


A
AUTOMATIC 已提交
1578
def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index):
A
AUTOMATIC 已提交
1579 1580
    torch_gc()

1581 1582
    image = image.convert("RGB")

A
AUTOMATIC 已提交
1583 1584
    outpath = opts.outdir or "outputs/extras-samples"

A
AUTOMATIC 已提交
1585 1586 1587
    if have_gfpgan is not None and GFPGAN_strength > 0:
        gfpgan_model = gfpgan()
        cropped_faces, restored_faces, restored_img = gfpgan_model.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
A
AUTOMATIC 已提交
1588 1589 1590 1591 1592 1593 1594 1595
        res = Image.fromarray(restored_img)

        if GFPGAN_strength < 1.0:
            res = Image.blend(image, res, GFPGAN_strength)

        image = res

    if have_realesrgan and RealESRGAN_upscaling != 1.0:
A
AUTOMATIC 已提交
1596
        image = upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
1597

A
AUTOMATIC 已提交
1598
    os.makedirs(outpath, exist_ok=True)
A
AUTOMATIC 已提交
1599 1600
    base_count = len(os.listdir(outpath))
    save_image(image, outpath, f"{base_count:05}", None, '', opts.samples_format, short_filename=True)
1601

A
AUTOMATIC 已提交
1602
    return image, 0, ''
1603 1604


A
AUTOMATIC 已提交
1605 1606
extras_interface = gr.Interface(
    wrap_gradio_call(run_extras),
A
AUTOMATIC 已提交
1607 1608
    inputs=[
        gr.Image(label="Source", source="upload", interactive=True, type="pil"),
A
AUTOMATIC 已提交
1609
        gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=have_gfpgan),
A
AUTOMATIC 已提交
1610 1611
        gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Real-ESRGAN upscaling", value=2, interactive=have_realesrgan),
        gr.Radio(label='Real-ESRGAN model', choices=[x.name for x in realesrgan_models], value=realesrgan_models[0].name, type="index", interactive=have_realesrgan),
A
AUTOMATIC 已提交
1612 1613 1614 1615 1616 1617 1618
    ],
    outputs=[
        gr.Image(label="Result"),
        gr.Number(label='Seed', visible=False),
        gr.HTML(),
    ],
    allow_flagging="never",
A
AUTOMATIC 已提交
1619
    analytics_enabled=False,
A
AUTOMATIC 已提交
1620 1621
)

A
AUTOMATIC 已提交
1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648

def run_pnginfo(image):
    info = ''
    for key, text in image.info.items():
        info += f"""
<div>
<p><b>{plaintext_to_html(str(key))}</b></p>
<p>{plaintext_to_html(str(text))}</p>
</div>
""".strip()+"\n"

    if len(info) == 0:
        message = "Nothing found in the image."
        info = f"<div><p>{message}<p></div>"

    return [info]


pnginfo_interface = gr.Interface(
    wrap_gradio_call(run_pnginfo),
    inputs=[
        gr.Image(label="Source", source="upload", interactive=True, type="pil"),
    ],
    outputs=[
        gr.HTML(),
    ],
    allow_flagging="never",
A
AUTOMATIC 已提交
1649
    analytics_enabled=False,
A
AUTOMATIC 已提交
1650 1651 1652
)


A
AUTOMATIC 已提交
1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671
opts = Options()
if os.path.exists(config_filename):
    opts.load(config_filename)


def run_settings(*args):
    up = []

    for key, value, comp in zip(opts.data_labels.keys(), args, settings_interface.input_components):
        opts.data[key] = value
        up.append(comp.update(value=value))

    opts.save(config_filename)

    return 'Settings saved.', ''


def create_setting_component(key):
    def fun():
1672 1673 1674 1675
        return opts.data[key] if key in opts.data else opts.data_labels[key].default

    info = opts.data_labels[key]
    t = type(info.default)
A
AUTOMATIC 已提交
1676

1677 1678 1679 1680
    if info.component is not None:
        item = info.component(label=info.label, value=fun, **(info.component_args or {}))
    elif t == str:
        item = gr.Textbox(label=info.label, value=fun, lines=1)
A
AUTOMATIC 已提交
1681
    elif t == int:
1682
        item = gr.Number(label=info.label, value=fun)
A
AUTOMATIC 已提交
1683
    elif t == bool:
1684
        item = gr.Checkbox(label=info.label, value=fun)
A
AUTOMATIC 已提交
1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700
    else:
        raise Exception(f'bad options item type: {str(t)} for key {key}')

    return item


settings_interface = gr.Interface(
    run_settings,
    inputs=[create_setting_component(key) for key in opts.data_labels.keys()],
    outputs=[
        gr.Textbox(label='Result'),
        gr.HTML(),
    ],
    title=None,
    description=None,
    allow_flagging="never",
A
AUTOMATIC 已提交
1701
    analytics_enabled=False,
A
AUTOMATIC 已提交
1702 1703 1704 1705 1706
)

interfaces = [
    (txt2img_interface, "txt2img"),
    (img2img_interface, "img2img"),
A
AUTOMATIC 已提交
1707
    (extras_interface, "Extras"),
A
AUTOMATIC 已提交
1708
    (pnginfo_interface, "PNG Info"),
A
AUTOMATIC 已提交
1709 1710 1711
    (settings_interface, "Settings"),
]

A
AUTOMATIC 已提交
1712 1713 1714 1715 1716 1717 1718 1719 1720
try:
    # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.

    from transformers import logging

    logging.set_verbosity_error()
except Exception:
    pass

1721 1722
sd_config = OmegaConf.load(cmd_opts.config)
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
1723 1724 1725 1726 1727 1728 1729
sd_model = (sd_model if cmd_opts.no_half else sd_model.half())

if not cmd_opts.lowvram:
    sd_model = sd_model.to(device)

else:
    setup_for_low_vram(sd_model)
A
AUTOMATIC 已提交
1730

O
orionaskatu 已提交
1731
model_hijack = StableDiffusionModelHijack()
1732
model_hijack.hijack(sd_model)
1733

A
AUTOMATIC 已提交
1734 1735 1736
with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
    css = file.read()

1737 1738 1739
demo = gr.TabbedInterface(
    interface_list=[x[0] for x in interfaces],
    tab_names=[x[1] for x in interfaces],
A
AUTOMATIC 已提交
1740
    css=("" if cmd_opts.no_progressbar_hiding else css_hide_progressbar) + """
A
AUTOMATIC 已提交
1741 1742
.output-html p {margin: 0 0.5em;}
.performance { font-size: 0.85em; color: #444; }
A
AUTOMATIC 已提交
1743 1744
""" + css,
    analytics_enabled=False,
1745
)
A
first  
AUTOMATIC 已提交
1746

1747 1748 1749 1750 1751 1752 1753
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(signal, frame):
    print('Interrupted')
    os._exit(0)

signal.signal(signal.SIGINT, sigint_handler)

1754
demo.queue(concurrency_count=1)
A
AUTOMATIC 已提交
1755
demo.launch()