提交 e227d093 编写于 作者: L LiYuRio 提交者: GitHub

Revert "fix peak memory (#52175)"

This reverts commit bd3b6adf.
上级 6f3c9643
...@@ -110,12 +110,15 @@ void PreventVarsDelete( ...@@ -110,12 +110,15 @@ void PreventVarsDelete(
std::vector<std::string> GetUnusedVarsAfterWhile( std::vector<std::string> GetUnusedVarsAfterWhile(
const framework::ProgramDesc& program_desc, const framework::ProgramDesc& program_desc,
TaskNode* cond_task,
const std::vector<std::string>& vars_not_gc) { const std::vector<std::string>& vars_not_gc) {
// NOTE: Since while op won't appear in task node, in order to analyze // NOTE: Since while op won't appear in task node, in order to analyze
// the vars which should be free after calling while op, we rebuild the // the vars which should be free after calling while op, we rebuild the
// whole program and get the unused vars after calling while op. // whole program and get the unused vars after calling while op.
// vars in parent block should not be free until the while op is finished. // The vars in while block should not be free until the while op is finished.
// The local vars will be free while running op in sub block. // In a word, the vars need to be free after while op is:
// 1. Vars in parent block and being used in while block.
// 2. Local vars only defined in while block.
// The unused vars above will be free in cond interceptor. // The unused vars above will be free in cond interceptor.
std::vector<std::string> while_block_vars; std::vector<std::string> while_block_vars;
std::vector<std::unique_ptr<framework::OperatorBase>> ops; std::vector<std::unique_ptr<framework::OperatorBase>> ops;
...@@ -129,29 +132,14 @@ std::vector<std::string> GetUnusedVarsAfterWhile( ...@@ -129,29 +132,14 @@ std::vector<std::string> GetUnusedVarsAfterWhile(
for (const auto& var_name : pair.second) { for (const auto& var_name : pair.second) {
while_block_vars.emplace_back(var_name); while_block_vars.emplace_back(var_name);
} }
for (auto& var : program_desc.Block(1).AllVars()) {
while_block_vars.emplace_back(var->Name());
}
} }
} }
return while_block_vars; return while_block_vars;
} }
std::unordered_map<const framework::OperatorBase*, std::vector<std::string>>
GetSubUnusedVars(const framework::ProgramDesc& program_desc,
const std::set<TaskNode*>& sub_block_tasks,
const std::vector<std::string>& vars_not_gc) {
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (auto* task_node : sub_block_tasks) {
for (const auto& op : task_node->ops()) {
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
}
}
auto unused_vars = framework::GetUnusedVars(program_desc.Block(1), ops, {});
for (auto& unique_op : ops) {
unique_op.release();
}
PreventVarsDelete(&unused_vars, vars_not_gc);
return unused_vars;
}
} // namespace } // namespace
void FleetExecutor::Init( void FleetExecutor::Init(
...@@ -174,13 +162,8 @@ void FleetExecutor::Init( ...@@ -174,13 +162,8 @@ void FleetExecutor::Init(
for (const auto& task_node : task_nodes) { for (const auto& task_node : task_nodes) {
if (task_node->type() == "Cond") { if (task_node->type() == "Cond") {
GetSubBlockTask(task_nodes, task_node, &sub_block_tasks); GetSubBlockTask(task_nodes, task_node, &sub_block_tasks);
while_block_vars = while_block_vars = GetUnusedVarsAfterWhile(
GetUnusedVarsAfterWhile(program_desc, inference_root_scope_vars); program_desc, task_node, inference_root_scope_vars);
for (auto* task_node : sub_block_tasks) {
for (auto iter : task_node->vars_to_dtype()) {
while_block_vars.emplace_back(iter.first);
}
}
VLOG(3) << "Vars will be gced after while op"; VLOG(3) << "Vars will be gced after while op";
for (auto var : while_block_vars) { for (auto var : while_block_vars) {
VLOG(3) << var; VLOG(3) << var;
...@@ -210,9 +193,6 @@ void FleetExecutor::Init( ...@@ -210,9 +193,6 @@ void FleetExecutor::Init(
unique_op.release(); unique_op.release();
} }
auto sub_unused_vars =
GetSubUnusedVars(program_desc, sub_block_tasks, while_block_vars);
// NOTE: For inference, the vars in inference_root_scope_vars // NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the // shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result. // inf. If they are GCed, it will cause error during ZeroCopy the result.
...@@ -223,8 +203,6 @@ void FleetExecutor::Init( ...@@ -223,8 +203,6 @@ void FleetExecutor::Init(
for (auto task_node : task_nodes) { for (auto task_node : task_nodes) {
if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) { if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) {
task_node->SetUnusedVars(global_unused_vars); task_node->SetUnusedVars(global_unused_vars);
} else {
task_node->SetUnusedVars(sub_unused_vars);
} }
int64_t interceptor_id = task_node->task_id(); int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node); interceptor_id_to_task.emplace(interceptor_id, task_node);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册