Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
fbbd8208
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fbbd8208
编写于
6月 14, 2019
作者:
S
sangoly
提交者:
GitHub
6月 14, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix conv_bias_relu fuse bug & x86 conv kernel bug (#18098)
fix conv_bias_relu fuse bug,x86 conv kernel bug
上级
d8ba9626
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
50 addition
and
35 deletion
+50
-35
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc
...luid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc
+4
-1
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc
...lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc
+2
-2
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc
...d/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc
+23
-18
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h
...id/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h
+3
-0
paddle/fluid/lite/core/mir/passes.h
paddle/fluid/lite/core/mir/passes.h
+1
-0
paddle/fluid/lite/kernels/x86/conv_compute.cc
paddle/fluid/lite/kernels/x86/conv_compute.cc
+1
-1
paddle/fluid/lite/kernels/x86/relu_compute.cc
paddle/fluid/lite/kernels/x86/relu_compute.cc
+1
-1
paddle/fluid/lite/operators/conv_op.h
paddle/fluid/lite/operators/conv_op.h
+12
-9
paddle/fluid/lite/operators/op_params.h
paddle/fluid/lite/operators/op_params.h
+2
-2
paddle/fluid/lite/operators/relu_op.cc
paddle/fluid/lite/operators/relu_op.cc
+1
-1
未找到文件。
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc
浏览文件 @
fbbd8208
...
...
@@ -24,8 +24,11 @@ namespace mir {
void
ConvElementwiseAddReLUFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
ConvElementwiseAddReLUFuser
fuser
;
fusion
::
ConvElementwiseAddReLUFuser
fuser
(
"conv2d"
)
;
fuser
(
graph
.
get
());
fusion
::
ConvElementwiseAddReLUFuser
depthwise_fuser
(
"depthwise_conv2d"
);
depthwise_fuser
(
graph
.
get
());
}
}
// namespace mir
...
...
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc
浏览文件 @
fbbd8208
...
...
@@ -85,7 +85,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
add_1
->
SetAttr
(
"axis"
,
1
);
relu_1
->
SetType
(
"relu"
);
relu_1
->
SetInput
(
"
Input
"
,
{
"add_1_out"
});
relu_1
->
SetInput
(
"
X
"
,
{
"add_1_out"
});
relu_1
->
SetOutput
(
"Out"
,
{
"relu_1_out"
});
conv2d_2
->
SetType
(
"conv2d"
);
...
...
@@ -105,7 +105,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
add_2
->
SetAttr
(
"axis"
,
1
);
relu_2
->
SetType
(
"relu"
);
relu_2
->
SetInput
(
"
Input
"
,
{
"add_2_out"
});
relu_2
->
SetInput
(
"
X
"
,
{
"add_2_out"
});
relu_2
->
SetOutput
(
"Out"
,
{
"out"
});
program_desc
->
Flush
();
...
...
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc
浏览文件 @
fbbd8208
...
...
@@ -23,21 +23,33 @@ namespace fusion {
void
ConvElementwiseAddReLUFuser
::
BuildPattern
()
{
// create input nodes.
auto
*
input
=
VarNode
(
"input"
);
auto
*
filter
=
VarNode
(
"filter"
);
auto
*
bias
=
VarNode
(
"bias"
);
auto
*
input
=
VarNode
(
"input"
)
->
assert_is_op_input
(
conv_type_
,
"Input"
)
->
AsInput
();
auto
*
filter
=
VarNode
(
"filter"
)
->
assert_is_op_input
(
conv_type_
,
"Filter"
)
->
AsInput
();
auto
*
bias
=
VarNode
(
"bias"
)
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
AsInput
();
// create op nodes
auto
*
conv2d
=
OpNode
(
"conv2d"
,
"conv2d"
);
auto
*
add
=
OpNode
(
"add"
,
"elementwise_add"
);
auto
*
relu
=
OpNode
(
"relu"
,
"relu"
);
auto
*
conv2d
=
OpNode
(
"conv2d"
,
conv_type_
)
->
assert_is_op
(
conv_type_
)
->
AsIntermediate
();
auto
*
add
=
OpNode
(
"add"
,
"elementwise_add"
)
->
assert_is_op
(
"elementwise_add"
)
->
AsIntermediate
();
auto
*
relu
=
OpNode
(
"relu"
,
"relu"
)
->
assert_is_op
(
"relu"
)
->
AsIntermediate
();
// create intermediate nodes
auto
*
conv2d_out
=
VarNode
(
"conv2d_out"
);
auto
*
add_out
=
VarNode
(
"add_out"
);
auto
*
conv2d_out
=
VarNode
(
"conv2d_out"
)
->
assert_is_op_output
(
conv_type_
,
"Output"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
AsIntermediate
();
auto
*
add_out
=
VarNode
(
"add_out"
)
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
)
->
assert_is_op_input
(
"relu"
,
"X"
)
->
AsIntermediate
();
// create output node
auto
*
out
=
VarNode
(
"output"
);
auto
*
out
=
VarNode
(
"output"
)
->
assert_is_op_output
(
"relu"
,
"Out"
)
->
AsOutput
()
;
// create topology.
std
::
vector
<
PMNode
*>
conv2d_inputs
{
filter
,
input
};
...
...
@@ -45,19 +57,12 @@ void ConvElementwiseAddReLUFuser::BuildPattern() {
conv2d_inputs
>>
*
conv2d
>>
*
conv2d_out
;
add_inputs
>>
*
add
>>
*
add_out
;
*
add_out
>>
*
relu
>>
*
out
;
// Some op specialities.
conv2d_out
->
AsIntermediate
();
add_out
->
AsIntermediate
();
conv2d
->
AsIntermediate
();
add
->
AsIntermediate
();
relu
->
AsIntermediate
();
}
void
ConvElementwiseAddReLUFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
conv_op
=
LiteOpRegistry
::
Global
().
Create
(
"conv2d"
);
auto
conv_op
=
LiteOpRegistry
::
Global
().
Create
(
conv_type_
);
auto
conv_old
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
;
auto
*
scope
=
conv_old
->
scope
();
auto
&
valid_places
=
conv_old
->
valid_places
();
...
...
@@ -75,7 +80,7 @@ cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) {
auto
*
desc
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op_info
();
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"conv2d"
);
op_desc
.
SetType
(
conv_type_
);
op_desc
.
SetInput
(
"Input"
,
{
matched
.
at
(
"input"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Filter"
,
{
matched
.
at
(
"filter"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Bias"
,
{
matched
.
at
(
"bias"
)
->
arg
()
->
name
});
...
...
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h
浏览文件 @
fbbd8208
...
...
@@ -25,11 +25,14 @@ namespace fusion {
class
ConvElementwiseAddReLUFuser
:
public
FuseBase
{
public:
explicit
ConvElementwiseAddReLUFuser
(
const
std
::
string
&
conv_type
)
:
conv_type_
(
conv_type
)
{}
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
std
::
string
conv_type_
;
};
}
// namespace fusion
...
...
paddle/fluid/lite/core/mir/passes.h
浏览文件 @
fbbd8208
...
...
@@ -32,3 +32,4 @@ USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS
(
argument_type_display_pass
);
USE_MIR_PASS
(
runtime_context_assign_pass
);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
graph_visualze
);
paddle/fluid/lite/kernels/x86/conv_compute.cc
浏览文件 @
fbbd8208
...
...
@@ -105,7 +105,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
param
.
x
->
raw_tensor
().
Slice
(
i
,
i
+
1
).
Resize
(
input_shape
.
data
()));
lite
::
Tensor
out_batch
;
out_batch
.
ShareDataWith
(
param
.
output
->
raw_tensor
().
Slice
(
i
,
i
+
1
).
Resize
(
input
_shape
.
data
()));
output_matrix
_shape
.
data
()));
for
(
int
g
=
0
;
g
<
param
.
groups
;
g
++
)
{
lite
::
Tensor
in_slice
;
...
...
paddle/fluid/lite/kernels/x86/relu_compute.cc
浏览文件 @
fbbd8208
...
...
@@ -51,6 +51,6 @@ class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL
(
relu
,
kX86
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
x86
::
ReluCompute
<
float
>
,
def
)
.
BindInput
(
"
Input
"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"
X
"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
Finalize
();
paddle/fluid/lite/operators/conv_op.h
浏览文件 @
fbbd8208
...
...
@@ -73,19 +73,22 @@ class ConvOpLite : public OpLite {
std
::
vector
<
std
::
string
>
input_arg_names
=
op_desc
.
InputArgumentNames
();
if
(
std
::
find
(
input_arg_names
.
begin
(),
input_arg_names
.
end
(),
"Bias"
)
!=
input_arg_names
.
end
())
{
auto
bias_var
=
scope
->
FindVar
(
op_desc
.
Input
(
"Bias"
).
front
());
if
(
bias_var
!=
nullptr
)
{
param_
.
bias
=
const_cast
<
lite
::
Tensor
*>
(
&
(
bias_var
->
Get
<
lite
::
Tensor
>
()));
auto
bias_arguments
=
op_desc
.
Input
(
"Bias"
);
if
(
bias_arguments
.
size
()
!=
0
)
{
auto
bias_var
=
scope
->
FindVar
(
bias_arguments
.
front
());
if
(
bias_var
!=
nullptr
)
{
param_
.
bias
=
&
bias_var
->
Get
<
lite
::
Tensor
>
();
}
}
}
if
(
std
::
find
(
input_arg_names
.
begin
(),
input_arg_names
.
end
(),
"ResidualData"
)
!=
input_arg_names
.
end
())
{
auto
residual_data_var
=
scope
->
FindVar
(
op_desc
.
Input
(
"ResidualData"
).
front
());
if
(
residual_data_var
!=
nullptr
)
{
param_
.
residualData
=
const_cast
<
lite
::
Tensor
*>
(
&
(
residual_data_var
->
Get
<
lite
::
Tensor
>
()));
auto
res_argument
=
op_desc
.
Input
(
"ResidualData"
);
if
(
res_argument
.
size
()
!=
0
)
{
auto
residual_data_var
=
scope
->
FindVar
(
res_argument
.
front
());
if
(
residual_data_var
!=
nullptr
)
{
param_
.
residualData
=
&
residual_data_var
->
Get
<
lite
::
Tensor
>
();
}
}
}
...
...
paddle/fluid/lite/operators/op_params.h
浏览文件 @
fbbd8208
...
...
@@ -124,8 +124,8 @@ struct ConcatParam {
struct
ConvParam
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
filter
{};
lite
::
Tensor
*
bias
{};
lite
::
Tensor
*
residualData
{};
const
lite
::
Tensor
*
bias
{};
const
lite
::
Tensor
*
residualData
{};
lite
::
Tensor
*
output
{};
std
::
vector
<
int
>
strides
{
1
,
1
};
std
::
vector
<
int
>
paddings
{
0
,
0
};
...
...
paddle/fluid/lite/operators/relu_op.cc
浏览文件 @
fbbd8208
...
...
@@ -32,7 +32,7 @@ bool ReluOp::InferShape() const {
bool
ReluOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
input
=
const_cast
<
lite
::
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"
Input
"
).
front
())
->
Get
<
lite
::
Tensor
>
());
&
scope
->
FindVar
(
opdesc
.
Input
(
"
X
"
).
front
())
->
Get
<
lite
::
Tensor
>
());
param_
.
output
=
scope
->
FindVar
(
opdesc
.
Output
(
"Out"
).
front
())
->
GetMutable
<
lite
::
Tensor
>
();
CHECK
(
param_
.
input
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录