diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index cc5ed287e954f67d9c2877a413333d72a4bde534..9cf1cdde223488a5ae5a56676b1d97bdeac93b5c 100755 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -36,6 +36,7 @@ cc_library( interceptor.cc compute_interceptor.cc amplifier_interceptor.cc + cond_interceptor.cc source_interceptor.cc sink_interceptor.cc message_service.cc @@ -66,6 +67,8 @@ if(WITH_DISTRIBUTE) set_source_files_properties( amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties( + cond_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties( source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties( diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 3449c87998a9dba21824e854afdb7216cb818164..094afff577a9e851640cfe947f72656d8395e556 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -33,6 +33,7 @@ USE_INTERCEPTOR(Source); USE_INTERCEPTOR(Compute); USE_INTERCEPTOR(Amplifier); USE_INTERCEPTOR(Sink); +USE_INTERCEPTOR(Cond); void Carrier::Init( int64_t rank, @@ -96,29 +97,30 @@ void Carrier::CopyParameters( int microbatch_id, const framework::ProgramDesc& program, const std::vector& inference_root_scope_vars) { - auto& global_block = program.Block(0); - std::map inference_root_scope_var_map; for (auto var_name : inference_root_scope_vars) { inference_root_scope_var_map.insert({var_name, 1}); } - for (auto& var : global_block.AllVars()) { - std::string var_name = var->Name(); - bool force_root = inference_root_scope_var_map.find(var_name) != - inference_root_scope_var_map.end(); - if (force_root) { - VLOG(4) << var_name << " will be forced to be created in the root scope."; - } - if ((var->Persistable() || force_root) && microbatch_id == 0) { - auto* ptr = root_scope_->Var(var->Name()); - InitializeVariable(ptr, var->GetType()); - VLOG(5) << "Create persistable var: " << var->Name() - << ", which pointer is " << ptr; - } else if (!var->Persistable()) { - auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name()); - VLOG(5) << "Create variable " << var->Name() << " for microbatch " - << microbatch_id << ", which pointer is " << ptr << "."; - InitializeVariable(ptr, var->GetType()); + for (size_t i = 0; i < program.Size(); ++i) { + for (auto& var : program.Block(i).AllVars()) { + std::string var_name = var->Name(); + bool force_root = inference_root_scope_var_map.find(var_name) != + inference_root_scope_var_map.end(); + if (force_root) { + VLOG(4) << var_name + << " will be forced to be created in the root scope."; + } + if ((var->Persistable() || force_root) && microbatch_id == 0) { + auto* ptr = root_scope_->Var(var->Name()); + InitializeVariable(ptr, var->GetType()); + VLOG(5) << "Create persistable var: " << var->Name() + << ", which pointer is " << ptr; + } else if (!var->Persistable()) { + auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name()); + VLOG(5) << "Create variable " << var->Name() << " for microbatch " + << microbatch_id << ", which pointer is " << ptr << "."; + InitializeVariable(ptr, var->GetType()); + } } } } diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 5017f81523c8aea31fb8732e001e4af311313d32..9aedaa131400f3bfd6be24953050071e8970a557 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -125,6 +125,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() { InterceptorMessage ready_msg; ready_msg.set_message_type(DATA_IS_READY); + ready_msg.set_scope_idx(cur_scope_id_); VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " Send data_is_ready msg to " << down_id << " in scope: " << cur_scope_id_; @@ -152,6 +153,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { InterceptorMessage reply_msg; reply_msg.set_message_type(DATA_IS_USELESS); + reply_msg.set_scope_idx(cur_scope_id_); Send(up_id, reply_msg); } } diff --git a/paddle/fluid/distributed/fleet_executor/cond_interceptor.cc b/paddle/fluid/distributed/fleet_executor/cond_interceptor.cc new file mode 100644 index 0000000000000000000000000000000000000000..1d82b73fb898c7d2cd81bcd4e60d16dfea56c777 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/cond_interceptor.cc @@ -0,0 +1,141 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/cond_interceptor.h" +#include "paddle/fluid/distributed/fleet_executor/task_node.h" +#include "paddle/fluid/framework/executor_gc_helper.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/errors.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/errors.h" + +namespace paddle { +namespace distributed { + +CondInterceptor::CondInterceptor(int64_t interceptor_id, TaskNode* node) + : Interceptor(interceptor_id, node) { + PrepareDeps(); + RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); }); +} + +void CondInterceptor::PrepareDeps() { + auto& upstream = node_->upstream(); + auto& downstream = node_->downstream(); + auto& id_to_dep_type = node_->id_to_dep_type(); + + for (const auto& up : upstream) { + if (id_to_dep_type.at(up.first) == DependType::NORMAL) { + normal_in_id_.insert(up.first); + } + } + + for (const auto& down : downstream) { + if (id_to_dep_type.at(down.first) == DependType::NORMAL) { + normal_out_id_.insert(down.first); + } else if (id_to_dep_type.at(down.first) == DependType::STOP_LOOP) { + stop_loop_id_ = down.first; + } + } +} + +bool CondInterceptor::GetCondResult() { + PADDLE_ENFORCE_LT(cur_scope_id_, + microbatch_scopes_.size(), + platform::errors::InvalidArgument( + "Step out of range. There are %ld " + "microbatch_scopes, but recevice scope index %ld", + microbatch_scopes_.size(), + cur_scope_id_)); + auto* cond_var = + microbatch_scopes_[cur_scope_id_]->FindVar(node_->cond_var()); + PADDLE_ENFORCE(cond_var, + platform::errors::NotFound( + "Condition variable %s not exists in scope %ld", + node_->cond_var(), + cur_scope_id_)); + const auto& cond_tensor = cond_var->Get(); + bool res = false; + if (platform::is_gpu_place(cond_tensor.place())) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + phi::DenseTensor cpu_tensor; + framework::TensorCopy(cond_tensor, platform::CPUPlace(), &cpu_tensor); + platform::DeviceContextPool::Instance().Get(cond_tensor.place())->Wait(); + res = cpu_tensor.data()[0]; +#endif + } else if (platform::is_cpu_place(cond_tensor.place())) { + res = cond_tensor.data()[0]; + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupport device for cond interceptor.")); + } + return res; +} + +void CondInterceptor::SendDataReady(int64_t down_id) { + InterceptorMessage ready_msg; + ready_msg.set_message_type(DATA_IS_READY); + ready_msg.set_scope_idx(cur_scope_id_); + Send(down_id, ready_msg); +} + +void CondInterceptor::ReplyDataIsUseless(int64_t up_id) { + InterceptorMessage ready_msg; + ready_msg.set_message_type(DATA_IS_USELESS); + ready_msg.set_scope_idx(cur_scope_id_); + Send(up_id, ready_msg); +} + +void CondInterceptor::Compute() { + cur_scope_id_ = ready_queue_.front(); + ready_queue_.pop(); + bool cond = GetCondResult(); + VLOG(3) << "Cond interceptor get condition var " << node_->cond_var() + << " with value " << cond; + if (cond) { + VLOG(3) << "Loop again in scope " << cur_scope_id_; + for (auto& down_id : normal_out_id_) { + SendDataReady(down_id); + } + } else { + VLOG(3) << "Finish loop in scope " << cur_scope_id_; + SendDataReady(stop_loop_id_); + } +} + +void CondInterceptor::Run(const InterceptorMessage& msg) { + if (msg.message_type() == DATA_IS_READY) { + ready_queue_.push(msg.scope_idx()); + Compute(); + } else if (msg.message_type() == DATA_IS_USELESS) { + if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) { + for (auto& up_id : normal_in_id_) { + ReplyDataIsUseless(up_id); + } + // Gc the variable in while block + int64_t scope_id = msg.scope_idx(); + if (gc_) { + VLOG(3) << "Release vars in while block in scope " << scope_id; + framework::DeleteUnusedTensors(*microbatch_scopes_[scope_id], + node_->while_block_vars(), + gc_.get()); + } + } + } +} + +REGISTER_INTERCEPTOR(Cond, CondInterceptor); + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/cond_interceptor.h b/paddle/fluid/distributed/fleet_executor/cond_interceptor.h new file mode 100644 index 0000000000000000000000000000000000000000..81b001135f189ef7e85ee279774b103d7dec7368 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/cond_interceptor.h @@ -0,0 +1,51 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" + +namespace paddle { +namespace distributed { + +/* Condition Interceptor + * This is a special interceptor and only one condition op in the task node. + * This interceptor has two downstreams, + * 1. If the program result is true, select one of the downstreams, otherwise + * select another. + * 2. Used to implement while op in program. + */ +class CondInterceptor final : public Interceptor { + public: + CondInterceptor(int64_t interceptor_id, TaskNode* node); + + private: + void PrepareDeps(); + void Run(const InterceptorMessage& msg); + void Compute(); + bool GetCondResult(); + void SendDataReady(int64_t down_id); + void ReplyDataIsUseless(int64_t up_id); + + std::queue ready_queue_; + int64_t cur_scope_id_; + + std::set normal_in_id_; + std::set normal_out_id_; + int64_t stop_loop_id_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index a2d2ecd9bbf106c1ca3c774fc338c8a1eb82fe20..1f397a91746b96035fa420452f06702a43ef2c45 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -66,12 +66,11 @@ void FleetExecutor::Init( "Fleet executor is inited with empty task node")); // TODO(fleet_exe devs): the unused_vars should be got from run time graph std::vector> ops; - for (auto task_node : task_nodes) { - for (auto op : task_node->ops()) { - ops.emplace_back(std::unique_ptr(op)); - } + for (const auto& desc : program_desc.Block(0).AllOps()) { + ops.emplace_back(framework::OpRegistry::CreateOp(*desc)); } auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {}); + // NOTE: For inference, the vars in inference_root_scope_vars // shouldn't be deleted during inf, for that they may be the result of the // inf. If they are GCed, it will cause error during ZeroCopy the result. @@ -107,6 +106,25 @@ void FleetExecutor::Init( std::unordered_map interceptor_id_to_task; for (auto task_node : task_nodes) { task_node->SetUnusedVars(unused_vars); + if (task_node->type() == "Cond") { + std::vector while_block_vars; + std::vector vars_in_parent; + std::vector vars_in_sub; + for (auto& var : program_desc.Block(0).AllVars()) { + vars_in_parent.emplace_back(var->Name()); + } + for (auto& var : program_desc.Block(1).AllVars()) { + vars_in_sub.emplace_back(var->Name()); + } + std::sort(vars_in_parent.begin(), vars_in_parent.end()); + std::sort(vars_in_sub.begin(), vars_in_sub.end()); + std::set_difference(vars_in_sub.begin(), + vars_in_sub.end(), + vars_in_parent.begin(), + vars_in_parent.end(), + std::back_inserter(while_block_vars)); + task_node->SetWhileBlockVars(while_block_vars); + } int64_t interceptor_id = task_node->task_id(); interceptor_id_to_task.emplace(interceptor_id, task_node); } diff --git a/paddle/fluid/distributed/fleet_executor/task_node.cc b/paddle/fluid/distributed/fleet_executor/task_node.cc index 341ffe290a52055143db2729d30dd18582cbb6df..f02ac1c0f9aa03d5576ac22e524e390632338aaf 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.cc +++ b/paddle/fluid/distributed/fleet_executor/task_node.cc @@ -141,13 +141,19 @@ TaskNode::TaskNode(int32_t role, max_run_times_(max_run_times), max_slot_nums_(max_slot_nums) {} -bool TaskNode::AddUpstreamTask(int64_t task_id, int64_t buff_size) { +bool TaskNode::AddUpstreamTask(int64_t task_id, + int64_t buff_size, + DependType type) { const auto& ret = upstream_.emplace(task_id, buff_size); + id_to_dep_type_.emplace(task_id, type); return ret.second; } -bool TaskNode::AddDownstreamTask(int64_t task_id, int64_t buff_size) { +bool TaskNode::AddDownstreamTask(int64_t task_id, + int64_t buff_size, + DependType type) { const auto& ret = downstream_.emplace(task_id, buff_size); + id_to_dep_type_.emplace(task_id, type); return ret.second; } diff --git a/paddle/fluid/distributed/fleet_executor/task_node.h b/paddle/fluid/distributed/fleet_executor/task_node.h index 8538ac9ff81faccac10f6c3dddd2d8f143268ccf..a9f474d93c5bd7cc98af882951337814b08ada3f 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.h +++ b/paddle/fluid/distributed/fleet_executor/task_node.h @@ -14,8 +14,10 @@ #pragma once #include +#include #include #include +#include #include #include @@ -29,6 +31,8 @@ class OpDesc; } // namespace framework namespace distributed { +enum class DependType { NORMAL, LOOP, STOP_LOOP }; + class TaskNode final { public: using OperatorBase = paddle::framework::OperatorBase; @@ -61,6 +65,7 @@ class TaskNode final { int64_t rank, int64_t max_run_times, int64_t max_slot_nums); + ~TaskNode() = default; void SetProgram(paddle::framework::ProgramDesc* program); @@ -74,6 +79,7 @@ class TaskNode final { int64_t run_at_offset() const { return run_at_offset_; } int64_t reply_up_per_steps() const { return reply_up_per_steps_; } int64_t send_down_per_steps() const { return send_down_per_steps_; } + const std::string& cond_var() const { return cond_var_; } const std::unordered_map& upstream() const { return upstream_; } @@ -86,11 +92,20 @@ class TaskNode final { const std::vector>& unique_ops() const { return ops_vec_; } + const std::unordered_map id_to_dep_type() const { + return id_to_dep_type_; + } const std::unordered_map>& unused_vars() const { return unused_vars_; } + const std::vector while_block_vars() const { + return while_block_vars_; + } + void SetCondVarName(const std::string& cond_var_name) { + cond_var_ = cond_var_name; + } void SetRunPerSteps(int64_t value); void SetRunAtOffset(int64_t value); void SetReplyUpPerSteps(int64_t value); @@ -101,10 +116,17 @@ class TaskNode final { unused_vars) { unused_vars_ = unused_vars; } + void SetWhileBlockVars(const std::vector& vars) { + while_block_vars_ = vars; + } // upstream need buffs? - bool AddUpstreamTask(int64_t task_id, int64_t buff_size = 1); - bool AddDownstreamTask(int64_t task_id, int64_t buff_size = 1); + bool AddUpstreamTask(int64_t task_id, + int64_t buff_size = 1, + DependType type = DependType::NORMAL); + bool AddDownstreamTask(int64_t task_id, + int64_t buff_size = 1, + DependType type = DependType::NORMAL); std::string DebugString() const; private: @@ -115,10 +137,15 @@ class TaskNode final { // task_id-->buff_size std::unordered_map upstream_; std::unordered_map downstream_; + // task_id-->type + std::unordered_map id_to_dep_type_; + framework::ProgramDesc* program_; + std::string cond_var_; std::vector> ops_vec_; std::unordered_map> unused_vars_; + std::vector while_block_vars_; int32_t role_; int64_t rank_; diff --git a/paddle/fluid/pybind/bind_fleet_executor.cc b/paddle/fluid/pybind/bind_fleet_executor.cc index b4a6432e9e58b2ae8993651e8767731516302b0c..cd2341a219762c8853ac9d6d318293547eb0c8b1 100644 --- a/paddle/fluid/pybind/bind_fleet_executor.cc +++ b/paddle/fluid/pybind/bind_fleet_executor.cc @@ -65,6 +65,7 @@ struct npy_format_descriptor { namespace paddle { namespace pybind { +using paddle::distributed::DependType; using paddle::distributed::DistModel; using paddle::distributed::DistModelConfig; using paddle::distributed::DistModelDataBuf; @@ -164,6 +165,11 @@ void BindFleetExecutor(py::module* m) { .def( "run", &FleetExecutor::Run, py::call_guard()); + py::enum_(*m, "DependType") + .value("NORMAL", DependType::NORMAL) + .value("LOOP", DependType::LOOP) + .value("STOP_LOOP", DependType::STOP_LOOP); + py::class_(*m, "TaskNode") .def(py::init