diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index cd5af6f3abc563ebc90360a1c0f29165505fc768..ce0d73f520a711d1cd7d77358425a6bc2ab3de60 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -308,7 +308,7 @@ bool CUDADeviceContext::tensor_core_available() const { cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { - return CudnnWorkspaceHandle(*this); + return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_); } cudaStream_t CUDADeviceContext::stream() const { return stream_; } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 77443c72be8e76e655cf801a0a01ea11fb1e315b..cbb700fb35a648702670d31db5d339397b2c9f86 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -163,7 +163,10 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; cudaStream_t stream_; + cudnnHandle_t cudnn_handle_; + mutable std::mutex cudnn_handle_mtx_; + std::unique_ptr cublas_handle_; std::unique_ptr cublas_tensor_core_handle_; @@ -190,8 +193,8 @@ class CUDADeviceContext : public DeviceContext { class CudnnWorkspaceHandle { public: - inline explicit CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx) - : device_context_(dev_ctx) {} + inline CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx, std::mutex* mtx) + : device_context_(dev_ctx), mtx_(mtx) {} template inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) { @@ -200,7 +203,10 @@ class CudnnWorkspaceHandle { } VLOG(2) << "Cudnn workspace size at RunFunc: " << static_cast(WorkspaceSize()) / (1 << 20) << " MB"; - cudnn_func(allocation_ ? allocation_->ptr() : nullptr); + { + std::lock_guard guard(*mtx_); + cudnn_func(allocation_ ? allocation_->ptr() : nullptr); + } } /*! \brief Thread which call RunFuncSync() would release gpu memory after @@ -238,6 +244,7 @@ class CudnnWorkspaceHandle { private: memory::allocation::AllocationPtr allocation_; const CUDADeviceContext& device_context_; + std::mutex* mtx_; }; template <>