diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_helper.h index 2a055fda4e9d4aebca5a8177734c7b743eb41dd2..24626d8859c57e52b5f84da908c6c06cdec7d50c 100644 --- a/paddle/fluid/platform/cuda_helper.h +++ b/paddle/fluid/platform/cuda_helper.h @@ -96,12 +96,14 @@ class CublasHandleHolder { #endif // CUDA_VERSION >= 9000 } + const cublasHandle_t& GetCublasHandle() const { return handle_; } + ~CublasHandleHolder() PADDLE_MAY_THROW { PADDLE_RETRY_CUDA_SUCCESS(dynload::cublasDestroy(handle_)); } template - inline void Call(Callback &&callback) const { + inline void Call(Callback&& callback) const { std::lock_guard guard(mtx_); callback(handle_); } diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index a8d56c0717d302b842c1469b66b52e2920e116ff..1f33fdd3ff14a11620a407cbf4fe380e355d9060 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -437,6 +437,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return context()->CudnnHandle(); } +cublasHandle_t CUDADeviceContext::cublas_handle() const { + return context()->CublasHandle()->GetCublasHandle(); +} + CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index a6612a5061f27e92c58134470db92ad58304f4e9..d40a898b859e3f90e96c89f139c0260b949310b8 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -346,6 +346,9 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle() const; + /*! \brief Return cublas handle in the device context. */ + cublasHandle_t cublas_handle() const; + /*! \brief Return a cudnn workspace handle to call multiple cudnn * functions without interrupting by other threads. * Once the first cudnn function is called by the handle, a lock diff --git a/paddle/fluid/platform/device_context_test.cu b/paddle/fluid/platform/device_context_test.cu index 5b3aa98efb46b51d6c3edb6d2cbd4200bd0a35c6..e5024c43ebc0f074b557e4cb19ef2a95b073aa5d 100644 --- a/paddle/fluid/platform/device_context_test.cu +++ b/paddle/fluid/platform/device_context_test.cu @@ -43,6 +43,8 @@ TEST(Device, CUDADeviceContext) { ASSERT_NE(nullptr, gpu_device); cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); ASSERT_NE(nullptr, cudnn_handle); + cublasHandle_t cublas_handle = device_context->cublas_handle(); + ASSERT_NE(nullptr, cublas_handle); delete device_context; } }