diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc old mode 100644 new mode 100755 index 2ff82986e945caf3ecd0ee91bac02c9a9ad48272..40a8c5ce66a2a5b7c5f54784abdcbdc2c9e3e531 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc @@ -222,6 +222,51 @@ void MakeSimpleReusePlan( } } +// Remove the inplace operation from the plan because it does not support memory +// reuse +void DelInplaceOpFromPlan( + Graph* graph, + std::unordered_map* node2cluster, + int sort_kind) { + auto topo_nodes = TopologyVarientSort( + *graph, static_cast(sort_kind)); + for (auto* op_node : topo_nodes) { + if (!op_node->IsOp()) continue; + auto input_tensors = op_node->inputs; + auto output_tensors = op_node->outputs; + + std::unordered_set in_names; + for (const Node* node : input_tensors) { + if (!node->Var()) continue; + if (node->Var()->Persistable()) continue; + std::string var = node->Name(); + in_names.insert(var); + } + + for (const Node* node : output_tensors) { + if (!node->Var()) continue; + if (node->Var()->Persistable()) continue; + std::string var = node->Name(); + if (in_names.find(var) != in_names.end()) { + // delete key + if (node2cluster->count(var)) { + node2cluster->erase(var); + } + // delete value + std::string tmp_name = ""; + for (auto it = node2cluster->begin(); it != node2cluster->end(); ++it) { + if (it->second == var) { + if (tmp_name == "") { + tmp_name = it->first; + } + it->second = tmp_name; + } + } + } + } + } +} + // NOTE The optimized opdesc doesn't match ir::Graph. void UpdateOpDescsByReuse( Graph* graph, @@ -324,6 +369,7 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { CollectLifeCycle(graph, &lifecycles, sort_kind); CollectVarMemorySize(graph, &space_table); MakeSimpleReusePlan(lifecycles, space_table, &node2cluster, &cluster_size); + DelInplaceOpFromPlan(graph, &node2cluster, sort_kind); auto* pass_res_info = PassResultInfoForRuntime::Instance(); pass_res_info->Set(