From 0dd41a2a53590a58a9a9756db8eb82c9d41f41f0 Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Tue, 7 Feb 2023 11:57:55 +0800 Subject: [PATCH] Add start interceptor and fix bug in switch scope (#50225) --- .../distributed/fleet_executor/CMakeLists.txt | 3 + .../distributed/fleet_executor/carrier.cc | 1 + .../fleet_executor/compute_interceptor.cc | 64 ++++++---- .../fleet_executor/compute_interceptor.h | 21 ++-- .../fleet_executor/cond_interceptor.cc | 6 +- .../fleet_executor/cond_interceptor.h | 1 - .../fleet_executor/start_interceptor.cc | 114 ++++++++++++++++++ .../fleet_executor/start_interceptor.h | 39 ++++++ .../test/compute_interceptor_test.cc | 3 + .../interceptor_pipeline_long_path_test.cc | 2 +- .../interceptor_pipeline_short_path_test.cc | 4 +- .../test_fleet_executor_cond_interceptor.py | 4 +- 12 files changed, 220 insertions(+), 42 deletions(-) create mode 100644 paddle/fluid/distributed/fleet_executor/start_interceptor.cc create mode 100644 paddle/fluid/distributed/fleet_executor/start_interceptor.h diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index 9cf1cdde223..ff8ed811ee6 100755 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -37,6 +37,7 @@ cc_library( compute_interceptor.cc amplifier_interceptor.cc cond_interceptor.cc + start_interceptor.cc source_interceptor.cc sink_interceptor.cc message_service.cc @@ -69,6 +70,8 @@ if(WITH_DISTRIBUTE) ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties( cond_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties( + start_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties( source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties( diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 4a759646067..efb09bf736e 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -36,6 +36,7 @@ USE_INTERCEPTOR(Compute); USE_INTERCEPTOR(Amplifier); USE_INTERCEPTOR(Sink); USE_INTERCEPTOR(Cond); +USE_INTERCEPTOR(Start); void Carrier::Init( int64_t rank, diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index a03ac900e9f..9d229d640c9 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -18,6 +18,7 @@ #include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/phi/core/errors.h" namespace paddle { namespace distributed { @@ -33,14 +34,18 @@ void ComputeInterceptor::PrepareDeps() { auto& downstream = node_->downstream(); for (auto up : upstream) { - in_readys_.emplace(up.first, std::make_pair(up.second, 0)); + std::map ready_size_map; + for (int64_t i = 0; i < node_->max_run_times(); ++i) { + ready_size_map.emplace(i, 0); + } + in_readys_.emplace(up.first, std::make_pair(up.second, ready_size_map)); } for (auto down : downstream) { out_buffs_.emplace(down.first, std::make_pair(down.second, 0)); } } -void ComputeInterceptor::IncreaseReady(int64_t up_id) { +void ComputeInterceptor::IncreaseReady(int64_t up_id, int64_t scope_id) { auto it = in_readys_.find(up_id); PADDLE_ENFORCE_NE(it, in_readys_.end(), @@ -48,8 +53,11 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) { "Cannot find upstream=%lld in in_readys.", up_id)); auto max_ready_size = it->second.first; - auto ready_size = it->second.second; - ready_size += 1; + const auto& ready_scope_map = it->second.second; + int64_t ready_size = 0; + for (auto& scope_iter : ready_scope_map) { + ready_size += scope_iter.second; + } if (max_ready_size != INFINITE_BUFFER_SIZE) { PADDLE_ENFORCE_LE( ready_size, @@ -61,7 +69,14 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) { ready_size, max_ready_size)); } - it->second.second = ready_size; + PADDLE_ENFORCE_NE( + it->second.second.find(scope_id), + it->second.second.end(), + platform::errors::OutOfRange( + "Interceptor %lld can not find scope %lld in upstream ready map", + interceptor_id_, + scope_id)); + it->second.second.at(scope_id) = ready_scope_map.at(scope_id) + 1; } void ComputeInterceptor::DecreaseBuff(int64_t down_id) { @@ -83,16 +98,21 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) { } bool ComputeInterceptor::IsInputReady() { - for (auto& ins : in_readys_) { - auto ready_size = ins.second.second; - // not ready, return false - if (ready_size == 0) { - VLOG(3) << "Interceptor " << GetInterceptorId() + for (int64_t i = 0; i < node_->max_run_times(); ++i) { + bool flag = true; + for (auto& ins : in_readys_) { + auto ready_size_map = ins.second.second; + flag = flag && (ready_size_map.at(i) != 0); + } + if (flag) { + cur_scope_id_ = i; + return true; + } else { + VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i << "'s upstreams aren't all ready."; - return false; } } - return true; + return false; } bool ComputeInterceptor::CanWriteOutput() { @@ -144,7 +164,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() { void ComputeInterceptor::ReplyCompletedToUpStream() { for (auto& ins : in_readys_) { auto up_id = ins.first; - auto ready_size = ins.second.second; + auto ready_size = ins.second.second.at(cur_scope_id_); ready_size -= 1; PADDLE_ENFORCE_GE( ready_size, @@ -153,7 +173,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { "upstream=%lld ready_size must >= 0, but now got %lld", up_id, ready_size)); - ins.second.second = ready_size; + ins.second.second[cur_scope_id_] = ready_size; VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " Reply data_is_useless msg to " << up_id @@ -187,11 +207,8 @@ void ComputeInterceptor::RunOps() { void ComputeInterceptor::Run() { while (IsInputReady() && CanWriteOutput()) { - VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; - - // get the ready scope id from queue - cur_scope_id_ = ready_queue_.front(); - ready_queue_.pop(); + VLOG(0) << "id=" << GetInterceptorId() + << " ComputeInterceptor running in scope " << cur_scope_id_; RunOps(); @@ -204,10 +221,15 @@ void ComputeInterceptor::Run() { void ComputeInterceptor::Compute(const InterceptorMessage& msg) { if (msg.message_type() == DATA_IS_READY) { - IncreaseReady(msg.src_id()); - ready_queue_.push(msg.scope_idx()); + VLOG(3) << "Compute interceptor " << interceptor_id_ + << " receive data_is_ready " << msg.src_id() << " " + << msg.scope_idx() << " "; + IncreaseReady(msg.src_id(), msg.scope_idx()); Run(); } else if (msg.message_type() == DATA_IS_USELESS) { + VLOG(3) << "Compute interceptor " << interceptor_id_ + << " receive data_is_useless " << msg.src_id() << " " + << msg.scope_idx() << " "; DecreaseBuff(msg.src_id()); Run(); } diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h index eade47fd878..519d0af4306 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h @@ -32,25 +32,24 @@ class ComputeInterceptor : public Interceptor { virtual void RunOps(); virtual void SendDataReadyToDownStream(); virtual void ReplyCompletedToUpStream(); + virtual void Compute(const InterceptorMessage& msg); + void Run(); + void IncreaseReady(int64_t up_id, int64_t scope_id); + void DecreaseBuff(int64_t down_id); - std::queue ready_queue_; int64_t cur_scope_id_; + // upstream_id-->(max_ready_size, scope-->ready_size) + std::map>> + in_readys_{}; + // downstream_id-->(max_buffer_size, used_size) + std::map> out_buffs_{}; + private: void PrepareDeps(); - void IncreaseReady(int64_t up_id); - void DecreaseBuff(int64_t down_id); bool IsInputReady(); bool CanWriteOutput(); - - void Run(); - void Compute(const InterceptorMessage& msg); - - // upstream_id-->(max_ready_size, ready_size) - std::map> in_readys_{}; - // downstream_id-->(max_buffer_size, used_size) - std::map> out_buffs_{}; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/cond_interceptor.cc b/paddle/fluid/distributed/fleet_executor/cond_interceptor.cc index 1d82b73fb89..ace969e7980 100644 --- a/paddle/fluid/distributed/fleet_executor/cond_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/cond_interceptor.cc @@ -98,8 +98,6 @@ void CondInterceptor::ReplyDataIsUseless(int64_t up_id) { } void CondInterceptor::Compute() { - cur_scope_id_ = ready_queue_.front(); - ready_queue_.pop(); bool cond = GetCondResult(); VLOG(3) << "Cond interceptor get condition var " << node_->cond_var() << " with value " << cond; @@ -109,14 +107,14 @@ void CondInterceptor::Compute() { SendDataReady(down_id); } } else { - VLOG(3) << "Finish loop in scope " << cur_scope_id_; + VLOG(0) << "Finish loop in scope " << cur_scope_id_; SendDataReady(stop_loop_id_); } } void CondInterceptor::Run(const InterceptorMessage& msg) { if (msg.message_type() == DATA_IS_READY) { - ready_queue_.push(msg.scope_idx()); + cur_scope_id_ = msg.scope_idx(); Compute(); } else if (msg.message_type() == DATA_IS_USELESS) { if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) { diff --git a/paddle/fluid/distributed/fleet_executor/cond_interceptor.h b/paddle/fluid/distributed/fleet_executor/cond_interceptor.h index 81b001135f1..8ea2d4b370c 100644 --- a/paddle/fluid/distributed/fleet_executor/cond_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/cond_interceptor.h @@ -39,7 +39,6 @@ class CondInterceptor final : public Interceptor { void SendDataReady(int64_t down_id); void ReplyDataIsUseless(int64_t up_id); - std::queue ready_queue_; int64_t cur_scope_id_; std::set normal_in_id_; diff --git a/paddle/fluid/distributed/fleet_executor/start_interceptor.cc b/paddle/fluid/distributed/fleet_executor/start_interceptor.cc new file mode 100644 index 00000000000..b5f3bcb2404 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/start_interceptor.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/start_interceptor.h" + +#include "paddle/fluid/distributed/fleet_executor/task_node.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/core/errors.h" + +namespace paddle { +namespace distributed { + +StartInterceptor::StartInterceptor(int64_t interceptor_id, TaskNode* node) + : ComputeInterceptor(interceptor_id, node) { + auto& downstream = node_->downstream(); + PADDLE_ENFORCE_EQ( + downstream.size(), + 1, + platform::errors::OutOfRange( + "The downstream for StartInterceptor only support 1 for now.")); + for (auto down : downstream) { + batch_size_ = down.second; + } + bool evenly_divisible = ((node_->max_run_times() % batch_size_) == 0); + PADDLE_ENFORCE( + evenly_divisible, + platform::errors::Fatal( + "Wrong config: Num of step should be divided by batch_size," + "num_step=%lld, batch_size=%lld", + node_->max_run_times(), + batch_size_)); +} + +void StartInterceptor::RunOps() { + finish_count_++; + ComputeInterceptor::RunOps(); +} + +void StartInterceptor::SendDataReadyToDownStream() { + for (auto& outs : out_buffs_) { + auto down_id = outs.first; + auto max_buff_size = outs.second.first; + auto used_size = outs.second.second; + used_size += 1; + if (max_buff_size != INFINITE_BUFFER_SIZE) { + PADDLE_ENFORCE_LE( + used_size, + max_buff_size, + platform::errors::OutOfRange("downstream=%lld used buff size must <= " + "max_buff_size, but now used_size=%lld, " + "max_buff_size=%lld", + down_id, + used_size, + max_buff_size)); + } + outs.second.second = used_size; + } + if (finish_count_ == batch_size_) { + for (int64_t i = 0; i < batch_size_; ++i) { + for (auto& outs : out_buffs_) { + auto down_id = outs.first; + InterceptorMessage ready_msg; + ready_msg.set_message_type(DATA_IS_READY); + ready_msg.set_scope_idx(step_); + VLOG(3) << "StartInterceptor " << interceptor_id_ + << " Send data_is_ready msg to " << down_id + << " in scope: " << step_; + Send(down_id, ready_msg); + } + step_++; + } + } +} + +void StartInterceptor::Compute(const InterceptorMessage& msg) { + if (msg.message_type() == DATA_IS_READY) { + VLOG(3) << "Start interceptor " << interceptor_id_ + << " receive data_is_ready " << msg.src_id() << " " + << msg.scope_idx() << " "; + IncreaseReady(msg.src_id(), msg.scope_idx()); + Run(); + } else if (msg.message_type() == DATA_IS_USELESS) { + VLOG(3) << "Start interceptor receive data_is_useless " << msg.src_id() + << " " << finish_count_; + finish_count_--; + if (finish_count_ == 0) { + for (int64_t i = 0; i < batch_size_; ++i) { + for (auto& outs : out_buffs_) { + auto down_id = outs.first; + DecreaseBuff(down_id); + } + } + for (int64_t i = 0; i < batch_size_; ++i) { + Run(); + } + } + } +} + +REGISTER_INTERCEPTOR(Start, StartInterceptor); + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/start_interceptor.h b/paddle/fluid/distributed/fleet_executor/start_interceptor.h new file mode 100644 index 00000000000..f082c48922b --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/start_interceptor.h @@ -0,0 +1,39 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h" + +namespace paddle { +namespace distributed { + +class StartInterceptor final : public ComputeInterceptor { + public: + StartInterceptor(int64_t interceptor_id, TaskNode* node); + + private: + void SendDataReadyToDownStream() override; + void RunOps() override; + void Compute(const InterceptorMessage& msg) override; + + int64_t batch_size_{0}; + int64_t finish_count_{0}; + int64_t step_{0}; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc index 618e55ba6ef..1a4f3f2ce9a 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc @@ -21,6 +21,9 @@ limitations under the License. */ #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/phi/core/kernel_registry.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc index 3dca7aed141..12fc77a2717 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc @@ -66,7 +66,7 @@ TEST(AmplifierInterceptor, Amplifier) { MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); - int64_t micro_steps = 3; + int64_t micro_steps = 1; // NOTE: don't delete, otherwise interceptor will use undefined node TaskNode* source = diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc index 3101ad5f489..4a29f07db5b 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc @@ -84,8 +84,8 @@ TEST(AmplifierInterceptor, Amplifier) { TaskNode* source = new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times TaskNode* node_a = new TaskNode(0, 0, 0, micro_steps); // role, rank, task_id - TaskNode* node_b = new TaskNode(0, 0, 1, 3); - TaskNode* node_c = new TaskNode(0, 0, 2, 3); + TaskNode* node_b = new TaskNode(0, 0, 1, micro_steps); + TaskNode* node_c = new TaskNode(0, 0, 2, micro_steps); TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps); TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps); diff --git a/python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py b/python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py index f6418cdee2c..bb235af71bc 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py @@ -33,7 +33,7 @@ def body(i, ten, data): return [i, ten, data] -num_micro_batches = 3 +num_micro_batches = 4 def batch_generator_creator(): @@ -126,7 +126,7 @@ class TestFleetExecutor(unittest.TestCase): task_a = TaskNode( 0, num_micro_batches, - node_type="Compute", + node_type="Start", task_id=0, program=program_a, lazy_initialize=True, -- GitLab