diff --git a/lite/core/mir/subgraph/subgraph_detector.cc b/lite/core/mir/subgraph/subgraph_detector.cc index 454682043b5199a32e56ed7c8bbea1752942c212..9799cc72437c7581bde681ef2e80c0234635c2fe 100644 --- a/lite/core/mir/subgraph/subgraph_detector.cc +++ b/lite/core/mir/subgraph/subgraph_detector.cc @@ -450,9 +450,6 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, for (auto &var_node : output_var_nodes) { output_var_names.push_back(var_node->AsArg().name); } - for (auto &var_node : local_var_nodes) { - output_var_names.push_back(var_node->AsArg().name); - } subgraph_op_desc.SetAttr>("input_data_names", input_var_names); subgraph_op_desc.SetAttr>("output_data_names", @@ -494,6 +491,9 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, for (auto &var_node : weight_var_nodes) { input_var_names.push_back(var_node->AsArg().name); } + for (auto &var_node : local_var_nodes) { + output_var_names.push_back(var_node->AsArg().name); + } for (auto &var_node : unused_var_nodes) { output_var_names.push_back(var_node->AsArg().name); } @@ -579,13 +579,14 @@ void ExtractInputsOutputs(const std::vector &op_nodes, unused_var_nodes->insert(var_node); continue; } - // Var can have more than one next op node, So, if any one in the - // op_nodes then continue - bool next_op_in_nodes = false; + // Var can have more than one next op node, So, if all next nodes are in + // op_nodes then it should be put into local_var_nodes + bool next_op_in_nodes = true; for (auto &next_op_node : var_node->outlinks) { - if (std::find(op_nodes.begin(), op_nodes.end(), next_op_node) != + if (std::find(op_nodes.begin(), op_nodes.end(), next_op_node) == op_nodes.end()) { - next_op_in_nodes = true; + next_op_in_nodes = false; + break; } } if (next_op_in_nodes) {