提交 c833d5bf 编写于 作者: J Jay Smith

fixes #3449 - VRAM leak when switching to/from inpainting model

上级 828438b4
from collections import namedtuple from collections import namedtuple, deque
import numpy as np import numpy as np
from math import floor from math import floor
import torch import torch
...@@ -335,18 +335,28 @@ class CFGDenoiser(torch.nn.Module): ...@@ -335,18 +335,28 @@ class CFGDenoiser(torch.nn.Module):
class TorchHijack: class TorchHijack:
def __init__(self, kdiff_sampler): def __init__(self, sampler_noises):
self.kdiff_sampler = kdiff_sampler # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
# implementation.
self.sampler_noises = deque(sampler_noises)
def __getattr__(self, item): def __getattr__(self, item):
if item == 'randn_like': if item == 'randn_like':
return self.kdiff_sampler.randn_like return self.randn_like
if hasattr(torch, item): if hasattr(torch, item):
return getattr(torch, item) return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
def randn_like(self, x):
if self.sampler_noises:
noise = self.sampler_noises.popleft()
if noise.shape == x.shape:
return noise
return torch.randn_like(x)
class KDiffusionSampler: class KDiffusionSampler:
def __init__(self, funcname, sd_model): def __init__(self, funcname, sd_model):
...@@ -356,7 +366,6 @@ class KDiffusionSampler: ...@@ -356,7 +366,6 @@ class KDiffusionSampler:
self.extra_params = sampler_extra_params.get(funcname, []) self.extra_params = sampler_extra_params.get(funcname, [])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None self.sampler_noises = None
self.sampler_noise_index = 0
self.stop_at = None self.stop_at = None
self.eta = None self.eta = None
self.default_eta = 1.0 self.default_eta = 1.0
...@@ -389,26 +398,14 @@ class KDiffusionSampler: ...@@ -389,26 +398,14 @@ class KDiffusionSampler:
def number_of_needed_noises(self, p): def number_of_needed_noises(self, p):
return p.steps return p.steps
def randn_like(self, x):
noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
if noise is not None and x.shape == noise.shape:
res = noise
else:
res = torch.randn_like(x)
self.sampler_noise_index += 1
return res
def initialize(self, p): def initialize(self, p):
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap.step = 0 self.model_wrap.step = 0
self.sampler_noise_index = 0
self.eta = p.eta or opts.eta_ancestral self.eta = p.eta or opts.eta_ancestral
if self.sampler_noises is not None: if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self) k_diffusion.sampling.torch = TorchHijack(self.sampler_noises)
extra_params_kwargs = {} extra_params_kwargs = {}
for param_name in self.extra_params: for param_name in self.extra_params:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册