Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
47459e98
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
47459e98
编写于
3月 11, 2022
作者:
S
Sylwester Fraczek
提交者:
GitHub
3月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor conv+relementwise_add (residual) (#40005)
上级
c0e29233
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
177 addition
and
307 deletion
+177
-307
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc
...mework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc
+175
-217
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h
...amework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h
+2
-90
未找到文件。
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc
浏览文件 @
47459e98
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -135,157 +136,9 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
...
@@ -135,157 +136,9 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
.
End
();
.
End
();
}
}
ResidualConnectionMKLDNNFusePass
::
IdentityFuseHandle
::
IdentityFuseHandle
(
const
ResidualConnectionMKLDNNFusePass
::
CanFuseFunc
&
can_fuse_func
,
const
ResidualConnectionMKLDNNFusePass
::
IdentityConvFunc
&
get_node_from_conv_op
,
const
ResidualConnectionMKLDNNFusePass
::
IdentityElementwiseAddFunc
&
get_node_from_elementwise_add_op
,
const
ResidualConnectionMKLDNNFusePass
*
pass
)
:
fusion_stats
{
std
::
make_shared
<
int
>
(
0
)},
can_fuse_func
{
can_fuse_func
},
get_node_from_conv_op
{
get_node_from_conv_op
},
get_node_from_elementwise_add_op
{
get_node_from_elementwise_add_op
},
pass_
{
pass
}
{}
void
ResidualConnectionMKLDNNFusePass
::
IdentityFuseHandle
::
operator
()(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
Node
*
conv_op
;
Node
*
conv_input
;
Node
*
conv_filter
;
Node
*
conv_output
;
Node
*
elementwise_add_op
;
Node
*
elementwise_add_identity
;
Node
*
elementwise_add_out
;
std
::
tie
(
conv_op
,
conv_input
,
conv_filter
,
conv_output
)
=
get_node_from_conv_op
(
subgraph
);
std
::
tie
(
elementwise_add_op
,
elementwise_add_identity
,
elementwise_add_out
)
=
get_node_from_elementwise_add_op
(
subgraph
);
if
(
!
can_fuse_func
(
conv_op
,
elementwise_add_op
))
return
;
if
(
!
IsReachable
(
graph
,
elementwise_add_identity
,
conv_output
))
return
;
if
(
HasFusedActivation
(
conv_op
))
return
;
if
(
!
pass_
->
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"conv_elementwise_add_mkldnn_fuse_pass in op compat failed."
;
return
;
}
conv_op
->
Op
()
->
SetInput
(
"ResidualData"
,
{
elementwise_add_identity
->
Name
()});
conv_op
->
Op
()
->
SetOutput
(
"Output"
,
{
elementwise_add_out
->
Name
()});
conv_op
->
Op
()
->
SetAttr
(
"fuse_residual_connection"
,
true
);
GraphSafeRemoveNodes
(
graph
,
{
conv_output
,
elementwise_add_op
});
IR_NODE_LINK_TO
(
elementwise_add_identity
,
conv_op
);
IR_NODE_LINK_TO
(
conv_op
,
elementwise_add_out
);
(
*
fusion_stats
)
++
;
}
ResidualConnectionMKLDNNFusePass
::
ProjectionFuseHandle
::
ProjectionFuseHandle
(
const
ResidualConnectionMKLDNNFusePass
::
CanFuseFunc
&
can_fuse_func
,
const
ResidualConnectionMKLDNNFusePass
::
ProjectionConvFunc
&
get_node_from_conv_x_op
,
const
ResidualConnectionMKLDNNFusePass
::
ProjectionConvFunc
&
get_node_from_conv_y_op
,
const
ResidualConnectionMKLDNNFusePass
::
ProjectionElementwiseAddFunc
&
get_node_from_elementwise_add_op
,
const
ResidualConnectionMKLDNNFusePass
*
pass
)
:
fusion_stats
{
std
::
make_shared
<
int
>
(
0
)},
can_fuse_func
{
can_fuse_func
},
get_node_from_conv_x_op
{
get_node_from_conv_x_op
},
get_node_from_conv_y_op
{
get_node_from_conv_y_op
},
get_node_from_elementwise_add_op
{
get_node_from_elementwise_add_op
},
pass_
{
pass
}
{}
void
ResidualConnectionMKLDNNFusePass
::
ProjectionFuseHandle
::
operator
()(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
Node
*
conv_x_op
;
Node
*
conv_x_input
;
Node
*
conv_x_filter
;
Node
*
conv_x_output
;
Node
*
conv_y_op
;
Node
*
conv_y_input
;
Node
*
conv_y_filter
;
Node
*
conv_y_output
;
Node
*
elementwise_add_op
;
Node
*
elementwise_add_out
;
if
(
!
pass_
->
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"conv_elementwise_add_mkldnn_fuse_pass in op compat failed."
;
return
;
}
std
::
tie
(
conv_x_op
,
conv_x_input
,
conv_x_filter
,
conv_x_output
)
=
get_node_from_conv_x_op
(
subgraph
);
std
::
tie
(
conv_y_op
,
conv_y_input
,
conv_y_filter
,
conv_y_output
)
=
get_node_from_conv_y_op
(
subgraph
);
std
::
tie
(
elementwise_add_op
,
elementwise_add_out
)
=
get_node_from_elementwise_add_op
(
subgraph
);
if
(
!
can_fuse_func
(
conv_x_op
,
elementwise_add_op
))
return
;
if
(
!
can_fuse_func
(
conv_y_op
,
elementwise_add_op
))
return
;
Node
*
projection_node
;
Node
*
residual_conv_op
;
Node
*
residual_conv_output
;
if
(
IsReachable
(
graph
,
conv_x_input
,
conv_y_output
))
{
projection_node
=
conv_x_output
;
residual_conv_op
=
conv_y_op
;
residual_conv_output
=
conv_y_output
;
}
else
if
(
IsReachable
(
graph
,
conv_y_input
,
conv_x_output
))
{
projection_node
=
conv_y_output
;
residual_conv_op
=
conv_x_op
;
residual_conv_output
=
conv_x_output
;
}
else
{
return
;
}
if
(
HasFusedActivation
(
residual_conv_op
))
return
;
residual_conv_op
->
Op
()
->
SetInput
(
"ResidualData"
,
{
projection_node
->
Name
()});
residual_conv_op
->
Op
()
->
SetOutput
(
"Output"
,
{
elementwise_add_out
->
Name
()});
residual_conv_op
->
Op
()
->
SetAttr
(
"fuse_residual_connection"
,
true
);
GraphSafeRemoveNodes
(
graph
,
{
residual_conv_output
,
elementwise_add_op
});
IR_NODE_LINK_TO
(
projection_node
,
residual_conv_op
);
IR_NODE_LINK_TO
(
residual_conv_op
,
elementwise_add_out
);
(
*
fusion_stats
)
++
;
}
std
::
tuple
<
Node
*
,
Node
*
,
Node
*
,
Node
*>
ResidualConnectionMKLDNNFusePass
::
GetNodesFromConv
(
const
patterns
::
Conv
&
conv_pattern
,
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
const
{
GET_IR_NODE_FROM_SUBGRAPH
(
conv_op
,
conv_op
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_input
,
conv_input
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_filter
,
conv_filter
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_output
,
conv_output
,
conv_pattern
);
return
std
::
make_tuple
(
conv_op
,
conv_input
,
conv_filter
,
conv_output
);
}
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseConvAsX
(
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseConvAsX
(
const
std
::
string
&
name_scope
,
const
std
::
string
&
name_scope
,
const
GraphWithStats
&
graph_with_stats
)
const
{
const
GraphWithStats
&
graph_with_stats
)
const
{
ir
::
Graph
*
graph
;
int
stats
;
std
::
tie
(
graph
,
stats
)
=
graph_with_stats
;
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
pattern
=
gpd
.
mutable_pattern
();
auto
pattern
=
gpd
.
mutable_pattern
();
...
@@ -298,26 +151,56 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
...
@@ -298,26 +151,56 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
pattern
->
NewNode
(
elementwise_add_pattern
.
elementwise_add_y_repr
()));
pattern
->
NewNode
(
elementwise_add_pattern
.
elementwise_add_y_repr
()));
conv_output
->
AsIntermediate
();
conv_output
->
AsIntermediate
();
auto
get_node_from_elementwise_add
=
[
&
elementwise_add_pattern
](
int
found_conv_as_x_count
=
0
;
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
->
std
::
tuple
<
Node
*
,
Node
*
,
Node
*>
{
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_op
,
elementwise_add_op
,
Graph
*
g
)
{
elementwise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_op
,
conv_op
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_y
,
elementwise_add_y
,
GET_IR_NODE_FROM_SUBGRAPH
(
conv_input
,
conv_input
,
conv_pattern
);
elementwise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_filter
,
conv_filter
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_out
,
elementwise_add_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
conv_output
,
conv_output
,
conv_pattern
);
elementwise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_op
,
elementwise_add_op
,
return
std
::
make_tuple
(
elementwise_add_op
,
elementwise_add_y
,
elementwise_add_pattern
);
elementwise_add_out
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_identity
,
elementwise_add_y
,
};
elementwise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_out
,
elementwise_add_out
,
return
ExecuteHandleOnGraph
<
IdentityFuseHandle
>
(
elementwise_add_pattern
);
&
gpd
,
graph_with_stats
,
[
this
,
&
conv_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
if
(
FindFuseOption
(
*
conv_op
,
*
elementwise_add_op
)
!=
FUSE_MKLDNN
)
return
;
return
GetNodesFromConv
(
conv_pattern
,
subgraph
);
},
if
(
!
IsReachable
(
g
,
elementwise_add_identity
,
conv_output
))
return
;
get_node_from_elementwise_add
,
this
);
if
(
HasFusedActivation
(
conv_op
))
return
;
if
(
!
IsCompat
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"conv_elementwise_add_mkldnn_fuse_pass in op compat failed."
;
return
;
}
conv_op
->
Op
()
->
SetInput
(
"ResidualData"
,
{
elementwise_add_identity
->
Name
()});
conv_op
->
Op
()
->
SetOutput
(
"Output"
,
{
elementwise_add_out
->
Name
()});
conv_op
->
Op
()
->
SetAttr
(
"fuse_residual_connection"
,
true
);
GraphSafeRemoveNodes
(
g
,
{
conv_output
,
elementwise_add_op
});
IR_NODE_LINK_TO
(
elementwise_add_identity
,
conv_op
);
IR_NODE_LINK_TO
(
conv_op
,
elementwise_add_out
);
found_conv_as_x_count
++
;
};
gpd
(
graph_with_stats
.
first
,
handler
);
if
(
!
Has
(
"disable_logs"
)
||
!
Get
<
bool
>
(
"disable_logs"
))
{
std
::
stringstream
msg_ss
;
msg_ss
<<
"--- Fused "
<<
found_conv_as_x_count
<<
" conv (as x) + elementwise_add patterns"
;
paddle
::
string
::
PrettyLogDetail
(
msg_ss
.
str
().
c_str
());
}
return
std
::
make_pair
(
graph_with_stats
.
first
,
found_conv_as_x_count
+
graph_with_stats
.
second
);
}
}
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseConvAsY
(
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseConvAsY
(
...
@@ -335,26 +218,56 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
...
@@ -335,26 +218,56 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
conv_output
);
conv_output
);
conv_output
->
AsIntermediate
();
conv_output
->
AsIntermediate
();
auto
get_node_from_elementwise_add
=
[
&
elementwise_add_pattern
](
int
found_conv_as_y_count
=
0
;
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
->
std
::
tuple
<
Node
*
,
Node
*
,
Node
*>
{
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_op
,
elementwise_add_op
,
Graph
*
g
)
{
elementwise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_op
,
conv_op
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_x
,
elementwise_add_x
,
GET_IR_NODE_FROM_SUBGRAPH
(
conv_input
,
conv_input
,
conv_pattern
);
elementwise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_filter
,
conv_filter
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_out
,
elementwise_add_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
conv_output
,
conv_output
,
conv_pattern
);
elementwise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_op
,
elementwise_add_op
,
return
std
::
make_tuple
(
elementwise_add_op
,
elementwise_add_x
,
elementwise_add_pattern
);
elementwise_add_out
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_x
,
elementwise_add_x
,
};
elementwise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_out
,
elementwise_add_out
,
return
ExecuteHandleOnGraph
<
IdentityFuseHandle
>
(
elementwise_add_pattern
);
&
gpd
,
graph_with_stats
,
[
this
,
&
conv_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
if
(
FindFuseOption
(
*
conv_op
,
*
elementwise_add_op
)
!=
FUSE_MKLDNN
)
return
;
return
GetNodesFromConv
(
conv_pattern
,
subgraph
);
},
if
(
!
IsReachable
(
g
,
elementwise_add_x
,
conv_output
))
return
;
get_node_from_elementwise_add
,
this
);
if
(
HasFusedActivation
(
conv_op
))
return
;
if
(
!
IsCompat
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"conv_elementwise_add_mkldnn_fuse_pass in op compat failed."
;
return
;
}
conv_op
->
Op
()
->
SetInput
(
"ResidualData"
,
{
elementwise_add_x
->
Name
()});
conv_op
->
Op
()
->
SetOutput
(
"Output"
,
{
elementwise_add_out
->
Name
()});
conv_op
->
Op
()
->
SetAttr
(
"fuse_residual_connection"
,
true
);
GraphSafeRemoveNodes
(
g
,
{
conv_output
,
elementwise_add_op
});
IR_NODE_LINK_TO
(
elementwise_add_x
,
conv_op
);
IR_NODE_LINK_TO
(
conv_op
,
elementwise_add_out
);
found_conv_as_y_count
++
;
};
gpd
(
graph_with_stats
.
first
,
handler
);
if
(
!
Has
(
"disable_logs"
)
||
!
Get
<
bool
>
(
"disable_logs"
))
{
std
::
stringstream
msg_ss
;
msg_ss
<<
"--- Fused "
<<
found_conv_as_y_count
<<
" conv (as y) + elementwise_add patterns"
;
paddle
::
string
::
PrettyLogDetail
(
msg_ss
.
str
().
c_str
());
}
return
std
::
make_pair
(
graph_with_stats
.
first
,
found_conv_as_y_count
+
graph_with_stats
.
second
);
}
}
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseProjectionConv
(
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseProjectionConv
(
...
@@ -374,39 +287,84 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
...
@@ -374,39 +287,84 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
conv_x_output
->
AsIntermediate
();
conv_x_output
->
AsIntermediate
();
conv_y_output
->
AsIntermediate
();
conv_y_output
->
AsIntermediate
();
auto
get_node_from_elementwise_add
=
[
&
elementwise_add_pattern
](
int
found_projection_conv_count
=
0
;
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
->
std
::
tuple
<
Node
*
,
Node
*>
{
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_op
,
elementwise_add_op
,
Graph
*
g
)
{
elementwise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_x_op
,
conv_op
,
conv_x_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_out
,
elementwise_add_out
,
GET_IR_NODE_FROM_SUBGRAPH
(
conv_x_input
,
conv_input
,
conv_x_pattern
);
elementwise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_x_filter
,
conv_filter
,
conv_x_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_x_output
,
conv_output
,
conv_x_pattern
);
return
std
::
make_tuple
(
elementwise_add_op
,
elementwise_add_out
);
};
GET_IR_NODE_FROM_SUBGRAPH
(
conv_y_op
,
conv_op
,
conv_y_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_y_input
,
conv_input
,
conv_y_pattern
);
return
ExecuteHandleOnGraph
<
ProjectionFuseHandle
>
(
GET_IR_NODE_FROM_SUBGRAPH
(
conv_y_filter
,
conv_filter
,
conv_y_pattern
);
&
gpd
,
graph_with_stats
,
GET_IR_NODE_FROM_SUBGRAPH
(
conv_y_output
,
conv_output
,
conv_y_pattern
);
[
this
,
&
conv_x_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_op
,
elementwise_add_op
,
return
GetNodesFromConv
(
conv_x_pattern
,
subgraph
);
elementwise_add_pattern
);
},
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_out
,
elementwise_add_out
,
[
this
,
elementwise_add_pattern
);
&
conv_y_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
return
GetNodesFromConv
(
conv_y_pattern
,
subgraph
);
if
(
!
IsCompat
(
subgraph
,
g
))
{
},
LOG
(
WARNING
)
get_node_from_elementwise_add
,
this
);
<<
"conv_elementwise_add_mkldnn_fuse_pass in op compat failed."
;
return
;
}
if
(
FindFuseOption
(
*
conv_x_op
,
*
elementwise_add_op
)
!=
FUSE_MKLDNN
)
return
;
if
(
FindFuseOption
(
*
conv_y_op
,
*
elementwise_add_op
)
!=
FUSE_MKLDNN
)
return
;
Node
*
projection_node
;
Node
*
residual_conv_op
;
Node
*
residual_conv_output
;
if
(
IsReachable
(
g
,
conv_x_input
,
conv_y_output
))
{
projection_node
=
conv_x_output
;
residual_conv_op
=
conv_y_op
;
residual_conv_output
=
conv_y_output
;
}
else
if
(
IsReachable
(
g
,
conv_y_input
,
conv_x_output
))
{
projection_node
=
conv_y_output
;
residual_conv_op
=
conv_x_op
;
residual_conv_output
=
conv_x_output
;
}
else
{
return
;
}
if
(
HasFusedActivation
(
residual_conv_op
))
return
;
residual_conv_op
->
Op
()
->
SetInput
(
"ResidualData"
,
{
projection_node
->
Name
()});
residual_conv_op
->
Op
()
->
SetOutput
(
"Output"
,
{
elementwise_add_out
->
Name
()});
residual_conv_op
->
Op
()
->
SetAttr
(
"fuse_residual_connection"
,
true
);
GraphSafeRemoveNodes
(
g
,
{
residual_conv_output
,
elementwise_add_op
});
IR_NODE_LINK_TO
(
projection_node
,
residual_conv_op
);
IR_NODE_LINK_TO
(
residual_conv_op
,
elementwise_add_out
);
found_projection_conv_count
++
;
};
gpd
(
graph_with_stats
.
first
,
handler
);
if
(
!
Has
(
"disable_logs"
)
||
!
Get
<
bool
>
(
"disable_logs"
))
{
std
::
stringstream
msg_ss
;
msg_ss
<<
"--- Fused "
<<
found_projection_conv_count
<<
" projection conv (as y) + elementwise_add patterns"
;
paddle
::
string
::
PrettyLogDetail
(
msg_ss
.
str
().
c_str
());
}
return
std
::
make_pair
(
graph_with_stats
.
first
,
found_projection_conv_count
+
graph_with_stats
.
second
);
}
}
void
ResidualConnectionMKLDNNFusePass
::
ApplyImpl
(
graph_ptr
graph
)
const
{
void
ResidualConnectionMKLDNNFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
fused_graph_with_stats
=
FuseConvAsY
(
auto
graph_with_stats
=
name_scope_
,
FuseProjectionConv
(
name_scope_
,
std
::
make_pair
(
graph
,
0
));
FuseConvAsX
(
name_scope_
,
graph_with_stats
=
FuseConvAsX
(
name_scope_
,
graph_with_stats
);
FuseProjectionConv
(
name_scope_
,
std
::
make_pair
(
graph
,
0
)))
);
graph_with_stats
=
FuseConvAsY
(
name_scope_
,
graph_with_stats
);
LOG
(
INFO
)
<<
"Fused graph "
<<
fused_graph_with_stats
.
second
<<
"
\n
"
;
AddStatis
(
graph_with_stats
.
second
);
AddStatis
(
fused_graph_with_stats
.
second
);
}
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h
浏览文件 @
47459e98
...
@@ -28,19 +28,9 @@ namespace paddle {
...
@@ -28,19 +28,9 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
class
Graph
;
class
GraphPatternDetector
;
class
Node
;
namespace
patterns
{
struct
Conv
;
}
// namespace patterns
using
graph_ptr
=
ir
::
Graph
*
;
using
GraphWithStats
=
std
::
pair
<
ir
::
Graph
*
,
int
>
;
using
GraphWithStats
=
std
::
pair
<
ir
::
Graph
*
,
int
>
;
void
CorrectGraphEdges
(
Graph
*
graph
,
Node
*
from
,
Node
*
to
);
bool
IsReachable
(
ir
::
Graph
*
graph
,
Node
*
from
,
Node
*
to
);
bool
IsReachable
(
ir
::
Graph
*
graph
,
Node
*
from
,
Node
*
to
);
paddle
::
optional
<
Node
*>
HasBias
(
const
Node
&
op
,
const
std
::
string
&
bias_name
);
class
ResidualConnectionMKLDNNFusePass
:
public
FusePassBase
{
class
ResidualConnectionMKLDNNFusePass
:
public
FusePassBase
{
private:
private:
...
@@ -52,91 +42,13 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
...
@@ -52,91 +42,13 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
const
std
::
string
&
name_scope
,
const
std
::
string
&
name_scope
,
const
GraphWithStats
&
graph_with_stats
)
const
;
const
GraphWithStats
&
graph_with_stats
)
const
;
template
<
typename
RetType
>
using
GetNodeFunc
=
std
::
function
<
RetType
(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
>
;
using
IdentityConvFunc
=
GetNodeFunc
<
std
::
tuple
<
Node
*
,
Node
*
,
Node
*
,
Node
*>>
;
using
IdentityElementwiseAddFunc
=
GetNodeFunc
<
std
::
tuple
<
Node
*
,
Node
*
,
Node
*>>
;
using
ProjectionConvFunc
=
IdentityConvFunc
;
using
ProjectionElementwiseAddFunc
=
GetNodeFunc
<
std
::
tuple
<
Node
*
,
Node
*>>
;
using
CanFuseFunc
=
std
::
function
<
bool
(
Node
*
,
Node
*
)
>
;
std
::
tuple
<
Node
*
,
Node
*
,
Node
*
,
Node
*>
GetNodesFromConv
(
const
patterns
::
Conv
&
conv_pattern
,
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
const
;
std
::
tuple
<
Node
*
,
Node
*
,
Node
*
,
Node
*>
GetNodesFromProjectionConv
(
const
patterns
::
Conv
&
conv_pattern
,
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
const
;
template
<
typename
HandleType
,
typename
...
OpFuncs
>
GraphWithStats
ExecuteHandleOnGraph
(
GraphPatternDetector
*
gpd
,
const
GraphWithStats
&
graph_with_stats
,
OpFuncs
&&
...
op_funcs
)
const
{
ir
::
Graph
*
graph
;
int
stats
;
std
::
tie
(
graph
,
stats
)
=
graph_with_stats
;
auto
can_fuse
=
[
this
](
Node
*
op1
,
Node
*
op2
)
->
bool
{
return
this
->
FindFuseOption
(
*
op1
,
*
op2
)
==
FUSE_MKLDNN
;
};
auto
fuse_handle
=
HandleType
{
can_fuse
,
std
::
forward
<
OpFuncs
>
(
op_funcs
)...};
(
*
gpd
)(
graph
,
fuse_handle
);
return
std
::
make_pair
(
graph
,
stats
+
fuse_handle
.
get_stats
());
}
struct
IdentityFuseHandle
{
IdentityFuseHandle
(
const
CanFuseFunc
&
can_fuse_func
,
const
IdentityConvFunc
&
get_node_from_conv_op
,
const
IdentityElementwiseAddFunc
&
get_node_from_elementwise_add_op
,
const
ResidualConnectionMKLDNNFusePass
*
pass
);
void
operator
()(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
);
int
get_stats
()
const
{
return
*
fusion_stats
;
}
private:
std
::
shared_ptr
<
int
>
fusion_stats
;
CanFuseFunc
can_fuse_func
;
IdentityConvFunc
get_node_from_conv_op
;
IdentityElementwiseAddFunc
get_node_from_elementwise_add_op
;
const
ResidualConnectionMKLDNNFusePass
*
pass_
;
};
struct
ProjectionFuseHandle
{
ProjectionFuseHandle
(
const
CanFuseFunc
&
can_fuse_func
,
const
ProjectionConvFunc
&
get_node_from_conv_x_op
,
const
ProjectionConvFunc
&
get_node_from_conv_y_op
,
const
ProjectionElementwiseAddFunc
&
get_node_from_elementwise_add_op
,
const
ResidualConnectionMKLDNNFusePass
*
pass
);
void
operator
()(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
);
int
get_stats
()
const
{
return
*
fusion_stats
;
}
private:
std
::
shared_ptr
<
int
>
fusion_stats
;
CanFuseFunc
can_fuse_func
;
ProjectionConvFunc
get_node_from_conv_x_op
;
ProjectionConvFunc
get_node_from_conv_y_op
;
ProjectionElementwiseAddFunc
get_node_from_elementwise_add_op
;
const
ResidualConnectionMKLDNNFusePass
*
pass_
;
};
public:
public:
ResidualConnectionMKLDNNFusePass
();
ResidualConnectionMKLDNNFusePass
();
virtual
~
ResidualConnectionMKLDNNFusePass
()
{}
virtual
~
ResidualConnectionMKLDNNFusePass
()
{}
protected:
protected:
void
ApplyImpl
(
graph_ptr
graph
)
const
;
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
;
static
bool
HasFusedActivation
(
Node
*
conv_node
)
{
static
bool
HasFusedActivation
(
Node
*
conv_node
)
{
return
!
(
conv_node
->
Op
()
return
!
(
conv_node
->
Op
()
->
GetAttrIfExists
<
std
::
string
>
(
"fuse_activation"
)
->
GetAttrIfExists
<
std
::
string
>
(
"fuse_activation"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录