processing.py 61.5 KB
Newer Older
1
import json
2
import logging
3 4 5
import math
import os
import sys
6
import hashlib
7 8 9

import torch
import numpy as np
A
linter  
AUTOMATIC 已提交
10
from PIL import Image, ImageOps
11
import random
12 13
import cv2
from skimage import exposure
A
AUTOMATIC 已提交
14
from typing import Any, Dict, List
15

16
import modules.sd_hijack
A
AUTOMATIC 已提交
17
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
18 19 20
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
21
import modules.paths as paths
A
AUTOMATIC 已提交
22
import modules.face_restoration
23
import modules.images as images
A
AUTOMATIC 已提交
24
import modules.styles
25 26
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
J
Jay Smith 已提交
27 28
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
29

J
Jay Smith 已提交
30
from einops import repeat, rearrange
31
from blendmodes.blend import blendLayers, BlendType
32 33


34 35 36 37 38
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8


39
def setup_color_correction(image):
R
Robin Fernandes 已提交
40
    logging.info("Calibrating color correction.")
41 42 43 44
    correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
    return correction_target


45
def apply_color_correction(correction, original_image):
R
Robin Fernandes 已提交
46
    logging.info("Applying color correction.")
47 48
    image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
        cv2.cvtColor(
49
            np.asarray(original_image),
50 51 52 53 54
            cv2.COLOR_RGB2LAB
        ),
        correction,
        channel_axis=2
    ), cv2.COLOR_LAB2RGB).astype("uint8"))
55

56
    image = blendLayers(image, original_image, BlendType.LUMINOSITY)
57

58 59
    return image

A
AUTOMATIC 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76

def apply_overlay(image, paste_loc, index, overlays):
    if overlays is None or index >= len(overlays):
        return image

    overlay = overlays[index]

    if paste_loc is not None:
        x, y, w, h = paste_loc
        base_image = Image.new('RGBA', (overlay.width, overlay.height))
        image = images.resize_image(1, image, w, h)
        base_image.paste(image, (x, y))
        image = base_image

    image = image.convert('RGBA')
    image.alpha_composite(overlay)
    image = image.convert('RGB')
77 78

    return image
79

F
frostydad 已提交
80

81
def txt2img_image_conditioning(sd_model, x, width, height):
82 83 84 85 86 87 88 89 90 91 92
    if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models

        # The "masked-image" in this case will just be all zeros since the entire image is masked.
        image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
        image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))

        # Add the fake full 1s mask to the first dimension.
        image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
        image_conditioning = image_conditioning.to(x.dtype)

        return image_conditioning
93

94
    elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
95

96
        return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
97

98 99 100 101 102
    else:
        # Dummy zero conditioning if we're not using inpainting or unclip models.
        # Still takes up a bit of memory, but no encoder call.
        # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
        return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
103 104


105
class StableDiffusionProcessing:
A
arcticfaded 已提交
106 107 108
    """
    The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
    """
W
w-e-w 已提交
109 110 111
    cached_uc = [None, None]
    cached_c = [None, None]

D
devdn 已提交
112
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
113
        if sampler_index is not None:
114
            print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
115

116 117 118
        self.outpath_samples: str = outpath_samples
        self.outpath_grids: str = outpath_grids
        self.prompt: str = prompt
A
AUTOMATIC 已提交
119
        self.prompt_for_display: str = None
120
        self.negative_prompt: str = (negative_prompt or "")
121
        self.styles: list = styles or []
122
        self.seed: int = seed
123 124 125 126
        self.subseed: int = subseed
        self.subseed_strength: float = subseed_strength
        self.seed_resize_from_h: int = seed_resize_from_h
        self.seed_resize_from_w: int = seed_resize_from_w
127
        self.sampler_name: str = sampler_name
128 129 130 131 132 133
        self.batch_size: int = batch_size
        self.n_iter: int = n_iter
        self.steps: int = steps
        self.cfg_scale: float = cfg_scale
        self.width: int = width
        self.height: int = height
A
AUTOMATIC 已提交
134
        self.restore_faces: bool = restore_faces
135
        self.tiling: bool = tiling
136 137
        self.do_not_save_samples: bool = do_not_save_samples
        self.do_not_save_grid: bool = do_not_save_grid
A
AUTOMATIC 已提交
138
        self.extra_generation_params: dict = extra_generation_params or {}
139
        self.overlay_images = overlay_images
140
        self.eta = eta
141
        self.do_not_reload_embeddings = do_not_reload_embeddings
142
        self.paste_to = None
143
        self.color_corrections = None
144
        self.denoising_strength: float = denoising_strength
145
        self.sampler_noise_scheduler_override = None
146
        self.ddim_discretize = ddim_discretize or opts.ddim_discretize
D
devdn 已提交
147
        self.s_min_uncond = s_min_uncond or opts.s_min_uncond
A
arcticfaded 已提交
148 149 150 151
        self.s_churn = s_churn or opts.s_churn
        self.s_tmin = s_tmin or opts.s_tmin
        self.s_tmax = s_tmax or float('inf')  # not representable as a standard ui option
        self.s_noise = s_noise or opts.s_noise
152
        self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
153
        self.override_settings_restore_afterwards = override_settings_restore_afterwards
154
        self.is_using_inpainting_conditioning = False
A
AUTOMATIC 已提交
155
        self.disable_extra_networks = False
156 157
        self.token_merging_ratio = 0
        self.token_merging_ratio_hr = 0
158

159 160 161 162 163 164
        if not seed_enable_extras:
            self.subseed = -1
            self.subseed_strength = 0
            self.seed_resize_from_h = 0
            self.seed_resize_from_w = 0

165
        self.scripts = None
N
noodleanon 已提交
166
        self.script_args = script_args
167
        self.all_prompts = None
168
        self.all_negative_prompts = None
169 170
        self.all_seeds = None
        self.all_subseeds = None
171
        self.iteration = 0
A
AUTOMATIC 已提交
172
        self.is_hr_pass = False
173
        self.sampler = None
174

175 176
        self.prompts = None
        self.negative_prompts = None
W
w-e-w 已提交
177
        self.extra_network_data = None
178 179 180 181
        self.seeds = None
        self.subseeds = None

        self.step_multiplier = 1
W
w-e-w 已提交
182 183
        self.cached_uc = StableDiffusionProcessing.cached_uc
        self.cached_c = StableDiffusionProcessing.cached_c
184 185
        self.uc = None
        self.c = None
186

187 188
        self.user = None

189 190 191 192
    @property
    def sd_model(self):
        return shared.sd_model

193
    def txt2img_image_conditioning(self, x, width=None, height=None):
194
        self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
195

196
        return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
197

J
Jay Smith 已提交
198 199 200 201 202 203 204
    def depth2img_image_conditioning(self, source_image):
        # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
        transformer = AddMiDaS(model_type="dpt_hybrid")
        transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
        midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
        midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)

205
        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
J
Jay Smith 已提交
206 207 208 209 210 211 212 213 214 215
        conditioning = torch.nn.functional.interpolate(
            self.sd_model.depth_model(midas_in),
            size=conditioning_image.shape[2:],
            mode="bicubic",
            align_corners=False,
        )

        (depth_min, depth_max) = torch.aminmax(conditioning)
        conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
        return conditioning
216

217
    def edit_image_conditioning(self, source_image):
K
Kyle 已提交
218
        conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
219 220 221

        return conditioning_image

222 223 224 225 226 227 228 229
    def unclip_image_conditioning(self, source_image):
        c_adm = self.sd_model.embedder(source_image)
        if self.sd_model.noise_augmentor is not None:
            noise_level = 0 # TODO: Allow other noise levels?
            c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
            c_adm = torch.cat((c_adm, noise_level_emb), 1)
        return c_adm

230
    def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
231 232
        self.is_using_inpainting_conditioning = True

233 234 235 236 237 238 239 240 241 242 243 244
        # Handle the different mask inputs
        if image_mask is not None:
            if torch.is_tensor(image_mask):
                conditioning_mask = image_mask
            else:
                conditioning_mask = np.array(image_mask.convert("L"))
                conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
                conditioning_mask = torch.from_numpy(conditioning_mask[None, None])

                # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
                conditioning_mask = torch.round(conditioning_mask)
        else:
245
            conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
246 247 248

        # Create another latent image, this time with a masked version of the original input.
        # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
249
        conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
250 251 252 253 254
        conditioning_image = torch.lerp(
            source_image,
            source_image * (1.0 - conditioning_mask),
            getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
        )
J
Jim Hays 已提交
255

256
        # Encode the new masked image using first stage of network.
257
        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
258 259 260 261 262 263 264 265 266

        # Create the concatenated conditioning tensor to be fed to `c_concat`
        conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
        conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
        image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
        image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)

        return image_conditioning

J
Jay Smith 已提交
267
    def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
268 269
        source_image = devices.cond_cast_float(source_image)

J
Jay Smith 已提交
270 271 272
        # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
        # identify itself with a field common to all models. The conditioning_key is also hybrid.
        if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
273
            return self.depth2img_image_conditioning(source_image)
J
Jay Smith 已提交
274

275 276 277
        if self.sd_model.cond_stage_key == "edit":
            return self.edit_image_conditioning(source_image)

J
Jay Smith 已提交
278
        if self.sampler.conditioning_key in {'hybrid', 'concat'}:
279
            return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
J
Jay Smith 已提交
280

281 282 283
        if self.sampler.conditioning_key == "crossattn-adm":
            return self.unclip_image_conditioning(source_image)

J
Jay Smith 已提交
284 285 286
        # Dummy zero conditioning if we're not using inpainting or depth model.
        return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)

A
AUTOMATIC 已提交
287
    def init(self, all_prompts, all_seeds, all_subseeds):
288 289
        pass

290
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
291 292
        raise NotImplementedError()

293 294
    def close(self):
        self.sampler = None
295 296
        self.c = None
        self.uc = None
W
w-e-w 已提交
297 298 299
        if not opts.experimental_persistent_cond_cache:
            StableDiffusionProcessing.cached_c = [None, None]
            StableDiffusionProcessing.cached_uc = [None, None]
300

301 302 303 304 305 306
    def get_token_merging_ratio(self, for_hr=False):
        if for_hr:
            return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio

        return self.token_merging_ratio or opts.token_merging_ratio

307 308 309 310 311 312 313 314 315 316 317 318 319 320
    def setup_prompts(self):
        if type(self.prompt) == list:
            self.all_prompts = self.prompt
        else:
            self.all_prompts = self.batch_size * self.n_iter * [self.prompt]

        if type(self.negative_prompt) == list:
            self.all_negative_prompts = self.negative_prompt
        else:
            self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]

        self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
        self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]

321
    def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
322 323 324 325 326 327 328 329
        """
        Returns the result of calling function(shared.sd_model, required_prompts, steps)
        using a cache to store the result if the same arguments have been used before.

        cache is an array containing two elements. The first element is a tuple
        representing the previously used arguments, or None if no arguments
        have been used before. The second element is where the previously
        computed result is stored.
330 331

        caches is a list with items described above.
332
        """
333 334 335 336 337 338 339 340 341 342 343 344 345

        cached_params = (
            required_prompts,
            steps,
            opts.CLIP_stop_at_last_layers,
            shared.sd_model.sd_checkpoint_info,
            extra_network_data,
            opts.sdxl_crop_left,
            opts.sdxl_crop_top,
            self.width,
            self.height,
        )

346
        for cache in caches:
347
            if cache[0] is not None and cached_params == cache[0]:
348 349 350
                return cache[1]

        cache = caches[0]
351 352 353 354

        with devices.autocast():
            cache[1] = function(shared.sd_model, required_prompts, steps)

355
        cache[0] = cached_params
356 357 358
        return cache[1]

    def setup_conds(self):
A
AUTOMATIC1111 已提交
359
        prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
360
        negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
A
AUTOMATIC1111 已提交
361

362 363
        sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
        self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
A
AUTOMATIC1111 已提交
364 365
        self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
        self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
366 367

    def parse_extra_network_prompts(self):
W
w-e-w 已提交
368
        self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
369

370 371

class Processed:
372
    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
373 374
        self.images = images_list
        self.prompt = p.prompt
375
        self.negative_prompt = p.negative_prompt
376
        self.seed = seed
377 378
        self.subseed = subseed
        self.subseed_strength = p.subseed_strength
379
        self.info = info
380
        self.comments = comments
381 382
        self.width = p.width
        self.height = p.height
383
        self.sampler_name = p.sampler_name
384
        self.cfg_scale = p.cfg_scale
K
Kyle 已提交
385
        self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
386
        self.steps = p.steps
387 388 389 390 391 392 393 394 395
        self.batch_size = p.batch_size
        self.restore_faces = p.restore_faces
        self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
        self.sd_model_hash = shared.sd_model.sd_model_hash
        self.seed_resize_from_w = p.seed_resize_from_w
        self.seed_resize_from_h = p.seed_resize_from_h
        self.denoising_strength = getattr(p, 'denoising_strength', None)
        self.extra_generation_params = p.extra_generation_params
        self.index_of_first_image = index_of_first_image
M
Milly 已提交
396
        self.styles = p.styles
M
Milly 已提交
397
        self.job_timestamp = state.job_timestamp
398
        self.clip_skip = opts.CLIP_stop_at_last_layers
399 400
        self.token_merging_ratio = p.token_merging_ratio
        self.token_merging_ratio_hr = p.token_merging_ratio_hr
401

C
C43H66N12O12S2 已提交
402
        self.eta = p.eta
403 404 405 406 407
        self.ddim_discretize = p.ddim_discretize
        self.s_churn = p.s_churn
        self.s_tmin = p.s_tmin
        self.s_tmax = p.s_tmax
        self.s_noise = p.s_noise
A
Aarni Koskela 已提交
408
        self.s_min_uncond = p.s_min_uncond
409
        self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
410 411
        self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
        self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
G
githublsx 已提交
412
        self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
413
        self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
414
        self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
415

416 417 418 419
        self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
        self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
        self.all_seeds = all_seeds or p.all_seeds or [self.seed]
        self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
420
        self.infotexts = infotexts or [info]
421 422 423

    def js(self):
        obj = {
424
            "prompt": self.all_prompts[0],
425
            "all_prompts": self.all_prompts,
426 427
            "negative_prompt": self.all_negative_prompts[0],
            "all_negative_prompts": self.all_negative_prompts,
428 429 430 431
            "seed": self.seed,
            "all_seeds": self.all_seeds,
            "subseed": self.subseed,
            "all_subseeds": self.all_subseeds,
432
            "subseed_strength": self.subseed_strength,
433 434
            "width": self.width,
            "height": self.height,
435
            "sampler_name": self.sampler_name,
436 437
            "cfg_scale": self.cfg_scale,
            "steps": self.steps,
438 439 440 441 442 443 444 445 446
            "batch_size": self.batch_size,
            "restore_faces": self.restore_faces,
            "face_restoration_model": self.face_restoration_model,
            "sd_model_hash": self.sd_model_hash,
            "seed_resize_from_w": self.seed_resize_from_w,
            "seed_resize_from_h": self.seed_resize_from_h,
            "denoising_strength": self.denoising_strength,
            "extra_generation_params": self.extra_generation_params,
            "index_of_first_image": self.index_of_first_image,
447
            "infotexts": self.infotexts,
M
Milly 已提交
448
            "styles": self.styles,
M
Milly 已提交
449
            "job_timestamp": self.job_timestamp,
450
            "clip_skip": self.clip_skip,
451
            "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
452 453 454 455
        }

        return json.dumps(obj)

S
space-nuko 已提交
456
    def infotext(self, p: StableDiffusionProcessing, index):
457 458
        return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)

459 460 461
    def get_token_merging_ratio(self, for_hr=False):
        return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio

462

463 464 465 466
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
467 468 469 470 471 472
    dot = (low_norm*high_norm).sum(1)

    if dot.mean() > 0.9995:
        return low * val + high * (1 - val)

    omega = torch.acos(dot)
473 474 475
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
    return res
476

477

478
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
479
    eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
480
    xs = []
481

482 483
    # if we have multiple seeds, this means we are working with batch size>1; this then
    # enables the generation of additional tensors with noise that the sampler will use during its processing.
S
Steve Eberhardt 已提交
484
    # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
485
    # produce the same images as with two batches [100], [101].
486
    if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
487 488 489 490
        sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
    else:
        sampler_noises = None

491 492 493 494 495 496
    for i, seed in enumerate(seeds):
        noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)

        subnoise = None
        if subseeds is not None:
            subseed = 0 if i >= len(subseeds) else subseeds[i]
E
Elias Oenal 已提交
497

A
AUTOMATIC 已提交
498
            subnoise = devices.randn(subseed, noise_shape)
499 500 501

        # randn results depend on device; gpu and cpu get different results for same seed;
        # the way I see it, it's better to do this on CPU, so that everyone gets same result;
502
        # but the original script had it like this, so I do not dare change it for now because
503
        # it will break everyone's seeds.
A
AUTOMATIC 已提交
504
        noise = devices.randn(seed, noise_shape)
505 506 507 508 509

        if subnoise is not None:
            noise = slerp(subseed_strength, noise, subnoise)

        if noise_shape != shape:
A
AUTOMATIC 已提交
510 511
            x = devices.randn(seed, shape)
            dx = (shape[2] - noise_shape[2]) // 2
512 513 514 515 516 517 518 519 520 521 522
            dy = (shape[1] - noise_shape[1]) // 2
            w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
            h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
            tx = 0 if dx < 0 else dx
            ty = 0 if dy < 0 else dy
            dx = max(-dx, 0)
            dy = max(-dy, 0)

            x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
            noise = x

523 524
        if sampler_noises is not None:
            cnt = p.sampler.number_of_needed_noises(p)
525

526 527
            if eta_noise_seed_delta > 0:
                torch.manual_seed(seed + eta_noise_seed_delta)
A
AUTOMATIC 已提交
528

529 530
            for j in range(cnt):
                sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
531 532

        xs.append(noise)
533 534 535 536

    if sampler_noises is not None:
        p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]

537
    x = torch.stack(xs).to(shared.device)
538 539 540
    return x


A
AUTOMATIC 已提交
541
def decode_first_stage(model, x):
542
    x = model.decode_first_stage(x.to(devices.dtype_vae))
A
AUTOMATIC 已提交
543 544 545 546

    return x


547 548 549 550 551 552 553
def get_fixed_seed(seed):
    if seed is None or seed == '' or seed == -1:
        return int(random.randrange(4294967294))

    return seed


554
def fix_seed(p):
555 556
    p.seed = get_fixed_seed(p.seed)
    p.subseed = get_fixed_seed(p.subseed)
A
AUTOMATIC 已提交
557 558


559 560 561 562 563 564 565 566 567 568
def program_version():
    import launch

    res = launch.git_tag()
    if res == "<none>":
        res = None

    return res


569
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
570 571
    index = position_in_batch + iteration * p.batch_size

572
    clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
P
papuSpartan 已提交
573
    enable_hr = getattr(p, 'enable_hr', False)
574 575
    token_merging_ratio = p.get_token_merging_ratio()
    token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
576

577 578 579 580
    uses_ensd = opts.eta_noise_seed_delta != 0
    if uses_ensd:
        uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)

581 582
    generation_params = {
        "Steps": p.steps,
583
        "Sampler": p.sampler_name,
584
        "CFG scale": p.cfg_scale,
K
Kyle 已提交
585
        "Image CFG scale": getattr(p, 'image_cfg_scale', None),
586 587 588 589
        "Seed": all_seeds[index],
        "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
        "Size": f"{p.width}x{p.height}",
        "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
590
        "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
591 592
        "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
        "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
M
missionfloyd 已提交
593
        "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
594
        "Denoising strength": getattr(p, 'denoising_strength', None),
595
        "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
596
        "Clip skip": None if clip_skip <= 1 else clip_skip,
597
        "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
598 599
        "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
        "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
600
        "Init image hash": getattr(p, 'init_img_hash', None),
601 602
        "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
        "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
603
        **p.extra_generation_params,
604
        "Version": program_version() if opts.add_version_to_infotext else None,
605
        "User": p.user if opts.add_user_name_to_info else None,
606 607
    }

608
    generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
609

610
    prompt_text = p.prompt if use_main_prompt else all_prompts[index]
611
    negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
612

613
    return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
614 615


616
def process_images(p: StableDiffusionProcessing) -> Processed:
617 618 619
    if p.scripts is not None:
        p.scripts.before_process(p)

620 621 622
    stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}

    try:
W
w-e-w 已提交
623
        # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
A
Aarni Koskela 已提交
624
        if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
W
w-e-w 已提交
625 626 627
            p.override_settings.pop('sd_model_checkpoint', None)
            sd_models.reload_model_weights()

628
        for k, v in p.override_settings.items():
629
            setattr(opts, k, v)
630 631

            if k == 'sd_model_checkpoint':
A
AUTOMATIC 已提交
632
                sd_models.reload_model_weights()
633 634

            if k == 'sd_vae':
A
AUTOMATIC 已提交
635
                sd_vae.reload_vae_weights()
636

637
        sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
638

639 640
        res = process_images_inner(p)

641
    finally:
642
        sd_models.apply_token_merging(p.sd_model, 0)
643

644 645 646 647
        # restore opts to original state
        if p.override_settings_restore_afterwards:
            for k, v in stored_opts.items():
                setattr(opts, k, v)
A
AUTOMATIC 已提交
648 649 650

                if k == 'sd_vae':
                    sd_vae.reload_vae_weights()
651 652 653 654 655

    return res


def process_images_inner(p: StableDiffusionProcessing) -> Processed:
656 657
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

658 659 660 661
    if type(p.prompt) == list:
        assert(len(p.prompt) > 0)
    else:
        assert p.prompt is not None
662

663
    devices.torch_gc()
664

665 666
    seed = get_fixed_seed(p.seed)
    subseed = get_fixed_seed(p.subseed)
667

668
    modules.sd_hijack.model_hijack.apply_circular(p.tiling)
669
    modules.sd_hijack.model_hijack.clear_comments()
670

671
    comments = {}
A
AUTOMATIC 已提交
672

673
    p.setup_prompts()
I
invincibledude 已提交
674

675
    if type(seed) == list:
676
        p.all_seeds = seed
A
AUTOMATIC 已提交
677
    else:
678
        p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
679

680
    if type(subseed) == list:
681
        p.all_subseeds = subseed
682
    else:
683
        p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
684

685 686
    def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
        return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
687

688
    if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
689
        model_hijack.embedding_db.load_textual_inversion_embeddings()
690

691
    if p.scripts is not None:
A
AUTOMATIC 已提交
692
        p.scripts.process(p)
693

694
    infotexts = []
695
    output_images = []
696

697
    with torch.no_grad(), p.sd_model.ema_scope():
A
AUTOMATIC 已提交
698
        with devices.autocast():
699
            p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
700

701 702
            # for OSX, loading the model during sampling changes the generated picture, so it is loaded here
            if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
703 704
                sd_vae_approx.model()

A
AUTOMATIC 已提交
705 706
            sd_unet.apply_unet()

A
AUTOMATIC 已提交
707 708
        if state.job_count == -1:
            state.job_count = p.n_iter
A
AUTOMATIC 已提交
709

710
        for n in range(p.n_iter):
711 712
            p.iteration = n

713 714
            if state.skipped:
                state.skipped = False
J
Jim Hays 已提交
715

716 717 718
            if state.interrupted:
                break

719 720 721 722
            p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
            p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
723

724
            if p.scripts is not None:
725
                p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
726

727
            if len(p.prompts) == 0:
728 729
                break

W
w-e-w 已提交
730
            p.parse_extra_network_prompts()
I
InvincibleDude 已提交
731

732 733
            if not p.disable_extra_networks:
                with devices.autocast():
W
w-e-w 已提交
734
                    extra_networks.activate(p, p.extra_network_data)
735

A
Artem Zagidulin 已提交
736
            if p.scripts is not None:
737
                p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
A
Artem Zagidulin 已提交
738

739 740 741 742 743 744 745 746 747
            # params.txt should be saved after scripts.process_batch, since the
            # infotext could be modified by that callback
            # Example: a wildcard processed by process_batch sets an extra model
            # strength, which is saved as "Model Strength: 1.0" in the infotext
            if n == 0:
                with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
                    processed = Processed(p, [], p.seed, "")
                    file.write(processed.infotext(p, 0))

748
            p.setup_conds()
749

750 751 752 753
            for comment in model_hijack.comments:
                comments[comment] = 1

            p.extra_generation_params.update(model_hijack.extra_generation_params)
754 755

            if p.n_iter > 1:
756
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"
757

758
            with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
759
                samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
760

761
            x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
762 763 764
            for x in x_samples_ddim:
                devices.test_for_nans(x, "vae")

765
            x_samples_ddim = torch.stack(x_samples_ddim).float()
766 767
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

768 769
            del samples_ddim

770
            if lowvram.is_enabled(shared.sd_model):
771 772 773 774
                lowvram.send_everything_to_cpu()

            devices.torch_gc()

775 776
            if p.scripts is not None:
                p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
G
GRMrGecko 已提交
777

778
            for i, x_sample in enumerate(x_samples_ddim):
779 780
                p.batch_index = i

781 782 783
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)

784
                if p.restore_faces:
785
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
786
                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
787

788
                    devices.torch_gc()
789

790 791
                    x_sample = modules.face_restoration.restore_faces(x_sample)
                    devices.torch_gc()
792

793
                image = Image.fromarray(x_sample)
V
Vladimir Repin 已提交
794

795 796 797 798 799
                if p.scripts is not None:
                    pp = scripts.PostprocessImageArgs(image)
                    p.scripts.postprocess_image(p, pp)
                    image = pp.image

800
                if p.color_corrections is not None and i < len(p.color_corrections):
801
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
A
AUTOMATIC 已提交
802
                        image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
803
                        images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
804
                    image = apply_color_correction(p.color_corrections[i], image)
805

A
AUTOMATIC 已提交
806
                image = apply_overlay(image, p.paste_to, i, p.overlay_images)
807 808

                if opts.samples_save and not p.do_not_save_samples:
809
                    images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p)
810

811 812
                text = infotext(n, i)
                infotexts.append(text)
813 814
                if opts.enable_pnginfo:
                    image.info["parameters"] = text
815 816
                output_images.append(image)

817
                if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
818
                    image_mask = p.mask_for_overlay.convert('RGB')
819
                    image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
820 821

                    if opts.save_mask:
822
                        images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
823 824

                    if opts.save_mask_composite:
825
                        images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
826 827 828

                    if opts.return_mask:
                        output_images.append(image_mask)
829

830 831 832
                    if opts.return_mask_composite:
                        output_images.append(image_mask_composite)

J
Jim Hays 已提交
833
            del x_samples_ddim
A
AUTOMATIC 已提交
834

835
            devices.torch_gc()
836

837
            state.nextjob()
838

839 840
        p.color_corrections = None

841
        index_of_first_image = 0
842
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
843
        if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
A
AUTOMATIC 已提交
844
            grid = images.image_grid(output_images, p.batch_size)
845

846
            if opts.return_grid:
847
                text = infotext(use_main_prompt=True)
848
                infotexts.insert(0, text)
849 850
                if opts.enable_pnginfo:
                    grid.info["parameters"] = text
851
                output_images.insert(0, grid)
852
                index_of_first_image = 1
853 854

            if opts.grid_save:
855
                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
856

W
w-e-w 已提交
857 858
    if not p.disable_extra_networks and p.extra_network_data:
        extra_networks.deactivate(p, p.extra_network_data)
A
AUTOMATIC 已提交
859

860
    devices.torch_gc()
A
AUTOMATIC 已提交
861

862 863 864 865 866
    res = Processed(
        p,
        images_list=output_images,
        seed=p.all_seeds[0],
        info=infotext(),
867
        comments="".join(f"{comment}\n" for comment in comments),
868 869 870 871
        subseed=p.all_subseeds[0],
        index_of_first_image=index_of_first_image,
        infotexts=infotexts,
    )
A
AUTOMATIC 已提交
872 873 874 875 876

    if p.scripts is not None:
        p.scripts.postprocess(p, res)

    return res
877 878


879 880 881 882 883 884 885 886 887 888 889 890
def old_hires_fix_first_pass_dimensions(width, height):
    """old algorithm for auto-calculating first pass size"""

    desired_pixel_count = 512 * 512
    actual_pixel_count = width * height
    scale = math.sqrt(desired_pixel_count / actual_pixel_count)
    width = math.ceil(scale * width / 64) * 64
    height = math.ceil(scale * height / 64) * 64

    return width, height


891 892
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    sampler = None
W
w-e-w 已提交
893 894
    cached_hr_uc = [None, None]
    cached_hr_c = [None, None]
A
AUTOMATIC 已提交
895

896
    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
A
AUTOMATIC 已提交
897 898 899
        super().__init__(**kwargs)
        self.enable_hr = enable_hr
        self.denoising_strength = denoising_strength
A
AUTOMATIC 已提交
900 901
        self.hr_scale = hr_scale
        self.hr_upscaler = hr_upscaler
902 903 904 905 906
        self.hr_second_pass_steps = hr_second_pass_steps
        self.hr_resize_x = hr_resize_x
        self.hr_resize_y = hr_resize_y
        self.hr_upscale_to_x = hr_resize_x
        self.hr_upscale_to_y = hr_resize_y
907 908 909
        self.hr_sampler_name = hr_sampler_name
        self.hr_prompt = hr_prompt
        self.hr_negative_prompt = hr_negative_prompt
I
invincibledude 已提交
910 911
        self.all_hr_prompts = None
        self.all_hr_negative_prompts = None
A
AUTOMATIC 已提交
912 913

        if firstphase_width != 0 or firstphase_height != 0:
914 915
            self.hr_upscale_to_x = self.width
            self.hr_upscale_to_y = self.height
A
AUTOMATIC 已提交
916 917
            self.width = firstphase_width
            self.height = firstphase_height
A
AUTOMATIC 已提交
918

919 920
        self.truncate_x = 0
        self.truncate_y = 0
921
        self.applied_old_hires_behavior_to = None
A
AUTOMATIC 已提交
922

923 924 925 926
        self.hr_prompts = None
        self.hr_negative_prompts = None
        self.hr_extra_network_data = None

W
w-e-w 已提交
927 928
        self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
        self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
929 930 931
        self.hr_c = None
        self.hr_uc = None

A
AUTOMATIC 已提交
932 933
    def init(self, all_prompts, all_seeds, all_subseeds):
        if self.enable_hr:
934 935 936 937 938
            if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
                self.extra_generation_params["Hires sampler"] = self.hr_sampler_name

            if tuple(self.hr_prompt) != tuple(self.prompt):
                self.extra_generation_params["Hires prompt"] = self.hr_prompt
I
InvincibleDude 已提交
939

940 941
            if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
                self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
I
InvincibleDude 已提交
942

943 944 945 946 947 948 949 950 951
            if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
                self.hr_resize_x = self.width
                self.hr_resize_y = self.height
                self.hr_upscale_to_x = self.width
                self.hr_upscale_to_y = self.height

                self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
                self.applied_old_hires_behavior_to = (self.width, self.height)

952 953 954 955
            if self.hr_resize_x == 0 and self.hr_resize_y == 0:
                self.extra_generation_params["Hires upscale"] = self.hr_scale
                self.hr_upscale_to_x = int(self.width * self.hr_scale)
                self.hr_upscale_to_y = int(self.height * self.hr_scale)
A
AUTOMATIC 已提交
956
            else:
957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980
                self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"

                if self.hr_resize_y == 0:
                    self.hr_upscale_to_x = self.hr_resize_x
                    self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
                elif self.hr_resize_x == 0:
                    self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
                    self.hr_upscale_to_y = self.hr_resize_y
                else:
                    target_w = self.hr_resize_x
                    target_h = self.hr_resize_y
                    src_ratio = self.width / self.height
                    dst_ratio = self.hr_resize_x / self.hr_resize_y

                    if src_ratio < dst_ratio:
                        self.hr_upscale_to_x = self.hr_resize_x
                        self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
                    else:
                        self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
                        self.hr_upscale_to_y = self.hr_resize_y

                    self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
                    self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f

981 982 983 984 985 986 987
            # special case: the user has chosen to do nothing
            if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
                self.enable_hr = False
                self.denoising_strength = None
                self.extra_generation_params.pop("Hires upscale", None)
                self.extra_generation_params.pop("Hires resize", None)
                return
988

989 990 991
            if not state.processing_has_refined_job_count:
                if state.job_count == -1:
                    state.job_count = self.n_iter
992

993 994 995
                shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
                state.job_count = state.job_count * 2
                state.processing_has_refined_job_count = True
996

997 998
            if self.hr_second_pass_steps:
                self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
999

A
AUTOMATIC 已提交
1000 1001
            if self.hr_upscaler is not None:
                self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
1002

1003
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1004
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1005

1006
        latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
A
AUTOMATIC 已提交
1007
        if self.enable_hr and latent_scale_mode is None:
1008 1009
            if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
                raise Exception(f"could not find upscaler named {self.hr_upscaler}")
A
AUTOMATIC 已提交
1010 1011 1012 1013

        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))

1014
        if not self.enable_hr:
A
AUTOMATIC 已提交
1015 1016
            return samples

A
AUTOMATIC 已提交
1017 1018
        self.is_hr_pass = True

1019 1020
        target_width = self.hr_upscale_to_x
        target_height = self.hr_upscale_to_y
A
AUTOMATIC 已提交
1021

1022
        def save_intermediate(image, index):
A
AUTOMATIC 已提交
1023 1024
            """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""

1025 1026 1027 1028
            if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
                return

            if not isinstance(image, Image.Image):
M
MMaker 已提交
1029
                image = sd_samplers.sample_to_image(image, index, approximation=0)
1030

1031
            info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
W
w-e-w 已提交
1032
            images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
1033

A
AUTOMATIC 已提交
1034
        if latent_scale_mode is not None:
1035 1036 1037
            for i in range(samples.shape[0]):
                save_intermediate(samples, i)

M
MMaker 已提交
1038
            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
1039

J
Jim Hays 已提交
1040
            # Avoid making the inpainting conditioning unless necessary as
1041 1042 1043 1044 1045
            # this does need some extra compute to decode / encode the image again.
            if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
                image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
            else:
                image_conditioning = self.txt2img_image_conditioning(samples)
A
AUTOMATIC 已提交
1046
        else:
1047
            decoded_samples = decode_first_stage(self.sd_model, samples)
1048
            lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
1049

1050 1051 1052 1053 1054
            batch_images = []
            for i, x_sample in enumerate(lowres_samples):
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)
                image = Image.fromarray(x_sample)
1055 1056 1057

                save_intermediate(image, i)

A
AUTOMATIC 已提交
1058
                image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
1059 1060 1061 1062 1063 1064 1065 1066
                image = np.array(image).astype(np.float32) / 255.0
                image = np.moveaxis(image, 2, 0)
                batch_images.append(image)

            decoded_samples = torch.from_numpy(np.array(batch_images))
            decoded_samples = decoded_samples.to(shared.device)
            decoded_samples = 2. * decoded_samples - 1.

1067
            samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
A
AUTOMATIC 已提交
1068

1069
            image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
1070

A
AUTOMATIC 已提交
1071
        shared.state.nextjob()
1072

1073
        img2img_sampler_name = self.hr_sampler_name or self.sampler_name
1074

S
bug fix  
space-nuko 已提交
1075 1076
        if self.sampler_name in ['PLMS', 'UniPC']:  # PLMS/UniPC do not support img2img so we just silently switch to DDIM
            img2img_sampler_name = 'DDIM'
1077

1078
        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
A
AUTOMATIC 已提交
1079

1080 1081
        samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]

A
AUTOMATIC 已提交
1082
        noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
1083 1084 1085 1086

        # GC now before running the next img2img to prevent running out of memory
        x = None
        devices.torch_gc()
1087

1088 1089 1090 1091
        if not self.disable_extra_networks:
            with devices.autocast():
                extra_networks.activate(self, self.hr_extra_network_data)

1092 1093 1094
        with devices.autocast():
            self.calculate_hr_conds()

1095
        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
1096

1097 1098 1099
        if self.scripts is not None:
            self.scripts.before_hr(self)

1100
        samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
A
AUTOMATIC 已提交
1101

1102
        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
P
papuSpartan 已提交
1103

A
AUTOMATIC 已提交
1104 1105
        self.is_hr_pass = False

A
AUTOMATIC 已提交
1106
        return samples
1107

1108
    def close(self):
W
w-e-w 已提交
1109
        super().close()
1110 1111
        self.hr_c = None
        self.hr_uc = None
W
w-e-w 已提交
1112 1113 1114
        if not opts.experimental_persistent_cond_cache:
            StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
            StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140

    def setup_prompts(self):
        super().setup_prompts()

        if not self.enable_hr:
            return

        if self.hr_prompt == '':
            self.hr_prompt = self.prompt

        if self.hr_negative_prompt == '':
            self.hr_negative_prompt = self.negative_prompt

        if type(self.hr_prompt) == list:
            self.all_hr_prompts = self.hr_prompt
        else:
            self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]

        if type(self.hr_negative_prompt) == list:
            self.all_hr_negative_prompts = self.hr_negative_prompt
        else:
            self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]

        self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
        self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]

1141 1142 1143 1144
    def calculate_hr_conds(self):
        if self.hr_c is not None:
            return

1145 1146
        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
1147

1148 1149 1150
    def setup_conds(self):
        super().setup_conds()

1151 1152 1153
        self.hr_uc = None
        self.hr_c = None

1154
        if self.enable_hr:
1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165
            if shared.opts.hires_fix_use_firstpass_conds:
                self.calculate_hr_conds()

            elif lowvram.is_enabled(shared.sd_model):  # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
                with devices.autocast():
                    extra_networks.activate(self, self.hr_extra_network_data)

                self.calculate_hr_conds()

                with devices.autocast():
                    extra_networks.activate(self, self.extra_network_data)
1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177

    def parse_extra_network_prompts(self):
        res = super().parse_extra_network_prompts()

        if self.enable_hr:
            self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
            self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]

            self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)

        return res

1178 1179 1180 1181

class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    sampler = None

1182
    def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
1183 1184 1185 1186 1187
        super().__init__(**kwargs)

        self.init_images = init_images
        self.resize_mode: int = resize_mode
        self.denoising_strength: float = denoising_strength
K
Kyle 已提交
1188
        self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
1189 1190
        self.init_latent = None
        self.image_mask = mask
A
AUTOMATIC 已提交
1191
        self.latent_mask = None
1192
        self.mask_for_overlay = None
1193 1194 1195 1196 1197
        if mask_blur is not None:
            mask_blur_x = mask_blur
            mask_blur_y = mask_blur
        self.mask_blur_x = mask_blur_x
        self.mask_blur_y = mask_blur_y
1198 1199
        self.inpainting_fill = inpainting_fill
        self.inpaint_full_res = inpaint_full_res
1200
        self.inpaint_full_res_padding = inpaint_full_res_padding
A
AUTOMATIC 已提交
1201
        self.inpainting_mask_invert = inpainting_mask_invert
1202
        self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
1203 1204
        self.mask = None
        self.nmask = None
1205
        self.image_conditioning = None
1206

A
AUTOMATIC 已提交
1207
    def init(self, all_prompts, all_seeds, all_subseeds):
1208
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1209 1210
        crop_region = None

1211
        image_mask = self.image_mask
A
AUTOMATIC 已提交
1212

1213 1214
        if image_mask is not None:
            image_mask = image_mask.convert('L')
A
AUTOMATIC 已提交
1215

1216 1217
            if self.inpainting_mask_invert:
                image_mask = ImageOps.invert(image_mask)
A
AUTOMATIC 已提交
1218

1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229
            if self.mask_blur_x > 0:
                np_mask = np.array(image_mask)
                kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1
                np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
                image_mask = Image.fromarray(np_mask)

            if self.mask_blur_y > 0:
                np_mask = np.array(image_mask)
                kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1
                np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
                image_mask = Image.fromarray(np_mask)
1230 1231

            if self.inpaint_full_res:
1232 1233
                self.mask_for_overlay = image_mask
                mask = image_mask.convert('L')
1234
                crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
1235
                crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
1236 1237 1238
                x1, y1, x2, y2 = crop_region

                mask = mask.crop(crop_region)
1239
                image_mask = images.resize_image(2, mask, self.width, self.height)
1240 1241
                self.paste_to = (x1, y1, x2-x1, y2-y1)
            else:
1242
                image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
1243
                np_mask = np.array(image_mask)
J
JJ 已提交
1244
                np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
1245
                self.mask_for_overlay = Image.fromarray(np_mask)
1246 1247 1248

            self.overlay_images = []

1249
        latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
1250

1251 1252 1253
        add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
        if add_color_corrections:
            self.color_corrections = []
1254 1255
        imgs = []
        for img in self.init_images:
1256 1257 1258 1259 1260 1261

            # Save init image
            if opts.save_init_img:
                self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
                images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)

1262
            image = images.flatten(img, opts.img2img_background_color)
1263

A
Andrew Ryan 已提交
1264
            if crop_region is None and self.resize_mode != 3:
1265
                image = images.resize_image(self.resize_mode, image, self.width, self.height)
1266

1267
            if image_mask is not None:
1268 1269 1270 1271 1272
                image_masked = Image.new('RGBa', (image.width, image.height))
                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))

                self.overlay_images.append(image_masked.convert('RGBA'))

1273
            # crop_region is not None if we are doing inpaint full res
1274 1275 1276 1277
            if crop_region is not None:
                image = image.crop(crop_region)
                image = images.resize_image(2, image, self.width, self.height)

1278
            if image_mask is not None:
1279
                if self.inpainting_fill != 1:
1280
                    image = masking.fill(image, latent_mask)
1281

1282
            if add_color_corrections:
1283 1284
                self.color_corrections.append(setup_color_correction(image))

1285 1286 1287 1288 1289 1290 1291 1292 1293
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)

            imgs.append(image)

        if len(imgs) == 1:
            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size
1294 1295 1296 1297

            if self.color_corrections is not None and len(self.color_corrections) == 1:
                self.color_corrections = self.color_corrections * self.batch_size

1298 1299 1300 1301 1302 1303 1304 1305
        elif len(imgs) <= self.batch_size:
            self.batch_size = len(imgs)
            batch_images = np.array(imgs)
        else:
            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

        image = torch.from_numpy(batch_images)
        image = 2. * image - 1.
1306
        image = image.to(shared.device, dtype=devices.dtype_vae)
1307 1308 1309

        self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))

1310 1311
        if self.resize_mode == 3:
            self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
A
Andrew Ryan 已提交
1312

1313
        if image_mask is not None:
1314
            init_mask = latent_mask
A
AUTOMATIC 已提交
1315
            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
A
AUTOMATIC 已提交
1316
            latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
1317
            latmask = latmask[0]
1318
            latmask = np.around(latmask)
1319 1320 1321 1322 1323
            latmask = np.tile(latmask[None], (4, 1, 1))

            self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
            self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)

A
AUTOMATIC 已提交
1324
            # this needs to be fixed to be done in sample() using actual seeds for batches
1325
            if self.inpainting_fill == 2:
A
AUTOMATIC 已提交
1326
                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
1327 1328 1329
            elif self.inpainting_fill == 3:
                self.init_latent = self.init_latent * self.mask

1330
        self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
1331

1332
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
A
AUTOMATIC 已提交
1333
        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
1334 1335 1336 1337

        if self.initial_noise_multiplier != 1.0:
            self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
            x *= self.initial_noise_multiplier
A
AUTOMATIC 已提交
1338

1339
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
1340 1341 1342 1343

        if self.mask is not None:
            samples = samples * self.nmask + self.init_latent * self.mask

1344 1345 1346
        del x
        devices.torch_gc()

1347
        return samples
1348 1349 1350

    def get_token_merging_ratio(self, for_hr=False):
        return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio