Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
bd68761a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bd68761a
编写于
6月 25, 2021
作者:
W
Wangzheee
提交者:
GitHub
6月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ pass_enhance ]quant_conv2d_dequant_fuse_pass (#33737)
上级
3ad6630f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
287 addition
and
11 deletion
+287
-11
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
+212
-6
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h
+8
-3
paddle/fluid/operators/compat/conv2d.pbtxt
paddle/fluid/operators/compat/conv2d.pbtxt
+16
-0
paddle/fluid/operators/compat/fake_channel_wise_dequantize_max_abs.pbtxt
...erators/compat/fake_channel_wise_dequantize_max_abs.pbtxt
+47
-0
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+4
-2
未找到文件。
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
浏览文件 @
bd68761a
...
@@ -21,11 +21,209 @@
...
@@ -21,11 +21,209 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
QuantDequantFusePass
::
QuantDequantFusePass
()
{
AddOpCompat
(
OpCompat
(
"fake_quantize_range_abs_max"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"InScale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Iter"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"OutScale"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"OutScales"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"window_size"
)
.
IsType
<
int
>
()
.
IsNumGT
(
0
)
.
End
()
.
AddAttr
(
"bit_length"
)
.
IsIntIn
({
8
,
16
})
.
End
();
AddOpCompat
(
OpCompat
(
"fake_quantize_moving_average_abs_max"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"InScale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"InAccum"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"InState"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"OutScale"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"OutState"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"OutAccum"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddAttr
(
"moving_rate"
)
.
IsType
<
float
>
()
.
IsNumGT
(
0.0
f
)
.
End
()
.
AddAttr
(
"bit_length"
)
.
IsIntIn
({
8
,
16
})
.
End
();
AddOpCompat
(
OpCompat
(
"fake_dequantize_max_abs"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"max_range"
)
.
IsType
<
float
>
()
.
IsNumGT
(
0.0
f
)
.
End
();
AddOpCompat
(
OpCompat
(
"fake_channel_wise_dequantize_max_abs"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scales"
)
// "Scales" is a vector with at most two tensors
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"quant_bits"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"quant_axis"
)
.
IsIntIn
({
0
,
1
})
.
IsOptional
()
.
End
();
AddOpCompat
(
OpCompat
(
"conv2d"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ResidualData"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"strides"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"paddings"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsOptional
()
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
AddOpCompat
(
OpCompat
(
"mul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"x_num_col_dims"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"y_num_col_dims"
)
.
IsNumEQ
(
1
)
.
End
();
AddOpCompat
(
OpCompat
(
"fc"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"W"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"in_num_col_dims"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"activation_type"
)
.
IsStringIn
({
"relu"
,
""
})
.
End
();
AddOpCompat
(
OpCompat
(
"conv2d_transpose"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"strides"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"paddings"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
IsOptional
()
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
}
// Delete quant op before quantized ops, and set input scale in the attr of
// Delete quant op before quantized ops, and set input scale in the attr of
// quantized ops
// quantized ops
void
DeleteQuant
(
ir
::
Graph
*
graph
,
Scope
*
scope
,
void
QuantDequantFusePass
::
DeleteQuant
(
ir
::
Graph
*
graph
,
Scope
*
scope
,
const
std
::
string
&
quant_type
)
{
const
std
::
string
&
quant_type
)
const
{
const
std
::
string
pattern_name
=
"delete_quant_fuse"
;
const
std
::
string
pattern_name
=
"delete_quant_fuse"
;
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
input_act_node
=
gpd
.
mutable_pattern
()
auto
*
input_act_node
=
gpd
.
mutable_pattern
()
...
@@ -41,6 +239,10 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
...
@@ -41,6 +239,10 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
// ops linked from it
// ops linked from it
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
if
(
!
IsCompat
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"Pass in op compat failed."
;
return
;
}
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
subgraph
.
count
(
input_act_node
),
true
,
subgraph
.
count
(
input_act_node
),
true
,
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
...
@@ -103,9 +305,9 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
...
@@ -103,9 +305,9 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
// Delete dequant op after quantized ops, and convert weight from fp32 range to
// Delete dequant op after quantized ops, and convert weight from fp32 range to
// int8 range
// int8 range
void
FuseDequant
(
ir
::
Graph
*
graph
,
Scope
*
scope
,
void
QuantDequantFusePass
::
FuseDequant
(
ir
::
Graph
*
graph
,
Scope
*
scope
,
const
std
::
string
&
quantized_op_type
,
const
std
::
string
&
quantized_op_type
,
const
std
::
string
&
dequant_type
)
{
const
std
::
string
&
dequant_type
)
const
{
std
::
string
weight_name
=
""
;
std
::
string
weight_name
=
""
;
std
::
string
input_name
=
""
;
std
::
string
input_name
=
""
;
if
(
quantized_op_type
==
"conv2d"
||
if
(
quantized_op_type
==
"conv2d"
||
...
@@ -142,6 +344,10 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
...
@@ -142,6 +344,10 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
// Create new op desc
// Create new op desc
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
if
(
!
IsCompat
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"Pass in op compat failed."
;
return
;
}
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
subgraph
.
count
(
quantized_op_input
),
true
,
subgraph
.
count
(
quantized_op_input
),
true
,
platform
::
errors
::
NotFound
(
"Quantized op input node(%s) did not find "
platform
::
errors
::
NotFound
(
"Quantized op input node(%s) did not find "
...
...
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h
浏览文件 @
bd68761a
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
#include <memory>
#include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -25,14 +24,20 @@ namespace ir {
...
@@ -25,14 +24,20 @@ namespace ir {
///
///
/// Fuse quant + conv2d/depthwise_conv2d/mul/fc + dequant
/// Fuse quant + conv2d/depthwise_conv2d/mul/fc + dequant
///
///
class
Graph
;
class
QuantDequantFusePass
:
public
FusePassBase
{
class
QuantDequantFusePass
:
public
FusePassBase
{
public:
public:
QuantDequantFusePass
();
virtual
~
QuantDequantFusePass
()
{}
virtual
~
QuantDequantFusePass
()
{}
protected:
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
void
DeleteQuant
(
ir
::
Graph
*
graph
,
Scope
*
scope
,
const
std
::
string
&
quant_type
)
const
;
void
FuseDequant
(
ir
::
Graph
*
graph
,
Scope
*
scope
,
const
std
::
string
&
quantized_op_type
,
const
std
::
string
&
dequant_type
)
const
;
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/operators/compat/conv2d.pbtxt
浏览文件 @
bd68761a
...
@@ -41,6 +41,22 @@ def {
...
@@ -41,6 +41,22 @@ def {
}
}
}
}
extra {
extra {
attrs {
name: "Input_scale"
type: FLOAT
}
attrs {
name: "quantization_type"
type: STRING
}
attrs {
name: "bit_length"
type: INT
}
attrs {
name: "out_threshold"
type: FLOAT
}
attrs {
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
type: BOOLEAN
...
...
paddle/fluid/operators/compat/fake_channel_wise_dequantize_max_abs.pbtxt
0 → 100644
浏览文件 @
bd68761a
type: "fake_channel_wise_dequantize_max_abs"
def {
inputs {
name: "X"
}
inputs {
name: "Scales"
}
outputs {
name: "Out"
}
attrs {
name: "quant_bits"
type: INTS
}
attrs {
name: "quant_axis"
type: INT
}
}
extra {
attrs {
name: "is_test"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
bd68761a
...
@@ -1183,7 +1183,8 @@ class QuantizationFreezePass(object):
...
@@ -1183,7 +1183,8 @@ class QuantizationFreezePass(object):
if
op_node_desc
.
has_attr
(
"quantization_type"
)
and
\
if
op_node_desc
.
has_attr
(
"quantization_type"
)
and
\
op_node_desc
.
attr
(
"quantization_type"
)
==
"qat_with_weight"
:
op_node_desc
.
attr
(
"quantization_type"
)
==
"qat_with_weight"
:
if
self
.
_weight_quantize_type
==
'channel_wise_abs_max'
:
if
self
.
_weight_quantize_type
==
'channel_wise_abs_max'
:
self
.
_insert_post_channel_dequant_op
(
graph
,
op_node
)
self
.
_insert_post_channel_dequant_op
(
graph
,
op_node
,
quant_axis
)
else
:
else
:
self
.
_insert_post_dequant_op
(
graph
,
op_node
)
self
.
_insert_post_dequant_op
(
graph
,
op_node
)
...
@@ -1210,7 +1211,7 @@ class QuantizationFreezePass(object):
...
@@ -1210,7 +1211,7 @@ class QuantizationFreezePass(object):
v
.
node
]
v
.
node
]
graph
.
safe_remove_nodes
(
op_node
)
graph
.
safe_remove_nodes
(
op_node
)
def
_insert_post_channel_dequant_op
(
self
,
graph
,
op_node
):
def
_insert_post_channel_dequant_op
(
self
,
graph
,
op_node
,
quant_axis
):
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_nodes
()]
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_nodes
()]
for
var_node
in
op_node
.
inputs
:
for
var_node
in
op_node
.
inputs
:
name
=
var_node
.
name
()
name
=
var_node
.
name
()
...
@@ -1258,6 +1259,7 @@ class QuantizationFreezePass(object):
...
@@ -1258,6 +1259,7 @@ class QuantizationFreezePass(object):
op_type
=
'fake_channel_wise_dequantize_max_abs'
,
op_type
=
'fake_channel_wise_dequantize_max_abs'
,
attrs
=
{
attrs
=
{
'quant_bits'
:
[
self
.
_weight_bits
,
self
.
_activation_bits
],
'quant_bits'
:
[
self
.
_weight_bits
,
self
.
_activation_bits
],
'quant_axis'
:
quant_axis
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
},
},
inputs
=
{
inputs
=
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录