processing.py 67.5 KB
Newer Older
1
from __future__ import annotations
2
import json
3
import logging
4 5 6
import math
import os
import sys
7
import hashlib
8
from dataclasses import dataclass, field
9 10 11

import torch
import numpy as np
A
linter  
AUTOMATIC 已提交
12
from PIL import Image, ImageOps
13
import random
14 15
import cv2
from skimage import exposure
16
from typing import Any
17

18
import modules.sd_hijack
19
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
20
from modules.rng import slerp # noqa: F401
21
from modules.sd_hijack import model_hijack
22
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
23 24
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
25
import modules.paths as paths
A
AUTOMATIC 已提交
26
import modules.face_restoration
27
import modules.images as images
A
AUTOMATIC 已提交
28
import modules.styles
29 30
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
J
Jay Smith 已提交
31 32
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
33

J
Jay Smith 已提交
34
from einops import repeat, rearrange
35
from blendmodes.blend import blendLayers, BlendType
36 37


38 39 40 41 42
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8


43
def setup_color_correction(image):
R
Robin Fernandes 已提交
44
    logging.info("Calibrating color correction.")
45 46 47 48
    correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
    return correction_target


49
def apply_color_correction(correction, original_image):
R
Robin Fernandes 已提交
50
    logging.info("Applying color correction.")
51 52
    image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
        cv2.cvtColor(
53
            np.asarray(original_image),
54 55 56 57 58
            cv2.COLOR_RGB2LAB
        ),
        correction,
        channel_axis=2
    ), cv2.COLOR_LAB2RGB).astype("uint8"))
59

60
    image = blendLayers(image, original_image, BlendType.LUMINOSITY)
61

62
    return image.convert('RGB')
63

A
AUTOMATIC 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80

def apply_overlay(image, paste_loc, index, overlays):
    if overlays is None or index >= len(overlays):
        return image

    overlay = overlays[index]

    if paste_loc is not None:
        x, y, w, h = paste_loc
        base_image = Image.new('RGBA', (overlay.width, overlay.height))
        image = images.resize_image(1, image, w, h)
        base_image.paste(image, (x, y))
        image = base_image

    image = image.convert('RGBA')
    image.alpha_composite(overlay)
    image = image.convert('RGB')
81 82

    return image
83

84 85 86 87 88 89
def create_binary_mask(image):
    if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
        image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
    else:
        image = image.convert('L')
    return image
F
frostydad 已提交
90

91
def txt2img_image_conditioning(sd_model, x, width, height):
92 93 94 95
    if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models

        # The "masked-image" in this case will just be all zeros since the entire image is masked.
        image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
96
        image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
97 98 99 100 101 102

        # Add the fake full 1s mask to the first dimension.
        image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
        image_conditioning = image_conditioning.to(x.dtype)

        return image_conditioning
103

104
    elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
105

106
        return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
107

108 109 110 111 112
    else:
        # Dummy zero conditioning if we're not using inpainting or unclip models.
        # Still takes up a bit of memory, but no encoder call.
        # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
        return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
113 114


115
@dataclass(repr=False)
116
class StableDiffusionProcessing:
117 118 119 120 121 122
    sd_model: object = None
    outpath_samples: str = None
    outpath_grids: str = None
    prompt: str = ""
    prompt_for_display: str = None
    negative_prompt: str = ""
A
AUTOMATIC1111 已提交
123
    styles: list[str] = None
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
    seed: int = -1
    subseed: int = -1
    subseed_strength: float = 0
    seed_resize_from_h: int = -1
    seed_resize_from_w: int = -1
    seed_enable_extras: bool = True
    sampler_name: str = None
    batch_size: int = 1
    n_iter: int = 1
    steps: int = 50
    cfg_scale: float = 7.0
    width: int = 512
    height: int = 512
    restore_faces: bool = None
    tiling: bool = None
    do_not_save_samples: bool = False
    do_not_save_grid: bool = False
    extra_generation_params: dict[str, Any] = None
    overlay_images: list = None
    eta: float = None
    do_not_reload_embeddings: bool = False
    denoising_strength: float = 0
    ddim_discretize: str = None
    s_min_uncond: float = None
    s_churn: float = None
    s_tmax: float = None
    s_tmin: float = None
    s_noise: float = None
    override_settings: dict[str, Any] = None
    override_settings_restore_afterwards: bool = True
    sampler_index: int = None
    refiner_checkpoint: str = None
    refiner_switch_at: float = None
    token_merging_ratio = 0
    token_merging_ratio_hr = 0
    disable_extra_networks: bool = False

A
AUTOMATIC1111 已提交
161 162 163
    scripts_value: scripts.ScriptRunner = field(default=None, init=False)
    script_args_value: list = field(default=None, init=False)
    scripts_setup_complete: bool = field(default=False, init=False)
164

W
w-e-w 已提交
165 166 167
    cached_uc = [None, None]
    cached_c = [None, None]

168
    comments: dict = None
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
    sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
    is_using_inpainting_conditioning: bool = field(default=False, init=False)
    paste_to: tuple | None = field(default=None, init=False)

    is_hr_pass: bool = field(default=False, init=False)

    c: tuple = field(default=None, init=False)
    uc: tuple = field(default=None, init=False)

    rng: rng.ImageRNG | None = field(default=None, init=False)
    step_multiplier: int = field(default=1, init=False)
    color_corrections: list = field(default=None, init=False)

    all_prompts: list = field(default=None, init=False)
    all_negative_prompts: list = field(default=None, init=False)
    all_seeds: list = field(default=None, init=False)
    all_subseeds: list = field(default=None, init=False)
    iteration: int = field(default=0, init=False)
    main_prompt: str = field(default=None, init=False)
    main_negative_prompt: str = field(default=None, init=False)

    prompts: list = field(default=None, init=False)
    negative_prompts: list = field(default=None, init=False)
    seeds: list = field(default=None, init=False)
    subseeds: list = field(default=None, init=False)
    extra_network_data: dict = field(default=None, init=False)

    user: str = field(default=None, init=False)

    sd_model_name: str = field(default=None, init=False)
    sd_model_hash: str = field(default=None, init=False)
    sd_vae_name: str = field(default=None, init=False)
    sd_vae_hash: str = field(default=None, init=False)

A
AUTOMATIC1111 已提交
203 204
    is_api: bool = field(default=False, init=False)

205 206
    def __post_init__(self):
        if self.sampler_index is not None:
207
            print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
208

209
        self.comments = {}
210 211 212

        if self.styles is None:
            self.styles = []
213

214
        self.sampler_noise_scheduler_override = None
215 216 217 218 219 220 221 222 223 224
        self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
        self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
        self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
        self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
        self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise

        self.extra_generation_params = self.extra_generation_params or {}
        self.override_settings = self.override_settings or {}
        self.script_args = self.script_args or {}

225
        self.refiner_checkpoint_info = None
226

227
        if not self.seed_enable_extras:
228 229 230 231 232
            self.subseed = -1
            self.subseed_strength = 0
            self.seed_resize_from_h = 0
            self.seed_resize_from_w = 0

W
w-e-w 已提交
233 234
        self.cached_uc = StableDiffusionProcessing.cached_uc
        self.cached_c = StableDiffusionProcessing.cached_c
235

236 237 238 239
    @property
    def sd_model(self):
        return shared.sd_model

240 241 242 243
    @sd_model.setter
    def sd_model(self, value):
        pass

A
AUTOMATIC1111 已提交
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
    @property
    def scripts(self):
        return self.scripts_value

    @scripts.setter
    def scripts(self, value):
        self.scripts_value = value

        if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
            self.setup_scripts()

    @property
    def script_args(self):
        return self.script_args_value

    @script_args.setter
    def script_args(self, value):
        self.script_args_value = value

        if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
            self.setup_scripts()

    def setup_scripts(self):
        self.scripts_setup_complete = True

A
AUTOMATIC1111 已提交
269
        self.scripts.setup_scrips(self, is_ui=not self.is_api)
A
AUTOMATIC1111 已提交
270

271 272 273
    def comment(self, text):
        self.comments[text] = 1

274
    def txt2img_image_conditioning(self, x, width=None, height=None):
275
        self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
276

277
        return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
278

J
Jay Smith 已提交
279 280 281 282 283 284 285
    def depth2img_image_conditioning(self, source_image):
        # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
        transformer = AddMiDaS(model_type="dpt_hybrid")
        transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
        midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
        midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)

286
        conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
J
Jay Smith 已提交
287 288 289 290 291 292 293 294 295 296
        conditioning = torch.nn.functional.interpolate(
            self.sd_model.depth_model(midas_in),
            size=conditioning_image.shape[2:],
            mode="bicubic",
            align_corners=False,
        )

        (depth_min, depth_max) = torch.aminmax(conditioning)
        conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
        return conditioning
297

298
    def edit_image_conditioning(self, source_image):
299
        conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
300 301 302

        return conditioning_image

303 304 305 306 307 308 309 310
    def unclip_image_conditioning(self, source_image):
        c_adm = self.sd_model.embedder(source_image)
        if self.sd_model.noise_augmentor is not None:
            noise_level = 0 # TODO: Allow other noise levels?
            c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
            c_adm = torch.cat((c_adm, noise_level_emb), 1)
        return c_adm

311
    def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
312 313
        self.is_using_inpainting_conditioning = True

314 315 316 317 318 319 320 321 322 323 324 325
        # Handle the different mask inputs
        if image_mask is not None:
            if torch.is_tensor(image_mask):
                conditioning_mask = image_mask
            else:
                conditioning_mask = np.array(image_mask.convert("L"))
                conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
                conditioning_mask = torch.from_numpy(conditioning_mask[None, None])

                # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
                conditioning_mask = torch.round(conditioning_mask)
        else:
326
            conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
327 328 329

        # Create another latent image, this time with a masked version of the original input.
        # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
330
        conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
331 332 333 334 335
        conditioning_image = torch.lerp(
            source_image,
            source_image * (1.0 - conditioning_mask),
            getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
        )
J
Jim Hays 已提交
336

337
        # Encode the new masked image using first stage of network.
K
Kohaku-Blueleaf 已提交
338
        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
339 340 341 342 343 344 345 346 347

        # Create the concatenated conditioning tensor to be fed to `c_concat`
        conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
        conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
        image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
        image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)

        return image_conditioning

J
Jay Smith 已提交
348
    def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
349 350
        source_image = devices.cond_cast_float(source_image)

J
Jay Smith 已提交
351 352 353
        # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
        # identify itself with a field common to all models. The conditioning_key is also hybrid.
        if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
354
            return self.depth2img_image_conditioning(source_image)
J
Jay Smith 已提交
355

356 357 358
        if self.sd_model.cond_stage_key == "edit":
            return self.edit_image_conditioning(source_image)

J
Jay Smith 已提交
359
        if self.sampler.conditioning_key in {'hybrid', 'concat'}:
360
            return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
J
Jay Smith 已提交
361

362 363 364
        if self.sampler.conditioning_key == "crossattn-adm":
            return self.unclip_image_conditioning(source_image)

J
Jay Smith 已提交
365 366 367
        # Dummy zero conditioning if we're not using inpainting or depth model.
        return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)

A
AUTOMATIC 已提交
368
    def init(self, all_prompts, all_seeds, all_subseeds):
369 370
        pass

371
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
372 373
        raise NotImplementedError()

374 375
    def close(self):
        self.sampler = None
376 377
        self.c = None
        self.uc = None
A
AUTOMATIC1111 已提交
378
        if not opts.persistent_cond_cache:
W
w-e-w 已提交
379 380
            StableDiffusionProcessing.cached_c = [None, None]
            StableDiffusionProcessing.cached_uc = [None, None]
381

382 383 384 385 386 387
    def get_token_merging_ratio(self, for_hr=False):
        if for_hr:
            return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio

        return self.token_merging_ratio or opts.token_merging_ratio

388
    def setup_prompts(self):
X
XDOneDude 已提交
389
        if isinstance(self.prompt,list):
390
            self.all_prompts = self.prompt
X
XDOneDude 已提交
391
        elif isinstance(self.negative_prompt, list):
392
            self.all_prompts = [self.prompt] * len(self.negative_prompt)
393 394 395
        else:
            self.all_prompts = self.batch_size * self.n_iter * [self.prompt]

X
XDOneDude 已提交
396
        if isinstance(self.negative_prompt, list):
397 398
            self.all_negative_prompts = self.negative_prompt
        else:
399 400 401 402
            self.all_negative_prompts = [self.negative_prompt] * len(self.all_prompts)

        if len(self.all_prompts) != len(self.all_negative_prompts):
            raise RuntimeError(f"Received a different number of prompts ({len(self.all_prompts)}) and negative prompts ({len(self.all_negative_prompts)})")
403 404 405 406

        self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
        self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]

407 408 409
        self.main_prompt = self.all_prompts[0]
        self.main_negative_prompt = self.all_negative_prompts[0]

410
    def cached_params(self, required_prompts, steps, extra_network_data, hires_steps=None, use_old_scheduling=False):
A
AUTOMATIC1111 已提交
411 412 413 414 415
        """Returns parameters that invalidate the cond cache if changed"""

        return (
            required_prompts,
            steps,
416 417
            hires_steps,
            use_old_scheduling,
A
AUTOMATIC1111 已提交
418 419 420 421 422 423 424 425 426
            opts.CLIP_stop_at_last_layers,
            shared.sd_model.sd_checkpoint_info,
            extra_network_data,
            opts.sdxl_crop_left,
            opts.sdxl_crop_top,
            self.width,
            self.height,
        )

427
    def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
428 429 430 431 432 433 434 435
        """
        Returns the result of calling function(shared.sd_model, required_prompts, steps)
        using a cache to store the result if the same arguments have been used before.

        cache is an array containing two elements. The first element is a tuple
        representing the previously used arguments, or None if no arguments
        have been used before. The second element is where the previously
        computed result is stored.
436 437

        caches is a list with items described above.
438
        """
439

440 441 442 443 444 445
        if shared.opts.use_old_scheduling:
            old_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, False)
            new_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, True)
            if old_schedules != new_schedules:
                self.extra_generation_params["Old prompt editing timelines"] = True

446
        cached_params = self.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling)
447

448
        for cache in caches:
449
            if cache[0] is not None and cached_params == cache[0]:
450 451 452
                return cache[1]

        cache = caches[0]
453 454

        with devices.autocast():
455
            cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
456

457
        cache[0] = cached_params
458 459 460
        return cache[1]

    def setup_conds(self):
A
AUTOMATIC1111 已提交
461
        prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
462
        negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
A
AUTOMATIC1111 已提交
463

464
        sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
465 466
        total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
        self.step_multiplier = total_steps // self.steps
467 468
        self.firstpass_steps = total_steps

469 470
        self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
        self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
471

A
AUTOMATIC1111 已提交
472 473 474
    def get_conds(self):
        return self.c, self.uc

475
    def parse_extra_network_prompts(self):
W
w-e-w 已提交
476
        self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
477

A
AUTOMATIC1111 已提交
478 479 480 481
    def save_samples(self) -> bool:
        """Returns whether generated images need to be written to disk"""
        return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)

482 483

class Processed:
484
    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
485 486
        self.images = images_list
        self.prompt = p.prompt
487
        self.negative_prompt = p.negative_prompt
488
        self.seed = seed
489 490
        self.subseed = subseed
        self.subseed_strength = p.subseed_strength
491
        self.info = info
492
        self.comments = "".join(f"{comment}\n" for comment in p.comments)
493 494
        self.width = p.width
        self.height = p.height
495
        self.sampler_name = p.sampler_name
496
        self.cfg_scale = p.cfg_scale
K
Kyle 已提交
497
        self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
498
        self.steps = p.steps
499 500 501
        self.batch_size = p.batch_size
        self.restore_faces = p.restore_faces
        self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
502 503 504 505
        self.sd_model_name = p.sd_model_name
        self.sd_model_hash = p.sd_model_hash
        self.sd_vae_name = p.sd_vae_name
        self.sd_vae_hash = p.sd_vae_hash
506 507 508 509 510
        self.seed_resize_from_w = p.seed_resize_from_w
        self.seed_resize_from_h = p.seed_resize_from_h
        self.denoising_strength = getattr(p, 'denoising_strength', None)
        self.extra_generation_params = p.extra_generation_params
        self.index_of_first_image = index_of_first_image
M
Milly 已提交
511
        self.styles = p.styles
M
Milly 已提交
512
        self.job_timestamp = state.job_timestamp
513
        self.clip_skip = opts.CLIP_stop_at_last_layers
514 515
        self.token_merging_ratio = p.token_merging_ratio
        self.token_merging_ratio_hr = p.token_merging_ratio_hr
516

C
C43H66N12O12S2 已提交
517
        self.eta = p.eta
518 519 520 521 522
        self.ddim_discretize = p.ddim_discretize
        self.s_churn = p.s_churn
        self.s_tmin = p.s_tmin
        self.s_tmax = p.s_tmax
        self.s_noise = p.s_noise
A
Aarni Koskela 已提交
523
        self.s_min_uncond = p.s_min_uncond
524
        self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
X
XDOneDude 已提交
525 526 527 528
        self.prompt = self.prompt if not isinstance(self.prompt, list) else self.prompt[0]
        self.negative_prompt = self.negative_prompt if not isinstance(self.negative_prompt, list) else self.negative_prompt[0]
        self.seed = int(self.seed if not isinstance(self.seed, list) else self.seed[0]) if self.seed is not None else -1
        self.subseed = int(self.subseed if not isinstance(self.subseed, list) else self.subseed[0]) if self.subseed is not None else -1
529
        self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
530

531 532 533 534
        self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
        self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
        self.all_seeds = all_seeds or p.all_seeds or [self.seed]
        self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
535
        self.infotexts = infotexts or [info]
536 537 538

    def js(self):
        obj = {
539
            "prompt": self.all_prompts[0],
540
            "all_prompts": self.all_prompts,
541 542
            "negative_prompt": self.all_negative_prompts[0],
            "all_negative_prompts": self.all_negative_prompts,
543 544 545 546
            "seed": self.seed,
            "all_seeds": self.all_seeds,
            "subseed": self.subseed,
            "all_subseeds": self.all_subseeds,
547
            "subseed_strength": self.subseed_strength,
548 549
            "width": self.width,
            "height": self.height,
550
            "sampler_name": self.sampler_name,
551 552
            "cfg_scale": self.cfg_scale,
            "steps": self.steps,
553 554 555
            "batch_size": self.batch_size,
            "restore_faces": self.restore_faces,
            "face_restoration_model": self.face_restoration_model,
556
            "sd_model_name": self.sd_model_name,
557
            "sd_model_hash": self.sd_model_hash,
558 559
            "sd_vae_name": self.sd_vae_name,
            "sd_vae_hash": self.sd_vae_hash,
560 561 562 563 564
            "seed_resize_from_w": self.seed_resize_from_w,
            "seed_resize_from_h": self.seed_resize_from_h,
            "denoising_strength": self.denoising_strength,
            "extra_generation_params": self.extra_generation_params,
            "index_of_first_image": self.index_of_first_image,
565
            "infotexts": self.infotexts,
M
Milly 已提交
566
            "styles": self.styles,
M
Milly 已提交
567
            "job_timestamp": self.job_timestamp,
568
            "clip_skip": self.clip_skip,
569
            "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
570 571 572 573
        }

        return json.dumps(obj)

S
space-nuko 已提交
574
    def infotext(self, p: StableDiffusionProcessing, index):
575 576
        return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)

577 578 579
    def get_token_merging_ratio(self, for_hr=False):
        return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio

580

581
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
582 583
    g = rng.ImageRNG(shape, seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w)
    return g.next()
584 585


A
AUTOMATIC1111 已提交
586 587 588 589
class DecodedSamples(list):
    already_decoded = True


590
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
A
AUTOMATIC1111 已提交
591
    samples = DecodedSamples()
592 593 594 595 596 597 598 599 600 601 602 603 604 605

    for i in range(batch.shape[0]):
        sample = decode_first_stage(model, batch[i:i + 1])[0]

        if check_for_nans:
            try:
                devices.test_for_nans(sample, "vae")
            except devices.NansException as e:
                if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
                    raise e

                errors.print_error_explanation(
                    "A tensor with all NaNs was produced in VAE.\n"
                    "Web UI will now convert VAE into 32-bit float and retry.\n"
D
dhwz 已提交
606
                    "To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n"
607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623
                    "To always start with 32-bit VAE, use --no-half-vae commandline flag."
                )

                devices.dtype_vae = torch.float32
                model.first_stage_model.to(devices.dtype_vae)
                batch = batch.to(devices.dtype_vae)

                sample = decode_first_stage(model, batch[i:i + 1])[0]

        if target_device is not None:
            sample = sample.to(target_device)

        samples.append(sample)

    return samples


624
def get_fixed_seed(seed):
A
AUTOMATIC1111 已提交
625 626 627 628 629 630 631 632 633
    if seed == '' or seed is None:
        seed = -1
    elif isinstance(seed, str):
        try:
            seed = int(seed)
        except Exception:
            seed = -1

    if seed == -1:
634 635 636 637 638
        return int(random.randrange(4294967294))

    return seed


639
def fix_seed(p):
640 641
    p.seed = get_fixed_seed(p.seed)
    p.subseed = get_fixed_seed(p.subseed)
A
AUTOMATIC 已提交
642 643


644 645 646 647 648 649 650 651 652 653
def program_version():
    import launch

    res = launch.git_tag()
    if res == "<none>":
        res = None

    return res


654 655 656 657 658 659
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):
    if index is None:
        index = position_in_batch + iteration * p.batch_size

    if all_negative_prompts is None:
        all_negative_prompts = p.all_negative_prompts
660

661
    clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
P
papuSpartan 已提交
662
    enable_hr = getattr(p, 'enable_hr', False)
663 664
    token_merging_ratio = p.get_token_merging_ratio()
    token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
665

666 667 668 669
    uses_ensd = opts.eta_noise_seed_delta != 0
    if uses_ensd:
        uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)

670 671
    generation_params = {
        "Steps": p.steps,
672
        "Sampler": p.sampler_name,
673
        "CFG scale": p.cfg_scale,
K
Kyle 已提交
674
        "Image CFG scale": getattr(p, 'image_cfg_scale', None),
A
AUTOMATIC1111 已提交
675
        "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
676
        "Face restoration": opts.face_restoration_model if p.restore_faces else None,
677
        "Size": f"{p.width}x{p.height}",
678 679 680 681
        "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
        "Model": p.sd_model_name if opts.add_model_name_to_info else None,
        "VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None,
        "VAE": p.sd_vae_name if opts.add_model_name_to_info else None,
A
AUTOMATIC1111 已提交
682
        "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
683
        "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
M
missionfloyd 已提交
684
        "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
685
        "Denoising strength": getattr(p, 'denoising_strength', None),
686
        "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
687
        "Clip skip": None if clip_skip <= 1 else clip_skip,
688
        "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
689 690
        "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
        "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
691
        "Init image hash": getattr(p, 'init_img_hash', None),
692
        "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
693
        "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
694
        "Tiling": "True" if p.tiling else None,
695
        **p.extra_generation_params,
696
        "Version": program_version() if opts.add_version_to_infotext else None,
697
        "User": p.user if opts.add_user_name_to_info else None,
698 699
    }

700
    generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
701

702 703
    prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
    negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else ""
704

705
    return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
706 707


708
def process_images(p: StableDiffusionProcessing) -> Processed:
709 710 711
    if p.scripts is not None:
        p.scripts.before_process(p)

712 713 714
    stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}

    try:
W
w-e-w 已提交
715
        # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
W
w-e-w 已提交
716
        # and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards
A
Aarni Koskela 已提交
717
        if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
W
w-e-w 已提交
718 719 720
            p.override_settings.pop('sd_model_checkpoint', None)
            sd_models.reload_model_weights()

721
        for k, v in p.override_settings.items():
722
            opts.set(k, v, is_api=True, run_callbacks=False)
723 724

            if k == 'sd_model_checkpoint':
A
AUTOMATIC 已提交
725
                sd_models.reload_model_weights()
726 727

            if k == 'sd_vae':
A
AUTOMATIC 已提交
728
                sd_vae.reload_vae_weights()
729

730
        sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
731

732 733
        res = process_images_inner(p)

734
    finally:
735
        sd_models.apply_token_merging(p.sd_model, 0)
736

737 738 739 740
        # restore opts to original state
        if p.override_settings_restore_afterwards:
            for k, v in stored_opts.items():
                setattr(opts, k, v)
A
AUTOMATIC 已提交
741 742 743

                if k == 'sd_vae':
                    sd_vae.reload_vae_weights()
744 745 746 747 748

    return res


def process_images_inner(p: StableDiffusionProcessing) -> Processed:
749 750
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

X
XDOneDude 已提交
751
    if isinstance(p.prompt, list):
752 753 754
        assert(len(p.prompt) > 0)
    else:
        assert p.prompt is not None
755

756
    devices.torch_gc()
757

758 759
    seed = get_fixed_seed(p.seed)
    subseed = get_fixed_seed(p.subseed)
760

761 762 763 764 765 766
    if p.restore_faces is None:
        p.restore_faces = opts.face_restoration

    if p.tiling is None:
        p.tiling = opts.tiling

767
    if p.refiner_checkpoint not in (None, "", "None", "none"):
768 769 770 771 772 773 774 775
        p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
        if p.refiner_checkpoint_info is None:
            raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')

    p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
    p.sd_model_hash = shared.sd_model.sd_model_hash
    p.sd_vae_name = sd_vae.get_loaded_vae_name()
    p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
776

777
    modules.sd_hijack.model_hijack.apply_circular(p.tiling)
778
    modules.sd_hijack.model_hijack.clear_comments()
779

780
    p.setup_prompts()
I
invincibledude 已提交
781

X
XDOneDude 已提交
782
    if isinstance(seed, list):
783
        p.all_seeds = seed
A
AUTOMATIC 已提交
784
    else:
785
        p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
786

X
XDOneDude 已提交
787
    if isinstance(subseed, list):
788
        p.all_subseeds = subseed
789
    else:
790
        p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
791

792
    if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
793
        model_hijack.embedding_db.load_textual_inversion_embeddings()
794

795
    if p.scripts is not None:
A
AUTOMATIC 已提交
796
        p.scripts.process(p)
797

798
    infotexts = []
799
    output_images = []
800

801
    with torch.no_grad(), p.sd_model.ema_scope():
A
AUTOMATIC 已提交
802
        with devices.autocast():
803
            p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
804

805 806
            # for OSX, loading the model during sampling changes the generated picture, so it is loaded here
            if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
807 808
                sd_vae_approx.model()

A
AUTOMATIC 已提交
809 810
            sd_unet.apply_unet()

A
AUTOMATIC 已提交
811 812
        if state.job_count == -1:
            state.job_count = p.n_iter
A
AUTOMATIC 已提交
813

814
        for n in range(p.n_iter):
815 816
            p.iteration = n

817 818
            if state.skipped:
                state.skipped = False
J
Jim Hays 已提交
819

820 821 822
            if state.interrupted:
                break

823 824
            sd_models.reload_model_weights()  # model can be changed for example by refiner

825 826 827 828
            p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
            p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
829

830 831
            p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)

832
            if p.scripts is not None:
833
                p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
834

835
            if len(p.prompts) == 0:
836 837
                break

W
w-e-w 已提交
838
            p.parse_extra_network_prompts()
I
InvincibleDude 已提交
839

840 841
            if not p.disable_extra_networks:
                with devices.autocast():
W
w-e-w 已提交
842
                    extra_networks.activate(p, p.extra_network_data)
843

A
Artem Zagidulin 已提交
844
            if p.scripts is not None:
845
                p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
A
Artem Zagidulin 已提交
846

847 848 849 850 851 852
            # params.txt should be saved after scripts.process_batch, since the
            # infotext could be modified by that callback
            # Example: a wildcard processed by process_batch sets an extra model
            # strength, which is saved as "Model Strength: 1.0" in the infotext
            if n == 0:
                with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
A
AUTOMATIC1111 已提交
853
                    processed = Processed(p, [])
854 855
                    file.write(processed.infotext(p, 0))

856
            p.setup_conds()
857

858
            for comment in model_hijack.comments:
859
                p.comment(comment)
860 861

            p.extra_generation_params.update(model_hijack.extra_generation_params)
862 863

            if p.n_iter > 1:
864
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"
865

866
            with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
867
                samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
868

A
AUTOMATIC1111 已提交
869 870 871
            if getattr(samples_ddim, 'already_decoded', False):
                x_samples_ddim = samples_ddim
            else:
872 873 874
                if opts.sd_vae_decode_method != 'Full':
                    p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method

A
AUTOMATIC1111 已提交
875 876
                x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)

877
            x_samples_ddim = torch.stack(x_samples_ddim).float()
878 879
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

880 881
            del samples_ddim

882
            if lowvram.is_enabled(shared.sd_model):
883 884 885 886
                lowvram.send_everything_to_cpu()

            devices.torch_gc()

887 888
            if p.scripts is not None:
                p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
G
GRMrGecko 已提交
889

890 891
                p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
                p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
892

893
                batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
894 895 896 897
                p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
                x_samples_ddim = batch_params.images

            def infotext(index=0, use_main_prompt=False):
A
AUTOMATIC1111 已提交
898
                return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
L
ljleb 已提交
899

A
AUTOMATIC1111 已提交
900
            save_samples = p.save_samples()
W
w-e-w 已提交
901

902
            for i, x_sample in enumerate(x_samples_ddim):
903 904
                p.batch_index = i

905 906 907
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)

908
                if p.restore_faces:
A
AUTOMATIC1111 已提交
909
                    if save_samples and opts.save_images_before_face_restoration:
910
                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
911

912
                    devices.torch_gc()
913

914 915
                    x_sample = modules.face_restoration.restore_faces(x_sample)
                    devices.torch_gc()
916

917
                image = Image.fromarray(x_sample)
V
Vladimir Repin 已提交
918

919 920 921 922
                if p.scripts is not None:
                    pp = scripts.PostprocessImageArgs(image)
                    p.scripts.postprocess_image(p, pp)
                    image = pp.image
923
                if p.color_corrections is not None and i < len(p.color_corrections):
A
AUTOMATIC1111 已提交
924
                    if save_samples and opts.save_images_before_color_correction:
A
AUTOMATIC 已提交
925
                        image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
926
                        images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
927
                    image = apply_color_correction(p.color_corrections[i], image)
928

A
AUTOMATIC 已提交
929
                image = apply_overlay(image, p.paste_to, i, p.overlay_images)
930

A
AUTOMATIC1111 已提交
931
                if save_samples:
932
                    images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
933

934
                text = infotext(i)
935
                infotexts.append(text)
936 937
                if opts.enable_pnginfo:
                    image.info["parameters"] = text
938
                output_images.append(image)
939
                if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
940
                    image_mask = p.mask_for_overlay.convert('RGB')
941
                    image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
942 943

                    if opts.save_mask:
944
                        images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
945 946

                    if opts.save_mask_composite:
947
                        images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
948 949 950

                    if opts.return_mask:
                        output_images.append(image_mask)
951

952 953 954
                    if opts.return_mask_composite:
                        output_images.append(image_mask_composite)

J
Jim Hays 已提交
955
            del x_samples_ddim
A
AUTOMATIC 已提交
956

957
            devices.torch_gc()
958

959
            state.nextjob()
960

961 962
        p.color_corrections = None

963
        index_of_first_image = 0
964
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
965
        if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
A
AUTOMATIC 已提交
966
            grid = images.image_grid(output_images, p.batch_size)
967

968
            if opts.return_grid:
969
                text = infotext(use_main_prompt=True)
970
                infotexts.insert(0, text)
971 972
                if opts.enable_pnginfo:
                    grid.info["parameters"] = text
973
                output_images.insert(0, grid)
974
                index_of_first_image = 1
975
            if opts.grid_save:
976
                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
977

W
w-e-w 已提交
978 979
    if not p.disable_extra_networks and p.extra_network_data:
        extra_networks.deactivate(p, p.extra_network_data)
A
AUTOMATIC 已提交
980

981
    devices.torch_gc()
A
AUTOMATIC 已提交
982

983 984 985 986
    res = Processed(
        p,
        images_list=output_images,
        seed=p.all_seeds[0],
A
AUTOMATIC1111 已提交
987
        info=infotexts[0],
988 989 990 991
        subseed=p.all_subseeds[0],
        index_of_first_image=index_of_first_image,
        infotexts=infotexts,
    )
A
AUTOMATIC 已提交
992 993 994 995 996

    if p.scripts is not None:
        p.scripts.postprocess(p, res)

    return res
997 998


999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010
def old_hires_fix_first_pass_dimensions(width, height):
    """old algorithm for auto-calculating first pass size"""

    desired_pixel_count = 512 * 512
    actual_pixel_count = width * height
    scale = math.sqrt(desired_pixel_count / actual_pixel_count)
    width = math.ceil(scale * width / 64) * 64
    height = math.ceil(scale * height / 64) * 64

    return width, height


1011
@dataclass(repr=False)
1012
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026
    enable_hr: bool = False
    denoising_strength: float = 0.75
    firstphase_width: int = 0
    firstphase_height: int = 0
    hr_scale: float = 2.0
    hr_upscaler: str = None
    hr_second_pass_steps: int = 0
    hr_resize_x: int = 0
    hr_resize_y: int = 0
    hr_checkpoint_name: str = None
    hr_sampler_name: str = None
    hr_prompt: str = ''
    hr_negative_prompt: str = ''

W
w-e-w 已提交
1027 1028
    cached_hr_uc = [None, None]
    cached_hr_c = [None, None]
A
AUTOMATIC 已提交
1029

1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048
    hr_checkpoint_info: dict = field(default=None, init=False)
    hr_upscale_to_x: int = field(default=0, init=False)
    hr_upscale_to_y: int = field(default=0, init=False)
    truncate_x: int = field(default=0, init=False)
    truncate_y: int = field(default=0, init=False)
    applied_old_hires_behavior_to: tuple = field(default=None, init=False)
    latent_scale_mode: dict = field(default=None, init=False)
    hr_c: tuple | None = field(default=None, init=False)
    hr_uc: tuple | None = field(default=None, init=False)
    all_hr_prompts: list = field(default=None, init=False)
    all_hr_negative_prompts: list = field(default=None, init=False)
    hr_prompts: list = field(default=None, init=False)
    hr_negative_prompts: list = field(default=None, init=False)
    hr_extra_network_data: list = field(default=None, init=False)

    def __post_init__(self):
        super().__post_init__()

        if self.firstphase_width != 0 or self.firstphase_height != 0:
1049 1050
            self.hr_upscale_to_x = self.width
            self.hr_upscale_to_y = self.height
1051 1052
            self.width = self.firstphase_width
            self.height = self.firstphase_height
1053

W
w-e-w 已提交
1054 1055
        self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
        self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
1056

1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095
    def calculate_target_resolution(self):
        if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
            self.hr_resize_x = self.width
            self.hr_resize_y = self.height
            self.hr_upscale_to_x = self.width
            self.hr_upscale_to_y = self.height

            self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
            self.applied_old_hires_behavior_to = (self.width, self.height)

        if self.hr_resize_x == 0 and self.hr_resize_y == 0:
            self.extra_generation_params["Hires upscale"] = self.hr_scale
            self.hr_upscale_to_x = int(self.width * self.hr_scale)
            self.hr_upscale_to_y = int(self.height * self.hr_scale)
        else:
            self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"

            if self.hr_resize_y == 0:
                self.hr_upscale_to_x = self.hr_resize_x
                self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
            elif self.hr_resize_x == 0:
                self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
                self.hr_upscale_to_y = self.hr_resize_y
            else:
                target_w = self.hr_resize_x
                target_h = self.hr_resize_y
                src_ratio = self.width / self.height
                dst_ratio = self.hr_resize_x / self.hr_resize_y

                if src_ratio < dst_ratio:
                    self.hr_upscale_to_x = self.hr_resize_x
                    self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
                else:
                    self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
                    self.hr_upscale_to_y = self.hr_resize_y

                self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
                self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f

A
AUTOMATIC 已提交
1096 1097
    def init(self, all_prompts, all_seeds, all_subseeds):
        if self.enable_hr:
A
AUTOMATIC1111 已提交
1098 1099 1100 1101 1102 1103 1104 1105
            if self.hr_checkpoint_name:
                self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)

                if self.hr_checkpoint_info is None:
                    raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')

                self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title

1106 1107 1108 1109 1110
            if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
                self.extra_generation_params["Hires sampler"] = self.hr_sampler_name

            if tuple(self.hr_prompt) != tuple(self.prompt):
                self.extra_generation_params["Hires prompt"] = self.hr_prompt
I
InvincibleDude 已提交
1111

1112 1113
            if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
                self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
I
InvincibleDude 已提交
1114

A
AUTOMATIC1111 已提交
1115 1116 1117 1118 1119
            self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
            if self.enable_hr and self.latent_scale_mode is None:
                if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
                    raise Exception(f"could not find upscaler named {self.hr_upscaler}")

1120
            self.calculate_target_resolution()
1121

1122 1123 1124
            if not state.processing_has_refined_job_count:
                if state.job_count == -1:
                    state.job_count = self.n_iter
1125

1126 1127 1128
                shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
                state.job_count = state.job_count * 2
                state.processing_has_refined_job_count = True
1129

1130 1131
            if self.hr_second_pass_steps:
                self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
1132

A
AUTOMATIC 已提交
1133 1134
            if self.hr_upscaler is not None:
                self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
1135

1136
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1137
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1138

1139
        x = self.rng.next()
A
AUTOMATIC 已提交
1140
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
A
linter  
AUTOMATIC1111 已提交
1141
        del x
A
AUTOMATIC 已提交
1142

1143
        if not self.enable_hr:
A
AUTOMATIC 已提交
1144 1145
            return samples

1146
        if self.latent_scale_mode is None:
1147
            decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
1148 1149 1150
        else:
            decoded_samples = None

A
AUTOMATIC1111 已提交
1151 1152 1153
        current = shared.sd_model.sd_checkpoint_info
        try:
            if self.hr_checkpoint_info is not None:
A
AUTOMATIC1111 已提交
1154
                self.sampler = None
A
AUTOMATIC1111 已提交
1155
                sd_models.reload_model_weights(info=self.hr_checkpoint_info)
1156
                devices.torch_gc()
A
AUTOMATIC1111 已提交
1157

1158
            return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
A
AUTOMATIC1111 已提交
1159
        finally:
A
AUTOMATIC1111 已提交
1160
            self.sampler = None
A
AUTOMATIC1111 已提交
1161
            sd_models.reload_model_weights(info=current)
1162
            devices.torch_gc()
A
AUTOMATIC1111 已提交
1163

1164
    def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
1165 1166 1167
        if shared.state.interrupted:
            return samples

A
AUTOMATIC 已提交
1168 1169
        self.is_hr_pass = True

1170 1171
        target_width = self.hr_upscale_to_x
        target_height = self.hr_upscale_to_y
A
AUTOMATIC 已提交
1172

1173
        def save_intermediate(image, index):
A
AUTOMATIC 已提交
1174 1175
            """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""

A
AUTOMATIC1111 已提交
1176
            if not self.save_samples() or not opts.save_images_before_highres_fix:
1177 1178 1179
                return

            if not isinstance(image, Image.Image):
M
MMaker 已提交
1180
                image = sd_samplers.sample_to_image(image, index, approximation=0)
1181

1182
            info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
W
w-e-w 已提交
1183
            images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
1184

1185 1186 1187 1188
        img2img_sampler_name = self.hr_sampler_name or self.sampler_name

        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)

A
AUTOMATIC1111 已提交
1189
        if self.latent_scale_mode is not None:
1190 1191 1192
            for i in range(samples.shape[0]):
                save_intermediate(samples, i)

A
AUTOMATIC1111 已提交
1193
            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
1194

J
Jim Hays 已提交
1195
            # Avoid making the inpainting conditioning unless necessary as
1196 1197 1198 1199 1200
            # this does need some extra compute to decode / encode the image again.
            if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
                image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
            else:
                image_conditioning = self.txt2img_image_conditioning(samples)
A
AUTOMATIC 已提交
1201
        else:
1202
            lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
1203

1204 1205 1206 1207 1208
            batch_images = []
            for i, x_sample in enumerate(lowres_samples):
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)
                image = Image.fromarray(x_sample)
1209 1210 1211

                save_intermediate(image, i)

A
AUTOMATIC 已提交
1212
                image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
1213 1214 1215 1216 1217
                image = np.array(image).astype(np.float32) / 255.0
                image = np.moveaxis(image, 2, 0)
                batch_images.append(image)

            decoded_samples = torch.from_numpy(np.array(batch_images))
K
Kohaku-Blueleaf 已提交
1218
            decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
1219

1220 1221
            if opts.sd_vae_encode_method != 'Full':
                self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1222
            samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
A
AUTOMATIC 已提交
1223

1224
            image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
1225

A
AUTOMATIC 已提交
1226
        shared.state.nextjob()
1227

1228 1229
        samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]

1230 1231
        self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w)
        noise = self.rng.next()
1232 1233 1234

        # GC now before running the next img2img to prevent running out of memory
        devices.torch_gc()
1235

1236 1237 1238 1239
        if not self.disable_extra_networks:
            with devices.autocast():
                extra_networks.activate(self, self.hr_extra_network_data)

1240 1241 1242
        with devices.autocast():
            self.calculate_hr_conds()

1243
        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
1244

1245 1246 1247
        if self.scripts is not None:
            self.scripts.before_hr(self)

1248
        samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
A
AUTOMATIC 已提交
1249

1250
        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
P
papuSpartan 已提交
1251

1252 1253 1254
        self.sampler = None
        devices.torch_gc()

A
AUTOMATIC1111 已提交
1255 1256
        decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)

A
AUTOMATIC 已提交
1257 1258
        self.is_hr_pass = False

A
AUTOMATIC1111 已提交
1259
        return decoded_samples
1260

1261
    def close(self):
W
w-e-w 已提交
1262
        super().close()
1263 1264
        self.hr_c = None
        self.hr_uc = None
A
AUTOMATIC1111 已提交
1265
        if not opts.persistent_cond_cache:
W
w-e-w 已提交
1266 1267
            StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
            StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280

    def setup_prompts(self):
        super().setup_prompts()

        if not self.enable_hr:
            return

        if self.hr_prompt == '':
            self.hr_prompt = self.prompt

        if self.hr_negative_prompt == '':
            self.hr_negative_prompt = self.negative_prompt

X
XDOneDude 已提交
1281
        if isinstance(self.hr_prompt, list):
1282 1283 1284 1285
            self.all_hr_prompts = self.hr_prompt
        else:
            self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]

X
XDOneDude 已提交
1286
        if isinstance(self.hr_negative_prompt, list):
1287 1288 1289 1290 1291 1292 1293
            self.all_hr_negative_prompts = self.hr_negative_prompt
        else:
            self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]

        self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
        self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]

1294 1295 1296 1297
    def calculate_hr_conds(self):
        if self.hr_c is not None:
            return

A
AUTOMATIC1111 已提交
1298 1299 1300
        hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
        hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)

1301 1302 1303 1304
        sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name)
        steps = self.hr_second_pass_steps or self.steps
        total_steps = sampler_config.total_steps(steps) if sampler_config else steps

1305 1306
        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
1307

1308
    def setup_conds(self):
1309 1310 1311 1312 1313 1314
        if self.is_hr_pass:
            # if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
            self.hr_c = None
            self.calculate_hr_conds()
            return

1315 1316
        super().setup_conds()

1317 1318 1319
        self.hr_uc = None
        self.hr_c = None

A
AUTOMATIC1111 已提交
1320
        if self.enable_hr and self.hr_checkpoint_info is None:
1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331
            if shared.opts.hires_fix_use_firstpass_conds:
                self.calculate_hr_conds()

            elif lowvram.is_enabled(shared.sd_model):  # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
                with devices.autocast():
                    extra_networks.activate(self, self.hr_extra_network_data)

                self.calculate_hr_conds()

                with devices.autocast():
                    extra_networks.activate(self, self.extra_network_data)
1332

A
AUTOMATIC1111 已提交
1333 1334 1335 1336 1337 1338
    def get_conds(self):
        if self.is_hr_pass:
            return self.hr_c, self.hr_uc

        return super().get_conds()

1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349
    def parse_extra_network_prompts(self):
        res = super().parse_extra_network_prompts()

        if self.enable_hr:
            self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
            self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]

            self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)

        return res

1350

1351
@dataclass(repr=False)
1352
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379
    init_images: list = None
    resize_mode: int = 0
    denoising_strength: float = 0.75
    image_cfg_scale: float = None
    mask: Any = None
    mask_blur_x: int = 4
    mask_blur_y: int = 4
    mask_blur: int = None
    inpainting_fill: int = 0
    inpaint_full_res: bool = True
    inpaint_full_res_padding: int = 0
    inpainting_mask_invert: int = 0
    initial_noise_multiplier: float = None
    latent_mask: Image = None

    image_mask: Any = field(default=None, init=False)

    nmask: torch.Tensor = field(default=None, init=False)
    image_conditioning: torch.Tensor = field(default=None, init=False)
    init_img_hash: str = field(default=None, init=False)
    mask_for_overlay: Image = field(default=None, init=False)
    init_latent: torch.Tensor = field(default=None, init=False)

    def __post_init__(self):
        super().__post_init__()

        self.image_mask = self.mask
1380
        self.mask = None
1381
        self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
1382

1383 1384 1385 1386 1387 1388 1389 1390
    @property
    def mask_blur(self):
        if self.mask_blur_x == self.mask_blur_y:
            return self.mask_blur_x
        return None

    @mask_blur.setter
    def mask_blur(self, value):
1391 1392 1393
        if isinstance(value, int):
            self.mask_blur_x = value
            self.mask_blur_y = value
1394

A
AUTOMATIC 已提交
1395
    def init(self, all_prompts, all_seeds, all_subseeds):
1396 1397
        self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None

1398
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
1399 1400
        crop_region = None

1401
        image_mask = self.image_mask
A
AUTOMATIC 已提交
1402

1403
        if image_mask is not None:
1404 1405 1406
            # image_mask is passed in as RGBA by Gradio to support alpha masks,
            # but we still want to support binary masks.
            image_mask = create_binary_mask(image_mask)
A
AUTOMATIC 已提交
1407

1408 1409
            if self.inpainting_mask_invert:
                image_mask = ImageOps.invert(image_mask)
A
AUTOMATIC 已提交
1410

1411 1412
            if self.mask_blur_x > 0:
                np_mask = np.array(image_mask)
1413
                kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1
1414 1415 1416 1417 1418
                np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
                image_mask = Image.fromarray(np_mask)

            if self.mask_blur_y > 0:
                np_mask = np.array(image_mask)
1419
                kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1
1420 1421
                np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
                image_mask = Image.fromarray(np_mask)
1422 1423

            if self.inpaint_full_res:
1424 1425
                self.mask_for_overlay = image_mask
                mask = image_mask.convert('L')
1426
                crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
1427
                crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
1428 1429 1430
                x1, y1, x2, y2 = crop_region

                mask = mask.crop(crop_region)
1431
                image_mask = images.resize_image(2, mask, self.width, self.height)
1432 1433
                self.paste_to = (x1, y1, x2-x1, y2-y1)
            else:
1434
                image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
1435
                np_mask = np.array(image_mask)
J
JJ 已提交
1436
                np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
1437
                self.mask_for_overlay = Image.fromarray(np_mask)
1438 1439 1440

            self.overlay_images = []

1441
        latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
1442

1443 1444 1445
        add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
        if add_color_corrections:
            self.color_corrections = []
1446 1447
        imgs = []
        for img in self.init_images:
1448 1449 1450 1451 1452 1453

            # Save init image
            if opts.save_init_img:
                self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
                images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)

1454
            image = images.flatten(img, opts.img2img_background_color)
1455

A
Andrew Ryan 已提交
1456
            if crop_region is None and self.resize_mode != 3:
1457
                image = images.resize_image(self.resize_mode, image, self.width, self.height)
1458

1459
            if image_mask is not None:
1460 1461 1462 1463 1464
                image_masked = Image.new('RGBa', (image.width, image.height))
                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))

                self.overlay_images.append(image_masked.convert('RGBA'))

1465
            # crop_region is not None if we are doing inpaint full res
1466 1467 1468 1469
            if crop_region is not None:
                image = image.crop(crop_region)
                image = images.resize_image(2, image, self.width, self.height)

1470
            if image_mask is not None:
1471
                if self.inpainting_fill != 1:
1472
                    image = masking.fill(image, latent_mask)
1473

1474
            if add_color_corrections:
1475 1476
                self.color_corrections.append(setup_color_correction(image))

1477 1478 1479 1480 1481 1482 1483 1484 1485
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)

            imgs.append(image)

        if len(imgs) == 1:
            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size
1486 1487 1488 1489

            if self.color_corrections is not None and len(self.color_corrections) == 1:
                self.color_corrections = self.color_corrections * self.batch_size

1490 1491 1492 1493 1494 1495 1496
        elif len(imgs) <= self.batch_size:
            self.batch_size = len(imgs)
            batch_images = np.array(imgs)
        else:
            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

        image = torch.from_numpy(batch_images)
K
Kohaku-Blueleaf 已提交
1497
        image = image.to(shared.device, dtype=devices.dtype_vae)
1498 1499 1500 1501

        if opts.sd_vae_encode_method != 'Full':
            self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method

K
Kohaku-Blueleaf 已提交
1502
        self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
1503
        devices.torch_gc()
1504

1505 1506
        if self.resize_mode == 3:
            self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
A
Andrew Ryan 已提交
1507

1508
        if image_mask is not None:
1509
            init_mask = latent_mask
A
AUTOMATIC 已提交
1510
            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
A
AUTOMATIC 已提交
1511
            latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
1512
            latmask = latmask[0]
1513
            latmask = np.around(latmask)
1514 1515 1516 1517 1518
            latmask = np.tile(latmask[None], (4, 1, 1))

            self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
            self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)

A
AUTOMATIC 已提交
1519
            # this needs to be fixed to be done in sample() using actual seeds for batches
1520
            if self.inpainting_fill == 2:
A
AUTOMATIC 已提交
1521
                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
1522 1523 1524
            elif self.inpainting_fill == 3:
                self.init_latent = self.init_latent * self.mask

1525
        self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
1526

1527
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
1528
        x = self.rng.next()
1529 1530 1531 1532

        if self.initial_noise_multiplier != 1.0:
            self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
            x *= self.initial_noise_multiplier
A
AUTOMATIC 已提交
1533

1534
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
1535 1536 1537 1538

        if self.mask is not None:
            samples = samples * self.nmask + self.init_latent * self.mask

1539 1540 1541
        del x
        devices.torch_gc()

1542
        return samples
1543 1544 1545

    def get_token_merging_ratio(self, for_hr=False):
        return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio