未验证 提交 08e46d6f 编写于 作者: U umiswing 提交者: GitHub

Fix select sdp for FA-2 (#56045)

上级 8d181e37
......@@ -81,7 +81,7 @@ def _math_attention(
def _select_sdp_cuda(head_dim):
if head_dim <= 128:
if head_dim <= 256:
return "flash_attn"
else:
return "mem_efficient"
......
......@@ -410,6 +410,17 @@ class TestFlashAttentionAPITest4(TestFlashAttentionAPI):
self.use_sdp_kernel = False
class TestFlashAttentionAPITest5(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (8, 1024, 16, 256)
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = False
class TestMathAttentionAPITest(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册