diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index 8d6af72c3f4e33d2b6df14f9ae6b633dc2fa04ef..da5febf4de3e6816c69743a77f5c297a4cce88db 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -56,9 +56,25 @@ def attention_naive(q, k, v, causal=False): return paddle.transpose(o, [0, 2, 1, 3]) +is_sm75 = ( + core.is_compiled_with_cuda() + and paddle.device.cuda.get_device_capability()[0] == 7 + and paddle.device.cuda.get_device_capability()[1] == 5 +) +is_sm8x = ( + core.is_compiled_with_cuda() + and paddle.device.cuda.get_device_capability()[0] == 8 + and paddle.device.cuda.get_device_capability()[1] >= 0 +) +is_sm_supported = is_sm75 or is_sm8x + + @unittest.skipIf( - not core.is_compiled_with_cuda() or get_cuda_version() < 11030, - "core is not compiled with CUDA and cuda version need larger than or equal to 11.3", + not core.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or not is_sm_supported, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" + "and device's compute capability must be 7.5 or 8.x", ) class TestFlashAttentionAPI(unittest.TestCase): def setUp(self):