Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
daccfef7
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
daccfef7
编写于
5月 23, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 23, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1361 Refactor multiple output pass
Merge pull request !1361 from huanghui/LambNextMVRule-fusion-pass
上级
20d71dfb
f16ff539
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
264 addition
and
212 deletion
+264
-212
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
.../ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
+2
-2
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc
.../ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc
+69
-84
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h
...e/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h
+57
-9
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc
...activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc
+45
-60
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h
..._activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h
+19
-19
mindspore/ccsrc/pre_activate/common/optimizer.cc
mindspore/ccsrc/pre_activate/common/optimizer.cc
+15
-0
mindspore/ccsrc/pre_activate/common/optimizer.h
mindspore/ccsrc/pre_activate/common/optimizer.h
+19
-0
tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc
...p/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc
+16
-16
tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc
...ate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc
+20
-20
tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_rule_test.py
..._input/gtest_input/pre_activate/lamb_next_mv_rule_test.py
+1
-1
tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_with_decay_rule_test.py
...t_input/pre_activate/lamb_next_mv_with_decay_rule_test.py
+1
-1
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
浏览文件 @
daccfef7
...
...
@@ -99,11 +99,11 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
ClipByNormNoDivSquareSumFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambUpdateWithLRRuleFusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
ConfusionSoftmaxGradRule
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambNextMVWithDecayRule
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambNextMVWithDecayRuleCond1
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambNextMVWithDecayRuleCond2
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambNextMVWithDecayRuleCond3
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambNextMVRule
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambNextMVWithDecayRuleCond4
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambNextMVRuleCond4
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambNextRightRule
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambUpdateWithLrV2
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
ReshapeTransposeFusion
>
());
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc
浏览文件 @
daccfef7
...
...
@@ -27,82 +27,12 @@
namespace
mindspore
{
namespace
opt
{
namespace
{
std
::
tuple
<
CNodePtr
,
CNodePtr
,
AnfNodePtr
>
GetSharedNodesByPattern
(
const
AnfNodePtr
&
node
)
{
auto
add3_cnode
=
CheckAnfNodeIfCNodeAndInputSize
(
node
,
kAddInputNum
);
MS_EXCEPTION_IF_NULL
(
add3_cnode
);
auto
real_div2_cnode
=
CheckAnfNodeIfCNodeAndInputSize
(
add3_cnode
->
input
(
1
),
kMulInputNum
);
MS_EXCEPTION_IF_NULL
(
real_div2_cnode
);
auto
real_div0_cnode
=
CheckAnfNodeIfCNodeAndInputSize
(
real_div2_cnode
->
input
(
1
),
kRealDivInputNum
);
MS_EXCEPTION_IF_NULL
(
real_div0_cnode
);
auto
sqrt0_cnode
=
CheckAnfNodeIfCNodeAndInputSize
(
real_div2_cnode
->
input
(
2
),
kSqrtInputNum
);
MS_EXCEPTION_IF_NULL
(
sqrt0_cnode
);
auto
add2_cnode
=
CheckAnfNodeIfCNodeAndInputSize
(
sqrt0_cnode
->
input
(
1
),
kAddInputNum
);
MS_EXCEPTION_IF_NULL
(
add2_cnode
);
auto
real_div1_cnode
=
CheckAnfNodeIfCNodeAndInputSize
(
add2_cnode
->
input
(
1
),
kRealDivInputNum
);
auto
constant_add2_y
=
add2_cnode
->
input
(
2
);
return
std
::
make_tuple
(
real_div0_cnode
,
real_div1_cnode
,
constant_add2_y
);
}
bool
MatchRealDiv4
(
const
AnfNodePtr
&
real_div4
,
const
AnfNodePtr
&
real_div1
,
const
AnfNodePtr
&
constant_add2_y
)
{
if
(
real_div4
==
nullptr
||
!
real_div4
->
isa
<
CNode
>
())
{
return
false
;
}
auto
real_div4_cnode
=
real_div4
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
real_div4_cnode
);
if
(
AnfAlgo
::
GetCNodeName
(
real_div4_cnode
)
!=
kRealDivOpName
||
real_div4_cnode
->
inputs
().
size
()
<
kRealDivInputNum
)
{
return
false
;
}
CNodePtr
add4_cnode
=
nullptr
;
if
(
!
CheckIfCNodeAndInputSize
(
real_div4_cnode
->
input
(
2
),
kAddInputNum
,
&
add4_cnode
)
||
AnfAlgo
::
GetCNodeName
(
add4_cnode
)
!=
prim
::
kPrimTensorAdd
->
name
())
{
return
false
;
}
CNodePtr
sqrt1_cnode
=
nullptr
;
if
(
!
CheckIfCNodeAndInputSize
(
add4_cnode
->
input
(
1
),
kSqrtInputNum
,
&
sqrt1_cnode
)
||
AnfAlgo
::
GetCNodeName
(
sqrt1_cnode
)
!=
kSqrtOpName
)
{
return
false
;
}
MS_EXCEPTION_IF_NULL
(
add4_cnode
->
input
(
2
));
MS_EXCEPTION_IF_NULL
(
constant_add2_y
);
return
sqrt1_cnode
->
input
(
1
)
==
real_div1
&&
*
(
add4_cnode
->
input
(
2
))
==
*
constant_add2_y
;
}
}
// namespace
const
BaseRef
LambNextMVRule
::
DefinePattern
()
const
{
const
auto
prim_rsqrt
=
std
::
make_shared
<
Primitive
>
(
kRsqrtOpName
);
MS_EXCEPTION_IF_NULL
(
prim_rsqrt
);
const
auto
prim_deal_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
MS_EXCEPTION_IF_NULL
(
prim_deal_div
);
auto
mul0
=
VectorRef
({
prim
::
kPrimMul
,
input_varptr_
[
7
],
input_varptr_
[
4
]});
auto
mul1
=
VectorRef
({
prim
::
kPrimMul
,
input_varptr_
[
8
],
input_varptr_
[
3
]});
auto
mul2
=
VectorRef
({
prim
::
kPrimMul
,
input_varptr_
[
9
],
input_varptr_
[
1
]});
auto
mul3
=
VectorRef
({
prim
::
kPrimMul
,
input_varptr_
[
10
],
input_varptr_
[
0
]});
auto
mul4
=
VectorRef
({
prim
::
kPrimMul
,
input_varptr_
[
11
],
input_varptr_
[
6
]});
auto
add0
=
VectorRef
({
prim
::
kPrimTensorAdd
,
mul0
,
mul1
});
auto
add1
=
VectorRef
({
prim
::
kPrimTensorAdd
,
mul2
,
mul3
});
auto
real_div0
=
VectorRef
({
prim_deal_div
,
add0
,
input_varptr_
[
5
]});
auto
real_div1
=
VectorRef
({
prim_deal_div
,
add1
,
input_varptr_
[
2
]});
auto
add2
=
VectorRef
({
prim
::
kPrimTensorAdd
,
real_div1
,
input_varptr_
[
12
]});
auto
sqrt0
=
VectorRef
({
prim_rsqrt
,
add2
});
auto
real_div2
=
VectorRef
({
prim
::
kPrimMul
,
real_div0
,
sqrt0
});
return
VectorRef
({
prim
::
kPrimTensorAdd
,
real_div2
,
mul4
});
}
bool
LambNextMVRule
::
IsRuleMatched
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
bool
LambNextMVRule
::
IsRuleMatched
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
,
std
::
vector
<
AnfNodePtr
>
*
old_pattern_outputs
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
CNodePtr
real_div0
=
nullptr
;
CNodePtr
real_div1
=
nullptr
;
AnfNodePtr
constant_add2_y
=
nullptr
;
std
::
tie
(
real_div0
,
real_div1
,
constant_add2_y
)
=
GetSharedNodesByPattern
(
node
);
MS_EXCEPTION_IF_NULL
(
equiv
);
auto
real_div0
=
GetAnfNodeByVar
(
equiv
,
real_div0_var_
);
auto
real_div2
=
GetAnfNodeByVar
(
equiv
,
real_div2_var_
);
auto
manager
=
func_graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
...
...
@@ -112,19 +42,17 @@ bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNode
}
AnfNodeIndexSet
real_div0_outputs
=
users
[
real_div0
];
auto
iter
=
std
::
find_if
(
real_div0_outputs
.
begin
(),
real_div0_outputs
.
end
(),
[
&
node
,
&
real_div1
,
&
constant_add2_y
](
const
std
::
pair
<
AnfNodePtr
,
int
>
&
node_index
)
{
return
node_index
.
first
!=
node
&&
node_index
.
second
==
1
&&
Match
RealDiv4
(
node_index
.
first
,
real_div1
,
constant_add2_y
);
[
&
real_div2
,
&
equiv
,
this
](
const
std
::
pair
<
AnfNodePtr
,
int
>
&
node_index
)
{
return
node_index
.
first
!=
real_div2
&&
node_index
.
second
==
1
&&
Match
AnotherPattern
(
node_index
.
first
,
equiv
);
});
if
(
iter
==
real_div0_outputs
.
end
())
{
return
false
;
}
auto
add0_cnode
=
CheckAnfNodeIfCNodeAndInputSize
(
real_div0
->
input
(
1
),
kAddInputNum
);
auto
add1_cnode
=
CheckAnfNodeIfCNodeAndInputSize
(
real_div1
->
input
(
1
),
kAddInputNum
);
(
*
old_pattern_outputs
).
push_back
(
node
);
(
*
old_pattern_outputs
).
push_back
(
add0_cnode
);
(
*
old_pattern_outputs
).
push_back
(
add1_cnode
);
(
*
old_pattern_outputs
).
push_back
(
GetAnfNodeByVar
(
equiv
,
add0_var_
)
);
(
*
old_pattern_outputs
).
push_back
(
GetAnfNodeByVar
(
equiv
,
add1_var_
)
);
(
*
old_pattern_outputs
).
push_back
(
iter
->
first
);
return
true
;
...
...
@@ -136,8 +64,19 @@ AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph,
MS_EXCEPTION_IF_NULL
(
func_graph
);
auto
prim
=
std
::
make_shared
<
Primitive
>
(
kLambNextMVOpName
);
std
::
vector
<
AnfNodePtr
>
lamb_next_mv_rule_inputs
=
{
NewValueNode
(
prim
)};
(
void
)
std
::
transform
(
input_varptr_
.
begin
(),
input_varptr_
.
end
(),
std
::
back_inserter
(
lamb_next_mv_rule_inputs
),
[
&
equiv
](
const
VarPtr
&
in
)
{
return
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
in
]);
});
lamb_next_mv_rule_inputs
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input0_
]));
lamb_next_mv_rule_inputs
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input1_
]));
lamb_next_mv_rule_inputs
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input2_
]));
lamb_next_mv_rule_inputs
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input3_
]));
lamb_next_mv_rule_inputs
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input4_
]));
lamb_next_mv_rule_inputs
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input5_
]));
lamb_next_mv_rule_inputs
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input6_
]));
lamb_next_mv_rule_inputs
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
mul0_x_
]));
lamb_next_mv_rule_inputs
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
mul1_sub_
]));
lamb_next_mv_rule_inputs
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
mul2_x_
]));
lamb_next_mv_rule_inputs
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
mul3_sub1_
]));
lamb_next_mv_rule_inputs
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
mul4_x_
]));
lamb_next_mv_rule_inputs
.
push_back
(
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
add2_y_
]));
auto
lamb_next_mv_rule
=
func_graph
->
NewCNode
(
lamb_next_mv_rule_inputs
);
MS_EXCEPTION_IF_NULL
(
lamb_next_mv_rule
);
...
...
@@ -162,14 +101,60 @@ AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph,
return
lamb_next_mv_rule_outputs
[
0
];
}
bool
LambNextMVRule
::
IsShareNodes
(
const
EquivPtr
&
equiv1
,
const
EquivPtr
&
equiv2
)
const
{
return
IsSameNode
(
equiv1
,
equiv2
,
real_div0_var_
)
&&
IsSameNode
(
equiv1
,
equiv2
,
real_div1_var_
)
&&
IsSameNode
(
equiv1
,
equiv2
,
add2_y_
);
}
const
AnfNodePtr
LambNextMVRule
::
Process
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
)
const
{
std
::
vector
<
AnfNodePtr
>
old_pattern_outputs
;
if
(
!
IsRuleMatched
(
func_graph
,
node
,
&
old_pattern_outputs
))
{
if
(
!
IsRuleMatched
(
func_graph
,
node
,
equiv
,
&
old_pattern_outputs
))
{
return
nullptr
;
}
return
CreateLambNextMVNode
(
func_graph
,
old_pattern_outputs
,
equiv
);
}
const
BaseRef
LambNextMVRuleCond4
::
DefinePattern
()
const
{
const
auto
prim_rsqrt
=
std
::
make_shared
<
Primitive
>
(
kRsqrtOpName
);
MS_EXCEPTION_IF_NULL
(
prim_rsqrt
);
auto
mul0
=
VectorRef
({
prim
::
kPrimMul
,
mul0_x_
,
input4_
});
auto
mul1
=
VectorRef
({
prim
::
kPrimMul
,
mul1_sub_
,
input3_
});
auto
mul2
=
VectorRef
({
prim
::
kPrimMul
,
mul2_x_
,
input1_
});
auto
mul3
=
VectorRef
({
prim
::
kPrimMul
,
mul3_sub1_
,
input0_
});
auto
mul4
=
VectorRef
({
prim
::
kPrimMul
,
mul4_x_
,
input6_
});
auto
add0
=
VectorRef
({
add0_var_
,
mul0
,
mul1
});
auto
add1
=
VectorRef
({
add1_var_
,
mul2
,
mul3
});
auto
real_div0
=
VectorRef
({
real_div0_var_
,
add0
,
input5_
});
auto
real_div1
=
VectorRef
({
real_div1_var_
,
add1
,
input2_
});
auto
add2
=
VectorRef
({
prim
::
kPrimTensorAdd
,
real_div1
,
add2_y_
});
auto
sqrt0
=
VectorRef
({
prim_rsqrt
,
add2
});
auto
real_div2
=
VectorRef
({
real_div2_var_
,
real_div0
,
sqrt0
});
return
VectorRef
({
prim
::
kPrimTensorAdd
,
real_div2
,
mul4
});
}
BaseRef
LambNextMVRuleCond4
::
DefineAnotherPattern
()
const
{
const
auto
prim_sqrt
=
std
::
make_shared
<
Primitive
>
(
kSqrtOpName
);
MS_EXCEPTION_IF_NULL
(
prim_sqrt
);
const
auto
prim_real_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
MS_EXCEPTION_IF_NULL
(
prim_real_div
);
VarPtr
Xs
=
std
::
make_shared
<
SeqVar
>
();
VarPtr
Ys
=
std
::
make_shared
<
SeqVar
>
();
MS_EXCEPTION_IF_NULL
(
Xs
);
MS_EXCEPTION_IF_NULL
(
Ys
);
// Two patterns share: real_div0, real_div1, add2_y_
VectorRef
real_div0
=
VectorRef
({
real_div0_var_
,
Xs
});
VectorRef
real_div1
=
VectorRef
({
real_div1_var_
,
Ys
});
VectorRef
sqrt1
=
VectorRef
({
prim_sqrt
,
real_div1
});
VectorRef
add4
=
VectorRef
({
prim
::
kPrimTensorAdd
,
sqrt1
,
add2_y_
});
VectorRef
real_div4
=
VectorRef
({
prim_real_div
,
real_div0
,
add4
});
return
real_div4
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h
浏览文件 @
daccfef7
...
...
@@ -29,23 +29,71 @@
namespace
mindspore
{
namespace
opt
{
class
LambNextMVRule
:
public
PatternProcessPass
{
class
LambNextMVRule
:
public
MultipleOutput
PatternProcessPass
{
public:
explicit
LambNextMVRule
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"lamb_next_mv_rule"
,
multigraph
)
{
for
(
size_t
i
=
0
;
i
<
kLambNextMVRuleInputNum
-
1
;
++
i
)
{
input_varptr_
.
push_back
(
std
::
make_shared
<
Var
>
());
}
explicit
LambNextMVRule
(
const
std
::
string
&
name
=
""
,
bool
multigraph
=
true
)
:
MultipleOutputPatternProcessPass
(
name
,
multigraph
)
{
input0_
=
std
::
make_shared
<
Var
>
();
input1_
=
std
::
make_shared
<
Var
>
();
input2_
=
std
::
make_shared
<
Var
>
();
input3_
=
std
::
make_shared
<
Var
>
();
input4_
=
std
::
make_shared
<
Var
>
();
input5_
=
std
::
make_shared
<
Var
>
();
input6_
=
std
::
make_shared
<
Var
>
();
mul0_x_
=
std
::
make_shared
<
Var
>
();
mul1_sub_
=
std
::
make_shared
<
Var
>
();
mul2_x_
=
std
::
make_shared
<
Var
>
();
mul3_sub1_
=
std
::
make_shared
<
Var
>
();
mul4_x_
=
std
::
make_shared
<
Var
>
();
add2_y_
=
std
::
make_shared
<
Var
>
();
real_div0_var_
=
std
::
make_shared
<
Var
>
(
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
));
real_div1_var_
=
std
::
make_shared
<
Var
>
(
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
));
real_div2_var_
=
std
::
make_shared
<
Var
>
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimMul
->
name
()));
add0_var_
=
std
::
make_shared
<
Var
>
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimTensorAdd
->
name
()));
add1_var_
=
std
::
make_shared
<
Var
>
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimTensorAdd
->
name
()));
}
~
LambNextMVRule
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
const
BaseRef
DefinePattern
()
const
override
=
0
;
BaseRef
DefineAnotherPattern
()
const
override
=
0
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
bool
IsShareNodes
(
const
EquivPtr
&
equiv1
,
const
EquivPtr
&
equiv2
)
const
override
;
private:
std
::
vector
<
VarPtr
>
input_varptr_
;
bool
IsRuleMatched
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
protected:
bool
IsRuleMatched
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
,
std
::
vector
<
AnfNodePtr
>
*
old_pattern_outputs
)
const
;
AnfNodePtr
CreateLambNextMVNode
(
const
FuncGraphPtr
&
func_graph
,
const
std
::
vector
<
AnfNodePtr
>
&
old_pattern_outputs
,
const
EquivPtr
&
equiv
)
const
;
VarPtr
input0_
;
VarPtr
input1_
;
VarPtr
input2_
;
VarPtr
input3_
;
VarPtr
input4_
;
VarPtr
input5_
;
VarPtr
input6_
;
VarPtr
mul0_x_
;
VarPtr
mul1_sub_
;
VarPtr
mul2_x_
;
VarPtr
mul3_sub1_
;
VarPtr
mul4_x_
;
VarPtr
add2_y_
;
// nodes which two patterns share, and add2_y_ also.
VarPtr
real_div0_var_
;
VarPtr
real_div1_var_
;
// part of output nodes
VarPtr
add0_var_
;
VarPtr
add1_var_
;
// other node
VarPtr
real_div2_var_
;
};
class
LambNextMVRuleCond4
:
public
LambNextMVRule
{
public:
explicit
LambNextMVRuleCond4
(
bool
multigraph
=
true
)
:
LambNextMVRule
(
"lamb_next_mv_rule_cond4"
,
multigraph
)
{}
~
LambNextMVRuleCond4
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
BaseRef
DefineAnotherPattern
()
const
override
;
};
}
// namespace opt
}
// namespace mindspore
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc
浏览文件 @
daccfef7
...
...
@@ -79,63 +79,6 @@ AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGrap
return
GetLambNextMVWithDecayOutput
(
func_graph
,
new_node
,
add3
,
add5
,
equiv
);
}
const
BaseRef
LambNextMVWithDecayRule
::
DefinePattern
()
const
{
const
auto
prim_sqrt
=
std
::
make_shared
<
Primitive
>
(
kSqrtOpName
);
MS_EXCEPTION_IF_NULL
(
prim_sqrt
);
const
auto
prim_deal_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
MS_EXCEPTION_IF_NULL
(
prim_deal_div
);
VectorRef
mul2
=
VectorRef
({
prim
::
kPrimMul
,
constant_mul_input_vars_
[
2
],
input_vars_
[
1
]});
VectorRef
mul3
=
VectorRef
({
prim
::
kPrimMul
,
constant_mul_input_vars_
[
3
],
input_vars_
[
0
]});
VectorRef
add1
=
VectorRef
({
add1_var_
,
mul2
,
mul3
});
VectorRef
real_div1
=
VectorRef
({
real_div1_var_
,
add1
,
input_vars_
[
2
]});
VectorRef
sqrt1
=
VectorRef
({
prim_sqrt
,
real_div1
});
VectorRef
add4
=
VectorRef
({
prim
::
kPrimTensorAdd
,
sqrt1
,
constant_add2_y_
});
VectorRef
mul0
=
VectorRef
({
prim
::
kPrimMul
,
constant_mul_input_vars_
[
0
],
input_vars_
[
4
]});
VectorRef
mul1
=
VectorRef
({
prim
::
kPrimMul
,
constant_mul_input_vars_
[
1
],
input_vars_
[
3
]});
VectorRef
add0
=
VectorRef
({
add0_var_
,
mul0
,
mul1
});
VectorRef
real_div0
=
VectorRef
({
real_div0_var_
,
add0
,
input_vars_
[
5
]});
VectorRef
real_div4
=
VectorRef
({
prim_deal_div
,
real_div0
,
add4
});
VectorRef
mul4
=
VectorRef
({
mul4_var_
,
constant_mul_input_vars_
[
4
],
input_vars_
[
6
]});
VectorRef
add5
=
VectorRef
({
prim
::
kPrimTensorAdd
,
real_div4
,
mul4
});
return
add5
;
}
const
BaseRef
LambNextMVWithDecayRule
::
DefineAnotherPattern
()
const
{
const
auto
prim_rsqrt
=
std
::
make_shared
<
Primitive
>
(
kRsqrtOpName
);
MS_EXCEPTION_IF_NULL
(
prim_rsqrt
);
VarPtr
Xs
=
std
::
make_shared
<
SeqVar
>
();
VarPtr
Ys
=
std
::
make_shared
<
SeqVar
>
();
VarPtr
Zs
=
std
::
make_shared
<
SeqVar
>
();
MS_EXCEPTION_IF_NULL
(
Xs
);
MS_EXCEPTION_IF_NULL
(
Ys
);
MS_EXCEPTION_IF_NULL
(
Zs
);
// Two patterns share: real_div0, real_div1, mul4, constant_add2_y_
VectorRef
real_div0
=
VectorRef
({
real_div0_var_
,
Xs
});
VectorRef
real_div1
=
VectorRef
({
real_div1_var_
,
Ys
});
VectorRef
mul4
=
VectorRef
({
mul4_var_
,
Zs
});
VectorRef
add2
=
VectorRef
({
prim
::
kPrimTensorAdd
,
real_div1
,
constant_add2_y_
});
VectorRef
sqrt0
=
VectorRef
({
prim_rsqrt
,
add2
});
VectorRef
real_div2
=
VectorRef
({
prim
::
kPrimMul
,
real_div0
,
sqrt0
});
VectorRef
add3
=
VectorRef
({
prim
::
kPrimTensorAdd
,
real_div2
,
mul4
});
return
add3
;
}
bool
LambNextMVWithDecayRule
::
MatchAnotherPattern
(
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
equiv
);
VarPtr
fg
=
std
::
make_shared
<
Var
>
(
"RootG"
);
auto
empty_equiv
=
std
::
make_shared
<
Equiv
>
();
MS_EXCEPTION_IF_NULL
(
child_primitive_vars_
);
EquivPtr
another_equiv
=
child_pattern_engine_
.
Match
(
SexpToNode
(
DefineAnotherPattern
(),
fg
,
child_primitive_vars_
.
get
(),
true
),
node
,
*
child_primitive_vars_
,
empty_equiv
);
if
(
another_equiv
!=
nullptr
&&
!
another_equiv
->
empty
())
{
return
IsShareNodes
(
equiv
,
another_equiv
);
}
return
false
;
}
bool
LambNextMVWithDecayRule
::
IsShareNodes
(
const
EquivPtr
&
equiv1
,
const
EquivPtr
&
equiv2
)
const
{
return
IsSameNode
(
equiv1
,
equiv2
,
mul4_var_
)
&&
IsSameNode
(
equiv1
,
equiv2
,
real_div0_var_
)
&&
IsSameNode
(
equiv1
,
equiv2
,
real_div1_var_
)
&&
IsSameNode
(
equiv1
,
equiv2
,
constant_add2_y_
);
...
...
@@ -164,7 +107,7 @@ const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph
return
nullptr
;
}
const
BaseRef
LambNextMVWithDecayRuleCond1
::
DefineAnotherPattern
()
const
{
BaseRef
LambNextMVWithDecayRuleCond1
::
DefineAnotherPattern
()
const
{
const
auto
prim_rsqrt
=
std
::
make_shared
<
Primitive
>
(
kRsqrtOpName
);
MS_EXCEPTION_IF_NULL
(
prim_rsqrt
);
VarPtr
Xs
=
std
::
make_shared
<
SeqVar
>
();
...
...
@@ -205,7 +148,7 @@ const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const {
return
add5
;
}
const
BaseRef
LambNextMVWithDecayRuleCond2
::
DefineAnotherPattern
()
const
{
BaseRef
LambNextMVWithDecayRuleCond2
::
DefineAnotherPattern
()
const
{
const
auto
prim_rsqrt
=
std
::
make_shared
<
Primitive
>
(
kRsqrtOpName
);
MS_EXCEPTION_IF_NULL
(
prim_rsqrt
);
VarPtr
Xs
=
std
::
make_shared
<
SeqVar
>
();
...
...
@@ -246,7 +189,7 @@ const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const {
return
add5
;
}
const
BaseRef
LambNextMVWithDecayRuleCond3
::
DefineAnotherPattern
()
const
{
BaseRef
LambNextMVWithDecayRuleCond3
::
DefineAnotherPattern
()
const
{
const
auto
prim_rsqrt
=
std
::
make_shared
<
Primitive
>
(
kRsqrtOpName
);
MS_EXCEPTION_IF_NULL
(
prim_rsqrt
);
VarPtr
Xs
=
std
::
make_shared
<
SeqVar
>
();
...
...
@@ -286,5 +229,47 @@ const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const {
VectorRef
add5
=
VectorRef
({
prim
::
kPrimTensorAdd
,
mul4
,
real_div4
});
return
add5
;
}
BaseRef
LambNextMVWithDecayRuleCond4
::
DefineAnotherPattern
()
const
{
const
auto
prim_rsqrt
=
std
::
make_shared
<
Primitive
>
(
kRsqrtOpName
);
MS_EXCEPTION_IF_NULL
(
prim_rsqrt
);
VarPtr
Xs
=
std
::
make_shared
<
SeqVar
>
();
VarPtr
Ys
=
std
::
make_shared
<
SeqVar
>
();
VarPtr
Zs
=
std
::
make_shared
<
SeqVar
>
();
MS_EXCEPTION_IF_NULL
(
Xs
);
MS_EXCEPTION_IF_NULL
(
Ys
);
MS_EXCEPTION_IF_NULL
(
Zs
);
// Two patterns share: real_div0, real_div1, mul4, constant_add2_y_
VectorRef
real_div0
=
VectorRef
({
real_div0_var_
,
Xs
});
VectorRef
real_div1
=
VectorRef
({
real_div1_var_
,
Ys
});
VectorRef
mul4
=
VectorRef
({
mul4_var_
,
Zs
});
VectorRef
add2
=
VectorRef
({
prim
::
kPrimTensorAdd
,
real_div1
,
constant_add2_y_
});
VectorRef
sqrt0
=
VectorRef
({
prim_rsqrt
,
add2
});
VectorRef
real_div2
=
VectorRef
({
prim
::
kPrimMul
,
real_div0
,
sqrt0
});
VectorRef
add3
=
VectorRef
({
prim
::
kPrimTensorAdd
,
real_div2
,
mul4
});
return
add3
;
}
const
BaseRef
LambNextMVWithDecayRuleCond4
::
DefinePattern
()
const
{
const
auto
prim_sqrt
=
std
::
make_shared
<
Primitive
>
(
kSqrtOpName
);
MS_EXCEPTION_IF_NULL
(
prim_sqrt
);
const
auto
prim_deal_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
MS_EXCEPTION_IF_NULL
(
prim_deal_div
);
VectorRef
mul2
=
VectorRef
({
prim
::
kPrimMul
,
constant_mul_input_vars_
[
2
],
input_vars_
[
1
]});
VectorRef
mul3
=
VectorRef
({
prim
::
kPrimMul
,
constant_mul_input_vars_
[
3
],
input_vars_
[
0
]});
VectorRef
add1
=
VectorRef
({
add1_var_
,
mul2
,
mul3
});
VectorRef
real_div1
=
VectorRef
({
real_div1_var_
,
add1
,
input_vars_
[
2
]});
VectorRef
sqrt1
=
VectorRef
({
prim_sqrt
,
real_div1
});
VectorRef
add4
=
VectorRef
({
prim
::
kPrimTensorAdd
,
sqrt1
,
constant_add2_y_
});
VectorRef
mul0
=
VectorRef
({
prim
::
kPrimMul
,
constant_mul_input_vars_
[
0
],
input_vars_
[
4
]});
VectorRef
mul1
=
VectorRef
({
prim
::
kPrimMul
,
constant_mul_input_vars_
[
1
],
input_vars_
[
3
]});
VectorRef
add0
=
VectorRef
({
add0_var_
,
mul0
,
mul1
});
VectorRef
real_div0
=
VectorRef
({
real_div0_var_
,
add0
,
input_vars_
[
5
]});
VectorRef
real_div4
=
VectorRef
({
prim_deal_div
,
real_div0
,
add4
});
VectorRef
mul4
=
VectorRef
({
mul4_var_
,
constant_mul_input_vars_
[
4
],
input_vars_
[
6
]});
VectorRef
add5
=
VectorRef
({
prim
::
kPrimTensorAdd
,
real_div4
,
mul4
});
return
add5
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h
浏览文件 @
daccfef7
...
...
@@ -24,15 +24,10 @@
namespace
mindspore
{
namespace
opt
{
class
LambNextMVWithDecayRule
:
public
PatternProcessPass
{
class
LambNextMVWithDecayRule
:
public
MultipleOutput
PatternProcessPass
{
public:
explicit
LambNextMVWithDecayRule
(
const
std
::
string
&
name
=
"lamb_next_mv_with_decay_rule_cond4"
,
bool
multigraph
=
true
)
:
PatternProcessPass
(
name
,
multigraph
),
child_pattern_engine_
(
PatternEngine
(
std
::
make_shared
<
DefaultVisitor
>
(),
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
(
AnfEqual
),
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
(
CNodeTypeEqual
))),
child_primitive_vars_
(
std
::
make_shared
<
PrimitiveVarMap
>
())
{
explicit
LambNextMVWithDecayRule
(
const
std
::
string
&
name
=
""
,
bool
multigraph
=
true
)
:
MultipleOutputPatternProcessPass
(
name
,
multigraph
)
{
for
(
size_t
i
=
0
;
i
<
kLambNextMVWithDecayInputNum
;
++
i
)
{
input_vars_
.
push_back
(
std
::
make_shared
<
Var
>
());
}
...
...
@@ -48,21 +43,16 @@ class LambNextMVWithDecayRule : public PatternProcessPass {
}
~
LambNextMVWithDecayRule
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
virtual
const
BaseRef
DefineAnotherPattern
()
const
;
const
BaseRef
DefinePattern
()
const
override
=
0
;
BaseRef
DefineAnotherPattern
()
const
override
=
0
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
bool
IsShareNodes
(
const
EquivPtr
&
equiv1
,
const
EquivPtr
&
equiv2
)
const
override
;
protected:
bool
MatchAnotherPattern
(
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
)
const
;
// check two patterns whether share the same nodes or not
bool
IsShareNodes
(
const
EquivPtr
&
equiv1
,
const
EquivPtr
&
equiv2
)
const
;
AnfNodePtr
GetLambNextMVWithDecayOutput
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
new_node
,
const
AnfNodePtr
&
add3
,
const
AnfNodePtr
&
add5
,
const
EquivPtr
&
equiv
)
const
;
AnfNodePtr
CreateLambNextMVWithDecayNode
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
add3
,
const
AnfNodePtr
&
add5
,
const
EquivPtr
&
equiv
)
const
;
PatternEngine
child_pattern_engine_
;
PrimitiveVarMapPtr
child_primitive_vars_
;
std
::
vector
<
VarPtr
>
input_vars_
;
std
::
vector
<
VarPtr
>
constant_mul_input_vars_
;
// nodes which two patterns share
...
...
@@ -82,7 +72,7 @@ class LambNextMVWithDecayRuleCond1 : public LambNextMVWithDecayRule {
~
LambNextMVWithDecayRuleCond1
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
const
BaseRef
DefineAnotherPattern
()
const
override
;
BaseRef
DefineAnotherPattern
()
const
override
;
};
class
LambNextMVWithDecayRuleCond2
:
public
LambNextMVWithDecayRule
{
...
...
@@ -92,7 +82,7 @@ class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule {
~
LambNextMVWithDecayRuleCond2
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
const
BaseRef
DefineAnotherPattern
()
const
override
;
BaseRef
DefineAnotherPattern
()
const
override
;
};
class
LambNextMVWithDecayRuleCond3
:
public
LambNextMVWithDecayRule
{
...
...
@@ -102,7 +92,17 @@ class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule {
~
LambNextMVWithDecayRuleCond3
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
const
BaseRef
DefineAnotherPattern
()
const
override
;
BaseRef
DefineAnotherPattern
()
const
override
;
};
class
LambNextMVWithDecayRuleCond4
:
public
LambNextMVWithDecayRule
{
public:
explicit
LambNextMVWithDecayRuleCond4
(
bool
multigraph
=
true
)
:
LambNextMVWithDecayRule
(
"lamb_next_mv_with_decay_rule_cond4"
,
multigraph
)
{}
~
LambNextMVWithDecayRuleCond4
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
BaseRef
DefineAnotherPattern
()
const
override
;
};
}
// namespace opt
}
// namespace mindspore
...
...
mindspore/ccsrc/pre_activate/common/optimizer.cc
浏览文件 @
daccfef7
...
...
@@ -62,6 +62,21 @@ AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNode
return
nullptr
;
}
bool
MultipleOutputPatternProcessPass
::
MatchAnotherPattern
(
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
equiv
);
VarPtr
fg
=
std
::
make_shared
<
Var
>
(
"RootG"
);
auto
empty_equiv
=
std
::
make_shared
<
Equiv
>
();
MS_EXCEPTION_IF_NULL
(
child_primitive_vars_
);
EquivPtr
another_equiv
=
child_pattern_engine_
.
Match
(
SexpToNode
(
DefineAnotherPattern
(),
fg
,
child_primitive_vars_
.
get
(),
true
),
node
,
*
child_primitive_vars_
,
empty_equiv
);
if
(
another_equiv
!=
nullptr
&&
!
another_equiv
->
empty
())
{
return
IsShareNodes
(
equiv
,
another_equiv
);
}
return
false
;
}
void
GraphOptimizer
::
AddPassManager
(
const
PassManagerPtr
&
pass_manager
)
{
if
(
pass_manager
!=
nullptr
)
{
pass_managers_
.
push_back
(
pass_manager
);
...
...
mindspore/ccsrc/pre_activate/common/optimizer.h
浏览文件 @
daccfef7
...
...
@@ -51,6 +51,25 @@ class PatternProcessPass : public NodePass {
PrimitiveVarMapPtr
primitive_vars_
;
};
class
MultipleOutputPatternProcessPass
:
public
PatternProcessPass
{
public:
explicit
MultipleOutputPatternProcessPass
(
const
std
::
string
&
name
=
""
,
bool
multigraph
=
true
)
:
PatternProcessPass
(
name
,
multigraph
),
child_pattern_engine_
(
PatternEngine
(
std
::
make_shared
<
DefaultVisitor
>
(),
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
(
AnfEqual
),
std
::
function
<
bool
(
const
BaseRef
&
,
const
BaseRef
&
)
>
(
CNodeTypeEqual
))),
child_primitive_vars_
(
std
::
make_shared
<
PrimitiveVarMap
>
())
{}
~
MultipleOutputPatternProcessPass
()
override
=
default
;
virtual
BaseRef
DefineAnotherPattern
()
const
=
0
;
// check two patterns whether share the same nodes or not
virtual
bool
IsShareNodes
(
const
EquivPtr
&
equiv1
,
const
EquivPtr
&
equiv2
)
const
=
0
;
protected:
bool
MatchAnotherPattern
(
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
)
const
;
PatternEngine
child_pattern_engine_
;
PrimitiveVarMapPtr
child_primitive_vars_
;
};
class
GraphOptimizer
{
public:
explicit
GraphOptimizer
(
const
std
::
string
&
name
=
"graph_optimizer"
)
:
name_
(
name
)
{}
...
...
tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc
浏览文件 @
daccfef7
...
...
@@ -30,7 +30,7 @@ class TestHWLambNextMVRule : public BackendCommon {
UT
::
PyFuncGraphFetcher
get_py_fun_
;
};
TEST_F
(
TestHWLambNextMVRule
,
test_lamb_next_mv_rule_matched
)
{
TEST_F
(
TestHWLambNextMVRule
,
test_lamb_next_mv_rule_
cond4_
matched
)
{
/*
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
...
...
@@ -54,7 +54,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_matched) {
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_rule"
,
"before"
);
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_rule
_cond4
"
,
"before"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
...
...
@@ -65,15 +65,15 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_matched) {
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVRule
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVRule
Cond4
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_rule"
,
"after"
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_rule
_cond4
"
,
"after"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWLambNextMVRule
,
test_lamb_next_mv_rule_unmatched_real_div4
)
{
TEST_F
(
TestHWLambNextMVRule
,
test_lamb_next_mv_rule_
cond4_
unmatched_real_div4
)
{
/*
* def before_unmatched_real_div4(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
* constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
...
...
@@ -97,7 +97,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div4) {
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_rule"
,
"before_unmatched_real_div4"
);
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_rule
_cond4
"
,
"before_unmatched_real_div4"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
...
...
@@ -109,14 +109,14 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div4) {
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVRule
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVRule
Cond4
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
EXPECT_TRUE
(
CheckEqualGraph
(
origin_graph
,
new_graph
));
}
TEST_F
(
TestHWLambNextMVRule
,
test_lamb_next_mv_rule_unmatched_real_div2
)
{
TEST_F
(
TestHWLambNextMVRule
,
test_lamb_next_mv_rule_
cond4_
unmatched_real_div2
)
{
/*
* def before_unmatched_real_div2(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
* constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
...
...
@@ -140,7 +140,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div2) {
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_rule"
,
"before_unmatched_real_div2"
);
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_rule
_cond4
"
,
"before_unmatched_real_div2"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
...
...
@@ -152,14 +152,14 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div2) {
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVRule
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVRule
Cond4
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
EXPECT_TRUE
(
CheckEqualGraph
(
origin_graph
,
new_graph
));
}
TEST_F
(
TestHWLambNextMVRule
,
test_lamb_next_mv_rule_unmatched_real_div0
)
{
TEST_F
(
TestHWLambNextMVRule
,
test_lamb_next_mv_rule_
cond4_
unmatched_real_div0
)
{
/*
* def before_unmatched_real_div0(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
* constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
...
...
@@ -183,7 +183,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div0) {
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_rule"
,
"before_unmatched_real_div0"
);
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_rule
_cond4
"
,
"before_unmatched_real_div0"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
...
...
@@ -195,14 +195,14 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div0) {
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVRule
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVRule
Cond4
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
EXPECT_TRUE
(
CheckEqualGraph
(
origin_graph
,
new_graph
));
}
TEST_F
(
TestHWLambNextMVRule
,
test_lamb_next_mv_rule_unmatched_real_div1
)
{
TEST_F
(
TestHWLambNextMVRule
,
test_lamb_next_mv_rule_
cond4_
unmatched_real_div1
)
{
/*
* def before_unmatched_real_div1(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
* constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
...
...
@@ -226,7 +226,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div1) {
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_rule"
,
"before_unmatched_real_div1"
);
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_rule
_cond4
"
,
"before_unmatched_real_div1"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
...
...
@@ -238,7 +238,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div1) {
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVRule
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVRule
Cond4
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
...
...
tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc
浏览文件 @
daccfef7
...
...
@@ -30,7 +30,7 @@ class TestHWLambNextMVWithDecayRule : public BackendCommon {
UT
::
PyFuncGraphFetcher
get_py_fun_
;
};
TEST_F
(
TestHWLambNextMVWithDecayRule
,
test_lamb_next_mv_decay_rule_matched
)
{
TEST_F
(
TestHWLambNextMVWithDecayRule
,
test_lamb_next_mv_decay_rule_
cond4_
matched
)
{
/*
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
...
...
@@ -55,7 +55,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_matched) {
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule"
,
"before"
);
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule
_cond4
"
,
"before"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
...
...
@@ -66,15 +66,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_matched) {
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVWithDecayRule
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVWithDecayRule
Cond4
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule"
,
"after"
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule
_cond4
"
,
"after"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWLambNextMVWithDecayRule
,
test_lamb_next_mv_decay_rule_unmatched_add3
)
{
TEST_F
(
TestHWLambNextMVWithDecayRule
,
test_lamb_next_mv_decay_rule_
cond4_
unmatched_add3
)
{
/*
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
...
...
@@ -99,7 +99,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_add
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule"
,
"before_unmatched_add3"
);
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule
_cond4
"
,
"before_unmatched_add3"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
...
...
@@ -111,15 +111,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_add
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVWithDecayRule
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVWithDecayRule
Cond4
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
EXPECT_TRUE
(
CheckEqualGraph
(
origin_graph
,
new_graph
));
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule"
,
"after"
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule
_cond4
"
,
"after"
);
EXPECT_FALSE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWLambNextMVWithDecayRule
,
test_lamb_next_mv_decay_rule_unmatched_mul4
)
{
TEST_F
(
TestHWLambNextMVWithDecayRule
,
test_lamb_next_mv_decay_rule_
cond4_
unmatched_mul4
)
{
/*
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
...
...
@@ -144,7 +144,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_mul
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule"
,
"before_unmatched_mul4"
);
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule
_cond4
"
,
"before_unmatched_mul4"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
...
...
@@ -156,15 +156,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_mul
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVWithDecayRule
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVWithDecayRule
Cond4
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
EXPECT_TRUE
(
CheckEqualGraph
(
origin_graph
,
new_graph
));
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule"
,
"after"
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule
_cond4
"
,
"after"
);
EXPECT_FALSE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWLambNextMVWithDecayRule
,
test_lamb_next_mv_decay_rule_unmatched_real_div0
)
{
TEST_F
(
TestHWLambNextMVWithDecayRule
,
test_lamb_next_mv_decay_rule_
cond4_
unmatched_real_div0
)
{
/*
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
...
...
@@ -189,7 +189,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule"
,
"before_unmatched_real_div0"
);
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule
_cond4
"
,
"before_unmatched_real_div0"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
...
...
@@ -201,15 +201,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVWithDecayRule
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVWithDecayRule
Cond4
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
EXPECT_TRUE
(
CheckEqualGraph
(
origin_graph
,
new_graph
));
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule"
,
"after"
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule
_cond4
"
,
"after"
);
EXPECT_FALSE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWLambNextMVWithDecayRule
,
test_lamb_next_mv_decay_rule_unmatched_real_div1
)
{
TEST_F
(
TestHWLambNextMVWithDecayRule
,
test_lamb_next_mv_decay_rule_
cond4_
unmatched_real_div1
)
{
/*
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
...
...
@@ -234,7 +234,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule"
,
"before_unmatched_real_div1"
);
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule
_cond4
"
,
"before_unmatched_real_div1"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
...
...
@@ -246,11 +246,11 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVWithDecayRule
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
LambNextMVWithDecayRule
Cond4
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
EXPECT_TRUE
(
CheckEqualGraph
(
origin_graph
,
new_graph
));
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule"
,
"after"
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_lamb_next_mv_with_decay_rule
_cond4
"
,
"after"
);
EXPECT_FALSE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
...
...
tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_rule_test.py
浏览文件 @
daccfef7
...
...
@@ -36,7 +36,7 @@ class FnDict:
return
self
.
fnDict
[
name
]
def
test_lamb_next_mv_rule
(
tag
):
def
test_lamb_next_mv_rule
_cond4
(
tag
):
fns
=
FnDict
()
@
fns
...
...
tests/ut/cpp/python_input/gtest_input/pre_activate/lamb_next_mv_with_decay_rule_test.py
浏览文件 @
daccfef7
...
...
@@ -34,7 +34,7 @@ class FnDict:
def
__getitem__
(
self
,
name
):
return
self
.
fnDict
[
name
]
def
test_lamb_next_mv_with_decay_rule
(
tag
):
def
test_lamb_next_mv_with_decay_rule
_cond4
(
tag
):
fns
=
FnDict
()
@
fns
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录