processing.py 42.7 KB
Newer Older
1 2 3 4
import json
import math
import os
import sys
5
import warnings
6 7 8 9 10

import torch
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import random
11 12
import cv2
from skimage import exposure
A
arcticfaded 已提交
13
from typing import Any, Dict, List, Optional
14

15
import modules.sd_hijack
16
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
17 18 19
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
A
AUTOMATIC 已提交
20
import modules.face_restoration
21
import modules.images as images
A
AUTOMATIC 已提交
22
import modules.styles
23 24
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
R
Robin Fernandes 已提交
25
import logging
J
Jay Smith 已提交
26 27
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
28

J
Jay Smith 已提交
29
from einops import repeat, rearrange
30
from blendmodes.blend import blendLayers, BlendType
31

32 33 34 35 36
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8


37
def setup_color_correction(image):
R
Robin Fernandes 已提交
38
    logging.info("Calibrating color correction.")
39 40 41 42
    correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
    return correction_target


43
def apply_color_correction(correction, original_image):
R
Robin Fernandes 已提交
44
    logging.info("Applying color correction.")
45 46
    image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
        cv2.cvtColor(
47
            np.asarray(original_image),
48 49 50 51 52
            cv2.COLOR_RGB2LAB
        ),
        correction,
        channel_axis=2
    ), cv2.COLOR_LAB2RGB).astype("uint8"))
53 54 55
    
    image = blendLayers(image, original_image, BlendType.LUMINOSITY)
    
56 57
    return image

A
AUTOMATIC 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

def apply_overlay(image, paste_loc, index, overlays):
    if overlays is None or index >= len(overlays):
        return image

    overlay = overlays[index]

    if paste_loc is not None:
        x, y, w, h = paste_loc
        base_image = Image.new('RGBA', (overlay.width, overlay.height))
        image = images.resize_image(1, image, w, h)
        base_image.paste(image, (x, y))
        image = base_image

    image = image.convert('RGBA')
    image.alpha_composite(overlay)
    image = image.convert('RGB')
75 76

    return image
77

F
frostydad 已提交
78

A
arcticfaded 已提交
79 80 81 82
class StableDiffusionProcessing():
    """
    The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
    """
83
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
84
        if sampler_index is not None:
85
            print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
86

87 88 89 90
        self.sd_model = sd_model
        self.outpath_samples: str = outpath_samples
        self.outpath_grids: str = outpath_grids
        self.prompt: str = prompt
A
AUTOMATIC 已提交
91
        self.prompt_for_display: str = None
92
        self.negative_prompt: str = (negative_prompt or "")
93
        self.styles: list = styles or []
94
        self.seed: int = seed
95 96 97 98
        self.subseed: int = subseed
        self.subseed_strength: float = subseed_strength
        self.seed_resize_from_h: int = seed_resize_from_h
        self.seed_resize_from_w: int = seed_resize_from_w
99
        self.sampler_name: str = sampler_name
100 101 102 103 104 105
        self.batch_size: int = batch_size
        self.n_iter: int = n_iter
        self.steps: int = steps
        self.cfg_scale: float = cfg_scale
        self.width: int = width
        self.height: int = height
A
AUTOMATIC 已提交
106
        self.restore_faces: bool = restore_faces
107
        self.tiling: bool = tiling
108 109
        self.do_not_save_samples: bool = do_not_save_samples
        self.do_not_save_grid: bool = do_not_save_grid
A
AUTOMATIC 已提交
110
        self.extra_generation_params: dict = extra_generation_params or {}
111
        self.overlay_images = overlay_images
112
        self.eta = eta
113
        self.do_not_reload_embeddings = do_not_reload_embeddings
114
        self.paste_to = None
115
        self.color_corrections = None
116
        self.denoising_strength: float = denoising_strength
117
        self.sampler_noise_scheduler_override = None
118
        self.ddim_discretize = ddim_discretize or opts.ddim_discretize
A
arcticfaded 已提交
119 120 121 122
        self.s_churn = s_churn or opts.s_churn
        self.s_tmin = s_tmin or opts.s_tmin
        self.s_tmax = s_tmax or float('inf')  # not representable as a standard ui option
        self.s_noise = s_noise or opts.s_noise
123
        self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
124
        self.override_settings_restore_afterwards = override_settings_restore_afterwards
125
        self.is_using_inpainting_conditioning = False
126

127 128 129 130 131 132
        if not seed_enable_extras:
            self.subseed = -1
            self.subseed_strength = 0
            self.seed_resize_from_h = 0
            self.seed_resize_from_w = 0

133 134 135
        self.scripts = None
        self.script_args = None
        self.all_prompts = None
136
        self.all_negative_prompts = None
137 138 139
        self.all_seeds = None
        self.all_subseeds = None

140 141 142 143 144
    def txt2img_image_conditioning(self, x, width=None, height=None):
        if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
            # Dummy zero conditioning if we're not using inpainting model.
            # Still takes up a bit of memory, but no encoder call.
            # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
145
            return x.new_zeros(x.shape[0], 5, 1, 1)
146

147 148
        self.is_using_inpainting_conditioning = True

149 150 151 152 153
        height = height or self.height
        width = width or self.width

        # The "masked-image" in this case will just be all zeros since the entire image is masked.
        image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
J
Jim Hays 已提交
154
        image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
155 156 157

        # Add the fake full 1s mask to the first dimension.
        image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
J
Jim Hays 已提交
158
        image_conditioning = image_conditioning.to(x.dtype)
159 160 161

        return image_conditioning

J
Jay Smith 已提交
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    def depth2img_image_conditioning(self, source_image):
        # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
        transformer = AddMiDaS(model_type="dpt_hybrid")
        transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
        midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
        midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)

        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
        conditioning = torch.nn.functional.interpolate(
            self.sd_model.depth_model(midas_in),
            size=conditioning_image.shape[2:],
            mode="bicubic",
            align_corners=False,
        )

        (depth_min, depth_max) = torch.aminmax(conditioning)
        conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
        return conditioning
180

J
Jay Smith 已提交
181
    def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
182 183
        self.is_using_inpainting_conditioning = True

184 185 186 187 188 189 190 191 192 193 194 195
        # Handle the different mask inputs
        if image_mask is not None:
            if torch.is_tensor(image_mask):
                conditioning_mask = image_mask
            else:
                conditioning_mask = np.array(image_mask.convert("L"))
                conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
                conditioning_mask = torch.from_numpy(conditioning_mask[None, None])

                # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
                conditioning_mask = torch.round(conditioning_mask)
        else:
196
            conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
197 198 199

        # Create another latent image, this time with a masked version of the original input.
        # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
R
random_thoughtss 已提交
200
        conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
201 202 203 204 205
        conditioning_image = torch.lerp(
            source_image,
            source_image * (1.0 - conditioning_mask),
            getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
        )
J
Jim Hays 已提交
206

207 208 209 210 211 212 213 214 215 216 217
        # Encode the new masked image using first stage of network.
        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))

        # Create the concatenated conditioning tensor to be fed to `c_concat`
        conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
        conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
        image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
        image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)

        return image_conditioning

J
Jay Smith 已提交
218 219 220 221 222 223 224 225 226 227 228 229
    def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
        # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
        # identify itself with a field common to all models. The conditioning_key is also hybrid.
        if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
            return self.depth2img_image_conditioning(source_image)

        if self.sampler.conditioning_key in {'hybrid', 'concat'}:
            return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)

        # Dummy zero conditioning if we're not using inpainting or depth model.
        return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)

A
AUTOMATIC 已提交
230
    def init(self, all_prompts, all_seeds, all_subseeds):
231 232
        pass

233
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
234 235
        raise NotImplementedError()

236 237 238 239
    def close(self):
        self.sd_model = None
        self.sampler = None

240 241

class Processed:
242
    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
243 244
        self.images = images_list
        self.prompt = p.prompt
245
        self.negative_prompt = p.negative_prompt
246
        self.seed = seed
247 248
        self.subseed = subseed
        self.subseed_strength = p.subseed_strength
249
        self.info = info
250
        self.comments = comments
251 252
        self.width = p.width
        self.height = p.height
253
        self.sampler_name = p.sampler_name
254 255
        self.cfg_scale = p.cfg_scale
        self.steps = p.steps
256 257 258 259 260 261 262 263 264
        self.batch_size = p.batch_size
        self.restore_faces = p.restore_faces
        self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
        self.sd_model_hash = shared.sd_model.sd_model_hash
        self.seed_resize_from_w = p.seed_resize_from_w
        self.seed_resize_from_h = p.seed_resize_from_h
        self.denoising_strength = getattr(p, 'denoising_strength', None)
        self.extra_generation_params = p.extra_generation_params
        self.index_of_first_image = index_of_first_image
M
Milly 已提交
265
        self.styles = p.styles
M
Milly 已提交
266
        self.job_timestamp = state.job_timestamp
267
        self.clip_skip = opts.CLIP_stop_at_last_layers
268

C
C43H66N12O12S2 已提交
269
        self.eta = p.eta
270 271 272 273 274
        self.ddim_discretize = p.ddim_discretize
        self.s_churn = p.s_churn
        self.s_tmin = p.s_tmin
        self.s_tmax = p.s_tmax
        self.s_noise = p.s_noise
275
        self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
276 277
        self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
        self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
G
githublsx 已提交
278
        self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
279
        self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
280
        self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
281

282 283 284 285
        self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
        self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
        self.all_seeds = all_seeds or p.all_seeds or [self.seed]
        self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
286
        self.infotexts = infotexts or [info]
287 288 289

    def js(self):
        obj = {
290
            "prompt": self.all_prompts[0],
291
            "all_prompts": self.all_prompts,
292 293
            "negative_prompt": self.all_negative_prompts[0],
            "all_negative_prompts": self.all_negative_prompts,
294 295 296 297
            "seed": self.seed,
            "all_seeds": self.all_seeds,
            "subseed": self.subseed,
            "all_subseeds": self.all_subseeds,
298
            "subseed_strength": self.subseed_strength,
299 300
            "width": self.width,
            "height": self.height,
301
            "sampler_name": self.sampler_name,
302 303
            "cfg_scale": self.cfg_scale,
            "steps": self.steps,
304 305 306 307 308 309 310 311 312
            "batch_size": self.batch_size,
            "restore_faces": self.restore_faces,
            "face_restoration_model": self.face_restoration_model,
            "sd_model_hash": self.sd_model_hash,
            "seed_resize_from_w": self.seed_resize_from_w,
            "seed_resize_from_h": self.seed_resize_from_h,
            "denoising_strength": self.denoising_strength,
            "extra_generation_params": self.extra_generation_params,
            "index_of_first_image": self.index_of_first_image,
313
            "infotexts": self.infotexts,
M
Milly 已提交
314
            "styles": self.styles,
M
Milly 已提交
315
            "job_timestamp": self.job_timestamp,
316
            "clip_skip": self.clip_skip,
317
            "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
318 319 320 321
        }

        return json.dumps(obj)

S
space-nuko 已提交
322
    def infotext(self, p: StableDiffusionProcessing, index):
323 324 325
        return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)


326 327 328 329
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
330 331 332 333 334 335
    dot = (low_norm*high_norm).sum(1)

    if dot.mean() > 0.9995:
        return low * val + high * (1 - val)

    omega = torch.acos(dot)
336 337 338
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
    return res
339

340

341
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
342
    eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
343
    xs = []
344

345 346
    # if we have multiple seeds, this means we are working with batch size>1; this then
    # enables the generation of additional tensors with noise that the sampler will use during its processing.
S
Steve Eberhardt 已提交
347
    # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
348
    # produce the same images as with two batches [100], [101].
349
    if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
350 351 352 353
        sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
    else:
        sampler_noises = None

354 355 356 357 358 359
    for i, seed in enumerate(seeds):
        noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)

        subnoise = None
        if subseeds is not None:
            subseed = 0 if i >= len(subseeds) else subseeds[i]
E
Elias Oenal 已提交
360

A
AUTOMATIC 已提交
361
            subnoise = devices.randn(subseed, noise_shape)
362 363 364

        # randn results depend on device; gpu and cpu get different results for same seed;
        # the way I see it, it's better to do this on CPU, so that everyone gets same result;
365
        # but the original script had it like this, so I do not dare change it for now because
366
        # it will break everyone's seeds.
A
AUTOMATIC 已提交
367
        noise = devices.randn(seed, noise_shape)
368 369 370 371 372

        if subnoise is not None:
            noise = slerp(subseed_strength, noise, subnoise)

        if noise_shape != shape:
A
AUTOMATIC 已提交
373 374
            x = devices.randn(seed, shape)
            dx = (shape[2] - noise_shape[2]) // 2
375 376 377 378 379 380 381 382 383 384 385
            dy = (shape[1] - noise_shape[1]) // 2
            w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
            h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
            tx = 0 if dx < 0 else dx
            ty = 0 if dy < 0 else dy
            dx = max(-dx, 0)
            dy = max(-dy, 0)

            x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
            noise = x

386 387
        if sampler_noises is not None:
            cnt = p.sampler.number_of_needed_noises(p)
388

389 390
            if eta_noise_seed_delta > 0:
                torch.manual_seed(seed + eta_noise_seed_delta)
A
AUTOMATIC 已提交
391

392 393
            for j in range(cnt):
                sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
394 395

        xs.append(noise)
396 397 398 399

    if sampler_noises is not None:
        p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]

400
    x = torch.stack(xs).to(shared.device)
401 402 403
    return x


A
AUTOMATIC 已提交
404 405 406 407 408 409 410
def decode_first_stage(model, x):
    with devices.autocast(disable=x.dtype == devices.dtype_vae):
        x = model.decode_first_stage(x)

    return x


411 412 413 414 415 416 417
def get_fixed_seed(seed):
    if seed is None or seed == '' or seed == -1:
        return int(random.randrange(4294967294))

    return seed


418
def fix_seed(p):
419 420
    p.seed = get_fixed_seed(p.seed)
    p.subseed = get_fixed_seed(p.subseed)
A
AUTOMATIC 已提交
421 422


423 424 425
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
    index = position_in_batch + iteration * p.batch_size

426
    clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
427

428 429
    generation_params = {
        "Steps": p.steps,
430
        "Sampler": p.sampler_name,
431 432 433 434 435
        "CFG scale": p.cfg_scale,
        "Seed": all_seeds[index],
        "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
        "Size": f"{p.width}x{p.height}",
        "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
436
        "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
437
        "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
S
space-nuko 已提交
438
        "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
439
        "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
440 441 442 443 444 445
        "Batch size": (None if p.batch_size < 2 else p.batch_size),
        "Batch pos": (None if p.batch_size < 2 else position_in_batch),
        "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
        "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
        "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
        "Denoising strength": getattr(p, 'denoising_strength', None),
446
        "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
447
        "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
448
        "Clip skip": None if clip_skip <= 1 else clip_skip,
A
AUTOMATIC 已提交
449
        "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
450 451
    }

A
AUTOMATIC 已提交
452
    generation_params.update(p.extra_generation_params)
453

454
    generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
455

S
space-nuko 已提交
456
    negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
457

458
    return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
459 460


461
def process_images(p: StableDiffusionProcessing) -> Processed:
462 463 464 465
    stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}

    try:
        for k, v in p.override_settings.items():
466 467 468 469
            setattr(opts, k, v)
            if k == 'sd_hypernetwork': shared.reload_hypernetworks()  # make onchange call for changing hypernet
            if k == 'sd_model_checkpoint': sd_models.reload_model_weights()  # make onchange call for changing SD model
            if k == 'sd_vae': sd_vae.reload_vae_weights()  # make onchange call for changing VAE
470 471 472

        res = process_images_inner(p)

473 474 475 476 477 478 479 480
    finally:
        # restore opts to original state
        if p.override_settings_restore_afterwards:
            for k, v in stored_opts.items():
                setattr(opts, k, v)
                if k == 'sd_hypernetwork': shared.reload_hypernetworks()
                if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
                if k == 'sd_vae': sd_vae.reload_vae_weights()
481 482 483 484 485

    return res


def process_images_inner(p: StableDiffusionProcessing) -> Processed:
486 487
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

488 489 490 491
    if type(p.prompt) == list:
        assert(len(p.prompt) > 0)
    else:
        assert p.prompt is not None
492

493
    devices.torch_gc()
494

495 496
    seed = get_fixed_seed(p.seed)
    subseed = get_fixed_seed(p.subseed)
497

498
    modules.sd_hijack.model_hijack.apply_circular(p.tiling)
499
    modules.sd_hijack.model_hijack.clear_comments()
500

501
    comments = {}
A
AUTOMATIC 已提交
502 503

    if type(p.prompt) == list:
504 505 506 507 508 509
        p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
    else:
        p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]

    if type(p.negative_prompt) == list:
        p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
510
    else:
511
        p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
A
AUTOMATIC 已提交
512

513
    if type(seed) == list:
514
        p.all_seeds = seed
A
AUTOMATIC 已提交
515
    else:
516
        p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
517

518
    if type(subseed) == list:
519
        p.all_subseeds = subseed
520
    else:
521
        p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
522 523

    def infotext(iteration=0, position_in_batch=0):
524
        return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
525

526 527 528 529
    with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
        processed = Processed(p, [], p.seed, "")
        file.write(processed.infotext(p, 0))

530
    if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
531
        model_hijack.embedding_db.load_textual_inversion_embeddings()
532

533
    if p.scripts is not None:
A
AUTOMATIC 已提交
534
        p.scripts.process(p)
535

536
    infotexts = []
537
    output_images = []
538

539
    with torch.no_grad(), p.sd_model.ema_scope():
A
AUTOMATIC 已提交
540
        with devices.autocast():
541
            p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
542

A
AUTOMATIC 已提交
543 544
        if state.job_count == -1:
            state.job_count = p.n_iter
A
AUTOMATIC 已提交
545

546
        for n in range(p.n_iter):
547 548
            if state.skipped:
                state.skipped = False
J
Jim Hays 已提交
549

550 551 552
            if state.interrupted:
                break

553
            prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
554
            negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
555 556
            seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
            subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
557

A
AUTOMATIC 已提交
558
            if len(prompts) == 0:
559 560
                break

A
Artem Zagidulin 已提交
561
            if p.scripts is not None:
562
                p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
A
Artem Zagidulin 已提交
563

564
            with devices.autocast():
565
                uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
A
AUTOMATIC 已提交
566
                c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
567 568

            if len(model_hijack.comments) > 0:
569 570
                for comment in model_hijack.comments:
                    comments[comment] = 1
571 572

            if p.n_iter > 1:
573
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"
574

575
            with devices.autocast():
576
                samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
577

578 579
            x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
            x_samples_ddim = torch.stack(x_samples_ddim).float()
580 581
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

582 583 584 585 586 587 588
            del samples_ddim

            if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
                lowvram.send_everything_to_cpu()

            devices.torch_gc()

589 590
            if p.scripts is not None:
                p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
G
GRMrGecko 已提交
591

592
            for i, x_sample in enumerate(x_samples_ddim):
593 594 595
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)

596
                if p.restore_faces:
597
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
598
                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
599

600
                    devices.torch_gc()
601

602 603
                    x_sample = modules.face_restoration.restore_faces(x_sample)
                    devices.torch_gc()
604

605
                image = Image.fromarray(x_sample)
V
Vladimir Repin 已提交
606

607
                if p.color_corrections is not None and i < len(p.color_corrections):
608
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
A
AUTOMATIC 已提交
609
                        image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
610
                        images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
611
                    image = apply_color_correction(p.color_corrections[i], image)
612

A
AUTOMATIC 已提交
613
                image = apply_overlay(image, p.paste_to, i, p.overlay_images)
614 615

                if opts.samples_save and not p.do_not_save_samples:
616
                    images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
617

618 619
                text = infotext(n, i)
                infotexts.append(text)
620 621
                if opts.enable_pnginfo:
                    image.info["parameters"] = text
622 623
                output_images.append(image)

J
Jim Hays 已提交
624
            del x_samples_ddim
A
AUTOMATIC 已提交
625

626
            devices.torch_gc()
627

628
            state.nextjob()
629

630 631
        p.color_corrections = None

632
        index_of_first_image = 0
633
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
634
        if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
A
AUTOMATIC 已提交
635
            grid = images.image_grid(output_images, p.batch_size)
636

637
            if opts.return_grid:
638 639
                text = infotext()
                infotexts.insert(0, text)
640 641
                if opts.enable_pnginfo:
                    grid.info["parameters"] = text
642
                output_images.insert(0, grid)
643
                index_of_first_image = 1
644 645

            if opts.grid_save:
646
                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
647

648
    devices.torch_gc()
A
AUTOMATIC 已提交
649

650
    res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
A
AUTOMATIC 已提交
651 652 653 654 655

    if p.scripts is not None:
        p.scripts.postprocess(p, res)

    return res
656 657 658 659


class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    sampler = None
A
AUTOMATIC 已提交
660

A
AUTOMATIC 已提交
661
    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
A
AUTOMATIC 已提交
662 663 664
        super().__init__(**kwargs)
        self.enable_hr = enable_hr
        self.denoising_strength = denoising_strength
A
AUTOMATIC 已提交
665 666 667 668 669 670 671 672
        self.hr_scale = hr_scale
        self.hr_upscaler = hr_upscaler

        if firstphase_width != 0 or firstphase_height != 0:
            print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
            self.hr_scale = self.width / firstphase_width
            self.width = firstphase_width
            self.height = firstphase_height
A
AUTOMATIC 已提交
673 674 675 676 677 678 679 680

    def init(self, all_prompts, all_seeds, all_subseeds):
        if self.enable_hr:
            if state.job_count == -1:
                state.job_count = self.n_iter * 2
            else:
                state.job_count = state.job_count * 2

A
AUTOMATIC 已提交
681 682 683
            self.extra_generation_params["Hires upscale"] = self.hr_scale
            if self.hr_upscaler is not None:
                self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
684

685
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
686
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
687

688
        latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
A
AUTOMATIC 已提交
689 690 691 692 693 694
        if self.enable_hr and latent_scale_mode is None:
            assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"

        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))

695
        if not self.enable_hr:
A
AUTOMATIC 已提交
696 697
            return samples

A
AUTOMATIC 已提交
698 699
        target_width = int(self.width * self.hr_scale)
        target_height = int(self.height * self.hr_scale)
A
AUTOMATIC 已提交
700

701
        def save_intermediate(image, index):
A
AUTOMATIC 已提交
702 703
            """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""

704 705 706 707 708 709 710 711
            if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
                return

            if not isinstance(image, Image.Image):
                image = sd_samplers.sample_to_image(image, index)

            images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")

A
AUTOMATIC 已提交
712
        if latent_scale_mode is not None:
713 714 715
            for i in range(samples.shape[0]):
                save_intermediate(samples, i)

A
AUTOMATIC 已提交
716
            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode)
717

J
Jim Hays 已提交
718
            # Avoid making the inpainting conditioning unless necessary as
719 720 721 722 723
            # this does need some extra compute to decode / encode the image again.
            if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
                image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
            else:
                image_conditioning = self.txt2img_image_conditioning(samples)
A
AUTOMATIC 已提交
724
        else:
725
            decoded_samples = decode_first_stage(self.sd_model, samples)
726
            lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
727

728 729 730 731 732
            batch_images = []
            for i, x_sample in enumerate(lowres_samples):
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)
                image = Image.fromarray(x_sample)
733 734 735

                save_intermediate(image, i)

A
AUTOMATIC 已提交
736
                image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
737 738 739 740 741 742 743 744
                image = np.array(image).astype(np.float32) / 255.0
                image = np.moveaxis(image, 2, 0)
                batch_images.append(image)

            decoded_samples = torch.from_numpy(np.array(batch_images))
            decoded_samples = decoded_samples.to(shared.device)
            decoded_samples = 2. * decoded_samples - 1.

745
            samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
A
AUTOMATIC 已提交
746

747
            image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
748

A
AUTOMATIC 已提交
749
        shared.state.nextjob()
750

751
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
A
AUTOMATIC 已提交
752

A
AUTOMATIC 已提交
753
        noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
754 755 756 757

        # GC now before running the next img2img to prevent running out of memory
        x = None
        devices.torch_gc()
758

759
        samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
A
AUTOMATIC 已提交
760 761

        return samples
762 763 764 765 766


class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    sampler = None

767
    def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
768 769 770 771 772 773 774
        super().__init__(**kwargs)

        self.init_images = init_images
        self.resize_mode: int = resize_mode
        self.denoising_strength: float = denoising_strength
        self.init_latent = None
        self.image_mask = mask
A
AUTOMATIC 已提交
775
        self.latent_mask = None
776 777 778 779
        self.mask_for_overlay = None
        self.mask_blur = mask_blur
        self.inpainting_fill = inpainting_fill
        self.inpaint_full_res = inpaint_full_res
780
        self.inpaint_full_res_padding = inpaint_full_res_padding
A
AUTOMATIC 已提交
781
        self.inpainting_mask_invert = inpainting_mask_invert
782
        self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
783 784
        self.mask = None
        self.nmask = None
785
        self.image_conditioning = None
786

A
AUTOMATIC 已提交
787
    def init(self, all_prompts, all_seeds, all_subseeds):
788
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
789 790
        crop_region = None

791
        image_mask = self.image_mask
A
AUTOMATIC 已提交
792

793 794
        if image_mask is not None:
            image_mask = image_mask.convert('L')
A
AUTOMATIC 已提交
795

796 797
            if self.inpainting_mask_invert:
                image_mask = ImageOps.invert(image_mask)
A
AUTOMATIC 已提交
798

799
            if self.mask_blur > 0:
800
                image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
801 802

            if self.inpaint_full_res:
803 804
                self.mask_for_overlay = image_mask
                mask = image_mask.convert('L')
805
                crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
806
                crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
807 808 809
                x1, y1, x2, y2 = crop_region

                mask = mask.crop(crop_region)
810
                image_mask = images.resize_image(2, mask, self.width, self.height)
811 812
                self.paste_to = (x1, y1, x2-x1, y2-y1)
            else:
813 814
                image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
                np_mask = np.array(image_mask)
J
JJ 已提交
815
                np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
816
                self.mask_for_overlay = Image.fromarray(np_mask)
817 818 819

            self.overlay_images = []

820
        latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
821

822 823 824
        add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
        if add_color_corrections:
            self.color_corrections = []
825 826
        imgs = []
        for img in self.init_images:
827
            image = images.flatten(img, opts.img2img_background_color)
828

A
Andrew Ryan 已提交
829
            if crop_region is None and self.resize_mode != 3:
830 831
                image = images.resize_image(self.resize_mode, image, self.width, self.height)

832
            if image_mask is not None:
833 834 835 836 837
                image_masked = Image.new('RGBa', (image.width, image.height))
                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))

                self.overlay_images.append(image_masked.convert('RGBA'))

838
            # crop_region is not None if we are doing inpaint full res
839 840 841 842
            if crop_region is not None:
                image = image.crop(crop_region)
                image = images.resize_image(2, image, self.width, self.height)

843
            if image_mask is not None:
844
                if self.inpainting_fill != 1:
845
                    image = masking.fill(image, latent_mask)
846

847
            if add_color_corrections:
848 849
                self.color_corrections.append(setup_color_correction(image))

850 851 852 853 854 855 856 857 858
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)

            imgs.append(image)

        if len(imgs) == 1:
            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size
859 860 861 862

            if self.color_corrections is not None and len(self.color_corrections) == 1:
                self.color_corrections = self.color_corrections * self.batch_size

863 864 865 866 867 868 869 870 871 872 873 874
        elif len(imgs) <= self.batch_size:
            self.batch_size = len(imgs)
            batch_images = np.array(imgs)
        else:
            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

        image = torch.from_numpy(batch_images)
        image = 2. * image - 1.
        image = image.to(shared.device)

        self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))

A
Andrew Ryan 已提交
875 876 877
        if self.resize_mode == 3:
            self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")

878
        if image_mask is not None:
879
            init_mask = latent_mask
A
AUTOMATIC 已提交
880
            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
A
AUTOMATIC 已提交
881
            latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
882
            latmask = latmask[0]
883
            latmask = np.around(latmask)
884 885 886 887 888
            latmask = np.tile(latmask[None], (4, 1, 1))

            self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
            self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)

A
AUTOMATIC 已提交
889
            # this needs to be fixed to be done in sample() using actual seeds for batches
890
            if self.inpainting_fill == 2:
A
AUTOMATIC 已提交
891
                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
892 893 894
            elif self.inpainting_fill == 3:
                self.init_latent = self.init_latent * self.mask

895
        self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
896

897
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
A
AUTOMATIC 已提交
898
        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
899 900 901 902

        if self.initial_noise_multiplier != 1.0:
            self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
            x *= self.initial_noise_multiplier
A
AUTOMATIC 已提交
903

904
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
905 906 907 908

        if self.mask is not None:
            samples = samples * self.nmask + self.init_latent * self.mask

909 910 911
        del x
        devices.torch_gc()

912
        return samples