Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
e7a02b5c
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e7a02b5c
编写于
4月 16, 2021
作者:
X
XGZhang
提交者:
GitHub
4月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
changed post-quant methods (#713)
上级
8fad8d41
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
24 addition
and
10 deletion
+24
-10
demo/quant/quant_post/README.md
demo/quant/quant_post/README.md
+3
-3
demo/quant/quant_post/quant_post.py
demo/quant/quant_post/quant_post.py
+7
-3
paddleslim/quant/quanter.py
paddleslim/quant/quanter.py
+14
-4
未找到文件。
demo/quant/quant_post/README.md
浏览文件 @
e7a02b5c
...
...
@@ -43,7 +43,7 @@ python quant_post_static.py --model_path ./inference_model/MobileNet --save_path
运行以上命令后,可在
``${save_path}``
下看到量化后的模型文件和参数文件。
> 使用的量化算法为``'
KL'``, 使用训练集中的160
张图片进行量化参数的校正。
> 使用的量化算法为``'
hist'``, 使用训练集中的32
张图片进行量化参数的校正。
### 测试精度
...
...
@@ -67,6 +67,6 @@ python eval.py --model_path ./quant_model_train/MobileNet --model_name __model__
精度输出为
```
top1_acc/top5_acc= [0.70
141864 0.89086477
]
top1_acc/top5_acc= [0.70
328485 0.89183184
]
```
从以上精度对比可以看出,对
``mobilenet``
在
``imagenet``
上的分类模型进行离线量化后
``top1``
精度损失为
``0.
77%``
,
``top5``
精度损失为
``0.4
6%``
.
从以上精度对比可以看出,对
``mobilenet``
在
``imagenet``
上的分类模型进行离线量化后
``top1``
精度损失为
``0.
59%``
,
``top5``
精度损失为
``0.3
6%``
.
demo/quant/quant_post/quant_post.py
浏览文件 @
e7a02b5c
...
...
@@ -19,13 +19,15 @@ _logger = get_logger(__name__, level=logging.INFO)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
add_arg
=
functools
.
partial
(
add_arguments
,
argparser
=
parser
)
# yapf: disable
add_arg
(
'batch_size'
,
int
,
16
,
"Minibatch size."
)
add_arg
(
'batch_num'
,
int
,
1
0
,
"Batch number"
)
add_arg
(
'batch_size'
,
int
,
32
,
"Minibatch size."
)
add_arg
(
'batch_num'
,
int
,
1
,
"Batch number"
)
add_arg
(
'use_gpu'
,
bool
,
True
,
"Whether to use GPU or not."
)
add_arg
(
'model_path'
,
str
,
"./inference_model/MobileNet/"
,
"model dir"
)
add_arg
(
'save_path'
,
str
,
"./quant_model/MobileNet/"
,
"model dir to save quanted model"
)
add_arg
(
'model_filename'
,
str
,
None
,
"model file name"
)
add_arg
(
'params_filename'
,
str
,
None
,
"params file name"
)
add_arg
(
'algo'
,
str
,
'hist'
,
"calibration algorithm"
)
add_arg
(
'hist_percent'
,
float
,
0.9999
,
"The percentile of algo:hist"
)
# yapf: enable
...
...
@@ -46,7 +48,9 @@ def quantize(args):
model_filename
=
args
.
model_filename
,
params_filename
=
args
.
params_filename
,
batch_size
=
args
.
batch_size
,
batch_nums
=
args
.
batch_num
)
batch_nums
=
args
.
batch_num
,
algo
=
args
.
algo
,
hist_percent
=
args
.
hist_percent
)
def
main
():
...
...
paddleslim/quant/quanter.py
浏览文件 @
e7a02b5c
...
...
@@ -313,7 +313,9 @@ def quant_post_static(
batch_size
=
16
,
batch_nums
=
None
,
scope
=
None
,
algo
=
'KL'
,
algo
=
'hist'
,
hist_percent
=
0.9999
,
bias_correction
=
False
,
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
],
is_full_quantize
=
False
,
weight_bits
=
8
,
...
...
@@ -358,9 +360,15 @@ def quant_post_static(
generated by sample_generator as calibrate data.
scope(paddle.static.Scope, optional): The scope to run program, use it to load
and save variables. If scope is None, will use paddle.static.global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to
get the more precise scale factor. If algo='direct', use
abs_max method to get the scale factor. Default: 'KL'.
algo(str, optional): If algo='KL', use KL-divergenc method to
get the scale factor. If algo='hist', use the hist_percent of histogram
to get the scale factor. If algo='mse', search for the best scale factor which
makes the mse loss minimal. Use one batch of data for mse is enough. If
algo='avg', use the average of abs_max values to get the scale factor. If
algo='abs_max', use abs_max method to get the scale factor. Default: 'hist'.
hist_percent(float, optional): The percentile of histogram for algo hist.Default:0.9999.
bias_correction(bool, optional): Bias correction method of https://arxiv.org/abs/1810.05723.
Default: False.
quantizable_op_type(list[str], optional): The list of op types
that will be quantized. Default: ["conv2d", "depthwise_conv2d",
"mul"].
...
...
@@ -397,6 +405,8 @@ def quant_post_static(
batch_nums
=
batch_nums
,
scope
=
scope
,
algo
=
algo
,
hist_percent
=
hist_percent
,
bias_correction
=
bias_correction
,
quantizable_op_type
=
quantizable_op_type
,
is_full_quantize
=
is_full_quantize
,
weight_bits
=
weight_bits
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录