Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
42240893
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
42240893
编写于
11月 08, 2018
作者:
T
Tomasz Patejko
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MKLDNN residual connections fuse pass: Maybe removed and boost::optional used where it makes sense
上级
86fd3b32
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
81 addition
and
88 deletion
+81
-88
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
...uid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
+70
-55
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h
...luid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h
+11
-33
未找到文件。
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
浏览文件 @
42240893
...
...
@@ -99,7 +99,7 @@ bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
return
false
;
}
std
::
pair
<
bool
,
Node
*>
HasBias
(
const
Node
&
op
,
const
std
::
string
&
bias_name
)
{
boost
::
optional
<
Node
*>
HasBias
(
const
Node
&
op
,
const
std
::
string
&
bias_name
)
{
auto
bias_input_names
=
op
.
Op
()
->
Inputs
();
auto
bias_it
=
bias_input_names
.
find
(
bias_name
);
...
...
@@ -113,11 +113,11 @@ std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name) {
[
&
bias_names
](
Node
*
n
)
->
bool
{
return
n
->
Name
()
==
bias_names
[
0
];
});
return
std
::
make_pair
(
has_bias
,
*
bias_names_it
)
;
return
*
bias_names_it
;
}
}
return
std
::
make_pair
(
false
,
nullptr
)
;
return
boost
::
none
;
}
ResidualConnectionMKLDNNFusePass
::
FuseHandler
::
FuseHandler
(
...
...
@@ -125,7 +125,8 @@ ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler(
const
ResidualConnectionMKLDNNFusePass
::
ElementwiseAddFunc
&
get_node_from_elementwise_add_op
,
const
ResidualConnectionMKLDNNFusePass
::
CanFuseFunc
&
can_fuse_func
)
:
get_node_from_conv_op
{
get_node_from_conv_op
},
:
fusion_stats
{
std
::
make_shared
<
int
>
(
0
)},
get_node_from_conv_op
{
get_node_from_conv_op
},
get_node_from_elementwise_add_op
{
get_node_from_elementwise_add_op
},
can_fuse_func
{
can_fuse_func
}
{}
...
...
@@ -157,13 +158,10 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
op_desc
.
SetInput
(
"ResidualData"
,
{
elementwise_add_identity
->
Name
()});
op_desc
.
SetOutput
(
"Output"
,
{
conv_output
->
Name
()});
bool
has_bias
;
Node
*
conv_bias
;
auto
conv_bias
=
HasBias
(
*
conv_op
,
"Bias"
);
std
::
tie
(
has_bias
,
conv_bias
)
=
HasBias
(
*
conv_op
,
"Bias"
);
if
(
has_bias
)
{
op_desc
.
SetInput
(
"Bias"
,
{
conv_bias
->
Name
()});
if
(
conv_bias
)
{
op_desc
.
SetInput
(
"Bias"
,
{(
*
conv_bias
)
->
Name
()});
}
for
(
const
auto
&
attr
:
conv_op
->
Op
()
->
GetAttrMap
())
{
...
...
@@ -179,40 +177,48 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
IR_NODE_LINK_TO
(
elementwise_add_identity
,
fused_conv_op
);
IR_NODE_LINK_TO
(
fused_conv_op
,
conv_output
);
if
(
has
_bias
)
{
IR_NODE_LINK_TO
(
conv_bias
,
fused_conv_op
);
if
(
conv
_bias
)
{
IR_NODE_LINK_TO
(
(
*
conv_bias
)
,
fused_conv_op
);
}
CorrectGraphEdges
(
graph
,
elementwise_add_out
,
conv_output
);
GraphSafeRemoveNodes
(
graph
,
{
elementwise_add_out
,
conv_op
,
elementwise_add_op
});
(
*
fusion_stats
)
++
;
}
std
::
tuple
<
Node
*
,
Node
*
,
Node
*
,
Node
*>
ResidualConnectionMKLDNNFusePass
::
GetNodesFromConv
(
const
patterns
::
Conv
&
conv_pattern
,
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
const
{
GET_IR_NODE_FROM_SUBGRAPH
(
conv_op
,
conv_op
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_input
,
conv_input
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_filter
,
conv_filter
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_output
,
conv_output
,
conv_pattern
);
return
std
::
make_tuple
(
conv_op
,
conv_input
,
conv_filter
,
conv_output
);
}
graph_ptr
ResidualConnectionMKLDNNFusePass
::
FuseConvAsX
(
const
std
::
string
&
name_scope_
,
graph_ptr
graph
)
const
{
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseConvAsX
(
const
std
::
string
&
name_scope
,
const
GraphWithStats
&
graph_with_stats
)
const
{
ir
::
Graph
*
graph
;
int
stats
;
std
::
tie
(
graph
,
stats
)
=
graph_with_stats
;
GraphPatternDetector
gpd
;
auto
pattern
=
gpd
.
mutable_pattern
();
patterns
::
Conv
conv_pattern
{
pattern
,
name_scope
_
};
patterns
::
Conv
conv_pattern
{
pattern
,
name_scope
};
auto
conv_output
=
conv_pattern
();
patterns
::
ElementwiseAdd
elementwise_add_pattern
{
pattern
,
name_scope
_
};
patterns
::
ElementwiseAdd
elementwise_add_pattern
{
pattern
,
name_scope
};
elementwise_add_pattern
(
conv_output
,
pattern
->
NewNode
(
elementwise_add_pattern
.
elementwise_add_y_repr
()));
conv_output
->
AsIntermediate
();
auto
get_node_from_conv
=
[
&
conv_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
->
std
::
tuple
<
Node
*
,
Node
*
,
Node
*
,
Node
*>
{
GET_IR_NODE_FROM_SUBGRAPH
(
conv_op
,
conv_op
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_input
,
conv_input
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_filter
,
conv_filter
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_output
,
conv_output
,
conv_pattern
);
return
std
::
make_tuple
(
conv_op
,
conv_input
,
conv_filter
,
conv_output
);
};
auto
get_node_from_elementwise_add
=
[
&
elementwise_add_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
->
std
::
tuple
<
Node
*
,
Node
*
,
Node
*>
{
...
...
@@ -227,43 +233,29 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
elementwise_add_out
);
};
auto
can_fuse
=
[
this
](
Node
*
op1
,
Node
*
op2
)
->
bool
{
return
this
->
FindFuseOption
(
*
op1
,
*
op2
)
==
FUSE_MKLDNN
;
};
auto
fuse_handler
=
FuseHandler
{
get_node_from_conv
,
get_node_from_elementwise_add
,
can_fuse
};
gpd
(
graph
.
get
(),
fuse_handler
);
return
graph
;
return
ExecuteHandlerOnGraph
(
&
gpd
,
graph_with_stats
,
[
this
,
&
conv_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
return
GetNodesFromConv
(
conv_pattern
,
subgraph
);
},
get_node_from_elementwise_add
);
}
graph_ptr
ResidualConnectionMKLDNNFusePass
::
FuseConvAsY
(
const
std
::
string
&
name_scope_
,
graph_ptr
graph
)
const
{
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseConvAsY
(
const
std
::
string
&
name_scope
,
const
GraphWithStats
&
graph_with_stats
)
const
{
GraphPatternDetector
gpd
;
auto
pattern
=
gpd
.
mutable_pattern
();
patterns
::
Conv
conv_pattern
{
pattern
,
name_scope
_
};
patterns
::
Conv
conv_pattern
{
pattern
,
name_scope
};
auto
conv_output
=
conv_pattern
();
patterns
::
ElementwiseAdd
elementwise_add_pattern
{
pattern
,
name_scope
_
};
patterns
::
ElementwiseAdd
elementwise_add_pattern
{
pattern
,
name_scope
};
elementwise_add_pattern
(
pattern
->
NewNode
(
elementwise_add_pattern
.
elementwise_add_x_repr
()),
conv_output
);
conv_output
->
AsIntermediate
();
auto
get_node_from_conv
=
[
&
conv_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
->
std
::
tuple
<
Node
*
,
Node
*
,
Node
*
,
Node
*>
{
GET_IR_NODE_FROM_SUBGRAPH
(
conv_op
,
conv_op
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_input
,
conv_input
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_filter
,
conv_filter
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_output
,
conv_output
,
conv_pattern
);
return
std
::
make_tuple
(
conv_op
,
conv_input
,
conv_filter
,
conv_output
);
};
auto
get_node_from_elementwise_add
=
[
&
elementwise_add_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
->
std
::
tuple
<
Node
*
,
Node
*
,
Node
*>
{
...
...
@@ -278,6 +270,24 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
elementwise_add_out
);
};
return
ExecuteHandlerOnGraph
(
&
gpd
,
graph_with_stats
,
[
this
,
&
conv_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
return
GetNodesFromConv
(
conv_pattern
,
subgraph
);
},
get_node_from_elementwise_add
);
}
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
ExecuteHandlerOnGraph
(
GraphPatternDetector
*
gpd
,
const
GraphWithStats
&
graph_with_stats
,
const
ResidualConnectionMKLDNNFusePass
::
ConvFunc
&
get_node_from_conv
,
const
ResidualConnectionMKLDNNFusePass
::
ElementwiseAddFunc
&
get_node_from_elementwise_add
)
const
{
ir
::
Graph
*
graph
;
int
stats
;
std
::
tie
(
graph
,
stats
)
=
graph_with_stats
;
auto
can_fuse
=
[
this
](
Node
*
op1
,
Node
*
op2
)
->
bool
{
return
this
->
FindFuseOption
(
*
op1
,
*
op2
)
==
FUSE_MKLDNN
;
};
...
...
@@ -285,15 +295,20 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
auto
fuse_handler
=
FuseHandler
{
get_node_from_conv
,
get_node_from_elementwise_add
,
can_fuse
};
gpd
(
graph
.
get
()
,
fuse_handler
);
(
*
gpd
)(
graph
,
fuse_handler
);
return
graph
;
return
std
::
make_pair
(
graph
,
stats
+
fuse_handler
.
get_stats
())
;
}
graph_ptr
ResidualConnectionMKLDNNFusePass
::
ApplyImpl
(
graph_ptr
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
.
get
());
return
FuseConvAsY
(
name_scope_
,
FuseConvAsX
(
name_scope_
,
std
::
move
(
graph
)));
auto
fused_graph_with_stats
=
FuseConvAsY
(
name_scope_
,
FuseConvAsX
(
name_scope_
,
std
::
make_pair
(
graph
.
get
(),
0
)));
std
::
cout
<<
"Fused graph "
<<
fused_graph_with_stats
.
second
<<
std
::
endl
;
AddStatis
(
fused_graph_with_stats
.
second
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
...
...
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h
浏览文件 @
42240893
...
...
@@ -27,43 +27,12 @@ namespace paddle {
namespace
framework
{
namespace
ir
{
// poor replacement for C++17 std::optional and Boost.Optional
struct
InPlace
{};
InPlace
in_place
;
template
<
typename
T
>
class
Maybe
{
private:
typename
std
::
aligned_storage
<
sizeof
(
T
),
alignof
(
T
)
>::
type
data
;
bool
is_initialized
{
false
};
public:
template
<
typename
...
Args
>
explicit
Maybe
(
InPlace
,
Args
&&
...
args
)
{
new
(
&
data
)
T
(
std
::
forward
<
Args
>
(
args
)...);
is_initialized
=
true
;
}
Maybe
()
{}
operator
bool
()
{
return
is_initialized
;
}
T
&
value
()
{
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
~
Maybe
()
{
reinterpret_cast
<
T
*>
(
&
data
)
->~
T
();
}
};
template
<
typename
T
,
typename
...
Args
>
Maybe
<
T
>
MakeMaybe
(
Args
&&
...
args
)
{
return
Maybe
<
T
>
(
in_place
,
std
::
forward
<
Args
>
(
args
)...);
}
using
graph_ptr
=
std
::
unique_ptr
<
ir
::
Graph
>
;
using
GraphWithStats
=
std
::
pair
<
ir
::
Graph
*
,
Maybe
<
int
>
>
;
using
GraphWithStats
=
std
::
pair
<
ir
::
Graph
*
,
int
>
;
void
CorrectGraphEdges
(
Graph
*
graph
,
Node
*
from
,
Node
*
to
);
bool
IsReachable
(
ir
::
Graph
*
graph
,
Node
*
from
,
Node
*
to
);
std
::
pair
<
bool
,
Node
*>
HasBias
(
const
Node
&
op
,
const
std
::
string
&
bias_name
);
boost
::
optional
<
Node
*>
HasBias
(
const
Node
&
op
,
const
std
::
string
&
bias_name
);
class
ResidualConnectionMKLDNNFusePass
:
public
FusePassBase
{
private:
...
...
@@ -79,6 +48,15 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
using
ElementwiseAddFunc
=
GetNodeFunc
<
std
::
tuple
<
Node
*
,
Node
*
,
Node
*>>
;
using
CanFuseFunc
=
std
::
function
<
bool
(
Node
*
,
Node
*
)
>
;
std
::
tuple
<
Node
*
,
Node
*
,
Node
*
,
Node
*>
GetNodesFromConv
(
const
patterns
::
Conv
&
conv_pattern
,
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
const
;
GraphWithStats
ExecuteHandlerOnGraph
(
GraphPatternDetector
*
gpd
,
const
GraphWithStats
&
graph_with_stats
,
const
ConvFunc
&
get_node_from_conv
,
const
ElementwiseAddFunc
&
get_node_from_elementwise_add
)
const
;
struct
FuseHandler
{
FuseHandler
(
const
ConvFunc
&
get_node_from_conv_op
,
const
ElementwiseAddFunc
&
get_node_from_elementwise_add_op
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录