未验证 提交 c48c9799 编写于 作者: M Molly Smith 提交者: GitHub

Fix for Diffusers 0.14.0 (#3142)

cross attention kwargs and vae config for diffusers 0.14.0
上级 30d97705
......@@ -62,5 +62,12 @@ class DSUNet(CUDAGraph, torch.nn.Module):
self.cuda_graph_created = True
def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True):
return self.unet(sample, timestamp, encoder_hidden_states, return_dict)
def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True, cross_attention_kwargs=None):
if cross_attention_kwargs:
return self.unet(sample,
timestamp,
encoder_hidden_states,
return_dict,
cross_attention_kwargs=cross_attention_kwargs)
else:
return self.unet(sample, timestamp, encoder_hidden_states, return_dict)
......@@ -12,6 +12,7 @@ class DSVAE(CUDAGraph, torch.nn.Module):
def __init__(self, vae, enable_cuda_graph=True):
super().__init__(enable_cuda_graph=enable_cuda_graph)
self.vae = vae
self.config = vae.config
self.device = self.vae.device
self.dtype = self.vae.dtype
self.vae.requires_grad_(requires_grad=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册