diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 72177ccb06d10e0e42c3bf6406d096c62d68e7a7..0033e386d8a2b2a488473d0d25ca37e18b1419ac 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -51,6 +51,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); + arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul}); special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, diff --git a/mindspore/ccsrc/optimizer/irpass.h b/mindspore/ccsrc/optimizer/irpass.h index 5e1550c883a52f44437743c114900684050906b8..fa4d1e4cae49cba0b202518298b5213aa73f37bd 100644 --- a/mindspore/ccsrc/optimizer/irpass.h +++ b/mindspore/ccsrc/optimizer/irpass.h @@ -33,6 +33,7 @@ class OptimizeIRPassLib { ~OptimizeIRPassLib() = default; SubstitutionPtr arithmetic_simplify_; + SubstitutionPtr arithmetic_simplify2_; SubstitutionPtr special_op_eliminate_; SubstitutionPtr zero_like_fill_zero_; SubstitutionPtr adjust_all_reduce_mul_add_; diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h index 1836a88dbcdc8859d309e71df4855282ed645e69..ae44ec1f7dc84408182d873edd4e424da821f8e2 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h @@ -139,76 +139,8 @@ class CheckTensorConstant { int check_value_; }; -// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} -// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} -class TensorMultiplyByZeroOrOne : public AnfVisitor { - public: - TensorMultiplyByZeroOrOne() : zero_(MakeValue(0)) {} - ~TensorMultiplyByZeroOrOne() override = default; - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimMul)(node); - - if (is_zero_) { - if (x_->func_graph() != node->func_graph()) { - return nullptr; - } - return NewTensorFilledWithData(node); - } - if (is_one_) { - return NewTensorFilledWithData(node, x_); - } - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (is_zero_ || is_one_) { - x_ = node; - return; - } - - if (IsParam(node)) { - x_ = node; - return; - } - - if (IsCNode(node)) { - CNodePtr cnode = node->cast(); - if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { - is_zero_ = true; - return; - } - x_ = node; - return; - } - auto value = node->cast()->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } else if (CheckTensorConstant(1).IsTensorConstant(value)) { - is_one_ = true; - return; - } - x_ = node; - } - - void Visit(const ValueNodePtr &vnode) override { - auto value = vnode->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } else if (CheckTensorConstant(1).IsTensorConstant(value)) { - is_one_ = true; - return; - } - x_ = vnode; - } - void Reset() { - x_ = nullptr; - is_one_ = false; - is_zero_ = false; - } - +class TensorMultiplyBase : public AnfVisitor { + protected: void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) { if (!node->isa()) { return nullptr; @@ -287,10 +219,122 @@ class TensorMultiplyByZeroOrOne : public AnfVisitor { return new_vnode; } + AnfNodePtr x_{nullptr}; +}; + +// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} +class TensorMultiplyByZero : public TensorMultiplyBase { + public: + TensorMultiplyByZero() : zero_(MakeValue(0)) {} + ~TensorMultiplyByZero() override = default; + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimMul)(node); + + if (is_zero_) { + if (x_->func_graph() != node->func_graph()) { + return nullptr; + } + return NewTensorFilledWithData(node); + } + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (is_zero_) { + x_ = node; + return; + } + + if (IsParam(node)) { + x_ = node; + return; + } + + if (IsCNode(node)) { + CNodePtr cnode = node->cast(); + if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { + is_zero_ = true; + return; + } + x_ = node; + return; + } + auto value = node->cast()->value(); + if (CheckTensorConstant(0).IsTensorConstant(value)) { + is_zero_ = true; + return; + } + x_ = node; + } + + void Visit(const ValueNodePtr &vnode) override { + auto value = vnode->value(); + if (CheckTensorConstant(0).IsTensorConstant(value)) { + is_zero_ = true; + return; + } + x_ = vnode; + } + void Reset() { + x_ = nullptr; + is_zero_ = false; + } + private: - bool is_zero_{false}, is_one_{false}; + bool is_zero_{false}; ValuePtr zero_; - AnfNodePtr x_{nullptr}; +}; + +// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} +class TensorMultiplyByOne : public TensorMultiplyBase { + public: + TensorMultiplyByOne() {} + ~TensorMultiplyByOne() override = default; + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimMul)(node); + + if (is_one_) { + return NewTensorFilledWithData(node, x_); + } + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (is_one_) { + x_ = node; + return; + } + + if (IsParam(node) || IsCNode(node)) { + x_ = node; + return; + } + + auto value = node->cast()->value(); + if (CheckTensorConstant(1).IsTensorConstant(value)) { + is_one_ = true; + return; + } + x_ = node; + } + + void Visit(const ValueNodePtr &vnode) override { + auto value = vnode->value(); + if (CheckTensorConstant(1).IsTensorConstant(value)) { + is_one_ = true; + return; + } + x_ = vnode; + } + void Reset() { + x_ = nullptr; + is_one_ = false; + } + + private: + bool is_one_{false}; }; // {prim::kPrimScalarAdd, X, 0} @@ -699,7 +743,7 @@ class ArithmeticSimplify { public: ArithmeticSimplify() : multiply_by_zero_or_one_(), - tensor_multiply_by_zero_or_one_(), + tensor_multiply_by_one_(), add_by_zero_(), tensor_add_by_zero_(), identity_(prim::kPrimIdentity), @@ -707,7 +751,7 @@ class ArithmeticSimplify { constant_duplicate_mul_(), power_one_() { eliminaters_.emplace_back(multiply_by_zero_or_one_); - eliminaters_.emplace_back(tensor_multiply_by_zero_or_one_); + eliminaters_.emplace_back(tensor_multiply_by_one_); eliminaters_.emplace_back(add_by_zero_); eliminaters_.emplace_back(tensor_add_by_zero_); eliminaters_.emplace_back(identity_); @@ -730,7 +774,7 @@ class ArithmeticSimplify { private: MultiplyByZeroOrOne multiply_by_zero_or_one_; - TensorMultiplyByZeroOrOne tensor_multiply_by_zero_or_one_; + TensorMultiplyByOne tensor_multiply_by_one_; AddByZero add_by_zero_; TensorAddByZero tensor_add_by_zero_; PrimEliminater identity_; @@ -739,6 +783,32 @@ class ArithmeticSimplify { PowerOneEliminate power_one_; std::vector eliminaters_{}; }; + +// Arithmetic Simplifications should be done after step_parallel. +// eg: Mul(0, weight) where weight is a parameter will be simplified to a constant tensor +// with shape(weight), but after step_parallel, shape of weight may be changed, so the +// shape of the constant tensor should also be changed. So this pass is seperated from +// ArithmeticSimplify and deferred until step_parallel. +class ArithmeticSimplify2 { + public: + ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); } + ~ArithmeticSimplify2() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = eliminater(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; + } + + private: + TensorMultiplyByZero tensor_multiply_by_zero_; + std::vector eliminaters_{}; +}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h b/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h index 2693aec1c93b9c329ab6f9265a14a1c8e8f91520..21cdff51ad02d687d68b7d7e5c18e5ef6e161538 100644 --- a/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h @@ -70,6 +70,45 @@ class GetitemEliminater : public AnfVisitor { CNodePtr tuple_{nullptr}; }; +// (a, b, c, ...)[0] => a +// (a, b, c, ...)[1] => b +// {prim::kPrimTupleGetItem, C1, C} +class GetitemConstEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsVNode, IsVNode})(node); + + if (is_match_) { + return NewValueNode((*tuple_)[id_]); + } + return nullptr; + } + + void Visit(const ValueNodePtr &vnode) override { + if (IsValueNode(vnode)) { + tuple_ = GetValueNode(vnode); + } + if (tuple_ != nullptr && IsValueNode(vnode)) { + id_ = IntToSize(GetValue(vnode->value())); + if (tuple_->size() > id_) { + is_match_ = true; + } + } + } + + void Reset() { + id_ = 0; + tuple_ = nullptr; + is_match_ = false; + } + + private: + bool is_match_{false}; + size_t id_{0}; + ValueTuplePtr tuple_{nullptr}; +}; + // setitem((a, b, c, ...), 0, z) => (z, b, c, ...) // setitem((a, b, c, ...), 1, z) => (a, z, c, ...) // {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z} @@ -225,8 +264,13 @@ class GetitemDependReorder : public AnfVisitor { class ItemTupleEliminater { public: ItemTupleEliminater() - : get_item_eliminater_(), set_item_eliminater_(), get_set_item_eliminater_(), get_item_depend_reorder_() { + : get_item_eliminater_(), + get_item_const_eliminater_(), + set_item_eliminater_(), + get_set_item_eliminater_(), + get_item_depend_reorder_() { eliminaters_.emplace_back(get_item_eliminater_); + eliminaters_.emplace_back(get_item_const_eliminater_); eliminaters_.emplace_back(set_item_eliminater_); eliminaters_.emplace_back(get_set_item_eliminater_); eliminaters_.emplace_back(get_item_depend_reorder_); @@ -246,6 +290,7 @@ class ItemTupleEliminater { private: GetitemEliminater get_item_eliminater_; + GetitemConstEliminater get_item_const_eliminater_; SetitemEliminater set_item_eliminater_; GetSetitemEliminater get_set_item_eliminater_; GetitemDependReorder get_item_depend_reorder_; diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index 94063fb780d4da778ba7f2101930e0941db7fcac..9876c0280ad17f32971484c176df3131d618de2a 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -114,6 +114,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.depend_value_elim_, }); opt::OptPassConfig a_3 = opt::OptPassConfig({ + irpass.arithmetic_simplify2_, irpass.same_eliminate_, irpass.check_bprop_eliminate_, irpass.replace_applicator_, diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index ed4497f9a56c95e4be1c95992666cb9e94f0a094..ebbcdf6f7c53bd13df2ae4147337e2214af595fc 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -20,9 +20,12 @@ #include "common/py_func_graph_fetcher.h" #include "ir/anf.h" +#include "ir/func_graph.h" #include "ir/func_graph_cloner.h" #include "ir/manager.h" +#include "ir/value.h" #include "ir/visitor.h" +#include "operator/ops.h" #include "optimizer/irpass.h" #include "pipeline/resource.h" #include "debug/draw.h" @@ -343,9 +346,26 @@ TEST_F(TestOptLib, test_tuple_getitem) { FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_getitem", "after_0"); FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_getitem", "after_1"); + FuncGraphPtr make_get_const = std::make_shared(); + auto value_node_1 = NewValueNode(1); + auto value_node_2 = NewValueNode(2); + std::vector vec{1, 2}; + auto value_node_tuple = NewValueNode(MakeValue(vec)); + std::vector node_list{ + NewValueNode(prim::kPrimTupleGetItem), + value_node_tuple, + value_node_1 + }; + auto get_item = make_get_const->NewCNode(node_list); + make_get_const->set_output(get_item); + + FuncGraphPtr after_2 = std::make_shared(); + after_2->set_output(value_node_2); + auto patterns = std::vector({irpass.item_tuple_eliminate_}); ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns)); + ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns)); } TEST_F(TestOptLib, test_tuple_setitem) {