diff --git a/docs/diffusion/stable_diffusion/model/unet_attention.html b/docs/diffusion/stable_diffusion/model/unet_attention.html
index b7118b0a71a8525cd0660a155893bf997c61947b..4399937e9e9bf3dd2b822f16fe684dbd7cb3f399 100644
--- a/docs/diffusion/stable_diffusion/model/unet_attention.html
+++ b/docs/diffusion/stable_diffusion/model/unet_attention.html
@@ -602,7 +602,7 @@
173 k = self.to_k(cond)
174 v = self.to_v(cond)
175
-176 print('use flash', CrossAttention.use_flash_attention)
+176 print('use flash', CrossAttention.use_flash_attention, self.flash)
177
178 if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
179 return self.flash_attention(q, k, v)
diff --git a/labml_nn/diffusion/stable_diffusion/model/unet_attention.py b/labml_nn/diffusion/stable_diffusion/model/unet_attention.py
index 6c8b1b1a45b4d5032cb3b6e8cb66396162146fb2..79baa603dd3eb511d899c0081b9ae6c6e774a629 100644
--- a/labml_nn/diffusion/stable_diffusion/model/unet_attention.py
+++ b/labml_nn/diffusion/stable_diffusion/model/unet_attention.py
@@ -173,7 +173,7 @@ class CrossAttention(nn.Module):
k = self.to_k(cond)
v = self.to_v(cond)
- print('use flash', CrossAttention.use_flash_attention)
+ print('use flash', CrossAttention.use_flash_attention, self.flash)
if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128:
return self.flash_attention(q, k, v)