未验证 提交 ae2be49f 编写于 作者: L liu zhengxi 提交者: GitHub

Add cublas_handle() to expose cublas_handle to ops (#31157)

* add get_cublas_handle() api

* update format

* add unittests

* alter function name
上级 406f4a75
...@@ -108,6 +108,8 @@ class CublasHandleHolder { ...@@ -108,6 +108,8 @@ class CublasHandleHolder {
} }
#endif #endif
const cublasHandle_t& GetCublasHandle() const { return handle_; }
~CublasHandleHolder() PADDLE_MAY_THROW { ~CublasHandleHolder() PADDLE_MAY_THROW {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_RETRY_CUDA_SUCCESS(dynload::rocblas_destroy_handle(handle_)); PADDLE_RETRY_CUDA_SUCCESS(dynload::rocblas_destroy_handle(handle_));
...@@ -117,7 +119,7 @@ class CublasHandleHolder { ...@@ -117,7 +119,7 @@ class CublasHandleHolder {
} }
template <typename Callback> template <typename Callback>
inline void Call(Callback &&callback) const { inline void Call(Callback&& callback) const {
std::lock_guard<std::mutex> guard(mtx_); std::lock_guard<std::mutex> guard(mtx_);
callback(handle_); callback(handle_);
} }
......
...@@ -459,6 +459,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { ...@@ -459,6 +459,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return context()->CudnnHandle(); return context()->CudnnHandle();
} }
cublasHandle_t CUDADeviceContext::cublas_handle() const {
return context()->CublasHandle()->GetCublasHandle();
}
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_); return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
} }
......
...@@ -409,6 +409,9 @@ class CUDADeviceContext : public DeviceContext { ...@@ -409,6 +409,9 @@ class CUDADeviceContext : public DeviceContext {
cudnnHandle_t cudnn_handle() const; cudnnHandle_t cudnn_handle() const;
#endif #endif
/*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle() const;
/*! \brief Return a cudnn workspace handle to call multiple cudnn /*! \brief Return a cudnn workspace handle to call multiple cudnn
* functions without interrupting by other threads. * functions without interrupting by other threads.
* Once the first cudnn function is called by the handle, a lock * Once the first cudnn function is called by the handle, a lock
......
...@@ -47,6 +47,8 @@ TEST(Device, CUDADeviceContext) { ...@@ -47,6 +47,8 @@ TEST(Device, CUDADeviceContext) {
cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
#endif #endif
ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
delete device_context; delete device_context;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册