Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5eb53798
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5eb53798
编写于
4月 17, 2020
作者:
Y
YuJianfeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add AdamApplyOne fusion pass
上级
80333e9f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
321 addition
and
8 deletion
+321
-8
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc
...rc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc
+54
-2
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h
...src/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h
+44
-6
tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc
...e_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc
+151
-0
tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_fusion_test.py
...ut/gtest_input/pre_activate/adam_apply_one_fusion_test.py
+72
-0
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc
浏览文件 @
5eb53798
...
...
@@ -42,17 +42,69 @@ AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_g
const
BaseRef
AdamApplyOneFusion
::
DefinePattern
()
const
{
const
auto
prim_sqrt
=
std
::
make_shared
<
Primitive
>
(
kSqrtOpName
);
const
auto
prim_
d
eal_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
const
auto
prim_
r
eal_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
VectorRef
mul2
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
2
],
input_vars_
[
1
]});
VectorRef
mul3
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
3
],
VectorRef
({
prim
::
kPrimSquare
,
input_vars_
[
0
]})});
VectorRef
sqrt0
=
VectorRef
({
prim_sqrt
,
VectorRef
({
add1_var_
,
mul2
,
mul3
})});
VectorRef
mul1
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
1
],
input_vars_
[
0
]});
VectorRef
mul0
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
0
],
input_vars_
[
2
]});
VectorRef
add0
=
VectorRef
({
add0_var_
,
mul0
,
mul1
});
VectorRef
true_div0
=
VectorRef
({
prim_
d
eal_div
,
add0
,
VectorRef
({
prim
::
kPrimTensorAdd
,
sqrt0
,
add2_y_
})});
VectorRef
true_div0
=
VectorRef
({
prim_
r
eal_div
,
add0
,
VectorRef
({
prim
::
kPrimTensorAdd
,
sqrt0
,
add2_y_
})});
return
VectorRef
({
prim
::
kPrimSub
,
input_vars_
[
3
],
VectorRef
({
prim
::
kPrimMul
,
input_vars_
[
4
],
true_div0
})});
}
const
BaseRef
AdamApplyOneCond1Fusion
::
DefinePattern
()
const
{
const
auto
prim_sqrt
=
std
::
make_shared
<
Primitive
>
(
kSqrtOpName
);
const
auto
prim_real_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
VectorRef
mul2
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
2
],
input_vars_
[
1
]});
VectorRef
mul3
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
3
],
VectorRef
({
prim
::
kPrimSquare
,
input_vars_
[
0
]})});
VectorRef
sqrt0
=
VectorRef
({
prim_sqrt
,
VectorRef
({
add1_var_
,
mul2
,
mul3
})});
VectorRef
mul1
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
1
],
input_vars_
[
0
]});
VectorRef
mul0
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
0
],
input_vars_
[
2
]});
VectorRef
add0
=
VectorRef
({
add0_var_
,
mul0
,
mul1
});
VectorRef
true_div0
=
VectorRef
({
prim_real_div
,
add0
,
VectorRef
({
prim
::
kPrimTensorAdd
,
add2_y_
,
sqrt0
})});
return
VectorRef
({
prim
::
kPrimSub
,
input_vars_
[
3
],
VectorRef
({
prim
::
kPrimMul
,
input_vars_
[
4
],
true_div0
})});
}
const
BaseRef
AdamApplyOneCond2Fusion
::
DefinePattern
()
const
{
const
auto
prim_sqrt
=
std
::
make_shared
<
Primitive
>
(
kSqrtOpName
);
const
auto
prim_real_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
VectorRef
mul2
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
2
],
input_vars_
[
1
]});
VectorRef
mul3
=
VectorRef
({
prim
::
kPrimMul
,
VectorRef
({
prim
::
kPrimSquare
,
input_vars_
[
0
]}),
mul_x_input_vars_
[
3
]});
VectorRef
sqrt0
=
VectorRef
({
prim_sqrt
,
VectorRef
({
add1_var_
,
mul2
,
mul3
})});
VectorRef
mul1
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
1
],
input_vars_
[
0
]});
VectorRef
mul0
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
0
],
input_vars_
[
2
]});
VectorRef
add0
=
VectorRef
({
add0_var_
,
mul0
,
mul1
});
VectorRef
true_div0
=
VectorRef
({
prim_real_div
,
add0
,
VectorRef
({
prim
::
kPrimTensorAdd
,
sqrt0
,
add2_y_
})});
return
VectorRef
({
prim
::
kPrimSub
,
input_vars_
[
3
],
VectorRef
({
prim
::
kPrimMul
,
true_div0
,
input_vars_
[
4
]})});
}
const
BaseRef
AdamApplyOneCond3Fusion
::
DefinePattern
()
const
{
const
auto
prim_sqrt
=
std
::
make_shared
<
Primitive
>
(
kSqrtOpName
);
const
auto
prim_real_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
VectorRef
mul2
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
2
],
input_vars_
[
1
]});
VectorRef
mul3
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
3
],
VectorRef
({
prim
::
kPrimSquare
,
input_vars_
[
0
]})});
VectorRef
sqrt0
=
VectorRef
({
prim_sqrt
,
VectorRef
({
add1_var_
,
mul2
,
mul3
})});
VectorRef
mul1
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
1
],
input_vars_
[
0
]});
VectorRef
mul0
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
0
],
input_vars_
[
2
]});
VectorRef
add0
=
VectorRef
({
add0_var_
,
mul0
,
mul1
});
VectorRef
true_div0
=
VectorRef
({
prim_real_div
,
add0
,
VectorRef
({
prim
::
kPrimTensorAdd
,
sqrt0
,
add2_y_
})});
return
VectorRef
({
prim
::
kPrimSub
,
input_vars_
[
3
],
VectorRef
({
prim
::
kPrimMul
,
true_div0
,
input_vars_
[
4
]})});
}
const
BaseRef
AdamApplyOneCond4Fusion
::
DefinePattern
()
const
{
const
auto
prim_sqrt
=
std
::
make_shared
<
Primitive
>
(
kSqrtOpName
);
const
auto
prim_real_div
=
std
::
make_shared
<
Primitive
>
(
kRealDivOpName
);
VectorRef
mul2
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
2
],
input_vars_
[
1
]});
VectorRef
mul3
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
3
],
VectorRef
({
prim
::
kPrimSquare
,
input_vars_
[
0
]})});
VectorRef
sqrt0
=
VectorRef
({
prim_sqrt
,
VectorRef
({
add1_var_
,
mul2
,
mul3
})});
VectorRef
mul1
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
1
],
input_vars_
[
0
]});
VectorRef
mul0
=
VectorRef
({
prim
::
kPrimMul
,
mul_x_input_vars_
[
0
],
input_vars_
[
2
]});
VectorRef
add0
=
VectorRef
({
add0_var_
,
mul0
,
mul1
});
VectorRef
true_div0
=
VectorRef
({
prim_real_div
,
add0
,
VectorRef
({
prim
::
kPrimTensorAdd
,
add2_y_
,
sqrt0
})});
return
VectorRef
({
prim
::
kPrimSub
,
input_vars_
[
3
],
VectorRef
({
prim
::
kPrimMul
,
true_div0
,
input_vars_
[
4
]})});
}
const
AnfNodePtr
AdamApplyOneFusion
::
Process
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
func_graph
);
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h
浏览文件 @
5eb53798
...
...
@@ -18,21 +18,23 @@
#include <vector>
#include <memory>
#include <string>
#include "pre_activate/common/optimizer.h"
#include "utils/utils.h"
namespace
mindspore
{
namespace
opt
{
constexpr
size_t
kAdamApplyOneInputNum
=
5
;
constexpr
size_t
kAdamApplyOneMulInputNum
=
4
;
constexpr
size_t
kAdamApplyOneInput
Var
Num
=
5
;
constexpr
size_t
kAdamApplyOneMulInput
Var
Num
=
4
;
class
AdamApplyOneFusion
:
public
PatternProcessPass
{
public:
explicit
AdamApplyOneFusion
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"adam_apply_one_fusion"
,
multigraph
)
{
for
(
size_t
i
=
0
;
i
<
kAdamApplyOneInputNum
;
++
i
)
{
explicit
AdamApplyOneFusion
(
const
std
::
string
&
name
=
"adam_apply_one_fusion"
,
bool
multigraph
=
true
)
:
PatternProcessPass
(
name
,
multigraph
)
{
for
(
size_t
i
=
0
;
i
<
kAdamApplyOneInputVarNum
;
++
i
)
{
input_vars_
.
push_back
(
std
::
make_shared
<
Var
>
());
}
for
(
size_t
i
=
0
;
i
<
kAdamApplyOneMulInputNum
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
kAdamApplyOneMulInput
Var
Num
;
++
i
)
{
mul_x_input_vars_
.
push_back
(
std
::
make_shared
<
Var
>
());
}
add2_y_
=
std
::
make_shared
<
Var
>
();
...
...
@@ -44,7 +46,7 @@ class AdamApplyOneFusion : public PatternProcessPass {
const
BaseRef
DefinePattern
()
const
override
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
pr
ivate
:
pr
otected
:
AnfNodePtr
CreateAdamApplyOneNode
(
const
FuncGraphPtr
&
func_graph
,
const
EquivPtr
&
equiv
)
const
;
std
::
vector
<
VarPtr
>
input_vars_
;
std
::
vector
<
VarPtr
>
mul_x_input_vars_
;
...
...
@@ -52,6 +54,42 @@ class AdamApplyOneFusion : public PatternProcessPass {
VarPtr
add0_var_
;
VarPtr
add1_var_
;
};
class
AdamApplyOneCond1Fusion
:
public
AdamApplyOneFusion
{
public:
explicit
AdamApplyOneCond1Fusion
(
bool
multigraph
=
true
)
:
AdamApplyOneFusion
(
"adam_apply_one_cond1_fusion"
,
multigraph
)
{}
~
AdamApplyOneCond1Fusion
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
};
class
AdamApplyOneCond2Fusion
:
public
AdamApplyOneFusion
{
public:
explicit
AdamApplyOneCond2Fusion
(
bool
multigraph
=
true
)
:
AdamApplyOneFusion
(
"adam_apply_one_cond2_fusion"
,
multigraph
)
{}
~
AdamApplyOneCond2Fusion
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
};
class
AdamApplyOneCond3Fusion
:
public
AdamApplyOneFusion
{
public:
explicit
AdamApplyOneCond3Fusion
(
bool
multigraph
=
true
)
:
AdamApplyOneFusion
(
"adam_apply_one_cond3_fusion"
,
multigraph
)
{}
~
AdamApplyOneCond3Fusion
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
};
class
AdamApplyOneCond4Fusion
:
public
AdamApplyOneFusion
{
public:
explicit
AdamApplyOneCond4Fusion
(
bool
multigraph
=
true
)
:
AdamApplyOneFusion
(
"adam_apply_one_cond4_fusion"
,
multigraph
)
{}
~
AdamApplyOneCond4Fusion
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_
tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc
浏览文件 @
5eb53798
...
...
@@ -66,5 +66,156 @@ TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_fusion) {
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWAdamApplyOneFusion
,
test_adam_apply_one_cond1_fusion
)
{
/*
* def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
* square0 = Square(input0)
* mul1 = Mul(mul1_x, input0)
* mul0 = Mul(mul0_x, input2)
* mul2 = Mul(mul2_x, input1)
* mul3 = Mul(mul3_x, square0)
* add0 = Add(mul0, mul1)
* add1 = Add(mul2, mul3)
* sqrt0 = Sqrt(add1)
* add2 = Add(add2_y, sqrt0)
* true_div0 = RealDiv(add0, add2)
* mul4 = Mul(input4, true_div0)
* sub0 = Sub(input3, mul4)
* outputs = make_tuple(add1, add0, sub0)
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_fusion"
,
"before_cond1"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
for
(
size_t
i
=
0
;
i
<
10
;
++
i
)
{
args_spec_list
.
push_back
(
x_abstract
);
}
auto
fg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
AdamApplyOneCond1Fusion
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_fusion"
,
"after"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWAdamApplyOneFusion
,
test_adam_apply_one_cond2_fusion
)
{
/*
* def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
* square0 = Square(input0)
* mul1 = Mul(mul1_x, input0)
* mul0 = Mul(mul0_x, input2)
* mul2 = Mul(mul2_x, input1)
* mul3 = Mul(square0, mul3_x)
* add0 = Add(mul0, mul1)
* add1 = Add(mul2, mul3)
* sqrt0 = Sqrt(add1)
* add2 = Add(sqrt0, add2_y)
* true_div0 = RealDiv(add0, add2)
* mul4 = Mul(true_div0, input4)
* sub0 = Sub(input3, mul4)
* outputs = make_tuple(add1, add0, sub0)
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_fusion"
,
"before_cond2"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
for
(
size_t
i
=
0
;
i
<
10
;
++
i
)
{
args_spec_list
.
push_back
(
x_abstract
);
}
auto
fg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
AdamApplyOneCond2Fusion
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_fusion"
,
"after"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWAdamApplyOneFusion
,
test_adam_apply_one_cond3_fusion
)
{
/*
* def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
* square0 = Square(input0)
* mul1 = Mul(mul1_x, input0)
* mul0 = Mul(mul0_x, input2)
* mul2 = Mul(mul2_x, input1)
* mul3 = Mul(mul3_x, square0)
* add0 = Add(mul0, mul1)
* add1 = Add(mul2, mul3)
* sqrt0 = Sqrt(add1)
* add2 = Add(sqrt0, add2_y)
* true_div0 = RealDiv(add0, add2)
* mul4 = Mul(true_div0, input4)
* sub0 = Sub(input3, mul4)
* outputs = make_tuple(add1, add0, sub0)
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_fusion"
,
"before_cond3"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
for
(
size_t
i
=
0
;
i
<
10
;
++
i
)
{
args_spec_list
.
push_back
(
x_abstract
);
}
auto
fg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
AdamApplyOneCond3Fusion
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_fusion"
,
"after"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
TEST_F
(
TestHWAdamApplyOneFusion
,
test_adam_apply_one_cond4_fusion
)
{
/*
* def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
* square0 = Square(input0)
* mul1 = Mul(mul1_x, input0)
* mul0 = Mul(mul0_x, input2)
* mul2 = Mul(mul2_x, input1)
* mul3 = Mul(mul3_x, square0)
* add0 = Add(mul0, mul1)
* add1 = Add(mul2, mul3)
* sqrt0 = Sqrt(add1)
* add2 = Add(add2_y, sqrt0)
* true_div0 = RealDiv(add0, add2)
* mul4 = Mul(true_div0, input4)
* sub0 = Sub(input3, mul4)
* outputs = make_tuple(add1, add0, sub0)
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_fusion"
,
"before_cond4"
);
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
AbstractBasePtrList
args_spec_list
;
for
(
size_t
i
=
0
;
i
<
10
;
++
i
)
{
args_spec_list
.
push_back
(
x_abstract
);
}
auto
fg
=
GetKernelGraph
(
g
,
args_spec_list
);
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
AdamApplyOneCond4Fusion
>
());
optimizer
->
AddPassManager
(
pm
);
FuncGraphPtr
new_graph
=
optimizer
->
Optimize
(
fg
);
FuncGraphPtr
g_after
=
get_py_fun_
.
CallAndParseRet
(
"test_adam_apply_one_fusion"
,
"after"
);
EXPECT_TRUE
(
CheckEqualGraph
(
g_after
,
new_graph
));
}
}
// namespace opt
}
// namespace mindspore
tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_fusion_test.py
浏览文件 @
5eb53798
...
...
@@ -58,6 +58,78 @@ def test_adam_apply_one_fusion(tag):
output
=
tuple_getitem
(
outputs
,
0
)
return
output
@
fns
def
before_cond1
(
input0
,
input1
,
input2
,
input3
,
input4
,
mul0_x
,
mul1_x
,
mul2_x
,
mul3_x
,
add2_y
):
square0
=
Square
(
input0
)
mul1
=
Mul
(
mul1_x
,
input0
)
mul0
=
Mul
(
mul0_x
,
input2
)
mul2
=
Mul
(
mul2_x
,
input1
)
mul3
=
Mul
(
mul3_x
,
square0
)
add0
=
Add
(
mul0
,
mul1
)
add1
=
Add
(
mul2
,
mul3
)
sqrt0
=
Sqrt
(
add1
)
add2
=
Add
(
add2_y
,
sqrt0
)
true_div0
=
RealDiv
(
add0
,
add2
)
mul4
=
Mul
(
input4
,
true_div0
)
sub0
=
Sub
(
input3
,
mul4
)
outputs
=
make_tuple
(
add1
,
add0
,
sub0
)
output
=
tuple_getitem
(
outputs
,
0
)
return
output
@
fns
def
before_cond2
(
input0
,
input1
,
input2
,
input3
,
input4
,
mul0_x
,
mul1_x
,
mul2_x
,
mul3_x
,
add2_y
):
square0
=
Square
(
input0
)
mul1
=
Mul
(
mul1_x
,
input0
)
mul0
=
Mul
(
mul0_x
,
input2
)
mul2
=
Mul
(
mul2_x
,
input1
)
mul3
=
Mul
(
square0
,
mul3_x
)
add0
=
Add
(
mul0
,
mul1
)
add1
=
Add
(
mul2
,
mul3
)
sqrt0
=
Sqrt
(
add1
)
add2
=
Add
(
sqrt0
,
add2_y
)
true_div0
=
RealDiv
(
add0
,
add2
)
mul4
=
Mul
(
true_div0
,
input4
)
sub0
=
Sub
(
input3
,
mul4
)
outputs
=
make_tuple
(
add1
,
add0
,
sub0
)
output
=
tuple_getitem
(
outputs
,
0
)
return
output
@
fns
def
before_cond3
(
input0
,
input1
,
input2
,
input3
,
input4
,
mul0_x
,
mul1_x
,
mul2_x
,
mul3_x
,
add2_y
):
square0
=
Square
(
input0
)
mul1
=
Mul
(
mul1_x
,
input0
)
mul0
=
Mul
(
mul0_x
,
input2
)
mul2
=
Mul
(
mul2_x
,
input1
)
mul3
=
Mul
(
mul3_x
,
square0
)
add0
=
Add
(
mul0
,
mul1
)
add1
=
Add
(
mul2
,
mul3
)
sqrt0
=
Sqrt
(
add1
)
add2
=
Add
(
sqrt0
,
add2_y
)
true_div0
=
RealDiv
(
add0
,
add2
)
mul4
=
Mul
(
true_div0
,
input4
)
sub0
=
Sub
(
input3
,
mul4
)
outputs
=
make_tuple
(
add1
,
add0
,
sub0
)
output
=
tuple_getitem
(
outputs
,
0
)
return
output
@
fns
def
before_cond4
(
input0
,
input1
,
input2
,
input3
,
input4
,
mul0_x
,
mul1_x
,
mul2_x
,
mul3_x
,
add2_y
):
square0
=
Square
(
input0
)
mul1
=
Mul
(
mul1_x
,
input0
)
mul0
=
Mul
(
mul0_x
,
input2
)
mul2
=
Mul
(
mul2_x
,
input1
)
mul3
=
Mul
(
mul3_x
,
square0
)
add0
=
Add
(
mul0
,
mul1
)
add1
=
Add
(
mul2
,
mul3
)
sqrt0
=
Sqrt
(
add1
)
add2
=
Add
(
add2_y
,
sqrt0
)
true_div0
=
RealDiv
(
add0
,
add2
)
mul4
=
Mul
(
true_div0
,
input4
)
sub0
=
Sub
(
input3
,
mul4
)
outputs
=
make_tuple
(
add1
,
add0
,
sub0
)
output
=
tuple_getitem
(
outputs
,
0
)
return
output
@
fns
def
after
(
input0
,
input1
,
input2
,
input3
,
input4
,
mul0_x
,
mul1_x
,
mul2_x
,
mul3_x
,
add2_y
):
adam_apply_one
=
AdamApplyOne
(
input0
,
input1
,
input2
,
input3
,
input4
,
mul0_x
,
mul1_x
,
mul2_x
,
mul3_x
,
add2_y
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录