processing.py 65.4 KB
Newer Older
1
import json
2
import logging
3 4 5
import math
import os
import sys
6
import hashlib
7 8 9

import torch
import numpy as np
A
linter  
AUTOMATIC 已提交
10
from PIL import Image, ImageOps
11
import random
12 13
import cv2
from skimage import exposure
A
AUTOMATIC 已提交
14
from typing import Any, Dict, List
15

16
import modules.sd_hijack
17
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
18
from modules.sd_hijack import model_hijack
19
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
20 21
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
22
import modules.paths as paths
A
AUTOMATIC 已提交
23
import modules.face_restoration
24
import modules.images as images
A
AUTOMATIC 已提交
25
import modules.styles
26 27
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
J
Jay Smith 已提交
28 29
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
30

J
Jay Smith 已提交
31
from einops import repeat, rearrange
32
from blendmodes.blend import blendLayers, BlendType
33 34


35 36 37 38 39
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8


40
def setup_color_correction(image):
R
Robin Fernandes 已提交
41
    logging.info("Calibrating color correction.")
42 43 44 45
    correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
    return correction_target


46
def apply_color_correction(correction, original_image):
R
Robin Fernandes 已提交
47
    logging.info("Applying color correction.")
48 49
    image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
        cv2.cvtColor(
50
            np.asarray(original_image),
51 52 53 54 55
            cv2.COLOR_RGB2LAB
        ),
        correction,
        channel_axis=2
    ), cv2.COLOR_LAB2RGB).astype("uint8"))
56

57
    image = blendLayers(image, original_image, BlendType.LUMINOSITY)
58

59 60
    return image

A
AUTOMATIC 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

def apply_overlay(image, paste_loc, index, overlays):
    if overlays is None or index >= len(overlays):
        return image

    overlay = overlays[index]

    if paste_loc is not None:
        x, y, w, h = paste_loc
        base_image = Image.new('RGBA', (overlay.width, overlay.height))
        image = images.resize_image(1, image, w, h)
        base_image.paste(image, (x, y))
        image = base_image

    image = image.convert('RGBA')
    image.alpha_composite(overlay)
    image = image.convert('RGB')
78 79

    return image
80

F
frostydad 已提交
81

82
def txt2img_image_conditioning(sd_model, x, width, height):
83 84 85 86
    if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models

        # The "masked-image" in this case will just be all zeros since the entire image is masked.
        image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
87
        image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
88 89 90 91 92 93

        # Add the fake full 1s mask to the first dimension.
        image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
        image_conditioning = image_conditioning.to(x.dtype)

        return image_conditioning
94

95
    elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
96

97
        return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
98

99 100 101 102 103
    else:
        # Dummy zero conditioning if we're not using inpainting or unclip models.
        # Still takes up a bit of memory, but no encoder call.
        # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
        return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
104 105


106
class StableDiffusionProcessing:
A
arcticfaded 已提交
107 108 109
    """
    The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
    """
W
w-e-w 已提交
110 111 112
    cached_uc = [None, None]
    cached_c = [None, None]

D
devdn 已提交
113
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
114
        if sampler_index is not None:
115
            print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
116

117 118 119
        self.outpath_samples: str = outpath_samples
        self.outpath_grids: str = outpath_grids
        self.prompt: str = prompt
A
AUTOMATIC 已提交
120
        self.prompt_for_display: str = None
121
        self.negative_prompt: str = (negative_prompt or "")
122
        self.styles: list = styles or []
123
        self.seed: int = seed
124 125 126 127
        self.subseed: int = subseed
        self.subseed_strength: float = subseed_strength
        self.seed_resize_from_h: int = seed_resize_from_h
        self.seed_resize_from_w: int = seed_resize_from_w
128
        self.sampler_name: str = sampler_name
129 130 131 132 133 134
        self.batch_size: int = batch_size
        self.n_iter: int = n_iter
        self.steps: int = steps
        self.cfg_scale: float = cfg_scale
        self.width: int = width
        self.height: int = height
A
AUTOMATIC 已提交
135
        self.restore_faces: bool = restore_faces
136
        self.tiling: bool = tiling
137 138
        self.do_not_save_samples: bool = do_not_save_samples
        self.do_not_save_grid: bool = do_not_save_grid
A
AUTOMATIC 已提交
139
        self.extra_generation_params: dict = extra_generation_params or {}
140
        self.overlay_images = overlay_images
141
        self.eta = eta
142
        self.do_not_reload_embeddings = do_not_reload_embeddings
143
        self.paste_to = None
144
        self.color_corrections = None
145
        self.denoising_strength: float = denoising_strength
146
        self.sampler_noise_scheduler_override = None
147
        self.ddim_discretize = ddim_discretize or opts.ddim_discretize
D
devdn 已提交
148
        self.s_min_uncond = s_min_uncond or opts.s_min_uncond
A
arcticfaded 已提交
149 150
        self.s_churn = s_churn or opts.s_churn
        self.s_tmin = s_tmin or opts.s_tmin
C
catboxanon 已提交
151
        self.s_tmax = opts.data.get('s_tmax', 0) or float('inf')  # not representable as a standard ui option
A
arcticfaded 已提交
152
        self.s_noise = s_noise or opts.s_noise
153
        self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
154
        self.override_settings_restore_afterwards = override_settings_restore_afterwards
155
        self.is_using_inpainting_conditioning = False
A
AUTOMATIC 已提交
156
        self.disable_extra_networks = False
157 158
        self.token_merging_ratio = 0
        self.token_merging_ratio_hr = 0
159

160 161 162 163 164 165
        if not seed_enable_extras:
            self.subseed = -1
            self.subseed_strength = 0
            self.seed_resize_from_h = 0
            self.seed_resize_from_w = 0

166
        self.scripts = None
N
noodleanon 已提交
167
        self.script_args = script_args
168
        self.all_prompts = None
169
        self.all_negative_prompts = None
170 171
        self.all_seeds = None
        self.all_subseeds = None
172
        self.iteration = 0
A
AUTOMATIC 已提交
173
        self.is_hr_pass = False
174
        self.sampler = None
175

176 177
        self.prompts = None
        self.negative_prompts = None
W
w-e-w 已提交
178
        self.extra_network_data = None
179 180 181 182
        self.seeds = None
        self.subseeds = None

        self.step_multiplier = 1
W
w-e-w 已提交
183 184
        self.cached_uc = StableDiffusionProcessing.cached_uc
        self.cached_c = StableDiffusionProcessing.cached_c
185 186
        self.uc = None
        self.c = None
187

188 189
        self.user = None

190 191 192 193
    @property
    def sd_model(self):
        return shared.sd_model

194
    def txt2img_image_conditioning(self, x, width=None, height=None):
195
        self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
196

197
        return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
198

J
Jay Smith 已提交
199 200 201 202 203 204 205
    def depth2img_image_conditioning(self, source_image):
        # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
        transformer = AddMiDaS(model_type="dpt_hybrid")
        transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
        midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
        midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)

206
        conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
J
Jay Smith 已提交
207 208 209 210 211 212 213 214 215 216
        conditioning = torch.nn.functional.interpolate(
            self.sd_model.depth_model(midas_in),
            size=conditioning_image.shape[2:],
            mode="bicubic",
            align_corners=False,
        )

        (depth_min, depth_max) = torch.aminmax(conditioning)
        conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
        return conditioning
217

218
    def edit_image_conditioning(self, source_image):
219
        conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
220 221 222

        return conditioning_image

223 224 225 226 227 228 229 230
    def unclip_image_conditioning(self, source_image):
        c_adm = self.sd_model.embedder(source_image)
        if self.sd_model.noise_augmentor is not None:
            noise_level = 0 # TODO: Allow other noise levels?
            c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
            c_adm = torch.cat((c_adm, noise_level_emb), 1)
        return c_adm

231
    def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
232 233
        self.is_using_inpainting_conditioning = True

234 235 236 237 238 239 240 241 242 243 244 245
        # Handle the different mask inputs
        if image_mask is not None:
            if torch.is_tensor(image_mask):
                conditioning_mask = image_mask
            else:
                conditioning_mask = np.array(image_mask.convert("L"))
                conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
                conditioning_mask = torch.from_numpy(conditioning_mask[None, None])

                # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
                conditioning_mask = torch.round(conditioning_mask)
        else:
246
            conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
247 248 249

        # Create another latent image, this time with a masked version of the original input.
        # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
250
        conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
251 252 253 254 255
        conditioning_image = torch.lerp(
            source_image,
            source_image * (1.0 - conditioning_mask),
            getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
        )
J
Jim Hays 已提交
256

257
        # Encode the new masked image using first stage of network.
K
Kohaku-Blueleaf 已提交
258
        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
259 260 261 262 263 264 265 266 267

        # Create the concatenated conditioning tensor to be fed to `c_concat`
        conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
        conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
        image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
        image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)

        return image_conditioning

J
Jay Smith 已提交
268
    def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
269 270
        source_image = devices.cond_cast_float(source_image)

J
Jay Smith 已提交
271 272 273
        # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
        # identify itself with a field common to all models. The conditioning_key is also hybrid.
        if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
274
            return self.depth2img_image_conditioning(source_image)
J
Jay Smith 已提交
275

276 277 278
        if self.sd_model.cond_stage_key == "edit":
            return self.edit_image_conditioning(source_image)

J
Jay Smith 已提交
279
        if self.sampler.conditioning_key in {'hybrid', 'concat'}:
280
            return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
J
Jay Smith 已提交
281

282 283 284
        if self.sampler.conditioning_key == "crossattn-adm":
            return self.unclip_image_conditioning(source_image)

J
Jay Smith 已提交
285 286 287
        # Dummy zero conditioning if we're not using inpainting or depth model.
        return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)

A
AUTOMATIC 已提交
288
    def init(self, all_prompts, all_seeds, all_subseeds):
289 290
        pass

291
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
292 293
        raise NotImplementedError()

294 295
    def close(self):
        self.sampler = None
296 297
        self.c = None
        self.uc = None
W
w-e-w 已提交
298 299 300
        if not opts.experimental_persistent_cond_cache:
            StableDiffusionProcessing.cached_c = [None, None]
            StableDiffusionProcessing.cached_uc = [None, None]
301

302 303 304 305 306 307
    def get_token_merging_ratio(self, for_hr=False):
        if for_hr:
            return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio

        return self.token_merging_ratio or opts.token_merging_ratio

308 309 310 311 312 313 314 315 316 317 318 319 320 321
    def setup_prompts(self):
        if type(self.prompt) == list:
            self.all_prompts = self.prompt
        else:
            self.all_prompts = self.batch_size * self.n_iter * [self.prompt]

        if type(self.negative_prompt) == list:
            self.all_negative_prompts = self.negative_prompt
        else:
            self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]

        self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
        self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]

322
    def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
323 324 325 326 327 328 329 330
        """
        Returns the result of calling function(shared.sd_model, required_prompts, steps)
        using a cache to store the result if the same arguments have been used before.

        cache is an array containing two elements. The first element is a tuple
        representing the previously used arguments, or None if no arguments
        have been used before. The second element is where the previously
        computed result is stored.
331 332

        caches is a list with items described above.
333
        """
334 335 336 337 338 339 340 341 342 343 344 345 346

        cached_params = (
            required_prompts,
            steps,
            opts.CLIP_stop_at_last_layers,
            shared.sd_model.sd_checkpoint_info,
            extra_network_data,
            opts.sdxl_crop_left,
            opts.sdxl_crop_top,
            self.width,
            self.height,
        )

347
        for cache in caches:
348
            if cache[0] is not None and cached_params == cache[0]:
349 350 351
                return cache[1]

        cache = caches[0]
352 353 354 355

        with devices.autocast():
            cache[1] = function(shared.sd_model, required_prompts, steps)

356
        cache[0] = cached_params
357 358 359
        return cache[1]

    def setup_conds(self):
A
AUTOMATIC1111 已提交
360
        prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
361
        negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
A
AUTOMATIC1111 已提交
362

363 364
        sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
        self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
A
AUTOMATIC1111 已提交
365 366
        self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
        self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
367 368

    def parse_extra_network_prompts(self):
W
w-e-w 已提交
369
        self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
370

371 372

class Processed:
373
    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
374 375
        self.images = images_list
        self.prompt = p.prompt
376
        self.negative_prompt = p.negative_prompt
377
        self.seed = seed
378 379
        self.subseed = subseed
        self.subseed_strength = p.subseed_strength
380
        self.info = info
381
        self.comments = comments
382 383
        self.width = p.width
        self.height = p.height
384
        self.sampler_name = p.sampler_name
385
        self.cfg_scale = p.cfg_scale
K
Kyle 已提交
386
        self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
387
        self.steps = p.steps
388 389 390 391 392 393 394 395 396
        self.batch_size = p.batch_size
        self.restore_faces = p.restore_faces
        self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
        self.sd_model_hash = shared.sd_model.sd_model_hash
        self.seed_resize_from_w = p.seed_resize_from_w
        self.seed_resize_from_h = p.seed_resize_from_h
        self.denoising_strength = getattr(p, 'denoising_strength', None)
        self.extra_generation_params = p.extra_generation_params
        self.index_of_first_image = index_of_first_image
M
Milly 已提交
397
        self.styles = p.styles
M
Milly 已提交
398
        self.job_timestamp = state.job_timestamp
399
        self.clip_skip = opts.CLIP_stop_at_last_layers
400 401
        self.token_merging_ratio = p.token_merging_ratio
        self.token_merging_ratio_hr = p.token_merging_ratio_hr
402

C
C43H66N12O12S2 已提交
403
        self.eta = p.eta
404 405 406 407 408
        self.ddim_discretize = p.ddim_discretize
        self.s_churn = p.s_churn
        self.s_tmin = p.s_tmin
        self.s_tmax = p.s_tmax
        self.s_noise = p.s_noise
A
Aarni Koskela 已提交
409
        self.s_min_uncond = p.s_min_uncond
410
        self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
411 412
        self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
        self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
G
githublsx 已提交
413
        self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
414
        self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
415
        self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
416

417 418 419 420
        self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
        self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
        self.all_seeds = all_seeds or p.all_seeds or [self.seed]
        self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
421
        self.infotexts = infotexts or [info]
422 423 424

    def js(self):
        obj = {
425
            "prompt": self.all_prompts[0],
426
            "all_prompts": self.all_prompts,
427 428
            "negative_prompt": self.all_negative_prompts[0],
            "all_negative_prompts": self.all_negative_prompts,
429 430 431 432
            "seed": self.seed,
            "all_seeds": self.all_seeds,
            "subseed": self.subseed,
            "all_subseeds": self.all_subseeds,
433
            "subseed_strength": self.subseed_strength,
434 435
            "width": self.width,
            "height": self.height,
436
            "sampler_name": self.sampler_name,
437 438
            "cfg_scale": self.cfg_scale,
            "steps": self.steps,
439 440 441 442 443 444 445 446 447
            "batch_size": self.batch_size,
            "restore_faces": self.restore_faces,
            "face_restoration_model": self.face_restoration_model,
            "sd_model_hash": self.sd_model_hash,
            "seed_resize_from_w": self.seed_resize_from_w,
            "seed_resize_from_h": self.seed_resize_from_h,
            "denoising_strength": self.denoising_strength,
            "extra_generation_params": self.extra_generation_params,
            "index_of_first_image": self.index_of_first_image,
448
            "infotexts": self.infotexts,
M
Milly 已提交
449
            "styles": self.styles,
M
Milly 已提交
450
            "job_timestamp": self.job_timestamp,
451
            "clip_skip": self.clip_skip,
452
            "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
453 454 455 456
        }

        return json.dumps(obj)

S
space-nuko 已提交
457
    def infotext(self, p: StableDiffusionProcessing, index):
458 459
        return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)

460 461 462
    def get_token_merging_ratio(self, for_hr=False):
        return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio

463

464 465 466 467
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
468 469 470 471 472 473
    dot = (low_norm*high_norm).sum(1)

    if dot.mean() > 0.9995:
        return low * val + high * (1 - val)

    omega = torch.acos(dot)
474 475 476
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
    return res
477

478

479
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
480
    eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
481
    xs = []
482

483 484
    # if we have multiple seeds, this means we are working with batch size>1; this then
    # enables the generation of additional tensors with noise that the sampler will use during its processing.
S
Steve Eberhardt 已提交
485
    # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
486
    # produce the same images as with two batches [100], [101].
487
    if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
488 489 490 491
        sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
    else:
        sampler_noises = None

492 493 494 495
    for i, seed in enumerate(seeds):
        noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)

        subnoise = None
496
        if subseeds is not None and subseed_strength != 0:
497
            subseed = 0 if i >= len(subseeds) else subseeds[i]
E
Elias Oenal 已提交
498

A
AUTOMATIC 已提交
499
            subnoise = devices.randn(subseed, noise_shape)
500 501 502

        # randn results depend on device; gpu and cpu get different results for same seed;
        # the way I see it, it's better to do this on CPU, so that everyone gets same result;
503
        # but the original script had it like this, so I do not dare change it for now because
504
        # it will break everyone's seeds.
A
AUTOMATIC 已提交
505
        noise = devices.randn(seed, noise_shape)
506 507 508 509 510

        if subnoise is not None:
            noise = slerp(subseed_strength, noise, subnoise)

        if noise_shape != shape:
A
AUTOMATIC 已提交
511 512
            x = devices.randn(seed, shape)
            dx = (shape[2] - noise_shape[2]) // 2
513 514 515 516 517 518 519 520 521 522 523
            dy = (shape[1] - noise_shape[1]) // 2
            w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
            h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
            tx = 0 if dx < 0 else dx
            ty = 0 if dy < 0 else dy
            dx = max(-dx, 0)
            dy = max(-dy, 0)

            x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
            noise = x

524 525
        if sampler_noises is not None:
            cnt = p.sampler.number_of_needed_noises(p)
526

527
            if eta_noise_seed_delta > 0:
528
                devices.manual_seed(seed + eta_noise_seed_delta)
A
AUTOMATIC 已提交
529

530 531
            for j in range(cnt):
                sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
532 533

        xs.append(noise)
534 535 536 537

    if sampler_noises is not None:
        p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]

538
    x = torch.stack(xs).to(shared.device)
539 540 541
    return x


A
AUTOMATIC1111 已提交
542 543 544 545
class DecodedSamples(list):
    already_decoded = True


546
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
A
AUTOMATIC1111 已提交
547
    samples = DecodedSamples()
548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579

    for i in range(batch.shape[0]):
        sample = decode_first_stage(model, batch[i:i + 1])[0]

        if check_for_nans:
            try:
                devices.test_for_nans(sample, "vae")
            except devices.NansException as e:
                if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
                    raise e

                errors.print_error_explanation(
                    "A tensor with all NaNs was produced in VAE.\n"
                    "Web UI will now convert VAE into 32-bit float and retry.\n"
                    "To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n"
                    "To always start with 32-bit VAE, use --no-half-vae commandline flag."
                )

                devices.dtype_vae = torch.float32
                model.first_stage_model.to(devices.dtype_vae)
                batch = batch.to(devices.dtype_vae)

                sample = decode_first_stage(model, batch[i:i + 1])[0]

        if target_device is not None:
            sample = sample.to(target_device)

        samples.append(sample)

    return samples


580 581 582 583 584 585 586
def get_fixed_seed(seed):
    if seed is None or seed == '' or seed == -1:
        return int(random.randrange(4294967294))

    return seed


587
def fix_seed(p):
588 589
    p.seed = get_fixed_seed(p.seed)
    p.subseed = get_fixed_seed(p.subseed)
A
AUTOMATIC 已提交
590 591


592 593 594 595 596 597 598 599 600 601
def program_version():
    import launch

    res = launch.git_tag()
    if res == "<none>":
        res = None

    return res


602 603 604 605 606 607
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):
    if index is None:
        index = position_in_batch + iteration * p.batch_size

    if all_negative_prompts is None:
        all_negative_prompts = p.all_negative_prompts
608

609
    clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
P
papuSpartan 已提交
610
    enable_hr = getattr(p, 'enable_hr', False)
611 612
    token_merging_ratio = p.get_token_merging_ratio()
    token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
613

614 615 616 617
    uses_ensd = opts.eta_noise_seed_delta != 0
    if uses_ensd:
        uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)

618 619
    generation_params = {
        "Steps": p.steps,
620
        "Sampler": p.sampler_name,
621
        "CFG scale": p.cfg_scale,
K
Kyle 已提交
622
        "Image CFG scale": getattr(p, 'image_cfg_scale', None),
A
AUTOMATIC1111 已提交
623
        "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
624 625 626
        "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
        "Size": f"{p.width}x{p.height}",
        "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
627
        "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
A
AUTOMATIC1111 已提交
628
        "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
629
        "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
M
missionfloyd 已提交
630
        "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
631
        "Denoising strength": getattr(p, 'denoising_strength', None),
632
        "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
633
        "Clip skip": None if clip_skip <= 1 else clip_skip,
634
        "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
635 636
        "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
        "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
637
        "Init image hash": getattr(p, 'init_img_hash', None),
638
        "RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
639
        "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
640
        **p.extra_generation_params,
641
        "Version": program_version() if opts.add_version_to_infotext else None,
642
        "User": p.user if opts.add_user_name_to_info else None,
643 644
    }

645
    generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
646

647
    prompt_text = p.prompt if use_main_prompt else all_prompts[index]
648
    negative_prompt_text = f"\nNegative prompt: {all_negative_prompts[index]}" if all_negative_prompts[index] else ""
649

650
    return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
651 652


653
def process_images(p: StableDiffusionProcessing) -> Processed:
654 655 656
    if p.scripts is not None:
        p.scripts.before_process(p)

657 658 659
    stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}

    try:
W
w-e-w 已提交
660
        # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
A
Aarni Koskela 已提交
661
        if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
W
w-e-w 已提交
662 663 664
            p.override_settings.pop('sd_model_checkpoint', None)
            sd_models.reload_model_weights()

665
        for k, v in p.override_settings.items():
666
            setattr(opts, k, v)
667 668

            if k == 'sd_model_checkpoint':
A
AUTOMATIC 已提交
669
                sd_models.reload_model_weights()
670 671

            if k == 'sd_vae':
A
AUTOMATIC 已提交
672
                sd_vae.reload_vae_weights()
673

674
        sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
675

676 677
        res = process_images_inner(p)

678
    finally:
679
        sd_models.apply_token_merging(p.sd_model, 0)
680

681 682 683 684
        # restore opts to original state
        if p.override_settings_restore_afterwards:
            for k, v in stored_opts.items():
                setattr(opts, k, v)
A
AUTOMATIC 已提交
685 686 687

                if k == 'sd_vae':
                    sd_vae.reload_vae_weights()
688 689 690 691 692

    return res


def process_images_inner(p: StableDiffusionProcessing) -> Processed:
693 694
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

695 696 697 698
    if type(p.prompt) == list:
        assert(len(p.prompt) > 0)
    else:
        assert p.prompt is not None
699

700
    devices.torch_gc()
701

702 703
    seed = get_fixed_seed(p.seed)
    subseed = get_fixed_seed(p.subseed)
704

705
    modules.sd_hijack.model_hijack.apply_circular(p.tiling)
706
    modules.sd_hijack.model_hijack.clear_comments()
707

708
    comments = {}
A
AUTOMATIC 已提交
709

710
    p.setup_prompts()
I
invincibledude 已提交
711

712
    if type(seed) == list:
713
        p.all_seeds = seed
A
AUTOMATIC 已提交
714
    else:
715
        p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
716

717
    if type(subseed) == list:
718
        p.all_subseeds = subseed
719
    else:
720
        p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
721

722
    if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
723
        model_hijack.embedding_db.load_textual_inversion_embeddings()
724

725
    if p.scripts is not None:
A
AUTOMATIC 已提交
726
        p.scripts.process(p)
727

728
    infotexts = []
729
    output_images = []
730

731
    with torch.no_grad(), p.sd_model.ema_scope():
A
AUTOMATIC 已提交
732
        with devices.autocast():
733
            p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
734

735 736
            # for OSX, loading the model during sampling changes the generated picture, so it is loaded here
            if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
737 738
                sd_vae_approx.model()

A
AUTOMATIC 已提交
739 740
            sd_unet.apply_unet()

A
AUTOMATIC 已提交
741 742
        if state.job_count == -1:
            state.job_count = p.n_iter
A
AUTOMATIC 已提交
743

744
        for n in range(p.n_iter):
745 746
            p.iteration = n

747 748
            if state.skipped:
                state.skipped = False
J
Jim Hays 已提交
749

750 751 752
            if state.interrupted:
                break

753 754 755 756
            p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
            p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
757

758
            if p.scripts is not None:
759
                p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
760

761
            if len(p.prompts) == 0:
762 763
                break

W
w-e-w 已提交
764
            p.parse_extra_network_prompts()
I
InvincibleDude 已提交
765

766 767
            if not p.disable_extra_networks:
                with devices.autocast():
W
w-e-w 已提交
768
                    extra_networks.activate(p, p.extra_network_data)
769

A
Artem Zagidulin 已提交
770
            if p.scripts is not None:
771
                p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
A
Artem Zagidulin 已提交
772

773 774 775 776 777 778 779 780 781
            # params.txt should be saved after scripts.process_batch, since the
            # infotext could be modified by that callback
            # Example: a wildcard processed by process_batch sets an extra model
            # strength, which is saved as "Model Strength: 1.0" in the infotext
            if n == 0:
                with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
                    processed = Processed(p, [], p.seed, "")
                    file.write(processed.infotext(p, 0))

782
            p.setup_conds()
783

784 785 786 787
            for comment in model_hijack.comments:
                comments[comment] = 1

            p.extra_generation_params.update(model_hijack.extra_generation_params)
788 789

            if p.n_iter > 1:
790
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"
791

792
            with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
793
                samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
794

A
AUTOMATIC1111 已提交
795 796 797
            if getattr(samples_ddim, 'already_decoded', False):
                x_samples_ddim = samples_ddim
            else:
798 799 800
                if opts.sd_vae_decode_method != 'Full':
                    p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method

A
AUTOMATIC1111 已提交
801 802
                x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)

803
            x_samples_ddim = torch.stack(x_samples_ddim).float()
804 805
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

806 807
            del samples_ddim

808
            if lowvram.is_enabled(shared.sd_model):
809 810 811 812
                lowvram.send_everything_to_cpu()

            devices.torch_gc()

813 814
            if p.scripts is not None:
                p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
G
GRMrGecko 已提交
815

816 817
                p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
                p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
818

819
                batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
820 821 822 823
                p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
                x_samples_ddim = batch_params.images

            def infotext(index=0, use_main_prompt=False):
A
AUTOMATIC1111 已提交
824
                return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
L
ljleb 已提交
825

826
            for i, x_sample in enumerate(x_samples_ddim):
827 828
                p.batch_index = i

829 830 831
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)

832
                if p.restore_faces:
833
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
834
                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
835

836
                    devices.torch_gc()
837

838 839
                    x_sample = modules.face_restoration.restore_faces(x_sample)
                    devices.torch_gc()
840

841
                image = Image.fromarray(x_sample)
V
Vladimir Repin 已提交
842

843 844 845 846 847
                if p.scripts is not None:
                    pp = scripts.PostprocessImageArgs(image)
                    p.scripts.postprocess_image(p, pp)
                    image = pp.image

848
                if p.color_corrections is not None and i < len(p.color_corrections):
849
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
A
AUTOMATIC 已提交
850
                        image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
851
                        images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
852
                    image = apply_color_correction(p.color_corrections[i], image)
853

A
AUTOMATIC 已提交
854
                image = apply_overlay(image, p.paste_to, i, p.overlay_images)
855 856

                if opts.samples_save and not p.do_not_save_samples:
857
                    images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
858

859
                text = infotext(i)
860
                infotexts.append(text)
861 862
                if opts.enable_pnginfo:
                    image.info["parameters"] = text
863 864
                output_images.append(image)

865
                if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
866
                    image_mask = p.mask_for_overlay.convert('RGB')
867
                    image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
868 869

                    if opts.save_mask:
870
                        images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
871 872

                    if opts.save_mask_composite:
873
                        images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
874 875 876

                    if opts.return_mask:
                        output_images.append(image_mask)
877

878 879 880
                    if opts.return_mask_composite:
                        output_images.append(image_mask_composite)

J
Jim Hays 已提交
881
            del x_samples_ddim
A
AUTOMATIC 已提交
882

883
            devices.torch_gc()
884

885
            state.nextjob()
886

887 888
        p.color_corrections = None

889
        index_of_first_image = 0
890
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
891
        if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
A
AUTOMATIC 已提交
892
            grid = images.image_grid(output_images, p.batch_size)
893

894
            if opts.return_grid:
895
                text = infotext(use_main_prompt=True)
896
                infotexts.insert(0, text)
897 898
                if opts.enable_pnginfo:
                    grid.info["parameters"] = text
899
                output_images.insert(0, grid)
900
                index_of_first_image = 1
901 902

            if opts.grid_save:
903
                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
904

W
w-e-w 已提交
905 906
    if not p.disable_extra_networks and p.extra_network_data:
        extra_networks.deactivate(p, p.extra_network_data)
A
AUTOMATIC 已提交
907

908
    devices.torch_gc()
A
AUTOMATIC 已提交
909

910 911 912 913
    res = Processed(
        p,
        images_list=output_images,
        seed=p.all_seeds[0],
A
AUTOMATIC1111 已提交
914
        info=infotexts[0],
915
        comments="".join(f"{comment}\n" for comment in comments),
916 917 918 919
        subseed=p.all_subseeds[0],
        index_of_first_image=index_of_first_image,
        infotexts=infotexts,
    )
A
AUTOMATIC 已提交
920 921 922 923 924

    if p.scripts is not None:
        p.scripts.postprocess(p, res)

    return res
925 926


927 928 929 930 931 932 933 934 935 936 937 938
def old_hires_fix_first_pass_dimensions(width, height):
    """old algorithm for auto-calculating first pass size"""

    desired_pixel_count = 512 * 512
    actual_pixel_count = width * height
    scale = math.sqrt(desired_pixel_count / actual_pixel_count)
    width = math.ceil(scale * width / 64) * 64
    height = math.ceil(scale * height / 64) * 64

    return width, height


939 940
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    sampler = None
W
w-e-w 已提交
941 942
    cached_hr_uc = [None, None]
    cached_hr_c = [None, None]
A
AUTOMATIC 已提交
943

A
AUTOMATIC1111 已提交
944
    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
A
AUTOMATIC 已提交
945 946 947
        super().__init__(**kwargs)
        self.enable_hr = enable_hr
        self.denoising_strength = denoising_strength
A
AUTOMATIC 已提交
948 949
        self.hr_scale = hr_scale
        self.hr_upscaler = hr_upscaler
950 951 952 953 954
        self.hr_second_pass_steps = hr_second_pass_steps
        self.hr_resize_x = hr_resize_x
        self.hr_resize_y = hr_resize_y
        self.hr_upscale_to_x = hr_resize_x
        self.hr_upscale_to_y = hr_resize_y
A
AUTOMATIC1111 已提交
955 956
        self.hr_checkpoint_name = hr_checkpoint_name
        self.hr_checkpoint_info = None
957 958 959
        self.hr_sampler_name = hr_sampler_name
        self.hr_prompt = hr_prompt
        self.hr_negative_prompt = hr_negative_prompt
I
invincibledude 已提交
960 961
        self.all_hr_prompts = None
        self.all_hr_negative_prompts = None
A
AUTOMATIC1111 已提交
962
        self.latent_scale_mode = None
A
AUTOMATIC 已提交
963 964

        if firstphase_width != 0 or firstphase_height != 0:
965 966
            self.hr_upscale_to_x = self.width
            self.hr_upscale_to_y = self.height
A
AUTOMATIC 已提交
967 968
            self.width = firstphase_width
            self.height = firstphase_height
A
AUTOMATIC 已提交
969

970 971
        self.truncate_x = 0
        self.truncate_y = 0
972
        self.applied_old_hires_behavior_to = None
A
AUTOMATIC 已提交
973

974 975 976 977
        self.hr_prompts = None
        self.hr_negative_prompts = None
        self.hr_extra_network_data = None

W
w-e-w 已提交
978 979
        self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
        self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
980 981 982
        self.hr_c = None
        self.hr_uc = None

A
AUTOMATIC 已提交
983 984
    def init(self, all_prompts, all_seeds, all_subseeds):
        if self.enable_hr:
A
AUTOMATIC1111 已提交
985 986 987 988 989 990 991 992
            if self.hr_checkpoint_name:
                self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)

                if self.hr_checkpoint_info is None:
                    raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')

                self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title

993 994 995 996 997
            if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
                self.extra_generation_params["Hires sampler"] = self.hr_sampler_name

            if tuple(self.hr_prompt) != tuple(self.prompt):
                self.extra_generation_params["Hires prompt"] = self.hr_prompt
I
InvincibleDude 已提交
998

999 1000
            if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
                self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
I
InvincibleDude 已提交
1001

A
AUTOMATIC1111 已提交
1002 1003 1004 1005 1006
            self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
            if self.enable_hr and self.latent_scale_mode is None:
                if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
                    raise Exception(f"could not find upscaler named {self.hr_upscaler}")

1007 1008 1009 1010 1011 1012 1013 1014 1015
            if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
                self.hr_resize_x = self.width
                self.hr_resize_y = self.height
                self.hr_upscale_to_x = self.width
                self.hr_upscale_to_y = self.height

                self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
                self.applied_old_hires_behavior_to = (self.width, self.height)

1016 1017 1018 1019
            if self.hr_resize_x == 0 and self.hr_resize_y == 0:
                self.extra_generation_params["Hires upscale"] = self.hr_scale
                self.hr_upscale_to_x = int(self.width * self.hr_scale)
                self.hr_upscale_to_y = int(self.height * self.hr_scale)
A
AUTOMATIC 已提交
1020
            else:
1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044
                self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"

                if self.hr_resize_y == 0:
                    self.hr_upscale_to_x = self.hr_resize_x
                    self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
                elif self.hr_resize_x == 0:
                    self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
                    self.hr_upscale_to_y = self.hr_resize_y
                else:
                    target_w = self.hr_resize_x
                    target_h = self.hr_resize_y
                    src_ratio = self.width / self.height
                    dst_ratio = self.hr_resize_x / self.hr_resize_y

                    if src_ratio < dst_ratio:
                        self.hr_upscale_to_x = self.hr_resize_x
                        self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
                    else:
                        self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
                        self.hr_upscale_to_y = self.hr_resize_y

                    self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
                    self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f

1045 1046 1047
            if not state.processing_has_refined_job_count:
                if state.job_count == -1:
                    state.job_count = self.n_iter
1048

1049 1050 1051
                shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
                state.job_count = state.job_count * 2
                state.processing_has_refined_job_count = True
1052

1053 1054
            if self.hr_second_pass_steps:
                self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
1055

A
AUTOMATIC 已提交
1056 1057
            if self.hr_upscaler is not None:
                self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
1058

1059
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1060
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1061

A
AUTOMATIC 已提交
1062 1063
        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
A
linter  
AUTOMATIC1111 已提交
1064
        del x
A
AUTOMATIC 已提交
1065

1066
        if not self.enable_hr:
A
AUTOMATIC 已提交
1067 1068
            return samples

1069
        if self.latent_scale_mode is None:
1070
            decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
1071 1072 1073
        else:
            decoded_samples = None

A
AUTOMATIC1111 已提交
1074 1075 1076
        current = shared.sd_model.sd_checkpoint_info
        try:
            if self.hr_checkpoint_info is not None:
A
AUTOMATIC1111 已提交
1077
                self.sampler = None
A
AUTOMATIC1111 已提交
1078
                sd_models.reload_model_weights(info=self.hr_checkpoint_info)
1079
                devices.torch_gc()
A
AUTOMATIC1111 已提交
1080

1081
            return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
A
AUTOMATIC1111 已提交
1082
        finally:
A
AUTOMATIC1111 已提交
1083
            self.sampler = None
A
AUTOMATIC1111 已提交
1084
            sd_models.reload_model_weights(info=current)
1085
            devices.torch_gc()
A
AUTOMATIC1111 已提交
1086

1087
    def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
A
AUTOMATIC 已提交
1088 1089
        self.is_hr_pass = True

1090 1091
        target_width = self.hr_upscale_to_x
        target_height = self.hr_upscale_to_y
A
AUTOMATIC 已提交
1092

1093
        def save_intermediate(image, index):
A
AUTOMATIC 已提交
1094 1095
            """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""

1096 1097 1098 1099
            if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
                return

            if not isinstance(image, Image.Image):
M
MMaker 已提交
1100
                image = sd_samplers.sample_to_image(image, index, approximation=0)
1101

1102
            info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
W
w-e-w 已提交
1103
            images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
1104

1105 1106 1107 1108 1109 1110 1111
        img2img_sampler_name = self.hr_sampler_name or self.sampler_name

        if self.sampler_name in ['PLMS', 'UniPC']:  # PLMS/UniPC do not support img2img so we just silently switch to DDIM
            img2img_sampler_name = 'DDIM'

        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)

A
AUTOMATIC1111 已提交
1112
        if self.latent_scale_mode is not None:
1113 1114 1115
            for i in range(samples.shape[0]):
                save_intermediate(samples, i)

A
AUTOMATIC1111 已提交
1116
            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
1117

J
Jim Hays 已提交
1118
            # Avoid making the inpainting conditioning unless necessary as
1119 1120 1121 1122 1123
            # this does need some extra compute to decode / encode the image again.
            if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
                image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
            else:
                image_conditioning = self.txt2img_image_conditioning(samples)
A
AUTOMATIC 已提交
1124
        else:
1125
            lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
1126

1127 1128 1129 1130 1131
            batch_images = []
            for i, x_sample in enumerate(lowres_samples):
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)
                image = Image.fromarray(x_sample)
1132 1133 1134

                save_intermediate(image, i)

A
AUTOMATIC 已提交
1135
                image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
1136 1137 1138 1139 1140
                image = np.array(image).astype(np.float32) / 255.0
                image = np.moveaxis(image, 2, 0)
                batch_images.append(image)

            decoded_samples = torch.from_numpy(np.array(batch_images))
K
Kohaku-Blueleaf 已提交
1141
            decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
1142

1143 1144
            if opts.sd_vae_encode_method != 'Full':
                self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1145
            samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
A
AUTOMATIC 已提交
1146

1147
            image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
1148

A
AUTOMATIC 已提交
1149
        shared.state.nextjob()
1150

1151 1152
        samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]

A
AUTOMATIC 已提交
1153
        noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
1154 1155 1156

        # GC now before running the next img2img to prevent running out of memory
        devices.torch_gc()
1157

1158 1159 1160 1161
        if not self.disable_extra_networks:
            with devices.autocast():
                extra_networks.activate(self, self.hr_extra_network_data)

1162 1163 1164
        with devices.autocast():
            self.calculate_hr_conds()

1165
        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
1166

1167 1168 1169
        if self.scripts is not None:
            self.scripts.before_hr(self)

1170
        samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
A
AUTOMATIC 已提交
1171

1172
        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
P
papuSpartan 已提交
1173

A
AUTOMATIC1111 已提交
1174 1175
        decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)

A
AUTOMATIC 已提交
1176 1177
        self.is_hr_pass = False

A
AUTOMATIC1111 已提交
1178
        return decoded_samples
1179

1180
    def close(self):
W
w-e-w 已提交
1181
        super().close()
1182 1183
        self.hr_c = None
        self.hr_uc = None
W
w-e-w 已提交
1184 1185 1186
        if not opts.experimental_persistent_cond_cache:
            StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
            StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212

    def setup_prompts(self):
        super().setup_prompts()

        if not self.enable_hr:
            return

        if self.hr_prompt == '':
            self.hr_prompt = self.prompt

        if self.hr_negative_prompt == '':
            self.hr_negative_prompt = self.negative_prompt

        if type(self.hr_prompt) == list:
            self.all_hr_prompts = self.hr_prompt
        else:
            self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]

        if type(self.hr_negative_prompt) == list:
            self.all_hr_negative_prompts = self.hr_negative_prompt
        else:
            self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]

        self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
        self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]

1213 1214 1215 1216
    def calculate_hr_conds(self):
        if self.hr_c is not None:
            return

A
AUTOMATIC1111 已提交
1217 1218 1219 1220 1221
        hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
        hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)

        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
1222

1223 1224 1225
    def setup_conds(self):
        super().setup_conds()

1226 1227 1228
        self.hr_uc = None
        self.hr_c = None

A
AUTOMATIC1111 已提交
1229
        if self.enable_hr and self.hr_checkpoint_info is None:
1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240
            if shared.opts.hires_fix_use_firstpass_conds:
                self.calculate_hr_conds()

            elif lowvram.is_enabled(shared.sd_model):  # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
                with devices.autocast():
                    extra_networks.activate(self, self.hr_extra_network_data)

                self.calculate_hr_conds()

                with devices.autocast():
                    extra_networks.activate(self, self.extra_network_data)
1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252

    def parse_extra_network_prompts(self):
        res = super().parse_extra_network_prompts()

        if self.enable_hr:
            self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
            self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]

            self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)

        return res

1253 1254 1255 1256

class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    sampler = None

1257
    def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
1258 1259 1260 1261 1262
        super().__init__(**kwargs)

        self.init_images = init_images
        self.resize_mode: int = resize_mode
        self.denoising_strength: float = denoising_strength
K
Kyle 已提交
1263
        self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
1264 1265
        self.init_latent = None
        self.image_mask = mask
A
AUTOMATIC 已提交
1266
        self.latent_mask = None
1267
        self.mask_for_overlay = None
1268 1269 1270 1271 1272
        if mask_blur is not None:
            mask_blur_x = mask_blur
            mask_blur_y = mask_blur
        self.mask_blur_x = mask_blur_x
        self.mask_blur_y = mask_blur_y
1273 1274
        self.inpainting_fill = inpainting_fill
        self.inpaint_full_res = inpaint_full_res
1275
        self.inpaint_full_res_padding = inpaint_full_res_padding
A
AUTOMATIC 已提交
1276
        self.inpainting_mask_invert = inpainting_mask_invert
1277
        self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
1278 1279
        self.mask = None
        self.nmask = None
1280
        self.image_conditioning = None
1281

A
AUTOMATIC 已提交
1282
    def init(self, all_prompts, all_seeds, all_subseeds):
1283
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1284 1285
        crop_region = None

1286
        image_mask = self.image_mask
A
AUTOMATIC 已提交
1287

1288 1289
        if image_mask is not None:
            image_mask = image_mask.convert('L')
A
AUTOMATIC 已提交
1290

1291 1292
            if self.inpainting_mask_invert:
                image_mask = ImageOps.invert(image_mask)
A
AUTOMATIC 已提交
1293

1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304
            if self.mask_blur_x > 0:
                np_mask = np.array(image_mask)
                kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1
                np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
                image_mask = Image.fromarray(np_mask)

            if self.mask_blur_y > 0:
                np_mask = np.array(image_mask)
                kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1
                np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
                image_mask = Image.fromarray(np_mask)
1305 1306

            if self.inpaint_full_res:
1307 1308
                self.mask_for_overlay = image_mask
                mask = image_mask.convert('L')
1309
                crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
1310
                crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
1311 1312 1313
                x1, y1, x2, y2 = crop_region

                mask = mask.crop(crop_region)
1314
                image_mask = images.resize_image(2, mask, self.width, self.height)
1315 1316
                self.paste_to = (x1, y1, x2-x1, y2-y1)
            else:
1317
                image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
1318
                np_mask = np.array(image_mask)
J
JJ 已提交
1319
                np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
1320
                self.mask_for_overlay = Image.fromarray(np_mask)
1321 1322 1323

            self.overlay_images = []

1324
        latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
1325

1326 1327 1328
        add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
        if add_color_corrections:
            self.color_corrections = []
1329 1330
        imgs = []
        for img in self.init_images:
1331 1332 1333 1334 1335 1336

            # Save init image
            if opts.save_init_img:
                self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
                images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)

1337
            image = images.flatten(img, opts.img2img_background_color)
1338

A
Andrew Ryan 已提交
1339
            if crop_region is None and self.resize_mode != 3:
1340
                image = images.resize_image(self.resize_mode, image, self.width, self.height)
1341

1342
            if image_mask is not None:
1343 1344 1345 1346 1347
                image_masked = Image.new('RGBa', (image.width, image.height))
                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))

                self.overlay_images.append(image_masked.convert('RGBA'))

1348
            # crop_region is not None if we are doing inpaint full res
1349 1350 1351 1352
            if crop_region is not None:
                image = image.crop(crop_region)
                image = images.resize_image(2, image, self.width, self.height)

1353
            if image_mask is not None:
1354
                if self.inpainting_fill != 1:
1355
                    image = masking.fill(image, latent_mask)
1356

1357
            if add_color_corrections:
1358 1359
                self.color_corrections.append(setup_color_correction(image))

1360 1361 1362 1363 1364 1365 1366 1367 1368
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)

            imgs.append(image)

        if len(imgs) == 1:
            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size
1369 1370 1371 1372

            if self.color_corrections is not None and len(self.color_corrections) == 1:
                self.color_corrections = self.color_corrections * self.batch_size

1373 1374 1375 1376 1377 1378 1379
        elif len(imgs) <= self.batch_size:
            self.batch_size = len(imgs)
            batch_images = np.array(imgs)
        else:
            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

        image = torch.from_numpy(batch_images)
K
Kohaku-Blueleaf 已提交
1380
        image = image.to(shared.device, dtype=devices.dtype_vae)
1381 1382 1383 1384

        if opts.sd_vae_encode_method != 'Full':
            self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method

K
Kohaku-Blueleaf 已提交
1385
        self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
1386
        devices.torch_gc()
1387

1388 1389
        if self.resize_mode == 3:
            self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
A
Andrew Ryan 已提交
1390

1391
        if image_mask is not None:
1392
            init_mask = latent_mask
A
AUTOMATIC 已提交
1393
            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
A
AUTOMATIC 已提交
1394
            latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
1395
            latmask = latmask[0]
1396
            latmask = np.around(latmask)
1397 1398 1399 1400 1401
            latmask = np.tile(latmask[None], (4, 1, 1))

            self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
            self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)

A
AUTOMATIC 已提交
1402
            # this needs to be fixed to be done in sample() using actual seeds for batches
1403
            if self.inpainting_fill == 2:
A
AUTOMATIC 已提交
1404
                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
1405 1406 1407
            elif self.inpainting_fill == 3:
                self.init_latent = self.init_latent * self.mask

1408
        self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
1409

1410
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
A
AUTOMATIC 已提交
1411
        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
1412 1413 1414 1415

        if self.initial_noise_multiplier != 1.0:
            self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
            x *= self.initial_noise_multiplier
A
AUTOMATIC 已提交
1416

1417
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
1418 1419 1420 1421

        if self.mask is not None:
            samples = samples * self.nmask + self.init_latent * self.mask

1422 1423 1424
        del x
        devices.torch_gc()

1425
        return samples
1426 1427 1428

    def get_token_merging_ratio(self, for_hr=False):
        return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio