From 50f75fb5d6ddf4651ff01f3e22fe60e91a5c49b1 Mon Sep 17 00:00:00 2001 From: WangXi Date: Thu, 25 Nov 2021 14:56:47 +0800 Subject: [PATCH] [fleet_executor] Compute Interceptor stop along data flow (#37531) --- .../distributed/fleet_executor/carrier.cc | 9 ++++ .../distributed/fleet_executor/carrier.h | 2 +- .../fleet_executor/compute_interceptor.cc | 49 +++++++++++++++++++ .../fleet_executor/compute_interceptor.h | 7 +++ .../distributed/fleet_executor/interceptor.cc | 20 ++++++-- .../distributed/fleet_executor/interceptor.h | 5 ++ .../test/compute_interceptor_test.cc | 42 +++++----------- 7 files changed, 100 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 8a42533f59e..728cfc62607 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -32,6 +32,15 @@ void Carrier::Init( is_init_ = true; } +Carrier::~Carrier() { + // NOTE(wangxi): must join before `Derived Interceptor` destruct, + // otherwise Derived object will be destructed before thread complete. + // TODO(wangxi): Maybe need a better to use thread. + for (auto& interceptor : interceptor_idx_to_interceptor_) { + interceptor.second->Join(); + } +} + bool Carrier::EnqueueInterceptorMessage( const InterceptorMessage& interceptor_message) { // enqueue message to interceptor diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index 6f3be48c75f..3413ed50004 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -42,7 +42,7 @@ class Carrier final { void Init( const std::unordered_map& interceptor_id_to_node); - ~Carrier() = default; + ~Carrier(); // Enqueue a message to corresponding interceptor id bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message); diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 27c46b23c50..9a80470c1e7 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -35,6 +35,7 @@ void ComputeInterceptor::PrepareDeps() { for (auto up_id : upstream) { in_readys_.emplace(up_id, std::make_pair(in_buff_size, 0)); + in_stops_.emplace(up_id, false); } for (auto down_id : downstream) { out_buffs_.emplace(down_id, std::make_pair(out_buff_size, 0)); @@ -144,6 +145,52 @@ void ComputeInterceptor::Run() { } } +void ComputeInterceptor::ReceivedStop(int64_t up_id) { + received_stop_ = true; + + // source node has no upstream, stop is send by carrier or others + if (up_id == -1) return; + + auto it = in_stops_.find(up_id); + PADDLE_ENFORCE_NE(it, in_stops_.end(), + platform::errors::NotFound( + "Cannot find upstream=%lld in in_stops.", up_id)); + PADDLE_ENFORCE_EQ( + it->second, false, + platform::errors::AlreadyExists("Already received stop from %lld, stop " + "cannot be send more than once.")); + it->second = true; +} + +void ComputeInterceptor::TryStop() { + if (!received_stop_) return; + + // can stop only when all upstream is stop and + // downstream complete + for (auto& in_stop : in_stops_) { + if (!in_stop.second) return; + } + for (auto& out_buff : out_buffs_) { + auto used_size = out_buff.second.second; + if (used_size != 0) return; + } + + // send stop to downstream + for (auto& out : out_buffs_) { + auto down_id = out.first; + InterceptorMessage stop; + stop.set_message_type(STOP); + Send(down_id, stop); + } + stop_ = true; +} + +void ComputeInterceptor::HandleStop(const InterceptorMessage& msg) { + ReceivedStop(msg.src_id()); + + TryStop(); +} + void ComputeInterceptor::Compute(const InterceptorMessage& msg) { if (msg.message_type() == DATA_IS_READY) { IncreaseReady(msg.src_id()); @@ -152,6 +199,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) { DecreaseBuff(msg.src_id()); Run(); } + + TryStop(); } REGISTER_INTERCEPTOR(Compute, ComputeInterceptor); diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h index 7ebcf897396..a116abeff80 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h @@ -38,6 +38,10 @@ class ComputeInterceptor : public Interceptor { void Run(); void Compute(const InterceptorMessage& msg); + void HandleStop(const InterceptorMessage& msg) override; + void ReceivedStop(int64_t up_id); + void TryStop(); + private: // FIXME(wangxi): if use step_ and max_steps_, how to restart step_ from 0 int64_t step_{0}; @@ -45,6 +49,9 @@ class ComputeInterceptor : public Interceptor { std::map> in_readys_{}; // downstream_id-->(max_buffer_size, used_size) std::map> out_buffs_{}; + + bool received_stop_{false}; + std::map in_stops_{}; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index a342d4431a1..916923ce590 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -28,7 +28,13 @@ Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node) }); } -Interceptor::~Interceptor() { interceptor_thread_.join(); } +Interceptor::~Interceptor() { Join(); } + +void Interceptor::Join() { + if (interceptor_thread_.joinable()) { + interceptor_thread_.join(); + } +} void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; } @@ -74,6 +80,9 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { return MessageBus::Instance().Send(msg); } +// maybe need a better method for interceptor base +void Interceptor::HandleStop(const InterceptorMessage& msg) { stop_ = true; } + void Interceptor::PoolTheMailbox() { // pool the local mailbox, parse the Message for (;;) { @@ -91,13 +100,18 @@ void Interceptor::PoolTheMailbox() { VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message" << " from interceptor " << interceptor_message.src_id() << " with message: " << message_type << "."; + if (message_type == STOP) { + HandleStop(interceptor_message); + } else { + Handle(interceptor_message); + } + + if (stop_) { // break the pooling thread VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting."; break; } - - Handle(interceptor_message); } } diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 9ea392ea5f8..9f74f99ded6 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -43,9 +43,13 @@ class Interceptor { virtual ~Interceptor(); + void Join(); + // register interceptor handle void RegisterMsgHandle(MsgHandle handle); + virtual void HandleStop(const InterceptorMessage& msg); + void Handle(const InterceptorMessage& msg); // return the interceptor id @@ -64,6 +68,7 @@ class Interceptor { protected: TaskNode* GetTaskNode() const { return node_; } + bool stop_{false}; private: // pool the local mailbox, parse the Message diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc index 2366f106d11..5b85abb4258 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc @@ -25,28 +25,6 @@ limitations under the License. */ namespace paddle { namespace distributed { -class StopInterceptor : public Interceptor { - public: - StopInterceptor(int64_t interceptor_id, TaskNode* node) - : Interceptor(interceptor_id, node) { - RegisterMsgHandle([this](const InterceptorMessage& msg) { Stop(msg); }); - } - - void Stop(const InterceptorMessage& msg) { - std::cout << GetInterceptorId() << " recv msg from " << msg.src_id() - << std::endl; - count_ += 1; - if (count_ == 1) return; - InterceptorMessage stop; - stop.set_message_type(STOP); - Send(0, stop); - Send(1, stop); - Send(2, stop); - Send(3, stop); - } - int count_{0}; -}; - class StartInterceptor : public Interceptor { public: StartInterceptor(int64_t interceptor_id, TaskNode* node) @@ -57,13 +35,20 @@ class StartInterceptor : public Interceptor { void NOP(const InterceptorMessage& msg) { std::cout << GetInterceptorId() << " recv msg from " << msg.src_id() << std::endl; + ++count_; + if (count_ == 3) { + InterceptorMessage stop; + stop.set_message_type(STOP); + Send(msg.dst_id(), stop); // stop 0, this + Send(msg.src_id(), stop); // stop 1, compute + } } + int count_{0}; }; TEST(ComputeInterceptor, Compute) { MessageBus& msg_bus = MessageBus::Instance(); - msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, "127.0.0.0:0"}}, - "127.0.0.0:0"); + msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); Carrier& carrier = Carrier::Instance(); @@ -71,27 +56,24 @@ TEST(ComputeInterceptor, Compute) { TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0); TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0); - TaskNode* node_d = new TaskNode(0, 0, 3, 0, 0); - // a->b->c->d + // a->b->c node_a->AddDownstreamTask(1); node_b->AddUpstreamTask(0); node_b->AddDownstreamTask(2); node_c->AddUpstreamTask(1); - node_c->AddDownstreamTask(3); - node_d->AddUpstreamTask(2); Interceptor* a = carrier.SetInterceptor(0, std::make_unique(0, node_a)); carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); - carrier.SetInterceptor(3, std::make_unique(3, node_c)); carrier.SetCreatingFlag(false); InterceptorMessage msg; msg.set_message_type(DATA_IS_READY); - // double buff, send twice + // test run three times + a->Send(1, msg); a->Send(1, msg); a->Send(1, msg); } -- GitLab