提交 b2ec296f 编写于 作者: Z zhousiyi

add opt pass for tuple_getitem with constant input

上级 5b14292f
......@@ -51,6 +51,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul});
special_op_eliminate_ =
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
......
......@@ -33,6 +33,7 @@ class OptimizeIRPassLib {
~OptimizeIRPassLib() = default;
SubstitutionPtr arithmetic_simplify_;
SubstitutionPtr arithmetic_simplify2_;
SubstitutionPtr special_op_eliminate_;
SubstitutionPtr zero_like_fill_zero_;
SubstitutionPtr adjust_all_reduce_mul_add_;
......
......@@ -139,76 +139,8 @@ class CheckTensorConstant {
int check_value_;
};
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
class TensorMultiplyByZeroOrOne : public AnfVisitor {
public:
TensorMultiplyByZeroOrOne() : zero_(MakeValue(0)) {}
~TensorMultiplyByZeroOrOne() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimMul)(node);
if (is_zero_) {
if (x_->func_graph() != node->func_graph()) {
return nullptr;
}
return NewTensorFilledWithData(node);
}
if (is_one_) {
return NewTensorFilledWithData(node, x_);
}
return nullptr;
}
void Visit(const AnfNodePtr &node) override {
if (is_zero_ || is_one_) {
x_ = node;
return;
}
if (IsParam(node)) {
x_ = node;
return;
}
if (IsCNode(node)) {
CNodePtr cnode = node->cast<CNodePtr>();
if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) {
is_zero_ = true;
return;
}
x_ = node;
return;
}
auto value = node->cast<ValueNodePtr>()->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
} else if (CheckTensorConstant(1).IsTensorConstant(value)) {
is_one_ = true;
return;
}
x_ = node;
}
void Visit(const ValueNodePtr &vnode) override {
auto value = vnode->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
} else if (CheckTensorConstant(1).IsTensorConstant(value)) {
is_one_ = true;
return;
}
x_ = vnode;
}
void Reset() {
x_ = nullptr;
is_one_ = false;
is_zero_ = false;
}
class TensorMultiplyBase : public AnfVisitor {
protected:
void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) {
if (!node->isa<ValueNode>()) {
return nullptr;
......@@ -287,10 +219,122 @@ class TensorMultiplyByZeroOrOne : public AnfVisitor {
return new_vnode;
}
AnfNodePtr x_{nullptr};
};
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
class TensorMultiplyByZero : public TensorMultiplyBase {
public:
TensorMultiplyByZero() : zero_(MakeValue(0)) {}
~TensorMultiplyByZero() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimMul)(node);
if (is_zero_) {
if (x_->func_graph() != node->func_graph()) {
return nullptr;
}
return NewTensorFilledWithData(node);
}
return nullptr;
}
void Visit(const AnfNodePtr &node) override {
if (is_zero_) {
x_ = node;
return;
}
if (IsParam(node)) {
x_ = node;
return;
}
if (IsCNode(node)) {
CNodePtr cnode = node->cast<CNodePtr>();
if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) {
is_zero_ = true;
return;
}
x_ = node;
return;
}
auto value = node->cast<ValueNodePtr>()->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
}
x_ = node;
}
void Visit(const ValueNodePtr &vnode) override {
auto value = vnode->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
}
x_ = vnode;
}
void Reset() {
x_ = nullptr;
is_zero_ = false;
}
private:
bool is_zero_{false}, is_one_{false};
bool is_zero_{false};
ValuePtr zero_;
AnfNodePtr x_{nullptr};
};
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
class TensorMultiplyByOne : public TensorMultiplyBase {
public:
TensorMultiplyByOne() {}
~TensorMultiplyByOne() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimMul)(node);
if (is_one_) {
return NewTensorFilledWithData(node, x_);
}
return nullptr;
}
void Visit(const AnfNodePtr &node) override {
if (is_one_) {
x_ = node;
return;
}
if (IsParam(node) || IsCNode(node)) {
x_ = node;
return;
}
auto value = node->cast<ValueNodePtr>()->value();
if (CheckTensorConstant(1).IsTensorConstant(value)) {
is_one_ = true;
return;
}
x_ = node;
}
void Visit(const ValueNodePtr &vnode) override {
auto value = vnode->value();
if (CheckTensorConstant(1).IsTensorConstant(value)) {
is_one_ = true;
return;
}
x_ = vnode;
}
void Reset() {
x_ = nullptr;
is_one_ = false;
}
private:
bool is_one_{false};
};
// {prim::kPrimScalarAdd, X, 0}
......@@ -699,7 +743,7 @@ class ArithmeticSimplify {
public:
ArithmeticSimplify()
: multiply_by_zero_or_one_(),
tensor_multiply_by_zero_or_one_(),
tensor_multiply_by_one_(),
add_by_zero_(),
tensor_add_by_zero_(),
identity_(prim::kPrimIdentity),
......@@ -707,7 +751,7 @@ class ArithmeticSimplify {
constant_duplicate_mul_(),
power_one_() {
eliminaters_.emplace_back(multiply_by_zero_or_one_);
eliminaters_.emplace_back(tensor_multiply_by_zero_or_one_);
eliminaters_.emplace_back(tensor_multiply_by_one_);
eliminaters_.emplace_back(add_by_zero_);
eliminaters_.emplace_back(tensor_add_by_zero_);
eliminaters_.emplace_back(identity_);
......@@ -730,7 +774,7 @@ class ArithmeticSimplify {
private:
MultiplyByZeroOrOne multiply_by_zero_or_one_;
TensorMultiplyByZeroOrOne tensor_multiply_by_zero_or_one_;
TensorMultiplyByOne tensor_multiply_by_one_;
AddByZero add_by_zero_;
TensorAddByZero tensor_add_by_zero_;
PrimEliminater identity_;
......@@ -739,6 +783,32 @@ class ArithmeticSimplify {
PowerOneEliminate power_one_;
std::vector<TransformFuncType> eliminaters_{};
};
// Arithmetic Simplifications should be done after step_parallel.
// eg: Mul(0, weight) where weight is a parameter will be simplified to a constant tensor
// with shape(weight), but after step_parallel, shape of weight may be changed, so the
// shape of the constant tensor should also be changed. So this pass is seperated from
// ArithmeticSimplify and deferred until step_parallel.
class ArithmeticSimplify2 {
public:
ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); }
~ArithmeticSimplify2() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
}
return nullptr;
}
private:
TensorMultiplyByZero tensor_multiply_by_zero_;
std::vector<TransformFuncType> eliminaters_{};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
......
......@@ -70,6 +70,45 @@ class GetitemEliminater : public AnfVisitor {
CNodePtr tuple_{nullptr};
};
// (a, b, c, ...)[0] => a
// (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, C1, C}
class GetitemConstEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsVNode, IsVNode})(node);
if (is_match_) {
return NewValueNode((*tuple_)[id_]);
}
return nullptr;
}
void Visit(const ValueNodePtr &vnode) override {
if (IsValueNode<ValueTuple>(vnode)) {
tuple_ = GetValueNode<ValueTuplePtr>(vnode);
}
if (tuple_ != nullptr && IsValueNode<Int32Imm>(vnode)) {
id_ = IntToSize(GetValue<int>(vnode->value()));
if (tuple_->size() > id_) {
is_match_ = true;
}
}
}
void Reset() {
id_ = 0;
tuple_ = nullptr;
is_match_ = false;
}
private:
bool is_match_{false};
size_t id_{0};
ValueTuplePtr tuple_{nullptr};
};
// setitem((a, b, c, ...), 0, z) => (z, b, c, ...)
// setitem((a, b, c, ...), 1, z) => (a, z, c, ...)
// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z}
......@@ -225,8 +264,13 @@ class GetitemDependReorder : public AnfVisitor {
class ItemTupleEliminater {
public:
ItemTupleEliminater()
: get_item_eliminater_(), set_item_eliminater_(), get_set_item_eliminater_(), get_item_depend_reorder_() {
: get_item_eliminater_(),
get_item_const_eliminater_(),
set_item_eliminater_(),
get_set_item_eliminater_(),
get_item_depend_reorder_() {
eliminaters_.emplace_back(get_item_eliminater_);
eliminaters_.emplace_back(get_item_const_eliminater_);
eliminaters_.emplace_back(set_item_eliminater_);
eliminaters_.emplace_back(get_set_item_eliminater_);
eliminaters_.emplace_back(get_item_depend_reorder_);
......@@ -246,6 +290,7 @@ class ItemTupleEliminater {
private:
GetitemEliminater get_item_eliminater_;
GetitemConstEliminater get_item_const_eliminater_;
SetitemEliminater set_item_eliminater_;
GetSetitemEliminater get_set_item_eliminater_;
GetitemDependReorder get_item_depend_reorder_;
......
......@@ -114,6 +114,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.depend_value_elim_,
});
opt::OptPassConfig a_3 = opt::OptPassConfig({
irpass.arithmetic_simplify2_,
irpass.same_eliminate_,
irpass.check_bprop_eliminate_,
irpass.replace_applicator_,
......
......@@ -20,9 +20,12 @@
#include "common/py_func_graph_fetcher.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/manager.h"
#include "ir/value.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "pipeline/resource.h"
#include "debug/draw.h"
......@@ -343,9 +346,26 @@ TEST_F(TestOptLib, test_tuple_getitem) {
FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_getitem", "after_0");
FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_getitem", "after_1");
FuncGraphPtr make_get_const = std::make_shared<FuncGraph>();
auto value_node_1 = NewValueNode(1);
auto value_node_2 = NewValueNode(2);
std::vector<int> vec{1, 2};
auto value_node_tuple = NewValueNode(MakeValue(vec));
std::vector<AnfNodePtr> node_list{
NewValueNode(prim::kPrimTupleGetItem),
value_node_tuple,
value_node_1
};
auto get_item = make_get_const->NewCNode(node_list);
make_get_const->set_output(get_item);
FuncGraphPtr after_2 = std::make_shared<FuncGraph>();
after_2->set_output(value_node_2);
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_});
ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns));
ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns));
}
TEST_F(TestOptLib, test_tuple_setitem) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册