From b6f71b8a013383f84db466fe41756bfbcfd3ee31 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Wed, 5 Jun 2019 15:35:49 +0800 Subject: [PATCH] Add graph fusion(pattern matcher) in the lite framework. (#17839) * add the pattern matcher logic. test=develop * update op_info usage. test=develop --- paddle/fluid/lite/core/mir/CMakeLists.txt | 3 + paddle/fluid/lite/core/mir/node.h | 10 + paddle/fluid/lite/core/mir/pattern_matcher.cc | 405 ++++++++++++++++++ paddle/fluid/lite/core/mir/pattern_matcher.h | 371 ++++++++++++++++ .../lite/core/mir/pattern_matcher_tester.cc | 233 ++++++++++ 5 files changed, 1022 insertions(+) create mode 100644 paddle/fluid/lite/core/mir/pattern_matcher.cc create mode 100644 paddle/fluid/lite/core/mir/pattern_matcher.h create mode 100644 paddle/fluid/lite/core/mir/pattern_matcher_tester.cc diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index 00b177fe6ba..c66078645f8 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -48,4 +48,7 @@ if (LITE_WITH_CUDA) endif() cc_test(test_variable_place_infrence_pass SRCS variable_place_inference_pass_test.cc DEPS ${test_variable_place_infrence_pass_DEPS}) + +cc_library(pattern_matcher_lite SRCS pattern_matcher.cc DEPS mir_node mir_ssa_graph op_lite) +cc_test(test_pattern_matcher_lite SRCS pattern_matcher_tester.cc DEPS pattern_matcher_lite) diff --git a/paddle/fluid/lite/core/mir/node.h b/paddle/fluid/lite/core/mir/node.h index e10472c5885..67ee47a9e12 100644 --- a/paddle/fluid/lite/core/mir/node.h +++ b/paddle/fluid/lite/core/mir/node.h @@ -93,6 +93,16 @@ class Node { return x; } + Stmt* stmt() const { + CHECK(IsStmt()); + return stmt_.get(); + } + + Arg* arg() const { + CHECK(IsArg()); + return arg_.get(); + } + // Set roles. Arg& AsArg() { if (role_ != Role::kUnk) { diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.cc b/paddle/fluid/lite/core/mir/pattern_matcher.cc new file mode 100644 index 00000000000..1f5a1c6e4e5 --- /dev/null +++ b/paddle/fluid/lite/core/mir/pattern_matcher.cc @@ -0,0 +1,405 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "paddle/fluid/inference/analysis/dot.h" +#include "paddle/fluid/lite/core/mir/pattern_matcher.h" +#include "paddle/fluid/lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace mir { + +size_t PMPattern::id_ = 0UL; + +PMNode *PMPattern::NewNode(const std::string &name) { + if (!name.empty()) { + CHECK_EQ(node_map_.count(name), 0UL) + << "PMNode's name should be unique, get duplicate " << name; + } + + nodes_.emplace_back(new PMNode(this, name)); + auto *cur = nodes_.back().get(); + node_map_[name] = cur; + return cur; +} + +PMNode *PMPattern::NewNode(PMNode::teller_t &&teller, const std::string &name) { + if (!name.empty()) { + CHECK_EQ(node_map_.count(name), 0UL) + << "PMNode's name should be unique, get duplicate " << name; + } + + nodes_.emplace_back(new PMNode(std::move(teller), this, name)); + auto *cur = nodes_.back().get(); + node_map_[name] = cur; + return cur; +} + +PMNode *PMPattern::RetrieveNode(const std::string &id) const { + auto it = node_map_.find(id); + if (it == node_map_.end()) { + return nullptr; + } + + return it->second; +} + +void PMPattern::AddEdge(PMNode *a, PMNode *b) { + CHECK(a); + CHECK(b); + CHECK_NE(a, b) << "Can't connect to the same nodes."; + edges_.emplace_back(a, b); +} + +void PatternMatcher::operator()(SSAGraph *graph, + PatternMatcher::handle_t handler) { + if (!MarkPMNodesInGraph(graph)) { + return; + } + + auto subgraphs = DetectPatterns(); + UniquePatterns(&subgraphs); + RemoveOverlappedMatch(&subgraphs); + ValidateByNodeRole(&subgraphs); + + if (subgraphs.empty()) return; + LOG(INFO) << "--- detected " << subgraphs.size() << " subgraphs."; + int id = 0; + for (auto &g : subgraphs) { + VLOG(3) << "optimizing #" << id++ << " subgraph"; + handler(g, graph); + } +} + +bool PatternMatcher::MarkPMNodesInGraph(SSAGraph *graph) { + VLOG(3) << "mark pmnodes in graph"; + if (graph->nodes().empty()) return false; + + for (auto &node : graph->mutable_nodes()) { + for (const auto &pmnode : pattern_.nodes()) { + if (pmnode->Tell(&node)) { + pmnodes2nodes_[pmnode.get()].insert(&node); + } + } + } + // Check to early stop if some PMNode can't find matched Node. + for (auto &pmnode : pattern_.nodes()) { + if (!pmnodes2nodes_.count(pmnode.get())) { + VLOG(4) << pmnode->name() << " can't find matched Node, early stop"; + // return false; + } + } + VLOG(3) << pmnodes2nodes_.size() << " nodes marked"; + + return !pmnodes2nodes_.empty(); +} + +// The intermediate Nodes can only link to the nodes inside the pattern, or this +// subgraph will be droped. +void PatternMatcher::ValidateByNodeRole( + std::vector *subgraphs) { + std::vector result; + + subgraphs->erase( + std::remove_if(subgraphs->begin(), subgraphs->end(), + [](const PatternMatcher::subgraph_t &subgraph) -> bool { + // Collect the inlinks and outlinks. + std::unordered_set ios; + for (auto &item : subgraph) { + if (!item.first->IsIntermediate()) { + ios.insert(item.second); + } + } + for (auto &item : subgraph) { + if (item.first->IsIntermediate()) { + for (auto *x : item.second->inlinks) { + if (!ios.count(x)) { + return true; + } + } + for (auto *x : item.second->outlinks) { + if (!ios.count(x)) { + return true; + } + } + } + } + return false; + }), + subgraphs->end()); +} + +struct HitGroup { + std::unordered_map roles; + + bool Match(Node *node, PMNode *pat) { + if (nodes_.count(node)) { + if (roles.count(pat) && roles[pat] == node) return true; + return false; + } else { + if (roles.count(pat) && roles[pat] != node) return false; + return true; + } + } + + void Register(Node *node, PMNode *pat) { + roles[pat] = node; + nodes_.insert(node); + } + + private: + std::unordered_set nodes_; +}; + +// Tell whether Node a links to b. +bool IsNodesLink(Node *a, Node *b) { + for (auto *node : a->outlinks) { + if (b == node) { + return true; + } + } + return false; +} + +std::vector PatternMatcher::DetectPatterns() { + // Init empty subgraphs. + std::vector result; + std::vector init_groups; + std::array, 2> bi_records; + auto *first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get() + : pattern_.edges().front().first; + if (!pmnodes2nodes_.count(first_pnode)) return result; + for (auto *node : pmnodes2nodes_[first_pnode]) { + HitGroup group; + group.roles[first_pnode] = node; + init_groups.emplace_back(group); + } + + int step = 0; + bi_records[0] = std::move(init_groups); + + // Extend a PMNode to subgraphs by deducing the connection relations defined + // in edges of PMNodes. + for (const auto &edge : pattern_.edges()) { + VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name(); + // TODO(Superjomn) Fix bug here, the groups might be duplicate here. + // Each role has two PMNodes, which indicates two roles. + // Detect two Nodes that can match these two roles and they are connected. + auto &pre_groups = bi_records[step % 2]; + auto &cur_groups = bi_records[1 - (step++ % 2)]; + cur_groups.clear(); + if (pre_groups.empty()) break; + // source -> target + for (Node *source : pmnodes2nodes_[edge.first]) { + for (Node *target : pmnodes2nodes_[edge.second]) { + // TODO(Superjomn) add some prune strategies. + for (const auto &group : pre_groups) { + if (IsNodesLink(source, target)) { + HitGroup new_group = group; + bool flag = new_group.Match(source, edge.first) && + new_group.Match(target, edge.second); + if (flag) { + new_group.Register(source, edge.first); + new_group.Register(target, edge.second); + cur_groups.push_back(new_group); + // TODO(Superjomn) need to unique + } + } + } + } + } + VLOG(3) << "step " << step << " get records: " << cur_groups.size(); + } + + for (auto &group : bi_records[step % 2]) { + PatternMatcher::subgraph_t subgraph; + for (auto &role : group.roles) { + subgraph.emplace(role.first, role.second); + } + result.emplace_back(subgraph); + } + return result; +} + +struct GraphItemLessThan { + bool operator()(const std::pair &a, + const std::pair &b) { + if (a.first != b.first) { + return a.first < b.first; + } else { + return a.second < b.second; + } + } +}; + +// TODO(Superjomn) enhance the function as it marks unique unique as duplicates +// see https://github.com/PaddlePaddle/Paddle/issues/13550 +void PatternMatcher::UniquePatterns( + std::vector *subgraphs) { + if (subgraphs->empty()) return; + std::vector result; + + std::unordered_set set; + std::hash hasher; + for (auto &g : *subgraphs) { + // Sort the items in the sub-graph, and transform to a string key. + std::vector> sorted_keys(g.begin(), g.end()); + std::sort(sorted_keys.begin(), sorted_keys.end(), GraphItemLessThan()); + std::stringstream ss; + for (auto &item : sorted_keys) { + ss << item.first << ":" << item.second; + } + auto key = hasher(ss.str()); + if (!set.count(key)) { + result.emplace_back(g); + set.insert(key); + } + } + *subgraphs = result; +} + +void PatternMatcher::RemoveOverlappedMatch(std::vector *subgraphs) { + std::vector result; + std::unordered_set node_set; + + for (const auto &subgraph : *subgraphs) { + bool valid = true; + for (auto &item : subgraph) { + if (item.first->IsIntermediate() && node_set.count(item.second)) { + valid = false; + break; + } + } + if (valid) { + for (auto &item : subgraph) { + node_set.insert(item.second); + } + result.push_back(subgraph); + } + } + *subgraphs = result; +} + +std::string PMPattern::DotString() const { + using inference::analysis::Dot; + Dot dot; + int id = 0; + // Create Nodes + std::unordered_map node2dot; + for (const auto &node : nodes()) { + std::string node_id = "Node" + std::to_string(id++); + dot.AddNode(node_id, {}, node->name()); + node2dot[node.get()] = node_id; + } + // Create Edges + for (const auto &edge : edges()) { + if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) { + LOG(ERROR) << "no node " << edge.first << " " << edge.second; + continue; + } + auto &src = node2dot.at(edge.first); + auto &trg = node2dot.at(edge.second); + dot.AddEdge(src, trg, {}); + } + return dot.Build(); +} + +PMNode &PMNode::LinksTo(const std::vector &others) { + // extend outlinks. + for (PMNode *x : others) { + pattern_->AddEdge(this, x); + } + return *this; +} + +PMNode &PMNode::LinksFrom(const std::vector &others) { + // extend outlinks. + for (PMNode *x : others) { + pattern_->AddEdge(x, this); + } + return *this; +} + +PMNode *PMNode::assert_is_op() { + asserts_.emplace_back([](const Node *x) { return x && x->IsStmt(); }); + return this; +} + +PMNode *PMNode::assert_is_op(const std::string &op_type) { + asserts_.emplace_back([op_type](const Node *x) { + if (x && x->IsStmt()) { + auto *op_info = x->stmt()->op_info(); + return op_info->Type() == op_type; + } else { + return false; + } + }); + return this; +} + +PMNode *PMNode::assert_is_var() { + asserts_.emplace_back([](const Node *x) { return x && x->IsArg(); }); + return this; +} + +PMNode *PMNode::assert_var_not_persistable() { + assert_is_var(); + asserts_.emplace_back([](const Node *x) { return !x->arg()->is_weight; }); + return this; +} + +PMNode *PMNode::assert_is_persistable_var() { + assert_is_var(); + asserts_.emplace_back([=](const Node *x) { return x->arg()->is_weight; }); + return this; +} + +PMNode *PMNode::assert_is_op_output(const std::string &op_type) { + assert_is_var(); + asserts_.emplace_back([=](const Node *x) { + for (auto *op : x->inlinks) { + if (op && op->IsStmt()) { + auto *op_info = x->stmt()->op_info(); + if (op_info->Type() == op_type) return true; + } + } + return false; + }); + return this; +} + +PMNode *PMNode::assert_is_op_input(const std::string &op_type) { + assert_is_var(); + asserts_.emplace_back([=](const Node *x) { + for (auto *op : x->outlinks) { + if (op && op->IsStmt()) { + auto *op_info = op->stmt()->op_info(); + if (op_info->Type() == op_type) { + return true; + } + } + } + return false; + }); + return this; +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.h b/paddle/fluid/lite/core/mir/pattern_matcher.h new file mode 100644 index 00000000000..8ea8f615aeb --- /dev/null +++ b/paddle/fluid/lite/core/mir/pattern_matcher.h @@ -0,0 +1,371 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_TESTING +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/lite/core/mir/node.h" +#include "paddle/fluid/lite/core/mir/ssa_graph.h" +#include "paddle/fluid/lite/model_parser/pb/op_desc.h" + +namespace paddle { +namespace lite { +namespace mir { +class PMPattern; + +// Some basic terminologies: +// - PMPattern: a pattern defined as a data flow graph. +// - PMNode: the node in the pattern, each PMNode represents an `mir::Node` +// that meets some conditions defined in `PMNode.teller`. +// - A pattern is defined with PMNodes with edges. + +// Pattern matcher node. This node helps to build a pattern. +struct PMNode { + // tell whether an mir::Node* is a candidation for a PMNode. + using teller_t = std::function; + enum class Type { kOp, kVar }; + enum class Role { + kUnknown, // No role, + kInput, // an input and will be retained, + kOutput, // an output and will be retained, + kIntermediate // will be removed after handler. + }; + + // this link to others + PMNode& LinksTo(const std::vector& others); + PMNode& LinksFrom(const std::vector& others); + + bool Tell(const Node* node) const { + if (teller_) return teller_(node); + + for (auto& asrt : asserts_) { + if (!asrt(node)) return false; + } + return true; + } + + bool IsOp() const { return type_ == Type::kOp; } + bool IsVar() const { return type_ == Type::kVar; } + + const std::string& name() const { return name_; } + + PMNode& operator=(const PMNode&) = delete; + PMNode(const PMNode&) = delete; + + // Mark this node is an Input of a subgraph and will be retained. + PMNode* AsInput() { + role_ = Role::kInput; + return this; + } + // Mark this node is an Output of a subgraph and will be retained. + PMNode* AsOutput() { + role_ = Role::kOutput; + return this; + } + // Mark this node will be removed, so all the links should be inside a matched + // sub-graph. + PMNode* AsIntermediate() { + role_ = Role::kIntermediate; + return this; + } + + bool IsIntermediate() const { return role_ == Role::kIntermediate; } + bool IsInput() const { return role_ == Role::kInput; } + bool IsOutput() const { return role_ == Role::kOutput; } + + // Assertions, helper functions to simplify the pattern definition. + PMNode* assert_is_op(); + PMNode* assert_is_op(const std::string& op_type); + PMNode* assert_is_var(); + PMNode* assert_var_not_persistable(); + PMNode* assert_is_persistable_var(); + PMNode* assert_is_op_output(const std::string& op_type); + PMNode* assert_is_op_input(const std::string& op_type); + + template + PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { + asserts_.emplace_back([=](Node* x) { + if (x && x->IsStmt()) { + auto* op_info = x->stmt()->op_info(); + return op_info->HasAttr(attr_name) && + op_info->GetAttr(attr_name) == attr; + } else { + return false; + } + }); + return this; + } + + private: + PMNode(PMPattern* pattern, const std::string& name = "", + Type type = Type::kVar) + : pattern_(pattern), name_(name), type_(type) {} + PMNode(teller_t&& teller, PMPattern* pattern, const std::string& name = "", + Type type = Type::kVar) + : teller_(std::move(teller)), + pattern_(pattern), + name_(name), + type_(type) { + CHECK(teller_ != nullptr) << "invalid teller functer is set."; + } + + PMNode(PMNode&& other) = default; + + friend class PMPattern; + + // Will removed latter. + teller_t teller_; + std::vector asserts_; + PMPattern* pattern_; + std::string name_; + Type type_; + Role role_{Role::kUnknown}; +}; + +/* + * A pattern in a graph, which defined with PMNode and edges. Most graph + * patterns can be divided into PMNodes and link relations between them. + * + * For example, the FC fusion need to filter the MUL and ELEMENTWISE_ADD + * operators from the computation graph, the MUL's output should have only one + * consumer which is the ELEMENTWISE_ADD. + * This pattern can be defined as with the following pseudo codes + * + * // Create two operator PMNodes. + * MUL = PMPattern.NewNode().assert_is_op("mul"); + * ELE = PMPattern.NewNode().assert_is_op("elementwise_add"); + * // Create the variable PMNodes. + * MUL_out = PMPattern.NewNode().assert_is_op_output("mul") \ + * .assert_is_op_input("elementwise_add") \ + * .AsIntermediate(); + * // Add relations. + * MUL->LinksTo({MUL_out}); + * MUL_out->LinksTo({ELE}); + * + * One can add more specific asserts for PMNodes or edges, both the Operator + * and Variable Nodes can be ruled in PMNode.assert_more(...). + * + * PMPattern can record the general patterns, such as the pattern represents + * - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place. + * - Ops whose inputs and outputs share the same variables + */ +class PMPattern { + public: + using edge_t = std::pair; + + void AddEdge(PMNode* a, PMNode* b); + + PMNode* NewNode(PMNode::teller_t&& teller, const std::string& name = NewID()); + PMNode* NewNode(const std::string& name = NewID()); + PMNode* NewNode(const std::string& prefix, const std::string& name) { + return NewNode(prefix + "/" + name); + } + PMNode* RetrieveNode(const std::string& id) const; + + const std::vector>& nodes() const { return nodes_; } + const std::vector& edges() const { return edges_; } + + std::string DotString() const; + + private: +#ifdef PADDLE_WITH_TESTING + FRIEND_TEST(PMPattern, AddEdge); + FRIEND_TEST(PMPattern, NewNode); +#endif + + static std::string NewID() { return "pmnode-" + std::to_string(id_++); } + + std::vector> nodes_; + std::vector edges_; + std::unordered_map node_map_; + static size_t id_; +}; + +/* + * PatternMatcher helps to detect the specific patterns in the graph. + * Input a pattern, output a list of the matched subgraphs/nodes. + * This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.). + * + * The algorithm has three phases: + * 1. Mark the nodes that match the defined PMNodes in a PMPattern, + * 2. Extend a PMNode to subgraphs by deducing the connection relation defined + * in PAPattern(the edges), + * 3. Get the filtered subgraphs and treat them with a pre-defined handler. + * + * Usage: + * // Create a matcher + * PatternMatcher matcher; + * // Define the matcher's pattern, by adding PMNode and define the edges. + * auto* node0 = matcher.mutable_pattern().AddNode(...) + * auto* node1 = matcher.mutable_pattern().AddNode(...) + * node0->teller = some lambda. + * node1->teller = some lambda. + * matcher.mutable_pattern().AddEdge(node0, node1); + * // Create an handler, to define the behavior of treating the filtered + * // subgraphs that comply with the patterns. + * PatternMatcher::handle_t handler = some labmda + * // Execute the matcher. + * matcher(&graph, handler); + */ +class PatternMatcher { + public: + using subgraph_t = std::unordered_map; + + // Operate on the detected pattern. + using handle_t = + std::function; + + void operator()(SSAGraph* graph, handle_t handler); + + const PMPattern& pattern() const { return pattern_; } + PMPattern* mutable_pattern() { return &pattern_; } + + private: + // Mark the nodes that fits the pattern. + bool MarkPMNodesInGraph(SSAGraph* graph); + + // Detect all the pattern and output the hit records. + std::vector DetectPatterns(); + + // Remove duplicate patterns. + void UniquePatterns(std::vector* subgraphs); + + // Remove overlapped match subgraphs, when overlapped, keep the previous one. + // The intermediate PMNodes will be removed, so can't shared by multiple + // patterns. + void RemoveOverlappedMatch(std::vector* subgraphs); + + // Validate whether the intermediate nodes are linked by external nodes. + void ValidateByNodeRole(std::vector* subgraphs); + +#ifdef PADDLE_WITH_TESTING + FRIEND_TEST(PatternMatcher, MarkPMNodesInGraph); + FRIEND_TEST(PatternMatcher, DetectPatterns); +#endif + + private: + using hit_rcd_t = + std::pair; + PMPattern pattern_; + std::unordered_map> pmnodes2nodes_; +}; + +// Some pre-defined patterns those can be reused in multiple passes. +// The related Fluid Layer or Op should be one pattern here for better re-usage +// across different fusion. +namespace patterns { + +struct KeyCounter { + static KeyCounter& Instance() { + static KeyCounter x; + return x; + } + + int IncCounter(const std::string& key) { return dic_[key]++; } + + private: + std::unordered_map dic_; +}; + +// Generate a unique PMNode's name with name_scope and id. +// The format is {name_scope}/{repr}/{id}/{name} +static std::string PMNodeName(const std::string& name_scope, + const std::string& repr, size_t id, + const std::string& name) { + std::stringstream ss; + ss << name_scope << "/" << repr << "/" << id << "/" << name; + return ss.str(); +} +// Generate a unique PMNode's name. +// The format is {name_scope}/{repr}/{id} +static std::string PMNodeName(const std::string& name_scope, + const std::string& repr) { + std::stringstream ss; + ss << name_scope << "/" << repr << "/" + << KeyCounter::Instance().IncCounter(repr); + return ss.str(); +} +// Generate a unique key. It can be used for a universally unique temporary +// name. +// The format is {repr}/{id} +static std::string UniqueKey(const std::string& repr) { + std::stringstream ss; + ss << repr << "/" << KeyCounter::Instance().IncCounter(repr); + return ss.str(); +} + +// Declare a PMNode in a pattern, will create two methods: +// std::string xxx_repr(); return this PMNode's string id. +// PMNode* xxx_n(); return the corresponding PMNode. +#define PATTERN_DECL_NODE(name__) \ + std::string name__##_repr() const { \ + return PMNodeName(name_scope_, repr_, id_, #name__); \ + } \ + PMNode* name__##_n() const { return pattern->RetrieveNode(name__##_repr()); } + +// Get an mir::Node* from the matched subgraph. +// var: variable. +// arg: the argument declared by PATTERN_DECL_NODE in a pattern definition. +// pat: the pattern object. +#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \ + CHECK(subgraph.count(pat.arg##_n())) \ + << "Node not found for PMNode " pat.arg##_repr(); \ + Node* var = subgraph.at(pat.arg##_n()); \ + CHECK(var) << "node " << #arg << "not exists in the sub-graph" + +// The base class of all the patterns. +struct PatternBase { + PatternBase(PMPattern* pattern, const std::string& name_scope, + const std::string& repr) + : pattern(pattern), + name_scope_(name_scope), + repr_(repr), + id_(KeyCounter::Instance().IncCounter(repr)) {} + + PMPattern* pattern; + + protected: + std::string name_scope_; + std::string repr_; + size_t id_; +}; + +} // namespace patterns + +// Link two mir::Nodes from each other. +#define IR_NODE_LINK_TO(a, b) \ + a->outlinks.push_back(b); \ + b->inlinks.push_back(a); + +// Set the out_var as the output of the op +#define IR_OP_VAR_LINK(op, out_var) \ + op->outlinks.push_back(out_var); \ + out_var->inlinks.clear(); \ + out_var->inlinks.push_back(op); + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_tester.cc b/paddle/fluid/lite/core/mir/pattern_matcher_tester.cc new file mode 100644 index 00000000000..3b082060fe2 --- /dev/null +++ b/paddle/fluid/lite/core/mir/pattern_matcher_tester.cc @@ -0,0 +1,233 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pattern_matcher.h" + +#include + +namespace paddle { +namespace lite { +namespace mir { + +void BuildGraph(SSAGraph* g) { + g->mutable_nodes().emplace_back(); + Node& o1 = g->mutable_nodes().back(); + o1.AsStmt().op_type = "op1"; + g->mutable_nodes().emplace_back(); + Node& o2 = g->mutable_nodes().back(); + o2.AsStmt().op_type = "op2"; + g->mutable_nodes().emplace_back(); + Node& o3 = g->mutable_nodes().back(); + o3.AsStmt().op_type = "op3"; + g->mutable_nodes().emplace_back(); + Node& o4 = g->mutable_nodes().back(); + o4.AsStmt().op_type = "op4"; + g->mutable_nodes().emplace_back(); + Node& o5 = g->mutable_nodes().back(); + o5.AsStmt().op_type = "op5"; + g->mutable_nodes().emplace_back(); + Node& v1 = g->mutable_nodes().back(); + v1.AsArg("var1"); + g->mutable_nodes().emplace_back(); + Node& v2 = g->mutable_nodes().back(); + v2.AsArg("var2"); + g->mutable_nodes().emplace_back(); + Node& v3 = g->mutable_nodes().back(); + v3.AsArg("var3"); + g->mutable_nodes().emplace_back(); + Node& v4 = g->mutable_nodes().back(); + v4.AsArg("var4"); + + // o1->v1->o2 + o1.outlinks.push_back(&v1); + o2.inlinks.push_back(&v1); + v1.inlinks.push_back(&o1); + v1.outlinks.push_back(&o2); + // o2->v2->o3 + // o2->v2->o4 + o2.outlinks.push_back(&v2); + o3.inlinks.push_back(&v2); + o4.inlinks.push_back(&v2); + v2.inlinks.push_back(&o2); + v2.outlinks.push_back(&o3); + v2.outlinks.push_back(&o4); + // o2->v3->o5 + o2.outlinks.push_back(&v3); + o5.inlinks.push_back(&v3); + v3.inlinks.push_back(&o2); + v3.outlinks.push_back(&o5); + // o3-v4->o5 + o3.outlinks.push_back(&v4); + o5.inlinks.push_back(&v4); + v4.inlinks.push_back(&o3); + v4.outlinks.push_back(&o5); +} + +TEST(PMPattern, NewNode) { + PMPattern x; + auto* n = x.NewNode([](const Node* x) { return true; }); + ASSERT_TRUE(n); + ASSERT_EQ(x.nodes_.size(), 1UL); +} + +TEST(PMPattern, AddEdge) { + PMPattern x; + auto* a = x.NewNode([](const Node* x) { return true; }); + auto* b = x.NewNode([](const Node* x) { return true; }); + ASSERT_TRUE(a); + ASSERT_TRUE(b); + x.AddEdge(a, b); + ASSERT_EQ(x.nodes_.size(), 2UL); + ASSERT_EQ(x.edges_.size(), 1UL); + ASSERT_EQ(x.edges_.front().first, a); + ASSERT_EQ(x.edges_.front().second, b); + + ASSERT_EQ(x.nodes().size(), 2UL); + ASSERT_EQ(x.edges().size(), 1UL); + ASSERT_EQ(x.edges().front().first, a); + ASSERT_EQ(x.edges().front().second, b); +} + +TEST(PatternMatcher, MarkPMNodesInGraph) { + PatternMatcher x; + // mark o2, o3, v2 + + // The pattern is a graph: + // o2(a node named o2) -> v2(a node named v2) + // v2 -> o3(a node named o3) + auto* o2 = x.pattern_.NewNode([](const Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->IsStmt() && node->stmt()->op_type == "op2"; + }); + auto* o3 = x.pattern_.NewNode([](const Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->IsStmt() && node->stmt()->op_type == "op3"; + }); + auto* v2 = x.pattern_.NewNode([](const Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->IsArg() && node->arg()->name == "var2"; + }); + + ASSERT_FALSE(o2->Tell(nullptr)); + ASSERT_FALSE(o3->Tell(nullptr)); + ASSERT_FALSE(v2->Tell(nullptr)); + + x.pattern_.AddEdge(o2, v2); + x.pattern_.AddEdge(v2, o3); + + ASSERT_EQ(x.pattern_.edges().size(), 2UL); + ASSERT_EQ(x.pattern_.edges()[0].first, o2); + ASSERT_EQ(x.pattern_.edges()[0].second, v2); + ASSERT_EQ(x.pattern_.edges()[1].first, v2); + ASSERT_EQ(x.pattern_.edges()[1].second, o3); + + SSAGraph graph; + BuildGraph(&graph); + + x.MarkPMNodesInGraph(&graph); + + ASSERT_EQ(x.pmnodes2nodes_.size(), 3UL); + + auto subgraphs = x.DetectPatterns(); + ASSERT_EQ(subgraphs.size(), 1UL); +} + +TEST(PatternMatcher, MultiSubgraph) { + SSAGraph graph; + BuildGraph(&graph); + + PatternMatcher x; + + // The pattern is a graph: + // op -> var + auto* any_op = x.mutable_pattern()->NewNode( + [](const Node* node) { + return node->IsStmt() && (node->stmt()->op_type == "op2" || + node->stmt()->op_type == "op3"); + }, + "OP0"); + auto* any_var = + x.mutable_pattern() + ->NewNode([](const Node* node) { return node->IsArg(); }, "VAR") + ->AsIntermediate(); + auto* any_op1 = x.mutable_pattern()->NewNode( + [](const Node* node) { return node->IsStmt(); }, "OP1"); + + x.mutable_pattern()->AddEdge(any_op, any_var); + x.mutable_pattern()->AddEdge(any_var, any_op1); + + int count = 0; + PatternMatcher::handle_t handle = [&](const PatternMatcher::subgraph_t& s, + SSAGraph* g) { + LOG(INFO) << "Detect " << s.at(any_op)->stmt()->op_type << " -> " + << s.at(any_var)->arg()->name << " -> " + << s.at(any_op1)->stmt()->op_type; + count++; + }; + + x(&graph, handle); + + // 1. Detect op3 -> var4 -> op5 + // 2. Detect op2 -> var2 -> op3 + // 3. Detect op2 -> var2 -> op4 + // 4. Detect op2 -> var3 -> op5 + // But 2 and 3 and 4 overlapped, so keep 2, so the final choices are 1 and 2 + ASSERT_GE(count, 1); + ASSERT_LE(count, 2); +} + +TEST(PatternMatcher, IntermediateCheck) { + SSAGraph graph; + BuildGraph(&graph); + + // o2->v2->o3 + // o2->v2->o4 + // check o2+o3 fuse, should fail because v2 also link to o4. + PatternMatcher matcher; + auto* op2 = matcher.mutable_pattern()->NewNode( + [](const Node* x) { + return x && x->IsStmt() && x->stmt()->op_type == "op2"; + }, + "op2"); + auto* op3 = matcher.mutable_pattern()->NewNode( + [](const Node* x) { + return x && x->IsStmt() && x->stmt()->op_type == "op3"; + }, + "op3"); + auto* v2 = matcher.mutable_pattern() + ->NewNode( + [](const Node* x) { + return x && x->IsArg() && x->arg()->name == "var2"; + }, + "var2") + ->AsIntermediate(); + v2->LinksFrom({op2}).LinksTo({op3}); + + int count = 0; + matcher(&graph, [&](const PatternMatcher::subgraph_t& g, SSAGraph* graph) { + ++count; + }); + EXPECT_EQ(count, 0); + + count = 0; + v2->AsInput(); + matcher(&graph, [&](const PatternMatcher::subgraph_t& g, SSAGraph* graph) { + ++count; + }); + ASSERT_EQ(count, 1); +} + +} // namespace mir +} // namespace lite +} // namespace paddle -- GitLab