提交 62bbf560 编写于 作者: B biffex

constant duplicate mul for momentum

上级 cc75cb35
...@@ -45,9 +45,9 @@ namespace mindspore { ...@@ -45,9 +45,9 @@ namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
OptimizeIRPassLib::OptimizeIRPassLib() { OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution( arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
ArithmeticSimplify(), "arithmetic_simplify", {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, prim::kPrimIdentity, prim::kPrimMomentum}); prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
......
...@@ -179,6 +179,55 @@ class OptUpdateZeroTensor : public AnfVisitor { ...@@ -179,6 +179,55 @@ class OptUpdateZeroTensor : public AnfVisitor {
} }
}; };
// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} ->
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
class ConstantDuplicateMul : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
// {prim::kPrimMul, Tensor1, {...}}
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node);
if (vnode_ == nullptr || cnode_ == nullptr) {
return nullptr;
}
auto tensor1 = vnode_;
auto mul = cnode_;
Reset();
// {prim::kPrimMul, Tensor2, {...}}
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul);
if (vnode_ == nullptr || cnode_ == nullptr) {
return nullptr;
}
auto tensor2 = vnode_;
auto cnode = cnode_;
auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0));
auto fg = node->func_graph();
auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg);
return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg);
}
void Visit(const AnfNodePtr &node) override {
if (IsValueNode<tensor::Tensor>(node)) {
vnode_ = node;
}
if (IsCNode(node)) {
cnode_ = node->cast<CNodePtr>();
}
}
void Reset() {
vnode_ = nullptr;
cnode_ = nullptr;
}
private:
AnfNodePtr vnode_;
CNodePtr cnode_;
};
class ArithmeticSimplify { class ArithmeticSimplify {
public: public:
ArithmeticSimplify() ArithmeticSimplify()
...@@ -186,12 +235,14 @@ class ArithmeticSimplify { ...@@ -186,12 +235,14 @@ class ArithmeticSimplify {
add_by_zero_(), add_by_zero_(),
tensor_add_by_zero_(), tensor_add_by_zero_(),
identity_(prim::kPrimIdentity), identity_(prim::kPrimIdentity),
opt_update_zero_tensor_() { opt_update_zero_tensor_(),
constant_duplicate_mul_() {
eliminaters_.emplace_back(multiply_by_zero_or_one_); eliminaters_.emplace_back(multiply_by_zero_or_one_);
eliminaters_.emplace_back(add_by_zero_); eliminaters_.emplace_back(add_by_zero_);
eliminaters_.emplace_back(tensor_add_by_zero_); eliminaters_.emplace_back(tensor_add_by_zero_);
eliminaters_.emplace_back(identity_); eliminaters_.emplace_back(identity_);
eliminaters_.emplace_back(opt_update_zero_tensor_); eliminaters_.emplace_back(opt_update_zero_tensor_);
eliminaters_.emplace_back(constant_duplicate_mul_);
} }
~ArithmeticSimplify() = default; ~ArithmeticSimplify() = default;
...@@ -212,6 +263,7 @@ class ArithmeticSimplify { ...@@ -212,6 +263,7 @@ class ArithmeticSimplify {
TensorAddByZero tensor_add_by_zero_; TensorAddByZero tensor_add_by_zero_;
PrimEliminater identity_; PrimEliminater identity_;
OptUpdateZeroTensor opt_update_zero_tensor_; OptUpdateZeroTensor opt_update_zero_tensor_;
ConstantDuplicateMul constant_duplicate_mul_;
std::vector<TransformFuncType> eliminaters_{}; std::vector<TransformFuncType> eliminaters_{};
}; };
} // namespace irpass } // namespace irpass
......
...@@ -400,6 +400,8 @@ static bool SameNodeShallow(const AnfNodePtr& node1, const AnfNodePtr& node2, Fu ...@@ -400,6 +400,8 @@ static bool SameNodeShallow(const AnfNodePtr& node1, const AnfNodePtr& node2, Fu
auto a2 = GetValueNode(node2); auto a2 = GetValueNode(node2);
if (a1->isa<Primitive>() && a2->isa<Primitive>()) { if (a1->isa<Primitive>() && a2->isa<Primitive>()) {
return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name(); return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name();
} else if (a1->isa<tensor::Tensor>() && a2->isa<tensor::Tensor>()) {
return a1->cast<tensor::TensorPtr>()->ValueEqual(*(a2->cast<tensor::TensorPtr>()));
} else { } else {
return *a1 == *a2; return *a1 == *a2;
} }
......
...@@ -771,6 +771,14 @@ class Mul(_MathBinaryOp): ...@@ -771,6 +771,14 @@ class Mul(_MathBinaryOp):
>>> mul(input_x, input_y) >>> mul(input_x, input_y)
[4, 10, 18] [4, 10, 18]
""" """
def infer_value(self, x, y):
if x is not None and y is not None:
x = x.asnumpy()
y = y.asnumpy()
out = x * y
out = np.array(out, x.dtype)
return Tensor(out)
return None
class Square(PrimitiveWithInfer): class Square(PrimitiveWithInfer):
......
...@@ -543,5 +543,18 @@ TEST_F(TestOptLib, test_print_tuple_wrapper) { ...@@ -543,5 +543,18 @@ TEST_F(TestOptLib, test_print_tuple_wrapper) {
ASSERT_TRUE(CheckOpt(before2, after2, patterns)); ASSERT_TRUE(CheckOpt(before2, after2, patterns));
ASSERT_TRUE(CheckOpt(before3, before3, patterns)); ASSERT_TRUE(CheckOpt(before3, before3, patterns));
} }
TEST_F(TestOptLib, test_constant_duplicate_mul) {
FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforell");
FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforelr");
FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerl");
FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerr");
FuncGraphPtr after = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "after");
auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_});
ASSERT_TRUE(CheckOpt(beforell, after, patterns));
ASSERT_TRUE(CheckOpt(beforelr, after, patterns));
ASSERT_TRUE(CheckOpt(beforerl, after, patterns));
ASSERT_TRUE(CheckOpt(beforerr, after, patterns));
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
from mindspore.ops import Primitive, PrimitiveWithInfer from mindspore.ops import Primitive, PrimitiveWithInfer
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
from mindspore import Tensor
import numpy as np
# pylint: disable=unused-variable # pylint: disable=unused-variable
...@@ -903,3 +905,34 @@ def test_print_tuple_wrapper(tag): ...@@ -903,3 +905,34 @@ def test_print_tuple_wrapper(tag):
return print_(make_tuple(x, y, z)) return print_(make_tuple(x, y, z))
return fns[tag] return fns[tag]
def test_constant_duplicate_mul(tag):
fns = FnDict()
Mul = Primitive('Mul');
Sqrt = Primitive('Sqrt');
x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32'))
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
tensor2 = Tensor(np.array([[2.2, 3.1], [3.2, 4.2]]).astype('float32'))
@fns
def beforell():
return Mul(tensor1, Mul(tensor2, Sqrt(x)))
@fns
def beforelr():
return Mul(tensor1, Mul(Sqrt(x), tensor2))
@fns
def beforerl():
return Mul(Mul(Sqrt(x), tensor2), tensor1)
@fns
def beforerr():
return Mul(Mul(Sqrt(x), tensor2), tensor1)
@fns
def after():
return Mul(Sqrt(x), Mul(tensor1, tensor2))
return fns[tag]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册