未验证 提交 60356f67 编写于 作者: L LiYuRio 提交者: GitHub

[FleetExecutor] Modified test cases using source and sink (#41926)

上级 771a4144
......@@ -186,7 +186,13 @@ int64_t Carrier::GetRank(int64_t interceptor_id) const {
}
bool Carrier::Send(const InterceptorMessage& msg) {
int64_t src_id = (msg.src_id() == -1) ? msg.dst_id() : msg.src_id();
int64_t src_id = msg.src_id();
// TODO(liyurui): compatible solution, will be removed completely in the
// future
if (interceptor_id_to_rank_.find(src_id) == interceptor_id_to_rank_.end() &&
src_id == SOURCE_ID) {
src_id = msg.dst_id();
}
int64_t dst_id = msg.dst_id();
int64_t src_rank = GetRank(src_id);
int64_t dst_rank = GetRank(dst_id);
......
......@@ -161,7 +161,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id
<< " for step: " << step_;
if (up_id == -1) return;
if (is_source_ && up_id == -1) return;
InterceptorMessage reply_msg;
reply_msg.set_message_type(DATA_IS_USELESS);
......
......@@ -40,6 +40,9 @@ class TaskNode;
class Carrier;
class TaskLoop;
constexpr int64_t SOURCE_ID = -1;
constexpr int64_t SINK_ID = -2;
class Interceptor {
public:
using MsgHandle = std::function<void(const InterceptorMessage&)>;
......
......@@ -27,8 +27,8 @@ enum MessageType {
}
message InterceptorMessage {
optional int64 src_id = 1 [ default = 0 ];
optional int64 dst_id = 2 [ default = 0 ];
optional sint64 src_id = 1 [ default = 0 ];
optional sint64 dst_id = 2 [ default = 0 ];
optional MessageType message_type = 3 [ default = RESET ];
optional bool ctrl_message = 4 [ default = false ];
optional int64 scope_idx = 5 [ default = 0 ];
......
......@@ -30,7 +30,7 @@ SinkInterceptor::SinkInterceptor(int64_t interceptor_id, TaskNode* node)
void SinkInterceptor::StopCarrierIfComplete() {
bool flag = true;
for (const auto& up : upstream_step_) {
flag = flag & (up.second == max_run_times_);
flag = flag && (up.second == max_run_times_);
}
if (flag) {
VLOG(3) << "Sink Interceptor is stopping carrier";
......
......@@ -74,6 +74,9 @@ void TaskNode::Init(bool use_feed_fetch_ops) {
}
}
TaskNode::TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times)
: rank_(rank), task_id_(task_id), max_run_times_(max_run_times) {}
TaskNode::TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs,
int64_t rank, int64_t task_id, int64_t max_run_times,
......
......@@ -32,6 +32,7 @@ namespace distributed {
class TaskNode final {
public:
using OperatorBase = paddle::framework::OperatorBase;
TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times);
TaskNode(int32_t role, int64_t rank, int64_t task_id, int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(int32_t role, const std::vector<framework::OpDesc*>& op_descs,
......
......@@ -69,32 +69,42 @@ TEST(ComputeInterceptor, Compute) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}});
carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {1, 0}, {SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 2); // rank, task_id, max_run_times
TaskNode* node_a =
new TaskNode(0, ops, 0, 0, 2, 0); // role, ops, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 2, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, 2);
// a->b
// source->a->b->sink
source->AddDownstreamTask(0);
node_a->AddUpstreamTask(SOURCE_ID);
node_a->AddDownstreamTask(1);
node_b->AddUpstreamTask(0);
sink->AddUpstreamTask(1);
node_b->AddDownstreamTask(SINK_ID);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
auto* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("Compute", 0, node_a));
carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
a->SetPlace(place);
a->SetMicroBatchScope(scopes);
// start
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
msg.set_src_id(-1);
msg.set_dst_id(0);
msg.set_message_type(START);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
......
......@@ -55,27 +55,39 @@ TEST(AmplifierInterceptor, Amplifier) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}});
carrier->Init(0, {{SOURCE_ID, 0},
{0, 0},
{1, 0},
{2, 0},
{3, 0},
{4, 0},
{5, 0},
{SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
int64_t micro_steps = 3;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, 1, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 1, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 1, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, 1, 0);
TaskNode* node_e = new TaskNode(0, 0, 4, 1, 0);
TaskNode* node_f = new TaskNode(0, 0, 5, 1, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// a->b->c->d->e->f
LinkNodes({node_a, node_b, node_c, node_d, node_e, node_f});
// source->a->b->c->d->e->f->sink
LinkNodes({source, node_a, node_b, node_c, node_d, node_e, node_f, sink});
// LR->b(1:3)->F->B->e(3:1)->U
node_b->SetReplyUpPerSteps(micro_steps);
node_e->SetSendDownPerSteps(micro_steps);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a));
carrier->SetInterceptor(1,
InterceptorFactory::Create("Amplifier", 1, node_b));
......@@ -84,12 +96,13 @@ TEST(AmplifierInterceptor, Amplifier) {
carrier->SetInterceptor(4,
InterceptorFactory::Create("Amplifier", 4, node_e));
carrier->SetInterceptor(5, InterceptorFactory::Create("Compute", 5, node_f));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
msg.set_src_id(-1);
msg.set_dst_id(0);
msg.set_message_type(START);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
......
......@@ -73,39 +73,47 @@ TEST(AmplifierInterceptor, Amplifier) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}, {3, 0}});
carrier->Init(0,
{{SOURCE_ID, 0}, {0, 0}, {1, 0}, {2, 0}, {3, 0}, {SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, ""}}, "");
int64_t micro_steps = 6;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a =
new TaskNode(0, 0, 0, micro_steps, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// a->b->c->d
// source->a->b->c->d->sink
// LR->F->B->U
LinkNodes({node_a, node_b, node_c, node_d}, {{{node_b, node_c}, 1}});
LinkNodes({source, node_a, node_b, node_c, node_d, sink},
{{{node_b, node_c}, 1}});
node_a->SetRunPerSteps(micro_steps);
node_d->SetRunPerSteps(micro_steps);
node_d->SetRunAtOffset(micro_steps - 1);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0,
InterceptorFactory::Create("Amplifier", 0, node_a));
carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier->SetInterceptor(3,
InterceptorFactory::Create("Amplifier", 3, node_d));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
msg.set_src_id(-1);
msg.set_dst_id(0);
msg.set_message_type(START);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
......
......@@ -39,10 +39,10 @@ class FakeInterceptor : public Interceptor {
<< std::endl;
InterceptorMessage reply;
reply.set_message_type(DATA_IS_USELESS);
Send(-1, reply);
Send(SOURCE_ID, reply);
InterceptorMessage ready;
ready.set_message_type(DATA_IS_READY);
Send(-2, ready);
Send(SINK_ID, ready);
} else if (msg.message_type() == DATA_IS_USELESS) {
std::cout << "FakeInterceptor remove result in scope " << msg.scope_idx()
<< std::endl;
......@@ -57,28 +57,31 @@ TEST(SourceInterceptor, Source) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{-1, 0}, {0, 0}, {-2, 0}});
carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = new TaskNode(0, -1, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* sink = new TaskNode(0, -2, 0, 3, 0); // role, rank, task_id
TaskNode* source =
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3, 0); // role, rank, task_id
source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(-1, 1);
node_a->AddDownstreamTask(-2, 1);
node_a->AddUpstreamTask(SOURCE_ID, 1);
node_a->AddDownstreamTask(SINK_ID, 1);
sink->AddUpstreamTask(0, 1);
carrier->SetInterceptor(-1, InterceptorFactory::Create("Source", -1, source));
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, std::make_unique<FakeInterceptor>(0, node_a));
carrier->SetInterceptor(-2, InterceptorFactory::Create("Sink", -2, sink));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start
InterceptorMessage msg;
msg.set_message_type(START);
msg.set_dst_id(-1);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
......
......@@ -40,7 +40,7 @@ class FakeInterceptor : public Interceptor {
<< std::endl;
InterceptorMessage reply;
reply.set_message_type(DATA_IS_USELESS);
Send(-1, reply);
Send(SOURCE_ID, reply);
step_++;
if (step_ == node_->max_run_times()) {
carrier_->WakeUp();
......@@ -56,24 +56,26 @@ TEST(SourceInterceptor, Source) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{-1, 0}, {0, 0}});
carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = new TaskNode(0, -1, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* source =
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(-1, 1);
carrier->SetInterceptor(-1, InterceptorFactory::Create("Source", -1, source));
node_a->AddUpstreamTask(SOURCE_ID, 1);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, std::make_unique<FakeInterceptor>(0, node_a));
// start
InterceptorMessage msg;
msg.set_message_type(START);
msg.set_dst_id(-1);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册