提交 038bd6cb 编写于 作者: L lvliang

enable-the-identity-op-can-be-eliminated-in-pynative-mode

上级 641601f7
......@@ -20,9 +20,6 @@ namespace mindspore {
namespace opt {
namespace irpass {
AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
return nullptr;
}
PatternNode x, y, z, xs;
PConstant one_(node, false, 1);
PConstant one_scalar_(node, false, 1, true);
......@@ -32,14 +29,16 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
PConstant const_2(node);
PConstant any_const(node);
MATCH_REPLACE(node, x + zero_, x); // Add by zero
MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero
MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node)); // Multiply by one
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, one_scalar_, true), x); // Scalar Mul by one
if (MsContext::GetInstance()->execution_mode() != kPynativeMode) {
MATCH_REPLACE(node, x + zero_, x); // Add by zero
MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero
MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node)); // Multiply by one
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, one_scalar_, true), x); // Scalar Mul by one
// Scalar Mul by zero
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, zero_scalar_, true), zero_scalar_.NewValue());
// Scalar Mul by zero
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, zero_scalar_, true), zero_scalar_.NewValue());
}
// Prim Eliminate (identity)
MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册