diff --git a/docs/diffusion/stable_diffusion/model/unet_attention.html b/docs/diffusion/stable_diffusion/model/unet_attention.html index b7118b0a71a8525cd0660a155893bf997c61947b..4399937e9e9bf3dd2b822f16fe684dbd7cb3f399 100644 --- a/docs/diffusion/stable_diffusion/model/unet_attention.html +++ b/docs/diffusion/stable_diffusion/model/unet_attention.html @@ -602,7 +602,7 @@ 173 k = self.to_k(cond) 174 v = self.to_v(cond) 175 -176 print('use flash', CrossAttention.use_flash_attention) +176 print('use flash', CrossAttention.use_flash_attention, self.flash) 177 178 if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128: 179 return self.flash_attention(q, k, v) diff --git a/labml_nn/diffusion/stable_diffusion/model/unet_attention.py b/labml_nn/diffusion/stable_diffusion/model/unet_attention.py index 6c8b1b1a45b4d5032cb3b6e8cb66396162146fb2..79baa603dd3eb511d899c0081b9ae6c6e774a629 100644 --- a/labml_nn/diffusion/stable_diffusion/model/unet_attention.py +++ b/labml_nn/diffusion/stable_diffusion/model/unet_attention.py @@ -173,7 +173,7 @@ class CrossAttention(nn.Module): k = self.to_k(cond) v = self.to_v(cond) - print('use flash', CrossAttention.use_flash_attention) + print('use flash', CrossAttention.use_flash_attention, self.flash) if CrossAttention.use_flash_attention and self.flash is not None and cond is None and self.d_head <= 128: return self.flash_attention(q, k, v)