From d00bd9eb7256b2022b11491374e94adfc82e720c Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Fri, 6 Apr 2018 10:19:31 -0700 Subject: [PATCH] Update the cuda API and enable tensor core for GEMM (#9622) * change from hgemm to gemmEx * fix cpplint --- paddle/fluid/operators/math/math_function.cu | 33 ++++++++++++++------ paddle/fluid/platform/dynload/cublas.cc | 4 +++ paddle/fluid/platform/dynload/cublas.h | 14 +++++++-- 3 files changed, 39 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index 1e909db52..82e129431 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -39,18 +39,33 @@ void gemm( cublasOperation_t cuTransB = (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const half h_alpha = static_cast(alpha); - const half h_beta = static_cast(beta); - const half* h_A = reinterpret_cast(A); - const half* h_B = reinterpret_cast(B); - half* h_C = reinterpret_cast(C); + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); // TODO(kexinzhao): add processing code for compute capability < 53 case PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53, - "cublas Hgemm requires GPU compute capability >= 53"); - PADDLE_ENFORCE(platform::dynload::cublasHgemm( - context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, - h_A, lda, &h_beta, h_C, N)); + "cublas fp16 gemm requires GPU compute capability >= 53"); + + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; +#if CUDA_VERSION >= 9000 + if (context.GetComputeCapability() >= 70) { + PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(), + CUBLAS_TENSOR_OP_MATH)); + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } else { + PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(), + CUBLAS_DEFAULT_MATH)); + } +#endif + + // cublasHgemm does true FP16 computation which is slow for non-Volta + // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: + // input/output in fp16, computation in fp32, which can also be accelerated + // using tensor cores in volta GPUs. + PADDLE_ENFORCE(platform::dynload::cublasGemmEx( + context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B, + CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, + CUDA_R_32F, algo)); } template <> diff --git a/paddle/fluid/platform/dynload/cublas.cc b/paddle/fluid/platform/dynload/cublas.cc index e90e3105f..eb541579a 100644 --- a/paddle/fluid/platform/dynload/cublas.cc +++ b/paddle/fluid/platform/dynload/cublas.cc @@ -24,6 +24,10 @@ void *cublas_dso_handle = nullptr; CUBLAS_BLAS_ROUTINE_EACH(DEFINE_WRAP); +#ifdef CUBLAS_BLAS_ROUTINE_EACH_R2 +CUBLAS_BLAS_ROUTINE_EACH_R2(DEFINE_WRAP); +#endif + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index fa9041134..3b8d192b6 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -15,8 +15,9 @@ limitations under the License. */ #pragma once #include +#include #include -#include +#include // NOLINT #include "paddle/fluid/platform/dynload/dynamic_loader.h" namespace paddle { @@ -70,6 +71,7 @@ extern void *cublas_dso_handle; __macro(cublasDgemm_v2); \ __macro(cublasHgemm); \ __macro(cublasSgemmEx); \ + __macro(cublasGemmEx); \ __macro(cublasSgeam_v2); \ __macro(cublasDgeam_v2); \ __macro(cublasCreate_v2); \ @@ -89,9 +91,15 @@ extern void *cublas_dso_handle; __macro(cublasSgetrfBatched); \ __macro(cublasSgetriBatched); \ __macro(cublasDgetrfBatched); \ - __macro(cublasDgetriBatched) + __macro(cublasDgetriBatched); -CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP); +CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) + +// APIs available after CUDA 9.0 +#if CUDA_VERSION >= 9000 +#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) __macro(cublasSetMathMode); +CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) +#endif #undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP } // namespace dynload -- GitLab