提交 8eaea744 编写于 作者: G Giancarlo Colmenares

Added a Pattern Matcher class to help with future optimization...

Added a Pattern Matcher class to help with future optimization implementations. Includes changes to barnch_culling to show how to use the new Pattern Matcher infrastructure.
上级 5cba231b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_
#define MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_
#include <tuple>
#include <vector>
#include "ir/anf.h"
#include "operator/ops.h"
namespace mindspore {
///
/// Base class for all recognizable patterns.
/// We implement an Expression Template approach using static polymorphism based on
/// the Curiously Recurring Template Pattern (CRTP) which "achieves a similar effect
/// to the use of virtual functions without the costs..." as described in:
/// https://en.wikipedia.org/wiki/Expression_templates and
/// https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
/// The TryCapture function tries to capture the pattern with the given node.
/// The GetNode function builds a new node using the captured values.
///
template <typename T>
class PBase {
public:
const T &get_object() const { return *static_cast<const T *>(this); }
template <typename TN>
bool TryCapture(const TN &value) const {
get_object().Reset();
return get_object().TryCapture_(value);
}
using Internal = T;
};
template <typename T>
class PIsEqual {
public:
bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; }
};
template <typename T>
class PatternNode : public PBase<PatternNode<T> > {
public:
T GetNode(const AnfNodePtr &node) const {
if (!captured_) {
MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode.";
}
return captured_node_;
}
bool TryCapture_(const T &node) const {
if (!captured_) {
captured_node_ = node;
captured_ = true;
return true;
}
return PIsEqual<T>()(captured_node_, node);
}
void Reset() const { captured_ = false; }
using Internal = const PatternNode<T> &;
protected:
mutable T captured_node_;
mutable bool captured_{false};
};
template <typename T, typename T2>
class PBinOperation : public PBase<PBinOperation<T, T2> > {
public:
PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y) : prim_(prim), x_(x), y_(y) {}
AnfNodePtr GetNode(const AnfNodePtr &node) const {
AnfNodePtr lhs = x_.GetNode(node->func_graph());
AnfNodePtr rhs = y_.GetNode(node->func_graph());
AnfNodePtrList list = {prim_->cast<AnfNodePtr>(), lhs, rhs};
return NewCNode(list, node->func_graph());
}
bool TryCapture_(const AnfNodePtr &node) const {
if (IsPrimitiveCNode(node, prim_)) {
auto cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
if (inputs.size() == 3) {
// Binary Prim assumes only two inputs
if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) {
return false;
}
return true;
}
}
return false;
}
void Reset() const {
x_.Reset();
y_.Reset();
}
private:
const PrimitivePtr prim_;
typename T::Internal x_;
typename T2::Internal y_;
};
///
/// Helper functions to apply a pattern function on all elements of a tuple
///
namespace tuple_utils {
template <bool stop, size_t Index, typename Func>
struct apply_func_tuple_item {
template <typename TTuple>
static void apply(Func *func, const TTuple &tuple) {
(*func)(Index, std::get<Index>(tuple));
apply_func_tuple_item<(Index + 1) == std::tuple_size<TTuple>::value, (Index + 1), Func>::apply(func, tuple);
}
};
template <size_t Index, typename Func>
struct apply_func_tuple_item<true, Index, Func> {
template <typename TTuple>
static void apply(Func *func, const TTuple &tuple) {}
};
template <typename Func, typename TTuple>
inline void apply_func_tuple(Func *func, const TTuple &tuple) {
apply_func_tuple_item<std::tuple_size<TTuple>::value == 0, 0, Func>::apply(func, tuple);
}
struct PTupleResetCapture {
template <typename T>
void operator()(size_t i, const T &pattern) const {
pattern.Reset();
}
};
struct PTupleCapture {
explicit PTupleCapture(const AnfNodePtrList tuple) : tuple_(tuple) {}
template <typename TPattern>
void operator()(size_t i, const TPattern &pattern) {
// Check if the first node is a Primitive
if (i == 0 && tuple_[i]->isa<Primitive>()) {
auto prim = tuple_[i]->cast<PrimitivePtr>();
if (tuple_[i] != pattern.GetNode(tuple_[i])) {
captured_ = false;
}
} else {
captured_ = captured_ && pattern.TryCapture_(tuple_[i]);
}
}
const AnfNodePtrList tuple_;
bool captured_{true};
};
struct PTupleGetNode {
explicit PTupleGetNode(const AnfNodePtr &node) : node_(node) {}
template <typename TPattern>
void operator()(size_t, const TPattern &pattern) {
args_.push_back(pattern.GetNode(node_));
}
const AnfNodePtr &node_;
std::vector<AnfNodePtr> args_;
};
} // namespace tuple_utils
template <typename... TArgs>
class PCNode : public PBase<PCNode<TArgs...> > {
public:
explicit PCNode(const TArgs &... args) : args_(args...) {}
AnfNodePtr GetNode(const AnfNodePtr &node) const {
tuple_utils::PTupleGetNode get_node(node);
tuple_utils::apply_func_tuple(&get_node, args_);
return NewCNode(get_node.args_, node->func_graph());
}
bool TryCapture_(const AnfNodePtr &node) const {
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
if (inputs.size() != sizeof...(TArgs)) {
return false;
}
tuple_utils::PTupleCapture capture_func(inputs);
tuple_utils::apply_func_tuple(&capture_func, args_);
return capture_func.captured_;
}
return false;
}
void Reset() const {
tuple_utils::PTupleResetCapture reset;
tuple_utils::apply_func_tuple(&reset, args_);
}
private:
std::tuple<typename TArgs::Internal...> args_;
};
template <typename... TArgs>
class PPrimitive : public PBase<PPrimitive<TArgs...> > {
public:
explicit PPrimitive(const PrimitivePtr &prim, const TArgs &... args) : prim_(prim), args_(args...) {}
AnfNodePtr GetNode(const AnfNodePtr &node) const {
tuple_utils::PTupleGetNode get_node(node);
tuple_utils::apply_func_tuple(&get_node, args_);
auto prim_cnode = get_node.args_;
prim_cnode.insert(prim_cnode.begin(), NewValueNode(prim_));
return NewCNode(prim_cnode, node->func_graph());
}
bool TryCapture_(const AnfNodePtr &node) const {
if (IsPrimitiveCNode(node, prim_)) {
auto cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
if ((inputs.size() - 1) != sizeof...(TArgs)) {
return false;
}
AnfNodePtrList rest(inputs.begin() + 1, inputs.end());
tuple_utils::PTupleCapture capture_func(rest);
tuple_utils::apply_func_tuple(&capture_func, args_);
return capture_func.captured_;
}
return false;
}
void Reset() const {
tuple_utils::PTupleResetCapture reset;
tuple_utils::apply_func_tuple(&reset, args_);
}
private:
const PrimitivePtr prim_;
std::tuple<typename TArgs::Internal...> args_;
};
// Macro for binary operation functions
#define BIN_OPERATION_PATTERN(Operator, MSPrimitive) \
template <typename T, typename T2> \
inline PBinOperation<T, T2> Operator(const PBase<T> &x, const PBase<T2> &y) { \
return PBinOperation(MSPrimitive, x.get_object(), y.get_object()); \
}
// Arithmetic operations
BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd);
BIN_OPERATION_PATTERN(operator*, prim::kPrimMul);
// Macros for match and replace
#define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \
if ((CaptureNode).TryCapture(OrigNode)) { \
return (ReplaceWith).GetNode(OrigNode); \
}
#define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \
if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \
return (ReplaceWith).GetNode(OrigNode); \
}
#define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \
if ((CaptureNode).TryCapture(OrigNode)) { \
if ((Condition)) { \
return (ReplaceWith).GetNode(OrigNode); \
} \
return (ElseNode).GetNode(OrigNode); \
}
#define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \
if ((CaptureNode).TryCapture(OrigNode)) { \
return (Lambda)(); \
}
#define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \
if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \
return (Lambda)(); \
}
} // namespace mindspore
#endif // #ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_
......@@ -26,141 +26,61 @@
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "operator/ops.h"
#include "ir/pattern_matcher.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimSwitch, true, X, Y}
// {prim::kPrimSwitch, false, X, Y}
class SwitchSimplify : public AnfVisitor {
class SwitchSimplify {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
auto getx = [this](const AnfNodePtr &node) -> bool {
this->x_ = node;
return true;
};
auto gety = [this](const AnfNodePtr &node) -> bool {
this->y_ = node;
return true;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
PatternNode<AnfNodePtr> cond, true_br, false_br;
auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr {
auto cond_value_ = GetValue<bool>(GetValueNode(cond.GetNode(node)));
if (cond_value_) {
return true_br.GetNode(node);
}
return false_br.GetNode(node);
};
AnfVisitor::Match(prim::kPrimSwitch, {IsValueNode<BoolImm>, getx, gety})(node);
// simplify the switch
if (is_match_) {
if (cond_) {
return x_;
}
return y_;
}
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda,
IsValueNode<BoolImm>(cond.GetNode(node)));
return nullptr;
}
void Visit(const AnfNodePtr &node) override {
if (!is_match_ && IsValueNode<BoolImm>(node)) {
cond_ = GetValue<bool>(GetValueNode(node));
is_match_ = true;
}
}
void Reset() {
x_ = nullptr;
y_ = nullptr;
cond_ = false;
is_match_ = false;
}
private:
bool is_match_{false}, cond_{false};
AnfNodePtr x_{nullptr}, y_{nullptr};
};
// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} =>
// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}}
class FloatTupleGetItemSwitch : public AnfVisitor {
class FloatTupleGetItemSwitch {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node);
auto fg = node->func_graph();
if (Xs_.empty() || c_ == nullptr || fg == nullptr) {
return nullptr;
}
auto true_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), Xs_[1], c_});
auto false_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), Xs_[2], c_});
return fg->NewCNode({NewValueNode(prim::kPrimSwitch), Xs_[0], true_node, false_node});
}
void Visit(const CNodePtr &cnode) override {
// {prim::kPrimSwith, X1, X2, X3}
if (!IsPrimitiveCNode(cnode, prim::kPrimSwitch) || cnode->size() != 4) {
return;
}
// copy X1, X2, X3
auto &inputs = cnode->inputs();
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_));
}
void Visit(const ValueNodePtr &vnode) override { c_ = vnode; }
void Reset() {
Xs_.clear();
c_ = nullptr;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
PatternNode<AnfNodePtr> cond, true_br, false_br, x;
MATCH_REPLACE_IF(node,
PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x),
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x),
PPrimitive(prim::kPrimTupleGetItem, false_br, x)),
IsVNode(x.GetNode(node)));
return nullptr;
}
private:
AnfNodePtr c_{nullptr};
std::vector<AnfNodePtr> Xs_{};
};
// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}}
class FloatEnvGetItemSwitch : public AnfVisitor {
class FloatEnvGetItemSwitch {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
is_match_ = false;
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsNode, IsNode})(node);
if (!is_match_) {
return nullptr;
}
// {prim::kPrimEnvGetItem, {...}, X4, X5}
auto cnode = node->cast<CNodePtr>();
auto sw_node = cnode->input(1)->cast<CNodePtr>();
auto x4 = cnode->input(2);
auto x5 = cnode->input(3);
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
PatternNode<AnfNodePtr> cond, true_br, false_br, x, x2;
MATCH_REPLACE_IF(node,
PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2),
PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2)),
IsNode(x.GetNode(node)) && IsNode(x2.GetNode(node)));
is_match_ = false;
AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsNode, IsNode})(sw_node);
if (!is_match_) {
return nullptr;
}
// {prim::kPrimSwitch, X1, X2, X3}
auto x1 = sw_node->input(1);
auto x2 = sw_node->input(2);
auto x3 = sw_node->input(3);
auto fg = node->func_graph();
if (fg == nullptr) {
return nullptr;
}
auto true_node = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x2, x4, x5});
auto false_node = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x3, x4, x5});
return fg->NewCNode({NewValueNode(prim::kPrimSwitch), x1, true_node, false_node});
return nullptr;
}
void Visit(const AnfNodePtr &) override { is_match_ = true; }
private:
bool is_match_{false};
};
namespace internal {
......@@ -173,79 +93,64 @@ AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfN
} // namespace internal
// {{prim::kPrimSwitch, X, G1, G2}, Xs}
class ConvertSwitchReplacement : public AnfVisitor {
class ConvertSwitchReplacement {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
Reset();
auto cnode = node->cast<CNodePtr>();
if (cnode->size() < 1) {
auto cnode_ = node->cast<CNodePtr>();
if (cnode_->size() < 1) {
return nullptr;
}
// {prim::kPrimSwitch, X, G1, G2}
AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode<FuncGraph>, IsValueNode<FuncGraph>})(cnode->input(0));
if (g2_ == nullptr || g1_->output() == nullptr || g2_->output() == nullptr) {
return nullptr;
}
// for switch replace method, only graphs without graph inside can be replaced
for (auto &item : g1_->value_nodes()) {
auto value_node = item.first;
if (IsValueNode<FuncGraph>(value_node)) {
return nullptr;
auto node_ = cnode_->input(0);
PatternNode<AnfNodePtr> cond, true_br, false_br;
auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr {
auto g1_ = GetValueNode<FuncGraphPtr>(true_br.GetNode(node_));
auto g2_ = GetValueNode<FuncGraphPtr>(false_br.GetNode(node_));
auto x_ = cond.GetNode(node_);
// for switch replace method, only graphs without graph inside can be replaced
for (auto &item : g1_->value_nodes()) {
auto value_node = item.first;
if (IsValueNode<FuncGraph>(value_node)) {
return nullptr;
}
}
}
for (auto &item : g2_->value_nodes()) {
auto value_node = item.first;
if (IsValueNode<FuncGraph>(value_node)) {
return nullptr;
for (auto &item : g2_->value_nodes()) {
auto value_node = item.first;
if (IsValueNode<FuncGraph>(value_node)) {
return nullptr;
}
}
}
auto true_output = g1_->output()->abstract();
auto false_output = g2_->output()->abstract();
auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1_, x_);
auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_);
std::vector<AnfNodePtr> params;
auto fg = node->func_graph();
auto cloned_g1 = InlineClone(trans_g1, fg, params);
auto cloned_g2 = InlineClone(trans_g2, fg, params);
auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg);
return nnode;
}
auto true_output = g1_->output()->abstract();
auto false_output = g2_->output()->abstract();
auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1_, x_);
auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_);
void Visit(const AnfNodePtr &node) override {
if (x_ == nullptr) {
x_ = node;
return;
}
AnfVisitor::Visit(node);
}
std::vector<AnfNodePtr> params;
auto fg = node_->func_graph();
auto cloned_g1 = InlineClone(trans_g1, fg, params);
auto cloned_g2 = InlineClone(trans_g2, fg, params);
auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg);
void Visit(const ValueNodePtr &vnode) override {
auto g = GetValueNode<FuncGraphPtr>(vnode);
if (g1_ == nullptr) {
g1_ = g;
} else {
g2_ = g;
}
}
return nnode;
};
void Reset() {
x_ = nullptr;
g1_ = nullptr;
g2_ = nullptr;
}
MATCH_REPLACE_LAMBDA_IF(node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda,
IsNode(cond.GetNode(node_)) && IsValueNode<FuncGraph>(true_br.GetNode(node_)) &&
IsValueNode<FuncGraph>(false_br.GetNode(node_)));
private:
AnfNodePtr x_{nullptr};
FuncGraphPtr g1_{nullptr}, g2_{nullptr};
return nullptr;
}
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册