未验证 提交 e263e885 编写于 作者: G Guanghua Yu 提交者: GitHub

update picodet full quant demo (#1460)

上级 38f6f578
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
| 模型 | 策略 | mAP | TRT-FP32 | TRT-FP16 | TRT-INT8 | 配置文件 | 模型 | | 模型 | 策略 | mAP | TRT-FP32 | TRT-FP16 | TRT-INT8 | 配置文件 | 模型 |
| :-------- |:-------- |:--------: | :----------------: | :----------------: | :---------------: | :----------------------: | :---------------------: | | :-------- |:-------- |:--------: | :----------------: | :----------------: | :---------------: | :----------------------: | :---------------------: |
| PicoDet-S-NPU | Baseline | 30.1 | - | - | - | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_416_coco_npu.yml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_416_coco_npu.tar) | | PicoDet-S-NPU | Baseline | 30.1 | - | - | - | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_416_coco_npu.yml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_416_coco_npu.tar) |
| PicoDet-S-NPU | 量化训练 | 29.7 | - | - | - | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/full_quantization/detection/configs/picodet_s_qat_dis.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_npu_quant.tar) | | PicoDet-S-NPU | 量化训练 | 29.7 | - | - | - | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/full_quantization/picodet/configs/picodet_npu_with_postprocess.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_npu_quant.tar) |
- mAP的指标均在COCO val2017数据集中评测得到,IoU=0.5:0.95。 - mAP的指标均在COCO val2017数据集中评测得到,IoU=0.5:0.95。
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
#### 3.1 准备环境 #### 3.1 准备环境
- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) - PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim >= 2.3 - PaddleSlim >= 2.3.4
- PaddleDet >= 2.4 - PaddleDet >= 2.4
- opencv-python - opencv-python
...@@ -67,9 +67,6 @@ pip install paddledet ...@@ -67,9 +67,6 @@ pip install paddledet
预测模型的格式为:`model.pdmodel``model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。 预测模型的格式为:`model.pdmodel``model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。
注:其他像`__model__``__params__`分别对应`model.pdmodel``model.pdiparams`文件。
根据[PaddleDetection文档](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/docs/tutorials/GETTING_STARTED_cn.md#8-%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA) 导出Inference模型,具体可参考下方PicoDet-S-NPU模型的导出示例: 根据[PaddleDetection文档](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/docs/tutorials/GETTING_STARTED_cn.md#8-%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA) 导出Inference模型,具体可参考下方PicoDet-S-NPU模型的导出示例:
- 下载代码 - 下载代码
``` ```
...@@ -77,13 +74,20 @@ git clone https://github.com/PaddlePaddle/PaddleDetection.git ...@@ -77,13 +74,20 @@ git clone https://github.com/PaddlePaddle/PaddleDetection.git
``` ```
- 导出预测模型 - 导出预测模型
PicoDet-S-NPU模型,包含NMS:如快速体验,可直接下载[PicoDet-S-NPU导出模型](https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_416_coco_npu.tar) PicoDet-S-NPU模型,包含后处理:如快速体验,可直接下载[PicoDet-S-NPU导出模型](https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_416_coco_npu.tar)
```shell ```shell
python tools/export_model.py \ python tools/export_model.py \
-c configs/picodet/picodet_s_416_coco_npu.yml \ -c configs/picodet/picodet_s_416_coco_npu.yml \
-o weights=https://paddledet.bj.bcebos.com/models/picodet_s_416_coco_npu.pdparams \ -o weights=https://paddledet.bj.bcebos.com/models/picodet_s_416_coco_npu.pdparams \
``` ```
导出PicoDet-S-NPU不带后处理模型:
```shell
python tools/export_model.py \
-c configs/picodet/picodet_s_416_coco_npu.yml \
-o weights=https://paddledet.bj.bcebos.com/models/picodet_s_416_coco_npu.pdparams \
export.benchmark=True
```
#### 3.4 全量化并产出模型 #### 3.4 全量化并产出模型
...@@ -92,14 +96,20 @@ python tools/export_model.py \ ...@@ -92,14 +96,20 @@ python tools/export_model.py \
- 单卡训练: - 单卡训练:
``` ```
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python run.py --config_path=./configs/picodet_s_qat_dis.yaml --save_dir='./output/' python run.py --config_path=./configs/picodet_npu_with_postprocess.yaml --save_dir='./output/'
``` ```
- 多卡训练: - 多卡训练:
``` ```
CUDA_VISIBLE_DEVICES=0,1,2,3 CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \ python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \
--config_path=./configs/picodet_s_qat_dis.yaml --save_dir='./output/' --config_path=./configs/picodet_npu_with_postprocess.yaml --save_dir='./output/'
```
- 不带后处理PicoDet模型训练:
```
export CUDA_VISIBLE_DEVICES=0
python run.py --config_path=./configs/picodet_npu.yaml --save_dir='./output/'
``` ```
#### 3.5 测试模型精度 #### 3.5 测试模型精度
...@@ -107,7 +117,7 @@ python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \ ...@@ -107,7 +117,7 @@ python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \
使用eval.py脚本得到模型的mAP: 使用eval.py脚本得到模型的mAP:
``` ```
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python eval.py --config_path=./configs/picodet_s_qat_dis.yaml python eval.py --config_path=./configs/picodet_npu_with_postprocess.yaml
``` ```
**注意** **注意**
......
Global:
reader_config: ./configs/picodet_reader.yml
input_list: ['image']
include_post_process: False
Evaluation: True
model_dir: ./picodet_s_416_coco_npu
model_filename: model.pdmodel
params_filename: model.pdiparams
Distillation:
alpha: 1.0
loss: l2
Quantization:
use_pact: true
activation_quantize_type: 'moving_average_abs_max'
weight_bits: 8
activation_bits: 8
quantize_op_types:
- conv2d
- depthwise_conv2d
TrainConfig:
train_iter: 8000
eval_iter: 1000
learning_rate:
type: CosineAnnealingDecay
learning_rate: 0.00001
T_max: 8000
optimizer_builder:
optimizer:
type: SGD
weight_decay: 4.0e-05
Global: Global:
reader_config: ./configs/picodet_reader.yml reader_config: ./configs/picodet_reader.yml
input_list: ['image', 'scale_factor'] input_list: ['image', 'scale_factor']
include_post_process: True
Evaluation: True Evaluation: True
model_dir: ./picodet_s_416_coco_npu/ model_dir: ./picodet_s_416_coco_npu
model_filename: model.pdmodel model_filename: model.pdmodel
params_filename: model.pdiparams params_filename: model.pdiparams
......
...@@ -7,26 +7,33 @@ TrainDataset: ...@@ -7,26 +7,33 @@ TrainDataset:
!COCODataSet !COCODataSet
image_dir: train2017 image_dir: train2017
anno_path: annotations/instances_train2017.json anno_path: annotations/instances_train2017.json
dataset_dir: /paddle/dataset/coco/ dataset_dir: dataset/coco/
EvalDataset: EvalDataset:
!COCODataSet !COCODataSet
image_dir: val2017 image_dir: val2017
anno_path: annotations/instances_val2017.json anno_path: annotations/instances_val2017.json
dataset_dir: /paddle/dataset/coco/ dataset_dir: dataset/coco/
worker_num: 6 worker_num: 0
eval_height: &eval_height 416 eval_height: &eval_height 416
eval_width: &eval_width 416 eval_width: &eval_width 416
eval_size: &eval_size [*eval_height, *eval_width] eval_size: &eval_size [*eval_height, *eval_width]
EvalReader: TrainReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True} - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {} - Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 8 batch_size: 8
shuffle: false shuffle: false
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {}
batch_size: 1
shuffle: false
...@@ -22,6 +22,8 @@ from ppdet.core.workspace import create ...@@ -22,6 +22,8 @@ from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from paddleslim.common import load_config as load_slim_config from paddleslim.common import load_config as load_slim_config
from post_process import PicoDetPostProcess
def argsparser(): def argsparser():
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -40,37 +42,7 @@ def argsparser(): ...@@ -40,37 +42,7 @@ def argsparser():
return parser return parser
def reader_wrapper(reader, input_list): def eval(metric):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
def convert_numpy_data(data, metric):
data_all = {}
data_all = {k: np.array(v) for k, v in data.items()}
if isinstance(metric, VOCMetric):
for k, v in data_all.items():
if not isinstance(v[0], np.ndarray):
tmp_list = []
for t in v:
tmp_list.append(np.array(t))
data_all[k] = np.array(tmp_list)
else:
data_all = {k: np.array(v) for k, v in data.items()}
return data_all
def eval():
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
...@@ -82,30 +54,46 @@ def eval(): ...@@ -82,30 +54,46 @@ def eval():
params_filename=global_config["params_filename"]) params_filename=global_config["params_filename"])
print('Loaded model from: {}'.format(global_config["model_dir"])) print('Loaded model from: {}'.format(global_config["model_dir"]))
metric = global_config['metric']
for batch_id, data in enumerate(val_loader): for batch_id, data in enumerate(val_loader):
data_all = convert_numpy_data(data, metric) data_all = {k: np.array(v) for k, v in data.items()}
batch_size = data_all['image'].shape[0]
data_input = {} data_input = {}
for k, v in data.items(): for k, v in data.items():
if isinstance(global_config['input_list'], list): if k in feed_target_names:
if k in global_config['input_list']: data_input[k] = np.array(v)
data_input[k] = np.array(v)
elif isinstance(global_config['input_list'], dict):
if k in global_config['input_list'].keys():
data_input[global_config['input_list'][k]] = np.array(v)
outs = exe.run(val_program, outs = exe.run(val_program,
feed=data_input, feed=data_input,
fetch_list=fetch_targets, fetch_list=fetch_targets,
return_numpy=False) return_numpy=False)
res = {} if not global_config['include_post_process']:
np_score_list, np_boxes_list = [], []
for out in outs: for i, out in enumerate(outs):
v = np.array(out) np_out = np.array(out)
if len(v.shape) > 1: if i < 4:
res['bbox'] = v num_classes = np_out.shape[-1]
else: np_score_list.append(
res['bbox_num'] = v np_out.reshape(batch_size, -1, num_classes))
else:
box_reg_shape = np_out.shape[-1]
np_boxes_list.append(
np_out.reshape(batch_size, -1, box_reg_shape))
post_processor = PicoDetPostProcess(
data_all['image'].shape[2:],
data_all['im_shape'],
data_all['scale_factor'],
score_threshold=0.01,
nms_threshold=0.6)
res = post_processor(np_score_list, np_boxes_list)
else:
res = {}
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res) metric.update(data_all, res)
if batch_id % 100 == 0: if batch_id % 100 == 0:
print('Eval iter:', batch_id) print('Eval iter:', batch_id)
...@@ -125,26 +113,15 @@ def main(): ...@@ -125,26 +113,15 @@ def main():
val_loader = create('EvalReader')(reader_cfg['EvalDataset'], val_loader = create('EvalReader')(reader_cfg['EvalDataset'],
reader_cfg['worker_num'], reader_cfg['worker_num'],
return_list=True) return_list=True)
global num_classes
num_classes = reader_cfg['num_classes']
metric = None metric = None
if reader_cfg['metric'] == 'COCO': clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} anno_file = dataset.get_anno()
anno_file = dataset.get_anno() metric = COCOMetric(
metric = COCOMetric( anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
elif reader_cfg['metric'] == 'VOC': eval(metric)
metric = VOCMetric(
label_list=dataset.get_label_list(),
class_num=reader_cfg['num_classes'],
map_type=reader_cfg['map_type'])
elif reader_cfg['metric'] == 'KeyPointTopDownCOCOEval':
anno_file = dataset.get_anno()
metric = KeyPointTopDownCOCOEval(anno_file,
len(dataset), 17, 'output_eval')
else:
raise ValueError("metric currently only supports COCO and VOC.")
global_config['metric'] = metric
eval()
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
import paddle
from ppdet.core.workspace import load_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric
import onnxruntime as ort
from post_process import PicoDetPostProcess
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--reader_config',
type=str,
default='configs/picodet_reader.yml',
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--model_path',
type=str,
default='onnx_file/picodet_s_416_npu_postprocessed.onnx',
help="onnx filepath")
parser.add_argument(
'--include_post_process',
type=bool,
default=False,
help="Whether include post_process or not.")
return parser
def eval(val_loader, metric, sess):
inputs_name = [a.name for a in sess.get_inputs()]
for batch_id, data in enumerate(val_loader):
data_all = {k: np.array(v) for k, v in data.items()}
batch_size = data_all['image'].shape[0]
data_input = {}
for k, v in data.items():
if k in inputs_name:
data_input[k] = np.array(v)
outs = sess.run(None, data_input)
if not FLAGS.include_post_process:
np_score_list, np_boxes_list = [], []
for i, out in enumerate(outs):
np_out = np.array(out)
if i < 4:
num_classes = np_out.shape[-1]
np_score_list.append(
np_out.reshape(batch_size, -1, num_classes))
else:
box_reg_shape = np_out.shape[-1]
np_boxes_list.append(
np_out.reshape(batch_size, -1, box_reg_shape))
post_processor = PicoDetPostProcess(
data_all['image'].shape[2:],
data_all['im_shape'],
data_all['scale_factor'],
score_threshold=0.01,
nms_threshold=0.6)
res = post_processor(np_score_list, np_boxes_list)
else:
res = {}
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res)
if batch_id % 100 == 0:
print('Eval iter:', batch_id)
metric.accumulate()
metric.log()
metric.reset()
def main():
reader_cfg = load_config(FLAGS.reader_config)
dataset = reader_cfg['EvalDataset']
val_loader = create('EvalReader')(reader_cfg['EvalDataset'],
reader_cfg['worker_num'],
return_list=True)
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
anno_file = dataset.get_anno()
metric = COCOMetric(
anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
providers = ['CPUExecutionProvider']
sess_options = ort.SessionOptions()
sess_options.optimized_model_filepath = "./optimize_model.onnx"
sess = ort.InferenceSession(
FLAGS.model_path, providers=providers, sess_options=sess_options)
eval(val_loader, metric, sess)
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
# DataLoader need run on cpu
paddle.set_device("cpu")
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from scipy.special import softmax
def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
"""
Args:
box_scores (N, 5): boxes in corner-form and probabilities.
iou_threshold: intersection over union threshold.
top_k: keep top_k results. If k <= 0, keep all the results.
candidate_size: only consider the candidates with the highest scores.
Returns:
picked: a list of indexes of the kept boxes
"""
scores = box_scores[:, -1]
boxes = box_scores[:, :-1]
picked = []
indexes = np.argsort(scores)
indexes = indexes[-candidate_size:]
while len(indexes) > 0:
current = indexes[-1]
picked.append(current)
if 0 < top_k == len(picked) or len(indexes) == 1:
break
current_box = boxes[current, :]
indexes = indexes[:-1]
rest_boxes = boxes[indexes, :]
iou = iou_of(
rest_boxes,
np.expand_dims(
current_box, axis=0), )
indexes = indexes[iou <= iou_threshold]
return box_scores[picked, :]
def iou_of(boxes0, boxes1, eps=1e-5):
"""Return intersection-over-union (Jaccard index) of boxes.
Args:
boxes0 (N, 4): ground truth boxes.
boxes1 (N or 1, 4): predicted boxes.
eps: a small number to avoid 0 as denominator.
Returns:
iou (N): IoU values.
"""
overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
overlap_area = area_of(overlap_left_top, overlap_right_bottom)
area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
return overlap_area / (area0 + area1 - overlap_area + eps)
def area_of(left_top, right_bottom):
"""Compute the areas of rectangles given two corners.
Args:
left_top (N, 2): left top corner.
right_bottom (N, 2): right bottom corner.
Returns:
area (N): return the area.
"""
hw = np.clip(right_bottom - left_top, 0.0, None)
return hw[..., 0] * hw[..., 1]
class PicoDetPostProcess(object):
"""
Args:
input_shape (int): network input image size
ori_shape (int): ori image shape of before padding
scale_factor (float): scale factor of ori image
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self,
input_shape,
ori_shape,
scale_factor,
strides=[8, 16, 32, 64],
score_threshold=0.4,
nms_threshold=0.5,
nms_top_k=1000,
keep_top_k=100):
self.ori_shape = ori_shape
self.input_shape = input_shape
self.scale_factor = scale_factor
self.strides = strides
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
def warp_boxes(self, boxes, ori_shape):
"""Apply transform to boxes
"""
width, height = ori_shape[1], ori_shape[0]
n = len(boxes)
if n:
# warp points
xy = np.ones((n * 4, 3))
xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
n * 4, 2) # x1y1, x2y2, x1y2, x2y1
# xy = xy @ M.T # transform
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
# create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
xy = np.concatenate(
(x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
# clip boxes
xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
return xy.astype(np.float32)
else:
return boxes
def __call__(self, scores, raw_boxes):
batch_size = raw_boxes[0].shape[0]
reg_max = int(raw_boxes[0].shape[-1] / 4 - 1)
out_boxes_num = []
out_boxes_list = []
for batch_id in range(batch_size):
# generate centers
decode_boxes = []
select_scores = []
for stride, box_distribute, score in zip(self.strides, raw_boxes,
scores):
box_distribute = box_distribute[batch_id]
score = score[batch_id]
# centers
fm_h = self.input_shape[0] / stride
fm_w = self.input_shape[1] / stride
h_range = np.arange(fm_h)
w_range = np.arange(fm_w)
ww, hh = np.meshgrid(w_range, h_range)
ct_row = (hh.flatten() + 0.5) * stride
ct_col = (ww.flatten() + 0.5) * stride
center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1)
# box distribution to distance
reg_range = np.arange(reg_max + 1)
box_distance = box_distribute.reshape((-1, reg_max + 1))
box_distance = softmax(box_distance, axis=1)
box_distance = box_distance * np.expand_dims(reg_range, axis=0)
box_distance = np.sum(box_distance, axis=1).reshape((-1, 4))
box_distance = box_distance * stride
# top K candidate
topk_idx = np.argsort(score.max(axis=1))[::-1]
topk_idx = topk_idx[:self.nms_top_k]
center = center[topk_idx]
score = score[topk_idx]
box_distance = box_distance[topk_idx]
# decode box
decode_box = center + [-1, -1, 1, 1] * box_distance
select_scores.append(score)
decode_boxes.append(decode_box)
# nms
bboxes = np.concatenate(decode_boxes, axis=0)
confidences = np.concatenate(select_scores, axis=0)
picked_box_probs = []
picked_labels = []
for class_index in range(0, confidences.shape[1]):
probs = confidences[:, class_index]
mask = probs > self.score_threshold
probs = probs[mask]
if probs.shape[0] == 0:
continue
subset_boxes = bboxes[mask, :]
box_probs = np.concatenate(
[subset_boxes, probs.reshape(-1, 1)], axis=1)
box_probs = hard_nms(
box_probs,
iou_threshold=self.nms_threshold,
top_k=self.keep_top_k, )
picked_box_probs.append(box_probs)
picked_labels.extend([class_index] * box_probs.shape[0])
if len(picked_box_probs) == 0:
out_boxes_list.append(np.empty((0, 4)))
out_boxes_num.append(0)
else:
picked_box_probs = np.concatenate(picked_box_probs)
# resize output boxes
picked_box_probs[:, :4] = self.warp_boxes(
picked_box_probs[:, :4], self.ori_shape[batch_id])
im_scale = np.concatenate([
self.scale_factor[batch_id][::-1],
self.scale_factor[batch_id][::-1]
])
picked_box_probs[:, :4] /= im_scale
# clas score box
out_boxes_list.append(
np.concatenate(
[
np.expand_dims(
np.array(picked_labels),
axis=-1), np.expand_dims(
picked_box_probs[:, 4], axis=-1),
picked_box_probs[:, :4]
],
axis=1))
out_boxes_num.append(len(picked_labels))
out_boxes_list = np.concatenate(out_boxes_list, axis=0)
out_boxes_num = np.asarray(out_boxes_num).astype(np.int32)
return {'bbox': out_boxes_list, 'bbox_num': out_boxes_num}
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
from paddleslim.quant import quant_post_static
from paddleslim.common import load_config as load_slim_config
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--save_dir',
type=str,
default='ptq_out',
help="directory to save compressed model.")
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
parser.add_argument(
'--algo', type=str, default='avg', help="post quant algo.")
return parser
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
def main():
all_config = load_slim_config(FLAGS.config_path)
global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
train_loader = create('EvalReader')(reader_cfg['TrainDataset'],
reader_cfg['worker_num'],
return_list=True)
train_loader = reader_wrapper(train_loader, global_config['input_list'])
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place)
quant_post_static(
executor=exe,
model_dir=global_config["model_dir"],
quantize_model_path=FLAGS.save_dir,
data_loader=train_loader,
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"],
batch_size=32,
batch_nums=10,
algo=FLAGS.algo,
hist_percent=0.999,
is_full_quantize=False,
bias_correction=False,
onnx_format=True,
skip_tensor_list=None)
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
paddle.set_device(FLAGS.devices)
main()
...@@ -24,6 +24,8 @@ from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval ...@@ -24,6 +24,8 @@ from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from paddleslim.common import load_config as load_slim_config from paddleslim.common import load_config as load_slim_config
from paddleslim.auto_compression import AutoCompression from paddleslim.auto_compression import AutoCompression
from post_process import PicoDetPostProcess
def argsparser(): def argsparser():
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -62,48 +64,48 @@ def reader_wrapper(reader, input_list): ...@@ -62,48 +64,48 @@ def reader_wrapper(reader, input_list):
return gen return gen
def convert_numpy_data(data, metric):
data_all = {}
data_all = {k: np.array(v) for k, v in data.items()}
if isinstance(metric, VOCMetric):
for k, v in data_all.items():
if not isinstance(v[0], np.ndarray):
tmp_list = []
for t in v:
tmp_list.append(np.array(t))
data_all[k] = np.array(tmp_list)
else:
data_all = {k: np.array(v) for k, v in data.items()}
return data_all
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
metric = global_config['metric'] metric = global_config['metric']
with tqdm( with tqdm(
total=len(val_loader), total=len(val_loader),
bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t: ncols=80) as t:
for batch_id, data in enumerate(val_loader): for data in val_loader:
data_all = convert_numpy_data(data, metric) data_all = {k: np.array(v) for k, v in data.items()}
batch_size = data_all['image'].shape[0]
data_input = {} data_input = {}
for k, v in data.items(): for k, v in data.items():
if isinstance(global_config['input_list'], list): if k in test_feed_names:
if k in test_feed_names: data_input[k] = np.array(v)
data_input[k] = np.array(v)
elif isinstance(global_config['input_list'], dict):
if k in global_config['input_list'].keys():
data_input[global_config['input_list'][k]] = np.array(v)
outs = exe.run(compiled_test_program, outs = exe.run(compiled_test_program,
feed=data_input, feed=data_input,
fetch_list=test_fetch_list, fetch_list=test_fetch_list,
return_numpy=False) return_numpy=False)
res = {} if not global_config['include_post_process']:
for out in outs: np_score_list, np_boxes_list = [], []
v = np.array(out) for i, out in enumerate(outs):
if len(v.shape) > 1: if i < 4:
res['bbox'] = v np_score_list.append(
else: np.array(out).reshape(batch_size, -1, num_classes))
res['bbox_num'] = v else:
np_boxes_list.append(
np.array(out).reshape(batch_size, -1, 32))
post_processor = PicoDetPostProcess(
data_all['image'].shape[2:],
data_all['im_shape'],
data_all['scale_factor'],
score_threshold=0.01,
nms_threshold=0.6)
res = post_processor(np_score_list, np_boxes_list)
else:
res = {}
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res) metric.update(data_all, res)
t.update() t.update()
...@@ -111,9 +113,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): ...@@ -111,9 +113,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
metric.log() metric.log()
map_res = metric.get_results() map_res = metric.get_results()
metric.reset() metric.reset()
map_key = 'keypoint' if 'arch' in global_config and global_config[ return map_res['bbox'][0]
'arch'] == 'keypoint' else 'bbox'
return map_res[map_key][0]
def main(): def main():
...@@ -123,9 +123,9 @@ def main(): ...@@ -123,9 +123,9 @@ def main():
global_config = all_config["Global"] global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config']) reader_cfg = load_config(global_config['reader_config'])
train_loader = create('EvalReader')(reader_cfg['TrainDataset'], train_loader = create('TrainReader')(reader_cfg['TrainDataset'],
reader_cfg['worker_num'], reader_cfg['worker_num'],
return_list=True) return_list=True)
train_loader = reader_wrapper(train_loader, global_config['input_list']) train_loader = reader_wrapper(train_loader, global_config['input_list'])
if 'Evaluation' in global_config.keys() and global_config[ if 'Evaluation' in global_config.keys() and global_config[
...@@ -139,23 +139,12 @@ def main(): ...@@ -139,23 +139,12 @@ def main():
reader_cfg['worker_num'], reader_cfg['worker_num'],
batch_sampler=_eval_batch_sampler, batch_sampler=_eval_batch_sampler,
return_list=True) return_list=True)
metric = None global num_classes
if reader_cfg['metric'] == 'COCO': num_classes = reader_cfg['num_classes']
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
anno_file = dataset.get_anno() anno_file = dataset.get_anno()
metric = COCOMetric( metric = COCOMetric(
anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox') anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
elif reader_cfg['metric'] == 'VOC':
metric = VOCMetric(
label_list=dataset.get_label_list(),
class_num=reader_cfg['num_classes'],
map_type=reader_cfg['map_type'])
elif reader_cfg['metric'] == 'KeyPointTopDownCOCOEval':
anno_file = dataset.get_anno()
metric = KeyPointTopDownCOCOEval(anno_file,
len(dataset), 17, 'output_eval')
else:
raise ValueError("metric currently only supports COCO and VOC.")
global_config['metric'] = metric global_config['metric'] = metric
else: else:
eval_func = None eval_func = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册