diff --git a/paddle/fluid/framework/rw_lock.h b/paddle/fluid/framework/rw_lock.h index a068d3543d9d2abec203f86362a8be5ba135d04d..4f1d9adbfc5d534b1b8f4780c9a0ade26fb53ecc 100644 --- a/paddle/fluid/framework/rw_lock.h +++ b/paddle/fluid/framework/rw_lock.h @@ -56,5 +56,70 @@ struct RWLock { }; #endif +class RWLockGuard { + public: + enum Status { kUnLock, kWRLock, kRDLock }; + + RWLockGuard(RWLock* rw_lock, Status init_status) + : lock_(rw_lock), status_(Status::kUnLock) { + switch (init_status) { + case Status::kRDLock: { + RDLock(); + break; + } + case Status::kWRLock: { + WRLock(); + break; + } + } + } + + void WRLock() { + switch (status_) { + case Status::kUnLock: { + lock_->WRLock(); + break; + } + case Status::kWRLock: { + break; + } + case Status::kRDLock: { + PADDLE_THROW( + "Please unlock read lock first before invoking write lock."); + break; + } + } + } + + void RDLock() { + switch (status_) { + case Status::kUnLock: { + lock_->RDLock(); + break; + } + case Status::kRDLock: { + break; + } + case Status::kWRLock: { + PADDLE_THROW( + "Please unlock write lock first before invoking read lock."); + break; + } + } + } + + void UnLock() { + if (status_ != Status::kUnLock) { + lock_->UNLock(); + } + } + + ~RWLockGuard() { UnLock(); } + + private: + RWLock* lock_; + Status status_; +}; + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 1e46e5de9dfa78eb171dcf3d1fbea5f8310f5f51..ec297ea9a6365f82819d329eb63754ba42393d61 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -15,12 +15,11 @@ limitations under the License. */ #include #include +#include "paddle/fluid/memory/memory.h" #ifdef PADDLE_WITH_CUDA -#include +#include "paddle/fluid/framework/rw_lock.h" #endif -#include "paddle/fluid/memory/memory.h" - namespace paddle { namespace platform { @@ -158,9 +157,14 @@ class CudnnHolder { void RunFunc(const std::function& cudnn_func, size_t required_workspace_len) { - boost::upgrade_lock shared_lock(mtx_); + framework::RWLockGuard lock_guard(&rw_lock_, + framework::RWLockGuard::Status::kRDLock); if (required_workspace_len > workspace_len_) { - ReallocateWorkspace(required_workspace_len, &shared_lock); + lock_guard.UnLock(); + lock_guard.WRLock(); + ReallocateWorkspace(required_workspace_len); + lock_guard.UnLock(); + lock_guard.RDLock(); } cudnn_func(workspace_); } @@ -168,9 +172,7 @@ class CudnnHolder { ~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); } private: - void ReallocateWorkspace(size_t required_workspace_len, - boost::upgrade_lock* lock) { - boost::upgrade_to_unique_lock unique_lock(*lock); + void ReallocateWorkspace(size_t required_workspace_len) { if (required_workspace_len <= workspace_len_) { return; } @@ -192,7 +194,7 @@ class CudnnHolder { const cudaStream_t* stream_; // not owned; const CUDAPlace place_; - boost::shared_mutex mtx_; + framework::RWLock rw_lock_; }; CUDADeviceContext::CUDADeviceContext(CUDAPlace place)