diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc index 868b968d9e4cee625c29c3635c14cd3b9e3ddd76..573c1c1d3566be9cf28b72e6ec080acb73c85489 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/session/ascend_control_parser.cc @@ -33,6 +33,21 @@ static constexpr size_t kCNodeSwitchLayerLength = 3; namespace mindspore { namespace session { +static CNodePtr GetJumpNode(NotNull parent_graph, NotNull child_graph) { + auto &nodes = parent_graph->execution_order(); + for (auto &node : nodes) { + if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) { + return node; + } else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) && + (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) || + child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) { + return node; + } + } + MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString(); + return nullptr; +} + static void InitUnionFindSet(NotNull kg, const NotNull *> union_find_set, const NotNull *> memo) { if (memo->find(kg.get()) != memo->end()) { @@ -200,7 +215,8 @@ void AscendControlParser::ChildGraphDataAssign(const std::mapsecond), NOT_NULL(arg), NOT_NULL(parameter)); + InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg), + NOT_NULL(parameter)); } } } @@ -433,7 +449,8 @@ std::tuple AscendControlParser::ParsePartial(NotNull kg, NotNull from, +void AscendControlParser::InsertMultipleAssignToGraph(NotNull from_graph, + NotNull to_graph, NotNull from, NotNull to) { std::vector from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); std::vector to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); @@ -443,18 +460,24 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull kg << to_outputs.size() << "]"; } for (size_t i = 0; i < from_outputs.size(); i++) { - InsertAssignToGraph(kg, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); + auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); + if (assign_node != nullptr) { + auto jump_node = GetJumpNode(from_graph, to_graph); + if (jump_node != nullptr) { + InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); + } + } } } -void AscendControlParser::InsertAssignToGraph(NotNull kg, NotNull from, - NotNull to) { +AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull kg, NotNull from, + NotNull to) { if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { - return; + return nullptr; } if (from.get() == to.get()) { - return; + return nullptr; } MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " << to->DebugString(); @@ -466,6 +489,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull kg, NotNul assign_node->set_abstract(to->abstract()); // append the assign at the end of from graph InsertDependToGraph(kg, NOT_NULL(assign_node)); + return assign_node; } std::vector AscendControlParser::RecurseGraph(NotNull graph, diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h index 73d68449b31f003dbf2ac57bf27f245af3319c8c..0cf7069046d49e153e592a99792273a13fe7d3db 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.h +++ b/mindspore/ccsrc/session/ascend_control_parser.h @@ -52,8 +52,9 @@ class AscendControlParser { const CNodePtr &last_label); static std::tuple ParsePartial(NotNull node); - static void InsertMultipleAssignToGraph(NotNull kg, NotNull from, NotNull to); - static void InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); + static void InsertMultipleAssignToGraph(NotNull from_graph, NotNull to_graph, + NotNull from, NotNull to); + static AnfNodePtr InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); // root graph order static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,