From ae2be49f402ea94ae228d04b32cad974fffbe459 Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Wed, 24 Feb 2021 16:52:50 +0800 Subject: [PATCH] Add cublas_handle() to expose cublas_handle to ops (#31157) * add get_cublas_handle() api * update format * add unittests * alter function name --- paddle/fluid/platform/cuda_helper.h | 4 +++- paddle/fluid/platform/device_context.cc | 4 ++++ paddle/fluid/platform/device_context.h | 3 +++ paddle/fluid/platform/device_context_test.cu | 2 ++ 4 files changed, 12 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_helper.h index 006062848e0..bfefeb2f4a3 100644 --- a/paddle/fluid/platform/cuda_helper.h +++ b/paddle/fluid/platform/cuda_helper.h @@ -108,6 +108,8 @@ class CublasHandleHolder { } #endif + const cublasHandle_t& GetCublasHandle() const { return handle_; } + ~CublasHandleHolder() PADDLE_MAY_THROW { #ifdef PADDLE_WITH_HIP PADDLE_RETRY_CUDA_SUCCESS(dynload::rocblas_destroy_handle(handle_)); @@ -117,7 +119,7 @@ class CublasHandleHolder { } template - inline void Call(Callback &&callback) const { + inline void Call(Callback&& callback) const { std::lock_guard guard(mtx_); callback(handle_); } diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 53659314be7..98dcf72aa4f 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -459,6 +459,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return context()->CudnnHandle(); } +cublasHandle_t CUDADeviceContext::cublas_handle() const { + return context()->CublasHandle()->GetCublasHandle(); +} + CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 72138b79091..11123c4e658 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -409,6 +409,9 @@ class CUDADeviceContext : public DeviceContext { cudnnHandle_t cudnn_handle() const; #endif + /*! \brief Return cublas handle in the device context. */ + cublasHandle_t cublas_handle() const; + /*! \brief Return a cudnn workspace handle to call multiple cudnn * functions without interrupting by other threads. * Once the first cudnn function is called by the handle, a lock diff --git a/paddle/fluid/platform/device_context_test.cu b/paddle/fluid/platform/device_context_test.cu index 857d5d27651..3e9fe461d74 100644 --- a/paddle/fluid/platform/device_context_test.cu +++ b/paddle/fluid/platform/device_context_test.cu @@ -47,6 +47,8 @@ TEST(Device, CUDADeviceContext) { cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); #endif ASSERT_NE(nullptr, cudnn_handle); + cublasHandle_t cublas_handle = device_context->cublas_handle(); + ASSERT_NE(nullptr, cublas_handle); delete device_context; } } -- GitLab