From f03948d884a3d639880ebd7499ccee4d72f75a40 Mon Sep 17 00:00:00 2001 From: JYChen Date: Sat, 13 Nov 2021 13:40:59 +0800 Subject: [PATCH] fix picodet error in det_keypoint_unite_infer (#4561) * fix picodet error in det_keypoint_unite_infer * Optimize configuration file path --- configs/keypoint/tiny_pose/README.md | 20 +-- .../{keypoint => }/tinypose_128x96.yml | 2 +- .../{keypoint => }/tinypose_256x192.yml | 2 +- .../picodet_s_192_pedestrian.yml | 144 ++++++++++++++++++ .../picodet_s_320_pedestrian.yml | 0 deploy/python/det_keypoint_unite_infer.py | 27 ++-- 6 files changed, 171 insertions(+), 24 deletions(-) rename configs/keypoint/tiny_pose/{keypoint => }/tinypose_128x96.yml (99%) rename configs/keypoint/tiny_pose/{keypoint => }/tinypose_256x192.yml (99%) create mode 100644 configs/picodet/application/pedestrian_detection/picodet_s_192_pedestrian.yml rename configs/{keypoint/tiny_pose => picodet/application}/pedestrian_detection/picodet_s_320_pedestrian.yml (100%) diff --git a/configs/keypoint/tiny_pose/README.md b/configs/keypoint/tiny_pose/README.md index d276f5385..043fe7a50 100644 --- a/configs/keypoint/tiny_pose/README.md +++ b/configs/keypoint/tiny_pose/README.md @@ -11,14 +11,14 @@ PP-TinyPose是PaddleDetecion针对移动端设备优化的实时姿态检测模 ### 姿态检测模型 | 模型 | 输入尺寸 | AP (coco val) | 单人推理耗时 (FP32)| 单人推理耗时(FP16) | 配置文件 | 模型权重 | 预测部署模型 | Paddle-Lite部署模型(FP32) | Paddle-Lite部署模型(FP16)| | :------------------------ | :-------: | :------: | :------: |:---: | :---: | :---: | :---: | :---: | :---: | -| PP-TinyPose | 128*96 | 58.1 | 4.57ms | 3.27ms | [Config](./keypoint/tinypose_128x96.yml) |[Model](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_128x96.pdparams) | [预测部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_128x96.tar) | [Lite部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_128x96.nb) | [Lite部署模型(FP16)](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_128x96_fp16.nb) | -| PP-TinyPose | 256*192 | 68.8 | 14.07ms | 8.33ms | [Config](./keypoint/tinypose_256x192.yml) | [Model](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_256x192.pdparams) | [预测部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_256x192.nb) | [Lite部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_256x192.tar) | [Lite部署模型(FP16)](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_256x192_fp16.nb) | +| PP-TinyPose | 128*96 | 58.1 | 4.57ms | 3.27ms | [Config](./tinypose_128x96.yml) |[Model](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_128x96.pdparams) | [预测部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_128x96.tar) | [Lite部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_128x96.nb) | [Lite部署模型(FP16)](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_128x96_fp16.nb) | +| PP-TinyPose | 256*192 | 68.8 | 14.07ms | 8.33ms | [Config](./tinypose_256x192.yml) | [Model](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_256x192.pdparams) | [预测部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_256x192.nb) | [Lite部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_256x192.tar) | [Lite部署模型(FP16)](https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_256x192_fp16.nb) | ### 行人检测模型 | 模型 | 输入尺寸 | mAP (coco val) | 平均推理耗时 (FP32) | 平均推理耗时 (FP16) | 配置文件 | 模型权重 | 预测部署模型 | Paddle-Lite部署模型(FP32) | Paddle-Lite部署模型(FP16)| | :------------------------ | :-------: | :------: | :------: | :---: | :---: | :---: | :---: | :---: | :---: | -| PicoDet-S-Pedestrian | 192*192 | 29.0 | 4.30ms | 2.37ms | [Config](./pedestrian_detection/picodet_s_192_pedestrian.yml) |[Model](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_192_pedestrian.pdparams) | [预测部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_192_pedestrian.tar) | [Lite部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_192_pedestrian.nb) | [Lite部署模型(FP16)](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_192_pedestrian_fp16.nb) | -| PicoDet-S-Pedestrian | 320*320 | 38.5 | 10.26ms | 6.30ms | [Config](./pedestrian_detection/picodet_s_320_pedestrian.yml) | [Model](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_320_pedestrian.pdparams) | [预测部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_320_pedestrian.tar) | [Lite部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_320_pedestrian.nb) | [Lite部署模型(FP16)](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_320_pedestrian_fp16.nb) | +| PicoDet-S-Pedestrian | 192*192 | 29.0 | 4.30ms | 2.37ms | [Config](../../picodet/application/pedestrian_detection/picodet_s_192_pedestrian.yml) |[Model](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_192_pedestrian.pdparams) | [预测部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_192_pedestrian.tar) | [Lite部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_192_pedestrian.nb) | [Lite部署模型(FP16)](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_192_pedestrian_fp16.nb) | +| PicoDet-S-Pedestrian | 320*320 | 38.5 | 10.26ms | 6.30ms | [Config](../../picodet/application/pedestrian_detection/picodet_s_320_pedestrian.yml) | [Model](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_320_pedestrian.pdparams) | [预测部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_320_pedestrian.tar) | [Lite部署模型](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_320_pedestrian.nb) | [Lite部署模型(FP16)](https://bj.bcebos.com/v1/paddledet/models/keypoint/picodet_s_320_pedestrian_fp16.nb) | **说明** @@ -86,19 +86,19 @@ AI Challenger Description: 利用转换为`COCO`形式的合并数据标注,执行模型训练: ```bash # 姿态检测模型 -python3 -m paddle.distributed.launch tools/train.py -c keypoint/tinypose_128x96.yml +python3 -m paddle.distributed.launch tools/train.py -c configs/keypoint/tiny_pose/tinypose_128x96.yml # 行人检测模型 -python3 -m paddle.distributed.launch tools/train.py -c pedestrian_detection/picodet_s_320_pedestrian.yml +python3 -m paddle.distributed.launch tools/train.py -c configs/picodet/application/pedestrian_detection/picodet_s_320_pedestrian.yml ``` ## 部署流程 ### 实现部署预测 1. 通过以下命令将训练得到的模型导出: ```bash -python3 tools/export_model.py -c keypoint/picodet_s_192_pedestrian.yml --output_dir=outut_inference -o weights=output/picodet_s_192_pedestrian/model_final +python3 tools/export_model.py -c configs/picodet/application/pedestrian_detection/picodet_s_192_pedestrian.yml --output_dir=outut_inference -o weights=output/picodet_s_192_pedestrian/model_final -python3 tools/export_model.py -c keypoint/tinypose_128x96.yml --output_dir=outut_inference -o weights=output/tinypose_128x96/model_final +python3 tools/export_model.py -c configs/keypoint/tiny_pose/tinypose_128x96.yml --output_dir=outut_inference -o weights=output/tinypose_128x96/model_final ``` 导出后的模型如: ``` @@ -147,9 +147,9 @@ python3 deploy/python/det_keypoint_unite_infer.py --det_model_dir=output_inferen 如果您希望将自己训练的模型应用于部署,可以参考以下步骤: 1. 将训练的模型导出 ```bash -python3 tools/export_model.py -c keypoint/picodet_s_192_pedestrian.yml --output_dir=outut_inference -o weights=output/picodet_s_192_pedestrian/model_final TestReader.fuse_normalize=true +python3 tools/export_model.py -c configs/picodet/application/pedestrian_detection/picodet_s_192_pedestrian.yml --output_dir=outut_inference -o weights=output/picodet_s_192_pedestrian/model_final TestReader.fuse_normalize=true -python3 tools/export_model.py -c keypoint/tinypose_128x96.yml --output_dir=outut_inference -o weights=output/tinypose_128x96/model_final TestReader.fuse_normalize=true +python3 tools/export_model.py -c configs/keypoint/tiny_pose/tinypose_128x96.yml --output_dir=outut_inference -o weights=output/tinypose_128x96/model_final TestReader.fuse_normalize=true ``` 2. 转换为Lite模型(依赖[Paddle-Lite](https://github.com/PaddlePaddle/Paddle-Lite)) diff --git a/configs/keypoint/tiny_pose/keypoint/tinypose_128x96.yml b/configs/keypoint/tiny_pose/tinypose_128x96.yml similarity index 99% rename from configs/keypoint/tiny_pose/keypoint/tinypose_128x96.yml rename to configs/keypoint/tiny_pose/tinypose_128x96.yml index bc003116d..e213c2990 100644 --- a/configs/keypoint/tiny_pose/keypoint/tinypose_128x96.yml +++ b/configs/keypoint/tiny_pose/tinypose_128x96.yml @@ -77,7 +77,7 @@ EvalDataset: trainsize: *trainsize pixel_std: *pixel_std use_gt_bbox: True - image_thre: 0.0 + image_thre: 0.5 TestDataset: !ImageFolder diff --git a/configs/keypoint/tiny_pose/keypoint/tinypose_256x192.yml b/configs/keypoint/tiny_pose/tinypose_256x192.yml similarity index 99% rename from configs/keypoint/tiny_pose/keypoint/tinypose_256x192.yml rename to configs/keypoint/tiny_pose/tinypose_256x192.yml index bc986331d..9de2a635f 100644 --- a/configs/keypoint/tiny_pose/keypoint/tinypose_256x192.yml +++ b/configs/keypoint/tiny_pose/tinypose_256x192.yml @@ -77,7 +77,7 @@ EvalDataset: trainsize: *trainsize pixel_std: *pixel_std use_gt_bbox: True - image_thre: 0.0 + image_thre: 0.5 TestDataset: !ImageFolder diff --git a/configs/picodet/application/pedestrian_detection/picodet_s_192_pedestrian.yml b/configs/picodet/application/pedestrian_detection/picodet_s_192_pedestrian.yml new file mode 100644 index 000000000..5af9fca6c --- /dev/null +++ b/configs/picodet/application/pedestrian_detection/picodet_s_192_pedestrian.yml @@ -0,0 +1,144 @@ +use_gpu: true +log_iter: 20 +save_dir: output +snapshot_epoch: 1 +print_flops: false +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x0_75_pretrained.pdparams +weights: output/picodet_s_192_pedestrian/model_final +find_unused_parameters: True +use_ema: true +cycle_epoch: 40 +snapshot_epoch: 10 +epoch: 300 +metric: COCO +num_classes: 1 + +architecture: PicoDet + +PicoDet: + backbone: ESNet + neck: CSPPAN + head: PicoHead + +ESNet: + scale: 0.75 + feature_maps: [4, 11, 14] + act: hard_swish + channel_ratio: [0.875, 0.5, 0.5, 0.5, 0.625, 0.5, 0.625, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5] + +CSPPAN: + out_channels: 96 + use_depthwise: True + num_csp_blocks: 1 + num_features: 4 + +PicoHead: + conv_feat: + name: PicoFeat + feat_in: 96 + feat_out: 96 + num_convs: 2 + num_fpn_stride: 4 + norm_type: bn + share_cls_reg: True + fpn_stride: [8, 16, 32, 64] + feat_in_chan: 96 + prior_prob: 0.01 + reg_max: 7 + cell_offset: 0.5 + loss_class: + name: VarifocalLoss + use_sigmoid: True + iou_weighted: True + loss_weight: 1.0 + loss_dfl: + name: DistributionFocalLoss + loss_weight: 0.25 + loss_bbox: + name: GIoULoss + loss_weight: 2.0 + assigner: + name: SimOTAAssigner + candidate_topk: 10 + iou_weight: 6 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.025 + nms_threshold: 0.6 + +LearningRate: + base_lr: 0.4 + schedulers: + - !CosineDecay + max_epochs: 300 + - !LinearWarmup + start_factor: 0.1 + steps: 300 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.00004 + type: L2 + +TrainDataset: + !COCODataSet + image_dir: "" + anno_path: aic_coco_train_cocoformat.json + dataset_dir: dataset + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] + +EvalDataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + +TestDataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + +worker_num: 8 +TrainReader: + sample_transforms: + - Decode: {} + - RandomCrop: {} + - RandomFlip: {prob: 0.5} + - RandomDistort: {} + batch_transforms: + - BatchRandomResize: {target_size: [128, 160, 192, 224, 256], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_size: 128 + shuffle: true + drop_last: true + collate_batch: false + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [192, 192], keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 8 + shuffle: false + +TestReader: + inputs_def: + image_shape: [1, 3, 192, 192] + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [192, 192], keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false + fuse_normalize: true diff --git a/configs/keypoint/tiny_pose/pedestrian_detection/picodet_s_320_pedestrian.yml b/configs/picodet/application/pedestrian_detection/picodet_s_320_pedestrian.yml similarity index 100% rename from configs/keypoint/tiny_pose/pedestrian_detection/picodet_s_320_pedestrian.yml rename to configs/picodet/application/pedestrian_detection/picodet_s_320_pedestrian.yml diff --git a/deploy/python/det_keypoint_unite_infer.py b/deploy/python/det_keypoint_unite_infer.py index bfa5c9157..5be63a72b 100644 --- a/deploy/python/det_keypoint_unite_infer.py +++ b/deploy/python/det_keypoint_unite_infer.py @@ -21,7 +21,7 @@ import paddle from det_keypoint_unite_utils import argsparser from preprocess import decode_image -from infer import Detector, PredictConfig, print_arguments, get_test_images +from infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get_test_images from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint from visualize import draw_pose from benchmark_utils import PaddleInferBenchmark @@ -217,17 +217,20 @@ def topdown_unite_predict_video(detector, def main(): pred_config = PredictConfig(FLAGS.det_model_dir) - detector = Detector( - pred_config, - FLAGS.det_model_dir, - device=FLAGS.device, - run_mode=FLAGS.run_mode, - trt_min_shape=FLAGS.trt_min_shape, - trt_max_shape=FLAGS.trt_max_shape, - trt_opt_shape=FLAGS.trt_opt_shape, - trt_calib_mode=FLAGS.trt_calib_mode, - cpu_threads=FLAGS.cpu_threads, - enable_mkldnn=FLAGS.enable_mkldnn) + detector_func = 'Detector' + if pred_config.arch == 'PicoDet': + detector_func = 'DetectorPicoDet' + + detector = eval(detector_func)(pred_config, + FLAGS.det_model_dir, + device=FLAGS.device, + run_mode=FLAGS.run_mode, + trt_min_shape=FLAGS.trt_min_shape, + trt_max_shape=FLAGS.trt_max_shape, + trt_opt_shape=FLAGS.trt_opt_shape, + trt_calib_mode=FLAGS.trt_calib_mode, + cpu_threads=FLAGS.cpu_threads, + enable_mkldnn=FLAGS.enable_mkldnn) pred_config = PredictConfig_KeyPoint(FLAGS.keypoint_model_dir) assert KEYPOINT_SUPPORT_MODELS[ -- GitLab