From 30234dd75bcc2415d672140342745ce8472ee601 Mon Sep 17 00:00:00 2001 From: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com> Date: Tue, 10 May 2022 11:12:41 +0800 Subject: [PATCH] pdnode_compare (#42597) * pdnode_compare * panode compare * pdnode_compare --- .../framework/ir/graph_pattern_detector.h | 43 ++++++++++++++++++- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index d7e265fe28..96a1e5c071 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -81,6 +81,7 @@ struct PDNode { bool IsVar() const { return type_ == Type::kVar; } const std::string& name() const { return name_; } + const PDPattern* pdpattern() const { return pattern_; } PDNode& operator=(const PDNode&) = delete; PDNode(const PDNode&) = delete; @@ -277,7 +278,44 @@ class PDPattern { */ class GraphPatternDetector { public: - using subgraph_t = std::map; + struct NodeIdCompare { + bool operator()(Node* node1, Node* node2) const { + return node1->id() < node2->id(); + } + }; + + struct PDNodeCompare { + bool operator()(const PDNode* node1, const PDNode* node2) const { + auto& nodes1 = node1->pdpattern()->nodes(); + auto& nodes2 = node2->pdpattern()->nodes(); + if (nodes1.size() != nodes2.size()) { + return nodes1.size() < nodes2.size(); + } else { + std::string pdnode_hash_key1 = ""; + std::string pdnode_hash_key2 = ""; + for (auto& node : nodes1) { + pdnode_hash_key1 += node.get()->name(); + pdnode_hash_key1 += "#"; + } + pdnode_hash_key1 += node1->name(); + for (auto& node : nodes2) { + pdnode_hash_key2 += node.get()->name(); + pdnode_hash_key2 += "#"; + } + pdnode_hash_key2 += node2->name(); + + auto pdnode_key1 = + std::to_string(std::hash()(pdnode_hash_key1)); + auto pdnode_key2 = + std::to_string(std::hash()(pdnode_hash_key2)); + + return pdnode_key1 < pdnode_key2; + } + return false; + } + }; + + using subgraph_t = std::map; // Operate on the detected pattern. using handle_t = @@ -321,7 +359,8 @@ class GraphPatternDetector { using hit_rcd_t = std::pair; PDPattern pattern_; - std::map> pdnodes2nodes_; + std::map, PDNodeCompare> + pdnodes2nodes_; }; // some helper methods. -- GitLab