Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
27573ece
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
27573ece
编写于
9月 18, 2018
作者:
T
Tomasz Patejko
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MKLDNN conv + elementwise_add fusion: trailing spaces removed
上级
7f5c8a95
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
117 addition
and
89 deletion
+117
-89
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
...uid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
+74
-63
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
...mework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
+43
-26
未找到文件。
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
浏览文件 @
27573ece
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include <functional>
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace
paddle
{
...
...
@@ -8,15 +24,14 @@ namespace patterns {
struct
Pattern
:
public
PatternBase
{
Pattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
{
pattern
,
name_scope
,
""
}
{
}
private:
:
PatternBase
{
pattern
,
name_scope
,
""
}
{}
private:
std
::
string
name_scope
()
{
return
name_scope_
;
}
std
::
string
repr
()
{
return
repr_
;
}
std
::
string
repr
()
{
return
repr_
;
}
size_t
id
()
{
return
id_
;
}
PDPattern
*
node_pattern
()
{
return
pattern
;
}
public:
std
::
string
node_name
(
std
::
string
op_name
)
{
return
PDNodeName
(
name_scope
(),
repr
(),
id
(),
op_name
);
...
...
@@ -37,22 +52,18 @@ struct Conv {
std
::
string
filter_name
()
{
return
"Filter"
;
}
std
::
string
output_name
()
{
return
"Output"
;
}
std
::
function
<
PDNode
*
()
>
operator
()(
std
::
shared_ptr
<
Pattern
>
pattern
)
{
std
::
function
<
PDNode
*
()
>
operator
()(
std
::
shared_ptr
<
Pattern
>
pattern
)
{
return
[
&
]()
->
PDNode
*
{
auto
conv_op
=
pattern
->
new_node
(
op_name
())
->
assert_is_op
(
"conv2d"
);
auto
conv_op
=
pattern
->
new_node
(
op_name
())
->
assert_is_op
(
"conv2d"
);
auto
input_var
=
pattern
->
new_node
(
input_name
())
->
assert_is_op_input
(
op_name
(),
input_name
());
->
assert_is_op_input
(
op_name
(),
input_name
());
auto
filter_var
=
pattern
->
new_node
(
filter_name
())
->
assert_is_op_input
(
op_name
(),
filter_name
());
->
assert_is_op_input
(
op_name
(),
filter_name
());
auto
output_var
=
pattern
->
new_node
(
output_name
())
->
assert_is_op_output
(
op_name
(),
output_name
());
->
assert_is_op_output
(
op_name
(),
output_name
());
conv_op
->
LinksFrom
({
input_var
,
filter_var
});
conv_op
->
LinksTo
({
output_var
});
...
...
@@ -68,22 +79,19 @@ struct ElementwiseAdd {
std
::
string
y_name
()
{
return
"Y"
;
}
std
::
string
out_name
()
{
return
"Out"
;
}
std
::
function
<
PDNode
*
(
PDNode
*
)
>
operator
()(
std
::
shared_ptr
<
Pattern
>
pattern
)
{
std
::
function
<
PDNode
*
(
PDNode
*
)
>
operator
()(
std
::
shared_ptr
<
Pattern
>
pattern
)
{
return
[
&
](
PDNode
*
conv_output
)
->
PDNode
*
{
auto
elementwise_add_op
=
pattern
->
new_node
(
op_name
())
->
assert_is_op
(
"elementwise_add"
);
auto
elementwise_add_op
=
pattern
->
new_node
(
op_name
())
->
assert_is_op
(
"elementwise_add"
);
auto
x_var
=
pattern
->
new_node
(
x_name
())
->
assert_is_op_input
(
op_name
(),
x_name
());
auto
x_var
=
pattern
->
new_node
(
x_name
())
->
assert_is_op_input
(
op_name
(),
x_name
());
conv_output
->
assert_is_op_input
(
op_name
(),
y_name
());
conv_output
->
assert_is_op_input
(
op_name
(),
y_name
());
auto
out_var
=
pattern
->
new_node
(
out_name
())
->
AsOutput
()
->
assert_is_op_output
(
op_name
(),
out_name
());
->
AsOutput
()
->
assert_is_op_output
(
op_name
(),
out_name
());
elementwise_add_op
->
LinksFrom
({
x_var
,
conv_output
});
elementwise_add_op
->
LinksTo
({
out_var
});
...
...
@@ -94,13 +102,13 @@ struct ElementwiseAdd {
};
Node
*
GetNodeFromSubgraph
(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
std
::
shared_ptr
<
patterns
::
Pattern
>
pattern
,
const
std
::
string
&
op_name
)
{
std
::
shared_ptr
<
patterns
::
Pattern
>
pattern
,
const
std
::
string
&
op_name
)
{
PADDLE_ENFORCE
(
subgraph
.
count
(
pattern
->
retrieve_node
(
op_name
)),
"Node not found for PDNode %s"
,
pattern
->
node_name
(
op_name
));
Node
*
var
=
subgraph
.
at
(
pattern
->
retrieve_node
(
op_name
));
PADDLE_ENFORCE
(
var
,
"node %s not exists in the sub-graph"
);
return
var
;
}
...
...
@@ -109,10 +117,9 @@ void LinkNodes(Node* from, Node* to) {
to
->
inputs
.
push_back
(
from
);
}
template
<
typename
IT
,
typename
FindFunc
,
typename
ReplaceFunc
>
template
<
typename
IT
,
typename
FindFunc
,
typename
ReplaceFunc
>
void
ReplaceAllOccurances
(
IT
s
,
IT
e
,
FindFunc
f
,
ReplaceFunc
r
)
{
if
(
s
==
e
)
return
;
if
(
s
==
e
)
return
;
auto
it
=
std
::
find_if
(
s
,
e
,
f
);
...
...
@@ -126,8 +133,7 @@ void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) {
void
CorrectGraphEdges
(
Graph
*
graph
,
Node
*
from
,
Node
*
to
)
{
for
(
auto
&
node
:
GraphTraits
::
DFS
(
*
graph
))
{
auto
same
=
std
::
find_if
(
std
::
begin
(
node
.
inputs
),
std
::
end
(
node
.
inputs
),
auto
same
=
std
::
find_if
(
std
::
begin
(
node
.
inputs
),
std
::
end
(
node
.
inputs
),
[
from
](
Node
*
n
)
{
return
n
==
from
;
});
if
(
same
!=
std
::
end
(
node
.
inputs
))
{
...
...
@@ -137,17 +143,19 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
using
input_type
=
VariableNameMap
::
value_type
;
ReplaceAllOccurances
(
std
::
begin
(
inputs
),
std
::
end
(
inputs
),
[
from
](
const
input_type
&
i
)
->
bool
{
auto
params
=
i
.
second
;
auto
pi
=
std
::
find_if
(
std
::
begin
(
params
),
std
::
end
(
params
),
std
::
bind
(
std
::
equal_to
<
std
::
string
>
(),
from
->
Name
(),
std
::
placeholders
::
_1
));
return
pi
!=
std
::
end
(
params
);
},
[
to
,
&
node
](
const
input_type
&
i
)
{
node
.
Op
()
->
SetInput
(
i
.
first
,
{
to
->
Name
()});
});
ReplaceAllOccurances
(
std
::
begin
(
inputs
),
std
::
end
(
inputs
),
[
from
](
const
input_type
&
i
)
->
bool
{
auto
params
=
i
.
second
;
auto
pi
=
std
::
find_if
(
std
::
begin
(
params
),
std
::
end
(
params
),
std
::
bind
(
std
::
equal_to
<
std
::
string
>
(),
from
->
Name
(),
std
::
placeholders
::
_1
));
return
pi
!=
std
::
end
(
params
);
},
[
to
,
&
node
](
const
input_type
&
i
)
{
node
.
Op
()
->
SetInput
(
i
.
first
,
{
to
->
Name
()});
});
}
}
}
...
...
@@ -169,7 +177,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_output
->
AsIntermediate
();
auto
fuse_conv
=
[](
Graph
*
g
,
Node
*
conv_input
,
Node
*
conv_filter
,
Node
*
conv_output
,
Node
*
elementwise_add_x
)
{
auto
fuse_conv
=
[](
Graph
*
g
,
Node
*
conv_input
,
Node
*
conv_filter
,
Node
*
conv_output
,
Node
*
elementwise_add_x
)
{
OpDesc
op_desc
;
op_desc
.
SetType
(
"conv2d"
);
...
...
@@ -189,22 +198,23 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
patterns
::
LinkNodes
(
fused_conv_op
,
conv_output
);
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
auto
conv_op
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
conv_pattern
.
op_name
());
conv_pattern
.
op_name
());
auto
conv_input
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
conv_pattern
.
input_name
());
auto
conv_filter
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
conv_pattern
.
filter_name
());
auto
conv_output
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
conv_pattern
.
output_name
());
auto
elementwise_add_op
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
elementwise_add_pattern
.
op_name
());
auto
elementwise_add_x
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
elementwise_add_pattern
.
x_name
());
auto
elementwise_add_out
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
elementwise_add_pattern
.
out_name
());
conv_pattern
.
input_name
());
auto
conv_filter
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
conv_pattern
.
filter_name
());
auto
conv_output
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
conv_pattern
.
output_name
());
auto
elementwise_add_op
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
elementwise_add_pattern
.
op_name
());
auto
elementwise_add_x
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
elementwise_add_pattern
.
x_name
());
auto
elementwise_add_out
=
patterns
::
GetNodeFromSubgraph
(
subgraph
,
pattern_ptr
,
elementwise_add_pattern
.
out_name
());
fuse_conv
(
g
,
conv_input
,
conv_filter
,
conv_output
,
elementwise_add_x
);
patterns
::
CorrectGraphEdges
(
g
,
elementwise_add_out
,
conv_output
);
...
...
@@ -219,4 +229,5 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
conv_elementwise_add_mkldnn_fuse_pass
,
paddle
::
framework
::
ir
::
ConvElementwiseAddMKLDNNFusePass
);
REGISTER_PASS
(
conv_elementwise_add_mkldnn_fuse_pass
,
paddle
::
framework
::
ir
::
ConvElementwiseAddMKLDNNFusePass
);
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
浏览文件 @
27573ece
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -33,10 +47,11 @@ void SetOp(ProgramDesc* prog, const std::string& type,
}
struct
IsReachable
{
using
func
=
std
::
function
<
bool
(
const
std
::
string
&
,
const
std
::
string
&
)
>
;
using
func
=
std
::
function
<
bool
(
const
std
::
string
&
,
const
std
::
string
&
)
>
;
auto
operator
()(
const
std
::
unique_ptr
<
ir
::
Graph
>&
graph
)
->
func
{
auto
find_node
=
[](
const
std
::
unique_ptr
<
ir
::
Graph
>&
graph
,
const
std
::
string
&
name
)
->
Node
*
{
auto
find_node
=
[](
const
std
::
unique_ptr
<
ir
::
Graph
>&
graph
,
const
std
::
string
&
name
)
->
Node
*
{
for
(
auto
&
node
:
GraphTraits
::
DFS
(
*
graph
))
{
if
(
name
==
node
.
Name
())
{
return
&
node
;
...
...
@@ -47,8 +62,7 @@ struct IsReachable {
};
return
[
&
](
std
::
string
from
,
const
std
::
string
to
)
->
bool
{
if
(
from
==
to
)
return
true
;
if
(
from
==
to
)
return
true
;
std
::
map
<
std
::
string
,
bool
>
visited
;
...
...
@@ -61,16 +75,14 @@ struct IsReachable {
std
::
list
<
std
::
string
>
queue
;
queue
.
push_back
(
from
);
while
(
!
queue
.
empty
())
{
while
(
!
queue
.
empty
())
{
auto
cur
=
find_node
(
graph
,
queue
.
front
());
queue
.
pop_front
();
if
(
cur
==
nullptr
)
return
false
;
if
(
cur
==
nullptr
)
return
false
;
for
(
auto
n
:
cur
->
outputs
)
{
if
(
n
->
Name
()
==
to
)
return
true
;
if
(
n
->
Name
()
==
to
)
return
true
;
if
(
!
visited
[
n
->
Name
()])
{
visited
[
n
->
Name
()]
=
true
;
...
...
@@ -87,14 +99,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
auto
build_program_desc
=
[
&
]()
->
ProgramDesc
{
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
({
"a"
,
"b"
,
"weights"
,
"c"
,
"d"
,
"e"
}))
{
std
::
vector
<
std
::
string
>
({
"a"
,
"b"
,
"weights"
,
"c"
,
"d"
,
"e"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
if
(
v
==
"weights"
)
{
var
->
SetPersistable
(
true
);
}
}
SetOp
(
&
prog
,
"conv2d"
,
{
"a"
,
"weights"
},
{
"b"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{
"c"
,
"b"
},
{
"d"
});
SetOp
(
&
prog
,
"relu"
,
{
"d"
},
{
"e"
});
...
...
@@ -109,14 +121,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"relu"
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
int
original_nodes_num
=
graph
->
Nodes
().
size
();
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"relu"
));
EXPECT_EQ
(
original_nodes_num
-
nodes_removed
+
nodes_added
,
current_nodes_num
);
EXPECT_EQ
(
original_nodes_num
-
nodes_removed
+
nodes_added
,
current_nodes_num
);
// Assert conv_relu op in newly generated graph
int
conv_count
=
0
;
int
elementwise_add_count
=
0
;
...
...
@@ -136,15 +150,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
ConvolutionElementwiseAdd
)
{
auto
build_program_desc
=
[
&
]()
->
ProgramDesc
{
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
({
"a"
,
"b"
,
"weights"
}))
{
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
({
"a"
,
"b"
,
"weights"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
if
(
v
==
"weights"
||
v
==
"bias"
)
{
var
->
SetPersistable
(
true
);
}
}
SetOp
(
&
prog
,
"conv2d"
,
{
"a"
,
"weights"
},
{
"b"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{
"c"
,
"b"
},
{
"d"
});
...
...
@@ -157,14 +170,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
IsReachable
is_reachable
;
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"d"
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
int
original_nodes_num
=
graph
->
Nodes
().
size
();
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_FALSE
(
is_reachable
(
graph
)(
"a"
,
"d"
));
EXPECT_EQ
(
original_nodes_num
-
nodes_removed
+
nodes_added
,
current_nodes_num
);
EXPECT_EQ
(
original_nodes_num
-
nodes_removed
+
nodes_added
,
current_nodes_num
);
// Assert conv_relu op in newly generated graph
int
conv_count
=
0
;
int
elementwise_add_count
=
0
;
...
...
@@ -185,14 +200,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
auto
build_program_desc
=
[
&
]()
->
ProgramDesc
{
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
({
"a"
,
"b"
,
"weights"
,
"c"
,
"d"
,
"e"
,
"f"
}))
{
std
::
vector
<
std
::
string
>
({
"a"
,
"b"
,
"weights"
,
"c"
,
"d"
,
"e"
,
"f"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
if
(
v
.
find
(
"weights"
))
{
var
->
SetPersistable
(
true
);
}
}
SetOp
(
&
prog
,
"sigmoid"
,
{
"a"
},
{
"b"
});
SetOp
(
&
prog
,
"conv2d"
,
{
"b"
,
"weights"
},
{
"c"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{
"d"
,
"c"
},
{
"e"
});
...
...
@@ -208,14 +223,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"f"
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
int
original_nodes_num
=
graph
->
Nodes
().
size
();
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"f"
));
EXPECT_EQ
(
original_nodes_num
-
nodes_removed
+
nodes_added
,
current_nodes_num
);
EXPECT_EQ
(
original_nodes_num
-
nodes_removed
+
nodes_added
,
current_nodes_num
);
// Assert conv_relu op in newly generated graph
int
conv_count
=
0
;
int
elementwise_add_count
=
0
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录