sd_samplers.py 7.6 KB
Newer Older
1
from collections import namedtuple
2
import numpy as np
3 4
import torch
import tqdm
5
from PIL import Image
6 7

import k_diffusion.sampling
A
AUTOMATIC 已提交
8 9
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
10 11 12 13

from modules.shared import opts, cmd_opts, state
import modules.shared as shared

A
AUTOMATIC 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31

SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases'])

samplers_k_diffusion = [
    ('Euler a', 'sample_euler_ancestral', ['k_euler_a']),
    ('Euler', 'sample_euler', ['k_euler']),
    ('LMS', 'sample_lms', ['k_lms']),
    ('Heun', 'sample_heun', ['k_heun']),
    ('DPM2', 'sample_dpm_2', ['k_dpm_2']),
    ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']),
]

samplers_data_k_diffusion = [
    SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases)
    for label, funcname, aliases in samplers_k_diffusion
    if hasattr(k_diffusion.sampling, funcname)
]

32
samplers = [
A
AUTOMATIC 已提交
33
    *samplers_data_k_diffusion,
A
AUTOMATIC 已提交
34 35
    SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
    SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
36 37 38 39
]
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']


40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
def sample_to_image(samples):
    x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0]
    x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
    x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
    x_sample = x_sample.astype(np.uint8)
    return Image.fromarray(x_sample)


def store_latent(decoded):
    state.current_latent = decoded

    if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
        if not shared.parallel_processing_allowed:
            shared.state.current_image = sample_to_image(decoded)


56 57 58 59 60
def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
    if sampler_wrapper.mask is not None:
        img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
        x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec

61 62 63 64
        store_latent(sampler_wrapper.init_latent * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec)

    else:
        store_latent(x_dec)
A
AUTOMATIC 已提交
65

66 67 68
    return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)


A
AUTOMATIC 已提交
69 70 71 72
def extended_tdqm(sequence, *args, desc=None, **kwargs):
    state.sampling_steps = len(sequence)
    state.sampling_step = 0

73
    for x in tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
A
AUTOMATIC 已提交
74 75 76 77 78 79
        if state.interrupted:
            break

        yield x

        state.sampling_step += 1
80
        shared.total_tqdm.update()
A
AUTOMATIC 已提交
81 82 83 84 85 86


ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)


87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
class VanillaStableDiffusionSampler:
    def __init__(self, constructor, sd_model):
        self.sampler = constructor(sd_model)
        self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else None
        self.mask = None
        self.nmask = None
        self.init_latent = None

    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
        t_enc = int(min(p.denoising_strength, 0.999) * p.steps)

        # existing code fails with cetin step counts, like 9
        try:
            self.sampler.make_schedule(ddim_num_steps=p.steps, verbose=False)
        except Exception:
            self.sampler.make_schedule(ddim_num_steps=p.steps+1, verbose=False)

        x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)

        self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
        self.mask = p.mask
        self.nmask = p.nmask
        self.init_latent = p.init_latent

        samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)

        return samples

    def sample(self, p, x, conditioning, unconditional_conditioning):
        samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
        return samples_ddim


class CFGDenoiser(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.inner_model = model
        self.mask = None
        self.nmask = None
        self.init_latent = None

    def forward(self, x, sigma, uncond, cond, cond_scale):
        if shared.batch_cond_uncond:
            x_in = torch.cat([x] * 2)
            sigma_in = torch.cat([sigma] * 2)
            cond_in = torch.cat([uncond, cond])
            uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
            denoised = uncond + (cond - uncond) * cond_scale
        else:
            uncond = self.inner_model(x, sigma, cond=uncond)
            cond = self.inner_model(x, sigma, cond=cond)
            denoised = uncond + (cond - uncond) * cond_scale

        if self.mask is not None:
            denoised = self.init_latent * self.mask + self.nmask * denoised

        return denoised


A
AUTOMATIC 已提交
146 147 148 149
def extended_trange(count, *args, **kwargs):
    state.sampling_steps = count
    state.sampling_step = 0

150
    for x in tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
151 152 153 154 155
        if state.interrupted:
            break

        yield x

A
AUTOMATIC 已提交
156
        state.sampling_step += 1
157
        shared.total_tqdm.update()
A
AUTOMATIC 已提交
158

159 160 161 162 163 164 165 166

class KDiffusionSampler:
    def __init__(self, funcname, sd_model):
        self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model)
        self.funcname = funcname
        self.func = getattr(k_diffusion.sampling, self.funcname)
        self.model_wrap_cfg = CFGDenoiser(self.model_wrap)

A
AUTOMATIC 已提交
167
    def callback_state(self, d):
168
        store_latent(d["denoised"])
A
AUTOMATIC 已提交
169

170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
        t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
        sigmas = self.model_wrap.get_sigmas(p.steps)
        noise = noise * sigmas[p.steps - t_enc - 1]

        xi = x + noise

        sigma_sched = sigmas[p.steps - t_enc - 1:]

        self.model_wrap_cfg.mask = p.mask
        self.model_wrap_cfg.nmask = p.nmask
        self.model_wrap_cfg.init_latent = p.init_latent

        if hasattr(k_diffusion.sampling, 'trange'):
            k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)

A
AUTOMATIC 已提交
186
        return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
187 188 189 190 191 192 193 194

    def sample(self, p, x, conditioning, unconditional_conditioning):
        sigmas = self.model_wrap.get_sigmas(p.steps)
        x = x * sigmas[0]

        if hasattr(k_diffusion.sampling, 'trange'):
            k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)

A
AUTOMATIC 已提交
195
        samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
196 197
        return samples_ddim