diff --git a/paddle/fluid/operators/math/math_function_test.cc b/paddle/fluid/operators/math/math_function_test.cc index 25a9d0111eee45b28adff012b705cbfa2407d2b6..39f766ff8aa4afc1ead77639cc07898cce22cded 100644 --- a/paddle/fluid/operators/math/math_function_test.cc +++ b/paddle/fluid/operators/math/math_function_test.cc @@ -14,11 +14,20 @@ #include "paddle/fluid/operators/math/math_function.h" #include "gtest/gtest.h" +#include + TEST(math_function, gemm_notrans_cblas) { paddle::framework::Tensor input1; paddle::framework::Tensor input2; paddle::framework::Tensor input3; + // fp16 GEMM in cublas requires GPU compute capability >= 53 + if (GetCUDAComputeCapability(0) >= 53) { + std::cout << "Compute capability is " << GetCUDAComputeCapability(0) + << std::endl; + return; + } + int m = 2; int n = 3; int k = 3; diff --git a/paddle/fluid/operators/math/math_function_test.cu b/paddle/fluid/operators/math/math_function_test.cu index 45628530864d84ae532527715eb6125cab3c1998..2316df40fe834d0a6adbfc4e4f430509cb04bb1f 100644 --- a/paddle/fluid/operators/math/math_function_test.cu +++ b/paddle/fluid/operators/math/math_function_test.cu @@ -24,15 +24,6 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size, } } -bool is_fp16_supported(int device_id) { - cudaDeviceProp device_prop; - cudaDeviceProperties(&device_prop, device_id); - PADDLE_ENFORCE_EQ(cudaGetLastError(), cudaSuccess); - int compute_capability = device_prop.major * 10 + device_prop.minor; - std::cout << "compute_capability is " << compute_capability << std::endl; - return compute_capability >= 53; -} - TEST(math_function, notrans_mul_trans_fp32) { using namespace paddle::framework; using namespace paddle::platform; @@ -73,7 +64,10 @@ TEST(math_function, notrans_mul_trans_fp16) { using namespace paddle::framework; using namespace paddle::platform; - if (!is_fp16_supported(0)) { + // fp16 GEMM in cublas requires GPU compute capability >= 53 + if (GetCUDAComputeCapability(0) >= 53) { + std::cout << "Compute capability is " << GetCUDAComputeCapability(0) + << std::endl; return; } @@ -154,7 +148,8 @@ TEST(math_function, trans_mul_notrans_fp16) { using namespace paddle::framework; using namespace paddle::platform; - if (!is_fp16_supported(0)) { + // fp16 GEMM in cublas requires GPU compute capability >= 53 + if (GetCUDAComputeCapability(0) >= 53) { return; } @@ -256,7 +251,8 @@ TEST(math_function, gemm_notrans_cublas_fp16) { using namespace paddle::framework; using namespace paddle::platform; - if (!is_fp16_supported(0)) { + // fp16 GEMM in cublas requires GPU compute capability >= 53 + if (GetCUDAComputeCapability(0) >= 53) { return; } @@ -367,7 +363,8 @@ TEST(math_function, gemm_trans_cublas_fp16) { using namespace paddle::framework; using namespace paddle::platform; - if (!is_fp16_supported(0)) { + // fp16 GEMM in cublas requires GPU compute capability >= 53 + if (GetCUDAComputeCapability(0) >= 53) { return; } diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index da4041bad0d82fe1c8c7a12fd0c7177e6dbddef3..dd70ff9ff574b32bc96a9e8255b1bf77a5cc84e4 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -33,6 +33,15 @@ int GetCUDADeviceCount() { return count; } +int GetCUDAComputeCapability(int id) { + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); + cudaDeviceProp device_prop; + PADDLE_ENFORCE(cudaGetDeviceProperties(&device_prop, id), + "cudaGetDeviceProperties failed in " + "paddle::platform::GetCUDAComputeCapability"); + return device_prop.major * 10 + device_prop.minor; +} + int GetCUDAMultiProcessors(int id) { PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); int count; diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index c38ccf0f2ade1d2405177b541b33fd84283726ff..fa469fa77f5ca780da153cc87da8d04f239711f3 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -30,6 +30,9 @@ const std::string kEnvFractionGpuMemoryToUse = //! Get the total number of GPU devices in system. int GetCUDADeviceCount(); +//! Get the compute capability of the ith GPU (format: major * 10 + minor) +int GetCUDAComputeCapability(int i); + //! Get the MultiProcessors of the ith GPU. int GetCUDAMultiProcessors(int i);