未验证 提交 07e5552c 编写于 作者: S Shijie 提交者: GitHub

skip flash attn ut on Hopper (#55148)

* skip flash attn ut on Hopper

* minor change
上级 8b9b9400
......@@ -56,9 +56,25 @@ def attention_naive(q, k, v, causal=False):
return paddle.transpose(o, [0, 2, 1, 3])
is_sm75 = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 7
and paddle.device.cuda.get_device_capability()[1] == 5
)
is_sm8x = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 8
and paddle.device.cuda.get_device_capability()[1] >= 0
)
is_sm_supported = is_sm75 or is_sm8x
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11030,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3",
not core.is_compiled_with_cuda()
or get_cuda_version() < 11030
or not is_sm_supported,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
"and device's compute capability must be 7.5 or 8.x",
)
class TestFlashAttentionAPI(unittest.TestCase):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册