From 00b9e9a1357bb3fa6e6adceb4e650d9f6424aa2a Mon Sep 17 00:00:00 2001 From: chengduo Date: Thu, 22 Nov 2018 20:40:56 +0800 Subject: [PATCH] Refine cublas to support CUBLAS_TENSOR_OP_MATH (#13929) * refine cublase test=develop * code refine * refine cublas * add GEMME_EX * add enable_cublas_tensor_op_math doc and add cublasCall test=develop * fix CublasCall for cuda version test=develop * fix error test=develop * fix GEMM_EX to be compatible with gcc 4.8 test=develop * add GEMM_EX test=develop * to compatiable with gcc4.8 test=develop --- paddle/fluid/operators/math/blas_impl.cu.h | 206 +++++++++++++++++---- paddle/fluid/platform/device_context.h | 47 +++++ paddle/fluid/platform/dynload/cublas.h | 26 ++- paddle/fluid/platform/gpu_info.cc | 20 ++ paddle/fluid/platform/gpu_info.h | 3 + python/paddle/fluid/__init__.py | 3 +- 6 files changed, 256 insertions(+), 49 deletions(-) diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index d84c88cb3b..d35073029a 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -16,6 +16,9 @@ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/dynload/cublas.h" +#include "paddle/fluid/platform/gpu_info.h" + +DECLARE_bool(enable_cublas_tensor_op_math); namespace paddle { namespace operators { @@ -42,11 +45,44 @@ struct CUBlas { } template - static void GEMM_BATCH(ARGS... args) { + static void GEMM_STRIDED_BATCH(ARGS... args) { #if CUDA_VERSION >= 8000 PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(args...)); #else PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5"); +#endif + } + + // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode + template + static void GEMM_EX(platform::CUDADeviceContext *dev_ctx, + cublasOperation_t transa, cublasOperation_t transb, int m, + int n, int k, const float *alpha, const void *A, + cudaDataType_t Atype, int lda, const void *B, + cudaDataType_t Btype, int ldb, const float *beta, void *C, + cudaDataType_t Ctype, int ldc) { + // Because the gcc 4.8 doesn't expand template parameter pack that + // appears in a lambda-expression, I can not use template parameter pack + // here. + auto cublas_call = [&]() { +#if CUDA_VERSION >= 8000 + VLOG(5) << "use_tensor_op_math: " + << (platform::TensorCoreAvailable() ? "True" : "False"); + PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( + dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, + lda, B, Btype, ldb, beta, C, Ctype, ldc)); +#else + PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); +#endif + }; + +#if CUDA_VERSION >= 9000 + // NOTES: To use Tensor Core, we should change the cublas config, + // but the cublas may be hold by multi-thread. + dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH); +#else + cublas_call(); #endif } }; @@ -69,13 +105,18 @@ struct CUBlas { } template - static void GEMM_BATCH(ARGS... args) { + static void GEMM_STRIDED_BATCH(ARGS... args) { #if CUDA_VERSION >= 8000 PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(args...)); #else PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5"); #endif } + + template + static void GEMM_EX(ARGS... args) { + PADDLE_THROW("Currently there are not cublasDgemmEx."); + } }; template <> @@ -96,14 +137,16 @@ struct CUBlas { reinterpret_cast<__half *>(C), ldc)); } - static void GEMM_BATCH(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float16 *alpha, const float16 *A, int lda, - long long int strideA, const float16 *B, // NOLINT - int ldb, long long int strideB, // NOLINT - const float16 *beta, float16 *C, int ldc, - long long int strideC, // NOLINT - int batchCount) { + static void GEMM_STRIDED_BATCH(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float16 *alpha, const float16 *A, + int lda, long long int strideA, // NOLINT + const float16 *B, // NOLINT + int ldb, long long int strideB, // NOLINT + const float16 *beta, float16 *C, int ldc, + long long int strideC, // NOLINT + int batchCount) { #if CUDA_VERSION >= 8000 PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched( handle, transa, transb, m, n, k, @@ -114,6 +157,45 @@ struct CUBlas { ldc, strideC, batchCount)); #else PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5"); +#endif + } + + // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. + // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode + template + static void GEMM_EX(platform::CUDADeviceContext *dev_ctx, + cublasOperation_t transa, cublasOperation_t transb, int m, + int n, int k, const void *alpha, const void *A, + cudaDataType_t Atype, int lda, const void *B, + cudaDataType_t Btype, int ldb, const void *beta, void *C, + cudaDataType_t Ctype, int ldc, + cudaDataType_t computeType) { + auto cublas_call = [&]() { +#if CUDA_VERSION >= 8000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; +#if CUDA_VERSION >= 9000 + bool use_tensor_op_math = platform::TensorCoreAvailable(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); +#endif // CUDA_VERSION >= 9000 + + PADDLE_ENFORCE(platform::dynload::cublasGemmEx( + dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, + lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo)); +#else + PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0"); +#endif + }; + +#if CUDA_VERSION >= 9000 + // NOTES: To use Tensor Core, we should change the cublas config, + // but the cublas may be hold by multi-thread. + dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH); +#else + cublas_call(); #endif } }; @@ -133,8 +215,21 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, cublasOperation_t cuTransB = (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, - B, ldb, A, lda, &beta, C, N); +#if CUDA_VERSION >= 8000 + if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { + auto &cuda_ctx = const_cast(context_); + CUBlas::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B, + CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C, + CUDA_R_32F, N); + } else { +#endif // CUDA_VERSION >= 8000 + + CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, + &alpha, B, ldb, A, lda, &beta, C, N); + +#if CUDA_VERSION >= 8000 + } +#endif // CUDA_VERSION >= 8000 } template <> @@ -157,30 +252,18 @@ inline void Blas::GEMM( PADDLE_ENFORCE_GE(context_.GetComputeCapability(), 53, "cublas fp16 gemm requires GPU compute capability >= 53"); -#if CUDA_VERSION >= 8000 float h_alpha = static_cast(alpha); float h_beta = static_cast(beta); - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - if (context_.GetComputeCapability() >= 70) { - PADDLE_ENFORCE(platform::dynload::cublasSetMathMode( - context_.cublas_handle(), CUBLAS_TENSOR_OP_MATH)); - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } else { - PADDLE_ENFORCE(platform::dynload::cublasSetMathMode( - context_.cublas_handle(), CUBLAS_DEFAULT_MATH)); - } -#endif // CUDA_VERSION >= 9000 - +#if CUDA_VERSION >= 8000 // cublasHgemm does true FP16 computation which is slow for non-Volta // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: // input/output in fp16, computation in fp32, which can also be accelerated // using tensor cores in volta GPUs. - PADDLE_ENFORCE(platform::dynload::cublasGemmEx( - context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B, - CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, - CUDA_R_32F, algo)); + auto &cuda_ctx = const_cast(context_); + CUBlas::GEMM_EX( + &cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16F, ldb, A, + CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F); #else // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, @@ -199,8 +282,38 @@ void Blas::GEMM(bool transA, bool transB, int M, // the cblas convention. cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, - B, ldb, A, lda, &beta, C, ldc); + +#if CUDA_VERSION >= 8000 + if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { + auto &cuda_ctx = const_cast(context_); + CUBlas::GEMM_EX(&cuda_ctx, cuTransB, cuTransA, N, M, K, &alpha, B, + CUDA_R_32F, ldb, A, CUDA_R_32F, lda, &beta, C, + CUDA_R_32F, ldc); + } else { +#endif // CUDA_VERSION >= 8000 + + CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, + &alpha, B, ldb, A, lda, &beta, C, ldc); + +#if CUDA_VERSION >= 8000 + } +#endif // CUDA_VERSION >= 8000 +} + +template <> +template <> +inline void Blas::GEMM( + bool transA, bool transB, int M, int N, int K, platform::float16 alpha, + const platform::float16 *A, int lda, const platform::float16 *B, int ldb, + platform::float16 beta, platform::float16 *C, int ldc) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + + CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, + N, M, K, &alpha, B, ldb, A, lda, &beta, C, + ldc); } template <> @@ -238,9 +351,34 @@ void Blas::BatchedGEMM( (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; const int64_t strideC = M * N; - CUBlas::GEMM_BATCH(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, - &alpha, B, ldb, strideB, A, lda, strideA, &beta, C, ldc, - strideC, batchCount); +#if CUDA_VERSION >= 9010 + if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { + auto cublas_call = [&]() { + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = platform::TensorCoreAvailable(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); + + PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( + context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, + CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C, + CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo)); + }; + auto &dev_ctx = const_cast(context_); + dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH); + } else { +#endif // CUDA_VERSION >= 9010 + + CUBlas::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA, + N, M, K, &alpha, B, ldb, strideB, A, lda, + strideA, &beta, C, ldc, strideC, batchCount); + +#if CUDA_VERSION >= 9010 + } +#endif // CUDA_VERSION >= 9010 } } // namespace math diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 9a9018cdea..3edd727978 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -143,6 +143,39 @@ class CudnnWorkspaceHandle { std::unique_ptr> guard_; }; +#if CUDA_VERSION >= 9000 +class ScopedCublasMathMode { + public: + ScopedCublasMathMode(cublasHandle_t handle, cublasMath_t new_math_mode) + : handle_(handle) { + need_reset = false; + PADDLE_ENFORCE( + platform::dynload::cublasGetMathMode(handle_, &old_math_mode_), + "Failed to get old cublas math mode"); + if (old_math_mode_ != new_math_mode) { + PADDLE_ENFORCE( + platform::dynload::cublasSetMathMode(handle_, new_math_mode), + "Failed to set old cublas math mode"); + need_reset = true; + } + } + + ~ScopedCublasMathMode() { + if (need_reset) { + PADDLE_ENFORCE( + platform::dynload::cublasSetMathMode(handle_, old_math_mode_), + "Failed to set old cublas math mode"); + } + } + + private: + cublasHandle_t handle_; + cublasMath_t old_math_mode_; + bool need_reset; +}; + +#endif + class CUDADeviceContext : public DeviceContext { public: explicit CUDADeviceContext(CUDAPlace place); @@ -199,6 +232,18 @@ class CUDADeviceContext : public DeviceContext { callback_manager_->Wait(); } +#if CUDA_VERSION >= 9000 + /*! \brief CublasCall may need to change cublas's config, + * but the cublas may be hold by multi-thread, so we should + * add lock here. */ + template + void CublasCall(Callback callback, cublasMath_t new_math) { + std::lock_guard guard(cublas_mtx_); + ScopedCublasMathMode scoped_cublas_math(cublas_handle_, new_math); + callback(); + } +#endif + private: CUDAPlace place_; @@ -220,6 +265,8 @@ class CUDADeviceContext : public DeviceContext { // If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes mutable std::mutex callback_mtx_; std::unique_ptr callback_manager_; + + mutable std::mutex cublas_mtx_; }; template <> diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index 4ea0cd7283..ff80bd525c 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -61,9 +61,6 @@ extern void *cublas_dso_handle; extern DynLoad__##__name __name #endif -#define DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) \ - DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) - #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasSaxpy_v2); \ __macro(cublasDaxpy_v2); \ @@ -93,22 +90,23 @@ CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) // APIs available after CUDA 8.0 #if CUDA_VERSION >= 8000 -#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) \ - __macro(cublasGemmEx); \ - __macro(cublasSgemmStridedBatched); \ - __macro(cublasDgemmStridedBatched); \ - __macro(cublasCgemmStridedBatched); \ - __macro(cublasZgemmStridedBatched); \ - __macro(cublasHgemmStridedBatched); - -CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmEx); +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmStridedBatched); +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmStridedBatched); +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmStridedBatched); +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmStridedBatched); +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasHgemmStridedBatched); #endif // APIs available after CUDA 9.0 #if CUDA_VERSION >= 9000 -#define CUBLAS_BLAS_ROUTINE_EACH_R3(__macro) __macro(cublasSetMathMode); +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSetMathMode); +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGetMathMode); +#endif -CUBLAS_BLAS_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) +#if CUDA_VERSION >= 9010 +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmBatchedEx); +DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmStridedBatchedEx); #endif #undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index c78f159ad2..833d48347f 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -26,6 +26,16 @@ DEFINE_double(fraction_of_gpu_memory_to_use, 0.92, "additional trunks of the same size will be requested from gpu " "until the gpu has no memory left for another trunk."); +DEFINE_bool( + enable_cublas_tensor_op_math, false, + "The enable_cublas_tensor_op_math indicate whether to use Tensor Core, " + "but it may loss precision. Currently, There are two CUDA libraries that" + " use Tensor Cores, cuBLAS and cuDNN. cuBLAS uses Tensor Cores to speed up" + " GEMM computations(the matrices must be either half precision or single " + "precision); cuDNN uses Tensor Cores to speed up both convolutions(the " + "input and output must be half precision) and recurrent neural networks " + "(RNNs)."); + namespace paddle { namespace platform { @@ -64,6 +74,16 @@ int GetCUDADriverVersion(int id) { return driver_version; } +bool TensorCoreAvailable() { +#if CUDA_VERSION >= 9000 + int device = GetCurrentDeviceId(); + int driver_version = GetCUDAComputeCapability(device); + return driver_version >= 70; +#else + return false; +#endif +} + int GetCUDAMultiProcessors(int id) { PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); int count; diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index be44158431..6a0b3c8e02 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -35,6 +35,9 @@ int GetCUDARuntimeVersion(int id); //! Get the driver version of the ith GPU int GetCUDADriverVersion(int id); +//! Wheter the current device support TensorCore +bool TensorCoreAvailable(); + //! Get the MultiProcessors of the ith GPU. int GetCUDAMultiProcessors(int i); diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 543acf2d34..3c092dee34 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -133,7 +133,8 @@ def __bootstrap__(): if core.is_compiled_with_cuda(): read_env_flags += [ 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', - 'conv_workspace_size_limit', 'cudnn_exhaustive_search' + 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', + 'cudnn_exhaustive_search' ] core.init_gflags([sys.argv[0]] + ["--tryfromenv=" + ",".join(read_env_flags)]) -- GitLab