Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
dbc4fcd7
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dbc4fcd7
编写于
11月 08, 2018
作者:
T
Tomasz Patejko
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MKLDNN residual connections fuse pass: unit tests enabled and added
上级
42240893
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
67 addition
and
70 deletion
+67
-70
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
...mework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
+67
-70
未找到文件。
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
浏览文件 @
dbc4fcd7
...
@@ -40,7 +40,7 @@ void SetOp(ProgramDesc* prog, const std::string& type,
...
@@ -40,7 +40,7 @@ void SetOp(ProgramDesc* prog, const std::string& type,
op
->
SetOutput
(
output
.
first
,
{
output
.
second
});
op
->
SetOutput
(
output
.
first
,
{
output
.
second
});
}
}
struct
IsReachable
{
struct
Test
IsReachable
{
using
func
=
std
::
function
<
bool
(
const
std
::
string
&
,
const
std
::
string
&
)
>
;
using
func
=
std
::
function
<
bool
(
const
std
::
string
&
,
const
std
::
string
&
)
>
;
auto
operator
()(
const
std
::
unique_ptr
<
ir
::
Graph
>&
graph
)
->
func
{
auto
operator
()(
const
std
::
unique_ptr
<
ir
::
Graph
>&
graph
)
->
func
{
...
@@ -89,7 +89,9 @@ struct IsReachable {
...
@@ -89,7 +89,9 @@ struct IsReachable {
}
}
};
};
void
AssertOpsCount
(
const
std
::
unique_ptr
<
ir
::
Graph
>&
graph
)
{
void
AssertOpsCount
(
const
std
::
unique_ptr
<
ir
::
Graph
>&
graph
,
int
expected_conv_count
,
int
expected_elementwise_add_count
=
0
)
{
int
conv_count
=
0
;
int
conv_count
=
0
;
int
elementwise_add_count
=
0
;
int
elementwise_add_count
=
0
;
...
@@ -101,8 +103,8 @@ void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
...
@@ -101,8 +103,8 @@ void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
++
elementwise_add_count
;
++
elementwise_add_count
;
}
}
}
}
EXPECT_EQ
(
conv_count
,
1
);
EXPECT_EQ
(
conv_count
,
expected_conv_count
);
EXPECT_EQ
(
elementwise_add_count
,
0
);
EXPECT_EQ
(
elementwise_add_count
,
expected_elementwise_add_count
);
}
}
ProgramDesc
BuildProgramDesc
(
const
std
::
vector
<
std
::
string
>&
transient_vars
,
ProgramDesc
BuildProgramDesc
(
const
std
::
vector
<
std
::
string
>&
transient_vars
,
...
@@ -127,22 +129,13 @@ ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
...
@@ -127,22 +129,13 @@ ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
return
prog
;
return
prog
;
}
}
}
// namespace
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
ConvolutionWithElementwiseAddRelu
)
{
auto
prog
=
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
,
"f"
},
{
"bias"
,
"weights"
});
SetOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"a"
},
{
"Bias"
,
"bias"
},
{
"Filter"
,
"weights"
}},
{
"Output"
,
"b"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"b"
},
{
"Y"
,
"c"
}},
{
"Out"
,
"d"
});
SetOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{
"Out"
,
"e"
});
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
void
RunPassAndAssert
(
ProgramDesc
*
prog
,
const
std
::
string
&
from
,
const
std
::
string
&
to
,
int
expected_conv_num
)
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
*
prog
));
IsReachable
is_reachable
;
Test
IsReachable
is_reachable
;
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"relu"
));
EXPECT_TRUE
(
is_reachable
(
graph
)(
from
,
to
));
auto
pass
=
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
...
@@ -150,82 +143,87 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
...
@@ -150,82 +143,87 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
int
current_nodes_num
=
graph
->
Nodes
().
size
();
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"relu"
));
EXPECT_TRUE
(
is_reachable
(
graph
)(
from
,
to
));
EXPECT_EQ
(
original_nodes_num
-
nodes_removed
+
nodes_added
,
EXPECT_EQ
(
original_nodes_num
-
nodes_removed
+
nodes_added
,
current_nodes_num
);
current_nodes_num
);
AssertOpsCount
(
graph
);
AssertOpsCount
(
graph
,
expected_conv_num
);
}
}
}
// namespace
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
ConvolutionAsYWithElementwiseAddRelu
)
{
ConvolutionWithElementwiseAddReluNoBias
)
{
auto
prog
=
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
},
{
"bias"
,
"weights"
});
auto
prog
=
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
},
{
"weights"
});
SetOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"a"
},
{
"Filter"
,
"weights"
}},
{
"Output"
,
"b"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"b"
},
{
"Y"
,
"c"
}},
{
"Out"
,
"d"
});
SetOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{
"Out"
,
"e"
});
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
IsReachable
is_reachable
;
SetOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{
"Out"
,
"b"
});
SetOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"b"
},
{
"Bias"
,
"bias"
},
{
"Filter"
,
"weights"
}},
{
"Output"
,
"c"
});
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"relu"
));
SetOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"a"
},
{
"Y"
,
"c"
}},
{
"Out"
,
"d"
});
SetOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{
"Out"
,
"e"
});
auto
pass
=
RunPassAndAssert
(
&
prog
,
"a"
,
"relu"
,
1
);
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
}
int
original_nodes_num
=
graph
->
Nodes
().
size
();
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"relu"
));
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
ConvolutionAsYWithElementwiseAddReluNoBias
)
{
auto
prog
=
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
},
{
"weights"
});
EXPECT_EQ
(
original_nodes_num
-
nodes_removed
+
nodes_added
,
SetOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{
"Out"
,
"b"
});
current_nodes_num
);
SetOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"b"
},
{
"Filter"
,
"weights"
}},
{
"Output"
,
"c"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"a"
},
{
"Y"
,
"c"
}},
{
"Out"
,
"d"
});
SetOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{
"Out"
,
"e"
});
AssertOpsCount
(
graph
);
RunPassAndAssert
(
&
prog
,
"a"
,
"relu"
,
1
);
}
}
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
ConvolutionElementwiseAdd
)
{
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
ConvolutionAsXWithElementwiseAddRelu
)
{
auto
prog
=
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
},
{
"bias"
,
"weights"
});
auto
prog
=
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
},
{
"bias"
,
"weights"
});
SetOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{
"Out"
,
"b"
});
SetOp
(
&
prog
,
"conv2d"
,
SetOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"a"
},
{
"Bias"
,
"bias"
},
{
"Filter"
,
"weights"
}},
{{
"Input"
,
"b"
},
{
"Bias"
,
"bias"
},
{
"Filter"
,
"weights"
}},
{
"Output"
,
"b"
});
{
"Output"
,
"c"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"b"
},
{
"Y"
,
"c"
}},
{
"Out"
,
"d"
});
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
SetOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"c"
},
{
"Y"
,
"a"
}},
{
"Out"
,
"d"
});
SetOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{
"Out"
,
"e"
});
IsReachable
is_reachable
;
RunPassAndAssert
(
&
prog
,
"a"
,
"relu"
,
1
)
;
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"d"
));
}
auto
pass
=
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
ConvolutionAsXWithElementwiseAddReluNoBias
)
{
int
original_nodes_num
=
graph
->
Nodes
().
size
();
auto
prog
=
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
},
{
"weights"
});
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_FALSE
(
is_reachable
(
graph
)(
"a"
,
"d"
));
SetOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{
"Out"
,
"b"
});
SetOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"b"
},
{
"Filter"
,
"weights"
}},
{
"Output"
,
"c"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"c"
},
{
"Y"
,
"a"
}},
{
"Out"
,
"d"
});
SetOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{
"Out"
,
"e"
});
EXPECT_EQ
(
original_nodes_num
-
nodes_removed
+
nodes_added
,
RunPassAndAssert
(
&
prog
,
"a"
,
"relu"
,
1
);
current_nodes_num
);
AssertOpsCount
(
graph
);
}
}
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
SigmoidConvolutionAddElementwiseRelu
)
{
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
NoFusion
)
{
auto
prog
=
auto
prog
=
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
,
"f"
},
{
"bias"
,
"weights"
});
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
,
"f"
,
"g"
},
{
"weights"
});
SetOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{
"Out"
,
"b"
});
SetOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{
"Out"
,
"b"
});
SetOp
(
&
prog
,
"conv2d"
,
SetOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"b"
},
{
"Filter"
,
"weights"
}},
{{
"Input"
,
"b"
},
{
"Bias"
,
"bias"
},
{
"Filter"
,
"weights"
}},
{
"Output"
,
"c"
});
{
"Output"
,
"c"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"c"
},
{
"Y"
,
"d"
}},
{
"Out"
,
"e"
});
SetOp
(
&
prog
,
"relu"
,
{{
"X"
,
"e"
}},
{
"Out"
,
"f"
});
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
SetOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"d"
},
{
"Filter"
,
"weights"
}},
{
"Output"
,
"e"
});
IsReachable
is_reachable
;
SetOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"c"
},
{
"Y"
,
"e"
}},
{
"Out"
,
"f"
});
SetOp
(
&
prog
,
"relu"
,
{{
"X"
,
"f"
}},
{
"Out"
,
"g"
});
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"f"
));
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
TestIsReachable
is_reachable
;
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"g"
));
auto
pass
=
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
...
@@ -233,11 +231,10 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
...
@@ -233,11 +231,10 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
int
current_nodes_num
=
graph
->
Nodes
().
size
();
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"f"
));
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"g"
));
EXPECT_EQ
(
original_nodes_num
,
current_nodes_num
);
EXPECT_EQ
(
original_nodes_num
-
nodes_removed
+
nodes_added
,
AssertOpsCount
(
graph
,
2
,
1
);
current_nodes_num
);
AssertOpsCount
(
graph
);
}
}
}
// namespace ir
}
// namespace ir
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录