提交 b7f5d94b 编写于 作者: Z Zhaolong Xing 提交者: GitHub

1. the split op's bug will triger memory optimize pass failed. (#2070)

test=develop
上级 4a948cfc
...@@ -49,7 +49,9 @@ void MemoryOptimizePass::CollectLifeCycleByDevice( ...@@ -49,7 +49,9 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
"equal", "equal",
"lod_reset", "lod_reset",
"concat", "concat",
"graph_op"}; "graph_op",
"feed",
"fetch"};
for (auto* tmp : node->inlinks) { for (auto* tmp : node->inlinks) {
CHECK(tmp->IsStmt()); CHECK(tmp->IsStmt());
std::string op_type = tmp->AsStmt().op_info()->Type(); std::string op_type = tmp->AsStmt().op_info()->Type();
...@@ -76,36 +78,23 @@ void MemoryOptimizePass::CollectLifeCycleByDevice( ...@@ -76,36 +78,23 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
std::vector<Node*> requires(inputs.begin(), inputs.end()); std::vector<Node*> requires(inputs.begin(), inputs.end());
requires.insert(requires.end(), outputs.begin(), outputs.end()); requires.insert(requires.end(), outputs.begin(), outputs.end());
auto& stmt = op_node->AsStmt(); auto& stmt = op_node->AsStmt();
// The feed and fetch op's inputs and outputs will not be reused. for (Node* node : requires) {
if (stmt.op_info()->Type() == "feed" || CHECK(node->IsArg());
stmt.op_info()->Type() == "fetch") { auto& arg = node->AsArg();
for (auto* node : op_node->outlinks) { if (arg.is_weight || arg.is_persist) continue;
CHECK(node->IsArg()); if (!valid_var(node)) continue;
std::string var_name = node->AsArg().name; std::string var_name = arg.name;
TargetType target_type = node->AsArg().type->target(); TargetType target_type = node->AsArg().type->target();
if (is_host(target_type)) target_type = TARGET(kHost); if (is_host(target_type)) target_type = TARGET(kHost);
(*lifecycles)[TargetToStr(target_type)].emplace(
var_name, std::make_pair(0, std::numeric_limits<int>::max()));
}
} else {
for (Node* node : requires) {
CHECK(node->IsArg());
auto& arg = node->AsArg();
if (arg.is_weight || arg.is_persist) continue;
if (!valid_var(node)) continue;
std::string var_name = arg.name;
TargetType target_type = node->AsArg().type->target();
if (is_host(target_type)) target_type = TARGET(kHost);
if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) { if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) {
(*lifecycles)[TargetToStr(target_type)].emplace( (*lifecycles)[TargetToStr(target_type)].emplace(
var_name, std::make_pair(max_lifecycle_, max_lifecycle_)); var_name, std::make_pair(max_lifecycle_, max_lifecycle_));
} else { } else {
int cur_life = int cur_life =
(*lifecycles)[TargetToStr(target_type)][var_name].second; (*lifecycles)[TargetToStr(target_type)][var_name].second;
(*lifecycles)[TargetToStr(target_type)][var_name].second = (*lifecycles)[TargetToStr(target_type)][var_name].second =
std::max(max_lifecycle_, cur_life); std::max(max_lifecycle_, cur_life);
}
} }
} }
++max_lifecycle_; ++max_lifecycle_;
...@@ -167,6 +156,7 @@ void MemoryOptimizePass::MakeReusePlan( ...@@ -167,6 +156,7 @@ void MemoryOptimizePass::MakeReusePlan(
void MemoryOptimizePass::PerformReusePlan( void MemoryOptimizePass::PerformReusePlan(
SSAGraph* graph, SSAGraph* graph,
const std::unordered_map<std::string, std::string>& reuse_table) { const std::unordered_map<std::string, std::string>& reuse_table) {
int node_append_idx = 0;
for (auto& op_node : graph->StmtTopologicalOrder()) { for (auto& op_node : graph->StmtTopologicalOrder()) {
if (!op_node->IsStmt()) continue; if (!op_node->IsStmt()) continue;
auto& stmt = op_node->AsStmt(); auto& stmt = op_node->AsStmt();
...@@ -190,7 +180,9 @@ void MemoryOptimizePass::PerformReusePlan( ...@@ -190,7 +180,9 @@ void MemoryOptimizePass::PerformReusePlan(
std::string name = input_node->AsArg().name; std::string name = input_node->AsArg().name;
if (reuse_table.count(name) && reuse_table.at(name) != name) { if (reuse_table.count(name) && reuse_table.at(name) != name) {
auto replace_name = reuse_table.at(name); auto replace_name = reuse_table.at(name);
input_node->AsArg().name = replace_name; input_node->AsArg().name =
replace_name + "(" + std::to_string(node_append_idx) + ")";
node_append_idx++;
} }
} }
...@@ -212,7 +204,9 @@ void MemoryOptimizePass::PerformReusePlan( ...@@ -212,7 +204,9 @@ void MemoryOptimizePass::PerformReusePlan(
std::string name = out_node->AsArg().name; std::string name = out_node->AsArg().name;
if (reuse_table.count(name) && reuse_table.at(name) != name) { if (reuse_table.count(name) && reuse_table.at(name) != name) {
auto replace_name = reuse_table.at(name); auto replace_name = reuse_table.at(name);
out_node->AsArg().name = replace_name; out_node->AsArg().name =
replace_name + "(" + std::to_string(node_append_idx) + ")";
node_append_idx++;
} }
} }
......
...@@ -69,6 +69,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -69,6 +69,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
auto input = opdesc.Input("X").front(); auto input = opdesc.Input("X").front();
auto outs = opdesc.Output("Out"); auto outs = opdesc.Output("Out");
param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>(); param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.output.clear();
for (auto var : outs) { for (auto var : outs) {
param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>()); param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册