diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 009df6438e27078c44605a3ecd97901b6ed532a7..1fa1f119191044f1a220962fc6c53b1902bc1d5e 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -49,10 +49,11 @@ void Carrier::Release() { // otherwise Derived object will be destructed before thread complete. // Sending STOP msg to the source interceptor - MessageBus& msg_bus = MessageBus::Instance(); - PADDLE_ENFORCE_EQ(msg_bus.IsInit(), true, + PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true, platform::errors::PreconditionNotMet( - "Message bus has not been initialized.")); + "Using message bus since it has not been initialized. " + "Please invoke MessageBus::Init() before using it or " + "neccessary components are not ready.")); for (int64_t id : source_interceptor_ids_) { VLOG(3) << "Carrier Release is sending stop to source interceptor " << id << "."; @@ -61,7 +62,7 @@ void Carrier::Release() { stop_msg.set_src_id(-1); stop_msg.set_dst_id(id); stop_msg.set_message_type(STOP); - msg_bus.Send(stop_msg); + Send(stop_msg); } // TODO(wangxi): Maybe need a better to use thread. @@ -113,11 +114,17 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { return iter->second.get(); } +void Carrier::Wait() { + std::unique_lock lock(running_mutex_); + cond_var_.wait(lock); +} + void Carrier::Start() { - MessageBus& msg_bus = MessageBus::Instance(); - PADDLE_ENFORCE_EQ(msg_bus.IsInit(), true, + PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true, platform::errors::PreconditionNotMet( - "Message bus has not been initialized.")); + "Using message bus since it has not been initialized. " + "Please invoke MessageBus::Init() before using it or " + "neccessary components are not ready.")); for (int64_t id : source_interceptor_ids_) { VLOG(3) << "Carrier Start is sending start to source interceptor " << id @@ -127,11 +134,9 @@ void Carrier::Start() { start_msg.set_src_id(-1); start_msg.set_dst_id(id); start_msg.set_message_type(DATA_IS_READY); - msg_bus.Send(start_msg); + Send(start_msg); } - - std::unique_lock lock(running_mutex_); - cond_var_.wait(lock); + Wait(); dev_ctx_->Wait(); } @@ -139,6 +144,11 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; } bool Carrier::IsInit() const { return is_init_; } +// TODO(liyurui): Move SendIntra into carrier +bool Carrier::Send(const InterceptorMessage& msg) const { + return msg_bus_->Send(msg); +} + Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, std::unique_ptr interceptor) { auto iter = interceptor_idx_to_interceptor_.find(interceptor_id); @@ -147,6 +157,7 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, "The interceptor id %lld has already been created! " "The interceptor id should be unique.", interceptor_id)); + interceptor->RegisterCarrier(this); auto* ptr = interceptor.get(); interceptor_idx_to_interceptor_.insert( std::make_pair(interceptor_id, std::move(interceptor))); diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index f9411aa73fad49daa11cf0573c44664b6b01933b..e850c120bdbe5d1b08e6577772038d214888eb31 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -40,22 +40,19 @@ namespace distributed { class TaskNode; class InterceptorMessageServiceImpl; class RuntimeGraph; +class MessageBus; -// A singleton MessageBus class Carrier final { public: - static Carrier& Instance() { - static Carrier carrier; - return carrier; - } - + Carrier() = default; + ~Carrier(); void Init(std::shared_ptr runtime_graph, framework::Scope* root_scope, framework::Scope* minibatch_scope, const std::vector& microbatch_scopes, const platform::Place& place); - ~Carrier(); void Release(); + void Wait(); // Enqueue a message to corresponding interceptor id bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message); @@ -68,6 +65,9 @@ class Carrier final { std::unique_ptr); void SetCreatingFlag(bool flag); + void SetMsgBus(const std::shared_ptr& msg_bus) { + msg_bus_ = msg_bus; + } std::condition_variable& GetCondVar(); @@ -75,15 +75,15 @@ class Carrier final { bool IsInit() const; + bool Send(const InterceptorMessage& msg) const; + // NOTE: This mutex will be used in interceptor's RunOps function. // This mutex is used for avoiding forward ops and backward ops run // simultaneously, which will lead to a random hang for some sync ops. std::mutex run; - DISABLE_COPY_AND_ASSIGN(Carrier); - private: - Carrier() = default; + DISABLE_COPY_AND_ASSIGN(Carrier); // create each Interceptor void CreateInterceptors(); @@ -110,6 +110,8 @@ class Carrier final { paddle::platform::Place place_; paddle::platform::DeviceContext* dev_ctx_{nullptr}; std::shared_ptr runtime_graph_; + std::shared_ptr msg_bus_; + std::unordered_map interceptor_id_to_rank_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 98583de84e7ea9c0fef595b097e7fe9116142c81..1f0d3408a3da857011e47cd893ee63f2c8282d47 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -170,8 +170,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { } void ComputeInterceptor::RunOps() { - Carrier& carrier_instance = Carrier::Instance(); - std::unique_lock lock(carrier_instance.run); + std::unique_lock lock(carrier_->run); VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the " << step_ + 1 << " time."; for (auto op : node_->ops()) { diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 0369c442734a43e12bf051ce8c5a905620a5d803..a2a51c45f4390e6932b73337439a0fd8288ee95d 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -34,7 +34,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { "Error occurs while parsing string to proto")); } -FleetExecutor::~FleetExecutor() { root_scope_->DropKids(); } +FleetExecutor::~FleetExecutor() { + root_scope_->DropKids(); + GetCarrier().Release(); +} + +Carrier& FleetExecutor::GetCarrier() { + static Carrier carrier; + return carrier; +} void FleetExecutor::Init( const framework::ProgramDesc& program_desc, framework::Scope* scope, @@ -78,15 +86,17 @@ void FleetExecutor::Init( CopyParameters(i, program_desc); } VLOG(5) << runtime_graph_->DebugString(); + msg_bus_ = std::make_shared(); InitCarrier(); InitMessageBus(); } void FleetExecutor::InitCarrier() { - Carrier& carrier_instance = Carrier::Instance(); - if (!carrier_instance.IsInit()) { - carrier_instance.Init(runtime_graph_, root_scope_, minibatch_scope_, - microbatch_scopes_, place_); + Carrier& carrier = GetCarrier(); + if (!carrier.IsInit()) { + carrier.SetMsgBus(msg_bus_); + carrier.Init(runtime_graph_, root_scope_, minibatch_scope_, + microbatch_scopes_, place_); } } @@ -120,24 +130,22 @@ void FleetExecutor::InitMessageBus() { VLOG(3) << "The number of ranks are " << (rank_to_addr.size() == 0 ? 1 : rank_to_addr.size()) << "."; VLOG(5) << ss.str(); - MessageBus& message_bus_instance = MessageBus::Instance(); - if (!message_bus_instance.IsInit()) { - message_bus_instance.Init(runtime_graph_->intercepter_id_to_rank(), - rank_to_addr, addr); + if (!msg_bus_->IsInit()) { + msg_bus_->Init(runtime_graph_->intercepter_id_to_rank(), rank_to_addr, + addr); } } void FleetExecutor::Run() { // Run - Carrier& carrier_instance = Carrier::Instance(); - MessageBus& message_bus_instance = MessageBus::Instance(); + Carrier& carrier = GetCarrier(); PADDLE_ENFORCE_EQ( - carrier_instance.IsInit(), true, + carrier.IsInit(), true, platform::errors::Unavailable("Carrier has not been init yet.")); PADDLE_ENFORCE_EQ( - message_bus_instance.IsInit(), true, + msg_bus_->IsInit(), true, platform::errors::Unavailable("MessageBus has not been init yet.")); - carrier_instance.Start(); + carrier.Start(); for (auto* micro_scop : microbatch_scopes_) { // By default, we should delete all kid scopes after run executor because // some operators may create local scope when running, such as while_op. diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index 9fddeae63f6765f1e7bb403611e89c4a5a02d185..a66288525c6f9ba90905915014fe2ddfe2b626c4 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -28,9 +28,9 @@ class Scope; namespace distributed { class RuntimeGraph; -class Carrier; class MessageBus; class TaskNode; +class Carrier; class FleetExecutor final { public: @@ -42,6 +42,8 @@ class FleetExecutor final { const std::vector& task_nodes, const std::unordered_map& task_id_to_rank); void Run(); + // TODO(liyurui): Change to use registry table for multi-carrier. + static Carrier& GetCarrier(); private: DISABLE_COPY_AND_ASSIGN(FleetExecutor); @@ -54,6 +56,9 @@ class FleetExecutor final { framework::Scope* minibatch_scope_; platform::Place place_; std::vector microbatch_scopes_; + // The carriers under FleetExecutor will share message bus, + // using shared_ptr to manage lifetime and condition race. + std::shared_ptr msg_bus_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index dd7b89c4b81199ee76fcef6db0c35d4f9d73185e..d649a84614e4d51a3426aac0a2e5fb4203929318 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -14,7 +14,6 @@ #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" -#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" namespace paddle { @@ -46,8 +45,9 @@ void Interceptor::Handle(const InterceptorMessage& msg) { } void Interceptor::StopCarrier() { - Carrier& carrier_instance = Carrier::Instance(); - std::condition_variable& cond_var = carrier_instance.GetCondVar(); + PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet( + "Carrier is not registered.")); + std::condition_variable& cond_var = carrier_->GetCondVar(); // probably double notify, but ok for ut cond_var.notify_all(); } @@ -73,9 +73,11 @@ bool Interceptor::EnqueueRemoteInterceptorMessage( } bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { + PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet( + "Carrier is not registered.")); msg.set_src_id(interceptor_id_); msg.set_dst_id(dst_id); - return MessageBus::Instance().Send(msg); + return carrier_->Send(msg); } void Interceptor::PoolTheMailbox() { diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index b0c1e46f03138214f4f1dc0aab8fc35647b902e8..bc20058074441eafafa86b8cf20e65fbeed41b07 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -36,6 +36,7 @@ class GarbageCollector; namespace distributed { class TaskNode; +class Carrier; class Interceptor { public: @@ -77,6 +78,7 @@ class Interceptor { void SetGC(const std::shared_ptr& gc) { gc_ = gc; } + void RegisterCarrier(Carrier* carrier) { carrier_ = carrier; } TaskNode* GetTaskNode() const { return node_; } @@ -100,6 +102,8 @@ class Interceptor { std::vector microbatch_scopes_{}; std::shared_ptr gc_{nullptr}; + Carrier* carrier_; + private: // pool the local mailbox, parse the Message void PoolTheMailbox(); diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc index 44195467045c34258f206d98fa330dd94f784d96..a8d29758ca16385ac2f340eb1aeee4b1fb76454d 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc @@ -29,9 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService( VLOG(3) << "Interceptor Message Service receives a message from interceptor " << request->src_id() << " to interceptor " << request->dst_id() << ", with the message: " << request->message_type(); + FleetExecutor::GetCarrier().EnqueueInterceptorMessage(*request); response->set_rst(true); - // call interceptor manager's method to handle the message - Carrier::Instance().EnqueueInterceptorMessage(*request); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index f087de69fa96b2861d916e28fa4f4a292f791401..d4c986de5a03ca4810edbb4ee7abcf69517ea841 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -57,10 +57,6 @@ void MessageBus::Init( bool MessageBus::IsInit() const { return is_init_; } MessageBus::~MessageBus() { - // NOTE: fleet_executor inits carrier before message bus, - // therefore the message bus's destructor will be called first - Carrier& carrier = Carrier::Instance(); - carrier.Release(); VLOG(3) << "Message bus releases resource."; #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) @@ -245,7 +241,8 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) { bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) { // send the message intra rank (dst is the same rank with src) - return Carrier::Instance().EnqueueInterceptorMessage(interceptor_message); + return FleetExecutor::GetCarrier().EnqueueInterceptorMessage( + interceptor_message); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.h b/paddle/fluid/distributed/fleet_executor/message_bus.h index 5b19a894aa35171a8b672804a9b90e7480db2668..3f151cab3a46c689f707f4ca3590772a8d6bc47f 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.h +++ b/paddle/fluid/distributed/fleet_executor/message_bus.h @@ -39,10 +39,8 @@ class Carrier; // A singleton MessageBus class MessageBus final { public: - static MessageBus& Instance() { - static MessageBus msg_bus; - return msg_bus; - } + MessageBus() = default; + ~MessageBus(); void Init(const std::unordered_map& interceptor_id_to_rank, const std::unordered_map& rank_to_addr, @@ -53,12 +51,8 @@ class MessageBus final { // called by Interceptor, send InterceptorMessage to dst bool Send(const InterceptorMessage& interceptor_message); - ~MessageBus(); - - DISABLE_COPY_AND_ASSIGN(MessageBus); - private: - MessageBus() = default; + DISABLE_COPY_AND_ASSIGN(MessageBus); // function keep listen the port and handle the message void ListenPort(); @@ -72,12 +66,11 @@ class MessageBus final { bool SendInterRank(const InterceptorMessage& interceptor_message); #endif + bool is_init_{false}; + // send the message intra rank (dst is the same rank with src) bool SendIntraRank(const InterceptorMessage& interceptor_message); - bool is_init_{false}; - std::once_flag once_flag_; - // handed by above layer, save the info mapping interceptor id to rank id std::unordered_map interceptor_id_to_rank_; diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc index c5348db83e0298db1c25c7424fa0e37c5724c24b..e56696d35f2a46c94343be4196ebe82568646fa9 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -61,10 +62,12 @@ TEST(ComputeInterceptor, Compute) { std::vector scopes = {scope, scope}; platform::Place place = platform::CPUPlace(); - Carrier& carrier = Carrier::Instance(); + // TODO(liyurui): Remove singleton when move SendIntra into Carrier + Carrier& carrier = FleetExecutor::GetCarrier(); - MessageBus& msg_bus = MessageBus::Instance(); - msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); + auto msg_bus = std::make_shared(); + msg_bus->Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, ""); + carrier.SetMsgBus(msg_bus); // FIXME: don't delete, otherwise interceptor will use undefined node TaskNode* node_a = @@ -90,6 +93,9 @@ TEST(ComputeInterceptor, Compute) { msg.set_src_id(-1); msg.set_dst_id(0); carrier.EnqueueInterceptorMessage(msg); + + carrier.Wait(); + carrier.Release(); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc index 44dc0c9bc9b0c9b1010ab20d8696e97e13165173..3bd2ddec4effcb3f30061c612bb5babe7c1c228c 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -46,9 +47,12 @@ class StartInterceptor : public Interceptor { }; TEST(ComputeInterceptor, Compute) { - Carrier& carrier = Carrier::Instance(); - MessageBus& msg_bus = MessageBus::Instance(); - msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); + // TODO(liyurui): Remove singleton when move SendIntra into Carrier + Carrier& carrier = FleetExecutor::GetCarrier(); + + auto msg_bus = std::make_shared(); + msg_bus->Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, ""); + carrier.SetMsgBus(msg_bus); // NOTE: don't delete, otherwise interceptor will use undefined node TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id @@ -74,6 +78,9 @@ TEST(ComputeInterceptor, Compute) { a->Send(1, msg); a->Send(1, msg); a->Send(1, msg); + + carrier.Wait(); + carrier.Release(); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc index c68688bfea646b11470eef0a5de62ddf3369f6da..8d9e609a2403405ccfe68880bff7f6cfd493a537 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" @@ -44,6 +45,7 @@ class PingPongInterceptor : public Interceptor { stop.set_message_type(STOP); Send(0, stop); Send(1, stop); + StopCarrier(); return; } @@ -58,10 +60,12 @@ class PingPongInterceptor : public Interceptor { REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor); TEST(InterceptorTest, PingPong) { - MessageBus& msg_bus = MessageBus::Instance(); - msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); + // TODO(liyurui): Remove singleton when move SendIntra into Carrier + Carrier& carrier = FleetExecutor::GetCarrier(); - Carrier& carrier = Carrier::Instance(); + auto msg_bus = std::make_shared(); + msg_bus->Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, ""); + carrier.SetMsgBus(msg_bus); Interceptor* a = carrier.SetInterceptor( 0, InterceptorFactory::Create("PingPong", 0, nullptr)); @@ -71,6 +75,8 @@ TEST(InterceptorTest, PingPong) { InterceptorMessage msg; a->Send(1, msg); + + carrier.Wait(); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc index 233c4d92c9f3943b880a4e34db3877ff6b5c9096..93574609960a11b5853cf0a4c3c022d12210eb0d 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" @@ -36,6 +37,7 @@ class PingPongInterceptor : public Interceptor { void PingPong(const InterceptorMessage& msg) { if (msg.message_type() == STOP) { stop_ = true; + StopCarrier(); return; } std::cout << GetInterceptorId() << " recv msg, count=" << count_ @@ -105,10 +107,12 @@ TEST(InterceptorTest, PingPong) { int pid = fork(); if (pid == 0) { - MessageBus& msg_bus = MessageBus::Instance(); - msg_bus.Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip0); + auto msg_bus = std::make_shared(); + msg_bus->Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip0); - Carrier& carrier = Carrier::Instance(); + // TODO(liyurui): Remove singleton when move SendIntra into Carrier + Carrier& carrier = FleetExecutor::GetCarrier(); + carrier.SetMsgBus(msg_bus); Interceptor* a = carrier.SetInterceptor( 0, InterceptorFactory::Create("PingPong", 0, nullptr)); @@ -116,15 +120,19 @@ TEST(InterceptorTest, PingPong) { InterceptorMessage msg; a->Send(1, msg); + carrier.Wait(); } else { - MessageBus& msg_bus = MessageBus::Instance(); - msg_bus.Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip1); + auto msg_bus = std::make_shared(); + msg_bus->Init({{0, 0}, {1, 1}}, {{0, ip0}, {1, ip1}}, ip1); - Carrier& carrier = Carrier::Instance(); + // TODO(liyurui): Remove singleton when move SendIntra into Carrier + Carrier& carrier = FleetExecutor::GetCarrier(); + carrier.SetMsgBus(msg_bus); carrier.SetInterceptor(1, InterceptorFactory::Create("PingPong", 1, nullptr)); carrier.SetCreatingFlag(false); + carrier.Wait(); } } diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc index b3fdb0b7adff016985b7c546a6517ce752120e2f..cf66725a88f8003cc071f26a1e064730c34fe27a 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -51,10 +52,12 @@ void LinkNodes(const std::vector& nodes) { } TEST(AmplifierInterceptor, Amplifier) { - Carrier& carrier = Carrier::Instance(); - MessageBus& msg_bus = MessageBus::Instance(); - msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}, - {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); + // TODO(liyurui): Remove singleton when move SendIntra into Carrier + Carrier& carrier = FleetExecutor::GetCarrier(); + auto msg_bus = std::make_shared(); + msg_bus->Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}, + {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); + carrier.SetMsgBus(msg_bus); int64_t micro_steps = 3; @@ -88,6 +91,8 @@ TEST(AmplifierInterceptor, Amplifier) { msg.set_src_id(-1); msg.set_dst_id(0); carrier.EnqueueInterceptorMessage(msg); + carrier.Wait(); + carrier.Release(); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc index 936a970c05f7c507a30fc9434fa3c8013f2d0362..e2ca934b5b02f58abeaf59d26384923356c46d44 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -69,9 +70,11 @@ void LinkNodes(const std::vector& nodes, } TEST(AmplifierInterceptor, Amplifier) { - Carrier& carrier = Carrier::Instance(); - MessageBus& msg_bus = MessageBus::Instance(); - msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, ""}}, ""); + // TODO(liyurui): Remove singleton when move SendIntra into Carrier + Carrier& carrier = FleetExecutor::GetCarrier(); + auto msg_bus = std::make_shared(); + msg_bus->Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, ""}}, ""); + carrier.SetMsgBus(msg_bus); int64_t micro_steps = 6; @@ -103,6 +106,8 @@ TEST(AmplifierInterceptor, Amplifier) { msg.set_src_id(-1); msg.set_dst_id(0); carrier.EnqueueInterceptorMessage(msg); + carrier.Wait(); + carrier.Release(); } } // namespace distributed