Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8a777519
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8a777519
编写于
8月 08, 2020
作者:
Y
yujianfeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add AdamApplyOneAssign and AdamApplyOneWithDecayAssign fusion pass
上级
fe2c2e83
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
434 addition
and
28 deletion
+434
-28
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
...c/backend/optimizer/ascend/ascend_backend_optimization.cc
+4
-0
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.cc
...ckend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.cc
+142
-27
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h
...ackend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h
+49
-1
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+1
-0
tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc
...e_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc
+100
-0
tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_fusion_test.py
...ut/gtest_input/pre_activate/adam_apply_one_fusion_test.py
+138
-0
未找到文件。
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
浏览文件 @
8a777519
...
...
@@ -124,6 +124,10 @@ void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) {
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambNextMVRuleCond4
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambNextRightRule
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
LambUpdateWithLrV2
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
AdamApplyOneAssignCond1Fusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
AdamApplyOneAssignCond2Fusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
AdamApplyOneAssignCond3Fusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
AdamApplyOneAssignCond4Fusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
AdamApplyOneCond1Fusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
AdamApplyOneCond2Fusion
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
AdamApplyOneCond3Fusion
>
());
...
...
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.cc
浏览文件 @
8a777519
...
...
@@ -15,30 +15,9 @@
*/
#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h"
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace
mindspore
{
namespace
opt
{
AnfNodePtr
AdamApplyOneFusion
::
CreateAdamApplyOneNode
(
const
FuncGraphPtr
&
func_graph
,
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
equiv
);
auto
prim
=
std
::
make_shared
<
Primitive
>
(
kAdamApplyOneOpName
);
std
::
vector
<
AnfNodePtr
>
new_node_inputs
=
{
NewValueNode
(
prim
)};
for
(
const
auto
&
input_var
:
input_vars_
)
{
auto
input_node
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input_var
]);
MS_EXCEPTION_IF_NULL
(
input_node
);
new_node_inputs
.
push_back
(
input_node
);
}
for
(
const
auto
&
mul_x_input_var
:
mul_x_input_vars_
)
{
auto
mul_x_input_node
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
mul_x_input_var
]);
MS_EXCEPTION_IF_NULL
(
mul_x_input_node
);
new_node_inputs
.
push_back
(
mul_x_input_node
);
}
auto
add2_y_node
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
add2_y_
]);
MS_EXCEPTION_IF_NULL
(
add2_y_node
);
new_node_inputs
.
push_back
(
add2_y_node
);
auto
new_node
=
func_graph
->
NewCNode
(
new_node_inputs
);
return
new_node
;
}
const
BaseRef
AdamApplyOneFusion
::
DefinePattern
()
const
{
const
auto
prim_sqrt
=
std
::
make_shared
<
Primitive
>
(
kSqrtOpName
);
const
auto
prim_real_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
...
...
@@ -104,16 +83,152 @@ const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const {
return
VectorRef
({
prim
::
kPrimSub
,
input_vars_
[
3
],
VectorRef
({
prim
::
kPrimMul
,
true_div0
,
input_vars_
[
4
]})});
}
const
BaseRef
AdamApplyOneAssignFusion
::
DefinePattern
()
const
{
const
auto
prim_sqrt
=
std
::
make_shared
<
Primitive
>
(
kSqrtOpName
);
const
auto
prim_real_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
VectorRef
mul2
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
2
],
input_vars_
[
1
]});
VectorRef
mul3
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
3
],
VectorRef
({
prim
::
kPrimSquare
,
input_vars_
[
0
]})});
VectorRef
add1
=
VectorRef
({
add1_var_
,
mul2
,
mul3
});
VectorRef
sqrt0
=
VectorRef
({
prim_sqrt
,
add1
});
VectorRef
mul1
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
1
],
input_vars_
[
0
]});
VectorRef
mul0
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
0
],
input_vars_
[
2
]});
VectorRef
add0
=
VectorRef
({
add0_var_
,
mul0
,
mul1
});
VectorRef
true_div0
=
VectorRef
({
prim_real_div
,
add0
,
VectorRef
({
prim
::
kPrimTensorAdd
,
sqrt0
,
add2_y_
})});
VectorRef
sub0
=
VectorRef
({
sub0_var_
,
input_vars_
[
3
],
VectorRef
({
prim
::
kPrimMul
,
input_vars_
[
4
],
true_div0
})});
VectorRef
assign0
=
VectorRef
({
prim
::
kPrimAssign
,
input_vars_
[
3
],
sub0
});
VectorRef
depend0
=
VectorRef
({
prim
::
kPrimDepend
,
sub0
,
assign0
});
VectorRef
assign1
=
VectorRef
({
prim
::
kPrimAssign
,
input_vars_
[
2
],
add0
});
VectorRef
depend1
=
VectorRef
({
prim
::
kPrimDepend
,
depend0
,
assign1
});
VectorRef
assign2
=
VectorRef
({
prim
::
kPrimAssign
,
input_vars_
[
1
],
add1
});
return
VectorRef
({
prim
::
kPrimDepend
,
depend1
,
assign2
});
}
const
BaseRef
AdamApplyOneAssignCond1Fusion
::
DefinePattern
()
const
{
const
auto
prim_sqrt
=
std
::
make_shared
<
Primitive
>
(
kSqrtOpName
);
const
auto
prim_real_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
VectorRef
mul2
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
2
],
input_vars_
[
1
]});
VectorRef
mul3
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
3
],
VectorRef
({
prim
::
kPrimSquare
,
input_vars_
[
0
]})});
VectorRef
add1
=
VectorRef
({
add1_var_
,
mul2
,
mul3
});
VectorRef
sqrt0
=
VectorRef
({
prim_sqrt
,
add1
});
VectorRef
mul1
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
1
],
input_vars_
[
0
]});
VectorRef
mul0
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
0
],
input_vars_
[
2
]});
VectorRef
add0
=
VectorRef
({
add0_var_
,
mul0
,
mul1
});
VectorRef
true_div0
=
VectorRef
({
prim_real_div
,
add0
,
VectorRef
({
prim
::
kPrimTensorAdd
,
add2_y_
,
sqrt0
})});
VectorRef
sub0
=
VectorRef
({
sub0_var_
,
input_vars_
[
3
],
VectorRef
({
prim
::
kPrimMul
,
input_vars_
[
4
],
true_div0
})});
VectorRef
assign0
=
VectorRef
({
prim
::
kPrimAssign
,
input_vars_
[
3
],
sub0
});
VectorRef
depend0
=
VectorRef
({
prim
::
kPrimDepend
,
sub0
,
assign0
});
VectorRef
assign1
=
VectorRef
({
prim
::
kPrimAssign
,
input_vars_
[
2
],
add0
});
VectorRef
depend1
=
VectorRef
({
prim
::
kPrimDepend
,
depend0
,
assign1
});
VectorRef
assign2
=
VectorRef
({
prim
::
kPrimAssign
,
input_vars_
[
1
],
add1
});
return
VectorRef
({
prim
::
kPrimDepend
,
depend1
,
assign2
});
}
const
BaseRef
AdamApplyOneAssignCond2Fusion
::
DefinePattern
()
const
{
const
auto
prim_sqrt
=
std
::
make_shared
<
Primitive
>
(
kSqrtOpName
);
const
auto
prim_real_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
VectorRef
mul2
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
2
],
input_vars_
[
1
]});
VectorRef
mul3
=
VectorRef
({
prim
::
kPrimMul
,
VectorRef
({
prim
::
kPrimSquare
,
input_vars_
[
0
]}),
mul_x_input_vars_
[
3
]});
VectorRef
add1
=
VectorRef
({
add1_var_
,
mul2
,
mul3
});
VectorRef
sqrt0
=
VectorRef
({
prim_sqrt
,
add1
});
VectorRef
mul1
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
1
],
input_vars_
[
0
]});
VectorRef
mul0
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
0
],
input_vars_
[
2
]});
VectorRef
add0
=
VectorRef
({
add0_var_
,
mul0
,
mul1
});
VectorRef
true_div0
=
VectorRef
({
prim_real_div
,
add0
,
VectorRef
({
prim
::
kPrimTensorAdd
,
sqrt0
,
add2_y_
})});
VectorRef
sub0
=
VectorRef
({
sub0_var_
,
input_vars_
[
3
],
VectorRef
({
prim
::
kPrimMul
,
true_div0
,
input_vars_
[
4
]})});
VectorRef
assign0
=
VectorRef
({
prim
::
kPrimAssign
,
input_vars_
[
3
],
sub0
});
VectorRef
depend0
=
VectorRef
({
prim
::
kPrimDepend
,
sub0
,
assign0
});
VectorRef
assign1
=
VectorRef
({
prim
::
kPrimAssign
,
input_vars_
[
2
],
add0
});
VectorRef
depend1
=
VectorRef
({
prim
::
kPrimDepend
,
depend0
,
assign1
});
VectorRef
assign2
=
VectorRef
({
prim
::
kPrimAssign
,
input_vars_
[
1
],
add1
});
return
VectorRef
({
prim
::
kPrimDepend
,
depend1
,
assign2
});
}
const
BaseRef
AdamApplyOneAssignCond3Fusion
::
DefinePattern
()
const
{
const
auto
prim_sqrt
=
std
::
make_shared
<
Primitive
>
(
kSqrtOpName
);
const
auto
prim_real_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
VectorRef
mul2
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
2
],
input_vars_
[
1
]});
VectorRef
mul3
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
3
],
VectorRef
({
prim
::
kPrimSquare
,
input_vars_
[
0
]})});
VectorRef
add1
=
VectorRef
({
add1_var_
,
mul2
,
mul3
});
VectorRef
sqrt0
=
VectorRef
({
prim_sqrt
,
add1
});
VectorRef
mul1
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
1
],
input_vars_
[
0
]});
VectorRef
mul0
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
0
],
input_vars_
[
2
]});
VectorRef
add0
=
VectorRef
({
add0_var_
,
mul0
,
mul1
});
VectorRef
true_div0
=
VectorRef
({
prim_real_div
,
add0
,
VectorRef
({
prim
::
kPrimTensorAdd
,
sqrt0
,
add2_y_
})});
VectorRef
sub0
=
VectorRef
({
sub0_var_
,
input_vars_
[
3
],
VectorRef
({
prim
::
kPrimMul
,
true_div0
,
input_vars_
[
4
]})});
VectorRef
assign0
=
VectorRef
({
prim
::
kPrimAssign
,
input_vars_
[
3
],
sub0
});
VectorRef
depend0
=
VectorRef
({
prim
::
kPrimDepend
,
sub0
,
assign0
});
VectorRef
assign1
=
VectorRef
({
prim
::
kPrimAssign
,
input_vars_
[
2
],
add0
});
VectorRef
depend1
=
VectorRef
({
prim
::
kPrimDepend
,
depend0
,
assign1
});
VectorRef
assign2
=
VectorRef
({
prim
::
kPrimAssign
,
input_vars_
[
1
],
add1
});
return
VectorRef
({
prim
::
kPrimDepend
,
depend1
,
assign2
});
}
const
BaseRef
AdamApplyOneAssignCond4Fusion
::
DefinePattern
()
const
{
const
auto
prim_sqrt
=
std
::
make_shared
<
Primitive
>
(
kSqrtOpName
);
const
auto
prim_real_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
VectorRef
mul2
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
2
],
input_vars_
[
1
]});
VectorRef
mul3
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
3
],
VectorRef
({
prim
::
kPrimSquare
,
input_vars_
[
0
]})});
VectorRef
add1
=
VectorRef
({
add1_var_
,
mul2
,
mul3
});
VectorRef
sqrt0
=
VectorRef
({
prim_sqrt
,
add1
});
VectorRef
mul1
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
1
],
input_vars_
[
0
]});
VectorRef
mul0
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
0
],
input_vars_
[
2
]});
VectorRef
add0
=
VectorRef
({
add0_var_
,
mul0
,
mul1
});
VectorRef
true_div0
=
VectorRef
({
prim_real_div
,
add0
,
VectorRef
({
prim
::
kPrimTensorAdd
,
add2_y_
,
sqrt0
})});
VectorRef
sub0
=
VectorRef
({
sub0_var_
,
input_vars_
[
3
],
VectorRef
({
prim
::
kPrimMul
,
true_div0
,
input_vars_
[
4
]})});
VectorRef
assign0
=
VectorRef
({
prim
::
kPrimAssign
,
input_vars_
[
3
],
sub0
});
VectorRef
depend0
=
VectorRef
({
prim
::
kPrimDepend
,
sub0
,
assign0
});
VectorRef
assign1
=
VectorRef
({
prim
::
kPrimAssign
,
input_vars_
[
2
],
add0
});
VectorRef
depend1
=
VectorRef
({
prim
::
kPrimDepend
,
depend0
,
assign1
});
VectorRef
assign2
=
VectorRef
({
prim
::
kPrimAssign
,
input_vars_
[
1
],
add1
});
return
VectorRef
({
prim
::
kPrimDepend
,
depend1
,
assign2
});
}
AnfNodePtr
AdamApplyOneFusion
::
CreateAdamApplyOneNode
(
const
FuncGraphPtr
&
func_graph
,
const
EquivPtr
&
equiv
,
const
AnfNodePtr
&
final_node
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
equiv
);
PrimitivePtr
prim
=
nullptr
;
if
(
AnfAlgo
::
CheckPrimitiveType
(
final_node
,
prim
::
kPrimDepend
))
{
prim
=
std
::
make_shared
<
Primitive
>
(
kAdamApplyOneAssignOpName
);
}
else
{
prim
=
std
::
make_shared
<
Primitive
>
(
kAdamApplyOneOpName
);
}
std
::
vector
<
AnfNodePtr
>
new_node_inputs
=
{
NewValueNode
(
prim
)};
for
(
const
auto
&
input_var
:
input_vars_
)
{
auto
input_node
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
input_var
]);
MS_EXCEPTION_IF_NULL
(
input_node
);
new_node_inputs
.
push_back
(
input_node
);
}
for
(
const
auto
&
mul_x_input_var
:
mul_x_input_vars_
)
{
auto
mul_x_input_node
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
mul_x_input_var
]);
MS_EXCEPTION_IF_NULL
(
mul_x_input_node
);
new_node_inputs
.
push_back
(
mul_x_input_node
);
}
auto
add2_y_node
=
utils
::
cast
<
AnfNodePtr
>
((
*
equiv
)[
add2_y_
]);
MS_EXCEPTION_IF_NULL
(
add2_y_node
);
new_node_inputs
.
push_back
(
add2_y_node
);
auto
new_node
=
func_graph
->
NewCNode
(
new_node_inputs
);
return
new_node
;
}
const
AnfNodePtr
AdamApplyOneFusion
::
Process
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_EXCEPTION_IF_NULL
(
node
);
if
(
!
CheckSupportDataType
(
node
,
kFloatDataTypeSet
))
{
auto
sub0
=
node
;
if
(
AnfAlgo
::
CheckPrimitiveType
(
node
,
prim
::
kPrimDepend
))
{
auto
iter_sub0
=
(
*
equiv
).
find
(
sub0_var_
);
if
(
iter_sub0
==
(
*
equiv
).
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The equiv map is expected to contains the sub0 var after matched."
;
}
sub0
=
utils
::
cast
<
AnfNodePtr
>
(
iter_sub0
->
second
);
}
MS_EXCEPTION_IF_NULL
(
sub0
);
if
(
!
CheckSupportDataType
(
sub0
,
kFloatDataTypeSet
))
{
return
nullptr
;
}
auto
new_node
=
CreateAdamApplyOneNode
(
func_graph
,
equiv
);
auto
new_node
=
CreateAdamApplyOneNode
(
func_graph
,
equiv
,
node
);
MS_EXCEPTION_IF_NULL
(
new_node
);
new_node
->
set_scope
(
node
->
scope
());
new_node
->
set_scope
(
sub0
->
scope
());
// Set abstract of new node
AbstractBasePtrList
new_node_abstract_list
;
auto
iter_add0
=
(
*
equiv
).
find
(
add0_var_
);
...
...
@@ -130,7 +245,7 @@ const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, con
MS_EXCEPTION_IF_NULL
(
add1
);
new_node_abstract_list
.
push_back
(
add1
->
abstract
());
new_node_abstract_list
.
push_back
(
add0
->
abstract
());
new_node_abstract_list
.
push_back
(
node
->
abstract
());
new_node_abstract_list
.
push_back
(
sub0
->
abstract
());
auto
abstract_tuple
=
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
new_node_abstract_list
);
new_node
->
set_abstract
(
abstract_tuple
);
// Create tuple_getitem node for outputs
...
...
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h
浏览文件 @
8a777519
...
...
@@ -40,6 +40,7 @@ class AdamApplyOneFusion : public PatternProcessPass {
add2_y_
=
std
::
make_shared
<
Var
>
();
add0_var_
=
std
::
make_shared
<
Var
>
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimTensorAdd
->
name
()));
add1_var_
=
std
::
make_shared
<
Var
>
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimTensorAdd
->
name
()));
sub0_var_
=
std
::
make_shared
<
Var
>
(
std
::
make_shared
<
Primitive
>
(
prim
::
kPrimSub
->
name
()));
}
~
AdamApplyOneFusion
()
override
=
default
;
...
...
@@ -47,12 +48,14 @@ class AdamApplyOneFusion : public PatternProcessPass {
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
protected:
AnfNodePtr
CreateAdamApplyOneNode
(
const
FuncGraphPtr
&
func_graph
,
const
EquivPtr
&
equiv
)
const
;
AnfNodePtr
CreateAdamApplyOneNode
(
const
FuncGraphPtr
&
func_graph
,
const
EquivPtr
&
equiv
,
const
AnfNodePtr
&
final_node
)
const
;
std
::
vector
<
VarPtr
>
input_vars_
;
std
::
vector
<
VarPtr
>
mul_x_input_vars_
;
VarPtr
add2_y_
;
VarPtr
add0_var_
;
VarPtr
add1_var_
;
VarPtr
sub0_var_
;
};
class
AdamApplyOneCond1Fusion
:
public
AdamApplyOneFusion
{
...
...
@@ -90,6 +93,51 @@ class AdamApplyOneCond4Fusion : public AdamApplyOneFusion {
~
AdamApplyOneCond4Fusion
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
};
class
AdamApplyOneAssignFusion
:
public
AdamApplyOneFusion
{
public:
explicit
AdamApplyOneAssignFusion
(
bool
multigraph
=
true
)
:
AdamApplyOneFusion
(
"adam_apply_one_assign_fusion"
,
multigraph
)
{}
~
AdamApplyOneAssignFusion
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
};
class
AdamApplyOneAssignCond1Fusion
:
public
AdamApplyOneFusion
{
public:
explicit
AdamApplyOneAssignCond1Fusion
(
bool
multigraph
=
true
)
:
AdamApplyOneFusion
(
"adam_apply_one_assign_cond1_fusion"
,
multigraph
)
{}
~
AdamApplyOneAssignCond1Fusion
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
};
class
AdamApplyOneAssignCond2Fusion
:
public
AdamApplyOneFusion
{
public:
explicit
AdamApplyOneAssignCond2Fusion
(
bool
multigraph
=
true
)
:
AdamApplyOneFusion
(
"adam_apply_one_assign_cond2_fusion"
,
multigraph
)
{}
~
AdamApplyOneAssignCond2Fusion
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
};
class
AdamApplyOneAssignCond3Fusion
:
public
AdamApplyOneFusion
{
public:
explicit
AdamApplyOneAssignCond3Fusion
(
bool
multigraph
=
true
)
:
AdamApplyOneFusion
(
"adam_apply_one_assign_cond3_fusion"
,
multigraph
)
{}
~
AdamApplyOneAssignCond3Fusion
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
};
class
AdamApplyOneAssignCond4Fusion
:
public
AdamApplyOneFusion
{
public:
explicit
AdamApplyOneAssignCond4Fusion
(
bool
multigraph
=
true
)
:
AdamApplyOneFusion
(
"adam_apply_one_assign_cond4_fusion"
,
multigraph
)
{}
~
AdamApplyOneAssignCond4Fusion
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_
mindspore/ccsrc/utils/utils.h
浏览文件 @
8a777519
...
...
@@ -119,6 +119,7 @@ constexpr auto kAdamApplyOneWithDecayOpName = "AdamApplyOneWithDecay";
constexpr
auto
kBatchNormGradOpName
=
"BatchNormGrad"
;
constexpr
auto
kBNInferOpName
=
"BNInfer"
;
constexpr
auto
kAdamApplyOneOpName
=
"AdamApplyOne"
;
constexpr
auto
kAdamApplyOneAssignOpName
=
"AdamApplyOneAssign"
;
constexpr
auto
kResizeNearestNeighborGradOpName
=
"ResizeNearestNeighborGrad"
;
constexpr
auto
kFusedMulAddOpName
=
"FusedMulAdd"
;
constexpr
auto
kFusedMulAddNOpName
=
"FusedMulAddN"
;
...
...
tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc
浏览文件 @
8a777519
...
...
@@ -217,5 +217,105 @@ TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond4_fusion) {
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_fusion"
,
"after"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWAdamApplyOneFusion
,
test_adam_apply_one_assign_fusion
)
{
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_assign_fusion"
,
"before"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
for
(
size_t
i
=
0
;
i
<
10
;
++
i
)
{
args_spec_list
.
push_back
(
x_abstract
);
}
auto
fg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
AdamApplyOneAssignFusion
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_assign_fusion"
,
"after"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWAdamApplyOneFusion
,
test_adam_apply_one_assign_cond1_fusion
)
{
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_assign_fusion"
,
"before_cond1"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
for
(
size_t
i
=
0
;
i
<
10
;
++
i
)
{
args_spec_list
.
push_back
(
x_abstract
);
}
auto
fg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
AdamApplyOneAssignCond1Fusion
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_assign_fusion"
,
"after"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWAdamApplyOneFusion
,
test_adam_apply_one_assign_cond2_fusion
)
{
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_assign_fusion"
,
"before_cond2"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
for
(
size_t
i
=
0
;
i
<
10
;
++
i
)
{
args_spec_list
.
push_back
(
x_abstract
);
}
auto
fg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
AdamApplyOneAssignCond2Fusion
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_assign_fusion"
,
"after"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWAdamApplyOneFusion
,
test_adam_apply_one_assign_cond3_fusion
)
{
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_assign_fusion"
,
"before_cond3"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
for
(
size_t
i
=
0
;
i
<
10
;
++
i
)
{
args_spec_list
.
push_back
(
x_abstract
);
}
auto
fg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
AdamApplyOneAssignCond3Fusion
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_assign_fusion"
,
"after"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWAdamApplyOneFusion
,
test_adam_apply_one_assign_cond4_fusion
)
{
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_assign_fusion"
,
"before_cond4"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
for
(
size_t
i
=
0
;
i
<
10
;
++
i
)
{
args_spec_list
.
push_back
(
x_abstract
);
}
auto
fg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
AdamApplyOneAssignCond4Fusion
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_assign_fusion"
,
"after"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
}
// namespace opt
}
// namespace mindspore
tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_fusion_test.py
浏览文件 @
8a777519
...
...
@@ -14,6 +14,7 @@
# ============================================================================
from
mindspore.ops
import
Primitive
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
Add
=
P
.
TensorAdd
()
Sub
=
P
.
Sub
()
...
...
@@ -21,9 +22,11 @@ Mul = P.Mul()
RealDiv
=
P
.
RealDiv
()
Sqrt
=
P
.
Sqrt
()
Square
=
P
.
Square
()
Assign
=
P
.
Assign
()
make_tuple
=
Primitive
(
'make_tuple'
)
tuple_getitem
=
Primitive
(
'tuple_getitem'
)
AdamApplyOne
=
Primitive
(
'AdamApplyOne'
)
AdamApplyOneAssign
=
Primitive
(
'AdamApplyOneAssign'
)
class
FnDict
:
...
...
@@ -139,3 +142,138 @@ def test_adam_apply_one_fusion(tag):
return
make_tuple
(
output
)
return
fns
[
tag
]
def
test_adam_apply_one_assign_fusion
(
tag
):
fns
=
FnDict
()
@
fns
def
before
(
input0
,
input1
,
input2
,
input3
,
input4
,
mul0_x
,
mul1_x
,
mul2_x
,
mul3_x
,
add2_y
):
square0
=
Square
(
input0
)
mul1
=
Mul
(
mul1_x
,
input0
)
mul0
=
Mul
(
mul0_x
,
input2
)
mul2
=
Mul
(
mul2_x
,
input1
)
mul3
=
Mul
(
mul3_x
,
square0
)
add0
=
Add
(
mul0
,
mul1
)
add1
=
Add
(
mul2
,
mul3
)
sqrt0
=
Sqrt
(
add1
)
add2
=
Add
(
sqrt0
,
add2_y
)
true_div0
=
RealDiv
(
add0
,
add2
)
mul4
=
Mul
(
input4
,
true_div0
)
sub0
=
Sub
(
input3
,
mul4
)
assign0
=
Assign
(
input3
,
sub0
)
depend0
=
F
.
depend
(
sub0
,
assign0
)
assign1
=
Assign
(
input2
,
add0
)
depend1
=
F
.
depend
(
depend0
,
assign1
)
assign2
=
Assign
(
input1
,
add1
)
depend2
=
F
.
depend
(
depend1
,
assign2
)
outputs
=
make_tuple
(
add1
,
add0
,
depend2
)
output
=
tuple_getitem
(
outputs
,
0
)
return
output
@
fns
def
before_cond1
(
input0
,
input1
,
input2
,
input3
,
input4
,
mul0_x
,
mul1_x
,
mul2_x
,
mul3_x
,
add2_y
):
square0
=
Square
(
input0
)
mul1
=
Mul
(
mul1_x
,
input0
)
mul0
=
Mul
(
mul0_x
,
input2
)
mul2
=
Mul
(
mul2_x
,
input1
)
mul3
=
Mul
(
mul3_x
,
square0
)
add0
=
Add
(
mul0
,
mul1
)
add1
=
Add
(
mul2
,
mul3
)
sqrt0
=
Sqrt
(
add1
)
add2
=
Add
(
add2_y
,
sqrt0
)
true_div0
=
RealDiv
(
add0
,
add2
)
mul4
=
Mul
(
input4
,
true_div0
)
sub0
=
Sub
(
input3
,
mul4
)
assign0
=
Assign
(
input3
,
sub0
)
depend0
=
F
.
depend
(
sub0
,
assign0
)
assign1
=
Assign
(
input2
,
add0
)
depend1
=
F
.
depend
(
depend0
,
assign1
)
assign2
=
Assign
(
input1
,
add1
)
depend2
=
F
.
depend
(
depend1
,
assign2
)
outputs
=
make_tuple
(
add1
,
add0
,
depend2
)
output
=
tuple_getitem
(
outputs
,
0
)
return
output
@
fns
def
before_cond2
(
input0
,
input1
,
input2
,
input3
,
input4
,
mul0_x
,
mul1_x
,
mul2_x
,
mul3_x
,
add2_y
):
square0
=
Square
(
input0
)
mul1
=
Mul
(
mul1_x
,
input0
)
mul0
=
Mul
(
mul0_x
,
input2
)
mul2
=
Mul
(
mul2_x
,
input1
)
mul3
=
Mul
(
square0
,
mul3_x
)
add0
=
Add
(
mul0
,
mul1
)
add1
=
Add
(
mul2
,
mul3
)
sqrt0
=
Sqrt
(
add1
)
add2
=
Add
(
sqrt0
,
add2_y
)
true_div0
=
RealDiv
(
add0
,
add2
)
mul4
=
Mul
(
true_div0
,
input4
)
sub0
=
Sub
(
input3
,
mul4
)
assign0
=
Assign
(
input3
,
sub0
)
depend0
=
F
.
depend
(
sub0
,
assign0
)
assign1
=
Assign
(
input2
,
add0
)
depend1
=
F
.
depend
(
depend0
,
assign1
)
assign2
=
Assign
(
input1
,
add1
)
depend2
=
F
.
depend
(
depend1
,
assign2
)
outputs
=
make_tuple
(
add1
,
add0
,
depend2
)
output
=
tuple_getitem
(
outputs
,
0
)
return
output
@
fns
def
before_cond3
(
input0
,
input1
,
input2
,
input3
,
input4
,
mul0_x
,
mul1_x
,
mul2_x
,
mul3_x
,
add2_y
):
square0
=
Square
(
input0
)
mul1
=
Mul
(
mul1_x
,
input0
)
mul0
=
Mul
(
mul0_x
,
input2
)
mul2
=
Mul
(
mul2_x
,
input1
)
mul3
=
Mul
(
mul3_x
,
square0
)
add0
=
Add
(
mul0
,
mul1
)
add1
=
Add
(
mul2
,
mul3
)
sqrt0
=
Sqrt
(
add1
)
add2
=
Add
(
sqrt0
,
add2_y
)
true_div0
=
RealDiv
(
add0
,
add2
)
mul4
=
Mul
(
true_div0
,
input4
)
sub0
=
Sub
(
input3
,
mul4
)
assign0
=
Assign
(
input3
,
sub0
)
depend0
=
F
.
depend
(
sub0
,
assign0
)
assign1
=
Assign
(
input2
,
add0
)
depend1
=
F
.
depend
(
depend0
,
assign1
)
assign2
=
Assign
(
input1
,
add1
)
depend2
=
F
.
depend
(
depend1
,
assign2
)
outputs
=
make_tuple
(
add1
,
add0
,
depend2
)
output
=
tuple_getitem
(
outputs
,
0
)
return
output
@
fns
def
before_cond4
(
input0
,
input1
,
input2
,
input3
,
input4
,
mul0_x
,
mul1_x
,
mul2_x
,
mul3_x
,
add2_y
):
square0
=
Square
(
input0
)
mul1
=
Mul
(
mul1_x
,
input0
)
mul0
=
Mul
(
mul0_x
,
input2
)
mul2
=
Mul
(
mul2_x
,
input1
)
mul3
=
Mul
(
mul3_x
,
square0
)
add0
=
Add
(
mul0
,
mul1
)
add1
=
Add
(
mul2
,
mul3
)
sqrt0
=
Sqrt
(
add1
)
add2
=
Add
(
add2_y
,
sqrt0
)
true_div0
=
RealDiv
(
add0
,
add2
)
mul4
=
Mul
(
true_div0
,
input4
)
sub0
=
Sub
(
input3
,
mul4
)
assign0
=
Assign
(
input3
,
sub0
)
depend0
=
F
.
depend
(
sub0
,
assign0
)
assign1
=
Assign
(
input2
,
add0
)
depend1
=
F
.
depend
(
depend0
,
assign1
)
assign2
=
Assign
(
input1
,
add1
)
depend2
=
F
.
depend
(
depend1
,
assign2
)
outputs
=
make_tuple
(
add1
,
add0
,
depend2
)
output
=
tuple_getitem
(
outputs
,
0
)
return
output
@
fns
def
after
(
input0
,
input1
,
input2
,
input3
,
input4
,
mul0_x
,
mul1_x
,
mul2_x
,
mul3_x
,
add2_y
):
adam_apply_one_assign
=
AdamApplyOneAssign
(
input0
,
input1
,
input2
,
input3
,
input4
,
mul0_x
,
mul1_x
,
mul2_x
,
mul3_x
,
add2_y
)
outputs
=
make_tuple
(
tuple_getitem
(
adam_apply_one_assign
,
0
),
tuple_getitem
(
adam_apply_one_assign
,
1
),
tuple_getitem
(
adam_apply_one_assign
,
2
))
output
=
tuple_getitem
(
outputs
,
0
)
return
make_tuple
(
output
)
return
fns
[
tag
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录