sd_samplers.py 8.9 KB
Newer Older
1
from collections import namedtuple
2
import numpy as np
3 4
import torch
import tqdm
5
from PIL import Image
6 7

import k_diffusion.sampling
A
AUTOMATIC 已提交
8 9
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
A
AUTOMATIC 已提交
10
from modules import prompt_parser
11 12 13 14

from modules.shared import opts, cmd_opts, state
import modules.shared as shared

A
AUTOMATIC 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32

SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases'])

samplers_k_diffusion = [
    ('Euler a', 'sample_euler_ancestral', ['k_euler_a']),
    ('Euler', 'sample_euler', ['k_euler']),
    ('LMS', 'sample_lms', ['k_lms']),
    ('Heun', 'sample_heun', ['k_heun']),
    ('DPM2', 'sample_dpm_2', ['k_dpm_2']),
    ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']),
]

samplers_data_k_diffusion = [
    SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases)
    for label, funcname, aliases in samplers_k_diffusion
    if hasattr(k_diffusion.sampling, funcname)
]

33
samplers = [
A
AUTOMATIC 已提交
34
    *samplers_data_k_diffusion,
A
AUTOMATIC 已提交
35 36
    SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
    SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
37 38 39 40
]
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']


41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
def sample_to_image(samples):
    x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0]
    x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
    x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
    x_sample = x_sample.astype(np.uint8)
    return Image.fromarray(x_sample)


def store_latent(decoded):
    state.current_latent = decoded

    if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
        if not shared.parallel_processing_allowed:
            shared.state.current_image = sample_to_image(decoded)


57

A
AUTOMATIC 已提交
58 59 60 61
def extended_tdqm(sequence, *args, desc=None, **kwargs):
    state.sampling_steps = len(sequence)
    state.sampling_step = 0

62
    for x in tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
A
AUTOMATIC 已提交
63 64 65 66 67 68
        if state.interrupted:
            break

        yield x

        state.sampling_step += 1
69
        shared.total_tqdm.update()
A
AUTOMATIC 已提交
70 71 72 73 74 75


ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)


76 77 78
class VanillaStableDiffusionSampler:
    def __init__(self, constructor, sd_model):
        self.sampler = constructor(sd_model)
79
        self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
80 81 82
        self.mask = None
        self.nmask = None
        self.init_latent = None
A
AUTOMATIC 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
        self.step = 0

    def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
        cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
        unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)

        if self.mask is not None:
            img_orig = self.sampler.model.q_sample(self.init_latent, ts)
            x_dec = img_orig * self.mask + self.nmask * x_dec

        res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)

        if self.mask is not None:
            store_latent(self.init_latent * self.mask + self.nmask * res[1])
        else:
            store_latent(res[1])

        self.step += 1
        return res
102 103 104 105

    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
        t_enc = int(min(p.denoising_strength, 0.999) * p.steps)

106
        # existing code fails with cetain step counts, like 9
107 108 109 110 111 112 113
        try:
            self.sampler.make_schedule(ddim_num_steps=p.steps, verbose=False)
        except Exception:
            self.sampler.make_schedule(ddim_num_steps=p.steps+1, verbose=False)

        x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)

A
AUTOMATIC 已提交
114
        self.sampler.p_sample_ddim = self.p_sample_ddim_hook
115 116 117
        self.mask = p.mask
        self.nmask = p.nmask
        self.init_latent = p.init_latent
118
        self.step = 0
119 120 121 122 123 124

        samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)

        return samples

    def sample(self, p, x, conditioning, unconditional_conditioning):
125 126
        for fieldname in ['p_sample_ddim', 'p_sample_plms']:
            if hasattr(self.sampler, fieldname):
A
AUTOMATIC 已提交
127
                setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
128 129 130
        self.mask = None
        self.nmask = None
        self.init_latent = None
131
        self.step = 0
132

133 134 135 136 137 138
        # existing code fails with cetin step counts, like 9
        try:
            samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
        except Exception:
            samples_ddim, _ = self.sampler.sample(S=p.steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)

139 140 141 142 143 144 145 146 147 148
        return samples_ddim


class CFGDenoiser(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.inner_model = model
        self.mask = None
        self.nmask = None
        self.init_latent = None
A
AUTOMATIC 已提交
149
        self.step = 0
150 151

    def forward(self, x, sigma, uncond, cond, cond_scale):
A
AUTOMATIC 已提交
152 153 154
        cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
        uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)

155 156 157 158 159 160 161 162 163 164 165 166 167 168
        if shared.batch_cond_uncond:
            x_in = torch.cat([x] * 2)
            sigma_in = torch.cat([sigma] * 2)
            cond_in = torch.cat([uncond, cond])
            uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
            denoised = uncond + (cond - uncond) * cond_scale
        else:
            uncond = self.inner_model(x, sigma, cond=uncond)
            cond = self.inner_model(x, sigma, cond=cond)
            denoised = uncond + (cond - uncond) * cond_scale

        if self.mask is not None:
            denoised = self.init_latent * self.mask + self.nmask * denoised

A
AUTOMATIC 已提交
169 170
        self.step += 1

171 172 173
        return denoised


A
AUTOMATIC 已提交
174 175 176 177
def extended_trange(count, *args, **kwargs):
    state.sampling_steps = count
    state.sampling_step = 0

178
    for x in tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
179 180 181 182 183
        if state.interrupted:
            break

        yield x

A
AUTOMATIC 已提交
184
        state.sampling_step += 1
185
        shared.total_tqdm.update()
A
AUTOMATIC 已提交
186

187 188 189

class KDiffusionSampler:
    def __init__(self, funcname, sd_model):
A
AUTOMATIC 已提交
190
        self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
191 192 193 194
        self.funcname = funcname
        self.func = getattr(k_diffusion.sampling, self.funcname)
        self.model_wrap_cfg = CFGDenoiser(self.model_wrap)

A
AUTOMATIC 已提交
195
    def callback_state(self, d):
196
        store_latent(d["denoised"])
A
AUTOMATIC 已提交
197

198 199 200
    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
        t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
        sigmas = self.model_wrap.get_sigmas(p.steps)
201

202 203 204 205 206 207 208 209 210
        noise = noise * sigmas[p.steps - t_enc - 1]

        xi = x + noise

        sigma_sched = sigmas[p.steps - t_enc - 1:]

        self.model_wrap_cfg.mask = p.mask
        self.model_wrap_cfg.nmask = p.nmask
        self.model_wrap_cfg.init_latent = p.init_latent
211
        self.model_wrap.step = 0
212 213 214 215

        if hasattr(k_diffusion.sampling, 'trange'):
            k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)

A
AUTOMATIC 已提交
216
        return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
217 218 219 220 221

    def sample(self, p, x, conditioning, unconditional_conditioning):
        sigmas = self.model_wrap.get_sigmas(p.steps)
        x = x * sigmas[0]

222 223
        self.model_wrap_cfg.step = 0

224 225 226
        if hasattr(k_diffusion.sampling, 'trange'):
            k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)

A
AUTOMATIC 已提交
227
        samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
228 229
        return samples_ddim