提交 64923214 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4248 Update RowTensorEliminater and IndexedSliceEliminate to Pattern Matcher

Merge pull request !4248 from Giancarlo/pm_update_sparsetensor
......@@ -20,10 +20,10 @@
#include <vector>
#include <algorithm>
#include "frontend/operator/ops.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
namespace opt {
......@@ -31,43 +31,16 @@ namespace irpass {
// {prim::kPrimRowTensorGetIndices, {prim::kPrimMakeRowTensor, Xs}}
// {prim::kPrimRowTensorGetValues, {prim::kPrimMakeRowTensor, Xs}}
// {prim::kPrimRowTensorGetDenseShape, {prim::kPrimMakeRowTensor, Xs}}
class RowTensorEliminater : public AnfVisitor {
class RowTensorEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimRowTensorGetIndices, {IsCNode})(node);
if (is_match_) {
return tuple_->input(1);
}
AnfVisitor::Match(prim::kPrimRowTensorGetValues, {IsCNode})(node);
if (is_match_) {
return tuple_->input(2);
}
AnfVisitor::Match(prim::kPrimRowTensorGetDenseShape, {IsCNode})(node);
if (is_match_) {
return tuple_->input(3);
}
PatternNode x, y, z;
auto slices = PPrimitive(prim::kPrimMakeRowTensor, x, y, z).MinExtraNodes(0);
MATCH_REPLACE(node, PPrimitive(prim::kPrimRowTensorGetIndices, slices), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimRowTensorGetValues, slices), y);
MATCH_REPLACE(node, PPrimitive(prim::kPrimRowTensorGetDenseShape, slices), z);
return nullptr;
}
void Visit(const CNodePtr &cnode) override {
if (IsPrimitiveCNode(cnode, prim::kPrimMakeRowTensor)) {
tuple_ = cnode;
is_match_ = true;
}
}
void Reset() {
tuple_ = nullptr;
is_match_ = false;
}
private:
bool is_match_{false};
CNodePtr tuple_{nullptr};
};
} // namespace irpass
} // namespace opt
......
......@@ -20,10 +20,10 @@
#include <vector>
#include <algorithm>
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
namespace opt {
......@@ -31,43 +31,16 @@ namespace irpass {
// {prim::kPrimSparseTensorGetIndices, {prim::kPrimMakeSparseTensor, Xs}}
// {prim::kPrimSparseTensorGetValues, {prim::kPrimMakeSparseTensor, Xs}}
// {prim::kPrimSparseTensorGetDenseShape, {prim::kPrimMakeSparseTensor, Xs}}
class SparseTensorEliminater : public AnfVisitor {
class SparseTensorEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimSparseTensorGetIndices, {IsCNode})(node);
if (is_match_) {
return tuple_->input(1);
}
AnfVisitor::Match(prim::kPrimSparseTensorGetValues, {IsCNode})(node);
if (is_match_) {
return tuple_->input(2);
}
AnfVisitor::Match(prim::kPrimSparseTensorGetDenseShape, {IsCNode})(node);
if (is_match_) {
return tuple_->input(3);
}
PatternNode x, y, z;
auto sparse = PPrimitive(prim::kPrimMakeSparseTensor, x, y, z).MinExtraNodes(0);
MATCH_REPLACE(node, PPrimitive(prim::kPrimSparseTensorGetIndices, sparse), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimSparseTensorGetValues, sparse), y);
MATCH_REPLACE(node, PPrimitive(prim::kPrimSparseTensorGetDenseShape, sparse), z);
return nullptr;
}
void Visit(const CNodePtr &cnode) override {
if (IsPrimitiveCNode(cnode, prim::kPrimMakeSparseTensor)) {
tuple_ = cnode;
is_match_ = true;
}
}
void Reset() {
tuple_ = nullptr;
is_match_ = false;
}
private:
bool is_match_{false};
CNodePtr tuple_{nullptr};
};
} // namespace irpass
} // namespace opt
......
......@@ -372,6 +372,8 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > {
return *this;
}
const AnfNodePtrList &GetCapturedExtraNodes() const { return extra_nodes_; }
/// Returns the FuncGraph of the original node captured by this Primitive Pattern.
/// Throws exception if a node was not captured before.
FuncGraphPtr GetFuncGraph() const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册