diff --git a/docs/diffusion/stable_diffusion/model/unet_attention.html b/docs/diffusion/stable_diffusion/model/unet_attention.html index 0a7176168a26f9ac9b940d6f342de68542d504bc..e9041d74af9dfeba63513a53a44d8d4feff66edc 100644 --- a/docs/diffusion/stable_diffusion/model/unet_attention.html +++ b/docs/diffusion/stable_diffusion/model/unet_attention.html @@ -603,9 +603,9 @@ 174 v = self.to_v(cond) 175 176 if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128: -177 self.flash_attention(q, k, v) +177 return self.flash_attention(q, k, v) 178 else: -179 self.normal_attention(q, k, v) +179 return self.normal_attention(q, k, v)
diff --git a/labml_nn/diffusion/stable_diffusion/model/unet_attention.py b/labml_nn/diffusion/stable_diffusion/model/unet_attention.py index 4042d5002485ab5fc6b4ca3b060a6a17e8d38beb..77e018e61f2397a854a912988e9deffccd85fcdb 100644 --- a/labml_nn/diffusion/stable_diffusion/model/unet_attention.py +++ b/labml_nn/diffusion/stable_diffusion/model/unet_attention.py @@ -174,9 +174,9 @@ class CrossAttention(nn.Module): v = self.to_v(cond) if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128: - self.flash_attention(q, k, v) + return self.flash_attention(q, k, v) else: - self.normal_attention(q, k, v) + return self.normal_attention(q, k, v) def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """