diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index a4fb1cdcd970f8c8e961f633b5cbf71fb67090be..58f7be12ce6b5d447e93cf86c4954a86fccf48ef 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -68,9 +68,11 @@ struct CUBlas { #if CUDA_VERSION >= 8000 VLOG(5) << "use_tensor_op_math: " << (dev_ctx->tensor_core_available() ? "True" : "False"); - PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( - dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k, - alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc)); + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( + handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, + beta, C, Ctype, ldc)); + }); #else PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); #endif @@ -171,10 +173,11 @@ struct CUBlas { << (use_tensor_op_math ? "True" : "False"); #endif // CUDA_VERSION >= 9000 - PADDLE_ENFORCE(platform::dynload::cublasGemmEx( - dev_ctx->possible_cublas_tensor_core_handle(), transa, transb, m, n, k, - alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, - algo)); + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE(platform::dynload::cublasGemmEx( + handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, + beta, C, Ctype, ldc, computeType, algo)); + }); #else PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0"); #endif @@ -204,9 +207,10 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, CUDA_R_32F, N); } else { #endif // CUDA_VERSION >= 8000 - - CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, - &alpha, B, ldb, A, lda, &beta, C, N); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, + lda, &beta, C, N); + }); #if CUDA_VERSION >= 8000 } @@ -247,9 +251,12 @@ inline void Blas::GEMM( CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F); #else // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, - N, M, K, &h_alpha, h_B, ldb, h_A, lda, - &h_beta, h_C, N); + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, + &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, + N); + }); #endif // CUDA_VERSION >= 8000 } @@ -273,8 +280,10 @@ void Blas::GEMM(bool transA, bool transB, int M, } else { #endif // CUDA_VERSION >= 8000 - CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, - &alpha, B, ldb, A, lda, &beta, C, ldc); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, + lda, &beta, C, ldc); + }); #if CUDA_VERSION >= 8000 } @@ -292,16 +301,19 @@ inline void Blas::GEMM( cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, - N, M, K, &alpha, B, ldb, A, lda, &beta, C, - ldc); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, + B, ldb, A, lda, &beta, C, ldc); + }); } template <> template void Blas::AXPY(int n, T alpha, const T *x, T *y) const { - CUBlas::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::AXPY(handle, n, &alpha, x, 1, y, 1); + }); } template <> @@ -311,8 +323,9 @@ void Blas::GEMV(bool trans_a, int M, int N, T beta, T *C) const { cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBlas::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1, - &beta, C, 1); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); + }); } template <> @@ -342,16 +355,20 @@ void Blas::BatchedGEMM( VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); - PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( - context_.possible_cublas_tensor_core_handle(), cuTransB, cuTransA, N, M, - K, &alpha, B, CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, - &beta, C, CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo)); + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( + handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb, + strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc, + strideC, batchCount, CUDA_R_32F, algo)); + }); } else { #endif // CUDA_VERSION >= 9010 - CUBlas::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA, - N, M, K, &alpha, B, ldb, strideB, A, lda, - strideA, &beta, C, ldc, strideC, batchCount); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha, + B, ldb, strideB, A, lda, strideA, &beta, C, + ldc, strideC, batchCount); + }); #if CUDA_VERSION >= 9010 } diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..122de72e15d587cf33b5d9856ac8b1243f666881 --- /dev/null +++ b/paddle/fluid/platform/cuda_helper.h @@ -0,0 +1,58 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include // NOLINT + +#include "paddle/fluid/platform/dynload/cublas.h" +#include "paddle/fluid/platform/macros.h" + +#if CUDA_VERSION < 9000 +enum cublasMath_t { CUBLAS_DEFAULT_MATH = 0 }; +#endif + +namespace paddle { +namespace platform { + +class CublasHandleHolder { + public: + CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) { + PADDLE_ENFORCE(dynload::cublasCreate(&handle_)); + PADDLE_ENFORCE(dynload::cublasSetStream(handle_, stream)); +#if CUDA_VERSION >= 9000 + if (math_type == CUBLAS_TENSOR_OP_MATH) { + PADDLE_ENFORCE( + dynload::cublasSetMathMode(handle_, CUBLAS_TENSOR_OP_MATH)); + } +#endif + } + + ~CublasHandleHolder() { PADDLE_ENFORCE(dynload::cublasDestroy(handle_)); } + + template + inline void Call(Callback &&callback) const { + std::lock_guard guard(mtx_); + callback(handle_); + } + + private: + DISABLE_COPY_AND_ASSIGN(CublasHandleHolder); + + cublasHandle_t handle_; + mutable std::mutex mtx_; +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index e40928fe5d2b678df1b895a74fe401e95c04b08b..be7f4949d65cef36d61b726c1c656f177e298fcc 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -245,17 +245,12 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_->Reinitialize(&stream_, place); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); - PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); - PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); + cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH)); if (TensorCoreAvailable()) { #if CUDA_VERSION >= 9000 - cublas_tensor_core_handle_.reset(new cublasHandle_t()); - PADDLE_ENFORCE(dynload::cublasCreate(cublas_tensor_core_handle_.get())); - PADDLE_ENFORCE( - dynload::cublasSetStream(*cublas_tensor_core_handle_, stream_)); - PADDLE_ENFORCE(dynload::cublasSetMathMode(*cublas_tensor_core_handle_, - CUBLAS_TENSOR_OP_MATH)); + cublas_tensor_core_handle_.reset( + new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH)); #endif } @@ -318,11 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() { SetDeviceId(place_.device); Wait(); WaitStreamCallback(); - PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); - if (cublas_tensor_core_handle_) { - PADDLE_ENFORCE(dynload::cublasDestroy(*cublas_tensor_core_handle_)); - cublas_tensor_core_handle_.reset(); - } + cublas_handle_.reset(); + cublas_tensor_core_handle_.reset(); eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); @@ -351,15 +343,6 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { return eigen_device_.get(); } -cublasHandle_t CUDADeviceContext::cublas_handle() const { - return cublas_handle_; -} - -cublasHandle_t CUDADeviceContext::possible_cublas_tensor_core_handle() const { - return cublas_tensor_core_handle_ ? *cublas_tensor_core_handle_ - : cublas_handle_; -} - bool CUDADeviceContext::tensor_core_available() const { return cublas_tensor_core_handle_ != nullptr; } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 41b741a68fa0b1aed578423cef55241e9943abac..c81d17380cf894631d06588c007c2e11ce5c7836 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/platform/temporary_allocator.h" #ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/gpu_info.h" @@ -229,15 +230,25 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return eigen device in the device context. */ Eigen::GpuDevice* eigen_device() const; - /*! \brief Return cublas handle in the device context. */ - cublasHandle_t cublas_handle() const; + /*! \brief Call cublas function safely. */ + template + inline void CublasCall(Callback&& callback) const { + cublas_handle_->Call(std::forward(callback)); + } /*! \brief Check whether tensor core is supported */ bool tensor_core_available() const; - /*! \brief Return cublas handle supporting Tensor Core. If Tensor Core is - * not supported, return the same handle as cublas_handle(). */ - cublasHandle_t possible_cublas_tensor_core_handle() const; + /*! \brief Call cublas function with Tensor Core safely. If + Tensor Core is not available, use DEFAULT_MATH instead. */ + template + inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const { + if (cublas_tensor_core_handle_) { + cublas_tensor_core_handle_->Call(std::forward(callback)); + } else { + cublas_handle_->Call(std::forward(callback)); + } + } /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle() const; @@ -256,7 +267,6 @@ class CUDADeviceContext : public DeviceContext { template void RecordEvent(cudaEvent_t ev, Callback callback) { - std::lock_guard guard(mtx_); callback(); PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); } @@ -275,8 +285,9 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_stream_; std::unique_ptr cudnn_holder_; cudaStream_t stream_; - cublasHandle_t cublas_handle_; - std::unique_ptr cublas_tensor_core_handle_; + + std::unique_ptr cublas_handle_; + std::unique_ptr cublas_tensor_core_handle_; int compute_capability_; int runtime_version_; @@ -284,12 +295,10 @@ class CUDADeviceContext : public DeviceContext { int multi_process_; int max_threads_per_mp_; - mutable std::mutex mtx_; - // StreamCallbackManager is thread-safe std::unique_ptr callback_manager_; - mutable std::mutex cublas_mtx_; + DISABLE_COPY_AND_ASSIGN(CUDADeviceContext); }; template <> diff --git a/paddle/fluid/platform/device_context_test.cu b/paddle/fluid/platform/device_context_test.cu index 171d2979a0218ad5e22112190a59866b3e0b617f..5b3aa98efb46b51d6c3edb6d2cbd4200bd0a35c6 100644 --- a/paddle/fluid/platform/device_context_test.cu +++ b/paddle/fluid/platform/device_context_test.cu @@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) { ASSERT_NE(nullptr, gpu_device); cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); ASSERT_NE(nullptr, cudnn_handle); - cublasHandle_t cublas_handle = device_context->cublas_handle(); - ASSERT_NE(nullptr, cublas_handle); - ASSERT_NE(nullptr, device_context->stream()); delete device_context; } }