diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 2d2a3b688fefeda4deef1f0a24a270788b380cfe..53bae87c0020eae2bbeec2a0e1e0d7098e897421 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -186,7 +186,13 @@ int64_t Carrier::GetRank(int64_t interceptor_id) const { } bool Carrier::Send(const InterceptorMessage& msg) { - int64_t src_id = (msg.src_id() == -1) ? msg.dst_id() : msg.src_id(); + int64_t src_id = msg.src_id(); + // TODO(liyurui): compatible solution, will be removed completely in the + // future + if (interceptor_id_to_rank_.find(src_id) == interceptor_id_to_rank_.end() && + src_id == SOURCE_ID) { + src_id = msg.dst_id(); + } int64_t dst_id = msg.dst_id(); int64_t src_rank = GetRank(src_id); int64_t dst_rank = GetRank(dst_id); diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index f49c84e6e5edc0f7e0c97041cd82d196ce0c44bb..fb907e3b5c29f4b8441aedcffe7cfd9cb8125ff2 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -161,7 +161,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " Reply data_is_useless msg to " << up_id << " for step: " << step_; - if (up_id == -1) return; + if (is_source_ && up_id == -1) return; InterceptorMessage reply_msg; reply_msg.set_message_type(DATA_IS_USELESS); diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index cb7ff2da89a9d1cf3ec0a9d6e96d5ed340551f46..86ca7be7f44db8b8c98e09093ab0fc1520b2b9ac 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -40,6 +40,9 @@ class TaskNode; class Carrier; class TaskLoop; +constexpr int64_t SOURCE_ID = -1; +constexpr int64_t SINK_ID = -2; + class Interceptor { public: using MsgHandle = std::function; diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message.proto b/paddle/fluid/distributed/fleet_executor/interceptor_message.proto index 7cf99e8741943774c935e41d8e152c39fffbcd75..8508bc35f29bef90a8112ff01465bdf448463809 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message.proto +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message.proto @@ -27,8 +27,8 @@ enum MessageType { } message InterceptorMessage { - optional int64 src_id = 1 [ default = 0 ]; - optional int64 dst_id = 2 [ default = 0 ]; + optional sint64 src_id = 1 [ default = 0 ]; + optional sint64 dst_id = 2 [ default = 0 ]; optional MessageType message_type = 3 [ default = RESET ]; optional bool ctrl_message = 4 [ default = false ]; optional int64 scope_idx = 5 [ default = 0 ]; diff --git a/paddle/fluid/distributed/fleet_executor/sink_interceptor.cc b/paddle/fluid/distributed/fleet_executor/sink_interceptor.cc index af707c28acd9e9576ebd5491bd4517ea0cbae32e..77fbb23a6c71b1aae986217295baaab2707b3ec0 100644 --- a/paddle/fluid/distributed/fleet_executor/sink_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/sink_interceptor.cc @@ -30,7 +30,7 @@ SinkInterceptor::SinkInterceptor(int64_t interceptor_id, TaskNode* node) void SinkInterceptor::StopCarrierIfComplete() { bool flag = true; for (const auto& up : upstream_step_) { - flag = flag & (up.second == max_run_times_); + flag = flag && (up.second == max_run_times_); } if (flag) { VLOG(3) << "Sink Interceptor is stopping carrier"; diff --git a/paddle/fluid/distributed/fleet_executor/task_node.cc b/paddle/fluid/distributed/fleet_executor/task_node.cc index 95e4c73305998e4190c1547cb2f92809e360b216..232317333ea11f85da3f7da8606eeeb3c18619f9 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.cc +++ b/paddle/fluid/distributed/fleet_executor/task_node.cc @@ -74,6 +74,9 @@ void TaskNode::Init(bool use_feed_fetch_ops) { } } +TaskNode::TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times) + : rank_(rank), task_id_(task_id), max_run_times_(max_run_times) {} + TaskNode::TaskNode(int32_t role, const std::vector& op_descs, int64_t rank, int64_t task_id, int64_t max_run_times, diff --git a/paddle/fluid/distributed/fleet_executor/task_node.h b/paddle/fluid/distributed/fleet_executor/task_node.h index 4764d4fd4af87adf3df31f2dabb614da7d719861..7dd4b5454567e58d33cf568f6c47f008dbca1fff 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.h +++ b/paddle/fluid/distributed/fleet_executor/task_node.h @@ -32,6 +32,7 @@ namespace distributed { class TaskNode final { public: using OperatorBase = paddle::framework::OperatorBase; + TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times); TaskNode(int32_t role, int64_t rank, int64_t task_id, int64_t max_run_times, int64_t max_slot_nums); TaskNode(int32_t role, const std::vector& op_descs, diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc index ba039385a74ba45aa1f33ba38138d8e5213f2e00..35857fc86b5e0cc2c3c2dc84d017dcbfb8b948d2 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc @@ -69,32 +69,42 @@ TEST(ComputeInterceptor, Compute) { std::string carrier_id = "0"; Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); - carrier->Init(0, {{0, 0}, {1, 0}}); + carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {1, 0}, {SINK_ID, 0}}); MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); // FIXME: don't delete, otherwise interceptor will use undefined node + TaskNode* source = + new TaskNode(0, SOURCE_ID, 2); // rank, task_id, max_run_times TaskNode* node_a = new TaskNode(0, ops, 0, 0, 2, 0); // role, ops, rank, task_id TaskNode* node_b = new TaskNode(0, 0, 1, 2, 0); + TaskNode* sink = new TaskNode(0, SINK_ID, 2); - // a->b + // source->a->b->sink + source->AddDownstreamTask(0); + node_a->AddUpstreamTask(SOURCE_ID); node_a->AddDownstreamTask(1); node_b->AddUpstreamTask(0); + sink->AddUpstreamTask(1); + node_b->AddDownstreamTask(SINK_ID); + carrier->SetInterceptor( + SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source)); auto* a = carrier->SetInterceptor( 0, InterceptorFactory::Create("Compute", 0, node_a)); carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); + carrier->SetInterceptor(SINK_ID, + InterceptorFactory::Create("Sink", SINK_ID, sink)); a->SetPlace(place); a->SetMicroBatchScope(scopes); // start InterceptorMessage msg; - msg.set_message_type(DATA_IS_READY); - msg.set_src_id(-1); - msg.set_dst_id(0); + msg.set_message_type(START); + msg.set_dst_id(SOURCE_ID); carrier->EnqueueInterceptorMessage(msg); carrier->Wait(); diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc index 3860e9f4e137e3f3222d5ce9995e466e3c22db00..e909744a4b5d65456e96dcc7d7f8d25b54992151 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc @@ -55,27 +55,39 @@ TEST(AmplifierInterceptor, Amplifier) { std::string carrier_id = "0"; Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); - carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}); + carrier->Init(0, {{SOURCE_ID, 0}, + {0, 0}, + {1, 0}, + {2, 0}, + {3, 0}, + {4, 0}, + {5, 0}, + {SINK_ID, 0}}); MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); int64_t micro_steps = 3; // NOTE: don't delete, otherwise interceptor will use undefined node + TaskNode* source = + new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times TaskNode* node_a = new TaskNode(0, 0, 0, 1, 0); // role, rank, task_id TaskNode* node_b = new TaskNode(0, 0, 1, 1, 0); TaskNode* node_c = new TaskNode(0, 0, 2, 1, 0); TaskNode* node_d = new TaskNode(0, 0, 3, 1, 0); TaskNode* node_e = new TaskNode(0, 0, 4, 1, 0); TaskNode* node_f = new TaskNode(0, 0, 5, 1, 0); + TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps); - // a->b->c->d->e->f - LinkNodes({node_a, node_b, node_c, node_d, node_e, node_f}); + // source->a->b->c->d->e->f->sink + LinkNodes({source, node_a, node_b, node_c, node_d, node_e, node_f, sink}); // LR->b(1:3)->F->B->e(3:1)->U node_b->SetReplyUpPerSteps(micro_steps); node_e->SetSendDownPerSteps(micro_steps); + carrier->SetInterceptor( + SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source)); carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a)); carrier->SetInterceptor(1, InterceptorFactory::Create("Amplifier", 1, node_b)); @@ -84,12 +96,13 @@ TEST(AmplifierInterceptor, Amplifier) { carrier->SetInterceptor(4, InterceptorFactory::Create("Amplifier", 4, node_e)); carrier->SetInterceptor(5, InterceptorFactory::Create("Compute", 5, node_f)); + carrier->SetInterceptor(SINK_ID, + InterceptorFactory::Create("Sink", SINK_ID, sink)); // start InterceptorMessage msg; - msg.set_message_type(DATA_IS_READY); - msg.set_src_id(-1); - msg.set_dst_id(0); + msg.set_message_type(START); + msg.set_dst_id(SOURCE_ID); carrier->EnqueueInterceptorMessage(msg); carrier->Wait(); carrier->Release(); diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc index b510b68e4e2ed5e770d96cf92575188f879d62b6..0e57596bacbe655ba8bbe8c02323a0affcb5ea11 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc @@ -73,39 +73,47 @@ TEST(AmplifierInterceptor, Amplifier) { std::string carrier_id = "0"; Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); - carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}); + carrier->Init(0, + {{SOURCE_ID, 0}, {0, 0}, {1, 0}, {2, 0}, {3, 0}, {SINK_ID, 0}}); MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, ""}}, ""); int64_t micro_steps = 6; // NOTE: don't delete, otherwise interceptor will use undefined node + TaskNode* source = + new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times TaskNode* node_a = new TaskNode(0, 0, 0, micro_steps, 0); // role, rank, task_id TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0); TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0); TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0); + TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps); - // a->b->c->d + // source->a->b->c->d->sink // LR->F->B->U - LinkNodes({node_a, node_b, node_c, node_d}, {{{node_b, node_c}, 1}}); + LinkNodes({source, node_a, node_b, node_c, node_d, sink}, + {{{node_b, node_c}, 1}}); node_a->SetRunPerSteps(micro_steps); node_d->SetRunPerSteps(micro_steps); node_d->SetRunAtOffset(micro_steps - 1); + carrier->SetInterceptor( + SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source)); carrier->SetInterceptor(0, InterceptorFactory::Create("Amplifier", 0, node_a)); carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); carrier->SetInterceptor(3, InterceptorFactory::Create("Amplifier", 3, node_d)); + carrier->SetInterceptor(SINK_ID, + InterceptorFactory::Create("Sink", SINK_ID, sink)); // start InterceptorMessage msg; - msg.set_message_type(DATA_IS_READY); - msg.set_src_id(-1); - msg.set_dst_id(0); + msg.set_message_type(START); + msg.set_dst_id(SOURCE_ID); carrier->EnqueueInterceptorMessage(msg); carrier->Wait(); carrier->Release(); diff --git a/paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc index 6b1a555e987a380da6ca4db8ceeb5b9965150ff8..8ff908f90ec85e56d52042ceaec1a7af7920b223 100644 --- a/paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc @@ -39,10 +39,10 @@ class FakeInterceptor : public Interceptor { << std::endl; InterceptorMessage reply; reply.set_message_type(DATA_IS_USELESS); - Send(-1, reply); + Send(SOURCE_ID, reply); InterceptorMessage ready; ready.set_message_type(DATA_IS_READY); - Send(-2, ready); + Send(SINK_ID, ready); } else if (msg.message_type() == DATA_IS_USELESS) { std::cout << "FakeInterceptor remove result in scope " << msg.scope_idx() << std::endl; @@ -57,28 +57,31 @@ TEST(SourceInterceptor, Source) { std::string carrier_id = "0"; Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); - carrier->Init(0, {{-1, 0}, {0, 0}, {-2, 0}}); + carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {SINK_ID, 0}}); MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); // NOTE: don't delete, otherwise interceptor will use undefined node - TaskNode* source = new TaskNode(0, -1, 0, 3, 0); // role, rank, task_id - TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id - TaskNode* sink = new TaskNode(0, -2, 0, 3, 0); // role, rank, task_id + TaskNode* source = + new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id + TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id + TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3, 0); // role, rank, task_id source->AddDownstreamTask(0, 1); - node_a->AddUpstreamTask(-1, 1); - node_a->AddDownstreamTask(-2, 1); + node_a->AddUpstreamTask(SOURCE_ID, 1); + node_a->AddDownstreamTask(SINK_ID, 1); sink->AddUpstreamTask(0, 1); - carrier->SetInterceptor(-1, InterceptorFactory::Create("Source", -1, source)); + carrier->SetInterceptor( + SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source)); carrier->SetInterceptor(0, std::make_unique(0, node_a)); - carrier->SetInterceptor(-2, InterceptorFactory::Create("Sink", -2, sink)); + carrier->SetInterceptor(SINK_ID, + InterceptorFactory::Create("Sink", SINK_ID, sink)); // start InterceptorMessage msg; msg.set_message_type(START); - msg.set_dst_id(-1); + msg.set_dst_id(SOURCE_ID); carrier->EnqueueInterceptorMessage(msg); carrier->Wait(); diff --git a/paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc index cf49e97474af0373497dc04c3d254dae820b2caf..e9c0437c829d4df574d27fb013cba8ea57711c2c 100644 --- a/paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc @@ -40,7 +40,7 @@ class FakeInterceptor : public Interceptor { << std::endl; InterceptorMessage reply; reply.set_message_type(DATA_IS_USELESS); - Send(-1, reply); + Send(SOURCE_ID, reply); step_++; if (step_ == node_->max_run_times()) { carrier_->WakeUp(); @@ -56,24 +56,26 @@ TEST(SourceInterceptor, Source) { std::string carrier_id = "0"; Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); - carrier->Init(0, {{-1, 0}, {0, 0}}); + carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}}); MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); // NOTE: don't delete, otherwise interceptor will use undefined node - TaskNode* source = new TaskNode(0, -1, 0, 3, 0); // role, rank, task_id - TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id + TaskNode* source = + new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id + TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id source->AddDownstreamTask(0, 1); - node_a->AddUpstreamTask(-1, 1); - carrier->SetInterceptor(-1, InterceptorFactory::Create("Source", -1, source)); + node_a->AddUpstreamTask(SOURCE_ID, 1); + carrier->SetInterceptor( + SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source)); carrier->SetInterceptor(0, std::make_unique(0, node_a)); // start InterceptorMessage msg; msg.set_message_type(START); - msg.set_dst_id(-1); + msg.set_dst_id(SOURCE_ID); carrier->EnqueueInterceptorMessage(msg); carrier->Wait();