Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
926227c8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
926227c8
编写于
2月 13, 2020
作者:
C
cc
提交者:
GitHub
2月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Post_training_quantization support set quant 8/16 bits (#22492) (#22577)
Post_training_quantization support set quant 8/16 bits
上级
baec7a35
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
37 addition
and
28 deletion
+37
-28
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
...d/contrib/slim/quantization/post_training_quantization.py
+29
-20
python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py
...ontrib/slim/tests/test_weight_quantization_mobilenetv1.py
+8
-8
未找到文件。
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
浏览文件 @
926227c8
...
...
@@ -54,17 +54,19 @@ def _set_variable_data(scope, place, var_name, np_value):
class
PostTrainingQuantization
(
object
):
def
__init__
(
self
,
executor
,
s
ample_generator
,
model_dir
,
executor
=
None
,
s
cope
=
None
,
model_dir
=
None
,
model_filename
=
None
,
params_filename
=
None
,
sample_generator
=
None
,
batch_size
=
10
,
batch_nums
=
None
,
scope
=
None
,
algo
=
"KL"
,
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
],
is_full_quantize
=
False
,
weight_bits
=
8
,
activation_bits
=
8
,
is_use_cache_file
=
False
,
cache_dir
=
"./temp_post_training"
):
'''
...
...
@@ -76,9 +78,8 @@ class PostTrainingQuantization(object):
Args:
executor(fluid.Executor): The executor to load, run and save the
quantized model.
sample_generator(Python Generator): The sample generator provides
calibrate data for DataLoader, and it only returns a sample every
time.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
model_dir(str): The path of the fp32 model that will be quantized,
and the model and params files are under the path.
model_filename(str, optional): The name of file to load the inference
...
...
@@ -88,12 +89,13 @@ class PostTrainingQuantization(object):
When all parameters were saved in a single binary file, set it
as the real filename. If parameters were saved in separate files,
set it as 'None'. Default is 'None'.
sample_generator(Python Generator): The sample generator provides
calibrate data for DataLoader, and it only returns a sample every
time.
batch_size(int, optional): The batch size of DataLoader. Default is 10.
batch_nums(int, optional): If batch_nums is not None, the number of
calibrate data is batch_size*batch_nums. If batch_nums is None, use
all data provided by sample_generator as calibrate data.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to
get the more precise scale factor. If algo='direct', use
abs_max methon to get the scale factor. Default is KL.
...
...
@@ -104,6 +106,8 @@ class PostTrainingQuantization(object):
apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type
according to the input quantizable_op_type.
weight_bits(int, optional): quantization bit number for weights.
activation_bits(int): quantization bit number for activation.
is_use_cache_file(bool, optional): If set is_use_cache_file as False,
all temp data will be saved in memory. If set is_use_cache_file as True,
it will save temp data to disk. When the fp32 model is complex or
...
...
@@ -150,14 +154,20 @@ class PostTrainingQuantization(object):
ptq.quantize()
ptq.save_quantized_model(save_model_path)
'''
assert
executor
is
not
None
,
"The executor cannot be None."
assert
model_dir
is
not
None
,
"The model_dir cannot be None."
assert
sample_generator
is
not
None
,
\
"The sample_generator cannot be None."
self
.
_executor
=
executor
self
.
_s
ample_generator
=
sample_generator
self
.
_s
cope
=
global_scope
()
if
scope
==
None
else
scope
self
.
_model_dir
=
model_dir
self
.
_model_filename
=
model_filename
self
.
_params_filename
=
params_filename
self
.
_sample_generator
=
sample_generator
self
.
_batch_size
=
batch_size
self
.
_batch_nums
=
batch_nums
self
.
_scope
=
global_scope
()
if
scope
==
None
else
scope
self
.
_algo
=
algo
self
.
_is_use_cache_file
=
is_use_cache_file
self
.
_cache_dir
=
cache_dir
...
...
@@ -604,7 +614,7 @@ class WeightQuantization(object):
save_model_filename
=
None
,
save_params_filename
=
None
,
quantizable_op_type
=
[
"conv2d"
,
"mul"
],
quantize_
weight_bits
=
8
,
weight_bits
=
8
,
threshold_rate
=
0.0
):
'''
In order to reduce the size of model, this api quantizes the weight
...
...
@@ -624,8 +634,8 @@ class WeightQuantization(object):
that will be quantized, and the quantized ops should be
contained in ["conv2d", "depthwise_conv2d", "mul"].
Default is ["conv2d","mul"].
quantize_weight_bits(int, optional): The bits for the quantized
weight,
and it should be 8 or 16. Default is 8.
weight_bits(int, optional): The bits for the quantized weight,
and it should be 8 or 16. Default is 8.
threshold_rate(float, optional): This api uses abs_max methd to
quantize the weight from float32 to int8/16, and the abs max
value is important for quantization diff. When the abs_max
...
...
@@ -637,10 +647,10 @@ class WeightQuantization(object):
assert
op_type
in
self
.
_supported_quantizable_op_type
,
\
"input error:"
+
op_type
+
\
" is not supported for weight quantization."
assert
quantize_
weight_bits
in
[
8
,
16
],
\
"input error:
quantize_
weight_bits should be 8 or 16."
quantize_range
=
(
1
<<
(
quantize_
weight_bits
-
1
))
-
1
save_weight_dtype
=
np
.
int8
if
quantize_
weight_bits
==
8
else
np
.
int16
assert
weight_bits
in
[
8
,
16
],
\
"input error: weight_bits should be 8 or 16."
quantize_range
=
(
1
<<
(
weight_bits
-
1
))
-
1
save_weight_dtype
=
np
.
int8
if
weight_bits
==
8
else
np
.
int16
place
=
core
.
CPUPlace
()
exe
=
Executor
(
place
)
...
...
@@ -677,8 +687,7 @@ class WeightQuantization(object):
_set_variable_data
(
scope
,
place
,
var_name
,
quantized_var_tensor_data
)
op
.
_set_attr
(
var_name
+
"_quant_scale"
,
[
scale
])
op
.
_set_attr
(
'quantize_weight_bits'
,
quantize_weight_bits
)
op
.
_set_attr
(
'quantize_weight_bits'
,
weight_bits
)
io
.
save_inference_model
(
dirname
=
save_model_dir
,
...
...
python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py
浏览文件 @
926227c8
...
...
@@ -42,8 +42,8 @@ class TestWeightQuantization(unittest.TestCase):
zip_path
)
os
.
system
(
cmd
)
def
run_test
(
self
,
model_name
,
model_data_url
,
model_data_md5
,
quantiz
e_weight_bits
,
quantiz
able_op_type
,
threshold_rate
):
def
run_test
(
self
,
model_name
,
model_data_url
,
model_data_md5
,
weight_bits
,
quantizable_op_type
,
threshold_rate
):
model_dir
=
self
.
download_model
(
model_name
,
model_data_url
,
model_data_md5
)
...
...
@@ -51,11 +51,11 @@ class TestWeightQuantization(unittest.TestCase):
timestamp
=
time
.
strftime
(
'%Y-%m-%d-%H-%M-%S'
,
time
.
localtime
())
save_model_dir
=
os
.
path
.
join
(
os
.
getcwd
(),
model_name
+
"_wq_"
+
str
(
quantize_
weight_bits
)
+
"_"
+
timestamp
)
model_name
+
"_wq_"
+
str
(
weight_bits
)
+
"_"
+
timestamp
)
weight_quant
=
WeightQuantization
(
model_dir
=
model_dir
+
"/model"
)
weight_quant
.
quantize_weight_to_int
(
save_model_dir
=
save_model_dir
,
quantize_weight_bits
=
quantize_
weight_bits
,
weight_bits
=
weight_bits
,
quantizable_op_type
=
quantizable_op_type
,
threshold_rate
=
threshold_rate
)
print
(
"finish weight quantization for "
+
model_name
+
"
\n
"
)
...
...
@@ -73,18 +73,18 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
model_data_md5
=
"13892b0716d26443a8cdea15b3c6438b"
def
test_weight_quantization_mobilenetv1_8bit
(
self
):
quantize_
weight_bits
=
8
weight_bits
=
8
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
threshold_rate
=
0.0
self
.
run_test
(
self
.
model_name
,
self
.
model_data_url
,
self
.
model_data_md5
,
quantize_
weight_bits
,
quantizable_op_type
,
threshold_rate
)
weight_bits
,
quantizable_op_type
,
threshold_rate
)
def
test_weight_quantization_mobilenetv1_16bit
(
self
):
quantize_
weight_bits
=
16
weight_bits
=
16
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
threshold_rate
=
1e-9
self
.
run_test
(
self
.
model_name
,
self
.
model_data_url
,
self
.
model_data_md5
,
quantize_
weight_bits
,
quantizable_op_type
,
threshold_rate
)
weight_bits
,
quantizable_op_type
,
threshold_rate
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录