未验证 提交 bde12b69 编写于 作者: G Guanghua Yu 提交者: GitHub

add ppyoloe full quant demo (#1502)

上级 6ef0a6ba
...@@ -42,21 +42,6 @@ def argsparser(): ...@@ -42,21 +42,6 @@ def argsparser():
return parser return parser
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
def convert_numpy_data(data, metric): def convert_numpy_data(data, metric):
data_all = {} data_all = {}
data_all = {k: np.array(v) for k, v in data.items()} data_all = {k: np.array(v) for k, v in data.items()}
...@@ -89,12 +74,8 @@ def eval(): ...@@ -89,12 +74,8 @@ def eval():
data_all = convert_numpy_data(data, metric) data_all = convert_numpy_data(data, metric)
data_input = {} data_input = {}
for k, v in data.items(): for k, v in data.items():
if isinstance(global_config['input_list'], list): if k in feed_target_names:
if k in global_config['input_list']: data_input[k] = np.array(v)
data_input[k] = np.array(v)
elif isinstance(global_config['input_list'], dict):
if k in global_config['input_list'].keys():
data_input[global_config['input_list'][k]] = np.array(v)
outs = exe.run(val_program, outs = exe.run(val_program,
feed=data_input, feed=data_input,
......
# PP-YOLOE模型全量化示例
目录:
- [1.简介](#1简介)
- [2.Benchmark](#2Benchmark)
- [3.开始全量化](#全量化流程)
- [3.1 环境准备](#31-准备环境)
- [3.2 准备数据集](#32-准备数据集)
- [3.3 准备预测模型](#33-准备预测模型)
- [3.4 测试模型精度](#34-测试模型精度)
- [3.5 全量化并产出模型](#35-全量化并产出模型)
- [4.预测部署](#4预测部署)
- [5.FAQ](5FAQ)
## 1. 简介
本示例将以目标检测模型PP-YOLOE为例,介绍如何使用PaddleDetection中Inference部署模型进行全量化。本示例使用的全量化策略为全量化加蒸馏。
## 2.Benchmark
| 模型 | 策略 | mAP | TRT-FP32 | TRT-FP16 | TRT-INT8 | 模型 |
| :-------- |:-------- |:--------: | :----------------: | :----------------: | :---------------: | :---------------------: |
| PP-YOLOE-s-416 | Baseline | 39.1 | - | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/ppyoloe_s_no_postprocess_416.tar) |
| PP-YOLOE-s-416 | 量化训练 | 38.5 | - | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/ppyoloe_s_no_postprocess_416_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/ppyoloe_s_quant_416_no_postprocess.onnx) |
- mAP的指标均在COCO val2017数据集中评测得到,IoU=0.5:0.95。
## 3. 全量化流程
#### 3.1 准备环境
- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim >= 2.3.4
- PaddleDet >= 2.5
- opencv-python
安装paddlepaddle:
```shell
# CPU
pip install paddlepaddle
# GPU
pip install paddlepaddle-gpu
```
安装paddleslim:
```shell
pip install paddleslim
```
安装paddledet:
```shell
pip install paddledet
```
注:安装PaddleDet的目的是为了直接使用PaddleDetection中的Dataloader组件。
#### 3.2 准备数据集
本案例默认以COCO数据进行全量化实验,如果自定义COCO数据,或者其他格式数据,请参考[PaddleDetection数据准备文档](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md) 来准备数据。
如果数据集为非COCO格式数据,请修改[configs](./configs)中reader配置文件中的Dataset字段。
如果已经准备好数据集,请直接修改[./configs/yoloe_416_reader.yml]中`EvalDataset``dataset_dir`字段为自己数据集路径即可。
#### 3.3 准备预测模型
预测模型的格式为:`model.pdmodel``model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。
根据[PaddleDetection文档](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/docs/tutorials/GETTING_STARTED_cn.md#8-%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA) 导出Inference模型,具体可参考下方PP-YOLOE模型的导出示例:
- 下载代码
```
git clone https://github.com/PaddlePaddle/PaddleDetection.git
```
- 导出预测模型
注意:PP-YOLOE默认导出640x640输入的模型,如果模型输入需要改为416x416,需要在导出时修改ppdet中[ppyoloe_reader.yml](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/_base_/ppyoloe_reader.yml#L2)`eval_height``eval_width`为416。
包含NMS:
```shell
python tools/export_model.py \
-c configs/ppyoloe/ppyoloe_crn_s_300e_coco.yml \
-o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_300e_coco.pdparams \
trt=True \
```
不包含NMS:
```shell
python tools/export_model.py \
-c configs/ppyoloe/ppyoloe_crn_s_300e_coco.yml \
-o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_300e_coco.pdparams \
trt=True exclude_post_process=True \
```
#### 3.4 全量化并产出模型
全量化示例通过auto_compress.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行全量化。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为:
- 单卡训练:
```
export CUDA_VISIBLE_DEVICES=0
python auto_compress.py --config_path=./configs/ppyoloe_s_416_qat_dis.yaml --save_dir='./output/'
```
- 多卡训练:
```
CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 auto_compress.py \
--config_path=./configs/ppyoloe_s_416_qat_dis.yaml --save_dir='./output/'
```
- 离线量化
```
python post_quant.py --config_path=./configs/ppyoloe_s_416_qat_dis.yaml
```
#### 3.5 测试模型精度
- 使用eval.py脚本得到模型的mAP:
```
export CUDA_VISIBLE_DEVICES=0
python eval.py --config_path=./configs/ppyoloe_s_416_qat_dis.yaml
```
**注意**
- 要测试的模型路径可以在配置文件中`model_dir`字段下进行修改。
- 导出ONNX,使用ONNXRuntime测试模型精度:
首先导出onnx量化模型
```
paddle2onnx --model_dir=ptq_out/ --model_filename=model.pdmodel --params_filename=model.pdiparams --save_file=ppyoloe_s_quant_416 --deploy_backend=rkn
```
可以根据不同部署后端设置`--deploy_backend`
然后进行评估:
```shell
python3.7 onnxruntime_eval.py --reader_config=configs/yolo_416_reader.yml --model_path=ppyoloe_s_quant_416_no_postprocess.onnx
```
## 4.预测部署
## 5.FAQ
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
from tqdm import tqdm
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric
from paddleslim.common import load_config as load_slim_config
from paddleslim.common.dataloader import get_feed_vars
from paddleslim.quant.analysis import AnalysisQuant
from post_process import PPYOLOEPostProcess
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of analysis config.",
required=True)
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
return parser
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
yield in_dict
return gen
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
metric = global_config['metric']
for batch_id, data in enumerate(val_loader):
data_all = {k: np.array(v) for k, v in data.items()}
data_input = {}
for k, v in data.items():
if isinstance(global_config['input_list'], list):
if k in test_feed_names:
data_input[k] = np.array(v)
elif isinstance(global_config['input_list'], dict):
if k in global_config['input_list'].keys():
data_input[global_config['input_list'][k]] = np.array(v)
outs = exe.run(compiled_test_program,
feed=data_input,
fetch_list=test_fetch_list,
return_numpy=False)
res = {}
if 'exclude_nms' in global_config and global_config['exclude_nms']:
postprocess = PPYOLOEPostProcess(
score_threshold=0.01, nms_threshold=0.6)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
else:
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res)
if batch_id % 100 == 0:
print('Eval iter:', batch_id)
metric.accumulate()
metric.log()
map_res = metric.get_results()
metric.reset()
return map_res['bbox'][0]
def main():
global global_config
all_config = load_slim_config(FLAGS.config_path)
assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}"
global_config = all_config["Global"]
ptq_config = all_config['PTQ']
global_config['input_list'] = get_feed_vars(
global_config['model_dir'], global_config['model_filename'],
global_config['params_filename'])
reader_cfg = load_config(global_config['reader_config'])
train_loader = create('EvalReader')(reader_cfg['TrainDataset'],
reader_cfg['worker_num'],
return_list=True)
train_loader = reader_wrapper(train_loader, global_config['input_list'])
dataset = reader_cfg['EvalDataset']
global val_loader
_eval_batch_sampler = paddle.io.BatchSampler(
dataset, batch_size=reader_cfg['EvalReader']['batch_size'])
val_loader = create('EvalReader')(dataset,
reader_cfg['worker_num'],
batch_sampler=_eval_batch_sampler,
return_list=True)
global num_classes
num_classes = reader_cfg['num_classes']
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
anno_file = dataset.get_anno()
metric = COCOMetric(
anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
global_config['metric'] = metric
analyzer = AnalysisQuant(
model_dir=global_config["model_dir"],
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"],
eval_function=eval_function,
data_loader=train_loader,
resume=True,
save_dir='output',
ptq_config=ptq_config)
# plot the boxplot of activations of quantizable weights
analyzer.plot_activation_distribution()
# get the rank of sensitivity of each quantized layer
# plot the histogram plot of best and worst activations and weights if plot_hist is True
analyzer.compute_quant_sensitivity(plot_hist=True)
# get the quantized model that satisfies target metric you set
analyzer.get_target_quant_model(target_metric=0.25)
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
paddle.set_device(FLAGS.devices)
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric
from paddleslim.common import load_config as load_slim_config
from paddleslim.auto_compression import AutoCompression
from post_process import PPYOLOEPostProcess
from paddleslim.common.dataloader import get_feed_vars
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--save_dir',
type=str,
default='output',
help="directory to save compressed model.")
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
return parser
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
def convert_numpy_data(data, metric):
data_all = {}
data_all = {k: np.array(v) for k, v in data.items()}
if isinstance(metric, VOCMetric):
for k, v in data_all.items():
if not isinstance(v[0], np.ndarray):
tmp_list = []
for t in v:
tmp_list.append(np.array(t))
data_all[k] = np.array(tmp_list)
else:
data_all = {k: np.array(v) for k, v in data.items()}
return data_all
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
metric = global_config['metric']
for batch_id, data in enumerate(val_loader):
data_all = convert_numpy_data(data, metric)
data_input = {}
for k, v in data.items():
if isinstance(global_config['input_list'], list):
if k in test_feed_names:
data_input[k] = np.array(v)
elif isinstance(global_config['input_list'], dict):
if k in global_config['input_list'].keys():
data_input[global_config['input_list'][k]] = np.array(v)
outs = exe.run(compiled_test_program,
feed=data_input,
fetch_list=test_fetch_list,
return_numpy=False)
res = {}
if 'exclude_nms' in global_config and global_config['exclude_nms']:
postprocess = PPYOLOEPostProcess(
score_threshold=0.01, nms_threshold=0.6)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
else:
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res)
if batch_id % 100 == 0:
print('Eval iter:', batch_id)
metric.accumulate()
metric.log()
map_res = metric.get_results()
metric.reset()
return map_res['bbox'][0]
def main():
global global_config
all_config = load_slim_config(FLAGS.config_path)
assert "Global" in all_config, "Key 'Global' not found in config file. \n{}".format(
all_config)
global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
train_loader = create('EvalReader')(reader_cfg['TrainDataset'],
reader_cfg['worker_num'],
return_list=True)
global_config['input_list'] = get_feed_vars(
global_config['model_dir'], global_config['model_filename'],
global_config['params_filename'])
train_loader = reader_wrapper(train_loader, global_config['input_list'])
if 'Evaluation' in global_config.keys() and global_config[
'Evaluation'] and paddle.distributed.get_rank() == 0:
eval_func = eval_function
dataset = reader_cfg['EvalDataset']
global val_loader
_eval_batch_sampler = paddle.io.BatchSampler(
dataset, batch_size=reader_cfg['EvalReader']['batch_size'])
val_loader = create('EvalReader')(dataset,
reader_cfg['worker_num'],
batch_sampler=_eval_batch_sampler,
return_list=True)
metric = None
if reader_cfg['metric'] == 'COCO':
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
anno_file = dataset.get_anno()
metric = COCOMetric(
anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
elif reader_cfg['metric'] == 'VOC':
metric = VOCMetric(
label_list=dataset.get_label_list(),
class_num=reader_cfg['num_classes'],
map_type=reader_cfg['map_type'])
else:
raise ValueError("metric currently only supports COCO and VOC.")
global_config['metric'] = metric
else:
eval_func = None
ac = AutoCompression(
model_dir=global_config["model_dir"],
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"],
save_dir=FLAGS.save_dir,
config=all_config,
train_dataloader=train_loader,
eval_callback=eval_func)
ac.compress()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
paddle.set_device(FLAGS.devices)
main()
Global:
reader_config: configs/yolo_416_reader.yml
exclude_nms: True
Evaluation: True
model_dir: ./ppyoloe_s_no_postprocess_416/
model_filename: model.pdmodel
params_filename: model.pdiparams
PTQ:
quantizable_op_type: ["conv2d", "depthwise_conv2d"]
activation_quantize_type: 'moving_average_abs_max'
is_full_quantize: True
onnx_format: False
batch_size: 10
batch_nums: 10
Distillation:
alpha: 1.0
loss: soft_label
Quantization:
onnx_format: true
use_pact: true
activation_quantize_type: 'moving_average_abs_max'
quantize_op_types:
- conv2d
- depthwise_conv2d
TrainConfig:
train_iter: 5000
eval_iter: 1000
learning_rate:
type: CosineAnnealingDecay
learning_rate: 0.00003
T_max: 6000
optimizer_builder:
optimizer:
type: SGD
weight_decay: 4.0e-05
Global:
reader_config: configs/yolo_reader.yml
exclude_nms: True
Evaluation: True
model_dir: ./ppyoloe_crn_s_300e_coco
model_filename: model.pdmodel
params_filename: model.pdiparams
PTQ:
quantizable_op_type: ["conv2d", "depthwise_conv2d"]
activation_quantize_type: 'moving_average_abs_max'
is_full_quantize: True
onnx_format: False
batch_size: 10
batch_nums: 10
Distillation:
alpha: 1.0
loss: soft_label
Quantization:
onnx_format: true
use_pact: true
activation_quantize_type: 'moving_average_abs_max'
quantize_op_types:
- conv2d
- depthwise_conv2d
TrainConfig:
train_iter: 5000
eval_iter: 1000
learning_rate:
type: CosineAnnealingDecay
learning_rate: 0.00003
T_max: 6000
optimizer_builder:
optimizer:
type: SGD
weight_decay: 4.0e-05
metric: COCO
num_classes: 80
# Datset configuration
TrainDataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco/
EvalDataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco/
worker_num: 0
# preprocess reader in test
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [416, 416], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 1
metric: COCO
num_classes: 80
# Datset configuration
TrainDataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco/
EvalDataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco/
worker_num: 0
# preprocess reader in test
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric
from paddleslim.common import load_config as load_slim_config
from post_process import PPYOLOEPostProcess
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
return parser
def eval():
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place)
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
global_config["model_dir"].rstrip('/'),
exe,
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"])
print('Loaded model from: {}'.format(global_config["model_dir"]))
metric = global_config['metric']
for batch_id, data in enumerate(val_loader):
data_all = {k: np.array(v) for k, v in data.items()}
data_input = {}
for k, v in data.items():
if k in feed_target_names:
data_input[k] = np.array(v)
outs = exe.run(val_program,
feed=data_input,
fetch_list=fetch_targets,
return_numpy=False)
res = {}
if 'exclude_nms' in global_config and global_config['exclude_nms']:
postprocess = PPYOLOEPostProcess(
score_threshold=0.01, nms_threshold=0.6)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
else:
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res)
if batch_id % 100 == 0:
print('Eval iter:', batch_id)
metric.accumulate()
metric.log()
metric.reset()
def main():
global global_config
all_config = load_slim_config(FLAGS.config_path)
global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
dataset = reader_cfg['EvalDataset']
global val_loader
val_loader = create('EvalReader')(reader_cfg['EvalDataset'],
reader_cfg['worker_num'],
return_list=True)
metric = None
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
anno_file = dataset.get_anno()
metric = COCOMetric(
anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
global_config['metric'] = metric
eval()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
paddle.set_device(FLAGS.devices)
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
import time
import paddle
from ppdet.core.workspace import load_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric
import onnxruntime as ort
from post_process import PPYOLOEPostProcess
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--reader_config',
type=str,
default='configs/picodet_reader.yml',
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--model_path',
type=str,
default='onnx_file/picodet_s_416_npu_postprocessed.onnx',
help="onnx filepath")
parser.add_argument(
'--include_post_process',
type=bool,
default=False,
help="Whether include post_process or not.")
return parser
def eval(val_loader, metric, sess):
inputs_name = [a.name for a in sess.get_inputs()]
predict_time = 0.0
time_min = float("inf")
time_max = float("-inf")
sample_nums = len(val_loader)
for batch_id, data in enumerate(val_loader):
data_all = {k: np.array(v) for k, v in data.items()}
data_input = {}
for k, v in data.items():
if k in inputs_name:
data_input[k] = np.array(v)
start_time = time.time()
outs = sess.run(None, data_input)
end_time = time.time()
timed = end_time - start_time
time_min = min(time_min, timed)
time_max = max(time_max, timed)
predict_time += timed
res = {}
if not FLAGS.include_post_process:
postprocess = PPYOLOEPostProcess(
score_threshold=0.01, nms_threshold=0.6)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
else:
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res)
if batch_id % 100 == 0:
print('Eval iter:', batch_id)
metric.accumulate()
metric.log()
map_res = metric.get_results()
metric.reset()
time_avg = predict_time / sample_nums
print("[Benchmark]Inference time(ms): min={}, max={}, avg={}".format(
round(time_min * 1000, 2),
round(time_max * 1000, 1), round(time_avg * 1000, 1)))
print("[Benchmark] COCO mAP: {}".format(map_res["bbox"][0]))
sys.stdout.flush()
def main():
reader_cfg = load_config(FLAGS.reader_config)
dataset = reader_cfg['EvalDataset']
val_loader = create('EvalReader')(reader_cfg['EvalDataset'],
reader_cfg['worker_num'],
return_list=True)
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
anno_file = dataset.get_anno()
metric = COCOMetric(
anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
providers = ['CPUExecutionProvider']
sess_options = ort.SessionOptions()
sess_options.optimized_model_filepath = "./optimize_model.onnx"
sess = ort.InferenceSession(
FLAGS.model_path, providers=providers, sess_options=sess_options)
eval(val_loader, metric, sess)
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
# DataLoader need run on cpu
paddle.set_device("cpu")
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
"""
Args:
box_scores (N, 5): boxes in corner-form and probabilities.
iou_threshold: intersection over union threshold.
top_k: keep top_k results. If k <= 0, keep all the results.
candidate_size: only consider the candidates with the highest scores.
Returns:
picked: a list of indexes of the kept boxes
"""
scores = box_scores[:, -1]
boxes = box_scores[:, :-1]
picked = []
indexes = np.argsort(scores)
indexes = indexes[-candidate_size:]
while len(indexes) > 0:
current = indexes[-1]
picked.append(current)
if 0 < top_k == len(picked) or len(indexes) == 1:
break
current_box = boxes[current, :]
indexes = indexes[:-1]
rest_boxes = boxes[indexes, :]
iou = iou_of(
rest_boxes,
np.expand_dims(
current_box, axis=0), )
indexes = indexes[iou <= iou_threshold]
return box_scores[picked, :]
def iou_of(boxes0, boxes1, eps=1e-5):
"""Return intersection-over-union (Jaccard index) of boxes.
Args:
boxes0 (N, 4): ground truth boxes.
boxes1 (N or 1, 4): predicted boxes.
eps: a small number to avoid 0 as denominator.
Returns:
iou (N): IoU values.
"""
overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
overlap_area = area_of(overlap_left_top, overlap_right_bottom)
area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
return overlap_area / (area0 + area1 - overlap_area + eps)
def area_of(left_top, right_bottom):
"""Compute the areas of rectangles given two corners.
Args:
left_top (N, 2): left top corner.
right_bottom (N, 2): right bottom corner.
Returns:
area (N): return the area.
"""
hw = np.clip(right_bottom - left_top, 0.0, None)
return hw[..., 0] * hw[..., 1]
class PPYOLOEPostProcess(object):
"""
Args:
input_shape (int): network input image size
scale_factor (float): scale factor of ori image
"""
def __init__(self,
score_threshold=0.4,
nms_threshold=0.5,
nms_top_k=10000,
keep_top_k=300):
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
def _non_max_suppression(self, prediction, scale_factor):
batch_size = prediction.shape[0]
out_boxes_list = []
box_num_list = []
for batch_id in range(batch_size):
bboxes, confidences = prediction[batch_id][..., :4], prediction[
batch_id][..., 4:]
# nms
picked_box_probs = []
picked_labels = []
for class_index in range(0, confidences.shape[1]):
probs = confidences[:, class_index]
mask = probs > self.score_threshold
probs = probs[mask]
if probs.shape[0] == 0:
continue
subset_boxes = bboxes[mask, :]
box_probs = np.concatenate(
[subset_boxes, probs.reshape(-1, 1)], axis=1)
box_probs = hard_nms(
box_probs,
iou_threshold=self.nms_threshold,
top_k=self.nms_top_k)
picked_box_probs.append(box_probs)
picked_labels.extend([class_index] * box_probs.shape[0])
if len(picked_box_probs) == 0:
out_boxes_list.append(np.empty((0, 4)))
else:
picked_box_probs = np.concatenate(picked_box_probs)
# resize output boxes
picked_box_probs[:, 0] /= scale_factor[batch_id][1]
picked_box_probs[:, 2] /= scale_factor[batch_id][1]
picked_box_probs[:, 1] /= scale_factor[batch_id][0]
picked_box_probs[:, 3] /= scale_factor[batch_id][0]
# clas score box
out_box = np.concatenate(
[
np.expand_dims(
np.array(picked_labels), axis=-1), np.expand_dims(
picked_box_probs[:, 4], axis=-1),
picked_box_probs[:, :4]
],
axis=1)
if out_box.shape[0] > self.keep_top_k:
out_box = out_box[out_box[:, 1].argsort()[::-1]
[:self.keep_top_k]]
out_boxes_list.append(out_box)
box_num_list.append(out_box.shape[0])
out_boxes_list = np.concatenate(out_boxes_list, axis=0)
box_num_list = np.array(box_num_list)
return out_boxes_list, box_num_list
def __call__(self, outs, scale_factor):
out_boxes_list, box_num_list = self._non_max_suppression(outs,
scale_factor)
return {'bbox': out_boxes_list, 'bbox_num': box_num_list}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
from paddleslim.quant import quant_post_static
from paddleslim.common import load_config as load_slim_config
from paddleslim.common.dataloader import get_feed_vars
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--save_dir',
type=str,
default='ptq_out',
help="directory to save compressed model.")
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
parser.add_argument(
'--algo', type=str, default='avg', help="post quant algo.")
return parser
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
def main():
all_config = load_slim_config(FLAGS.config_path)
global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
global_config['input_list'] = get_feed_vars(
global_config['model_dir'], global_config['model_filename'],
global_config['params_filename'])
train_loader = create('EvalReader')(reader_cfg['TrainDataset'],
reader_cfg['worker_num'],
return_list=True)
train_loader = reader_wrapper(train_loader, global_config['input_list'])
ptq_config = all_config['PTQ']
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place)
quant_post_static(
executor=exe,
model_dir=global_config["model_dir"],
quantize_model_path=FLAGS.save_dir,
data_loader=train_loader,
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"],
quantizable_op_type=ptq_config['quantizable_op_type'],
activation_quantize_type=ptq_config['activation_quantize_type'],
batch_size=ptq_config['batch_size'],
batch_nums=ptq_config['batch_nums'],
algo=FLAGS.algo,
hist_percent=0.999,
is_full_quantize=ptq_config['is_full_quantize'],
bias_correction=False,
onnx_format=ptq_config['onnx_format'],
skip_tensor_list=None)
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
paddle.set_device(FLAGS.devices)
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册