提交 68533439 编写于 作者: J jackzhang235

fix bug in subgraph partition

上级 b6c35b17
......@@ -413,6 +413,14 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
auto* sub_block_op_desc = sub_block_desc->GetOp<cpp::OpDesc>(i);
UpdateOutputTo(
sub_block_op_desc, tail_node->AsArg().name, cur_node->AsArg().name);
/* graph like this
* subgraph_op_0
* / \
* / \
* subgraph_op_1 host_op
*/
UpdateInputTo(
sub_block_op_desc, tail_node->AsArg().name, cur_node->AsArg().name);
}
// recreate the op
......
......@@ -449,6 +449,9 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
for (auto &var_node : output_var_nodes) {
output_var_names.push_back(var_node->AsArg().name);
}
for (auto &var_node : local_var_nodes) {
output_var_names.push_back(var_node->AsArg().name);
}
subgraph_op_desc.SetAttr<std::vector<std::string>>("input_data_names",
input_var_names);
subgraph_op_desc.SetAttr<std::vector<std::string>>("output_data_names",
......@@ -490,9 +493,6 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
for (auto &var_node : weight_var_nodes) {
input_var_names.push_back(var_node->AsArg().name);
}
for (auto &var_node : local_var_nodes) {
output_var_names.push_back(var_node->AsArg().name);
}
for (auto &var_node : unused_var_nodes) {
output_var_names.push_back(var_node->AsArg().name);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册