提交 3703f63f 编写于 作者: D dongshuilong

fix slim bugs

上级 eafcc864
# global configs
Global:
checkpoints: null
# pretrained_model: ./output/ResNet50_vd/epoch_29
pretrained_model: ./output/ResNet50_vd/best_model
pretrained_model: null
output_dir: ./output/
device: gpu
......@@ -15,19 +15,16 @@ Global:
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# for paddleslim
# for quantalization or prune model
Slim:
# for quantalization
# quant:
# name: pact
## for prune
prune:
name: fpgm
pruned_ratio: 0.3
name: fpgm
pruned_ratio: 0.3
# model architecture
Arch:
name: MobileNetV3_large_x1_0
name: ResNet50_vd
class_num: 1000
# loss function config for traing/eval process
......@@ -58,7 +55,7 @@ DataLoader:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train.txt
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
......@@ -89,7 +86,7 @@ DataLoader:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val.txt
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
......
# global configs
Global:
checkpoints: null
pretrained_model: ./output/ResNet50_vd/best_model
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 30
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# for quantalization or prune model
Slim:
## for quantalization
quant:
name: pact
# model architecture
Arch:
name: ResNet50_vd
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.1
regularizer:
name: 'L2'
coeff: 0.00007
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
batch_transform_ops:
- MixupOperator:
alpha: 0.2
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/whl/demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
Eval:
- TopkAcc:
topk: [1, 5]
## 介绍
复杂的模型有利于提高模型的性能,但也导致模型中存在一定冗余,模型量化将全精度缩减到定点数减少这种冗余,达到减少模型计算复杂度,提高模型推理性能的目的。
模型量化可以在基本不损失模型的精度的情况下,将FP32精度的模型参数转换为Int8精度,减小模型参数大小并加速计算,使用量化后的模型在移动端等部署时更具备速度优势。
本教程将介绍如何使用飞桨模型压缩库PaddleSlim做PaddleClas模型的压缩。
[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) 集成了模型剪枝、量化(包括量化训练和离线量化)、蒸馏和神经网络搜索等多种业界常用且领先的模型压缩功能,如果您感兴趣,可以关注并了解。
在开始本教程之前,建议先了解[PaddleClas模型的训练方法](../../../docs/zh_CN/tutorials/quick_start.md)以及[PaddleSlim](https://paddleslim.readthedocs.io/zh_CN/latest/index.html)
## 快速开始
量化多适用于轻量模型在移动端的部署,当训练出一个模型后,如果希望进一步的压缩模型大小并加速预测,可使用量化的方法压缩模型。
模型量化主要包括五个步骤:
1. 安装 PaddleSlim
2. 准备训练好的模型
3. 量化训练
4. 导出量化推理模型
5. 量化模型预测部署
### 1. 安装PaddleSlim
* 可以通过pip install的方式进行安装。
```bash
pip3.7 install paddleslim==2.0.0
```
* 如果获取PaddleSlim的最新特性,可以从源码安装。
```bash
git clone https://github.com/PaddlePaddle/PaddleSlim.git
cd Paddleslim
python3.7 setup.py install
```
### 2. 准备训练好的模型
PaddleClas提供了一系列训练好的[模型](../../../docs/zh_CN/models/models_intro.md),如果待量化的模型不在列表中,需要按照[常规训练](../../../docs/zh_CN/tutorials/getting_started.md)方法得到训练好的模型。
### 3. 量化训练
量化训练包括离线量化训练和在线量化训练,在线量化训练效果更好,需加载预训练模型,在定义好量化策略后即可对模型进行量化。
量化训练的代码位于`deploy/slim/quant/quant.py` 中,训练指令如下:
* CPU/单机单卡启动
```bash
python3.7 deploy/slim/quant/quant.py \
-c configs/MobileNetV3/MobileNetV3_large_x1_0.yaml \
-o pretrained_model="./MobileNetV3_large_x1_0_pretrained"
```
* 单机单卡/单机多卡/多机多卡启动
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python3.7 -m paddle.distributed.launch \
--gpus="0,1,2,3,4,5,6,7" \
deploy/slim/quant/quant.py \
-c configs/MobileNetV3/MobileNetV3_large_x1_0.yaml \
-o pretrained_model="./MobileNetV3_large_x1_0_pretrained"
```
* 下面是量化`MobileNetV3_large_x1_0`模型的训练示例脚本。
```bash
# 下载预训练模型
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x1_0_pretrained.pdparams
# 启动训练,这里如果因为显存限制,batch size无法设置过大,可以将batch size和learning rate同比例缩小。
python3.7 -m paddle.distributed.launch \
--gpus="0,1,2,3,4,5,6,7" \
deploy/slim/quant/quant.py \
-c configs/MobileNetV3/MobileNetV3_large_x1_0.yaml \
-o pretrained_model="./MobileNetV3_large_x1_0_pretrained"
-o LEARNING_RATE.params.lr=0.13 \
-o epochs=100
```
### 4. 导出模型
在得到量化训练保存的模型后,可以将其导出为inference model,用于预测部署:
```bash
python3.7 deploy/slim/quant/export_model.py \
-m MobileNetV3_large_x1_0 \
-p output/MobileNetV3_large_x1_0/best_model/ppcls \
-o ./MobileNetV3_large_x1_0_infer/ \
--img_size=224 \
--class_dim=1000
```
### 5. 量化模型部署
上述步骤导出的量化模型,参数精度仍然是FP32,但是参数的数值范围是int8,导出的模型可以通过PaddleLite的opt模型转换工具完成模型转换。
量化模型部署的可参考 [移动端模型部署](../../lite/readme.md)
## 量化训练超参数建议
* 量化训练时,建议加载常规训练得到的预训练模型,加速量化训练收敛。
* 量化训练时,建议初始学习率修改为常规训练的`1/20~1/10`,同时将训练epoch数修改为常规训练的`1/5~1/2`,学习率策略方面,加上Warmup,其他配置信息不建议修改。
## Introduction
Generally, a more complex model would achive better performance in the task, but it also leads to some redundancy in the model.
Quantization is a technique that reduces this redundancy by reducing the full precision data to a fixed number,
so as to reduce model calculation complexity and improve model inference performance.
This example uses PaddleSlim provided [APIs of Quantization](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/) to compress the PaddleClas models.
It is recommended that you could understand following pages before reading this example:
- [The training strategy of PaddleClas models](../../../docs/en/tutorials/quick_start_en.md)
- [PaddleSlim Document](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/)
## Quick Start
Quantization is mostly suitable for the deployment of lightweight models on mobile terminals.
After training, if you want to further compress the model size and accelerate the prediction, you can use quantization methods to compress the model according to the following steps.
1. Install PaddleSlim
2. Prepare trained model
3. Quantization-Aware Training
4. Export inference model
5. Deploy quantization inference model
### 1. Install PaddleSlim
* Install by pip.
```bash
pip3.7 install paddleslim==2.0.0
```
* Install from source code to get the lastest features.
```bash
git clone https://github.com/PaddlePaddle/PaddleSlim.git
cd Paddleslim
python setup.py install
```
### 2. Download Pretrain Model
PaddleClas provides a series of trained [models](../../../docs/en/models/models_intro_en.md).
If the model to be quantified is not in the list, you need to follow the [Regular Training](../../../docs/en/tutorials/getting_started_en.md) method to get the trained model.
### 3. Quant-Aware Training
Quantization training includes offline quantization training and online quantization training.
Online quantization training is more effective. It is necessary to load the pre-trained model.
After the quantization strategy is defined, the model can be quantified.
The code for quantization training is located in `deploy/slim/quant/quant.py`. The training command is as follow:
* CPU/Single GPU training
```bash
python3.7 deploy/slim/quant/quant.py \
-c configs/MobileNetV3/MobileNetV3_large_x1_0.yaml \
-o pretrained_model="./MobileNetV3_large_x1_0_pretrained"
```
* Distributed training
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python3.7 -m paddle.distributed.launch \
--gpus="0,1,2,3,4,5,6,7" \
deploy/slim/quant/quant.py \
-c configs/MobileNetV3/MobileNetV3_large_x1_0.yaml \
-o pretrained_model="./MobileNetV3_large_x1_0_pretrained"
```
* The command of quantizing `MobileNetV3_large_x1_0` model is as follow:
```bash
# download pre-trained model
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x1_0_pretrained.pdparams
# run training
python3.7 -m paddle.distributed.launch \
--gpus="0,1,2,3,4,5,6,7" \
deploy/slim/quant/quant.py \
-c configs/MobileNetV3/MobileNetV3_large_x1_0.yaml \
-o pretrained_model="./MobileNetV3_large_x1_0_pretrained"
-o LEARNING_RATE.params.lr=0.13 \
-o epochs=100
```
### 4. Export inference model
After getting the model quantization aware trained, we can export it as inference model for predictive deployment:
```bash
python3.7 deploy/slim/quant/export_model.py \
-m MobileNetV3_large_x1_0 \
-p output/MobileNetV3_large_x1_0/best_model/ppcls \
-o ./MobileNetV3_large_x1_0_infer/ \
--img_size=224 \
--class_dim=1000
```
### 5. Deploy
The type of quantized model's parameters derived from the above steps is still FP32, but the numerical range of the parameters is int8.
The derived model can be converted through the `opt tool` of PaddleLite.
For quantitative model deployment, please refer to [Mobile terminal model deployment](../../lite/readme_en.md)
## Notes:
* In quantitative training, it is suggested to load the pre-trained model obtained from conventional training to accelerate the convergence of quantitative training.
* In quantitative training, it is suggested that the initial learning rate should be changed to `1 / 20 ~ 1 / 10` of the conventional training, and the training epoch number should be changed to `1 / 5 ~ 1 / 2` of the conventional training. In terms of learning rate strategy, it's better to train with warmup, other configuration information is not recommended to be changed.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import os
import sys
import numpy as np
import paddle
import paddleslim
from paddle.jit import to_static
from paddleslim.analysis import dygraph_flops as flops
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
from paddleslim.dygraph.quant import QAT
from ppcls.data import build_dataloader
from ppcls.utils import config as conf
from ppcls.utils.logger import init_logger
def main():
args = conf.parse_args()
config = conf.get_config(args.config, overrides=args.override, show=False)
assert os.path.exists(
os.path.join(config["Global"]["save_inference_dir"],
'inference.pdmodel')) and os.path.exists(
os.path.join(config["Global"]["save_inference_dir"],
'inference.pdiparams'))
config["DataLoader"]["Train"]["sampler"]["batch_size"] = 1
config["DataLoader"]["Train"]["loader"]["num_workers"] = 0
init_logger()
device = paddle.set_device("cpu")
train_dataloader = build_dataloader(config["DataLoader"], "Train", device,
False)
def sample_generator(loader):
def __reader__():
for indx, data in enumerate(loader):
images = np.array(data[0])
yield images
return __reader__
paddle.enable_static()
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
paddleslim.quant.quant_post_static(
executor=exe,
model_dir=config["Global"]["save_inference_dir"],
model_filename='inference.pdmodel',
params_filename='inference.pdiparams',
quantize_model_path=os.path.join(
config["Global"]["save_inference_dir"], "quant_post_static_model"),
sample_generator=sample_generator(train_dataloader),
batch_nums=5)
if __name__ == "__main__":
main()
......@@ -18,9 +18,11 @@ import os
import sys
import paddle
import numpy as np
import paddleslim
from paddle.jit import to_static
from paddleslim.analysis import dygraph_flops as flops
import argparse
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
......@@ -29,6 +31,7 @@ from paddleslim.dygraph.quant import QAT
from ppcls.engine.trainer import Trainer
from ppcls.utils import config, logger
from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.data import build_dataloader
quant_config = {
# weight preprocess type, default is None and no preprocessing is performed.
......@@ -79,7 +82,7 @@ class Trainer_slim(Trainer):
else:
logger.info("FLOPs before pruning: {}GFLOPs".format(
flops(self.model, [1] + self.config["Global"][
"image_shape"]) / 1000000))
"image_shape"]) / 1e9))
self.model.eval()
if prune_config["name"].lower() == "fpgm":
......@@ -96,11 +99,6 @@ class Trainer_slim(Trainer):
if self.quanter is None and self.pruner is None:
logger.info("Training without slim")
def train(self):
super().train()
if self.config["Global"].get("save_inference_dir", None):
self.export_inference_model()
def export_inference_model(self):
if os.path.exists(
os.path.join(self.output_dir, self.config["Arch"]["name"],
......@@ -153,7 +151,7 @@ class Trainer_slim(Trainer):
logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
flops(self.model, [1] + self.config["Global"]["image_shape"]) /
1000000, plan.pruned_flops))
1e9, plan.pruned_flops))
for param in self.model.parameters():
if "conv2d" in param.name:
......@@ -162,9 +160,46 @@ class Trainer_slim(Trainer):
self.model.train()
def parse_args():
parser = argparse.ArgumentParser(
"generic-image-rec slim script, for train, eval and export inference model"
)
parser.add_argument(
'-c',
'--config',
type=str,
default='configs/config.yaml',
help='config file path')
parser.add_argument(
'-o',
'--override',
action='append',
default=[],
help='config options to be overridden')
parser.add_argument(
'-m',
'--mode',
type=str,
default='train',
choices=['train', 'eval', 'infer', 'export'],
help='the different function')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = config.parse_args()
args = parse_args()
config = config.get_config(
args.config, overrides=args.override, show=False)
trainer = Trainer_slim(config, mode="train")
trainer.train()
if args.mode == 'train':
trainer = Trainer_slim(config, mode="train")
trainer.train()
elif args.mode == 'eval':
trainer = Trainer_slim(config, mode="eval")
trainer.eval()
elif args.mode == 'infer':
trainer = Trainer_slim(config, mode="infer")
trainer.infer()
else:
trainer = Trainer_slim(config, mode="train")
trainer.export_inference_model()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册