webui.py 28.0 KB
Newer Older
A
first  
AUTOMATIC 已提交
1 2 3 4 5 6
import argparse, os, sys, glob
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
from omegaconf import OmegaConf
7
from PIL import Image, ImageFont, ImageDraw
A
first  
AUTOMATIC 已提交
8 9 10 11 12 13
from itertools import islice
from einops import rearrange, repeat
from torch import autocast
from contextlib import contextmanager, nullcontext
import mimetypes
import random
14
import math
A
first  
AUTOMATIC 已提交
15 16 17 18 19 20

import k_diffusion as K
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

21 22 23 24 25 26 27 28
try:
    # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.

    from transformers import logging
    logging.set_verbosity_error()
except:
    pass

A
first  
AUTOMATIC 已提交
29 30 31 32 33 34 35 36
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')

# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8

37
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
38
invalid_filename_chars = '<>:"/\|?*\n'
A
AUTOMATIC 已提交
39

A
first  
AUTOMATIC 已提交
40 41 42 43
parser = argparse.ArgumentParser()
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None)
parser.add_argument("--skip_grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",)
parser.add_argument("--skip_save", action='store_true', help="do not save indiviual samples. For speed measurements.",)
44
parser.add_argument("--n_rows", type=int, default=-1, help="rows in the grid; use -1 for autodetect and 0 for n_rows to be same as batch_size (default: -1)",)
A
first  
AUTOMATIC 已提交
45 46 47
parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
48
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go
49
parser.add_argument("--no-verify-input", action='store_true', help="do not verify input to check if it's too long")
50
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
51
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
A
AUTOMATIC 已提交
52 53
parser.add_argument("--max-batch-count",  type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--grid-format",  type=str, default='png', help="file format for saved grids; can be png or jpg")
A
first  
AUTOMATIC 已提交
54 55 56 57
opt = parser.parse_args()

GFPGAN_dir = opt.gfpgan_dir

58 59 60 61 62 63
css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
"""
A
first  
AUTOMATIC 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


class CFGDenoiser(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.inner_model = model

    def forward(self, x, sigma, uncond, cond, cond_scale):
        x_in = torch.cat([x] * 2)
        sigma_in = torch.cat([sigma] * 2)
        cond_in = torch.cat([uncond, cond])
        uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
        return uncond + (cond - uncond) * cond_scale


A
AUTOMATIC 已提交
103 104 105 106 107 108 109 110 111
class KDiffusionSampler:
    def __init__(self, m):
        self.model = m
        self.model_wrap = K.external.CompVisDenoiser(m)

    def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
        sigmas = self.model_wrap.get_sigmas(S)
        x = x_T * sigmas[0]
        model_wrap_cfg = CFGDenoiser(self.model_wrap)
A
AUTOMATIC 已提交
112

A
AUTOMATIC 已提交
113 114 115 116 117
        samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)

        return samples_ddim, None


A
AUTOMATIC 已提交
118
def create_random_tensors(shape, seeds):
A
AUTOMATIC 已提交
119
    xs = []
A
AUTOMATIC 已提交
120 121 122 123 124 125 126
    for seed in seeds:
        torch.manual_seed(seed)

        # randn results depend on device; gpu and cpu get different results for same seed;
        # the way I see it, it's better to do this on CPU, so that everyone gets same result;
        # but the original script had it like this so i do not dare change it for now because
        # it will break everyone's seeds.
A
AUTOMATIC 已提交
127 128 129 130 131
        xs.append(torch.randn(shape, device=device))
    x = torch.stack(xs)
    return x


A
first  
AUTOMATIC 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
def load_GFPGAN():
    model_name = 'GFPGANv1.3'
    model_path = os.path.join(GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
    if not os.path.isfile(model_path):
        raise Exception("GFPGAN model not found at path "+model_path)

    sys.path.append(os.path.abspath(GFPGAN_dir))
    from gfpgan import GFPGANer

    return GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)


GFPGAN = None
if os.path.exists(GFPGAN_dir):
    try:
        GFPGAN = load_GFPGAN()
        print("Loaded GFPGAN")
    except Exception:
        import traceback
        print("Error loading GFPGAN:", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)

config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
158
model = (model if opt.no_half else model.half()).to(device)
A
first  
AUTOMATIC 已提交
159 160


A
AUTOMATIC 已提交
161 162 163 164
def image_grid(imgs, batch_size, round_down=False, force_n_rows=None):
    if force_n_rows is not None:
        rows = force_n_rows
    elif opt.n_rows > 0:
165 166 167 168
        rows = opt.n_rows
    elif opt.n_rows == 0:
        rows = batch_size
    else:
A
AUTOMATIC 已提交
169 170
        rows = math.sqrt(len(imgs))
        rows = int(rows) if round_down else round(rows)
171 172

    cols = math.ceil(len(imgs) / rows)
A
first  
AUTOMATIC 已提交
173 174

    w, h = imgs[0].size
175
    grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
A
first  
AUTOMATIC 已提交
176 177 178 179 180 181

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))

    return grid

182

183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
def draw_prompt_matrix(im, width, height, all_prompts):
    def wrap(text, d, font, line_length):
        lines = ['']
        for word in text.split():
            line = f'{lines[-1]} {word}'.strip()
            if d.textlength(line, font=font) <= line_length:
                lines[-1] = line
            else:
                lines.append(word)
        return '\n'.join(lines)

    def draw_texts(pos, x, y, texts, sizes):
        for i, (text, size) in enumerate(zip(texts, sizes)):
            active = pos & (1 << i) != 0

            if not active:
                text = '\u0336'.join(text) + '\u0336'

            d.multiline_text((x, y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center")

            y += size[1] + line_spacing

    fontsize = (width + height) // 25
    line_spacing = fontsize // 2
    fnt = ImageFont.truetype("arial.ttf", fontsize)
    color_active = (0, 0, 0)
    color_inactive = (153, 153, 153)

    pad_top = height // 4
A
AUTOMATIC 已提交
212
    pad_left = width * 3 // 4 if len(all_prompts) > 2 else 0
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247

    cols = im.width // width
    rows = im.height // height

    prompts = all_prompts[1:]

    result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
    result.paste(im, (pad_left, pad_top))

    d = ImageDraw.Draw(result)

    boundary = math.ceil(len(prompts) / 2)
    prompts_horiz = [wrap(x, d, fnt, width) for x in prompts[:boundary]]
    prompts_vert = [wrap(x, d, fnt, pad_left) for x in prompts[boundary:]]

    sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]]
    sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]]
    hor_text_height = sum([x[1] + line_spacing for x in sizes_hor]) - line_spacing
    ver_text_height = sum([x[1] + line_spacing for x in sizes_ver]) - line_spacing

    for col in range(cols):
        x = pad_left + width * col + width / 2
        y = pad_top / 2 - hor_text_height / 2

        draw_texts(col, x, y, prompts_horiz, sizes_hor)

    for row in range(rows):
        x = pad_left / 2
        y = pad_top + height * row + height / 2 - ver_text_height / 2

        draw_texts(row, x, y, prompts_vert, sizes_ver)

    return result


A
AUTOMATIC 已提交
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
def resize_image(resize_mode, im, width, height):
    if resize_mode == 0:
        res = im.resize((width, height), resample=LANCZOS)
    elif resize_mode == 1:
        ratio = width / height
        src_ratio = im.width / im.height

        src_w = width if ratio > src_ratio else im.width * height // im.height
        src_h = height if ratio <= src_ratio else im.height * width // im.width

        resized = im.resize((src_w, src_h), resample=LANCZOS)
        res = Image.new("RGB", (width, height))
        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
    else:
        ratio = width / height
        src_ratio = im.width / im.height

        src_w = width if ratio < src_ratio else im.width * height // im.height
        src_h = height if ratio >= src_ratio else im.height * width // im.width

        resized = im.resize((src_w, src_h), resample=LANCZOS)
        res = Image.new("RGB", (width, height))
        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))

        if ratio < src_ratio:
            fill_height = height // 2 - src_h // 2
            res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
            res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
276
        elif ratio > src_ratio:
A
AUTOMATIC 已提交
277 278 279 280 281 282 283
            fill_width = width // 2 - src_w // 2
            res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
            res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))

    return res


284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
def check_prompt_length(prompt, comments):
    """this function tests if prompt is too long, and if so, adds a message to comments"""

    tokenizer = model.cond_stage_model.tokenizer
    max_length = model.cond_stage_model.max_length

    info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
    ovf = info['overflowing_tokens'][0]
    overflowing_count = ovf.shape[0]
    if overflowing_count == 0:
        return

    vocab = {v: k for k, v in tokenizer.get_vocab().items()}
    overflowing_words = [vocab.get(int(x), "") for x in ovf]
    overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))

    comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")


A
AUTOMATIC 已提交
303
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False):
A
AUTOMATIC 已提交
304
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
A
first  
AUTOMATIC 已提交
305

A
AUTOMATIC 已提交
306 307
    assert prompt is not None
    torch.cuda.empty_cache()
A
first  
AUTOMATIC 已提交
308 309 310 311 312 313 314 315 316 317 318 319

    if seed == -1:
        seed = random.randrange(4294967294)
    seed = int(seed)

    os.makedirs(outpath, exist_ok=True)

    sample_path = os.path.join(outpath, "samples")
    os.makedirs(sample_path, exist_ok=True)
    base_count = len(os.listdir(sample_path))
    grid_count = len(os.listdir(outpath)) - 1

320 321
    comments = []

322
    prompt_matrix_parts = []
A
AUTOMATIC 已提交
323
    if prompt_matrix:
A
AUTOMATIC 已提交
324
        all_prompts = []
325
        prompt_matrix_parts = prompt.split("|")
A
AUTOMATIC 已提交
326
        combination_count = 2 ** (len(prompt_matrix_parts) - 1)
A
AUTOMATIC 已提交
327
        for combination_num in range(combination_count):
328
            current = prompt_matrix_parts[0]
A
AUTOMATIC 已提交
329

330
            for n, text in enumerate(prompt_matrix_parts[1:]):
A
AUTOMATIC 已提交
331
                if combination_num & (2 ** n) > 0:
A
AUTOMATIC 已提交
332 333
                    current += ("" if text.strip().startswith(",") else ", ") + text

A
AUTOMATIC 已提交
334
            all_prompts.append(current)
A
AUTOMATIC 已提交
335

A
AUTOMATIC 已提交
336 337 338 339 340
        n_iter = math.ceil(len(all_prompts) / batch_size)
        all_seeds = len(all_prompts) * [seed]

        print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
    else:
341 342 343 344 345 346 347 348 349

        if not opt.no_verify_input:
            try:
                check_prompt_length(prompt, comments)
            except:
                import traceback
                print("Error verifying input:", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

A
AUTOMATIC 已提交
350 351
        all_prompts = batch_size * n_iter * [prompt]
        all_seeds = [seed + x for x in range(len(all_prompts))]
A
AUTOMATIC 已提交
352

A
first  
AUTOMATIC 已提交
353 354 355
    precision_scope = autocast if opt.precision == "autocast" else nullcontext
    output_images = []
    with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
A
AUTOMATIC 已提交
356 357
        init_data = func_init()

A
first  
AUTOMATIC 已提交
358
        for n in range(n_iter):
A
AUTOMATIC 已提交
359 360
            prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
            seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
A
AUTOMATIC 已提交
361 362 363 364 365 366 367 368 369

            uc = None
            if cfg_scale != 1.0:
                uc = model.get_learned_conditioning(len(prompts) * [""])
            if isinstance(prompts, tuple):
                prompts = list(prompts)
            c = model.get_learned_conditioning(prompts)

            # we manually generate all input noises because each one should have a specific seed
A
AUTOMATIC 已提交
370
            x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
A
AUTOMATIC 已提交
371

A
AUTOMATIC 已提交
372
            samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc)
A
first  
AUTOMATIC 已提交
373

A
AUTOMATIC 已提交
374 375
            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
A
first  
AUTOMATIC 已提交
376

377
            if prompt_matrix or not opt.skip_save or not opt.skip_grid:
A
AUTOMATIC 已提交
378 379 380 381 382 383 384 385 386
                for i, x_sample in enumerate(x_samples_ddim):
                    x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                    x_sample = x_sample.astype(np.uint8)

                    if use_GFPGAN and GFPGAN is not None:
                        cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
                        x_sample = restored_img

                    image = Image.fromarray(x_sample)
A
AUTOMATIC 已提交
387
                    filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
A
AUTOMATIC 已提交
388 389 390 391 392

                    image.save(os.path.join(sample_path, filename))

                    output_images.append(image)
                    base_count += 1
A
first  
AUTOMATIC 已提交
393

A
AUTOMATIC 已提交
394
        if (prompt_matrix or not opt.skip_grid) and not do_not_save_grid:
A
AUTOMATIC 已提交
395
            grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
396 397

            if prompt_matrix:
A
AUTOMATIC 已提交
398 399 400 401 402 403 404 405

                try:
                    grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
                except Exception:
                    import traceback
                    print("Error creating prompt_matrix text:", file=sys.stderr)
                    print(traceback.format_exc(), file=sys.stderr)

406 407
                output_images.insert(0, grid)

A
AUTOMATIC 已提交
408
            grid.save(os.path.join(outpath, f'grid-{grid_count:04}.{opt.grid_format}'))
A
first  
AUTOMATIC 已提交
409 410 411 412
            grid_count += 1

    info = f"""
{prompt}
A
AUTOMATIC 已提交
413 414
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
        """.strip()
A
first  
AUTOMATIC 已提交
415

416 417 418
    for comment in comments:
        info += "\n\n" + comment

A
first  
AUTOMATIC 已提交
419 420
    return output_images, seed, info

A
AUTOMATIC 已提交
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462

def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
    outpath = opt.outdir or "outputs/txt2img-samples"

    if sampler_name == 'PLMS':
        sampler = PLMSSampler(model)
    elif sampler_name == 'DDIM':
        sampler = DDIMSampler(model)
    elif sampler_name == 'k-diffusion':
        sampler = KDiffusionSampler(model)
    else:
        raise Exception("Unknown sampler: " + sampler_name)

    def init():
        pass

    def sample(init_data, x, conditioning, unconditional_conditioning):
        samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
        return samples_ddim

    output_images, seed, info = process_images(
        outpath=outpath,
        func_init=init,
        func_sample=sample,
        prompt=prompt,
        seed=seed,
        sampler_name=sampler_name,
        batch_size=batch_size,
        n_iter=n_iter,
        steps=ddim_steps,
        cfg_scale=cfg_scale,
        width=width,
        height=height,
        prompt_matrix=prompt_matrix,
        use_GFPGAN=use_GFPGAN
    )

    del sampler

    return output_images, seed, info


A
AUTOMATIC 已提交
463 464 465 466 467
class Flagging(gr.FlaggingCallback):

    def setup(self, components, flagging_dir: str):
        pass

468 469 470
    def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
        import csv

A
AUTOMATIC 已提交
471 472
        os.makedirs("log/images", exist_ok=True)

A
AUTOMATIC 已提交
473
        # those must match the "txt2img" function
A
AUTOMATIC 已提交
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
        prompt, ddim_steps, sampler_name, use_GFPGAN, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data

        filenames = []

        with open("log/log.csv", "a", encoding="utf8", newline='') as file:
            import time
            import base64

            at_start = file.tell() == 0
            writer = csv.writer(file)
            if at_start:
                writer.writerow(["prompt", "seed", "width", "height", "cfgs", "steps", "filename"])

            filename_base = str(int(time.time() * 1000))
            for i, filedata in enumerate(images):
                filename = "log/images/"+filename_base + ("" if len(images) == 1 else "-"+str(i+1)) + ".png"

                if filedata.startswith("data:image/png;base64,"):
                    filedata = filedata[len("data:image/png;base64,"):]

                with open(filename, "wb") as imgfile:
                    imgfile.write(base64.decodebytes(filedata.encode('utf-8')))

                filenames.append(filename)

            writer.writerow([prompt, seed, width, height, cfg_scale, ddim_steps, filenames[0]])

        print("Logged:", filenames[0])

A
first  
AUTOMATIC 已提交
503

A
AUTOMATIC 已提交
504 505
txt2img_interface = gr.Interface(
    txt2img,
A
first  
AUTOMATIC 已提交
506 507 508 509 510
    inputs=[
        gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
        gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
        gr.Radio(label='Sampling method', choices=["DDIM", "PLMS", "k-diffusion"], value="k-diffusion"),
        gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
A
AUTOMATIC 已提交
511
        gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
A
first  
AUTOMATIC 已提交
512
        gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
A
AUTOMATIC 已提交
513
        gr.Slider(minimum=1, maximum=opt.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
514
        gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
515
        gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
A
first  
AUTOMATIC 已提交
516 517 518 519 520 521 522 523 524 525 526
        gr.Number(label='Seed', value=-1),
        gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
        gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
    ],
    outputs=[
        gr.Gallery(label="Images"),
        gr.Number(label='Seed'),
        gr.Textbox(label="Copy-paste generation parameters"),
    ],
    title="Stable Diffusion Text-to-Image K",
    description="Generate images from text with Stable Diffusion (using K-LMS)",
A
AUTOMATIC 已提交
527
    flagging_callback=Flagging()
A
first  
AUTOMATIC 已提交
528 529 530
)


A
AUTOMATIC 已提交
531
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
A
first  
AUTOMATIC 已提交
532 533
    outpath = opt.outdir or "outputs/img2img-samples"

A
AUTOMATIC 已提交
534
    sampler = KDiffusionSampler(model)
A
first  
AUTOMATIC 已提交
535

A
AUTOMATIC 已提交
536
    assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
A
first  
AUTOMATIC 已提交
537

A
AUTOMATIC 已提交
538 539
    def init():
        image = init_img.convert("RGB")
A
AUTOMATIC 已提交
540
        image = resize_image(resize_mode, image, width, height)
A
AUTOMATIC 已提交
541 542 543
        image = np.array(image).astype(np.float32) / 255.0
        image = image[None].transpose(0, 3, 1, 2)
        image = torch.from_numpy(image)
A
first  
AUTOMATIC 已提交
544

545 546 547 548
        init_image = 2. * image - 1.
        init_image = init_image.to(device)
        init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
        init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))  # move to latent space
A
AUTOMATIC 已提交
549

A
AUTOMATIC 已提交
550 551 552
        return init_latent,

    def sample(init_data, x, conditioning, unconditional_conditioning):
A
AUTOMATIC 已提交
553 554
        t_enc = int(denoising_strength * ddim_steps)

A
AUTOMATIC 已提交
555 556 557 558 559 560 561 562 563 564 565
        x0, = init_data

        sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
        noise = x * sigmas[ddim_steps - t_enc - 1]

        xi = x0 + noise
        sigma_sched = sigmas[ddim_steps - t_enc - 1:]
        model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
        samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
        return samples_ddim

A
AUTOMATIC 已提交
566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621
    if loopback:
        output_images, info = None, None
        history = []
        initial_seed = None

        for i in range(n_iter):
            output_images, seed, info = process_images(
                outpath=outpath,
                func_init=init,
                func_sample=sample,
                prompt=prompt,
                seed=seed,
                sampler_name='k-diffusion',
                batch_size=1,
                n_iter=1,
                steps=ddim_steps,
                cfg_scale=cfg_scale,
                width=width,
                height=height,
                prompt_matrix=prompt_matrix,
                use_GFPGAN=use_GFPGAN,
                do_not_save_grid=True
            )

            if initial_seed is None:
                initial_seed = seed

            init_img = output_images[0]
            seed = seed + 1
            denoising_strength = max(denoising_strength * 0.95, 0.1)
            history.append(init_img)

        grid_count = len(os.listdir(outpath)) - 1
        grid = image_grid(history, batch_size, force_n_rows=1)
        grid.save(os.path.join(outpath, f'grid-{grid_count:04}.{opt.grid_format}'))

        output_images = history
        seed = initial_seed

    else:
        output_images, seed, info = process_images(
            outpath=outpath,
            func_init=init,
            func_sample=sample,
            prompt=prompt,
            seed=seed,
            sampler_name='k-diffusion',
            batch_size=batch_size,
            n_iter=n_iter,
            steps=ddim_steps,
            cfg_scale=cfg_scale,
            width=width,
            height=height,
            prompt_matrix=prompt_matrix,
            use_GFPGAN=use_GFPGAN
        )
A
AUTOMATIC 已提交
622

A
AUTOMATIC 已提交
623
    del sampler
A
first  
AUTOMATIC 已提交
624

A
AUTOMATIC 已提交
625
    return output_images, seed, info
A
first  
AUTOMATIC 已提交
626 627


628 629 630
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None

A
first  
AUTOMATIC 已提交
631
img2img_interface = gr.Interface(
A
AUTOMATIC 已提交
632
    img2img,
A
first  
AUTOMATIC 已提交
633 634
    inputs=[
        gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
635
        gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil"),
A
first  
AUTOMATIC 已提交
636
        gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
637
        gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
A
AUTOMATIC 已提交
638
        gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
A
AUTOMATIC 已提交
639 640
        gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),
        gr.Slider(minimum=1, maximum=opt.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
641
        gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
642
        gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
A
first  
AUTOMATIC 已提交
643 644
        gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75),
        gr.Number(label='Seed', value=-1),
645 646
        gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
        gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
A
AUTOMATIC 已提交
647
        gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
A
first  
AUTOMATIC 已提交
648 649 650
    ],
    outputs=[
        gr.Gallery(),
A
AUTOMATIC 已提交
651 652
        gr.Number(label='Seed'),
        gr.Textbox(label="Copy-paste generation parameters"),
A
first  
AUTOMATIC 已提交
653 654 655
    ],
    title="Stable Diffusion Image-to-Image",
    description="Generate images from images with Stable Diffusion",
656
    allow_flagging="never",
A
first  
AUTOMATIC 已提交
657 658
)

659
interfaces = [
A
AUTOMATIC 已提交
660
    (txt2img_interface, "txt2img"),
661
    (img2img_interface, "img2img")
662 663 664 665 666 667 668 669 670
]

def run_GFPGAN(image, strength):
    image = image.convert("RGB")

    cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
    res = Image.fromarray(restored_img)

    if strength < 1.0:
671
        res = Image.blend(image, res, strength)
672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687

    return res


if GFPGAN is not None:
    interfaces.append((gr.Interface(
        run_GFPGAN,
        inputs=[
            gr.Image(label="Source", source="upload", interactive=True, type="pil"),
            gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Effect strength", value=100),
        ],
        outputs=[
            gr.Image(label="Result"),
        ],
        title="GFPGAN",
        description="Fix faces on images",
688
        allow_flagging="never",
689 690
    ), "GFPGAN"))

691 692 693 694 695
demo = gr.TabbedInterface(
    interface_list=[x[0] for x in interfaces],
    tab_names=[x[1] for x in interfaces],
    css=("" if opt.no_progressbar_hiding else css_hide_progressbar)
)
A
first  
AUTOMATIC 已提交
696 697

demo.launch()