提交 245ab319 编写于 作者: W wenchunjiang

add make_tuple before reture as graph outputs in ConstructKernelGraph

上级 93e7c97a
......@@ -646,6 +646,16 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
MS_EXCEPTION_IF_NULL(func_graph_node);
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node);
ConstructKernelGraph(sub_func_graph);
} else if (prim->name() == kReturnOpName) {
std::vector<AnfNodePtr> outputs;
auto inputs = cnode->inputs();
if (inputs.size() < 2) {
MS_LOG(EXCEPTION) << "CNode[return] must have two inputs at least, actual inputs size is " << inputs.size();
}
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(outputs));
// add a make_tuple before return as graph output
graph->set_output(ConstructOutput(outputs, graph));
continue;
}
}
......@@ -655,11 +665,6 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());
graph->FrontBackendlMapAdd(node, new_cnode);
// set original return to kernel_graph
if (IsPrimitive(new_cnode->input(kAnfPrimitiveIndex), prim::kPrimReturn)) {
graph->set_return(new_cnode);
}
}
}
......
......@@ -144,6 +144,7 @@ constexpr auto kBNInferGradOpName = "BNInferGrad";
constexpr auto kCallOpName = "call";
constexpr auto kPartialOpName = "partial";
constexpr auto kSwitchOpName = "switch";
constexpr auto kReturnOpName = "return";
constexpr auto kLarsV2OpName = "LarsV2";
constexpr auto kLarsV2UpdateOpName = "LarsV2Update";
constexpr auto kSquareSumAllOpName = "SquareSumAll";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册