未验证 提交 50f75fb5 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] Compute Interceptor stop along data flow (#37531)

上级 992d4ebb
...@@ -32,6 +32,15 @@ void Carrier::Init( ...@@ -32,6 +32,15 @@ void Carrier::Init(
is_init_ = true; is_init_ = true;
} }
Carrier::~Carrier() {
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.
// TODO(wangxi): Maybe need a better to use thread.
for (auto& interceptor : interceptor_idx_to_interceptor_) {
interceptor.second->Join();
}
}
bool Carrier::EnqueueInterceptorMessage( bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) { const InterceptorMessage& interceptor_message) {
// enqueue message to interceptor // enqueue message to interceptor
......
...@@ -42,7 +42,7 @@ class Carrier final { ...@@ -42,7 +42,7 @@ class Carrier final {
void Init( void Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node); const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node);
~Carrier() = default; ~Carrier();
// Enqueue a message to corresponding interceptor id // Enqueue a message to corresponding interceptor id
bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message); bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message);
......
...@@ -35,6 +35,7 @@ void ComputeInterceptor::PrepareDeps() { ...@@ -35,6 +35,7 @@ void ComputeInterceptor::PrepareDeps() {
for (auto up_id : upstream) { for (auto up_id : upstream) {
in_readys_.emplace(up_id, std::make_pair(in_buff_size, 0)); in_readys_.emplace(up_id, std::make_pair(in_buff_size, 0));
in_stops_.emplace(up_id, false);
} }
for (auto down_id : downstream) { for (auto down_id : downstream) {
out_buffs_.emplace(down_id, std::make_pair(out_buff_size, 0)); out_buffs_.emplace(down_id, std::make_pair(out_buff_size, 0));
...@@ -144,6 +145,52 @@ void ComputeInterceptor::Run() { ...@@ -144,6 +145,52 @@ void ComputeInterceptor::Run() {
} }
} }
void ComputeInterceptor::ReceivedStop(int64_t up_id) {
received_stop_ = true;
// source node has no upstream, stop is send by carrier or others
if (up_id == -1) return;
auto it = in_stops_.find(up_id);
PADDLE_ENFORCE_NE(it, in_stops_.end(),
platform::errors::NotFound(
"Cannot find upstream=%lld in in_stops.", up_id));
PADDLE_ENFORCE_EQ(
it->second, false,
platform::errors::AlreadyExists("Already received stop from %lld, stop "
"cannot be send more than once."));
it->second = true;
}
void ComputeInterceptor::TryStop() {
if (!received_stop_) return;
// can stop only when all upstream is stop and
// downstream complete
for (auto& in_stop : in_stops_) {
if (!in_stop.second) return;
}
for (auto& out_buff : out_buffs_) {
auto used_size = out_buff.second.second;
if (used_size != 0) return;
}
// send stop to downstream
for (auto& out : out_buffs_) {
auto down_id = out.first;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(down_id, stop);
}
stop_ = true;
}
void ComputeInterceptor::HandleStop(const InterceptorMessage& msg) {
ReceivedStop(msg.src_id());
TryStop();
}
void ComputeInterceptor::Compute(const InterceptorMessage& msg) { void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) { if (msg.message_type() == DATA_IS_READY) {
IncreaseReady(msg.src_id()); IncreaseReady(msg.src_id());
...@@ -152,6 +199,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) { ...@@ -152,6 +199,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
DecreaseBuff(msg.src_id()); DecreaseBuff(msg.src_id());
Run(); Run();
} }
TryStop();
} }
REGISTER_INTERCEPTOR(Compute, ComputeInterceptor); REGISTER_INTERCEPTOR(Compute, ComputeInterceptor);
......
...@@ -38,6 +38,10 @@ class ComputeInterceptor : public Interceptor { ...@@ -38,6 +38,10 @@ class ComputeInterceptor : public Interceptor {
void Run(); void Run();
void Compute(const InterceptorMessage& msg); void Compute(const InterceptorMessage& msg);
void HandleStop(const InterceptorMessage& msg) override;
void ReceivedStop(int64_t up_id);
void TryStop();
private: private:
// FIXME(wangxi): if use step_ and max_steps_, how to restart step_ from 0 // FIXME(wangxi): if use step_ and max_steps_, how to restart step_ from 0
int64_t step_{0}; int64_t step_{0};
...@@ -45,6 +49,9 @@ class ComputeInterceptor : public Interceptor { ...@@ -45,6 +49,9 @@ class ComputeInterceptor : public Interceptor {
std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{}; std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{};
// downstream_id-->(max_buffer_size, used_size) // downstream_id-->(max_buffer_size, used_size)
std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{}; std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{};
bool received_stop_{false};
std::map<int64_t, bool> in_stops_{};
}; };
} // namespace distributed } // namespace distributed
......
...@@ -28,7 +28,13 @@ Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node) ...@@ -28,7 +28,13 @@ Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node)
}); });
} }
Interceptor::~Interceptor() { interceptor_thread_.join(); } Interceptor::~Interceptor() { Join(); }
void Interceptor::Join() {
if (interceptor_thread_.joinable()) {
interceptor_thread_.join();
}
}
void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; } void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; }
...@@ -74,6 +80,9 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { ...@@ -74,6 +80,9 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
return MessageBus::Instance().Send(msg); return MessageBus::Instance().Send(msg);
} }
// maybe need a better method for interceptor base
void Interceptor::HandleStop(const InterceptorMessage& msg) { stop_ = true; }
void Interceptor::PoolTheMailbox() { void Interceptor::PoolTheMailbox() {
// pool the local mailbox, parse the Message // pool the local mailbox, parse the Message
for (;;) { for (;;) {
...@@ -91,13 +100,18 @@ void Interceptor::PoolTheMailbox() { ...@@ -91,13 +100,18 @@ void Interceptor::PoolTheMailbox() {
VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message" VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message"
<< " from interceptor " << interceptor_message.src_id() << " from interceptor " << interceptor_message.src_id()
<< " with message: " << message_type << "."; << " with message: " << message_type << ".";
if (message_type == STOP) { if (message_type == STOP) {
HandleStop(interceptor_message);
} else {
Handle(interceptor_message);
}
if (stop_) {
// break the pooling thread // break the pooling thread
VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting."; VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting.";
break; break;
} }
Handle(interceptor_message);
} }
} }
......
...@@ -43,9 +43,13 @@ class Interceptor { ...@@ -43,9 +43,13 @@ class Interceptor {
virtual ~Interceptor(); virtual ~Interceptor();
void Join();
// register interceptor handle // register interceptor handle
void RegisterMsgHandle(MsgHandle handle); void RegisterMsgHandle(MsgHandle handle);
virtual void HandleStop(const InterceptorMessage& msg);
void Handle(const InterceptorMessage& msg); void Handle(const InterceptorMessage& msg);
// return the interceptor id // return the interceptor id
...@@ -64,6 +68,7 @@ class Interceptor { ...@@ -64,6 +68,7 @@ class Interceptor {
protected: protected:
TaskNode* GetTaskNode() const { return node_; } TaskNode* GetTaskNode() const { return node_; }
bool stop_{false};
private: private:
// pool the local mailbox, parse the Message // pool the local mailbox, parse the Message
......
...@@ -25,28 +25,6 @@ limitations under the License. */ ...@@ -25,28 +25,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class StopInterceptor : public Interceptor {
public:
StopInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { Stop(msg); });
}
void Stop(const InterceptorMessage& msg) {
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
count_ += 1;
if (count_ == 1) return;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(0, stop);
Send(1, stop);
Send(2, stop);
Send(3, stop);
}
int count_{0};
};
class StartInterceptor : public Interceptor { class StartInterceptor : public Interceptor {
public: public:
StartInterceptor(int64_t interceptor_id, TaskNode* node) StartInterceptor(int64_t interceptor_id, TaskNode* node)
...@@ -57,13 +35,20 @@ class StartInterceptor : public Interceptor { ...@@ -57,13 +35,20 @@ class StartInterceptor : public Interceptor {
void NOP(const InterceptorMessage& msg) { void NOP(const InterceptorMessage& msg) {
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id() std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl; << std::endl;
++count_;
if (count_ == 3) {
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(msg.dst_id(), stop); // stop 0, this
Send(msg.src_id(), stop); // stop 1, compute
}
} }
int count_{0};
}; };
TEST(ComputeInterceptor, Compute) { TEST(ComputeInterceptor, Compute) {
MessageBus& msg_bus = MessageBus::Instance(); MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, "127.0.0.0:0"}}, msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
"127.0.0.0:0");
Carrier& carrier = Carrier::Instance(); Carrier& carrier = Carrier::Instance();
...@@ -71,27 +56,24 @@ TEST(ComputeInterceptor, Compute) { ...@@ -71,27 +56,24 @@ TEST(ComputeInterceptor, Compute) {
TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0); TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0); TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, 0, 0);
// a->b->c->d // a->b->c
node_a->AddDownstreamTask(1); node_a->AddDownstreamTask(1);
node_b->AddUpstreamTask(0); node_b->AddUpstreamTask(0);
node_b->AddDownstreamTask(2); node_b->AddDownstreamTask(2);
node_c->AddUpstreamTask(1); node_c->AddUpstreamTask(1);
node_c->AddDownstreamTask(3);
node_d->AddUpstreamTask(2);
Interceptor* a = Interceptor* a =
carrier.SetInterceptor(0, std::make_unique<StartInterceptor>(0, node_a)); carrier.SetInterceptor(0, std::make_unique<StartInterceptor>(0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier.SetInterceptor(3, std::make_unique<StopInterceptor>(3, node_c));
carrier.SetCreatingFlag(false); carrier.SetCreatingFlag(false);
InterceptorMessage msg; InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY); msg.set_message_type(DATA_IS_READY);
// double buff, send twice // test run three times
a->Send(1, msg);
a->Send(1, msg); a->Send(1, msg);
a->Send(1, msg); a->Send(1, msg);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册