提交 36d6e416 编写于 作者: H hong19860320 提交者: GitHub

[Core] Fix memory_optmize_pass for reshape/reshape2 op with inplace=True (#3045)

上级 09f1ec4d
......@@ -39,9 +39,16 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
auto is_host = [](TargetType x) -> bool {
return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM);
};
// The vars which inputs or outputs are invalid op will not be reused.
auto valid_var = [&](Node* node) -> bool {
std::set<std::string> invalid_op = {"while",
// Collect the invalid input and output variables that will not be reused.
std::unordered_set<std::string> invalid_var_names;
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (!op_node->IsStmt()) continue;
auto op_info = op_node->AsStmt().op_info();
auto op_type = op_info->Type();
// The all of input and output variables of the Ops will not be reused.
std::unordered_set<std::string> invalid_op_nodes = {
"while",
"conditional_block",
"conditional_block_infer",
"merge_lod_tensor_infer",
......@@ -53,38 +60,58 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
"subgraph",
"feed",
"fetch"};
for (auto* tmp : node->inlinks) {
CHECK(tmp->IsStmt());
std::string op_type = tmp->AsStmt().op_info()->Type();
if (std::find(invalid_op.begin(), invalid_op.end(), op_type) !=
invalid_op.end()) {
return false;
auto invalid_op_node = invalid_op_nodes.find(op_type);
if (invalid_op_node != invalid_op_nodes.end()) {
for (auto in_var_node : op_node->inlinks) {
CHECK(in_var_node->IsArg());
invalid_var_names.insert(in_var_node->AsArg().name);
}
for (auto out_var_node : op_node->outlinks) {
CHECK(out_var_node->IsArg());
invalid_var_names.insert(out_var_node->AsArg().name);
}
continue;
}
// The specified input and output variables of the Ops whose 'inplace' attr
// is true will not be reused, such as reshape/reshape2's X and Out
// variables
std::unordered_map<std::string,
std::pair<std::unordered_set<std::string>,
std::unordered_set<std::string>>>
inplace_op_nodes = {{"reshape", {{"X"}, {"Out"}}},
{"reshape2", {{"X"}, {"Out"}}}};
auto inplace_op_node = inplace_op_nodes.find(op_type);
if (inplace_op_node != inplace_op_nodes.end()) {
bool inplace = false;
if (op_info->HasAttr("inplace")) {
inplace = op_info->GetAttr<bool>("inplace");
}
if (inplace) {
for (auto& in_param_name : inplace_op_node->second.first) {
const auto& in_arg_names = op_info->Input(in_param_name);
invalid_var_names.insert(in_arg_names.begin(), in_arg_names.end());
}
for (auto& out_param_name : inplace_op_node->second.second) {
const auto& out_arg_names = op_info->Output(out_param_name);
invalid_var_names.insert(out_arg_names.begin(), out_arg_names.end());
}
}
for (auto* tmp : node->outlinks) {
CHECK(tmp->IsStmt());
std::string op_type = tmp->AsStmt().op_info()->Type();
if (std::find(invalid_op.begin(), invalid_op.end(), op_type) !=
invalid_op.end()) {
return false;
}
}
return true;
};
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (op_node->IsStmt()) {
auto inputs = op_node->inlinks;
auto outputs = op_node->outlinks;
std::vector<Node*> requires(inputs.begin(), inputs.end());
requires.insert(requires.end(), outputs.begin(), outputs.end());
for (Node* node : requires) {
CHECK(node->IsArg());
auto& arg = node->AsArg();
std::vector<Node*> var_nodes(op_node->inlinks.begin(),
op_node->inlinks.end());
var_nodes.insert(
var_nodes.end(), op_node->outlinks.begin(), op_node->outlinks.end());
for (auto* var_node : var_nodes) {
CHECK(var_node->IsArg());
auto& arg = var_node->AsArg();
if (arg.is_weight || arg.is_persist) continue;
if (!valid_var(node)) continue;
std::string var_name = arg.name;
TargetType target_type = node->AsArg().type->target();
if (invalid_var_names.count(var_name)) continue;
TargetType target_type = arg.type->target();
if (is_host(target_type)) target_type = TARGET(kHost);
if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册