未验证 提交 c7f36e7c 编写于 作者: Z Zeng Jinle 提交者: GitHub

Add lock to cudnn handle calls (#19845)

* refine reallocate of workspace size, test=develop

* add lock to cudnn handle calls, test=develop
上级 2c5c6365
......@@ -308,7 +308,7 @@ bool CUDADeviceContext::tensor_core_available() const {
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(*this);
return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
}
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
......
......@@ -163,7 +163,10 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
cudaStream_t stream_;
cudnnHandle_t cudnn_handle_;
mutable std::mutex cudnn_handle_mtx_;
std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
......@@ -190,8 +193,8 @@ class CUDADeviceContext : public DeviceContext {
class CudnnWorkspaceHandle {
public:
inline explicit CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx)
: device_context_(dev_ctx) {}
inline CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx, std::mutex* mtx)
: device_context_(dev_ctx), mtx_(mtx) {}
template <typename Callback>
inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) {
......@@ -200,7 +203,10 @@ class CudnnWorkspaceHandle {
}
VLOG(2) << "Cudnn workspace size at RunFunc: "
<< static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
{
std::lock_guard<std::mutex> guard(*mtx_);
cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
}
}
/*! \brief Thread which call RunFuncSync() would release gpu memory after
......@@ -238,6 +244,7 @@ class CudnnWorkspaceHandle {
private:
memory::allocation::AllocationPtr allocation_;
const CUDADeviceContext& device_context_;
std::mutex* mtx_;
};
template <>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册