From effebd417549574f2c8fcf1765b517c16c825e90 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Thu, 1 Jun 2023 16:56:02 +0800 Subject: [PATCH] [ROCM] fix multihead_matmul (#54108) * [ROCM] fix multihead_matmul * skip bf16 uts * update --- .../operators/math/bert_encoder_functor.cu | 119 ++++++++++-------- paddle/fluid/pybind/place.cc | 12 +- paddle/phi/kernels/funcs/math_cuda_utils.h | 77 +++++++----- test/legacy_test/test_activation_op.py | 12 +- test/legacy_test/test_scale_op.py | 3 + test/legacy_test/test_softmax_op.py | 3 +- 6 files changed, 135 insertions(+), 91 deletions(-) diff --git a/paddle/fluid/operators/math/bert_encoder_functor.cu b/paddle/fluid/operators/math/bert_encoder_functor.cu index 8c5225edafd..c6bb45a0943 100644 --- a/paddle/fluid/operators/math/bert_encoder_functor.cu +++ b/paddle/fluid/operators/math/bert_encoder_functor.cu @@ -261,9 +261,9 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_, const int batch_size, const int head_num, const int seq_len, - const unsigned mask) { + const phi::funcs::warp_mask_t mask) { int qk_offset = blockIdx.x * seq_len; - assert(blockDim.x % 32 == 0); + assert(blockDim.x % WARP_SIZE == 0); float tmp = threadIdx.x < seq_len ? static_cast(qk_buf_[threadIdx.x + qk_offset] + @@ -281,15 +281,16 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_, // HIP defined __HIP_NO_HALF_CONVERSIONS__ #ifndef __HIPCC__ // @{ Half kernel: SoftmaxKernelWithEltadd template <> -__global__ void SoftmaxKernelWithEltadd(half *qk_buf_, - const half *bias_qk_, - const int batch_size, - const int head_num, - const int seq_len, - const unsigned mask) { +__global__ void SoftmaxKernelWithEltadd( + half *qk_buf_, + const half *bias_qk_, + const int batch_size, + const int head_num, + const int seq_len, + const phi::funcs::warp_mask_t mask) { #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) int qk_offset = blockIdx.x * seq_len; - assert(blockDim.x % 32 == 0); + assert(blockDim.x % WARP_SIZE == 0); float tmp = threadIdx.x < seq_len ? static_cast(qk_buf_[threadIdx.x + qk_offset] + @@ -312,10 +313,10 @@ __global__ void SoftmaxKernelWithEltadd2(T *qk_buf_, const int batch_size, const int head_num, const int seq_len, - const unsigned mask) { + const phi::funcs::warp_mask_t mask) { int qk_offset = blockIdx.x * seq_len; int idx = threadIdx.x; - assert(blockDim.x % 32 == 0); + assert(blockDim.x % WARP_SIZE == 0); float2 tmp = idx < seq_len ? phi::funcs::ToFloat2(qk_buf_[idx + qk_offset] + @@ -335,19 +336,20 @@ __global__ void SoftmaxKernelWithEltadd2(T *qk_buf_, } template <> -__global__ void SoftmaxKernelWithEltadd2(half2 *qk_buf_, - const half2 *bias_qk_, - const int batch_size, - const int head_num, - const int seq_len, - const unsigned mask) { +__global__ void SoftmaxKernelWithEltadd2( + half2 *qk_buf_, + const half2 *bias_qk_, + const int batch_size, + const int head_num, + const int seq_len, + const phi::funcs::warp_mask_t mask) { // operator "+" of half only suppotted after cuda version 10.0 // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake #if defined(PADDLE_WITH_CUDA) && \ (CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000) int qk_offset = blockIdx.x * seq_len; int idx = threadIdx.x; - assert(blockDim.x % 32 == 0); + assert(blockDim.x % WARP_SIZE == 0); float2 tmp = idx < seq_len ? phi::funcs::ToFloat2(qk_buf_[idx + qk_offset] + @@ -368,14 +370,15 @@ __global__ void SoftmaxKernelWithEltadd2(half2 *qk_buf_, } template -__global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf, - const T *bias_qk, - const int batch_size, - const int head_num, - const int seq_len, - const unsigned mask) { +__global__ void SoftmaxKernelWithEltaddForLarge( + T *qk_buf, + const T *bias_qk, + const int batch_size, + const int head_num, + const int seq_len, + const phi::funcs::warp_mask_t mask) { int qk_offset = blockIdx.x * seq_len; - assert(blockDim.x % 32 == 0); + assert(blockDim.x % WARP_SIZE == 0); T stride_max = -1e20f; for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { @@ -406,15 +409,16 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf, // HIP defined __HIP_NO_HALF_CONVERSIONS__ #ifndef __HIPCC__ // @{ Half kernel: SoftmaxKernelWithEltadd template <> -__global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf, - const half *bias_qk, - const int batch_size, - const int head_num, - const int seq_len, - const unsigned mask) { +__global__ void SoftmaxKernelWithEltaddForLarge( + half *qk_buf, + const half *bias_qk, + const int batch_size, + const int head_num, + const int seq_len, + const phi::funcs::warp_mask_t mask) { #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) int qk_offset = blockIdx.x * seq_len; - assert(blockDim.x % 32 == 0); + assert(blockDim.x % WARP_SIZE == 0); float stride_max = -1e20f; for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { @@ -444,14 +448,15 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf, #endif // @} End Half kernel: SoftmaxKernelWithEltadd template -__global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_, - const T *bias_qk_, - const int batch_size, - const int head_num, - const int seq_len, - const unsigned mask) { +__global__ void SoftmaxKernelWithEltaddForLarge2( + T *qk_buf_, + const T *bias_qk_, + const int batch_size, + const int head_num, + const int seq_len, + const phi::funcs::warp_mask_t mask) { int qk_offset = blockIdx.x * seq_len; - assert(blockDim.x % 32 == 0); + assert(blockDim.x % WARP_SIZE == 0); float2 stride_max = make_float2(-1e20f, -1e20f); for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { @@ -484,19 +489,20 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_, } template <> -__global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_, - const half2 *bias_qk_, - const int batch_size, - const int head_num, - const int seq_len, - const unsigned mask) { +__global__ void SoftmaxKernelWithEltaddForLarge2( + half2 *qk_buf_, + const half2 *bias_qk_, + const int batch_size, + const int head_num, + const int seq_len, + const phi::funcs::warp_mask_t mask) { // operator "+" of half only suppotted after cuda version 10.0 // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake #if defined(PADDLE_WITH_CUDA) && \ (CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000) int qk_offset = blockIdx.x * seq_len; - assert(blockDim.x % 32 == 0); + assert(blockDim.x % WARP_SIZE == 0); float2 stride_max = make_float2(-1e20f, -1e20f); for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { @@ -637,7 +643,7 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_, } } - if (blockDim.x <= 32) { + if (blockDim.x <= WARP_SIZE) { phi::funcs::WarpReduceMaxV2(local_max); } else { phi::funcs::BlockReduceMaxV2(local_max); @@ -672,7 +678,7 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_, } } - if (blockDim.x <= 32) { + if (blockDim.x <= WARP_SIZE) { phi::funcs::WarpReduceSumV2(local_sum); } else { phi::funcs::BlockReduceSumV2(local_sum); @@ -761,7 +767,10 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context, // Align block to 32, also limit seq_len to max block size. if (seq_len % 2 == 0) { - block = (seq_len <= 64) ? 32 : ((seq_len + 63) / 64) * 32; + block = + (seq_len <= (2 * WARP_SIZE)) + ? WARP_SIZE + : ((seq_len + (2 * WARP_SIZE - 1)) / (2 * WARP_SIZE)) * WARP_SIZE; if (std::is_same::value) { SoftmaxKernelWithEltadd2<<>>( reinterpret_cast(qk_buf_), @@ -780,7 +789,7 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context, "cuda_arch<700")); #else dim3 grid(seq_len, batch_size, head_num); - dim3 block((seq_len / 2 + 31) / 32 * 32); + dim3 block((seq_len / 2 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE); SOFTMAX_KERNEL_WITH_MASK(1); #endif } else { @@ -794,7 +803,9 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context, } } } else { - block = (seq_len <= 32) ? 32 : ((seq_len + 31) / 32) * 32; + block = (seq_len <= WARP_SIZE) + ? WARP_SIZE + : ((seq_len + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; SoftmaxKernelWithEltadd<<>>( qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK); } @@ -820,7 +831,7 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context, "cuda_arch<700")); #else dim3 grid(seq_len, batch_size, head_num); - dim3 block((seq_len / 2 + 31) / 32 * 32); + dim3 block((seq_len / 2 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE); if (block.x > 0 && block.x <= 1024) { SOFTMAX_KERNEL_WITH_MASK(1); } else if (block.x <= 2048) { @@ -1176,8 +1187,8 @@ void SkipLayerNormFunctor::operator()(const int num, float eps, gpuStream_t stream) { int block = num / hidden; - if (hidden <= 32) { - const int threads = 32; + if (hidden <= WARP_SIZE) { + const int threads = WARP_SIZE; SkipLayerNormSmallKernel<<>>( num, hidden, input1, input2, output, scale, bias, eps); } else if (hidden <= 128) { diff --git a/paddle/fluid/pybind/place.cc b/paddle/fluid/pybind/place.cc index 477dfef6c72..7119aca5639 100644 --- a/paddle/fluid/pybind/place.cc +++ b/paddle/fluid/pybind/place.cc @@ -374,12 +374,20 @@ void BindPlace(pybind11::module &m) { // NOLINT .def("__str__", string::to_string); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool { - // Only GPUs with Compute Capability >= 53 support float16 + // Only GPUs with Compute Capability >= 53 support float16 +#ifdef PADDLE_WITH_HIP + return true; +#else return platform::GetGPUComputeCapability(place.device) >= 53; +#endif }); m.def("is_bfloat16_supported", [](const platform::CUDAPlace &place) -> bool { - // Only GPUs with Compute Capability >= 80 support bfloat16 + // Only GPUs with Compute Capability >= 80 support bfloat16 +#ifdef PADDLE_WITH_HIP + return false; +#else return platform::GetGPUComputeCapability(place.device) >= 80; +#endif }); #endif py::class_ xpuplace(m, "XPUPlace", R"DOC( diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index b493e2ac41b..1a6cca7f11a 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h +++ b/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -163,12 +163,28 @@ struct KeyValuePair { } }; +// NOTE(wangran16): The warpSize variable is of type int and contains the warp +// size (in threads) for the target device. Note that all current NVIDIA devices +// return 32 for this variable, and all current AMD devices return 64. Device +// code should use the warpSize built-in to develop portable wave-aware code. +#ifdef PADDLE_WITH_HIP +#define FINAL_MASK 0xffffffffffffffffUL +#define HALF_WARP 32 +#define WARP_SIZE 64 +#define WARP_SIZE_WIDTH 6 +#define WARP_SIZE_WIDTH_MASK 0x3f +typedef u_int64_t warp_mask_t; +#else #define FINAL_MASK 0xffffffff #define HALF_WARP 16 #define WARP_SIZE 32 +#define WARP_SIZE_WIDTH 5 +#define WARP_SIZE_WIDTH_MASK 0x1f +typedef unsigned warp_mask_t; +#endif template -__inline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) { +__inline__ __device__ T WarpReduceSum(T val, warp_mask_t lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) val += __shfl_xor_sync(lane_mask, val, mask, warpSize); @@ -180,10 +196,10 @@ __inline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) { /* Calculate the sum of all elements in a block */ template -__inline__ __device__ T BlockReduceSum(T val, unsigned mask) { +__inline__ __device__ T BlockReduceSum(T val, warp_mask_t mask) { static __shared__ T shared[WARP_SIZE]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; + int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK; + int wid = threadIdx.x >> WARP_SIZE_WIDTH; val = WarpReduceSum(val, mask); @@ -193,7 +209,7 @@ __inline__ __device__ T BlockReduceSum(T val, unsigned mask) { __syncthreads(); // align block_span to warpSize - int block_span = (blockDim.x + warpSize - 1) >> 5; + int block_span = (blockDim.x + warpSize - 1) >> WARP_SIZE_WIDTH; val = (lane < block_span) ? shared[lane] : static_cast(0.0f); val = WarpReduceSum(val, mask); @@ -208,8 +224,8 @@ __inline__ __device__ T WarpReduceSumV2(T *val) { #pragma unroll for (int i = 0; i < NUM; i++) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + for (int mask = HALF_WARP; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, WARP_SIZE); } return (T)(0.0f); } @@ -217,8 +233,8 @@ __inline__ __device__ T WarpReduceSumV2(T *val) { template __inline__ __device__ T BlockReduceSumV2(T *val) { static __shared__ T shared[NUM][33]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; + int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK; + int wid = threadIdx.x >> WARP_SIZE_WIDTH; WarpReduceSumV2(val); @@ -231,7 +247,7 @@ __inline__ __device__ T BlockReduceSumV2(T *val) { __syncthreads(); - bool is_mask = threadIdx.x < (blockDim.x / 32.f); + bool is_mask = threadIdx.x < (blockDim.x / static_cast(WARP_SIZE)); #pragma unroll for (int i = 0; i < NUM; i++) { val[i] = is_mask ? shared[i][lane] : (T)(0.0f); @@ -241,7 +257,7 @@ __inline__ __device__ T BlockReduceSumV2(T *val) { } template -__inline__ __device__ T WarpReduceMax(T val, unsigned lane_mask) { +__inline__ __device__ T WarpReduceMax(T val, warp_mask_t lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); @@ -256,14 +272,15 @@ __inline__ __device__ T WarpReduceMaxV2(T *val) { #pragma unroll for (int i = 0; i < NUM; i++) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); + for (int mask = HALF_WARP; mask > 0; mask >>= 1) + val[i] = + max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, WARP_SIZE)); } return (T)(0.0f); } template -__inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) { +__inline__ __device__ T WarpReduceMin(T val, warp_mask_t lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) val = min(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); @@ -276,7 +293,7 @@ __inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) { /* Calculate the minimum of all elements in a warp when actual quantity of * threads are less than warpSize.*/ template -__inline__ __device__ T PartialWarpReduceMin(T val, unsigned lane_mask) { +__inline__ __device__ T PartialWarpReduceMin(T val, warp_mask_t lane_mask) { #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) T warp_val = __shfl_sync(lane_mask, val, 0, warpSize); #else @@ -297,10 +314,10 @@ __inline__ __device__ T PartialWarpReduceMin(T val, unsigned lane_mask) { /* Calculate the maximum of all elements in a block */ template -__inline__ __device__ T BlockReduceMax(T val, unsigned mask) { +__inline__ __device__ T BlockReduceMax(T val, warp_mask_t mask) { static __shared__ T shared[WARP_SIZE]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; + int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK; + int wid = threadIdx.x >> WARP_SIZE_WIDTH; val = WarpReduceMax(val, mask); @@ -309,7 +326,7 @@ __inline__ __device__ T BlockReduceMax(T val, unsigned mask) { __syncthreads(); // align block_span to warpSize - int block_span = (blockDim.x + warpSize - 1) >> 5; + int block_span = (blockDim.x + warpSize - 1) >> WARP_SIZE_WIDTH; val = (lane < block_span) ? shared[lane] : -1e10f; val = WarpReduceMax(val, mask); @@ -318,9 +335,9 @@ __inline__ __device__ T BlockReduceMax(T val, unsigned mask) { template __inline__ __device__ T BlockReduceMaxV2(T *val) { - static __shared__ T shared[32][NUM]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx + static __shared__ T shared[WARP_SIZE][NUM]; + int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK; // in-warp idx + int wid = threadIdx.x >> WARP_SIZE_WIDTH; // warp idx WarpReduceMaxV2(val); // get maxx in each warp @@ -335,7 +352,7 @@ __inline__ __device__ T BlockReduceMaxV2(T *val) { // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent // blockDim.x is not divided by 32 - bool is_mask = threadIdx.x < (blockDim.x / 32.f); + bool is_mask = threadIdx.x < (blockDim.x / static_cast(WARP_SIZE)); #pragma unroll for (int i = 0; i < NUM; i++) { val[i] = is_mask ? shared[lane][i] : (T)-1e20f; @@ -347,17 +364,17 @@ __inline__ __device__ T BlockReduceMaxV2(T *val) { /* Calculate the minimum of all elements in a block */ template -__inline__ __device__ T BlockReduceMin(T val, unsigned mask) { +__inline__ __device__ T BlockReduceMin(T val, warp_mask_t mask) { static __shared__ T shared[WARP_SIZE]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; + int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK; + int wid = threadIdx.x >> WARP_SIZE_WIDTH; val = WarpReduceMin(val, mask); if (lane == 0) shared[wid] = val; __syncthreads(); // align block_span to warpSize - int block_span = (blockDim.x + warpSize - 1) >> 5; + int block_span = (blockDim.x + warpSize - 1) >> WARP_SIZE_WIDTH; val = (lane < block_span) ? shared[lane] : 1e10f; val = WarpReduceMin(val, mask); @@ -367,11 +384,11 @@ __inline__ __device__ T BlockReduceMin(T val, unsigned mask) { /* Calculate the minimum of all elements in a warp when actual quantity of * threads are less than warpSize.*/ template -__inline__ __device__ T PartialBlockReduceMin(T val, unsigned mask) { +__inline__ __device__ T PartialBlockReduceMin(T val, warp_mask_t mask) { static __shared__ T shared[WARP_SIZE]; static __shared__ T min_value; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; + int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK; + int wid = threadIdx.x >> WARP_SIZE_WIDTH; val = PartialWarpReduceMin(val, mask); if (lane == 0) shared[wid] = val; diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index f081d2d872c..999510e6733 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -278,7 +278,8 @@ class TestSigmoid_ZeroDim(TestSigmoid): @unittest.skipIf( - not core.is_compiled_with_cuda(), "core is not compiled with CUDA" + not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(), + "core is not compiled with CUDA", ) class TestSigmoidBF16(OpTest): def setUp(self): @@ -1237,7 +1238,8 @@ class TestSqrt_ZeroDim(TestSqrt): @unittest.skipIf( - not core.is_compiled_with_cuda(), "core is not compiled with CUDA" + not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(), + "core is not compiled with CUDA", ) class TestSqrtBF16(OpTest): def setUp(self): @@ -3060,7 +3062,8 @@ class TestSquare_ZeroDim(TestSquare): @unittest.skipIf( - not core.is_compiled_with_cuda(), "core is not compiled with CUDA" + not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(), + "core is not compiled with CUDA", ) class TestSquareBF16(OpTest): def setUp(self): @@ -3350,7 +3353,8 @@ class TestSoftplus_ZeroDim(TestSoftplus): @unittest.skipIf( - not core.is_compiled_with_cuda(), "core is not compiled with CUDA" + not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(), + "core is not compiled with CUDA", ) class TestSoftplusBF16(OpTest): def setUp(self): diff --git a/test/legacy_test/test_scale_op.py b/test/legacy_test/test_scale_op.py index 137a31c89fb..40712745dec 100644 --- a/test/legacy_test/test_scale_op.py +++ b/test/legacy_test/test_scale_op.py @@ -154,6 +154,9 @@ class TestScaleFp16Op(TestScaleOp): self.check_grad(["X"], "Out") +@unittest.skipIf( + not core.is_compiled_with_rocm(), "core is not compiled with CUDA" +) class TestScaleBF16Op(OpTest): def setUp(self): self.op_type = "scale" diff --git a/test/legacy_test/test_softmax_op.py b/test/legacy_test/test_softmax_op.py index 9dc1f8d5408..abf753ef8e0 100644 --- a/test/legacy_test/test_softmax_op.py +++ b/test/legacy_test/test_softmax_op.py @@ -392,7 +392,8 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp): @unittest.skipIf( - not core.is_compiled_with_cuda(), "core is not compiled with CUDA" + not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(), + "core is not compiled with CUDA", ) class TestSoftmaxBF16Op(OpTest): def setUp(self): -- GitLab