Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4b5cbe5d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4b5cbe5d
编写于
6月 17, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 17, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2186 Optimization for opt
Merge pull request !2186 from Kang/opt
上级
b391eb2b
2974b906
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
93 addition
and
73 deletion
+93
-73
mindspore/ccsrc/optimizer/irpass.cc
mindspore/ccsrc/optimizer/irpass.cc
+5
-9
mindspore/ccsrc/optimizer/irpass.h
mindspore/ccsrc/optimizer/irpass.h
+9
-33
mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h
mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h
+27
-0
mindspore/ccsrc/optimizer/irpass/gradient_eliminate.h
mindspore/ccsrc/optimizer/irpass/gradient_eliminate.h
+0
-15
mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h
mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h
+25
-0
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
+4
-1
mindspore/ccsrc/optimizer/opt.cc
mindspore/ccsrc/optimizer/opt.cc
+11
-2
mindspore/ccsrc/pipeline/pass.cc
mindspore/ccsrc/pipeline/pass.cc
+10
-11
tests/ut/cpp/optimizer/lib_test.cc
tests/ut/cpp/optimizer/lib_test.cc
+2
-2
未找到文件。
mindspore/ccsrc/optimizer/irpass.cc
浏览文件 @
4b5cbe5d
...
...
@@ -51,8 +51,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
prim
::
kPrimIdentity
,
prim
::
kPrimMomentum
,
prim
::
kPrimMul
});
special_op_eliminate_
=
MakeSubstitution
(
SpecialOpEliminater
(),
"special_op_eliminate"
,
{
prim
::
kPrimInsertGradientOf
,
prim
::
kPrim
HookBackward
,
prim
::
kPrimPrintShapeType
,
prim
::
kPrimGetRefKey
,
prim
::
kPrimMirror
,
prim
::
kPrimVirtualDiv
});
{
prim
::
kPrimInsertGradientOf
,
prim
::
kPrim
StopGradient
,
prim
::
kPrimHookBackward
,
prim
::
kPrim
PrintShapeType
,
prim
::
kPrim
GetRefKey
,
prim
::
kPrimMirror
,
prim
::
kPrimVirtualDiv
});
zero_like_fill_zero_
=
MakeSubstitution
(
ZeroLikeFillZero
(),
"zero_like_fill_zero"
,
prim
::
kPrimZerosLike
);
adjust_all_reduce_mul_add_
=
MakeSubstitution
(
AdjustAllReduceMulAdd
(),
"adjust_all_reduce_mul_add"
,
prim
::
kPrimAddN
);
...
...
@@ -72,9 +72,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
reset_defer_inline_
=
MakeSubstitution
(
ResetDeferInline
(),
"reset_defer_inline"
,
IsValueNode
<
FuncGraph
>
);
// Env Item Eliminate
env_get_item_eliminate_
=
MakeSubstitution
(
EnvGetItemEliminater
(),
"env_get_item_eliminate"
,
prim
::
kPrimEnvGetItem
);
new_env_get_item_
=
MakeSubstitution
(
NewEnvGetItem
(),
"new_env_get_item"
,
prim
::
kPrimEnvGetItem
);
add_env_get_item_
=
MakeSubstitution
(
AddEnvGetItem
(),
"add_env_get_item"
,
prim
::
kPrimEnvGetItem
);
env_get_set_item_
=
MakeSubstitution
(
EnvGetSetItem
(),
"env_get_set_item"
,
prim
::
kPrimEnvGetItem
);
incorporate_env_getitem_
=
MakeSubstitution
(
IncorporateEnvGetitem
(),
"incorporate_env_get_item"
,
prim
::
kPrimEnvGetItem
);
incorporate_env_getitem_switch_
=
...
...
@@ -91,8 +90,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Gradient transforms
expand_jprim_
=
MakeSubstitution
(
ExpandJPrim
(),
"expand_jprim"
,
prim
::
kPrimJ
);
stop_gradient_eliminate_
=
MakeSubstitution
(
StopGradientEliminater
(),
"stop_gradient_eliminate"
,
prim
::
kPrimStopGradient
);
minmaximum_grad_
=
MakeSubstitution
(
MinMaximumGrad
(),
"minmaximum_grad"
,
prim
::
kPrimTupleGetItem
);
// branch culling
...
...
@@ -113,9 +110,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
specialize_transform_
=
MakeSubstitution
(
SpecializeOnGraphArguments
(),
"specialize_transform"
,
IsCNodeGraph
);
// Incorporation
incorporate_getitem_
=
MakeSubstitution
(
IncorporateGetitem
(),
"incorporate_getitem"
,
prim
::
kPrimTupleGetItem
);
incorporate_getitem_switch_
=
MakeSubstitution
(
IncorporateGetitemSwitch
(),
"incorporate_getitem_switch"
,
prim
::
kPrimTupleGetItem
);
incorporate_getitem_set_
=
MakeSubstitution
(
IncorporateGetitemSet
(),
"incorporate_getitem_set"
,
prim
::
kPrimTupleGetItem
);
incorporate_call_
=
MakeSubstitution
(
IncorporateCall
(),
"incorporate_call"
,
IsCNodeDup
);
incorporate_call_switch_
=
MakeSubstitution
(
IncorporateCallSwitch
(),
"incorporate_call_switch"
,
IsCNodeDup
);
...
...
mindspore/ccsrc/optimizer/irpass.h
浏览文件 @
4b5cbe5d
...
...
@@ -50,9 +50,8 @@ class OptimizeIRPassLib {
SubstitutionPtr
reset_defer_inline_
;
// Env Item Eliminate
SubstitutionPtr
env_get_item_eliminate_
;
SubstitutionPtr
new_env_get_item_
;
SubstitutionPtr
add_env_get_item_
;
SubstitutionPtr
env_get_set_item_
;
SubstitutionPtr
incorporate_env_getitem_
;
SubstitutionPtr
incorporate_env_getitem_switch_
;
...
...
@@ -74,7 +73,6 @@ class OptimizeIRPassLib {
// Gradient irpasses
SubstitutionPtr
expand_jprim_
;
SubstitutionPtr
stop_gradient_eliminate_
;
SubstitutionPtr
minmaximum_grad_
;
// inline
...
...
@@ -83,8 +81,7 @@ class OptimizeIRPassLib {
SubstitutionPtr
specialize_transform_
;
// Incorporation
SubstitutionPtr
incorporate_getitem_
;
SubstitutionPtr
incorporate_getitem_switch_
;
SubstitutionPtr
incorporate_getitem_set_
;
SubstitutionPtr
incorporate_call_
;
SubstitutionPtr
incorporate_call_switch_
;
...
...
@@ -115,51 +112,30 @@ class InferenceOptPrepareLib {
// predicate functions
inline
bool
IsNode
(
const
AnfNodePtr
&
)
{
return
true
;
}
inline
bool
IsCNode
(
const
AnfNodePtr
&
node
)
{
if
(
node
!=
nullptr
)
{
return
node
->
isa
<
CNode
>
();
}
return
false
;
}
inline
bool
IsCNode
(
const
AnfNodePtr
&
node
)
{
return
node
->
isa
<
CNode
>
();
}
inline
bool
IsVNode
(
const
AnfNodePtr
&
node
)
{
if
(
node
!=
nullptr
)
{
return
node
->
isa
<
ValueNode
>
();
}
return
false
;
}
inline
bool
IsVNode
(
const
AnfNodePtr
&
node
)
{
return
node
->
isa
<
ValueNode
>
();
}
inline
bool
IsParam
(
const
AnfNodePtr
&
node
)
{
if
(
node
!=
nullptr
)
{
return
node
->
isa
<
Parameter
>
();
}
return
false
;
}
inline
bool
IsParam
(
const
AnfNodePtr
&
node
)
{
return
node
->
isa
<
Parameter
>
();
}
// Check if CNode Input 0 is Func Graph
inline
bool
IsCNodeGraph
(
const
AnfNodePtr
&
node
)
{
if
(
node
==
nullptr
||
!
node
->
isa
<
CNode
>
())
{
if
(
!
node
->
isa
<
CNode
>
())
{
return
false
;
}
auto
inp0
=
node
->
cast
<
CNodePtr
>
()
->
input
(
0
);
if
(
IsValueNode
<
FuncGraph
>
(
inp0
))
{
return
true
;
}
return
false
;
return
IsValueNode
<
FuncGraph
>
(
inp0
);
}
// Check if CNode Input 0 is CNode
inline
bool
IsCNodeDup
(
const
AnfNodePtr
&
node
)
{
if
(
node
==
nullptr
||
!
node
->
isa
<
CNode
>
())
{
if
(
!
node
->
isa
<
CNode
>
())
{
return
false
;
}
auto
inp0
=
node
->
cast
<
CNodePtr
>
()
->
input
(
0
);
if
(
inp0
!=
nullptr
&&
inp0
->
isa
<
CNode
>
())
{
return
true
;
}
return
false
;
return
(
inp0
!=
nullptr
)
&&
inp0
->
isa
<
CNode
>
();
}
}
// namespace irpass
}
// namespace opt
...
...
mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h
浏览文件 @
4b5cbe5d
...
...
@@ -225,6 +225,33 @@ class EnvGetSetItem : public AnfVisitor {
bool
is_match_
{
false
};
};
class
EnvGetItemEliminater
{
public:
EnvGetItemEliminater
()
:
new_env_get_item_
(),
add_env_get_item_
(),
env_get_set_item_
()
{
eliminaters_
.
emplace_back
(
new_env_get_item_
);
eliminaters_
.
emplace_back
(
add_env_get_item_
);
eliminaters_
.
emplace_back
(
env_get_set_item_
);
}
~
EnvGetItemEliminater
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
eliminater
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
return
new_node
;
}
}
return
nullptr
;
}
private:
NewEnvGetItem
new_env_get_item_
;
AddEnvGetItem
add_env_get_item_
;
EnvGetSetItem
env_get_set_item_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
};
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
class
IncorporateEnvGetitem
:
public
AnfVisitor
{
public:
...
...
mindspore/ccsrc/optimizer/irpass/gradient_eliminate.h
浏览文件 @
4b5cbe5d
...
...
@@ -55,21 +55,6 @@ class ExpandJPrim : public AnfVisitor {
private:
ValueNodePtr
x_
{
nullptr
};
};
// stop_gradient(x) ==> x
class
StopGradientEliminater
:
public
AnfVisitor
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
x_
=
nullptr
;
AnfVisitor
::
Match
(
prim
::
kPrimStopGradient
)(
node
);
return
x_
;
}
void
Visit
(
const
AnfNodePtr
&
node
)
override
{
x_
=
node
;
}
private:
AnfNodePtr
x_
{
nullptr
};
};
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
...
...
mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h
浏览文件 @
4b5cbe5d
...
...
@@ -197,6 +197,31 @@ class IncorporateGetitemSwitch : public AnfVisitor {
std
::
vector
<
AnfNodePtr
>
args_
{};
internal
::
GetitemTransform
getitem_transform_
;
};
class
IncorporateGetitemSet
{
public:
IncorporateGetitemSet
()
:
incorporate_getitem_
(),
incorporate_getitem_switch_
()
{
eliminaters_
.
emplace_back
(
incorporate_getitem_
);
eliminaters_
.
emplace_back
(
incorporate_getitem_switch_
);
}
~
IncorporateGetitemSet
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
eliminater
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
return
new_node
;
}
}
return
nullptr
;
}
private:
IncorporateGetitem
incorporate_getitem_
;
IncorporateGetitemSwitch
incorporate_getitem_switch_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
};
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
...
...
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
浏览文件 @
4b5cbe5d
...
...
@@ -35,12 +35,14 @@ class SpecialOpEliminater {
public:
SpecialOpEliminater
()
:
insert_gradient_of_
(
prim
::
kPrimInsertGradientOf
),
stop_gradient_
(
prim
::
kPrimStopGradient
),
hook_backward_
(
prim
::
kPrimHookBackward
),
print_shape_type_
(
prim
::
kPrimPrintShapeType
),
get_ref_value_
(
prim
::
kPrimGetRefValue
),
mirror_
(
prim
::
kPrimMirror
),
virtual_div_
(
prim
::
kPrimVirtualDiv
)
{
eliminaters_
.
emplace_back
(
insert_gradient_of_
);
eliminaters_
.
emplace_back
(
stop_gradient_
);
eliminaters_
.
emplace_back
(
hook_backward_
);
eliminaters_
.
emplace_back
(
print_shape_type_
);
eliminaters_
.
emplace_back
(
get_ref_value_
);
...
...
@@ -61,7 +63,8 @@ class SpecialOpEliminater {
}
private:
PrimEliminater
insert_gradient_of_
,
hook_backward_
,
print_shape_type_
,
get_ref_value_
,
mirror_
,
virtual_div_
;
PrimEliminater
insert_gradient_of_
,
stop_gradient_
,
hook_backward_
,
print_shape_type_
,
get_ref_value_
,
mirror_
,
virtual_div_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
};
...
...
mindspore/ccsrc/optimizer/opt.cc
浏览文件 @
4b5cbe5d
...
...
@@ -44,8 +44,17 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::
return
false
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
auto
inp0
=
cnode
->
input
(
0
);
auto
prim0
=
GetValueNode
<
PrimitivePtr
>
(
inp0
);
if
(
prim0
==
nullptr
)
{
return
false
;
}
auto
hash
=
prim0
->
Hash
();
auto
const
&
name
=
prim0
->
name
();
for
(
auto
&
prim
:
prims
)
{
if
(
IsPrimitiveCNode
(
node
,
prim
))
{
if
(
hash
==
prim
->
Hash
()
&&
name
==
prim
->
name
(
))
{
return
true
;
}
}
...
...
@@ -172,7 +181,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
}
#ifdef ENABLE_PROFILE
MsProfile
::
StatTime
(
"opt.transform
"
,
GetTime
()
-
start
);
MsProfile
::
StatTime
(
"opt.transform
."
+
optimizer
->
name
()
,
GetTime
()
-
start
);
#endif
return
changes
;
}
...
...
mindspore/ccsrc/pipeline/pass.cc
浏览文件 @
4b5cbe5d
...
...
@@ -79,16 +79,9 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
// Specialization
irpass
.
specialize_transform_
,
// Arithmetic simplifications
irpass
.
arithmetic_simplify_
,
irpass
.
addn_zero_filter_
,
irpass
.
adjust_all_reduce_mul_add_
,
// Miscellaneous
irpass
.
item_tuple_eliminate_
,
irpass
.
env_get_set_item_
,
irpass
.
new_env_get_item_
,
irpass
.
add_env_get_item_
,
irpass
.
env_get_item_eliminate_
,
irpass
.
cast_eliminate_
,
irpass
.
reshape_eliminate_
,
irpass
.
reduce_eliminate_
,
...
...
@@ -96,13 +89,20 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass
.
transpose_eliminate_
,
irpass
.
minmaximum_grad_
,
irpass
.
get_make_ref_eliminate_
,
// Arithmetic simplifications
irpass
.
arithmetic_simplify_
,
irpass
.
addn_zero_filter_
,
irpass
.
adjust_all_reduce_mul_add_
,
// Safe inlining
irpass
.
inline_
,
});
opt
::
OptPassConfig
a_2
=
opt
::
OptPassConfig
({
irpass
.
merge_addn_
,
irpass
.
float_tuple_getitem_switch_
,
irpass
.
float_env_getitem_switch_
,
irpass
.
incorporate_getitem_
,
irpass
.
incorporate_getitem_switch_
,
irpass
.
incorporate_getitem_set_
,
irpass
.
incorporate_call_
,
irpass
.
incorporate_call_switch_
,
irpass
.
incorporate_env_getitem_
,
...
...
@@ -145,7 +145,6 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass
.
reset_defer_inline_
,
irpass
.
inline_
,
irpass
.
special_op_eliminate_
,
irpass
.
stop_gradient_eliminate_
,
irpass
.
get_make_ref_eliminate_
,
});
opt
::
OptPassConfig
b_2
=
opt
::
OptPassConfig
({
...
...
tests/ut/cpp/optimizer/lib_test.cc
浏览文件 @
4b5cbe5d
...
...
@@ -401,7 +401,7 @@ TEST_F(TestOptLib, test_incorporate_getitem) {
FuncGraphPtr
after1
=
getPyFun
.
CallAndParseRet
(
"test_incorporate_getitem"
,
"after1"
);
FuncGraphPtr
after2
=
getPyFun
.
CallAndParseRet
(
"test_incorporate_getitem"
,
"after2"
);
auto
patterns
=
std
::
vector
<
SubstitutionPtr
>
({
irpass
.
incorporate_getitem_
});
auto
patterns
=
std
::
vector
<
SubstitutionPtr
>
({
irpass
.
incorporate_getitem_
set_
});
ASSERT_TRUE
(
CheckOpt
(
before1
,
after1
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
before2
,
after2
,
patterns
));
...
...
@@ -411,7 +411,7 @@ TEST_F(TestOptLib, test_incorporate_getitem_through_switch) {
FuncGraphPtr
before
=
getPyFun
.
CallAndParseRet
(
"test_incorporate_getitem_through_switch"
,
"before"
);
FuncGraphPtr
after
=
getPyFun
.
CallAndParseRet
(
"test_incorporate_getitem_through_switch"
,
"after"
);
auto
patterns
=
std
::
vector
<
SubstitutionPtr
>
({
irpass
.
incorporate_getitem_s
witch
_
});
auto
patterns
=
std
::
vector
<
SubstitutionPtr
>
({
irpass
.
incorporate_getitem_s
et
_
});
ASSERT_TRUE
(
CheckOpt
(
before
,
after
,
patterns
));
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录