未验证 提交 c7b32fe1 编写于 作者: L liu zhengxi 提交者: GitHub

Add cublas_handle() to expose cublas_handle to ops (#31157) (#31190)

* add get_cublas_handle() api

* update format

* add unittests

* alter function name
上级 0def5938
......@@ -96,12 +96,14 @@ class CublasHandleHolder {
#endif // CUDA_VERSION >= 9000
}
const cublasHandle_t& GetCublasHandle() const { return handle_; }
~CublasHandleHolder() PADDLE_MAY_THROW {
PADDLE_RETRY_CUDA_SUCCESS(dynload::cublasDestroy(handle_));
}
template <typename Callback>
inline void Call(Callback &&callback) const {
inline void Call(Callback&& callback) const {
std::lock_guard<std::mutex> guard(mtx_);
callback(handle_);
}
......
......@@ -437,6 +437,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return context()->CudnnHandle();
}
cublasHandle_t CUDADeviceContext::cublas_handle() const {
return context()->CublasHandle()->GetCublasHandle();
}
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
}
......
......@@ -346,6 +346,9 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const;
/*! \brief Return cublas handle in the device context. */
cublasHandle_t cublas_handle() const;
/*! \brief Return a cudnn workspace handle to call multiple cudnn
* functions without interrupting by other threads.
* Once the first cudnn function is called by the handle, a lock
......
......@@ -43,6 +43,8 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE(nullptr, gpu_device);
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
delete device_context;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册