Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
d1fe71fb
M
models
项目概览
PaddlePaddle
/
models
1 年多 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d1fe71fb
编写于
1月 04, 2022
作者:
G
Guanghua Yu
提交者:
GitHub
1月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add PACT tipc docs (#5456)
上级
e78fb7e2
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
351 addition
and
2 deletion
+351
-2
tutorials/mobilenetv3_prod/Step6/docs/train_pact_infer_python.md
...ls/mobilenetv3_prod/Step6/docs/train_pact_infer_python.md
+83
-0
tutorials/mobilenetv3_prod/Step6/requirements.txt
tutorials/mobilenetv3_prod/Step6/requirements.txt
+1
-0
tutorials/mobilenetv3_prod/Step6/train.py
tutorials/mobilenetv3_prod/Step6/train.py
+54
-0
tutorials/tipc/images/quant_aware_training_guide.png
tutorials/tipc/images/quant_aware_training_guide.png
+0
-0
tutorials/tipc/ptq_infer_python/ptq_infer_python.md
tutorials/tipc/ptq_infer_python/ptq_infer_python.md
+0
-2
tutorials/tipc/train_pact_infer_python/README.md
tutorials/tipc/train_pact_infer_python/README.md
+11
-0
tutorials/tipc/train_pact_infer_python/test_ptq_infer_python.md
...als/tipc/train_pact_infer_python/test_ptq_infer_python.md
+7
-0
tutorials/tipc/train_pact_infer_python/train_pact_infer_python.md
...s/tipc/train_pact_infer_python/train_pact_infer_python.md
+195
-0
未找到文件。
tutorials/mobilenetv3_prod/Step6/docs/train_pact_infer_python.md
0 → 100644
浏览文件 @
d1fe71fb
# MobileNetV3
## 目录
-
[
1. 简介
](
#1
)
-
[
2. PACT量化训练
](
#2
)
-
[
2.1 检查环境
](
#2.1
)
-
[
2.2 开始量化训练
](
#2.2
)
-
[
2.3 验证量化模型指标
](
#2.3
)
-
[
3. FAQ
](
#3
)
<a
name=
"1"
></a>
## 1. 简介
Paddle 量化训练(Quant-aware Training, QAT)是指在训练过程中对模型的权重及激活做模拟量化,并且产出量化训练校准后的量化模型,使用该量化模型进行预测,可以减少计算量、降低计算内存、减小模型大小。
本文档主要基于Paddle的MobileNetV3模型进量化训练。
更多关于PaddleSlim 量化的介绍,可以参考
[
PaddleSlim 量化训练官网教程
](
https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/dygraph/quanter/qat.rst#%E9%87%8F%E5%8C%96%E8%AE%AD%E7%BB%83
)
。
<a
name=
"2"
></a>
## 2. PACT量化训练
<a
name=
"2.1"
></a>
### 2.1 准备环境
-
PACT量化训练依赖于PaddleSlim,需要事先安装PaddlePaddle和PaddleSlim:
```
shell
pip
install
paddlepaddle-gpu
==
2.2.0
pip
install
paddleslim
==
2.2.1
```
-
数据准备请参考
[
数据准备文档
](
https://github.com/PaddlePaddle/models/tree/release/2.2/tutorials/mobilenetv3_prod/Step6#32-%E5%87%86%E5%A4%87%E6%95%B0%E6%8D%AE
)
。
-
准备训好的FP32模型
可以通过
[
模型训练文档
](
https://github.com/PaddlePaddle/models/tree/release/2.2/tutorials/mobilenetv3_prod/Step6#41-%E6%A8%A1%E5%9E%8B%E8%AE%AD%E7%BB%83
)
准备好训好的模型权重,也可以直接下载预训练模型:
```
shell
# 下载预训练模型
wget https://paddle-model-ecology.bj.bcebos.com/model/mobilenetv3_reprod/mobilenet_v3_small_pretrained.pdparams
```
<a
name=
"2.2"
></a>
### 2.2 开始量化训练
使用如下命令开启单机单卡混合精度O1训练:
```
bash
python3 train.py
--data-path
=
./ILSVRC2012
\
--lr
=
0.001
--batch-size
=
64
\
--epochs
=
10
\
--pretrained
=
mobilenet_v3_small_pretrained.pdparams
\
--pact_quant
```
部分训练日志如下:
```
[Epoch 0, iter: 0] top1: 0.43750, top5: 0.73438, lr: 0.00100, loss: 2.49211, avg_reader_cost: 2.37156 sec, avg_batch_cost: 3.47348 sec, avg_samples: 64.0, avg_ips: 18.42531 images/sec.
[Epoch 0, iter: 10] top1: 0.48594, top5: 0.70625, lr: 0.00100, loss: 2.24239, avg_reader_cost: 0.00026 sec, avg_batch_cost: 0.41364 sec, avg_samples: 64.0, avg_ips: 154.72471 images/sec.
[Epoch 0, iter: 20] top1: 0.45781, top5: 0.69531, lr: 0.00100, loss: 2.36400, avg_reader_cost: 0.00056 sec, avg_batch_cost: 0.42063 sec, avg_samples: 64.0, avg_ips: 152.15310 images/sec.
```
<a
name=
"2.3"
></a>
### 2.3 验证量化模型指标
训练完成量化模型,会在
`output_dir`
路径下生成
`qat_inference.pdmodel`
和
`qat_inference.pdiparams`
的Inference模型,可以直接使用Paddle Inference进行预测部署,或者导出Paddle Lite格式进行部署。
为了验证量化后的模型精度或指标,可以参考
[
量化模型精度验证文档
](
https://github.com/PaddlePaddle/models/tree/release/2.2/tutorials/mobilenetv3_prod/Step6/deploy/ptq_python#23-%E9%AA%8C%E8%AF%81%E6%8E%A8%E7%90%86%E7%BB%93%E6%9E%9C
)
进行指标或模型效果的验证。
** 注意:**
需要将
`--model_filename`
指定为
`qat_inference.pdmodel`
,将
`--params_filename`
指定为
`qat_inference.pdiparams`
。
<a
name=
"3"
></a>
## 3. FAQ
tutorials/mobilenetv3_prod/Step6/requirements.txt
浏览文件 @
d1fe71fb
reprod-log>=1.0.1
paddleslim>=2.2.1
tutorials/mobilenetv3_prod/Step6/train.py
100755 → 100644
浏览文件 @
d1fe71fb
...
...
@@ -5,6 +5,7 @@ import time
import
paddle
from
paddle
import
nn
from
paddle.static
import
InputSpec
import
paddlevision
import
presets
...
...
@@ -190,6 +191,42 @@ def main(args):
model
=
paddlevision
.
models
.
__dict__
[
args
.
model
](
pretrained
=
args
.
pretrained
)
if
args
.
pact_quant
:
try
:
from
paddleslim.dygraph.quant
import
QAT
except
Exception
as
e
:
print
(
'Unable to QAT, please install paddleslim, for example: `pip install paddleslim`'
)
return
quant_config
=
{
# activation preprocess type, default is None and no preprocessing is performed.
'activation_preprocess_type'
:
'PACT'
,
# weight preprocess type, default is None and no preprocessing is performed.
'weight_preprocess_type'
:
None
,
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type'
:
'channel_wise_abs_max'
,
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type'
:
'moving_average_abs_max'
,
# weight quantize bit num, default is 8
'weight_bits'
:
8
,
# activation quantize bit num, default is 8
'activation_bits'
:
8
,
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype'
:
'int8'
,
# window size for 'range_abs_max' quantization. default is 10000
'window_size'
:
10000
,
# The decay coefficient of moving average, default is 0.9
'moving_rate'
:
0.9
,
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized
'quantizable_layer_type'
:
[
'Conv2D'
,
'Linear'
],
}
quanter
=
QAT
(
config
=
quant_config
)
quanter
.
quantize
(
model
)
print
(
"Quanted model"
)
criterion
=
nn
.
CrossEntropyLoss
()
lr_scheduler
=
paddle
.
optimizer
.
lr
.
StepDecay
(
...
...
@@ -265,6 +302,17 @@ def main(args):
paddle
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
args
.
output_dir
,
'best.pdopt'
))
if
args
.
pact_quant
:
input_spec
=
[
InputSpec
(
shape
=
[
None
,
3
,
224
,
224
],
dtype
=
'float32'
)
]
quanter
.
save_quantized_model
(
model
,
os
.
path
.
join
(
args
.
output_dir
,
"qat_inference"
),
input_spec
=
input_spec
)
print
(
"QAT inference model saved in {args.output_dir}"
)
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
'Training time {}'
.
format
(
total_time_str
))
...
...
@@ -362,6 +410,12 @@ def get_args_parser(add_help=True):
'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
)
# Quant aware training parameters
parser
.
add_argument
(
'--pact_quant'
,
action
=
'store_true'
,
help
=
'Use pact for quant aware training'
)
return
parser
...
...
tutorials/tipc/images/quant_aware_training_guide.png
0 → 100644
浏览文件 @
d1fe71fb
314.5 KB
tutorials/tipc/ptq_infer_python/ptq_infer_python.md
浏览文件 @
d1fe71fb
...
...
@@ -192,7 +192,6 @@ predictor = paddle_infer.create_predictor(config)
-
Step2:配置预测库输入输出
```
python
```
python
input_names
=
predictor
.
get_input_names
()
input_handle
=
predictor
.
get_input_handle
(
input_names
[
0
])
...
...
@@ -202,7 +201,6 @@ output_handle = predictor.get_output_handle(output_names[0])
-
Step3:开始预测并检验结果正确性
```
python
```
python
input_handle
.
copy_from_cpu
(
img_np
)
predictor
.
run
()
...
...
tutorials/tipc/train_pact_infer_python/README.md
0 → 100644
浏览文件 @
d1fe71fb
# Linux GPU/CPU PACT量化训练开发文档
# 目录
-
[
1. 简介
](
#1
)
-
[
2. Linux GPU/CPU 量化训练功能开发文档
](
#2
)
-
[
2.1 开发流程
](
#2.1
)
-
[
2.2 核验点
](
#2.2
)
-
[
3. Linux GPU/CPU 量化训练测试开发与规范
](
#3
)
-
[
3.1 开发流程
](
#3.1
)
-
[
3.2 核验点
](
#3.2
)
tutorials/tipc/train_pact_infer_python/test_ptq_infer_python.md
0 → 100644
浏览文件 @
d1fe71fb
# Linux GPU/CPU PACT量化训练测试开发与规范
# 目录
-
[
1. 简介
](
#1
)
-
[
2. 测试流程
](
#2
)
-
[
3. FAQ
](
#3
)
tutorials/tipc/train_pact_infer_python/train_pact_infer_python.md
0 → 100644
浏览文件 @
d1fe71fb
# Linux GPU/CPU PACT量化训练功能开发文档
# 目录
-
[
1. 简介
](
#1
)
-
[
2. 量化训练功能开发
](
#2
)
-
[
2.1 准备数据和环境
](
#2.1
)
-
[
2.2 准备待量化模型
](
#2.2
)
-
[
2.3 开始量化训练及保存模型
](
#2.3
)
-
[
2.4 验证推理结果正确性
](
#2.4
)
-
[
3. FAQ
](
#3
)
-
[
3.1 通用问题
](
#3.1
)
<a
name=
"1"
></a>
## 1. 简介
Paddle 量化训练(Quant-aware Training, QAT)是指在训练过程中对模型的权重及激活做模拟量化,并且产出量化训练校准后的量化模型,使用该量化模型进行预测,可以减少计算量、降低计算内存、减小模型大小。
更多关于PaddleSlim 量化的介绍,可以参考
[
PaddleSlim 量化训练官网教程
](
https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/dygraph/quanter/qat.rst#%E9%87%8F%E5%8C%96%E8%AE%AD%E7%BB%83
)
。
<a
name=
"2"
></a>
## 2. 量化训练功能开发
Paddle 混合精度训练开发可以分为4个步骤,如下图所示。
<div
align=
"center"
>
<img
src=
"../images/quant_aware_training_guide.png"
width=
"600"
>
</div>
其中设置了2个核验点,分别为:
*
准备待量化模型
*
验证量化模型推理结果正确性
<a
name=
"2.1"
></a>
### 2.1 准备数据和环境
**【准备校准数据】**
将
``ImageNet``
数据集解压在
``data``
文件夹下,解压后
``data/ILSVRC2012``
文件夹下应包含以下文件:
-
``'train'``
文件夹,训练图片
-
``'train_list.txt'``
文件
-
``'val'``
文件夹,验证图片
-
``'val_list.txt'``
文件
选择适量训练集或验证集
**【准备开发环境】**
-
确定已安装paddle,通过pip安装linux版本paddle命令如下,更多的版本安装方法可查看飞桨
[
官网
](
https://www.paddlepaddle.org.cn/
)
-
确定已安装paddleslim,通过pip安装linux版本paddle命令如下,更多的版本安装方法可查看
[
PaddleSlim
](
https://github.com/PaddlePaddle/PaddleSlim
)
```
pip install paddlepaddle-gpu==2.2.1.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
pip install paddleslim==2.2.1
```
<a
name=
"2.2"
></a>
### 2.2 准备待量化模型
**【基本流程】**
-
Step1:定义继承自
`paddle.nn.Layer`
的网络模型
**【实战】**
模型组网可以参考
[
mobilenet_v3
](
https://github.com/PaddlePaddle/models/blob/release/2.2/tutorials/mobilenetv3_prod/Step6/paddlevision/models/mobilenet_v3.py
)
```
python
fp32_model
=
mobilenet_v3_small
()
```
<a
name=
"2.3"
></a>
### 2.3 开始量化训练及保存模型
**【基本流程】**
使用飞桨PaddleSlim中的
`QAT`
接口开始进行量化训练:
-
Step1:配置量化训练参数。
```
python
quant_config
=
{
'weight_preprocess_type'
:
None
,
'activation_preprocess_type'
:
None
,
'weight_quantize_type'
:
'channel_wise_abs_max'
,
'activation_quantize_type'
:
'moving_average_abs_max'
,
'weight_bits'
:
8
,
'activation_bits'
:
8
,
'dtype'
:
'int8'
,
'window_size'
:
10000
,
'moving_rate'
:
0.9
,
'quantizable_layer_type'
:
[
'Conv2D'
,
'Linear'
],
}
```
-
`activation_preprocess_type`
':代表对量化模型激活值预处理的方法,目前支持PACT方法,如需使用可以改为'PACT';默认为None,代表不对激活值进行任何预处理。
-
`weight_preprocess_type`
:代表对量化模型权重参数预处理的方法;默认为None,代表不对权重进行任何预处理。
-
`weight_quantize_type`
:代表模型权重的量化方式,可选的有['abs_max', 'moving_average_abs_max', 'channel_wise_abs_max'],默认为channel_wise_abs_max
-
`activation_quantize_type`
:代表模型激活值的量化方式,可选的有['abs_max', 'moving_average_abs_max'],默认为moving_average_abs_max
-
`quantizable_layer_type`
:代表量化OP的类型,目前支持Conv2D和Linear
-
Step2:插入量化算子,得到量化训练模型
```
python
from
paddleslim.dygraph.quant
import
QAT
quanter
=
QAT
(
config
=
quant_config
)
quanter
.
quantize
(
net
)
```
-
Step3:开始训练。
-
Step4:量化训练结束,保存量化模型
```
python
quanter
.
save_quantized_model
(
net
,
'save_dir'
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
224
,
224
],
dtype
=
'float32'
)])
```
**【实战】**
量化训练配置、训练及保存量化模型请参考
[
MobileNetv3量化训练文档
](
https://github.com/PaddlePaddle/models/tree/release/2.2/tutorials/mobilenetv3_prod/Step6/docs/train_pact_infer_python.md
)
<a
name=
"2.4"
></a>
### 2.4 验证推理结果正确性
**【基本流程】**
使用Paddle Inference库测试离线量化模型,确保模型精度符合预期。
-
Step1:初始化
`paddle.inference`
库并配置相应参数
```
python
import
paddle.inference
as
paddle_infer
model_file
=
os
.
path
.
join
(
'quant_model'
,
'qat_inference.pdmodel'
)
params_file
=
os
.
path
.
join
(
'quant_model'
,
'qat_inference.pdiparams'
)
config
=
paddle_infer
.
Config
(
model_file
,
params_file
)
if
FLAGS
.
use_gpu
:
config
.
enable_use_gpu
(
1000
,
0
)
if
not
FLAGS
.
ir_optim
:
config
.
switch_ir_optim
(
False
)
predictor
=
paddle_infer
.
create_predictor
(
config
)
```
-
Step2:配置预测库输入输出
```
python
input_names
=
predictor
.
get_input_names
()
input_handle
=
predictor
.
get_input_handle
(
input_names
[
0
])
output_names
=
predictor
.
get_output_names
()
output_handle
=
predictor
.
get_output_handle
(
output_names
[
0
])
```
-
Step3:开始预测并检验结果正确性
```
python
input_handle
.
copy_from_cpu
(
img_np
)
predictor
.
run
()
output_data
=
output_handle
.
copy_to_cpu
()
```
**【实战】**
1)初始化
`paddle.inference`
库并配置相应参数:
具体可以参考MobileNetv3
[
Inference模型测试代码
](
https://github.com/PaddlePaddle/models/tree/release/2.2/tutorials/mobilenetv3_prod/Step6/deploy/ptq_python/eval.py
)
2)配置预测库输入输出:
具体可以参考MobileNetv3
[
Inference模型测试代码
](
https://github.com/PaddlePaddle/models/tree/release/2.2/tutorials/mobilenetv3_prod/Step6/deploy/ptq_python/eval.py
)
3)开始预测:
具体可以参考MobileNetv3
[
Inference模型测试代码
](
https://github.com/PaddlePaddle/models/tree/release/2.2/tutorials/mobilenetv3_prod/Step6/deploy/ptq_python/eval.py
)
4)测试单张图像预测结果是否正确,可参考
[
Inference预测文档
](
https://github.com/PaddlePaddle/models/blob/release/2.2/docs/tipc/train_infer_python/infer_python.md
)
5)同时也可以测试量化模型和FP32模型的精度,确保量化后模型精度损失符合预期。参考
[
MobileNet量化模型精度验证文档
](
https://github.com/PaddlePaddle/models/tree/release/2.2/tutorials/mobilenetv3_prod/Step6/deploy/ptq_python/README.md
)
<a
name=
"3"
></a>
## 3. FAQ
### 3.1 通用问题
如果您在使用该文档完成PACT量化训练的过程中遇到问题,可以给在
[
这里
](
https://github.com/PaddlePaddle/PaddleSlim/issues
)
提一个ISSUE,我们会高优跟进。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录