diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index b87f48bc27c544e3bfe1d24a2da1e78047565d96..84548d7fd69c056c2cfe8de3d1f9092e978a44ab 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" +#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" namespace paddle { @@ -22,8 +23,11 @@ namespace distributed { void Carrier::Init( const std::unordered_map& interceptor_id_to_node) { + PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists( + "Carrier is already init.")); interceptor_id_to_node_ = interceptor_id_to_node; CreateInterceptors(); + is_init_ = true; } bool Carrier::EnqueueInterceptorMessage( @@ -63,6 +67,24 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { return iter->second.get(); } +void Carrier::Start() { + // TODO(fleet_executor dev): this start is a faked one, need replace + for (const auto& pair : interceptor_idx_to_interceptor_) { + VLOG(3) << "Fake run is sending start to interceptor " << pair.first << "."; + InterceptorMessage tmp_msg; + tmp_msg.set_src_id(pair.first); + tmp_msg.set_dst_id(pair.first); + tmp_msg.set_message_type(DATA_IS_READY); + MessageBus& message_bus_instance = MessageBus::Instance(); + PADDLE_ENFORCE_EQ(message_bus_instance.IsInit(), true, + platform::errors::PreconditionNotMet( + "Message bus has not been initialized.")); + message_bus_instance.Send(tmp_msg); + } +} + +bool Carrier::IsInit() const { return is_init_; } + Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, std::unique_ptr interceptor) { auto iter = interceptor_idx_to_interceptor_.find(interceptor_id); diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index 95f9ffcdf4960f7831b405d15fa52393e2c30a90..6f3be48c75fcfa5f5f158cb035a3aa4db8874c33 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -56,6 +56,10 @@ class Carrier final { void SetCreatingFlag(bool flag); + void Start(); + + bool IsInit() const; + DISABLE_COPY_AND_ASSIGN(Carrier); private: @@ -75,6 +79,7 @@ class Carrier final { std::vector message_tmp_{}; bool creating_interceptors_{true}; + bool is_init_{false}; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 05e78e77cb7650992c3543de2d5c637c9ef0883d..13c6a4e8c39d783f60542918d6f6efd5725a80b4 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" +#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/runtime_graph.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" @@ -34,14 +35,21 @@ FleetExecutor::~FleetExecutor() { void FleetExecutor::Init(const paddle::framework::ProgramDesc& program_desc) { runtime_graph_ = std::make_unique(program_desc, exe_desc_); + InitCarrier(); InitMessageBus(); } +void FleetExecutor::InitCarrier() { + Carrier& carrier_instance = Carrier::Instance(); + if (!carrier_instance.IsInit()) { + carrier_instance.Init(runtime_graph_->intercepter_id_to_node()); + } +} + void FleetExecutor::InitMessageBus() { std::stringstream ss; ss << "\nThe DNS table of the message bus is: \n"; int64_t cur_rank = exe_desc_.cur_rank(); - std::unordered_map interceptor_id_to_rank; std::unordered_map rank_to_addr; std::string addr; for (const auto& rank_info : exe_desc_.cluster_info()) { @@ -49,8 +57,6 @@ void FleetExecutor::InitMessageBus() { int64_t rank = rank_info.rank(); std::string ip_port = rank_info.ip_port(); ss << rank << "\t->\t" << ip_port << "\n"; - // TODO(Yuang): init interceptor_id_to_rank out of this loop - interceptor_id_to_rank.insert(std::make_pair(rank, rank)); rank_to_addr.insert(std::make_pair(rank, ip_port)); if (rank == cur_rank) { addr = ip_port; @@ -58,7 +64,7 @@ void FleetExecutor::InitMessageBus() { } if (addr == "") { PADDLE_ENFORCE_EQ( - rank_to_addr.size(), 0, + rank_to_addr.size(), 1, platform::errors::NotFound("Empty address is not valid for " "paddle.distributed.launch method.")); PADDLE_ENFORCE_EQ( @@ -72,12 +78,22 @@ void FleetExecutor::InitMessageBus() { VLOG(5) << ss.str(); MessageBus& message_bus_instance = MessageBus::Instance(); if (!message_bus_instance.IsInit()) { - message_bus_instance.Init(interceptor_id_to_rank, rank_to_addr, addr); + message_bus_instance.Init(runtime_graph_->intercepter_id_to_rank(), + rank_to_addr, addr); } } void FleetExecutor::Run() { // Run + Carrier& carrier_instance = Carrier::Instance(); + MessageBus& message_bus_instance = MessageBus::Instance(); + PADDLE_ENFORCE_EQ( + carrier_instance.IsInit(), true, + platform::errors::Unavailable("Carrier has not been init yet.")); + PADDLE_ENFORCE_EQ( + message_bus_instance.IsInit(), true, + platform::errors::Unavailable("MessageBus has not been init yet.")); + carrier_instance.Start(); } void FleetExecutor::Release() { diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index 779d2f91221dfa4cb97f86315ac54c4e7a304353..c939f70955c61337c861425f28e12941d58ec626 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -43,6 +43,7 @@ class FleetExecutor final { FleetExecutorDesc exe_desc_; std::unique_ptr runtime_graph_; void InitMessageBus(); + void InitCarrier(); }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 6b606290fa160e421772d9742443d07a0490bedd..696f7dd752eec3f2f5e800f8f7c76bf7e8befb1d 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -33,6 +33,16 @@ void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; } void Interceptor::Handle(const InterceptorMessage& msg) { if (handle_) { handle_(msg); + } else { + VLOG(3) << "Interceptor is using default message handler. This handler is " + "only used for test purpose. Check whether you init interceptor " + "in the proper way."; + if (msg.message_type() == DATA_IS_READY) { + VLOG(3) << "Fake handler is sending stop message to it self."; + InterceptorMessage msg; + msg.set_message_type(STOP); + Send(interceptor_id_, msg); + } } } diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 7dbf2bed3af776da01ff12fc879fb2b40f8790c1..c493a420b946b10bd26a8f6a229b47b97d824532 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1958,6 +1958,11 @@ class Executor(object): fleet_exe_desc.cluster_info.append(rank_info) nrank = len(trainer_endpoints) else: + fleet_exe_desc.cur_rank = 0 + rank_info = fleet_executor_desc_pb2.RankInfo() + rank_info.rank = 0 + rank_info.ip_port = '' + fleet_exe_desc.cluster_info.append(rank_info) logging.warning("Fleet Executor will run on single device only.") fleet_opt = program._pipeline_opt["fleet_opt"] if "dist_strategy" in fleet_opt: