提交 26c8b551 编写于 作者: J jackzhang235 提交者: jackzhang235

fix a bug in subgraph division: put all intermediate outputs to local

outputs wrongly
上级 3f98791f
...@@ -450,9 +450,6 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, ...@@ -450,9 +450,6 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
for (auto &var_node : output_var_nodes) { for (auto &var_node : output_var_nodes) {
output_var_names.push_back(var_node->AsArg().name); output_var_names.push_back(var_node->AsArg().name);
} }
for (auto &var_node : local_var_nodes) {
output_var_names.push_back(var_node->AsArg().name);
}
subgraph_op_desc.SetAttr<std::vector<std::string>>("input_data_names", subgraph_op_desc.SetAttr<std::vector<std::string>>("input_data_names",
input_var_names); input_var_names);
subgraph_op_desc.SetAttr<std::vector<std::string>>("output_data_names", subgraph_op_desc.SetAttr<std::vector<std::string>>("output_data_names",
...@@ -494,6 +491,9 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, ...@@ -494,6 +491,9 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
for (auto &var_node : weight_var_nodes) { for (auto &var_node : weight_var_nodes) {
input_var_names.push_back(var_node->AsArg().name); input_var_names.push_back(var_node->AsArg().name);
} }
for (auto &var_node : local_var_nodes) {
output_var_names.push_back(var_node->AsArg().name);
}
for (auto &var_node : unused_var_nodes) { for (auto &var_node : unused_var_nodes) {
output_var_names.push_back(var_node->AsArg().name); output_var_names.push_back(var_node->AsArg().name);
} }
...@@ -579,13 +579,14 @@ void ExtractInputsOutputs(const std::vector<Node *> &op_nodes, ...@@ -579,13 +579,14 @@ void ExtractInputsOutputs(const std::vector<Node *> &op_nodes,
unused_var_nodes->insert(var_node); unused_var_nodes->insert(var_node);
continue; continue;
} }
// Var can have more than one next op node, So, if any one in the // Var can have more than one next op node, So, if all next nodes are in
// op_nodes then continue // op_nodes then it should be put into local_var_nodes
bool next_op_in_nodes = false; bool next_op_in_nodes = true;
for (auto &next_op_node : var_node->outlinks) { for (auto &next_op_node : var_node->outlinks) {
if (std::find(op_nodes.begin(), op_nodes.end(), next_op_node) != if (std::find(op_nodes.begin(), op_nodes.end(), next_op_node) ==
op_nodes.end()) { op_nodes.end()) {
next_op_in_nodes = true; next_op_in_nodes = false;
break;
} }
} }
if (next_op_in_nodes) { if (next_op_in_nodes) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册