processing.py 78.1 KB
Newer Older
1
from __future__ import annotations
2
import json
3
import logging
4 5 6
import math
import os
import sys
7
import hashlib
8
from dataclasses import dataclass, field
9 10 11

import torch
import numpy as np
A
linter  
AUTOMATIC 已提交
12
from PIL import Image, ImageOps
13
import random
14 15
import cv2
from skimage import exposure
16
from typing import Any
17

18
import modules.sd_hijack
19
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
20
from modules.rng import slerp # noqa: F401
21
from modules.sd_hijack import model_hijack
22
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
23 24
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
25
import modules.paths as paths
A
AUTOMATIC 已提交
26
import modules.face_restoration
27
import modules.images as images
A
AUTOMATIC 已提交
28
import modules.styles
29 30
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
J
Jay Smith 已提交
31 32
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
33

J
Jay Smith 已提交
34
from einops import repeat, rearrange
35
from blendmodes.blend import blendLayers, BlendType
36 37


38 39 40 41 42
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8


43
def setup_color_correction(image):
R
Robin Fernandes 已提交
44
    logging.info("Calibrating color correction.")
45 46 47 48
    correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
    return correction_target


49
def apply_color_correction(correction, original_image):
R
Robin Fernandes 已提交
50
    logging.info("Applying color correction.")
51 52
    image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
        cv2.cvtColor(
53
            np.asarray(original_image),
54 55 56 57 58
            cv2.COLOR_RGB2LAB
        ),
        correction,
        channel_axis=2
    ), cv2.COLOR_LAB2RGB).astype("uint8"))
59

60
    image = blendLayers(image, original_image, BlendType.LUMINOSITY)
61

62
    return image.convert('RGB')
63

A
AUTOMATIC 已提交
64

65 66 67 68 69 70 71 72
def uncrop(image, dest_size, paste_loc):
    x, y, w, h = paste_loc
    base_image = Image.new('RGBA', dest_size)
    image = images.resize_image(1, image, w, h)
    base_image.paste(image, (x, y))
    image = base_image

    return image
A
AUTOMATIC 已提交
73

74

75 76
def apply_overlay(image, paste_loc, overlay):
    if overlay is None:
A
Andray 已提交
77
        return image, image.copy()
A
AUTOMATIC 已提交
78 79

    if paste_loc is not None:
80
        image = uncrop(image, (overlay.width, overlay.height), paste_loc)
A
AUTOMATIC 已提交
81

A
Andray 已提交
82 83
    original_denoised_image = image.copy()

A
AUTOMATIC 已提交
84 85 86
    image = image.convert('RGBA')
    image.alpha_composite(overlay)
    image = image.convert('RGB')
87

A
Andray 已提交
88
    return image, original_denoised_image
89

90
def create_binary_mask(image, round=True):
91
    if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
92 93 94 95
        if round:
            image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
        else:
            image = image.split()[-1].convert("L")
96 97 98
    else:
        image = image.convert('L')
    return image
F
frostydad 已提交
99

100
def txt2img_image_conditioning(sd_model, x, width, height):
101 102
    if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models

103 104
        # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
        image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
105
        image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
106 107 108 109 110 111

        # Add the fake full 1s mask to the first dimension.
        image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
        image_conditioning = image_conditioning.to(x.dtype)

        return image_conditioning
112

113
    elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
114

115
        return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
116

117
    else:
W
wangqyqq 已提交
118 119
        sd = sd_model.model.state_dict()
        diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
W
wangqyqq 已提交
120 121 122 123 124 125
        if diffusion_model_input is not None:
            if diffusion_model_input.shape[1] == 9:
                # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
                image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
                image_conditioning = images_tensor_to_samples(image_conditioning,
                                                              approximation_indexes.get(opts.sd_vae_encode_method))
W
wangqyqq 已提交
126

W
wangqyqq 已提交
127 128 129
                # Add the fake full 1s mask to the first dimension.
                image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
                image_conditioning = image_conditioning.to(x.dtype)
W
wangqyqq 已提交
130

W
wangqyqq 已提交
131
                return image_conditioning
W
wangqyqq 已提交
132

133 134 135 136
        # Dummy zero conditioning if we're not using inpainting or unclip models.
        # Still takes up a bit of memory, but no encoder call.
        # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
        return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
137 138


139
@dataclass(repr=False)
140
class StableDiffusionProcessing:
141 142 143 144 145 146
    sd_model: object = None
    outpath_samples: str = None
    outpath_grids: str = None
    prompt: str = ""
    prompt_for_display: str = None
    negative_prompt: str = ""
A
AUTOMATIC1111 已提交
147
    styles: list[str] = None
148 149 150 151 152 153 154
    seed: int = -1
    subseed: int = -1
    subseed_strength: float = 0
    seed_resize_from_h: int = -1
    seed_resize_from_w: int = -1
    seed_enable_extras: bool = True
    sampler_name: str = None
A
AUTOMATIC1111 已提交
155
    scheduler: str = None
156 157 158 159 160 161 162 163 164 165 166 167 168 169
    batch_size: int = 1
    n_iter: int = 1
    steps: int = 50
    cfg_scale: float = 7.0
    width: int = 512
    height: int = 512
    restore_faces: bool = None
    tiling: bool = None
    do_not_save_samples: bool = False
    do_not_save_grid: bool = False
    extra_generation_params: dict[str, Any] = None
    overlay_images: list = None
    eta: float = None
    do_not_reload_embeddings: bool = False
170
    denoising_strength: float = None
171 172 173 174 175 176 177 178 179 180 181 182 183 184
    ddim_discretize: str = None
    s_min_uncond: float = None
    s_churn: float = None
    s_tmax: float = None
    s_tmin: float = None
    s_noise: float = None
    override_settings: dict[str, Any] = None
    override_settings_restore_afterwards: bool = True
    sampler_index: int = None
    refiner_checkpoint: str = None
    refiner_switch_at: float = None
    token_merging_ratio = 0
    token_merging_ratio_hr = 0
    disable_extra_networks: bool = False
185
    firstpass_image: Image = None
186

A
AUTOMATIC1111 已提交
187 188 189
    scripts_value: scripts.ScriptRunner = field(default=None, init=False)
    script_args_value: list = field(default=None, init=False)
    scripts_setup_complete: bool = field(default=False, init=False)
190

W
w-e-w 已提交
191 192 193
    cached_uc = [None, None]
    cached_c = [None, None]

194
    comments: dict = None
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
    sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
    is_using_inpainting_conditioning: bool = field(default=False, init=False)
    paste_to: tuple | None = field(default=None, init=False)

    is_hr_pass: bool = field(default=False, init=False)

    c: tuple = field(default=None, init=False)
    uc: tuple = field(default=None, init=False)

    rng: rng.ImageRNG | None = field(default=None, init=False)
    step_multiplier: int = field(default=1, init=False)
    color_corrections: list = field(default=None, init=False)

    all_prompts: list = field(default=None, init=False)
    all_negative_prompts: list = field(default=None, init=False)
    all_seeds: list = field(default=None, init=False)
    all_subseeds: list = field(default=None, init=False)
    iteration: int = field(default=0, init=False)
    main_prompt: str = field(default=None, init=False)
    main_negative_prompt: str = field(default=None, init=False)

    prompts: list = field(default=None, init=False)
    negative_prompts: list = field(default=None, init=False)
    seeds: list = field(default=None, init=False)
    subseeds: list = field(default=None, init=False)
    extra_network_data: dict = field(default=None, init=False)

    user: str = field(default=None, init=False)

    sd_model_name: str = field(default=None, init=False)
    sd_model_hash: str = field(default=None, init=False)
    sd_vae_name: str = field(default=None, init=False)
    sd_vae_hash: str = field(default=None, init=False)

A
AUTOMATIC1111 已提交
229 230
    is_api: bool = field(default=False, init=False)

231 232
    def __post_init__(self):
        if self.sampler_index is not None:
233
            print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
234

235
        self.comments = {}
236 237 238

        if self.styles is None:
            self.styles = []
239

240
        self.sampler_noise_scheduler_override = None
241 242 243 244 245 246 247 248 249 250
        self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
        self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
        self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
        self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
        self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise

        self.extra_generation_params = self.extra_generation_params or {}
        self.override_settings = self.override_settings or {}
        self.script_args = self.script_args or {}

251
        self.refiner_checkpoint_info = None
252

253
        if not self.seed_enable_extras:
254 255 256 257 258
            self.subseed = -1
            self.subseed_strength = 0
            self.seed_resize_from_h = 0
            self.seed_resize_from_w = 0

W
w-e-w 已提交
259 260
        self.cached_uc = StableDiffusionProcessing.cached_uc
        self.cached_c = StableDiffusionProcessing.cached_c
261

262 263 264 265
    @property
    def sd_model(self):
        return shared.sd_model

266 267 268 269
    @sd_model.setter
    def sd_model(self, value):
        pass

A
AUTOMATIC1111 已提交
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
    @property
    def scripts(self):
        return self.scripts_value

    @scripts.setter
    def scripts(self, value):
        self.scripts_value = value

        if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
            self.setup_scripts()

    @property
    def script_args(self):
        return self.script_args_value

    @script_args.setter
    def script_args(self, value):
        self.script_args_value = value

        if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
            self.setup_scripts()

    def setup_scripts(self):
        self.scripts_setup_complete = True

A
AUTOMATIC1111 已提交
295
        self.scripts.setup_scrips(self, is_ui=not self.is_api)
A
AUTOMATIC1111 已提交
296

297 298 299
    def comment(self, text):
        self.comments[text] = 1

300
    def txt2img_image_conditioning(self, x, width=None, height=None):
301
        self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
302

303
        return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
304

J
Jay Smith 已提交
305 306 307 308 309 310 311
    def depth2img_image_conditioning(self, source_image):
        # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
        transformer = AddMiDaS(model_type="dpt_hybrid")
        transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
        midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
        midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)

312
        conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
J
Jay Smith 已提交
313 314 315 316 317 318 319 320 321 322
        conditioning = torch.nn.functional.interpolate(
            self.sd_model.depth_model(midas_in),
            size=conditioning_image.shape[2:],
            mode="bicubic",
            align_corners=False,
        )

        (depth_min, depth_max) = torch.aminmax(conditioning)
        conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
        return conditioning
323

324
    def edit_image_conditioning(self, source_image):
A
AUTOMATIC1111 已提交
325
        conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
326 327 328

        return conditioning_image

329 330 331 332 333 334 335 336
    def unclip_image_conditioning(self, source_image):
        c_adm = self.sd_model.embedder(source_image)
        if self.sd_model.noise_augmentor is not None:
            noise_level = 0 # TODO: Allow other noise levels?
            c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
            c_adm = torch.cat((c_adm, noise_level_emb), 1)
        return c_adm

337
    def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
338 339
        self.is_using_inpainting_conditioning = True

340 341 342 343 344 345 346 347 348
        # Handle the different mask inputs
        if image_mask is not None:
            if torch.is_tensor(image_mask):
                conditioning_mask = image_mask
            else:
                conditioning_mask = np.array(image_mask.convert("L"))
                conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
                conditioning_mask = torch.from_numpy(conditioning_mask[None, None])

349 350 351 352
                if round_image_mask:
                    # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
                    conditioning_mask = torch.round(conditioning_mask)

353
        else:
354
            conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
355 356 357

        # Create another latent image, this time with a masked version of the original input.
        # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
358
        conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
359 360 361 362 363
        conditioning_image = torch.lerp(
            source_image,
            source_image * (1.0 - conditioning_mask),
            getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
        )
J
Jim Hays 已提交
364

365
        # Encode the new masked image using first stage of network.
K
Kohaku-Blueleaf 已提交
366
        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
367 368 369 370 371 372 373 374 375

        # Create the concatenated conditioning tensor to be fed to `c_concat`
        conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
        conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
        image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
        image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)

        return image_conditioning

376
    def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
377 378
        source_image = devices.cond_cast_float(source_image)

J
Jay Smith 已提交
379 380 381
        # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
        # identify itself with a field common to all models. The conditioning_key is also hybrid.
        if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
382
            return self.depth2img_image_conditioning(source_image)
J
Jay Smith 已提交
383

384 385 386
        if self.sd_model.cond_stage_key == "edit":
            return self.edit_image_conditioning(source_image)

J
Jay Smith 已提交
387
        if self.sampler.conditioning_key in {'hybrid', 'concat'}:
C
CodeHatchling 已提交
388
            return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
J
Jay Smith 已提交
389

390 391 392
        if self.sampler.conditioning_key == "crossattn-adm":
            return self.unclip_image_conditioning(source_image)

W
wangqyqq 已提交
393 394
        sd = self.sampler.model_wrap.inner_model.model.state_dict()
        diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
W
wangqyqq 已提交
395 396 397
        if diffusion_model_input is not None:
            if diffusion_model_input.shape[1] == 9:
                return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
W
wangqyqq 已提交
398

J
Jay Smith 已提交
399 400 401
        # Dummy zero conditioning if we're not using inpainting or depth model.
        return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)

A
AUTOMATIC 已提交
402
    def init(self, all_prompts, all_seeds, all_subseeds):
403 404
        pass

405
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
406 407
        raise NotImplementedError()

408 409
    def close(self):
        self.sampler = None
410 411
        self.c = None
        self.uc = None
A
AUTOMATIC1111 已提交
412
        if not opts.persistent_cond_cache:
W
w-e-w 已提交
413 414
            StableDiffusionProcessing.cached_c = [None, None]
            StableDiffusionProcessing.cached_uc = [None, None]
415

416 417 418 419 420 421
    def get_token_merging_ratio(self, for_hr=False):
        if for_hr:
            return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio

        return self.token_merging_ratio or opts.token_merging_ratio

422
    def setup_prompts(self):
X
XDOneDude 已提交
423
        if isinstance(self.prompt,list):
424
            self.all_prompts = self.prompt
X
XDOneDude 已提交
425
        elif isinstance(self.negative_prompt, list):
426
            self.all_prompts = [self.prompt] * len(self.negative_prompt)
427 428 429
        else:
            self.all_prompts = self.batch_size * self.n_iter * [self.prompt]

X
XDOneDude 已提交
430
        if isinstance(self.negative_prompt, list):
431 432
            self.all_negative_prompts = self.negative_prompt
        else:
433 434 435 436
            self.all_negative_prompts = [self.negative_prompt] * len(self.all_prompts)

        if len(self.all_prompts) != len(self.all_negative_prompts):
            raise RuntimeError(f"Received a different number of prompts ({len(self.all_prompts)}) and negative prompts ({len(self.all_negative_prompts)})")
437 438 439 440

        self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
        self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]

441 442 443
        self.main_prompt = self.all_prompts[0]
        self.main_negative_prompt = self.all_negative_prompts[0]

444
    def cached_params(self, required_prompts, steps, extra_network_data, hires_steps=None, use_old_scheduling=False):
A
AUTOMATIC1111 已提交
445 446 447 448 449
        """Returns parameters that invalidate the cond cache if changed"""

        return (
            required_prompts,
            steps,
450 451
            hires_steps,
            use_old_scheduling,
A
AUTOMATIC1111 已提交
452 453 454 455 456 457 458
            opts.CLIP_stop_at_last_layers,
            shared.sd_model.sd_checkpoint_info,
            extra_network_data,
            opts.sdxl_crop_left,
            opts.sdxl_crop_top,
            self.width,
            self.height,
459 460
            opts.fp8_storage,
            opts.cache_fp16_weight,
461
            opts.emphasis,
A
AUTOMATIC1111 已提交
462 463
        )

464
    def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
465 466 467 468 469 470 471 472
        """
        Returns the result of calling function(shared.sd_model, required_prompts, steps)
        using a cache to store the result if the same arguments have been used before.

        cache is an array containing two elements. The first element is a tuple
        representing the previously used arguments, or None if no arguments
        have been used before. The second element is where the previously
        computed result is stored.
473 474

        caches is a list with items described above.
475
        """
476

477 478 479 480 481 482
        if shared.opts.use_old_scheduling:
            old_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, False)
            new_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, True)
            if old_schedules != new_schedules:
                self.extra_generation_params["Old prompt editing timelines"] = True

483
        cached_params = self.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling)
484

485
        for cache in caches:
486
            if cache[0] is not None and cached_params == cache[0]:
487 488 489
                return cache[1]

        cache = caches[0]
490 491

        with devices.autocast():
492
            cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
493

494
        cache[0] = cached_params
495 496 497
        return cache[1]

    def setup_conds(self):
A
AUTOMATIC1111 已提交
498
        prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
499
        negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
A
AUTOMATIC1111 已提交
500

501
        sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
502 503
        total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
        self.step_multiplier = total_steps // self.steps
504 505
        self.firstpass_steps = total_steps

506 507
        self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
        self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
508

A
AUTOMATIC1111 已提交
509 510 511
    def get_conds(self):
        return self.c, self.uc

512
    def parse_extra_network_prompts(self):
W
w-e-w 已提交
513
        self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
514

A
AUTOMATIC1111 已提交
515 516 517 518
    def save_samples(self) -> bool:
        """Returns whether generated images need to be written to disk"""
        return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)

519 520

class Processed:
521
    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
522 523
        self.images = images_list
        self.prompt = p.prompt
524
        self.negative_prompt = p.negative_prompt
525
        self.seed = seed
526 527
        self.subseed = subseed
        self.subseed_strength = p.subseed_strength
528
        self.info = info
529
        self.comments = "".join(f"{comment}\n" for comment in p.comments)
530 531
        self.width = p.width
        self.height = p.height
532
        self.sampler_name = p.sampler_name
533
        self.cfg_scale = p.cfg_scale
K
Kyle 已提交
534
        self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
535
        self.steps = p.steps
536 537 538
        self.batch_size = p.batch_size
        self.restore_faces = p.restore_faces
        self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
539 540 541 542
        self.sd_model_name = p.sd_model_name
        self.sd_model_hash = p.sd_model_hash
        self.sd_vae_name = p.sd_vae_name
        self.sd_vae_hash = p.sd_vae_hash
543 544 545 546 547
        self.seed_resize_from_w = p.seed_resize_from_w
        self.seed_resize_from_h = p.seed_resize_from_h
        self.denoising_strength = getattr(p, 'denoising_strength', None)
        self.extra_generation_params = p.extra_generation_params
        self.index_of_first_image = index_of_first_image
M
Milly 已提交
548
        self.styles = p.styles
M
Milly 已提交
549
        self.job_timestamp = state.job_timestamp
550
        self.clip_skip = opts.CLIP_stop_at_last_layers
551 552
        self.token_merging_ratio = p.token_merging_ratio
        self.token_merging_ratio_hr = p.token_merging_ratio_hr
553

C
C43H66N12O12S2 已提交
554
        self.eta = p.eta
555 556 557 558 559
        self.ddim_discretize = p.ddim_discretize
        self.s_churn = p.s_churn
        self.s_tmin = p.s_tmin
        self.s_tmax = p.s_tmax
        self.s_noise = p.s_noise
A
Aarni Koskela 已提交
560
        self.s_min_uncond = p.s_min_uncond
561
        self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
X
XDOneDude 已提交
562 563 564 565
        self.prompt = self.prompt if not isinstance(self.prompt, list) else self.prompt[0]
        self.negative_prompt = self.negative_prompt if not isinstance(self.negative_prompt, list) else self.negative_prompt[0]
        self.seed = int(self.seed if not isinstance(self.seed, list) else self.seed[0]) if self.seed is not None else -1
        self.subseed = int(self.subseed if not isinstance(self.subseed, list) else self.subseed[0]) if self.subseed is not None else -1
566
        self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
567

568 569 570 571
        self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
        self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
        self.all_seeds = all_seeds or p.all_seeds or [self.seed]
        self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
572
        self.infotexts = infotexts or [info]
573
        self.version = program_version()
574 575 576

    def js(self):
        obj = {
577
            "prompt": self.all_prompts[0],
578
            "all_prompts": self.all_prompts,
579 580
            "negative_prompt": self.all_negative_prompts[0],
            "all_negative_prompts": self.all_negative_prompts,
581 582 583 584
            "seed": self.seed,
            "all_seeds": self.all_seeds,
            "subseed": self.subseed,
            "all_subseeds": self.all_subseeds,
585
            "subseed_strength": self.subseed_strength,
586 587
            "width": self.width,
            "height": self.height,
588
            "sampler_name": self.sampler_name,
589 590
            "cfg_scale": self.cfg_scale,
            "steps": self.steps,
591 592 593
            "batch_size": self.batch_size,
            "restore_faces": self.restore_faces,
            "face_restoration_model": self.face_restoration_model,
594
            "sd_model_name": self.sd_model_name,
595
            "sd_model_hash": self.sd_model_hash,
596 597
            "sd_vae_name": self.sd_vae_name,
            "sd_vae_hash": self.sd_vae_hash,
598 599 600 601 602
            "seed_resize_from_w": self.seed_resize_from_w,
            "seed_resize_from_h": self.seed_resize_from_h,
            "denoising_strength": self.denoising_strength,
            "extra_generation_params": self.extra_generation_params,
            "index_of_first_image": self.index_of_first_image,
603
            "infotexts": self.infotexts,
M
Milly 已提交
604
            "styles": self.styles,
M
Milly 已提交
605
            "job_timestamp": self.job_timestamp,
606
            "clip_skip": self.clip_skip,
607
            "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
608
            "version": self.version,
609 610
        }

W
w-e-w 已提交
611
        return json.dumps(obj, default=lambda o: None)
612

S
space-nuko 已提交
613
    def infotext(self, p: StableDiffusionProcessing, index):
614 615
        return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)

616 617 618
    def get_token_merging_ratio(self, for_hr=False):
        return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio

619

620
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
621 622
    g = rng.ImageRNG(shape, seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w)
    return g.next()
623 624


A
AUTOMATIC1111 已提交
625 626 627 628
class DecodedSamples(list):
    already_decoded = True


629
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
A
AUTOMATIC1111 已提交
630
    samples = DecodedSamples()
631 632 633 634 635

    for i in range(batch.shape[0]):
        sample = decode_first_stage(model, batch[i:i + 1])[0]

        if check_for_nans:
636

637 638 639
            try:
                devices.test_for_nans(sample, "vae")
            except devices.NansException as e:
640 641 642 643 644 645 646 647 648 649 650 651 652 653
                if shared.opts.auto_vae_precision_bfloat16:
                    autofix_dtype = torch.bfloat16
                    autofix_dtype_text = "bfloat16"
                    autofix_dtype_setting = "Automatically convert VAE to bfloat16"
                    autofix_dtype_comment = ""
                elif shared.opts.auto_vae_precision:
                    autofix_dtype = torch.float32
                    autofix_dtype_text = "32-bit float"
                    autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
                    autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
                else:
                    raise e

                if devices.dtype_vae == autofix_dtype:
654 655 656 657
                    raise e

                errors.print_error_explanation(
                    "A tensor with all NaNs was produced in VAE.\n"
658 659
                    f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
                    f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
660 661
                )

662
                devices.dtype_vae = autofix_dtype
663 664 665 666 667 668 669 670 671 672 673 674 675
                model.first_stage_model.to(devices.dtype_vae)
                batch = batch.to(devices.dtype_vae)

                sample = decode_first_stage(model, batch[i:i + 1])[0]

        if target_device is not None:
            sample = sample.to(target_device)

        samples.append(sample)

    return samples


676
def get_fixed_seed(seed):
A
AUTOMATIC1111 已提交
677 678 679 680 681 682 683 684 685
    if seed == '' or seed is None:
        seed = -1
    elif isinstance(seed, str):
        try:
            seed = int(seed)
        except Exception:
            seed = -1

    if seed == -1:
686 687 688 689 690
        return int(random.randrange(4294967294))

    return seed


691
def fix_seed(p):
692 693
    p.seed = get_fixed_seed(p.seed)
    p.subseed = get_fixed_seed(p.subseed)
A
AUTOMATIC 已提交
694 695


696 697 698 699 700 701 702 703 704 705
def program_version():
    import launch

    res = launch.git_tag()
    if res == "<none>":
        res = None

    return res


706
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):
W
w-e-w 已提交
707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750
    """
    this function is used to generate the infotext that is stored in the generated images, it's contains the parameters that are required to generate the imagee
    Args:
        p: StableDiffusionProcessing
        all_prompts: list[str]
        all_seeds: list[int]
        all_subseeds: list[int]
        comments: list[str]
        iteration: int
        position_in_batch: int
        use_main_prompt: bool
        index: int
        all_negative_prompts: list[str]

    Returns: str

    Extra generation params
    p.extra_generation_params dictionary allows for additional parameters to be added to the infotext
    this can be use by the base webui or extensions.
    To add a new entry, add a new key value pair, the dictionary key will be used as the key of the parameter in the infotext
    the value generation_params can be defined as:
        - str | None
        - List[str|None]
        - callable func(**kwargs) -> str | None

    When defined as a string, it will be used as without extra processing; this is this most common use case.

    Defining as a list allows for parameter that changes across images in the job, for example, the 'Seed' parameter.
    The list should have the same length as the total number of images in the entire job.

    Defining as a callable function allows parameter cannot be generated earlier or when extra logic is required.
    For example 'Hires prompt', due to reasons the hr_prompt might be changed by process in the pipeline or extensions
    and may vary across different images, defining as a static string or list would not work.

    The function takes locals() as **kwargs, as such will have access to variables like 'p' and 'index'.
    the base signature of the function should be:
        func(**kwargs) -> str | None
    optionally it can have additional arguments that will be used in the function:
        func(p, index, **kwargs) -> str | None
    note: for better future compatibility even though this function will have access to all variables in the locals(),
        it is recommended to only use the arguments present in the function signature of create_infotext.
    For actual implementation examples, see StableDiffusionProcessingTxt2Img.init > get_hr_prompt.
    """

W
w-e-w 已提交
751 752 753
    if use_main_prompt:
        index = 0
    elif index is None:
754 755 756 757
        index = position_in_batch + iteration * p.batch_size

    if all_negative_prompts is None:
        all_negative_prompts = p.all_negative_prompts
758

759
    clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
P
papuSpartan 已提交
760
    enable_hr = getattr(p, 'enable_hr', False)
761 762
    token_merging_ratio = p.get_token_merging_ratio()
    token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
763

764 765 766
    prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
    negative_prompt = p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]

767 768 769 770
    uses_ensd = opts.eta_noise_seed_delta != 0
    if uses_ensd:
        uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)

771 772
    generation_params = {
        "Steps": p.steps,
773
        "Sampler": p.sampler_name,
K
kaalibro 已提交
774
        "Schedule type": p.scheduler,
775
        "CFG scale": p.cfg_scale,
K
Kyle 已提交
776
        "Image CFG scale": getattr(p, 'image_cfg_scale', None),
A
AUTOMATIC1111 已提交
777
        "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
778
        "Face restoration": opts.face_restoration_model if p.restore_faces else None,
779
        "Size": f"{p.width}x{p.height}",
780 781
        "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
        "Model": p.sd_model_name if opts.add_model_name_to_info else None,
K
Kohaku-Blueleaf 已提交
782 783
        "FP8 weight": opts.fp8_storage if devices.fp8 else None,
        "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
784 785
        "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
        "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
A
AUTOMATIC1111 已提交
786
        "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
787
        "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
M
missionfloyd 已提交
788
        "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
789
        "Denoising strength": p.extra_generation_params.get("Denoising strength"),
790
        "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
791
        "Clip skip": None if clip_skip <= 1 else clip_skip,
792
        "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
793 794
        "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
        "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
795
        "Init image hash": getattr(p, 'init_img_hash', None),
796
        "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
797
        "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
798
        "Tiling": "True" if p.tiling else None,
799
        **p.extra_generation_params,
800
        "Version": program_version() if opts.add_version_to_infotext else None,
801
        "User": p.user if opts.add_user_name_to_info else None,
802 803
    }

804 805 806 807 808 809 810 811 812 813
    for key, value in generation_params.items():
        try:
            if isinstance(value, list):
                generation_params[key] = value[index]
            elif callable(value):
                generation_params[key] = value(**locals())
        except Exception:
            errors.report(f'Error creating infotext for key "{key}"', exc_info=True)
            generation_params[key] = None

814
    generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None])
815

816
    negative_prompt_text = f"\nNegative prompt: {negative_prompt}" if negative_prompt else ""
817

818
    return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
819 820


821
def process_images(p: StableDiffusionProcessing) -> Processed:
822 823 824
    if p.scripts is not None:
        p.scripts.before_process(p)

825
    stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
826 827

    try:
W
w-e-w 已提交
828
        # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
W
w-e-w 已提交
829
        # and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards
A
Aarni Koskela 已提交
830
        if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
W
w-e-w 已提交
831 832 833
            p.override_settings.pop('sd_model_checkpoint', None)
            sd_models.reload_model_weights()

834
        for k, v in p.override_settings.items():
835
            opts.set(k, v, is_api=True, run_callbacks=False)
836 837

            if k == 'sd_model_checkpoint':
A
AUTOMATIC 已提交
838
                sd_models.reload_model_weights()
839 840

            if k == 'sd_vae':
A
AUTOMATIC 已提交
841
                sd_vae.reload_vae_weights()
842

843
        sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
844

845 846
        res = process_images_inner(p)

847
    finally:
848
        sd_models.apply_token_merging(p.sd_model, 0)
849

850 851 852 853
        # restore opts to original state
        if p.override_settings_restore_afterwards:
            for k, v in stored_opts.items():
                setattr(opts, k, v)
A
AUTOMATIC 已提交
854 855 856

                if k == 'sd_vae':
                    sd_vae.reload_vae_weights()
857 858 859 860 861

    return res


def process_images_inner(p: StableDiffusionProcessing) -> Processed:
862 863
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

X
XDOneDude 已提交
864
    if isinstance(p.prompt, list):
865 866 867
        assert(len(p.prompt) > 0)
    else:
        assert p.prompt is not None
868

869
    devices.torch_gc()
870

871 872
    seed = get_fixed_seed(p.seed)
    subseed = get_fixed_seed(p.subseed)
873

874 875 876 877 878 879
    if p.restore_faces is None:
        p.restore_faces = opts.face_restoration

    if p.tiling is None:
        p.tiling = opts.tiling

880
    if p.refiner_checkpoint not in (None, "", "None", "none"):
881 882 883 884 885 886 887 888
        p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
        if p.refiner_checkpoint_info is None:
            raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')

    p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
    p.sd_model_hash = shared.sd_model.sd_model_hash
    p.sd_vae_name = sd_vae.get_loaded_vae_name()
    p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
889

890
    modules.sd_hijack.model_hijack.apply_circular(p.tiling)
891
    modules.sd_hijack.model_hijack.clear_comments()
892

893
    p.setup_prompts()
I
invincibledude 已提交
894

X
XDOneDude 已提交
895
    if isinstance(seed, list):
896
        p.all_seeds = seed
A
AUTOMATIC 已提交
897
    else:
898
        p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
899

X
XDOneDude 已提交
900
    if isinstance(subseed, list):
901
        p.all_subseeds = subseed
902
    else:
903
        p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
904

905
    if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
906
        model_hijack.embedding_db.load_textual_inversion_embeddings()
907

908
    if p.scripts is not None:
A
AUTOMATIC 已提交
909
        p.scripts.process(p)
910

911
    infotexts = []
912
    output_images = []
913
    with torch.no_grad(), p.sd_model.ema_scope():
A
AUTOMATIC 已提交
914
        with devices.autocast():
915
            p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
916

917 918
            # for OSX, loading the model during sampling changes the generated picture, so it is loaded here
            if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
919 920
                sd_vae_approx.model()

A
AUTOMATIC 已提交
921 922
            sd_unet.apply_unet()

A
AUTOMATIC 已提交
923 924
        if state.job_count == -1:
            state.job_count = p.n_iter
A
AUTOMATIC 已提交
925

926
        for n in range(p.n_iter):
927 928
            p.iteration = n

929 930
            if state.skipped:
                state.skipped = False
J
Jim Hays 已提交
931

932
            if state.interrupted or state.stopping_generation:
933 934
                break

935 936
            sd_models.reload_model_weights()  # model can be changed for example by refiner

937 938 939 940
            p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
            p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
941

942 943
            p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)

944
            if p.scripts is not None:
945
                p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
946

947
            if len(p.prompts) == 0:
948 949
                break

W
w-e-w 已提交
950
            p.parse_extra_network_prompts()
I
InvincibleDude 已提交
951

952 953
            if not p.disable_extra_networks:
                with devices.autocast():
W
w-e-w 已提交
954
                    extra_networks.activate(p, p.extra_network_data)
955

A
Artem Zagidulin 已提交
956
            if p.scripts is not None:
957
                p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
A
Artem Zagidulin 已提交
958

959 960 961 962
            p.setup_conds()

            p.extra_generation_params.update(model_hijack.extra_generation_params)

963 964 965 966
            # params.txt should be saved after scripts.process_batch, since the
            # infotext could be modified by that callback
            # Example: a wildcard processed by process_batch sets an extra model
            # strength, which is saved as "Model Strength: 1.0" in the infotext
967
            if n == 0 and not cmd_opts.no_prompt_history:
968
                with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
A
AUTOMATIC1111 已提交
969
                    processed = Processed(p, [])
970 971
                    file.write(processed.infotext(p, 0))

972
            for comment in model_hijack.comments:
973
                p.comment(comment)
974

975
            if p.n_iter > 1:
976
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"
977

978
            sd_models.apply_alpha_schedule_override(p.sd_model, p)
979

980
            with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
981
                samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
982

983 984 985
            if p.scripts is not None:
                ps = scripts.PostSampleArgs(samples_ddim)
                p.scripts.post_sample(p, ps)
986
                samples_ddim = ps.samples
987

A
AUTOMATIC1111 已提交
988 989 990
            if getattr(samples_ddim, 'already_decoded', False):
                x_samples_ddim = samples_ddim
            else:
991 992
                if opts.sd_vae_decode_method != 'Full':
                    p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
A
AUTOMATIC1111 已提交
993 994
                x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)

995
            x_samples_ddim = torch.stack(x_samples_ddim).float()
996 997
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

998 999
            del samples_ddim

1000
            if lowvram.is_enabled(shared.sd_model):
1001 1002 1003 1004
                lowvram.send_everything_to_cpu()

            devices.torch_gc()

1005 1006
            state.nextjob()

1007 1008
            if p.scripts is not None:
                p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
G
GRMrGecko 已提交
1009

1010 1011
                p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
                p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
1012

1013
                batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
1014 1015 1016 1017
                p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
                x_samples_ddim = batch_params.images

            def infotext(index=0, use_main_prompt=False):
A
AUTOMATIC1111 已提交
1018
                return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
L
ljleb 已提交
1019

A
AUTOMATIC1111 已提交
1020
            save_samples = p.save_samples()
W
w-e-w 已提交
1021

1022
            for i, x_sample in enumerate(x_samples_ddim):
1023 1024
                p.batch_index = i

1025 1026 1027
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)

1028
                if p.restore_faces:
A
AUTOMATIC1111 已提交
1029
                    if save_samples and opts.save_images_before_face_restoration:
1030
                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
1031

1032
                    devices.torch_gc()
1033

1034 1035
                    x_sample = modules.face_restoration.restore_faces(x_sample)
                    devices.torch_gc()
1036

1037
                image = Image.fromarray(x_sample)
V
Vladimir Repin 已提交
1038

1039 1040 1041 1042
                if p.scripts is not None:
                    pp = scripts.PostprocessImageArgs(image)
                    p.scripts.postprocess_image(p, pp)
                    image = pp.image
1043

1044
                mask_for_overlay = getattr(p, "mask_for_overlay", None)
1045 1046 1047 1048 1049 1050 1051

                if not shared.opts.overlay_inpaint:
                    overlay_image = None
                elif getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images):
                    overlay_image = p.overlay_images[i]
                else:
                    overlay_image = None
1052 1053 1054 1055

                if p.scripts is not None:
                    ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
                    p.scripts.postprocess_maskoverlay(p, ppmo)
1056
                    mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
1057

1058
                if p.color_corrections is not None and i < len(p.color_corrections):
A
AUTOMATIC1111 已提交
1059
                    if save_samples and opts.save_images_before_color_correction:
A
Andray 已提交
1060
                        image_without_cc, _ = apply_overlay(image, p.paste_to, overlay_image)
1061
                        images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
1062
                    image = apply_color_correction(p.color_corrections[i], image)
1063

1064 1065 1066 1067
                # If the intention is to show the output from the model
                # that is being composited over the original image,
                # we need to keep the original image around
                # and use it in the composite step.
A
Andray 已提交
1068
                image, original_denoised_image = apply_overlay(image, p.paste_to, overlay_image)
1069

1070 1071 1072 1073 1074
                if p.scripts is not None:
                    pp = scripts.PostprocessImageArgs(image)
                    p.scripts.postprocess_image_after_composite(p, pp)
                    image = pp.image

A
AUTOMATIC1111 已提交
1075
                if save_samples:
1076
                    images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
1077

1078
                text = infotext(i)
1079
                infotexts.append(text)
1080 1081
                if opts.enable_pnginfo:
                    image.info["parameters"] = text
1082
                output_images.append(image)
1083

1084
                if mask_for_overlay is not None:
W
w-e-w 已提交
1085
                    if opts.return_mask or opts.save_mask:
1086
                        image_mask = mask_for_overlay.convert('RGB')
W
w-e-w 已提交
1087 1088
                        if save_samples and opts.save_mask:
                            images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
1089 1090 1091
                        if opts.return_mask:
                            output_images.append(image_mask)

W
w-e-w 已提交
1092
                    if opts.return_mask_composite or opts.save_mask_composite:
1093
                        image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
W
w-e-w 已提交
1094 1095
                        if save_samples and opts.save_mask_composite:
                            images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
1096 1097
                        if opts.return_mask_composite:
                            output_images.append(image_mask_composite)
1098

J
Jim Hays 已提交
1099
            del x_samples_ddim
A
AUTOMATIC 已提交
1100

1101
            devices.torch_gc()
1102

1103 1104 1105
        if not infotexts:
            infotexts.append(Processed(p, []).infotext(p, 0))

1106 1107
        p.color_corrections = None

1108
        index_of_first_image = 0
1109
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
1110
        if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
A
AUTOMATIC 已提交
1111
            grid = images.image_grid(output_images, p.batch_size)
1112

1113
            if opts.return_grid:
1114
                text = infotext(use_main_prompt=True)
1115
                infotexts.insert(0, text)
1116 1117
                if opts.enable_pnginfo:
                    grid.info["parameters"] = text
1118
                output_images.insert(0, grid)
1119
                index_of_first_image = 1
1120
            if opts.grid_save:
1121
                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
1122

W
w-e-w 已提交
1123 1124
    if not p.disable_extra_networks and p.extra_network_data:
        extra_networks.deactivate(p, p.extra_network_data)
A
AUTOMATIC 已提交
1125

1126
    devices.torch_gc()
A
AUTOMATIC 已提交
1127

1128 1129 1130 1131
    res = Processed(
        p,
        images_list=output_images,
        seed=p.all_seeds[0],
A
AUTOMATIC1111 已提交
1132
        info=infotexts[0],
1133 1134 1135 1136
        subseed=p.all_subseeds[0],
        index_of_first_image=index_of_first_image,
        infotexts=infotexts,
    )
A
AUTOMATIC 已提交
1137 1138 1139 1140 1141

    if p.scripts is not None:
        p.scripts.postprocess(p, res)

    return res
1142 1143


1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155
def old_hires_fix_first_pass_dimensions(width, height):
    """old algorithm for auto-calculating first pass size"""

    desired_pixel_count = 512 * 512
    actual_pixel_count = width * height
    scale = math.sqrt(desired_pixel_count / actual_pixel_count)
    width = math.ceil(scale * width / 64) * 64
    height = math.ceil(scale * height / 64) * 64

    return width, height


1156
@dataclass(repr=False)
1157
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168
    enable_hr: bool = False
    denoising_strength: float = 0.75
    firstphase_width: int = 0
    firstphase_height: int = 0
    hr_scale: float = 2.0
    hr_upscaler: str = None
    hr_second_pass_steps: int = 0
    hr_resize_x: int = 0
    hr_resize_y: int = 0
    hr_checkpoint_name: str = None
    hr_sampler_name: str = None
1169
    hr_scheduler: str = None
1170 1171
    hr_prompt: str = ''
    hr_negative_prompt: str = ''
1172
    force_task_id: str = None
1173

W
w-e-w 已提交
1174 1175
    cached_hr_uc = [None, None]
    cached_hr_c = [None, None]
A
AUTOMATIC 已提交
1176

1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195
    hr_checkpoint_info: dict = field(default=None, init=False)
    hr_upscale_to_x: int = field(default=0, init=False)
    hr_upscale_to_y: int = field(default=0, init=False)
    truncate_x: int = field(default=0, init=False)
    truncate_y: int = field(default=0, init=False)
    applied_old_hires_behavior_to: tuple = field(default=None, init=False)
    latent_scale_mode: dict = field(default=None, init=False)
    hr_c: tuple | None = field(default=None, init=False)
    hr_uc: tuple | None = field(default=None, init=False)
    all_hr_prompts: list = field(default=None, init=False)
    all_hr_negative_prompts: list = field(default=None, init=False)
    hr_prompts: list = field(default=None, init=False)
    hr_negative_prompts: list = field(default=None, init=False)
    hr_extra_network_data: list = field(default=None, init=False)

    def __post_init__(self):
        super().__post_init__()

        if self.firstphase_width != 0 or self.firstphase_height != 0:
1196 1197
            self.hr_upscale_to_x = self.width
            self.hr_upscale_to_y = self.height
1198 1199
            self.width = self.firstphase_width
            self.height = self.firstphase_height
1200

W
w-e-w 已提交
1201 1202
        self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
        self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
1203

1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242
    def calculate_target_resolution(self):
        if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
            self.hr_resize_x = self.width
            self.hr_resize_y = self.height
            self.hr_upscale_to_x = self.width
            self.hr_upscale_to_y = self.height

            self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
            self.applied_old_hires_behavior_to = (self.width, self.height)

        if self.hr_resize_x == 0 and self.hr_resize_y == 0:
            self.extra_generation_params["Hires upscale"] = self.hr_scale
            self.hr_upscale_to_x = int(self.width * self.hr_scale)
            self.hr_upscale_to_y = int(self.height * self.hr_scale)
        else:
            self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"

            if self.hr_resize_y == 0:
                self.hr_upscale_to_x = self.hr_resize_x
                self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
            elif self.hr_resize_x == 0:
                self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
                self.hr_upscale_to_y = self.hr_resize_y
            else:
                target_w = self.hr_resize_x
                target_h = self.hr_resize_y
                src_ratio = self.width / self.height
                dst_ratio = self.hr_resize_x / self.hr_resize_y

                if src_ratio < dst_ratio:
                    self.hr_upscale_to_x = self.hr_resize_x
                    self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
                else:
                    self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
                    self.hr_upscale_to_y = self.hr_resize_y

                self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
                self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f

A
AUTOMATIC 已提交
1243 1244
    def init(self, all_prompts, all_seeds, all_subseeds):
        if self.enable_hr:
1245 1246
            self.extra_generation_params["Denoising strength"] = self.denoising_strength

A
a  
AUTOMATIC1111 已提交
1247
            if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
A
AUTOMATIC1111 已提交
1248 1249 1250 1251 1252 1253 1254
                self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)

                if self.hr_checkpoint_info is None:
                    raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')

                self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title

1255 1256 1257
            if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
                self.extra_generation_params["Hires sampler"] = self.hr_sampler_name

1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268
            def get_hr_prompt(p, index, prompt_text, **kwargs):
                hr_prompt = p.all_hr_prompts[index]
                return hr_prompt if hr_prompt != prompt_text else None

            def get_hr_negative_prompt(p, index, negative_prompt, **kwargs):
                hr_negative_prompt = p.all_hr_negative_prompts[index]
                return hr_negative_prompt if hr_negative_prompt != negative_prompt else None

            self.extra_generation_params["Hires prompt"] = get_hr_prompt
            self.extra_generation_params["Hires negative prompt"] = get_hr_negative_prompt

1269 1270 1271 1272 1273
            self.extra_generation_params["Hires schedule type"] = None  # to be set in sd_samplers_kdiffusion.py

            if self.hr_scheduler is None:
                self.hr_scheduler = self.scheduler

A
AUTOMATIC1111 已提交
1274 1275 1276 1277 1278
            self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
            if self.enable_hr and self.latent_scale_mode is None:
                if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
                    raise Exception(f"could not find upscaler named {self.hr_upscaler}")

1279
            self.calculate_target_resolution()
1280

1281 1282 1283
            if not state.processing_has_refined_job_count:
                if state.job_count == -1:
                    state.job_count = self.n_iter
1284 1285 1286 1287 1288
                if getattr(self, 'txt2img_upscale', False):
                    total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count
                else:
                    total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count
                shared.total_tqdm.updateTotal(total_steps)
1289 1290
                state.job_count = state.job_count * 2
                state.processing_has_refined_job_count = True
1291

1292 1293
            if self.hr_second_pass_steps:
                self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
1294

A
AUTOMATIC 已提交
1295 1296
            if self.hr_upscaler is not None:
                self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
1297

1298
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1299
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1300

1301 1302
        if self.firstpass_image is not None and self.enable_hr:
            # here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix
A
AUTOMATIC 已提交
1303

1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322
            if self.latent_scale_mode is None:
                image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0
                image = np.moveaxis(image, 2, 0)

                samples = None
                decoded_samples = torch.asarray(np.expand_dims(image, 0))

            else:
                image = np.array(self.firstpass_image).astype(np.float32) / 255.0
                image = np.moveaxis(image, 2, 0)
                image = torch.from_numpy(np.expand_dims(image, axis=0))
                image = image.to(shared.device, dtype=devices.dtype_vae)

                if opts.sd_vae_encode_method != 'Full':
                    self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method

                samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
                decoded_samples = None
                devices.torch_gc()
A
AUTOMATIC 已提交
1323

1324
        else:
1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339
            # here we generate an image normally

            x = self.rng.next()
            samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
            del x

            if not self.enable_hr:
                return samples

            devices.torch_gc()

            if self.latent_scale_mode is None:
                decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
            else:
                decoded_samples = None
1340

1341 1342 1343 1344
        with sd_models.SkipWritingToConfig():
            sd_models.reload_model_weights(info=self.hr_checkpoint_info)

        return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
A
AUTOMATIC1111 已提交
1345

1346
    def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
1347 1348 1349
        if shared.state.interrupted:
            return samples

A
AUTOMATIC 已提交
1350
        self.is_hr_pass = True
1351 1352
        target_width = self.hr_upscale_to_x
        target_height = self.hr_upscale_to_y
A
AUTOMATIC 已提交
1353

1354
        def save_intermediate(image, index):
A
AUTOMATIC 已提交
1355 1356
            """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""

A
AUTOMATIC1111 已提交
1357
            if not self.save_samples() or not opts.save_images_before_highres_fix:
1358 1359 1360
                return

            if not isinstance(image, Image.Image):
M
MMaker 已提交
1361
                image = sd_samplers.sample_to_image(image, index, approximation=0)
1362

1363
            info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
W
w-e-w 已提交
1364
            images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
1365

1366 1367 1368 1369
        img2img_sampler_name = self.hr_sampler_name or self.sampler_name

        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)

A
AUTOMATIC1111 已提交
1370
        if self.latent_scale_mode is not None:
1371 1372 1373
            for i in range(samples.shape[0]):
                save_intermediate(samples, i)

A
AUTOMATIC1111 已提交
1374
            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
1375

J
Jim Hays 已提交
1376
            # Avoid making the inpainting conditioning unless necessary as
1377 1378 1379 1380 1381
            # this does need some extra compute to decode / encode the image again.
            if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
                image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
            else:
                image_conditioning = self.txt2img_image_conditioning(samples)
A
AUTOMATIC 已提交
1382
        else:
1383
            lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
1384

1385 1386 1387 1388 1389
            batch_images = []
            for i, x_sample in enumerate(lowres_samples):
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)
                image = Image.fromarray(x_sample)
1390 1391 1392

                save_intermediate(image, i)

A
AUTOMATIC 已提交
1393
                image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
1394 1395 1396 1397 1398
                image = np.array(image).astype(np.float32) / 255.0
                image = np.moveaxis(image, 2, 0)
                batch_images.append(image)

            decoded_samples = torch.from_numpy(np.array(batch_images))
K
Kohaku-Blueleaf 已提交
1399
            decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
1400

1401 1402
            if opts.sd_vae_encode_method != 'Full':
                self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1403
            samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
A
AUTOMATIC 已提交
1404

1405
            image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
1406

A
AUTOMATIC 已提交
1407
        shared.state.nextjob()
1408

1409 1410
        samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]

1411 1412
        self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w)
        noise = self.rng.next()
1413 1414 1415

        # GC now before running the next img2img to prevent running out of memory
        devices.torch_gc()
1416

1417 1418 1419 1420
        if not self.disable_extra_networks:
            with devices.autocast():
                extra_networks.activate(self, self.hr_extra_network_data)

1421 1422 1423
        with devices.autocast():
            self.calculate_hr_conds()

1424
        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
1425

1426 1427 1428
        if self.scripts is not None:
            self.scripts.before_hr(self)

1429
        samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
A
AUTOMATIC 已提交
1430

1431
        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
P
papuSpartan 已提交
1432

1433 1434 1435
        self.sampler = None
        devices.torch_gc()

A
AUTOMATIC1111 已提交
1436 1437
        decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)

A
AUTOMATIC 已提交
1438
        self.is_hr_pass = False
A
AUTOMATIC1111 已提交
1439
        return decoded_samples
1440

1441
    def close(self):
W
w-e-w 已提交
1442
        super().close()
1443 1444
        self.hr_c = None
        self.hr_uc = None
A
AUTOMATIC1111 已提交
1445
        if not opts.persistent_cond_cache:
W
w-e-w 已提交
1446 1447
            StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
            StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460

    def setup_prompts(self):
        super().setup_prompts()

        if not self.enable_hr:
            return

        if self.hr_prompt == '':
            self.hr_prompt = self.prompt

        if self.hr_negative_prompt == '':
            self.hr_negative_prompt = self.negative_prompt

X
XDOneDude 已提交
1461
        if isinstance(self.hr_prompt, list):
1462 1463 1464 1465
            self.all_hr_prompts = self.hr_prompt
        else:
            self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]

X
XDOneDude 已提交
1466
        if isinstance(self.hr_negative_prompt, list):
1467 1468 1469 1470 1471 1472 1473
            self.all_hr_negative_prompts = self.hr_negative_prompt
        else:
            self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]

        self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
        self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]

1474 1475 1476 1477
    def calculate_hr_conds(self):
        if self.hr_c is not None:
            return

A
AUTOMATIC1111 已提交
1478 1479 1480
        hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
        hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)

1481 1482 1483 1484
        sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name)
        steps = self.hr_second_pass_steps or self.steps
        total_steps = sampler_config.total_steps(steps) if sampler_config else steps

1485 1486
        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
1487

1488
    def setup_conds(self):
1489 1490 1491 1492 1493 1494
        if self.is_hr_pass:
            # if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
            self.hr_c = None
            self.calculate_hr_conds()
            return

1495 1496
        super().setup_conds()

1497 1498 1499
        self.hr_uc = None
        self.hr_c = None

A
AUTOMATIC1111 已提交
1500
        if self.enable_hr and self.hr_checkpoint_info is None:
1501 1502 1503
            if shared.opts.hires_fix_use_firstpass_conds:
                self.calculate_hr_conds()

1504
            elif lowvram.is_enabled(shared.sd_model) and shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint():  # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
1505 1506 1507 1508 1509 1510 1511
                with devices.autocast():
                    extra_networks.activate(self, self.hr_extra_network_data)

                self.calculate_hr_conds()

                with devices.autocast():
                    extra_networks.activate(self, self.extra_network_data)
1512

A
AUTOMATIC1111 已提交
1513 1514 1515 1516 1517 1518
    def get_conds(self):
        if self.is_hr_pass:
            return self.hr_c, self.hr_uc

        return super().get_conds()

1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529
    def parse_extra_network_prompts(self):
        res = super().parse_extra_network_prompts()

        if self.enable_hr:
            self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
            self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]

            self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)

        return res

1530

1531
@dataclass(repr=False)
1532
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
1533 1534 1535 1536 1537 1538 1539 1540
    init_images: list = None
    resize_mode: int = 0
    denoising_strength: float = 0.75
    image_cfg_scale: float = None
    mask: Any = None
    mask_blur_x: int = 4
    mask_blur_y: int = 4
    mask_blur: int = None
1541
    mask_round: bool = True
1542 1543 1544 1545 1546 1547
    inpainting_fill: int = 0
    inpaint_full_res: bool = True
    inpaint_full_res_padding: int = 0
    inpainting_mask_invert: int = 0
    initial_noise_multiplier: float = None
    latent_mask: Image = None
G
gayshub 已提交
1548
    force_task_id: str = None
1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561

    image_mask: Any = field(default=None, init=False)

    nmask: torch.Tensor = field(default=None, init=False)
    image_conditioning: torch.Tensor = field(default=None, init=False)
    init_img_hash: str = field(default=None, init=False)
    mask_for_overlay: Image = field(default=None, init=False)
    init_latent: torch.Tensor = field(default=None, init=False)

    def __post_init__(self):
        super().__post_init__()

        self.image_mask = self.mask
1562
        self.mask = None
1563
        self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
1564

1565 1566 1567 1568 1569 1570 1571 1572
    @property
    def mask_blur(self):
        if self.mask_blur_x == self.mask_blur_y:
            return self.mask_blur_x
        return None

    @mask_blur.setter
    def mask_blur(self, value):
1573 1574 1575
        if isinstance(value, int):
            self.mask_blur_x = value
            self.mask_blur_y = value
1576

A
AUTOMATIC 已提交
1577
    def init(self, all_prompts, all_seeds, all_subseeds):
1578 1579
        self.extra_generation_params["Denoising strength"] = self.denoising_strength

1580 1581
        self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None

1582
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1583 1584
        crop_region = None

1585
        image_mask = self.image_mask
A
AUTOMATIC 已提交
1586

1587
        if image_mask is not None:
1588 1589
            # image_mask is passed in as RGBA by Gradio to support alpha masks,
            # but we still want to support binary masks.
1590
            image_mask = create_binary_mask(image_mask, round=self.mask_round)
A
AUTOMATIC 已提交
1591

1592 1593
            if self.inpainting_mask_invert:
                image_mask = ImageOps.invert(image_mask)
A
AUTOMATIC1111 已提交
1594
                self.extra_generation_params["Mask mode"] = "Inpaint not masked"
A
AUTOMATIC 已提交
1595

1596 1597
            if self.mask_blur_x > 0:
                np_mask = np.array(image_mask)
1598
                kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1
1599 1600 1601 1602 1603
                np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
                image_mask = Image.fromarray(np_mask)

            if self.mask_blur_y > 0:
                np_mask = np.array(image_mask)
1604
                kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1
1605 1606
                np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
                image_mask = Image.fromarray(np_mask)
1607

A
AUTOMATIC1111 已提交
1608 1609 1610
            if self.mask_blur_x > 0 or self.mask_blur_y > 0:
                self.extra_generation_params["Mask blur"] = self.mask_blur

1611
            if self.inpaint_full_res:
1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631
                try:
                    self.mask_for_overlay = image_mask
                    mask = image_mask.convert('L')
                    crop_region = masking.get_crop_region(mask, self.inpaint_full_res_padding)
                    crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
                    x1, y1, x2, y2 = crop_region

                    mask = mask.crop(crop_region)
                    image_mask = images.resize_image(2, mask, self.width, self.height)
                    self.paste_to = (x1, y1, x2-x1, y2-y1)

                    self.extra_generation_params["Inpaint area"] = "Only masked"
                    self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding
                except ValueError:
                    self.mask_for_overlay = None
                    image_mask = None
                    crop_region = None
                    massage = 'Unable to perform "Inpaint Only mask" because mask is blank, switch to img2img mode.'
                    model_hijack.comments.append(massage)
                    logging.info(massage)
1632
            else:
1633
                image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
1634
                np_mask = np.array(image_mask)
J
JJ 已提交
1635
                np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
1636
                self.mask_for_overlay = Image.fromarray(np_mask)
1637 1638 1639

            self.overlay_images = []

1640
        latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
1641

1642 1643 1644
        add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
        if add_color_corrections:
            self.color_corrections = []
1645 1646
        imgs = []
        for img in self.init_images:
1647 1648 1649 1650

            # Save init image
            if opts.save_init_img:
                self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
W
w-e-w 已提交
1651
                images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info)
1652

1653
            image = images.flatten(img, opts.img2img_background_color)
1654

A
Andrew Ryan 已提交
1655
            if crop_region is None and self.resize_mode != 3:
1656
                image = images.resize_image(self.resize_mode, image, self.width, self.height)
1657

1658
            if image_mask is not None:
1659 1660 1661 1662 1663
                image_masked = Image.new('RGBa', (image.width, image.height))
                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))

                self.overlay_images.append(image_masked.convert('RGBA'))

1664
            # crop_region is not None if we are doing inpaint full res
1665 1666 1667 1668
            if crop_region is not None:
                image = image.crop(crop_region)
                image = images.resize_image(2, image, self.width, self.height)

1669
            if image_mask is not None:
1670
                if self.inpainting_fill != 1:
1671
                    image = masking.fill(image, latent_mask)
1672

A
AUTOMATIC1111 已提交
1673 1674 1675
                    if self.inpainting_fill == 0:
                        self.extra_generation_params["Masked content"] = 'fill'

1676
            if add_color_corrections:
1677 1678
                self.color_corrections.append(setup_color_correction(image))

1679 1680 1681 1682 1683 1684 1685 1686 1687
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)

            imgs.append(image)

        if len(imgs) == 1:
            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size
1688 1689 1690 1691

            if self.color_corrections is not None and len(self.color_corrections) == 1:
                self.color_corrections = self.color_corrections * self.batch_size

1692 1693 1694 1695 1696 1697 1698
        elif len(imgs) <= self.batch_size:
            self.batch_size = len(imgs)
            batch_images = np.array(imgs)
        else:
            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

        image = torch.from_numpy(batch_images)
K
Kohaku-Blueleaf 已提交
1699
        image = image.to(shared.device, dtype=devices.dtype_vae)
1700 1701 1702 1703

        if opts.sd_vae_encode_method != 'Full':
            self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method

K
Kohaku-Blueleaf 已提交
1704
        self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
1705
        devices.torch_gc()
1706

1707 1708
        if self.resize_mode == 3:
            self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
A
Andrew Ryan 已提交
1709

1710
        if image_mask is not None:
1711
            init_mask = latent_mask
A
AUTOMATIC 已提交
1712
            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
A
AUTOMATIC 已提交
1713
            latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
1714
            latmask = latmask[0]
1715
            if self.mask_round:
1716
                latmask = np.around(latmask)
1717 1718 1719 1720 1721
            latmask = np.tile(latmask[None], (4, 1, 1))

            self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
            self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)

A
AUTOMATIC 已提交
1722
            # this needs to be fixed to be done in sample() using actual seeds for batches
1723
            if self.inpainting_fill == 2:
A
AUTOMATIC 已提交
1724
                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
A
AUTOMATIC1111 已提交
1725 1726
                self.extra_generation_params["Masked content"] = 'latent noise'

1727 1728
            elif self.inpainting_fill == 3:
                self.init_latent = self.init_latent * self.mask
A
AUTOMATIC1111 已提交
1729
                self.extra_generation_params["Masked content"] = 'latent nothing'
1730

1731
        self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
1732

1733
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1734
        x = self.rng.next()
1735 1736 1737 1738

        if self.initial_noise_multiplier != 1.0:
            self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
            x *= self.initial_noise_multiplier
A
AUTOMATIC 已提交
1739

1740
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
1741 1742

        if self.mask is not None:
1743
            blended_samples = samples * self.nmask + self.init_latent * self.mask
1744

1745 1746 1747 1748
            if self.scripts is not None:
                mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
                self.scripts.on_mask_blend(self, mba)
                blended_samples = mba.blended_latent
1749

1750
            samples = blended_samples
1751

1752 1753 1754
        del x
        devices.torch_gc()

1755
        return samples
1756 1757 1758

    def get_token_merging_ratio(self, for_hr=False):
        return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio