提交 26c8b551 编写于 作者: J jackzhang235 提交者: jackzhang235

fix a bug in subgraph division: put all intermediate outputs to local

outputs wrongly
上级 3f98791f
......@@ -450,9 +450,6 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
for (auto &var_node : output_var_nodes) {
output_var_names.push_back(var_node->AsArg().name);
}
for (auto &var_node : local_var_nodes) {
output_var_names.push_back(var_node->AsArg().name);
}
subgraph_op_desc.SetAttr<std::vector<std::string>>("input_data_names",
input_var_names);
subgraph_op_desc.SetAttr<std::vector<std::string>>("output_data_names",
......@@ -494,6 +491,9 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
for (auto &var_node : weight_var_nodes) {
input_var_names.push_back(var_node->AsArg().name);
}
for (auto &var_node : local_var_nodes) {
output_var_names.push_back(var_node->AsArg().name);
}
for (auto &var_node : unused_var_nodes) {
output_var_names.push_back(var_node->AsArg().name);
}
......@@ -579,13 +579,14 @@ void ExtractInputsOutputs(const std::vector<Node *> &op_nodes,
unused_var_nodes->insert(var_node);
continue;
}
// Var can have more than one next op node, So, if any one in the
// op_nodes then continue
bool next_op_in_nodes = false;
// Var can have more than one next op node, So, if all next nodes are in
// op_nodes then it should be put into local_var_nodes
bool next_op_in_nodes = true;
for (auto &next_op_node : var_node->outlinks) {
if (std::find(op_nodes.begin(), op_nodes.end(), next_op_node) !=
if (std::find(op_nodes.begin(), op_nodes.end(), next_op_node) ==
op_nodes.end()) {
next_op_in_nodes = true;
next_op_in_nodes = false;
break;
}
}
if (next_op_in_nodes) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册