From a1aae040fed9ec581d1db80098e9ddd3c6b79833 Mon Sep 17 00:00:00 2001 From: Wilber Date: Thu, 26 Aug 2021 11:30:03 +0800 Subject: [PATCH] [Inference] Replace unordered_map with map to support subgraph stability (#35147) --- paddle/fluid/framework/ir/subgraph_detector.cc | 6 +++--- .../fluid/tests/unittests/ir/test_ir_fusion_group_pass.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/subgraph_detector.cc b/paddle/fluid/framework/ir/subgraph_detector.cc index 5910daf547b..9a1db5e2578 100644 --- a/paddle/fluid/framework/ir/subgraph_detector.cc +++ b/paddle/fluid/framework/ir/subgraph_detector.cc @@ -117,7 +117,7 @@ void SubgraphDetector::MarkNodesInsideSubGraph() { // Use the Union Find(UF) algorithm to find fully connected sub-graphs, if node // a's output is node b, that is a and b is in the same sub-graph. The UF // algorithm will group them to the same cluster. -using node_map_t = std::unordered_map; +using node_map_t = std::map; // Find the ancestor id of a node. int UnionFindGetAncestor(const node_map_t &node_map, size_t id) { int tmp = id; @@ -155,7 +155,7 @@ struct BriefNode { // 3. change all the dst's inputs and outputs // corresponding inlinks and outlinks to src node. // 4. delete all dst's inlinks and outlinks. -void UnionContractedNodes(const std::unordered_map &node_map, +void UnionContractedNodes(const std::map &node_map, int src_id, int dst_id) { // merge the two adjacent nodes into one node. BriefNode *src_node = node_map.at(src_id); @@ -262,7 +262,7 @@ std::vector> SubgraphDetector::ExtractSubGraphs() { std::vector marked_nodes; // We use brief_node_map to represent the original graph in order to avoid // changing the original graph. - std::unordered_map brief_node_map; + std::map brief_node_map; std::unordered_set valid_node_ids; for (auto *node : graph_->Nodes()) { diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py index 46d574dad0d..84d7bb5c969 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py @@ -167,7 +167,7 @@ class FusionGroupPassSumTest(FusionGroupPassTest): self.append_gradients(tmp_3) - self.num_fused_ops = 4 + self.num_fused_ops = 3 self.fetch_list = [tmp_3, self.grad(tmp_0)] -- GitLab