diff --git a/lite/core/mir/memory_optimize_pass.cc b/lite/core/mir/memory_optimize_pass.cc index 4a4c83baef6c055320327409f2d8008a35f2f875..6956e805c673d8776d7bdd414dce0a5ddfcd965a 100644 --- a/lite/core/mir/memory_optimize_pass.cc +++ b/lite/core/mir/memory_optimize_pass.cc @@ -49,7 +49,9 @@ void MemoryOptimizePass::CollectLifeCycleByDevice( "equal", "lod_reset", "concat", - "graph_op"}; + "graph_op", + "feed", + "fetch"}; for (auto* tmp : node->inlinks) { CHECK(tmp->IsStmt()); std::string op_type = tmp->AsStmt().op_info()->Type(); @@ -76,36 +78,23 @@ void MemoryOptimizePass::CollectLifeCycleByDevice( std::vector requires(inputs.begin(), inputs.end()); requires.insert(requires.end(), outputs.begin(), outputs.end()); auto& stmt = op_node->AsStmt(); - // The feed and fetch op's inputs and outputs will not be reused. - if (stmt.op_info()->Type() == "feed" || - stmt.op_info()->Type() == "fetch") { - for (auto* node : op_node->outlinks) { - CHECK(node->IsArg()); - std::string var_name = node->AsArg().name; - TargetType target_type = node->AsArg().type->target(); - if (is_host(target_type)) target_type = TARGET(kHost); - (*lifecycles)[TargetToStr(target_type)].emplace( - var_name, std::make_pair(0, std::numeric_limits::max())); - } - } else { - for (Node* node : requires) { - CHECK(node->IsArg()); - auto& arg = node->AsArg(); - if (arg.is_weight || arg.is_persist) continue; - if (!valid_var(node)) continue; - std::string var_name = arg.name; - TargetType target_type = node->AsArg().type->target(); - if (is_host(target_type)) target_type = TARGET(kHost); + for (Node* node : requires) { + CHECK(node->IsArg()); + auto& arg = node->AsArg(); + if (arg.is_weight || arg.is_persist) continue; + if (!valid_var(node)) continue; + std::string var_name = arg.name; + TargetType target_type = node->AsArg().type->target(); + if (is_host(target_type)) target_type = TARGET(kHost); - if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) { - (*lifecycles)[TargetToStr(target_type)].emplace( - var_name, std::make_pair(max_lifecycle_, max_lifecycle_)); - } else { - int cur_life = - (*lifecycles)[TargetToStr(target_type)][var_name].second; - (*lifecycles)[TargetToStr(target_type)][var_name].second = - std::max(max_lifecycle_, cur_life); - } + if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) { + (*lifecycles)[TargetToStr(target_type)].emplace( + var_name, std::make_pair(max_lifecycle_, max_lifecycle_)); + } else { + int cur_life = + (*lifecycles)[TargetToStr(target_type)][var_name].second; + (*lifecycles)[TargetToStr(target_type)][var_name].second = + std::max(max_lifecycle_, cur_life); } } ++max_lifecycle_; @@ -167,6 +156,7 @@ void MemoryOptimizePass::MakeReusePlan( void MemoryOptimizePass::PerformReusePlan( SSAGraph* graph, const std::unordered_map& reuse_table) { + int node_append_idx = 0; for (auto& op_node : graph->StmtTopologicalOrder()) { if (!op_node->IsStmt()) continue; auto& stmt = op_node->AsStmt(); @@ -190,7 +180,9 @@ void MemoryOptimizePass::PerformReusePlan( std::string name = input_node->AsArg().name; if (reuse_table.count(name) && reuse_table.at(name) != name) { auto replace_name = reuse_table.at(name); - input_node->AsArg().name = replace_name; + input_node->AsArg().name = + replace_name + "(" + std::to_string(node_append_idx) + ")"; + node_append_idx++; } } @@ -212,7 +204,9 @@ void MemoryOptimizePass::PerformReusePlan( std::string name = out_node->AsArg().name; if (reuse_table.count(name) && reuse_table.at(name) != name) { auto replace_name = reuse_table.at(name); - out_node->AsArg().name = replace_name; + out_node->AsArg().name = + replace_name + "(" + std::to_string(node_append_idx) + ")"; + node_append_idx++; } } diff --git a/lite/operators/split_op.cc b/lite/operators/split_op.cc index 4ab42d4d2129313220598d3ebc5f3cf7757308b2..18280616aa00b734596b620727f6dcfd5beb67d7 100644 --- a/lite/operators/split_op.cc +++ b/lite/operators/split_op.cc @@ -69,6 +69,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { auto input = opdesc.Input("X").front(); auto outs = opdesc.Output("Out"); param_.x = scope->FindVar(input)->GetMutable(); + param_.output.clear(); for (auto var : outs) { param_.output.push_back(scope->FindVar(var)->GetMutable()); }