processing.py 38.4 KB
Newer Older
1 2 3 4 5 6 7 8 9
import json
import math
import os
import sys

import torch
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import random
10 11
import cv2
from skimage import exposure
A
arcticfaded 已提交
12
from typing import Any, Dict, List, Optional
13

14
import modules.sd_hijack
15
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
16 17 18
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
A
AUTOMATIC 已提交
19
import modules.face_restoration
20
import modules.images as images
A
AUTOMATIC 已提交
21
import modules.styles
R
Robin Fernandes 已提交
22
import logging
23

24

25 26 27 28 29
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8


30
def setup_color_correction(image):
R
Robin Fernandes 已提交
31
    logging.info("Calibrating color correction.")
32 33 34 35 36
    correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
    return correction_target


def apply_color_correction(correction, image):
R
Robin Fernandes 已提交
37
    logging.info("Applying color correction.")
38 39 40 41 42 43 44 45 46 47 48
    image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
        cv2.cvtColor(
            np.asarray(image),
            cv2.COLOR_RGB2LAB
        ),
        correction,
        channel_axis=2
    ), cv2.COLOR_LAB2RGB).astype("uint8"))

    return image

A
AUTOMATIC 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65

def apply_overlay(image, paste_loc, index, overlays):
    if overlays is None or index >= len(overlays):
        return image

    overlay = overlays[index]

    if paste_loc is not None:
        x, y, w, h = paste_loc
        base_image = Image.new('RGBA', (overlay.width, overlay.height))
        image = images.resize_image(1, image, w, h)
        base_image.paste(image, (x, y))
        image = base_image

    image = image.convert('RGBA')
    image.alpha_composite(overlay)
    image = image.convert('RGB')
66 67

    return image
68

F
frostydad 已提交
69 70 71 72 73
def get_correct_sampler(p):
    if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
        return sd_samplers.samplers
    elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
        return sd_samplers.samplers_for_img2img
A
arcticfaded 已提交
74 75
    elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
        return sd_samplers.samplers
F
frostydad 已提交
76

A
arcticfaded 已提交
77 78 79 80
class StableDiffusionProcessing():
    """
    The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
    """
81
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
82 83 84 85
        self.sd_model = sd_model
        self.outpath_samples: str = outpath_samples
        self.outpath_grids: str = outpath_grids
        self.prompt: str = prompt
A
AUTOMATIC 已提交
86
        self.prompt_for_display: str = None
87
        self.negative_prompt: str = (negative_prompt or "")
88
        self.styles: list = styles or []
89
        self.seed: int = seed
90 91 92 93
        self.subseed: int = subseed
        self.subseed_strength: float = subseed_strength
        self.seed_resize_from_h: int = seed_resize_from_h
        self.seed_resize_from_w: int = seed_resize_from_w
94 95 96 97 98 99 100
        self.sampler_index: int = sampler_index
        self.batch_size: int = batch_size
        self.n_iter: int = n_iter
        self.steps: int = steps
        self.cfg_scale: float = cfg_scale
        self.width: int = width
        self.height: int = height
A
AUTOMATIC 已提交
101
        self.restore_faces: bool = restore_faces
102
        self.tiling: bool = tiling
103 104
        self.do_not_save_samples: bool = do_not_save_samples
        self.do_not_save_grid: bool = do_not_save_grid
A
AUTOMATIC 已提交
105
        self.extra_generation_params: dict = extra_generation_params or {}
106
        self.overlay_images = overlay_images
107
        self.eta = eta
108
        self.do_not_reload_embeddings = do_not_reload_embeddings
109
        self.paste_to = None
110
        self.color_corrections = None
111
        self.denoising_strength: float = denoising_strength
112
        self.sampler_noise_scheduler_override = None
113
        self.ddim_discretize = ddim_discretize or opts.ddim_discretize
A
arcticfaded 已提交
114 115 116 117
        self.s_churn = s_churn or opts.s_churn
        self.s_tmin = s_tmin or opts.s_tmin
        self.s_tmax = s_tmax or float('inf')  # not representable as a standard ui option
        self.s_noise = s_noise or opts.s_noise
118
        self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
119

120 121 122 123 124 125
        if not seed_enable_extras:
            self.subseed = -1
            self.subseed_strength = 0
            self.seed_resize_from_h = 0
            self.seed_resize_from_w = 0

126 127 128 129 130 131
        self.scripts = None
        self.script_args = None
        self.all_prompts = None
        self.all_seeds = None
        self.all_subseeds = None

132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
    def txt2img_image_conditioning(self, x, width=None, height=None):
        if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
            # Dummy zero conditioning if we're not using inpainting model.
            # Still takes up a bit of memory, but no encoder call.
            # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
            return torch.zeros(
                x.shape[0], 5, 1, 1, 
                dtype=x.dtype, 
                device=x.device
            )

        height = height or self.height
        width = width or self.width

        # The "masked-image" in this case will just be all zeros since the entire image is masked.
        image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
        image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) 

        # Add the fake full 1s mask to the first dimension.
        image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
        image_conditioning = image_conditioning.to(x.dtype)            

        return image_conditioning

    def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
        if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
            # Dummy zero conditioning if we're not using inpainting model.
            return torch.zeros(
                latent_image.shape[0], 5, 1, 1,
                dtype=latent_image.dtype,
                device=latent_image.device
            )

        # Handle the different mask inputs
        if image_mask is not None:
            if torch.is_tensor(image_mask):
                conditioning_mask = image_mask
            else:
                conditioning_mask = np.array(image_mask.convert("L"))
                conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
                conditioning_mask = torch.from_numpy(conditioning_mask[None, None])

                # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
                conditioning_mask = torch.round(conditioning_mask)
        else:
            conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])

        # Create another latent image, this time with a masked version of the original input.
        # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
        conditioning_mask = conditioning_mask.to(source_image.device)
        conditioning_image = torch.lerp(
            source_image,
            source_image * (1.0 - conditioning_mask),
            getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
        )
        
        # Encode the new masked image using first stage of network.
        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))

        # Create the concatenated conditioning tensor to be fed to `c_concat`
        conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
        conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
        image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
        image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)

        return image_conditioning

A
AUTOMATIC 已提交
199
    def init(self, all_prompts, all_seeds, all_subseeds):
200 201
        pass

A
AUTOMATIC 已提交
202
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
203 204 205 206
        raise NotImplementedError()


class Processed:
207
    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
208 209
        self.images = images_list
        self.prompt = p.prompt
210
        self.negative_prompt = p.negative_prompt
211
        self.seed = seed
212 213
        self.subseed = subseed
        self.subseed_strength = p.subseed_strength
214 215 216
        self.info = info
        self.width = p.width
        self.height = p.height
217
        self.sampler_index = p.sampler_index
218
        self.sampler = sd_samplers.samplers[p.sampler_index].name
219 220
        self.cfg_scale = p.cfg_scale
        self.steps = p.steps
221 222 223 224 225 226 227 228 229
        self.batch_size = p.batch_size
        self.restore_faces = p.restore_faces
        self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
        self.sd_model_hash = shared.sd_model.sd_model_hash
        self.seed_resize_from_w = p.seed_resize_from_w
        self.seed_resize_from_h = p.seed_resize_from_h
        self.denoising_strength = getattr(p, 'denoising_strength', None)
        self.extra_generation_params = p.extra_generation_params
        self.index_of_first_image = index_of_first_image
M
Milly 已提交
230
        self.styles = p.styles
M
Milly 已提交
231
        self.job_timestamp = state.job_timestamp
232
        self.clip_skip = opts.CLIP_stop_at_last_layers
233

C
C43H66N12O12S2 已提交
234
        self.eta = p.eta
235 236 237 238 239
        self.ddim_discretize = p.ddim_discretize
        self.s_churn = p.s_churn
        self.s_tmin = p.s_tmin
        self.s_tmax = p.s_tmax
        self.s_noise = p.s_noise
240
        self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
241 242
        self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
        self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
G
githublsx 已提交
243
        self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
244 245 246 247 248
        self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1

        self.all_prompts = all_prompts or [self.prompt]
        self.all_seeds = all_seeds or [self.seed]
        self.all_subseeds = all_subseeds or [self.subseed]
249
        self.infotexts = infotexts or [info]
250 251 252

    def js(self):
        obj = {
253 254 255 256 257 258 259
            "prompt": self.prompt,
            "all_prompts": self.all_prompts,
            "negative_prompt": self.negative_prompt,
            "seed": self.seed,
            "all_seeds": self.all_seeds,
            "subseed": self.subseed,
            "all_subseeds": self.all_subseeds,
260
            "subseed_strength": self.subseed_strength,
261 262
            "width": self.width,
            "height": self.height,
263
            "sampler_index": self.sampler_index,
264 265 266
            "sampler": self.sampler,
            "cfg_scale": self.cfg_scale,
            "steps": self.steps,
267 268 269 270 271 272 273 274 275
            "batch_size": self.batch_size,
            "restore_faces": self.restore_faces,
            "face_restoration_model": self.face_restoration_model,
            "sd_model_hash": self.sd_model_hash,
            "seed_resize_from_w": self.seed_resize_from_w,
            "seed_resize_from_h": self.seed_resize_from_h,
            "denoising_strength": self.denoising_strength,
            "extra_generation_params": self.extra_generation_params,
            "index_of_first_image": self.index_of_first_image,
276
            "infotexts": self.infotexts,
M
Milly 已提交
277
            "styles": self.styles,
M
Milly 已提交
278
            "job_timestamp": self.job_timestamp,
279
            "clip_skip": self.clip_skip,
280 281 282 283
        }

        return json.dumps(obj)

284 285 286 287
    def infotext(self,  p: StableDiffusionProcessing, index):
        return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)


288 289 290 291
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
292 293 294 295 296 297
    dot = (low_norm*high_norm).sum(1)

    if dot.mean() > 0.9995:
        return low * val + high * (1 - val)

    omega = torch.acos(dot)
298 299 300
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
    return res
301

302

303
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
304
    xs = []
305

306 307
    # if we have multiple seeds, this means we are working with batch size>1; this then
    # enables the generation of additional tensors with noise that the sampler will use during its processing.
S
Steve Eberhardt 已提交
308
    # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
309
    # produce the same images as with two batches [100], [101].
A
AUTOMATIC 已提交
310
    if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
311 312 313 314
        sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
    else:
        sampler_noises = None

315 316 317 318 319 320
    for i, seed in enumerate(seeds):
        noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)

        subnoise = None
        if subseeds is not None:
            subseed = 0 if i >= len(subseeds) else subseeds[i]
E
Elias Oenal 已提交
321

A
AUTOMATIC 已提交
322
            subnoise = devices.randn(subseed, noise_shape)
323 324 325

        # randn results depend on device; gpu and cpu get different results for same seed;
        # the way I see it, it's better to do this on CPU, so that everyone gets same result;
326
        # but the original script had it like this, so I do not dare change it for now because
327
        # it will break everyone's seeds.
A
AUTOMATIC 已提交
328
        noise = devices.randn(seed, noise_shape)
329 330 331 332 333

        if subnoise is not None:
            noise = slerp(subseed_strength, noise, subnoise)

        if noise_shape != shape:
A
AUTOMATIC 已提交
334 335
            x = devices.randn(seed, shape)
            dx = (shape[2] - noise_shape[2]) // 2
336 337 338 339 340 341 342 343 344 345 346
            dy = (shape[1] - noise_shape[1]) // 2
            w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
            h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
            tx = 0 if dx < 0 else dx
            ty = 0 if dy < 0 else dy
            dx = max(-dx, 0)
            dy = max(-dy, 0)

            x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
            noise = x

347 348
        if sampler_noises is not None:
            cnt = p.sampler.number_of_needed_noises(p)
349

A
AUTOMATIC 已提交
350 351 352
            if opts.eta_noise_seed_delta > 0:
                torch.manual_seed(seed + opts.eta_noise_seed_delta)

353 354
            for j in range(cnt):
                sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
355 356

        xs.append(noise)
357 358 359 360

    if sampler_noises is not None:
        p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]

361
    x = torch.stack(xs).to(shared.device)
362 363 364
    return x


A
AUTOMATIC 已提交
365 366 367 368 369 370 371
def decode_first_stage(model, x):
    with devices.autocast(disable=x.dtype == devices.dtype_vae):
        x = model.decode_first_stage(x)

    return x


372 373 374 375 376 377 378
def get_fixed_seed(seed):
    if seed is None or seed == '' or seed == -1:
        return int(random.randrange(4294967294))

    return seed


379
def fix_seed(p):
380 381
    p.seed = get_fixed_seed(p.seed)
    p.subseed = get_fixed_seed(p.subseed)
A
AUTOMATIC 已提交
382 383


384 385 386
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
    index = position_in_batch + iteration * p.batch_size

387
    clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
388

389 390
    generation_params = {
        "Steps": p.steps,
F
frostydad 已提交
391
        "Sampler": get_correct_sampler(p)[p.sampler_index].name,
392 393 394 395 396
        "CFG scale": p.cfg_scale,
        "Seed": all_seeds[index],
        "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
        "Size": f"{p.width}x{p.height}",
        "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
397
        "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
398
        "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
399 400 401 402 403 404
        "Batch size": (None if p.batch_size < 2 else p.batch_size),
        "Batch pos": (None if p.batch_size < 2 else position_in_batch),
        "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
        "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
        "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
        "Denoising strength": getattr(p, 'denoising_strength', None),
405
        "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
406
        "Clip skip": None if clip_skip <= 1 else clip_skip,
A
AUTOMATIC 已提交
407
        "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
408 409
    }

A
AUTOMATIC 已提交
410
    generation_params.update(p.extra_generation_params)
411

412
    generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
413 414 415

    negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""

416
    return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
417 418


419
def process_images(p: StableDiffusionProcessing) -> Processed:
420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
    stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}

    try:
        for k, v in p.override_settings.items():
            opts.data[k] = v  # we don't call onchange for simplicity which makes changing model, hypernet impossible

        res = process_images_inner(p)

    finally:
        for k, v in stored_opts.items():
            opts.data[k] = v

    return res


def process_images_inner(p: StableDiffusionProcessing) -> Processed:
436 437
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

438 439 440 441
    if type(p.prompt) == list:
        assert(len(p.prompt) > 0)
    else:
        assert p.prompt is not None
442

T
Trung Ngo 已提交
443 444 445 446
    with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
        processed = Processed(p, [], p.seed, "")
        file.write(processed.infotext(p, 0))

447
    devices.torch_gc()
448

449 450
    seed = get_fixed_seed(p.seed)
    subseed = get_fixed_seed(p.subseed)
451

452
    modules.sd_hijack.model_hijack.apply_circular(p.tiling)
453
    modules.sd_hijack.model_hijack.clear_comments()
454

455
    comments = {}
456

A
AUTOMATIC 已提交
457
    shared.prompt_styles.apply_styles(p)
A
AUTOMATIC 已提交
458 459

    if type(p.prompt) == list:
460
        p.all_prompts = p.prompt
461
    else:
462
        p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
A
AUTOMATIC 已提交
463

464
    if type(seed) == list:
465
        p.all_seeds = seed
A
AUTOMATIC 已提交
466
    else:
467
        p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
468

469
    if type(subseed) == list:
470
        p.all_subseeds = subseed
471
    else:
472
        p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
473 474

    def infotext(iteration=0, position_in_batch=0):
475
        return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
476

477
    if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
478
        model_hijack.embedding_db.load_textual_inversion_embeddings()
479

480
    if p.scripts is not None:
A
AUTOMATIC 已提交
481
        p.scripts.process(p)
482

483
    infotexts = []
484
    output_images = []
485

486
    with torch.no_grad(), p.sd_model.ema_scope():
A
AUTOMATIC 已提交
487
        with devices.autocast():
488
            p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
489

A
AUTOMATIC 已提交
490 491
        if state.job_count == -1:
            state.job_count = p.n_iter
A
AUTOMATIC 已提交
492

493
        for n in range(p.n_iter):
494 495 496
            if state.skipped:
                state.skipped = False
            
497 498 499
            if state.interrupted:
                break

500 501 502
            prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
            subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
503

A
AUTOMATIC 已提交
504
            if len(prompts) == 0:
505 506
                break

507
            with devices.autocast():
508
                uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
A
AUTOMATIC 已提交
509
                c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
510 511

            if len(model_hijack.comments) > 0:
512 513
                for comment in model_hijack.comments:
                    comments[comment] = 1
514 515

            if p.n_iter > 1:
516
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"
517

518
            with devices.autocast():
519
                samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
520

A
AUTOMATIC 已提交
521
            samples_ddim = samples_ddim.to(devices.dtype_vae)
A
AUTOMATIC 已提交
522
            x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
523 524
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

525 526 527 528 529 530 531
            del samples_ddim

            if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
                lowvram.send_everything_to_cpu()

            devices.torch_gc()

G
GRMrGecko 已提交
532
            if opts.filter_nsfw:
533 534
                import modules.safety as safety
                x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
G
GRMrGecko 已提交
535

536
            for i, x_sample in enumerate(x_samples_ddim):
537 538 539
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)

540
                if p.restore_faces:
541
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
542
                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
543

544
                    devices.torch_gc()
545

546 547
                    x_sample = modules.face_restoration.restore_faces(x_sample)
                    devices.torch_gc()
548

549
                image = Image.fromarray(x_sample)
V
Vladimir Repin 已提交
550

551
                if p.color_corrections is not None and i < len(p.color_corrections):
552
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
A
AUTOMATIC 已提交
553
                        image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
554
                        images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
555
                    image = apply_color_correction(p.color_corrections[i], image)
556

A
AUTOMATIC 已提交
557
                image = apply_overlay(image, p.paste_to, i, p.overlay_images)
558 559

                if opts.samples_save and not p.do_not_save_samples:
560
                    images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
561

562 563
                text = infotext(n, i)
                infotexts.append(text)
564 565
                if opts.enable_pnginfo:
                    image.info["parameters"] = text
566 567
                output_images.append(image)

568
            del x_samples_ddim 
A
AUTOMATIC 已提交
569

570
            devices.torch_gc()
571

572
            state.nextjob()
573

574 575
        p.color_corrections = None

576
        index_of_first_image = 0
577
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
578
        if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
A
AUTOMATIC 已提交
579
            grid = images.image_grid(output_images, p.batch_size)
580

581
            if opts.return_grid:
582 583
                text = infotext()
                infotexts.insert(0, text)
584 585
                if opts.enable_pnginfo:
                    grid.info["parameters"] = text
586
                output_images.insert(0, grid)
587
                index_of_first_image = 1
588 589

            if opts.grid_save:
590
                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
591

592
    devices.torch_gc()
A
AUTOMATIC 已提交
593 594 595 596 597 598 599

    res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)

    if p.scripts is not None:
        p.scripts.postprocess(p, res)

    return res
600 601 602 603


class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    sampler = None
A
AUTOMATIC 已提交
604

A
arcticfaded 已提交
605
    def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
A
AUTOMATIC 已提交
606 607 608
        super().__init__(**kwargs)
        self.enable_hr = enable_hr
        self.denoising_strength = denoising_strength
609 610
        self.firstphase_width = firstphase_width
        self.firstphase_height = firstphase_height
611 612
        self.truncate_x = 0
        self.truncate_y = 0
A
AUTOMATIC 已提交
613 614 615 616 617 618 619 620

    def init(self, all_prompts, all_seeds, all_subseeds):
        if self.enable_hr:
            if state.job_count == -1:
                state.job_count = self.n_iter * 2
            else:
                state.job_count = state.job_count * 2

621 622
            self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"

623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646
            if self.firstphase_width == 0 or self.firstphase_height == 0:
                desired_pixel_count = 512 * 512
                actual_pixel_count = self.width * self.height
                scale = math.sqrt(desired_pixel_count / actual_pixel_count)
                self.firstphase_width = math.ceil(scale * self.width / 64) * 64
                self.firstphase_height = math.ceil(scale * self.height / 64) * 64
                firstphase_width_truncated = int(scale * self.width)
                firstphase_height_truncated = int(scale * self.height)

            else:

                width_ratio = self.width / self.firstphase_width
                height_ratio = self.height / self.firstphase_height

                if width_ratio > height_ratio:
                    firstphase_width_truncated = self.firstphase_width
                    firstphase_height_truncated = self.firstphase_width * self.height / self.width
                else:
                    firstphase_width_truncated = self.firstphase_height * self.width / self.height
                    firstphase_height_truncated = self.firstphase_height

            self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
            self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f

647 648 649 650 651
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
        self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)

        if not self.enable_hr:
            x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
652
            samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
A
AUTOMATIC 已提交
653 654 655
            return samples

        x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
656
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
A
AUTOMATIC 已提交
657

658
        samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
A
AUTOMATIC 已提交
659

660 661
        if opts.use_scale_latent_for_hires_fix:
            samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
A
AUTOMATIC 已提交
662 663

        else:
664
            decoded_samples = decode_first_stage(self.sd_model, samples)
665
            lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
666

667 668 669 670 671 672 673 674 675 676 677 678 679 680
            batch_images = []
            for i, x_sample in enumerate(lowres_samples):
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)
                image = Image.fromarray(x_sample)
                image = images.resize_image(0, image, self.width, self.height)
                image = np.array(image).astype(np.float32) / 255.0
                image = np.moveaxis(image, 2, 0)
                batch_images.append(image)

            decoded_samples = torch.from_numpy(np.array(batch_images))
            decoded_samples = decoded_samples.to(shared.device)
            decoded_samples = 2. * decoded_samples - 1.

681
            samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
A
AUTOMATIC 已提交
682 683

        shared.state.nextjob()
684

A
AUTOMATIC 已提交
685 686
        self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)

A
AUTOMATIC 已提交
687
        noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
688 689 690 691

        # GC now before running the next img2img to prevent running out of memory
        x = None
        devices.torch_gc()
692

693 694 695 696 697 698
        image_conditioning = self.img2img_image_conditioning(
            decoded_samples, 
            samples, 
            decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
        )
        samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
A
AUTOMATIC 已提交
699 700

        return samples
701 702 703 704 705


class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    sampler = None

S
Stephen 已提交
706
    def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: Any=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs):
707 708 709 710 711 712 713
        super().__init__(**kwargs)

        self.init_images = init_images
        self.resize_mode: int = resize_mode
        self.denoising_strength: float = denoising_strength
        self.init_latent = None
        self.image_mask = mask
A
AUTOMATIC 已提交
714 715
        #self.image_unblurred_mask = None
        self.latent_mask = None
716 717 718 719
        self.mask_for_overlay = None
        self.mask_blur = mask_blur
        self.inpainting_fill = inpainting_fill
        self.inpaint_full_res = inpaint_full_res
720
        self.inpaint_full_res_padding = inpaint_full_res_padding
A
AUTOMATIC 已提交
721
        self.inpainting_mask_invert = inpainting_mask_invert
722 723
        self.mask = None
        self.nmask = None
724
        self.image_conditioning = None
725

A
AUTOMATIC 已提交
726
    def init(self, all_prompts, all_seeds, all_subseeds):
A
AUTOMATIC 已提交
727
        self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
728 729 730
        crop_region = None

        if self.image_mask is not None:
A
AUTOMATIC 已提交
731 732 733 734 735
            self.image_mask = self.image_mask.convert('L')

            if self.inpainting_mask_invert:
                self.image_mask = ImageOps.invert(self.image_mask)

A
AUTOMATIC 已提交
736 737
            #self.image_unblurred_mask = self.image_mask

738
            if self.mask_blur > 0:
A
AUTOMATIC 已提交
739
                self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
740 741 742 743

            if self.inpaint_full_res:
                self.mask_for_overlay = self.image_mask
                mask = self.image_mask.convert('L')
744
                crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
745
                crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
746 747 748 749 750 751 752
                x1, y1, x2, y2 = crop_region

                mask = mask.crop(crop_region)
                self.image_mask = images.resize_image(2, mask, self.width, self.height)
                self.paste_to = (x1, y1, x2-x1, y2-y1)
            else:
                self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
753
                np_mask = np.array(self.image_mask)
J
JJ 已提交
754
                np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
755
                self.mask_for_overlay = Image.fromarray(np_mask)
756 757 758

            self.overlay_images = []

759 760
        latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask

761 762 763
        add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
        if add_color_corrections:
            self.color_corrections = []
764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
        imgs = []
        for img in self.init_images:
            image = img.convert("RGB")

            if crop_region is None:
                image = images.resize_image(self.resize_mode, image, self.width, self.height)

            if self.image_mask is not None:
                image_masked = Image.new('RGBa', (image.width, image.height))
                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))

                self.overlay_images.append(image_masked.convert('RGBA'))

            if crop_region is not None:
                image = image.crop(crop_region)
                image = images.resize_image(2, image, self.width, self.height)

781 782
            if self.image_mask is not None:
                if self.inpainting_fill != 1:
783
                    image = masking.fill(image, latent_mask)
784

785
            if add_color_corrections:
786 787
                self.color_corrections.append(setup_color_correction(image))

788 789 790 791 792 793 794 795 796
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)

            imgs.append(image)

        if len(imgs) == 1:
            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size
797 798 799 800

            if self.color_corrections is not None and len(self.color_corrections) == 1:
                self.color_corrections = self.color_corrections * self.batch_size

801 802 803 804 805 806 807 808 809 810 811 812 813
        elif len(imgs) <= self.batch_size:
            self.batch_size = len(imgs)
            batch_images = np.array(imgs)
        else:
            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

        image = torch.from_numpy(batch_images)
        image = 2. * image - 1.
        image = image.to(shared.device)

        self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))

        if self.image_mask is not None:
814
            init_mask = latent_mask
A
AUTOMATIC 已提交
815
            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
A
AUTOMATIC 已提交
816
            latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
817
            latmask = latmask[0]
818
            latmask = np.around(latmask)
819 820 821 822 823
            latmask = np.tile(latmask[None], (4, 1, 1))

            self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
            self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)

A
AUTOMATIC 已提交
824
            # this needs to be fixed to be done in sample() using actual seeds for batches
825
            if self.inpainting_fill == 2:
A
AUTOMATIC 已提交
826
                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
827 828 829
            elif self.inpainting_fill == 3:
                self.init_latent = self.init_latent * self.mask

830
        self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
831

832

A
AUTOMATIC 已提交
833 834 835
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)

836
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
837 838 839 840

        if self.mask is not None:
            samples = samples * self.nmask + self.init_latent * self.mask

841 842 843
        del x
        devices.torch_gc()

844
        return samples