diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 288fe0bbb85fc68c85177c8008f74e28caa05cda..05ebcaccf84c774b2db6850c2fa62ff4b23874c1 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -225,7 +225,7 @@ static void BindCallArgsWithParameter(const std::vector ¶meters, // if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of // graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] -static void UpdateRealInput(NotNull graph) { +static void UpdateRealInput(NotNull graph, bool split_flag) { auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); for (auto &call_node : call_nodes) { MS_EXCEPTION_IF_NULL(call_node); @@ -236,7 +236,9 @@ static void UpdateRealInput(NotNull graph) { std::vector(call_node->inputs().begin() + 2, call_node->inputs().end()); std::vector child_inputs = child_graphs[0]->inputs(); BindCallArgsWithParameter(child_inputs, real_args, child_graphs[0].get()); - call_node->set_inputs(std::vector(call_node->inputs().begin(), call_node->inputs().begin() + 2)); + if (split_flag) { + call_node->set_inputs(std::vector(call_node->inputs().begin(), call_node->inputs().begin() + 2)); + } } else if (child_graphs.size() == 2) { auto get_partial_args = [&](size_t input_index) -> std::vector { auto switch_node = call_node->input(1); @@ -248,8 +250,10 @@ static void UpdateRealInput(NotNull graph) { auto partial_cnode = partial->cast(); MS_EXCEPTION_IF_NULL(partial_cnode); auto ret = std::vector(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); - partial_cnode->set_inputs( - std::vector(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); + if (split_flag) { + partial_cnode->set_inputs( + std::vector(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); + } return ret; }; BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); @@ -1678,6 +1682,7 @@ AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull graph, void AscendSession::SplitGraph(NotNull graph, const std::set &cut_prims) { MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id(); + bool split_flag = false; auto apply_list = GetCNodes(TopoSort(graph->get_return())); // update the root graph child graph order AscendControlParser::UpdateChildGraphOrder(graph); @@ -1710,9 +1715,10 @@ void AscendSession::SplitGraph(NotNull graph, const std::setgraph_id() << "] end"; // recurse to split child graph }