From e227d0939124a97766c9ed373991b82bd419df8d Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Tue, 18 Apr 2023 17:07:49 +0800 Subject: [PATCH] Revert "fix peak memory (#52175)" This reverts commit bd3b6adf2587dd13f69ea7ff38fcde91addcb57a. --- .../fleet_executor/fleet_executor.cc | 42 +++++-------------- 1 file changed, 10 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 05f75ad79ce..915b1f82804 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -110,12 +110,15 @@ void PreventVarsDelete( std::vector GetUnusedVarsAfterWhile( const framework::ProgramDesc& program_desc, + TaskNode* cond_task, const std::vector& vars_not_gc) { // NOTE: Since while op won't appear in task node, in order to analyze // the vars which should be free after calling while op, we rebuild the // whole program and get the unused vars after calling while op. - // vars in parent block should not be free until the while op is finished. - // The local vars will be free while running op in sub block. + // The vars in while block should not be free until the while op is finished. + // In a word, the vars need to be free after while op is: + // 1. Vars in parent block and being used in while block. + // 2. Local vars only defined in while block. // The unused vars above will be free in cond interceptor. std::vector while_block_vars; std::vector> ops; @@ -129,29 +132,14 @@ std::vector GetUnusedVarsAfterWhile( for (const auto& var_name : pair.second) { while_block_vars.emplace_back(var_name); } + for (auto& var : program_desc.Block(1).AllVars()) { + while_block_vars.emplace_back(var->Name()); + } } } return while_block_vars; } -std::unordered_map> -GetSubUnusedVars(const framework::ProgramDesc& program_desc, - const std::set& sub_block_tasks, - const std::vector& vars_not_gc) { - std::vector> ops; - for (auto* task_node : sub_block_tasks) { - for (const auto& op : task_node->ops()) { - ops.emplace_back(std::unique_ptr(op)); - } - } - auto unused_vars = framework::GetUnusedVars(program_desc.Block(1), ops, {}); - for (auto& unique_op : ops) { - unique_op.release(); - } - PreventVarsDelete(&unused_vars, vars_not_gc); - return unused_vars; -} - } // namespace void FleetExecutor::Init( @@ -174,13 +162,8 @@ void FleetExecutor::Init( for (const auto& task_node : task_nodes) { if (task_node->type() == "Cond") { GetSubBlockTask(task_nodes, task_node, &sub_block_tasks); - while_block_vars = - GetUnusedVarsAfterWhile(program_desc, inference_root_scope_vars); - for (auto* task_node : sub_block_tasks) { - for (auto iter : task_node->vars_to_dtype()) { - while_block_vars.emplace_back(iter.first); - } - } + while_block_vars = GetUnusedVarsAfterWhile( + program_desc, task_node, inference_root_scope_vars); VLOG(3) << "Vars will be gced after while op"; for (auto var : while_block_vars) { VLOG(3) << var; @@ -210,9 +193,6 @@ void FleetExecutor::Init( unique_op.release(); } - auto sub_unused_vars = - GetSubUnusedVars(program_desc, sub_block_tasks, while_block_vars); - // NOTE: For inference, the vars in inference_root_scope_vars // shouldn't be deleted during inf, for that they may be the result of the // inf. If they are GCed, it will cause error during ZeroCopy the result. @@ -223,8 +203,6 @@ void FleetExecutor::Init( for (auto task_node : task_nodes) { if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) { task_node->SetUnusedVars(global_unused_vars); - } else { - task_node->SetUnusedVars(sub_unused_vars); } int64_t interceptor_id = task_node->task_id(); interceptor_id_to_task.emplace(interceptor_id, task_node); -- GitLab