提交 b2ec296f 编写于 作者: Z zhousiyi

add opt pass for tuple_getitem with constant input

上级 5b14292f
...@@ -51,6 +51,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { ...@@ -51,6 +51,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul});
special_op_eliminate_ = special_op_eliminate_ =
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
......
...@@ -33,6 +33,7 @@ class OptimizeIRPassLib { ...@@ -33,6 +33,7 @@ class OptimizeIRPassLib {
~OptimizeIRPassLib() = default; ~OptimizeIRPassLib() = default;
SubstitutionPtr arithmetic_simplify_; SubstitutionPtr arithmetic_simplify_;
SubstitutionPtr arithmetic_simplify2_;
SubstitutionPtr special_op_eliminate_; SubstitutionPtr special_op_eliminate_;
SubstitutionPtr zero_like_fill_zero_; SubstitutionPtr zero_like_fill_zero_;
SubstitutionPtr adjust_all_reduce_mul_add_; SubstitutionPtr adjust_all_reduce_mul_add_;
......
...@@ -139,76 +139,8 @@ class CheckTensorConstant { ...@@ -139,76 +139,8 @@ class CheckTensorConstant {
int check_value_; int check_value_;
}; };
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} class TensorMultiplyBase : public AnfVisitor {
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} protected:
class TensorMultiplyByZeroOrOne : public AnfVisitor {
public:
TensorMultiplyByZeroOrOne() : zero_(MakeValue(0)) {}
~TensorMultiplyByZeroOrOne() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimMul)(node);
if (is_zero_) {
if (x_->func_graph() != node->func_graph()) {
return nullptr;
}
return NewTensorFilledWithData(node);
}
if (is_one_) {
return NewTensorFilledWithData(node, x_);
}
return nullptr;
}
void Visit(const AnfNodePtr &node) override {
if (is_zero_ || is_one_) {
x_ = node;
return;
}
if (IsParam(node)) {
x_ = node;
return;
}
if (IsCNode(node)) {
CNodePtr cnode = node->cast<CNodePtr>();
if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) {
is_zero_ = true;
return;
}
x_ = node;
return;
}
auto value = node->cast<ValueNodePtr>()->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
} else if (CheckTensorConstant(1).IsTensorConstant(value)) {
is_one_ = true;
return;
}
x_ = node;
}
void Visit(const ValueNodePtr &vnode) override {
auto value = vnode->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
} else if (CheckTensorConstant(1).IsTensorConstant(value)) {
is_one_ = true;
return;
}
x_ = vnode;
}
void Reset() {
x_ = nullptr;
is_one_ = false;
is_zero_ = false;
}
void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) { void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) {
if (!node->isa<ValueNode>()) { if (!node->isa<ValueNode>()) {
return nullptr; return nullptr;
...@@ -287,10 +219,122 @@ class TensorMultiplyByZeroOrOne : public AnfVisitor { ...@@ -287,10 +219,122 @@ class TensorMultiplyByZeroOrOne : public AnfVisitor {
return new_vnode; return new_vnode;
} }
AnfNodePtr x_{nullptr};
};
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
class TensorMultiplyByZero : public TensorMultiplyBase {
public:
TensorMultiplyByZero() : zero_(MakeValue(0)) {}
~TensorMultiplyByZero() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimMul)(node);
if (is_zero_) {
if (x_->func_graph() != node->func_graph()) {
return nullptr;
}
return NewTensorFilledWithData(node);
}
return nullptr;
}
void Visit(const AnfNodePtr &node) override {
if (is_zero_) {
x_ = node;
return;
}
if (IsParam(node)) {
x_ = node;
return;
}
if (IsCNode(node)) {
CNodePtr cnode = node->cast<CNodePtr>();
if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) {
is_zero_ = true;
return;
}
x_ = node;
return;
}
auto value = node->cast<ValueNodePtr>()->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
}
x_ = node;
}
void Visit(const ValueNodePtr &vnode) override {
auto value = vnode->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
}
x_ = vnode;
}
void Reset() {
x_ = nullptr;
is_zero_ = false;
}
private: private:
bool is_zero_{false}, is_one_{false}; bool is_zero_{false};
ValuePtr zero_; ValuePtr zero_;
AnfNodePtr x_{nullptr}; };
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
class TensorMultiplyByOne : public TensorMultiplyBase {
public:
TensorMultiplyByOne() {}
~TensorMultiplyByOne() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimMul)(node);
if (is_one_) {
return NewTensorFilledWithData(node, x_);
}
return nullptr;
}
void Visit(const AnfNodePtr &node) override {
if (is_one_) {
x_ = node;
return;
}
if (IsParam(node) || IsCNode(node)) {
x_ = node;
return;
}
auto value = node->cast<ValueNodePtr>()->value();
if (CheckTensorConstant(1).IsTensorConstant(value)) {
is_one_ = true;
return;
}
x_ = node;
}
void Visit(const ValueNodePtr &vnode) override {
auto value = vnode->value();
if (CheckTensorConstant(1).IsTensorConstant(value)) {
is_one_ = true;
return;
}
x_ = vnode;
}
void Reset() {
x_ = nullptr;
is_one_ = false;
}
private:
bool is_one_{false};
}; };
// {prim::kPrimScalarAdd, X, 0} // {prim::kPrimScalarAdd, X, 0}
...@@ -699,7 +743,7 @@ class ArithmeticSimplify { ...@@ -699,7 +743,7 @@ class ArithmeticSimplify {
public: public:
ArithmeticSimplify() ArithmeticSimplify()
: multiply_by_zero_or_one_(), : multiply_by_zero_or_one_(),
tensor_multiply_by_zero_or_one_(), tensor_multiply_by_one_(),
add_by_zero_(), add_by_zero_(),
tensor_add_by_zero_(), tensor_add_by_zero_(),
identity_(prim::kPrimIdentity), identity_(prim::kPrimIdentity),
...@@ -707,7 +751,7 @@ class ArithmeticSimplify { ...@@ -707,7 +751,7 @@ class ArithmeticSimplify {
constant_duplicate_mul_(), constant_duplicate_mul_(),
power_one_() { power_one_() {
eliminaters_.emplace_back(multiply_by_zero_or_one_); eliminaters_.emplace_back(multiply_by_zero_or_one_);
eliminaters_.emplace_back(tensor_multiply_by_zero_or_one_); eliminaters_.emplace_back(tensor_multiply_by_one_);
eliminaters_.emplace_back(add_by_zero_); eliminaters_.emplace_back(add_by_zero_);
eliminaters_.emplace_back(tensor_add_by_zero_); eliminaters_.emplace_back(tensor_add_by_zero_);
eliminaters_.emplace_back(identity_); eliminaters_.emplace_back(identity_);
...@@ -730,7 +774,7 @@ class ArithmeticSimplify { ...@@ -730,7 +774,7 @@ class ArithmeticSimplify {
private: private:
MultiplyByZeroOrOne multiply_by_zero_or_one_; MultiplyByZeroOrOne multiply_by_zero_or_one_;
TensorMultiplyByZeroOrOne tensor_multiply_by_zero_or_one_; TensorMultiplyByOne tensor_multiply_by_one_;
AddByZero add_by_zero_; AddByZero add_by_zero_;
TensorAddByZero tensor_add_by_zero_; TensorAddByZero tensor_add_by_zero_;
PrimEliminater identity_; PrimEliminater identity_;
...@@ -739,6 +783,32 @@ class ArithmeticSimplify { ...@@ -739,6 +783,32 @@ class ArithmeticSimplify {
PowerOneEliminate power_one_; PowerOneEliminate power_one_;
std::vector<TransformFuncType> eliminaters_{}; std::vector<TransformFuncType> eliminaters_{};
}; };
// Arithmetic Simplifications should be done after step_parallel.
// eg: Mul(0, weight) where weight is a parameter will be simplified to a constant tensor
// with shape(weight), but after step_parallel, shape of weight may be changed, so the
// shape of the constant tensor should also be changed. So this pass is seperated from
// ArithmeticSimplify and deferred until step_parallel.
class ArithmeticSimplify2 {
public:
ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); }
~ArithmeticSimplify2() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
}
return nullptr;
}
private:
TensorMultiplyByZero tensor_multiply_by_zero_;
std::vector<TransformFuncType> eliminaters_{};
};
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
......
...@@ -70,6 +70,45 @@ class GetitemEliminater : public AnfVisitor { ...@@ -70,6 +70,45 @@ class GetitemEliminater : public AnfVisitor {
CNodePtr tuple_{nullptr}; CNodePtr tuple_{nullptr};
}; };
// (a, b, c, ...)[0] => a
// (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, C1, C}
class GetitemConstEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsVNode, IsVNode})(node);
if (is_match_) {
return NewValueNode((*tuple_)[id_]);
}
return nullptr;
}
void Visit(const ValueNodePtr &vnode) override {
if (IsValueNode<ValueTuple>(vnode)) {
tuple_ = GetValueNode<ValueTuplePtr>(vnode);
}
if (tuple_ != nullptr && IsValueNode<Int32Imm>(vnode)) {
id_ = IntToSize(GetValue<int>(vnode->value()));
if (tuple_->size() > id_) {
is_match_ = true;
}
}
}
void Reset() {
id_ = 0;
tuple_ = nullptr;
is_match_ = false;
}
private:
bool is_match_{false};
size_t id_{0};
ValueTuplePtr tuple_{nullptr};
};
// setitem((a, b, c, ...), 0, z) => (z, b, c, ...) // setitem((a, b, c, ...), 0, z) => (z, b, c, ...)
// setitem((a, b, c, ...), 1, z) => (a, z, c, ...) // setitem((a, b, c, ...), 1, z) => (a, z, c, ...)
// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z} // {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z}
...@@ -225,8 +264,13 @@ class GetitemDependReorder : public AnfVisitor { ...@@ -225,8 +264,13 @@ class GetitemDependReorder : public AnfVisitor {
class ItemTupleEliminater { class ItemTupleEliminater {
public: public:
ItemTupleEliminater() ItemTupleEliminater()
: get_item_eliminater_(), set_item_eliminater_(), get_set_item_eliminater_(), get_item_depend_reorder_() { : get_item_eliminater_(),
get_item_const_eliminater_(),
set_item_eliminater_(),
get_set_item_eliminater_(),
get_item_depend_reorder_() {
eliminaters_.emplace_back(get_item_eliminater_); eliminaters_.emplace_back(get_item_eliminater_);
eliminaters_.emplace_back(get_item_const_eliminater_);
eliminaters_.emplace_back(set_item_eliminater_); eliminaters_.emplace_back(set_item_eliminater_);
eliminaters_.emplace_back(get_set_item_eliminater_); eliminaters_.emplace_back(get_set_item_eliminater_);
eliminaters_.emplace_back(get_item_depend_reorder_); eliminaters_.emplace_back(get_item_depend_reorder_);
...@@ -246,6 +290,7 @@ class ItemTupleEliminater { ...@@ -246,6 +290,7 @@ class ItemTupleEliminater {
private: private:
GetitemEliminater get_item_eliminater_; GetitemEliminater get_item_eliminater_;
GetitemConstEliminater get_item_const_eliminater_;
SetitemEliminater set_item_eliminater_; SetitemEliminater set_item_eliminater_;
GetSetitemEliminater get_set_item_eliminater_; GetSetitemEliminater get_set_item_eliminater_;
GetitemDependReorder get_item_depend_reorder_; GetitemDependReorder get_item_depend_reorder_;
......
...@@ -114,6 +114,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { ...@@ -114,6 +114,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.depend_value_elim_, irpass.depend_value_elim_,
}); });
opt::OptPassConfig a_3 = opt::OptPassConfig({ opt::OptPassConfig a_3 = opt::OptPassConfig({
irpass.arithmetic_simplify2_,
irpass.same_eliminate_, irpass.same_eliminate_,
irpass.check_bprop_eliminate_, irpass.check_bprop_eliminate_,
irpass.replace_applicator_, irpass.replace_applicator_,
......
...@@ -20,9 +20,12 @@ ...@@ -20,9 +20,12 @@
#include "common/py_func_graph_fetcher.h" #include "common/py_func_graph_fetcher.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "ir/manager.h" #include "ir/manager.h"
#include "ir/value.h"
#include "ir/visitor.h" #include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "pipeline/resource.h" #include "pipeline/resource.h"
#include "debug/draw.h" #include "debug/draw.h"
...@@ -343,9 +346,26 @@ TEST_F(TestOptLib, test_tuple_getitem) { ...@@ -343,9 +346,26 @@ TEST_F(TestOptLib, test_tuple_getitem) {
FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_getitem", "after_0"); FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_getitem", "after_0");
FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_getitem", "after_1"); FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_getitem", "after_1");
FuncGraphPtr make_get_const = std::make_shared<FuncGraph>();
auto value_node_1 = NewValueNode(1);
auto value_node_2 = NewValueNode(2);
std::vector<int> vec{1, 2};
auto value_node_tuple = NewValueNode(MakeValue(vec));
std::vector<AnfNodePtr> node_list{
NewValueNode(prim::kPrimTupleGetItem),
value_node_tuple,
value_node_1
};
auto get_item = make_get_const->NewCNode(node_list);
make_get_const->set_output(get_item);
FuncGraphPtr after_2 = std::make_shared<FuncGraph>();
after_2->set_output(value_node_2);
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_}); auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_});
ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns)); ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns));
ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns));
} }
TEST_F(TestOptLib, test_tuple_setitem) { TEST_F(TestOptLib, test_tuple_setitem) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册