Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
d7509d63
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d7509d63
编写于
10月 12, 2018
作者:
M
Michal Gallus
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Conv+Bias: Support non-null bias
test=develop
上级
91e8fbac
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
82 addition
and
134 deletion
+82
-134
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+0
-1
paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc
paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc
+79
-27
paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h
paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h
+2
-0
paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass_tester.cc
...e/fluid/framework/ir/conv_bias_mkldnn_fuse_pass_tester.cc
+0
-106
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+1
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
d7509d63
...
@@ -56,6 +56,5 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph
...
@@ -56,6 +56,5 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph
cc_test
(
test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector
)
cc_test
(
test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector
)
cc_test
(
test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto
)
cc_test
(
test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto
)
if
(
WITH_MKLDNN
)
if
(
WITH_MKLDNN
)
cc_test
(
test_conv_bias_mkldnn_fuse_pass SRCS conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass
)
cc_test
(
test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass
)
cc_test
(
test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass
)
endif
()
endif
()
paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc
浏览文件 @
d7509d63
...
@@ -11,24 +11,48 @@
...
@@ -11,24 +11,48 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
#include <functional>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
template
<
typename
BinaryOperation
>
LoDTensor
tensor_apply_eltwise
(
const
LoDTensor
&
vec_a
,
const
LoDTensor
&
vec_b
,
BinaryOperation
f
)
{
PADDLE_ENFORCE_EQ
(
vec_a
.
dims
(),
vec_b
.
dims
());
LoDTensor
vec_y
;
vec_y
.
Resize
(
vec_a
.
dims
());
const
float
*
a
=
vec_a
.
data
<
float
>
();
const
float
*
b
=
vec_b
.
data
<
float
>
();
float
*
y
=
vec_y
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
vec_a
.
numel
();
i
++
)
{
y
[
i
]
=
f
(
a
[
i
],
b
[
i
]);
}
return
vec_y
;
}
std
::
unique_ptr
<
ir
::
Graph
>
ConvBiasFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ConvBiasFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
PADDLE_ENFORCE
(
graph
.
get
());
PADDLE_ENFORCE
(
graph
.
get
());
FusePassBase
::
Init
(
"conv_bias_mkldnn_fuse"
,
graph
.
get
());
FusePassBase
::
Init
(
name_scope_
,
graph
.
get
());
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE
(
scope
);
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
conv_input
=
gpd
.
mutable_pattern
()
auto
*
conv_input
=
->
NewNode
(
"conv_bias_mkldnn_fuse/conv_input"
)
gpd
.
mutable_pattern
()
->
NewNode
(
patterns
::
PDNodeName
(
name_scope_
,
"conv_input"
))
->
AsInput
()
->
AsInput
()
->
assert_is_op_input
(
"conv2d"
,
"Input"
);
->
assert_is_op_input
(
"conv2d"
,
"Input"
);
patterns
::
ConvBias
conv_bias_pattern
(
gpd
.
mutable_pattern
(),
patterns
::
ConvBias
conv_bias_pattern
(
gpd
.
mutable_pattern
(),
name_scope_
);
"conv_bias_mkldnn_fuse"
);
conv_bias_pattern
(
conv_input
);
conv_bias_pattern
(
conv_input
);
int
found_conv_bias_count
=
0
;
int
found_conv_bias_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
...
@@ -44,27 +68,55 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
...
@@ -44,27 +68,55 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_out
,
eltwise_out
,
conv_bias_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise_out
,
eltwise_out
,
conv_bias_pattern
);
// elementwise_add op
// elementwise_add op
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise
,
eltwise
,
conv_bias_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltwise
,
eltwise
,
conv_bias_pattern
);
// Create an ConvBias Node.
PADDLE_ENFORCE
(
subgraph
.
count
(
conv_input
));
auto
*
eltwise_bias_tensor
=
scope
->
FindVar
(
eltwise_bias
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
input_names
=
conv
->
Op
()
->
InputNames
();
bool
has_bias
=
std
::
find
(
input_names
.
begin
(),
input_names
.
end
(),
"Bias"
)
!=
input_names
.
end
();
if
(
has_bias
&&
conv
->
Op
()
->
Input
(
"Bias"
).
size
()
>
0
)
{
auto
conv_bias_names
=
conv
->
Op
()
->
Input
(
"Bias"
);
// add eltwise bias to existing conv bias
PADDLE_ENFORCE_EQ
(
conv_bias_names
.
size
(),
1
);
auto
*
conv_bias_var
=
scope
->
FindVar
(
conv_bias_names
[
0
]);
auto
*
conv_bias_tensor
=
conv_bias_var
->
GetMutable
<
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
conv_bias_tensor
->
dims
(),
eltwise_bias_tensor
->
dims
());
*
conv_bias_tensor
=
tensor_apply_eltwise
(
*
conv_bias_tensor
,
*
eltwise_bias_tensor
,
std
::
plus
<
float
>
());
conv
->
Op
()
->
SetOutput
(
"Output"
,
std
::
vector
<
std
::
string
>
({
eltwise_out
->
Name
()}));
GraphSafeRemoveNodes
(
graph
.
get
(),
{
eltwise
,
conv_out
});
IR_NODE_LINK_TO
(
conv
,
eltwise_out
);
}
else
{
// take eltwise bias as conv bias
OpDesc
desc
;
OpDesc
desc
;
std
::
string
conv_bias_i_in
=
subgraph
.
at
(
conv_input
)
->
Name
();
std
::
string
conv_bias_w_in
=
conv_weight
->
Name
();
desc
.
SetInput
(
std
::
string
conv_bias_b_in
=
eltwise_bias
->
Name
();
"Input"
,
std
::
vector
<
std
::
string
>
({
subgraph
.
at
(
conv_input
)
->
Name
()}));
std
::
string
conv_bias_out
=
eltwise_out
->
Name
();
desc
.
SetInput
(
"Filter"
,
std
::
vector
<
std
::
string
>
({
conv_weight
->
Name
()}));
desc
.
SetInput
(
"Input"
,
std
::
vector
<
std
::
string
>
({
conv_bias_i_in
}));
desc
.
SetInput
(
"Bias"
,
std
::
vector
<
std
::
string
>
({
eltwise_bias
->
Name
()}));
desc
.
SetInput
(
"Filter"
,
std
::
vector
<
std
::
string
>
({
conv_bias_w_in
}));
desc
.
SetOutput
(
"Output"
,
std
::
vector
<
std
::
string
>
({
eltwise_out
->
Name
()}));
desc
.
SetInput
(
"Bias"
,
std
::
vector
<
std
::
string
>
({
conv_bias_b_in
}));
desc
.
SetOutput
(
"Output"
,
std
::
vector
<
std
::
string
>
({
conv_bias_out
}));
desc
.
SetType
(
"conv2d"
);
desc
.
SetType
(
"conv2d"
);
for
(
auto
&
attr
:
conv
->
Op
()
->
GetAttrMap
())
{
for
(
auto
&
attr
:
conv
->
Op
()
->
GetAttrMap
())
{
desc
.
SetAttr
(
attr
.
first
,
attr
.
second
);
desc
.
SetAttr
(
attr
.
first
,
attr
.
second
);
}
}
auto
conv_bias_node
=
g
->
CreateOpNode
(
&
desc
);
// OpDesc will be copied.
auto
conv_bias_node
=
g
->
CreateOpNode
(
&
desc
);
GraphSafeRemoveNodes
(
graph
.
get
(),
{
conv
,
eltwise
,
conv_out
});
PADDLE_ENFORCE
(
subgraph
.
count
(
conv_input
));
IR_NODE_LINK_TO
(
subgraph
.
at
(
conv_input
),
conv_bias_node
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
conv_input
),
conv_bias_node
);
IR_NODE_LINK_TO
(
conv_weight
,
conv_bias_node
);
IR_NODE_LINK_TO
(
conv_weight
,
conv_bias_node
);
IR_NODE_LINK_TO
(
eltwise_bias
,
conv_bias_node
);
IR_NODE_LINK_TO
(
eltwise_bias
,
conv_bias_node
);
IR_NODE_LINK_TO
(
conv_bias_node
,
eltwise_out
);
IR_NODE_LINK_TO
(
conv_bias_node
,
eltwise_out
);
GraphSafeRemoveNodes
(
graph
.
get
(),
{
conv
,
eltwise
,
conv_out
});
}
found_conv_bias_count
++
;
found_conv_bias_count
++
;
};
};
gpd
(
graph
.
get
(),
handler
);
gpd
(
graph
.
get
(),
handler
);
...
...
paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h
浏览文件 @
d7509d63
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
...
@@ -28,6 +29,7 @@ class ConvBiasFusePass : public FusePassBase {
...
@@ -28,6 +29,7 @@ class ConvBiasFusePass : public FusePassBase {
protected:
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
const
std
::
string
name_scope_
{
"conv_bias_mkldnn_fuse"
};
};
};
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass_tester.cc
已删除
100644 → 0
浏览文件 @
91e8fbac
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
SetOp
(
ProgramDesc
*
prog
,
const
std
::
string
&
type
,
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
vector
<
std
::
string
>&
outputs
)
{
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
type
);
if
(
type
==
"conv2d"
)
{
op
->
SetAttr
(
"use_mkldnn"
,
true
);
op
->
SetInput
(
"Input"
,
{
inputs
[
0
]});
op
->
SetInput
(
"Filter"
,
{
inputs
[
1
]});
}
else
if
(
type
==
"elementwise_add"
)
{
op
->
SetInput
(
"X"
,
{
inputs
[
0
]});
op
->
SetInput
(
"Y"
,
{
inputs
[
1
]});
}
op
->
SetOutput
(
"Out"
,
outputs
);
}
// a->OP0->b
// b->OP1->c
// (c, weights)->conv->f
// (f, bias)->elementwise_add->g
ProgramDesc
BuildProgramDesc
()
{
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
({
"a"
,
"b"
,
"c"
,
"weights"
,
"bias"
,
"f"
,
"g"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
if
(
v
==
"weights"
||
v
==
"bias"
)
{
var
->
SetPersistable
(
true
);
}
}
SetOp
(
&
prog
,
"OP0"
,
std
::
vector
<
std
::
string
>
({
"a"
}),
std
::
vector
<
std
::
string
>
({
"b"
}));
SetOp
(
&
prog
,
"OP1"
,
std
::
vector
<
std
::
string
>
({
"b"
}),
std
::
vector
<
std
::
string
>
({
"c"
}));
SetOp
(
&
prog
,
"conv2d"
,
std
::
vector
<
std
::
string
>
({
"c"
,
"weights"
}),
std
::
vector
<
std
::
string
>
({
"f"
}));
SetOp
(
&
prog
,
"elementwise_add"
,
std
::
vector
<
std
::
string
>
({
"f"
,
"bias"
}),
std
::
vector
<
std
::
string
>
({
"g"
}));
return
prog
;
}
TEST
(
ConvBiasFusePass
,
basic
)
{
auto
prog
=
BuildProgramDesc
();
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"conv_bias_mkldnn_fuse_pass"
);
int
original_nodes_num
=
graph
->
Nodes
().
size
();
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
int
current_nodes_num
=
graph
->
Nodes
().
size
();
// Remove 3 Nodes: conv, elementwise_add, conv_out
// Add 1 Node: ConvBias
EXPECT_EQ
(
original_nodes_num
-
2
,
current_nodes_num
);
// Assert conv_bias op in newly generated graph
int
conv_bias_count
=
0
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"conv2d"
)
{
if
(
node
->
Op
()
->
HasAttr
(
"use_mkldnn"
))
{
bool
use_mkldnn
=
boost
::
get
<
bool
>
(
node
->
Op
()
->
GetAttr
(
"use_mkldnn"
));
if
(
use_mkldnn
)
{
auto
names
=
node
->
Op
()
->
InputNames
();
if
(
std
::
find
(
names
.
begin
(),
names
.
end
(),
"Bias"
)
!=
names
.
end
())
{
conv_bias_count
++
;
}
}
}
}
}
EXPECT_EQ
(
conv_bias_count
,
1
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
conv_bias_mkldnn_fuse_pass
);
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
d7509d63
...
@@ -987,6 +987,7 @@ PDNode *patterns::ConvBias::operator()(
...
@@ -987,6 +987,7 @@ PDNode *patterns::ConvBias::operator()(
// Bias stored in elementwise_add
// Bias stored in elementwise_add
auto
*
eltwise_bias_var
=
pattern
->
NewNode
(
eltwise_bias_repr
())
auto
*
eltwise_bias_var
=
pattern
->
NewNode
(
eltwise_bias_repr
())
->
AsInput
()
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
// output
// output
auto
*
eltwise_out_var
=
pattern
->
NewNode
(
eltwise_out_repr
())
auto
*
eltwise_out_var
=
pattern
->
NewNode
(
eltwise_out_repr
())
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录