From 668ae523d2cdb61dfac1b2b64cbdba9fd9abc8e6 Mon Sep 17 00:00:00 2001 From: Tao Luo Date: Mon, 12 Nov 2018 21:08:45 +0800 Subject: [PATCH] speedup DetectPatterns test=develop --- .../framework/ir/graph_pattern_detector.cc | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 0a3c8a6cb5c..0d504b3048b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -167,10 +167,12 @@ struct HitGroup { bool Match(Node *node, PDNode *pat) { if (nodes_.count(node)) { - if (!roles.count(pat)) return false; - return roles[pat] == node; + if (roles.count(pat) && roles[pat] == node) return true; + return false; + } else { + if (roles.count(pat) && roles[pat] != node) return false; + return true; } - return !roles.count(pat) || roles.at(pat) == node; } void Register(Node *node, PDNode *pat) { @@ -198,7 +200,6 @@ GraphPatternDetector::DetectPatterns() { std::vector result; std::vector init_groups; std::array, 2> bi_records; - // PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed"); auto *first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get() : pattern_.edges().front().first; if (!pdnodes2nodes_.count(first_pnode)) return result; @@ -228,11 +229,12 @@ GraphPatternDetector::DetectPatterns() { VLOG(80) << "check " << source->id() << " -- " << target->id(); // TODO(Superjomn) add some prune strategies. for (const auto &group : pre_groups) { - HitGroup new_group = group; - if (IsNodesLink(source, target) && - new_group.Match(source, edge.first)) { - new_group.Register(source, edge.first); - if (new_group.Match(target, edge.second)) { + if (IsNodesLink(source, target)) { + HitGroup new_group = group; + bool flag = new_group.Match(source, edge.first) && + new_group.Match(target, edge.second); + if (flag) { + new_group.Register(source, edge.first); new_group.Register(target, edge.second); cur_groups.push_back(new_group); // TODO(Superjomn) need to unique -- GitLab