From 0685b3ec302fb8298856e86ab0c3e1c5ca786dea Mon Sep 17 00:00:00 2001 From: iamsonderr <38247842+iamsonderr@users.noreply.github.com> Date: Fri, 7 Jul 2023 10:27:12 +0800 Subject: [PATCH] [Paddle Inference] del inplace op in memory_optimize_pass.cc (#55081) * commit * del inplace op in memory_optimize_pass.cc * check code style --- .../analysis/passes/memory_optimize_pass.cc | 120 ------------------ 1 file changed, 120 deletions(-) diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc index d6baea5e65c..9fe1ba1f15d 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc @@ -248,125 +248,6 @@ void MakeSimpleReusePlan( } } -// Remove the inplace operation from the plan because it does not support memory -// reuse -void DelInplaceOpFromPlan( - Graph* graph, - std::unordered_map* node2cluster, - int sort_kind) { - auto topo_nodes = TopologyVarientSort( - *graph, static_cast(sort_kind)); - for (auto* op_node : topo_nodes) { - if (!op_node->IsOp()) continue; - auto input_tensors = op_node->inputs; - auto output_tensors = op_node->outputs; - - std::unordered_set in_names; - for (const Node* node : input_tensors) { - if (!node->Var()) continue; - if (node->Var()->Persistable()) continue; - std::string var = node->Name(); - in_names.insert(var); - } - - for (const Node* node : output_tensors) { - if (!node->Var()) continue; - if (node->Var()->Persistable()) continue; - std::string var = node->Name(); - if (in_names.find(var) != in_names.end()) { - // delete key - if (node2cluster->count(var)) { - node2cluster->erase(var); - } - // delete value - std::string tmp_name = ""; - for (auto it = node2cluster->begin(); it != node2cluster->end(); ++it) { - if (it->second == var) { - if (tmp_name == "") { - tmp_name = it->first; - } - it->second = tmp_name; - } - } - } - } - } -} - -// NOTE The optimized opdesc doesn't match ir::Graph. -void UpdateOpDescsByReuse( - Graph* graph, - const std::unordered_map& reuse_table, - int sort_kind) { - // TODO(Superjomn) change here to be compatible with the runtime order. - for (auto* node : TopologyVarientSort( - *graph, static_cast(sort_kind))) { - if (node->IsOp()) { - // Replace the original inputs/outputs with the reused tensors. - std::unordered_map> in_args, - out_args; - for (auto argument : node->Op()->Inputs()) { - for (const auto& x : argument.second) { - auto name = x; - if (reuse_table.count(x) && reuse_table.at(x) != x) { - name = reuse_table.at(x); - } - in_args[argument.first].push_back(name); - VLOG(4) << node->Name() << " input " << x << " -> " << name; - } - } - - // modify the graph - for (auto input_node : node->inputs) { - PADDLE_ENFORCE_EQ(input_node->IsVar(), - true, - platform::errors::PreconditionNotMet( - "The input node should be a variable.")); - std::string input_node_name = input_node->Name(); - if (reuse_table.count(input_node_name) && - reuse_table.at(input_node_name) != input_node_name) { - auto name = reuse_table.at(input_node_name); - input_node->RenameVar(name); - } - } - - for (auto argument : node->Op()->Outputs()) { - for (const auto& x : argument.second) { - auto name = x; - if (reuse_table.count(x) && reuse_table.at(x) != x) { - name = reuse_table.at(x); - } - out_args[argument.first].push_back(name); - VLOG(4) << node->Name() << " output " << x << " -> " << name; - } - } - - // modify the graph - for (auto out_node : node->outputs) { - PADDLE_ENFORCE_EQ(out_node->IsVar(), - true, - platform::errors::PreconditionNotMet( - "The output node should be a variable.")); - std::string out_node_name = out_node->Name(); - if (reuse_table.count(out_node_name) && - reuse_table.at(out_node_name) != out_node_name) { - auto name = reuse_table.at(out_node_name); - out_node->RenameVar(name); - } - } - - // Update arguments. - for (auto& arg : in_args) { - node->Op()->SetInput(arg.first, arg.second); - } - for (auto& arg : out_args) { - node->Op()->SetOutput(arg.first, arg.second); - } - node->Op()->Flush(); - } - } -} - std::string MemoryOptimizePass::repr() const { return "memory_optimize_pass"; } void MemoryOptimizePass::RunImpl(Argument* argument) { @@ -395,7 +276,6 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { CollectLifeCycle(graph, &lifecycles, sort_kind); CollectVarMemorySize(graph, &space_table); MakeSimpleReusePlan(lifecycles, space_table, &node2cluster, &cluster_size); - DelInplaceOpFromPlan(graph, &node2cluster, sort_kind); auto* pass_res_info = PassResultInfoForRuntime::Instance(); pass_res_info->Set( -- GitLab