diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 728cfc626079e108b53220577b4e5b1af373b716..939e987b397a37b9de2ecc33056819994e5e2eb2 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -48,12 +48,16 @@ bool Carrier::EnqueueInterceptorMessage( // handle control message return true; } else { - if (creating_interceptors_) { - // Cannot handle the message to interceptor since interceptors - // are still under creating. Will enqueue into a tmp stack. - VLOG(3) << "Receiving message while creating interceptors."; - message_tmp_.emplace_back(interceptor_message); - return true; + { + std::unique_lock lock_creating(creating_flag_mutex_); + if (creating_interceptors_) { + std::unique_lock lock_message(tmp_message_mutex_); + // Cannot handle the message to interceptor since interceptors + // are still under creating. Will enqueue into a tmp stack. + VLOG(3) << "Receiving message while creating interceptors."; + message_tmp_.emplace_back(interceptor_message); + return true; + } } int64_t dst_id = interceptor_message.dst_id(); Interceptor* dst_interceptor = GetInterceptor(dst_id); @@ -112,9 +116,11 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, void Carrier::SetCreatingFlag(bool flag) { // set the creating flag + creating_flag_mutex_.lock(); VLOG(3) << "Carrier is set the creating flag from " << creating_interceptors_ << " to " << flag << "."; creating_interceptors_ = flag; + creating_flag_mutex_.unlock(); if (!flag) { // finish create interceptors outside, handle tmp messsages HandleTmpMessages(); @@ -122,6 +128,12 @@ void Carrier::SetCreatingFlag(bool flag) { } void Carrier::HandleTmpMessages() { + // NOTE: It's ok lock on the tmp_message_mutex_ here, when enter this + // `HandleTmpMessages` method, the creating_interceptors_ flag + // must be false, therefore, there won't have conflict with the + // lock on the tmp_message_mutex_ inside `EnqueueInterceptorMessage` + // on the same thread. + std::unique_lock lock(tmp_message_mutex_); VLOG(3) << "Carrier has received " << message_tmp_.size() << " messages during creating interceptors."; for (const auto& msg : message_tmp_) { @@ -147,7 +159,9 @@ void Carrier::CreateInterceptors() { } // The carrier will be always waiting for outside initializer // since there is no interceptor has been created during auto init + creating_flag_mutex_.lock(); creating_interceptors_ = false; + creating_flag_mutex_.unlock(); HandleTmpMessages(); } } diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index 3413ed50004845e4ec21b920dd756b2017ed5bdb..980847c716b79f7e5be91cb1751699e31ba3c26a 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -78,7 +79,9 @@ class Carrier final { interceptor_idx_to_interceptor_; std::vector message_tmp_{}; + std::mutex tmp_message_mutex_; bool creating_interceptors_{true}; + std::mutex creating_flag_mutex_; bool is_init_{false}; };