diff --git a/example/full_quantization/detection/README.md b/example/full_quantization/picodet/README.md similarity index 79% rename from example/full_quantization/detection/README.md rename to example/full_quantization/picodet/README.md index 1d5631467be691b1668d406c9e6894e1dfeff0bf..4957801528380611d5c132a91f35b527a82741f2 100644 --- a/example/full_quantization/detection/README.md +++ b/example/full_quantization/picodet/README.md @@ -23,7 +23,7 @@ | 模型 | 策略 | mAP | TRT-FP32 | TRT-FP16 | TRT-INT8 | 配置文件 | 模型 | | :-------- |:-------- |:--------: | :----------------: | :----------------: | :---------------: | :----------------------: | :---------------------: | | PicoDet-S-NPU | Baseline | 30.1 | - | - | - | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_416_coco_npu.yml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_416_coco_npu.tar) | -| PicoDet-S-NPU | 量化训练 | 29.7 | - | - | - | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/full_quantization/detection/configs/picodet_s_qat_dis.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_npu_quant.tar) | +| PicoDet-S-NPU | 量化训练 | 29.7 | - | - | - | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/full_quantization/picodet/configs/picodet_npu_with_postprocess.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_npu_quant.tar) | - mAP的指标均在COCO val2017数据集中评测得到,IoU=0.5:0.95。 @@ -31,7 +31,7 @@ #### 3.1 准备环境 - PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) -- PaddleSlim >= 2.3 +- PaddleSlim >= 2.3.4 - PaddleDet >= 2.4 - opencv-python @@ -67,9 +67,6 @@ pip install paddledet 预测模型的格式为:`model.pdmodel` 和 `model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。 -注:其他像`__model__`和`__params__`分别对应`model.pdmodel` 和 `model.pdiparams`文件。 - - 根据[PaddleDetection文档](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/docs/tutorials/GETTING_STARTED_cn.md#8-%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA) 导出Inference模型,具体可参考下方PicoDet-S-NPU模型的导出示例: - 下载代码 ``` @@ -77,13 +74,20 @@ git clone https://github.com/PaddlePaddle/PaddleDetection.git ``` - 导出预测模型 -PicoDet-S-NPU模型,包含NMS:如快速体验,可直接下载[PicoDet-S-NPU导出模型](https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_416_coco_npu.tar) +PicoDet-S-NPU模型,包含后处理:如快速体验,可直接下载[PicoDet-S-NPU导出模型](https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_416_coco_npu.tar) ```shell python tools/export_model.py \ -c configs/picodet/picodet_s_416_coco_npu.yml \ -o weights=https://paddledet.bj.bcebos.com/models/picodet_s_416_coco_npu.pdparams \ ``` +导出PicoDet-S-NPU不带后处理模型: +```shell +python tools/export_model.py \ + -c configs/picodet/picodet_s_416_coco_npu.yml \ + -o weights=https://paddledet.bj.bcebos.com/models/picodet_s_416_coco_npu.pdparams \ + export.benchmark=True +``` #### 3.4 全量化并产出模型 @@ -92,14 +96,20 @@ python tools/export_model.py \ - 单卡训练: ``` export CUDA_VISIBLE_DEVICES=0 -python run.py --config_path=./configs/picodet_s_qat_dis.yaml --save_dir='./output/' +python run.py --config_path=./configs/picodet_npu_with_postprocess.yaml --save_dir='./output/' ``` - 多卡训练: ``` CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \ - --config_path=./configs/picodet_s_qat_dis.yaml --save_dir='./output/' + --config_path=./configs/picodet_npu_with_postprocess.yaml --save_dir='./output/' +``` + +- 不带后处理PicoDet模型训练: +``` +export CUDA_VISIBLE_DEVICES=0 +python run.py --config_path=./configs/picodet_npu.yaml --save_dir='./output/' ``` #### 3.5 测试模型精度 @@ -107,7 +117,7 @@ python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \ 使用eval.py脚本得到模型的mAP: ``` export CUDA_VISIBLE_DEVICES=0 -python eval.py --config_path=./configs/picodet_s_qat_dis.yaml +python eval.py --config_path=./configs/picodet_npu_with_postprocess.yaml ``` **注意**: diff --git a/example/full_quantization/picodet/configs/picodet_npu.yaml b/example/full_quantization/picodet/configs/picodet_npu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..37f20d7b70c2869cd51e2587e58cd7a822ea3f31 --- /dev/null +++ b/example/full_quantization/picodet/configs/picodet_npu.yaml @@ -0,0 +1,35 @@ +Global: + reader_config: ./configs/picodet_reader.yml + input_list: ['image'] + include_post_process: False + Evaluation: True + model_dir: ./picodet_s_416_coco_npu + model_filename: model.pdmodel + params_filename: model.pdiparams + +Distillation: + alpha: 1.0 + loss: l2 + +Quantization: + use_pact: true + activation_quantize_type: 'moving_average_abs_max' + weight_bits: 8 + activation_bits: 8 + quantize_op_types: + - conv2d + - depthwise_conv2d + +TrainConfig: + train_iter: 8000 + eval_iter: 1000 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.00001 + T_max: 8000 + optimizer_builder: + optimizer: + type: SGD + weight_decay: 4.0e-05 + + diff --git a/example/full_quantization/detection/configs/picodet_s_qat_dis.yaml b/example/full_quantization/picodet/configs/picodet_npu_with_postprocess.yaml similarity index 90% rename from example/full_quantization/detection/configs/picodet_s_qat_dis.yaml rename to example/full_quantization/picodet/configs/picodet_npu_with_postprocess.yaml index 8d9091569ca49138625a4f00f54a8db94cde818c..4064df0db2063b38203b4617db771df492d18fc8 100644 --- a/example/full_quantization/detection/configs/picodet_s_qat_dis.yaml +++ b/example/full_quantization/picodet/configs/picodet_npu_with_postprocess.yaml @@ -1,8 +1,9 @@ Global: reader_config: ./configs/picodet_reader.yml input_list: ['image', 'scale_factor'] + include_post_process: True Evaluation: True - model_dir: ./picodet_s_416_coco_npu/ + model_dir: ./picodet_s_416_coco_npu model_filename: model.pdmodel params_filename: model.pdiparams diff --git a/example/full_quantization/detection/configs/picodet_reader.yml b/example/full_quantization/picodet/configs/picodet_reader.yml similarity index 65% rename from example/full_quantization/detection/configs/picodet_reader.yml rename to example/full_quantization/picodet/configs/picodet_reader.yml index 7d2ae4d1ef8f3120fa13758af606f58ddcef2c9e..88d0ddb58ad21ccc1d3f8221a236706fc9264c82 100644 --- a/example/full_quantization/detection/configs/picodet_reader.yml +++ b/example/full_quantization/picodet/configs/picodet_reader.yml @@ -7,26 +7,33 @@ TrainDataset: !COCODataSet image_dir: train2017 anno_path: annotations/instances_train2017.json - dataset_dir: /paddle/dataset/coco/ + dataset_dir: dataset/coco/ EvalDataset: !COCODataSet image_dir: val2017 anno_path: annotations/instances_val2017.json - dataset_dir: /paddle/dataset/coco/ + dataset_dir: dataset/coco/ -worker_num: 6 +worker_num: 0 eval_height: &eval_height 416 eval_width: &eval_width 416 eval_size: &eval_size [*eval_height, *eval_width] -EvalReader: +TrainReader: sample_transforms: - Decode: {} - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True} - Permute: {} - batch_transforms: - - PadBatch: {pad_to_stride: 32} batch_size: 8 shuffle: false + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} + - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True} + - Permute: {} + batch_size: 1 + shuffle: false diff --git a/example/full_quantization/detection/eval.py b/example/full_quantization/picodet/eval.py similarity index 55% rename from example/full_quantization/detection/eval.py rename to example/full_quantization/picodet/eval.py index d6c7d49daf8ccc43ad914eb56dd7727ae3e1f00b..3c3ff5501fb27218879cac2e63c15aa7f92cb0fc 100644 --- a/example/full_quantization/detection/eval.py +++ b/example/full_quantization/picodet/eval.py @@ -22,6 +22,8 @@ from ppdet.core.workspace import create from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval from paddleslim.common import load_config as load_slim_config +from post_process import PicoDetPostProcess + def argsparser(): parser = argparse.ArgumentParser(description=__doc__) @@ -40,37 +42,7 @@ def argsparser(): return parser -def reader_wrapper(reader, input_list): - def gen(): - for data in reader: - in_dict = {} - if isinstance(input_list, list): - for input_name in input_list: - in_dict[input_name] = data[input_name] - elif isinstance(input_list, dict): - for input_name in input_list.keys(): - in_dict[input_list[input_name]] = data[input_name] - yield in_dict - - return gen - - -def convert_numpy_data(data, metric): - data_all = {} - data_all = {k: np.array(v) for k, v in data.items()} - if isinstance(metric, VOCMetric): - for k, v in data_all.items(): - if not isinstance(v[0], np.ndarray): - tmp_list = [] - for t in v: - tmp_list.append(np.array(t)) - data_all[k] = np.array(tmp_list) - else: - data_all = {k: np.array(v) for k, v in data.items()} - return data_all - - -def eval(): +def eval(metric): place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() exe = paddle.static.Executor(place) @@ -82,30 +54,46 @@ def eval(): params_filename=global_config["params_filename"]) print('Loaded model from: {}'.format(global_config["model_dir"])) - metric = global_config['metric'] for batch_id, data in enumerate(val_loader): - data_all = convert_numpy_data(data, metric) + data_all = {k: np.array(v) for k, v in data.items()} + batch_size = data_all['image'].shape[0] data_input = {} for k, v in data.items(): - if isinstance(global_config['input_list'], list): - if k in global_config['input_list']: - data_input[k] = np.array(v) - elif isinstance(global_config['input_list'], dict): - if k in global_config['input_list'].keys(): - data_input[global_config['input_list'][k]] = np.array(v) + if k in feed_target_names: + data_input[k] = np.array(v) outs = exe.run(val_program, feed=data_input, fetch_list=fetch_targets, return_numpy=False) - res = {} - - for out in outs: - v = np.array(out) - if len(v.shape) > 1: - res['bbox'] = v - else: - res['bbox_num'] = v + if not global_config['include_post_process']: + np_score_list, np_boxes_list = [], [] + for i, out in enumerate(outs): + np_out = np.array(out) + if i < 4: + num_classes = np_out.shape[-1] + np_score_list.append( + np_out.reshape(batch_size, -1, num_classes)) + else: + box_reg_shape = np_out.shape[-1] + np_boxes_list.append( + np_out.reshape(batch_size, -1, box_reg_shape)) + post_processor = PicoDetPostProcess( + data_all['image'].shape[2:], + data_all['im_shape'], + data_all['scale_factor'], + score_threshold=0.01, + nms_threshold=0.6) + res = post_processor(np_score_list, np_boxes_list) + else: + res = {} + for out in outs: + v = np.array(out) + if len(v.shape) > 1: + res['bbox'] = v + else: + res['bbox_num'] = v + metric.update(data_all, res) if batch_id % 100 == 0: print('Eval iter:', batch_id) @@ -125,26 +113,15 @@ def main(): val_loader = create('EvalReader')(reader_cfg['EvalDataset'], reader_cfg['worker_num'], return_list=True) + global num_classes + num_classes = reader_cfg['num_classes'] metric = None - if reader_cfg['metric'] == 'COCO': - clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} - anno_file = dataset.get_anno() - metric = COCOMetric( - anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox') - elif reader_cfg['metric'] == 'VOC': - metric = VOCMetric( - label_list=dataset.get_label_list(), - class_num=reader_cfg['num_classes'], - map_type=reader_cfg['map_type']) - elif reader_cfg['metric'] == 'KeyPointTopDownCOCOEval': - anno_file = dataset.get_anno() - metric = KeyPointTopDownCOCOEval(anno_file, - len(dataset), 17, 'output_eval') - else: - raise ValueError("metric currently only supports COCO and VOC.") - global_config['metric'] = metric - - eval() + clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} + anno_file = dataset.get_anno() + metric = COCOMetric( + anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox') + + eval(metric) if __name__ == '__main__': diff --git a/example/full_quantization/picodet/onnxruntime_eval.py b/example/full_quantization/picodet/onnxruntime_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..d4d6c791ec710e03997bfa5427c5c92739db6999 --- /dev/null +++ b/example/full_quantization/picodet/onnxruntime_eval.py @@ -0,0 +1,128 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import argparse +import paddle +from ppdet.core.workspace import load_config +from ppdet.core.workspace import create +from ppdet.metrics import COCOMetric +import onnxruntime as ort + +from post_process import PicoDetPostProcess + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--reader_config', + type=str, + default='configs/picodet_reader.yml', + help="path of compression strategy config.", + required=True) + parser.add_argument( + '--model_path', + type=str, + default='onnx_file/picodet_s_416_npu_postprocessed.onnx', + help="onnx filepath") + parser.add_argument( + '--include_post_process', + type=bool, + default=False, + help="Whether include post_process or not.") + + return parser + + +def eval(val_loader, metric, sess): + inputs_name = [a.name for a in sess.get_inputs()] + + for batch_id, data in enumerate(val_loader): + data_all = {k: np.array(v) for k, v in data.items()} + batch_size = data_all['image'].shape[0] + data_input = {} + for k, v in data.items(): + if k in inputs_name: + data_input[k] = np.array(v) + + outs = sess.run(None, data_input) + + if not FLAGS.include_post_process: + np_score_list, np_boxes_list = [], [] + for i, out in enumerate(outs): + np_out = np.array(out) + if i < 4: + num_classes = np_out.shape[-1] + np_score_list.append( + np_out.reshape(batch_size, -1, num_classes)) + else: + box_reg_shape = np_out.shape[-1] + np_boxes_list.append( + np_out.reshape(batch_size, -1, box_reg_shape)) + post_processor = PicoDetPostProcess( + data_all['image'].shape[2:], + data_all['im_shape'], + data_all['scale_factor'], + score_threshold=0.01, + nms_threshold=0.6) + res = post_processor(np_score_list, np_boxes_list) + else: + res = {} + for out in outs: + v = np.array(out) + if len(v.shape) > 1: + res['bbox'] = v + else: + res['bbox_num'] = v + + metric.update(data_all, res) + if batch_id % 100 == 0: + print('Eval iter:', batch_id) + metric.accumulate() + metric.log() + metric.reset() + + +def main(): + + reader_cfg = load_config(FLAGS.reader_config) + + dataset = reader_cfg['EvalDataset'] + val_loader = create('EvalReader')(reader_cfg['EvalDataset'], + reader_cfg['worker_num'], + return_list=True) + clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} + anno_file = dataset.get_anno() + metric = COCOMetric( + anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox') + + providers = ['CPUExecutionProvider'] + sess_options = ort.SessionOptions() + sess_options.optimized_model_filepath = "./optimize_model.onnx" + sess = ort.InferenceSession( + FLAGS.model_path, providers=providers, sess_options=sess_options) + eval(val_loader, metric, sess) + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + + # DataLoader need run on cpu + paddle.set_device("cpu") + + main() diff --git a/example/full_quantization/picodet/post_process.py b/example/full_quantization/picodet/post_process.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc6874a1f1c8ef4606b264b16e29f6b70c8a1c1 --- /dev/null +++ b/example/full_quantization/picodet/post_process.py @@ -0,0 +1,227 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from scipy.special import softmax + + +def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200): + """ + Args: + box_scores (N, 5): boxes in corner-form and probabilities. + iou_threshold: intersection over union threshold. + top_k: keep top_k results. If k <= 0, keep all the results. + candidate_size: only consider the candidates with the highest scores. + Returns: + picked: a list of indexes of the kept boxes + """ + scores = box_scores[:, -1] + boxes = box_scores[:, :-1] + picked = [] + indexes = np.argsort(scores) + indexes = indexes[-candidate_size:] + while len(indexes) > 0: + current = indexes[-1] + picked.append(current) + if 0 < top_k == len(picked) or len(indexes) == 1: + break + current_box = boxes[current, :] + indexes = indexes[:-1] + rest_boxes = boxes[indexes, :] + iou = iou_of( + rest_boxes, + np.expand_dims( + current_box, axis=0), ) + indexes = indexes[iou <= iou_threshold] + + return box_scores[picked, :] + + +def iou_of(boxes0, boxes1, eps=1e-5): + """Return intersection-over-union (Jaccard index) of boxes. + Args: + boxes0 (N, 4): ground truth boxes. + boxes1 (N or 1, 4): predicted boxes. + eps: a small number to avoid 0 as denominator. + Returns: + iou (N): IoU values. + """ + overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2]) + overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:]) + + overlap_area = area_of(overlap_left_top, overlap_right_bottom) + area0 = area_of(boxes0[..., :2], boxes0[..., 2:]) + area1 = area_of(boxes1[..., :2], boxes1[..., 2:]) + return overlap_area / (area0 + area1 - overlap_area + eps) + + +def area_of(left_top, right_bottom): + """Compute the areas of rectangles given two corners. + Args: + left_top (N, 2): left top corner. + right_bottom (N, 2): right bottom corner. + Returns: + area (N): return the area. + """ + hw = np.clip(right_bottom - left_top, 0.0, None) + return hw[..., 0] * hw[..., 1] + + +class PicoDetPostProcess(object): + """ + Args: + input_shape (int): network input image size + ori_shape (int): ori image shape of before padding + scale_factor (float): scale factor of ori image + enable_mkldnn (bool): whether to open MKLDNN + """ + + def __init__(self, + input_shape, + ori_shape, + scale_factor, + strides=[8, 16, 32, 64], + score_threshold=0.4, + nms_threshold=0.5, + nms_top_k=1000, + keep_top_k=100): + self.ori_shape = ori_shape + self.input_shape = input_shape + self.scale_factor = scale_factor + self.strides = strides + self.score_threshold = score_threshold + self.nms_threshold = nms_threshold + self.nms_top_k = nms_top_k + self.keep_top_k = keep_top_k + + def warp_boxes(self, boxes, ori_shape): + """Apply transform to boxes + """ + width, height = ori_shape[1], ori_shape[0] + n = len(boxes) + if n: + # warp points + xy = np.ones((n * 4, 3)) + xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape( + n * 4, 2) # x1y1, x2y2, x1y2, x2y1 + # xy = xy @ M.T # transform + xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale + # create new boxes + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + xy = np.concatenate( + (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + # clip boxes + xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width) + xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height) + return xy.astype(np.float32) + else: + return boxes + + def __call__(self, scores, raw_boxes): + batch_size = raw_boxes[0].shape[0] + reg_max = int(raw_boxes[0].shape[-1] / 4 - 1) + out_boxes_num = [] + out_boxes_list = [] + for batch_id in range(batch_size): + # generate centers + decode_boxes = [] + select_scores = [] + for stride, box_distribute, score in zip(self.strides, raw_boxes, + scores): + box_distribute = box_distribute[batch_id] + score = score[batch_id] + # centers + fm_h = self.input_shape[0] / stride + fm_w = self.input_shape[1] / stride + h_range = np.arange(fm_h) + w_range = np.arange(fm_w) + ww, hh = np.meshgrid(w_range, h_range) + ct_row = (hh.flatten() + 0.5) * stride + ct_col = (ww.flatten() + 0.5) * stride + center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1) + + # box distribution to distance + reg_range = np.arange(reg_max + 1) + box_distance = box_distribute.reshape((-1, reg_max + 1)) + box_distance = softmax(box_distance, axis=1) + box_distance = box_distance * np.expand_dims(reg_range, axis=0) + box_distance = np.sum(box_distance, axis=1).reshape((-1, 4)) + box_distance = box_distance * stride + + # top K candidate + topk_idx = np.argsort(score.max(axis=1))[::-1] + topk_idx = topk_idx[:self.nms_top_k] + center = center[topk_idx] + score = score[topk_idx] + box_distance = box_distance[topk_idx] + + # decode box + decode_box = center + [-1, -1, 1, 1] * box_distance + + select_scores.append(score) + decode_boxes.append(decode_box) + + # nms + bboxes = np.concatenate(decode_boxes, axis=0) + confidences = np.concatenate(select_scores, axis=0) + picked_box_probs = [] + picked_labels = [] + for class_index in range(0, confidences.shape[1]): + probs = confidences[:, class_index] + mask = probs > self.score_threshold + probs = probs[mask] + if probs.shape[0] == 0: + continue + subset_boxes = bboxes[mask, :] + box_probs = np.concatenate( + [subset_boxes, probs.reshape(-1, 1)], axis=1) + box_probs = hard_nms( + box_probs, + iou_threshold=self.nms_threshold, + top_k=self.keep_top_k, ) + picked_box_probs.append(box_probs) + picked_labels.extend([class_index] * box_probs.shape[0]) + + if len(picked_box_probs) == 0: + out_boxes_list.append(np.empty((0, 4))) + out_boxes_num.append(0) + + else: + picked_box_probs = np.concatenate(picked_box_probs) + + # resize output boxes + picked_box_probs[:, :4] = self.warp_boxes( + picked_box_probs[:, :4], self.ori_shape[batch_id]) + im_scale = np.concatenate([ + self.scale_factor[batch_id][::-1], + self.scale_factor[batch_id][::-1] + ]) + picked_box_probs[:, :4] /= im_scale + # clas score box + out_boxes_list.append( + np.concatenate( + [ + np.expand_dims( + np.array(picked_labels), + axis=-1), np.expand_dims( + picked_box_probs[:, 4], axis=-1), + picked_box_probs[:, :4] + ], + axis=1)) + out_boxes_num.append(len(picked_labels)) + + out_boxes_list = np.concatenate(out_boxes_list, axis=0) + out_boxes_num = np.asarray(out_boxes_num).astype(np.int32) + return {'bbox': out_boxes_list, 'bbox_num': out_boxes_num} \ No newline at end of file diff --git a/example/full_quantization/picodet/post_quant.py b/example/full_quantization/picodet/post_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..afbe3dbaf4951cf6236c8db6bae58d22b77aba6a --- /dev/null +++ b/example/full_quantization/picodet/post_quant.py @@ -0,0 +1,102 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import argparse +import paddle +from ppdet.core.workspace import load_config, merge_config +from ppdet.core.workspace import create +from paddleslim.quant import quant_post_static +from paddleslim.common import load_config as load_slim_config + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of compression strategy config.", + required=True) + parser.add_argument( + '--save_dir', + type=str, + default='ptq_out', + help="directory to save compressed model.") + parser.add_argument( + '--devices', + type=str, + default='gpu', + help="which device used to compress.") + parser.add_argument( + '--algo', type=str, default='avg', help="post quant algo.") + + return parser + + +def reader_wrapper(reader, input_list): + def gen(): + for data in reader: + in_dict = {} + if isinstance(input_list, list): + for input_name in input_list: + in_dict[input_name] = data[input_name] + elif isinstance(input_list, dict): + for input_name in input_list.keys(): + in_dict[input_list[input_name]] = data[input_name] + yield in_dict + + return gen + + +def main(): + all_config = load_slim_config(FLAGS.config_path) + global_config = all_config["Global"] + reader_cfg = load_config(global_config['reader_config']) + + train_loader = create('EvalReader')(reader_cfg['TrainDataset'], + reader_cfg['worker_num'], + return_list=True) + train_loader = reader_wrapper(train_loader, global_config['input_list']) + + place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() + exe = paddle.static.Executor(place) + quant_post_static( + executor=exe, + model_dir=global_config["model_dir"], + quantize_model_path=FLAGS.save_dir, + data_loader=train_loader, + model_filename=global_config["model_filename"], + params_filename=global_config["params_filename"], + batch_size=32, + batch_nums=10, + algo=FLAGS.algo, + hist_percent=0.999, + is_full_quantize=False, + bias_correction=False, + onnx_format=True, + skip_tensor_list=None) + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + + assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu'] + paddle.set_device(FLAGS.devices) + + main() diff --git a/example/full_quantization/detection/run.py b/example/full_quantization/picodet/run.py similarity index 64% rename from example/full_quantization/detection/run.py rename to example/full_quantization/picodet/run.py index b05b921fa81c98ee3b8bfa2f8654f9744acc475d..094e67d21d3a88cc81a02cef4072c13058f4906d 100644 --- a/example/full_quantization/detection/run.py +++ b/example/full_quantization/picodet/run.py @@ -24,6 +24,8 @@ from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval from paddleslim.common import load_config as load_slim_config from paddleslim.auto_compression import AutoCompression +from post_process import PicoDetPostProcess + def argsparser(): parser = argparse.ArgumentParser(description=__doc__) @@ -62,48 +64,48 @@ def reader_wrapper(reader, input_list): return gen -def convert_numpy_data(data, metric): - data_all = {} - data_all = {k: np.array(v) for k, v in data.items()} - if isinstance(metric, VOCMetric): - for k, v in data_all.items(): - if not isinstance(v[0], np.ndarray): - tmp_list = [] - for t in v: - tmp_list.append(np.array(t)) - data_all[k] = np.array(tmp_list) - else: - data_all = {k: np.array(v) for k, v in data.items()} - return data_all - - def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): metric = global_config['metric'] with tqdm( total=len(val_loader), bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', ncols=80) as t: - for batch_id, data in enumerate(val_loader): - data_all = convert_numpy_data(data, metric) + for data in val_loader: + data_all = {k: np.array(v) for k, v in data.items()} + batch_size = data_all['image'].shape[0] data_input = {} for k, v in data.items(): - if isinstance(global_config['input_list'], list): - if k in test_feed_names: - data_input[k] = np.array(v) - elif isinstance(global_config['input_list'], dict): - if k in global_config['input_list'].keys(): - data_input[global_config['input_list'][k]] = np.array(v) + if k in test_feed_names: + data_input[k] = np.array(v) + outs = exe.run(compiled_test_program, feed=data_input, fetch_list=test_fetch_list, return_numpy=False) - res = {} - for out in outs: - v = np.array(out) - if len(v.shape) > 1: - res['bbox'] = v - else: - res['bbox_num'] = v + if not global_config['include_post_process']: + np_score_list, np_boxes_list = [], [] + for i, out in enumerate(outs): + if i < 4: + np_score_list.append( + np.array(out).reshape(batch_size, -1, num_classes)) + else: + np_boxes_list.append( + np.array(out).reshape(batch_size, -1, 32)) + post_processor = PicoDetPostProcess( + data_all['image'].shape[2:], + data_all['im_shape'], + data_all['scale_factor'], + score_threshold=0.01, + nms_threshold=0.6) + res = post_processor(np_score_list, np_boxes_list) + else: + res = {} + for out in outs: + v = np.array(out) + if len(v.shape) > 1: + res['bbox'] = v + else: + res['bbox_num'] = v metric.update(data_all, res) t.update() @@ -111,9 +113,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): metric.log() map_res = metric.get_results() metric.reset() - map_key = 'keypoint' if 'arch' in global_config and global_config[ - 'arch'] == 'keypoint' else 'bbox' - return map_res[map_key][0] + return map_res['bbox'][0] def main(): @@ -123,9 +123,9 @@ def main(): global_config = all_config["Global"] reader_cfg = load_config(global_config['reader_config']) - train_loader = create('EvalReader')(reader_cfg['TrainDataset'], - reader_cfg['worker_num'], - return_list=True) + train_loader = create('TrainReader')(reader_cfg['TrainDataset'], + reader_cfg['worker_num'], + return_list=True) train_loader = reader_wrapper(train_loader, global_config['input_list']) if 'Evaluation' in global_config.keys() and global_config[ @@ -139,23 +139,12 @@ def main(): reader_cfg['worker_num'], batch_sampler=_eval_batch_sampler, return_list=True) - metric = None - if reader_cfg['metric'] == 'COCO': - clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} - anno_file = dataset.get_anno() - metric = COCOMetric( - anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox') - elif reader_cfg['metric'] == 'VOC': - metric = VOCMetric( - label_list=dataset.get_label_list(), - class_num=reader_cfg['num_classes'], - map_type=reader_cfg['map_type']) - elif reader_cfg['metric'] == 'KeyPointTopDownCOCOEval': - anno_file = dataset.get_anno() - metric = KeyPointTopDownCOCOEval(anno_file, - len(dataset), 17, 'output_eval') - else: - raise ValueError("metric currently only supports COCO and VOC.") + global num_classes + num_classes = reader_cfg['num_classes'] + clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} + anno_file = dataset.get_anno() + metric = COCOMetric( + anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox') global_config['metric'] = metric else: eval_func = None