Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
926227c8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
926227c8
编写于
2月 13, 2020
作者:
C
cc
提交者:
GitHub
2月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Post_training_quantization support set quant 8/16 bits (#22492) (#22577)
Post_training_quantization support set quant 8/16 bits
上级
baec7a35
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
37 addition
and
28 deletion
+37
-28
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
...d/contrib/slim/quantization/post_training_quantization.py
+29
-20
python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py
...ontrib/slim/tests/test_weight_quantization_mobilenetv1.py
+8
-8
未找到文件。
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
浏览文件 @
926227c8
...
@@ -54,17 +54,19 @@ def _set_variable_data(scope, place, var_name, np_value):
...
@@ -54,17 +54,19 @@ def _set_variable_data(scope, place, var_name, np_value):
class
PostTrainingQuantization
(
object
):
class
PostTrainingQuantization
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
executor
,
executor
=
None
,
s
ample_generator
,
s
cope
=
None
,
model_dir
,
model_dir
=
None
,
model_filename
=
None
,
model_filename
=
None
,
params_filename
=
None
,
params_filename
=
None
,
sample_generator
=
None
,
batch_size
=
10
,
batch_size
=
10
,
batch_nums
=
None
,
batch_nums
=
None
,
scope
=
None
,
algo
=
"KL"
,
algo
=
"KL"
,
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
],
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
],
is_full_quantize
=
False
,
is_full_quantize
=
False
,
weight_bits
=
8
,
activation_bits
=
8
,
is_use_cache_file
=
False
,
is_use_cache_file
=
False
,
cache_dir
=
"./temp_post_training"
):
cache_dir
=
"./temp_post_training"
):
'''
'''
...
@@ -76,9 +78,8 @@ class PostTrainingQuantization(object):
...
@@ -76,9 +78,8 @@ class PostTrainingQuantization(object):
Args:
Args:
executor(fluid.Executor): The executor to load, run and save the
executor(fluid.Executor): The executor to load, run and save the
quantized model.
quantized model.
sample_generator(Python Generator): The sample generator provides
scope(fluid.Scope, optional): The scope of the program, use it to load
calibrate data for DataLoader, and it only returns a sample every
and save variables. If scope=None, get scope by global_scope().
time.
model_dir(str): The path of the fp32 model that will be quantized,
model_dir(str): The path of the fp32 model that will be quantized,
and the model and params files are under the path.
and the model and params files are under the path.
model_filename(str, optional): The name of file to load the inference
model_filename(str, optional): The name of file to load the inference
...
@@ -88,12 +89,13 @@ class PostTrainingQuantization(object):
...
@@ -88,12 +89,13 @@ class PostTrainingQuantization(object):
When all parameters were saved in a single binary file, set it
When all parameters were saved in a single binary file, set it
as the real filename. If parameters were saved in separate files,
as the real filename. If parameters were saved in separate files,
set it as 'None'. Default is 'None'.
set it as 'None'. Default is 'None'.
sample_generator(Python Generator): The sample generator provides
calibrate data for DataLoader, and it only returns a sample every
time.
batch_size(int, optional): The batch size of DataLoader. Default is 10.
batch_size(int, optional): The batch size of DataLoader. Default is 10.
batch_nums(int, optional): If batch_nums is not None, the number of
batch_nums(int, optional): If batch_nums is not None, the number of
calibrate data is batch_size*batch_nums. If batch_nums is None, use
calibrate data is batch_size*batch_nums. If batch_nums is None, use
all data provided by sample_generator as calibrate data.
all data provided by sample_generator as calibrate data.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to
algo(str, optional): If algo=KL, use KL-divergenc method to
get the more precise scale factor. If algo='direct', use
get the more precise scale factor. If algo='direct', use
abs_max methon to get the scale factor. Default is KL.
abs_max methon to get the scale factor. Default is KL.
...
@@ -104,6 +106,8 @@ class PostTrainingQuantization(object):
...
@@ -104,6 +106,8 @@ class PostTrainingQuantization(object):
apply quantization to all supported quantizable op type. If set
apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type
is_full_quantized as False, only apply quantization to the op type
according to the input quantizable_op_type.
according to the input quantizable_op_type.
weight_bits(int, optional): quantization bit number for weights.
activation_bits(int): quantization bit number for activation.
is_use_cache_file(bool, optional): If set is_use_cache_file as False,
is_use_cache_file(bool, optional): If set is_use_cache_file as False,
all temp data will be saved in memory. If set is_use_cache_file as True,
all temp data will be saved in memory. If set is_use_cache_file as True,
it will save temp data to disk. When the fp32 model is complex or
it will save temp data to disk. When the fp32 model is complex or
...
@@ -150,14 +154,20 @@ class PostTrainingQuantization(object):
...
@@ -150,14 +154,20 @@ class PostTrainingQuantization(object):
ptq.quantize()
ptq.quantize()
ptq.save_quantized_model(save_model_path)
ptq.save_quantized_model(save_model_path)
'''
'''
assert
executor
is
not
None
,
"The executor cannot be None."
assert
model_dir
is
not
None
,
"The model_dir cannot be None."
assert
sample_generator
is
not
None
,
\
"The sample_generator cannot be None."
self
.
_executor
=
executor
self
.
_executor
=
executor
self
.
_s
ample_generator
=
sample_generator
self
.
_s
cope
=
global_scope
()
if
scope
==
None
else
scope
self
.
_model_dir
=
model_dir
self
.
_model_dir
=
model_dir
self
.
_model_filename
=
model_filename
self
.
_model_filename
=
model_filename
self
.
_params_filename
=
params_filename
self
.
_params_filename
=
params_filename
self
.
_sample_generator
=
sample_generator
self
.
_batch_size
=
batch_size
self
.
_batch_size
=
batch_size
self
.
_batch_nums
=
batch_nums
self
.
_batch_nums
=
batch_nums
self
.
_scope
=
global_scope
()
if
scope
==
None
else
scope
self
.
_algo
=
algo
self
.
_algo
=
algo
self
.
_is_use_cache_file
=
is_use_cache_file
self
.
_is_use_cache_file
=
is_use_cache_file
self
.
_cache_dir
=
cache_dir
self
.
_cache_dir
=
cache_dir
...
@@ -604,7 +614,7 @@ class WeightQuantization(object):
...
@@ -604,7 +614,7 @@ class WeightQuantization(object):
save_model_filename
=
None
,
save_model_filename
=
None
,
save_params_filename
=
None
,
save_params_filename
=
None
,
quantizable_op_type
=
[
"conv2d"
,
"mul"
],
quantizable_op_type
=
[
"conv2d"
,
"mul"
],
quantize_
weight_bits
=
8
,
weight_bits
=
8
,
threshold_rate
=
0.0
):
threshold_rate
=
0.0
):
'''
'''
In order to reduce the size of model, this api quantizes the weight
In order to reduce the size of model, this api quantizes the weight
...
@@ -624,8 +634,8 @@ class WeightQuantization(object):
...
@@ -624,8 +634,8 @@ class WeightQuantization(object):
that will be quantized, and the quantized ops should be
that will be quantized, and the quantized ops should be
contained in ["conv2d", "depthwise_conv2d", "mul"].
contained in ["conv2d", "depthwise_conv2d", "mul"].
Default is ["conv2d","mul"].
Default is ["conv2d","mul"].
quantize_weight_bits(int, optional): The bits for the quantized
weight_bits(int, optional): The bits for the quantized weight,
weight,
and it should be 8 or 16. Default is 8.
and it should be 8 or 16. Default is 8.
threshold_rate(float, optional): This api uses abs_max methd to
threshold_rate(float, optional): This api uses abs_max methd to
quantize the weight from float32 to int8/16, and the abs max
quantize the weight from float32 to int8/16, and the abs max
value is important for quantization diff. When the abs_max
value is important for quantization diff. When the abs_max
...
@@ -637,10 +647,10 @@ class WeightQuantization(object):
...
@@ -637,10 +647,10 @@ class WeightQuantization(object):
assert
op_type
in
self
.
_supported_quantizable_op_type
,
\
assert
op_type
in
self
.
_supported_quantizable_op_type
,
\
"input error:"
+
op_type
+
\
"input error:"
+
op_type
+
\
" is not supported for weight quantization."
" is not supported for weight quantization."
assert
quantize_
weight_bits
in
[
8
,
16
],
\
assert
weight_bits
in
[
8
,
16
],
\
"input error:
quantize_
weight_bits should be 8 or 16."
"input error: weight_bits should be 8 or 16."
quantize_range
=
(
1
<<
(
quantize_
weight_bits
-
1
))
-
1
quantize_range
=
(
1
<<
(
weight_bits
-
1
))
-
1
save_weight_dtype
=
np
.
int8
if
quantize_
weight_bits
==
8
else
np
.
int16
save_weight_dtype
=
np
.
int8
if
weight_bits
==
8
else
np
.
int16
place
=
core
.
CPUPlace
()
place
=
core
.
CPUPlace
()
exe
=
Executor
(
place
)
exe
=
Executor
(
place
)
...
@@ -677,8 +687,7 @@ class WeightQuantization(object):
...
@@ -677,8 +687,7 @@ class WeightQuantization(object):
_set_variable_data
(
scope
,
place
,
var_name
,
_set_variable_data
(
scope
,
place
,
var_name
,
quantized_var_tensor_data
)
quantized_var_tensor_data
)
op
.
_set_attr
(
var_name
+
"_quant_scale"
,
[
scale
])
op
.
_set_attr
(
var_name
+
"_quant_scale"
,
[
scale
])
op
.
_set_attr
(
'quantize_weight_bits'
,
op
.
_set_attr
(
'quantize_weight_bits'
,
weight_bits
)
quantize_weight_bits
)
io
.
save_inference_model
(
io
.
save_inference_model
(
dirname
=
save_model_dir
,
dirname
=
save_model_dir
,
...
...
python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py
浏览文件 @
926227c8
...
@@ -42,8 +42,8 @@ class TestWeightQuantization(unittest.TestCase):
...
@@ -42,8 +42,8 @@ class TestWeightQuantization(unittest.TestCase):
zip_path
)
zip_path
)
os
.
system
(
cmd
)
os
.
system
(
cmd
)
def
run_test
(
self
,
model_name
,
model_data_url
,
model_data_md5
,
def
run_test
(
self
,
model_name
,
model_data_url
,
model_data_md5
,
weight_bits
,
quantiz
e_weight_bits
,
quantiz
able_op_type
,
threshold_rate
):
quantizable_op_type
,
threshold_rate
):
model_dir
=
self
.
download_model
(
model_name
,
model_data_url
,
model_dir
=
self
.
download_model
(
model_name
,
model_data_url
,
model_data_md5
)
model_data_md5
)
...
@@ -51,11 +51,11 @@ class TestWeightQuantization(unittest.TestCase):
...
@@ -51,11 +51,11 @@ class TestWeightQuantization(unittest.TestCase):
timestamp
=
time
.
strftime
(
'%Y-%m-%d-%H-%M-%S'
,
time
.
localtime
())
timestamp
=
time
.
strftime
(
'%Y-%m-%d-%H-%M-%S'
,
time
.
localtime
())
save_model_dir
=
os
.
path
.
join
(
save_model_dir
=
os
.
path
.
join
(
os
.
getcwd
(),
os
.
getcwd
(),
model_name
+
"_wq_"
+
str
(
quantize_
weight_bits
)
+
"_"
+
timestamp
)
model_name
+
"_wq_"
+
str
(
weight_bits
)
+
"_"
+
timestamp
)
weight_quant
=
WeightQuantization
(
model_dir
=
model_dir
+
"/model"
)
weight_quant
=
WeightQuantization
(
model_dir
=
model_dir
+
"/model"
)
weight_quant
.
quantize_weight_to_int
(
weight_quant
.
quantize_weight_to_int
(
save_model_dir
=
save_model_dir
,
save_model_dir
=
save_model_dir
,
quantize_weight_bits
=
quantize_
weight_bits
,
weight_bits
=
weight_bits
,
quantizable_op_type
=
quantizable_op_type
,
quantizable_op_type
=
quantizable_op_type
,
threshold_rate
=
threshold_rate
)
threshold_rate
=
threshold_rate
)
print
(
"finish weight quantization for "
+
model_name
+
"
\n
"
)
print
(
"finish weight quantization for "
+
model_name
+
"
\n
"
)
...
@@ -73,18 +73,18 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
...
@@ -73,18 +73,18 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
model_data_md5
=
"13892b0716d26443a8cdea15b3c6438b"
model_data_md5
=
"13892b0716d26443a8cdea15b3c6438b"
def
test_weight_quantization_mobilenetv1_8bit
(
self
):
def
test_weight_quantization_mobilenetv1_8bit
(
self
):
quantize_
weight_bits
=
8
weight_bits
=
8
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
threshold_rate
=
0.0
threshold_rate
=
0.0
self
.
run_test
(
self
.
model_name
,
self
.
model_data_url
,
self
.
model_data_md5
,
self
.
run_test
(
self
.
model_name
,
self
.
model_data_url
,
self
.
model_data_md5
,
quantize_
weight_bits
,
quantizable_op_type
,
threshold_rate
)
weight_bits
,
quantizable_op_type
,
threshold_rate
)
def
test_weight_quantization_mobilenetv1_16bit
(
self
):
def
test_weight_quantization_mobilenetv1_16bit
(
self
):
quantize_
weight_bits
=
16
weight_bits
=
16
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
threshold_rate
=
1e-9
threshold_rate
=
1e-9
self
.
run_test
(
self
.
model_name
,
self
.
model_data_url
,
self
.
model_data_md5
,
self
.
run_test
(
self
.
model_name
,
self
.
model_data_url
,
self
.
model_data_md5
,
quantize_
weight_bits
,
quantizable_op_type
,
threshold_rate
)
weight_bits
,
quantizable_op_type
,
threshold_rate
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录