diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index b0b60f6ab4fa065f869478c30e6dbb1178482dd4..178e3ebc90fa7c29b5ae6d5c02c9ddfb3c103786 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -81,7 +81,7 @@ def _math_attention( def _select_sdp_cuda(head_dim): - if head_dim < 128: + if head_dim <= 128: return "flash_attn" else: return "mem_efficient"