From aa96ddc303c05b581c9fc1dd36501503faeed91b Mon Sep 17 00:00:00 2001 From: gem5 <117625383+linsheng011@users.noreply.github.com> Date: Sun, 1 Jan 2023 11:05:10 +0800 Subject: [PATCH] memorty_optimize remove inplace op (#49431) --- .../analysis/passes/memory_optimize_pass.cc | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) mode change 100644 => 100755 paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc old mode 100644 new mode 100755 index 2ff82986e9..40a8c5ce66 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc @@ -222,6 +222,51 @@ void MakeSimpleReusePlan( } } +// Remove the inplace operation from the plan because it does not support memory +// reuse +void DelInplaceOpFromPlan( + Graph* graph, + std::unordered_map* node2cluster, + int sort_kind) { + auto topo_nodes = TopologyVarientSort( + *graph, static_cast(sort_kind)); + for (auto* op_node : topo_nodes) { + if (!op_node->IsOp()) continue; + auto input_tensors = op_node->inputs; + auto output_tensors = op_node->outputs; + + std::unordered_set in_names; + for (const Node* node : input_tensors) { + if (!node->Var()) continue; + if (node->Var()->Persistable()) continue; + std::string var = node->Name(); + in_names.insert(var); + } + + for (const Node* node : output_tensors) { + if (!node->Var()) continue; + if (node->Var()->Persistable()) continue; + std::string var = node->Name(); + if (in_names.find(var) != in_names.end()) { + // delete key + if (node2cluster->count(var)) { + node2cluster->erase(var); + } + // delete value + std::string tmp_name = ""; + for (auto it = node2cluster->begin(); it != node2cluster->end(); ++it) { + if (it->second == var) { + if (tmp_name == "") { + tmp_name = it->first; + } + it->second = tmp_name; + } + } + } + } + } +} + // NOTE The optimized opdesc doesn't match ir::Graph. void UpdateOpDescsByReuse( Graph* graph, @@ -324,6 +369,7 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { CollectLifeCycle(graph, &lifecycles, sort_kind); CollectVarMemorySize(graph, &space_table); MakeSimpleReusePlan(lifecycles, space_table, &node2cluster, &cluster_size); + DelInplaceOpFromPlan(graph, &node2cluster, sort_kind); auto* pass_res_info = PassResultInfoForRuntime::Instance(); pass_res_info->Set( -- GitLab