diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 50e79c1ad17f5c7945a6865dd00bf4a748e6faa8..7c9c1b201c89735da35095abb6710d95c389c536 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -81,7 +81,7 @@ def _math_attention( def _select_sdp_cuda(head_dim): - if head_dim < 128: + if head_dim <= 128: return "flash_attn" else: return "mem_efficient"