diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 0e79656edea091b9dea48d4415d228ab88270758..b87f48bc27c544e3bfe1d24a2da1e78047565d96 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -33,6 +33,13 @@ bool Carrier::EnqueueInterceptorMessage( // handle control message return true; } else { + if (creating_interceptors_) { + // Cannot handle the message to interceptor since interceptors + // are still under creating. Will enqueue into a tmp stack. + VLOG(3) << "Receiving message while creating interceptors."; + message_tmp_.emplace_back(interceptor_message); + return true; + } int64_t dst_id = interceptor_message.dst_id(); Interceptor* dst_interceptor = GetInterceptor(dst_id); bool rst = @@ -70,16 +77,45 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, return ptr; } +void Carrier::SetCreatingFlag(bool flag) { + // set the creating flag + VLOG(3) << "Carrier is set the creating flag from " << creating_interceptors_ + << " to " << flag << "."; + creating_interceptors_ = flag; + if (!flag) { + // finish create interceptors outside, handle tmp messsages + HandleTmpMessages(); + } +} + +void Carrier::HandleTmpMessages() { + VLOG(3) << "Carrier has received " << message_tmp_.size() + << " messages during creating interceptors."; + for (const auto& msg : message_tmp_) { + EnqueueInterceptorMessage(msg); + } + message_tmp_.clear(); +} + void Carrier::CreateInterceptors() { // create each Interceptor - for (const auto& item : interceptor_id_to_node_) { - int64_t interceptor_id = item.first; - TaskNode* task_node = item.second; + if (!interceptor_id_to_node_.empty()) { + // no auto init since there is no config + for (const auto& item : interceptor_id_to_node_) { + int64_t interceptor_id = item.first; + TaskNode* task_node = item.second; - // TODO(wangxi): use node_type to select different Interceptor - auto interceptor = std::make_unique(interceptor_id, task_node); - SetInterceptor(interceptor_id, std::move(interceptor)); - VLOG(3) << "Create Interceptor for " << interceptor_id; + // TODO(wangxi): use node_type to select different Interceptor + auto interceptor = + std::make_unique(interceptor_id, task_node); + SetInterceptor(interceptor_id, std::move(interceptor)); + VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id + << "."; + } + // The carrier will be always waiting for outside initializer + // since there is no interceptor has been created during auto init + creating_interceptors_ = false; + HandleTmpMessages(); } } diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index 64974714f7b1c779c8a078f51c412c798090650b..95f9ffcdf4960f7831b405d15fa52393e2c30a90 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -17,6 +17,7 @@ #include #include #include +#include #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" @@ -53,6 +54,8 @@ class Carrier final { Interceptor* SetInterceptor(int64_t interceptor_id, std::unique_ptr); + void SetCreatingFlag(bool flag); + DISABLE_COPY_AND_ASSIGN(Carrier); private: @@ -61,12 +64,17 @@ class Carrier final { // create each Interceptor void CreateInterceptors(); + void HandleTmpMessages(); + // interceptor logic id to the Nodes info std::unordered_map interceptor_id_to_node_; // interceptor logic id to actually interceptor std::unordered_map> interceptor_idx_to_interceptor_; + + std::vector message_tmp_{}; + bool creating_interceptors_{true}; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index dbee46afcf86fac3f2b47f7d6d729afe2968407b..6b606290fa160e421772d9742443d07a0490bedd 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -56,10 +56,10 @@ bool Interceptor::EnqueueRemoteInterceptorMessage( return true; } -void Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { +bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { msg.set_src_id(interceptor_id_); msg.set_dst_id(dst_id); - MessageBus::Instance().Send(msg); + return MessageBus::Instance().Send(msg); } void Interceptor::PoolTheMailbox() { @@ -76,10 +76,12 @@ void Interceptor::PoolTheMailbox() { const InterceptorMessage interceptor_message = local_mailbox_.front(); local_mailbox_.pop(); const MessageType message_type = interceptor_message.message_type(); - VLOG(3) << interceptor_id_ << " has received a message: " << message_type - << "."; + VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message" + << " from interceptor " << interceptor_message.src_id() + << " with message: " << message_type << "."; if (message_type == STOP) { // break the pooling thread + VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting."; break; } diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 24fad8331863e29099c53d7e536d2dffb9df3d12..2e86dc2fe525d44d491aad7b8ef730317b50777f 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -58,7 +58,7 @@ class Interceptor { bool EnqueueRemoteInterceptorMessage( const InterceptorMessage& interceptor_message); - void Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT + bool Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT DISABLE_COPY_AND_ASSIGN(Interceptor); diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc index 2205c6e5544bb5128f329edc0844c795e66462ea..44195467045c34258f206d98fa330dd94f784d96 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc @@ -26,8 +26,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService( const InterceptorMessage* request, InterceptorResponse* response, google::protobuf::Closure* done) { brpc::ClosureGuard done_guard(done); - VLOG(3) << "Interceptor Message Service receives a message from: " - << request->src_id() + VLOG(3) << "Interceptor Message Service receives a message from interceptor " + << request->src_id() << " to interceptor " << request->dst_id() << ", with the message: " << request->message_type(); response->set_rst(true); // call interceptor manager's method to handle the message diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index 27a1f90767fe6ed55d8af5ca81780282621ca169..309982bc04bebd5df9b75ce43d7605440abe6fec 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -55,12 +55,13 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) { int64_t src_id = interceptor_message.src_id(); int64_t dst_id = interceptor_message.dst_id(); if (IsSameRank(src_id, dst_id)) { - VLOG(3) << "Send a message from rank " << src_id << " to rank " << dst_id - << ", which are same ranks."; + VLOG(3) << "Send a message from interceptor " << src_id + << " to interceptor " << dst_id << ", which are in the same ranks."; return SendIntraRank(interceptor_message); } else { - VLOG(3) << "Send a message from rank " << src_id << " to rank " << dst_id - << ", which are different ranks."; + VLOG(3) << "Send a message from interceptor " << src_id + << " to interceptor " << dst_id + << ", which are in different ranks."; #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) int retry_time = 0; // message bus will retry sending for 10 times @@ -155,6 +156,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) { "Cannot find rank for dst interceptor id %lld. " "Init error.", dst_id)); + VLOG(3) << "Message bus sending to addr: " << dst_ip->second; const char* dst_ip_for_brpc = dst_ip->second.c_str(); brpc::Channel channel; brpc::ChannelOptions options; diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc index 856bbb4754738efd4bde8a569e38df1716711d23..783c924398a70307b47c28b90082241fb711b344 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc @@ -59,8 +59,8 @@ TEST(InterceptorTest, PingPong) { Interceptor* a = carrier.SetInterceptor( 0, std::make_unique(0, nullptr)); - carrier.SetInterceptor(1, std::make_unique(1, nullptr)); + carrier.SetCreatingFlag(false); InterceptorMessage msg; a->Send(1, msg);