processing.py 30.5 KB
Newer Older
1 2 3 4 5 6 7 8 9
import json
import math
import os
import sys

import torch
import numpy as np
from PIL import Image, ImageFilter, ImageOps
import random
10 11
import cv2
from skimage import exposure
12

13
import modules.sd_hijack
J
Jairo Correa 已提交
14
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
15 16 17
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
A
AUTOMATIC 已提交
18
import modules.face_restoration
19
import modules.images as images
A
AUTOMATIC 已提交
20
import modules.styles
R
Robin Fernandes 已提交
21
import logging
22

23

24 25 26 27 28
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8


29
def setup_color_correction(image):
R
Robin Fernandes 已提交
30
    logging.info("Calibrating color correction.")
31 32 33 34 35
    correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
    return correction_target


def apply_color_correction(correction, image):
R
Robin Fernandes 已提交
36
    logging.info("Applying color correction.")
37 38 39 40 41 42 43 44 45 46 47 48
    image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
        cv2.cvtColor(
            np.asarray(image),
            cv2.COLOR_RGB2LAB
        ),
        correction,
        channel_axis=2
    ), cv2.COLOR_LAB2RGB).astype("uint8"))

    return image


49
class StableDiffusionProcessing:
50
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
51 52 53 54
        self.sd_model = sd_model
        self.outpath_samples: str = outpath_samples
        self.outpath_grids: str = outpath_grids
        self.prompt: str = prompt
A
AUTOMATIC 已提交
55
        self.prompt_for_display: str = None
56
        self.negative_prompt: str = (negative_prompt or "")
57
        self.styles: list = styles or []
58
        self.seed: int = seed
59 60 61 62
        self.subseed: int = subseed
        self.subseed_strength: float = subseed_strength
        self.seed_resize_from_h: int = seed_resize_from_h
        self.seed_resize_from_w: int = seed_resize_from_w
63 64 65 66 67 68 69
        self.sampler_index: int = sampler_index
        self.batch_size: int = batch_size
        self.n_iter: int = n_iter
        self.steps: int = steps
        self.cfg_scale: float = cfg_scale
        self.width: int = width
        self.height: int = height
A
AUTOMATIC 已提交
70
        self.restore_faces: bool = restore_faces
71
        self.tiling: bool = tiling
72 73
        self.do_not_save_samples: bool = do_not_save_samples
        self.do_not_save_grid: bool = do_not_save_grid
A
AUTOMATIC 已提交
74
        self.extra_generation_params: dict = extra_generation_params or {}
75
        self.overlay_images = overlay_images
76
        self.eta = eta
77
        self.paste_to = None
78
        self.color_corrections = None
A
AUTOMATIC 已提交
79
        self.denoising_strength: float = 0
80
        self.sampler_noise_scheduler_override = None
81 82 83
        self.ddim_discretize = opts.ddim_discretize
        self.s_churn = opts.s_churn
        self.s_tmin = opts.s_tmin
84
        self.s_tmax = float('inf')  # not representable as a standard ui option
85
        self.s_noise = opts.s_noise
86

87 88 89 90 91 92
        if not seed_enable_extras:
            self.subseed = -1
            self.subseed_strength = 0
            self.seed_resize_from_h = 0
            self.seed_resize_from_w = 0

A
AUTOMATIC 已提交
93
    def init(self, all_prompts, all_seeds, all_subseeds):
94 95
        pass

A
AUTOMATIC 已提交
96
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
97 98 99 100
        raise NotImplementedError()


class Processed:
101
    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
102 103
        self.images = images_list
        self.prompt = p.prompt
104
        self.negative_prompt = p.negative_prompt
105
        self.seed = seed
106 107
        self.subseed = subseed
        self.subseed_strength = p.subseed_strength
108 109 110
        self.info = info
        self.width = p.width
        self.height = p.height
111
        self.sampler_index = p.sampler_index
112
        self.sampler = sd_samplers.samplers[p.sampler_index].name
113 114
        self.cfg_scale = p.cfg_scale
        self.steps = p.steps
115 116 117 118 119 120 121 122 123
        self.batch_size = p.batch_size
        self.restore_faces = p.restore_faces
        self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
        self.sd_model_hash = shared.sd_model.sd_model_hash
        self.seed_resize_from_w = p.seed_resize_from_w
        self.seed_resize_from_h = p.seed_resize_from_h
        self.denoising_strength = getattr(p, 'denoising_strength', None)
        self.extra_generation_params = p.extra_generation_params
        self.index_of_first_image = index_of_first_image
M
Milly 已提交
124
        self.styles = p.styles
M
Milly 已提交
125
        self.job_timestamp = state.job_timestamp
126

C
C43H66N12O12S2 已提交
127
        self.eta = p.eta
128 129 130 131 132
        self.ddim_discretize = p.ddim_discretize
        self.s_churn = p.s_churn
        self.s_tmin = p.s_tmin
        self.s_tmax = p.s_tmax
        self.s_noise = p.s_noise
133
        self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
134 135 136 137 138 139 140 141
        self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
        self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
        self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
        self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1

        self.all_prompts = all_prompts or [self.prompt]
        self.all_seeds = all_seeds or [self.seed]
        self.all_subseeds = all_subseeds or [self.subseed]
142
        self.infotexts = infotexts or [info]
143

144

145 146
    def js(self):
        obj = {
147 148 149 150 151 152 153
            "prompt": self.prompt,
            "all_prompts": self.all_prompts,
            "negative_prompt": self.negative_prompt,
            "seed": self.seed,
            "all_seeds": self.all_seeds,
            "subseed": self.subseed,
            "all_subseeds": self.all_subseeds,
154
            "subseed_strength": self.subseed_strength,
155 156
            "width": self.width,
            "height": self.height,
157
            "sampler_index": self.sampler_index,
158 159 160
            "sampler": self.sampler,
            "cfg_scale": self.cfg_scale,
            "steps": self.steps,
161 162 163 164 165 166 167 168 169
            "batch_size": self.batch_size,
            "restore_faces": self.restore_faces,
            "face_restoration_model": self.face_restoration_model,
            "sd_model_hash": self.sd_model_hash,
            "seed_resize_from_w": self.seed_resize_from_w,
            "seed_resize_from_h": self.seed_resize_from_h,
            "denoising_strength": self.denoising_strength,
            "extra_generation_params": self.extra_generation_params,
            "index_of_first_image": self.index_of_first_image,
170
            "infotexts": self.infotexts,
M
Milly 已提交
171
            "styles": self.styles,
M
Milly 已提交
172
            "job_timestamp": self.job_timestamp,
173 174 175 176
        }

        return json.dumps(obj)

177 178 179 180
    def infotext(self,  p: StableDiffusionProcessing, index):
        return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)


181 182 183 184
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
185 186 187 188 189 190
    dot = (low_norm*high_norm).sum(1)

    if dot.mean() > 0.9995:
        return low * val + high * (1 - val)

    omega = torch.acos(dot)
191 192 193
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
    return res
194

195

196
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
197
    xs = []
198

199 200
    # if we have multiple seeds, this means we are working with batch size>1; this then
    # enables the generation of additional tensors with noise that the sampler will use during its processing.
S
Steve Eberhardt 已提交
201
    # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
202
    # produce the same images as with two batches [100], [101].
203
    if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
204 205 206 207
        sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
    else:
        sampler_noises = None

208 209 210 211 212 213
    for i, seed in enumerate(seeds):
        noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)

        subnoise = None
        if subseeds is not None:
            subseed = 0 if i >= len(subseeds) else subseeds[i]
E
Elias Oenal 已提交
214

A
AUTOMATIC 已提交
215
            subnoise = devices.randn(subseed, noise_shape)
216 217 218

        # randn results depend on device; gpu and cpu get different results for same seed;
        # the way I see it, it's better to do this on CPU, so that everyone gets same result;
219
        # but the original script had it like this, so I do not dare change it for now because
220
        # it will break everyone's seeds.
A
AUTOMATIC 已提交
221
        noise = devices.randn(seed, noise_shape)
222 223 224 225 226

        if subnoise is not None:
            noise = slerp(subseed_strength, noise, subnoise)

        if noise_shape != shape:
A
AUTOMATIC 已提交
227 228
            x = devices.randn(seed, shape)
            dx = (shape[2] - noise_shape[2]) // 2
229 230 231 232 233 234 235 236 237 238 239
            dy = (shape[1] - noise_shape[1]) // 2
            w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
            h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
            tx = 0 if dx < 0 else dx
            ty = 0 if dy < 0 else dy
            dx = max(-dx, 0)
            dy = max(-dy, 0)

            x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
            noise = x

240 241
        if sampler_noises is not None:
            cnt = p.sampler.number_of_needed_noises(p)
242

243 244
            for j in range(cnt):
                sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
245 246

        xs.append(noise)
247 248 249 250

    if sampler_noises is not None:
        p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]

251
    x = torch.stack(xs).to(shared.device)
252 253 254
    return x


255 256 257 258 259 260 261
def get_fixed_seed(seed):
    if seed is None or seed == '' or seed == -1:
        return int(random.randrange(4294967294))

    return seed


262
def fix_seed(p):
263 264
    p.seed = get_fixed_seed(p.seed)
    p.subseed = get_fixed_seed(p.subseed)
A
AUTOMATIC 已提交
265 266


267 268 269 270 271
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
    index = position_in_batch + iteration * p.batch_size

    generation_params = {
        "Steps": p.steps,
272
        "Sampler": sd_samplers.samplers[p.sampler_index].name,
273 274 275 276 277 278 279 280 281 282 283
        "CFG scale": p.cfg_scale,
        "Seed": all_seeds[index],
        "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
        "Size": f"{p.width}x{p.height}",
        "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
        "Batch size": (None if p.batch_size < 2 else p.batch_size),
        "Batch pos": (None if p.batch_size < 2 else position_in_batch),
        "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
        "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
        "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
        "Denoising strength": getattr(p, 'denoising_strength', None),
284
        "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
285 286
    }

A
AUTOMATIC 已提交
287
    generation_params.update(p.extra_generation_params)
288 289 290 291 292

    generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])

    negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""

293
    return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
294 295


296 297 298
def process_images(p: StableDiffusionProcessing) -> Processed:
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

299 300 301 302
    if type(p.prompt) == list:
        assert(len(p.prompt) > 0)
    else:
        assert p.prompt is not None
303

304
    devices.torch_gc()
305

306 307
    seed = get_fixed_seed(p.seed)
    subseed = get_fixed_seed(p.subseed)
308

309 310 311 312 313
    if p.outpath_samples is not None:
        os.makedirs(p.outpath_samples, exist_ok=True)

    if p.outpath_grids is not None:
        os.makedirs(p.outpath_grids, exist_ok=True)
314

315
    modules.sd_hijack.model_hijack.apply_circular(p.tiling)
316
    modules.sd_hijack.model_hijack.clear_comments()
317

318
    comments = {}
319

A
AUTOMATIC 已提交
320
    shared.prompt_styles.apply_styles(p)
A
AUTOMATIC 已提交
321 322 323

    if type(p.prompt) == list:
        all_prompts = p.prompt
324
    else:
A
AUTOMATIC 已提交
325
        all_prompts = p.batch_size * p.n_iter * [p.prompt]
A
AUTOMATIC 已提交
326

327 328
    if type(seed) == list:
        all_seeds = seed
A
AUTOMATIC 已提交
329
    else:
330
        all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
331

332 333
    if type(subseed) == list:
        all_subseeds = subseed
334
    else:
335
        all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
336 337

    def infotext(iteration=0, position_in_batch=0):
338
        return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
339 340

    if os.path.exists(cmd_opts.embeddings_dir):
341
        model_hijack.embedding_db.load_textual_inversion_embeddings()
342

343
    infotexts = []
344
    output_images = []
345 346

    with torch.no_grad():
A
AUTOMATIC 已提交
347 348
        with devices.autocast():
            p.init(all_prompts, all_seeds, all_subseeds)
349

A
AUTOMATIC 已提交
350 351
        if state.job_count == -1:
            state.job_count = p.n_iter
A
AUTOMATIC 已提交
352

353
        for n in range(p.n_iter):
354 355 356
            if state.skipped:
                state.skipped = False
            
357 358 359 360 361
            if state.interrupted:
                break

            prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
362
            subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
363

364 365 366
            if (len(prompts) == 0):
                break

A
AUTOMATIC 已提交
367 368
            #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
            #c = p.sd_model.get_learned_conditioning(prompts)
369
            with devices.autocast():
370
                uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
A
AUTOMATIC 已提交
371
                c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
372 373

            if len(model_hijack.comments) > 0:
374 375
                for comment in model_hijack.comments:
                    comments[comment] = 1
376 377

            if p.n_iter > 1:
378
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"
379

380
            with devices.autocast():
381
                samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
382

T
Trung Ngo 已提交
383
            if state.interrupted or state.skipped:
A
AUTOMATIC 已提交
384 385 386 387

                # if we are interruped, sample returns just noise
                # use the image collected previously in sampler loop
                samples_ddim = shared.state.current_latent
388

389 390
            samples_ddim = samples_ddim.to(devices.dtype)

391 392 393
            x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

394 395 396 397 398 399 400
            del samples_ddim

            if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
                lowvram.send_everything_to_cpu()

            devices.torch_gc()

G
GRMrGecko 已提交
401
            if opts.filter_nsfw:
402 403
                import modules.safety as safety
                x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
G
GRMrGecko 已提交
404

405
            for i, x_sample in enumerate(x_samples_ddim):
406 407 408
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)

409
                if p.restore_faces:
410
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
411
                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
412

413
                    devices.torch_gc()
414

415 416
                    x_sample = modules.face_restoration.restore_faces(x_sample)
                    devices.torch_gc()
417

418 419
                image = Image.fromarray(x_sample)

420
                if p.color_corrections is not None and i < len(p.color_corrections):
421
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
R
Robin Fernandes 已提交
422
                        images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
423
                    image = apply_color_correction(p.color_corrections[i], image)
424

425 426 427 428 429 430 431 432 433 434 435 436 437 438 439
                if p.overlay_images is not None and i < len(p.overlay_images):
                    overlay = p.overlay_images[i]

                    if p.paste_to is not None:
                        x, y, w, h = p.paste_to
                        base_image = Image.new('RGBA', (overlay.width, overlay.height))
                        image = images.resize_image(1, image, w, h)
                        base_image.paste(image, (x, y))
                        image = base_image

                    image = image.convert('RGBA')
                    image.alpha_composite(overlay)
                    image = image.convert('RGB')

                if opts.samples_save and not p.do_not_save_samples:
440
                    images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
441

442 443 444
                text = infotext(n, i)
                infotexts.append(text)
                image.info["parameters"] = text
445 446
                output_images.append(image)

447
            del x_samples_ddim 
A
AUTOMATIC 已提交
448

449
            devices.torch_gc()
450

451
            state.nextjob()
452

453 454
        p.color_corrections = None

455
        index_of_first_image = 0
456
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
457
        if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
A
AUTOMATIC 已提交
458
            grid = images.image_grid(output_images, p.batch_size)
459

460
            if opts.return_grid:
461 462 463
                text = infotext()
                infotexts.insert(0, text)
                grid.info["parameters"] = text
464
                output_images.insert(0, grid)
465
                index_of_first_image = 1
466 467

            if opts.grid_save:
468
                images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
469

470
    devices.torch_gc()
471
    return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
472 473 474 475


class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    sampler = None
A
AUTOMATIC 已提交
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503
    firstphase_width = 0
    firstphase_height = 0
    firstphase_width_truncated = 0
    firstphase_height_truncated = 0

    def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
        super().__init__(**kwargs)
        self.enable_hr = enable_hr
        self.scale_latent = scale_latent
        self.denoising_strength = denoising_strength

    def init(self, all_prompts, all_seeds, all_subseeds):
        if self.enable_hr:
            if state.job_count == -1:
                state.job_count = self.n_iter * 2
            else:
                state.job_count = state.job_count * 2

            desired_pixel_count = 512 * 512
            actual_pixel_count = self.width * self.height
            scale = math.sqrt(desired_pixel_count / actual_pixel_count)

            self.firstphase_width = math.ceil(scale * self.width / 64) * 64
            self.firstphase_height = math.ceil(scale * self.height / 64) * 64
            self.firstphase_width_truncated = int(scale * self.width)
            self.firstphase_height_truncated = int(scale * self.height)

    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
A
AUTOMATIC 已提交
504
        self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
A
AUTOMATIC 已提交
505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522

        if not self.enable_hr:
            x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
            samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
            return samples

        x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)

        truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
        truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f

        samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]

        if self.scale_latent:
            samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
        else:
            decoded_samples = self.sd_model.decode_first_stage(samples)
523

524
            if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
525 526 527 528 529 530 531 532 533
                decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
            else:
                lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)

                batch_images = []
                for i, x_sample in enumerate(lowres_samples):
                    x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                    x_sample = x_sample.astype(np.uint8)
                    image = Image.fromarray(x_sample)
A
AUTOMATIC 已提交
534
                    image = images.resize_image(0, image, self.width, self.height)
535 536 537 538 539 540 541 542
                    image = np.array(image).astype(np.float32) / 255.0
                    image = np.moveaxis(image, 2, 0)
                    batch_images.append(image)

                decoded_samples = torch.from_numpy(np.array(batch_images))
                decoded_samples = decoded_samples.to(shared.device)
                decoded_samples = 2. * decoded_samples - 1.

A
AUTOMATIC 已提交
543 544 545
            samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))

        shared.state.nextjob()
546

A
AUTOMATIC 已提交
547 548
        self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)

A
AUTOMATIC 已提交
549
        noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
550 551 552 553

        # GC now before running the next img2img to prevent running out of memory
        x = None
        devices.torch_gc()
554

A
AUTOMATIC 已提交
555 556 557
        samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)

        return samples
558 559 560 561 562


class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    sampler = None

563
    def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
564 565 566 567 568 569 570
        super().__init__(**kwargs)

        self.init_images = init_images
        self.resize_mode: int = resize_mode
        self.denoising_strength: float = denoising_strength
        self.init_latent = None
        self.image_mask = mask
A
AUTOMATIC 已提交
571 572
        #self.image_unblurred_mask = None
        self.latent_mask = None
573 574 575 576
        self.mask_for_overlay = None
        self.mask_blur = mask_blur
        self.inpainting_fill = inpainting_fill
        self.inpaint_full_res = inpaint_full_res
577
        self.inpaint_full_res_padding = inpaint_full_res_padding
A
AUTOMATIC 已提交
578
        self.inpainting_mask_invert = inpainting_mask_invert
579 580 581
        self.mask = None
        self.nmask = None

A
AUTOMATIC 已提交
582
    def init(self, all_prompts, all_seeds, all_subseeds):
A
AUTOMATIC 已提交
583
        self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
584 585 586
        crop_region = None

        if self.image_mask is not None:
A
AUTOMATIC 已提交
587 588 589 590 591
            self.image_mask = self.image_mask.convert('L')

            if self.inpainting_mask_invert:
                self.image_mask = ImageOps.invert(self.image_mask)

A
AUTOMATIC 已提交
592 593
            #self.image_unblurred_mask = self.image_mask

594
            if self.mask_blur > 0:
A
AUTOMATIC 已提交
595
                self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
596 597 598 599

            if self.inpaint_full_res:
                self.mask_for_overlay = self.image_mask
                mask = self.image_mask.convert('L')
600
                crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
601
                crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
602 603 604 605 606 607 608
                x1, y1, x2, y2 = crop_region

                mask = mask.crop(crop_region)
                self.image_mask = images.resize_image(2, mask, self.width, self.height)
                self.paste_to = (x1, y1, x2-x1, y2-y1)
            else:
                self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
609
                np_mask = np.array(self.image_mask)
J
JJ 已提交
610
                np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
611
                self.mask_for_overlay = Image.fromarray(np_mask)
612 613 614

            self.overlay_images = []

615 616
        latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask

617 618 619
        add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
        if add_color_corrections:
            self.color_corrections = []
620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636
        imgs = []
        for img in self.init_images:
            image = img.convert("RGB")

            if crop_region is None:
                image = images.resize_image(self.resize_mode, image, self.width, self.height)

            if self.image_mask is not None:
                image_masked = Image.new('RGBa', (image.width, image.height))
                image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))

                self.overlay_images.append(image_masked.convert('RGBA'))

            if crop_region is not None:
                image = image.crop(crop_region)
                image = images.resize_image(2, image, self.width, self.height)

637 638
            if self.image_mask is not None:
                if self.inpainting_fill != 1:
639
                    image = masking.fill(image, latent_mask)
640

641
            if add_color_corrections:
642 643
                self.color_corrections.append(setup_color_correction(image))

644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)

            imgs.append(image)

        if len(imgs) == 1:
            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size
        elif len(imgs) <= self.batch_size:
            self.batch_size = len(imgs)
            batch_images = np.array(imgs)
        else:
            raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

        image = torch.from_numpy(batch_images)
        image = 2. * image - 1.
        image = image.to(shared.device)

        self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))

        if self.image_mask is not None:
666
            init_mask = latent_mask
A
AUTOMATIC 已提交
667
            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
A
AUTOMATIC 已提交
668
            latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
669
            latmask = latmask[0]
670
            latmask = np.around(latmask)
671 672 673 674 675
            latmask = np.tile(latmask[None], (4, 1, 1))

            self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
            self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)

A
AUTOMATIC 已提交
676
            # this needs to be fixed to be done in sample() using actual seeds for batches
677
            if self.inpainting_fill == 2:
A
AUTOMATIC 已提交
678
                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
679 680 681
            elif self.inpainting_fill == 3:
                self.init_latent = self.init_latent * self.mask

A
AUTOMATIC 已提交
682 683 684
    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)

685 686 687 688 689
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)

        if self.mask is not None:
            samples = samples * self.nmask + self.init_latent * self.mask

690 691 692
        del x
        devices.torch_gc()

693
        return samples