From 07e5552c7f4ec8805f5e38b38ea5267edd74bab9 Mon Sep 17 00:00:00 2001 From: Shijie <505749828@qq.com> Date: Thu, 6 Jul 2023 16:06:13 +0800 Subject: [PATCH] skip flash attn ut on Hopper (#55148) * skip flash attn ut on Hopper * minor change --- test/legacy_test/test_flash_attention.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index 8d6af72c3f4..da5febf4de3 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -56,9 +56,25 @@ def attention_naive(q, k, v, causal=False): return paddle.transpose(o, [0, 2, 1, 3]) +is_sm75 = ( + core.is_compiled_with_cuda() + and paddle.device.cuda.get_device_capability()[0] == 7 + and paddle.device.cuda.get_device_capability()[1] == 5 +) +is_sm8x = ( + core.is_compiled_with_cuda() + and paddle.device.cuda.get_device_capability()[0] == 8 + and paddle.device.cuda.get_device_capability()[1] >= 0 +) +is_sm_supported = is_sm75 or is_sm8x + + @unittest.skipIf( - not core.is_compiled_with_cuda() or get_cuda_version() < 11030, - "core is not compiled with CUDA and cuda version need larger than or equal to 11.3", + not core.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or not is_sm_supported, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" + "and device's compute capability must be 7.5 or 8.x", ) class TestFlashAttentionAPI(unittest.TestCase): def setUp(self): -- GitLab