未验证 提交 30234dd7 编写于 作者: J JingZhuangzhuang 提交者: GitHub

pdnode_compare (#42597)

* pdnode_compare

* panode compare

* pdnode_compare
上级 0ce42fb0
...@@ -81,6 +81,7 @@ struct PDNode { ...@@ -81,6 +81,7 @@ struct PDNode {
bool IsVar() const { return type_ == Type::kVar; } bool IsVar() const { return type_ == Type::kVar; }
const std::string& name() const { return name_; } const std::string& name() const { return name_; }
const PDPattern* pdpattern() const { return pattern_; }
PDNode& operator=(const PDNode&) = delete; PDNode& operator=(const PDNode&) = delete;
PDNode(const PDNode&) = delete; PDNode(const PDNode&) = delete;
...@@ -277,7 +278,44 @@ class PDPattern { ...@@ -277,7 +278,44 @@ class PDPattern {
*/ */
class GraphPatternDetector { class GraphPatternDetector {
public: public:
using subgraph_t = std::map<PDNode*, Node*>; struct NodeIdCompare {
bool operator()(Node* node1, Node* node2) const {
return node1->id() < node2->id();
}
};
struct PDNodeCompare {
bool operator()(const PDNode* node1, const PDNode* node2) const {
auto& nodes1 = node1->pdpattern()->nodes();
auto& nodes2 = node2->pdpattern()->nodes();
if (nodes1.size() != nodes2.size()) {
return nodes1.size() < nodes2.size();
} else {
std::string pdnode_hash_key1 = "";
std::string pdnode_hash_key2 = "";
for (auto& node : nodes1) {
pdnode_hash_key1 += node.get()->name();
pdnode_hash_key1 += "#";
}
pdnode_hash_key1 += node1->name();
for (auto& node : nodes2) {
pdnode_hash_key2 += node.get()->name();
pdnode_hash_key2 += "#";
}
pdnode_hash_key2 += node2->name();
auto pdnode_key1 =
std::to_string(std::hash<std::string>()(pdnode_hash_key1));
auto pdnode_key2 =
std::to_string(std::hash<std::string>()(pdnode_hash_key2));
return pdnode_key1 < pdnode_key2;
}
return false;
}
};
using subgraph_t = std::map<PDNode*, Node*, PDNodeCompare>;
// Operate on the detected pattern. // Operate on the detected pattern.
using handle_t = using handle_t =
...@@ -321,7 +359,8 @@ class GraphPatternDetector { ...@@ -321,7 +359,8 @@ class GraphPatternDetector {
using hit_rcd_t = using hit_rcd_t =
std::pair<Node* /*node in graph*/, PDNode* /*node in pattern*/>; std::pair<Node* /*node in graph*/, PDNode* /*node in pattern*/>;
PDPattern pattern_; PDPattern pattern_;
std::map<const PDNode*, std::set<Node*>> pdnodes2nodes_; std::map<const PDNode*, std::set<Node*, NodeIdCompare>, PDNodeCompare>
pdnodes2nodes_;
}; };
// some helper methods. // some helper methods.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册