Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
d0000082
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d0000082
编写于
9月 25, 2018
作者:
T
Tao Luo
提交者:
GitHub
9月 25, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13552 from sfraczek/sfraczek/conv-relu-update
little update to conv relu fuse pass (fix)
上级
cc20867d
e5d1bd1e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
17 addition
and
45 deletion
+17
-45
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc
+8
-26
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc
...e/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc
+7
-10
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+1
-7
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+1
-2
未找到文件。
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc
浏览文件 @
d0000082
...
...
@@ -26,8 +26,6 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
PADDLE_ENFORCE
(
graph
.
get
());
FusePassBase
::
Init
(
"conv_relu_mkldnn_fuse"
,
graph
.
get
());
std
::
unordered_set
<
Node
*>
nodes2delete
;
GraphPatternDetector
gpd
;
auto
*
conv_input
=
gpd
.
mutable_pattern
()
->
NewNode
(
"conv_relu_mkldnn_fuse/conv_input"
)
...
...
@@ -42,36 +40,20 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle ConvReLU fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
conv_weight
,
conv_weight
,
conv_relu_pattern
);
// Filter
GET_IR_NODE_FROM_SUBGRAPH
(
conv_bias
,
conv_bias
,
conv_relu_pattern
);
// Bias
GET_IR_NODE_FROM_SUBGRAPH
(
conv_out
,
conv_out
,
conv_relu_pattern
);
// tmp
conv_relu_pattern
);
// Filter
GET_IR_NODE_FROM_SUBGRAPH
(
conv_out
,
conv_out
,
conv_relu_pattern
);
// tmp
GET_IR_NODE_FROM_SUBGRAPH
(
conv
,
conv
,
conv_relu_pattern
);
// CONV op
GET_IR_NODE_FROM_SUBGRAPH
(
relu_out
,
relu_out
,
conv_relu_pattern
);
// Out
GET_IR_NODE_FROM_SUBGRAPH
(
relu
,
relu
,
conv_relu_pattern
);
// ReLU op
// Create an ConvReLU Node.
OpDesc
desc
;
std
::
string
conv_relu_i_in
=
subgraph
.
at
(
conv_input
)
->
Name
();
std
::
string
conv_relu_w_in
=
conv_weight
->
Name
();
std
::
string
conv_relu_b_in
=
conv_bias
->
Name
();
std
::
string
conv_relu_out
=
relu_out
->
Name
();
desc
.
SetInput
(
"Input"
,
std
::
vector
<
std
::
string
>
({
conv_relu_i_in
}));
desc
.
SetInput
(
"Filter"
,
std
::
vector
<
std
::
string
>
({
conv_relu_w_in
}));
desc
.
SetInput
(
"Bias"
,
std
::
vector
<
std
::
string
>
({
conv_relu_b_in
}));
desc
.
SetOutput
(
"Output"
,
std
::
vector
<
std
::
string
>
({
conv_relu_out
}));
desc
.
SetType
(
"conv2d"
);
for
(
auto
&
attr
:
conv
->
Op
()
->
GetAttrMap
())
{
desc
.
SetAttr
(
attr
.
first
,
attr
.
second
);
}
desc
.
SetAttr
(
"fuse_relu"
,
true
);
auto
conv_relu_node
=
g
->
CreateOpNode
(
&
desc
);
// OpDesc will be copied.
GraphSafeRemoveNodes
(
graph
.
get
(),
{
conv
,
relu
,
conv_out
});
// Transform Conv node into ConvReLU node.
OpDesc
*
desc
=
conv
->
Op
();
desc
->
SetOutput
(
"Output"
,
std
::
vector
<
std
::
string
>
({
relu_out
->
Name
()}));
desc
->
SetAttr
(
"fuse_relu"
,
true
);
GraphSafeRemoveNodes
(
graph
.
get
(),
{
relu
,
conv_out
});
PADDLE_ENFORCE
(
subgraph
.
count
(
conv_input
));
IR_NODE_LINK_TO
(
subgraph
.
at
(
conv_input
),
conv_relu_node
);
IR_NODE_LINK_TO
(
conv_weight
,
conv_relu_node
);
IR_NODE_LINK_TO
(
conv_bias
,
conv_relu_node
);
IR_NODE_LINK_TO
(
conv_relu_node
,
relu_out
);
IR_NODE_LINK_TO
(
conv
,
relu_out
);
found_conv_relu_count
++
;
};
...
...
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc
浏览文件 @
d0000082
...
...
@@ -85,16 +85,13 @@ TEST(ConvReLUFusePass, basic) {
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"conv2d"
)
{
if
(
node
->
Op
()
->
HasAttr
(
"use_mkldnn"
))
{
bool
use_mkldnn
=
boost
::
get
<
bool
>
(
node
->
Op
()
->
GetAttr
(
"use_mkldnn"
));
if
(
use_mkldnn
)
{
if
(
node
->
Op
()
->
HasAttr
(
"fuse_relu"
))
{
bool
fuse_relu
=
boost
::
get
<
bool
>
(
node
->
Op
()
->
GetAttr
(
"fuse_relu"
));
if
(
fuse_relu
)
{
++
conv_relu_count
;
}
}
}
auto
*
op
=
node
->
Op
();
ASSERT_TRUE
(
op
->
HasAttr
(
"use_mkldnn"
));
EXPECT_TRUE
(
boost
::
get
<
bool
>
(
op
->
GetAttr
(
"use_mkldnn"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"fuse_relu"
));
bool
fuse_relu
=
boost
::
get
<
bool
>
(
op
->
GetAttr
(
"fuse_relu"
));
if
(
fuse_relu
)
{
++
conv_relu_count
;
}
}
}
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
d0000082
...
...
@@ -638,11 +638,6 @@ PDNode *patterns::ConvReLU::operator()(
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"conv2d"
,
"Filter"
);
// Bias
auto
*
conv_bias_var
=
pattern
->
NewNode
(
conv_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"conv2d"
,
"Bias"
);
// intermediate variable, will be removed in the IR after fuse.
auto
*
conv_out_var
=
pattern
->
NewNode
(
conv_out_repr
())
->
AsIntermediate
()
...
...
@@ -653,8 +648,7 @@ PDNode *patterns::ConvReLU::operator()(
->
AsOutput
()
->
assert_is_op_output
(
"relu"
);
conv_op
->
LinksFrom
({
conv_input
,
conv_weight_var
,
conv_bias_var
})
.
LinksTo
({
conv_out_var
});
conv_op
->
LinksFrom
({
conv_input
,
conv_weight_var
}).
LinksTo
({
conv_out_var
});
relu_op
->
LinksFrom
({
conv_out_var
}).
LinksTo
({
relu_out_var
});
return
relu_out_var
;
}
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
d0000082
...
...
@@ -379,7 +379,7 @@ struct PatternBase {
// op: conv + relu
// named nodes:
// conv_input, conv_weight,
// conv_
bias, conv_
out, conv,
// conv_out, conv,
// relu_out, relu
struct
ConvReLU
:
public
PatternBase
{
ConvReLU
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
...
...
@@ -392,7 +392,6 @@ struct ConvReLU : public PatternBase {
PATTERN_DECL_NODE
(
relu
);
// declare variable node's name
PATTERN_DECL_NODE
(
conv_weight
);
PATTERN_DECL_NODE
(
conv_bias
);
PATTERN_DECL_NODE
(
conv_out
);
PATTERN_DECL_NODE
(
relu_out
);
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录