提交 e3c4ee75 编写于 作者: W wenchunjiang

reset call inputs only when graph has been splited

上级 2a84a86b
...@@ -225,7 +225,7 @@ static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> &parameters, ...@@ -225,7 +225,7 @@ static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> &parameters,
// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of // if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of
// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] // graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2]
static void UpdateRealInput(NotNull<KernelGraphPtr> graph) { static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag) {
auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall);
for (auto &call_node : call_nodes) { for (auto &call_node : call_nodes) {
MS_EXCEPTION_IF_NULL(call_node); MS_EXCEPTION_IF_NULL(call_node);
...@@ -236,7 +236,9 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph) { ...@@ -236,7 +236,9 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph) {
std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end()); std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end());
std::vector<AnfNodePtr> child_inputs = child_graphs[0]->inputs(); std::vector<AnfNodePtr> child_inputs = child_graphs[0]->inputs();
BindCallArgsWithParameter(child_inputs, real_args, child_graphs[0].get()); BindCallArgsWithParameter(child_inputs, real_args, child_graphs[0].get());
call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2)); if (split_flag) {
call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2));
}
} else if (child_graphs.size() == 2) { } else if (child_graphs.size() == 2) {
auto get_partial_args = [&](size_t input_index) -> std::vector<AnfNodePtr> { auto get_partial_args = [&](size_t input_index) -> std::vector<AnfNodePtr> {
auto switch_node = call_node->input(1); auto switch_node = call_node->input(1);
...@@ -248,8 +250,10 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph) { ...@@ -248,8 +250,10 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph) {
auto partial_cnode = partial->cast<CNodePtr>(); auto partial_cnode = partial->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode); MS_EXCEPTION_IF_NULL(partial_cnode);
auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end());
partial_cnode->set_inputs( if (split_flag) {
std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); partial_cnode->set_inputs(
std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2));
}
return ret; return ret;
}; };
BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get());
...@@ -1678,6 +1682,7 @@ AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph, ...@@ -1678,6 +1682,7 @@ AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph,
void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) { void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) {
MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id(); MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id();
bool split_flag = false;
auto apply_list = GetCNodes(TopoSort(graph->get_return())); auto apply_list = GetCNodes(TopoSort(graph->get_return()));
// update the root graph child graph order // update the root graph child graph order
AscendControlParser::UpdateChildGraphOrder(graph); AscendControlParser::UpdateChildGraphOrder(graph);
...@@ -1710,9 +1715,10 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri ...@@ -1710,9 +1715,10 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri
AscendControlParser::InsertControlDependToGraph(graph, NOT_NULL(cur_call_node), NOT_NULL(pre_call_node)); AscendControlParser::InsertControlDependToGraph(graph, NOT_NULL(cur_call_node), NOT_NULL(pre_call_node));
} }
} }
split_flag = true;
} }
AscendControlParser::UpdateChildGraphOrder(graph); AscendControlParser::UpdateChildGraphOrder(graph);
UpdateRealInput(graph); UpdateRealInput(graph, split_flag);
MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end"; MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end";
// recurse to split child graph // recurse to split child graph
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册