From 3b44b849d318bc60e9f6ceb4915f7262172c45e5 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sun, 11 Mar 2018 22:27:38 -0700 Subject: [PATCH] address comments --- paddle/fluid/operators/math/math_function.cu | 9 +++++ .../operators/math/math_function_test.cu | 40 +++++++++---------- paddle/fluid/platform/device_context.cc | 5 +++ paddle/fluid/platform/device_context.h | 3 ++ 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index 36655508be..3abbcdb71d 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -45,6 +45,9 @@ void gemm( const half* h_B = reinterpret_cast(B); half* h_C = reinterpret_cast(C); + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53, + "cublas Hgemm requires GPU compute capability >= 53"); PADDLE_ENFORCE(platform::dynload::cublasHgemm( context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, N)); @@ -106,6 +109,9 @@ void gemm( const half* h_B = reinterpret_cast(B); half* h_C = reinterpret_cast(C); + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53, + "cublas Hgemm requires GPU compute capability >= 53"); PADDLE_ENFORCE(platform::dynload::cublasHgemm( context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, ldc)); @@ -251,6 +257,9 @@ void batched_gemm( const half* h_B = reinterpret_cast(B); half* h_C = reinterpret_cast(C); + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53, + "cublas Hgemm requires GPU compute capability >= 53"); PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched( context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount)); diff --git a/paddle/fluid/operators/math/math_function_test.cu b/paddle/fluid/operators/math/math_function_test.cu index 49da6f69d6..8982d9d066 100644 --- a/paddle/fluid/operators/math/math_function_test.cu +++ b/paddle/fluid/operators/math/math_function_test.cu @@ -62,11 +62,6 @@ TEST(math_function, notrans_mul_trans_fp16) { using namespace paddle::framework; using namespace paddle::platform; - // fp16 GEMM in cublas requires GPU compute capability >= 53 - if (GetCUDAComputeCapability(0) < 53) { - return; - } - Tensor input1; Tensor input1_gpu; Tensor input2_gpu; @@ -77,6 +72,11 @@ TEST(math_function, notrans_mul_trans_fp16) { CUDAPlace gpu_place(0); CUDADeviceContext context(gpu_place); + // fp16 GEMM in cublas requires GPU compute capability >= 53 + if (context.GetComputeCapability() < 53) { + return; + } + float16* input1_ptr = input1.mutable_data({2, 3}, cpu_place); fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5}); @@ -144,11 +144,6 @@ TEST(math_function, trans_mul_notrans_fp16) { using namespace paddle::framework; using namespace paddle::platform; - // fp16 GEMM in cublas requires GPU compute capability >= 53 - if (GetCUDAComputeCapability(0) < 53) { - return; - } - Tensor input1; Tensor input1_gpu; Tensor input2_gpu; @@ -159,6 +154,11 @@ TEST(math_function, trans_mul_notrans_fp16) { CUDAPlace gpu_place(0); CUDADeviceContext context(gpu_place); + // fp16 GEMM in cublas requires GPU compute capability >= 53 + if (context.GetComputeCapability() < 53) { + return; + } + float16* input1_ptr = input1.mutable_data({2, 3}, cpu_place); fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5}); @@ -247,11 +247,6 @@ TEST(math_function, gemm_notrans_cublas_fp16) { using namespace paddle::framework; using namespace paddle::platform; - // fp16 GEMM in cublas requires GPU compute capability >= 53 - if (GetCUDAComputeCapability(0) < 53) { - return; - } - Tensor input1; Tensor input2; Tensor input3; @@ -263,6 +258,11 @@ TEST(math_function, gemm_notrans_cublas_fp16) { CUDAPlace gpu_place(0); CUDADeviceContext context(gpu_place); + // fp16 GEMM in cublas requires GPU compute capability >= 53 + if (context.GetComputeCapability() < 53) { + return; + } + int m = 2; int n = 3; int k = 3; @@ -359,11 +359,6 @@ TEST(math_function, gemm_trans_cublas_fp16) { using namespace paddle::framework; using namespace paddle::platform; - // fp16 GEMM in cublas requires GPU compute capability >= 53 - if (GetCUDAComputeCapability(0) < 53) { - return; - } - Tensor input1; Tensor input2; Tensor input3; @@ -375,6 +370,11 @@ TEST(math_function, gemm_trans_cublas_fp16) { CUDAPlace gpu_place(0); CUDADeviceContext context(gpu_place); + // fp16 GEMM in cublas requires GPU compute capability >= 53 + if (context.GetComputeCapability() < 53) { + return; + } + int m = 2; int n = 3; int k = 3; diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index bb9fbd468f..98b4178177 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -127,6 +127,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { SetDeviceId(place_.device); + compute_capability = GetCUDAComputeCapability(place_.device); multi_process = GetCUDAMultiProcessors(place_.device); max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); @@ -162,6 +163,10 @@ void CUDADeviceContext::Wait() const { PADDLE_ENFORCE(cudaGetLastError()); } +int CUDADeviceContext::GetComputeCapability() const { + return compute_capability; +} + int CUDADeviceContext::GetMaxPhysicalThreadCount() const { return multi_process * max_threads_per_mp; } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index e779644190..500891ac7a 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -79,6 +79,8 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return place in the device context. */ Place GetPlace() const override; + int GetComputeCapability() const; + /*! \brief Return the max physical thread count in the device context */ int GetMaxPhysicalThreadCount() const; @@ -104,6 +106,7 @@ class CUDADeviceContext : public DeviceContext { cudnnHandle_t cudnn_handle_; cublasHandle_t cublas_handle_; + int compute_capability; int multi_process; int max_threads_per_mp; }; -- GitLab