diff --git a/docs/diffusion/stable_diffusion/model/unet_attention.html b/docs/diffusion/stable_diffusion/model/unet_attention.html
index 0a7176168a26f9ac9b940d6f342de68542d504bc..e9041d74af9dfeba63513a53a44d8d4feff66edc 100644
--- a/docs/diffusion/stable_diffusion/model/unet_attention.html
+++ b/docs/diffusion/stable_diffusion/model/unet_attention.html
@@ -603,9 +603,9 @@
174 v = self.to_v(cond)
175
176 if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
-177 self.flash_attention(q, k, v)
+177 return self.flash_attention(q, k, v)
178 else:
-179 self.normal_attention(q, k, v)
+179 return self.normal_attention(q, k, v)
diff --git a/labml_nn/diffusion/stable_diffusion/model/unet_attention.py b/labml_nn/diffusion/stable_diffusion/model/unet_attention.py
index 4042d5002485ab5fc6b4ca3b060a6a17e8d38beb..77e018e61f2397a854a912988e9deffccd85fcdb 100644
--- a/labml_nn/diffusion/stable_diffusion/model/unet_attention.py
+++ b/labml_nn/diffusion/stable_diffusion/model/unet_attention.py
@@ -174,9 +174,9 @@ class CrossAttention(nn.Module):
v = self.to_v(cond)
if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
- self.flash_attention(q, k, v)
+ return self.flash_attention(q, k, v)
else:
- self.normal_attention(q, k, v)
+ return self.normal_attention(q, k, v)
def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""