未验证 提交 403b503f 编写于 作者: J JingZhuangzhuang 提交者: GitHub

pdnode_compare (#42597) (#42633)

* pdnode_compare

* panode compare

* pdnode_compare
上级 25124d7f
......@@ -81,6 +81,7 @@ struct PDNode {
bool IsVar() const { return type_ == Type::kVar; }
const std::string& name() const { return name_; }
const PDPattern* pdpattern() const { return pattern_; }
PDNode& operator=(const PDNode&) = delete;
PDNode(const PDNode&) = delete;
......@@ -277,7 +278,44 @@ class PDPattern {
*/
class GraphPatternDetector {
public:
using subgraph_t = std::map<PDNode*, Node*>;
struct NodeIdCompare {
bool operator()(Node* node1, Node* node2) const {
return node1->id() < node2->id();
}
};
struct PDNodeCompare {
bool operator()(const PDNode* node1, const PDNode* node2) const {
auto& nodes1 = node1->pdpattern()->nodes();
auto& nodes2 = node2->pdpattern()->nodes();
if (nodes1.size() != nodes2.size()) {
return nodes1.size() < nodes2.size();
} else {
std::string pdnode_hash_key1 = "";
std::string pdnode_hash_key2 = "";
for (auto& node : nodes1) {
pdnode_hash_key1 += node.get()->name();
pdnode_hash_key1 += "#";
}
pdnode_hash_key1 += node1->name();
for (auto& node : nodes2) {
pdnode_hash_key2 += node.get()->name();
pdnode_hash_key2 += "#";
}
pdnode_hash_key2 += node2->name();
auto pdnode_key1 =
std::to_string(std::hash<std::string>()(pdnode_hash_key1));
auto pdnode_key2 =
std::to_string(std::hash<std::string>()(pdnode_hash_key2));
return pdnode_key1 < pdnode_key2;
}
return false;
}
};
using subgraph_t = std::map<PDNode*, Node*, PDNodeCompare>;
// Operate on the detected pattern.
using handle_t =
......@@ -321,7 +359,8 @@ class GraphPatternDetector {
using hit_rcd_t =
std::pair<Node* /*node in graph*/, PDNode* /*node in pattern*/>;
PDPattern pattern_;
std::map<const PDNode*, std::set<Node*>> pdnodes2nodes_;
std::map<const PDNode*, std::set<Node*, NodeIdCompare>, PDNodeCompare>
pdnodes2nodes_;
};
// some helper methods.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册