processing.py 19.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
import contextlib
import json
import math
import os
import sys

import torch
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import random

12
import modules.sd_hijack
13 14 15 16
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
A
AUTOMATIC 已提交
17
import modules.face_restoration
18
import modules.images as images
A
AUTOMATIC 已提交
19
import modules.styles
20 21 22 23 24 25 26 27 28 29 30 31 32

# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8


def torch_gc():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()


class StableDiffusionProcessing:
A
AUTOMATIC 已提交
33
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style=0, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
34 35 36 37
        self.sd_model = sd_model
        self.outpath_samples: str = outpath_samples
        self.outpath_grids: str = outpath_grids
        self.prompt: str = prompt
A
AUTOMATIC 已提交
38
        self.prompt_for_display: str = None
39
        self.negative_prompt: str = (negative_prompt or "")
A
AUTOMATIC 已提交
40
        self.prompt_style: int = prompt_style
41
        self.seed: int = seed
42 43 44 45
        self.subseed: int = subseed
        self.subseed_strength: float = subseed_strength
        self.seed_resize_from_h: int = seed_resize_from_h
        self.seed_resize_from_w: int = seed_resize_from_w
46 47 48 49 50 51 52
        self.sampler_index: int = sampler_index
        self.batch_size: int = batch_size
        self.n_iter: int = n_iter
        self.steps: int = steps
        self.cfg_scale: float = cfg_scale
        self.width: int = width
        self.height: int = height
A
AUTOMATIC 已提交
53
        self.restore_faces: bool = restore_faces
54
        self.tiling: bool = tiling
55 56 57 58 59 60
        self.do_not_save_samples: bool = do_not_save_samples
        self.do_not_save_grid: bool = do_not_save_grid
        self.extra_generation_params: dict = extra_generation_params
        self.overlay_images = overlay_images
        self.paste_to = None

61
    def init(self, seed):
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        pass

    def sample(self, x, conditioning, unconditional_conditioning):
        raise NotImplementedError()


class Processed:
    def __init__(self, p: StableDiffusionProcessing, images_list, seed, info):
        self.images = images_list
        self.prompt = p.prompt
        self.seed = seed
        self.info = info
        self.width = p.width
        self.height = p.height
        self.sampler = samplers[p.sampler_index].name
        self.cfg_scale = p.cfg_scale
        self.steps = p.steps

    def js(self):
        obj = {
A
AUTOMATIC 已提交
82 83
            "prompt": self.prompt if type(self.prompt) != list else self.prompt[0],
            "seed": int(self.seed if type(self.seed) != list else self.seed[0]),
84 85 86 87 88 89 90 91 92
            "width": self.width,
            "height": self.height,
            "sampler": self.sampler,
            "cfg_scale": self.cfg_scale,
            "steps": self.steps,
        }

        return json.dumps(obj)

93 94 95 96 97 98 99 100
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm*high_norm).sum(1))
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
    return res
101

102 103

def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
104
    xs = []
105 106 107 108 109 110 111 112
    for i, seed in enumerate(seeds):
        noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)

        subnoise = None
        if subseeds is not None:
            subseed = 0 if i >= len(subseeds) else subseeds[i]
            torch.manual_seed(subseed)
            subnoise = torch.randn(noise_shape, device=shared.device)
113 114 115

        # randn results depend on device; gpu and cpu get different results for same seed;
        # the way I see it, it's better to do this on CPU, so that everyone gets same result;
116
        # but the original script had it like this, so I do not dare change it for now because
117
        # it will break everyone's seeds.
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
        torch.manual_seed(seed)
        noise = torch.randn(noise_shape, device=shared.device)

        if subnoise is not None:
            #noise = subnoise * subseed_strength + noise * (1 - subseed_strength)
            noise = slerp(subseed_strength, noise, subnoise)

        if noise_shape != shape:
            #noise = torch.nn.functional.interpolate(noise.unsqueeze(1), size=shape[1:], mode="bilinear").squeeze()
            # noise_shape = (64, 80)
            # shape = (64, 72)

            torch.manual_seed(seed)
            x = torch.randn(shape, device=shared.device)
            dx = (shape[2] - noise_shape[2]) // 2 # -4
            dy = (shape[1] - noise_shape[1]) // 2
            w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
            h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
            tx = 0 if dx < 0 else dx
            ty = 0 if dy < 0 else dy
            dx = max(-dx, 0)
            dy = max(-dy, 0)

            x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
            noise = x



        xs.append(noise)
    x = torch.stack(xs).to(shared.device)
148 149 150
    return x


151 152 153
def fix_seed(p):
    p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == -1 else p.seed
    p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed
A
AUTOMATIC 已提交
154 155


156 157 158 159 160 161
def process_images(p: StableDiffusionProcessing) -> Processed:
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

    assert p.prompt is not None
    torch_gc()

162
    fix_seed(p)
163 164 165 166

    os.makedirs(p.outpath_samples, exist_ok=True)
    os.makedirs(p.outpath_grids, exist_ok=True)

167 168
    modules.sd_hijack.model_hijack.apply_circular(p.tiling)

169 170
    comments = []

A
AUTOMATIC 已提交
171 172 173 174
    modules.styles.apply_style(p, shared.prompt_styles[p.prompt_style])

    if type(p.prompt) == list:
        all_prompts = p.prompt
175
    else:
A
AUTOMATIC 已提交
176
        all_prompts = p.batch_size * p.n_iter * [p.prompt]
A
AUTOMATIC 已提交
177

178 179
    if type(p.seed) == list:
        all_seeds = int(p.seed)
A
AUTOMATIC 已提交
180
    else:
181 182 183 184 185 186
        all_seeds = [int(p.seed + x) for x in range(len(all_prompts))]

    if type(p.subseed) == list:
        all_subseeds = p.subseed
    else:
        all_subseeds = [int(p.subseed + x) for x in range(len(all_prompts))]
187 188

    def infotext(iteration=0, position_in_batch=0):
189 190
        index = position_in_batch + iteration * p.batch_size

191 192 193 194
        generation_params = {
            "Steps": p.steps,
            "Sampler": samplers[p.sampler_index].name,
            "CFG scale": p.cfg_scale,
195
            "Seed": all_seeds[index],
A
AUTOMATIC 已提交
196
            "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
197
            "Size": f"{p.width}x{p.height}",
198 199
            "Batch size": (None if p.batch_size < 2 else p.batch_size),
            "Batch pos": (None if p.batch_size < 2 else position_in_batch),
200 201 202
            "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
            "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
            "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
203 204 205 206 207 208
        }

        if p.extra_generation_params is not None:
            generation_params.update(p.extra_generation_params)

        generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
209 210
        
        negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
211

A
AUTOMATIC 已提交
212
        return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
213 214 215 216 217 218 219 220

    if os.path.exists(cmd_opts.embeddings_dir):
        model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)

    output_images = []
    precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
    ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
    with torch.no_grad(), precision_scope("cuda"), ema_scope():
221
        p.init(seed=all_seeds[0])
222

A
AUTOMATIC 已提交
223 224
        if state.job_count == -1:
            state.job_count = p.n_iter
A
AUTOMATIC 已提交
225

226 227 228 229 230 231 232 233 234 235 236 237 238 239
        for n in range(p.n_iter):
            if state.interrupted:
                break

            prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]

            uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
            c = p.sd_model.get_learned_conditioning(prompts)

            if len(model_hijack.comments) > 0:
                comments += model_hijack.comments

            # we manually generate all input noises because each one should have a specific seed
240
            x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=all_subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
241 242 243 244 245

            if p.n_iter > 1:
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"

            samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
A
AUTOMATIC 已提交
246 247 248 249 250
            if state.interrupted:

                # if we are interruped, sample returns just noise
                # use the image collected previously in sampler loop
                samples_ddim = shared.state.current_latent
251 252 253 254 255 256 257 258

            x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

            for i, x_sample in enumerate(x_samples_ddim):
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)

A
AUTOMATIC 已提交
259
                if p.restore_faces:
260 261
                    torch_gc()

A
AUTOMATIC 已提交
262
                    x_sample = modules.face_restoration.restore_faces(x_sample)
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284

                image = Image.fromarray(x_sample)

                if p.overlay_images is not None and i < len(p.overlay_images):
                    overlay = p.overlay_images[i]

                    if p.paste_to is not None:
                        x, y, w, h = p.paste_to
                        base_image = Image.new('RGBA', (overlay.width, overlay.height))
                        image = images.resize_image(1, image, w, h)
                        base_image.paste(image, (x, y))
                        image = base_image

                    image = image.convert('RGBA')
                    image.alpha_composite(overlay)
                    image = image.convert('RGB')

                if opts.samples_save and not p.do_not_save_samples:
                    images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i))

                output_images.append(image)

A
AUTOMATIC 已提交
285 286
            state.nextjob()

287 288 289 290
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
        if not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
            return_grid = opts.return_grid

A
AUTOMATIC 已提交
291
            grid = images.image_grid(output_images, p.batch_size)
292 293 294 295 296

            if return_grid:
                output_images.insert(0, grid)

            if opts.grid_save:
297
                images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
298 299

    torch_gc()
300
    return Processed(p, output_images, all_seeds[0], infotext())
301 302 303 304 305


class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    sampler = None

306
    def init(self, seed):
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
        self.sampler = samplers[self.sampler_index].constructor(self.sd_model)

    def sample(self, x, conditioning, unconditional_conditioning):
        samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
        return samples_ddim


def get_crop_region(mask, pad=0):
    h, w = mask.shape

    crop_left = 0
    for i in range(w):
        if not (mask[:, i] == 0).all():
            break
        crop_left += 1

    crop_right = 0
    for i in reversed(range(w)):
        if not (mask[:, i] == 0).all():
            break
        crop_right += 1

    crop_top = 0
    for i in range(h):
        if not (mask[i] == 0).all():
            break
        crop_top += 1

    crop_bottom = 0
    for i in reversed(range(h)):
        if not (mask[i] == 0).all():
            break
        crop_bottom += 1

    return (
        int(max(crop_left-pad, 0)),
        int(max(crop_top-pad, 0)),
        int(min(w - crop_right + pad, w)),
        int(min(h - crop_bottom + pad, h))
    )


def fill(image, mask):
    image_mod = Image.new('RGBA', (image.width, image.height))

    image_masked = Image.new('RGBa', (image.width, image.height))
    image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))

    image_masked = image_masked.convert('RGBa')

A
AUTOMATIC 已提交
357
    for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
358 359 360 361 362 363 364 365 366 367
        blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
        for _ in range(repeats):
            image_mod.alpha_composite(blurred)

    return image_mod.convert("RGB")


class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    sampler = None

A
AUTOMATIC 已提交
368
    def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpainting_mask_invert=0, **kwargs):
369 370 371 372 373 374 375
        super().__init__(**kwargs)

        self.init_images = init_images
        self.resize_mode: int = resize_mode
        self.denoising_strength: float = denoising_strength
        self.init_latent = None
        self.image_mask = mask
A
AUTOMATIC 已提交
376 377
        #self.image_unblurred_mask = None
        self.latent_mask = None
378 379 380 381
        self.mask_for_overlay = None
        self.mask_blur = mask_blur
        self.inpainting_fill = inpainting_fill
        self.inpaint_full_res = inpaint_full_res
A
AUTOMATIC 已提交
382
        self.inpainting_mask_invert = inpainting_mask_invert
383 384 385
        self.mask = None
        self.nmask = None

386
    def init(self, seed):
387 388 389 390
        self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
        crop_region = None

        if self.image_mask is not None:
A
AUTOMATIC 已提交
391 392 393 394 395
            self.image_mask = self.image_mask.convert('L')

            if self.inpainting_mask_invert:
                self.image_mask = ImageOps.invert(self.image_mask)

A
AUTOMATIC 已提交
396 397
            #self.image_unblurred_mask = self.image_mask

398
            if self.mask_blur > 0:
A
AUTOMATIC 已提交
399
                self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
400 401 402 403

            if self.inpaint_full_res:
                self.mask_for_overlay = self.image_mask
                mask = self.image_mask.convert('L')
A
AUTOMATIC 已提交
404
                crop_region = get_crop_region(np.array(mask), opts.upscale_at_full_resolution_padding)
405 406 407 408 409 410 411
                x1, y1, x2, y2 = crop_region

                mask = mask.crop(crop_region)
                self.image_mask = images.resize_image(2, mask, self.width, self.height)
                self.paste_to = (x1, y1, x2-x1, y2-y1)
            else:
                self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
412
                np_mask = np.array(self.image_mask)
413
                np_mask = np.clip((np_mask.astype(np.float)) * 2, 0, 255).astype(np.uint8)
414
                self.mask_for_overlay = Image.fromarray(np_mask)
415 416 417

            self.overlay_images = []

418 419
        latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask

420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
        imgs = []
        for img in self.init_images:
            image = img.convert("RGB")

            if crop_region is None:
                image = images.resize_image(self.resize_mode, image, self.width, self.height)

            if self.image_mask is not None:
                image_masked = Image.new('RGBa', (image.width, image.height))
                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))

                self.overlay_images.append(image_masked.convert('RGBA'))

            if crop_region is not None:
                image = image.crop(crop_region)
                image = images.resize_image(2, image, self.width, self.height)

437 438 439 440
            if self.image_mask is not None:
                if self.inpainting_fill != 1:
                    image = fill(image, latent_mask)

441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)

            imgs.append(image)

        if len(imgs) == 1:
            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size
        elif len(imgs) <= self.batch_size:
            self.batch_size = len(imgs)
            batch_images = np.array(imgs)
        else:
            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

        image = torch.from_numpy(batch_images)
        image = 2. * image - 1.
        image = image.to(shared.device)

        self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))

        if self.image_mask is not None:
463
            init_mask = latent_mask
A
AUTOMATIC 已提交
464
            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
465 466
            latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255
            latmask = latmask[0]
467
            latmask = np.around(latmask)
468 469 470 471 472 473
            latmask = np.tile(latmask[None], (4, 1, 1))

            self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
            self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)

            if self.inpainting_fill == 2:
474
                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], [seed + x + 1 for x in range(self.init_latent.shape[0])]) * self.nmask
475 476 477 478 479 480 481 482 483 484
            elif self.inpainting_fill == 3:
                self.init_latent = self.init_latent * self.mask

    def sample(self, x, conditioning, unconditional_conditioning):
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)

        if self.mask is not None:
            samples = samples * self.nmask + self.init_latent * self.mask

        return samples