diff --git a/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc b/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc index 72c689732b5b7df5f61d28d93a3bef6e305f426d..a166ff0b6dfa2f381da02ff0e90dadc08732de5e 100644 --- a/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc @@ -33,7 +33,7 @@ void AmplifierInterceptor::RunOps() { // run_per_steps_, run_at_offset_ // 4, 0 --> run at step 0, 4, 8, 12 // 4, 3 --> run at step 3, 7, 11, 15 - if ((step_ % run_per_steps_) == run_at_offset_) { + if ((cur_scope_id_ % run_per_steps_) == run_at_offset_) { ComputeInterceptor::RunOps(); } } @@ -41,7 +41,7 @@ void AmplifierInterceptor::RunOps() { void AmplifierInterceptor::SendDataReadyToDownStream() { // run multi times, send ready one times to downstream, that is // input multi times, output one times - if (step_ % send_down_per_steps_ == 0) { + if (cur_scope_id_ % send_down_per_steps_ == 0) { ComputeInterceptor::SendDataReadyToDownStream(); } } @@ -49,7 +49,7 @@ void AmplifierInterceptor::SendDataReadyToDownStream() { void AmplifierInterceptor::ReplyCompletedToUpStream() { // run multi times, reply one times to upstream, that is // input one times, output multi times - if (step_ % reply_up_per_steps_ == 0) { + if (cur_scope_id_ % reply_up_per_steps_ == 0) { ComputeInterceptor::ReplyCompletedToUpStream(); } } diff --git a/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h b/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h index 776aa8d3e88db10d551d6fd0180a5da9d6a6f3db..93e8ffa1d75aecc063b05fff84545238e7a1fba2 100644 --- a/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h @@ -21,7 +21,7 @@ namespace paddle { namespace distributed { -class AmplifierInterceptor : public ComputeInterceptor { +class AmplifierInterceptor final : public ComputeInterceptor { public: AmplifierInterceptor(int64_t interceptor_id, TaskNode* node); diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 6fb0d55a4859ef39d04857d39d1e70f6a31bb4a3..3449c87998a9dba21824e854afdb7216cb818164 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -71,6 +71,9 @@ void Carrier::Init( microbatch_scopes_[i] = &minibatch_scope_->NewScope(); CopyParameters(i, program, inference_root_scope_vars); } + // Add source and sink interceptor id to rank + interceptor_id_to_rank_.emplace(SOURCE_ID, rank); + interceptor_id_to_rank_.emplace(SINK_ID, rank); // TODO(fleet_exe dev): thread pool thread_num_ = 1; @@ -159,16 +162,10 @@ void Carrier::Start() { true, platform::errors::PreconditionNotMet( "Using carrier before initialized.")); - for (int64_t id : source_interceptor_ids_) { - VLOG(3) << "Carrier Start is sending start to source interceptor " << id - << "."; - InterceptorMessage start_msg; - // source node data_is_ready is send by carrier, so set src_id=-1 - start_msg.set_src_id(-1); - start_msg.set_dst_id(id); - start_msg.set_message_type(DATA_IS_READY); - Send(start_msg); - } + InterceptorMessage start_msg; + start_msg.set_dst_id(SOURCE_ID); + start_msg.set_message_type(START); + Send(start_msg); // TODO(wangxi): async step Wait(); dev_ctx_->Wait(); @@ -270,6 +267,38 @@ void Carrier::CreateInterceptors() { auto gc = GetGC(place_); + // create source and sink task node + auto max_run_times = microbatch_scopes_.size(); + TaskNode* source = new TaskNode( + rank_, SOURCE_ID, max_run_times); // rank, task_id, max_run_times + TaskNode* sink = new TaskNode(rank_, SINK_ID, max_run_times); + // find nodes without upstreams or without downstreams + std::vector origin_sources, origin_sinks; + for (const auto& item : interceptor_id_to_node_) { + TaskNode* task_node = item.second; + if (task_node->upstream().empty()) { + origin_sources.emplace_back(task_node); + } + if (task_node->downstream().empty()) { + origin_sinks.emplace_back(task_node); + } + } + // link source node with origin source + for (const auto& node : origin_sources) { + source->AddDownstreamTask(node->task_id(), + std::numeric_limits::max()); + node->AddUpstreamTask(SOURCE_ID, std::numeric_limits::max()); + } + // link sink node with origin sink + for (const auto& node : origin_sinks) { + sink->AddUpstreamTask(node->task_id(), std::numeric_limits::max()); + node->AddDownstreamTask(SINK_ID, std::numeric_limits::max()); + } + // create source and sink interceptor + SetInterceptor(SOURCE_ID, + InterceptorFactory::Create("Source", SOURCE_ID, source)); + SetInterceptor(SINK_ID, InterceptorFactory::Create("Sink", SINK_ID, sink)); + // create each Interceptor // no auto init since there is no config for (const auto& item : interceptor_id_to_node_) { @@ -303,9 +332,15 @@ void Carrier::CreateInterceptors() { VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id << " with type: " << task_node->type() << "."; - if (task_node->upstream().empty()) { - source_interceptor_ids_.emplace_back(interceptor_id); - } + PADDLE_ENFORCE_EQ( + task_node->upstream().empty(), + false, + platform::errors::PreconditionNotMet( + "There should not have normal nodes as source nodes")); + PADDLE_ENFORCE_EQ(task_node->downstream().empty(), + false, + platform::errors::PreconditionNotMet( + "There should not have normal nodes as sink nodes")); } } diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index fe3d492676655838f6f077718ef65681bcdb53cb..2523942e06223f6210461a625a1a3bce2dcedb92 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -100,8 +100,6 @@ class Carrier final { std::unordered_map> interceptor_idx_to_interceptor_; - std::vector source_interceptor_ids_; - bool is_init_{false}; std::mutex running_mutex_; diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 5b96ee76e7144692bad974c14a2bce1f6ae2f3b4..5017f81523c8aea31fb8732e001e4af311313d32 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -34,29 +34,10 @@ void ComputeInterceptor::PrepareDeps() { for (auto up : upstream) { in_readys_.emplace(up.first, std::make_pair(up.second, 0)); - in_stops_.emplace(up.first, false); } for (auto down : downstream) { out_buffs_.emplace(down.first, std::make_pair(down.second, 0)); } - - // source compute node, should we add a new SourceInterceptor? - if (upstream.empty()) { - is_source_ = true; - PADDLE_ENFORCE_GT(node_->max_run_times(), - 0, - platform::errors::InvalidArgument( - "Source ComputeInterceptor must run at least one " - "times, but now max_run_times=%ld", - node_->max_run_times())); - in_readys_.emplace(-1, - std::make_pair(std::numeric_limits::max(), 0)); - } - - // If there is no downstream or every downstream is in different rank, - // then this interceptor is the last one for current rank. - // This can be get during init, can be cached for later use. - is_last_ = downstream.empty(); } void ComputeInterceptor::IncreaseReady(int64_t up_id) { @@ -66,12 +47,6 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) { platform::errors::NotFound( "Cannot find upstream=%lld in in_readys.", up_id)); - // source node has no upstream, data_is_ready is send by carrier or others - if (is_source_ && up_id == -1) { - it->second.second += GetTaskNode()->max_run_times(); - return; - } - auto max_ready_size = it->second.first; auto ready_size = it->second.second; ready_size += 1; @@ -152,7 +127,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() { ready_msg.set_message_type(DATA_IS_READY); VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " Send data_is_ready msg to " << down_id - << " for step: " << step_; + << " in scope: " << cur_scope_id_; Send(down_id, ready_msg); } } @@ -173,8 +148,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " Reply data_is_useless msg to " << up_id - << " for step: " << step_; - if (is_source_ && up_id == -1) return; + << " in scope: " << cur_scope_id_; InterceptorMessage reply_msg; reply_msg.set_message_type(DATA_IS_USELESS); @@ -183,16 +157,20 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { } void ComputeInterceptor::RunOps() { - VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the " - << step_ + 1 << " time."; for (auto op : node_->ops()) { - op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_); + PADDLE_ENFORCE_LT(cur_scope_id_, + microbatch_scopes_.size(), + platform::errors::InvalidArgument( + "Step out of range. There are %ld " + "microbatch_scopes, but recevice scope index %ld", + microbatch_scopes_.size(), + cur_scope_id_)); + op->Run(*microbatch_scopes_[cur_scope_id_], place_); if (gc_) { - framework::DeleteUnusedTensors( - *microbatch_scopes_[step_ % node_->max_run_times()], - op, - node_->unused_vars(), - gc_.get()); + framework::DeleteUnusedTensors(*microbatch_scopes_[cur_scope_id_], + op, + node_->unused_vars(), + gc_.get()); } } } @@ -201,77 +179,28 @@ void ComputeInterceptor::Run() { while (IsInputReady() && CanWriteOutput()) { VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; + // get the ready scope id from queue + cur_scope_id_ = ready_queue_.front(); + ready_queue_.pop(); + RunOps(); - ++step_; // send to downstream and increase buff used SendDataReadyToDownStream(); // reply to upstream and decrease ready data ReplyCompletedToUpStream(); - // Try to stop Carrier - if (is_last_ && (step_ % node_->max_run_times() == 0)) { - VLOG(3) << "Interceptor " << GetInterceptorId() - << " is stopping carrier."; - // FIXME(wangxi): with multi sink interceptor - StopCarrier(); - } - } -} - -void ComputeInterceptor::ReceivedStop(int64_t up_id) { - received_stop_ = true; - - // source node has no upstream, stop is send by carrier or others - if (is_source_ && up_id == -1) return; - - auto it = in_stops_.find(up_id); - PADDLE_ENFORCE_NE(it, - in_stops_.end(), - platform::errors::NotFound( - "Cannot find upstream=%lld in in_stops.", up_id)); - PADDLE_ENFORCE_EQ( - it->second, - false, - platform::errors::AlreadyExists("Already received stop from %lld, stop " - "cannot be send more than once.")); - it->second = true; -} - -void ComputeInterceptor::TryStop() { - if (!received_stop_) return; - - // can stop only when all upstream is stop and - // downstream complete - for (auto& in_stop : in_stops_) { - if (!in_stop.second) return; - } - for (auto& out_buff : out_buffs_) { - auto used_size = out_buff.second.second; - if (used_size != 0) return; } - - // send stop to downstream - for (auto& out : out_buffs_) { - auto down_id = out.first; - InterceptorMessage stop; - stop.set_message_type(STOP); - Send(down_id, stop); - } - stop_ = true; } void ComputeInterceptor::Compute(const InterceptorMessage& msg) { if (msg.message_type() == DATA_IS_READY) { IncreaseReady(msg.src_id()); + ready_queue_.push(msg.scope_idx()); Run(); } else if (msg.message_type() == DATA_IS_USELESS) { DecreaseBuff(msg.src_id()); Run(); - } else if (msg.message_type() == STOP) { - ReceivedStop(msg.src_id()); } - - TryStop(); } REGISTER_INTERCEPTOR(Compute, ComputeInterceptor); diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h index fb82ce76c7bdb851c32b1959121059cfca041b94..9709cd4437f1019fea80cf04ecce5a38f74bb463 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "paddle/fluid/distributed/fleet_executor/interceptor.h" @@ -30,7 +31,8 @@ class ComputeInterceptor : public Interceptor { virtual void SendDataReadyToDownStream(); virtual void ReplyCompletedToUpStream(); - int64_t step_{0}; + std::queue ready_queue_; + int64_t cur_scope_id_; private: void PrepareDeps(); @@ -43,19 +45,10 @@ class ComputeInterceptor : public Interceptor { void Run(); void Compute(const InterceptorMessage& msg); - void ReceivedStop(int64_t up_id); - void TryStop(); - - bool is_source_{false}; - bool is_last_{false}; - // upstream_id-->(max_ready_size, ready_size) std::map> in_readys_{}; // downstream_id-->(max_buffer_size, used_size) std::map> out_buffs_{}; - - bool received_stop_{false}; - std::map in_stops_{}; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 6a761072027a924f21b38f7a694bba65b77e425d..2c20e1ad6113ecda58404429697fa4077fece492 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -93,7 +93,6 @@ class Interceptor { TaskNode* node_; // for stop - bool stop_{false}; void StopCarrier(); // for runtime @@ -114,9 +113,6 @@ class Interceptor { std::mutex mutex_; std::deque messages_; - - int64_t already_run_times_{0}; - int64_t used_slot_nums_{0}; }; class InterceptorFactory { diff --git a/paddle/fluid/distributed/fleet_executor/sink_interceptor.h b/paddle/fluid/distributed/fleet_executor/sink_interceptor.h index cb1d698a78526fdde61586304e588e8009340584..1abb7a641e23a5237570b9f469009f4fa3fb72a7 100644 --- a/paddle/fluid/distributed/fleet_executor/sink_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/sink_interceptor.h @@ -25,7 +25,7 @@ namespace distributed { * 1. record the num of micro-step * 2. check whether to notify carrier the current step is finished */ -class SinkInterceptor : public Interceptor { +class SinkInterceptor final : public Interceptor { public: SinkInterceptor(int64_t interceptor_id, TaskNode* node); diff --git a/paddle/fluid/distributed/fleet_executor/source_interceptor.h b/paddle/fluid/distributed/fleet_executor/source_interceptor.h index f8b18fb1848645c44c75db90a7d123ba48aeae21..95e8c1b3b03781a653152219a73e6b590cced631 100644 --- a/paddle/fluid/distributed/fleet_executor/source_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/source_interceptor.h @@ -25,7 +25,7 @@ namespace distributed { * 1. receive `start` message from carrier * 2. send num_of_steps `data_is_ready` message to downstream */ -class SourceInterceptor : public Interceptor { +class SourceInterceptor final : public Interceptor { public: SourceInterceptor(int64_t interceptor_id, TaskNode* node); diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc index 4992a8b34c9da163af6bb64cad0094da9142afb2..e484031161489f4e6cd54403fbd15da0128433e8 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc @@ -25,57 +25,42 @@ limitations under the License. */ namespace paddle { namespace distributed { -class StartInterceptor : public Interceptor { - public: - StartInterceptor(int64_t interceptor_id, TaskNode* node) - : Interceptor(interceptor_id, node) { - RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); }); - } - - void NOP(const InterceptorMessage& msg) { - if (msg.message_type() == STOP) { - stop_ = true; - InterceptorMessage stop; - stop.set_message_type(STOP); - Send(1, stop); // stop 1, compute - return; - } - std::cout << GetInterceptorId() << " recv msg from " << msg.src_id() - << std::endl; - } -}; - TEST(ComputeInterceptor, Compute) { std::string carrier_id = "0"; Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); - carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}}); + carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {1, 0}, {SINK_ID, 0}}); MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); // NOTE: don't delete, otherwise interceptor will use undefined node - TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id + TaskNode* source = + new TaskNode(0, SOURCE_ID, 3); // rank, task_id, max_run_times + TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0); - TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0); + TaskNode* sink = new TaskNode(0, SINK_ID, 3); - // a->b->c + // source->a->b->sink + source->AddDownstreamTask(0); + node_a->AddUpstreamTask(SOURCE_ID); node_a->AddDownstreamTask(1, 3); node_b->AddUpstreamTask(0, 3); - node_b->AddDownstreamTask(2); - node_c->AddUpstreamTask(1); + node_b->AddDownstreamTask(SINK_ID); + sink->AddUpstreamTask(1); - Interceptor* a = - carrier->SetInterceptor(0, std::make_unique(0, node_a)); + carrier->SetInterceptor( + SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source)); + carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a)); carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); - carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); + carrier->SetInterceptor(SINK_ID, + InterceptorFactory::Create("Sink", SINK_ID, sink)); + // start InterceptorMessage msg; - msg.set_message_type(DATA_IS_READY); - // test run three times - a->Send(1, msg); - a->Send(1, msg); - a->Send(1, msg); + msg.set_message_type(START); + msg.set_dst_id(SOURCE_ID); + carrier->EnqueueInterceptorMessage(msg); carrier->Wait(); carrier->Release(); diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc index 54adf06fb67ddf6e5d9ac803b3aa097289c33c38..f43f3860199fb772bc5d4537a41490a70c8270e5 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc @@ -33,7 +33,6 @@ class PingPongInterceptor : public Interceptor { void PingPong(const InterceptorMessage& msg) { if (msg.message_type() == STOP) { - stop_ = true; return; } std::cout << GetInterceptorId() << " recv msg, count=" << count_ diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc index 3828c4478cbe6eecad18a88ce5501eae84eb0589..62c23068d7d4a9eb6574aacc53d0c258ae2ddc51 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc @@ -36,7 +36,6 @@ class PingPongInterceptor : public Interceptor { void PingPong(const InterceptorMessage& msg) { if (msg.message_type() == STOP) { - stop_ = true; StopCarrier(); return; }