Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
efd76614
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
efd76614
编写于
9月 26, 2018
作者:
T
Tomasz Patejko
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MKLDNN conv + elementwise_add fusion: implementation changed to conform with Paddle API
上级
347bf904
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
101 addition
and
46 deletion
+101
-46
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
...uid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
+36
-46
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+39
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+26
-0
未找到文件。
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
浏览文件 @
efd76614
...
...
@@ -22,6 +22,7 @@ namespace framework {
namespace
ir
{
namespace
patterns
{
/*
struct Pattern : public PatternBase {
Pattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase{pattern, name_scope, ""} {}
...
...
@@ -45,7 +46,8 @@ struct Pattern : public PatternBase {
return node_pattern()->NewNode(node_name(op_name));
}
};
*/
/*
struct Conv {
std::string op_name() const { return "conv2d"; }
std::string input_name() const { return "Input"; }
...
...
@@ -105,7 +107,8 @@ struct ElementwiseAdd {
};
}
};
*/
/*
Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
std::shared_ptr<patterns::Pattern> pattern,
const std::string& op_name) {
...
...
@@ -116,6 +119,7 @@ Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
return var;
}
*/
void
LinkNodes
(
Node
*
from
,
Node
*
to
)
{
from
->
outputs
.
push_back
(
to
);
...
...
@@ -172,64 +176,50 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
GraphPatternDetector
gpd
;
auto
pattern
=
gpd
.
mutable_pattern
();
auto
pattern_ptr
=
std
::
make_shared
<
patterns
::
Pattern
>
(
pattern
,
name_scope_
);
patterns
::
Conv
conv_pattern
;
auto
conv_output
=
conv_pattern
(
pattern_ptr
)(
);
patterns
::
Conv
conv_pattern
{
pattern
,
"skip_connections_fusion"
}
;
auto
conv_output
=
conv_pattern
();
patterns
::
ElementwiseAdd
elementwise_add_pattern
;
elementwise_add_pattern
(
pattern_ptr
)(
conv_output
);
patterns
::
ElementwiseAdd
elementwise_add_pattern
{
pattern
,
"skip_connections_fusion"
};
elementwise_add_pattern
(
conv_output
);
conv_output
->
AsIntermediate
();
auto
fuse_conv
=
[
&
conv_pattern
](
Graph
*
g
,
Node
*
conv_input
,
Node
*
conv_bias
,
Node
*
conv_filter
,
Node
*
conv_output
,
Node
*
elementwise_add_x
)
{
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
conv_op
,
conv_op
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_input
,
conv_input
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_bias
,
conv_bias
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_filter
,
conv_filter
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_output
,
conv_output
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_op
,
elementwise_add_op
,
elementwise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_x
,
elementwise_add_x
,
elementwise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_out
,
elementwise_add_out
,
elementwise_add_pattern
);
OpDesc
op_desc
;
op_desc
.
SetType
(
conv_pattern
.
op_name
()
);
op_desc
.
SetType
(
"conv2d"
);
op_desc
.
SetInput
(
conv_pattern
.
input_name
(),
{
conv_input
->
Name
()});
op_desc
.
SetInput
(
conv_pattern
.
bias_name
(),
{
conv_bias
->
Name
()});
op_desc
.
SetInput
(
conv_pattern
.
filter_name
(),
{
conv_filter
->
Name
()});
op_desc
.
SetInput
(
conv_pattern
.
residual_data_name
(),
{
elementwise_add_x
->
Name
()});
op_desc
.
SetOutput
(
conv_pattern
.
output_name
(),
{
conv_output
->
Name
()});
op_desc
.
SetInput
(
"Input"
,
{
conv_input
->
Name
()});
op_desc
.
SetInput
(
"Bias"
,
{
conv_bias
->
Name
()});
op_desc
.
SetInput
(
"Filter"
,
{
conv_filter
->
Name
()});
op_desc
.
SetInput
(
"ResidualData"
,
{
elementwise_add_x
->
Name
()});
op_desc
.
SetOutput
(
"Output"
,
{
conv_output
->
Name
()});
op_desc
.
SetAttr
(
"use_mkldnn"
,
true
);
op_desc
.
SetAttr
(
"fuse_eltwise"
,
true
);
auto
fused_conv_op
=
g
->
CreateOpNode
(
&
op_desc
);
patterns
::
LinkNodes
(
conv_input
,
fused_conv_op
);
patterns
::
LinkNodes
(
conv_bias
,
fused_conv_op
);
patterns
::
LinkNodes
(
conv_filter
,
fused_conv_op
);
patterns
::
LinkNodes
(
elementwise_add_x
,
fused_conv_op
);
patterns
::
LinkNodes
(
fused_conv_op
,
conv_output
);
};
IR_NODE_LINK_TO
(
conv_input
,
fused_conv_op
);
IR_NODE_LINK_TO
(
conv_bias
,
fused_conv_op
);
IR_NODE_LINK_TO
(
conv_filter
,
fused_conv_op
);
IR_NODE_LINK_TO
(
elementwise_add_x
,
fused_conv_op
);
IR_NODE_LINK_TO
(
fused_conv_op
,
conv_output
);
auto
handler
=
[
&
conv_pattern
,
&
elementwise_add_pattern
,
pattern_ptr
,
fuse_conv
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
auto
conv_op
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
conv_pattern
.
op_name
());
auto
conv_input
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
conv_pattern
.
input_name
());
auto
conv_bias
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
conv_pattern
.
bias_name
());
auto
conv_filter
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
conv_pattern
.
filter_name
());
auto
conv_output
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
conv_pattern
.
output_name
());
auto
elementwise_add_op
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
elementwise_add_pattern
.
op_name
());
auto
elementwise_add_x
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
elementwise_add_pattern
.
x_name
());
auto
elementwise_add_out
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
elementwise_add_pattern
.
out_name
());
fuse_conv
(
g
,
conv_input
,
conv_bias
,
conv_filter
,
conv_output
,
elementwise_add_x
);
patterns
::
CorrectGraphEdges
(
g
,
elementwise_add_out
,
conv_output
);
GraphSafeRemoveNodes
(
g
,
{
elementwise_add_out
,
conv_op
,
elementwise_add_op
});
};
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
efd76614
...
...
@@ -999,6 +999,45 @@ PDNode *patterns::ConvBias::operator()(
return
eltwise_out_var
;
}
PDNode
*
patterns
::
Conv
::
operator
()()
{
auto
conv_op
=
pattern
->
NewNode
(
conv_op_repr
())
->
assert_is_op
(
"conv2d"
);
auto
input_var
=
pattern
->
NewNode
(
conv_input_repr
())
->
assert_is_op_input
(
"conv2d"
,
"Input"
);
auto
bias_var
=
pattern
->
NewNode
(
conv_bias_repr
())
->
assert_is_op_input
(
"conv2d"
,
"Bias"
);
auto
filter_var
=
pattern
->
NewNode
(
conv_filter_repr
())
->
assert_is_op_input
(
"conv2d"
,
"Filter"
);
auto
output_var
=
pattern
->
NewNode
(
conv_output_repr
())
->
assert_is_op_output
(
"conv2d"
,
"Output"
);
conv_op
->
LinksFrom
({
input_var
,
bias_var
,
filter_var
});
conv_op
->
LinksTo
({
output_var
});
return
output_var
;
}
PDNode
*
patterns
::
ElementwiseAdd
::
operator
()(
PDNode
*
conv_output
)
{
auto
elementwise_add_op
=
pattern
->
NewNode
(
elementwise_add_op_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
x_var
=
pattern
->
NewNode
(
elementwise_add_x_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
conv_output
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
out_var
=
pattern
->
NewNode
(
elementwise_add_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
);
elementwise_add_op
->
LinksFrom
({
x_var
,
conv_output
});
elementwise_add_op
->
LinksTo
({
out_var
});
return
out_var
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
efd76614
...
...
@@ -599,6 +599,32 @@ struct ConvBias : public PatternBase {
PATTERN_DECL_NODE
(
eltwise_bias
);
PATTERN_DECL_NODE
(
eltwise_out
);
};
struct
Conv
:
public
PatternBase
{
Conv
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"convolution"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
conv_op
);
PATTERN_DECL_NODE
(
conv_input
);
PATTERN_DECL_NODE
(
conv_bias
);
PATTERN_DECL_NODE
(
conv_filter
);
PATTERN_DECL_NODE
(
conv_residual_data
);
PATTERN_DECL_NODE
(
conv_output
);
};
struct
ElementwiseAdd
:
public
PatternBase
{
ElementwiseAdd
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"elementwise_add"
)
{}
PDNode
*
operator
()(
PDNode
*
conv_output
);
PATTERN_DECL_NODE
(
elementwise_add_op
);
PATTERN_DECL_NODE
(
elementwise_add_x
);
PATTERN_DECL_NODE
(
elementwise_add_y
);
PATTERN_DECL_NODE
(
elementwise_add_out
);
};
}
// namespace patterns
// Link two ir::Nodes from each other.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录