processing.py 19.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
import contextlib
import json
import math
import os
import sys

import torch
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import random

12
import modules.sd_hijack
13
from modules import devices
14 15 16 17
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
A
AUTOMATIC 已提交
18
import modules.face_restoration
19
import modules.images as images
A
AUTOMATIC 已提交
20
import modules.styles
21 22 23 24 25 26 27 28

# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8



class StableDiffusionProcessing:
29
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
30 31 32 33
        self.sd_model = sd_model
        self.outpath_samples: str = outpath_samples
        self.outpath_grids: str = outpath_grids
        self.prompt: str = prompt
A
AUTOMATIC 已提交
34
        self.prompt_for_display: str = None
35
        self.negative_prompt: str = (negative_prompt or "")
36
        self.prompt_style: str = prompt_style
37
        self.seed: int = seed
38 39 40 41
        self.subseed: int = subseed
        self.subseed_strength: float = subseed_strength
        self.seed_resize_from_h: int = seed_resize_from_h
        self.seed_resize_from_w: int = seed_resize_from_w
42 43 44 45 46 47 48
        self.sampler_index: int = sampler_index
        self.batch_size: int = batch_size
        self.n_iter: int = n_iter
        self.steps: int = steps
        self.cfg_scale: float = cfg_scale
        self.width: int = width
        self.height: int = height
A
AUTOMATIC 已提交
49
        self.restore_faces: bool = restore_faces
50
        self.tiling: bool = tiling
51 52 53 54 55 56
        self.do_not_save_samples: bool = do_not_save_samples
        self.do_not_save_grid: bool = do_not_save_grid
        self.extra_generation_params: dict = extra_generation_params
        self.overlay_images = overlay_images
        self.paste_to = None

57
    def init(self, seed):
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
        pass

    def sample(self, x, conditioning, unconditional_conditioning):
        raise NotImplementedError()


class Processed:
    def __init__(self, p: StableDiffusionProcessing, images_list, seed, info):
        self.images = images_list
        self.prompt = p.prompt
        self.seed = seed
        self.info = info
        self.width = p.width
        self.height = p.height
        self.sampler = samplers[p.sampler_index].name
        self.cfg_scale = p.cfg_scale
        self.steps = p.steps

    def js(self):
        obj = {
A
AUTOMATIC 已提交
78 79
            "prompt": self.prompt if type(self.prompt) != list else self.prompt[0],
            "seed": int(self.seed if type(self.seed) != list else self.seed[0]),
80 81 82 83 84 85 86 87 88
            "width": self.width,
            "height": self.height,
            "sampler": self.sampler,
            "cfg_scale": self.cfg_scale,
            "steps": self.steps,
        }

        return json.dumps(obj)

89 90 91 92 93 94 95 96
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm*high_norm).sum(1))
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
    return res
97

98 99

def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
100
    xs = []
101 102 103 104 105 106 107 108
    for i, seed in enumerate(seeds):
        noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)

        subnoise = None
        if subseeds is not None:
            subseed = 0 if i >= len(subseeds) else subseeds[i]
            torch.manual_seed(subseed)
            subnoise = torch.randn(noise_shape, device=shared.device)
109 110 111

        # randn results depend on device; gpu and cpu get different results for same seed;
        # the way I see it, it's better to do this on CPU, so that everyone gets same result;
112
        # but the original script had it like this, so I do not dare change it for now because
113
        # it will break everyone's seeds.
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
        torch.manual_seed(seed)
        noise = torch.randn(noise_shape, device=shared.device)

        if subnoise is not None:
            #noise = subnoise * subseed_strength + noise * (1 - subseed_strength)
            noise = slerp(subseed_strength, noise, subnoise)

        if noise_shape != shape:
            #noise = torch.nn.functional.interpolate(noise.unsqueeze(1), size=shape[1:], mode="bilinear").squeeze()
            # noise_shape = (64, 80)
            # shape = (64, 72)

            torch.manual_seed(seed)
            x = torch.randn(shape, device=shared.device)
            dx = (shape[2] - noise_shape[2]) // 2 # -4
            dy = (shape[1] - noise_shape[1]) // 2
            w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
            h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
            tx = 0 if dx < 0 else dx
            ty = 0 if dy < 0 else dy
            dx = max(-dx, 0)
            dy = max(-dy, 0)

            x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
            noise = x



        xs.append(noise)
    x = torch.stack(xs).to(shared.device)
144 145 146
    return x


147 148 149
def fix_seed(p):
    p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == -1 else p.seed
    p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed
A
AUTOMATIC 已提交
150 151


152 153 154 155
def process_images(p: StableDiffusionProcessing) -> Processed:
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

    assert p.prompt is not None
156
    devices.torch_gc()
157

158
    fix_seed(p)
159 160 161 162

    os.makedirs(p.outpath_samples, exist_ok=True)
    os.makedirs(p.outpath_grids, exist_ok=True)

163 164
    modules.sd_hijack.model_hijack.apply_circular(p.tiling)

165 166
    comments = []

A
AUTOMATIC 已提交
167 168 169 170
    modules.styles.apply_style(p, shared.prompt_styles[p.prompt_style])

    if type(p.prompt) == list:
        all_prompts = p.prompt
171
    else:
A
AUTOMATIC 已提交
172
        all_prompts = p.batch_size * p.n_iter * [p.prompt]
A
AUTOMATIC 已提交
173

174
    if type(p.seed) == list:
L
Lukas Meller 已提交
175
        all_seeds = p.seed
A
AUTOMATIC 已提交
176
    else:
177 178 179 180 181 182
        all_seeds = [int(p.seed + x) for x in range(len(all_prompts))]

    if type(p.subseed) == list:
        all_subseeds = p.subseed
    else:
        all_subseeds = [int(p.subseed + x) for x in range(len(all_prompts))]
183 184

    def infotext(iteration=0, position_in_batch=0):
185 186
        index = position_in_batch + iteration * p.batch_size

187 188 189 190
        generation_params = {
            "Steps": p.steps,
            "Sampler": samplers[p.sampler_index].name,
            "CFG scale": p.cfg_scale,
191
            "Seed": all_seeds[index],
A
AUTOMATIC 已提交
192
            "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
193
            "Size": f"{p.width}x{p.height}",
194 195
            "Batch size": (None if p.batch_size < 2 else p.batch_size),
            "Batch pos": (None if p.batch_size < 2 else position_in_batch),
196 197 198
            "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
            "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
            "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
199 200 201 202 203 204
        }

        if p.extra_generation_params is not None:
            generation_params.update(p.extra_generation_params)

        generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
205 206
        
        negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
207

A
AUTOMATIC 已提交
208
        return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
209 210 211 212 213 214 215 216

    if os.path.exists(cmd_opts.embeddings_dir):
        model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)

    output_images = []
    precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
    ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
    with torch.no_grad(), precision_scope("cuda"), ema_scope():
217
        p.init(seed=all_seeds[0])
218

A
AUTOMATIC 已提交
219 220
        if state.job_count == -1:
            state.job_count = p.n_iter
A
AUTOMATIC 已提交
221

222 223 224 225 226 227
        for n in range(p.n_iter):
            if state.interrupted:
                break

            prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
228
            subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
229 230 231 232 233 234 235 236

            uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
            c = p.sd_model.get_learned_conditioning(prompts)

            if len(model_hijack.comments) > 0:
                comments += model_hijack.comments

            # we manually generate all input noises because each one should have a specific seed
237
            x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
238 239 240 241 242

            if p.n_iter > 1:
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"

            samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
A
AUTOMATIC 已提交
243 244 245 246 247
            if state.interrupted:

                # if we are interruped, sample returns just noise
                # use the image collected previously in sampler loop
                samples_ddim = shared.state.current_latent
248 249 250 251 252 253 254 255

            x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

            for i, x_sample in enumerate(x_samples_ddim):
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)

A
AUTOMATIC 已提交
256
                if p.restore_faces:
257
                    devices.torch_gc()
258

A
AUTOMATIC 已提交
259
                    x_sample = modules.face_restoration.restore_faces(x_sample)
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277

                image = Image.fromarray(x_sample)

                if p.overlay_images is not None and i < len(p.overlay_images):
                    overlay = p.overlay_images[i]

                    if p.paste_to is not None:
                        x, y, w, h = p.paste_to
                        base_image = Image.new('RGBA', (overlay.width, overlay.height))
                        image = images.resize_image(1, image, w, h)
                        base_image.paste(image, (x, y))
                        image = base_image

                    image = image.convert('RGBA')
                    image.alpha_composite(overlay)
                    image = image.convert('RGB')

                if opts.samples_save and not p.do_not_save_samples:
278
                    images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
279 280 281

                output_images.append(image)

A
AUTOMATIC 已提交
282 283
            state.nextjob()

284 285 286 287
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
        if not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
            return_grid = opts.return_grid

A
AUTOMATIC 已提交
288
            grid = images.image_grid(output_images, p.batch_size)
289 290 291 292 293

            if return_grid:
                output_images.insert(0, grid)

            if opts.grid_save:
294
                images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p)
295

296
    devices.torch_gc()
297
    return Processed(p, output_images, all_seeds[0], infotext())
298 299 300 301 302


class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    sampler = None

303
    def init(self, seed):
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
        self.sampler = samplers[self.sampler_index].constructor(self.sd_model)

    def sample(self, x, conditioning, unconditional_conditioning):
        samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
        return samples_ddim


def get_crop_region(mask, pad=0):
    h, w = mask.shape

    crop_left = 0
    for i in range(w):
        if not (mask[:, i] == 0).all():
            break
        crop_left += 1

    crop_right = 0
    for i in reversed(range(w)):
        if not (mask[:, i] == 0).all():
            break
        crop_right += 1

    crop_top = 0
    for i in range(h):
        if not (mask[i] == 0).all():
            break
        crop_top += 1

    crop_bottom = 0
    for i in reversed(range(h)):
        if not (mask[i] == 0).all():
            break
        crop_bottom += 1

    return (
        int(max(crop_left-pad, 0)),
        int(max(crop_top-pad, 0)),
        int(min(w - crop_right + pad, w)),
        int(min(h - crop_bottom + pad, h))
    )


def fill(image, mask):
    image_mod = Image.new('RGBA', (image.width, image.height))

    image_masked = Image.new('RGBa', (image.width, image.height))
    image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))

    image_masked = image_masked.convert('RGBa')

A
AUTOMATIC 已提交
354
    for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
355 356 357 358 359 360 361 362 363 364
        blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
        for _ in range(repeats):
            image_mod.alpha_composite(blurred)

    return image_mod.convert("RGB")


class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    sampler = None

A
AUTOMATIC 已提交
365
    def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpainting_mask_invert=0, **kwargs):
366 367 368 369 370 371 372
        super().__init__(**kwargs)

        self.init_images = init_images
        self.resize_mode: int = resize_mode
        self.denoising_strength: float = denoising_strength
        self.init_latent = None
        self.image_mask = mask
A
AUTOMATIC 已提交
373 374
        #self.image_unblurred_mask = None
        self.latent_mask = None
375 376 377 378
        self.mask_for_overlay = None
        self.mask_blur = mask_blur
        self.inpainting_fill = inpainting_fill
        self.inpaint_full_res = inpaint_full_res
A
AUTOMATIC 已提交
379
        self.inpainting_mask_invert = inpainting_mask_invert
380 381 382
        self.mask = None
        self.nmask = None

383
    def init(self, seed):
384 385 386 387
        self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
        crop_region = None

        if self.image_mask is not None:
A
AUTOMATIC 已提交
388 389 390 391 392
            self.image_mask = self.image_mask.convert('L')

            if self.inpainting_mask_invert:
                self.image_mask = ImageOps.invert(self.image_mask)

A
AUTOMATIC 已提交
393 394
            #self.image_unblurred_mask = self.image_mask

395
            if self.mask_blur > 0:
A
AUTOMATIC 已提交
396
                self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
397 398 399 400

            if self.inpaint_full_res:
                self.mask_for_overlay = self.image_mask
                mask = self.image_mask.convert('L')
A
AUTOMATIC 已提交
401
                crop_region = get_crop_region(np.array(mask), opts.upscale_at_full_resolution_padding)
402 403 404 405 406 407 408
                x1, y1, x2, y2 = crop_region

                mask = mask.crop(crop_region)
                self.image_mask = images.resize_image(2, mask, self.width, self.height)
                self.paste_to = (x1, y1, x2-x1, y2-y1)
            else:
                self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
409
                np_mask = np.array(self.image_mask)
410
                np_mask = np.clip((np_mask.astype(np.float)) * 2, 0, 255).astype(np.uint8)
411
                self.mask_for_overlay = Image.fromarray(np_mask)
412 413 414

            self.overlay_images = []

415 416
        latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask

417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
        imgs = []
        for img in self.init_images:
            image = img.convert("RGB")

            if crop_region is None:
                image = images.resize_image(self.resize_mode, image, self.width, self.height)

            if self.image_mask is not None:
                image_masked = Image.new('RGBa', (image.width, image.height))
                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))

                self.overlay_images.append(image_masked.convert('RGBA'))

            if crop_region is not None:
                image = image.crop(crop_region)
                image = images.resize_image(2, image, self.width, self.height)

434 435 436 437
            if self.image_mask is not None:
                if self.inpainting_fill != 1:
                    image = fill(image, latent_mask)

438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)

            imgs.append(image)

        if len(imgs) == 1:
            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size
        elif len(imgs) <= self.batch_size:
            self.batch_size = len(imgs)
            batch_images = np.array(imgs)
        else:
            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

        image = torch.from_numpy(batch_images)
        image = 2. * image - 1.
        image = image.to(shared.device)

        self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))

        if self.image_mask is not None:
460
            init_mask = latent_mask
A
AUTOMATIC 已提交
461
            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
462 463
            latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255
            latmask = latmask[0]
464
            latmask = np.around(latmask)
465 466 467 468 469 470
            latmask = np.tile(latmask[None], (4, 1, 1))

            self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
            self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)

            if self.inpainting_fill == 2:
471
                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], [seed + x + 1 for x in range(self.init_latent.shape[0])]) * self.nmask
472 473 474 475 476 477 478 479 480 481
            elif self.inpainting_fill == 3:
                self.init_latent = self.init_latent * self.mask

    def sample(self, x, conditioning, unconditional_conditioning):
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)

        if self.mask is not None:
            samples = samples * self.nmask + self.init_latent * self.mask

        return samples