From 160f25a9381a7d847b17507aa0f7885da9ba28b8 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sat, 24 Sep 2022 14:35:28 +0530 Subject: [PATCH] fix --- docs/diffusion/stable_diffusion/model/unet_attention.html | 4 ++-- labml_nn/diffusion/stable_diffusion/model/unet_attention.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/diffusion/stable_diffusion/model/unet_attention.html b/docs/diffusion/stable_diffusion/model/unet_attention.html index 0a717616..e9041d74 100644 --- a/docs/diffusion/stable_diffusion/model/unet_attention.html +++ b/docs/diffusion/stable_diffusion/model/unet_attention.html @@ -603,9 +603,9 @@ 174 v = self.to_v(cond) 175 176 if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128: -177 self.flash_attention(q, k, v) +177 return self.flash_attention(q, k, v) 178 else: -179 self.normal_attention(q, k, v) +179 return self.normal_attention(q, k, v)
diff --git a/labml_nn/diffusion/stable_diffusion/model/unet_attention.py b/labml_nn/diffusion/stable_diffusion/model/unet_attention.py index 4042d500..77e018e6 100644 --- a/labml_nn/diffusion/stable_diffusion/model/unet_attention.py +++ b/labml_nn/diffusion/stable_diffusion/model/unet_attention.py @@ -174,9 +174,9 @@ class CrossAttention(nn.Module): v = self.to_v(cond) if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128: - self.flash_attention(q, k, v) + return self.flash_attention(q, k, v) else: - self.normal_attention(q, k, v) + return self.normal_attention(q, k, v) def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """ -- GitLab