From 49a45f71a70d7abbd00d4623c1ed4343ebb7f5dc Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Wed, 14 Jun 2023 14:22:32 +0800 Subject: [PATCH] update flash attn select (#54630) --- python/paddle/nn/functional/flash_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index b0b60f6ab4f..178e3ebc90f 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -81,7 +81,7 @@ def _math_attention( def _select_sdp_cuda(head_dim): - if head_dim < 128: + if head_dim <= 128: return "flash_attn" else: return "mem_efficient" -- GitLab