提交 668ae523 编写于 作者: T Tao Luo

speedup DetectPatterns

test=develop
上级 c27554ac
...@@ -167,10 +167,12 @@ struct HitGroup { ...@@ -167,10 +167,12 @@ struct HitGroup {
bool Match(Node *node, PDNode *pat) { bool Match(Node *node, PDNode *pat) {
if (nodes_.count(node)) { if (nodes_.count(node)) {
if (!roles.count(pat)) return false; if (roles.count(pat) && roles[pat] == node) return true;
return roles[pat] == node; return false;
} else {
if (roles.count(pat) && roles[pat] != node) return false;
return true;
} }
return !roles.count(pat) || roles.at(pat) == node;
} }
void Register(Node *node, PDNode *pat) { void Register(Node *node, PDNode *pat) {
...@@ -198,7 +200,6 @@ GraphPatternDetector::DetectPatterns() { ...@@ -198,7 +200,6 @@ GraphPatternDetector::DetectPatterns() {
std::vector<GraphPatternDetector::subgraph_t> result; std::vector<GraphPatternDetector::subgraph_t> result;
std::vector<HitGroup> init_groups; std::vector<HitGroup> init_groups;
std::array<std::vector<HitGroup>, 2> bi_records; std::array<std::vector<HitGroup>, 2> bi_records;
// PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
auto *first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get() auto *first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get()
: pattern_.edges().front().first; : pattern_.edges().front().first;
if (!pdnodes2nodes_.count(first_pnode)) return result; if (!pdnodes2nodes_.count(first_pnode)) return result;
...@@ -228,11 +229,12 @@ GraphPatternDetector::DetectPatterns() { ...@@ -228,11 +229,12 @@ GraphPatternDetector::DetectPatterns() {
VLOG(80) << "check " << source->id() << " -- " << target->id(); VLOG(80) << "check " << source->id() << " -- " << target->id();
// TODO(Superjomn) add some prune strategies. // TODO(Superjomn) add some prune strategies.
for (const auto &group : pre_groups) { for (const auto &group : pre_groups) {
if (IsNodesLink(source, target)) {
HitGroup new_group = group; HitGroup new_group = group;
if (IsNodesLink(source, target) && bool flag = new_group.Match(source, edge.first) &&
new_group.Match(source, edge.first)) { new_group.Match(target, edge.second);
if (flag) {
new_group.Register(source, edge.first); new_group.Register(source, edge.first);
if (new_group.Match(target, edge.second)) {
new_group.Register(target, edge.second); new_group.Register(target, edge.second);
cur_groups.push_back(new_group); cur_groups.push_back(new_group);
// TODO(Superjomn) need to unique // TODO(Superjomn) need to unique
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册