未验证 提交 c7f36e7c 编写于 作者: Z Zeng Jinle 提交者: GitHub

Add lock to cudnn handle calls (#19845)

* refine reallocate of workspace size, test=develop

* add lock to cudnn handle calls, test=develop
上级 2c5c6365
...@@ -308,7 +308,7 @@ bool CUDADeviceContext::tensor_core_available() const { ...@@ -308,7 +308,7 @@ bool CUDADeviceContext::tensor_core_available() const {
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(*this); return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
} }
cudaStream_t CUDADeviceContext::stream() const { return stream_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; }
......
...@@ -163,7 +163,10 @@ class CUDADeviceContext : public DeviceContext { ...@@ -163,7 +163,10 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
cudaStream_t stream_; cudaStream_t stream_;
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
mutable std::mutex cudnn_handle_mtx_;
std::unique_ptr<CublasHandleHolder> cublas_handle_; std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_; std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
...@@ -190,8 +193,8 @@ class CUDADeviceContext : public DeviceContext { ...@@ -190,8 +193,8 @@ class CUDADeviceContext : public DeviceContext {
class CudnnWorkspaceHandle { class CudnnWorkspaceHandle {
public: public:
inline explicit CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx) inline CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx, std::mutex* mtx)
: device_context_(dev_ctx) {} : device_context_(dev_ctx), mtx_(mtx) {}
template <typename Callback> template <typename Callback>
inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) { inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) {
...@@ -200,8 +203,11 @@ class CudnnWorkspaceHandle { ...@@ -200,8 +203,11 @@ class CudnnWorkspaceHandle {
} }
VLOG(2) << "Cudnn workspace size at RunFunc: " VLOG(2) << "Cudnn workspace size at RunFunc: "
<< static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB"; << static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
{
std::lock_guard<std::mutex> guard(*mtx_);
cudnn_func(allocation_ ? allocation_->ptr() : nullptr); cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
} }
}
/*! \brief Thread which call RunFuncSync() would release gpu memory after /*! \brief Thread which call RunFuncSync() would release gpu memory after
* running the function. Currently this function is only used when cudnn * running the function. Currently this function is only used when cudnn
...@@ -238,6 +244,7 @@ class CudnnWorkspaceHandle { ...@@ -238,6 +244,7 @@ class CudnnWorkspaceHandle {
private: private:
memory::allocation::AllocationPtr allocation_; memory::allocation::AllocationPtr allocation_;
const CUDADeviceContext& device_context_; const CUDADeviceContext& device_context_;
std::mutex* mtx_;
}; };
template <> template <>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册