From c48c97991bff44917db4b40d4f6b22447a3459cb Mon Sep 17 00:00:00 2001 From: Molly Smith <112220543+molly-smith@users.noreply.github.com> Date: Thu, 6 Apr 2023 10:56:22 -0700 Subject: [PATCH] Fix for Diffusers 0.14.0 (#3142) cross attention kwargs and vae config for diffusers 0.14.0 --- deepspeed/model_implementations/diffusers/unet.py | 11 +++++++++-- deepspeed/model_implementations/diffusers/vae.py | 1 + 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/deepspeed/model_implementations/diffusers/unet.py b/deepspeed/model_implementations/diffusers/unet.py index 3ceb0e0a..6086d9fb 100644 --- a/deepspeed/model_implementations/diffusers/unet.py +++ b/deepspeed/model_implementations/diffusers/unet.py @@ -62,5 +62,12 @@ class DSUNet(CUDAGraph, torch.nn.Module): self.cuda_graph_created = True - def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True): - return self.unet(sample, timestamp, encoder_hidden_states, return_dict) + def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True, cross_attention_kwargs=None): + if cross_attention_kwargs: + return self.unet(sample, + timestamp, + encoder_hidden_states, + return_dict, + cross_attention_kwargs=cross_attention_kwargs) + else: + return self.unet(sample, timestamp, encoder_hidden_states, return_dict) diff --git a/deepspeed/model_implementations/diffusers/vae.py b/deepspeed/model_implementations/diffusers/vae.py index 8a2dd567..445a9843 100644 --- a/deepspeed/model_implementations/diffusers/vae.py +++ b/deepspeed/model_implementations/diffusers/vae.py @@ -12,6 +12,7 @@ class DSVAE(CUDAGraph, torch.nn.Module): def __init__(self, vae, enable_cuda_graph=True): super().__init__(enable_cuda_graph=enable_cuda_graph) self.vae = vae + self.config = vae.config self.device = self.vae.device self.dtype = self.vae.dtype self.vae.requires_grad_(requires_grad=False) -- GitLab