Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
930ca3f4
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
930ca3f4
编写于
6月 18, 2021
作者:
W
Wangzheee
提交者:
GitHub
6月 18, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pass enhance (#33661)
上级
39556a44
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
307 addition
and
7 deletion
+307
-7
paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
+256
-3
paddle/fluid/framework/ir/conv_bn_fuse_pass.h
paddle/fluid/framework/ir/conv_bn_fuse_pass.h
+6
-0
paddle/fluid/framework/ir/pass_tester_helper.h
paddle/fluid/framework/ir/pass_tester_helper.h
+25
-4
paddle/fluid/operators/compat/batch_norm.pbtxt
paddle/fluid/operators/compat/batch_norm.pbtxt
+4
-0
paddle/fluid/operators/compat/conv2d.pbtxt
paddle/fluid/operators/compat/conv2d.pbtxt
+8
-0
paddle/fluid/operators/compat/relu.pbtxt
paddle/fluid/operators/compat/relu.pbtxt
+8
-0
未找到文件。
paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
浏览文件 @
930ca3f4
...
...
@@ -140,6 +140,91 @@ void recompute_bias_and_weights(const Scope* scope,
}
}
ConvBNFusePass
::
ConvBNFusePass
()
{
AddOpCompat
(
OpCompat
(
"conv2d"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsOptional
()
.
End
()
.
AddInput
(
"ResidualData"
)
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"strides"
)
.
End
()
.
AddAttr
(
"paddings"
)
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsOptional
()
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
AddOpCompat
(
OpCompat
(
"batch_norm"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"MeanOut"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"VarianceOut"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"SavedMean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"SavedVariance"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumLE
(
0.001
f
)
.
IsNumGE
(
0.0
f
)
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsNumEQ
(
1
)
.
End
();
}
void
ConvBNFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
...
...
@@ -161,8 +246,11 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
int
found_conv_bn_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
if
(
!
IsCompat
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"Pass in op compat failed."
;
return
;
}
VLOG
(
4
)
<<
"handle "
+
conv_type
()
+
"BN fuse"
;
// conv, batch_norm,
// conv_weight, conv_out,
// bn_scale, bn_bias, bn_mean, bn_variance,
...
...
@@ -236,6 +324,10 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
}
conv
->
Op
()
->
SetOutput
(
"Output"
,
std
::
vector
<
std
::
string
>
({
bn_out
->
Name
()}));
if
(
!
IsCompat
(
*
conv
->
Op
()))
{
LOG
(
WARNING
)
<<
"conv_bn fuse pass in out conv op compat failed."
;
return
;
}
GraphSafeRemoveNodes
(
graph
,
{
conv_out
,
bn_scale
,
bn_bias
,
bn_mean
,
bn_variance
,
batch_norm
,
...
...
@@ -251,6 +343,11 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetOutput
(
"Out"
,
std
::
vector
<
std
::
string
>
({
bn_out
->
Name
()}));
desc
.
SetType
(
"elementwise_add"
);
desc
.
SetAttr
(
"axis"
,
1
);
if
(
!
IsCompat
(
desc
))
{
LOG
(
WARNING
)
<<
"conv_bn fuse pass in out elementwise_add op compat failed."
;
return
;
}
auto
eltwise_op
=
g
->
CreateOpNode
(
&
desc
);
// OpDesc will be copied.
GraphSafeRemoveNodes
(
graph
,
{
bn_scale
,
bn_bias
,
bn_mean
,
bn_variance
,
...
...
@@ -269,6 +366,91 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis
(
found_conv_bn_count
);
}
ConvEltwiseAddBNFusePass
::
ConvEltwiseAddBNFusePass
()
{
AddOpCompat
(
OpCompat
(
"conv2d"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsOptional
()
.
End
()
.
AddInput
(
"ResidualData"
)
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"strides"
)
.
End
()
.
AddAttr
(
"paddings"
)
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
IsOptional
()
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
AddOpCompat
(
OpCompat
(
"batch_norm"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"MeanOut"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"VarianceOut"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"SavedMean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"SavedVariance"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumLE
(
0.001
f
)
.
IsNumGE
(
0.0
f
)
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsNumEQ
(
1
)
.
End
();
}
void
ConvEltwiseAddBNFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
...
...
@@ -290,8 +472,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
int
found_conv_bn_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
if
(
!
IsCompat
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"Pass in op compat failed."
;
return
;
}
VLOG
(
4
)
<<
"handle "
+
conv_type
()
+
"BN fuse"
;
// conv, batch_norm,
// conv_weight, conv_out,
// bn_scale, bn_bias, bn_mean, bn_variance,
...
...
@@ -361,7 +546,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
// Update the elementwise_add node
eltwise
->
Op
()
->
SetAttr
(
"axis"
,
1
);
eltwise
->
Op
()
->
SetOutput
(
"Out"
,
std
::
vector
<
std
::
string
>
({
bn_out
->
Name
()}));
if
(
!
IsCompat
(
*
eltwise
->
Op
()))
{
LOG
(
WARNING
)
<<
"conv_eltwise_bn fuse pass in out eltwise op compat failed."
;
return
;
}
GraphSafeRemoveNodes
(
graph
,
{
bn_scale
,
bn_bias
,
bn_mean
,
bn_variance
,
batch_norm
,
bn_mean_out
,
...
...
@@ -377,6 +566,70 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis
(
found_conv_bn_count
);
}
ConvTransposeBNFusePass
::
ConvTransposeBNFusePass
()
{
AddOpCompat
(
OpCompat
(
"conv2d_transpose"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"strides"
)
.
End
()
.
AddAttr
(
"paddings"
)
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
IsOptional
()
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
}
ConvTransposeEltwiseAddBNFusePass
::
ConvTransposeEltwiseAddBNFusePass
()
{
AddOpCompat
(
OpCompat
(
"conv2d_transpose"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"strides"
)
.
End
()
.
AddAttr
(
"paddings"
)
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
IsOptional
()
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
...
...
paddle/fluid/framework/ir/conv_bn_fuse_pass.h
浏览文件 @
930ca3f4
...
...
@@ -31,6 +31,7 @@ class Graph;
class
ConvBNFusePass
:
public
FusePassBase
{
public:
ConvBNFusePass
();
virtual
~
ConvBNFusePass
()
{}
virtual
std
::
string
conv_type
()
const
{
return
"conv2d"
;
}
...
...
@@ -41,6 +42,7 @@ class ConvBNFusePass : public FusePassBase {
class
ConvEltwiseAddBNFusePass
:
public
FusePassBase
{
public:
ConvEltwiseAddBNFusePass
();
virtual
~
ConvEltwiseAddBNFusePass
()
{}
virtual
std
::
string
conv_type
()
const
{
return
"conv2d"
;
}
...
...
@@ -51,11 +53,15 @@ class ConvEltwiseAddBNFusePass : public FusePassBase {
class
ConvTransposeBNFusePass
:
public
ConvBNFusePass
{
public:
ConvTransposeBNFusePass
();
virtual
~
ConvTransposeBNFusePass
()
{}
std
::
string
conv_type
()
const
{
return
"conv2d_transpose"
;
}
};
class
ConvTransposeEltwiseAddBNFusePass
:
public
ConvEltwiseAddBNFusePass
{
public:
ConvTransposeEltwiseAddBNFusePass
();
virtual
~
ConvTransposeEltwiseAddBNFusePass
()
{}
std
::
string
conv_type
()
const
{
return
"conv2d_transpose"
;
}
};
...
...
paddle/fluid/framework/ir/pass_tester_helper.h
浏览文件 @
930ca3f4
...
...
@@ -39,28 +39,49 @@ struct Layers {
}
VarDesc
*
conv2d
(
VarDesc
*
input
,
VarDesc
*
filter
,
VarDesc
*
bias
,
bool
use_cudnn
=
false
)
{
int
groups
=
1
,
std
::
vector
<
int
>
strides
=
{
1
,
1
},
std
::
vector
<
int
>
paddings
=
{
0
,
0
},
std
::
string
padding_algorithm
=
"EXPLICIT"
,
std
::
vector
<
int
>
dilations
=
{
1
,
1
},
std
::
string
data_format
=
"NCHW"
,
bool
use_cudnn
=
false
)
{
VarDesc
*
out
=
lod_tensor
(
unique_name
());
OpDesc
*
op
=
program_
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"conv2d"
);
op
->
SetInput
(
"Input"
,
{
input
->
Name
()});
op
->
SetInput
(
"Filter"
,
{
filter
->
Name
()});
op
->
SetInput
(
"Bias"
,
{
bias
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
out
->
Name
()});
op
->
SetOutput
(
"Out
put
"
,
{
out
->
Name
()});
op
->
SetAttr
(
"use_cudnn"
,
use_cudnn
);
op
->
SetAttr
(
"groups"
,
groups
);
op
->
SetAttr
(
"strides"
,
strides
);
op
->
SetAttr
(
"paddings"
,
paddings
);
op
->
SetAttr
(
"padding_algorithm"
,
padding_algorithm
);
op
->
SetAttr
(
"dilations"
,
dilations
);
op
->
SetAttr
(
"data_format"
,
data_format
);
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
return
out
;
}
VarDesc
*
conv2d_transpose
(
VarDesc
*
input
,
VarDesc
*
filter
,
VarDesc
*
bias
)
{
VarDesc
*
conv2d_transpose
(
VarDesc
*
input
,
VarDesc
*
filter
,
VarDesc
*
bias
,
int
groups
=
1
,
std
::
vector
<
int
>
strides
=
{
1
,
1
},
std
::
vector
<
int
>
paddings
=
{
0
,
0
},
std
::
string
padding_algorithm
=
"EXPLICIT"
,
std
::
vector
<
int
>
dilations
=
{
1
,
1
},
std
::
string
data_format
=
"NCHW"
)
{
VarDesc
*
out
=
lod_tensor
(
unique_name
());
OpDesc
*
op
=
program_
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"conv2d_transpose"
);
op
->
SetInput
(
"Input"
,
{
input
->
Name
()});
op
->
SetInput
(
"Filter"
,
{
filter
->
Name
()});
op
->
SetInput
(
"Bias"
,
{
bias
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
out
->
Name
()});
op
->
SetOutput
(
"Output"
,
{
out
->
Name
()});
op
->
SetAttr
(
"groups"
,
groups
);
op
->
SetAttr
(
"strides"
,
strides
);
op
->
SetAttr
(
"paddings"
,
paddings
);
op
->
SetAttr
(
"padding_algorithm"
,
padding_algorithm
);
op
->
SetAttr
(
"dilations"
,
dilations
);
op
->
SetAttr
(
"data_format"
,
data_format
);
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
return
out
;
...
...
paddle/fluid/operators/compat/batch_norm.pbtxt
浏览文件 @
930ca3f4
...
...
@@ -42,6 +42,10 @@ extra {
inputs {
name: "MomentumTensor"
}
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
...
...
paddle/fluid/operators/compat/conv2d.pbtxt
浏览文件 @
930ca3f4
...
...
@@ -41,6 +41,14 @@ def {
}
}
extra {
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "skip_quant"
type: BOOLEAN
}
attrs {
name: "is_test"
type: BOOLEAN
...
...
paddle/fluid/operators/compat/relu.pbtxt
浏览文件 @
930ca3f4
...
...
@@ -12,6 +12,14 @@ extra {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "out_threshold"
type: FLOAT
}
attrs {
name: "Out0_threshold"
type: FLOAT
}
attrs {
name: "use_mkldnn"
type: BOOLEAN
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录