diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index ae3776d2c5beacbccc7d63f05aff7882a9b2440a..ec95f2a14623e02f59b5d3576ab4a80dcba7424f 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include +#include #include #include "paddle/fluid/distributed/fleet_executor/global.h" @@ -53,40 +54,40 @@ FleetExecutor::~FleetExecutor() { } } -void FleetExecutor::Init( - const std::string& carrier_id, - const framework::ProgramDesc& program_desc, - framework::Scope* scope, - const platform::Place& place, - int64_t num_micro_batches, - const std::vector& task_nodes, - const std::unordered_map& task_id_to_rank, - const std::vector& inference_root_scope_vars, - const std::vector& micro_scope_list) { - PADDLE_ENFORCE_GT(task_nodes.size(), - 0, - platform::errors::InvalidArgument( - "Fleet executor is inited with empty task node")); - // TODO(fleet_exe devs): the unused_vars should be got from run time graph - std::vector> ops; - for (const auto& desc : program_desc.Block(0).AllOps()) { - ops.emplace_back(framework::OpRegistry::CreateOp(*desc)); +namespace { +void GetSubBlockTask(const std::vector& tasks, + TaskNode* cur_task, + std::set* sub_block_task) { + auto& downstream = cur_task->downstream(); + auto& id_to_dep_type = cur_task->id_to_dep_type(); + for (auto& down : downstream) { + int64_t task_id = down.first; + if (id_to_dep_type.at(task_id) == DependType::NORMAL) { + for (const auto& task : tasks) { + if (task->task_id() == task_id) { + sub_block_task->emplace(task); + GetSubBlockTask(tasks, task, sub_block_task); + } + } + } } - auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {}); +} - // NOTE: For inference, the vars in inference_root_scope_vars - // shouldn't be deleted during inf, for that they may be the result of the - // inf. If they are GCed, it will cause error during ZeroCopy the result. +void PreventVarsDelete( + std::unordered_map>* unused_vars, + const std::vector& vars_not_gc) { std::vector changed_ops; - for (auto pair : unused_vars) { + + for (const auto& pair : *unused_vars) { const framework::OperatorBase* op = pair.first; - std::vector unused = pair.second; - for (auto name : inference_root_scope_vars) { - auto iter = std::find(unused.begin(), unused.end(), name); - if (iter != unused.end()) { + std::vector cur_unused = pair.second; + for (auto name : vars_not_gc) { + auto iter = std::find(cur_unused.begin(), cur_unused.end(), name); + if (iter != cur_unused.end()) { VLOG(3) << "Removing var: [" << name << "] from the unused vars list of op: [" << op->Type() << "]"; - unused.erase(iter); + cur_unused.erase(iter); if (std::find(changed_ops.begin(), changed_ops.end(), op) == changed_ops.end()) { // record the op whose unused vars have been updated @@ -95,48 +96,118 @@ void FleetExecutor::Init( } } // update the unused vars list in the map - unused_vars[op] = unused; + unused_vars->at(op) = cur_unused; } for (auto op : changed_ops) { - auto iter = unused_vars.find(op); + const auto& iter = unused_vars->find(op); if (iter->second.empty()) { // remove those ops in the map that have empty unused vars list VLOG(3) << "Removing op: [" << op->Type() << "] from unused_vars map."; - unused_vars.erase(iter); + unused_vars->erase(iter); } } - runtime_graph_ = std::make_shared(); - std::unordered_map interceptor_id_to_task; - for (auto task_node : task_nodes) { - task_node->SetUnusedVars(unused_vars); - if (task_node->type() == "Cond") { - std::vector while_block_vars; - VLOG(3) << "Vars in while sub block:"; - for (auto& var : program_desc.Block(1).AllVars()) { - VLOG(3) << var->Name(); - while_block_vars.emplace_back(var->Name()); - } - for (const auto& pair : unused_vars) { - if (pair.first->Type() == "while") { - for (const auto& var_name : pair.second) { - while_block_vars.emplace_back(var_name); - } - } +} + +std::vector GetUnusedVarsAfterWhile( + const framework::ProgramDesc& program_desc, + TaskNode* cond_task, + const std::vector vars_not_gc) { + std::vector while_block_vars; + std::vector> ops; + for (const auto& desc : program_desc.Block(0).AllOps()) { + ops.emplace_back(framework::OpRegistry::CreateOp(*desc)); + } + auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {}); + PreventVarsDelete(&unused_vars, vars_not_gc); + for (const auto& pair : unused_vars) { + if (pair.first->Type() == "while") { + for (const auto& var_name : pair.second) { + while_block_vars.emplace_back(var_name); } - VLOG(3) << "Vars below will be removed after while:"; - for (const auto& name : while_block_vars) { - VLOG(3) << name; + } + } + return while_block_vars; +} + +} // namespace + +void FleetExecutor::Init( + const std::string& carrier_id, + const framework::ProgramDesc& program_desc, + framework::Scope* scope, + const platform::Place& place, + int64_t num_micro_batches, + const std::vector& task_nodes, + const std::unordered_map& task_id_to_rank, + const std::vector& inference_root_scope_vars, + const std::vector& micro_scope_list) { + PADDLE_ENFORCE_GT(task_nodes.size(), + 0, + platform::errors::InvalidArgument( + "Fleet executor is inited with empty task node")); + // Set the unused var after running while op + std::set sub_block_tasks; + std::vector while_block_vars; + for (const auto& task_node : task_nodes) { + if (task_node->type() == "Cond") { + GetSubBlockTask(task_nodes, task_node, &sub_block_tasks); + while_block_vars = GetUnusedVarsAfterWhile( + program_desc, task_node, inference_root_scope_vars); + VLOG(3) << "Vars will be gced after while op"; + for (auto var : while_block_vars) { + VLOG(3) << var; } task_node->SetWhileBlockVars(while_block_vars); } + } + std::vector sub_block_ops; + for (const auto& task_node : sub_block_tasks) { + for (const auto& op : task_node->ops()) { + sub_block_ops.emplace_back(op); + } + } + // Analyse the unused vars in block 0. The operators in block 1 + // should be passed in first for prevent vars been released but removed soon. + // Since the unused vars in block 1 need to analyse separately. + std::vector> ops; + for (const auto& task_node : task_nodes) { + for (const auto& op : task_node->ops()) { + ops.emplace_back(std::unique_ptr(op)); + } + } + auto global_unused_vars = + framework::GetUnusedVars(program_desc.Block(0), ops, {}); + + // Analyse the unused vars in block 1. + std::unordered_map> + sub_unused_vars; + if (program_desc.Size() > 1) { + sub_unused_vars = framework::GetUnusedVars(program_desc.Block(1), ops, {}); + PreventVarsDelete(&sub_unused_vars, while_block_vars); + } + for (auto& unique_op : ops) { + unique_op.release(); + } + + // NOTE: For inference, the vars in inference_root_scope_vars + // shouldn't be deleted during inf, for that they may be the result of the + // inf. If they are GCed, it will cause error during ZeroCopy the result. + PreventVarsDelete(&global_unused_vars, inference_root_scope_vars); + + runtime_graph_ = std::make_shared(); + std::unordered_map interceptor_id_to_task; + for (auto task_node : task_nodes) { + if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) { + task_node->SetUnusedVars(global_unused_vars); + } else { + task_node->SetUnusedVars(sub_unused_vars); + } int64_t interceptor_id = task_node->task_id(); interceptor_id_to_task.emplace(interceptor_id, task_node); } runtime_graph_->SetInterceptorIdToRank(task_id_to_rank); runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task); - for (auto& unique_op : ops) { - unique_op.release(); - } + VLOG(5) << runtime_graph_->DebugString(); Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id);