diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 3c3f9d17c871ac1cb4df83db17cf489d5b9e0563..3dbbd75b1e945208395c42ace3235db7891936c5 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -56,7 +56,7 @@ class VarHandle { const std::string& name, const platform::DeviceContext* p_ctx = nullptr, const framework::Scope* p_scope = nullptr) - : ok_(kVarHandleDefaultState) { + : status_(kDefaultState) { ep_ = ep; ctx_ = p_ctx; scope_ = p_scope; @@ -68,18 +68,20 @@ class VarHandle { public: bool Wait() { + int ret = kDefaultState; { std::unique_lock lk(sync_mutex_); - wait_cond_.wait(lk, [this] { return ok_ != kVarHandleDefaultState; }); + wait_cond_.wait(lk, [this] { return status_ != kDefaultState; }); + ret = status_; } - VLOG(7) << "VarHandle wait:" << ok_; - return ok_ != 0; + VLOG(7) << "VarHandle wait:" << ret; + return ret != kErrorState; } void Finish(bool ok) { { std::unique_lock lk(sync_mutex_); - ok_ = ok; + status_ = ok ? kFinishState : kErrorState; } VLOG(7) << "VarHandle finish:" << ok; wait_cond_.notify_all(); @@ -87,8 +89,8 @@ class VarHandle { std::string String() const { std::ostringstream s; - s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], ok:[" << ok_ - << "]"; + s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], status:[" + << status_ << "]"; return s.str(); } @@ -111,9 +113,13 @@ class VarHandle { protected: std::mutex sync_mutex_; std::condition_variable wait_cond_; - int ok_; - static const int kVarHandleDefaultState = -1; + enum VarHandleStatus { + kDefaultState = -1, + kErrorState = 0, + kFinishState = 1, + }; + VarHandleStatus status_; private: DISABLE_COPY_AND_ASSIGN(VarHandle);