diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/row_tensor_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/row_tensor_eliminate.h index c237bec0ece1c61315cb7dbf263b207920a56002..abfc54327a85407907eaec64f623c3cd6c184c79 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/row_tensor_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/row_tensor_eliminate.h @@ -20,10 +20,10 @@ #include #include +#include "frontend/operator/ops.h" +#include "frontend/optimizer/anf_visitor.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" -#include "frontend/optimizer/anf_visitor.h" -#include "frontend/operator/ops.h" namespace mindspore { namespace opt { @@ -31,43 +31,16 @@ namespace irpass { // {prim::kPrimRowTensorGetIndices, {prim::kPrimMakeRowTensor, Xs}} // {prim::kPrimRowTensorGetValues, {prim::kPrimMakeRowTensor, Xs}} // {prim::kPrimRowTensorGetDenseShape, {prim::kPrimMakeRowTensor, Xs}} -class RowTensorEliminater : public AnfVisitor { +class RowTensorEliminater : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimRowTensorGetIndices, {IsCNode})(node); - - if (is_match_) { - return tuple_->input(1); - } - AnfVisitor::Match(prim::kPrimRowTensorGetValues, {IsCNode})(node); - - if (is_match_) { - return tuple_->input(2); - } - AnfVisitor::Match(prim::kPrimRowTensorGetDenseShape, {IsCNode})(node); - - if (is_match_) { - return tuple_->input(3); - } + PatternNode x, y, z; + auto slices = PPrimitive(prim::kPrimMakeRowTensor, x, y, z).MinExtraNodes(0); + MATCH_REPLACE(node, PPrimitive(prim::kPrimRowTensorGetIndices, slices), x); + MATCH_REPLACE(node, PPrimitive(prim::kPrimRowTensorGetValues, slices), y); + MATCH_REPLACE(node, PPrimitive(prim::kPrimRowTensorGetDenseShape, slices), z); return nullptr; } - - void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimMakeRowTensor)) { - tuple_ = cnode; - is_match_ = true; - } - } - - void Reset() { - tuple_ = nullptr; - is_match_ = false; - } - - private: - bool is_match_{false}; - CNodePtr tuple_{nullptr}; }; } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/sparse_tensor_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/sparse_tensor_eliminate.h index 07fb4e80b1f4413ef8e33724b9a20b4060fc3fea..c0a7bb5cf70a687fd91b27add158331653ee71f6 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/sparse_tensor_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/sparse_tensor_eliminate.h @@ -20,10 +20,10 @@ #include #include +#include "frontend/operator/ops.h" #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" #include "ir/visitor.h" -#include "frontend/operator/ops.h" namespace mindspore { namespace opt { @@ -31,43 +31,16 @@ namespace irpass { // {prim::kPrimSparseTensorGetIndices, {prim::kPrimMakeSparseTensor, Xs}} // {prim::kPrimSparseTensorGetValues, {prim::kPrimMakeSparseTensor, Xs}} // {prim::kPrimSparseTensorGetDenseShape, {prim::kPrimMakeSparseTensor, Xs}} -class SparseTensorEliminater : public AnfVisitor { +class SparseTensorEliminater : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimSparseTensorGetIndices, {IsCNode})(node); - - if (is_match_) { - return tuple_->input(1); - } - AnfVisitor::Match(prim::kPrimSparseTensorGetValues, {IsCNode})(node); - - if (is_match_) { - return tuple_->input(2); - } - AnfVisitor::Match(prim::kPrimSparseTensorGetDenseShape, {IsCNode})(node); - - if (is_match_) { - return tuple_->input(3); - } + PatternNode x, y, z; + auto sparse = PPrimitive(prim::kPrimMakeSparseTensor, x, y, z).MinExtraNodes(0); + MATCH_REPLACE(node, PPrimitive(prim::kPrimSparseTensorGetIndices, sparse), x); + MATCH_REPLACE(node, PPrimitive(prim::kPrimSparseTensorGetValues, sparse), y); + MATCH_REPLACE(node, PPrimitive(prim::kPrimSparseTensorGetDenseShape, sparse), z); return nullptr; } - - void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimMakeSparseTensor)) { - tuple_ = cnode; - is_match_ = true; - } - } - - void Reset() { - tuple_ = nullptr; - is_match_ = false; - } - - private: - bool is_match_{false}; - CNodePtr tuple_{nullptr}; }; } // namespace irpass } // namespace opt diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h index 1ed559d656572285c2dd2851cafbf57df9ccd134..b04fbed11f4d8120b197f0d2cbf673d51bece3a2 100644 --- a/mindspore/core/ir/pattern_matcher.h +++ b/mindspore/core/ir/pattern_matcher.h @@ -372,6 +372,8 @@ class PPrimitive : public PBase > { return *this; } + const AnfNodePtrList &GetCapturedExtraNodes() const { return extra_nodes_; } + /// Returns the FuncGraph of the original node captured by this Primitive Pattern. /// Throws exception if a node was not captured before. FuncGraphPtr GetFuncGraph() const {