processing.py 65.8 KB
Newer Older
1
import json
2
import logging
3 4 5
import math
import os
import sys
6
import hashlib
7 8 9

import torch
import numpy as np
A
linter  
AUTOMATIC 已提交
10
from PIL import Image, ImageOps
11
import random
12 13
import cv2
from skimage import exposure
A
AUTOMATIC 已提交
14
from typing import Any, Dict, List
15

16
import modules.sd_hijack
17
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
18
from modules.sd_hijack import model_hijack
19
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
20 21
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
22
import modules.paths as paths
A
AUTOMATIC 已提交
23
import modules.face_restoration
24
import modules.images as images
A
AUTOMATIC 已提交
25
import modules.styles
26 27
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
J
Jay Smith 已提交
28 29
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
30

J
Jay Smith 已提交
31
from einops import repeat, rearrange
32
from blendmodes.blend import blendLayers, BlendType
33 34


35 36 37 38 39
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8


40
def setup_color_correction(image):
R
Robin Fernandes 已提交
41
    logging.info("Calibrating color correction.")
42 43 44 45
    correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
    return correction_target


46
def apply_color_correction(correction, original_image):
R
Robin Fernandes 已提交
47
    logging.info("Applying color correction.")
48 49
    image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
        cv2.cvtColor(
50
            np.asarray(original_image),
51 52 53 54 55
            cv2.COLOR_RGB2LAB
        ),
        correction,
        channel_axis=2
    ), cv2.COLOR_LAB2RGB).astype("uint8"))
56

57
    image = blendLayers(image, original_image, BlendType.LUMINOSITY)
58

59 60
    return image

A
AUTOMATIC 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

def apply_overlay(image, paste_loc, index, overlays):
    if overlays is None or index >= len(overlays):
        return image

    overlay = overlays[index]

    if paste_loc is not None:
        x, y, w, h = paste_loc
        base_image = Image.new('RGBA', (overlay.width, overlay.height))
        image = images.resize_image(1, image, w, h)
        base_image.paste(image, (x, y))
        image = base_image

    image = image.convert('RGBA')
    image.alpha_composite(overlay)
    image = image.convert('RGB')
78 79

    return image
80

F
frostydad 已提交
81

82
def txt2img_image_conditioning(sd_model, x, width, height):
83 84 85 86
    if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models

        # The "masked-image" in this case will just be all zeros since the entire image is masked.
        image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
87
        image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
88 89 90 91 92 93

        # Add the fake full 1s mask to the first dimension.
        image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
        image_conditioning = image_conditioning.to(x.dtype)

        return image_conditioning
94

95
    elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
96

97
        return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
98

99 100 101 102 103
    else:
        # Dummy zero conditioning if we're not using inpainting or unclip models.
        # Still takes up a bit of memory, but no encoder call.
        # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
        return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
104 105


106
class StableDiffusionProcessing:
A
arcticfaded 已提交
107 108 109
    """
    The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
    """
W
w-e-w 已提交
110 111 112
    cached_uc = [None, None]
    cached_c = [None, None]

C
catboxanon 已提交
113
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
114
        if sampler_index is not None:
115
            print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
116

117 118 119
        self.outpath_samples: str = outpath_samples
        self.outpath_grids: str = outpath_grids
        self.prompt: str = prompt
A
AUTOMATIC 已提交
120
        self.prompt_for_display: str = None
121
        self.negative_prompt: str = (negative_prompt or "")
122
        self.styles: list = styles or []
123
        self.seed: int = seed
124 125 126 127
        self.subseed: int = subseed
        self.subseed_strength: float = subseed_strength
        self.seed_resize_from_h: int = seed_resize_from_h
        self.seed_resize_from_w: int = seed_resize_from_w
128
        self.sampler_name: str = sampler_name
129 130 131 132 133 134
        self.batch_size: int = batch_size
        self.n_iter: int = n_iter
        self.steps: int = steps
        self.cfg_scale: float = cfg_scale
        self.width: int = width
        self.height: int = height
A
AUTOMATIC 已提交
135
        self.restore_faces: bool = restore_faces
136
        self.tiling: bool = tiling
137 138
        self.do_not_save_samples: bool = do_not_save_samples
        self.do_not_save_grid: bool = do_not_save_grid
A
AUTOMATIC 已提交
139
        self.extra_generation_params: dict = extra_generation_params or {}
140
        self.overlay_images = overlay_images
141
        self.eta = eta
142
        self.do_not_reload_embeddings = do_not_reload_embeddings
143
        self.paste_to = None
144
        self.color_corrections = None
145
        self.denoising_strength: float = denoising_strength
146
        self.sampler_noise_scheduler_override = None
147
        self.ddim_discretize = ddim_discretize or opts.ddim_discretize
D
devdn 已提交
148
        self.s_min_uncond = s_min_uncond or opts.s_min_uncond
A
arcticfaded 已提交
149 150
        self.s_churn = s_churn or opts.s_churn
        self.s_tmin = s_tmin or opts.s_tmin
A
AUTOMATIC1111 已提交
151 152
        self.s_tmax = (s_tmax if s_tmax is not None else opts.s_tmax) or float('inf')
        self.s_noise = s_noise if s_noise is not None else opts.s_noise
153
        self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
154
        self.override_settings_restore_afterwards = override_settings_restore_afterwards
155
        self.is_using_inpainting_conditioning = False
A
AUTOMATIC 已提交
156
        self.disable_extra_networks = False
157 158
        self.token_merging_ratio = 0
        self.token_merging_ratio_hr = 0
159

160 161 162 163 164 165
        if not seed_enable_extras:
            self.subseed = -1
            self.subseed_strength = 0
            self.seed_resize_from_h = 0
            self.seed_resize_from_w = 0

166
        self.scripts = None
N
noodleanon 已提交
167
        self.script_args = script_args
168
        self.all_prompts = None
169
        self.all_negative_prompts = None
170 171
        self.all_seeds = None
        self.all_subseeds = None
172
        self.iteration = 0
A
AUTOMATIC 已提交
173
        self.is_hr_pass = False
174
        self.sampler = None
175

176 177
        self.prompts = None
        self.negative_prompts = None
W
w-e-w 已提交
178
        self.extra_network_data = None
179 180 181 182
        self.seeds = None
        self.subseeds = None

        self.step_multiplier = 1
W
w-e-w 已提交
183 184
        self.cached_uc = StableDiffusionProcessing.cached_uc
        self.cached_c = StableDiffusionProcessing.cached_c
185 186
        self.uc = None
        self.c = None
187

188 189
        self.user = None

190 191 192 193
    @property
    def sd_model(self):
        return shared.sd_model

194
    def txt2img_image_conditioning(self, x, width=None, height=None):
195
        self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
196

197
        return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
198

J
Jay Smith 已提交
199 200 201 202 203 204 205
    def depth2img_image_conditioning(self, source_image):
        # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
        transformer = AddMiDaS(model_type="dpt_hybrid")
        transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
        midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
        midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)

206
        conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
J
Jay Smith 已提交
207 208 209 210 211 212 213 214 215 216
        conditioning = torch.nn.functional.interpolate(
            self.sd_model.depth_model(midas_in),
            size=conditioning_image.shape[2:],
            mode="bicubic",
            align_corners=False,
        )

        (depth_min, depth_max) = torch.aminmax(conditioning)
        conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
        return conditioning
217

218
    def edit_image_conditioning(self, source_image):
219
        conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
220 221 222

        return conditioning_image

223 224 225 226 227 228 229 230
    def unclip_image_conditioning(self, source_image):
        c_adm = self.sd_model.embedder(source_image)
        if self.sd_model.noise_augmentor is not None:
            noise_level = 0 # TODO: Allow other noise levels?
            c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
            c_adm = torch.cat((c_adm, noise_level_emb), 1)
        return c_adm

231
    def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
232 233
        self.is_using_inpainting_conditioning = True

234 235 236 237 238 239 240 241 242 243 244 245
        # Handle the different mask inputs
        if image_mask is not None:
            if torch.is_tensor(image_mask):
                conditioning_mask = image_mask
            else:
                conditioning_mask = np.array(image_mask.convert("L"))
                conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
                conditioning_mask = torch.from_numpy(conditioning_mask[None, None])

                # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
                conditioning_mask = torch.round(conditioning_mask)
        else:
246
            conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
247 248 249

        # Create another latent image, this time with a masked version of the original input.
        # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
250
        conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
251 252 253 254 255
        conditioning_image = torch.lerp(
            source_image,
            source_image * (1.0 - conditioning_mask),
            getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
        )
J
Jim Hays 已提交
256

257
        # Encode the new masked image using first stage of network.
K
Kohaku-Blueleaf 已提交
258
        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
259 260 261 262 263 264 265 266 267

        # Create the concatenated conditioning tensor to be fed to `c_concat`
        conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
        conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
        image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
        image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)

        return image_conditioning

J
Jay Smith 已提交
268
    def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
269 270
        source_image = devices.cond_cast_float(source_image)

J
Jay Smith 已提交
271 272 273
        # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
        # identify itself with a field common to all models. The conditioning_key is also hybrid.
        if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
274
            return self.depth2img_image_conditioning(source_image)
J
Jay Smith 已提交
275

276 277 278
        if self.sd_model.cond_stage_key == "edit":
            return self.edit_image_conditioning(source_image)

J
Jay Smith 已提交
279
        if self.sampler.conditioning_key in {'hybrid', 'concat'}:
280
            return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
J
Jay Smith 已提交
281

282 283 284
        if self.sampler.conditioning_key == "crossattn-adm":
            return self.unclip_image_conditioning(source_image)

J
Jay Smith 已提交
285 286 287
        # Dummy zero conditioning if we're not using inpainting or depth model.
        return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)

A
AUTOMATIC 已提交
288
    def init(self, all_prompts, all_seeds, all_subseeds):
289 290
        pass

291
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
292 293
        raise NotImplementedError()

294 295
    def close(self):
        self.sampler = None
296 297
        self.c = None
        self.uc = None
A
AUTOMATIC1111 已提交
298
        if not opts.persistent_cond_cache:
W
w-e-w 已提交
299 300
            StableDiffusionProcessing.cached_c = [None, None]
            StableDiffusionProcessing.cached_uc = [None, None]
301

302 303 304 305 306 307
    def get_token_merging_ratio(self, for_hr=False):
        if for_hr:
            return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio

        return self.token_merging_ratio or opts.token_merging_ratio

308 309 310 311 312 313 314 315 316 317 318 319 320 321
    def setup_prompts(self):
        if type(self.prompt) == list:
            self.all_prompts = self.prompt
        else:
            self.all_prompts = self.batch_size * self.n_iter * [self.prompt]

        if type(self.negative_prompt) == list:
            self.all_negative_prompts = self.negative_prompt
        else:
            self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]

        self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
        self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]

A
AUTOMATIC1111 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
    def cached_params(self, required_prompts, steps, extra_network_data):
        """Returns parameters that invalidate the cond cache if changed"""

        return (
            required_prompts,
            steps,
            opts.CLIP_stop_at_last_layers,
            shared.sd_model.sd_checkpoint_info,
            extra_network_data,
            opts.sdxl_crop_left,
            opts.sdxl_crop_top,
            self.width,
            self.height,
        )

337
    def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
338 339 340 341 342 343 344 345
        """
        Returns the result of calling function(shared.sd_model, required_prompts, steps)
        using a cache to store the result if the same arguments have been used before.

        cache is an array containing two elements. The first element is a tuple
        representing the previously used arguments, or None if no arguments
        have been used before. The second element is where the previously
        computed result is stored.
346 347

        caches is a list with items described above.
348
        """
349

A
AUTOMATIC1111 已提交
350
        cached_params = self.cached_params(required_prompts, steps, extra_network_data)
351

352
        for cache in caches:
353
            if cache[0] is not None and cached_params == cache[0]:
354 355 356
                return cache[1]

        cache = caches[0]
357 358 359 360

        with devices.autocast():
            cache[1] = function(shared.sd_model, required_prompts, steps)

361
        cache[0] = cached_params
362 363 364
        return cache[1]

    def setup_conds(self):
A
AUTOMATIC1111 已提交
365
        prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
366
        negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
A
AUTOMATIC1111 已提交
367

368 369
        sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
        self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
A
AUTOMATIC1111 已提交
370 371
        self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
        self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
372 373

    def parse_extra_network_prompts(self):
W
w-e-w 已提交
374
        self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
375

A
AUTOMATIC1111 已提交
376 377 378 379
    def save_samples(self) -> bool:
        """Returns whether generated images need to be written to disk"""
        return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)

380 381

class Processed:
382
    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
383 384
        self.images = images_list
        self.prompt = p.prompt
385
        self.negative_prompt = p.negative_prompt
386
        self.seed = seed
387 388
        self.subseed = subseed
        self.subseed_strength = p.subseed_strength
389
        self.info = info
390
        self.comments = comments
391 392
        self.width = p.width
        self.height = p.height
393
        self.sampler_name = p.sampler_name
394
        self.cfg_scale = p.cfg_scale
K
Kyle 已提交
395
        self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
396
        self.steps = p.steps
397 398 399 400 401 402 403 404 405
        self.batch_size = p.batch_size
        self.restore_faces = p.restore_faces
        self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
        self.sd_model_hash = shared.sd_model.sd_model_hash
        self.seed_resize_from_w = p.seed_resize_from_w
        self.seed_resize_from_h = p.seed_resize_from_h
        self.denoising_strength = getattr(p, 'denoising_strength', None)
        self.extra_generation_params = p.extra_generation_params
        self.index_of_first_image = index_of_first_image
M
Milly 已提交
406
        self.styles = p.styles
M
Milly 已提交
407
        self.job_timestamp = state.job_timestamp
408
        self.clip_skip = opts.CLIP_stop_at_last_layers
409 410
        self.token_merging_ratio = p.token_merging_ratio
        self.token_merging_ratio_hr = p.token_merging_ratio_hr
411

C
C43H66N12O12S2 已提交
412
        self.eta = p.eta
413 414 415 416 417
        self.ddim_discretize = p.ddim_discretize
        self.s_churn = p.s_churn
        self.s_tmin = p.s_tmin
        self.s_tmax = p.s_tmax
        self.s_noise = p.s_noise
A
Aarni Koskela 已提交
418
        self.s_min_uncond = p.s_min_uncond
419
        self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
420 421
        self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
        self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
G
githublsx 已提交
422
        self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
423
        self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
424
        self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
425

426 427 428 429
        self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
        self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
        self.all_seeds = all_seeds or p.all_seeds or [self.seed]
        self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
430
        self.infotexts = infotexts or [info]
431 432 433

    def js(self):
        obj = {
434
            "prompt": self.all_prompts[0],
435
            "all_prompts": self.all_prompts,
436 437
            "negative_prompt": self.all_negative_prompts[0],
            "all_negative_prompts": self.all_negative_prompts,
438 439 440 441
            "seed": self.seed,
            "all_seeds": self.all_seeds,
            "subseed": self.subseed,
            "all_subseeds": self.all_subseeds,
442
            "subseed_strength": self.subseed_strength,
443 444
            "width": self.width,
            "height": self.height,
445
            "sampler_name": self.sampler_name,
446 447
            "cfg_scale": self.cfg_scale,
            "steps": self.steps,
448 449 450 451 452 453 454 455 456
            "batch_size": self.batch_size,
            "restore_faces": self.restore_faces,
            "face_restoration_model": self.face_restoration_model,
            "sd_model_hash": self.sd_model_hash,
            "seed_resize_from_w": self.seed_resize_from_w,
            "seed_resize_from_h": self.seed_resize_from_h,
            "denoising_strength": self.denoising_strength,
            "extra_generation_params": self.extra_generation_params,
            "index_of_first_image": self.index_of_first_image,
457
            "infotexts": self.infotexts,
M
Milly 已提交
458
            "styles": self.styles,
M
Milly 已提交
459
            "job_timestamp": self.job_timestamp,
460
            "clip_skip": self.clip_skip,
461
            "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
462 463 464 465
        }

        return json.dumps(obj)

S
space-nuko 已提交
466
    def infotext(self, p: StableDiffusionProcessing, index):
467 468
        return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)

469 470 471
    def get_token_merging_ratio(self, for_hr=False):
        return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio

472

473 474 475 476
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
477 478 479 480 481 482
    dot = (low_norm*high_norm).sum(1)

    if dot.mean() > 0.9995:
        return low * val + high * (1 - val)

    omega = torch.acos(dot)
483 484 485
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
    return res
486

487

488
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
489
    eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
490
    xs = []
491

492 493
    # if we have multiple seeds, this means we are working with batch size>1; this then
    # enables the generation of additional tensors with noise that the sampler will use during its processing.
S
Steve Eberhardt 已提交
494
    # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
495
    # produce the same images as with two batches [100], [101].
496
    if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
497 498 499 500
        sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
    else:
        sampler_noises = None

501 502 503 504
    for i, seed in enumerate(seeds):
        noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)

        subnoise = None
505
        if subseeds is not None and subseed_strength != 0:
506
            subseed = 0 if i >= len(subseeds) else subseeds[i]
E
Elias Oenal 已提交
507

A
AUTOMATIC 已提交
508
            subnoise = devices.randn(subseed, noise_shape)
509 510 511

        # randn results depend on device; gpu and cpu get different results for same seed;
        # the way I see it, it's better to do this on CPU, so that everyone gets same result;
512
        # but the original script had it like this, so I do not dare change it for now because
513
        # it will break everyone's seeds.
A
AUTOMATIC 已提交
514
        noise = devices.randn(seed, noise_shape)
515 516 517 518 519

        if subnoise is not None:
            noise = slerp(subseed_strength, noise, subnoise)

        if noise_shape != shape:
A
AUTOMATIC 已提交
520 521
            x = devices.randn(seed, shape)
            dx = (shape[2] - noise_shape[2]) // 2
522 523 524 525 526 527 528 529 530 531 532
            dy = (shape[1] - noise_shape[1]) // 2
            w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
            h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
            tx = 0 if dx < 0 else dx
            ty = 0 if dy < 0 else dy
            dx = max(-dx, 0)
            dy = max(-dy, 0)

            x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
            noise = x

533 534
        if sampler_noises is not None:
            cnt = p.sampler.number_of_needed_noises(p)
535

536
            if eta_noise_seed_delta > 0:
537
                devices.manual_seed(seed + eta_noise_seed_delta)
A
AUTOMATIC 已提交
538

539 540
            for j in range(cnt):
                sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
541 542

        xs.append(noise)
543 544 545 546

    if sampler_noises is not None:
        p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]

547
    x = torch.stack(xs).to(shared.device)
548 549 550
    return x


A
AUTOMATIC1111 已提交
551 552 553 554
class DecodedSamples(list):
    already_decoded = True


555
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
A
AUTOMATIC1111 已提交
556
    samples = DecodedSamples()
557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588

    for i in range(batch.shape[0]):
        sample = decode_first_stage(model, batch[i:i + 1])[0]

        if check_for_nans:
            try:
                devices.test_for_nans(sample, "vae")
            except devices.NansException as e:
                if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
                    raise e

                errors.print_error_explanation(
                    "A tensor with all NaNs was produced in VAE.\n"
                    "Web UI will now convert VAE into 32-bit float and retry.\n"
                    "To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n"
                    "To always start with 32-bit VAE, use --no-half-vae commandline flag."
                )

                devices.dtype_vae = torch.float32
                model.first_stage_model.to(devices.dtype_vae)
                batch = batch.to(devices.dtype_vae)

                sample = decode_first_stage(model, batch[i:i + 1])[0]

        if target_device is not None:
            sample = sample.to(target_device)

        samples.append(sample)

    return samples


589 590 591 592 593 594 595
def get_fixed_seed(seed):
    if seed is None or seed == '' or seed == -1:
        return int(random.randrange(4294967294))

    return seed


596
def fix_seed(p):
597 598
    p.seed = get_fixed_seed(p.seed)
    p.subseed = get_fixed_seed(p.subseed)
A
AUTOMATIC 已提交
599 600


601 602 603 604 605 606 607 608 609 610
def program_version():
    import launch

    res = launch.git_tag()
    if res == "<none>":
        res = None

    return res


611 612 613 614 615 616
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):
    if index is None:
        index = position_in_batch + iteration * p.batch_size

    if all_negative_prompts is None:
        all_negative_prompts = p.all_negative_prompts
617

618
    clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
P
papuSpartan 已提交
619
    enable_hr = getattr(p, 'enable_hr', False)
620 621
    token_merging_ratio = p.get_token_merging_ratio()
    token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
622

623 624 625 626
    uses_ensd = opts.eta_noise_seed_delta != 0
    if uses_ensd:
        uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)

627 628
    generation_params = {
        "Steps": p.steps,
629
        "Sampler": p.sampler_name,
630
        "CFG scale": p.cfg_scale,
K
Kyle 已提交
631
        "Image CFG scale": getattr(p, 'image_cfg_scale', None),
A
AUTOMATIC1111 已提交
632
        "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
633 634 635
        "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
        "Size": f"{p.width}x{p.height}",
        "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
636
        "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
A
AUTOMATIC1111 已提交
637
        "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
638
        "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
M
missionfloyd 已提交
639
        "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
640
        "Denoising strength": getattr(p, 'denoising_strength', None),
641
        "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
642
        "Clip skip": None if clip_skip <= 1 else clip_skip,
643
        "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
644 645
        "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
        "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
646
        "Init image hash": getattr(p, 'init_img_hash', None),
647
        "RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
648
        "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
649
        **p.extra_generation_params,
650
        "Version": program_version() if opts.add_version_to_infotext else None,
651
        "User": p.user if opts.add_user_name_to_info else None,
652 653
    }

654
    generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
655

656
    prompt_text = p.prompt if use_main_prompt else all_prompts[index]
657
    negative_prompt_text = f"\nNegative prompt: {all_negative_prompts[index]}" if all_negative_prompts[index] else ""
658

659
    return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
660 661


662
def process_images(p: StableDiffusionProcessing) -> Processed:
663 664 665
    if p.scripts is not None:
        p.scripts.before_process(p)

666 667 668
    stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}

    try:
W
w-e-w 已提交
669
        # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
A
Aarni Koskela 已提交
670
        if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
W
w-e-w 已提交
671 672 673
            p.override_settings.pop('sd_model_checkpoint', None)
            sd_models.reload_model_weights()

674
        for k, v in p.override_settings.items():
675
            setattr(opts, k, v)
676 677

            if k == 'sd_model_checkpoint':
A
AUTOMATIC 已提交
678
                sd_models.reload_model_weights()
679 680

            if k == 'sd_vae':
A
AUTOMATIC 已提交
681
                sd_vae.reload_vae_weights()
682

683
        sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
684

685 686
        res = process_images_inner(p)

687
    finally:
688
        sd_models.apply_token_merging(p.sd_model, 0)
689

690 691 692 693
        # restore opts to original state
        if p.override_settings_restore_afterwards:
            for k, v in stored_opts.items():
                setattr(opts, k, v)
A
AUTOMATIC 已提交
694 695 696

                if k == 'sd_vae':
                    sd_vae.reload_vae_weights()
697 698 699 700 701

    return res


def process_images_inner(p: StableDiffusionProcessing) -> Processed:
702 703
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

704 705 706 707
    if type(p.prompt) == list:
        assert(len(p.prompt) > 0)
    else:
        assert p.prompt is not None
708

709
    devices.torch_gc()
710

711 712
    seed = get_fixed_seed(p.seed)
    subseed = get_fixed_seed(p.subseed)
713

714
    modules.sd_hijack.model_hijack.apply_circular(p.tiling)
715
    modules.sd_hijack.model_hijack.clear_comments()
716

717
    comments = {}
A
AUTOMATIC 已提交
718

719
    p.setup_prompts()
I
invincibledude 已提交
720

721
    if type(seed) == list:
722
        p.all_seeds = seed
A
AUTOMATIC 已提交
723
    else:
724
        p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
725

726
    if type(subseed) == list:
727
        p.all_subseeds = subseed
728
    else:
729
        p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
730

731
    if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
732
        model_hijack.embedding_db.load_textual_inversion_embeddings()
733

734
    if p.scripts is not None:
A
AUTOMATIC 已提交
735
        p.scripts.process(p)
736

737
    infotexts = []
738
    output_images = []
739

740
    with torch.no_grad(), p.sd_model.ema_scope():
A
AUTOMATIC 已提交
741
        with devices.autocast():
742
            p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
743

744 745
            # for OSX, loading the model during sampling changes the generated picture, so it is loaded here
            if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
746 747
                sd_vae_approx.model()

A
AUTOMATIC 已提交
748 749
            sd_unet.apply_unet()

A
AUTOMATIC 已提交
750 751
        if state.job_count == -1:
            state.job_count = p.n_iter
A
AUTOMATIC 已提交
752

753
        for n in range(p.n_iter):
754 755
            p.iteration = n

756 757
            if state.skipped:
                state.skipped = False
J
Jim Hays 已提交
758

759 760 761
            if state.interrupted:
                break

762 763 764 765
            p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
            p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
766

767
            if p.scripts is not None:
768
                p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
769

770
            if len(p.prompts) == 0:
771 772
                break

W
w-e-w 已提交
773
            p.parse_extra_network_prompts()
I
InvincibleDude 已提交
774

775 776
            if not p.disable_extra_networks:
                with devices.autocast():
W
w-e-w 已提交
777
                    extra_networks.activate(p, p.extra_network_data)
778

A
Artem Zagidulin 已提交
779
            if p.scripts is not None:
780
                p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
A
Artem Zagidulin 已提交
781

782 783 784 785 786 787 788 789 790
            # params.txt should be saved after scripts.process_batch, since the
            # infotext could be modified by that callback
            # Example: a wildcard processed by process_batch sets an extra model
            # strength, which is saved as "Model Strength: 1.0" in the infotext
            if n == 0:
                with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
                    processed = Processed(p, [], p.seed, "")
                    file.write(processed.infotext(p, 0))

791
            p.setup_conds()
792

793 794 795 796
            for comment in model_hijack.comments:
                comments[comment] = 1

            p.extra_generation_params.update(model_hijack.extra_generation_params)
797 798

            if p.n_iter > 1:
799
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"
800

801
            with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
802
                samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
803

A
AUTOMATIC1111 已提交
804 805 806
            if getattr(samples_ddim, 'already_decoded', False):
                x_samples_ddim = samples_ddim
            else:
807 808 809
                if opts.sd_vae_decode_method != 'Full':
                    p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method

A
AUTOMATIC1111 已提交
810 811
                x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)

812
            x_samples_ddim = torch.stack(x_samples_ddim).float()
813 814
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

815 816
            del samples_ddim

817
            if lowvram.is_enabled(shared.sd_model):
818 819 820 821
                lowvram.send_everything_to_cpu()

            devices.torch_gc()

822 823
            if p.scripts is not None:
                p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
G
GRMrGecko 已提交
824

825 826
                p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
                p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
827

828
                batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
829 830 831 832
                p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
                x_samples_ddim = batch_params.images

            def infotext(index=0, use_main_prompt=False):
A
AUTOMATIC1111 已提交
833
                return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
L
ljleb 已提交
834

A
AUTOMATIC1111 已提交
835
            save_samples = p.save_samples()
W
w-e-w 已提交
836

837
            for i, x_sample in enumerate(x_samples_ddim):
838 839
                p.batch_index = i

840 841 842
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)

843
                if p.restore_faces:
A
AUTOMATIC1111 已提交
844
                    if save_samples and opts.save_images_before_face_restoration:
845
                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
846

847
                    devices.torch_gc()
848

849 850
                    x_sample = modules.face_restoration.restore_faces(x_sample)
                    devices.torch_gc()
851

852
                image = Image.fromarray(x_sample)
V
Vladimir Repin 已提交
853

854 855 856 857
                if p.scripts is not None:
                    pp = scripts.PostprocessImageArgs(image)
                    p.scripts.postprocess_image(p, pp)
                    image = pp.image
858
                if p.color_corrections is not None and i < len(p.color_corrections):
A
AUTOMATIC1111 已提交
859
                    if save_samples and opts.save_images_before_color_correction:
A
AUTOMATIC 已提交
860
                        image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
861
                        images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
862
                    image = apply_color_correction(p.color_corrections[i], image)
863

A
AUTOMATIC 已提交
864
                image = apply_overlay(image, p.paste_to, i, p.overlay_images)
865

A
AUTOMATIC1111 已提交
866
                if save_samples:
867
                    images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
868

869
                text = infotext(i)
870
                infotexts.append(text)
871 872
                if opts.enable_pnginfo:
                    image.info["parameters"] = text
873
                output_images.append(image)
874
                if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
875
                    image_mask = p.mask_for_overlay.convert('RGB')
876
                    image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
877 878

                    if opts.save_mask:
879
                        images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
880 881

                    if opts.save_mask_composite:
882
                        images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
883 884 885

                    if opts.return_mask:
                        output_images.append(image_mask)
886

887 888 889
                    if opts.return_mask_composite:
                        output_images.append(image_mask_composite)

J
Jim Hays 已提交
890
            del x_samples_ddim
A
AUTOMATIC 已提交
891

892
            devices.torch_gc()
893

894
            state.nextjob()
895

896 897
        p.color_corrections = None

898
        index_of_first_image = 0
899
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
900
        if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
A
AUTOMATIC 已提交
901
            grid = images.image_grid(output_images, p.batch_size)
902

903
            if opts.return_grid:
904
                text = infotext(use_main_prompt=True)
905
                infotexts.insert(0, text)
906 907
                if opts.enable_pnginfo:
                    grid.info["parameters"] = text
908
                output_images.insert(0, grid)
909
                index_of_first_image = 1
910
            if opts.grid_save:
911
                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
912

W
w-e-w 已提交
913 914
    if not p.disable_extra_networks and p.extra_network_data:
        extra_networks.deactivate(p, p.extra_network_data)
A
AUTOMATIC 已提交
915

916
    devices.torch_gc()
A
AUTOMATIC 已提交
917

918 919 920 921
    res = Processed(
        p,
        images_list=output_images,
        seed=p.all_seeds[0],
A
AUTOMATIC1111 已提交
922
        info=infotexts[0],
923
        comments="".join(f"{comment}\n" for comment in comments),
924 925 926 927
        subseed=p.all_subseeds[0],
        index_of_first_image=index_of_first_image,
        infotexts=infotexts,
    )
A
AUTOMATIC 已提交
928 929 930 931 932

    if p.scripts is not None:
        p.scripts.postprocess(p, res)

    return res
933 934


935 936 937 938 939 940 941 942 943 944 945 946
def old_hires_fix_first_pass_dimensions(width, height):
    """old algorithm for auto-calculating first pass size"""

    desired_pixel_count = 512 * 512
    actual_pixel_count = width * height
    scale = math.sqrt(desired_pixel_count / actual_pixel_count)
    width = math.ceil(scale * width / 64) * 64
    height = math.ceil(scale * height / 64) * 64

    return width, height


947 948
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    sampler = None
W
w-e-w 已提交
949 950
    cached_hr_uc = [None, None]
    cached_hr_c = [None, None]
A
AUTOMATIC 已提交
951

A
AUTOMATIC1111 已提交
952
    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
A
AUTOMATIC 已提交
953 954 955
        super().__init__(**kwargs)
        self.enable_hr = enable_hr
        self.denoising_strength = denoising_strength
A
AUTOMATIC 已提交
956 957
        self.hr_scale = hr_scale
        self.hr_upscaler = hr_upscaler
958 959 960 961 962
        self.hr_second_pass_steps = hr_second_pass_steps
        self.hr_resize_x = hr_resize_x
        self.hr_resize_y = hr_resize_y
        self.hr_upscale_to_x = hr_resize_x
        self.hr_upscale_to_y = hr_resize_y
A
AUTOMATIC1111 已提交
963 964
        self.hr_checkpoint_name = hr_checkpoint_name
        self.hr_checkpoint_info = None
965 966 967
        self.hr_sampler_name = hr_sampler_name
        self.hr_prompt = hr_prompt
        self.hr_negative_prompt = hr_negative_prompt
I
invincibledude 已提交
968 969
        self.all_hr_prompts = None
        self.all_hr_negative_prompts = None
A
AUTOMATIC1111 已提交
970
        self.latent_scale_mode = None
A
AUTOMATIC 已提交
971 972

        if firstphase_width != 0 or firstphase_height != 0:
973 974
            self.hr_upscale_to_x = self.width
            self.hr_upscale_to_y = self.height
A
AUTOMATIC 已提交
975 976
            self.width = firstphase_width
            self.height = firstphase_height
A
AUTOMATIC 已提交
977

978 979
        self.truncate_x = 0
        self.truncate_y = 0
980
        self.applied_old_hires_behavior_to = None
A
AUTOMATIC 已提交
981

982 983 984 985
        self.hr_prompts = None
        self.hr_negative_prompts = None
        self.hr_extra_network_data = None

W
w-e-w 已提交
986 987
        self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
        self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
988 989 990
        self.hr_c = None
        self.hr_uc = None

A
AUTOMATIC 已提交
991 992
    def init(self, all_prompts, all_seeds, all_subseeds):
        if self.enable_hr:
A
AUTOMATIC1111 已提交
993 994 995 996 997 998 999 1000
            if self.hr_checkpoint_name:
                self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)

                if self.hr_checkpoint_info is None:
                    raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')

                self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title

1001 1002 1003 1004 1005
            if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
                self.extra_generation_params["Hires sampler"] = self.hr_sampler_name

            if tuple(self.hr_prompt) != tuple(self.prompt):
                self.extra_generation_params["Hires prompt"] = self.hr_prompt
I
InvincibleDude 已提交
1006

1007 1008
            if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
                self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
I
InvincibleDude 已提交
1009

A
AUTOMATIC1111 已提交
1010 1011 1012 1013 1014
            self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
            if self.enable_hr and self.latent_scale_mode is None:
                if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
                    raise Exception(f"could not find upscaler named {self.hr_upscaler}")

1015 1016 1017 1018 1019 1020 1021 1022 1023
            if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
                self.hr_resize_x = self.width
                self.hr_resize_y = self.height
                self.hr_upscale_to_x = self.width
                self.hr_upscale_to_y = self.height

                self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
                self.applied_old_hires_behavior_to = (self.width, self.height)

1024 1025 1026 1027
            if self.hr_resize_x == 0 and self.hr_resize_y == 0:
                self.extra_generation_params["Hires upscale"] = self.hr_scale
                self.hr_upscale_to_x = int(self.width * self.hr_scale)
                self.hr_upscale_to_y = int(self.height * self.hr_scale)
A
AUTOMATIC 已提交
1028
            else:
1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052
                self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"

                if self.hr_resize_y == 0:
                    self.hr_upscale_to_x = self.hr_resize_x
                    self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
                elif self.hr_resize_x == 0:
                    self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
                    self.hr_upscale_to_y = self.hr_resize_y
                else:
                    target_w = self.hr_resize_x
                    target_h = self.hr_resize_y
                    src_ratio = self.width / self.height
                    dst_ratio = self.hr_resize_x / self.hr_resize_y

                    if src_ratio < dst_ratio:
                        self.hr_upscale_to_x = self.hr_resize_x
                        self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
                    else:
                        self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
                        self.hr_upscale_to_y = self.hr_resize_y

                    self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
                    self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f

1053 1054 1055
            if not state.processing_has_refined_job_count:
                if state.job_count == -1:
                    state.job_count = self.n_iter
1056

1057 1058 1059
                shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
                state.job_count = state.job_count * 2
                state.processing_has_refined_job_count = True
1060

1061 1062
            if self.hr_second_pass_steps:
                self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
1063

A
AUTOMATIC 已提交
1064 1065
            if self.hr_upscaler is not None:
                self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
1066

1067
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1068
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1069

A
AUTOMATIC 已提交
1070 1071
        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
A
linter  
AUTOMATIC1111 已提交
1072
        del x
A
AUTOMATIC 已提交
1073

1074
        if not self.enable_hr:
A
AUTOMATIC 已提交
1075 1076
            return samples

1077
        if self.latent_scale_mode is None:
1078
            decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
1079 1080 1081
        else:
            decoded_samples = None

A
AUTOMATIC1111 已提交
1082 1083 1084
        current = shared.sd_model.sd_checkpoint_info
        try:
            if self.hr_checkpoint_info is not None:
A
AUTOMATIC1111 已提交
1085
                self.sampler = None
A
AUTOMATIC1111 已提交
1086
                sd_models.reload_model_weights(info=self.hr_checkpoint_info)
1087
                devices.torch_gc()
A
AUTOMATIC1111 已提交
1088

1089
            return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
A
AUTOMATIC1111 已提交
1090
        finally:
A
AUTOMATIC1111 已提交
1091
            self.sampler = None
A
AUTOMATIC1111 已提交
1092
            sd_models.reload_model_weights(info=current)
1093
            devices.torch_gc()
A
AUTOMATIC1111 已提交
1094

1095
    def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
A
AUTOMATIC 已提交
1096 1097
        self.is_hr_pass = True

1098 1099
        target_width = self.hr_upscale_to_x
        target_height = self.hr_upscale_to_y
A
AUTOMATIC 已提交
1100

1101
        def save_intermediate(image, index):
A
AUTOMATIC 已提交
1102 1103
            """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""

A
AUTOMATIC1111 已提交
1104
            if not self.save_samples() or not opts.save_images_before_highres_fix:
1105 1106 1107
                return

            if not isinstance(image, Image.Image):
M
MMaker 已提交
1108
                image = sd_samplers.sample_to_image(image, index, approximation=0)
1109

1110
            info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
W
w-e-w 已提交
1111
            images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
1112

1113 1114 1115 1116 1117 1118 1119
        img2img_sampler_name = self.hr_sampler_name or self.sampler_name

        if self.sampler_name in ['PLMS', 'UniPC']:  # PLMS/UniPC do not support img2img so we just silently switch to DDIM
            img2img_sampler_name = 'DDIM'

        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)

A
AUTOMATIC1111 已提交
1120
        if self.latent_scale_mode is not None:
1121 1122 1123
            for i in range(samples.shape[0]):
                save_intermediate(samples, i)

A
AUTOMATIC1111 已提交
1124
            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
1125

J
Jim Hays 已提交
1126
            # Avoid making the inpainting conditioning unless necessary as
1127 1128 1129 1130 1131
            # this does need some extra compute to decode / encode the image again.
            if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
                image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
            else:
                image_conditioning = self.txt2img_image_conditioning(samples)
A
AUTOMATIC 已提交
1132
        else:
1133
            lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
1134

1135 1136 1137 1138 1139
            batch_images = []
            for i, x_sample in enumerate(lowres_samples):
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)
                image = Image.fromarray(x_sample)
1140 1141 1142

                save_intermediate(image, i)

A
AUTOMATIC 已提交
1143
                image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
1144 1145 1146 1147 1148
                image = np.array(image).astype(np.float32) / 255.0
                image = np.moveaxis(image, 2, 0)
                batch_images.append(image)

            decoded_samples = torch.from_numpy(np.array(batch_images))
K
Kohaku-Blueleaf 已提交
1149
            decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
1150

1151 1152
            if opts.sd_vae_encode_method != 'Full':
                self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1153
            samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
A
AUTOMATIC 已提交
1154

1155
            image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
1156

A
AUTOMATIC 已提交
1157
        shared.state.nextjob()
1158

1159 1160
        samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]

A
AUTOMATIC 已提交
1161
        noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
1162 1163 1164

        # GC now before running the next img2img to prevent running out of memory
        devices.torch_gc()
1165

1166 1167 1168 1169
        if not self.disable_extra_networks:
            with devices.autocast():
                extra_networks.activate(self, self.hr_extra_network_data)

1170 1171 1172
        with devices.autocast():
            self.calculate_hr_conds()

1173
        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
1174

1175 1176 1177
        if self.scripts is not None:
            self.scripts.before_hr(self)

1178
        samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
A
AUTOMATIC 已提交
1179

1180
        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
P
papuSpartan 已提交
1181

A
AUTOMATIC1111 已提交
1182 1183
        decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)

A
AUTOMATIC 已提交
1184 1185
        self.is_hr_pass = False

A
AUTOMATIC1111 已提交
1186
        return decoded_samples
1187

1188
    def close(self):
W
w-e-w 已提交
1189
        super().close()
1190 1191
        self.hr_c = None
        self.hr_uc = None
A
AUTOMATIC1111 已提交
1192
        if not opts.persistent_cond_cache:
W
w-e-w 已提交
1193 1194
            StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
            StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220

    def setup_prompts(self):
        super().setup_prompts()

        if not self.enable_hr:
            return

        if self.hr_prompt == '':
            self.hr_prompt = self.prompt

        if self.hr_negative_prompt == '':
            self.hr_negative_prompt = self.negative_prompt

        if type(self.hr_prompt) == list:
            self.all_hr_prompts = self.hr_prompt
        else:
            self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]

        if type(self.hr_negative_prompt) == list:
            self.all_hr_negative_prompts = self.hr_negative_prompt
        else:
            self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]

        self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
        self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]

1221 1222 1223 1224
    def calculate_hr_conds(self):
        if self.hr_c is not None:
            return

A
AUTOMATIC1111 已提交
1225 1226 1227 1228 1229
        hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
        hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)

        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
1230

1231 1232 1233
    def setup_conds(self):
        super().setup_conds()

1234 1235 1236
        self.hr_uc = None
        self.hr_c = None

A
AUTOMATIC1111 已提交
1237
        if self.enable_hr and self.hr_checkpoint_info is None:
1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248
            if shared.opts.hires_fix_use_firstpass_conds:
                self.calculate_hr_conds()

            elif lowvram.is_enabled(shared.sd_model):  # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
                with devices.autocast():
                    extra_networks.activate(self, self.hr_extra_network_data)

                self.calculate_hr_conds()

                with devices.autocast():
                    extra_networks.activate(self, self.extra_network_data)
1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260

    def parse_extra_network_prompts(self):
        res = super().parse_extra_network_prompts()

        if self.enable_hr:
            self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
            self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]

            self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)

        return res

1261 1262 1263 1264

class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    sampler = None

1265
    def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
1266 1267 1268 1269 1270
        super().__init__(**kwargs)

        self.init_images = init_images
        self.resize_mode: int = resize_mode
        self.denoising_strength: float = denoising_strength
K
Kyle 已提交
1271
        self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
1272 1273
        self.init_latent = None
        self.image_mask = mask
A
AUTOMATIC 已提交
1274
        self.latent_mask = None
1275
        self.mask_for_overlay = None
1276 1277 1278 1279 1280
        if mask_blur is not None:
            mask_blur_x = mask_blur
            mask_blur_y = mask_blur
        self.mask_blur_x = mask_blur_x
        self.mask_blur_y = mask_blur_y
1281 1282
        self.inpainting_fill = inpainting_fill
        self.inpaint_full_res = inpaint_full_res
1283
        self.inpaint_full_res_padding = inpaint_full_res_padding
A
AUTOMATIC 已提交
1284
        self.inpainting_mask_invert = inpainting_mask_invert
1285
        self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
1286 1287
        self.mask = None
        self.nmask = None
1288
        self.image_conditioning = None
1289

A
AUTOMATIC 已提交
1290
    def init(self, all_prompts, all_seeds, all_subseeds):
1291
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1292 1293
        crop_region = None

1294
        image_mask = self.image_mask
A
AUTOMATIC 已提交
1295

1296 1297
        if image_mask is not None:
            image_mask = image_mask.convert('L')
A
AUTOMATIC 已提交
1298

1299 1300
            if self.inpainting_mask_invert:
                image_mask = ImageOps.invert(image_mask)
A
AUTOMATIC 已提交
1301

1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312
            if self.mask_blur_x > 0:
                np_mask = np.array(image_mask)
                kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1
                np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
                image_mask = Image.fromarray(np_mask)

            if self.mask_blur_y > 0:
                np_mask = np.array(image_mask)
                kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1
                np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
                image_mask = Image.fromarray(np_mask)
1313 1314

            if self.inpaint_full_res:
1315 1316
                self.mask_for_overlay = image_mask
                mask = image_mask.convert('L')
1317
                crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
1318
                crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
1319 1320 1321
                x1, y1, x2, y2 = crop_region

                mask = mask.crop(crop_region)
1322
                image_mask = images.resize_image(2, mask, self.width, self.height)
1323 1324
                self.paste_to = (x1, y1, x2-x1, y2-y1)
            else:
1325
                image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
1326
                np_mask = np.array(image_mask)
J
JJ 已提交
1327
                np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
1328
                self.mask_for_overlay = Image.fromarray(np_mask)
1329 1330 1331

            self.overlay_images = []

1332
        latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
1333

1334 1335 1336
        add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
        if add_color_corrections:
            self.color_corrections = []
1337 1338
        imgs = []
        for img in self.init_images:
1339 1340 1341 1342 1343 1344

            # Save init image
            if opts.save_init_img:
                self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
                images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)

1345
            image = images.flatten(img, opts.img2img_background_color)
1346

A
Andrew Ryan 已提交
1347
            if crop_region is None and self.resize_mode != 3:
1348
                image = images.resize_image(self.resize_mode, image, self.width, self.height)
1349

1350
            if image_mask is not None:
1351 1352 1353 1354 1355
                image_masked = Image.new('RGBa', (image.width, image.height))
                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))

                self.overlay_images.append(image_masked.convert('RGBA'))

1356
            # crop_region is not None if we are doing inpaint full res
1357 1358 1359 1360
            if crop_region is not None:
                image = image.crop(crop_region)
                image = images.resize_image(2, image, self.width, self.height)

1361
            if image_mask is not None:
1362
                if self.inpainting_fill != 1:
1363
                    image = masking.fill(image, latent_mask)
1364

1365
            if add_color_corrections:
1366 1367
                self.color_corrections.append(setup_color_correction(image))

1368 1369 1370 1371 1372 1373 1374 1375 1376
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)

            imgs.append(image)

        if len(imgs) == 1:
            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size
1377 1378 1379 1380

            if self.color_corrections is not None and len(self.color_corrections) == 1:
                self.color_corrections = self.color_corrections * self.batch_size

1381 1382 1383 1384 1385 1386 1387
        elif len(imgs) <= self.batch_size:
            self.batch_size = len(imgs)
            batch_images = np.array(imgs)
        else:
            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

        image = torch.from_numpy(batch_images)
K
Kohaku-Blueleaf 已提交
1388
        image = image.to(shared.device, dtype=devices.dtype_vae)
1389 1390 1391 1392

        if opts.sd_vae_encode_method != 'Full':
            self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method

K
Kohaku-Blueleaf 已提交
1393
        self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
1394
        devices.torch_gc()
1395

1396 1397
        if self.resize_mode == 3:
            self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
A
Andrew Ryan 已提交
1398

1399
        if image_mask is not None:
1400
            init_mask = latent_mask
A
AUTOMATIC 已提交
1401
            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
A
AUTOMATIC 已提交
1402
            latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
1403
            latmask = latmask[0]
1404
            latmask = np.around(latmask)
1405 1406 1407 1408 1409
            latmask = np.tile(latmask[None], (4, 1, 1))

            self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
            self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)

A
AUTOMATIC 已提交
1410
            # this needs to be fixed to be done in sample() using actual seeds for batches
1411
            if self.inpainting_fill == 2:
A
AUTOMATIC 已提交
1412
                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
1413 1414 1415
            elif self.inpainting_fill == 3:
                self.init_latent = self.init_latent * self.mask

1416
        self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
1417

1418
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
A
AUTOMATIC 已提交
1419
        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
1420 1421 1422 1423

        if self.initial_noise_multiplier != 1.0:
            self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
            x *= self.initial_noise_multiplier
A
AUTOMATIC 已提交
1424

1425
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
1426 1427 1428 1429

        if self.mask is not None:
            samples = samples * self.nmask + self.init_latent * self.mask

1430 1431 1432
        del x
        devices.torch_gc()

1433
        return samples
1434 1435 1436

    def get_token_merging_ratio(self, for_hr=False):
        return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio