Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
347bf904
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
347bf904
编写于
9月 20, 2018
作者:
T
Tomasz Patejko
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MKLDNN conv + elementwise_add fusion: bias is also handled
上级
bf95ac36
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
22 addition
and
12 deletion
+22
-12
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
...uid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
+12
-3
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
...mework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
+10
-9
未找到文件。
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
浏览文件 @
347bf904
...
...
@@ -49,6 +49,7 @@ struct Pattern : public PatternBase {
struct
Conv
{
std
::
string
op_name
()
const
{
return
"conv2d"
;
}
std
::
string
input_name
()
const
{
return
"Input"
;
}
std
::
string
bias_name
()
const
{
return
"Bias"
;
}
std
::
string
filter_name
()
const
{
return
"Filter"
;
}
std
::
string
residual_data_name
()
const
{
return
"ResidualData"
;
}
std
::
string
output_name
()
const
{
return
"Output"
;
}
...
...
@@ -60,13 +61,16 @@ struct Conv {
auto
input_var
=
pattern
->
new_node
(
input_name
())
->
assert_is_op_input
(
op_name
(),
input_name
());
auto
bias_var
=
pattern
->
new_node
(
bias_name
())
->
assert_is_op_input
(
op_name
(),
bias_name
());
auto
filter_var
=
pattern
->
new_node
(
filter_name
())
->
assert_is_op_input
(
op_name
(),
filter_name
());
auto
output_var
=
pattern
->
new_node
(
output_name
())
->
assert_is_op_output
(
op_name
(),
output_name
());
conv_op
->
LinksFrom
({
input_var
,
filter_var
});
conv_op
->
LinksFrom
({
input_var
,
bias_var
,
filter_var
});
conv_op
->
LinksTo
({
output_var
});
return
output_var
;
...
...
@@ -178,13 +182,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_output
->
AsIntermediate
();
auto
fuse_conv
=
[
&
conv_pattern
](
Graph
*
g
,
Node
*
conv_input
,
auto
fuse_conv
=
[
&
conv_pattern
](
Graph
*
g
,
Node
*
conv_input
,
Node
*
conv_bias
,
Node
*
conv_filter
,
Node
*
conv_output
,
Node
*
elementwise_add_x
)
{
OpDesc
op_desc
;
op_desc
.
SetType
(
conv_pattern
.
op_name
());
op_desc
.
SetInput
(
conv_pattern
.
input_name
(),
{
conv_input
->
Name
()});
op_desc
.
SetInput
(
conv_pattern
.
bias_name
(),
{
conv_bias
->
Name
()});
op_desc
.
SetInput
(
conv_pattern
.
filter_name
(),
{
conv_filter
->
Name
()});
op_desc
.
SetInput
(
conv_pattern
.
residual_data_name
(),
{
elementwise_add_x
->
Name
()});
...
...
@@ -196,6 +201,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
auto
fused_conv_op
=
g
->
CreateOpNode
(
&
op_desc
);
patterns
::
LinkNodes
(
conv_input
,
fused_conv_op
);
patterns
::
LinkNodes
(
conv_bias
,
fused_conv_op
);
patterns
::
LinkNodes
(
conv_filter
,
fused_conv_op
);
patterns
::
LinkNodes
(
elementwise_add_x
,
fused_conv_op
);
patterns
::
LinkNodes
(
fused_conv_op
,
conv_output
);
...
...
@@ -208,6 +214,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_pattern
.
op_name
());
auto
conv_input
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
conv_pattern
.
input_name
());
auto
conv_bias
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
conv_pattern
.
bias_name
());
auto
conv_filter
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
conv_pattern
.
filter_name
());
auto
conv_output
=
patterns
::
GetNodeFromSubgraph
(
...
...
@@ -220,7 +228,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
auto
elementwise_add_out
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
elementwise_add_pattern
.
out_name
());
fuse_conv
(
g
,
conv_input
,
conv_filter
,
conv_output
,
elementwise_add_x
);
fuse_conv
(
g
,
conv_input
,
conv_bias
,
conv_filter
,
conv_output
,
elementwise_add_x
);
patterns
::
CorrectGraphEdges
(
g
,
elementwise_add_out
,
conv_output
);
GraphSafeRemoveNodes
(
g
,
{
elementwise_add_out
,
conv_op
,
elementwise_add_op
});
};
...
...
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
浏览文件 @
347bf904
...
...
@@ -34,7 +34,8 @@ void SetOp(ProgramDesc* prog, const std::string& type,
if
(
type
==
"conv2d"
)
{
op
->
SetAttr
(
"use_mkldnn"
,
true
);
op
->
SetInput
(
"Input"
,
{
inputs
[
0
]});
op
->
SetInput
(
"Filter"
,
{
inputs
[
1
]});
op
->
SetInput
(
"Bias"
,
{
inputs
[
1
]});
op
->
SetInput
(
"Filter"
,
{
inputs
[
2
]});
op
->
SetOutput
(
"Output"
,
outputs
);
}
else
if
(
type
==
"elementwise_add"
)
{
op
->
SetInput
(
"X"
,
{
inputs
[
0
]});
...
...
@@ -98,8 +99,8 @@ struct IsReachable {
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
ConvolutionWithElementwiseAddRelu
)
{
auto
build_program_desc
=
[
&
]()
->
ProgramDesc
{
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
({
"a"
,
"b"
,
"weights"
,
"c"
,
"d"
,
"e
"
}))
{
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
(
{
"a"
,
"b"
,
"bias"
,
"weights"
,
"c"
,
"d"
,
"e"
,
"f
"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
if
(
v
==
"weights"
)
{
...
...
@@ -107,7 +108,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
}
}
SetOp
(
&
prog
,
"conv2d"
,
{
"a"
,
"weights"
},
{
"b"
});
SetOp
(
&
prog
,
"conv2d"
,
{
"a"
,
"
bias"
,
"
weights"
},
{
"b"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{
"c"
,
"b"
},
{
"d"
});
SetOp
(
&
prog
,
"relu"
,
{
"d"
},
{
"e"
});
...
...
@@ -150,7 +151,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
ConvolutionElementwiseAdd
)
{
auto
build_program_desc
=
[
&
]()
->
ProgramDesc
{
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
({
"a"
,
"b"
,
"weights"
}))
{
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
({
"a"
,
"b"
,
"
bias"
,
"
weights"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
if
(
v
==
"weights"
||
v
==
"bias"
)
{
...
...
@@ -158,7 +159,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
}
}
SetOp
(
&
prog
,
"conv2d"
,
{
"a"
,
"weights"
},
{
"b"
});
SetOp
(
&
prog
,
"conv2d"
,
{
"a"
,
"
bias"
,
"
weights"
},
{
"b"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{
"c"
,
"b"
},
{
"d"
});
return
prog
;
...
...
@@ -199,8 +200,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
SigmoidConvolutionAddElementwiseRelu
)
{
auto
build_program_desc
=
[
&
]()
->
ProgramDesc
{
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
({
"a"
,
"b
"
,
"weights"
,
"c"
,
"d"
,
"e"
,
"f"
}))
{
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
(
{
"a"
,
"b"
,
"bias
"
,
"weights"
,
"c"
,
"d"
,
"e"
,
"f"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
if
(
v
.
find
(
"weights"
))
{
...
...
@@ -209,7 +210,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
}
SetOp
(
&
prog
,
"sigmoid"
,
{
"a"
},
{
"b"
});
SetOp
(
&
prog
,
"conv2d"
,
{
"b"
,
"weights"
},
{
"c"
});
SetOp
(
&
prog
,
"conv2d"
,
{
"b"
,
"
bias"
,
"
weights"
},
{
"c"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{
"d"
,
"c"
},
{
"e"
});
SetOp
(
&
prog
,
"relu"
,
{
"e"
},
{
"f"
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录