未验证 提交 c7b32fe1 编写于 作者: L liu zhengxi 提交者: GitHub

Add cublas_handle() to expose cublas_handle to ops (#31157) (#31190)

* add get_cublas_handle() api

* update format

* add unittests

* alter function name
上级 0def5938
...@@ -96,12 +96,14 @@ class CublasHandleHolder { ...@@ -96,12 +96,14 @@ class CublasHandleHolder {
#endif // CUDA_VERSION >= 9000 #endif // CUDA_VERSION >= 9000
} }
const cublasHandle_t& GetCublasHandle() const { return handle_; }
~CublasHandleHolder() PADDLE_MAY_THROW { ~CublasHandleHolder() PADDLE_MAY_THROW {
PADDLE_RETRY_CUDA_SUCCESS(dynload::cublasDestroy(handle_)); PADDLE_RETRY_CUDA_SUCCESS(dynload::cublasDestroy(handle_));
} }
template <typename Callback> template <typename Callback>
inline void Call(Callback &&callback) const { inline void Call(Callback&& callback) const {
std::lock_guard<std::mutex> guard(mtx_); std::lock_guard<std::mutex> guard(mtx_);
callback(handle_); callback(handle_);
} }
......
...@@ -437,6 +437,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { ...@@ -437,6 +437,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return context()->CudnnHandle(); return context()->CudnnHandle();
} }
cublasHandle_t CUDADeviceContext::cublas_handle() const {
return context()->CublasHandle()->GetCublasHandle();
}
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_); return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
} }
......
...@@ -346,6 +346,9 @@ class CUDADeviceContext : public DeviceContext { ...@@ -346,6 +346,9 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */ /*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const; cudnnHandle_t cudnn_handle() const;
/*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle() const;
/*! \brief Return a cudnn workspace handle to call multiple cudnn /*! \brief Return a cudnn workspace handle to call multiple cudnn
* functions without interrupting by other threads. * functions without interrupting by other threads.
* Once the first cudnn function is called by the handle, a lock * Once the first cudnn function is called by the handle, a lock
......
...@@ -43,6 +43,8 @@ TEST(Device, CUDADeviceContext) { ...@@ -43,6 +43,8 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE(nullptr, gpu_device); ASSERT_NE(nullptr, gpu_device);
cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
delete device_context; delete device_context;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册