diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 6b49b4b87873fce1af8efc8f4ce4f64f92549192..c44cd8bbbdab634784958effeba1f2de58f2b345 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -198,7 +198,7 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue seed_nodes; + std::stack seed_nodes; UpdateNodeEdgeList(&seed_nodes); execution_order_.clear(); std::unordered_set visited_nodes; @@ -211,7 +211,7 @@ void KernelGraph::SetExecOrderByDefault() { VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); last_communication_node = nullptr; } else { - zero_input_nodes.push(seed_nodes.front()); + zero_input_nodes.push(seed_nodes.top()); seed_nodes.pop(); } // all reduce node descendant first, then common queue @@ -785,7 +785,7 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue *seed_nodes) { +void KernelGraph::UpdateNodeEdgeList(std::stack *seed_nodes) { MS_EXCEPTION_IF_NULL(seed_nodes); node_output_edges_.clear(); node_input_num_.clear(); @@ -868,7 +868,7 @@ void KernelGraph::ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNo void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNull new_anf_node) { MS_EXCEPTION_IF_NULL(inputs_); { - std::queue seed_nodes; + std::stack seed_nodes; UpdateNodeEdgeList(&seed_nodes); } auto it = node_output_edges_.find(old_anf_node); @@ -894,7 +894,7 @@ void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNull seed_nodes; + std::stack seed_nodes; UpdateNodeEdgeList(&seed_nodes); } } diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 047c21ea203ba06a478d913c8a5ca0e10949c8b0..536571c10bda50fc6f4f9e0b90aeea973414c4fe 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -17,6 +17,7 @@ #define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H #include +#include #include #include #include @@ -170,7 +171,7 @@ class KernelGraph : public FuncGraph { void VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, std::unordered_set *visited_nodes); // update node edge list - void UpdateNodeEdgeList(std::queue *seed_nodes); + void UpdateNodeEdgeList(std::stack *seed_nodes); // add node depend edge by data edge or control depend void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); // handle control depend