Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
589cd878
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
589cd878
编写于
3月 24, 2020
作者:
C
cc
提交者:
GitHub
3月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Post_training_quantizaion supports min_max methon (#23078)
* Post_training_quantizaion supports min_max methon
上级
194a22c5
变更
4
展开全部
隐藏空白更改
内联
并排
Showing
4 changed file
with
268 addition
and
184 deletion
+268
-184
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
...d/contrib/slim/quantization/post_training_quantization.py
+193
-110
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+47
-66
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py
...slim/tests/test_post_training_quantization_mobilenetv1.py
+25
-6
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py
...ib/slim/tests/test_post_training_quantization_resnet50.py
+3
-2
未找到文件。
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
浏览文件 @
589cd878
此差异已折叠。
点击以展开。
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
589cd878
...
...
@@ -35,6 +35,10 @@ _fake_dequant_op_list = [
'fake_dequantize_max_abs'
,
'fake_channel_wise_dequantize_max_abs'
]
_fake_quant_dequant_op_list
=
[
'fake_quantize_dequantize_moving_average_abs_max'
]
_out_scale_op_list
=
[
"mul"
,
"conv2d"
,
"pool2d"
,
"relu"
,
"softmax"
,
"sigmoid"
,
"depthwise_conv2d"
,
"batch_norm"
,
"concat"
,
"tanh"
,
"pad"
,
"elementwise_add"
,
"elementwise_mul"
,
...
...
@@ -44,7 +48,7 @@ _out_scale_op_list = [
# list op real input and output names, to avoid processing input such as AxisTensor.
_op_real_in_out_name
=
{
"conv2d"
:
[[
"Input"
,
"Filter"
],
[
"Output"
]],
"depthwise_conv2d"
:
[[
"Input"
],
[
"Output"
]],
"depthwise_conv2d"
:
[[
"Input"
,
"Filter"
],
[
"Output"
]],
"mul"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"matmul"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"pool2d"
:
[[
"X"
],
[
"Out"
]],
...
...
@@ -236,6 +240,7 @@ class QuantizationTransformPass(object):
op_node
.
op
().
_set_attr
(
"skip_quant"
,
True
)
def
_transform_forward
(
graph
,
op
):
op
.
op
().
_set_attr
(
"quantization_type"
,
"qat_with_weight"
)
for
var_node
in
op
.
inputs
:
if
var_node
.
name
()
not
in
op
.
input_arg_names
():
continue
...
...
@@ -290,7 +295,7 @@ class QuantizationTransformPass(object):
# The loop for transforming the forward graph:
for
op
in
ops
:
if
op
.
name
()
in
self
.
_quantizable_ops
:
if
not
QuantizationTransformPass
.
_is_skip_quant
(
graph
,
op
):
if
not
self
.
_is_skip_quant
(
graph
,
op
):
_transform_forward
(
graph
,
op
)
# The loop for renaming the inputs of backward op.
for
op
in
ops
:
...
...
@@ -636,8 +641,7 @@ class QuantizationTransformPass(object):
"""
return
"%s.scale"
%
(
var_name
)
@
staticmethod
def
_is_skip_quant
(
graph
,
op_node
):
def
_is_skip_quant
(
self
,
graph
,
op_node
):
"""
Analyse whether the op node skips quantization.
"""
...
...
@@ -650,20 +654,20 @@ class QuantizationTransformPass(object):
if
op_node
.
name
()
in
[
"mul"
,
"matmul"
]
and
\
_is_input_all_not_persistable
(
graph
,
op_node
):
is_skip
=
True
if
op_node
.
op
().
has_attr
(
"quantization_type"
)
and
\
op_node
.
op
().
attr
(
"quantization_type"
)
==
"qat_without_weight"
:
is_skip
=
True
return
is_skip
class
QuantizationFreezePass
(
object
):
_supported_quantizable_op_type
=
\
QuantizationTransformPass
.
_supported_quantizable_op_type
def
__init__
(
self
,
scope
,
place
,
weight_bits
=
8
,
activation_bits
=
8
,
weight_quantize_type
=
'abs_max'
,
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
):
quantizable_op_type
=
None
):
"""
The freeze pass is used to adjust the quantize operator order, for example:
1) `activation -> quant -> dequant -> conv2d` will be frozen into
...
...
@@ -679,9 +683,8 @@ class QuantizationFreezePass(object):
weight_quantize_type(str): quantization type for weights, support 'abs_max' and
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
since weights are fixed once the model is well trained.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationTransformPass and ConvertToInt8Pass must be the same as this.
quantizable_op_type(list[str]): This input param will be removed latter. The pass
will process all quantized op, so it is not necessary to set the input param.
"""
assert
scope
is
not
None
,
\
'The scope cannot be set None.'
...
...
@@ -692,16 +695,12 @@ class QuantizationFreezePass(object):
self
.
_weight_bits
=
weight_bits
self
.
_activation_bits
=
activation_bits
self
.
_weight_quantize_type
=
weight_quantize_type
self
.
_quantizable_ops
=
quantizable_op_type
for
op
in
self
.
_quantizable_ops
:
assert
op
in
QuantizationFreezePass
.
_supported_quantizable_op_type
,
\
op
+
" is not supported for quantization."
self
.
_conv_ops
=
[
'conv2d'
,
'depthwise_conv2d'
]
self
.
_fake_quant_op_names
=
_fake_quant_op_list
self
.
_fake_dequant_op_names
=
_fake_dequant_op_list
self
.
_op_input_rename_map
=
collections
.
OrderedDict
()
self
.
_op_output_rename_map
=
collections
.
OrderedDict
()
self
.
_var_scale_map
=
collections
.
OrderedDict
()
self
.
_
quant_
var_scale_map
=
collections
.
OrderedDict
()
def
apply
(
self
,
graph
):
"""
...
...
@@ -712,6 +711,7 @@ class QuantizationFreezePass(object):
Returns:
None
"""
# Get input scales in fake quant op and process weights
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_nodes
()]
ops
=
graph
.
all_op_nodes
()
for
op_node
in
ops
:
...
...
@@ -733,7 +733,7 @@ class QuantizationFreezePass(object):
else
:
scale_v
=
self
.
_load_var
(
op_node
.
output
(
'OutScale'
)[
0
])[
0
]
self
.
_var_scale_map
[
input_arg_name
]
=
scale_v
self
.
_
quant_
var_scale_map
[
input_arg_name
]
=
scale_v
self
.
_remove_fake_quant_and_dequant_op
(
graph
,
op_node
)
# quantize weight and restore
param_v
=
self
.
_load_var
(
input_arg_name
)
...
...
@@ -743,32 +743,29 @@ class QuantizationFreezePass(object):
else
:
scale_v
=
graph
.
_find_node_by_name
(
op_node
.
outputs
,
op_node
.
output
(
'OutScale'
)[
0
])
self
.
_var_scale_map
[
input_arg_name
]
=
scale_v
self
.
_
quant_
var_scale_map
[
input_arg_name
]
=
scale_v
# Remove all fake dequant op
ops
=
graph
.
all_op_nodes
()
for
op_node
in
ops
:
op_name
=
op_node
.
name
()
if
op_name
in
self
.
_fake_dequant_op_names
:
self
.
_remove_fake_quant_and_dequant_op
(
graph
,
op_node
)
# Insert post dequant op
ops
=
graph
.
all_op_nodes
()
for
op_node
in
ops
:
op_name
=
op_node
.
name
()
if
op_name
in
self
.
_quantizable_ops
:
# only process the node that is quantized by QuantizationTransformPass
is_op_node_quantized
=
False
for
var_node
in
op_node
.
inputs
:
var_name
=
var_node
.
name
()
if
var_name
.
endswith
(
'.dequantized'
):
is_op_node_quantized
=
True
if
is_op_node_quantized
:
if
self
.
_weight_quantize_type
==
'channel_wise_abs_max'
and
op_name
in
self
.
_conv_ops
:
self
.
_insert_post_channel_dequant_op
(
graph
,
op_node
)
else
:
self
.
_insert_post_dequant_op
(
graph
,
op_node
)
op_node_desc
=
op_node
.
op
()
if
op_node_desc
.
has_attr
(
"quantization_type"
)
and
\
op_node_desc
.
attr
(
"quantization_type"
)
==
"qat_with_weight"
:
if
self
.
_weight_quantize_type
==
'channel_wise_abs_max'
\
and
op_node
.
name
()
in
self
.
_conv_ops
:
self
.
_insert_post_channel_dequant_op
(
graph
,
op_node
)
else
:
self
.
_insert_post_dequant_op
(
graph
,
op_node
)
# Rename inputs of the followed ops after inserting dequant_op after fc/conv
for
op_node
in
ops
:
# insert dequant_op after fc/conv, need to rename inputs of the followed ops
for
var_node
in
op_node
.
inputs
:
if
var_node
.
node
in
self
.
_op_output_rename_map
:
old_in
=
var_node
...
...
@@ -802,7 +799,7 @@ class QuantizationFreezePass(object):
new_in
.
clear_outputs
()
graph
.
update_input_link
(
old_in
,
new_in
,
op_node
)
original_var_name
=
self
.
_original_var_name
(
name
)
scale_v
=
self
.
_var_scale_map
[
original_var_name
]
scale_v
=
self
.
_
quant_
var_scale_map
[
original_var_name
]
if
original_var_name
in
persistable_vars
:
assert
isinstance
(
scale_v
,
...
...
@@ -811,7 +808,7 @@ class QuantizationFreezePass(object):
channel_scale
=
np
.
array
(
scale_v
)
else
:
assert
isinstance
(
scale_v
,
IrNode
)
scale_var_node
=
self
.
_var_scale_map
[
original_var_name
]
scale_var_node
=
self
.
_
quant_
var_scale_map
[
original_var_name
]
if
len
(
op_node
.
output_arg_names
())
!=
1
:
raise
ValueError
(
"Only support one output, but op %s has"
...
...
@@ -867,7 +864,7 @@ class QuantizationFreezePass(object):
new_in
.
clear_outputs
()
graph
.
update_input_link
(
old_in
,
new_in
,
op_node
)
original_var_name
=
self
.
_original_var_name
(
name
)
scale_v
=
self
.
_var_scale_map
[
original_var_name
]
scale_v
=
self
.
_
quant_
var_scale_map
[
original_var_name
]
if
original_var_name
in
persistable_vars
:
assert
self
.
_is_float
(
scale_v
),
'The scale of parameter %s is not a float.'
%
(
...
...
@@ -876,7 +873,7 @@ class QuantizationFreezePass(object):
else
:
max_range
*=
act_range
assert
isinstance
(
scale_v
,
IrNode
)
scale_var_node
=
self
.
_var_scale_map
[
original_var_name
]
scale_var_node
=
self
.
_
quant_
var_scale_map
[
original_var_name
]
if
len
(
op_node
.
output_arg_names
())
!=
1
:
raise
ValueError
(
"Only support one output, but op %s has"
...
...
@@ -963,13 +960,7 @@ class QuantizationFreezePass(object):
class
ConvertToInt8Pass
(
object
):
_supported_quantizable_op_type
=
\
QuantizationTransformPass
.
_supported_quantizable_op_type
def
__init__
(
self
,
scope
,
place
,
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]):
def
__init__
(
self
,
scope
,
place
,
quantizable_op_type
=
None
):
"""
Convert the weights into int8_t type.
...
...
@@ -977,9 +968,8 @@ class ConvertToInt8Pass(object):
scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the
8bits weight tensors.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationTransformPass and QuantizationFreezePass must be the same as this.
quantizable_op_type(list[str]): This input param will be removed latter. The pass
will process all quantized op, so it is not necessary to set the input param.
"""
assert
scope
is
not
None
,
\
'The scope cannot be set None.'
...
...
@@ -987,10 +977,6 @@ class ConvertToInt8Pass(object):
'The place cannot be set None.'
self
.
_scope
=
scope
self
.
_place
=
place
self
.
_quantizable_ops
=
quantizable_op_type
for
op
in
self
.
_quantizable_ops
:
assert
op
in
ConvertToInt8Pass
.
_supported_quantizable_op_type
,
\
op
+
" is not supported for quantization."
def
apply
(
self
,
graph
):
"""
...
...
@@ -1006,10 +992,8 @@ class ConvertToInt8Pass(object):
ops
=
graph
.
all_op_nodes
()
input_map
=
{}
for
op_node
in
ops
:
op_name
=
op_node
.
name
()
if
op_name
in
self
.
_quantizable_ops
:
if
QuantizationTransformPass
.
_is_skip_quant
(
graph
,
op_node
):
continue
if
op_node
.
op
().
has_attr
(
"quantization_type"
)
and
\
op_node
.
op
().
attr
(
"quantization_type"
)
==
"qat_with_weight"
:
for
var_node
in
op_node
.
inputs
:
name
=
var_node
.
name
()
if
name
in
persistable_vars
:
...
...
@@ -1259,9 +1243,9 @@ class AddQuantDequantPass(object):
"equal"
,
"gather"
,
"greater_equal"
,
"greater_than"
,
"less_equal"
,
"less_than"
,
"mean"
,
"not_equal"
,
"reshape"
,
"reshape2"
,
"bilinear_interp"
,
"nearest_interp"
,
"trilinear_interp"
,
"slice"
,
"squeeze"
,
"elementwise_sub"
,
"mul"
,
"matmul"
"squeeze"
,
"elementwise_sub"
,
"mul"
,
"matmul"
,
"relu"
,
"relu6"
,
"leaky_relu"
,
"tanh"
,
"swish"
]
_activation_type
=
[
"relu"
,
"relu6"
,
"leaky_relu"
,
"tanh"
,
"swish"
]
def
__init__
(
self
,
scope
=
None
,
...
...
@@ -1307,8 +1291,7 @@ class AddQuantDequantPass(object):
else
:
self
.
_quantizable_op_type
=
quantizable_op_type
for
op_type
in
quantizable_op_type
:
assert
op_type
in
AddQuantDequantPass
.
_supported_quantizable_op_type
+
\
AddQuantDequantPass
.
_activation_type
,
\
assert
op_type
in
AddQuantDequantPass
.
_supported_quantizable_op_type
,
\
op_type
+
" is not supported for quantization."
self
.
_quantizable_grad_op_type
=
[
'%s_grad'
%
(
op
)
for
op
in
self
.
_quantizable_op_type
...
...
@@ -1343,17 +1326,15 @@ class AddQuantDequantPass(object):
elif
isinstance
(
self
.
_skip_pattern
,
str
):
is_skip
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
self
.
_skip_pattern
)
!=
-
1
is_op_node_quantized
=
False
for
var_node
in
op_node
.
inputs
:
var_name
=
var_node
.
name
()
if
var_name
.
endswith
(
'.dequantized'
):
is_op_node_quantized
=
True
if
is_skip
or
is_op_node_quantized
or
\
is_quantized
=
op_node
.
op
().
has_attr
(
"quantization_type"
)
and
\
op_node
.
op
().
attr
(
"quantization_type"
)
==
"qat_with_weight"
if
is_skip
or
is_quantized
or
\
(
not
_is_input_all_not_persistable
(
graph
,
op_node
)):
continue
op_node
.
op
().
_set_attr
(
"quantization_type"
,
"qat_without_weight"
)
op_node
.
op
().
_set_attr
(
"activation_bits"
,
self
.
_quant_bits
)
input_name_list
=
_op_real_in_out_name
[
op_node
.
name
()][
0
]
arg_names
=
[]
for
input_name
in
input_name_list
:
...
...
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py
浏览文件 @
589cd878
...
...
@@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
ptq
.
save_quantized_model
(
self
.
int8_model
)
def
run_test
(
self
,
model
,
algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
):
is_full_quantize
,
is_use_cache_file
,
diff_threshold
):
infer_iterations
=
self
.
infer_iterations
batch_size
=
self
.
batch_size
sample_iterations
=
self
.
sample_iterations
...
...
@@ -296,11 +296,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
sys
.
stdout
.
flush
()
delta_value
=
fp32_acc1
-
int8_acc1
self
.
assertLess
(
delta_value
,
0.025
)
self
.
assertLess
(
delta_value
,
diff_threshold
)
class
TestPostTrainingForMobilenetv1
(
TestPostTrainingQuantization
):
def
test_post_training_mobilenetv1
(
self
):
class
TestPostTraining
KL
ForMobilenetv1
(
TestPostTrainingQuantization
):
def
test_post_training_
kl_
mobilenetv1
(
self
):
model
=
"MobileNet-V1"
algo
=
"KL"
data_urls
=
[
...
...
@@ -310,10 +310,29 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
,
"pool2d"
,
"elementwise_add"
]
is_full_quantize
=
Tru
e
is_full_quantize
=
Fals
e
is_use_cache_file
=
False
diff_threshold
=
0.025
self
.
run_test
(
model
,
algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
)
is_full_quantize
,
is_use_cache_file
,
diff_threshold
)
class
TestPostTrainingAbsMaxForMobilenetv1
(
TestPostTrainingQuantization
):
def
test_post_training_abs_max_mobilenetv1
(
self
):
model
=
"MobileNet-V1"
algo
=
"abs_max"
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s
=
[
'13892b0716d26443a8cdea15b3c6438b'
]
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
,
"pool2d"
,
"elementwise_add"
]
is_full_quantize
=
False
is_use_cache_file
=
False
diff_threshold
=
0.05
self
.
run_test
(
model
,
algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
diff_threshold
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py
浏览文件 @
589cd878
...
...
@@ -20,7 +20,7 @@ from test_post_training_quantization_mobilenetv1 import TestPostTrainingQuantiza
class
TestPostTrainingForResnet50
(
TestPostTrainingQuantization
):
def
test_post_training_resnet50
(
self
):
model
=
"ResNet-50"
algo
=
"
direct
"
algo
=
"
min_max
"
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
...
...
@@ -28,8 +28,9 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
quantizable_op_type
=
[
"conv2d"
,
"mul"
]
is_full_quantize
=
False
is_use_cache_file
=
False
diff_threshold
=
0.025
self
.
run_test
(
model
,
algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
)
is_full_quantize
,
is_use_cache_file
,
diff_threshold
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录