Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
cc1416bf
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cc1416bf
编写于
4月 09, 2020
作者:
B
biffex
提交者:
高东海
4月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
constant duplicate mul for momentum
上级
d87fc50e
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
112 addition
and
4 deletion
+112
-4
mindspore/ccsrc/optimizer/irpass.cc
mindspore/ccsrc/optimizer/irpass.cc
+3
-3
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
+53
-1
mindspore/ccsrc/utils/graph_utils.cc
mindspore/ccsrc/utils/graph_utils.cc
+2
-0
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+8
-0
tests/ut/cpp/optimizer/lib_test.cc
tests/ut/cpp/optimizer/lib_test.cc
+13
-0
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py
+33
-0
未找到文件。
mindspore/ccsrc/optimizer/irpass.cc
浏览文件 @
cc1416bf
...
...
@@ -45,9 +45,9 @@ namespace mindspore {
namespace
opt
{
namespace
irpass
{
OptimizeIRPassLib
::
OptimizeIRPassLib
()
{
arithmetic_simplify_
=
MakeSubstitution
(
ArithmeticSimplify
(),
"arithmetic_simplify"
,
{
prim
::
kPrimScalarAdd
,
prim
::
kPrimScalarMul
,
prim
::
kPrimTensorAdd
,
prim
::
kPrimIdentity
,
prim
::
kPrimMomentum
});
arithmetic_simplify_
=
MakeSubstitution
(
ArithmeticSimplify
(),
"arithmetic_simplify"
,
{
prim
::
kPrimScalarAdd
,
prim
::
kPrimScalarMul
,
prim
::
kPrimTensorAdd
,
prim
::
kPrimIdentity
,
prim
::
kPrimMomentum
,
prim
::
kPrimMul
});
special_op_eliminate_
=
MakeSubstitution
(
SpecialOpEliminater
(),
"special_op_eliminate"
,
{
prim
::
kPrimInsertGradientOf
,
prim
::
kPrimPrintShapeType
,
prim
::
kPrimGetRefKey
,
prim
::
kPrimMirror
,
prim
::
kPrimVirtualDiv
});
...
...
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
浏览文件 @
cc1416bf
...
...
@@ -179,6 +179,55 @@ class OptUpdateZeroTensor : public AnfVisitor {
}
};
// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} ->
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
class
ConstantDuplicateMul
:
public
AnfVisitor
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
Reset
();
// {prim::kPrimMul, Tensor1, {...}}
AnfVisitor
::
Match
(
prim
::
kPrimMul
,
{
IsNode
,
IsNode
})(
node
);
if
(
vnode_
==
nullptr
||
cnode_
==
nullptr
)
{
return
nullptr
;
}
auto
tensor1
=
vnode_
;
auto
mul
=
cnode_
;
Reset
();
// {prim::kPrimMul, Tensor2, {...}}
AnfVisitor
::
Match
(
prim
::
kPrimMul
,
{
IsNode
,
IsNode
})(
mul
);
if
(
vnode_
==
nullptr
||
cnode_
==
nullptr
)
{
return
nullptr
;
}
auto
tensor2
=
vnode_
;
auto
cnode
=
cnode_
;
auto
PrimMul
=
GetValueNode
<
PrimitivePtr
>
(
mul
->
input
(
0
));
auto
fg
=
node
->
func_graph
();
auto
ttmul
=
NewCNode
({
NewValueNode
(
PrimMul
),
tensor1
,
tensor2
},
fg
);
return
NewCNode
({
NewValueNode
(
PrimMul
),
cnode
,
ttmul
},
fg
);
}
void
Visit
(
const
AnfNodePtr
&
node
)
override
{
if
(
IsValueNode
<
tensor
::
Tensor
>
(
node
))
{
vnode_
=
node
;
}
if
(
IsCNode
(
node
))
{
cnode_
=
node
->
cast
<
CNodePtr
>
();
}
}
void
Reset
()
{
vnode_
=
nullptr
;
cnode_
=
nullptr
;
}
private:
AnfNodePtr
vnode_
;
CNodePtr
cnode_
;
};
class
ArithmeticSimplify
{
public:
ArithmeticSimplify
()
...
...
@@ -186,12 +235,14 @@ class ArithmeticSimplify {
add_by_zero_
(),
tensor_add_by_zero_
(),
identity_
(
prim
::
kPrimIdentity
),
opt_update_zero_tensor_
()
{
opt_update_zero_tensor_
(),
constant_duplicate_mul_
()
{
eliminaters_
.
emplace_back
(
multiply_by_zero_or_one_
);
eliminaters_
.
emplace_back
(
add_by_zero_
);
eliminaters_
.
emplace_back
(
tensor_add_by_zero_
);
eliminaters_
.
emplace_back
(
identity_
);
eliminaters_
.
emplace_back
(
opt_update_zero_tensor_
);
eliminaters_
.
emplace_back
(
constant_duplicate_mul_
);
}
~
ArithmeticSimplify
()
=
default
;
...
...
@@ -212,6 +263,7 @@ class ArithmeticSimplify {
TensorAddByZero
tensor_add_by_zero_
;
PrimEliminater
identity_
;
OptUpdateZeroTensor
opt_update_zero_tensor_
;
ConstantDuplicateMul
constant_duplicate_mul_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
};
}
// namespace irpass
...
...
mindspore/ccsrc/utils/graph_utils.cc
浏览文件 @
cc1416bf
...
...
@@ -400,6 +400,8 @@ static bool SameNodeShallow(const AnfNodePtr& node1, const AnfNodePtr& node2, Fu
auto
a2
=
GetValueNode
(
node2
);
if
(
a1
->
isa
<
Primitive
>
()
&&
a2
->
isa
<
Primitive
>
())
{
return
a1
->
cast
<
PrimitivePtr
>
()
->
name
()
==
a2
->
cast
<
PrimitivePtr
>
()
->
name
();
}
else
if
(
a1
->
isa
<
tensor
::
Tensor
>
()
&&
a2
->
isa
<
tensor
::
Tensor
>
())
{
return
a1
->
cast
<
tensor
::
TensorPtr
>
()
->
ValueEqual
(
*
(
a2
->
cast
<
tensor
::
TensorPtr
>
()));
}
else
{
return
*
a1
==
*
a2
;
}
...
...
mindspore/ops/operations/math_ops.py
浏览文件 @
cc1416bf
...
...
@@ -774,6 +774,14 @@ class Mul(_MathBinaryOp):
>>> mul(input_x, input_y)
[4, 10, 18]
"""
def
infer_value
(
self
,
x
,
y
):
if
x
is
not
None
and
y
is
not
None
:
x
=
x
.
asnumpy
()
y
=
y
.
asnumpy
()
out
=
x
*
y
out
=
np
.
array
(
out
,
x
.
dtype
)
return
Tensor
(
out
)
return
None
class
Square
(
PrimitiveWithInfer
):
...
...
tests/ut/cpp/optimizer/lib_test.cc
浏览文件 @
cc1416bf
...
...
@@ -543,5 +543,18 @@ TEST_F(TestOptLib, test_print_tuple_wrapper) {
ASSERT_TRUE
(
CheckOpt
(
before2
,
after2
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
before3
,
before3
,
patterns
));
}
TEST_F
(
TestOptLib
,
test_constant_duplicate_mul
)
{
FuncGraphPtr
beforell
=
getPyFun
.
CallAndParseRet
(
"test_constant_duplicate_mul"
,
"beforell"
);
FuncGraphPtr
beforelr
=
getPyFun
.
CallAndParseRet
(
"test_constant_duplicate_mul"
,
"beforelr"
);
FuncGraphPtr
beforerl
=
getPyFun
.
CallAndParseRet
(
"test_constant_duplicate_mul"
,
"beforerl"
);
FuncGraphPtr
beforerr
=
getPyFun
.
CallAndParseRet
(
"test_constant_duplicate_mul"
,
"beforerr"
);
FuncGraphPtr
after
=
getPyFun
.
CallAndParseRet
(
"test_constant_duplicate_mul"
,
"after"
);
auto
patterns
=
std
::
vector
<
SubstitutionPtr
>
({
irpass
.
arithmetic_simplify_
});
ASSERT_TRUE
(
CheckOpt
(
beforell
,
after
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
beforelr
,
after
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
beforerl
,
after
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
beforerr
,
after
,
patterns
));
}
}
// namespace opt
}
// namespace mindspore
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py
浏览文件 @
cc1416bf
...
...
@@ -16,6 +16,8 @@
from
mindspore.ops
import
Primitive
,
PrimitiveWithInfer
from
mindspore.ops
import
operations
as
P
from
mindspore.ops.operations
import
_grad_ops
as
G
from
mindspore
import
Tensor
import
numpy
as
np
# pylint: disable=unused-variable
...
...
@@ -903,3 +905,34 @@ def test_print_tuple_wrapper(tag):
return
print_
(
make_tuple
(
x
,
y
,
z
))
return
fns
[
tag
]
def
test_constant_duplicate_mul
(
tag
):
fns
=
FnDict
()
Mul
=
Primitive
(
'Mul'
);
Sqrt
=
Primitive
(
'Sqrt'
);
x
=
Tensor
(
np
.
array
([[
2
,
2
],
[
2
,
3
]]).
astype
(
'float32'
))
tensor1
=
Tensor
(
np
.
array
([[
1.2
,
2.1
],
[
2.2
,
3.2
]]).
astype
(
'float32'
))
tensor2
=
Tensor
(
np
.
array
([[
2.2
,
3.1
],
[
3.2
,
4.2
]]).
astype
(
'float32'
))
@
fns
def
beforell
():
return
Mul
(
tensor1
,
Mul
(
tensor2
,
Sqrt
(
x
)))
@
fns
def
beforelr
():
return
Mul
(
tensor1
,
Mul
(
Sqrt
(
x
),
tensor2
))
@
fns
def
beforerl
():
return
Mul
(
Mul
(
Sqrt
(
x
),
tensor2
),
tensor1
)
@
fns
def
beforerr
():
return
Mul
(
Mul
(
Sqrt
(
x
),
tensor2
),
tensor1
)
@
fns
def
after
():
return
Mul
(
Sqrt
(
x
),
Mul
(
tensor1
,
tensor2
))
return
fns
[
tag
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录