diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index d35073029a3440d8a17e383ce97fcfc582663888..a4fb1cdcd970f8c8e961f633b5cbf71fb67090be 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -62,27 +62,17 @@ struct CUBlas { cudaDataType_t Atype, int lda, const void *B, cudaDataType_t Btype, int ldb, const float *beta, void *C, cudaDataType_t Ctype, int ldc) { - // Because the gcc 4.8 doesn't expand template parameter pack that - // appears in a lambda-expression, I can not use template parameter pack - // here. - auto cublas_call = [&]() { +// Because the gcc 4.8 doesn't expand template parameter pack that +// appears in a lambda-expression, I can not use template parameter pack +// here. #if CUDA_VERSION >= 8000 - VLOG(5) << "use_tensor_op_math: " - << (platform::TensorCoreAvailable() ? "True" : "False"); - PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( - dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, - lda, B, Btype, ldb, beta, C, Ctype, ldc)); + VLOG(5) << "use_tensor_op_math: " + << (dev_ctx->tensor_core_available() ? "True" : "False"); + PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( + dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k, + alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc)); #else - PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); -#endif - }; - -#if CUDA_VERSION >= 9000 - // NOTES: To use Tensor Core, we should change the cublas config, - // but the cublas may be hold by multi-thread. - dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH); -#else - cublas_call(); + PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); #endif } }; @@ -170,32 +160,23 @@ struct CUBlas { cudaDataType_t Btype, int ldb, const void *beta, void *C, cudaDataType_t Ctype, int ldc, cudaDataType_t computeType) { - auto cublas_call = [&]() { #if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; #if CUDA_VERSION >= 9000 - bool use_tensor_op_math = platform::TensorCoreAvailable(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); + bool use_tensor_op_math = dev_ctx->tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); #endif // CUDA_VERSION >= 9000 - PADDLE_ENFORCE(platform::dynload::cublasGemmEx( - dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, - lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo)); -#else - PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0"); -#endif - }; - -#if CUDA_VERSION >= 9000 - // NOTES: To use Tensor Core, we should change the cublas config, - // but the cublas may be hold by multi-thread. - dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH); + PADDLE_ENFORCE(platform::dynload::cublasGemmEx( + dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k, + alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, + algo)); #else - cublas_call(); + PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0"); #endif } }; @@ -353,22 +334,18 @@ void Blas::BatchedGEMM( #if CUDA_VERSION >= 9010 if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { - auto cublas_call = [&]() { - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = platform::TensorCoreAvailable(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); - - PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( - context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, - CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C, - CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo)); - }; - auto &dev_ctx = const_cast(context_); - dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH); + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); + + PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( + context_.possible_cublas_tensor_core_handle(), cuTransB, cuTransA, N, M, + K, &alpha, B, CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, + &beta, C, CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo)); } else { #endif // CUDA_VERSION >= 9010 diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 022afb686b29c2c493cfd05600ee372470cbc710..e40928fe5d2b678df1b895a74fe401e95c04b08b 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -247,6 +247,18 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); + + if (TensorCoreAvailable()) { +#if CUDA_VERSION >= 9000 + cublas_tensor_core_handle_.reset(new cublasHandle_t()); + PADDLE_ENFORCE(dynload::cublasCreate(cublas_tensor_core_handle_.get())); + PADDLE_ENFORCE( + dynload::cublasSetStream(*cublas_tensor_core_handle_, stream_)); + PADDLE_ENFORCE(dynload::cublasSetMathMode(*cublas_tensor_core_handle_, + CUBLAS_TENSOR_OP_MATH)); +#endif + } + if (dynload::HasCUDNN()) { cudnn_holder_.reset(new CudnnHolder(&stream_, place)); } @@ -307,6 +319,10 @@ CUDADeviceContext::~CUDADeviceContext() { Wait(); WaitStreamCallback(); PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); + if (cublas_tensor_core_handle_) { + PADDLE_ENFORCE(dynload::cublasDestroy(*cublas_tensor_core_handle_)); + cublas_tensor_core_handle_.reset(); + } eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); @@ -339,6 +355,15 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const { return cublas_handle_; } +cublasHandle_t CUDADeviceContext::possible_cublas_tensor_core_handle() const { + return cublas_tensor_core_handle_ ? *cublas_tensor_core_handle_ + : cublas_handle_; +} + +bool CUDADeviceContext::tensor_core_available() const { + return cublas_tensor_core_handle_ != nullptr; +} + cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_holder_->cudnn_handle(); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 7e875801893f3b73f8efaf33af690f8c855beee4..41b741a68fa0b1aed578423cef55241e9943abac 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -209,39 +209,6 @@ class CudnnWorkspaceHandle { std::unique_ptr> guard_; }; -#if CUDA_VERSION >= 9000 -class ScopedCublasMathMode { - public: - ScopedCublasMathMode(cublasHandle_t handle, cublasMath_t new_math_mode) - : handle_(handle) { - need_reset = false; - PADDLE_ENFORCE( - platform::dynload::cublasGetMathMode(handle_, &old_math_mode_), - "Failed to get old cublas math mode"); - if (old_math_mode_ != new_math_mode) { - PADDLE_ENFORCE( - platform::dynload::cublasSetMathMode(handle_, new_math_mode), - "Failed to set old cublas math mode"); - need_reset = true; - } - } - - ~ScopedCublasMathMode() { - if (need_reset) { - PADDLE_ENFORCE( - platform::dynload::cublasSetMathMode(handle_, old_math_mode_), - "Failed to set old cublas math mode"); - } - } - - private: - cublasHandle_t handle_; - cublasMath_t old_math_mode_; - bool need_reset; -}; - -#endif - class CUDADeviceContext : public DeviceContext { public: explicit CUDADeviceContext(CUDAPlace place); @@ -265,6 +232,13 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cublas handle in the device context. */ cublasHandle_t cublas_handle() const; + /*! \brief Check whether tensor core is supported */ + bool tensor_core_available() const; + + /*! \brief Return cublas handle supporting Tensor Core. If Tensor Core is + * not supported, return the same handle as cublas_handle(). */ + cublasHandle_t possible_cublas_tensor_core_handle() const; + /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle() const; @@ -294,18 +268,6 @@ class CUDADeviceContext : public DeviceContext { void WaitStreamCallback() const { callback_manager_->Wait(); } -#if CUDA_VERSION >= 9000 - /*! \brief CublasCall may need to change cublas's config, - * but the cublas may be hold by multi-thread, so we should - * add lock here. */ - template - void CublasCall(Callback callback, cublasMath_t new_math) { - std::lock_guard guard(cublas_mtx_); - ScopedCublasMathMode scoped_cublas_math(cublas_handle_, new_math); - callback(); - } -#endif - private: CUDAPlace place_; @@ -314,6 +276,7 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr cudnn_holder_; cudaStream_t stream_; cublasHandle_t cublas_handle_; + std::unique_ptr cublas_tensor_core_handle_; int compute_capability_; int runtime_version_;