processing.py 21.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
import contextlib
import json
import math
import os
import sys

import torch
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import random
11 12
import cv2
from skimage import exposure
13

14
import modules.sd_hijack
15
from modules import devices, prompt_parser, masking
16 17 18 19
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
A
AUTOMATIC 已提交
20
import modules.face_restoration
21
import modules.images as images
A
AUTOMATIC 已提交
22
import modules.styles
23

24

25 26 27 28 29
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8


30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
def setup_color_correction(image):
    correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
    return correction_target


def apply_color_correction(correction, image):
    image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
        cv2.cvtColor(
            np.asarray(image),
            cv2.COLOR_RGB2LAB
        ),
        correction,
        channel_axis=2
    ), cv2.COLOR_LAB2RGB).astype("uint8"))

    return image


48
class StableDiffusionProcessing:
A
AUTOMATIC 已提交
49
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
50 51 52 53
        self.sd_model = sd_model
        self.outpath_samples: str = outpath_samples
        self.outpath_grids: str = outpath_grids
        self.prompt: str = prompt
A
AUTOMATIC 已提交
54
        self.prompt_for_display: str = None
55
        self.negative_prompt: str = (negative_prompt or "")
A
AUTOMATIC 已提交
56
        self.styles: str = styles
57
        self.seed: int = seed
58 59 60 61
        self.subseed: int = subseed
        self.subseed_strength: float = subseed_strength
        self.seed_resize_from_h: int = seed_resize_from_h
        self.seed_resize_from_w: int = seed_resize_from_w
62 63 64 65 66 67 68
        self.sampler_index: int = sampler_index
        self.batch_size: int = batch_size
        self.n_iter: int = n_iter
        self.steps: int = steps
        self.cfg_scale: float = cfg_scale
        self.width: int = width
        self.height: int = height
A
AUTOMATIC 已提交
69
        self.restore_faces: bool = restore_faces
70
        self.tiling: bool = tiling
71 72 73 74 75
        self.do_not_save_samples: bool = do_not_save_samples
        self.do_not_save_grid: bool = do_not_save_grid
        self.extra_generation_params: dict = extra_generation_params
        self.overlay_images = overlay_images
        self.paste_to = None
76
        self.color_corrections = None
77

78
    def init(self, seed):
79 80 81 82 83 84 85
        pass

    def sample(self, x, conditioning, unconditional_conditioning):
        raise NotImplementedError()


class Processed:
86
    def __init__(self, p: StableDiffusionProcessing, images_list, seed, info, subseed=None):
87 88
        self.images = images_list
        self.prompt = p.prompt
89
        self.negative_prompt = p.negative_prompt
90
        self.seed = seed
91 92
        self.subseed = subseed
        self.subseed_strength = p.subseed_strength
93 94 95 96 97 98 99 100 101
        self.info = info
        self.width = p.width
        self.height = p.height
        self.sampler = samplers[p.sampler_index].name
        self.cfg_scale = p.cfg_scale
        self.steps = p.steps

    def js(self):
        obj = {
A
AUTOMATIC 已提交
102
            "prompt": self.prompt if type(self.prompt) != list else self.prompt[0],
103
            "negative_prompt": self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0],
A
AUTOMATIC 已提交
104
            "seed": int(self.seed if type(self.seed) != list else self.seed[0]),
105 106
            "subseed": int(self.subseed if type(self.subseed) != list else self.subseed[0]),
            "subseed_strength": self.subseed_strength,
107 108 109 110 111 112 113 114 115
            "width": self.width,
            "height": self.height,
            "sampler": self.sampler,
            "cfg_scale": self.cfg_scale,
            "steps": self.steps,
        }

        return json.dumps(obj)

116 117 118 119 120 121 122 123
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm*high_norm).sum(1))
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
    return res
124

125

126
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
127
    xs = []
128

129 130
    # if we have multiple seeds, this means we are working with batch size>1; this then
    # enables the generation of additional tensors with noise that the sampler will use during its processing.
S
Steve Eberhardt 已提交
131
    # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
132
    # produce the same images as with two batches [100], [101].
133
    if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
134 135 136 137
        sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
    else:
        sampler_noises = None

138 139 140 141 142 143
    for i, seed in enumerate(seeds):
        noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)

        subnoise = None
        if subseeds is not None:
            subseed = 0 if i >= len(subseeds) else subseeds[i]
E
Elias Oenal 已提交
144

A
AUTOMATIC 已提交
145
            subnoise = devices.randn(subseed, noise_shape)
146 147 148

        # randn results depend on device; gpu and cpu get different results for same seed;
        # the way I see it, it's better to do this on CPU, so that everyone gets same result;
149
        # but the original script had it like this, so I do not dare change it for now because
150
        # it will break everyone's seeds.
A
AUTOMATIC 已提交
151
        noise = devices.randn(seed, noise_shape)
152 153 154 155 156 157 158

        if subnoise is not None:
            #noise = subnoise * subseed_strength + noise * (1 - subseed_strength)
            noise = slerp(subseed_strength, noise, subnoise)

        if noise_shape != shape:
            #noise = torch.nn.functional.interpolate(noise.unsqueeze(1), size=shape[1:], mode="bilinear").squeeze()
A
AUTOMATIC 已提交
159 160
            x = devices.randn(seed, shape)
            dx = (shape[2] - noise_shape[2]) // 2
161 162 163 164 165 166 167 168 169 170 171
            dy = (shape[1] - noise_shape[1]) // 2
            w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
            h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
            tx = 0 if dx < 0 else dx
            ty = 0 if dy < 0 else dy
            dx = max(-dx, 0)
            dy = max(-dy, 0)

            x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
            noise = x

172 173
        if sampler_noises is not None:
            cnt = p.sampler.number_of_needed_noises(p)
174

175 176
            for j in range(cnt):
                sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
177 178

        xs.append(noise)
179 180 181 182

    if sampler_noises is not None:
        p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]

183
    x = torch.stack(xs).to(shared.device)
184 185 186
    return x


187 188 189
def fix_seed(p):
    p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == -1 else p.seed
    p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed
A
AUTOMATIC 已提交
190 191


192 193 194
def process_images(p: StableDiffusionProcessing) -> Processed:
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

195 196 197 198 199
    if type(p.prompt) == list:
        assert(len(p.prompt) > 0)
    else:
        assert p.prompt is not None
        
200
    devices.torch_gc()
201

202
    fix_seed(p)
203 204 205 206

    os.makedirs(p.outpath_samples, exist_ok=True)
    os.makedirs(p.outpath_grids, exist_ok=True)

207 208
    modules.sd_hijack.model_hijack.apply_circular(p.tiling)

209
    comments = {}
210

A
AUTOMATIC 已提交
211
    shared.prompt_styles.apply_styles(p)
A
AUTOMATIC 已提交
212 213 214

    if type(p.prompt) == list:
        all_prompts = p.prompt
215
    else:
A
AUTOMATIC 已提交
216
        all_prompts = p.batch_size * p.n_iter * [p.prompt]
A
AUTOMATIC 已提交
217

218
    if type(p.seed) == list:
L
Lukas Meller 已提交
219
        all_seeds = p.seed
A
AUTOMATIC 已提交
220
    else:
221
        all_seeds = [int(p.seed + (x if p.subseed_strength == 0 else 0)) for x in range(len(all_prompts))]
222 223 224 225 226

    if type(p.subseed) == list:
        all_subseeds = p.subseed
    else:
        all_subseeds = [int(p.subseed + x) for x in range(len(all_prompts))]
227 228

    def infotext(iteration=0, position_in_batch=0):
229 230
        index = position_in_batch + iteration * p.batch_size

231 232 233 234
        generation_params = {
            "Steps": p.steps,
            "Sampler": samplers[p.sampler_index].name,
            "CFG scale": p.cfg_scale,
235
            "Seed": all_seeds[index],
A
AUTOMATIC 已提交
236
            "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
237
            "Size": f"{p.width}x{p.height}",
238
            "Model hash": (None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
239 240
            "Batch size": (None if p.batch_size < 2 else p.batch_size),
            "Batch pos": (None if p.batch_size < 2 else position_in_batch),
241 242 243
            "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
            "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
            "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
244
            "Denoising strength": getattr(p, 'denoising_strength', None),
245 246 247 248 249 250
        }

        if p.extra_generation_params is not None:
            generation_params.update(p.extra_generation_params)

        generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
251 252
        
        negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
253

A
AUTOMATIC 已提交
254
        return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
255 256 257 258 259 260 261 262

    if os.path.exists(cmd_opts.embeddings_dir):
        model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)

    output_images = []
    precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
    ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
    with torch.no_grad(), precision_scope("cuda"), ema_scope():
263
        p.init(seed=all_seeds[0])
264

A
AUTOMATIC 已提交
265 266
        if state.job_count == -1:
            state.job_count = p.n_iter
A
AUTOMATIC 已提交
267

268 269 270 271 272 273
        for n in range(p.n_iter):
            if state.interrupted:
                break

            prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
274
            subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
275

276 277 278
            if (len(prompts) == 0):
                break

A
AUTOMATIC 已提交
279 280 281 282
            #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
            #c = p.sd_model.get_learned_conditioning(prompts)
            uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
            c = prompt_parser.get_learned_conditioning(prompts, p.steps)
283 284

            if len(model_hijack.comments) > 0:
285 286
                for comment in model_hijack.comments:
                    comments[comment] = 1
287 288

            # we manually generate all input noises because each one should have a specific seed
289
            x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
290 291 292 293 294

            if p.n_iter > 1:
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"

            samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
A
AUTOMATIC 已提交
295 296 297 298 299
            if state.interrupted:

                # if we are interruped, sample returns just noise
                # use the image collected previously in sampler loop
                samples_ddim = shared.state.current_latent
300 301 302 303

            x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

G
GRMrGecko 已提交
304
            if opts.filter_nsfw:
305 306
                import modules.safety as safety
                x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
G
GRMrGecko 已提交
307

308 309 310 311
            for i, x_sample in enumerate(x_samples_ddim):
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)

A
AUTOMATIC 已提交
312
                if p.restore_faces:
313 314 315
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)

316
                    devices.torch_gc()
317

A
AUTOMATIC 已提交
318
                    x_sample = modules.face_restoration.restore_faces(x_sample)
319 320 321

                image = Image.fromarray(x_sample)

322 323
                if p.color_corrections is not None and i < len(p.color_corrections):
                    image = apply_color_correction(p.color_corrections[i], image)
324

325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
                if p.overlay_images is not None and i < len(p.overlay_images):
                    overlay = p.overlay_images[i]

                    if p.paste_to is not None:
                        x, y, w, h = p.paste_to
                        base_image = Image.new('RGBA', (overlay.width, overlay.height))
                        image = images.resize_image(1, image, w, h)
                        base_image.paste(image, (x, y))
                        image = base_image

                    image = image.convert('RGBA')
                    image.alpha_composite(overlay)
                    image = image.convert('RGB')

                if opts.samples_save and not p.do_not_save_samples:
340
                    images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
341 342 343

                output_images.append(image)

A
AUTOMATIC 已提交
344 345
            state.nextjob()

346 347
        p.color_corrections = None

348
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
349
        if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
A
AUTOMATIC 已提交
350
            grid = images.image_grid(output_images, p.batch_size)
351

352
            if opts.return_grid:
353 354 355
                output_images.insert(0, grid)

            if opts.grid_save:
356
                images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p)
357

358
    devices.torch_gc()
359
    return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0])
360 361 362 363 364


class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    sampler = None

365
    def init(self, seed):
366 367 368 369 370 371 372 373 374
        self.sampler = samplers[self.sampler_index].constructor(self.sd_model)

    def sample(self, x, conditioning, unconditional_conditioning):
        samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
        return samples_ddim

class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    sampler = None

A
AUTOMATIC 已提交
375
    def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpainting_mask_invert=0, **kwargs):
376 377 378 379 380 381 382
        super().__init__(**kwargs)

        self.init_images = init_images
        self.resize_mode: int = resize_mode
        self.denoising_strength: float = denoising_strength
        self.init_latent = None
        self.image_mask = mask
A
AUTOMATIC 已提交
383 384
        #self.image_unblurred_mask = None
        self.latent_mask = None
385 386 387 388
        self.mask_for_overlay = None
        self.mask_blur = mask_blur
        self.inpainting_fill = inpainting_fill
        self.inpaint_full_res = inpaint_full_res
A
AUTOMATIC 已提交
389
        self.inpainting_mask_invert = inpainting_mask_invert
390 391 392
        self.mask = None
        self.nmask = None

393
    def init(self, seed):
394 395 396 397
        self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
        crop_region = None

        if self.image_mask is not None:
A
AUTOMATIC 已提交
398 399 400 401 402
            self.image_mask = self.image_mask.convert('L')

            if self.inpainting_mask_invert:
                self.image_mask = ImageOps.invert(self.image_mask)

A
AUTOMATIC 已提交
403 404
            #self.image_unblurred_mask = self.image_mask

405
            if self.mask_blur > 0:
A
AUTOMATIC 已提交
406
                self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
407 408 409 410

            if self.inpaint_full_res:
                self.mask_for_overlay = self.image_mask
                mask = self.image_mask.convert('L')
411 412
                crop_region = masking.get_crop_region(np.array(mask), opts.upscale_at_full_resolution_padding)
                crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
413 414 415 416 417 418 419
                x1, y1, x2, y2 = crop_region

                mask = mask.crop(crop_region)
                self.image_mask = images.resize_image(2, mask, self.width, self.height)
                self.paste_to = (x1, y1, x2-x1, y2-y1)
            else:
                self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
420
                np_mask = np.array(self.image_mask)
J
JJ 已提交
421
                np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
422
                self.mask_for_overlay = Image.fromarray(np_mask)
423 424 425

            self.overlay_images = []

426 427
        latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask

428 429 430
        add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
        if add_color_corrections:
            self.color_corrections = []
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
        imgs = []
        for img in self.init_images:
            image = img.convert("RGB")

            if crop_region is None:
                image = images.resize_image(self.resize_mode, image, self.width, self.height)

            if self.image_mask is not None:
                image_masked = Image.new('RGBa', (image.width, image.height))
                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))

                self.overlay_images.append(image_masked.convert('RGBA'))

            if crop_region is not None:
                image = image.crop(crop_region)
                image = images.resize_image(2, image, self.width, self.height)

448 449
            if self.image_mask is not None:
                if self.inpainting_fill != 1:
450
                    image = masking.fill(image, latent_mask)
451

452
            if add_color_corrections:
453 454
                self.color_corrections.append(setup_color_correction(image))

455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)

            imgs.append(image)

        if len(imgs) == 1:
            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size
        elif len(imgs) <= self.batch_size:
            self.batch_size = len(imgs)
            batch_images = np.array(imgs)
        else:
            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

        image = torch.from_numpy(batch_images)
        image = 2. * image - 1.
        image = image.to(shared.device)

        self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))

        if self.image_mask is not None:
477
            init_mask = latent_mask
A
AUTOMATIC 已提交
478
            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
A
AUTOMATIC 已提交
479
            latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
480
            latmask = latmask[0]
481
            latmask = np.around(latmask)
482 483 484 485 486 487
            latmask = np.tile(latmask[None], (4, 1, 1))

            self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
            self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)

            if self.inpainting_fill == 2:
488
                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], [seed + x + 1 for x in range(self.init_latent.shape[0])]) * self.nmask
489 490 491 492 493 494 495 496 497 498
            elif self.inpainting_fill == 3:
                self.init_latent = self.init_latent * self.mask

    def sample(self, x, conditioning, unconditional_conditioning):
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)

        if self.mask is not None:
            samples = samples * self.nmask + self.init_latent * self.mask

        return samples