From 7695b713e1b4a60084202a5b6ed10525d507821f Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Fri, 1 Nov 2019 22:29:46 +0800 Subject: [PATCH] gpu info query refine test=develop (#20904) --- paddle/fluid/platform/gpu_info.cc | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index add5cabd444..c8d312c61bd 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -32,6 +32,11 @@ constexpr static float fraction_reserve_gpu_memory = 0.05f; namespace paddle { namespace platform { +/* Here is a very simple CUDA “pro tip”: cudaDeviceGetAttribute() is a much +faster way to query device properties. You can see details in +https://devblogs.nvidia.com/cuda-pro-tip-the-fast-way-to-query-device-properties/ +*/ + inline std::string CudaErrorWebsite() { return "Please see detail in https://docs.nvidia.com/cuda/cuda-runtime-api" "/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3f51e3575c217824" @@ -75,14 +80,23 @@ int GetCUDADeviceCount() { int GetCUDAComputeCapability(int id) { PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); - cudaDeviceProp device_prop; - auto error_code = cudaGetDeviceProperties(&device_prop, id); - PADDLE_ENFORCE( - error_code, - "cudaGetDeviceProperties failed in " + int major, minor; + + auto major_error_code = + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id); + auto minor_error_code = + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id); + PADDLE_ENFORCE_EQ( + major_error_code, 0, + "cudaDevAttrComputeCapabilityMajor failed in " "paddle::platform::GetCUDAComputeCapability, error code : %d, %s", - error_code, CudaErrorWebsite()); - return device_prop.major * 10 + device_prop.minor; + major_error_code, CudaErrorWebsite()); + PADDLE_ENFORCE_EQ( + minor_error_code, 0, + "cudaDevAttrComputeCapabilityMinor failed in " + "paddle::platform::GetCUDAComputeCapability, error code : %d, %s", + minor_error_code, CudaErrorWebsite()); + return major * 10 + minor; } int GetCUDARuntimeVersion(int id) { -- GitLab