Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
262e4fc0
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
262e4fc0
编写于
6月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2585 Replace TransformFuncType with OptimizerCaller
Merge pull request !2585 from Giancarlo/remove_transformfunc
上级
19f79cd7
aabec55c
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
226 addition
and
207 deletion
+226
-207
mindspore/ccsrc/ir/optimizer_caller.h
mindspore/ccsrc/ir/optimizer_caller.h
+11
-1
mindspore/ccsrc/optimizer/irpass.cc
mindspore/ccsrc/optimizer/irpass.cc
+89
-75
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
+28
-31
mindspore/ccsrc/optimizer/irpass/cast_eliminate.h
mindspore/ccsrc/optimizer/irpass/cast_eliminate.h
+3
-3
mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h
mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h
+16
-14
mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h
mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h
+15
-12
mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h
mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h
+16
-17
mindspore/ccsrc/optimizer/irpass/ref_eliminate.h
mindspore/ccsrc/optimizer/irpass/ref_eliminate.h
+2
-2
mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h
mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h
+6
-5
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
+18
-18
mindspore/ccsrc/optimizer/opt.cc
mindspore/ccsrc/optimizer/opt.cc
+9
-10
mindspore/ccsrc/optimizer/opt.h
mindspore/ccsrc/optimizer/opt.h
+9
-15
tests/ut/cpp/optimizer/opt_test.cc
tests/ut/cpp/optimizer/opt_test.cc
+4
-4
未找到文件。
mindspore/ccsrc/ir/optimizer_caller.h
浏览文件 @
262e4fc0
...
...
@@ -17,13 +17,23 @@
#ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#include <memory>
#include "ir/anf.h"
#include "optimizer/opt.h"
namespace
mindspore
{
namespace
opt
{
class
Optimizer
;
using
OptimizerPtr
=
std
::
shared_ptr
<
Optimizer
>
;
using
OptimizerWeakPtr
=
std
::
weak_ptr
<
Optimizer
>
;
using
PredicateFuncType
=
std
::
function
<
bool
(
const
AnfNodePtr
&
)
>
;
}
// namespace opt
class
OptimizerCaller
{
public:
virtual
AnfNodePtr
operator
()(
const
opt
::
OptimizerPtr
&
,
const
AnfNodePtr
&
)
{
return
nullptr
;
}
};
using
OptimizerCallerPtr
=
std
::
shared_ptr
<
OptimizerCaller
>
;
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
mindspore/ccsrc/optimizer/irpass.cc
浏览文件 @
262e4fc0
...
...
@@ -14,140 +14,154 @@
* limitations under the License.
*/
#include "optimizer/irpass.h"
#include <string>
#include "optimizer/irpass
/symbol_resolver
.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/arithmetic_simplify.h"
#include "optimizer/irpass/special_op_eliminate.h"
#include "optimizer/irpass/item_tuple_eliminate.h"
#include "optimizer/irpass/env_item_eliminate.h"
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/cast_eliminate.h"
#include "optimizer/irpass/reshape_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/irpass/reduce_eliminate.h"
#include "optimizer/irpass/partial_eliminate.h"
#include "optimizer/irpass/ref_eliminate.h"
#include "optimizer/irpass/merge_addn.h"
#include "optimizer/irpass/branch_culling.h"
#include "optimizer/irpass/cast_eliminate.h"
#include "optimizer/irpass/convert.h"
#include "optimizer/irpass/env_item_eliminate.h"
#include "optimizer/irpass/grad_var_prepare.h"
#include "optimizer/irpass/gradient_eliminate.h"
#include "optimizer/irpass/minmax_grad.h"
#include "optimizer/irpass/inline.h"
#include "optimizer/irpass/convert.h"
#include "optimizer/irpass/specialize_transform.h"
#include "optimizer/irpass/incorporate_getitem.h"
#include "optimizer/irpass/incorporate_call.h"
#include "optimizer/irpass/
grad_var_prepare
.h"
#include "optimizer/irpass/
param_replac
e.h"
#include "optimizer/irpass/
incorporate_getitem
.h"
#include "optimizer/irpass/
item_tuple_eliminat
e.h"
#include "optimizer/irpass/mark_interface_fusion.h"
#include "optimizer/irpass/merge_addn.h"
#include "optimizer/irpass/minmax_grad.h"
#include "optimizer/irpass/param_replace.h"
#include "optimizer/irpass/partial_eliminate.h"
#include "optimizer/irpass/reduce_eliminate.h"
#include "optimizer/irpass/ref_eliminate.h"
#include "optimizer/irpass/reshape_eliminate.h"
#include "optimizer/irpass/special_op_eliminate.h"
#include "optimizer/irpass/specialize_transform.h"
#include "optimizer/irpass/symbol_resolver.h"
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/opt.h"
namespace
mindspore
{
namespace
opt
{
namespace
irpass
{
OptimizeIRPassLib
::
OptimizeIRPassLib
()
{
arithmetic_simplify_
=
MakeSubstitution
(
ArithmeticSimplify
(),
"arithmetic_simplify"
,
arithmetic_simplify_
=
MakeSubstitution
(
std
::
make_shared
<
ArithmeticSimplify
>
(),
"arithmetic_simplify"
,
{
prim
::
kPrimScalarAdd
,
prim
::
kPrimScalarMul
,
prim
::
kPrimTensorAdd
,
prim
::
kPrimIdentity
,
prim
::
kPrimMomentum
,
prim
::
kPrimMul
,
prim
::
kPrimPow
});
arithmetic_simplify2_
=
MakeSubstitution
(
ArithmeticSimplify2
(),
"arithmetic_simplify2"
,
{
prim
::
kPrimMul
});
arithmetic_simplify2_
=
MakeSubstitution
(
std
::
make_shared
<
ArithmeticSimplify2
>
(),
"arithmetic_simplify2"
,
{
prim
::
kPrimMul
});
special_op_eliminate_
=
MakeSubstitution
(
SpecialOpEliminater
(),
"special_op_eliminate"
,
MakeSubstitution
(
std
::
make_shared
<
SpecialOpEliminater
>
(),
"special_op_eliminate"
,
{
prim
::
kPrimInsertGradientOf
,
prim
::
kPrimStopGradient
,
prim
::
kPrimHookBackward
,
prim
::
kPrimPrintShapeType
,
prim
::
kPrimGetRefKey
,
prim
::
kPrimMirror
,
prim
::
kPrimVirtualDiv
});
zero_like_fill_zero_
=
MakeSubstitution
(
ZeroLikeFillZero
(),
"zero_like_fill_zero"
,
prim
::
kPrimZerosLike
);
adjust_all_reduce_mul_add_
=
MakeSubstitution
(
AdjustAllReduceMulAdd
(),
"adjust_all_reduce_mul_add"
,
prim
::
kPrimAddN
);
zero_like_fill_zero_
=
MakeSubstitution
(
std
::
make_shared
<
ZeroLikeFillZero
>
(),
"zero_like_fill_zero"
,
prim
::
kPrimZerosLike
);
adjust_all_reduce_mul_add_
=
MakeSubstitution
(
std
::
make_shared
<
AdjustAllReduceMulAdd
>
(),
"adjust_all_reduce_mul_add"
,
prim
::
kPrimAddN
);
// ops eliminate
item_tuple_eliminate_
=
MakeSubstitution
(
ItemTupleEliminater
(),
"item_tuple_eliminate"
,
{
prim
::
kPrimTupleGetItem
,
prim
::
kPrimTupleSetItem
});
tile_eliminate_
=
MakeSubstitution
(
TileMultiplyByOne
(),
"tile_eliminate"
,
prim
::
kPrimTile
);
cast_eliminate_
=
MakeSubstitution
(
CastEliminater
(),
"cast_eliminate"
,
prim
::
kPrimCast
);
reshape_eliminate_
=
MakeSubstitution
(
ReshapeEliminater
(),
"reshape_eliminate"
,
prim
::
kPrimReshape
);
transpose_eliminate_
=
MakeSubstitution
(
TransposeSameIOEliminater
(),
"transpose_eliminate"
,
prim
::
kPrimTranspose
);
item_tuple_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
ItemTupleEliminater
>
(),
"item_tuple_eliminate"
,
{
prim
::
kPrimTupleGetItem
,
prim
::
kPrimTupleSetItem
});
tile_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
TileMultiplyByOne
>
(),
"tile_eliminate"
,
prim
::
kPrimTile
);
cast_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
CastEliminater
>
(),
"cast_eliminate"
,
prim
::
kPrimCast
);
reshape_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
ReshapeEliminater
>
(),
"reshape_eliminate"
,
prim
::
kPrimReshape
);
transpose_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
TransposeSameIOEliminater
>
(),
"transpose_eliminate"
,
prim
::
kPrimTranspose
);
reduce_eliminate_
=
MakeSubstitution
(
ReduceOneEliminater
(),
"reduce_eliminate"
,
std
::
make_shared
<
ReduceOneEliminater
>
(),
"reduce_eliminate"
,
{
prim
::
kPrimReduceMean
,
prim
::
kPrimReduceAll
,
prim
::
kPrimReduceSum
,
prim
::
kPrimReduceMax
,
prim
::
kPrimReduceMin
});
partial_eliminate_
=
MakeSubstitution
(
PartialEliminater
(),
"partial_eliminate"
,
IsCNodeDup
);
same_eliminate_
=
MakeSubstitution
(
SameEliminater
(),
"same_eliminate"
,
prim
::
kPrimSameTypeShape
);
check_bprop_eliminate_
=
MakeSubstitution
(
CheckBpropEliminater
(),
"check_bprop_eliminate"
,
prim
::
kPrimCheckBprop
);
reset_defer_inline_
=
MakeSubstitution
(
ResetDeferInline
(),
"reset_defer_inline"
,
IsValueNode
<
FuncGraph
>
);
depend_value_elim_
=
MakeSubstitution
(
DependValueElim
(),
"depend_value_elim"
,
prim
::
kPrimDepend
);
partial_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
PartialEliminater
>
(),
"partial_eliminate"
,
IsCNodeDup
);
same_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
SameEliminater
>
(),
"same_eliminate"
,
prim
::
kPrimSameTypeShape
);
check_bprop_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
CheckBpropEliminater
>
(),
"check_bprop_eliminate"
,
prim
::
kPrimCheckBprop
);
reset_defer_inline_
=
MakeSubstitution
(
std
::
make_shared
<
ResetDeferInline
>
(),
"reset_defer_inline"
,
IsValueNode
<
FuncGraph
>
);
depend_value_elim_
=
MakeSubstitution
(
std
::
make_shared
<
DependValueElim
>
(),
"depend_value_elim"
,
prim
::
kPrimDepend
);
// Env Item Eliminate
env_get_item_eliminate_
=
MakeSubstitution
(
EnvGetItemEliminater
(),
"env_get_item_eliminate"
,
prim
::
kPrimEnvGetItem
);
new_env_get_item_
=
MakeSubstitution
(
NewEnvGetItem
(),
"new_env_get_item"
,
prim
::
kPrimEnvGetItem
);
env_get_item_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
EnvGetItemEliminater
>
(),
"env_get_item_eliminate"
,
prim
::
kPrimEnvGetItem
);
new_env_get_item_
=
MakeSubstitution
(
std
::
make_shared
<
NewEnvGetItem
>
(),
"new_env_get_item"
,
prim
::
kPrimEnvGetItem
);
incorporate_env_getitem_
=
MakeSubstitution
(
IncorporateEnvGetitem
(),
"incorporate_env_get_item"
,
prim
::
kPrimEnvGetItem
);
incorporate_env_getitem_switch_
=
MakeSubstitution
(
IncorporateEnvGetitemSwitch
(),
"incorporate_env_getitem_switch"
,
prim
::
kPrimEnvGetItem
);
MakeSubstitution
(
std
::
make_shared
<
IncorporateEnvGetitem
>
(),
"incorporate_env_get_item"
,
prim
::
kPrimEnvGetItem
);
incorporate_env_getitem_switch_
=
MakeSubstitution
(
std
::
make_shared
<
IncorporateEnvGetitemSwitch
>
(),
"incorporate_env_getitem_switch"
,
prim
::
kPrimEnvGetItem
);
// Ref eliminate
make_ref_eliminate_
=
MakeSubstitution
(
MakeRefEliminater
(),
"make_ref_eliminate"
,
prim
::
kPrimMakeRef
);
get_ref_param_eliminate_
=
MakeSubstitution
(
GetRefParamEliminater
(),
"get_ref_param_eliminate"
,
make_ref_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
MakeRefEliminater
>
(),
"make_ref_eliminate"
,
prim
::
kPrimMakeRef
);
get_ref_param_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
GetRefParamEliminater
>
(),
"get_ref_param_eliminate"
,
{
prim
::
kPrimGetRefValue
,
prim
::
kPrimGetRefOrigin
});
get_make_ref_eliminate_
=
MakeSubstitution
(
GetMakeRefEliminater
(),
"get_make_ref_eliminate"
,
get_make_ref_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
GetMakeRefEliminater
>
(),
"get_make_ref_eliminate"
,
{
prim
::
kPrimGetRefKey
,
prim
::
kPrimGetRefValue
,
prim
::
kPrimGetRefOrigin
});
replace_refkey_by_param_
=
MakeSubstitution
(
ReplaceRefkeyByParam
(),
"replace_refkey_by_param"
,
IsValueNode
<
RefKey
>
,
opt
::
FORCE_RENORM
);
replace_old_param_
=
MakeSubstitution
(
ReplaceOldParam
(),
"replace_old_param"
,
IsParam
);
replace_refkey_by_param_
=
MakeSubstitution
(
std
::
make_shared
<
ReplaceRefkeyByParam
>
(),
"replace_refkey_by_param"
,
IsValueNode
<
RefKey
>
,
opt
::
FORCE_RENORM
);
replace_old_param_
=
MakeSubstitution
(
std
::
make_shared
<
ReplaceOldParam
>
(),
"replace_old_param"
,
IsParam
);
// Gradient transforms
expand_jprim_
=
MakeSubstitution
(
ExpandJPrim
(),
"expand_jprim"
,
prim
::
kPrimJ
);
minmaximum_grad_
=
MakeSubstitution
(
MinMaximumGrad
(),
"minmaximum_grad"
,
prim
::
kPrimTupleGetItem
);
expand_jprim_
=
MakeSubstitution
(
std
::
make_shared
<
ExpandJPrim
>
(),
"expand_jprim"
,
prim
::
kPrimJ
);
minmaximum_grad_
=
MakeSubstitution
(
std
::
make_shared
<
MinMaximumGrad
>
(),
"minmaximum_grad"
,
prim
::
kPrimTupleGetItem
);
// branch culling
switch_simplify_
=
MakeSubstitution
(
SwitchSimplify
(),
"switch_simplify"
,
prim
::
kPrimSwitch
);
float_tuple_getitem_switch_
=
MakeSubstitution
(
FloatTupleGetItemSwitch
(),
"float_tuple_getitem_switch"
,
prim
::
kPrimTupleGetItem
);
switch_simplify_
=
MakeSubstitution
(
std
::
make_shared
<
SwitchSimplify
>
(),
"switch_simplify"
,
prim
::
kPrimSwitch
);
float_tuple_getitem_switch_
=
MakeSubstitution
(
std
::
make_shared
<
FloatTupleGetItemSwitch
>
(),
"float_tuple_getitem_switch"
,
prim
::
kPrimTupleGetItem
);
float_env_getitem_switch_
=
MakeSubstitution
(
FloatEnvGetItemSwitch
(),
"float_env_getitem_switch"
,
prim
::
kPrimEnvGetItem
);
convert_switch_replacement_
=
MakeSubstitution
(
ConvertSwitchReplacement
(),
"convert_switch_replacement"
,
IsCNodeDup
);
MakeSubstitution
(
std
::
make_shared
<
FloatEnvGetItemSwitch
>
(),
"float_env_getitem_switch"
,
prim
::
kPrimEnvGetItem
);
convert_switch_replacement_
=
MakeSubstitution
(
std
::
make_shared
<
ConvertSwitchReplacement
>
(),
"convert_switch_replacement"
,
IsCNodeDup
);
// Addn
merge_addn_
=
MakeSubstitution
(
MergeAddN
(),
"merge_addn"
,
prim
::
kPrimAddN
);
addn_zero_filter_
=
MakeSubstitution
(
AddNZeroFilter
(),
"addn_zero_filter"
,
prim
::
kPrimAddN
);
merge_addn_
=
MakeSubstitution
(
std
::
make_shared
<
MergeAddN
>
(),
"merge_addn"
,
prim
::
kPrimAddN
);
addn_zero_filter_
=
MakeSubstitution
(
std
::
make_shared
<
AddNZeroFilter
>
(),
"addn_zero_filter"
,
prim
::
kPrimAddN
);
// inline
inline_
=
MakeSubstitution
(
Inliner
(),
"inline"
,
IsCNodeGraph
);
replace_applicator_
=
MakeSubstitution
(
ReplaceApplicator
(),
"replace_applicator"
,
IsValueNode
<
FuncGraph
>
);
specialize_transform_
=
MakeSubstitution
(
SpecializeOnGraphArguments
(),
"specialize_transform"
,
IsCNodeGraph
);
inline_
=
MakeSubstitution
(
std
::
make_shared
<
Inliner
>
(),
"inline"
,
IsCNodeGraph
);
replace_applicator_
=
MakeSubstitution
(
std
::
make_shared
<
ReplaceApplicator
>
(),
"replace_applicator"
,
IsValueNode
<
FuncGraph
>
);
specialize_transform_
=
MakeSubstitution
(
std
::
make_shared
<
SpecializeOnGraphArguments
>
(),
"specialize_transform"
,
IsCNodeGraph
);
// Incorporation
incorporate_getitem_set_
=
MakeSubstitution
(
IncorporateGetitemSet
(),
"incorporate_getitem_set"
,
prim
::
kPrimTupleGetItem
);
incorporate_getitem_from_param_
=
MakeSubstitution
(
IncorporateGetitemFromParam
(),
"incorporate_getitem_from_param"
,
IsCNodeGraphKernel
);
incorporate_call_
=
MakeSubstitution
(
IncorporateCall
(),
"incorporate_call"
,
IsCNodeDup
);
incorporate_call_switch_
=
MakeSubstitution
(
IncorporateCallSwitch
(),
"incorporate_call_switch"
,
IsCNodeDup
);
MakeSubstitution
(
std
::
make_shared
<
IncorporateGetitemSet
>
(),
"incorporate_getitem_set"
,
prim
::
kPrimTupleGetItem
);
incorporate_getitem_from_param_
=
MakeSubstitution
(
std
::
make_shared
<
IncorporateGetitemFromParam
>
(),
"incorporate_getitem_from_param"
,
IsCNodeGraphKernel
);
incorporate_call_
=
MakeSubstitution
(
std
::
make_shared
<
IncorporateCall
>
(),
"incorporate_call"
,
IsCNodeDup
);
incorporate_call_switch_
=
MakeSubstitution
(
std
::
make_shared
<
IncorporateCallSwitch
>
(),
"incorporate_call_switch"
,
IsCNodeDup
);
// Virtual Dataset
virtual_dataset_eliminate_
=
MakeSubstitution
(
VirtualDatasetEliminater
(),
"virtual_dataset_eliminate"
,
prim
::
kPrimVirtualDataset
);
virtual_dataset_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
VirtualDatasetEliminater
>
(),
"virtual_dataset_eliminate"
,
prim
::
kPrimVirtualDataset
);
// Convert
print_tuple_wrapper_
=
MakeSubstitution
(
PrintTupleWrapper
(),
"print_tuple_wrapper"
,
prim
::
kPrimPrint
);
print_tuple_wrapper_
=
MakeSubstitution
(
std
::
make_shared
<
PrintTupleWrapper
>
(),
"print_tuple_wrapper"
,
prim
::
kPrimPrint
);
// Unused parameter eliminate
unused_parameter_eliminate_
=
MakeSubstitution
(
UnusedParasEliminater
(),
"unused_parameter_eliminate"
,
IsCNodeGraphKernel
);
unused_output_eliminate_
=
MakeSubstitution
(
UnusedOutputEliminater
(),
"unused_output_eliminate"
,
IsCNodeGraphKernel
);
MakeSubstitution
(
std
::
make_shared
<
UnusedParasEliminater
>
(),
"unused_parameter_eliminate"
,
IsCNodeGraphKernel
);
unused_output_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
UnusedOutputEliminater
>
(),
"unused_output_eliminate"
,
IsCNodeGraphKernel
);
// AddN eliminate
addn_eliminate_
=
MakeSubstitution
(
AddNEliminater
(),
"addn_eliminate"
,
IsCNodeGraphKernel
);
addn_eliminate_
=
MakeSubstitution
(
std
::
make_shared
<
AddNEliminater
>
(),
"addn_eliminate"
,
IsCNodeGraphKernel
);
// Mark interface fusion
mark_interface_fusion_
=
MakeSubstitution
(
MarkInterfaceFusion
(),
"mark_interface_fusion"
,
prim
::
kPrimSelect
);
mark_interface_fusion_
=
MakeSubstitution
(
std
::
make_shared
<
MarkInterfaceFusion
>
(),
"mark_interface_fusion"
,
prim
::
kPrimSelect
);
}
ResolveIRPassLib
::
ResolveIRPassLib
()
{
resolver_resolve_
=
MakeSubstitution
(
ResolverResolve
(),
"resolver_resolve"
,
prim
::
kPrimResolve
);
resolver_getattr_
=
MakeSubstitution
(
ResolverGetattr
(),
"resolver_getattr"
,
prim
::
kPrimGetAttr
);
resolver_resolve_
=
MakeSubstitution
(
std
::
make_shared
<
ResolverResolve
>
(),
"resolver_resolve"
,
prim
::
kPrimResolve
);
resolver_getattr_
=
MakeSubstitution
(
std
::
make_shared
<
ResolverGetattr
>
(),
"resolver_getattr"
,
prim
::
kPrimGetAttr
);
}
InferenceOptPrepareLib
::
InferenceOptPrepareLib
()
{
grad_var_prepare_
=
MakeSubstitution
(
GradVarPrepare
(),
"grad_var_prepare"
,
IsCNode
);
grad_var_prepare_
=
MakeSubstitution
(
std
::
make_shared
<
GradVarPrepare
>
(),
"grad_var_prepare"
,
IsCNode
);
}
}
// namespace irpass
}
// namespace opt
...
...
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
浏览文件 @
262e4fc0
...
...
@@ -17,15 +17,16 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
#include <vector>
#include <memory>
#include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace
mindspore
{
namespace
opt
{
...
...
@@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
FuncGraphPtr
all_reduce_fg_
{
nullptr
};
};
class
ArithmeticSimplify
{
class
ArithmeticSimplify
:
public
OptimizerCaller
{
public:
ArithmeticSimplify
()
:
multiply_by_zero_or_one_
(),
tensor_multiply_by_one_
(),
add_by_zero_
(),
tensor_add_by_zero_
(),
identity_
(
prim
::
kPrimIdentity
),
opt_update_zero_tensor_
(),
constant_duplicate_mul_
(),
power_one_
()
{
:
multiply_by_zero_or_one_
(
std
::
make_shared
<
MultiplyByZeroOrOne
>
()
),
tensor_multiply_by_one_
(
std
::
make_shared
<
TensorMultiplyByOne
>
()
),
add_by_zero_
(
std
::
make_shared
<
AddByZero
>
()
),
tensor_add_by_zero_
(
std
::
make_shared
<
TensorAddByZero
>
()
),
identity_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimIdentity
)
),
opt_update_zero_tensor_
(
std
::
make_shared
<
OptUpdateZeroTensor
>
()
),
constant_duplicate_mul_
(
std
::
make_shared
<
ConstantDuplicateMul
>
()
),
power_one_
(
std
::
make_shared
<
PowerOneEliminate
>
()
)
{
eliminaters_
.
emplace_back
(
multiply_by_zero_or_one_
);
eliminaters_
.
emplace_back
(
tensor_multiply_by_one_
);
eliminaters_
.
emplace_back
(
add_by_zero_
);
...
...
@@ -761,10 +762,10 @@ class ArithmeticSimplify {
}
~
ArithmeticSimplify
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
eliminater
(
optimizer
,
node
);
new_node
=
(
*
eliminater
)
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
return
new_node
;
}
...
...
@@ -773,15 +774,9 @@ class ArithmeticSimplify {
}
private:
MultiplyByZeroOrOne
multiply_by_zero_or_one_
;
TensorMultiplyByOne
tensor_multiply_by_one_
;
AddByZero
add_by_zero_
;
TensorAddByZero
tensor_add_by_zero_
;
PrimEliminater
identity_
;
OptUpdateZeroTensor
opt_update_zero_tensor_
;
ConstantDuplicateMul
constant_duplicate_mul_
;
PowerOneEliminate
power_one_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
OptimizerCallerPtr
multiply_by_zero_or_one_
,
tensor_multiply_by_one_
,
add_by_zero_
,
tensor_add_by_zero_
,
identity_
,
opt_update_zero_tensor_
,
constant_duplicate_mul_
,
power_one_
;
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
};
// Arithmetic Simplifications should be done after step_parallel.
...
...
@@ -789,15 +784,17 @@ class ArithmeticSimplify {
// with shape(weight), but after step_parallel, shape of weight may be changed, so the
// shape of the constant tensor should also be changed. So this pass is seperated from
// ArithmeticSimplify and deferred until step_parallel.
class
ArithmeticSimplify2
{
class
ArithmeticSimplify2
:
public
OptimizerCaller
{
public:
ArithmeticSimplify2
()
:
tensor_multiply_by_zero_
()
{
eliminaters_
.
emplace_back
(
tensor_multiply_by_zero_
);
}
ArithmeticSimplify2
()
:
tensor_multiply_by_zero_
(
std
::
make_shared
<
TensorMultiplyByZero
>
())
{
eliminaters_
.
emplace_back
(
tensor_multiply_by_zero_
);
}
~
ArithmeticSimplify2
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
eliminater
(
optimizer
,
node
);
new_node
=
(
*
eliminater
)
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
return
new_node
;
}
...
...
@@ -806,8 +803,8 @@ class ArithmeticSimplify2 {
}
private:
TensorMultiplyByZero
tensor_multiply_by_zero_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
OptimizerCallerPtr
tensor_multiply_by_zero_
;
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
};
}
// namespace irpass
}
// namespace opt
...
...
mindspore/ccsrc/optimizer/irpass/cast_eliminate.h
浏览文件 @
262e4fc0
...
...
@@ -17,9 +17,9 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#include "ir/visitor.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
namespace
mindspore
{
namespace
opt
{
...
...
@@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor {
AnfNodePtr
x_
{
nullptr
},
t_
{
nullptr
};
};
class
CastEliminater
{
class
CastEliminater
:
public
OptimizerCaller
{
public:
CastEliminater
()
:
cast_same_type_eliminater_
(),
two_cast_eliminater_
()
{}
~
CastEliminater
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
auto
new_node
=
cast_same_type_eliminater_
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
return
new_node
;
...
...
mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h
浏览文件 @
262e4fc0
...
...
@@ -17,18 +17,19 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#include <vector>
#include <utility>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "utils/symbolic.h"
namespace
mindspore
{
...
...
@@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor {
bool
is_match_
{
false
};
};
class
EnvGetItemEliminater
{
class
EnvGetItemEliminater
:
public
OptimizerCaller
{
public:
EnvGetItemEliminater
()
:
new_env_get_item_
(),
add_env_get_item_
(),
env_get_set_item_
()
{
EnvGetItemEliminater
()
:
new_env_get_item_
(
std
::
make_shared
<
NewEnvGetItem
>
()),
add_env_get_item_
(
std
::
make_shared
<
AddEnvGetItem
>
()),
env_get_set_item_
(
std
::
make_shared
<
EnvGetSetItem
>
())
{
eliminaters_
.
emplace_back
(
new_env_get_item_
);
eliminaters_
.
emplace_back
(
add_env_get_item_
);
eliminaters_
.
emplace_back
(
env_get_set_item_
);
}
~
EnvGetItemEliminater
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
eliminater
(
optimizer
,
node
);
new_node
=
(
*
eliminater
)
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
return
new_node
;
}
...
...
@@ -246,10 +250,8 @@ class EnvGetItemEliminater {
}
private:
NewEnvGetItem
new_env_get_item_
;
AddEnvGetItem
add_env_get_item_
;
EnvGetSetItem
env_get_set_item_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
OptimizerCallerPtr
new_env_get_item_
,
add_env_get_item_
,
env_get_set_item_
;
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
};
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
...
...
mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h
浏览文件 @
262e4fc0
...
...
@@ -17,18 +17,20 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
#include <vector>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace
mindspore
{
namespace
opt
{
namespace
irpass
{
...
...
@@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor {
internal
::
GetitemTransform
getitem_transform_
;
};
class
IncorporateGetitemSet
{
class
IncorporateGetitemSet
:
public
OptimizerCaller
{
public:
IncorporateGetitemSet
()
:
incorporate_getitem_
(),
incorporate_getitem_switch_
()
{
IncorporateGetitemSet
()
:
incorporate_getitem_
(
std
::
make_shared
<
IncorporateGetitem
>
()),
incorporate_getitem_switch_
(
std
::
make_shared
<
IncorporateGetitemSwitch
>
())
{
eliminaters_
.
emplace_back
(
incorporate_getitem_
);
eliminaters_
.
emplace_back
(
incorporate_getitem_switch_
);
}
~
IncorporateGetitemSet
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
eliminater
(
optimizer
,
node
);
new_node
=
(
*
eliminater
)
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
return
new_node
;
}
...
...
@@ -403,9 +407,8 @@ class IncorporateGetitemSet {
}
private:
IncorporateGetitem
incorporate_getitem_
;
IncorporateGetitemSwitch
incorporate_getitem_switch_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
OptimizerCallerPtr
incorporate_getitem_
,
incorporate_getitem_switch_
;
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
};
}
// namespace irpass
}
// namespace opt
...
...
mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h
浏览文件 @
262e4fc0
...
...
@@ -17,13 +17,15 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#include <vector>
#include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace
mindspore
{
namespace
opt
{
...
...
@@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor {
AnfNodePtr
x_
{
nullptr
},
y_
{
nullptr
},
c_
{
nullptr
};
};
class
ItemTupleEliminater
{
class
ItemTupleEliminater
:
public
OptimizerCaller
{
public:
ItemTupleEliminater
()
:
get_item_eliminater_
(),
get_item_const_eliminater_
(),
set_item_eliminater_
(),
get_set_item_eliminater_
(),
get_item_depend_reorder_
()
{
:
get_item_eliminater_
(
std
::
make_shared
<
GetitemEliminater
>
()
),
get_item_const_eliminater_
(
std
::
make_shared
<
GetitemConstEliminater
>
()
),
set_item_eliminater_
(
std
::
make_shared
<
SetitemEliminater
>
()
),
get_set_item_eliminater_
(
std
::
make_shared
<
GetSetitemEliminater
>
()
),
get_item_depend_reorder_
(
std
::
make_shared
<
GetitemDependReorder
>
()
)
{
eliminaters_
.
emplace_back
(
get_item_eliminater_
);
eliminaters_
.
emplace_back
(
get_item_const_eliminater_
);
eliminaters_
.
emplace_back
(
set_item_eliminater_
);
...
...
@@ -277,10 +279,10 @@ class ItemTupleEliminater {
}
~
ItemTupleEliminater
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
eliminater
(
optimizer
,
node
);
new_node
=
(
*
eliminater
)
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
return
new_node
;
}
...
...
@@ -289,12 +291,9 @@ class ItemTupleEliminater {
}
private:
GetitemEliminater
get_item_eliminater_
;
GetitemConstEliminater
get_item_const_eliminater_
;
SetitemEliminater
set_item_eliminater_
;
GetSetitemEliminater
get_set_item_eliminater_
;
GetitemDependReorder
get_item_depend_reorder_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
OptimizerCallerPtr
get_item_eliminater_
,
get_item_const_eliminater_
,
set_item_eliminater_
,
get_set_item_eliminater_
,
get_item_depend_reorder_
;
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
};
}
// namespace irpass
}
// namespace opt
...
...
mindspore/ccsrc/optimizer/irpass/ref_eliminate.h
浏览文件 @
262e4fc0
...
...
@@ -19,9 +19,9 @@
#include <memory>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/pattern_matcher.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace
mindspore
{
namespace
opt
{
...
...
mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h
浏览文件 @
262e4fc0
...
...
@@ -19,11 +19,12 @@
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "pipeline/static_analysis/dshape.h"
namespace
mindspore
{
...
...
@@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor {
AnfNodePtr
x_
{
nullptr
},
shape_
{
nullptr
};
};
class
ReshapeEliminater
{
class
ReshapeEliminater
:
public
OptimizerCaller
{
public:
ReshapeEliminater
()
:
reshape_same_shape_eliminater_
(),
two_reshape_eliminater_
()
{}
~
ReshapeEliminater
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
auto
new_node
=
reshape_same_shape_eliminater_
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
return
new_node
;
...
...
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
浏览文件 @
262e4fc0
...
...
@@ -18,31 +18,31 @@
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_
#include <securec.h>
#include <vector>
#include <memory>
#include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/optimizer_caller.h"
#include "
optimizer/irpass/prim_eliminate
.h"
#include "
ir/pattern_matcher
.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "ir/pattern_matcher.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace
mindspore
{
namespace
opt
{
namespace
irpass
{
class
SpecialOpEliminater
{
class
SpecialOpEliminater
:
public
OptimizerCaller
{
public:
SpecialOpEliminater
()
:
insert_gradient_of_
(
prim
::
kPrimInsertGradientOf
),
stop_gradient_
(
prim
::
kPrimStopGradient
),
hook_backward_
(
prim
::
kPrimHookBackward
),
print_shape_type_
(
prim
::
kPrimPrintShapeType
),
get_ref_value_
(
prim
::
kPrimGetRefValue
),
mirror_
(
prim
::
kPrimMirror
),
virtual_div_
(
prim
::
kPrimVirtualDiv
)
{
:
insert_gradient_of_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimInsertGradientOf
)
),
stop_gradient_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimStopGradient
)
),
hook_backward_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimHookBackward
)
),
print_shape_type_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimPrintShapeType
)
),
get_ref_value_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimGetRefValue
)
),
mirror_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimMirror
)
),
virtual_div_
(
std
::
make_shared
<
PrimEliminater
>
(
prim
::
kPrimVirtualDiv
)
)
{
eliminaters_
.
emplace_back
(
insert_gradient_of_
);
eliminaters_
.
emplace_back
(
stop_gradient_
);
eliminaters_
.
emplace_back
(
hook_backward_
);
...
...
@@ -53,10 +53,10 @@ class SpecialOpEliminater {
}
~
SpecialOpEliminater
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
AnfNodePtr
new_node
;
for
(
auto
&
eliminater
:
eliminaters_
)
{
new_node
=
eliminater
(
optimizer
,
node
);
new_node
=
(
*
eliminater
)
(
optimizer
,
node
);
if
(
new_node
!=
nullptr
)
{
return
new_node
;
}
...
...
@@ -65,9 +65,9 @@ class SpecialOpEliminater {
}
private:
PrimEliminate
r
insert_gradient_of_
,
stop_gradient_
,
hook_backward_
,
print_shape_type_
,
get_ref_value_
,
mirror_
,
OptimizerCallerPt
r
insert_gradient_of_
,
stop_gradient_
,
hook_backward_
,
print_shape_type_
,
get_ref_value_
,
mirror_
,
virtual_div_
;
std
::
vector
<
TransformFuncType
>
eliminaters_
{};
std
::
vector
<
OptimizerCallerPtr
>
eliminaters_
{};
};
// {PrimVirtualDataset, X} -> X
...
...
mindspore/ccsrc/optimizer/opt.cc
浏览文件 @
262e4fc0
...
...
@@ -16,28 +16,27 @@
#include "optimizer/opt.h"
#include <algorithm>
#include <deque>
#include <memory>
#include <unordered_set>
#include <deque>
#include <algorithm>
#include "ir/anf.h"
#include "ir/manager.h"
#include "utils/ordered_set.h"
#include "utils/log_adapter.h"
#include "optimizer/optimizer.h"
#include "utils/log_adapter.h"
#include "utils/ordered_set.h"
namespace
mindspore
{
/* namespace to support opt */
namespace
opt
{
SubstitutionPtr
MakeSubstitution
(
const
TransformFuncType
&
transform
,
const
std
::
string
&
name
,
const
PrimitivePtr
&
prim
,
SubstitutionPtr
MakeSubstitution
(
const
OptimizerCallerPtr
&
transform
,
const
std
::
string
&
name
,
const
PrimitivePtr
&
prim
,
const
RenormAction
&
renorm_action
)
{
auto
fn
=
[
prim
](
const
AnfNodePtr
&
node
)
->
bool
{
return
IsPrimitiveCNode
(
node
,
prim
);
};
return
std
::
make_shared
<
Substitution
>
(
transform
,
name
,
fn
,
renorm_action
);
}
SubstitutionPtr
MakeSubstitution
(
const
TransformFuncType
&
transform
,
const
std
::
string
&
name
,
SubstitutionPtr
MakeSubstitution
(
const
OptimizerCallerPtr
&
transform
,
const
std
::
string
&
name
,
const
std
::
vector
<
PrimitivePtr
>
&
prims
,
const
RenormAction
&
renorm_action
)
{
auto
fn
=
[
prims
](
const
AnfNodePtr
&
node
)
->
bool
{
if
(
!
node
->
isa
<
CNode
>
())
{
...
...
@@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::
return
std
::
make_shared
<
Substitution
>
(
transform
,
name
,
fn
,
renorm_action
);
}
SubstitutionPtr
MakeSubstitution
(
const
TransformFuncType
&
transform
,
const
std
::
string
&
name
,
SubstitutionPtr
MakeSubstitution
(
const
OptimizerCallerPtr
&
transform
,
const
std
::
string
&
name
,
const
PredicateFuncType
&
predicate
,
const
RenormAction
&
renorm_action
)
{
return
std
::
make_shared
<
Substitution
>
(
transform
,
name
,
predicate
,
renorm_action
);
}
AnfNodePtr
Substitution
::
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
const
{
AnfNodePtr
Substitution
::
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
{
#ifdef ENABLE_PROFILE
double
t
=
GetTime
();
#endif
AnfNodePtr
result
=
transform_
(
optimizer
,
node
);
AnfNodePtr
result
=
(
*
transform_
)
(
optimizer
,
node
);
#ifdef ENABLE_PROFILE
if
(
optimizer
!=
nullptr
)
{
auto
time
=
GetTime
();
...
...
mindspore/ccsrc/optimizer/opt.h
浏览文件 @
262e4fc0
...
...
@@ -17,24 +17,18 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
#define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
#include <vector>
#include <string>
#include <memory>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "operator/ops.h"
namespace
mindspore
{
/* namespace to support opt */
namespace
opt
{
class
Optimizer
;
using
OptimizerPtr
=
std
::
shared_ptr
<
Optimizer
>
;
using
OptimizerWeakPtr
=
std
::
weak_ptr
<
Optimizer
>
;
using
PredicateFuncType
=
std
::
function
<
bool
(
const
AnfNodePtr
&
)
>
;
using
TransformFuncType
=
std
::
function
<
AnfNodePtr
(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
)
>
;
// Define the interaction mode between an Optimize pass and Renormalize pass
// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed
...
...
@@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM };
class
Substitution
{
public:
TransformFuncType
transform_
{
nullptr
}
;
OptimizerCallerPtr
transform_
;
std
::
string
name_
;
PredicateFuncType
predicate_
{
nullptr
};
// an enum to mark this Substitution relation to renormalize pass
RenormAction
renorm_action_
;
Substitution
(
const
TransformFuncType
&
transform
,
const
std
::
string
&
name
,
const
PredicateFuncType
&
predicate
,
Substitution
(
const
OptimizerCallerPtr
&
transform
,
const
std
::
string
&
name
,
const
PredicateFuncType
&
predicate
,
const
RenormAction
&
renorm_action
)
:
transform_
(
transform
),
name_
(
name
),
predicate_
(
predicate
),
renorm_action_
(
renorm_action
)
{}
~
Substitution
()
=
default
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
const
;
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
);
};
using
SubstitutionPtr
=
std
::
shared_ptr
<
Substitution
>
;
SubstitutionPtr
MakeSubstitution
(
const
TransformFuncType
&
transform
,
const
std
::
string
&
name
,
const
PrimitivePtr
&
prim
,
SubstitutionPtr
MakeSubstitution
(
const
OptimizerCallerPtr
&
transform
,
const
std
::
string
&
name
,
const
PrimitivePtr
&
prim
,
const
RenormAction
&
action_renorm
=
CHECK_RENORM
);
SubstitutionPtr
MakeSubstitution
(
const
TransformFuncType
&
transform
,
const
std
::
string
&
name
,
SubstitutionPtr
MakeSubstitution
(
const
OptimizerCallerPtr
&
transform
,
const
std
::
string
&
name
,
const
std
::
vector
<
PrimitivePtr
>
&
prims
,
const
RenormAction
&
action_renorm
=
CHECK_RENORM
);
SubstitutionPtr
MakeSubstitution
(
const
TransformFuncType
&
transform
,
const
std
::
string
&
name
,
SubstitutionPtr
MakeSubstitution
(
const
OptimizerCallerPtr
&
transform
,
const
std
::
string
&
name
,
const
PredicateFuncType
&
predicate
,
const
RenormAction
&
action_renorm
=
CHECK_RENORM
);
class
SubstitutionList
{
...
...
tests/ut/cpp/optimizer/opt_test.cc
浏览文件 @
262e4fc0
...
...
@@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common {
};
void
SetUp
()
{
elim_Z
=
MakeSubstitution
(
irpass
::
AddByZero
(),
"elim_Z"
,
prim
::
kPrimScalarAdd
);
elim_R
=
MakeSubstitution
(
irpass
::
PrimEliminater
(
R
),
"elim_R"
,
R
);
idempotent_P
=
MakeSubstitution
(
IdempotentEliminater
(),
"idempotent_P"
,
P
);
Qct_to_P
=
MakeSubstitution
(
QctToP
(),
"Qct_to_P"
,
Q
);
elim_Z
=
MakeSubstitution
(
std
::
make_shared
<
irpass
::
AddByZero
>
(),
"elim_Z"
,
prim
::
kPrimScalarAdd
);
elim_R
=
MakeSubstitution
(
std
::
make_shared
<
irpass
::
PrimEliminater
>
(
R
),
"elim_R"
,
R
);
idempotent_P
=
MakeSubstitution
(
std
::
make_shared
<
IdempotentEliminater
>
(),
"idempotent_P"
,
P
);
Qct_to_P
=
MakeSubstitution
(
std
::
make_shared
<
QctToP
>
(),
"Qct_to_P"
,
Q
);
}
bool
CheckTransform
(
FuncGraphPtr
gbefore
,
FuncGraphPtr
gafter
,
const
SubstitutionList
&
transform
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录