未验证 提交 04e3316f 编写于 作者: G Guanghua Yu 提交者: GitHub

add PP-YOLOE auto commpression (#1090)

上级 d2bd7c18
# 目标检测模型自动压缩
预测模型保存接口:
动态图使用``paddle.jit.save``保存;
静态图使用``paddle.static.save_inference_model``保存。
本示例将介绍如何使用PaddleDetection中预测模型进行蒸馏量化训练。
## 模型量化蒸馏训练流程
### 1. 准备COCO格式数据
参考[COCO数据准备文档](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md#coco%E6%95%B0%E6%8D%AE)
### 2. 准备需要量化的环境
- PaddlePaddle >= 2.2
- PaddleDet >= 2.3
```shell
pip install paddledet
```
#### 3 准备待量化模型
- 下载代码
```
git clone https://github.com/PaddlePaddle/PaddleDetection.git
```
- 导出预测模型
```shell
python tools/export_model.py -c configs/yolov3/yolov3_mobilenet_v1_270e_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v1_270e_coco.pdparams
```
或直接下载:
```shell
wget https://paddledet.bj.bcebos.com/models/slim/yolov3_mobilenet_v1_270e_coco.tar
tar -xf yolov3_mobilenet_v1_270e_coco.tar
```
#### 2.4 测试模型精度
拷贝``yolov3_mobilenet_v1_270e_coco``文件夹到``PaddleSlim/demo/auto-compression/``文件夹。
```
cd PaddleSlim/demo/auto-compression/
```
使用[run_main.py](run_main.py)脚本得到模型的mAP:
```
python3.7 run_main.py --config_path='./configs/yolov3_mbv1_qat_dis.yaml --eval=True
```
### 3. 进行多策略融合压缩
每一个小章节代表一种多策略融合压缩,不代表需要串行执行。
### 3.1 进行量化蒸馏压缩
蒸馏量化训练示例脚本为[run_main.py](run_main.py),使用接口``paddleslim.auto_compression.AutoCompression``对模型进行量化训练。运行命令为:
```
python run_main.py --config_path='./configs/yolov3_mbv1_qat_dis.yaml --save_dir='./output/' --devices='gpu'
```
......@@ -50,9 +50,9 @@ python tools/export_model.py \
-o Global.save_inference_dir=infermodel_mobilenetv2
```
#### 2.4 测试模型精度
拷贝``infermodel_mobilenetv2``文件夹到``PaddleSlim/demo/auto-compression/``文件夹。
拷贝``infermodel_mobilenetv2``文件夹到``PaddleSlim/demo/auto_compression/``文件夹。
```
cd PaddleSlim/demo/auto-compression/
cd PaddleSlim/demo/auto_compression/
```
使用[eval.py](../quant/quant_post/eval.py)脚本得到模型的分类精度,压缩后的模型也可以使用同一个脚本测试精度:
```
......
# 目标检测模型自动压缩
本示例将介绍如何使用PaddleDetection中Inference部署模型进行自动压缩。
## Benchmark
- PP-YOLOE模型:
| 模型 | 策略 | 输入尺寸 | mAP<sup>val<br>0.5:0.95 | 预测时延<sup><small>FP32</small><sup><br><sup>(ms) |预测时延<sup><small>FP32</small><sup><br><sup>(ms) | 预测时延<sup><small>INT8</small><sup><br><sup>(ms) | 配置文件 | Inference模型 |
| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
| PP-YOLOE-l | Base模型 | 640*640 | 50.9 | 11.2 | 7.7ms | - | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/detection/ppyoloe_crn_l_300e_coco.tar) |
| PP-YOLOE-l | 量化+蒸馏 | 640*640 | 49.5 | - | - | 6.7ms | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/auto_compression/detection/configs/ppyoloe_l_qat_dist.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/detection/ppyoloe_crn_l_300e_coco_quant.tar) |
- mAP的指标均在COCO val2017数据集中评测得到。
- PP-YOLOE模型在Tesla V100的GPU环境下测试,测试脚本是[benchmark demo](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy/python)
## 环境准备
### 1. 准备数据
本案例默认以COCO数据进行自动压缩实验,如果自定义COCO数据,或者其他格式数据,请参考[PaddleDetection数据准备文档](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md) 来准备数据。
如果数据集为非COCO格式数据,请修改[configs](./configs)中reader配置文件中的Dataset字段。
### 2. 准备需要量化的环境
- PaddlePaddle >= 2.2
- PaddleDet >= 2.4
```shell
pip install paddledet
```
注:安装PaddleDet的目的是为了直接使用PaddleDetection中的Dataloader组件。
### 3. 准备待量化的部署模型
如果已经准备好部署的`model.pdmodel``model.pdiparams`部署模型,跳过此步。
根据[PaddleDetection文档](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/docs/tutorials/GETTING_STARTED_cn.md#8-%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA) 导出Inference模型,具体可参考下方PP-YOLOE模型的导出示例:
- 下载代码
```
git clone https://github.com/PaddlePaddle/PaddleDetection.git
```
- 导出预测模型
```shell
python tools/export_model.py \
-c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml \
-o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams \
trt=True \
```
**注意**:PP-YOLOE导出时设置`trt=True`旨在优化在TensorRT上的性能,其他模型不需要设置`trt=True`
或直接下载:
```shell
wget https://bj.bcebos.com/v1/paddle-slim-models/detection/ppyoloe_crn_l_300e_coco.tar
tar -xf ppyoloe_crn_l_300e_coco.tar
```
### 4. 测试模型精度
使用[run_main.py](run_main.py)脚本得到模型的mAP:
```
python3.7 run_main.py --config_path=./configs/ppyoloe_l_qat_dist.yaml --eval=True
```
**注意**:TinyPose模型暂不支持精度测试。
## 开始自动压缩
### 进行量化蒸馏自动压缩
蒸馏量化自动压缩示例通过[run_main.py](run_main.py)脚本启动,会使用接口``paddleslim.auto_compression.AutoCompression``对模型进行量化训练。具体运行命令为:
```
python run_main.py --config_path=./configs/ppyoloe_l_qat_dist.yaml --save_dir='./output/' --devices='gpu'
```
## 部署
可以参考[PaddleDetection部署教程](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy)
- GPU上量化模型开启TensorRT并设置trt_int8模式进行部署;
- CPU上可参考[X86 CPU部署量化模型教程](https://github.com/PaddlePaddle/Paddle-Inference-Demo/blob/master/docs/optimize/paddle_x86_cpu_int8.md)
- 移动端请直接使用[Paddle Lite Demo](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy/lite)部署。
Global:
reader_config: configs/yolo_reader.yml
input_list: ['image', 'scale_factor']
Evaluation: True
model_dir: ./ppyoloe_crn_l_300e_coco/
model_filename: model.pdmodel
params_filename: model.pdiparams
Distillation:
distill_lambda: 1.0
distill_loss: l2_loss
distill_node_pair:
- teacher_concat_15.tmp_0
- concat_15.tmp_0
- teacher_concat_14.tmp_0
- concat_14.tmp_0
merge_feed: true
teacher_model_dir: ./ppyoloe_crn_l_300e_coco/
teacher_model_filename: model.pdmodel
teacher_params_filename: model.pdiparams
Quantization:
use_pact: true
activation_bits: 8
weight_bits: 8
activation_quantize_type: 'range_abs_max'
weight_quantize_type: 'channel_wise_abs_max'
is_full_quantize: false
not_quant_pattern:
- skip_quant
quantize_op_types:
- conv2d
- depthwise_conv2d
TrainConfig:
epochs: 1
eval_iter: 1000
learning_rate: 0.00001
optimizer: SGD
optim_args:
weight_decay: 4.0e-05
......@@ -14,7 +14,7 @@ EvalDataset:
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco/
worker_num: 8
worker_num: 4
# preprocess reader in test
EvalReader:
......
......@@ -2,7 +2,7 @@ Global:
reader_config: configs/yolo_reader.yml
input_list: ['image', 'im_shape', 'scale_factor']
Evaluation: True
model_dir: ./yolov3_mobilenet_v1_270e_coco/
model_dir: ./yolov3_mobilenet_v1_270e_coco/ # Model Link: https://paddledet.bj.bcebos.com/models/slim/yolov3_mobilenet_v1_270e_coco.tar
model_filename: model.pdmodel
params_filename: model.pdiparams
......@@ -23,6 +23,7 @@ Distillation:
Quantization:
activation_bits: 8
weight_bits: 8
activation_quantize_type: 'range_abs_max'
weight_quantize_type: 'channel_wise_abs_max'
is_full_quantize: false
......@@ -31,7 +32,6 @@ Quantization:
quantize_op_types:
- conv2d
- depthwise_conv2d
weight_bits: 8
TrainConfig:
epochs: 1
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
import argparse
import functools
from functools import partial
import numpy as np
import argparse
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
......@@ -12,20 +23,36 @@ from ppdet.metrics import COCOMetric
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.auto_compression import AutoCompression
paddle.enable_static()
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('save_dir', str, 'output', "directory to save compressed model.")
add_arg('devices', str, 'gpu', "which device used to compress.")
add_arg('batch_size', int, 1, "train batch size.")
add_arg('config_path', str, None, "path of compression strategy config.")
add_arg('eval', bool, False, "whether to run evaluation.")
# yapf: enable
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--save_dir',
type=str,
default='output',
help="directory to save compressed model.")
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
parser.add_argument(
'--eval', type=bool, default=False, help="whether to run evaluation.")
return parser
def print_arguments(args):
print('----------- Running Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------')
def reader_wrapper(reader, input_list):
......@@ -39,9 +66,9 @@ def reader_wrapper(reader, input_list):
return gen
def eval(args, compress_config):
def eval(compress_config):
place = paddle.CUDAPlace(0) if args.devices == 'gpu' else paddle.CPUPlace()
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place)
val_program, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model(
......@@ -114,8 +141,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
return map_res['bbox'][0]
def main(args):
compress_config, train_config = load_slim_config(args.config_path)
def main():
compress_config, train_config = load_slim_config(FLAGS.config_path)
reader_cfg = load_config(compress_config['reader_config'])
train_loader = create('EvalReader')(reader_cfg['TrainDataset'],
......@@ -130,8 +157,8 @@ def main(args):
reader_cfg['worker_num'],
return_list=True)
if args.eval:
eval(args, compress_config)
if FLAGS.eval:
eval(compress_config)
sys.exit(0)
if 'Evaluation' in compress_config.keys() and compress_config['Evaluation']:
......@@ -143,7 +170,7 @@ def main(args):
model_dir=compress_config["model_dir"],
model_filename=compress_config["model_filename"],
params_filename=compress_config["params_filename"],
save_dir=args.save_dir,
save_dir=FLAGS.save_dir,
strategy_config=compress_config,
train_config=train_config,
train_dataloader=train_loader,
......@@ -153,7 +180,12 @@ def main(args):
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
paddle.enable_static()
main(args)
parser = argsparser()
FLAGS = parser.parse_args()
print_arguments(FLAGS)
assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
paddle.set_device(FLAGS.devices)
main()
......@@ -490,4 +490,6 @@ class AutoCompression:
feeded_var_names=test_program_info.feed_target_names,
target_vars=test_program_info.fetch_targets,
executor=self._exe,
main_program=test_program)
main_program=test_program,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册