提交 3b44b849 编写于 作者: K Kexin Zhao

address comments

上级 95de7617
...@@ -45,6 +45,9 @@ void gemm<platform::CUDADeviceContext, float16>( ...@@ -45,6 +45,9 @@ void gemm<platform::CUDADeviceContext, float16>(
const half* h_B = reinterpret_cast<const half*>(B); const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C); half* h_C = reinterpret_cast<half*>(C);
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas Hgemm requires GPU compute capability >= 53");
PADDLE_ENFORCE(platform::dynload::cublasHgemm( PADDLE_ENFORCE(platform::dynload::cublasHgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
h_A, lda, &h_beta, h_C, N)); h_A, lda, &h_beta, h_C, N));
...@@ -106,6 +109,9 @@ void gemm<platform::CUDADeviceContext, float16>( ...@@ -106,6 +109,9 @@ void gemm<platform::CUDADeviceContext, float16>(
const half* h_B = reinterpret_cast<const half*>(B); const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C); half* h_C = reinterpret_cast<half*>(C);
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas Hgemm requires GPU compute capability >= 53");
PADDLE_ENFORCE(platform::dynload::cublasHgemm( PADDLE_ENFORCE(platform::dynload::cublasHgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
h_A, lda, &h_beta, h_C, ldc)); h_A, lda, &h_beta, h_C, ldc));
...@@ -251,6 +257,9 @@ void batched_gemm<platform::CUDADeviceContext, float16>( ...@@ -251,6 +257,9 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
const half* h_B = reinterpret_cast<const half*>(B); const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C); half* h_C = reinterpret_cast<half*>(C);
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas Hgemm requires GPU compute capability >= 53");
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched( PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount)); strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
......
...@@ -62,11 +62,6 @@ TEST(math_function, notrans_mul_trans_fp16) { ...@@ -62,11 +62,6 @@ TEST(math_function, notrans_mul_trans_fp16) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
// fp16 GEMM in cublas requires GPU compute capability >= 53
if (GetCUDAComputeCapability(0) < 53) {
return;
}
Tensor input1; Tensor input1;
Tensor input1_gpu; Tensor input1_gpu;
Tensor input2_gpu; Tensor input2_gpu;
...@@ -77,6 +72,11 @@ TEST(math_function, notrans_mul_trans_fp16) { ...@@ -77,6 +72,11 @@ TEST(math_function, notrans_mul_trans_fp16) {
CUDAPlace gpu_place(0); CUDAPlace gpu_place(0);
CUDADeviceContext context(gpu_place); CUDADeviceContext context(gpu_place);
// fp16 GEMM in cublas requires GPU compute capability >= 53
if (context.GetComputeCapability() < 53) {
return;
}
float16* input1_ptr = input1.mutable_data<float16>({2, 3}, cpu_place); float16* input1_ptr = input1.mutable_data<float16>({2, 3}, cpu_place);
fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5}); fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5});
...@@ -144,11 +144,6 @@ TEST(math_function, trans_mul_notrans_fp16) { ...@@ -144,11 +144,6 @@ TEST(math_function, trans_mul_notrans_fp16) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
// fp16 GEMM in cublas requires GPU compute capability >= 53
if (GetCUDAComputeCapability(0) < 53) {
return;
}
Tensor input1; Tensor input1;
Tensor input1_gpu; Tensor input1_gpu;
Tensor input2_gpu; Tensor input2_gpu;
...@@ -159,6 +154,11 @@ TEST(math_function, trans_mul_notrans_fp16) { ...@@ -159,6 +154,11 @@ TEST(math_function, trans_mul_notrans_fp16) {
CUDAPlace gpu_place(0); CUDAPlace gpu_place(0);
CUDADeviceContext context(gpu_place); CUDADeviceContext context(gpu_place);
// fp16 GEMM in cublas requires GPU compute capability >= 53
if (context.GetComputeCapability() < 53) {
return;
}
float16* input1_ptr = input1.mutable_data<float16>({2, 3}, cpu_place); float16* input1_ptr = input1.mutable_data<float16>({2, 3}, cpu_place);
fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5}); fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5});
...@@ -247,11 +247,6 @@ TEST(math_function, gemm_notrans_cublas_fp16) { ...@@ -247,11 +247,6 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
// fp16 GEMM in cublas requires GPU compute capability >= 53
if (GetCUDAComputeCapability(0) < 53) {
return;
}
Tensor input1; Tensor input1;
Tensor input2; Tensor input2;
Tensor input3; Tensor input3;
...@@ -263,6 +258,11 @@ TEST(math_function, gemm_notrans_cublas_fp16) { ...@@ -263,6 +258,11 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
CUDAPlace gpu_place(0); CUDAPlace gpu_place(0);
CUDADeviceContext context(gpu_place); CUDADeviceContext context(gpu_place);
// fp16 GEMM in cublas requires GPU compute capability >= 53
if (context.GetComputeCapability() < 53) {
return;
}
int m = 2; int m = 2;
int n = 3; int n = 3;
int k = 3; int k = 3;
...@@ -359,11 +359,6 @@ TEST(math_function, gemm_trans_cublas_fp16) { ...@@ -359,11 +359,6 @@ TEST(math_function, gemm_trans_cublas_fp16) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
// fp16 GEMM in cublas requires GPU compute capability >= 53
if (GetCUDAComputeCapability(0) < 53) {
return;
}
Tensor input1; Tensor input1;
Tensor input2; Tensor input2;
Tensor input3; Tensor input3;
...@@ -375,6 +370,11 @@ TEST(math_function, gemm_trans_cublas_fp16) { ...@@ -375,6 +370,11 @@ TEST(math_function, gemm_trans_cublas_fp16) {
CUDAPlace gpu_place(0); CUDAPlace gpu_place(0);
CUDADeviceContext context(gpu_place); CUDADeviceContext context(gpu_place);
// fp16 GEMM in cublas requires GPU compute capability >= 53
if (context.GetComputeCapability() < 53) {
return;
}
int m = 2; int m = 2;
int n = 3; int n = 3;
int k = 3; int k = 3;
......
...@@ -127,6 +127,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { ...@@ -127,6 +127,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
SetDeviceId(place_.device); SetDeviceId(place_.device);
compute_capability = GetCUDAComputeCapability(place_.device);
multi_process = GetCUDAMultiProcessors(place_.device); multi_process = GetCUDAMultiProcessors(place_.device);
max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device); max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_)); PADDLE_ENFORCE(cudaStreamCreate(&stream_));
...@@ -162,6 +163,10 @@ void CUDADeviceContext::Wait() const { ...@@ -162,6 +163,10 @@ void CUDADeviceContext::Wait() const {
PADDLE_ENFORCE(cudaGetLastError()); PADDLE_ENFORCE(cudaGetLastError());
} }
int CUDADeviceContext::GetComputeCapability() const {
return compute_capability;
}
int CUDADeviceContext::GetMaxPhysicalThreadCount() const { int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
return multi_process * max_threads_per_mp; return multi_process * max_threads_per_mp;
} }
......
...@@ -79,6 +79,8 @@ class CUDADeviceContext : public DeviceContext { ...@@ -79,6 +79,8 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return place in the device context. */ /*! \brief Return place in the device context. */
Place GetPlace() const override; Place GetPlace() const override;
int GetComputeCapability() const;
/*! \brief Return the max physical thread count in the device context */ /*! \brief Return the max physical thread count in the device context */
int GetMaxPhysicalThreadCount() const; int GetMaxPhysicalThreadCount() const;
...@@ -104,6 +106,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -104,6 +106,7 @@ class CUDADeviceContext : public DeviceContext {
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
cublasHandle_t cublas_handle_; cublasHandle_t cublas_handle_;
int compute_capability;
int multi_process; int multi_process;
int max_threads_per_mp; int max_threads_per_mp;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册