From c7b32fe1bdb3819ce1eb76affd28462d1201cd0c Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Thu, 25 Feb 2021 13:03:48 +0800 Subject: [PATCH] Add cublas_handle() to expose cublas_handle to ops (#31157) (#31190) * add get_cublas_handle() api * update format * add unittests * alter function name --- paddle/fluid/platform/cuda_helper.h | 4 +++- paddle/fluid/platform/device_context.cc | 4 ++++ paddle/fluid/platform/device_context.h | 3 +++ paddle/fluid/platform/device_context_test.cu | 2 ++ 4 files changed, 12 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_helper.h index 2a055fda4e9..24626d8859c 100644 --- a/paddle/fluid/platform/cuda_helper.h +++ b/paddle/fluid/platform/cuda_helper.h @@ -96,12 +96,14 @@ class CublasHandleHolder { #endif // CUDA_VERSION >= 9000 } + const cublasHandle_t& GetCublasHandle() const { return handle_; } + ~CublasHandleHolder() PADDLE_MAY_THROW { PADDLE_RETRY_CUDA_SUCCESS(dynload::cublasDestroy(handle_)); } template - inline void Call(Callback &&callback) const { + inline void Call(Callback&& callback) const { std::lock_guard guard(mtx_); callback(handle_); } diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index a8d56c0717d..1f33fdd3ff1 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -437,6 +437,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return context()->CudnnHandle(); } +cublasHandle_t CUDADeviceContext::cublas_handle() const { + return context()->CublasHandle()->GetCublasHandle(); +} + CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index a6612a5061f..d40a898b859 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -346,6 +346,9 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle() const; + /*! \brief Return cublas handle in the device context. */ + cublasHandle_t cublas_handle() const; + /*! \brief Return a cudnn workspace handle to call multiple cudnn * functions without interrupting by other threads. * Once the first cudnn function is called by the handle, a lock diff --git a/paddle/fluid/platform/device_context_test.cu b/paddle/fluid/platform/device_context_test.cu index 5b3aa98efb4..e5024c43ebc 100644 --- a/paddle/fluid/platform/device_context_test.cu +++ b/paddle/fluid/platform/device_context_test.cu @@ -43,6 +43,8 @@ TEST(Device, CUDADeviceContext) { ASSERT_NE(nullptr, gpu_device); cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); ASSERT_NE(nullptr, cudnn_handle); + cublasHandle_t cublas_handle = device_context->cublas_handle(); + ASSERT_NE(nullptr, cublas_handle); delete device_context; } } -- GitLab