提交 44d1499e 编写于 作者: Z zhoufeng

Adjust layer number of outputs of empty graph

Signed-off-by: Nzhoufeng <zhoufeng54@huawei.com>
上级 17319d8d
......@@ -51,6 +51,7 @@
namespace mindspore {
namespace session {
const size_t kInvalidIndex = SIZE_MAX;
constexpr size_t kReturnDataIndex = 1;
namespace {
void DumpGraphExeOrder(const std::vector<CNodePtr> &execution_order, const std::string &tag = "") {
MS_LOG(INFO) << "Dump execution_order size " << execution_order.size();
......@@ -288,6 +289,19 @@ static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph,
// this action should from bottom to top
graph->UpdateCallRealInput();
}
void InsertMakeTupleForEmptyGraph(NotNull<KernelGraphPtr> graph) {
auto return_node = graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
auto origin_output = return_node->input(kReturnDataIndex);
MS_EXCEPTION_IF_NULL(origin_output);
std::vector<AnfNodePtr> make_tuple_input{
std::make_shared<ValueNode>(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), origin_output};
auto new_outputs = graph->NewCNode(make_tuple_input);
MS_EXCEPTION_IF_NULL(new_outputs);
new_outputs->set_abstract(origin_output->abstract());
return_node->set_input(kReturnDataIndex, new_outputs);
}
} // namespace
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
......@@ -305,8 +319,10 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
BackendOptimization(all_graphs);
// empty graph dont entry to backend
if (root_graph->execution_order().empty()) {
if (std::none_of(root_graph->execution_order().begin(), root_graph->execution_order().end(),
[](const CNodePtr &cnode) -> bool { return AnfAlgo::IsRealKernel(cnode); })) {
MS_LOG(INFO) << root_graph->ToString() << " is empty graph.";
InsertMakeTupleForEmptyGraph(NOT_NULL(root_graph));
root_graph->set_executable(false);
InitRuntimeResource();
return root_graph->graph_id();
......@@ -1027,7 +1043,7 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true
// append switch at the end of condition graph
auto return_node = condition_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
InsertControlDependToGraph(condition_graph_id, return_node->input(1), switch_node);
InsertControlDependToGraph(condition_graph_id, return_node->input(kReturnDataIndex), switch_node);
MS_LOG(INFO) << "Finish!";
}
......@@ -1477,7 +1493,7 @@ void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived
// append the active node at the end of from graph
auto return_node = from_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
InsertControlDependToGraph(graph_id, return_node->input(1), active_node);
InsertControlDependToGraph(graph_id, return_node->input(kReturnDataIndex), active_node);
}
void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册