Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
53da846d
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
53da846d
编写于
11月 15, 2018
作者:
T
Tomasz Patejko
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MKLDNN residual connections fuse pass: initial implementation of fusion for projection pass
test=develop
上级
dbc4fcd7
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
206 addition
and
39 deletion
+206
-39
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
...uid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
+147
-27
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h
...luid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h
+59
-12
未找到文件。
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
浏览文件 @
53da846d
...
...
@@ -120,17 +120,18 @@ boost::optional<Node*> HasBias(const Node& op, const std::string& bias_name) {
return
boost
::
none
;
}
ResidualConnectionMKLDNNFusePass
::
FuseHandler
::
FuseHandler
(
const
ResidualConnectionMKLDNNFusePass
::
ConvFunc
&
get_node_from_conv_op
,
const
ResidualConnectionMKLDNNFusePass
::
ElementwiseAddFunc
&
get_node_from_elementwise_add_op
,
const
ResidualConnectionMKLDNNFusePass
::
CanFuseFunc
&
can_fuse_func
)
ResidualConnectionMKLDNNFusePass
::
IdentityFuseHandle
::
IdentityFuseHandle
(
const
ResidualConnectionMKLDNNFusePass
::
CanFuseFunc
&
can_fuse_func
,
const
ResidualConnectionMKLDNNFusePass
::
IdentityConvFunc
&
get_node_from_conv_op
,
const
ResidualConnectionMKLDNNFusePass
::
IdentityElementwiseAddFunc
&
get_node_from_elementwise_add_op
)
:
fusion_stats
{
std
::
make_shared
<
int
>
(
0
)},
can_fuse_func
{
can_fuse_func
},
get_node_from_conv_op
{
get_node_from_conv_op
},
get_node_from_elementwise_add_op
{
get_node_from_elementwise_add_op
},
can_fuse_func
{
can_fuse_func
}
{}
get_node_from_elementwise_add_op
{
get_node_from_elementwise_add_op
}
{}
void
ResidualConnectionMKLDNNFusePass
::
FuseHandler
::
operator
()(
void
ResidualConnectionMKLDNNFusePass
::
IdentityFuseHandle
::
operator
()(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
Node
*
conv_op
;
Node
*
conv_input
;
...
...
@@ -187,6 +188,104 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
(
*
fusion_stats
)
++
;
}
ResidualConnectionMKLDNNFusePass
::
ProjectionFuseHandle
::
ProjectionFuseHandle
(
const
ResidualConnectionMKLDNNFusePass
::
CanFuseFunc
&
can_fuse_func
,
const
ResidualConnectionMKLDNNFusePass
::
ProjectionConvFunc
&
get_node_from_conv_x_op
,
const
ResidualConnectionMKLDNNFusePass
::
ProjectionConvFunc
&
get_node_from_conv_y_op
,
const
ResidualConnectionMKLDNNFusePass
::
ProjectionElementwiseAddFunc
&
get_node_from_elementwise_add_op
)
:
fusion_stats
{
std
::
make_shared
<
int
>
(
0
)},
can_fuse_func
{
can_fuse_func
},
get_node_from_conv_x_op
{
get_node_from_conv_x_op
},
get_node_from_conv_y_op
{
get_node_from_conv_y_op
},
get_node_from_elementwise_add_op
{
get_node_from_elementwise_add_op
}
{}
void
ResidualConnectionMKLDNNFusePass
::
ProjectionFuseHandle
::
operator
()(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
Node
*
conv_x_op
;
Node
*
conv_x_input
;
Node
*
conv_x_filter
;
Node
*
conv_x_output
;
Node
*
conv_y_op
;
Node
*
conv_y_input
;
Node
*
conv_y_filter
;
Node
*
conv_y_output
;
Node
*
elementwise_add_op
;
Node
*
elementwise_add_out
;
std
::
tie
(
conv_x_op
,
conv_x_input
,
conv_x_filter
,
conv_x_output
)
=
get_node_from_conv_x_op
(
subgraph
);
std
::
tie
(
conv_y_op
,
conv_y_input
,
conv_y_filter
,
conv_y_output
)
=
get_node_from_conv_y_op
(
subgraph
);
std
::
tie
(
elementwise_add_op
,
elementwise_add_out
)
=
get_node_from_elementwise_add_op
(
subgraph
);
if
(
!
can_fuse_func
(
conv_x_op
,
elementwise_add_op
))
return
;
if
(
!
can_fuse_func
(
conv_y_op
,
elementwise_add_op
))
return
;
Node
*
projection_node
;
Node
*
residual_conv_op
;
Node
*
residual_conv_input
;
Node
*
residual_conv_filter
;
Node
*
residual_conv_output
;
if
(
IsReachable
(
graph
,
conv_x_input
,
conv_y_output
))
{
projection_node
=
conv_x_output
;
residual_conv_op
=
conv_y_op
;
residual_conv_input
=
conv_y_input
;
residual_conv_filter
=
conv_y_filter
;
residual_conv_output
=
conv_y_output
;
}
else
if
(
IsReachable
(
graph
,
conv_y_input
,
conv_x_output
))
{
projection_node
=
conv_y_output
;
residual_conv_op
=
conv_x_op
;
residual_conv_input
=
conv_x_input
;
residual_conv_filter
=
conv_x_filter
;
residual_conv_output
=
conv_x_output
;
}
else
{
return
;
}
OpDesc
op_desc
;
op_desc
.
SetType
(
"conv2d"
);
op_desc
.
SetInput
(
"Input"
,
{
residual_conv_input
->
Name
()});
op_desc
.
SetInput
(
"Filter"
,
{
residual_conv_filter
->
Name
()});
op_desc
.
SetInput
(
"ResidualData"
,
{
projection_node
->
Name
()});
op_desc
.
SetOutput
(
"Output"
,
{
residual_conv_output
->
Name
()});
auto
residual_conv_bias
=
HasBias
(
*
residual_conv_op
,
"Bias"
);
if
(
residual_conv_bias
)
{
op_desc
.
SetInput
(
"Bias"
,
{(
*
residual_conv_bias
)
->
Name
()});
}
for
(
const
auto
&
attr
:
residual_conv_op
->
Op
()
->
GetAttrMap
())
{
op_desc
.
SetAttr
(
attr
.
first
,
attr
.
second
);
}
op_desc
.
SetAttr
(
"fuse_residual_connection"
,
true
);
auto
fused_conv_op
=
graph
->
CreateOpNode
(
&
op_desc
);
IR_NODE_LINK_TO
(
residual_conv_input
,
fused_conv_op
);
IR_NODE_LINK_TO
(
residual_conv_filter
,
fused_conv_op
);
IR_NODE_LINK_TO
(
projection_node
,
fused_conv_op
);
IR_NODE_LINK_TO
(
fused_conv_op
,
residual_conv_output
);
if
(
residual_conv_bias
)
{
IR_NODE_LINK_TO
((
*
residual_conv_bias
),
fused_conv_op
);
}
CorrectGraphEdges
(
graph
,
elementwise_add_out
,
residual_conv_output
);
GraphSafeRemoveNodes
(
graph
,
{
elementwise_add_out
,
residual_conv_op
,
elementwise_add_op
});
(
*
fusion_stats
)
++
;
}
std
::
tuple
<
Node
*
,
Node
*
,
Node
*
,
Node
*>
ResidualConnectionMKLDNNFusePass
::
GetNodesFromConv
(
const
patterns
::
Conv
&
conv_pattern
,
...
...
@@ -233,7 +332,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
elementwise_add_out
);
};
return
ExecuteHandle
rOnGraph
(
return
ExecuteHandle
OnGraph
<
IdentityFuseHandle
>
(
&
gpd
,
graph_with_stats
,
[
this
,
&
conv_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
return
GetNodesFromConv
(
conv_pattern
,
subgraph
);
...
...
@@ -270,7 +369,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
elementwise_add_out
);
};
return
ExecuteHandle
rOnGraph
(
return
ExecuteHandle
OnGraph
<
IdentityFuseHandle
>
(
&
gpd
,
graph_with_stats
,
[
this
,
&
conv_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
return
GetNodesFromConv
(
conv_pattern
,
subgraph
);
...
...
@@ -278,33 +377,54 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
get_node_from_elementwise_add
);
}
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
ExecuteHandlerOnGraph
(
GraphPatternDetector
*
gpd
,
const
GraphWithStats
&
graph_with_stats
,
const
ResidualConnectionMKLDNNFusePass
::
ConvFunc
&
get_node_from_conv
,
const
ResidualConnectionMKLDNNFusePass
::
ElementwiseAddFunc
&
get_node_from_elementwise_add
)
const
{
ir
::
Graph
*
graph
;
int
stats
;
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseProjectionConv
(
const
std
::
string
&
name_scope
,
const
GraphWithStats
&
graph_with_stats
)
const
{
GraphPatternDetector
gpd
;
auto
pattern
=
gpd
.
mutable_pattern
();
std
::
tie
(
graph
,
stats
)
=
graph_with_stats
;
patterns
::
Conv
conv_x_pattern
{
pattern
,
name_scope
};
auto
conv_x_output
=
conv_x_pattern
();
auto
can_fuse
=
[
this
](
Node
*
op1
,
Node
*
op2
)
->
bool
{
return
this
->
FindFuseOption
(
*
op1
,
*
op2
)
==
FUSE_MKLDNN
;
};
patterns
::
Conv
conv_y_pattern
{
pattern
,
name_scope
};
auto
conv_y_output
=
conv_y_pattern
();
auto
fuse_handler
=
FuseHandler
{
get_node_from_conv
,
get_node_from_elementwise_add
,
can_fuse
};
patterns
::
ElementwiseAdd
elementwise_add_pattern
{
pattern
,
name_scope
};
elementwise_add_pattern
(
conv_x_output
,
conv_y_output
);
conv_x_output
->
AsIntermediate
();
conv_y_output
->
AsIntermediate
();
(
*
gpd
)(
graph
,
fuse_handler
);
auto
get_node_from_elementwise_add
=
[
&
elementwise_add_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
->
std
::
tuple
<
Node
*
,
Node
*>
{
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_op
,
elementwise_add_op
,
elementwise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_out
,
elementwise_add_out
,
elementwise_add_pattern
);
return
std
::
make_pair
(
graph
,
stats
+
fuse_handler
.
get_stats
());
return
std
::
make_tuple
(
elementwise_add_op
,
elementwise_add_out
);
};
return
ExecuteHandleOnGraph
<
ProjectionFuseHandle
>
(
&
gpd
,
graph_with_stats
,
[
this
,
&
conv_x_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
return
GetNodesFromConv
(
conv_x_pattern
,
subgraph
);
},
[
this
,
&
conv_y_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
return
GetNodesFromConv
(
conv_y_pattern
,
subgraph
);
},
get_node_from_elementwise_add
);
}
graph_ptr
ResidualConnectionMKLDNNFusePass
::
ApplyImpl
(
graph_ptr
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
.
get
());
auto
fused_graph_with_stats
=
FuseConvAsY
(
name_scope_
,
FuseConvAsX
(
name_scope_
,
std
::
make_pair
(
graph
.
get
(),
0
)));
name_scope_
,
FuseConvAsX
(
name_scope_
,
FuseProjectionConv
(
name_scope_
,
std
::
make_pair
(
graph
.
get
(),
0
))));
std
::
cout
<<
"Fused graph "
<<
fused_graph_with_stats
.
second
<<
std
::
endl
;
AddStatis
(
fused_graph_with_stats
.
second
);
...
...
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h
浏览文件 @
53da846d
...
...
@@ -40,27 +40,73 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
const
GraphWithStats
&
graph_with_stats
)
const
;
GraphWithStats
FuseConvAsY
(
const
std
::
string
&
name_scope
,
const
GraphWithStats
&
graph_with_stats
)
const
;
GraphWithStats
FuseProjectionConv
(
const
std
::
string
&
name_scope
,
const
GraphWithStats
&
graph_with_stats
)
const
;
template
<
typename
RetType
>
using
GetNodeFunc
=
std
::
function
<
RetType
(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
>
;
using
ConvFunc
=
GetNodeFunc
<
std
::
tuple
<
Node
*
,
Node
*
,
Node
*
,
Node
*>>
;
using
ElementwiseAddFunc
=
GetNodeFunc
<
std
::
tuple
<
Node
*
,
Node
*
,
Node
*>>
;
using
IdentityConvFunc
=
GetNodeFunc
<
std
::
tuple
<
Node
*
,
Node
*
,
Node
*
,
Node
*>>
;
using
IdentityElementwiseAddFunc
=
GetNodeFunc
<
std
::
tuple
<
Node
*
,
Node
*
,
Node
*>>
;
using
ProjectionConvFunc
=
IdentityConvFunc
;
using
ProjectionElementwiseAddFunc
=
GetNodeFunc
<
std
::
tuple
<
Node
*
,
Node
*>>
;
using
CanFuseFunc
=
std
::
function
<
bool
(
Node
*
,
Node
*
)
>
;
std
::
tuple
<
Node
*
,
Node
*
,
Node
*
,
Node
*>
GetNodesFromConv
(
const
patterns
::
Conv
&
conv_pattern
,
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
const
;
GraphWithStats
ExecuteHandlerOnGraph
(
GraphPatternDetector
*
gpd
,
const
GraphWithStats
&
graph_with_stats
,
const
ConvFunc
&
get_node_from_conv
,
const
ElementwiseAddFunc
&
get_node_from_elementwise_add
)
const
;
std
::
tuple
<
Node
*
,
Node
*
,
Node
*
,
Node
*>
GetNodesFromProjectionConv
(
const
patterns
::
Conv
&
conv_pattern
,
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
const
;
template
<
typename
HandleType
,
typename
...
OpFuncs
>
GraphWithStats
ExecuteHandleOnGraph
(
GraphPatternDetector
*
gpd
,
const
GraphWithStats
&
graph_with_stats
,
OpFuncs
&&
...
op_funcs
)
const
{
ir
::
Graph
*
graph
;
int
stats
;
std
::
tie
(
graph
,
stats
)
=
graph_with_stats
;
auto
can_fuse
=
[
this
](
Node
*
op1
,
Node
*
op2
)
->
bool
{
return
this
->
FindFuseOption
(
*
op1
,
*
op2
)
==
FUSE_MKLDNN
;
};
auto
fuse_handle
=
HandleType
{
can_fuse
,
std
::
forward
<
OpFuncs
>
(
op_funcs
)...};
(
*
gpd
)(
graph
,
fuse_handle
);
return
std
::
make_pair
(
graph
,
stats
+
fuse_handle
.
get_stats
());
}
struct
IdentityFuseHandle
{
IdentityFuseHandle
(
const
CanFuseFunc
&
can_fuse_func
,
const
IdentityConvFunc
&
get_node_from_conv_op
,
const
IdentityElementwiseAddFunc
&
get_node_from_elementwise_add_op
);
void
operator
()(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
);
int
get_stats
()
const
{
return
*
fusion_stats
;
}
private:
std
::
shared_ptr
<
int
>
fusion_stats
;
CanFuseFunc
can_fuse_func
;
IdentityConvFunc
get_node_from_conv_op
;
IdentityElementwiseAddFunc
get_node_from_elementwise_add_op
;
};
struct
FuseHandler
{
FuseHandler
(
const
ConvFunc
&
get_node_from_conv_op
,
const
ElementwiseAddFunc
&
get_node_from_elementwise_add_op
,
const
CanFuseFunc
&
can_fuse_func
);
struct
ProjectionFuseHandle
{
ProjectionFuseHandle
(
const
CanFuseFunc
&
can_fuse_func
,
const
ProjectionConvFunc
&
get_node_from_conv_x_op
,
const
ProjectionConvFunc
&
get_node_from_conv_y_op
,
const
ProjectionElementwiseAddFunc
&
get_node_from_elementwise_add_op
);
void
operator
()(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
);
...
...
@@ -68,9 +114,10 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
private:
std
::
shared_ptr
<
int
>
fusion_stats
;
ConvFunc
get_node_from_conv_op
;
ElementwiseAddFunc
get_node_from_elementwise_add_op
;
CanFuseFunc
can_fuse_func
;
ProjectionConvFunc
get_node_from_conv_x_op
;
ProjectionConvFunc
get_node_from_conv_y_op
;
ProjectionElementwiseAddFunc
get_node_from_elementwise_add_op
;
};
public:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录