提交 8628b898 编写于 作者: Z zhoufeng

Make assign-node to be before jump-node, ensure child graph can get its

input
Signed-off-by: Nzhoufeng <zhoufeng54@huawei.com>
上级 12a359b9
......@@ -33,6 +33,21 @@ static constexpr size_t kCNodeSwitchLayerLength = 3;
namespace mindspore {
namespace session {
static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) {
auto &nodes = parent_graph->execution_order();
for (auto &node : nodes) {
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) {
return node;
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) &&
(child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) {
return node;
}
}
MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
return nullptr;
}
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
const NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(kg.get()) != memo->end()) {
......@@ -200,7 +215,8 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
if (target_graph_iter == graph_id_map.end()) {
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
}
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter));
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg),
NOT_NULL(parameter));
}
}
}
......@@ -433,7 +449,8 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
return {partial_cnode, branch_kg};
}
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph,
NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> from,
NotNull<AnfNodePtr> to) {
std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem});
std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem});
......@@ -443,18 +460,24 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg
<< to_outputs.size() << "]";
}
for (size_t i = 0; i < from_outputs.size(); i++) {
InsertAssignToGraph(kg, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
if (assign_node != nullptr) {
auto jump_node = GetJumpNode(from_graph, to_graph);
if (jump_node != nullptr) {
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
}
}
}
}
void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
NotNull<AnfNodePtr> to) {
if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
return;
return nullptr;
}
if (from.get() == to.get()) {
return;
return nullptr;
}
MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to "
<< to->DebugString();
......@@ -466,6 +489,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
assign_node->set_abstract(to->abstract());
// append the assign at the end of from graph
InsertDependToGraph(kg, NOT_NULL(assign_node));
return assign_node;
}
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
......
......@@ -52,8 +52,9 @@ class AscendControlParser {
const CNodePtr &last_label);
static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node);
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph,
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
// root graph order
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册