提交 438946b8 编写于 作者: D DannyIsFunny

test=develop

上级 17f00635
......@@ -127,10 +127,16 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
}
}
std::vector<std::pair<std::string, int>> lifecycles_dims;
std::map<std::string, lifecycle_map_t> lifecycles_tmp;
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (op_node->IsStmt()) {
std::vector<Node*> var_nodes(op_node->inlinks.begin(),
op_node->inlinks.end());
// Current Scope
auto* scope = op_node->AsStmt().op()->scope();
// Collect var nodes
var_nodes.insert(
var_nodes.end(), op_node->outlinks.begin(), op_node->outlinks.end());
for (auto* var_node : var_nodes) {
......@@ -141,20 +147,40 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
if (invalid_var_names.count(var_name)) continue;
TargetType target_type = arg.type->target();
if (is_host(target_type)) target_type = TARGET(kHost);
if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) {
(*lifecycles)[TargetToStr(target_type)].emplace(
// Collect var dims
auto var = scope->FindVar(var_name);
int var_dims = var->Get<lite::Tensor>().numel();
if (var_dims < 0) {
var_dims = -4 * var_dims;
}
// Calculate lifecycle
if (!(lifecycles_tmp)[TargetToStr(target_type)].count(var_name)) {
(lifecycles_tmp)[TargetToStr(target_type)].emplace(
var_name, std::make_pair(max_lifecycle_, max_lifecycle_));
lifecycles_dims.push_back(std::make_pair(var_name, var_dims));
} else {
int cur_life =
(*lifecycles)[TargetToStr(target_type)][var_name].second;
(*lifecycles)[TargetToStr(target_type)][var_name].second =
(lifecycles_tmp)[TargetToStr(target_type)][var_name].second;
(lifecycles_tmp)[TargetToStr(target_type)][var_name].second =
std::max(max_lifecycle_, cur_life);
}
}
++max_lifecycle_;
}
}
// sort nodes according to their dims
sort(lifecycles_dims.begin(),
lifecycles_dims.end(),
[](const std::pair<std::string, int>& x,
const std::pair<std::string, int>& y) -> int {
return x.second > y.second;
});
for (auto it = lifecycles_dims.begin(); it != lifecycles_dims.end(); it++) {
auto node_name = it->first;
(*lifecycles)[TargetToStr(TARGET(kHost))].emplace(
node_name, (lifecycles_tmp)[TargetToStr(TARGET(kHost))][node_name]);
}
LOG(INFO) << "There are " << (*lifecycles).size() << " types device var.";
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册