未验证 提交 ae2be49f 编写于 作者: L liu zhengxi 提交者: GitHub

Add cublas_handle() to expose cublas_handle to ops (#31157)

* add get_cublas_handle() api

* update format

* add unittests

* alter function name
上级 406f4a75
......@@ -108,6 +108,8 @@ class CublasHandleHolder {
}
#endif
const cublasHandle_t& GetCublasHandle() const { return handle_; }
~CublasHandleHolder() PADDLE_MAY_THROW {
#ifdef PADDLE_WITH_HIP
PADDLE_RETRY_CUDA_SUCCESS(dynload::rocblas_destroy_handle(handle_));
......@@ -117,7 +119,7 @@ class CublasHandleHolder {
}
template <typename Callback>
inline void Call(Callback &&callback) const {
inline void Call(Callback&& callback) const {
std::lock_guard<std::mutex> guard(mtx_);
callback(handle_);
}
......
......@@ -459,6 +459,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return context()->CudnnHandle();
}
cublasHandle_t CUDADeviceContext::cublas_handle() const {
return context()->CublasHandle()->GetCublasHandle();
}
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
}
......
......@@ -409,6 +409,9 @@ class CUDADeviceContext : public DeviceContext {
cudnnHandle_t cudnn_handle() const;
#endif
/*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle() const;
/*! \brief Return a cudnn workspace handle to call multiple cudnn
* functions without interrupting by other threads.
* Once the first cudnn function is called by the handle, a lock
......
......@@ -47,6 +47,8 @@ TEST(Device, CUDADeviceContext) {
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
#endif
ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
delete device_context;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册