未验证 提交 60356f67 编写于 作者: L LiYuRio 提交者: GitHub

[FleetExecutor] Modified test cases using source and sink (#41926)

上级 771a4144
...@@ -186,7 +186,13 @@ int64_t Carrier::GetRank(int64_t interceptor_id) const { ...@@ -186,7 +186,13 @@ int64_t Carrier::GetRank(int64_t interceptor_id) const {
} }
bool Carrier::Send(const InterceptorMessage& msg) { bool Carrier::Send(const InterceptorMessage& msg) {
int64_t src_id = (msg.src_id() == -1) ? msg.dst_id() : msg.src_id(); int64_t src_id = msg.src_id();
// TODO(liyurui): compatible solution, will be removed completely in the
// future
if (interceptor_id_to_rank_.find(src_id) == interceptor_id_to_rank_.end() &&
src_id == SOURCE_ID) {
src_id = msg.dst_id();
}
int64_t dst_id = msg.dst_id(); int64_t dst_id = msg.dst_id();
int64_t src_rank = GetRank(src_id); int64_t src_rank = GetRank(src_id);
int64_t dst_rank = GetRank(dst_id); int64_t dst_rank = GetRank(dst_id);
......
...@@ -161,7 +161,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -161,7 +161,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
VLOG(3) << "ComputeInterceptor " << interceptor_id_ VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id << " Reply data_is_useless msg to " << up_id
<< " for step: " << step_; << " for step: " << step_;
if (up_id == -1) return; if (is_source_ && up_id == -1) return;
InterceptorMessage reply_msg; InterceptorMessage reply_msg;
reply_msg.set_message_type(DATA_IS_USELESS); reply_msg.set_message_type(DATA_IS_USELESS);
......
...@@ -40,6 +40,9 @@ class TaskNode; ...@@ -40,6 +40,9 @@ class TaskNode;
class Carrier; class Carrier;
class TaskLoop; class TaskLoop;
constexpr int64_t SOURCE_ID = -1;
constexpr int64_t SINK_ID = -2;
class Interceptor { class Interceptor {
public: public:
using MsgHandle = std::function<void(const InterceptorMessage&)>; using MsgHandle = std::function<void(const InterceptorMessage&)>;
......
...@@ -27,8 +27,8 @@ enum MessageType { ...@@ -27,8 +27,8 @@ enum MessageType {
} }
message InterceptorMessage { message InterceptorMessage {
optional int64 src_id = 1 [ default = 0 ]; optional sint64 src_id = 1 [ default = 0 ];
optional int64 dst_id = 2 [ default = 0 ]; optional sint64 dst_id = 2 [ default = 0 ];
optional MessageType message_type = 3 [ default = RESET ]; optional MessageType message_type = 3 [ default = RESET ];
optional bool ctrl_message = 4 [ default = false ]; optional bool ctrl_message = 4 [ default = false ];
optional int64 scope_idx = 5 [ default = 0 ]; optional int64 scope_idx = 5 [ default = 0 ];
......
...@@ -30,7 +30,7 @@ SinkInterceptor::SinkInterceptor(int64_t interceptor_id, TaskNode* node) ...@@ -30,7 +30,7 @@ SinkInterceptor::SinkInterceptor(int64_t interceptor_id, TaskNode* node)
void SinkInterceptor::StopCarrierIfComplete() { void SinkInterceptor::StopCarrierIfComplete() {
bool flag = true; bool flag = true;
for (const auto& up : upstream_step_) { for (const auto& up : upstream_step_) {
flag = flag & (up.second == max_run_times_); flag = flag && (up.second == max_run_times_);
} }
if (flag) { if (flag) {
VLOG(3) << "Sink Interceptor is stopping carrier"; VLOG(3) << "Sink Interceptor is stopping carrier";
......
...@@ -74,6 +74,9 @@ void TaskNode::Init(bool use_feed_fetch_ops) { ...@@ -74,6 +74,9 @@ void TaskNode::Init(bool use_feed_fetch_ops) {
} }
} }
TaskNode::TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times)
: rank_(rank), task_id_(task_id), max_run_times_(max_run_times) {}
TaskNode::TaskNode(int32_t role, TaskNode::TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs, const std::vector<framework::OpDesc*>& op_descs,
int64_t rank, int64_t task_id, int64_t max_run_times, int64_t rank, int64_t task_id, int64_t max_run_times,
......
...@@ -32,6 +32,7 @@ namespace distributed { ...@@ -32,6 +32,7 @@ namespace distributed {
class TaskNode final { class TaskNode final {
public: public:
using OperatorBase = paddle::framework::OperatorBase; using OperatorBase = paddle::framework::OperatorBase;
TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times);
TaskNode(int32_t role, int64_t rank, int64_t task_id, int64_t max_run_times, TaskNode(int32_t role, int64_t rank, int64_t task_id, int64_t max_run_times,
int64_t max_slot_nums); int64_t max_slot_nums);
TaskNode(int32_t role, const std::vector<framework::OpDesc*>& op_descs, TaskNode(int32_t role, const std::vector<framework::OpDesc*>& op_descs,
......
...@@ -69,32 +69,42 @@ TEST(ComputeInterceptor, Compute) { ...@@ -69,32 +69,42 @@ TEST(ComputeInterceptor, Compute) {
std::string carrier_id = "0"; std::string carrier_id = "0";
Carrier* carrier = Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id); GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}}); carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {1, 0}, {SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create(); MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// FIXME: don't delete, otherwise interceptor will use undefined node // FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 2); // rank, task_id, max_run_times
TaskNode* node_a = TaskNode* node_a =
new TaskNode(0, ops, 0, 0, 2, 0); // role, ops, rank, task_id new TaskNode(0, ops, 0, 0, 2, 0); // role, ops, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 2, 0); TaskNode* node_b = new TaskNode(0, 0, 1, 2, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, 2);
// a->b // source->a->b->sink
source->AddDownstreamTask(0);
node_a->AddUpstreamTask(SOURCE_ID);
node_a->AddDownstreamTask(1); node_a->AddDownstreamTask(1);
node_b->AddUpstreamTask(0); node_b->AddUpstreamTask(0);
sink->AddUpstreamTask(1);
node_b->AddDownstreamTask(SINK_ID);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
auto* a = carrier->SetInterceptor( auto* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("Compute", 0, node_a)); 0, InterceptorFactory::Create("Compute", 0, node_a));
carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
a->SetPlace(place); a->SetPlace(place);
a->SetMicroBatchScope(scopes); a->SetMicroBatchScope(scopes);
// start // start
InterceptorMessage msg; InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY); msg.set_message_type(START);
msg.set_src_id(-1); msg.set_dst_id(SOURCE_ID);
msg.set_dst_id(0);
carrier->EnqueueInterceptorMessage(msg); carrier->EnqueueInterceptorMessage(msg);
carrier->Wait(); carrier->Wait();
......
...@@ -55,27 +55,39 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -55,27 +55,39 @@ TEST(AmplifierInterceptor, Amplifier) {
std::string carrier_id = "0"; std::string carrier_id = "0";
Carrier* carrier = Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id); GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}}); carrier->Init(0, {{SOURCE_ID, 0},
{0, 0},
{1, 0},
{2, 0},
{3, 0},
{4, 0},
{5, 0},
{SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create(); MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
int64_t micro_steps = 3; int64_t micro_steps = 3;
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, 1, 0); // role, rank, task_id TaskNode* node_a = new TaskNode(0, 0, 0, 1, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 1, 0); TaskNode* node_b = new TaskNode(0, 0, 1, 1, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 1, 0); TaskNode* node_c = new TaskNode(0, 0, 2, 1, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, 1, 0); TaskNode* node_d = new TaskNode(0, 0, 3, 1, 0);
TaskNode* node_e = new TaskNode(0, 0, 4, 1, 0); TaskNode* node_e = new TaskNode(0, 0, 4, 1, 0);
TaskNode* node_f = new TaskNode(0, 0, 5, 1, 0); TaskNode* node_f = new TaskNode(0, 0, 5, 1, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// a->b->c->d->e->f // source->a->b->c->d->e->f->sink
LinkNodes({node_a, node_b, node_c, node_d, node_e, node_f}); LinkNodes({source, node_a, node_b, node_c, node_d, node_e, node_f, sink});
// LR->b(1:3)->F->B->e(3:1)->U // LR->b(1:3)->F->B->e(3:1)->U
node_b->SetReplyUpPerSteps(micro_steps); node_b->SetReplyUpPerSteps(micro_steps);
node_e->SetSendDownPerSteps(micro_steps); node_e->SetSendDownPerSteps(micro_steps);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a)); carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a));
carrier->SetInterceptor(1, carrier->SetInterceptor(1,
InterceptorFactory::Create("Amplifier", 1, node_b)); InterceptorFactory::Create("Amplifier", 1, node_b));
...@@ -84,12 +96,13 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -84,12 +96,13 @@ TEST(AmplifierInterceptor, Amplifier) {
carrier->SetInterceptor(4, carrier->SetInterceptor(4,
InterceptorFactory::Create("Amplifier", 4, node_e)); InterceptorFactory::Create("Amplifier", 4, node_e));
carrier->SetInterceptor(5, InterceptorFactory::Create("Compute", 5, node_f)); carrier->SetInterceptor(5, InterceptorFactory::Create("Compute", 5, node_f));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start // start
InterceptorMessage msg; InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY); msg.set_message_type(START);
msg.set_src_id(-1); msg.set_dst_id(SOURCE_ID);
msg.set_dst_id(0);
carrier->EnqueueInterceptorMessage(msg); carrier->EnqueueInterceptorMessage(msg);
carrier->Wait(); carrier->Wait();
carrier->Release(); carrier->Release();
......
...@@ -73,39 +73,47 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -73,39 +73,47 @@ TEST(AmplifierInterceptor, Amplifier) {
std::string carrier_id = "0"; std::string carrier_id = "0";
Carrier* carrier = Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id); GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}}); carrier->Init(0,
{{SOURCE_ID, 0}, {0, 0}, {1, 0}, {2, 0}, {3, 0}, {SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create(); MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, ""}}, ""); msg_bus->Init(0, {{0, ""}}, "");
int64_t micro_steps = 6; int64_t micro_steps = 6;
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a = TaskNode* node_a =
new TaskNode(0, 0, 0, micro_steps, 0); // role, rank, task_id new TaskNode(0, 0, 0, micro_steps, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0); TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0); TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0); TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// a->b->c->d // source->a->b->c->d->sink
// LR->F->B->U // LR->F->B->U
LinkNodes({node_a, node_b, node_c, node_d}, {{{node_b, node_c}, 1}}); LinkNodes({source, node_a, node_b, node_c, node_d, sink},
{{{node_b, node_c}, 1}});
node_a->SetRunPerSteps(micro_steps); node_a->SetRunPerSteps(micro_steps);
node_d->SetRunPerSteps(micro_steps); node_d->SetRunPerSteps(micro_steps);
node_d->SetRunAtOffset(micro_steps - 1); node_d->SetRunAtOffset(micro_steps - 1);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, carrier->SetInterceptor(0,
InterceptorFactory::Create("Amplifier", 0, node_a)); InterceptorFactory::Create("Amplifier", 0, node_a));
carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier->SetInterceptor(3, carrier->SetInterceptor(3,
InterceptorFactory::Create("Amplifier", 3, node_d)); InterceptorFactory::Create("Amplifier", 3, node_d));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start // start
InterceptorMessage msg; InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY); msg.set_message_type(START);
msg.set_src_id(-1); msg.set_dst_id(SOURCE_ID);
msg.set_dst_id(0);
carrier->EnqueueInterceptorMessage(msg); carrier->EnqueueInterceptorMessage(msg);
carrier->Wait(); carrier->Wait();
carrier->Release(); carrier->Release();
......
...@@ -39,10 +39,10 @@ class FakeInterceptor : public Interceptor { ...@@ -39,10 +39,10 @@ class FakeInterceptor : public Interceptor {
<< std::endl; << std::endl;
InterceptorMessage reply; InterceptorMessage reply;
reply.set_message_type(DATA_IS_USELESS); reply.set_message_type(DATA_IS_USELESS);
Send(-1, reply); Send(SOURCE_ID, reply);
InterceptorMessage ready; InterceptorMessage ready;
ready.set_message_type(DATA_IS_READY); ready.set_message_type(DATA_IS_READY);
Send(-2, ready); Send(SINK_ID, ready);
} else if (msg.message_type() == DATA_IS_USELESS) { } else if (msg.message_type() == DATA_IS_USELESS) {
std::cout << "FakeInterceptor remove result in scope " << msg.scope_idx() std::cout << "FakeInterceptor remove result in scope " << msg.scope_idx()
<< std::endl; << std::endl;
...@@ -57,28 +57,31 @@ TEST(SourceInterceptor, Source) { ...@@ -57,28 +57,31 @@ TEST(SourceInterceptor, Source) {
std::string carrier_id = "0"; std::string carrier_id = "0";
Carrier* carrier = Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id); GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{-1, 0}, {0, 0}, {-2, 0}}); carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create(); MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = new TaskNode(0, -1, 0, 3, 0); // role, rank, task_id TaskNode* source =
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id
TaskNode* sink = new TaskNode(0, -2, 0, 3, 0); // role, rank, task_id TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3, 0); // role, rank, task_id
source->AddDownstreamTask(0, 1); source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(-1, 1); node_a->AddUpstreamTask(SOURCE_ID, 1);
node_a->AddDownstreamTask(-2, 1); node_a->AddDownstreamTask(SINK_ID, 1);
sink->AddUpstreamTask(0, 1); sink->AddUpstreamTask(0, 1);
carrier->SetInterceptor(-1, InterceptorFactory::Create("Source", -1, source)); carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, std::make_unique<FakeInterceptor>(0, node_a)); carrier->SetInterceptor(0, std::make_unique<FakeInterceptor>(0, node_a));
carrier->SetInterceptor(-2, InterceptorFactory::Create("Sink", -2, sink)); carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start // start
InterceptorMessage msg; InterceptorMessage msg;
msg.set_message_type(START); msg.set_message_type(START);
msg.set_dst_id(-1); msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg); carrier->EnqueueInterceptorMessage(msg);
carrier->Wait(); carrier->Wait();
......
...@@ -40,7 +40,7 @@ class FakeInterceptor : public Interceptor { ...@@ -40,7 +40,7 @@ class FakeInterceptor : public Interceptor {
<< std::endl; << std::endl;
InterceptorMessage reply; InterceptorMessage reply;
reply.set_message_type(DATA_IS_USELESS); reply.set_message_type(DATA_IS_USELESS);
Send(-1, reply); Send(SOURCE_ID, reply);
step_++; step_++;
if (step_ == node_->max_run_times()) { if (step_ == node_->max_run_times()) {
carrier_->WakeUp(); carrier_->WakeUp();
...@@ -56,24 +56,26 @@ TEST(SourceInterceptor, Source) { ...@@ -56,24 +56,26 @@ TEST(SourceInterceptor, Source) {
std::string carrier_id = "0"; std::string carrier_id = "0";
Carrier* carrier = Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id); GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{-1, 0}, {0, 0}}); carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create(); MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = new TaskNode(0, -1, 0, 3, 0); // role, rank, task_id TaskNode* source =
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
source->AddDownstreamTask(0, 1); source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(-1, 1); node_a->AddUpstreamTask(SOURCE_ID, 1);
carrier->SetInterceptor(-1, InterceptorFactory::Create("Source", -1, source)); carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, std::make_unique<FakeInterceptor>(0, node_a)); carrier->SetInterceptor(0, std::make_unique<FakeInterceptor>(0, node_a));
// start // start
InterceptorMessage msg; InterceptorMessage msg;
msg.set_message_type(START); msg.set_message_type(START);
msg.set_dst_id(-1); msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg); carrier->EnqueueInterceptorMessage(msg);
carrier->Wait(); carrier->Wait();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册