diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index 51f1d936bd70a910267511d219b8a7fe3cc61d4c..82444ae77dc9dc51151ed0d2a5ad34dcfe126f26 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -5,7 +5,7 @@ endif() proto_library(interceptor_message_proto SRCS interceptor_message.proto) if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) - set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog) + set(BRPC_DEPS brpc ssl crypto protobuf zlib leveldb snappy gflags glog) else() set(BRPC_DEPS "") endif() @@ -13,7 +13,7 @@ endif() cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry - executor_gc_helper ${BRPC_DEPS}) + executor_gc_helper gflags glog ${BRPC_DEPS}) if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 9d9755569b2fc0933c97c95ad7615a1107331c2b..3279f954fa5f880ab9ff1149bc1357f7d6cd210a 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -27,14 +27,16 @@ namespace distributed { USE_INTERCEPTOR(Compute); USE_INTERCEPTOR(Amplifier); -void Carrier::Init(std::shared_ptr runtime_graph, +void Carrier::Init(int64_t rank, std::shared_ptr runtime_graph, framework::Scope* root_scope, framework::Scope* minibatch_scope, const std::vector& microbatch_scopes, const platform::Place& place) { PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists( "Carrier is already init.")); + rank_ = rank; runtime_graph_ = runtime_graph; + interceptor_id_to_rank_ = runtime_graph_->interceptor_id_to_rank(); minibatch_scope_ = minibatch_scope; microbatch_scopes_ = microbatch_scopes; place_ = place; @@ -48,12 +50,6 @@ void Carrier::Release() { // NOTE(wangxi): must join before `Derived Interceptor` destruct, // otherwise Derived object will be destructed before thread complete. - // Sending STOP msg to the source interceptor - PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true, - platform::errors::PreconditionNotMet( - "Using message bus since it has not been initialized. " - "Please invoke MessageBus::Init() before using it or " - "neccessary components are not ready.")); for (int64_t id : source_interceptor_ids_) { VLOG(3) << "Carrier Release is sending stop to source interceptor " << id << "."; @@ -75,10 +71,10 @@ Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; } bool Carrier::EnqueueInterceptorMessage( const InterceptorMessage& interceptor_message) { - // enqueue message to interceptor if (interceptor_message.ctrl_message()) { - // handle control message - return true; + VLOG(3) << "Receiving control message from rank " + << interceptor_message.src_id() << " to rank " + << interceptor_message.dst_id(); } else { { std::unique_lock lock_creating(creating_flag_mutex_); @@ -93,15 +89,9 @@ bool Carrier::EnqueueInterceptorMessage( } int64_t dst_id = interceptor_message.dst_id(); Interceptor* dst_interceptor = GetInterceptor(dst_id); - bool rst = - dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message); - if (rst) { - std::condition_variable& interceptor_cond_var = - dst_interceptor->GetCondVar(); - interceptor_cond_var.notify_all(); - } - return rst; + dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message); } + return true; } Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { @@ -144,9 +134,44 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; } bool Carrier::IsInit() const { return is_init_; } -// TODO(liyurui): Move SendIntra into carrier -bool Carrier::Send(const InterceptorMessage& msg) const { - return msg_bus_->Send(msg); +int64_t Carrier::GetRank(int64_t interceptor_id) const { + PADDLE_ENFORCE_NE( + interceptor_id_to_rank_.find(interceptor_id), + interceptor_id_to_rank_.end(), + platform::errors::NotFound("Cannot find rank for interceptor id %lld.", + interceptor_id)); + return interceptor_id_to_rank_.at(interceptor_id); +} + +bool Carrier::Send(const InterceptorMessage& msg) { + int64_t src_id = (msg.src_id() == -1) ? msg.dst_id() : msg.src_id(); + int64_t dst_id = msg.dst_id(); + int64_t src_rank = GetRank(src_id); + int64_t dst_rank = GetRank(dst_id); + PADDLE_ENFORCE_EQ( + src_rank, rank_, + platform::errors::Fatal("The source rank id %lld, which is not equal to " + "the carrier rank id %lld.", + src_rank, rank_)); + if (src_rank == dst_rank) { + VLOG(3) << "Send a message from interceptor " << src_id + << " to interceptor " << dst_id << ", which are in the same ranks."; + return EnqueueInterceptorMessage(msg); + } else { + PADDLE_ENFORCE_NOT_NULL( + msg_bus_.get(), + platform::errors::Unavailable("Message bus is released accidently")); + PADDLE_ENFORCE_EQ( + msg_bus_->IsInit(), true, + platform::errors::PreconditionNotMet( + "Using message bus since it has not been initialized. " + "Please invoke MessageBus::Init() before using it or " + "neccessary components are not ready.")); + VLOG(3) << "Send a message from interceptor " << src_id + << " to interceptor " << dst_id + << ", which are in different ranks."; + return msg_bus_->Send(dst_rank, msg); + } } Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, @@ -222,13 +247,13 @@ static std::shared_ptr GetGC( } void Carrier::CreateInterceptors() { - if (runtime_graph_->intercepter_id_to_node().empty()) return; + if (runtime_graph_->interceptor_id_to_node().empty()) return; auto gc = GetGC(place_); // create each Interceptor // no auto init since there is no config - for (const auto& item : runtime_graph_->intercepter_id_to_node()) { + for (const auto& item : runtime_graph_->interceptor_id_to_node()) { int64_t interceptor_id = item.first; TaskNode* task_node = item.second; diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index e850c120bdbe5d1b08e6577772038d214888eb31..54cf2150030fca9160437dacf647246e5ec34106 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -45,8 +45,11 @@ class MessageBus; class Carrier final { public: Carrier() = default; + Carrier(int64_t rank, + const std::unordered_map& interceptor_id_to_rank) + : rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {} ~Carrier(); - void Init(std::shared_ptr runtime_graph, + void Init(int64_t rank, std::shared_ptr runtime_graph, framework::Scope* root_scope, framework::Scope* minibatch_scope, const std::vector& microbatch_scopes, const platform::Place& place); @@ -75,7 +78,7 @@ class Carrier final { bool IsInit() const; - bool Send(const InterceptorMessage& msg) const; + bool Send(const InterceptorMessage& msg); // NOTE: This mutex will be used in interceptor's RunOps function. // This mutex is used for avoiding forward ops and backward ops run @@ -90,6 +93,8 @@ class Carrier final { void HandleTmpMessages(); + int64_t GetRank(int64_t interceptor_id) const; + // interceptor logic id to actually interceptor std::unordered_map> interceptor_idx_to_interceptor_; @@ -111,6 +116,7 @@ class Carrier final { paddle::platform::DeviceContext* dev_ctx_{nullptr}; std::shared_ptr runtime_graph_; std::shared_ptr msg_bus_; + int64_t rank_; std::unordered_map interceptor_id_to_rank_; }; diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index f7173a7b8bdfbc3d819f55c5349b7a3ef5025288..697c4aaaf3aaa0a172db47c88620aaa9c738b385 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" -#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/runtime_graph.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -28,6 +27,8 @@ namespace paddle { namespace distributed { +std::unique_ptr FleetExecutor::carrier_; + FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { bool parse_flag = exe_desc_.ParseFromString(exe_desc_str); PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet( @@ -36,12 +37,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { FleetExecutor::~FleetExecutor() { root_scope_->DropKids(); - GetCarrier().Release(); + GetCarrier()->Release(); } -Carrier& FleetExecutor::GetCarrier() { - static Carrier carrier; - return carrier; +Carrier* FleetExecutor::GetCarrier() { + PADDLE_ENFORCE_NOT_NULL(carrier_.get(), platform::errors::NotFound( + "Carrier has not been created.")); + return carrier_.get(); } void FleetExecutor::Init( @@ -84,16 +86,16 @@ void FleetExecutor::Init( } VLOG(5) << runtime_graph_->DebugString(); msg_bus_ = std::make_shared(); + CreateCarrier(); InitCarrier(); InitMessageBus(); } void FleetExecutor::InitCarrier() { - Carrier& carrier = GetCarrier(); - if (!carrier.IsInit()) { - carrier.SetMsgBus(msg_bus_); - carrier.Init(runtime_graph_, root_scope_, minibatch_scope_, - microbatch_scopes_, place_); + if (!GetCarrier()->IsInit()) { + GetCarrier()->SetMsgBus(msg_bus_); + GetCarrier()->Init(exe_desc_.cur_rank(), runtime_graph_, root_scope_, + minibatch_scope_, microbatch_scopes_, place_); } } @@ -128,21 +130,19 @@ void FleetExecutor::InitMessageBus() { << (rank_to_addr.size() == 0 ? 1 : rank_to_addr.size()) << "."; VLOG(5) << ss.str(); if (!msg_bus_->IsInit()) { - msg_bus_->Init(runtime_graph_->intercepter_id_to_rank(), rank_to_addr, - addr); + msg_bus_->Init(cur_rank, rank_to_addr, addr); } } void FleetExecutor::Run() { // Run - Carrier& carrier = GetCarrier(); PADDLE_ENFORCE_EQ( - carrier.IsInit(), true, + GetCarrier()->IsInit(), true, platform::errors::Unavailable("Carrier has not been init yet.")); PADDLE_ENFORCE_EQ( msg_bus_->IsInit(), true, platform::errors::Unavailable("MessageBus has not been init yet.")); - carrier.Start(); + GetCarrier()->Start(); for (auto* micro_scop : microbatch_scopes_) { // By default, we should delete all kid scopes after run executor because // some operators may create local scope when running, such as while_op. diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index a66288525c6f9ba90905915014fe2ddfe2b626c4..3572e07efc5da6797961c2661fcbb5781bce6d7c 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -16,6 +16,7 @@ #include #include +#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h" #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/place.h" @@ -30,7 +31,6 @@ namespace distributed { class RuntimeGraph; class MessageBus; class TaskNode; -class Carrier; class FleetExecutor final { public: @@ -43,7 +43,15 @@ class FleetExecutor final { const std::unordered_map& task_id_to_rank); void Run(); // TODO(liyurui): Change to use registry table for multi-carrier. - static Carrier& GetCarrier(); + static Carrier* GetCarrier(); + template + static Carrier* CreateCarrier(Args&&... args) { + PADDLE_ENFORCE_EQ( + carrier_.get(), nullptr, + platform::errors::AlreadyExists("Carrier has been created already.")); + carrier_ = std::make_unique(std::forward(args)...); + return carrier_.get(); + } private: DISABLE_COPY_AND_ASSIGN(FleetExecutor); @@ -59,6 +67,7 @@ class FleetExecutor final { // The carriers under FleetExecutor will share message bus, // using shared_ptr to manage lifetime and condition race. std::shared_ptr msg_bus_; + static std::unique_ptr carrier_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index d649a84614e4d51a3426aac0a2e5fb4203929318..f5501754cd7299c61ba6463876694220d66d34f5 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -52,24 +52,17 @@ void Interceptor::StopCarrier() { cond_var.notify_all(); } -std::condition_variable& Interceptor::GetCondVar() { - // get the conditional var - return cond_var_; -} - int64_t Interceptor::GetInterceptorId() const { // return the interceptor id return interceptor_id_; } -bool Interceptor::EnqueueRemoteInterceptorMessage( +void Interceptor::EnqueueRemoteInterceptorMessage( const InterceptorMessage& interceptor_message) { // Called by Carrier, enqueue an InterceptorMessage to remote mailbox VLOG(3) << "Enqueue message: " << interceptor_message.message_type() << " into " << interceptor_id_ << "'s remote mailbox."; - std::unique_lock lock(remote_mailbox_mutex_); - remote_mailbox_.push(interceptor_message); - return true; + remote_mailbox_.Push(interceptor_message); } bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { @@ -92,7 +85,7 @@ void Interceptor::PoolTheMailbox() { "Error encountered when fetch remote mailbox.")); } const InterceptorMessage interceptor_message = local_mailbox_.front(); - local_mailbox_.pop(); + local_mailbox_.pop_front(); const MessageType message_type = interceptor_message.message_type(); VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message" << " from interceptor " << interceptor_message.src_id() @@ -109,19 +102,8 @@ void Interceptor::PoolTheMailbox() { } bool Interceptor::FetchRemoteMailbox() { - // fetch all Message from remote mailbox to local mailbox - // return true if remote mailbox not empty, otherwise return false - std::unique_lock lock(remote_mailbox_mutex_); - cond_var_.wait(lock, [this]() { return !remote_mailbox_.empty(); }); - if (remote_mailbox_.empty()) { - // the thread has been unblocked accidentally - return false; - } - while (!remote_mailbox_.empty()) { - local_mailbox_.push(std::move(remote_mailbox_.front())); - remote_mailbox_.pop(); - } - return true; + remote_mailbox_.PopAll(&local_mailbox_); + return !local_mailbox_.empty(); } static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() { diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index bc20058074441eafafa86b8cf20e65fbeed41b07..d9e8d050dd1fc8d03199c0ae490da08617a10021 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -15,14 +15,15 @@ #pragma once #include +#include #include #include #include -#include #include #include #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" +#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/macros.h" @@ -59,11 +60,8 @@ class Interceptor { // return the interceptor id int64_t GetInterceptorId() const; - // return the conditional var - std::condition_variable& GetCondVar(); - // Called by Carrier, enqueue an InterceptorMessage to remote mailbox - bool EnqueueRemoteInterceptorMessage( + void EnqueueRemoteInterceptorMessage( const InterceptorMessage& interceptor_message); bool Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT @@ -115,23 +113,16 @@ class Interceptor { // interceptor handle which process message MsgHandle handle_{nullptr}; - // mutex to control read/write conflict for remote mailbox - std::mutex remote_mailbox_mutex_; - // interceptor runs PoolTheMailbox() function to poll local mailbox std::thread interceptor_thread_; - // conditional variable for blocking the thread when - // fetch an empty remote mailbox - std::condition_variable cond_var_; - // remote mailbox, written by EnqueueRemoteMessage() // read by FetchRemoteMailbox() - std::queue remote_mailbox_; + framework::BlockingQueue remote_mailbox_; // local mailbox, written by FetchRemoteMailbox() // read by PoolTheMailbox() - std::queue local_mailbox_; + std::deque local_mailbox_; int64_t already_run_times_{0}; int64_t used_slot_nums_{0}; diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc index a8d29758ca16385ac2f340eb1aeee4b1fb76454d..231b6c780e24e77683def9955eca49a1b0a07b22 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc @@ -29,8 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService( VLOG(3) << "Interceptor Message Service receives a message from interceptor " << request->src_id() << " to interceptor " << request->dst_id() << ", with the message: " << request->message_type(); - FleetExecutor::GetCarrier().EnqueueInterceptorMessage(*request); - response->set_rst(true); + bool flag = FleetExecutor::GetCarrier()->EnqueueInterceptorMessage(*request); + response->set_rst(flag); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index d4c986de5a03ca4810edbb4ee7abcf69517ea841..ac7b08c4b2868b5f4e9fa88be4f2e266bc3ceb8b 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -17,8 +17,6 @@ #include #include -#include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" @@ -26,16 +24,25 @@ namespace paddle { namespace distributed { void MessageBus::Init( - const std::unordered_map& interceptor_id_to_rank, - const std::unordered_map& rank_to_addr, + int64_t rank, const std::unordered_map& rank_to_addr, const std::string& addr) { PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists( "MessageBus is already init.")); + rank_ = rank; is_init_ = true; - interceptor_id_to_rank_ = interceptor_id_to_rank; rank_to_addr_ = rank_to_addr; addr_ = addr; + if (addr_ != "") { + const auto& addr = GetAddr(rank_); + PADDLE_ENFORCE_EQ(addr, addr_, + platform::errors::Fatal( + "The current rank's addr is %s, while the " + "message bus's addr is %s, which are different. " + "Init error.", + addr, addr_)); + } + #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL) // NOTE: To make the brpc is compatible with collective, @@ -65,43 +72,57 @@ MessageBus::~MessageBus() { #endif } -bool MessageBus::Send(const InterceptorMessage& interceptor_message) { - // called by Interceptor, send InterceptorMessage to dst - int64_t src_id = interceptor_message.src_id(); - int64_t dst_id = interceptor_message.dst_id(); - if (IsSameRank(src_id, dst_id)) { - VLOG(3) << "Send a message from interceptor " << src_id - << " to interceptor " << dst_id << ", which are in the same ranks."; - return SendIntraRank(interceptor_message); - } else { - VLOG(3) << "Send a message from interceptor " << src_id - << " to interceptor " << dst_id - << ", which are in different ranks."; +const std::string& MessageBus::GetAddr(int64_t rank) const { + PADDLE_ENFORCE_NE( + rank_to_addr_.find(rank), rank_to_addr_.end(), + platform::errors::NotFound("Cannot find addr rank id %lld.", rank)); + return rank_to_addr_.at(rank); +} + +bool MessageBus::Send(int64_t dst_rank, + const InterceptorMessage& interceptor_message) { #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) - int retry_time = 0; // message bus will retry sending for 10 times - while (retry_time < 10) { - ++retry_time; - if (SendInterRank(interceptor_message)) { - VLOG(3) << "Message bus sends inter rank successfully with " - << retry_time << " times retries."; - return true; - } - VLOG(3) << "Message bus sends failed, retry after 1 seconds."; - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + int retry_time = 0; // message bus will retry sending for 10 times + while (retry_time < 10) { + ++retry_time; + if (SendInterRank(dst_rank, interceptor_message)) { + VLOG(3) << "Message bus sends inter rank successfully with " << retry_time + << " times retries."; + return true; } - VLOG(3) << "Message bus sends inter rank fail after 10 times retries."; - return false; + VLOG(3) << "Message bus sends failed, retry after 1 seconds."; + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + VLOG(3) << "Message bus sends inter rank fail after 10 times retries."; + return false; #else - PADDLE_THROW(platform::errors::Unavailable( - "Fleet executor does not support sending message between different " - "ranks when Paddle is compiled with npu or " - "isn't compiled with distributed for now.")); + PADDLE_THROW(platform::errors::Unavailable( + "Fleet executor does not support sending message between different " + "ranks when Paddle is compiled with npu or " + "isn't compiled with distributed for now.")); #endif - } return true; } +void MessageBus::TestConnection() { + InterceptorMessage ctrl_msg; + ctrl_msg.set_ctrl_message(true); + ctrl_msg.set_src_id(rank_); + for (const auto& dst_rank_pair : rank_to_addr_) { + int64_t dst_rank = dst_rank_pair.first; + if (dst_rank != rank_) { + ctrl_msg.set_dst_id(dst_rank); + VLOG(3) << "Send control message bus from rank " << rank_ << " to rank " + << dst_rank; + while (!Send(dst_rank, ctrl_msg)) { + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + VLOG(3) << "Message bus has connected to rank: " << dst_rank << "."; + } + } +} + void MessageBus::ListenPort() { if (addr_ == "") { LOG(INFO) << "No need listen to port since training on single card."; @@ -130,30 +151,7 @@ void MessageBus::ListenPort() { interval += 500; } LOG(INFO) << "Message bus's listen port thread starts successful."; - - std::set visit; - InterceptorMessage tmp_msg; - tmp_msg.set_ctrl_message(true); - for (auto pair : interceptor_id_to_rank_) { - if (rank_to_addr_.at(pair.second) == addr_) { - tmp_msg.set_src_id(pair.first); - } - } - for (auto pair : interceptor_id_to_rank_) { - int64_t rank = pair.second; - if (rank_to_addr_.at(rank) == addr_) { - continue; - } - tmp_msg.set_dst_id(pair.first); - if (visit.find(rank) == visit.end()) { - VLOG(3) << "Message bus is testing connection for rank: " << rank << "."; - visit.insert(rank); - while (!Send(tmp_msg)) { - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - } - VLOG(3) << "Message bus has connected to rank: " << rank << "."; - } - } + TestConnection(); #else LOG(WARNING) << "Fleet executor's ListenPort() is a fake function when Paddle is " @@ -162,53 +160,13 @@ void MessageBus::ListenPort() { #endif } -bool MessageBus::IsSameRank(int64_t src_id, int64_t dst_id) { - // -1 is sent by carrier to source interceptor - if (src_id == -1) src_id = dst_id; - - // check whether the dst is the same rank or different rank with src - const auto& src_rank = interceptor_id_to_rank_.find(src_id); - const auto& dst_rank = interceptor_id_to_rank_.find(dst_id); - PADDLE_ENFORCE_NE( - src_rank, interceptor_id_to_rank_.end(), - platform::errors::NotFound( - "Cannot find rank for src interceptor id %lld. Init error.", src_id)); - PADDLE_ENFORCE_NE( - dst_rank, interceptor_id_to_rank_.end(), - platform::errors::NotFound( - "Cannot find rank for dst interceptor id %lld. Init error.", dst_id)); - if (addr_ == "") { - // single card training, must be same rank - return true; - } - const auto& src_ip = rank_to_addr_.find(src_rank->second); - PADDLE_ENFORCE_NE(src_ip, rank_to_addr_.end(), - platform::errors::NotFound( - "Cannot find addr for src rank id %lld. Init error.", - src_rank->second)); - PADDLE_ENFORCE_EQ( - src_ip->second, addr_, - platform::errors::Fatal("The src interceptor's addr is %s, while the " - "message bus's addr is %s, which are different. " - "Init error.", - src_ip->second, addr_)); - return src_rank->second == dst_rank->second; -} - #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) -bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) { - // send the message inter rank (dst is different rank with src) - int64_t dst_id = interceptor_message.dst_id(); - int64_t dst_rank = interceptor_id_to_rank_[dst_id]; - auto dst_ip = rank_to_addr_.find(dst_rank); - PADDLE_ENFORCE_NE(dst_ip, rank_to_addr_.end(), - platform::errors::InvalidArgument( - "Cannot find rank for dst interceptor id %lld. " - "Init error.", - dst_id)); - VLOG(3) << "Message bus sending to addr: " << dst_ip->second; - const char* dst_ip_for_brpc = dst_ip->second.c_str(); +bool MessageBus::SendInterRank(int64_t dst_rank, + const InterceptorMessage& interceptor_message) { + const auto& dst_addr = GetAddr(dst_rank); + VLOG(3) << "Message bus sending to addr: " << dst_addr; + const char* dst_addr_for_brpc = dst_addr.c_str(); brpc::Channel channel; brpc::ChannelOptions options; options.protocol = "baidu_std"; @@ -216,7 +174,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) { options.timeout_ms = 1000; options.max_retry = 5; PADDLE_ENFORCE_EQ( - channel.Init(dst_ip_for_brpc, &options), 0, + channel.Init(dst_addr_for_brpc, &options), 0, platform::errors::Unavailable("Message bus: init brpc channel error.")); TheInterceptorMessageService_Stub stub(&channel); InterceptorResponse response; @@ -239,11 +197,5 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) { } #endif -bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) { - // send the message intra rank (dst is the same rank with src) - return FleetExecutor::GetCarrier().EnqueueInterceptorMessage( - interceptor_message); -} - } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.h b/paddle/fluid/distributed/fleet_executor/message_bus.h index 3f151cab3a46c689f707f4ca3590772a8d6bc47f..d4a2af54e6cd4a538d6590255360e463384a4e11 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.h +++ b/paddle/fluid/distributed/fleet_executor/message_bus.h @@ -42,14 +42,14 @@ class MessageBus final { MessageBus() = default; ~MessageBus(); - void Init(const std::unordered_map& interceptor_id_to_rank, + void Init(int64_t rank, const std::unordered_map& rank_to_addr, const std::string& addr); bool IsInit() const; // called by Interceptor, send InterceptorMessage to dst - bool Send(const InterceptorMessage& interceptor_message); + bool Send(int64_t dst_rank, const InterceptorMessage& interceptor_message); private: DISABLE_COPY_AND_ASSIGN(MessageBus); @@ -57,22 +57,20 @@ class MessageBus final { // function keep listen the port and handle the message void ListenPort(); - // check whether the dst is the same rank or different rank with src - bool IsSameRank(int64_t src_id, int64_t dst_id); + void TestConnection(); + + const std::string& GetAddr(int64_t rank) const; #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) // send the message inter rank (dst is different rank with src) - bool SendInterRank(const InterceptorMessage& interceptor_message); + bool SendInterRank(int64_t dst_rank, + const InterceptorMessage& interceptor_message); #endif bool is_init_{false}; - // send the message intra rank (dst is the same rank with src) - bool SendIntraRank(const InterceptorMessage& interceptor_message); - - // handed by above layer, save the info mapping interceptor id to rank id - std::unordered_map interceptor_id_to_rank_; + int64_t rank_; // handed by above layer, save the info mapping rank id to addr std::unordered_map rank_to_addr_; diff --git a/paddle/fluid/distributed/fleet_executor/runtime_graph.cc b/paddle/fluid/distributed/fleet_executor/runtime_graph.cc index 1ad144470af2668fcbc0098e91f272b8f5f96b96..614b4c37e82545d7898fbc5db03f35991e8d3f1d 100644 --- a/paddle/fluid/distributed/fleet_executor/runtime_graph.cc +++ b/paddle/fluid/distributed/fleet_executor/runtime_graph.cc @@ -21,7 +21,7 @@ namespace distributed { std::string RuntimeGraph::DebugString() const { std::ostringstream os; os << "\nRuntime Graph Debug: \n"; - for (const auto& pair : intercepter_id_to_node_) { + for (const auto& pair : interceptor_id_to_node_) { os << pair.second->DebugString(); os << "\n"; } diff --git a/paddle/fluid/distributed/fleet_executor/runtime_graph.h b/paddle/fluid/distributed/fleet_executor/runtime_graph.h index 3678e2e860a9d9e746c9d338bd55dea47cf8edc4..1ca9f0174ed07f3c12a8fb937799cfc4dd444b37 100644 --- a/paddle/fluid/distributed/fleet_executor/runtime_graph.h +++ b/paddle/fluid/distributed/fleet_executor/runtime_graph.h @@ -29,26 +29,26 @@ class RuntimeGraph final { public: RuntimeGraph() = default; ~RuntimeGraph() = default; - const std::unordered_map& intercepter_id_to_node() const { - return intercepter_id_to_node_; + const std::unordered_map& interceptor_id_to_node() const { + return interceptor_id_to_node_; } - const std::unordered_map& intercepter_id_to_rank() const { - return intercepter_id_to_rank_; + const std::unordered_map& interceptor_id_to_rank() const { + return interceptor_id_to_rank_; } void SetInterceptorIdToRank( - const std::unordered_map& intercepter_id_to_rank) { - intercepter_id_to_rank_ = intercepter_id_to_rank; + const std::unordered_map& interceptor_id_to_rank) { + interceptor_id_to_rank_ = interceptor_id_to_rank; } void SetInterceptorIdToNode( - const std::unordered_map& intercepter_id_to_node) { - intercepter_id_to_node_ = intercepter_id_to_node; + const std::unordered_map& interceptor_id_to_node) { + interceptor_id_to_node_ = interceptor_id_to_node; } std::string DebugString() const; private: DISABLE_COPY_AND_ASSIGN(RuntimeGraph); - std::unordered_map intercepter_id_to_node_; - std::unordered_map intercepter_id_to_rank_; + std::unordered_map interceptor_id_to_node_; + std::unordered_map interceptor_id_to_rank_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc index e56696d35f2a46c94343be4196ebe82568646fa9..2e0a12b4244cecd6374a4e08609737616c23b98c 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc @@ -62,11 +62,10 @@ TEST(ComputeInterceptor, Compute) { std::vector scopes = {scope, scope}; platform::Place place = platform::CPUPlace(); - // TODO(liyurui): Remove singleton when move SendIntra into Carrier - Carrier& carrier = FleetExecutor::GetCarrier(); + Carrier carrier(0, {{0, 0}, {1, 0}}); auto msg_bus = std::make_shared(); - msg_bus->Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, ""); + msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); carrier.SetMsgBus(msg_bus); // FIXME: don't delete, otherwise interceptor will use undefined node diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc index 3bd2ddec4effcb3f30061c612bb5babe7c1c228c..47f4cf0c04849e9aac3fcb436d9a52d56a2ad946 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc @@ -47,11 +47,10 @@ class StartInterceptor : public Interceptor { }; TEST(ComputeInterceptor, Compute) { - // TODO(liyurui): Remove singleton when move SendIntra into Carrier - Carrier& carrier = FleetExecutor::GetCarrier(); + Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}}); auto msg_bus = std::make_shared(); - msg_bus->Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, ""); + msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); carrier.SetMsgBus(msg_bus); // NOTE: don't delete, otherwise interceptor will use undefined node diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc index 8d9e609a2403405ccfe68880bff7f6cfd493a537..639e16a94a1d0260ca043858d5f40f57367dd916 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc @@ -18,7 +18,6 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" @@ -60,11 +59,9 @@ class PingPongInterceptor : public Interceptor { REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor); TEST(InterceptorTest, PingPong) { - // TODO(liyurui): Remove singleton when move SendIntra into Carrier - Carrier& carrier = FleetExecutor::GetCarrier(); - + Carrier carrier(0, {{0, 0}, {1, 0}}); auto msg_bus = std::make_shared(); - msg_bus->Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, ""); + msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); carrier.SetMsgBus(msg_bus); Interceptor* a = carrier.SetInterceptor( diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc index 93574609960a11b5853cf0a4c3c022d12210eb0d..a577b30fa8c0b7f2cd8d8969f7907bba034ccd32 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc @@ -104,35 +104,42 @@ TEST(InterceptorTest, PingPong) { std::string ip1 = "127.0.0.1:" + std::to_string(port1); std::cout << "ip0: " << ip0 << std::endl; std::cout << "ip1: " << ip1 << std::endl; - - int pid = fork(); - if (pid == 0) { - auto msg_bus = std::make_shared(); - msg_bus->Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip0); - - // TODO(liyurui): Remove singleton when move SendIntra into Carrier - Carrier& carrier = FleetExecutor::GetCarrier(); - carrier.SetMsgBus(msg_bus); - - Interceptor* a = carrier.SetInterceptor( - 0, InterceptorFactory::Create("PingPong", 0, nullptr)); - carrier.SetCreatingFlag(false); - - InterceptorMessage msg; - a->Send(1, msg); - carrier.Wait(); + std::unordered_map interceptor_id_to_rank = {{0, 0}, + {1, 1}}; + + int exe_pid = fork(); + if (exe_pid == 0) { + int pid = fork(); + if (pid == 0) { + Carrier* carrier = + FleetExecutor::CreateCarrier(0, interceptor_id_to_rank); + carrier->SetCreatingFlag(false); + auto msg_bus = std::make_shared(); + msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0); + carrier->SetMsgBus(msg_bus); + Interceptor* a = carrier->SetInterceptor( + 0, InterceptorFactory::Create("PingPong", 0, nullptr)); + InterceptorMessage msg; + a->Send(1, msg); + carrier->Wait(); + } else { + Carrier* carrier = + FleetExecutor::CreateCarrier(1, interceptor_id_to_rank); + carrier->SetCreatingFlag(false); + auto msg_bus = std::make_shared(); + msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1); + carrier->SetMsgBus(msg_bus); + carrier->SetInterceptor( + 1, InterceptorFactory::Create("PingPong", 1, nullptr)); + carrier->Wait(); + int status; + int ret = waitpid(pid, &status, 0); + CHECK_EQ(ret, pid); + } } else { - auto msg_bus = std::make_shared(); - msg_bus->Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip1); - - // TODO(liyurui): Remove singleton when move SendIntra into Carrier - Carrier& carrier = FleetExecutor::GetCarrier(); - carrier.SetMsgBus(msg_bus); - - carrier.SetInterceptor(1, - InterceptorFactory::Create("PingPong", 1, nullptr)); - carrier.SetCreatingFlag(false); - carrier.Wait(); + int status; + int ret = waitpid(exe_pid, &status, 0); + CHECK_EQ(ret, exe_pid); } } diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc index cf66725a88f8003cc071f26a1e064730c34fe27a..b203617738571da1e6d03b8305c8408c10929180 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc @@ -18,7 +18,6 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -52,11 +51,9 @@ void LinkNodes(const std::vector& nodes) { } TEST(AmplifierInterceptor, Amplifier) { - // TODO(liyurui): Remove singleton when move SendIntra into Carrier - Carrier& carrier = FleetExecutor::GetCarrier(); + Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}); auto msg_bus = std::make_shared(); - msg_bus->Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}, - {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); + msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); carrier.SetMsgBus(msg_bus); int64_t micro_steps = 3; diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc index e2ca934b5b02f58abeaf59d26384923356c46d44..68ee054e76fdacf9c35d44e3c69846379d312725 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc @@ -18,7 +18,6 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -70,10 +69,9 @@ void LinkNodes(const std::vector& nodes, } TEST(AmplifierInterceptor, Amplifier) { - // TODO(liyurui): Remove singleton when move SendIntra into Carrier - Carrier& carrier = FleetExecutor::GetCarrier(); + Carrier carrier(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}); auto msg_bus = std::make_shared(); - msg_bus->Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, ""}}, ""); + msg_bus->Init(0, {{0, ""}}, ""); carrier.SetMsgBus(msg_bus); int64_t micro_steps = 6; diff --git a/paddle/fluid/framework/blocking_queue.h b/paddle/fluid/framework/blocking_queue.h index 4f35da402f3ec2b0616c29085d01e8b7f3d0d472..5bc38c1398aa5ee02470a59b79e8f5c0c97b22e8 100644 --- a/paddle/fluid/framework/blocking_queue.h +++ b/paddle/fluid/framework/blocking_queue.h @@ -75,6 +75,12 @@ class BlockingQueue { return ret; } + void PopAll(std::deque *empty_queue) { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return !q_.empty(); }); + std::swap(*empty_queue, q_); + } + T Pop() { std::unique_lock lock(mutex_); cv_.wait(lock, [=] { return !q_.empty(); });