Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
589cd878
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
589cd878
编写于
3月 24, 2020
作者:
C
cc
提交者:
GitHub
3月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Post_training_quantizaion supports min_max methon (#23078)
* Post_training_quantizaion supports min_max methon
上级
194a22c5
变更
4
展开全部
隐藏空白更改
内联
并排
Showing
4 changed file
with
268 addition
and
184 deletion
+268
-184
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
...d/contrib/slim/quantization/post_training_quantization.py
+193
-110
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+47
-66
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py
...slim/tests/test_post_training_quantization_mobilenetv1.py
+25
-6
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py
...ib/slim/tests/test_post_training_quantization_resnet50.py
+3
-2
未找到文件。
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
浏览文件 @
589cd878
此差异已折叠。
点击以展开。
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
589cd878
...
@@ -35,6 +35,10 @@ _fake_dequant_op_list = [
...
@@ -35,6 +35,10 @@ _fake_dequant_op_list = [
'fake_dequantize_max_abs'
,
'fake_channel_wise_dequantize_max_abs'
'fake_dequantize_max_abs'
,
'fake_channel_wise_dequantize_max_abs'
]
]
_fake_quant_dequant_op_list
=
[
'fake_quantize_dequantize_moving_average_abs_max'
]
_out_scale_op_list
=
[
_out_scale_op_list
=
[
"mul"
,
"conv2d"
,
"pool2d"
,
"relu"
,
"softmax"
,
"sigmoid"
,
"depthwise_conv2d"
,
"mul"
,
"conv2d"
,
"pool2d"
,
"relu"
,
"softmax"
,
"sigmoid"
,
"depthwise_conv2d"
,
"batch_norm"
,
"concat"
,
"tanh"
,
"pad"
,
"elementwise_add"
,
"elementwise_mul"
,
"batch_norm"
,
"concat"
,
"tanh"
,
"pad"
,
"elementwise_add"
,
"elementwise_mul"
,
...
@@ -44,7 +48,7 @@ _out_scale_op_list = [
...
@@ -44,7 +48,7 @@ _out_scale_op_list = [
# list op real input and output names, to avoid processing input such as AxisTensor.
# list op real input and output names, to avoid processing input such as AxisTensor.
_op_real_in_out_name
=
{
_op_real_in_out_name
=
{
"conv2d"
:
[[
"Input"
,
"Filter"
],
[
"Output"
]],
"conv2d"
:
[[
"Input"
,
"Filter"
],
[
"Output"
]],
"depthwise_conv2d"
:
[[
"Input"
],
[
"Output"
]],
"depthwise_conv2d"
:
[[
"Input"
,
"Filter"
],
[
"Output"
]],
"mul"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"mul"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"matmul"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"matmul"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"pool2d"
:
[[
"X"
],
[
"Out"
]],
"pool2d"
:
[[
"X"
],
[
"Out"
]],
...
@@ -236,6 +240,7 @@ class QuantizationTransformPass(object):
...
@@ -236,6 +240,7 @@ class QuantizationTransformPass(object):
op_node
.
op
().
_set_attr
(
"skip_quant"
,
True
)
op_node
.
op
().
_set_attr
(
"skip_quant"
,
True
)
def
_transform_forward
(
graph
,
op
):
def
_transform_forward
(
graph
,
op
):
op
.
op
().
_set_attr
(
"quantization_type"
,
"qat_with_weight"
)
for
var_node
in
op
.
inputs
:
for
var_node
in
op
.
inputs
:
if
var_node
.
name
()
not
in
op
.
input_arg_names
():
if
var_node
.
name
()
not
in
op
.
input_arg_names
():
continue
continue
...
@@ -290,7 +295,7 @@ class QuantizationTransformPass(object):
...
@@ -290,7 +295,7 @@ class QuantizationTransformPass(object):
# The loop for transforming the forward graph:
# The loop for transforming the forward graph:
for
op
in
ops
:
for
op
in
ops
:
if
op
.
name
()
in
self
.
_quantizable_ops
:
if
op
.
name
()
in
self
.
_quantizable_ops
:
if
not
QuantizationTransformPass
.
_is_skip_quant
(
graph
,
op
):
if
not
self
.
_is_skip_quant
(
graph
,
op
):
_transform_forward
(
graph
,
op
)
_transform_forward
(
graph
,
op
)
# The loop for renaming the inputs of backward op.
# The loop for renaming the inputs of backward op.
for
op
in
ops
:
for
op
in
ops
:
...
@@ -636,8 +641,7 @@ class QuantizationTransformPass(object):
...
@@ -636,8 +641,7 @@ class QuantizationTransformPass(object):
"""
"""
return
"%s.scale"
%
(
var_name
)
return
"%s.scale"
%
(
var_name
)
@
staticmethod
def
_is_skip_quant
(
self
,
graph
,
op_node
):
def
_is_skip_quant
(
graph
,
op_node
):
"""
"""
Analyse whether the op node skips quantization.
Analyse whether the op node skips quantization.
"""
"""
...
@@ -650,20 +654,20 @@ class QuantizationTransformPass(object):
...
@@ -650,20 +654,20 @@ class QuantizationTransformPass(object):
if
op_node
.
name
()
in
[
"mul"
,
"matmul"
]
and
\
if
op_node
.
name
()
in
[
"mul"
,
"matmul"
]
and
\
_is_input_all_not_persistable
(
graph
,
op_node
):
_is_input_all_not_persistable
(
graph
,
op_node
):
is_skip
=
True
is_skip
=
True
if
op_node
.
op
().
has_attr
(
"quantization_type"
)
and
\
op_node
.
op
().
attr
(
"quantization_type"
)
==
"qat_without_weight"
:
is_skip
=
True
return
is_skip
return
is_skip
class
QuantizationFreezePass
(
object
):
class
QuantizationFreezePass
(
object
):
_supported_quantizable_op_type
=
\
QuantizationTransformPass
.
_supported_quantizable_op_type
def
__init__
(
self
,
def
__init__
(
self
,
scope
,
scope
,
place
,
place
,
weight_bits
=
8
,
weight_bits
=
8
,
activation_bits
=
8
,
activation_bits
=
8
,
weight_quantize_type
=
'abs_max'
,
weight_quantize_type
=
'abs_max'
,
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
):
quantizable_op_type
=
None
):
"""
"""
The freeze pass is used to adjust the quantize operator order, for example:
The freeze pass is used to adjust the quantize operator order, for example:
1) `activation -> quant -> dequant -> conv2d` will be frozen into
1) `activation -> quant -> dequant -> conv2d` will be frozen into
...
@@ -679,9 +683,8 @@ class QuantizationFreezePass(object):
...
@@ -679,9 +683,8 @@ class QuantizationFreezePass(object):
weight_quantize_type(str): quantization type for weights, support 'abs_max' and
weight_quantize_type(str): quantization type for weights, support 'abs_max' and
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
since weights are fixed once the model is well trained.
since weights are fixed once the model is well trained.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
quantizable_op_type(list[str]): This input param will be removed latter. The pass
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
will process all quantized op, so it is not necessary to set the input param.
QuantizationTransformPass and ConvertToInt8Pass must be the same as this.
"""
"""
assert
scope
is
not
None
,
\
assert
scope
is
not
None
,
\
'The scope cannot be set None.'
'The scope cannot be set None.'
...
@@ -692,16 +695,12 @@ class QuantizationFreezePass(object):
...
@@ -692,16 +695,12 @@ class QuantizationFreezePass(object):
self
.
_weight_bits
=
weight_bits
self
.
_weight_bits
=
weight_bits
self
.
_activation_bits
=
activation_bits
self
.
_activation_bits
=
activation_bits
self
.
_weight_quantize_type
=
weight_quantize_type
self
.
_weight_quantize_type
=
weight_quantize_type
self
.
_quantizable_ops
=
quantizable_op_type
for
op
in
self
.
_quantizable_ops
:
assert
op
in
QuantizationFreezePass
.
_supported_quantizable_op_type
,
\
op
+
" is not supported for quantization."
self
.
_conv_ops
=
[
'conv2d'
,
'depthwise_conv2d'
]
self
.
_conv_ops
=
[
'conv2d'
,
'depthwise_conv2d'
]
self
.
_fake_quant_op_names
=
_fake_quant_op_list
self
.
_fake_quant_op_names
=
_fake_quant_op_list
self
.
_fake_dequant_op_names
=
_fake_dequant_op_list
self
.
_fake_dequant_op_names
=
_fake_dequant_op_list
self
.
_op_input_rename_map
=
collections
.
OrderedDict
()
self
.
_op_input_rename_map
=
collections
.
OrderedDict
()
self
.
_op_output_rename_map
=
collections
.
OrderedDict
()
self
.
_op_output_rename_map
=
collections
.
OrderedDict
()
self
.
_var_scale_map
=
collections
.
OrderedDict
()
self
.
_
quant_
var_scale_map
=
collections
.
OrderedDict
()
def
apply
(
self
,
graph
):
def
apply
(
self
,
graph
):
"""
"""
...
@@ -712,6 +711,7 @@ class QuantizationFreezePass(object):
...
@@ -712,6 +711,7 @@ class QuantizationFreezePass(object):
Returns:
Returns:
None
None
"""
"""
# Get input scales in fake quant op and process weights
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_nodes
()]
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_nodes
()]
ops
=
graph
.
all_op_nodes
()
ops
=
graph
.
all_op_nodes
()
for
op_node
in
ops
:
for
op_node
in
ops
:
...
@@ -733,7 +733,7 @@ class QuantizationFreezePass(object):
...
@@ -733,7 +733,7 @@ class QuantizationFreezePass(object):
else
:
else
:
scale_v
=
self
.
_load_var
(
scale_v
=
self
.
_load_var
(
op_node
.
output
(
'OutScale'
)[
0
])[
0
]
op_node
.
output
(
'OutScale'
)[
0
])[
0
]
self
.
_var_scale_map
[
input_arg_name
]
=
scale_v
self
.
_
quant_
var_scale_map
[
input_arg_name
]
=
scale_v
self
.
_remove_fake_quant_and_dequant_op
(
graph
,
op_node
)
self
.
_remove_fake_quant_and_dequant_op
(
graph
,
op_node
)
# quantize weight and restore
# quantize weight and restore
param_v
=
self
.
_load_var
(
input_arg_name
)
param_v
=
self
.
_load_var
(
input_arg_name
)
...
@@ -743,32 +743,29 @@ class QuantizationFreezePass(object):
...
@@ -743,32 +743,29 @@ class QuantizationFreezePass(object):
else
:
else
:
scale_v
=
graph
.
_find_node_by_name
(
scale_v
=
graph
.
_find_node_by_name
(
op_node
.
outputs
,
op_node
.
output
(
'OutScale'
)[
0
])
op_node
.
outputs
,
op_node
.
output
(
'OutScale'
)[
0
])
self
.
_var_scale_map
[
input_arg_name
]
=
scale_v
self
.
_
quant_
var_scale_map
[
input_arg_name
]
=
scale_v
# Remove all fake dequant op
ops
=
graph
.
all_op_nodes
()
ops
=
graph
.
all_op_nodes
()
for
op_node
in
ops
:
for
op_node
in
ops
:
op_name
=
op_node
.
name
()
op_name
=
op_node
.
name
()
if
op_name
in
self
.
_fake_dequant_op_names
:
if
op_name
in
self
.
_fake_dequant_op_names
:
self
.
_remove_fake_quant_and_dequant_op
(
graph
,
op_node
)
self
.
_remove_fake_quant_and_dequant_op
(
graph
,
op_node
)
# Insert post dequant op
ops
=
graph
.
all_op_nodes
()
ops
=
graph
.
all_op_nodes
()
for
op_node
in
ops
:
for
op_node
in
ops
:
op_name
=
op_node
.
name
()
op_node_desc
=
op_node
.
op
()
if
op_name
in
self
.
_quantizable_ops
:
if
op_node_desc
.
has_attr
(
"quantization_type"
)
and
\
# only process the node that is quantized by QuantizationTransformPass
op_node_desc
.
attr
(
"quantization_type"
)
==
"qat_with_weight"
:
is_op_node_quantized
=
False
if
self
.
_weight_quantize_type
==
'channel_wise_abs_max'
\
for
var_node
in
op_node
.
inputs
:
and
op_node
.
name
()
in
self
.
_conv_ops
:
var_name
=
var_node
.
name
()
self
.
_insert_post_channel_dequant_op
(
graph
,
op_node
)
if
var_name
.
endswith
(
'.dequantized'
):
else
:
is_op_node_quantized
=
True
self
.
_insert_post_dequant_op
(
graph
,
op_node
)
if
is_op_node_quantized
:
if
self
.
_weight_quantize_type
==
'channel_wise_abs_max'
and
op_name
in
self
.
_conv_ops
:
self
.
_insert_post_channel_dequant_op
(
graph
,
op_node
)
else
:
self
.
_insert_post_dequant_op
(
graph
,
op_node
)
# Rename inputs of the followed ops after inserting dequant_op after fc/conv
for
op_node
in
ops
:
for
op_node
in
ops
:
# insert dequant_op after fc/conv, need to rename inputs of the followed ops
for
var_node
in
op_node
.
inputs
:
for
var_node
in
op_node
.
inputs
:
if
var_node
.
node
in
self
.
_op_output_rename_map
:
if
var_node
.
node
in
self
.
_op_output_rename_map
:
old_in
=
var_node
old_in
=
var_node
...
@@ -802,7 +799,7 @@ class QuantizationFreezePass(object):
...
@@ -802,7 +799,7 @@ class QuantizationFreezePass(object):
new_in
.
clear_outputs
()
new_in
.
clear_outputs
()
graph
.
update_input_link
(
old_in
,
new_in
,
op_node
)
graph
.
update_input_link
(
old_in
,
new_in
,
op_node
)
original_var_name
=
self
.
_original_var_name
(
name
)
original_var_name
=
self
.
_original_var_name
(
name
)
scale_v
=
self
.
_var_scale_map
[
original_var_name
]
scale_v
=
self
.
_
quant_
var_scale_map
[
original_var_name
]
if
original_var_name
in
persistable_vars
:
if
original_var_name
in
persistable_vars
:
assert
isinstance
(
assert
isinstance
(
scale_v
,
scale_v
,
...
@@ -811,7 +808,7 @@ class QuantizationFreezePass(object):
...
@@ -811,7 +808,7 @@ class QuantizationFreezePass(object):
channel_scale
=
np
.
array
(
scale_v
)
channel_scale
=
np
.
array
(
scale_v
)
else
:
else
:
assert
isinstance
(
scale_v
,
IrNode
)
assert
isinstance
(
scale_v
,
IrNode
)
scale_var_node
=
self
.
_var_scale_map
[
original_var_name
]
scale_var_node
=
self
.
_
quant_
var_scale_map
[
original_var_name
]
if
len
(
op_node
.
output_arg_names
())
!=
1
:
if
len
(
op_node
.
output_arg_names
())
!=
1
:
raise
ValueError
(
"Only support one output, but op %s has"
raise
ValueError
(
"Only support one output, but op %s has"
...
@@ -867,7 +864,7 @@ class QuantizationFreezePass(object):
...
@@ -867,7 +864,7 @@ class QuantizationFreezePass(object):
new_in
.
clear_outputs
()
new_in
.
clear_outputs
()
graph
.
update_input_link
(
old_in
,
new_in
,
op_node
)
graph
.
update_input_link
(
old_in
,
new_in
,
op_node
)
original_var_name
=
self
.
_original_var_name
(
name
)
original_var_name
=
self
.
_original_var_name
(
name
)
scale_v
=
self
.
_var_scale_map
[
original_var_name
]
scale_v
=
self
.
_
quant_
var_scale_map
[
original_var_name
]
if
original_var_name
in
persistable_vars
:
if
original_var_name
in
persistable_vars
:
assert
self
.
_is_float
(
assert
self
.
_is_float
(
scale_v
),
'The scale of parameter %s is not a float.'
%
(
scale_v
),
'The scale of parameter %s is not a float.'
%
(
...
@@ -876,7 +873,7 @@ class QuantizationFreezePass(object):
...
@@ -876,7 +873,7 @@ class QuantizationFreezePass(object):
else
:
else
:
max_range
*=
act_range
max_range
*=
act_range
assert
isinstance
(
scale_v
,
IrNode
)
assert
isinstance
(
scale_v
,
IrNode
)
scale_var_node
=
self
.
_var_scale_map
[
original_var_name
]
scale_var_node
=
self
.
_
quant_
var_scale_map
[
original_var_name
]
if
len
(
op_node
.
output_arg_names
())
!=
1
:
if
len
(
op_node
.
output_arg_names
())
!=
1
:
raise
ValueError
(
"Only support one output, but op %s has"
raise
ValueError
(
"Only support one output, but op %s has"
...
@@ -963,13 +960,7 @@ class QuantizationFreezePass(object):
...
@@ -963,13 +960,7 @@ class QuantizationFreezePass(object):
class
ConvertToInt8Pass
(
object
):
class
ConvertToInt8Pass
(
object
):
_supported_quantizable_op_type
=
\
def
__init__
(
self
,
scope
,
place
,
quantizable_op_type
=
None
):
QuantizationTransformPass
.
_supported_quantizable_op_type
def
__init__
(
self
,
scope
,
place
,
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]):
"""
"""
Convert the weights into int8_t type.
Convert the weights into int8_t type.
...
@@ -977,9 +968,8 @@ class ConvertToInt8Pass(object):
...
@@ -977,9 +968,8 @@ class ConvertToInt8Pass(object):
scope(fluid.Scope): scope is used to get the weight tensor values.
scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the
8bits weight tensors.
8bits weight tensors.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
quantizable_op_type(list[str]): This input param will be removed latter. The pass
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
will process all quantized op, so it is not necessary to set the input param.
QuantizationTransformPass and QuantizationFreezePass must be the same as this.
"""
"""
assert
scope
is
not
None
,
\
assert
scope
is
not
None
,
\
'The scope cannot be set None.'
'The scope cannot be set None.'
...
@@ -987,10 +977,6 @@ class ConvertToInt8Pass(object):
...
@@ -987,10 +977,6 @@ class ConvertToInt8Pass(object):
'The place cannot be set None.'
'The place cannot be set None.'
self
.
_scope
=
scope
self
.
_scope
=
scope
self
.
_place
=
place
self
.
_place
=
place
self
.
_quantizable_ops
=
quantizable_op_type
for
op
in
self
.
_quantizable_ops
:
assert
op
in
ConvertToInt8Pass
.
_supported_quantizable_op_type
,
\
op
+
" is not supported for quantization."
def
apply
(
self
,
graph
):
def
apply
(
self
,
graph
):
"""
"""
...
@@ -1006,10 +992,8 @@ class ConvertToInt8Pass(object):
...
@@ -1006,10 +992,8 @@ class ConvertToInt8Pass(object):
ops
=
graph
.
all_op_nodes
()
ops
=
graph
.
all_op_nodes
()
input_map
=
{}
input_map
=
{}
for
op_node
in
ops
:
for
op_node
in
ops
:
op_name
=
op_node
.
name
()
if
op_node
.
op
().
has_attr
(
"quantization_type"
)
and
\
if
op_name
in
self
.
_quantizable_ops
:
op_node
.
op
().
attr
(
"quantization_type"
)
==
"qat_with_weight"
:
if
QuantizationTransformPass
.
_is_skip_quant
(
graph
,
op_node
):
continue
for
var_node
in
op_node
.
inputs
:
for
var_node
in
op_node
.
inputs
:
name
=
var_node
.
name
()
name
=
var_node
.
name
()
if
name
in
persistable_vars
:
if
name
in
persistable_vars
:
...
@@ -1259,9 +1243,9 @@ class AddQuantDequantPass(object):
...
@@ -1259,9 +1243,9 @@ class AddQuantDequantPass(object):
"equal"
,
"gather"
,
"greater_equal"
,
"greater_than"
,
"less_equal"
,
"equal"
,
"gather"
,
"greater_equal"
,
"greater_than"
,
"less_equal"
,
"less_than"
,
"mean"
,
"not_equal"
,
"reshape"
,
"reshape2"
,
"less_than"
,
"mean"
,
"not_equal"
,
"reshape"
,
"reshape2"
,
"bilinear_interp"
,
"nearest_interp"
,
"trilinear_interp"
,
"slice"
,
"bilinear_interp"
,
"nearest_interp"
,
"trilinear_interp"
,
"slice"
,
"squeeze"
,
"elementwise_sub"
,
"mul"
,
"matmul"
"squeeze"
,
"elementwise_sub"
,
"mul"
,
"matmul"
,
"relu"
,
"relu6"
,
"leaky_relu"
,
"tanh"
,
"swish"
]
]
_activation_type
=
[
"relu"
,
"relu6"
,
"leaky_relu"
,
"tanh"
,
"swish"
]
def
__init__
(
self
,
def
__init__
(
self
,
scope
=
None
,
scope
=
None
,
...
@@ -1307,8 +1291,7 @@ class AddQuantDequantPass(object):
...
@@ -1307,8 +1291,7 @@ class AddQuantDequantPass(object):
else
:
else
:
self
.
_quantizable_op_type
=
quantizable_op_type
self
.
_quantizable_op_type
=
quantizable_op_type
for
op_type
in
quantizable_op_type
:
for
op_type
in
quantizable_op_type
:
assert
op_type
in
AddQuantDequantPass
.
_supported_quantizable_op_type
+
\
assert
op_type
in
AddQuantDequantPass
.
_supported_quantizable_op_type
,
\
AddQuantDequantPass
.
_activation_type
,
\
op_type
+
" is not supported for quantization."
op_type
+
" is not supported for quantization."
self
.
_quantizable_grad_op_type
=
[
self
.
_quantizable_grad_op_type
=
[
'%s_grad'
%
(
op
)
for
op
in
self
.
_quantizable_op_type
'%s_grad'
%
(
op
)
for
op
in
self
.
_quantizable_op_type
...
@@ -1343,17 +1326,15 @@ class AddQuantDequantPass(object):
...
@@ -1343,17 +1326,15 @@ class AddQuantDequantPass(object):
elif
isinstance
(
self
.
_skip_pattern
,
str
):
elif
isinstance
(
self
.
_skip_pattern
,
str
):
is_skip
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
is_skip
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
self
.
_skip_pattern
)
!=
-
1
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
self
.
_skip_pattern
)
!=
-
1
is_quantized
=
op_node
.
op
().
has_attr
(
"quantization_type"
)
and
\
is_op_node_quantized
=
False
op_node
.
op
().
attr
(
"quantization_type"
)
==
"qat_with_weight"
for
var_node
in
op_node
.
inputs
:
if
is_skip
or
is_quantized
or
\
var_name
=
var_node
.
name
()
if
var_name
.
endswith
(
'.dequantized'
):
is_op_node_quantized
=
True
if
is_skip
or
is_op_node_quantized
or
\
(
not
_is_input_all_not_persistable
(
graph
,
op_node
)):
(
not
_is_input_all_not_persistable
(
graph
,
op_node
)):
continue
continue
op_node
.
op
().
_set_attr
(
"quantization_type"
,
"qat_without_weight"
)
op_node
.
op
().
_set_attr
(
"activation_bits"
,
self
.
_quant_bits
)
input_name_list
=
_op_real_in_out_name
[
op_node
.
name
()][
0
]
input_name_list
=
_op_real_in_out_name
[
op_node
.
name
()][
0
]
arg_names
=
[]
arg_names
=
[]
for
input_name
in
input_name_list
:
for
input_name
in
input_name_list
:
...
...
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py
浏览文件 @
589cd878
...
@@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
...
@@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
ptq
.
save_quantized_model
(
self
.
int8_model
)
ptq
.
save_quantized_model
(
self
.
int8_model
)
def
run_test
(
self
,
model
,
algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
def
run_test
(
self
,
model
,
algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
):
is_full_quantize
,
is_use_cache_file
,
diff_threshold
):
infer_iterations
=
self
.
infer_iterations
infer_iterations
=
self
.
infer_iterations
batch_size
=
self
.
batch_size
batch_size
=
self
.
batch_size
sample_iterations
=
self
.
sample_iterations
sample_iterations
=
self
.
sample_iterations
...
@@ -296,11 +296,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
...
@@ -296,11 +296,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
delta_value
=
fp32_acc1
-
int8_acc1
delta_value
=
fp32_acc1
-
int8_acc1
self
.
assertLess
(
delta_value
,
0.025
)
self
.
assertLess
(
delta_value
,
diff_threshold
)
class
TestPostTrainingForMobilenetv1
(
TestPostTrainingQuantization
):
class
TestPostTraining
KL
ForMobilenetv1
(
TestPostTrainingQuantization
):
def
test_post_training_mobilenetv1
(
self
):
def
test_post_training_
kl_
mobilenetv1
(
self
):
model
=
"MobileNet-V1"
model
=
"MobileNet-V1"
algo
=
"KL"
algo
=
"KL"
data_urls
=
[
data_urls
=
[
...
@@ -310,10 +310,29 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
...
@@ -310,10 +310,29 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
quantizable_op_type
=
[
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
,
"pool2d"
,
"elementwise_add"
"conv2d"
,
"depthwise_conv2d"
,
"mul"
,
"pool2d"
,
"elementwise_add"
]
]
is_full_quantize
=
Tru
e
is_full_quantize
=
Fals
e
is_use_cache_file
=
False
is_use_cache_file
=
False
diff_threshold
=
0.025
self
.
run_test
(
model
,
algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
self
.
run_test
(
model
,
algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
)
is_full_quantize
,
is_use_cache_file
,
diff_threshold
)
class
TestPostTrainingAbsMaxForMobilenetv1
(
TestPostTrainingQuantization
):
def
test_post_training_abs_max_mobilenetv1
(
self
):
model
=
"MobileNet-V1"
algo
=
"abs_max"
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s
=
[
'13892b0716d26443a8cdea15b3c6438b'
]
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
,
"pool2d"
,
"elementwise_add"
]
is_full_quantize
=
False
is_use_cache_file
=
False
diff_threshold
=
0.05
self
.
run_test
(
model
,
algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
diff_threshold
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py
浏览文件 @
589cd878
...
@@ -20,7 +20,7 @@ from test_post_training_quantization_mobilenetv1 import TestPostTrainingQuantiza
...
@@ -20,7 +20,7 @@ from test_post_training_quantization_mobilenetv1 import TestPostTrainingQuantiza
class
TestPostTrainingForResnet50
(
TestPostTrainingQuantization
):
class
TestPostTrainingForResnet50
(
TestPostTrainingQuantization
):
def
test_post_training_resnet50
(
self
):
def
test_post_training_resnet50
(
self
):
model
=
"ResNet-50"
model
=
"ResNet-50"
algo
=
"
direct
"
algo
=
"
min_max
"
data_urls
=
[
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
]
...
@@ -28,8 +28,9 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
...
@@ -28,8 +28,9 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
quantizable_op_type
=
[
"conv2d"
,
"mul"
]
quantizable_op_type
=
[
"conv2d"
,
"mul"
]
is_full_quantize
=
False
is_full_quantize
=
False
is_use_cache_file
=
False
is_use_cache_file
=
False
diff_threshold
=
0.025
self
.
run_test
(
model
,
algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
self
.
run_test
(
model
,
algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
)
is_full_quantize
,
is_use_cache_file
,
diff_threshold
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录