From feff99f5e164afd0b9c4dd27264db37de75ce0cb Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Mon, 19 Jun 2023 14:00:11 +0800 Subject: [PATCH] update flash attn select (#54630) (#54716) --- python/paddle/nn/functional/flash_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 50e79c1ad17..7c9c1b201c8 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -81,7 +81,7 @@ def _math_attention( def _select_sdp_cuda(head_dim): - if head_dim < 128: + if head_dim <= 128: return "flash_attn" else: return "mem_efficient" -- GitLab