未验证 提交 3e4d5697 编写于 作者: G Guanghua Yu 提交者: GitHub

add PP-YOLOE Auto Compression demo (#6568)

上级 36e8a183
# 自动化压缩
目录:
- [1.简介](#1简介)
- [2.Benchmark](#2Benchmark)
- [3.开始自动压缩](#自动压缩流程)
- [3.1 环境准备](#31-准备环境)
- [3.2 准备数据集](#32-准备数据集)
- [3.3 准备预测模型](#33-准备预测模型)
- [3.4 测试模型精度](#34-测试模型精度)
- [3.5 自动压缩并产出模型](#35-自动压缩并产出模型)
- [4.预测部署](#4预测部署)
## 1. 简介
本示例使用PaddleDetection中Inference部署模型进行自动化压缩。本示例使用的自动化压缩策略为量化蒸馏。
## 2.Benchmark
### PP-YOLOE
| 模型 | Base mAP | 离线量化mAP | ACT量化mAP | TRT-FP32 | TRT-FP16 | TRT-INT8 | 配置文件 | 量化模型 |
| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :----------------------: | :---------------------: |
| PP-YOLOE-l | 50.9 | - | 50.6 | 11.2ms | 7.7ms | **6.7ms** | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/deploy/auto_compression/configs/ppyoloe_l_qat_dis.yaml) | [Quant Model](https://bj.bcebos.com/v1/paddle-slim-models/act/ppyoloe_crn_l_300e_coco_quant.tar) |
- mAP的指标均在COCO val2017数据集中评测得到,IoU=0.5:0.95。
- PP-YOLOE-l模型在Tesla V100的GPU环境下测试,并且开启TensorRT,batch_size=1,包含NMS,测试脚本是[benchmark demo](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy/python)
## 3. 自动压缩流程
#### 3.1 准备环境
- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim >= 2.3
- PaddleDet >= 2.4
- opencv-python
安装paddlepaddle:
```shell
# CPU
pip install paddlepaddle
# GPU
pip install paddlepaddle-gpu
```
安装paddleslim:
```shell
pip install paddleslim
```
安装paddledet:
```shell
pip install paddledet
```
#### 3.2 准备数据集
本案例默认以COCO数据进行自动压缩实验,如果自定义COCO数据,或者其他格式数据,请参考[数据准备文档](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md) 来准备数据。
如果数据集为非COCO格式数据,请修改[configs](./configs)中reader配置文件中的Dataset字段。
以PP-YOLOE模型为例,如果已经准备好数据集,请直接修改[./configs/yolo_reader.yml]中`EvalDataset``dataset_dir`字段为自己数据集路径即可。
#### 3.3 准备预测模型
预测模型的格式为:`model.pdmodel``model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。
根据[PaddleDetection文档](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/docs/tutorials/GETTING_STARTED_cn.md#8-%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA) 导出Inference模型,具体可参考下方PP-YOLOE模型的导出示例:
- 下载代码
```
git clone https://github.com/PaddlePaddle/PaddleDetection.git
```
- 导出预测模型
PPYOLOE-l模型,包含NMS:如快速体验,可直接下载[PP-YOLOE-l导出模型](https://bj.bcebos.com/v1/paddle-slim-models/act/ppyoloe_crn_l_300e_coco.tar)
```shell
python tools/export_model.py \
-c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml \
-o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams \
trt=True \
```
#### 3.4 自动压缩并产出模型
蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为:
- 单卡训练:
```
export CUDA_VISIBLE_DEVICES=0
python run.py --config_path=./configs/ppyoloe_l_qat_dis.yaml --save_dir='./output/'
```
- 多卡训练:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \
--config_path=./configs/ppyoloe_l_qat_dis.yaml --save_dir='./output/'
```
#### 3.5 测试模型精度
使用eval.py脚本得到模型的mAP:
```
export CUDA_VISIBLE_DEVICES=0
python eval.py --config_path=./configs/ppyoloe_l_qat_dis.yaml
```
**注意**
- 要测试的模型路径可以在配置文件中`model_dir`字段下进行修改。
## 4.预测部署
- 可以参考[PaddleDetection部署教程](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy),GPU上量化模型开启TensorRT并设置trt_int8模式进行部署。
Global:
reader_config: configs/yolo_reader.yml
input_list: ['image', 'scale_factor']
arch: YOLO
Evaluation: True
model_dir: ./ppyoloe_crn_l_300e_coco
model_filename: model.pdmodel
params_filename: model.pdiparams
Distillation:
alpha: 1.0
loss: soft_label
Quantization:
use_pact: true
activation_quantize_type: 'moving_average_abs_max'
quantize_op_types:
- conv2d
- depthwise_conv2d
TrainConfig:
train_iter: 5000
eval_iter: 1000
learning_rate:
type: CosineAnnealingDecay
learning_rate: 0.00003
T_max: 6000
optimizer_builder:
optimizer:
type: SGD
weight_decay: 4.0e-05
metric: COCO
num_classes: 80
# Datset configuration
TrainDataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco/
EvalDataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco/
worker_num: 0
# preprocess reader in test
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 4
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from post_process import PPYOLOEPostProcess
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
return parser
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
def convert_numpy_data(data, metric):
data_all = {}
data_all = {k: np.array(v) for k, v in data.items()}
if isinstance(metric, VOCMetric):
for k, v in data_all.items():
if not isinstance(v[0], np.ndarray):
tmp_list = []
for t in v:
tmp_list.append(np.array(t))
data_all[k] = np.array(tmp_list)
else:
data_all = {k: np.array(v) for k, v in data.items()}
return data_all
def eval():
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place)
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
global_config["model_dir"].rstrip('/'),
exe,
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"])
print('Loaded model from: {}'.format(global_config["model_dir"]))
metric = global_config['metric']
for batch_id, data in enumerate(val_loader):
data_all = convert_numpy_data(data, metric)
data_input = {}
for k, v in data.items():
if isinstance(global_config['input_list'], list):
if k in global_config['input_list']:
data_input[k] = np.array(v)
elif isinstance(global_config['input_list'], dict):
if k in global_config['input_list'].keys():
data_input[global_config['input_list'][k]] = np.array(v)
outs = exe.run(val_program,
feed=data_input,
fetch_list=fetch_targets,
return_numpy=False)
res = {}
if 'arch' in global_config and global_config['arch'] == 'PPYOLOE':
postprocess = PPYOLOEPostProcess(
score_threshold=0.01, nms_threshold=0.6)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
else:
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res)
if batch_id % 100 == 0:
print('Eval iter:', batch_id)
metric.accumulate()
metric.log()
metric.reset()
def main():
global global_config
all_config = load_slim_config(FLAGS.config_path)
assert "Global" in all_config, "Key 'Global' not found in config file."
global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
dataset = reader_cfg['EvalDataset']
global val_loader
val_loader = create('EvalReader')(reader_cfg['EvalDataset'],
reader_cfg['worker_num'],
return_list=True)
metric = None
if reader_cfg['metric'] == 'COCO':
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
anno_file = dataset.get_anno()
metric = COCOMetric(
anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
elif reader_cfg['metric'] == 'VOC':
metric = VOCMetric(
label_list=dataset.get_label_list(),
class_num=reader_cfg['num_classes'],
map_type=reader_cfg['map_type'])
elif reader_cfg['metric'] == 'KeyPointTopDownCOCOEval':
anno_file = dataset.get_anno()
metric = KeyPointTopDownCOCOEval(anno_file,
len(dataset), 17, 'output_eval')
else:
raise ValueError("metric currently only supports COCO and VOC.")
global_config['metric'] = metric
eval()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
paddle.set_device(FLAGS.devices)
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
"""
Args:
box_scores (N, 5): boxes in corner-form and probabilities.
iou_threshold: intersection over union threshold.
top_k: keep top_k results. If k <= 0, keep all the results.
candidate_size: only consider the candidates with the highest scores.
Returns:
picked: a list of indexes of the kept boxes
"""
scores = box_scores[:, -1]
boxes = box_scores[:, :-1]
picked = []
indexes = np.argsort(scores)
indexes = indexes[-candidate_size:]
while len(indexes) > 0:
current = indexes[-1]
picked.append(current)
if 0 < top_k == len(picked) or len(indexes) == 1:
break
current_box = boxes[current, :]
indexes = indexes[:-1]
rest_boxes = boxes[indexes, :]
iou = iou_of(
rest_boxes,
np.expand_dims(
current_box, axis=0), )
indexes = indexes[iou <= iou_threshold]
return box_scores[picked, :]
def iou_of(boxes0, boxes1, eps=1e-5):
"""Return intersection-over-union (Jaccard index) of boxes.
Args:
boxes0 (N, 4): ground truth boxes.
boxes1 (N or 1, 4): predicted boxes.
eps: a small number to avoid 0 as denominator.
Returns:
iou (N): IoU values.
"""
overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
overlap_area = area_of(overlap_left_top, overlap_right_bottom)
area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
return overlap_area / (area0 + area1 - overlap_area + eps)
def area_of(left_top, right_bottom):
"""Compute the areas of rectangles given two corners.
Args:
left_top (N, 2): left top corner.
right_bottom (N, 2): right bottom corner.
Returns:
area (N): return the area.
"""
hw = np.clip(right_bottom - left_top, 0.0, None)
return hw[..., 0] * hw[..., 1]
class PPYOLOEPostProcess(object):
"""
Args:
input_shape (int): network input image size
scale_factor (float): scale factor of ori image
"""
def __init__(self,
score_threshold=0.4,
nms_threshold=0.5,
nms_top_k=10000,
keep_top_k=300):
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
def _non_max_suppression(self, prediction, scale_factor):
batch_size = prediction.shape[0]
out_boxes_list = []
box_num_list = []
for batch_id in range(batch_size):
bboxes, confidences = prediction[batch_id][..., :4], prediction[
batch_id][..., 4:]
# nms
picked_box_probs = []
picked_labels = []
for class_index in range(0, confidences.shape[1]):
probs = confidences[:, class_index]
mask = probs > self.score_threshold
probs = probs[mask]
if probs.shape[0] == 0:
continue
subset_boxes = bboxes[mask, :]
box_probs = np.concatenate(
[subset_boxes, probs.reshape(-1, 1)], axis=1)
box_probs = hard_nms(
box_probs,
iou_threshold=self.nms_threshold,
top_k=self.nms_top_k)
picked_box_probs.append(box_probs)
picked_labels.extend([class_index] * box_probs.shape[0])
if len(picked_box_probs) == 0:
out_boxes_list.append(np.empty((0, 4)))
else:
picked_box_probs = np.concatenate(picked_box_probs)
# resize output boxes
picked_box_probs[:, 0] /= scale_factor[batch_id][1]
picked_box_probs[:, 2] /= scale_factor[batch_id][1]
picked_box_probs[:, 1] /= scale_factor[batch_id][0]
picked_box_probs[:, 3] /= scale_factor[batch_id][0]
# clas score box
out_box = np.concatenate(
[
np.expand_dims(
np.array(picked_labels), axis=-1), np.expand_dims(
picked_box_probs[:, 4], axis=-1),
picked_box_probs[:, :4]
],
axis=1)
if out_box.shape[0] > self.keep_top_k:
out_box = out_box[out_box[:, 1].argsort()[::-1]
[:self.keep_top_k]]
out_boxes_list.append(out_box)
box_num_list.append(out_box.shape[0])
out_boxes_list = np.concatenate(out_boxes_list, axis=0)
box_num_list = np.array(box_num_list)
return out_boxes_list, box_num_list
def __call__(self, outs, scale_factor):
out_boxes_list, box_num_list = self._non_max_suppression(outs,
scale_factor)
return {'bbox': out_boxes_list, 'bbox_num': box_num_list}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.auto_compression import AutoCompression
from post_process import PPYOLOEPostProcess
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--save_dir',
type=str,
default='output',
help="directory to save compressed model.")
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
return parser
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
def convert_numpy_data(data, metric):
data_all = {}
data_all = {k: np.array(v) for k, v in data.items()}
if isinstance(metric, VOCMetric):
for k, v in data_all.items():
if not isinstance(v[0], np.ndarray):
tmp_list = []
for t in v:
tmp_list.append(np.array(t))
data_all[k] = np.array(tmp_list)
else:
data_all = {k: np.array(v) for k, v in data.items()}
return data_all
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
metric = global_config['metric']
for batch_id, data in enumerate(val_loader):
data_all = convert_numpy_data(data, metric)
data_input = {}
for k, v in data.items():
if isinstance(global_config['input_list'], list):
if k in test_feed_names:
data_input[k] = np.array(v)
elif isinstance(global_config['input_list'], dict):
if k in global_config['input_list'].keys():
data_input[global_config['input_list'][k]] = np.array(v)
outs = exe.run(compiled_test_program,
feed=data_input,
fetch_list=test_fetch_list,
return_numpy=False)
res = {}
if 'arch' in global_config and global_config['arch'] == 'PPYOLOE':
postprocess = PPYOLOEPostProcess(
score_threshold=0.01, nms_threshold=0.6)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
else:
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res)
if batch_id % 100 == 0:
print('Eval iter:', batch_id)
metric.accumulate()
metric.log()
map_res = metric.get_results()
metric.reset()
map_key = 'keypoint' if 'arch' in global_config and global_config[
'arch'] == 'keypoint' else 'bbox'
return map_res[map_key][0]
def main():
global global_config
all_config = load_slim_config(FLAGS.config_path)
assert "Global" in all_config, "Key 'Global' not found in config file."
global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
train_loader = create('EvalReader')(reader_cfg['TrainDataset'],
reader_cfg['worker_num'],
return_list=True)
train_loader = reader_wrapper(train_loader, global_config['input_list'])
if 'Evaluation' in global_config.keys() and global_config[
'Evaluation'] and paddle.distributed.get_rank() == 0:
eval_func = eval_function
dataset = reader_cfg['EvalDataset']
global val_loader
_eval_batch_sampler = paddle.io.BatchSampler(
dataset, batch_size=reader_cfg['EvalReader']['batch_size'])
val_loader = create('EvalReader')(dataset,
reader_cfg['worker_num'],
batch_sampler=_eval_batch_sampler,
return_list=True)
metric = None
if reader_cfg['metric'] == 'COCO':
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
anno_file = dataset.get_anno()
metric = COCOMetric(
anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
elif reader_cfg['metric'] == 'VOC':
metric = VOCMetric(
label_list=dataset.get_label_list(),
class_num=reader_cfg['num_classes'],
map_type=reader_cfg['map_type'])
elif reader_cfg['metric'] == 'KeyPointTopDownCOCOEval':
anno_file = dataset.get_anno()
metric = KeyPointTopDownCOCOEval(anno_file,
len(dataset), 17, 'output_eval')
else:
raise ValueError("metric currently only supports COCO and VOC.")
global_config['metric'] = metric
else:
eval_func = None
ac = AutoCompression(
model_dir=global_config["model_dir"],
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"],
save_dir=FLAGS.save_dir,
config=all_config,
train_dataloader=train_loader,
eval_callback=eval_func)
ac.compress()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
paddle.set_device(FLAGS.devices)
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册