diff --git a/paddle/fluid/operators/math/math_function_test.cu b/paddle/fluid/operators/math/math_function_test.cu index 442e62d563ebd40316d001914c93447c102cbf61..45628530864d84ae532527715eb6125cab3c1998 100644 --- a/paddle/fluid/operators/math/math_function_test.cu +++ b/paddle/fluid/operators/math/math_function_test.cu @@ -14,6 +14,8 @@ #include "gtest/gtest.h" #include "paddle/fluid/operators/math/math_function.h" +#include + void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size, const std::vector& data) { PADDLE_ENFORCE_EQ(size, data.size()); @@ -22,6 +24,15 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size, } } +bool is_fp16_supported(int device_id) { + cudaDeviceProp device_prop; + cudaDeviceProperties(&device_prop, device_id); + PADDLE_ENFORCE_EQ(cudaGetLastError(), cudaSuccess); + int compute_capability = device_prop.major * 10 + device_prop.minor; + std::cout << "compute_capability is " << compute_capability << std::endl; + return compute_capability >= 53; +} + TEST(math_function, notrans_mul_trans_fp32) { using namespace paddle::framework; using namespace paddle::platform; @@ -62,6 +73,10 @@ TEST(math_function, notrans_mul_trans_fp16) { using namespace paddle::framework; using namespace paddle::platform; + if (!is_fp16_supported(0)) { + return; + } + Tensor input1; Tensor input1_gpu; Tensor input2_gpu; @@ -139,6 +154,10 @@ TEST(math_function, trans_mul_notrans_fp16) { using namespace paddle::framework; using namespace paddle::platform; + if (!is_fp16_supported(0)) { + return; + } + Tensor input1; Tensor input1_gpu; Tensor input2_gpu; @@ -237,6 +256,10 @@ TEST(math_function, gemm_notrans_cublas_fp16) { using namespace paddle::framework; using namespace paddle::platform; + if (!is_fp16_supported(0)) { + return; + } + Tensor input1; Tensor input2; Tensor input3; @@ -344,6 +367,10 @@ TEST(math_function, gemm_trans_cublas_fp16) { using namespace paddle::framework; using namespace paddle::platform; + if (!is_fp16_supported(0)) { + return; + } + Tensor input1; Tensor input2; Tensor input3;