diff --git a/paddle/fluid/framework/ir/subgraph_detector.cc b/paddle/fluid/framework/ir/subgraph_detector.cc index 5910daf547bbd943cd0428bdab07db56b238e655..9a1db5e25784fc83764d7de2497aeec46d037613 100644 --- a/paddle/fluid/framework/ir/subgraph_detector.cc +++ b/paddle/fluid/framework/ir/subgraph_detector.cc @@ -117,7 +117,7 @@ void SubgraphDetector::MarkNodesInsideSubGraph() { // Use the Union Find(UF) algorithm to find fully connected sub-graphs, if node // a's output is node b, that is a and b is in the same sub-graph. The UF // algorithm will group them to the same cluster. -using node_map_t = std::unordered_map; +using node_map_t = std::map; // Find the ancestor id of a node. int UnionFindGetAncestor(const node_map_t &node_map, size_t id) { int tmp = id; @@ -155,7 +155,7 @@ struct BriefNode { // 3. change all the dst's inputs and outputs // corresponding inlinks and outlinks to src node. // 4. delete all dst's inlinks and outlinks. -void UnionContractedNodes(const std::unordered_map &node_map, +void UnionContractedNodes(const std::map &node_map, int src_id, int dst_id) { // merge the two adjacent nodes into one node. BriefNode *src_node = node_map.at(src_id); @@ -262,7 +262,7 @@ std::vector> SubgraphDetector::ExtractSubGraphs() { std::vector marked_nodes; // We use brief_node_map to represent the original graph in order to avoid // changing the original graph. - std::unordered_map brief_node_map; + std::map brief_node_map; std::unordered_set valid_node_ids; for (auto *node : graph_->Nodes()) { diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py index 46d574dad0d0ae1f72617c6aaf3369b16195f76b..84d7bb5c969e610e2d70ee92e114b43863c87be6 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py @@ -167,7 +167,7 @@ class FusionGroupPassSumTest(FusionGroupPassTest): self.append_gradients(tmp_3) - self.num_fused_ops = 4 + self.num_fused_ops = 3 self.fetch_list = [tmp_3, self.grad(tmp_0)]