diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 108a21b92fdfd2a61debdfe8308238065c5f3e32..8a4f10473e3d279757fe429bb122aef1064c287a 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -92,19 +92,22 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { } void Carrier::Start() { - // TODO(fleet_executor dev): this start is a faked one, need replace - for (const auto& pair : interceptor_idx_to_interceptor_) { - VLOG(3) << "Fake run is sending start to interceptor " << pair.first << "."; - InterceptorMessage tmp_msg; - tmp_msg.set_src_id(pair.first); - tmp_msg.set_dst_id(pair.first); - tmp_msg.set_message_type(DATA_IS_READY); - MessageBus& message_bus_instance = MessageBus::Instance(); - PADDLE_ENFORCE_EQ(message_bus_instance.IsInit(), true, - platform::errors::PreconditionNotMet( - "Message bus has not been initialized.")); - message_bus_instance.Send(tmp_msg); + MessageBus& msg_bus = MessageBus::Instance(); + PADDLE_ENFORCE_EQ(msg_bus.IsInit(), true, + platform::errors::PreconditionNotMet( + "Message bus has not been initialized.")); + + for (int64_t id : source_interceptor_ids_) { + VLOG(3) << "Carrier Start is sending start to source interceptor " << id + << "."; + InterceptorMessage start_msg; + // source node data_is_ready is send by carrier, so set src_id=-1 + start_msg.set_src_id(-1); + start_msg.set_dst_id(id); + start_msg.set_message_type(DATA_IS_READY); + msg_bus.Send(start_msg); } + std::unique_lock lock(running_mutex_); cond_var_.wait(lock); dev_ctx_->Wait(); @@ -164,16 +167,26 @@ void Carrier::CreateInterceptors() { int64_t interceptor_id = item.first; TaskNode* task_node = item.second; - // TODO(wangxi): use node_type to select different Interceptor - auto interceptor = - std::make_unique(interceptor_id, task_node); + std::unique_ptr interceptor; + if (task_node->type().empty()) { + // TODO(wangxi): delete this in future + interceptor.reset(new Interceptor(interceptor_id, task_node)); + } else { + interceptor = InterceptorFactory::Create(task_node->type(), + interceptor_id, task_node); + } interceptor->SetPlace(place_); interceptor->SetMiniBatchScope(minibatch_scope_); interceptor->SetMicroBatchScope(microbatch_scopes_); interceptor->SetRootScope(root_scope_); + SetInterceptor(interceptor_id, std::move(interceptor)); VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id << "."; + + if (task_node->upstream().empty()) { + source_interceptor_ids_.emplace_back(interceptor_id); + } } // The carrier will be always waiting for outside initializer // since there is no interceptor has been created during auto init diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index c4c6a418464747f7343a6c3d31ffb5e5991423fe..b5976b297f91394f8317293c58907548ba47b08f 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -90,6 +91,8 @@ class Carrier final { std::unordered_map> interceptor_idx_to_interceptor_; + std::vector source_interceptor_ids_; + std::vector message_tmp_{}; std::mutex tmp_message_mutex_; bool creating_interceptors_{true}; diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 3008c83069942c2b7bf8cf3759de7d1ec5dde2b0..fd55aa2aa1c4656a5b14a3eb936f18e77076076e 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -154,18 +154,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { } void ComputeInterceptor::Run() { - // If there is no limit, source interceptor can be executed - // an unlimited number of times. - // Now source node can only run - if (ShouldReset()) { - for (auto& out_buff : out_buffs_) { - // buffer is using - if (out_buff.second.second != 0) return; - } - step_ = 0; // reset - return; - } - while (IsInputReady() && CanWriteOutput() && !ShouldReset()) { VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; @@ -181,6 +169,18 @@ void ComputeInterceptor::Run() { // reply to upstream and decrease ready data ReplyCompletedToUpStream(); } + + // If there is no limit, source interceptor can be executed + // an unlimited number of times. + // Now source node can only run max_run_times. + if (ShouldReset()) { + for (auto& out_buff : out_buffs_) { + // buffer is using + if (out_buff.second.second != 0) return; + } + step_ = 0; // reset + return; + } } void ComputeInterceptor::ReceivedStop(int64_t up_id) { diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 40429502825c9ca02c3503a51da0fd87b6805af2..63c2bb3fc6eecb3f79896e9116806fb5dc494028 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -46,11 +46,19 @@ void Interceptor::Handle(const InterceptorMessage& msg) { VLOG(3) << "Interceptor is using default message handler. This handler is " "only used for test purpose. Check whether you init interceptor " "in the proper way."; + if (msg.message_type() == DATA_IS_READY) { + if (node_->role() != 2) { + VLOG(3) << "Fake handler is sending DATA_IS_READY message to: " + << interceptor_id_ + 1 << "."; + InterceptorMessage data_is_ready_msg; + data_is_ready_msg.set_message_type(DATA_IS_READY); + Send(interceptor_id_ + 1, data_is_ready_msg); + } VLOG(3) << "Fake handler is sending stop message to it self."; - InterceptorMessage msg; - msg.set_message_type(STOP); - Send(interceptor_id_, msg); + InterceptorMessage stop_msg; + stop_msg.set_message_type(STOP); + Send(interceptor_id_, stop_msg); } else if (msg.message_type() == STOP) { stop_ = true; StopCarrier(); diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index 2071477372c9e7ee5a58ae4f01af268b5b014211..de2171e68e19e20f0661856916a67189dabb5630 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -136,6 +136,9 @@ void MessageBus::ListenPort() { } bool MessageBus::IsSameRank(int64_t src_id, int64_t dst_id) { + // -1 is sent by carrier to source interceptor + if (src_id == -1) src_id = dst_id; + // check whether the dst is the same rank or different rank with src const auto& src_rank = interceptor_id_to_rank_.find(src_id); const auto& dst_rank = interceptor_id_to_rank_.find(dst_id); diff --git a/paddle/fluid/distributed/fleet_executor/runtime_graph.cc b/paddle/fluid/distributed/fleet_executor/runtime_graph.cc index 3a76bd43f9d55be8e5ac6dc6caa1d3008e3687e4..b32db6c2294b808997c75c7479cdc9997f9e18c5 100644 --- a/paddle/fluid/distributed/fleet_executor/runtime_graph.cc +++ b/paddle/fluid/distributed/fleet_executor/runtime_graph.cc @@ -112,6 +112,7 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) { for (const auto& op_desc : program.Block(0).AllOps()) { ops_.emplace_back(OpRegistry::CreateOp(*op_desc)); } + std::unordered_map> role_to_ops; for (const auto& op : ops_) { int32_t op_role = op->Attr("op_role");