From f85bd5c942d5f4b14858fd7af9c3626db89d4c18 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 18 Nov 2021 11:05:09 +0800 Subject: [PATCH] [fleet_executor] Parse runtime graph to start carrier (#37282) --- .../distributed/fleet_executor/carrier.cc | 22 ++++++++++++++++ .../distributed/fleet_executor/carrier.h | 5 ++++ .../fleet_executor/fleet_executor.cc | 26 +++++++++++++++---- .../fleet_executor/fleet_executor.h | 1 + .../distributed/fleet_executor/interceptor.cc | 10 +++++++ python/paddle/fluid/executor.py | 5 ++++ 6 files changed, 64 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index b87f48bc27c..84548d7fd69 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" +#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" namespace paddle { @@ -22,8 +23,11 @@ namespace distributed { void Carrier::Init( const std::unordered_map& interceptor_id_to_node) { + PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists( + "Carrier is already init.")); interceptor_id_to_node_ = interceptor_id_to_node; CreateInterceptors(); + is_init_ = true; } bool Carrier::EnqueueInterceptorMessage( @@ -63,6 +67,24 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { return iter->second.get(); } +void Carrier::Start() { + // TODO(fleet_executor dev): this start is a faked one, need replace + for (const auto& pair : interceptor_idx_to_interceptor_) { + VLOG(3) << "Fake run is sending start to interceptor " << pair.first << "."; + InterceptorMessage tmp_msg; + tmp_msg.set_src_id(pair.first); + tmp_msg.set_dst_id(pair.first); + tmp_msg.set_message_type(DATA_IS_READY); + MessageBus& message_bus_instance = MessageBus::Instance(); + PADDLE_ENFORCE_EQ(message_bus_instance.IsInit(), true, + platform::errors::PreconditionNotMet( + "Message bus has not been initialized.")); + message_bus_instance.Send(tmp_msg); + } +} + +bool Carrier::IsInit() const { return is_init_; } + Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, std::unique_ptr interceptor) { auto iter = interceptor_idx_to_interceptor_.find(interceptor_id); diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index 95f9ffcdf49..6f3be48c75f 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -56,6 +56,10 @@ class Carrier final { void SetCreatingFlag(bool flag); + void Start(); + + bool IsInit() const; + DISABLE_COPY_AND_ASSIGN(Carrier); private: @@ -75,6 +79,7 @@ class Carrier final { std::vector message_tmp_{}; bool creating_interceptors_{true}; + bool is_init_{false}; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 05e78e77cb7..13c6a4e8c39 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" +#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/runtime_graph.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -34,14 +35,21 @@ FleetExecutor::~FleetExecutor() { void FleetExecutor::Init(const paddle::framework::ProgramDesc& program_desc) { runtime_graph_ = std::make_unique(program_desc, exe_desc_); + InitCarrier(); InitMessageBus(); } +void FleetExecutor::InitCarrier() { + Carrier& carrier_instance = Carrier::Instance(); + if (!carrier_instance.IsInit()) { + carrier_instance.Init(runtime_graph_->intercepter_id_to_node()); + } +} + void FleetExecutor::InitMessageBus() { std::stringstream ss; ss << "\nThe DNS table of the message bus is: \n"; int64_t cur_rank = exe_desc_.cur_rank(); - std::unordered_map interceptor_id_to_rank; std::unordered_map rank_to_addr; std::string addr; for (const auto& rank_info : exe_desc_.cluster_info()) { @@ -49,8 +57,6 @@ void FleetExecutor::InitMessageBus() { int64_t rank = rank_info.rank(); std::string ip_port = rank_info.ip_port(); ss << rank << "\t->\t" << ip_port << "\n"; - // TODO(Yuang): init interceptor_id_to_rank out of this loop - interceptor_id_to_rank.insert(std::make_pair(rank, rank)); rank_to_addr.insert(std::make_pair(rank, ip_port)); if (rank == cur_rank) { addr = ip_port; @@ -58,7 +64,7 @@ void FleetExecutor::InitMessageBus() { } if (addr == "") { PADDLE_ENFORCE_EQ( - rank_to_addr.size(), 0, + rank_to_addr.size(), 1, platform::errors::NotFound("Empty address is not valid for " "paddle.distributed.launch method.")); PADDLE_ENFORCE_EQ( @@ -72,12 +78,22 @@ void FleetExecutor::InitMessageBus() { VLOG(5) << ss.str(); MessageBus& message_bus_instance = MessageBus::Instance(); if (!message_bus_instance.IsInit()) { - message_bus_instance.Init(interceptor_id_to_rank, rank_to_addr, addr); + message_bus_instance.Init(runtime_graph_->intercepter_id_to_rank(), + rank_to_addr, addr); } } void FleetExecutor::Run() { // Run + Carrier& carrier_instance = Carrier::Instance(); + MessageBus& message_bus_instance = MessageBus::Instance(); + PADDLE_ENFORCE_EQ( + carrier_instance.IsInit(), true, + platform::errors::Unavailable("Carrier has not been init yet.")); + PADDLE_ENFORCE_EQ( + message_bus_instance.IsInit(), true, + platform::errors::Unavailable("MessageBus has not been init yet.")); + carrier_instance.Start(); } void FleetExecutor::Release() { diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index 779d2f91221..c939f70955c 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -43,6 +43,7 @@ class FleetExecutor final { FleetExecutorDesc exe_desc_; std::unique_ptr runtime_graph_; void InitMessageBus(); + void InitCarrier(); }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 6b606290fa1..696f7dd752e 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -33,6 +33,16 @@ void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; } void Interceptor::Handle(const InterceptorMessage& msg) { if (handle_) { handle_(msg); + } else { + VLOG(3) << "Interceptor is using default message handler. This handler is " + "only used for test purpose. Check whether you init interceptor " + "in the proper way."; + if (msg.message_type() == DATA_IS_READY) { + VLOG(3) << "Fake handler is sending stop message to it self."; + InterceptorMessage msg; + msg.set_message_type(STOP); + Send(interceptor_id_, msg); + } } } diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 7dbf2bed3af..c493a420b94 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1958,6 +1958,11 @@ class Executor(object): fleet_exe_desc.cluster_info.append(rank_info) nrank = len(trainer_endpoints) else: + fleet_exe_desc.cur_rank = 0 + rank_info = fleet_executor_desc_pb2.RankInfo() + rank_info.rank = 0 + rank_info.ip_port = '' + fleet_exe_desc.cluster_info.append(rank_info) logging.warning("Fleet Executor will run on single device only.") fleet_opt = program._pipeline_opt["fleet_opt"] if "dist_strategy" in fleet_opt: -- GitLab