From be3b77404f12ac0c382db394ebea5a9687eb778e Mon Sep 17 00:00:00 2001 From: WangXi Date: Wed, 24 Nov 2021 10:46:35 +0800 Subject: [PATCH] [fleet_executor] Complete compute interceptor (#37485) --- .../fleet_executor/compute_interceptor.cc | 133 +++++++++++++++--- .../fleet_executor/compute_interceptor.h | 16 ++- .../test/compute_interceptor_test.cc | 35 ++++- 3 files changed, 161 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 4307665f30e..27c46b23c50 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -27,31 +27,130 @@ ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node) void ComputeInterceptor::PrepareDeps() { auto& upstream = GetTaskNode()->upstream(); - upstream_deps_.insert(upstream.begin(), upstream.end()); + auto& downstream = GetTaskNode()->downstream(); + + // TODO(wangxi): get from task node + int64_t in_buff_size = std::numeric_limits::max(); + int64_t out_buff_size = 2; + + for (auto up_id : upstream) { + in_readys_.emplace(up_id, std::make_pair(in_buff_size, 0)); + } + for (auto down_id : downstream) { + out_buffs_.emplace(down_id, std::make_pair(out_buff_size, 0)); + } +} + +void ComputeInterceptor::IncreaseReady(int64_t up_id) { + auto it = in_readys_.find(up_id); + PADDLE_ENFORCE_NE(it, in_readys_.end(), + platform::errors::NotFound( + "Cannot find upstream=%lld in in_readys.", up_id)); + + auto max_ready_size = it->second.first; + auto ready_size = it->second.second; + ready_size += 1; + PADDLE_ENFORCE_LE(ready_size, max_ready_size, + platform::errors::OutOfRange( + "upstream=%lld ready_size must <= max_ready_size, but " + "now ready_size=%lld, max_ready_size=%lld", + up_id, ready_size, max_ready_size)); + it->second.second = ready_size; +} + +void ComputeInterceptor::DecreaseBuff(int64_t down_id) { + auto it = out_buffs_.find(down_id); + PADDLE_ENFORCE_NE(it, out_buffs_.end(), + platform::errors::NotFound( + "Cannot find downstream=%lld in out_buffs.", down_id)); + auto used_size = it->second.second; + used_size -= 1; + PADDLE_ENFORCE_GE( + used_size, 0, + platform::errors::OutOfRange( + "downstream=%lld used buff size must >= 0, but now equal %lld", + down_id, used_size)); + it->second.second = used_size; +} + +bool ComputeInterceptor::IsInputReady() { + for (auto& ins : in_readys_) { + auto ready_size = ins.second.second; + // not ready, return false + if (ready_size == 0) return false; + } + return true; +} + +bool ComputeInterceptor::CanWriteOutput() { + for (auto& outs : out_buffs_) { + auto max_buffer_size = outs.second.first; + auto used_size = outs.second.second; + // full, return false + if (used_size == max_buffer_size) return false; + } + return true; } void ComputeInterceptor::SendDataReadyToDownStream() { - auto& downstream = GetTaskNode()->downstream(); - for (auto dst_id : downstream) { - InterceptorMessage dst_msg; - dst_msg.set_message_type(DATA_IS_READY); - VLOG(3) << "ComputeInterceptor Send msg to " << dst_id; - Send(dst_id, dst_msg); + for (auto& outs : out_buffs_) { + auto down_id = outs.first; + auto max_buff_size = outs.second.first; + auto used_size = outs.second.second; + used_size += 1; + PADDLE_ENFORCE_LE( + used_size, max_buff_size, + platform::errors::OutOfRange("downstream=%lld used buff size must <= " + "max_buff_size, but now used_size=%lld, " + "max_buff_size=%lld", + down_id, used_size, max_buff_size)); + outs.second.second = used_size; + + InterceptorMessage ready_msg; + ready_msg.set_message_type(DATA_IS_READY); + VLOG(3) << "ComputeInterceptor Send data_is_ready msg to " << down_id; + Send(down_id, ready_msg); + } +} + +void ComputeInterceptor::ReplyCompletedToUpStream() { + for (auto& ins : in_readys_) { + auto up_id = ins.first; + auto ready_size = ins.second.second; + ready_size -= 1; + PADDLE_ENFORCE_GE( + ready_size, 0, + platform::errors::OutOfRange( + "upstream=%lld ready_size must >= 0, but now got %lld", up_id, + ready_size)); + ins.second.second = ready_size; + + InterceptorMessage reply_msg; + reply_msg.set_message_type(DATE_IS_USELESS); + VLOG(3) << "ComputeInterceptor Reply data_is_useless msg to " << up_id; + Send(up_id, reply_msg); + } +} + +void ComputeInterceptor::Run() { + while (IsInputReady() && CanWriteOutput()) { + VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; + // TODO(wangxi): add op run + + // send to downstream and increase buff used + SendDataReadyToDownStream(); + // reply to upstream and decrease ready data + ReplyCompletedToUpStream(); } } void ComputeInterceptor::Compute(const InterceptorMessage& msg) { if (msg.message_type() == DATA_IS_READY) { - auto src_id = msg.src_id(); - upstream_deps_.erase(src_id); - - // all input is ready - if (upstream_deps_.empty()) { - // TODO(wangxi): op run - VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; - SendDataReadyToDownStream(); - PrepareDeps(); - } + IncreaseReady(msg.src_id()); + Run(); + } else if (msg.message_type() == DATE_IS_USELESS) { + DecreaseBuff(msg.src_id()); + Run(); } } diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h index 9b49910b9eb..7ebcf897396 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h @@ -14,6 +14,8 @@ #pragma once +#include + #include "paddle/fluid/distributed/fleet_executor/interceptor.h" namespace paddle { @@ -25,12 +27,24 @@ class ComputeInterceptor : public Interceptor { void PrepareDeps(); + void IncreaseReady(int64_t up_id); + void DecreaseBuff(int64_t down_id); + bool IsInputReady(); + bool CanWriteOutput(); + void SendDataReadyToDownStream(); + void ReplyCompletedToUpStream(); + void Run(); void Compute(const InterceptorMessage& msg); private: - std::unordered_set upstream_deps_; + // FIXME(wangxi): if use step_ and max_steps_, how to restart step_ from 0 + int64_t step_{0}; + // upstream_id-->(max_ready_size, ready_size) + std::map> in_readys_{}; + // downstream_id-->(max_buffer_size, used_size) + std::map> out_buffs_{}; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc index 658ff25672d..2366f106d11 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc @@ -35,17 +35,35 @@ class StopInterceptor : public Interceptor { void Stop(const InterceptorMessage& msg) { std::cout << GetInterceptorId() << " recv msg from " << msg.src_id() << std::endl; + count_ += 1; + if (count_ == 1) return; InterceptorMessage stop; stop.set_message_type(STOP); Send(0, stop); Send(1, stop); Send(2, stop); + Send(3, stop); + } + int count_{0}; +}; + +class StartInterceptor : public Interceptor { + public: + StartInterceptor(int64_t interceptor_id, TaskNode* node) + : Interceptor(interceptor_id, node) { + RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); }); + } + + void NOP(const InterceptorMessage& msg) { + std::cout << GetInterceptorId() << " recv msg from " << msg.src_id() + << std::endl; } }; TEST(ComputeInterceptor, Compute) { MessageBus& msg_bus = MessageBus::Instance(); - msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); + msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, "127.0.0.0:0"}}, + "127.0.0.0:0"); Carrier& carrier = Carrier::Instance(); @@ -53,21 +71,28 @@ TEST(ComputeInterceptor, Compute) { TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0); TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0); + TaskNode* node_d = new TaskNode(0, 0, 3, 0, 0); - // a->b->c + // a->b->c->d node_a->AddDownstreamTask(1); node_b->AddUpstreamTask(0); node_b->AddDownstreamTask(2); + node_c->AddUpstreamTask(1); + node_c->AddDownstreamTask(3); + node_d->AddUpstreamTask(2); - Interceptor* a = carrier.SetInterceptor( - 0, InterceptorFactory::Create("Compute", 0, node_a)); + Interceptor* a = + carrier.SetInterceptor(0, std::make_unique(0, node_a)); carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); - carrier.SetInterceptor(2, std::make_unique(2, node_c)); + carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); + carrier.SetInterceptor(3, std::make_unique(3, node_c)); carrier.SetCreatingFlag(false); InterceptorMessage msg; msg.set_message_type(DATA_IS_READY); + // double buff, send twice + a->Send(1, msg); a->Send(1, msg); } -- GitLab