提交 36d6e416 编写于 作者: H hong19860320 提交者: GitHub

[Core] Fix memory_optmize_pass for reshape/reshape2 op with inplace=True (#3045)

上级 09f1ec4d
...@@ -39,9 +39,16 @@ void MemoryOptimizePass::CollectLifeCycleByDevice( ...@@ -39,9 +39,16 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
auto is_host = [](TargetType x) -> bool { auto is_host = [](TargetType x) -> bool {
return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM); return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM);
}; };
// The vars which inputs or outputs are invalid op will not be reused.
auto valid_var = [&](Node* node) -> bool { // Collect the invalid input and output variables that will not be reused.
std::set<std::string> invalid_op = {"while", std::unordered_set<std::string> invalid_var_names;
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (!op_node->IsStmt()) continue;
auto op_info = op_node->AsStmt().op_info();
auto op_type = op_info->Type();
// The all of input and output variables of the Ops will not be reused.
std::unordered_set<std::string> invalid_op_nodes = {
"while",
"conditional_block", "conditional_block",
"conditional_block_infer", "conditional_block_infer",
"merge_lod_tensor_infer", "merge_lod_tensor_infer",
...@@ -53,38 +60,58 @@ void MemoryOptimizePass::CollectLifeCycleByDevice( ...@@ -53,38 +60,58 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
"subgraph", "subgraph",
"feed", "feed",
"fetch"}; "fetch"};
for (auto* tmp : node->inlinks) { auto invalid_op_node = invalid_op_nodes.find(op_type);
CHECK(tmp->IsStmt()); if (invalid_op_node != invalid_op_nodes.end()) {
std::string op_type = tmp->AsStmt().op_info()->Type(); for (auto in_var_node : op_node->inlinks) {
if (std::find(invalid_op.begin(), invalid_op.end(), op_type) != CHECK(in_var_node->IsArg());
invalid_op.end()) { invalid_var_names.insert(in_var_node->AsArg().name);
return false; }
for (auto out_var_node : op_node->outlinks) {
CHECK(out_var_node->IsArg());
invalid_var_names.insert(out_var_node->AsArg().name);
}
continue;
}
// The specified input and output variables of the Ops whose 'inplace' attr
// is true will not be reused, such as reshape/reshape2's X and Out
// variables
std::unordered_map<std::string,
std::pair<std::unordered_set<std::string>,
std::unordered_set<std::string>>>
inplace_op_nodes = {{"reshape", {{"X"}, {"Out"}}},
{"reshape2", {{"X"}, {"Out"}}}};
auto inplace_op_node = inplace_op_nodes.find(op_type);
if (inplace_op_node != inplace_op_nodes.end()) {
bool inplace = false;
if (op_info->HasAttr("inplace")) {
inplace = op_info->GetAttr<bool>("inplace");
}
if (inplace) {
for (auto& in_param_name : inplace_op_node->second.first) {
const auto& in_arg_names = op_info->Input(in_param_name);
invalid_var_names.insert(in_arg_names.begin(), in_arg_names.end());
}
for (auto& out_param_name : inplace_op_node->second.second) {
const auto& out_arg_names = op_info->Output(out_param_name);
invalid_var_names.insert(out_arg_names.begin(), out_arg_names.end());
} }
} }
for (auto* tmp : node->outlinks) {
CHECK(tmp->IsStmt());
std::string op_type = tmp->AsStmt().op_info()->Type();
if (std::find(invalid_op.begin(), invalid_op.end(), op_type) !=
invalid_op.end()) {
return false;
} }
} }
return true;
};
for (auto& op_node : graph->StmtTopologicalOrder()) { for (auto& op_node : graph->StmtTopologicalOrder()) {
if (op_node->IsStmt()) { if (op_node->IsStmt()) {
auto inputs = op_node->inlinks; std::vector<Node*> var_nodes(op_node->inlinks.begin(),
auto outputs = op_node->outlinks; op_node->inlinks.end());
std::vector<Node*> requires(inputs.begin(), inputs.end()); var_nodes.insert(
requires.insert(requires.end(), outputs.begin(), outputs.end()); var_nodes.end(), op_node->outlinks.begin(), op_node->outlinks.end());
for (Node* node : requires) { for (auto* var_node : var_nodes) {
CHECK(node->IsArg()); CHECK(var_node->IsArg());
auto& arg = node->AsArg(); auto& arg = var_node->AsArg();
if (arg.is_weight || arg.is_persist) continue; if (arg.is_weight || arg.is_persist) continue;
if (!valid_var(node)) continue;
std::string var_name = arg.name; std::string var_name = arg.name;
TargetType target_type = node->AsArg().type->target(); if (invalid_var_names.count(var_name)) continue;
TargetType target_type = arg.type->target();
if (is_host(target_type)) target_type = TARGET(kHost); if (is_host(target_type)) target_type = TARGET(kHost);
if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) { if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册