提交 668ae523 编写于 作者: T Tao Luo

speedup DetectPatterns

test=develop
上级 c27554ac
......@@ -167,10 +167,12 @@ struct HitGroup {
bool Match(Node *node, PDNode *pat) {
if (nodes_.count(node)) {
if (!roles.count(pat)) return false;
return roles[pat] == node;
if (roles.count(pat) && roles[pat] == node) return true;
return false;
} else {
if (roles.count(pat) && roles[pat] != node) return false;
return true;
}
return !roles.count(pat) || roles.at(pat) == node;
}
void Register(Node *node, PDNode *pat) {
......@@ -198,7 +200,6 @@ GraphPatternDetector::DetectPatterns() {
std::vector<GraphPatternDetector::subgraph_t> result;
std::vector<HitGroup> init_groups;
std::array<std::vector<HitGroup>, 2> bi_records;
// PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
auto *first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get()
: pattern_.edges().front().first;
if (!pdnodes2nodes_.count(first_pnode)) return result;
......@@ -228,11 +229,12 @@ GraphPatternDetector::DetectPatterns() {
VLOG(80) << "check " << source->id() << " -- " << target->id();
// TODO(Superjomn) add some prune strategies.
for (const auto &group : pre_groups) {
HitGroup new_group = group;
if (IsNodesLink(source, target) &&
new_group.Match(source, edge.first)) {
new_group.Register(source, edge.first);
if (new_group.Match(target, edge.second)) {
if (IsNodesLink(source, target)) {
HitGroup new_group = group;
bool flag = new_group.Match(source, edge.first) &&
new_group.Match(target, edge.second);
if (flag) {
new_group.Register(source, edge.first);
new_group.Register(target, edge.second);
cur_groups.push_back(new_group);
// TODO(Superjomn) need to unique
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册