processing.py 45.5 KB
Newer Older
1 2 3 4
import json
import math
import os
import sys
5
import warnings
6 7 8 9 10

import torch
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import random
11 12
import cv2
from skimage import exposure
A
arcticfaded 已提交
13
from typing import Any, Dict, List, Optional
14

15
import modules.sd_hijack
16
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
17 18 19
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
A
AUTOMATIC 已提交
20
import modules.face_restoration
21
import modules.images as images
A
AUTOMATIC 已提交
22
import modules.styles
23 24
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
R
Robin Fernandes 已提交
25
import logging
J
Jay Smith 已提交
26 27
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
28

J
Jay Smith 已提交
29
from einops import repeat, rearrange
30
from blendmodes.blend import blendLayers, BlendType
31

32 33 34 35 36
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8


37
def setup_color_correction(image):
R
Robin Fernandes 已提交
38
    logging.info("Calibrating color correction.")
39 40 41 42
    correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
    return correction_target


43
def apply_color_correction(correction, original_image):
R
Robin Fernandes 已提交
44
    logging.info("Applying color correction.")
45 46
    image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
        cv2.cvtColor(
47
            np.asarray(original_image),
48 49 50 51 52
            cv2.COLOR_RGB2LAB
        ),
        correction,
        channel_axis=2
    ), cv2.COLOR_LAB2RGB).astype("uint8"))
53 54 55
    
    image = blendLayers(image, original_image, BlendType.LUMINOSITY)
    
56 57
    return image

A
AUTOMATIC 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

def apply_overlay(image, paste_loc, index, overlays):
    if overlays is None or index >= len(overlays):
        return image

    overlay = overlays[index]

    if paste_loc is not None:
        x, y, w, h = paste_loc
        base_image = Image.new('RGBA', (overlay.width, overlay.height))
        image = images.resize_image(1, image, w, h)
        base_image.paste(image, (x, y))
        image = base_image

    image = image.convert('RGBA')
    image.alpha_composite(overlay)
    image = image.convert('RGB')
75 76

    return image
77

F
frostydad 已提交
78

79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
def txt2img_image_conditioning(sd_model, x, width, height):
    if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
        # Dummy zero conditioning if we're not using inpainting model.
        # Still takes up a bit of memory, but no encoder call.
        # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
        return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)

    # The "masked-image" in this case will just be all zeros since the entire image is masked.
    image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
    image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))

    # Add the fake full 1s mask to the first dimension.
    image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
    image_conditioning = image_conditioning.to(x.dtype)

    return image_conditioning


A
arcticfaded 已提交
97 98 99 100
class StableDiffusionProcessing():
    """
    The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
    """
101
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
102
        if sampler_index is not None:
103
            print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
104

105 106 107 108
        self.sd_model = sd_model
        self.outpath_samples: str = outpath_samples
        self.outpath_grids: str = outpath_grids
        self.prompt: str = prompt
A
AUTOMATIC 已提交
109
        self.prompt_for_display: str = None
110
        self.negative_prompt: str = (negative_prompt or "")
111
        self.styles: list = styles or []
112
        self.seed: int = seed
113 114 115 116
        self.subseed: int = subseed
        self.subseed_strength: float = subseed_strength
        self.seed_resize_from_h: int = seed_resize_from_h
        self.seed_resize_from_w: int = seed_resize_from_w
117
        self.sampler_name: str = sampler_name
118 119 120 121 122 123
        self.batch_size: int = batch_size
        self.n_iter: int = n_iter
        self.steps: int = steps
        self.cfg_scale: float = cfg_scale
        self.width: int = width
        self.height: int = height
A
AUTOMATIC 已提交
124
        self.restore_faces: bool = restore_faces
125
        self.tiling: bool = tiling
126 127
        self.do_not_save_samples: bool = do_not_save_samples
        self.do_not_save_grid: bool = do_not_save_grid
A
AUTOMATIC 已提交
128
        self.extra_generation_params: dict = extra_generation_params or {}
129
        self.overlay_images = overlay_images
130
        self.eta = eta
131
        self.do_not_reload_embeddings = do_not_reload_embeddings
132
        self.paste_to = None
133
        self.color_corrections = None
134
        self.denoising_strength: float = denoising_strength
135
        self.sampler_noise_scheduler_override = None
136
        self.ddim_discretize = ddim_discretize or opts.ddim_discretize
A
arcticfaded 已提交
137 138 139 140
        self.s_churn = s_churn or opts.s_churn
        self.s_tmin = s_tmin or opts.s_tmin
        self.s_tmax = s_tmax or float('inf')  # not representable as a standard ui option
        self.s_noise = s_noise or opts.s_noise
141
        self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
142
        self.override_settings_restore_afterwards = override_settings_restore_afterwards
143
        self.is_using_inpainting_conditioning = False
144

145 146 147 148 149 150
        if not seed_enable_extras:
            self.subseed = -1
            self.subseed_strength = 0
            self.seed_resize_from_h = 0
            self.seed_resize_from_w = 0

151 152 153
        self.scripts = None
        self.script_args = None
        self.all_prompts = None
154
        self.all_negative_prompts = None
155 156
        self.all_seeds = None
        self.all_subseeds = None
157
        self.iteration = 0
158

159
    def txt2img_image_conditioning(self, x, width=None, height=None):
160
        self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
161

162
        return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
163

J
Jay Smith 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
    def depth2img_image_conditioning(self, source_image):
        # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
        transformer = AddMiDaS(model_type="dpt_hybrid")
        transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
        midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
        midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)

        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
        conditioning = torch.nn.functional.interpolate(
            self.sd_model.depth_model(midas_in),
            size=conditioning_image.shape[2:],
            mode="bicubic",
            align_corners=False,
        )

        (depth_min, depth_max) = torch.aminmax(conditioning)
        conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
        return conditioning
182

J
Jay Smith 已提交
183
    def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
184 185
        self.is_using_inpainting_conditioning = True

186 187 188 189 190 191 192 193 194 195 196 197
        # Handle the different mask inputs
        if image_mask is not None:
            if torch.is_tensor(image_mask):
                conditioning_mask = image_mask
            else:
                conditioning_mask = np.array(image_mask.convert("L"))
                conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
                conditioning_mask = torch.from_numpy(conditioning_mask[None, None])

                # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
                conditioning_mask = torch.round(conditioning_mask)
        else:
198
            conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
199 200 201

        # Create another latent image, this time with a masked version of the original input.
        # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
R
random_thoughtss 已提交
202
        conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
203 204 205 206 207
        conditioning_image = torch.lerp(
            source_image,
            source_image * (1.0 - conditioning_mask),
            getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
        )
J
Jim Hays 已提交
208

209 210 211 212 213 214 215 216 217 218 219
        # Encode the new masked image using first stage of network.
        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))

        # Create the concatenated conditioning tensor to be fed to `c_concat`
        conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
        conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
        image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
        image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)

        return image_conditioning

J
Jay Smith 已提交
220 221 222 223 224 225 226 227 228 229 230 231
    def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
        # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
        # identify itself with a field common to all models. The conditioning_key is also hybrid.
        if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
            return self.depth2img_image_conditioning(source_image)

        if self.sampler.conditioning_key in {'hybrid', 'concat'}:
            return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)

        # Dummy zero conditioning if we're not using inpainting or depth model.
        return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)

A
AUTOMATIC 已提交
232
    def init(self, all_prompts, all_seeds, all_subseeds):
233 234
        pass

235
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
236 237
        raise NotImplementedError()

238 239 240 241
    def close(self):
        self.sd_model = None
        self.sampler = None

242 243

class Processed:
244
    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
245 246
        self.images = images_list
        self.prompt = p.prompt
247
        self.negative_prompt = p.negative_prompt
248
        self.seed = seed
249 250
        self.subseed = subseed
        self.subseed_strength = p.subseed_strength
251
        self.info = info
252
        self.comments = comments
253 254
        self.width = p.width
        self.height = p.height
255
        self.sampler_name = p.sampler_name
256 257
        self.cfg_scale = p.cfg_scale
        self.steps = p.steps
258 259 260 261 262 263 264 265 266
        self.batch_size = p.batch_size
        self.restore_faces = p.restore_faces
        self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
        self.sd_model_hash = shared.sd_model.sd_model_hash
        self.seed_resize_from_w = p.seed_resize_from_w
        self.seed_resize_from_h = p.seed_resize_from_h
        self.denoising_strength = getattr(p, 'denoising_strength', None)
        self.extra_generation_params = p.extra_generation_params
        self.index_of_first_image = index_of_first_image
M
Milly 已提交
267
        self.styles = p.styles
M
Milly 已提交
268
        self.job_timestamp = state.job_timestamp
269
        self.clip_skip = opts.CLIP_stop_at_last_layers
270

C
C43H66N12O12S2 已提交
271
        self.eta = p.eta
272 273 274 275 276
        self.ddim_discretize = p.ddim_discretize
        self.s_churn = p.s_churn
        self.s_tmin = p.s_tmin
        self.s_tmax = p.s_tmax
        self.s_noise = p.s_noise
277
        self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
278 279
        self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
        self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
G
githublsx 已提交
280
        self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
281
        self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
282
        self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
283

284 285 286 287
        self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
        self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
        self.all_seeds = all_seeds or p.all_seeds or [self.seed]
        self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
288
        self.infotexts = infotexts or [info]
289 290 291

    def js(self):
        obj = {
292
            "prompt": self.all_prompts[0],
293
            "all_prompts": self.all_prompts,
294 295
            "negative_prompt": self.all_negative_prompts[0],
            "all_negative_prompts": self.all_negative_prompts,
296 297 298 299
            "seed": self.seed,
            "all_seeds": self.all_seeds,
            "subseed": self.subseed,
            "all_subseeds": self.all_subseeds,
300
            "subseed_strength": self.subseed_strength,
301 302
            "width": self.width,
            "height": self.height,
303
            "sampler_name": self.sampler_name,
304 305
            "cfg_scale": self.cfg_scale,
            "steps": self.steps,
306 307 308 309 310 311 312 313 314
            "batch_size": self.batch_size,
            "restore_faces": self.restore_faces,
            "face_restoration_model": self.face_restoration_model,
            "sd_model_hash": self.sd_model_hash,
            "seed_resize_from_w": self.seed_resize_from_w,
            "seed_resize_from_h": self.seed_resize_from_h,
            "denoising_strength": self.denoising_strength,
            "extra_generation_params": self.extra_generation_params,
            "index_of_first_image": self.index_of_first_image,
315
            "infotexts": self.infotexts,
M
Milly 已提交
316
            "styles": self.styles,
M
Milly 已提交
317
            "job_timestamp": self.job_timestamp,
318
            "clip_skip": self.clip_skip,
319
            "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
320 321 322 323
        }

        return json.dumps(obj)

S
space-nuko 已提交
324
    def infotext(self, p: StableDiffusionProcessing, index):
325 326 327
        return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)


328 329 330 331
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
332 333 334 335 336 337
    dot = (low_norm*high_norm).sum(1)

    if dot.mean() > 0.9995:
        return low * val + high * (1 - val)

    omega = torch.acos(dot)
338 339 340
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
    return res
341

342

343
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
344
    eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
345
    xs = []
346

347 348
    # if we have multiple seeds, this means we are working with batch size>1; this then
    # enables the generation of additional tensors with noise that the sampler will use during its processing.
S
Steve Eberhardt 已提交
349
    # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
350
    # produce the same images as with two batches [100], [101].
351
    if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
352 353 354 355
        sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
    else:
        sampler_noises = None

356 357 358 359 360 361
    for i, seed in enumerate(seeds):
        noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)

        subnoise = None
        if subseeds is not None:
            subseed = 0 if i >= len(subseeds) else subseeds[i]
E
Elias Oenal 已提交
362

A
AUTOMATIC 已提交
363
            subnoise = devices.randn(subseed, noise_shape)
364 365 366

        # randn results depend on device; gpu and cpu get different results for same seed;
        # the way I see it, it's better to do this on CPU, so that everyone gets same result;
367
        # but the original script had it like this, so I do not dare change it for now because
368
        # it will break everyone's seeds.
A
AUTOMATIC 已提交
369
        noise = devices.randn(seed, noise_shape)
370 371 372 373 374

        if subnoise is not None:
            noise = slerp(subseed_strength, noise, subnoise)

        if noise_shape != shape:
A
AUTOMATIC 已提交
375 376
            x = devices.randn(seed, shape)
            dx = (shape[2] - noise_shape[2]) // 2
377 378 379 380 381 382 383 384 385 386 387
            dy = (shape[1] - noise_shape[1]) // 2
            w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
            h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
            tx = 0 if dx < 0 else dx
            ty = 0 if dy < 0 else dy
            dx = max(-dx, 0)
            dy = max(-dy, 0)

            x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
            noise = x

388 389
        if sampler_noises is not None:
            cnt = p.sampler.number_of_needed_noises(p)
390

391 392
            if eta_noise_seed_delta > 0:
                torch.manual_seed(seed + eta_noise_seed_delta)
A
AUTOMATIC 已提交
393

394 395
            for j in range(cnt):
                sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
396 397

        xs.append(noise)
398 399 400 401

    if sampler_noises is not None:
        p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]

402
    x = torch.stack(xs).to(shared.device)
403 404 405
    return x


A
AUTOMATIC 已提交
406 407 408 409 410 411 412
def decode_first_stage(model, x):
    with devices.autocast(disable=x.dtype == devices.dtype_vae):
        x = model.decode_first_stage(x)

    return x


413 414 415 416 417 418 419
def get_fixed_seed(seed):
    if seed is None or seed == '' or seed == -1:
        return int(random.randrange(4294967294))

    return seed


420
def fix_seed(p):
421 422
    p.seed = get_fixed_seed(p.seed)
    p.subseed = get_fixed_seed(p.subseed)
A
AUTOMATIC 已提交
423 424


425
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
426 427
    index = position_in_batch + iteration * p.batch_size

428
    clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
429

430 431
    generation_params = {
        "Steps": p.steps,
432
        "Sampler": p.sampler_name,
433 434 435 436 437
        "CFG scale": p.cfg_scale,
        "Seed": all_seeds[index],
        "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
        "Size": f"{p.width}x{p.height}",
        "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
438
        "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
439
        "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
S
space-nuko 已提交
440
        "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
441
        "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
442 443 444 445 446 447
        "Batch size": (None if p.batch_size < 2 else p.batch_size),
        "Batch pos": (None if p.batch_size < 2 else position_in_batch),
        "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
        "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
        "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
        "Denoising strength": getattr(p, 'denoising_strength', None),
448
        "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
449
        "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
450
        "Clip skip": None if clip_skip <= 1 else clip_skip,
A
AUTOMATIC 已提交
451
        "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
452 453
    }

A
AUTOMATIC 已提交
454
    generation_params.update(p.extra_generation_params)
455

456
    generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
457

S
space-nuko 已提交
458
    negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
459

460
    return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
461 462


463
def process_images(p: StableDiffusionProcessing) -> Processed:
464 465 466 467
    stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}

    try:
        for k, v in p.override_settings.items():
468 469 470 471
            setattr(opts, k, v)
            if k == 'sd_hypernetwork': shared.reload_hypernetworks()  # make onchange call for changing hypernet
            if k == 'sd_model_checkpoint': sd_models.reload_model_weights()  # make onchange call for changing SD model
            if k == 'sd_vae': sd_vae.reload_vae_weights()  # make onchange call for changing VAE
472 473 474

        res = process_images_inner(p)

475 476 477 478 479 480 481 482
    finally:
        # restore opts to original state
        if p.override_settings_restore_afterwards:
            for k, v in stored_opts.items():
                setattr(opts, k, v)
                if k == 'sd_hypernetwork': shared.reload_hypernetworks()
                if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
                if k == 'sd_vae': sd_vae.reload_vae_weights()
483 484 485 486 487

    return res


def process_images_inner(p: StableDiffusionProcessing) -> Processed:
488 489
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

490 491 492 493
    if type(p.prompt) == list:
        assert(len(p.prompt) > 0)
    else:
        assert p.prompt is not None
494

495
    devices.torch_gc()
496

497 498
    seed = get_fixed_seed(p.seed)
    subseed = get_fixed_seed(p.subseed)
499

500
    modules.sd_hijack.model_hijack.apply_circular(p.tiling)
501
    modules.sd_hijack.model_hijack.clear_comments()
502

503
    comments = {}
A
AUTOMATIC 已提交
504 505

    if type(p.prompt) == list:
506 507 508 509 510 511
        p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
    else:
        p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]

    if type(p.negative_prompt) == list:
        p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
512
    else:
513
        p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
A
AUTOMATIC 已提交
514

515
    if type(seed) == list:
516
        p.all_seeds = seed
A
AUTOMATIC 已提交
517
    else:
518
        p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
519

520
    if type(subseed) == list:
521
        p.all_subseeds = subseed
522
    else:
523
        p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
524 525

    def infotext(iteration=0, position_in_batch=0):
526
        return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
527

528 529 530 531
    with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
        processed = Processed(p, [], p.seed, "")
        file.write(processed.infotext(p, 0))

532
    if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
533
        model_hijack.embedding_db.load_textual_inversion_embeddings()
534

535
    if p.scripts is not None:
A
AUTOMATIC 已提交
536
        p.scripts.process(p)
537

538
    infotexts = []
539
    output_images = []
540

541
    with torch.no_grad(), p.sd_model.ema_scope():
A
AUTOMATIC 已提交
542
        with devices.autocast():
543
            p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
544

A
AUTOMATIC 已提交
545 546
        if state.job_count == -1:
            state.job_count = p.n_iter
A
AUTOMATIC 已提交
547

548
        for n in range(p.n_iter):
549 550
            p.iteration = n

551 552
            if state.skipped:
                state.skipped = False
J
Jim Hays 已提交
553

554 555 556
            if state.interrupted:
                break

557
            prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
558
            negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
559 560
            seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
            subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
561

A
AUTOMATIC 已提交
562
            if len(prompts) == 0:
563 564
                break

A
Artem Zagidulin 已提交
565
            if p.scripts is not None:
566
                p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
A
Artem Zagidulin 已提交
567

568
            with devices.autocast():
569
                uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
A
AUTOMATIC 已提交
570
                c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
571 572

            if len(model_hijack.comments) > 0:
573 574
                for comment in model_hijack.comments:
                    comments[comment] = 1
575 576

            if p.n_iter > 1:
577
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"
578

579
            with devices.autocast():
580
                samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
581

582 583
            x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
            x_samples_ddim = torch.stack(x_samples_ddim).float()
584 585
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

586 587 588 589 590 591 592
            del samples_ddim

            if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
                lowvram.send_everything_to_cpu()

            devices.torch_gc()

593 594
            if p.scripts is not None:
                p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
G
GRMrGecko 已提交
595

596
            for i, x_sample in enumerate(x_samples_ddim):
597 598 599
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)

600
                if p.restore_faces:
601
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
602
                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
603

604
                    devices.torch_gc()
605

606 607
                    x_sample = modules.face_restoration.restore_faces(x_sample)
                    devices.torch_gc()
608

609
                image = Image.fromarray(x_sample)
V
Vladimir Repin 已提交
610

611
                if p.color_corrections is not None and i < len(p.color_corrections):
612
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
A
AUTOMATIC 已提交
613
                        image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
614
                        images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
615
                    image = apply_color_correction(p.color_corrections[i], image)
616

A
AUTOMATIC 已提交
617
                image = apply_overlay(image, p.paste_to, i, p.overlay_images)
618 619

                if opts.samples_save and not p.do_not_save_samples:
620
                    images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
621

622 623
                text = infotext(n, i)
                infotexts.append(text)
624 625
                if opts.enable_pnginfo:
                    image.info["parameters"] = text
626 627
                output_images.append(image)

J
Jim Hays 已提交
628
            del x_samples_ddim
A
AUTOMATIC 已提交
629

630
            devices.torch_gc()
631

632
            state.nextjob()
633

634 635
        p.color_corrections = None

636
        index_of_first_image = 0
637
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
638
        if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
A
AUTOMATIC 已提交
639
            grid = images.image_grid(output_images, p.batch_size)
640

641
            if opts.return_grid:
642 643
                text = infotext()
                infotexts.insert(0, text)
644 645
                if opts.enable_pnginfo:
                    grid.info["parameters"] = text
646
                output_images.insert(0, grid)
647
                index_of_first_image = 1
648 649

            if opts.grid_save:
650
                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
651

652
    devices.torch_gc()
A
AUTOMATIC 已提交
653

654
    res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
A
AUTOMATIC 已提交
655 656 657 658 659

    if p.scripts is not None:
        p.scripts.postprocess(p, res)

    return res
660 661 662 663


class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    sampler = None
A
AUTOMATIC 已提交
664

665
    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs):
A
AUTOMATIC 已提交
666 667 668
        super().__init__(**kwargs)
        self.enable_hr = enable_hr
        self.denoising_strength = denoising_strength
A
AUTOMATIC 已提交
669 670
        self.hr_scale = hr_scale
        self.hr_upscaler = hr_upscaler
671 672 673 674 675
        self.hr_second_pass_steps = hr_second_pass_steps
        self.hr_resize_x = hr_resize_x
        self.hr_resize_y = hr_resize_y
        self.hr_upscale_to_x = hr_resize_x
        self.hr_upscale_to_y = hr_resize_y
A
AUTOMATIC 已提交
676 677 678 679 680 681

        if firstphase_width != 0 or firstphase_height != 0:
            print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
            self.hr_scale = self.width / firstphase_width
            self.width = firstphase_width
            self.height = firstphase_height
A
AUTOMATIC 已提交
682

683 684 685
        self.truncate_x = 0
        self.truncate_y = 0

A
AUTOMATIC 已提交
686 687
    def init(self, all_prompts, all_seeds, all_subseeds):
        if self.enable_hr:
688 689 690 691 692
            if not state.processing_has_refined_job_count:
                if state.job_count == -1:
                    state.job_count = self.n_iter

                shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
A
AUTOMATIC 已提交
693
                state.job_count = state.job_count * 2
694
                state.processing_has_refined_job_count = True
A
AUTOMATIC 已提交
695

696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727
            if self.hr_resize_x == 0 and self.hr_resize_y == 0:
                self.extra_generation_params["Hires upscale"] = self.hr_scale
                self.hr_upscale_to_x = int(self.width * self.hr_scale)
                self.hr_upscale_to_y = int(self.height * self.hr_scale)
            else:
                self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"

                if self.hr_resize_y == 0:
                    self.hr_upscale_to_x = self.hr_resize_x
                    self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
                elif self.hr_resize_x == 0:
                    self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
                    self.hr_upscale_to_y = self.hr_resize_y
                else:
                    target_w = self.hr_resize_x
                    target_h = self.hr_resize_y
                    src_ratio = self.width / self.height
                    dst_ratio = self.hr_resize_x / self.hr_resize_y

                    if src_ratio < dst_ratio:
                        self.hr_upscale_to_x = self.hr_resize_x
                        self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
                    else:
                        self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
                        self.hr_upscale_to_y = self.hr_resize_y

                    self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
                    self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f

            if self.hr_second_pass_steps:
                self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps

A
AUTOMATIC 已提交
728 729
            if self.hr_upscaler is not None:
                self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
730

731
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
732
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
733

734
        latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
A
AUTOMATIC 已提交
735 736 737 738 739 740
        if self.enable_hr and latent_scale_mode is None:
            assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"

        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))

741
        if not self.enable_hr:
A
AUTOMATIC 已提交
742 743
            return samples

744 745
        target_width = self.hr_upscale_to_x
        target_height = self.hr_upscale_to_y
A
AUTOMATIC 已提交
746

747
        def save_intermediate(image, index):
A
AUTOMATIC 已提交
748 749
            """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""

750 751 752 753
            if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
                return

            if not isinstance(image, Image.Image):
M
MMaker 已提交
754
                image = sd_samplers.sample_to_image(image, index, approximation=0)
755

756 757
            info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
            images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
758

A
AUTOMATIC 已提交
759
        if latent_scale_mode is not None:
760 761 762
            for i in range(samples.shape[0]):
                save_intermediate(samples, i)

M
MMaker 已提交
763
            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
764

J
Jim Hays 已提交
765
            # Avoid making the inpainting conditioning unless necessary as
766 767 768 769 770
            # this does need some extra compute to decode / encode the image again.
            if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
                image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
            else:
                image_conditioning = self.txt2img_image_conditioning(samples)
A
AUTOMATIC 已提交
771
        else:
772
            decoded_samples = decode_first_stage(self.sd_model, samples)
773
            lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
774

775 776 777 778 779
            batch_images = []
            for i, x_sample in enumerate(lowres_samples):
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)
                image = Image.fromarray(x_sample)
780 781 782

                save_intermediate(image, i)

A
AUTOMATIC 已提交
783
                image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
784 785 786 787 788 789 790 791
                image = np.array(image).astype(np.float32) / 255.0
                image = np.moveaxis(image, 2, 0)
                batch_images.append(image)

            decoded_samples = torch.from_numpy(np.array(batch_images))
            decoded_samples = decoded_samples.to(shared.device)
            decoded_samples = 2. * decoded_samples - 1.

792
            samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
A
AUTOMATIC 已提交
793

794
            image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
795

A
AUTOMATIC 已提交
796
        shared.state.nextjob()
797

798
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
A
AUTOMATIC 已提交
799

800 801
        samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]

A
AUTOMATIC 已提交
802
        noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
803 804 805 806

        # GC now before running the next img2img to prevent running out of memory
        x = None
        devices.torch_gc()
807

808
        samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
A
AUTOMATIC 已提交
809 810

        return samples
811 812 813 814 815


class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    sampler = None

816
    def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
817 818 819 820 821 822 823
        super().__init__(**kwargs)

        self.init_images = init_images
        self.resize_mode: int = resize_mode
        self.denoising_strength: float = denoising_strength
        self.init_latent = None
        self.image_mask = mask
A
AUTOMATIC 已提交
824
        self.latent_mask = None
825 826 827 828
        self.mask_for_overlay = None
        self.mask_blur = mask_blur
        self.inpainting_fill = inpainting_fill
        self.inpaint_full_res = inpaint_full_res
829
        self.inpaint_full_res_padding = inpaint_full_res_padding
A
AUTOMATIC 已提交
830
        self.inpainting_mask_invert = inpainting_mask_invert
831
        self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
832 833
        self.mask = None
        self.nmask = None
834
        self.image_conditioning = None
835

A
AUTOMATIC 已提交
836
    def init(self, all_prompts, all_seeds, all_subseeds):
837
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
838 839
        crop_region = None

840
        image_mask = self.image_mask
A
AUTOMATIC 已提交
841

842 843
        if image_mask is not None:
            image_mask = image_mask.convert('L')
A
AUTOMATIC 已提交
844

845 846
            if self.inpainting_mask_invert:
                image_mask = ImageOps.invert(image_mask)
A
AUTOMATIC 已提交
847

848
            if self.mask_blur > 0:
849
                image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
850 851

            if self.inpaint_full_res:
852 853
                self.mask_for_overlay = image_mask
                mask = image_mask.convert('L')
854
                crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
855
                crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
856 857 858
                x1, y1, x2, y2 = crop_region

                mask = mask.crop(crop_region)
859
                image_mask = images.resize_image(2, mask, self.width, self.height)
860 861
                self.paste_to = (x1, y1, x2-x1, y2-y1)
            else:
862 863
                image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
                np_mask = np.array(image_mask)
J
JJ 已提交
864
                np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
865
                self.mask_for_overlay = Image.fromarray(np_mask)
866 867 868

            self.overlay_images = []

869
        latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
870

871 872 873
        add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
        if add_color_corrections:
            self.color_corrections = []
874 875
        imgs = []
        for img in self.init_images:
876
            image = images.flatten(img, opts.img2img_background_color)
877

A
Andrew Ryan 已提交
878
            if crop_region is None and self.resize_mode != 3:
879 880
                image = images.resize_image(self.resize_mode, image, self.width, self.height)

881
            if image_mask is not None:
882 883 884 885 886
                image_masked = Image.new('RGBa', (image.width, image.height))
                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))

                self.overlay_images.append(image_masked.convert('RGBA'))

887
            # crop_region is not None if we are doing inpaint full res
888 889 890 891
            if crop_region is not None:
                image = image.crop(crop_region)
                image = images.resize_image(2, image, self.width, self.height)

892
            if image_mask is not None:
893
                if self.inpainting_fill != 1:
894
                    image = masking.fill(image, latent_mask)
895

896
            if add_color_corrections:
897 898
                self.color_corrections.append(setup_color_correction(image))

899 900 901 902 903 904 905 906 907
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)

            imgs.append(image)

        if len(imgs) == 1:
            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size
908 909 910 911

            if self.color_corrections is not None and len(self.color_corrections) == 1:
                self.color_corrections = self.color_corrections * self.batch_size

912 913 914 915 916 917 918 919 920 921 922 923
        elif len(imgs) <= self.batch_size:
            self.batch_size = len(imgs)
            batch_images = np.array(imgs)
        else:
            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

        image = torch.from_numpy(batch_images)
        image = 2. * image - 1.
        image = image.to(shared.device)

        self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))

A
Andrew Ryan 已提交
924 925 926
        if self.resize_mode == 3:
            self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")

927
        if image_mask is not None:
928
            init_mask = latent_mask
A
AUTOMATIC 已提交
929
            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
A
AUTOMATIC 已提交
930
            latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
931
            latmask = latmask[0]
932
            latmask = np.around(latmask)
933 934 935 936 937
            latmask = np.tile(latmask[None], (4, 1, 1))

            self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
            self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)

A
AUTOMATIC 已提交
938
            # this needs to be fixed to be done in sample() using actual seeds for batches
939
            if self.inpainting_fill == 2:
A
AUTOMATIC 已提交
940
                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
941 942 943
            elif self.inpainting_fill == 3:
                self.init_latent = self.init_latent * self.mask

944
        self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
945

946
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
A
AUTOMATIC 已提交
947
        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
948 949 950 951

        if self.initial_noise_multiplier != 1.0:
            self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
            x *= self.initial_noise_multiplier
A
AUTOMATIC 已提交
952

953
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
954 955 956 957

        if self.mask is not None:
            samples = samples * self.nmask + self.init_latent * self.mask

958 959 960
        del x
        devices.torch_gc()

961
        return samples