From 438946b87379c683b3f8541a3c46520d56316e3c Mon Sep 17 00:00:00 2001 From: DannyIsFunny <912790387@qq.com> Date: Wed, 19 Aug 2020 13:01:14 +0000 Subject: [PATCH] test=develop --- lite/core/mir/memory_optimize_pass.cc | 36 +++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/lite/core/mir/memory_optimize_pass.cc b/lite/core/mir/memory_optimize_pass.cc index eddbebb545..a0c2bb9c33 100644 --- a/lite/core/mir/memory_optimize_pass.cc +++ b/lite/core/mir/memory_optimize_pass.cc @@ -127,10 +127,16 @@ void MemoryOptimizePass::CollectLifeCycleByDevice( } } + std::vector> lifecycles_dims; + std::map lifecycles_tmp; + for (auto& op_node : graph->StmtTopologicalOrder()) { if (op_node->IsStmt()) { std::vector var_nodes(op_node->inlinks.begin(), op_node->inlinks.end()); + // Current Scope + auto* scope = op_node->AsStmt().op()->scope(); + // Collect var nodes var_nodes.insert( var_nodes.end(), op_node->outlinks.begin(), op_node->outlinks.end()); for (auto* var_node : var_nodes) { @@ -141,20 +147,40 @@ void MemoryOptimizePass::CollectLifeCycleByDevice( if (invalid_var_names.count(var_name)) continue; TargetType target_type = arg.type->target(); if (is_host(target_type)) target_type = TARGET(kHost); - - if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) { - (*lifecycles)[TargetToStr(target_type)].emplace( + // Collect var dims + auto var = scope->FindVar(var_name); + int var_dims = var->Get().numel(); + if (var_dims < 0) { + var_dims = -4 * var_dims; + } + // Calculate lifecycle + if (!(lifecycles_tmp)[TargetToStr(target_type)].count(var_name)) { + (lifecycles_tmp)[TargetToStr(target_type)].emplace( var_name, std::make_pair(max_lifecycle_, max_lifecycle_)); + lifecycles_dims.push_back(std::make_pair(var_name, var_dims)); } else { int cur_life = - (*lifecycles)[TargetToStr(target_type)][var_name].second; - (*lifecycles)[TargetToStr(target_type)][var_name].second = + (lifecycles_tmp)[TargetToStr(target_type)][var_name].second; + (lifecycles_tmp)[TargetToStr(target_type)][var_name].second = std::max(max_lifecycle_, cur_life); } } ++max_lifecycle_; } } + // sort nodes according to their dims + sort(lifecycles_dims.begin(), + lifecycles_dims.end(), + [](const std::pair& x, + const std::pair& y) -> int { + return x.second > y.second; + }); + for (auto it = lifecycles_dims.begin(); it != lifecycles_dims.end(); it++) { + auto node_name = it->first; + (*lifecycles)[TargetToStr(TARGET(kHost))].emplace( + node_name, (lifecycles_tmp)[TargetToStr(TARGET(kHost))][node_name]); + } + LOG(INFO) << "There are " << (*lifecycles).size() << " types device var."; } -- GitLab