diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_helper.h index 006062848e080eedcafc1bbf35ccd789bb57ce37..bfefeb2f4a3da5da4e7c2059decb1cf677f02a1e 100644 --- a/paddle/fluid/platform/cuda_helper.h +++ b/paddle/fluid/platform/cuda_helper.h @@ -108,6 +108,8 @@ class CublasHandleHolder { } #endif + const cublasHandle_t& GetCublasHandle() const { return handle_; } + ~CublasHandleHolder() PADDLE_MAY_THROW { #ifdef PADDLE_WITH_HIP PADDLE_RETRY_CUDA_SUCCESS(dynload::rocblas_destroy_handle(handle_)); @@ -117,7 +119,7 @@ class CublasHandleHolder { } template - inline void Call(Callback &&callback) const { + inline void Call(Callback&& callback) const { std::lock_guard guard(mtx_); callback(handle_); } diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 53659314be7896ef86101535b1e7b3ecbc0b6c91..98dcf72aa4fb48709976aefa3c33cf518ba76fac 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -459,6 +459,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return context()->CudnnHandle(); } +cublasHandle_t CUDADeviceContext::cublas_handle() const { + return context()->CublasHandle()->GetCublasHandle(); +} + CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 72138b7909117736f06fc4827b63b1895a43d72e..11123c4e658ed9891336096b881a78d527dfd1c5 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -409,6 +409,9 @@ class CUDADeviceContext : public DeviceContext { cudnnHandle_t cudnn_handle() const; #endif + /*! \brief Return cublas handle in the device context. */ + cublasHandle_t cublas_handle() const; + /*! \brief Return a cudnn workspace handle to call multiple cudnn * functions without interrupting by other threads. * Once the first cudnn function is called by the handle, a lock diff --git a/paddle/fluid/platform/device_context_test.cu b/paddle/fluid/platform/device_context_test.cu index 857d5d276516059acace1727b86f8d837e279cd3..3e9fe461d746ca800a8ddd3f8aa12776b12479d6 100644 --- a/paddle/fluid/platform/device_context_test.cu +++ b/paddle/fluid/platform/device_context_test.cu @@ -47,6 +47,8 @@ TEST(Device, CUDADeviceContext) { cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); #endif ASSERT_NE(nullptr, cudnn_handle); + cublasHandle_t cublas_handle = device_context->cublas_handle(); + ASSERT_NE(nullptr, cublas_handle); delete device_context; } }