未验证 提交 be3b7740 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] Complete compute interceptor (#37485)

上级 1799c032
......@@ -27,31 +27,130 @@ ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node)
void ComputeInterceptor::PrepareDeps() {
auto& upstream = GetTaskNode()->upstream();
upstream_deps_.insert(upstream.begin(), upstream.end());
auto& downstream = GetTaskNode()->downstream();
// TODO(wangxi): get from task node
int64_t in_buff_size = std::numeric_limits<int64_t>::max();
int64_t out_buff_size = 2;
for (auto up_id : upstream) {
in_readys_.emplace(up_id, std::make_pair(in_buff_size, 0));
}
for (auto down_id : downstream) {
out_buffs_.emplace(down_id, std::make_pair(out_buff_size, 0));
}
}
void ComputeInterceptor::IncreaseReady(int64_t up_id) {
auto it = in_readys_.find(up_id);
PADDLE_ENFORCE_NE(it, in_readys_.end(),
platform::errors::NotFound(
"Cannot find upstream=%lld in in_readys.", up_id));
auto max_ready_size = it->second.first;
auto ready_size = it->second.second;
ready_size += 1;
PADDLE_ENFORCE_LE(ready_size, max_ready_size,
platform::errors::OutOfRange(
"upstream=%lld ready_size must <= max_ready_size, but "
"now ready_size=%lld, max_ready_size=%lld",
up_id, ready_size, max_ready_size));
it->second.second = ready_size;
}
void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
auto it = out_buffs_.find(down_id);
PADDLE_ENFORCE_NE(it, out_buffs_.end(),
platform::errors::NotFound(
"Cannot find downstream=%lld in out_buffs.", down_id));
auto used_size = it->second.second;
used_size -= 1;
PADDLE_ENFORCE_GE(
used_size, 0,
platform::errors::OutOfRange(
"downstream=%lld used buff size must >= 0, but now equal %lld",
down_id, used_size));
it->second.second = used_size;
}
bool ComputeInterceptor::IsInputReady() {
for (auto& ins : in_readys_) {
auto ready_size = ins.second.second;
// not ready, return false
if (ready_size == 0) return false;
}
return true;
}
bool ComputeInterceptor::CanWriteOutput() {
for (auto& outs : out_buffs_) {
auto max_buffer_size = outs.second.first;
auto used_size = outs.second.second;
// full, return false
if (used_size == max_buffer_size) return false;
}
return true;
}
void ComputeInterceptor::SendDataReadyToDownStream() {
auto& downstream = GetTaskNode()->downstream();
for (auto dst_id : downstream) {
InterceptorMessage dst_msg;
dst_msg.set_message_type(DATA_IS_READY);
VLOG(3) << "ComputeInterceptor Send msg to " << dst_id;
Send(dst_id, dst_msg);
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
auto max_buff_size = outs.second.first;
auto used_size = outs.second.second;
used_size += 1;
PADDLE_ENFORCE_LE(
used_size, max_buff_size,
platform::errors::OutOfRange("downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld",
down_id, used_size, max_buff_size));
outs.second.second = used_size;
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
VLOG(3) << "ComputeInterceptor Send data_is_ready msg to " << down_id;
Send(down_id, ready_msg);
}
}
void ComputeInterceptor::ReplyCompletedToUpStream() {
for (auto& ins : in_readys_) {
auto up_id = ins.first;
auto ready_size = ins.second.second;
ready_size -= 1;
PADDLE_ENFORCE_GE(
ready_size, 0,
platform::errors::OutOfRange(
"upstream=%lld ready_size must >= 0, but now got %lld", up_id,
ready_size));
ins.second.second = ready_size;
InterceptorMessage reply_msg;
reply_msg.set_message_type(DATE_IS_USELESS);
VLOG(3) << "ComputeInterceptor Reply data_is_useless msg to " << up_id;
Send(up_id, reply_msg);
}
}
void ComputeInterceptor::Run() {
while (IsInputReady() && CanWriteOutput()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
// TODO(wangxi): add op run
// send to downstream and increase buff used
SendDataReadyToDownStream();
// reply to upstream and decrease ready data
ReplyCompletedToUpStream();
}
}
void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
auto src_id = msg.src_id();
upstream_deps_.erase(src_id);
// all input is ready
if (upstream_deps_.empty()) {
// TODO(wangxi): op run
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
SendDataReadyToDownStream();
PrepareDeps();
}
IncreaseReady(msg.src_id());
Run();
} else if (msg.message_type() == DATE_IS_USELESS) {
DecreaseBuff(msg.src_id());
Run();
}
}
......
......@@ -14,6 +14,8 @@
#pragma once
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace paddle {
......@@ -25,12 +27,24 @@ class ComputeInterceptor : public Interceptor {
void PrepareDeps();
void IncreaseReady(int64_t up_id);
void DecreaseBuff(int64_t down_id);
bool IsInputReady();
bool CanWriteOutput();
void SendDataReadyToDownStream();
void ReplyCompletedToUpStream();
void Run();
void Compute(const InterceptorMessage& msg);
private:
std::unordered_set<int64_t> upstream_deps_;
// FIXME(wangxi): if use step_ and max_steps_, how to restart step_ from 0
int64_t step_{0};
// upstream_id-->(max_ready_size, ready_size)
std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{};
// downstream_id-->(max_buffer_size, used_size)
std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{};
};
} // namespace distributed
......
......@@ -35,17 +35,35 @@ class StopInterceptor : public Interceptor {
void Stop(const InterceptorMessage& msg) {
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
count_ += 1;
if (count_ == 1) return;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(0, stop);
Send(1, stop);
Send(2, stop);
Send(3, stop);
}
int count_{0};
};
class StartInterceptor : public Interceptor {
public:
StartInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); });
}
void NOP(const InterceptorMessage& msg) {
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
}
};
TEST(ComputeInterceptor, Compute) {
MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, "127.0.0.0:0"}},
"127.0.0.0:0");
Carrier& carrier = Carrier::Instance();
......@@ -53,21 +71,28 @@ TEST(ComputeInterceptor, Compute) {
TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, 0, 0);
// a->b->c
// a->b->c->d
node_a->AddDownstreamTask(1);
node_b->AddUpstreamTask(0);
node_b->AddDownstreamTask(2);
node_c->AddUpstreamTask(1);
node_c->AddDownstreamTask(3);
node_d->AddUpstreamTask(2);
Interceptor* a = carrier.SetInterceptor(
0, InterceptorFactory::Create("Compute", 0, node_a));
Interceptor* a =
carrier.SetInterceptor(0, std::make_unique<StartInterceptor>(0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier.SetInterceptor(2, std::make_unique<StopInterceptor>(2, node_c));
carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier.SetInterceptor(3, std::make_unique<StopInterceptor>(3, node_c));
carrier.SetCreatingFlag(false);
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
// double buff, send twice
a->Send(1, msg);
a->Send(1, msg);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册