From f8ff8c0638997fd0aef217db1505598846f14782 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 22:09:40 +0300 Subject: [PATCH] merge errors --- modules/sd_samplers_cfg_denoiser.py | 23 +++++++++++++++++++++-- modules/sd_samplers_common.py | 6 +++++- modules/sd_samplers_kdiffusion.py | 17 ++++++++++++----- modules/sd_samplers_timesteps.py | 27 +++++++++++++++++---------- 4 files changed, 55 insertions(+), 18 deletions(-) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index d826222cd..a532e0137 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -38,16 +38,24 @@ class CFGDenoiser(torch.nn.Module): negative prompt. """ - def __init__(self, model, sampler): + def __init__(self, sampler): super().__init__() - self.inner_model = model + self.model_wrap = None self.mask = None self.nmask = None self.init_latent = None + self.steps = None self.step = 0 self.image_cfg_scale = None self.padded_cond_uncond = False self.sampler = sampler + self.model_wrap = None + self.p = None + + @property + def inner_model(self): + raise NotImplementedError() + def combine_denoised(self, x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] @@ -68,10 +76,21 @@ class CFGDenoiser(torch.nn.Module): def get_pred_x0(self, x_in, x_out, sigma): return x_out + def update_inner_model(self): + self.model_wrap = None + + c, uc = self.p.get_conds() + self.sampler.sampler_extra_args['cond'] = c + self.sampler.sampler_extra_args['uncond'] = uc + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException + if sd_samplers_common.apply_refiner(self): + cond = self.sampler.sampler_extra_args['cond'] + uncond = self.sampler.sampler_extra_args['uncond'] + # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, # so is_edit_model is set to False to support AND composition. is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 15f279707..fa3614ff0 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -202,8 +202,9 @@ class Sampler: self.conditioning_key = shared.sd_model.model.conditioning_key - self.model_wrap = None + self.p = None self.model_wrap_cfg = None + self.sampler_extra_args = None def callback_state(self, d): step = d['i'] @@ -215,6 +216,7 @@ class Sampler: shared.total_tqdm.update() def launch_sampling(self, steps, func): + self.model_wrap_cfg.steps = steps state.sampling_steps = steps state.sampling_step = 0 @@ -234,6 +236,8 @@ class Sampler: return p.steps def initialize(self, p) -> dict: + self.p = p + self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.step = 0 diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 3ff4b6345..95a43ceff 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -52,17 +52,24 @@ k_diffusion_scheduler = { } +class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) + + return self.model_wrap + + class KDiffusionSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model): - super().__init__(funcname) - self.extra_params = sampler_extra_params.get(funcname, []) self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self) + self.model_wrap_cfg = CFGDenoiserKDiffusion(self) + self.model_wrap = self.model_wrap_cfg.inner_model def get_sigmas(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index d89d0efb3..965e61c67 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -44,10 +44,10 @@ class CompVisTimestepsVDenoiser(torch.nn.Module): class CFGDenoiserTimesteps(CFGDenoiser): - def __init__(self, model, sampler): - super().__init__(model, sampler) + def __init__(self, sampler): + super().__init__(sampler) - self.alphas = model.inner_model.alphas_cumprod + self.alphas = shared.sd_model.alphas_cumprod def get_pred_x0(self, x_in, x_out, sigma): ts = int(sigma.item()) @@ -60,6 +60,14 @@ class CFGDenoiserTimesteps(CFGDenoiser): return pred_x0 + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser + self.model_wrap = denoiser(shared.sd_model) + + return self.model_wrap + class CompVisSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model): @@ -68,9 +76,7 @@ class CompVisSampler(sd_samplers_common.Sampler): self.eta_option_field = 'eta_ddim' self.eta_infotext_field = 'Eta DDIM' - denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser - self.model_wrap = denoiser(sd_model) - self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self) + self.model_wrap_cfg = CFGDenoiserTimesteps(self) def get_timesteps(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) @@ -106,7 +112,7 @@ class CompVisSampler(sd_samplers_common.Sampler): self.model_wrap_cfg.init_latent = x self.last_latent = x - extra_args = { + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, @@ -114,7 +120,7 @@ class CompVisSampler(sd_samplers_common.Sampler): 's_min_uncond': self.s_min_uncond } - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True @@ -132,13 +138,14 @@ class CompVisSampler(sd_samplers_common.Sampler): extra_params_kwargs['timesteps'] = timesteps self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale, 's_min_uncond': self.s_min_uncond - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + } + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True -- GitLab