processing.py 22.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
import contextlib
import json
import math
import os
import sys

import torch
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import random
11 12
import cv2
from skimage import exposure
13

14
import modules.sd_hijack
A
AUTOMATIC 已提交
15
from modules import devices, prompt_parser
16 17 18 19
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
A
AUTOMATIC 已提交
20
import modules.face_restoration
21
import modules.images as images
A
AUTOMATIC 已提交
22
import modules.styles
23

24

25 26 27 28 29
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8


30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
def setup_color_correction(image):
    correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
    return correction_target


def apply_color_correction(correction, image):
    image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
        cv2.cvtColor(
            np.asarray(image),
            cv2.COLOR_RGB2LAB
        ),
        correction,
        channel_axis=2
    ), cv2.COLOR_LAB2RGB).astype("uint8"))

    return image


48
class StableDiffusionProcessing:
A
AUTOMATIC 已提交
49
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
50 51 52 53
        self.sd_model = sd_model
        self.outpath_samples: str = outpath_samples
        self.outpath_grids: str = outpath_grids
        self.prompt: str = prompt
A
AUTOMATIC 已提交
54
        self.prompt_for_display: str = None
55
        self.negative_prompt: str = (negative_prompt or "")
A
AUTOMATIC 已提交
56
        self.styles: str = styles
57
        self.seed: int = seed
58 59 60 61
        self.subseed: int = subseed
        self.subseed_strength: float = subseed_strength
        self.seed_resize_from_h: int = seed_resize_from_h
        self.seed_resize_from_w: int = seed_resize_from_w
62 63 64 65 66 67 68
        self.sampler_index: int = sampler_index
        self.batch_size: int = batch_size
        self.n_iter: int = n_iter
        self.steps: int = steps
        self.cfg_scale: float = cfg_scale
        self.width: int = width
        self.height: int = height
A
AUTOMATIC 已提交
69
        self.restore_faces: bool = restore_faces
70
        self.tiling: bool = tiling
71 72 73 74 75
        self.do_not_save_samples: bool = do_not_save_samples
        self.do_not_save_grid: bool = do_not_save_grid
        self.extra_generation_params: dict = extra_generation_params
        self.overlay_images = overlay_images
        self.paste_to = None
76
        self.color_corrections = None
77

78
    def init(self, seed):
79 80 81 82 83 84 85 86 87 88
        pass

    def sample(self, x, conditioning, unconditional_conditioning):
        raise NotImplementedError()


class Processed:
    def __init__(self, p: StableDiffusionProcessing, images_list, seed, info):
        self.images = images_list
        self.prompt = p.prompt
89
        self.negative_prompt = p.negative_prompt
90 91 92 93 94 95 96 97 98 99
        self.seed = seed
        self.info = info
        self.width = p.width
        self.height = p.height
        self.sampler = samplers[p.sampler_index].name
        self.cfg_scale = p.cfg_scale
        self.steps = p.steps

    def js(self):
        obj = {
A
AUTOMATIC 已提交
100
            "prompt": self.prompt if type(self.prompt) != list else self.prompt[0],
101
            "negative_prompt": self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0],
A
AUTOMATIC 已提交
102
            "seed": int(self.seed if type(self.seed) != list else self.seed[0]),
103 104 105 106 107 108 109 110 111
            "width": self.width,
            "height": self.height,
            "sampler": self.sampler,
            "cfg_scale": self.cfg_scale,
            "steps": self.steps,
        }

        return json.dumps(obj)

112 113 114 115 116 117 118 119
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm*high_norm).sum(1))
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
    return res
120

121

122
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
123
    xs = []
124

125 126 127 128
    # if we have multiple seeds, this means we are working with batch size>1; this then
    # enables the generation of additional tensors with noise that the sampler will use during its processing.
    # Using those pre-genrated tensors instead of siimple torch.randn allows a batch with seeds [100, 101] to
    # produce the same images as with two batches [100], [101].
129
    if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
130 131 132 133
        sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
    else:
        sampler_noises = None

134 135 136 137 138 139
    for i, seed in enumerate(seeds):
        noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)

        subnoise = None
        if subseeds is not None:
            subseed = 0 if i >= len(subseeds) else subseeds[i]
E
Elias Oenal 已提交
140

A
AUTOMATIC 已提交
141
            subnoise = devices.randn(subseed, noise_shape)
142 143 144

        # randn results depend on device; gpu and cpu get different results for same seed;
        # the way I see it, it's better to do this on CPU, so that everyone gets same result;
145
        # but the original script had it like this, so I do not dare change it for now because
146
        # it will break everyone's seeds.
A
AUTOMATIC 已提交
147
        noise = devices.randn(seed, noise_shape)
148 149 150 151 152 153 154

        if subnoise is not None:
            #noise = subnoise * subseed_strength + noise * (1 - subseed_strength)
            noise = slerp(subseed_strength, noise, subnoise)

        if noise_shape != shape:
            #noise = torch.nn.functional.interpolate(noise.unsqueeze(1), size=shape[1:], mode="bilinear").squeeze()
A
AUTOMATIC 已提交
155 156
            x = devices.randn(seed, shape)
            dx = (shape[2] - noise_shape[2]) // 2
157 158 159 160 161 162 163 164 165 166 167
            dy = (shape[1] - noise_shape[1]) // 2
            w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
            h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
            tx = 0 if dx < 0 else dx
            ty = 0 if dy < 0 else dy
            dx = max(-dx, 0)
            dy = max(-dy, 0)

            x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
            noise = x

168 169
        if sampler_noises is not None:
            cnt = p.sampler.number_of_needed_noises(p)
170

171 172
            for j in range(cnt):
                sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
173 174

        xs.append(noise)
175 176 177 178

    if sampler_noises is not None:
        p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]

179
    x = torch.stack(xs).to(shared.device)
180 181 182
    return x


183 184 185
def fix_seed(p):
    p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == -1 else p.seed
    p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed
A
AUTOMATIC 已提交
186 187


188 189 190 191
def process_images(p: StableDiffusionProcessing) -> Processed:
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

    assert p.prompt is not None
192
    devices.torch_gc()
193

194
    fix_seed(p)
195 196 197 198

    os.makedirs(p.outpath_samples, exist_ok=True)
    os.makedirs(p.outpath_grids, exist_ok=True)

199 200
    modules.sd_hijack.model_hijack.apply_circular(p.tiling)

201
    comments = {}
202

A
AUTOMATIC 已提交
203
    shared.prompt_styles.apply_styles(p)
A
AUTOMATIC 已提交
204 205 206

    if type(p.prompt) == list:
        all_prompts = p.prompt
207
    else:
A
AUTOMATIC 已提交
208
        all_prompts = p.batch_size * p.n_iter * [p.prompt]
A
AUTOMATIC 已提交
209

210
    if type(p.seed) == list:
L
Lukas Meller 已提交
211
        all_seeds = p.seed
A
AUTOMATIC 已提交
212
    else:
213
        all_seeds = [int(p.seed + (x if p.subseed_strength == 0 else 0)) for x in range(len(all_prompts))]
214 215 216 217 218

    if type(p.subseed) == list:
        all_subseeds = p.subseed
    else:
        all_subseeds = [int(p.subseed + x) for x in range(len(all_prompts))]
219 220

    def infotext(iteration=0, position_in_batch=0):
221 222
        index = position_in_batch + iteration * p.batch_size

223 224 225 226
        generation_params = {
            "Steps": p.steps,
            "Sampler": samplers[p.sampler_index].name,
            "CFG scale": p.cfg_scale,
227
            "Seed": all_seeds[index],
A
AUTOMATIC 已提交
228
            "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
229
            "Size": f"{p.width}x{p.height}",
230
            "Model hash": (None if not opts.add_model_hash_to_info or not shared.sd_model_hash else shared.sd_model_hash),
231 232
            "Batch size": (None if p.batch_size < 2 else p.batch_size),
            "Batch pos": (None if p.batch_size < 2 else position_in_batch),
233 234 235
            "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
            "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
            "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
236
            "Denoising strength": getattr(p, 'denoising_strength', None),
237 238 239 240 241 242
        }

        if p.extra_generation_params is not None:
            generation_params.update(p.extra_generation_params)

        generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
243 244
        
        negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
245

A
AUTOMATIC 已提交
246
        return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
247 248 249 250 251 252 253 254

    if os.path.exists(cmd_opts.embeddings_dir):
        model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)

    output_images = []
    precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
    ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
    with torch.no_grad(), precision_scope("cuda"), ema_scope():
255
        p.init(seed=all_seeds[0])
256

A
AUTOMATIC 已提交
257 258
        if state.job_count == -1:
            state.job_count = p.n_iter
A
AUTOMATIC 已提交
259

260 261 262 263 264 265
        for n in range(p.n_iter):
            if state.interrupted:
                break

            prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
266
            subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
267

A
AUTOMATIC 已提交
268 269 270 271
            #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
            #c = p.sd_model.get_learned_conditioning(prompts)
            uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
            c = prompt_parser.get_learned_conditioning(prompts, p.steps)
272 273

            if len(model_hijack.comments) > 0:
274 275
                for comment in model_hijack.comments:
                    comments[comment] = 1
276 277

            # we manually generate all input noises because each one should have a specific seed
278
            x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
279 280 281 282 283

            if p.n_iter > 1:
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"

            samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
A
AUTOMATIC 已提交
284 285 286 287 288
            if state.interrupted:

                # if we are interruped, sample returns just noise
                # use the image collected previously in sampler loop
                samples_ddim = shared.state.current_latent
289 290 291 292

            x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

G
GRMrGecko 已提交
293
            if opts.filter_nsfw:
294 295
                import modules.safety as safety
                x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
G
GRMrGecko 已提交
296

297 298 299 300
            for i, x_sample in enumerate(x_samples_ddim):
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)

A
AUTOMATIC 已提交
301
                if p.restore_faces:
302 303 304
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)

305
                    devices.torch_gc()
306

A
AUTOMATIC 已提交
307
                    x_sample = modules.face_restoration.restore_faces(x_sample)
308 309 310

                image = Image.fromarray(x_sample)

311 312
                if p.color_corrections is not None and i < len(p.color_corrections):
                    image = apply_color_correction(p.color_corrections[i], image)
313

314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
                if p.overlay_images is not None and i < len(p.overlay_images):
                    overlay = p.overlay_images[i]

                    if p.paste_to is not None:
                        x, y, w, h = p.paste_to
                        base_image = Image.new('RGBA', (overlay.width, overlay.height))
                        image = images.resize_image(1, image, w, h)
                        base_image.paste(image, (x, y))
                        image = base_image

                    image = image.convert('RGBA')
                    image.alpha_composite(overlay)
                    image = image.convert('RGB')

                if opts.samples_save and not p.do_not_save_samples:
329
                    images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
330 331 332

                output_images.append(image)

A
AUTOMATIC 已提交
333 334
            state.nextjob()

335
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
336
        if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
A
AUTOMATIC 已提交
337
            grid = images.image_grid(output_images, p.batch_size)
338

339
            if opts.return_grid:
340 341 342
                output_images.insert(0, grid)

            if opts.grid_save:
343
                images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p)
344

345
    devices.torch_gc()
346
    return Processed(p, output_images, all_seeds[0], infotext())
347 348 349 350 351


class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    sampler = None

352
    def init(self, seed):
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402
        self.sampler = samplers[self.sampler_index].constructor(self.sd_model)

    def sample(self, x, conditioning, unconditional_conditioning):
        samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
        return samples_ddim


def get_crop_region(mask, pad=0):
    h, w = mask.shape

    crop_left = 0
    for i in range(w):
        if not (mask[:, i] == 0).all():
            break
        crop_left += 1

    crop_right = 0
    for i in reversed(range(w)):
        if not (mask[:, i] == 0).all():
            break
        crop_right += 1

    crop_top = 0
    for i in range(h):
        if not (mask[i] == 0).all():
            break
        crop_top += 1

    crop_bottom = 0
    for i in reversed(range(h)):
        if not (mask[i] == 0).all():
            break
        crop_bottom += 1

    return (
        int(max(crop_left-pad, 0)),
        int(max(crop_top-pad, 0)),
        int(min(w - crop_right + pad, w)),
        int(min(h - crop_bottom + pad, h))
    )


def fill(image, mask):
    image_mod = Image.new('RGBA', (image.width, image.height))

    image_masked = Image.new('RGBa', (image.width, image.height))
    image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))

    image_masked = image_masked.convert('RGBa')

A
AUTOMATIC 已提交
403
    for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
404 405 406 407 408 409 410 411 412 413
        blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
        for _ in range(repeats):
            image_mod.alpha_composite(blurred)

    return image_mod.convert("RGB")


class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    sampler = None

A
AUTOMATIC 已提交
414
    def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpainting_mask_invert=0, **kwargs):
415 416 417 418 419 420 421
        super().__init__(**kwargs)

        self.init_images = init_images
        self.resize_mode: int = resize_mode
        self.denoising_strength: float = denoising_strength
        self.init_latent = None
        self.image_mask = mask
A
AUTOMATIC 已提交
422 423
        #self.image_unblurred_mask = None
        self.latent_mask = None
424 425 426 427
        self.mask_for_overlay = None
        self.mask_blur = mask_blur
        self.inpainting_fill = inpainting_fill
        self.inpaint_full_res = inpaint_full_res
A
AUTOMATIC 已提交
428
        self.inpainting_mask_invert = inpainting_mask_invert
429 430 431
        self.mask = None
        self.nmask = None

432
    def init(self, seed):
433 434 435 436
        self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
        crop_region = None

        if self.image_mask is not None:
A
AUTOMATIC 已提交
437 438 439 440 441
            self.image_mask = self.image_mask.convert('L')

            if self.inpainting_mask_invert:
                self.image_mask = ImageOps.invert(self.image_mask)

A
AUTOMATIC 已提交
442 443
            #self.image_unblurred_mask = self.image_mask

444
            if self.mask_blur > 0:
A
AUTOMATIC 已提交
445
                self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
446 447 448 449

            if self.inpaint_full_res:
                self.mask_for_overlay = self.image_mask
                mask = self.image_mask.convert('L')
A
AUTOMATIC 已提交
450
                crop_region = get_crop_region(np.array(mask), opts.upscale_at_full_resolution_padding)
451 452 453 454 455 456 457
                x1, y1, x2, y2 = crop_region

                mask = mask.crop(crop_region)
                self.image_mask = images.resize_image(2, mask, self.width, self.height)
                self.paste_to = (x1, y1, x2-x1, y2-y1)
            else:
                self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
458
                np_mask = np.array(self.image_mask)
J
JJ 已提交
459
                np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
460
                self.mask_for_overlay = Image.fromarray(np_mask)
461 462 463

            self.overlay_images = []

464 465
        latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask

466 467 468
        add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
        if add_color_corrections:
            self.color_corrections = []
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
        imgs = []
        for img in self.init_images:
            image = img.convert("RGB")

            if crop_region is None:
                image = images.resize_image(self.resize_mode, image, self.width, self.height)

            if self.image_mask is not None:
                image_masked = Image.new('RGBa', (image.width, image.height))
                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))

                self.overlay_images.append(image_masked.convert('RGBA'))

            if crop_region is not None:
                image = image.crop(crop_region)
                image = images.resize_image(2, image, self.width, self.height)

486 487 488 489
            if self.image_mask is not None:
                if self.inpainting_fill != 1:
                    image = fill(image, latent_mask)

490
            if add_color_corrections:
491 492
                self.color_corrections.append(setup_color_correction(image))

493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)

            imgs.append(image)

        if len(imgs) == 1:
            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size
        elif len(imgs) <= self.batch_size:
            self.batch_size = len(imgs)
            batch_images = np.array(imgs)
        else:
            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

        image = torch.from_numpy(batch_images)
        image = 2. * image - 1.
        image = image.to(shared.device)

        self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))

        if self.image_mask is not None:
515
            init_mask = latent_mask
A
AUTOMATIC 已提交
516
            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
A
AUTOMATIC 已提交
517
            latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
518
            latmask = latmask[0]
519
            latmask = np.around(latmask)
520 521 522 523 524 525
            latmask = np.tile(latmask[None], (4, 1, 1))

            self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
            self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)

            if self.inpainting_fill == 2:
526
                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], [seed + x + 1 for x in range(self.init_latent.shape[0])]) * self.nmask
527 528 529 530 531 532 533 534 535 536
            elif self.inpainting_fill == 3:
                self.init_latent = self.init_latent * self.mask

    def sample(self, x, conditioning, unconditional_conditioning):
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)

        if self.mask is not None:
            samples = samples * self.nmask + self.init_latent * self.mask

        return samples