未验证 提交 50f75fb5 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] Compute Interceptor stop along data flow (#37531)

上级 992d4ebb
......@@ -32,6 +32,15 @@ void Carrier::Init(
is_init_ = true;
}
Carrier::~Carrier() {
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.
// TODO(wangxi): Maybe need a better to use thread.
for (auto& interceptor : interceptor_idx_to_interceptor_) {
interceptor.second->Join();
}
}
bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) {
// enqueue message to interceptor
......
......@@ -42,7 +42,7 @@ class Carrier final {
void Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node);
~Carrier() = default;
~Carrier();
// Enqueue a message to corresponding interceptor id
bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message);
......
......@@ -35,6 +35,7 @@ void ComputeInterceptor::PrepareDeps() {
for (auto up_id : upstream) {
in_readys_.emplace(up_id, std::make_pair(in_buff_size, 0));
in_stops_.emplace(up_id, false);
}
for (auto down_id : downstream) {
out_buffs_.emplace(down_id, std::make_pair(out_buff_size, 0));
......@@ -144,6 +145,52 @@ void ComputeInterceptor::Run() {
}
}
void ComputeInterceptor::ReceivedStop(int64_t up_id) {
received_stop_ = true;
// source node has no upstream, stop is send by carrier or others
if (up_id == -1) return;
auto it = in_stops_.find(up_id);
PADDLE_ENFORCE_NE(it, in_stops_.end(),
platform::errors::NotFound(
"Cannot find upstream=%lld in in_stops.", up_id));
PADDLE_ENFORCE_EQ(
it->second, false,
platform::errors::AlreadyExists("Already received stop from %lld, stop "
"cannot be send more than once."));
it->second = true;
}
void ComputeInterceptor::TryStop() {
if (!received_stop_) return;
// can stop only when all upstream is stop and
// downstream complete
for (auto& in_stop : in_stops_) {
if (!in_stop.second) return;
}
for (auto& out_buff : out_buffs_) {
auto used_size = out_buff.second.second;
if (used_size != 0) return;
}
// send stop to downstream
for (auto& out : out_buffs_) {
auto down_id = out.first;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(down_id, stop);
}
stop_ = true;
}
void ComputeInterceptor::HandleStop(const InterceptorMessage& msg) {
ReceivedStop(msg.src_id());
TryStop();
}
void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
IncreaseReady(msg.src_id());
......@@ -152,6 +199,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
DecreaseBuff(msg.src_id());
Run();
}
TryStop();
}
REGISTER_INTERCEPTOR(Compute, ComputeInterceptor);
......
......@@ -38,6 +38,10 @@ class ComputeInterceptor : public Interceptor {
void Run();
void Compute(const InterceptorMessage& msg);
void HandleStop(const InterceptorMessage& msg) override;
void ReceivedStop(int64_t up_id);
void TryStop();
private:
// FIXME(wangxi): if use step_ and max_steps_, how to restart step_ from 0
int64_t step_{0};
......@@ -45,6 +49,9 @@ class ComputeInterceptor : public Interceptor {
std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{};
// downstream_id-->(max_buffer_size, used_size)
std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{};
bool received_stop_{false};
std::map<int64_t, bool> in_stops_{};
};
} // namespace distributed
......
......@@ -28,7 +28,13 @@ Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node)
});
}
Interceptor::~Interceptor() { interceptor_thread_.join(); }
Interceptor::~Interceptor() { Join(); }
void Interceptor::Join() {
if (interceptor_thread_.joinable()) {
interceptor_thread_.join();
}
}
void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; }
......@@ -74,6 +80,9 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
return MessageBus::Instance().Send(msg);
}
// maybe need a better method for interceptor base
void Interceptor::HandleStop(const InterceptorMessage& msg) { stop_ = true; }
void Interceptor::PoolTheMailbox() {
// pool the local mailbox, parse the Message
for (;;) {
......@@ -91,13 +100,18 @@ void Interceptor::PoolTheMailbox() {
VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message"
<< " from interceptor " << interceptor_message.src_id()
<< " with message: " << message_type << ".";
if (message_type == STOP) {
HandleStop(interceptor_message);
} else {
Handle(interceptor_message);
}
if (stop_) {
// break the pooling thread
VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting.";
break;
}
Handle(interceptor_message);
}
}
......
......@@ -43,9 +43,13 @@ class Interceptor {
virtual ~Interceptor();
void Join();
// register interceptor handle
void RegisterMsgHandle(MsgHandle handle);
virtual void HandleStop(const InterceptorMessage& msg);
void Handle(const InterceptorMessage& msg);
// return the interceptor id
......@@ -64,6 +68,7 @@ class Interceptor {
protected:
TaskNode* GetTaskNode() const { return node_; }
bool stop_{false};
private:
// pool the local mailbox, parse the Message
......
......@@ -25,28 +25,6 @@ limitations under the License. */
namespace paddle {
namespace distributed {
class StopInterceptor : public Interceptor {
public:
StopInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { Stop(msg); });
}
void Stop(const InterceptorMessage& msg) {
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
count_ += 1;
if (count_ == 1) return;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(0, stop);
Send(1, stop);
Send(2, stop);
Send(3, stop);
}
int count_{0};
};
class StartInterceptor : public Interceptor {
public:
StartInterceptor(int64_t interceptor_id, TaskNode* node)
......@@ -57,13 +35,20 @@ class StartInterceptor : public Interceptor {
void NOP(const InterceptorMessage& msg) {
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
++count_;
if (count_ == 3) {
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(msg.dst_id(), stop); // stop 0, this
Send(msg.src_id(), stop); // stop 1, compute
}
}
int count_{0};
};
TEST(ComputeInterceptor, Compute) {
MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, "127.0.0.0:0"}},
"127.0.0.0:0");
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
Carrier& carrier = Carrier::Instance();
......@@ -71,27 +56,24 @@ TEST(ComputeInterceptor, Compute) {
TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, 0, 0);
// a->b->c->d
// a->b->c
node_a->AddDownstreamTask(1);
node_b->AddUpstreamTask(0);
node_b->AddDownstreamTask(2);
node_c->AddUpstreamTask(1);
node_c->AddDownstreamTask(3);
node_d->AddUpstreamTask(2);
Interceptor* a =
carrier.SetInterceptor(0, std::make_unique<StartInterceptor>(0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier.SetInterceptor(3, std::make_unique<StopInterceptor>(3, node_c));
carrier.SetCreatingFlag(false);
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
// double buff, send twice
// test run three times
a->Send(1, msg);
a->Send(1, msg);
a->Send(1, msg);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册