diff --git a/mindspore/ccsrc/ir/pattern_matcher.h b/mindspore/ccsrc/ir/pattern_matcher.h index e955c0afd41b6b347d750b2b0c9e26f1a4a0d413..6605b9ce4c81e2163cc35a0a4b03528ddeb64ef8 100644 --- a/mindspore/ccsrc/ir/pattern_matcher.h +++ b/mindspore/ccsrc/ir/pattern_matcher.h @@ -39,6 +39,10 @@ namespace mindspore { template class PBase { public: + bool CheckFunc(const opt::PredicateFuncType &func, const AnfNodePtr &node) { + return func(get_object().GetNode(node)); + } + const T &get_object() const { return *static_cast(this); } template diff --git a/mindspore/ccsrc/optimizer/irpass/branch_culling.h b/mindspore/ccsrc/optimizer/irpass/branch_culling.h index bb5e0218864172a49089686dbf05bf80d707f306..2b5b30bdbfd9b0b1256f7cb7951e843af0e3d0d6 100644 --- a/mindspore/ccsrc/optimizer/irpass/branch_culling.h +++ b/mindspore/ccsrc/optimizer/irpass/branch_culling.h @@ -45,7 +45,7 @@ class SwitchSimplify : public OptimizerCaller { }; MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda, - IsValueNode(cond.GetNode(node))); + cond.CheckFunc(IsValueNode, node)); return nullptr; } @@ -61,7 +61,7 @@ class FloatTupleGetItemSwitch : public OptimizerCaller { PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x), PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x), PPrimitive(prim::kPrimTupleGetItem, false_br, x)), - IsVNode(x.GetNode(node))); + x.CheckFunc(IsVNode, node)); return nullptr; } }; @@ -72,11 +72,10 @@ class FloatEnvGetItemSwitch : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode cond, true_br, false_br, x, x2; - MATCH_REPLACE_IF(node, - PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2), - PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2), - PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2)), - IsNode(x.GetNode(node)) && IsNode(x2.GetNode(node))); + MATCH_REPLACE(node, + PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2), + PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2), + PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2))); return nullptr; } @@ -142,9 +141,9 @@ class ConvertSwitchReplacement : public OptimizerCaller { return nnode; }; - MATCH_REPLACE_LAMBDA_IF(node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda, - IsNode(cond.GetNode(node_)) && IsValueNode(true_br.GetNode(node_)) && - IsValueNode(false_br.GetNode(node_))); + MATCH_REPLACE_LAMBDA_IF( + node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda, + true_br.CheckFunc(IsValueNode, node_) && false_br.CheckFunc(IsValueNode, node_)); return nullptr; } diff --git a/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h b/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h index 8d700ec7f8c0c2af18ee8967a2b301e0b344fdab..41f379221c61aae8a2de07f229cf09612aa5920b 100644 --- a/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h @@ -21,109 +21,70 @@ #include "optimizer/optimizer.h" #include "optimizer/irpass.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "utils/graph_utils.h" -#include "operator/composite/composite.h" +#include "ir/pattern_matcher.h" namespace mindspore { namespace opt { namespace irpass { // {prim::kPrimMakeRef, X, Y, Z} -> Y -class MakeRefEliminater : public AnfVisitor { +class MakeRefEliminater : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - y_ = nullptr; - auto gety = [this](const AnfNodePtr &node) -> bool { - this->y_ = node; - return true; - }; - - AnfVisitor::Match(prim::kPrimMakeRef, {IsNode, gety, IsNode})(node); - return y_; + PatternNode x, y, z; + MATCH_REPLACE(node, PPrimitive(prim::kPrimMakeRef, x, y, z), y); + return nullptr; } - - void Visit(const AnfNodePtr &) override {} - - private: - AnfNodePtr y_{nullptr}; }; // {prim::kPrimGetRefValue, Parameter} -> Parameter // {prim::kPrimGetRefOrigin, Parameter} -> Parameter -class GetRefParamEliminater : public AnfVisitor { +class GetRefParamEliminater : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - x_ = nullptr; - AnfVisitor::Match(prim::kPrimGetRefOrigin, {IsParam})(node); - if (x_ != nullptr) { - return x_; - } - AnfVisitor::Match(prim::kPrimGetRefValue, {IsParam})(node); - return x_; + PatternNode x; + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node)); + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node)); + return nullptr; } - - void Visit(const AnfNodePtr &node) override { x_ = node; } - - private: - AnfNodePtr x_{nullptr}; }; // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y // {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z -class GetMakeRefEliminater : public AnfVisitor { +class GetMakeRefEliminater : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - auto cnode = node->cast(); - if (cnode == nullptr || cnode->size() != 2) { - return nullptr; - } - - // {prim::kPrimGetRefKey/Value, {...}} - auto ref = cnode->input(1)->cast(); - if (ref == nullptr || !ref->IsApply(prim::kPrimMakeRef) || ref->size() != 4) { - return nullptr; - } - - // {prim::kPrimMakeRef, X, Y, Z} - if (cnode->IsApply(prim::kPrimGetRefKey)) { - return ref->input(1); - } - - if (cnode->IsApply(prim::kPrimGetRefValue)) { - return ref->input(2); - } - - if (cnode->IsApply(prim::kPrimGetRefOrigin)) { - return ref->input(3); - } + PatternNode x, y, z; + MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); + MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); + MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z); return nullptr; } }; // IsValueNode -class ReplaceRefkeyByParam : public AnfVisitor { +class ReplaceRefkeyByParam : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - if (!IsValueNode(node)) { - return nullptr; - } - - auto refkey = GetValueNode(node); - auto resource = std::dynamic_pointer_cast(optimizer->resource()); - MS_EXCEPTION_IF_NULL(resource); - - auto top_graph = resource->func_graph(); - MS_EXCEPTION_IF_NULL(top_graph); - - for (const auto &tnode : top_graph->parameters()) { - auto para = tnode->cast(); - if (para != nullptr && para->name() == refkey->tag()) { - return para; + auto RefKeyLambda = [&node, &optimizer]() -> AnfNodePtr { + auto refkey = GetValueNode(node); + auto resource = std::dynamic_pointer_cast(optimizer->resource()); + MS_EXCEPTION_IF_NULL(resource); + + auto top_graph = resource->func_graph(); + MS_EXCEPTION_IF_NULL(top_graph); + + for (const auto &tnode : top_graph->parameters()) { + auto para = tnode->cast(); + if (para != nullptr && para->name() == refkey->tag()) { + return para; + } } - } + return nullptr; + }; + PatternNode x; + MATCH_REPLACE_LAMBDA_IF(node, x, RefKeyLambda, x.CheckFunc(IsValueNode, node)); return nullptr; } };