From 60356f67b6dc7296313c9adf20da3f4174883df7 Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Tue, 19 Apr 2022 11:06:50 +0800 Subject: [PATCH] [FleetExecutor] Modified test cases using source and sink (#41926) --- .../distributed/fleet_executor/carrier.cc | 8 +++++- .../fleet_executor/compute_interceptor.cc | 2 +- .../distributed/fleet_executor/interceptor.h | 3 +++ .../fleet_executor/interceptor_message.proto | 4 +-- .../fleet_executor/sink_interceptor.cc | 2 +- .../distributed/fleet_executor/task_node.cc | 3 +++ .../distributed/fleet_executor/task_node.h | 1 + .../test/compute_interceptor_run_op_test.cc | 20 +++++++++++---- .../interceptor_pipeline_long_path_test.cc | 25 ++++++++++++++----- .../interceptor_pipeline_short_path_test.cc | 20 ++++++++++----- .../test/sink_interceptor_test.cc | 25 +++++++++++-------- .../test/source_interceptor_test.cc | 16 ++++++------ 12 files changed, 89 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 2d2a3b688fe..53bae87c002 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -186,7 +186,13 @@ int64_t Carrier::GetRank(int64_t interceptor_id) const { } bool Carrier::Send(const InterceptorMessage& msg) { - int64_t src_id = (msg.src_id() == -1) ? msg.dst_id() : msg.src_id(); + int64_t src_id = msg.src_id(); + // TODO(liyurui): compatible solution, will be removed completely in the + // future + if (interceptor_id_to_rank_.find(src_id) == interceptor_id_to_rank_.end() && + src_id == SOURCE_ID) { + src_id = msg.dst_id(); + } int64_t dst_id = msg.dst_id(); int64_t src_rank = GetRank(src_id); int64_t dst_rank = GetRank(dst_id); diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index f49c84e6e5e..fb907e3b5c2 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -161,7 +161,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " Reply data_is_useless msg to " << up_id << " for step: " << step_; - if (up_id == -1) return; + if (is_source_ && up_id == -1) return; InterceptorMessage reply_msg; reply_msg.set_message_type(DATA_IS_USELESS); diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index cb7ff2da89a..86ca7be7f44 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -40,6 +40,9 @@ class TaskNode; class Carrier; class TaskLoop; +constexpr int64_t SOURCE_ID = -1; +constexpr int64_t SINK_ID = -2; + class Interceptor { public: using MsgHandle = std::function; diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message.proto b/paddle/fluid/distributed/fleet_executor/interceptor_message.proto index 7cf99e87419..8508bc35f29 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message.proto +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message.proto @@ -27,8 +27,8 @@ enum MessageType { } message InterceptorMessage { - optional int64 src_id = 1 [ default = 0 ]; - optional int64 dst_id = 2 [ default = 0 ]; + optional sint64 src_id = 1 [ default = 0 ]; + optional sint64 dst_id = 2 [ default = 0 ]; optional MessageType message_type = 3 [ default = RESET ]; optional bool ctrl_message = 4 [ default = false ]; optional int64 scope_idx = 5 [ default = 0 ]; diff --git a/paddle/fluid/distributed/fleet_executor/sink_interceptor.cc b/paddle/fluid/distributed/fleet_executor/sink_interceptor.cc index af707c28acd..77fbb23a6c7 100644 --- a/paddle/fluid/distributed/fleet_executor/sink_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/sink_interceptor.cc @@ -30,7 +30,7 @@ SinkInterceptor::SinkInterceptor(int64_t interceptor_id, TaskNode* node) void SinkInterceptor::StopCarrierIfComplete() { bool flag = true; for (const auto& up : upstream_step_) { - flag = flag & (up.second == max_run_times_); + flag = flag && (up.second == max_run_times_); } if (flag) { VLOG(3) << "Sink Interceptor is stopping carrier"; diff --git a/paddle/fluid/distributed/fleet_executor/task_node.cc b/paddle/fluid/distributed/fleet_executor/task_node.cc index 95e4c733059..232317333ea 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.cc +++ b/paddle/fluid/distributed/fleet_executor/task_node.cc @@ -74,6 +74,9 @@ void TaskNode::Init(bool use_feed_fetch_ops) { } } +TaskNode::TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times) + : rank_(rank), task_id_(task_id), max_run_times_(max_run_times) {} + TaskNode::TaskNode(int32_t role, const std::vector& op_descs, int64_t rank, int64_t task_id, int64_t max_run_times, diff --git a/paddle/fluid/distributed/fleet_executor/task_node.h b/paddle/fluid/distributed/fleet_executor/task_node.h index 4764d4fd4af..7dd4b545456 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.h +++ b/paddle/fluid/distributed/fleet_executor/task_node.h @@ -32,6 +32,7 @@ namespace distributed { class TaskNode final { public: using OperatorBase = paddle::framework::OperatorBase; + TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times); TaskNode(int32_t role, int64_t rank, int64_t task_id, int64_t max_run_times, int64_t max_slot_nums); TaskNode(int32_t role, const std::vector& op_descs, diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc index ba039385a74..35857fc86b5 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc @@ -69,32 +69,42 @@ TEST(ComputeInterceptor, Compute) { std::string carrier_id = "0"; Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); - carrier->Init(0, {{0, 0}, {1, 0}}); + carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {1, 0}, {SINK_ID, 0}}); MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); // FIXME: don't delete, otherwise interceptor will use undefined node + TaskNode* source = + new TaskNode(0, SOURCE_ID, 2); // rank, task_id, max_run_times TaskNode* node_a = new TaskNode(0, ops, 0, 0, 2, 0); // role, ops, rank, task_id TaskNode* node_b = new TaskNode(0, 0, 1, 2, 0); + TaskNode* sink = new TaskNode(0, SINK_ID, 2); - // a->b + // source->a->b->sink + source->AddDownstreamTask(0); + node_a->AddUpstreamTask(SOURCE_ID); node_a->AddDownstreamTask(1); node_b->AddUpstreamTask(0); + sink->AddUpstreamTask(1); + node_b->AddDownstreamTask(SINK_ID); + carrier->SetInterceptor( + SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source)); auto* a = carrier->SetInterceptor( 0, InterceptorFactory::Create("Compute", 0, node_a)); carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); + carrier->SetInterceptor(SINK_ID, + InterceptorFactory::Create("Sink", SINK_ID, sink)); a->SetPlace(place); a->SetMicroBatchScope(scopes); // start InterceptorMessage msg; - msg.set_message_type(DATA_IS_READY); - msg.set_src_id(-1); - msg.set_dst_id(0); + msg.set_message_type(START); + msg.set_dst_id(SOURCE_ID); carrier->EnqueueInterceptorMessage(msg); carrier->Wait(); diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc index 3860e9f4e13..e909744a4b5 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc @@ -55,27 +55,39 @@ TEST(AmplifierInterceptor, Amplifier) { std::string carrier_id = "0"; Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); - carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}); + carrier->Init(0, {{SOURCE_ID, 0}, + {0, 0}, + {1, 0}, + {2, 0}, + {3, 0}, + {4, 0}, + {5, 0}, + {SINK_ID, 0}}); MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); int64_t micro_steps = 3; // NOTE: don't delete, otherwise interceptor will use undefined node + TaskNode* source = + new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times TaskNode* node_a = new TaskNode(0, 0, 0, 1, 0); // role, rank, task_id TaskNode* node_b = new TaskNode(0, 0, 1, 1, 0); TaskNode* node_c = new TaskNode(0, 0, 2, 1, 0); TaskNode* node_d = new TaskNode(0, 0, 3, 1, 0); TaskNode* node_e = new TaskNode(0, 0, 4, 1, 0); TaskNode* node_f = new TaskNode(0, 0, 5, 1, 0); + TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps); - // a->b->c->d->e->f - LinkNodes({node_a, node_b, node_c, node_d, node_e, node_f}); + // source->a->b->c->d->e->f->sink + LinkNodes({source, node_a, node_b, node_c, node_d, node_e, node_f, sink}); // LR->b(1:3)->F->B->e(3:1)->U node_b->SetReplyUpPerSteps(micro_steps); node_e->SetSendDownPerSteps(micro_steps); + carrier->SetInterceptor( + SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source)); carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a)); carrier->SetInterceptor(1, InterceptorFactory::Create("Amplifier", 1, node_b)); @@ -84,12 +96,13 @@ TEST(AmplifierInterceptor, Amplifier) { carrier->SetInterceptor(4, InterceptorFactory::Create("Amplifier", 4, node_e)); carrier->SetInterceptor(5, InterceptorFactory::Create("Compute", 5, node_f)); + carrier->SetInterceptor(SINK_ID, + InterceptorFactory::Create("Sink", SINK_ID, sink)); // start InterceptorMessage msg; - msg.set_message_type(DATA_IS_READY); - msg.set_src_id(-1); - msg.set_dst_id(0); + msg.set_message_type(START); + msg.set_dst_id(SOURCE_ID); carrier->EnqueueInterceptorMessage(msg); carrier->Wait(); carrier->Release(); diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc index b510b68e4e2..0e57596bacb 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc @@ -73,39 +73,47 @@ TEST(AmplifierInterceptor, Amplifier) { std::string carrier_id = "0"; Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); - carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}); + carrier->Init(0, + {{SOURCE_ID, 0}, {0, 0}, {1, 0}, {2, 0}, {3, 0}, {SINK_ID, 0}}); MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, ""}}, ""); int64_t micro_steps = 6; // NOTE: don't delete, otherwise interceptor will use undefined node + TaskNode* source = + new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times TaskNode* node_a = new TaskNode(0, 0, 0, micro_steps, 0); // role, rank, task_id TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0); TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0); TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0); + TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps); - // a->b->c->d + // source->a->b->c->d->sink // LR->F->B->U - LinkNodes({node_a, node_b, node_c, node_d}, {{{node_b, node_c}, 1}}); + LinkNodes({source, node_a, node_b, node_c, node_d, sink}, + {{{node_b, node_c}, 1}}); node_a->SetRunPerSteps(micro_steps); node_d->SetRunPerSteps(micro_steps); node_d->SetRunAtOffset(micro_steps - 1); + carrier->SetInterceptor( + SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source)); carrier->SetInterceptor(0, InterceptorFactory::Create("Amplifier", 0, node_a)); carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); carrier->SetInterceptor(3, InterceptorFactory::Create("Amplifier", 3, node_d)); + carrier->SetInterceptor(SINK_ID, + InterceptorFactory::Create("Sink", SINK_ID, sink)); // start InterceptorMessage msg; - msg.set_message_type(DATA_IS_READY); - msg.set_src_id(-1); - msg.set_dst_id(0); + msg.set_message_type(START); + msg.set_dst_id(SOURCE_ID); carrier->EnqueueInterceptorMessage(msg); carrier->Wait(); carrier->Release(); diff --git a/paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc index 6b1a555e987..8ff908f90ec 100644 --- a/paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc @@ -39,10 +39,10 @@ class FakeInterceptor : public Interceptor { << std::endl; InterceptorMessage reply; reply.set_message_type(DATA_IS_USELESS); - Send(-1, reply); + Send(SOURCE_ID, reply); InterceptorMessage ready; ready.set_message_type(DATA_IS_READY); - Send(-2, ready); + Send(SINK_ID, ready); } else if (msg.message_type() == DATA_IS_USELESS) { std::cout << "FakeInterceptor remove result in scope " << msg.scope_idx() << std::endl; @@ -57,28 +57,31 @@ TEST(SourceInterceptor, Source) { std::string carrier_id = "0"; Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); - carrier->Init(0, {{-1, 0}, {0, 0}, {-2, 0}}); + carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {SINK_ID, 0}}); MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); // NOTE: don't delete, otherwise interceptor will use undefined node - TaskNode* source = new TaskNode(0, -1, 0, 3, 0); // role, rank, task_id - TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id - TaskNode* sink = new TaskNode(0, -2, 0, 3, 0); // role, rank, task_id + TaskNode* source = + new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id + TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id + TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3, 0); // role, rank, task_id source->AddDownstreamTask(0, 1); - node_a->AddUpstreamTask(-1, 1); - node_a->AddDownstreamTask(-2, 1); + node_a->AddUpstreamTask(SOURCE_ID, 1); + node_a->AddDownstreamTask(SINK_ID, 1); sink->AddUpstreamTask(0, 1); - carrier->SetInterceptor(-1, InterceptorFactory::Create("Source", -1, source)); + carrier->SetInterceptor( + SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source)); carrier->SetInterceptor(0, std::make_unique(0, node_a)); - carrier->SetInterceptor(-2, InterceptorFactory::Create("Sink", -2, sink)); + carrier->SetInterceptor(SINK_ID, + InterceptorFactory::Create("Sink", SINK_ID, sink)); // start InterceptorMessage msg; msg.set_message_type(START); - msg.set_dst_id(-1); + msg.set_dst_id(SOURCE_ID); carrier->EnqueueInterceptorMessage(msg); carrier->Wait(); diff --git a/paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc index cf49e97474a..e9c0437c829 100644 --- a/paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc @@ -40,7 +40,7 @@ class FakeInterceptor : public Interceptor { << std::endl; InterceptorMessage reply; reply.set_message_type(DATA_IS_USELESS); - Send(-1, reply); + Send(SOURCE_ID, reply); step_++; if (step_ == node_->max_run_times()) { carrier_->WakeUp(); @@ -56,24 +56,26 @@ TEST(SourceInterceptor, Source) { std::string carrier_id = "0"; Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); - carrier->Init(0, {{-1, 0}, {0, 0}}); + carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}}); MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); // NOTE: don't delete, otherwise interceptor will use undefined node - TaskNode* source = new TaskNode(0, -1, 0, 3, 0); // role, rank, task_id - TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id + TaskNode* source = + new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id + TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id source->AddDownstreamTask(0, 1); - node_a->AddUpstreamTask(-1, 1); - carrier->SetInterceptor(-1, InterceptorFactory::Create("Source", -1, source)); + node_a->AddUpstreamTask(SOURCE_ID, 1); + carrier->SetInterceptor( + SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source)); carrier->SetInterceptor(0, std::make_unique(0, node_a)); // start InterceptorMessage msg; msg.set_message_type(START); - msg.set_dst_id(-1); + msg.set_dst_id(SOURCE_ID); carrier->EnqueueInterceptorMessage(msg); carrier->Wait(); -- GitLab