From fceab30879c5faf9648db08f76e1d6cb939b3653 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Tue, 27 Jul 2021 18:25:35 +0800 Subject: [PATCH] Add GFL model and PicoDet (#3620) * add gfl model and PicoDet --- configs/gfl/README.md | 37 ++ configs/gfl/_base_/gfl_r50_fpn.yml | 51 ++ configs/gfl/_base_/gfl_reader.yml | 40 ++ configs/gfl/_base_/gflv2_r50_fpn.yml | 56 +++ configs/gfl/_base_/optimizer_1x.yml | 19 + configs/gfl/gfl_r50_fpn_1x_coco.yml | 10 + configs/gfl/gflv2_r50_fpn_1x_coco.yml | 10 + configs/picodet/README.md | 48 ++ configs/picodet/_base_/optimizer_280e.yml | 18 + configs/picodet/_base_/picodet_320_reader.yml | 45 ++ configs/picodet/_base_/picodet_416_reader.yml | 45 ++ configs/picodet/_base_/picodet_mbv3_0_5x.yml | 51 ++ .../_base_/picodet_shufflenetv2_1x.yml | 49 ++ configs/picodet/picodet_m_mbv3_320_coco.yml | 40 ++ configs/picodet/picodet_m_mbv3_416_coco.yml | 49 ++ configs/picodet/picodet_m_r18_320_coco.yml | 55 ++ .../picodet_m_shufflenetv2_320_coco.yml | 35 ++ .../picodet_m_shufflenetv2_416_coco.yml | 38 ++ configs/picodet/picodet_s_mbv3_320_coco.yml | 13 + configs/picodet/picodet_s_mbv3_416_coco.yml | 13 + .../picodet_s_shufflenetv2_320_coco.yml | 13 + .../picodet_s_shufflenetv2_416_coco.yml | 13 + .../quant/picodet_s_mbv3_320_quant_coco.yml | 23 + deploy/python/infer.py | 2 + docs/MODEL_ZOO_cn.md | 8 + ppdet/data/transform/atss_assigner.py | 266 ++++++++++ ppdet/data/transform/batch_operators.py | 136 ++++- ppdet/engine/export_utils.py | 7 +- ppdet/modeling/architectures/__init__.py | 4 + ppdet/modeling/architectures/gfl.py | 87 ++++ ppdet/modeling/architectures/picodet.py | 91 ++++ ppdet/modeling/backbones/__init__.py | 2 + ppdet/modeling/backbones/shufflenet_v2.py | 277 +++++++++++ ppdet/modeling/bbox_utils.py | 44 ++ ppdet/modeling/heads/__init__.py | 4 + ppdet/modeling/heads/gfl_head.py | 470 ++++++++++++++++++ ppdet/modeling/heads/pico_head.py | 328 ++++++++++++ ppdet/modeling/losses/__init__.py | 2 + ppdet/modeling/losses/gfocal_loss.py | 214 ++++++++ ppdet/modeling/necks/__init__.py | 2 + ppdet/modeling/necks/pan.py | 135 +++++ ppdet/modeling/tests/test_architectures.py | 10 + 42 files changed, 2854 insertions(+), 6 deletions(-) create mode 100644 configs/gfl/README.md create mode 100644 configs/gfl/_base_/gfl_r50_fpn.yml create mode 100644 configs/gfl/_base_/gfl_reader.yml create mode 100644 configs/gfl/_base_/gflv2_r50_fpn.yml create mode 100644 configs/gfl/_base_/optimizer_1x.yml create mode 100644 configs/gfl/gfl_r50_fpn_1x_coco.yml create mode 100644 configs/gfl/gflv2_r50_fpn_1x_coco.yml create mode 100644 configs/picodet/README.md create mode 100644 configs/picodet/_base_/optimizer_280e.yml create mode 100644 configs/picodet/_base_/picodet_320_reader.yml create mode 100644 configs/picodet/_base_/picodet_416_reader.yml create mode 100644 configs/picodet/_base_/picodet_mbv3_0_5x.yml create mode 100644 configs/picodet/_base_/picodet_shufflenetv2_1x.yml create mode 100644 configs/picodet/picodet_m_mbv3_320_coco.yml create mode 100644 configs/picodet/picodet_m_mbv3_416_coco.yml create mode 100644 configs/picodet/picodet_m_r18_320_coco.yml create mode 100644 configs/picodet/picodet_m_shufflenetv2_320_coco.yml create mode 100644 configs/picodet/picodet_m_shufflenetv2_416_coco.yml create mode 100644 configs/picodet/picodet_s_mbv3_320_coco.yml create mode 100644 configs/picodet/picodet_s_mbv3_416_coco.yml create mode 100644 configs/picodet/picodet_s_shufflenetv2_320_coco.yml create mode 100644 configs/picodet/picodet_s_shufflenetv2_416_coco.yml create mode 100644 configs/slim/quant/picodet_s_mbv3_320_quant_coco.yml create mode 100644 ppdet/data/transform/atss_assigner.py create mode 100644 ppdet/modeling/architectures/gfl.py create mode 100644 ppdet/modeling/architectures/picodet.py create mode 100644 ppdet/modeling/backbones/shufflenet_v2.py create mode 100644 ppdet/modeling/heads/gfl_head.py create mode 100644 ppdet/modeling/heads/pico_head.py create mode 100644 ppdet/modeling/losses/gfocal_loss.py create mode 100644 ppdet/modeling/necks/pan.py diff --git a/configs/gfl/README.md b/configs/gfl/README.md new file mode 100644 index 000000000..c812581f3 --- /dev/null +++ b/configs/gfl/README.md @@ -0,0 +1,37 @@ +# Generalized Focal Loss Model(GFL) + +## Introduction + +[Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection](https://arxiv.org/abs/2006.04388) and [Generalized Focal Loss V2](https://arxiv.org/pdf/2011.12885.pdf) + + + +## Model Zoo + +| Backbone | Model | images/GPU | lr schedule |FPS | Box AP | download | config | +| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | +| ResNet50-FPN | GFL | 2 | 1x | ---- | 40.1 | [download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r50_fpn_1x_coco.yml) | +| ResNet50-FPN | GFLv2 | 2 | 1x | ---- | 40.4 | [download](https://paddledet.bj.bcebos.com/models/gflv2_r50_fpn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gflv2_r50_fpn_1x_coco.yml) | + + +**Notes:** + +- GFL is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`. + +## Citations +``` +@article{li2020generalized, + title={Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection}, + author={Li, Xiang and Wang, Wenhai and Wu, Lijun and Chen, Shuo and Hu, Xiaolin and Li, Jun and Tang, Jinhui and Yang, Jian}, + journal={arXiv preprint arXiv:2006.04388}, + year={2020} +} + +@article{li2020gflv2, + title={Generalized Focal Loss V2: Learning Reliable Localization Quality Estimation for Dense Object Detection}, + author={Li, Xiang and Wang, Wenhai and Hu, Xiaolin and Li, Jun and Tang, Jinhui and Yang, Jian}, + journal={arXiv preprint arXiv:2011.12885}, + year={2020} +} + +``` diff --git a/configs/gfl/_base_/gfl_r50_fpn.yml b/configs/gfl/_base_/gfl_r50_fpn.yml new file mode 100644 index 000000000..8130b5ca8 --- /dev/null +++ b/configs/gfl/_base_/gfl_r50_fpn.yml @@ -0,0 +1,51 @@ +architecture: GFL +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams + +GFL: + backbone: ResNet + neck: FPN + head: GFLHead + +ResNet: + depth: 50 + variant: b + norm_type: bn + freeze_at: 0 + return_idx: [1,2,3] + num_stages: 4 + +FPN: + out_channel: 256 + spatial_scales: [0.125, 0.0625, 0.03125] + extra_stage: 2 + has_extra_convs: true + use_c5: false + +GFLHead: + conv_feat: + name: FCOSFeat + feat_in: 256 + feat_out: 256 + num_convs: 4 + norm_type: "gn" + use_dcn: false + fpn_stride: [8, 16, 32, 64, 128] + prior_prob: 0.01 + reg_max: 16 + loss_qfl: + name: QualityFocalLoss + use_sigmoid: True + beta: 2.0 + loss_weight: 1.0 + loss_dfl: + name: DistributionFocalLoss + loss_weight: 0.25 + loss_bbox: + name: GIoULoss + loss_weight: 2.0 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.025 + nms_threshold: 0.6 diff --git a/configs/gfl/_base_/gfl_reader.yml b/configs/gfl/_base_/gfl_reader.yml new file mode 100644 index 000000000..2de54ff94 --- /dev/null +++ b/configs/gfl/_base_/gfl_reader.yml @@ -0,0 +1,40 @@ +worker_num: 2 +TrainReader: + sample_transforms: + - Decode: {} + - RandomFlip: {prob: 0.5} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + - Gt2GFLTarget: + downsample_ratios: [8, 16, 32, 64, 128] + grid_cell_scale: 8 + batch_size: 2 + shuffle: true + drop_last: true + + +EvalReader: + sample_transforms: + - Decode: {} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 2 + shuffle: false + + +TestReader: + sample_transforms: + - Decode: {} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false diff --git a/configs/gfl/_base_/gflv2_r50_fpn.yml b/configs/gfl/_base_/gflv2_r50_fpn.yml new file mode 100644 index 000000000..691dde035 --- /dev/null +++ b/configs/gfl/_base_/gflv2_r50_fpn.yml @@ -0,0 +1,56 @@ +architecture: GFL +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams + +GFL: + backbone: ResNet + neck: FPN + head: GFLHead + +ResNet: + depth: 50 + variant: b + norm_type: bn + freeze_at: 0 + return_idx: [1,2,3] + num_stages: 4 + +FPN: + out_channel: 256 + spatial_scales: [0.125, 0.0625, 0.03125] + extra_stage: 2 + has_extra_convs: true + use_c5: false + +GFLHead: + conv_feat: + name: FCOSFeat + feat_in: 256 + feat_out: 256 + num_convs: 4 + norm_type: "gn" + use_dcn: false + fpn_stride: [8, 16, 32, 64, 128] + prior_prob: 0.01 + reg_max: 16 + dgqp_module: + name: DGQP + reg_topk: 4 + reg_channels: 64 + add_mean: True + loss_qfl: + name: QualityFocalLoss + use_sigmoid: False + beta: 2.0 + loss_weight: 1.0 + loss_dfl: + name: DistributionFocalLoss + loss_weight: 0.25 + loss_bbox: + name: GIoULoss + loss_weight: 2.0 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.025 + nms_threshold: 0.6 diff --git a/configs/gfl/_base_/optimizer_1x.yml b/configs/gfl/_base_/optimizer_1x.yml new file mode 100644 index 000000000..6a3284799 --- /dev/null +++ b/configs/gfl/_base_/optimizer_1x.yml @@ -0,0 +1,19 @@ +epoch: 12 + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [8, 11] + - !LinearWarmup + start_factor: 0.1 + steps: 500 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 diff --git a/configs/gfl/gfl_r50_fpn_1x_coco.yml b/configs/gfl/gfl_r50_fpn_1x_coco.yml new file mode 100644 index 000000000..2e17b23d8 --- /dev/null +++ b/configs/gfl/gfl_r50_fpn_1x_coco.yml @@ -0,0 +1,10 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/gfl_r50_fpn.yml', + '_base_/optimizer_1x.yml', + '_base_/gfl_reader.yml', +] + +weights: output/gfl_r50_fpn_1x_coco/model_final +find_unused_parameters: True diff --git a/configs/gfl/gflv2_r50_fpn_1x_coco.yml b/configs/gfl/gflv2_r50_fpn_1x_coco.yml new file mode 100644 index 000000000..73a69215c --- /dev/null +++ b/configs/gfl/gflv2_r50_fpn_1x_coco.yml @@ -0,0 +1,10 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/gflv2_r50_fpn.yml', + '_base_/optimizer_1x.yml', + '_base_/gfl_reader.yml', +] + +weights: output/gfl_r50_fpn_1x_coco/model_final +find_unused_parameters: True diff --git a/configs/picodet/README.md b/configs/picodet/README.md new file mode 100644 index 000000000..7b36047d5 --- /dev/null +++ b/configs/picodet/README.md @@ -0,0 +1,48 @@ +# PicoDet + +## Introduction + +We developed a series of mobile models, which named `PicoDet`. +Optimizing method of we use: +- [Generalized Focal Loss V2](https://arxiv.org/pdf/2011.12885.pdf) +- Lr Cosine Decay + + + +## Model Zoo + +### PicoDet-S + +| Backbone | Input size | images/GPU | lr schedule |Box AP | FLOPS | Inference Time | download | config | +| :------------------------ | :-------: | :-------: | :-----------: | :---: | :-----: | :-----: | :-------------------------------------------------: | :-----: | +| ShuffleNetv2-1x | 320*320 | 128 | 280e | 21.9 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_320_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_shufflenetv2_320_coco.yml) | +| MobileNetv3-large-0.5x | 320*320 | 128 | 280e | 20.4 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_s_mbv3_320_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_mbv3_320_coco.yml) | +| ShuffleNetv2-1x | 416*416 | 96 | 280e | 24.0 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_416_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_shufflenetv2_416_coco.yml) | +| MobileNetv3-large-0.5x | 416*416 | 96 | 280e | 23.3 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_s_mbv3_416_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_mbv3_416_coco.yml) | + +### PicoDet-M + +| Backbone | Input size | images/GPU | lr schedule |Box AP | FLOPS | Inference Time | download | config | +| :------------------------ | :-------: | :-------: | :-----------: | :---: | :-----: | :-----: | :-------------------------------------------------: | :-----: | +| ShuffleNetv2-1.5x | 320*320 | 128 | 280e | 24.9 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_m_shufflenetv2_320_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_shufflenetv2_320_coco.yml) | +| MobileNetv3-large-1x | 320*320 | 128 | 280e | 26.4 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_m_mbv3_320_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_mbv3_320_coco.yml) | +| ShuffleNetv2-1.5x | 416*416 | 128 | 280e | 27.4 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_m_shufflenetv2_416_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_shufflenetv2_416_coco.yml) | +| MobileNetv3-large-1x | 416*416 | 128 | 280e | 29.2 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_m_mbv3_416_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_mbv3_416_coco.yml) | + + +**Notes:** + +- PicoDet inference speed is tested on Kirin 980 with 4 threads by arm8 and with FP16. +- PicoDet is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`. +- PicoDet used 4 GPUs for training and mini-batch size as 128 or 96 on each GPU. + +## Citations +``` +@article{li2020gflv2, + title={Generalized Focal Loss V2: Learning Reliable Localization Quality Estimation for Dense Object Detection}, + author={Li, Xiang and Wang, Wenhai and Hu, Xiaolin and Li, Jun and Tang, Jinhui and Yang, Jian}, + journal={arXiv preprint arXiv:2011.12885}, + year={2020} +} + +``` diff --git a/configs/picodet/_base_/optimizer_280e.yml b/configs/picodet/_base_/optimizer_280e.yml new file mode 100644 index 000000000..a72bc7d24 --- /dev/null +++ b/configs/picodet/_base_/optimizer_280e.yml @@ -0,0 +1,18 @@ +epoch: 280 + +LearningRate: + base_lr: 0.4 + schedulers: + - !CosineDecay + max_epochs: 280 + - !LinearWarmup + start_factor: 0.1 + steps: 300 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 diff --git a/configs/picodet/_base_/picodet_320_reader.yml b/configs/picodet/_base_/picodet_320_reader.yml new file mode 100644 index 000000000..ca96183b5 --- /dev/null +++ b/configs/picodet/_base_/picodet_320_reader.yml @@ -0,0 +1,45 @@ +worker_num: 6 +TrainReader: + sample_transforms: + - Decode: {} + - RandomDistort: {} + - RandomCrop: {} + - RandomFlip: {prob: 0.5} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Resize: {target_size: [320, 320], keep_ratio: False, interp: 1} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + - Gt2GFLTarget: + downsample_ratios: [8, 16, 32] + grid_cell_scale: 5 + cell_offset: 0.5 + batch_size: 128 + shuffle: true + drop_last: true + + +EvalReader: + sample_transforms: + - Decode: {} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Resize: {interp: 1, target_size: [320, 320], keep_ratio: False} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 8 + shuffle: false + + +TestReader: + inputs_def: + image_shape: [3, 320, 320] + sample_transforms: + - Decode: {} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Resize: {interp: 1, target_size: [320, 320], keep_ratio: False} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false diff --git a/configs/picodet/_base_/picodet_416_reader.yml b/configs/picodet/_base_/picodet_416_reader.yml new file mode 100644 index 000000000..cb9b027f4 --- /dev/null +++ b/configs/picodet/_base_/picodet_416_reader.yml @@ -0,0 +1,45 @@ +worker_num: 6 +TrainReader: + sample_transforms: + - Decode: {} + - RandomDistort: {} + - RandomCrop: {} + - RandomFlip: {prob: 0.5} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Resize: {target_size: [416, 416], keep_ratio: False, interp: 1} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + - Gt2GFLTarget: + downsample_ratios: [8, 16, 32] + grid_cell_scale: 5 + cell_offset: 0.5 + batch_size: 96 + shuffle: true + drop_last: true + + +EvalReader: + sample_transforms: + - Decode: {} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Resize: {interp: 1, target_size: [416, 416], keep_ratio: False} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 8 + shuffle: false + + +TestReader: + inputs_def: + image_shape: [3, 416, 416] + sample_transforms: + - Decode: {} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Resize: {interp: 1, target_size: [416, 416], keep_ratio: False} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false diff --git a/configs/picodet/_base_/picodet_mbv3_0_5x.yml b/configs/picodet/_base_/picodet_mbv3_0_5x.yml new file mode 100644 index 000000000..aacd32d37 --- /dev/null +++ b/configs/picodet/_base_/picodet_mbv3_0_5x.yml @@ -0,0 +1,51 @@ +architecture: PicoDet +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams + +PicoDet: + backbone: MobileNetV3 + neck: PAN + head: PicoHead + +MobileNetV3: + model_name: large + scale: 0.5 + with_extra_blocks: false + extra_block_filters: [] + feature_maps: [7, 13, 16] + +PAN: + out_channel: 96 + start_level: 0 + end_level: 3 + spatial_scales: [0.125, 0.0625, 0.03125] + +PicoHead: + conv_feat: + name: PicoFeat + feat_in: 96 + feat_out: 96 + num_convs: 2 + norm_type: bn + share_cls_reg: True + fpn_stride: [8, 16, 32] + feat_in_chan: 96 + prior_prob: 0.01 + reg_max: 7 + cell_offset: 0.5 + loss_qfl: + name: QualityFocalLoss + use_sigmoid: True + beta: 2.0 + loss_weight: 1.0 + loss_dfl: + name: DistributionFocalLoss + loss_weight: 0.25 + loss_bbox: + name: GIoULoss + loss_weight: 2.0 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.025 + nms_threshold: 0.6 diff --git a/configs/picodet/_base_/picodet_shufflenetv2_1x.yml b/configs/picodet/_base_/picodet_shufflenetv2_1x.yml new file mode 100644 index 000000000..25517f9fb --- /dev/null +++ b/configs/picodet/_base_/picodet_shufflenetv2_1x.yml @@ -0,0 +1,49 @@ +architecture: PicoDet +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ShuffleNetV2_x1_0_pretrained.pdparams + +PicoDet: + backbone: ShuffleNetV2 + neck: PAN + head: PicoHead + +ShuffleNetV2: + scale: 1.0 + feature_maps: [5, 13, 17] + act: leaky_relu + +PAN: + out_channel: 96 + start_level: 0 + end_level: 3 + spatial_scales: [0.125, 0.0625, 0.03125] + +PicoHead: + conv_feat: + name: PicoFeat + feat_in: 96 + feat_out: 96 + num_convs: 2 + norm_type: bn + share_cls_reg: True + fpn_stride: [8, 16, 32] + feat_in_chan: 96 + prior_prob: 0.01 + reg_max: 7 + cell_offset: 0.5 + loss_qfl: + name: QualityFocalLoss + use_sigmoid: True + beta: 2.0 + loss_weight: 1.0 + loss_dfl: + name: DistributionFocalLoss + loss_weight: 0.25 + loss_bbox: + name: GIoULoss + loss_weight: 2.0 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.025 + nms_threshold: 0.6 diff --git a/configs/picodet/picodet_m_mbv3_320_coco.yml b/configs/picodet/picodet_m_mbv3_320_coco.yml new file mode 100644 index 000000000..43b2faf9a --- /dev/null +++ b/configs/picodet/picodet_m_mbv3_320_coco.yml @@ -0,0 +1,40 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_mbv3_0_5x.yml', + '_base_/optimizer_280e.yml', + '_base_/picodet_320_reader.yml', +] + +weights: output/picodet_m_mbv3_320_coco/model_final +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/MobileNetV3_large_x1_0_ssld_pretrained.pdparams +find_unused_parameters: True +use_ema: true +ema_decay: 0.9998 +snapshot_epoch: 10 + +MobileNetV3: + model_name: large + scale: 1.0 + with_extra_blocks: false + extra_block_filters: [] + feature_maps: [7, 13, 16] + +PAN: + out_channel: 128 + start_level: 0 + end_level: 3 + spatial_scales: [0.125, 0.0625, 0.03125] + +PicoHead: + conv_feat: + name: PicoFeat + feat_in: 128 + feat_out: 128 + num_convs: 2 + norm_type: bn + share_cls_reg: True + feat_in_chan: 128 + +TrainReader: + batch_size: 88 diff --git a/configs/picodet/picodet_m_mbv3_416_coco.yml b/configs/picodet/picodet_m_mbv3_416_coco.yml new file mode 100644 index 000000000..2660a914c --- /dev/null +++ b/configs/picodet/picodet_m_mbv3_416_coco.yml @@ -0,0 +1,49 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_mbv3_0_5x.yml', + '_base_/optimizer_280e.yml', + '_base_/picodet_416_reader.yml', +] + +weights: output/picodet_m_mbv3_320_coco/model_final +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/MobileNetV3_large_x1_0_ssld_pretrained.pdparams +find_unused_parameters: True +use_ema: true +ema_decay: 0.9998 +snapshot_epoch: 10 + +MobileNetV3: + model_name: large + scale: 1.0 + with_extra_blocks: false + extra_block_filters: [] + feature_maps: [7, 13, 16] + +PAN: + out_channel: 128 + start_level: 0 + end_level: 3 + spatial_scales: [0.125, 0.0625, 0.03125] + +PicoHead: + conv_feat: + name: PicoFeat + feat_in: 128 + feat_out: 128 + num_convs: 2 + norm_type: bn + share_cls_reg: True + feat_in_chan: 128 + +TrainReader: + batch_size: 56 + +LearningRate: + base_lr: 0.3 + schedulers: + - !CosineDecay + max_epochs: 280 + - !LinearWarmup + start_factor: 0.1 + steps: 300 diff --git a/configs/picodet/picodet_m_r18_320_coco.yml b/configs/picodet/picodet_m_r18_320_coco.yml new file mode 100644 index 000000000..52d405f74 --- /dev/null +++ b/configs/picodet/picodet_m_r18_320_coco.yml @@ -0,0 +1,55 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_mbv3_0_5x.yml', + '_base_/optimizer_280e.yml', + '_base_/picodet_320_reader.yml', +] + +weights: output/picodet_m_r18_320_coco/model_final +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet18_vd_pretrained.pdparams +find_unused_parameters: True +use_ema: true +ema_decay: 0.9998 +snapshot_epoch: 10 + +PicoDet: + backbone: ResNet + neck: PAN + head: PicoHead + +ResNet: + depth: 18 + variant: d + return_idx: [1, 2, 3] + freeze_at: -1 + freeze_norm: false + norm_decay: 0. + +PAN: + out_channel: 128 + start_level: 0 + end_level: 3 + spatial_scales: [0.125, 0.0625, 0.03125] + +PicoHead: + conv_feat: + name: PicoFeat + feat_in: 128 + feat_out: 128 + num_convs: 2 + norm_type: bn + share_cls_reg: True + feat_in_chan: 128 + +TrainReader: + batch_size: 56 + +LearningRate: + base_lr: 0.3 + schedulers: + - !CosineDecay + max_epochs: 280 + - !LinearWarmup + start_factor: 0.1 + steps: 300 diff --git a/configs/picodet/picodet_m_shufflenetv2_320_coco.yml b/configs/picodet/picodet_m_shufflenetv2_320_coco.yml new file mode 100644 index 000000000..e8b8fe327 --- /dev/null +++ b/configs/picodet/picodet_m_shufflenetv2_320_coco.yml @@ -0,0 +1,35 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_shufflenetv2_1x.yml', + '_base_/optimizer_280e.yml', + '_base_/picodet_320_reader.yml', +] + +weights: output/picodet_s_shufflenetv2_320_coco/model_final +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ShuffleNetV2_x1_5_pretrained.pdparams +find_unused_parameters: True +use_ema: true +ema_decay: 0.9998 +snapshot_epoch: 10 + +ShuffleNetV2: + scale: 1.5 + feature_maps: [5, 13, 17] + act: leaky_relu + +PAN: + out_channel: 128 + start_level: 0 + end_level: 3 + spatial_scales: [0.125, 0.0625, 0.03125] + +PicoHead: + conv_feat: + name: PicoFeat + feat_in: 128 + feat_out: 128 + num_convs: 2 + norm_type: bn + share_cls_reg: True + feat_in_chan: 128 diff --git a/configs/picodet/picodet_m_shufflenetv2_416_coco.yml b/configs/picodet/picodet_m_shufflenetv2_416_coco.yml new file mode 100644 index 000000000..f30d7b848 --- /dev/null +++ b/configs/picodet/picodet_m_shufflenetv2_416_coco.yml @@ -0,0 +1,38 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_shufflenetv2_1x.yml', + '_base_/optimizer_280e.yml', + '_base_/picodet_416_reader.yml', +] + +weights: output/picodet_s_shufflenetv2_320_coco/model_final +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ShuffleNetV2_x1_5_pretrained.pdparams +find_unused_parameters: True +use_ema: true +ema_decay: 0.9998 +snapshot_epoch: 10 + +ShuffleNetV2: + scale: 1.5 + feature_maps: [5, 13, 17] + act: leaky_relu + +PAN: + out_channel: 128 + start_level: 0 + end_level: 3 + spatial_scales: [0.125, 0.0625, 0.03125] + +PicoHead: + conv_feat: + name: PicoFeat + feat_in: 128 + feat_out: 128 + num_convs: 2 + norm_type: bn + share_cls_reg: True + feat_in_chan: 128 + +TrainReader: + batch_size: 88 diff --git a/configs/picodet/picodet_s_mbv3_320_coco.yml b/configs/picodet/picodet_s_mbv3_320_coco.yml new file mode 100644 index 000000000..c41a364ed --- /dev/null +++ b/configs/picodet/picodet_s_mbv3_320_coco.yml @@ -0,0 +1,13 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_mbv3_0_5x.yml', + '_base_/optimizer_280e.yml', + '_base_/picodet_320_reader.yml', +] + +weights: output/picodet_s_mbv3_320_coco/model_final +find_unused_parameters: True +use_ema: true +ema_decay: 0.9998 +snapshot_epoch: 10 diff --git a/configs/picodet/picodet_s_mbv3_416_coco.yml b/configs/picodet/picodet_s_mbv3_416_coco.yml new file mode 100644 index 000000000..7428cde80 --- /dev/null +++ b/configs/picodet/picodet_s_mbv3_416_coco.yml @@ -0,0 +1,13 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_mbv3_0_5x.yml', + '_base_/optimizer_280e.yml', + '_base_/picodet_416_reader.yml', +] + +weights: output/picodet_s_mbv3_320_coco/model_final +find_unused_parameters: True +use_ema: true +ema_decay: 0.9998 +snapshot_epoch: 10 diff --git a/configs/picodet/picodet_s_shufflenetv2_320_coco.yml b/configs/picodet/picodet_s_shufflenetv2_320_coco.yml new file mode 100644 index 000000000..0eb2b54f0 --- /dev/null +++ b/configs/picodet/picodet_s_shufflenetv2_320_coco.yml @@ -0,0 +1,13 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_shufflenetv2_1x.yml', + '_base_/optimizer_280e.yml', + '_base_/picodet_320_reader.yml', +] + +weights: output/picodet_s_shufflenetv2_320_coco/model_final +find_unused_parameters: True +use_ema: true +ema_decay: 0.9998 +snapshot_epoch: 10 diff --git a/configs/picodet/picodet_s_shufflenetv2_416_coco.yml b/configs/picodet/picodet_s_shufflenetv2_416_coco.yml new file mode 100644 index 000000000..b0ad5cb63 --- /dev/null +++ b/configs/picodet/picodet_s_shufflenetv2_416_coco.yml @@ -0,0 +1,13 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_shufflenetv2_1x.yml', + '_base_/optimizer_280e.yml', + '_base_/picodet_416_reader.yml', +] + +weights: output/picodet_s_shufflenetv2_320_coco/model_final +find_unused_parameters: True +use_ema: true +ema_decay: 0.9998 +snapshot_epoch: 10 diff --git a/configs/slim/quant/picodet_s_mbv3_320_quant_coco.yml b/configs/slim/quant/picodet_s_mbv3_320_quant_coco.yml new file mode 100644 index 000000000..3ecf1a723 --- /dev/null +++ b/configs/slim/quant/picodet_s_mbv3_320_quant_coco.yml @@ -0,0 +1,23 @@ +pretrain_weights: https://paddledet.bj.bcebos.com/models/picodet_s_mbv3_320_coco.pdparams +slim: QAT + +QAT: + quant_config: { + 'activation_preprocess_type': 'PACT', + 'weight_quantize_type': 'channel_wise_abs_max', 'activation_quantize_type': 'moving_average_abs_max', + 'weight_bits': 8, 'activation_bits': 8, 'dtype': 'int8', 'window_size': 10000, 'moving_rate': 0.9, + 'quantizable_layer_type': ['Conv2D', 'Linear']} + print_model: False + +epoch: 50 +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 35 + - 45 + - !LinearWarmup + start_factor: 0. + steps: 1000 diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 7c590d552..396f570b1 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -42,6 +42,8 @@ SUPPORT_MODELS = { 'JDE', 'FairMOT', 'DeepSORT', + 'GFL', + 'PicoDet', } diff --git a/docs/MODEL_ZOO_cn.md b/docs/MODEL_ZOO_cn.md index 61a5e3437..e76ea0ba1 100644 --- a/docs/MODEL_ZOO_cn.md +++ b/docs/MODEL_ZOO_cn.md @@ -79,6 +79,14 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型 请参考[Res2Net](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/res2net/) +### GFL + +请参考[GFL](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl) + +### PicoDet + +请参考[PicoDet](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet) + ## 旋转框检测 diff --git a/ppdet/data/transform/atss_assigner.py b/ppdet/data/transform/atss_assigner.py new file mode 100644 index 000000000..967e39889 --- /dev/null +++ b/ppdet/data/transform/atss_assigner.py @@ -0,0 +1,266 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + + +def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6): + """Calculate overlap between two set of bboxes. + If ``is_aligned `` is ``False``, then calculate the overlaps between each + bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned + pair of bboxes1 and bboxes2. + Args: + bboxes1 (Tensor): shape (B, m, 4) in format or empty. + bboxes2 (Tensor): shape (B, n, 4) in format or empty. + B indicates the batch dim, in shape (B1, B2, ..., Bn). + If ``is_aligned `` is ``True``, then m and n must be equal. + mode (str): "iou" (intersection over union) or "iof" (intersection over + foreground). + is_aligned (bool, optional): If True, then m and n must be equal. + Default False. + eps (float, optional): A value added to the denominator for numerical + stability. Default 1e-6. + Returns: + Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,) + """ + assert mode in ['iou', 'iof', 'giou'], 'Unsupported mode {}'.format(mode) + # Either the boxes are empty or the length of boxes's last dimenstion is 4 + assert (bboxes1.shape[-1] == 4 or bboxes1.shape[0] == 0) + assert (bboxes2.shape[-1] == 4 or bboxes2.shape[0] == 0) + + # Batch dim must be the same + # Batch dim: (B1, B2, ... Bn) + assert bboxes1.shape[:-2] == bboxes2.shape[:-2] + batch_shape = bboxes1.shape[:-2] + + rows = bboxes1.shape[-2] if bboxes1.shape[0] > 0 else 0 + cols = bboxes2.shape[-2] if bboxes2.shape[0] > 0 else 0 + if is_aligned: + assert rows == cols + + if rows * cols == 0: + if is_aligned: + return np.random.random(batch_shape + (rows, )) + else: + return np.random.random(batch_shape + (rows, cols)) + + area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * ( + bboxes1[..., 3] - bboxes1[..., 1]) + area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * ( + bboxes2[..., 3] - bboxes2[..., 1]) + + if is_aligned: + lt = np.maximum(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2] + rb = np.minimum(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2] + + wh = (rb - lt).clip(min=0) # [B, rows, 2] + overlap = wh[..., 0] * wh[..., 1] + + if mode in ['iou', 'giou']: + union = area1 + area2 - overlap + else: + union = area1 + if mode == 'giou': + enclosed_lt = np.minimum(bboxes1[..., :2], bboxes2[..., :2]) + enclosed_rb = np.maximum(bboxes1[..., 2:], bboxes2[..., 2:]) + else: + lt = np.maximum(bboxes1[..., :, None, :2], + bboxes2[..., None, :, :2]) # [B, rows, cols, 2] + rb = np.minimum(bboxes1[..., :, None, 2:], + bboxes2[..., None, :, 2:]) # [B, rows, cols, 2] + + wh = (rb - lt).clip(min=0) # [B, rows, cols, 2] + overlap = wh[..., 0] * wh[..., 1] + + if mode in ['iou', 'giou']: + union = area1[..., None] + area2[..., None, :] - overlap + else: + union = area1[..., None] + if mode == 'giou': + enclosed_lt = np.minimum(bboxes1[..., :, None, :2], + bboxes2[..., None, :, :2]) + enclosed_rb = np.maximum(bboxes1[..., :, None, 2:], + bboxes2[..., None, :, 2:]) + + eps = np.array([eps]) + union = np.maximum(union, eps) + ious = overlap / union + if mode in ['iou', 'iof']: + return ious + # calculate gious + enclose_wh = (enclosed_rb - enclosed_lt).clip(min=0) + enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1] + enclose_area = np.maximum(enclose_area, eps) + gious = ious - (enclose_area - union) / enclose_area + return gious + + +def topk_(input, k, axis=1, largest=True): + x = -input if largest else input + if axis == 0: + row_index = np.arange(input.shape[1 - axis]) + topk_index = np.argpartition(x, k, axis=axis)[0:k, :] + topk_data = x[topk_index, row_index] + + topk_index_sort = np.argsort(topk_data, axis=axis) + topk_data_sort = topk_data[topk_index_sort, row_index] + topk_index_sort = topk_index[0:k, :][topk_index_sort, row_index] + else: + column_index = np.arange(x.shape[1 - axis])[:, None] + topk_index = np.argpartition(x, k, axis=axis)[:, 0:k] + topk_data = x[column_index, topk_index] + topk_data = -topk_data if largest else topk_data + topk_index_sort = np.argsort(topk_data, axis=axis) + topk_data_sort = topk_data[column_index, topk_index_sort] + topk_index_sort = topk_index[:, 0:k][column_index, topk_index_sort] + + return topk_data_sort, topk_index_sort + + +class ATSSAssigner(object): + """Assign a corresponding gt bbox or background to each bbox. + + Each proposals will be assigned with `0` or a positive integer + indicating the ground truth index. + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + topk (float): number of bbox selected in each level + """ + + def __init__(self, topk=9): + self.topk = topk + + def __call__(self, + bboxes, + num_level_bboxes, + gt_bboxes, + gt_bboxes_ignore=None, + gt_labels=None): + """Assign gt to bboxes. + The assignment is done in following steps + 1. compute iou between all bbox (bbox of all pyramid levels) and gt + 2. compute center distance between all bbox and gt + 3. on each pyramid level, for each gt, select k bbox whose center + are closest to the gt center, so we total select k*l bbox as + candidates for each gt + 4. get corresponding iou for the these candidates, and compute the + mean and std, set mean + std as the iou threshold + 5. select these candidates whose iou are greater than or equal to + the threshold as postive + 6. limit the positive sample's center in gt + Args: + bboxes (np.array): Bounding boxes to be assigned, shape(n, 4). + num_level_bboxes (List): num of bboxes in each level + gt_bboxes (np.array): Groundtruth boxes, shape (k, 4). + gt_bboxes_ignore (np.array, optional): Ground truth bboxes that are + labelled as `ignored`, e.g., crowd boxes in COCO. + gt_labels (np.array, optional): Label of gt_bboxes, shape (k, ). + """ + bboxes = bboxes[:, :4] + num_gt, num_bboxes = gt_bboxes.shape[0], bboxes.shape[0] + # compute iou between all bbox and gt + overlaps = bbox_overlaps(bboxes, gt_bboxes) + + # assign 0 by default + assigned_gt_inds = np.zeros((num_bboxes, ), dtype=np.int64) + + if num_gt == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + max_overlaps = np.zeros((num_bboxes, )) + if num_gt == 0: + # No truth, assign everything to background + assigned_gt_inds[:] = 0 + if not np.any(gt_labels): + assigned_labels = None + else: + assigned_labels = -np.ones((num_bboxes, ), dtype=np.int64) + return assigned_gt_inds, max_overlaps, assigned_labels + + # compute center distance between all bbox and gt + gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0 + gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0 + gt_points = np.stack((gt_cx, gt_cy), axis=1) + + bboxes_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0 + bboxes_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0 + bboxes_points = np.stack((bboxes_cx, bboxes_cy), axis=1) + + distances = np.sqrt( + np.power((bboxes_points[:, None, :] - gt_points[None, :, :]), 2) + .sum(-1)) + + # Selecting candidates based on the center distance + candidate_idxs = [] + start_idx = 0 + for bboxes_per_level in num_level_bboxes: + # on each pyramid level, for each gt, + # select k bbox whose center are closest to the gt center + end_idx = start_idx + bboxes_per_level + distances_per_level = distances[start_idx:end_idx, :] + selectable_k = min(self.topk, bboxes_per_level) + _, topk_idxs_per_level = topk_( + distances_per_level, selectable_k, axis=0, largest=False) + candidate_idxs.append(topk_idxs_per_level + start_idx) + start_idx = end_idx + candidate_idxs = np.concatenate(candidate_idxs, axis=0) + + # get corresponding iou for the these candidates, and compute the + # mean and std, set mean + std as the iou threshold + candidate_overlaps = overlaps[candidate_idxs, np.arange(num_gt)] + overlaps_mean_per_gt = candidate_overlaps.mean(0) + overlaps_std_per_gt = candidate_overlaps.std(0) + overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt + + is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :] + + # limit the positive sample's center in gt + for gt_idx in range(num_gt): + candidate_idxs[:, gt_idx] += gt_idx * num_bboxes + ep_bboxes_cx = np.broadcast_to( + bboxes_cx.reshape(1, -1), [num_gt, num_bboxes]).reshape(-1) + ep_bboxes_cy = np.broadcast_to( + bboxes_cy.reshape(1, -1), [num_gt, num_bboxes]).reshape(-1) + candidate_idxs = candidate_idxs.reshape(-1) + + # calculate the left, top, right, bottom distance between positive + # bbox center and gt side + l_ = ep_bboxes_cx[candidate_idxs].reshape(-1, num_gt) - gt_bboxes[:, 0] + t_ = ep_bboxes_cy[candidate_idxs].reshape(-1, num_gt) - gt_bboxes[:, 1] + r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].reshape(-1, num_gt) + b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].reshape(-1, num_gt) + is_in_gts = np.stack([l_, t_, r_, b_], axis=1).min(axis=1) > 0.01 + is_pos = is_pos & is_in_gts + + # if an anchor box is assigned to multiple gts, + # the one with the highest IoU will be selected. + overlaps_inf = -np.inf * np.ones_like(overlaps).T.reshape(-1) + index = candidate_idxs.reshape(-1)[is_pos.reshape(-1)] + overlaps_inf[index] = overlaps.T.reshape(-1)[index] + overlaps_inf = overlaps_inf.reshape(num_gt, -1).T + + max_overlaps = overlaps_inf.max(axis=1) + argmax_overlaps = overlaps_inf.argmax(axis=1) + assigned_gt_inds[max_overlaps != + -np.inf] = argmax_overlaps[max_overlaps != -np.inf] + 1 + + return assigned_gt_inds, max_overlaps diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index 81b5ef728..e72635898 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -22,9 +22,11 @@ except Exception: from collections import Sequence import cv2 +import math import numpy as np from .operators import register_op, BaseOperator, Resize from .op_helper import jaccard_overlap, gaussian2D +from .atss_assigner import ATSSAssigner from scipy import ndimage from ppdet.modeling import bbox_utils @@ -33,7 +35,8 @@ logger = setup_logger(__name__) __all__ = [ 'PadBatch', 'BatchRandomResize', 'Gt2YoloTarget', 'Gt2FCOSTarget', - 'Gt2TTFTarget', 'Gt2Solov2Target', 'Gt2SparseRCNNTarget', 'PadMaskBatch' + 'Gt2TTFTarget', 'Gt2Solov2Target', 'Gt2SparseRCNNTarget', 'PadMaskBatch', + 'Gt2GFLTarget' ] @@ -177,8 +180,6 @@ class Gt2YoloTarget(BaseOperator): h, w = samples[0]['image'].shape[1:3] an_hw = np.array(self.anchors) / np.array([[w, h]]) for sample in samples: - # im, gt_bbox, gt_class, gt_score = sample - im = sample['image'] gt_bbox = sample['gt_bbox'] gt_class = sample['gt_class'] if 'gt_score' not in sample: @@ -367,7 +368,6 @@ class Gt2FCOSTarget(BaseOperator): "object_sizes_of_interest', and 'downsample_ratios' should have same length." for sample in samples: - # im, gt_bbox, gt_class, gt_score = sample im = sample['image'] bboxes = sample['gt_bbox'] gt_class = sample['gt_class'] @@ -466,6 +466,134 @@ class Gt2FCOSTarget(BaseOperator): return samples +@register_op +class Gt2GFLTarget(BaseOperator): + """ + Generate GFocal loss targets by groud truth data + """ + + def __init__(self, + num_classes=80, + downsample_ratios=[8, 16, 32, 64, 128], + grid_cell_scale=4, + cell_offset=0): + super(Gt2GFLTarget, self).__init__() + self.num_classes = num_classes + self.downsample_ratios = downsample_ratios + self.grid_cell_scale = grid_cell_scale + self.cell_offset = cell_offset + + self.assigner = ATSSAssigner() + + def get_grid_cells(self, featmap_size, scale, stride, offset=0): + """ + Generate grid cells of a feature map for target assignment. + Args: + featmap_size: Size of a single level feature map. + scale: Grid cell scale. + stride: Down sample stride of the feature map. + offset: Offset of grid cells. + return: + Grid_cells xyxy position. Size should be [feat_w * feat_h, 4] + """ + cell_size = stride * scale + h, w = featmap_size + x_range = (np.arange(w, dtype=np.float32) + offset) * stride + y_range = (np.arange(h, dtype=np.float32) + offset) * stride + x, y = np.meshgrid(x_range, y_range) + y = y.flatten() + x = x.flatten() + grid_cells = np.stack( + [ + x - 0.5 * cell_size, y - 0.5 * cell_size, x + 0.5 * cell_size, + y + 0.5 * cell_size + ], + axis=-1) + return grid_cells + + def get_sample(self, assign_gt_inds, gt_bboxes): + pos_inds = np.unique(np.nonzero(assign_gt_inds > 0)[0]) + neg_inds = np.unique(np.nonzero(assign_gt_inds == 0)[0]) + pos_assigned_gt_inds = assign_gt_inds[pos_inds] - 1 + + if gt_bboxes.size == 0: + # hack for index error case + assert pos_assigned_gt_inds.size == 0 + pos_gt_bboxes = np.empty_like(gt_bboxes).reshape(-1, 4) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.resize(-1, 4) + pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :] + return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds + + def __call__(self, samples, context=None): + assert len(samples) > 0 + batch_size = len(samples) + # get grid cells of image + h, w = samples[0]['image'].shape[1:3] + multi_level_grid_cells = [] + for stride in self.downsample_ratios: + featmap_size = (int(math.ceil(h / stride)), + int(math.ceil(w / stride))) + multi_level_grid_cells.append( + self.get_grid_cells(featmap_size, self.grid_cell_scale, stride, + self.cell_offset)) + mlvl_grid_cells_list = [ + multi_level_grid_cells for i in range(batch_size) + ] + # pixel cell number of multi-level feature maps + num_level_cells = [ + grid_cells.shape[0] for grid_cells in mlvl_grid_cells_list[0] + ] + num_level_cells_list = [num_level_cells] * batch_size + # concat all level cells and to a single array + for i in range(batch_size): + mlvl_grid_cells_list[i] = np.concatenate(mlvl_grid_cells_list[i]) + # target assign on all images + for sample, grid_cells, num_level_cells in zip( + samples, mlvl_grid_cells_list, num_level_cells_list): + gt_bboxes = sample['gt_bbox'] + gt_labels = sample['gt_class'].squeeze() + if gt_labels.size == 1: + gt_labels = np.array([gt_labels]).astype(np.int32) + gt_bboxes_ignore = None + assign_gt_inds, _ = self.assigner(grid_cells, num_level_cells, + gt_bboxes, gt_bboxes_ignore, + gt_labels) + pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = self.get_sample( + assign_gt_inds, gt_bboxes) + + num_cells = grid_cells.shape[0] + bbox_targets = np.zeros_like(grid_cells) + bbox_weights = np.zeros_like(grid_cells) + labels = np.ones([num_cells], dtype=np.int64) * self.num_classes + label_weights = np.zeros([num_cells], dtype=np.float32) + + if len(pos_inds) > 0: + pos_bbox_targets = pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + if not np.any(gt_labels): + labels[pos_inds] = 0 + else: + labels[pos_inds] = gt_labels[pos_assigned_gt_inds] + + label_weights[pos_inds] = 1.0 + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + sample['grid_cells'] = grid_cells + sample['labels'] = labels + sample['label_weights'] = label_weights + sample['bbox_targets'] = bbox_targets + sample['pos_num'] = max(pos_inds.size, 1) + sample.pop('is_crowd', None) + sample.pop('difficult', None) + sample.pop('gt_class', None) + sample.pop('gt_bbox', None) + sample.pop('gt_score', None) + return samples + + @register_op class Gt2TTFTarget(BaseOperator): __shared__ = ['num_classes'] diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index 0fe932af4..b3c26970f 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -42,6 +42,8 @@ TRT_MIN_SUBGRAPH = { 'DeepSORT': 3, 'JDE': 10, 'FairMOT': 5, + 'GFL': 16, + 'PicoDet': 3, } KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet'] @@ -116,8 +118,9 @@ def _dump_infer_config(config, path, image_shape, model): break if not arch_state: logger.error( - 'Architecture: {} is not supported for exporting model now'.format( - infer_arch)) + 'Architecture: {} is not supported for exporting model now.\n'. + format(infer_arch) + + 'Please set TRT_MIN_SUBGRAPH in ppdet/engine/export_utils.py') os._exit(0) if 'Mask' in infer_arch: infer_cfg['mask'] = True diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index 24df865b7..278d72000 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -21,6 +21,8 @@ from . import jde from . import deepsort from . import fairmot from . import centernet +from . import gfl +from . import picodet from . import detr from . import sparse_rcnn @@ -41,5 +43,7 @@ from .deepsort import * from .fairmot import * from .centernet import * from .blazeface import * +from .gfl import * +from .picodet import * from .detr import * from .sparse_rcnn import * diff --git a/ppdet/modeling/architectures/gfl.py b/ppdet/modeling/architectures/gfl.py new file mode 100644 index 000000000..91c13077f --- /dev/null +++ b/ppdet/modeling/architectures/gfl.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from ppdet.core.workspace import register, create +from .meta_arch import BaseArch + +__all__ = ['GFL'] + + +@register +class GFL(BaseArch): + """ + Generalized Focal Loss network, see https://arxiv.org/abs/2006.04388 + + Args: + backbone (object): backbone instance + neck (object): 'FPN' instance + head (object): 'GFLHead' instance + """ + + __category__ = 'architecture' + + def __init__(self, backbone, neck, head='GFLHead'): + super(GFL, self).__init__() + self.backbone = backbone + self.neck = neck + self.head = head + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + backbone = create(cfg['backbone']) + + kwargs = {'input_shape': backbone.out_shape} + neck = create(cfg['neck'], **kwargs) + + kwargs = {'input_shape': neck.out_shape} + head = create(cfg['head'], **kwargs) + + return { + 'backbone': backbone, + 'neck': neck, + "head": head, + } + + def _forward(self): + body_feats = self.backbone(self.inputs) + fpn_feats = self.neck(body_feats) + head_outs = self.head(fpn_feats) + if not self.training: + im_shape = self.inputs['im_shape'] + scale_factor = self.inputs['scale_factor'] + bboxes, bbox_num = self.head.post_process(head_outs, im_shape, + scale_factor) + return bboxes, bbox_num + else: + return head_outs + + def get_loss(self, ): + loss = {} + + head_outs = self._forward() + loss_gfl = self.head.get_loss(head_outs, self.inputs) + loss.update(loss_gfl) + total_loss = paddle.add_n(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + + def get_pred(self): + bbox_pred, bbox_num = self._forward() + output = {'bbox': bbox_pred, 'bbox_num': bbox_num} + return output diff --git a/ppdet/modeling/architectures/picodet.py b/ppdet/modeling/architectures/picodet.py new file mode 100644 index 000000000..3bb551e24 --- /dev/null +++ b/ppdet/modeling/architectures/picodet.py @@ -0,0 +1,91 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from ppdet.core.workspace import register, create +from .meta_arch import BaseArch + +__all__ = ['PicoDet'] + + +@register +class PicoDet(BaseArch): + """ + Generalized Focal Loss network, see https://arxiv.org/abs/2006.04388 + + Args: + backbone (object): backbone instance + neck (object): 'FPN' instance + head (object): 'PicoHead' instance + """ + + __category__ = 'architecture' + + def __init__(self, backbone, neck, head='PicoHead'): + super(PicoDet, self).__init__() + self.backbone = backbone + self.neck = neck + self.head = head + self.deploy = False + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + backbone = create(cfg['backbone']) + + kwargs = {'input_shape': backbone.out_shape} + neck = create(cfg['neck'], **kwargs) + + kwargs = {'input_shape': neck.out_shape} + head = create(cfg['head'], **kwargs) + + return { + 'backbone': backbone, + 'neck': neck, + "head": head, + } + + def _forward(self): + body_feats = self.backbone(self.inputs) + fpn_feats = self.neck(body_feats) + head_outs = self.head(fpn_feats) + if self.training or self.deploy: + return head_outs + else: + im_shape = self.inputs['im_shape'] + scale_factor = self.inputs['scale_factor'] + bboxes, bbox_num = self.head.post_process(head_outs, im_shape, + scale_factor) + return bboxes, bbox_num + + def get_loss(self, ): + loss = {} + + head_outs = self._forward() + loss_gfl = self.head.get_loss(head_outs, self.inputs) + loss.update(loss_gfl) + total_loss = paddle.add_n(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + + def get_pred(self): + if self.deploy: + return {'picodet': self._forward()[0]} + else: + bbox_pred, bbox_num = self._forward() + output = {'bbox': bbox_pred, 'bbox_num': bbox_num} + return output diff --git a/ppdet/modeling/backbones/__init__.py b/ppdet/modeling/backbones/__init__.py index 6d66690f2..c6e1c0c2c 100644 --- a/ppdet/modeling/backbones/__init__.py +++ b/ppdet/modeling/backbones/__init__.py @@ -23,6 +23,7 @@ from . import ghostnet from . import senet from . import res2net from . import dla +from . import shufflenet_v2 from .vgg import * from .resnet import * @@ -35,3 +36,4 @@ from .ghostnet import * from .senet import * from .res2net import * from .dla import * +from .shufflenet_v2 import * diff --git a/ppdet/modeling/backbones/shufflenet_v2.py b/ppdet/modeling/backbones/shufflenet_v2.py new file mode 100644 index 000000000..75cd6e38d --- /dev/null +++ b/ppdet/modeling/backbones/shufflenet_v2.py @@ -0,0 +1,277 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +from paddle import ParamAttr +from paddle.nn import Conv2D, MaxPool2D, AdaptiveAvgPool2D, BatchNorm +from paddle.nn.initializer import KaimingNormal + +from ppdet.core.workspace import register, serializable +from numbers import Integral +from ..shape_spec import ShapeSpec + +__all__ = ['ShuffleNetV2'] + + +def channel_shuffle(x, groups): + batch_size, num_channels, height, width = x.shape[0:4] + channels_per_group = num_channels // groups + + # reshape + x = paddle.reshape( + x=x, shape=[batch_size, groups, channels_per_group, height, width]) + + # transpose + x = paddle.transpose(x=x, perm=[0, 2, 1, 3, 4]) + + # flatten + x = paddle.reshape(x=x, shape=[batch_size, num_channels, height, width]) + return x + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups=1, + act=None): + super(ConvBNLayer, self).__init__() + self._conv = Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=ParamAttr(initializer=KaimingNormal()), + bias_attr=False) + + self._batch_norm = BatchNorm(out_channels, act=act) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class InvertedResidual(nn.Layer): + def __init__(self, in_channels, out_channels, stride, act="relu"): + super(InvertedResidual, self).__init__() + self._conv_pw = ConvBNLayer( + in_channels=in_channels // 2, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act) + self._conv_dw = ConvBNLayer( + in_channels=out_channels // 2, + out_channels=out_channels // 2, + kernel_size=3, + stride=stride, + padding=1, + groups=out_channels // 2, + act=None) + self._conv_linear = ConvBNLayer( + in_channels=out_channels // 2, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act) + + def forward(self, inputs): + x1, x2 = paddle.split( + inputs, + num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2], + axis=1) + x2 = self._conv_pw(x2) + x2 = self._conv_dw(x2) + x2 = self._conv_linear(x2) + out = paddle.concat([x1, x2], axis=1) + return channel_shuffle(out, 2) + + +class InvertedResidualDS(nn.Layer): + def __init__(self, in_channels, out_channels, stride, act="relu"): + super(InvertedResidualDS, self).__init__() + + # branch1 + self._conv_dw_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=stride, + padding=1, + groups=in_channels, + act=None) + self._conv_linear_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act) + # branch2 + self._conv_pw_2 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act) + self._conv_dw_2 = ConvBNLayer( + in_channels=out_channels // 2, + out_channels=out_channels // 2, + kernel_size=3, + stride=stride, + padding=1, + groups=out_channels // 2, + act=None) + self._conv_linear_2 = ConvBNLayer( + in_channels=out_channels // 2, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act) + + def forward(self, inputs): + x1 = self._conv_dw_1(inputs) + x1 = self._conv_linear_1(x1) + x2 = self._conv_pw_2(inputs) + x2 = self._conv_dw_2(x2) + x2 = self._conv_linear_2(x2) + out = paddle.concat([x1, x2], axis=1) + + return channel_shuffle(out, 2) + + +@register +@serializable +class ShuffleNetV2(nn.Layer): + def __init__(self, + scale=1.0, + act="relu", + feature_maps=[5, 13, 17], + with_last_conv=False): + super(ShuffleNetV2, self).__init__() + self.scale = scale + self.with_last_conv = with_last_conv + if isinstance(feature_maps, Integral): + feature_maps = [feature_maps] + self.feature_maps = feature_maps + stage_repeats = [4, 8, 4] + + if scale == 0.25: + stage_out_channels = [-1, 24, 24, 48, 96, 512] + elif scale == 0.33: + stage_out_channels = [-1, 24, 32, 64, 128, 512] + elif scale == 0.5: + stage_out_channels = [-1, 24, 48, 96, 192, 1024] + elif scale == 1.0: + stage_out_channels = [-1, 24, 116, 232, 464, 1024] + elif scale == 1.5: + stage_out_channels = [-1, 24, 176, 352, 704, 1024] + elif scale == 2.0: + stage_out_channels = [-1, 24, 224, 488, 976, 2048] + else: + raise NotImplementedError("This scale size:[" + str(scale) + + "] is not implemented!") + + self._out_channels = [] + self._feature_idx = 0 + # 1. conv1 + self._conv1 = ConvBNLayer( + in_channels=3, + out_channels=stage_out_channels[1], + kernel_size=3, + stride=2, + padding=1, + act=act) + self._max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1) + self._feature_idx += 1 + + # 2. bottleneck sequences + self._block_list = [] + for stage_id, num_repeat in enumerate(stage_repeats): + for i in range(num_repeat): + if i == 0: + block = self.add_sublayer( + name=str(stage_id + 2) + '_' + str(i + 1), + sublayer=InvertedResidualDS( + in_channels=stage_out_channels[stage_id + 1], + out_channels=stage_out_channels[stage_id + 2], + stride=2, + act=act)) + else: + block = self.add_sublayer( + name=str(stage_id + 2) + '_' + str(i + 1), + sublayer=InvertedResidual( + in_channels=stage_out_channels[stage_id + 2], + out_channels=stage_out_channels[stage_id + 2], + stride=1, + act=act)) + self._block_list.append(block) + self._feature_idx += 1 + self._update_out_channels(stage_out_channels[stage_id + 2], + self._feature_idx, self.feature_maps) + + if self.with_last_conv: + # last_conv + self._last_conv = ConvBNLayer( + in_channels=stage_out_channels[-2], + out_channels=stage_out_channels[-1], + kernel_size=1, + stride=1, + padding=0, + act=act) + self._feature_idx += 1 + self._update_out_channels(stage_out_channels[-1], self._feature_idx, + self.feature_maps) + + def _update_out_channels(self, channel, feature_idx, feature_maps): + if feature_idx in feature_maps: + self._out_channels.append(channel) + + def forward(self, inputs): + y = self._conv1(inputs['image']) + y = self._max_pool(y) + outs = [] + for i, inv in enumerate(self._block_list): + y = inv(y) + if i + 2 in self.feature_maps: + outs.append(y) + + if self.with_last_conv: + y = self._last_conv(y) + outs.append(y) + return outs + + @property + def out_shape(self): + return [ShapeSpec(channels=c) for c in self._out_channels] diff --git a/ppdet/modeling/bbox_utils.py b/ppdet/modeling/bbox_utils.py index 4c4acd9dd..df8eda94d 100644 --- a/ppdet/modeling/bbox_utils.py +++ b/ppdet/modeling/bbox_utils.py @@ -601,3 +601,47 @@ def bbox_iou_np_expand(box1, box2, x1y1x2y2=True, eps=1e-16): ious = inter_area / (b1_area + b2_area - inter_area + eps) return ious + + +def bbox2distance(points, bbox, max_dis=None, eps=0.1): + """Decode bounding box based on distances. + Args: + points (Tensor): Shape (n, 2), [x, y]. + bbox (Tensor): Shape (n, 4), "xyxy" format + max_dis (float): Upper bound of the distance. + eps (float): a small value to ensure target < max_dis, instead <= + Returns: + Tensor: Decoded distances. + """ + left = points[:, 0] - bbox[:, 0] + top = points[:, 1] - bbox[:, 1] + right = bbox[:, 2] - points[:, 0] + bottom = bbox[:, 3] - points[:, 1] + if max_dis is not None: + left = left.clip(min=0, max=max_dis - eps) + top = top.clip(min=0, max=max_dis - eps) + right = right.clip(min=0, max=max_dis - eps) + bottom = bottom.clip(min=0, max=max_dis - eps) + return paddle.stack([left, top, right, bottom], -1) + + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[:, 0] - distance[:, 0] + y1 = points[:, 1] - distance[:, 1] + x2 = points[:, 0] + distance[:, 2] + y2 = points[:, 1] + distance[:, 3] + if max_shape is not None: + x1 = x1.clip(min=0, max=max_shape[1]) + y1 = y1.clip(min=0, max=max_shape[0]) + x2 = x2.clip(min=0, max=max_shape[1]) + y2 = y2.clip(min=0, max=max_shape[0]) + return paddle.stack([x1, y1, x2, y2], -1) diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index a21cc3c97..dd6b1dcc2 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -25,6 +25,8 @@ from . import face_head from . import s2anet_head from . import keypoint_hrhrnet_head from . import centernet_head +from . import gfl_head +from . import pico_head from . import detr_head from . import sparsercnn_head @@ -41,5 +43,7 @@ from .face_head import * from .s2anet_head import * from .keypoint_hrhrnet_head import * from .centernet_head import * +from .gfl_head import * +from .pico_head import * from .detr_head import * from .sparsercnn_head import * diff --git a/ppdet/modeling/heads/gfl_head.py b/ppdet/modeling/heads/gfl_head.py new file mode 100644 index 000000000..89bdb1f2b --- /dev/null +++ b/ppdet/modeling/heads/gfl_head.py @@ -0,0 +1,470 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn.initializer import Normal, Constant + +from ppdet.core.workspace import register +from ppdet.modeling.layers import ConvNormLayer +from ppdet.modeling.bbox_utils import distance2bbox, bbox2distance +from ppdet.data.transform.atss_assigner import bbox_overlaps + + +class ScaleReg(nn.Layer): + """ + Parameter for scaling the regression outputs. + """ + + def __init__(self): + super(ScaleReg, self).__init__() + self.scale_reg = self.create_parameter( + shape=[1], + attr=ParamAttr(initializer=Constant(value=1.)), + dtype="float32") + + def forward(self, inputs): + out = inputs * self.scale_reg + return out + + +class Integral(nn.Layer): + """A fixed layer for calculating integral result from distribution. + This layer calculates the target location by :math: `sum{P(y_i) * y_i}`, + P(y_i) denotes the softmax vector that represents the discrete distribution + y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max} + + Args: + reg_max (int): The maximal value of the discrete set. Default: 16. You + may want to reset it according to your new dataset or related + settings. + """ + + def __init__(self, reg_max=16): + super(Integral, self).__init__() + self.reg_max = reg_max + self.register_buffer('project', + paddle.linspace(0, self.reg_max, self.reg_max + 1)) + + def forward(self, x): + """Forward feature from the regression head to get integral result of + bounding box location. + Args: + x (Tensor): Features of the regression head, shape (N, 4*(n+1)), + n is self.reg_max. + Returns: + x (Tensor): Integral result of box locations, i.e., distance + offsets from the box center in four directions, shape (N, 4). + """ + x = F.softmax(x.reshape([-1, self.reg_max + 1]), axis=1) + x = F.linear(x, self.project).reshape([-1, 4]) + return x + + +@register +class DGQP(nn.Layer): + """Distribution-Guided Quality Predictor of GFocal head + + Args: + reg_topk (int): top-k statistics of distribution to guide LQE + reg_channels (int): hidden layer unit to generate LQE + add_mean (bool): Whether to calculate the mean of top-k statistics + """ + + def __init__(self, reg_topk=4, reg_channels=64, add_mean=True): + super(DGQP, self).__init__() + self.reg_topk = reg_topk + self.reg_channels = reg_channels + self.add_mean = add_mean + self.total_dim = reg_topk + if add_mean: + self.total_dim += 1 + self.reg_conv1 = self.add_sublayer( + 'dgqp_reg_conv1', + nn.Conv2D( + in_channels=4 * self.total_dim, + out_channels=self.reg_channels, + kernel_size=1, + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.01)), + bias_attr=ParamAttr(initializer=Constant(value=0)))) + self.reg_conv2 = self.add_sublayer( + 'dgqp_reg_conv2', + nn.Conv2D( + in_channels=self.reg_channels, + out_channels=1, + kernel_size=1, + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.01)), + bias_attr=ParamAttr(initializer=Constant(value=0)))) + + def forward(self, x): + """Forward feature from the regression head to get integral result of + bounding box location. + Args: + x (Tensor): Features of the regression head, shape (N, 4*(n+1)), + n is self.reg_max. + Returns: + x (Tensor): Integral result of box locations, i.e., distance + offsets from the box center in four directions, shape (N, 4). + """ + N, _, H, W = x.shape[:] + prob = F.softmax(x.reshape([N, 4, -1, H, W]), axis=2) + prob_topk, _ = prob.topk(self.reg_topk, axis=2) + if self.add_mean: + stat = paddle.concat( + [prob_topk, prob_topk.mean( + axis=2, keepdim=True)], axis=2) + else: + stat = prob_topk + y = F.relu(self.reg_conv1(stat.reshape([N, -1, H, W]))) + y = F.sigmoid(self.reg_conv2(y)) + return y + + +@register +class GFLHead(nn.Layer): + """ + GFLHead + Args: + conv_feat (object): Instance of 'FCOSFeat' + num_classes (int): Number of classes + fpn_stride (list): The stride of each FPN Layer + prior_prob (float): Used to set the bias init for the class prediction layer + loss_qfl (object): + loss_dfl (object): + loss_bbox (object): + reg_max: Max value of integral set :math: `{0, ..., reg_max}` + n QFL setting. Default: 16. + """ + __inject__ = [ + 'conv_feat', 'dgqp_module', 'loss_qfl', 'loss_dfl', 'loss_bbox', 'nms' + ] + __shared__ = ['num_classes'] + + def __init__(self, + conv_feat='FCOSFeat', + dgqp_module=None, + num_classes=80, + fpn_stride=[8, 16, 32, 64, 128], + prior_prob=0.01, + loss_qfl='QualityFocalLoss', + loss_dfl='DistributionFocalLoss', + loss_bbox='GIoULoss', + reg_max=16, + feat_in_chan=256, + nms=None, + nms_pre=1000, + cell_offset=0): + super(GFLHead, self).__init__() + self.conv_feat = conv_feat + self.dgqp_module = dgqp_module + self.num_classes = num_classes + self.fpn_stride = fpn_stride + self.prior_prob = prior_prob + self.loss_qfl = loss_qfl + self.loss_dfl = loss_dfl + self.loss_bbox = loss_bbox + self.reg_max = reg_max + self.feat_in_chan = feat_in_chan + self.nms = nms + self.nms_pre = nms_pre + self.cell_offset = cell_offset + self.use_sigmoid = self.loss_qfl.use_sigmoid + if self.use_sigmoid: + self.cls_out_channels = self.num_classes + else: + self.cls_out_channels = self.num_classes + 1 + + conv_cls_name = "gfl_head_cls" + bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob) + self.gfl_head_cls = self.add_sublayer( + conv_cls_name, + nn.Conv2D( + in_channels=self.feat_in_chan, + out_channels=self.cls_out_channels, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.01)), + bias_attr=ParamAttr( + initializer=Constant(value=bias_init_value)))) + + conv_reg_name = "gfl_head_reg" + self.gfl_head_reg = self.add_sublayer( + conv_reg_name, + nn.Conv2D( + in_channels=self.feat_in_chan, + out_channels=4 * (self.reg_max + 1), + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.01)), + bias_attr=ParamAttr(initializer=Constant(value=0)))) + + self.scales_regs = [] + for i in range(len(self.fpn_stride)): + lvl = int(math.log(int(self.fpn_stride[i]), 2)) + feat_name = 'p{}_feat'.format(lvl) + scale_reg = self.add_sublayer(feat_name, ScaleReg()) + self.scales_regs.append(scale_reg) + + self.distribution_project = Integral(self.reg_max) + + def forward(self, fpn_feats): + assert len(fpn_feats) == len( + self.fpn_stride + ), "The size of fpn_feats is not equal to size of fpn_stride" + cls_logits_list = [] + bboxes_reg_list = [] + for scale_reg, fpn_feat in zip(self.scales_regs, fpn_feats): + conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat) + cls_logits = self.gfl_head_cls(conv_cls_feat) + bbox_reg = scale_reg(self.gfl_head_reg(conv_reg_feat)) + if self.dgqp_module: + quality_score = self.dgqp_module(bbox_reg) + cls_logits = F.sigmoid(cls_logits) * quality_score + cls_logits_list.append(cls_logits) + bboxes_reg_list.append(bbox_reg) + + return (cls_logits_list, bboxes_reg_list) + + def _images_to_levels(self, target, num_level_anchors): + """ + Convert targets by image to targets by feature level. + """ + level_targets = [] + start = 0 + for n in num_level_anchors: + end = start + n + level_targets.append(target[:, start:end].squeeze(0)) + start = end + return level_targets + + def _grid_cells_to_center(self, grid_cells): + """ + Get center location of each gird cell + Args: + grid_cells: grid cells of a feature map + Returns: + center points + """ + cells_cx = (grid_cells[:, 2] + grid_cells[:, 0]) / 2 + cells_cy = (grid_cells[:, 3] + grid_cells[:, 1]) / 2 + return paddle.stack([cells_cx, cells_cy], axis=-1) + + def get_loss(self, gfl_head_outs, gt_meta): + cls_logits, bboxes_reg = gfl_head_outs + num_level_anchors = [ + featmap.shape[-2] * featmap.shape[-1] for featmap in cls_logits + ] + grid_cells_list = self._images_to_levels(gt_meta['grid_cells'], + num_level_anchors) + labels_list = self._images_to_levels(gt_meta['labels'], + num_level_anchors) + label_weights_list = self._images_to_levels(gt_meta['label_weights'], + num_level_anchors) + bbox_targets_list = self._images_to_levels(gt_meta['bbox_targets'], + num_level_anchors) + num_total_pos = sum(gt_meta['pos_num']) + + loss_bbox_list, loss_dfl_list, loss_qfl_list, avg_factor = [], [], [], [] + for cls_score, bbox_pred, grid_cells, labels, label_weights, bbox_targets, stride in zip( + cls_logits, bboxes_reg, grid_cells_list, labels_list, + label_weights_list, bbox_targets_list, self.fpn_stride): + grid_cells = grid_cells.reshape([-1, 4]) + cls_score = cls_score.transpose([0, 2, 3, 1]).reshape( + [-1, self.cls_out_channels]) + bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape( + [-1, 4 * (self.reg_max + 1)]) + bbox_targets = bbox_targets.reshape([-1, 4]) + labels = labels.reshape([-1]) + label_weights = label_weights.reshape([-1]) + + bg_class_ind = self.num_classes + pos_inds = paddle.nonzero( + paddle.logical_and((labels >= 0), (labels < bg_class_ind)), + as_tuple=False).squeeze(1) + score = np.zeros(labels.shape) + if len(pos_inds) > 0: + pos_bbox_targets = paddle.gather(bbox_targets, pos_inds, axis=0) + pos_bbox_pred = paddle.gather(bbox_pred, pos_inds, axis=0) + pos_grid_cells = paddle.gather(grid_cells, pos_inds, axis=0) + pos_grid_cell_centers = self._grid_cells_to_center( + pos_grid_cells) / stride + + weight_targets = F.sigmoid(cls_score.detach()) + weight_targets = paddle.gather( + weight_targets.max(axis=1), pos_inds, axis=0) + pos_bbox_pred_corners = self.distribution_project(pos_bbox_pred) + pos_decode_bbox_pred = distance2bbox(pos_grid_cell_centers, + pos_bbox_pred_corners) + pos_decode_bbox_targets = pos_bbox_targets / stride + bbox_iou = bbox_overlaps( + pos_decode_bbox_pred.detach().numpy(), + pos_decode_bbox_targets.detach().numpy(), + is_aligned=True) + score[pos_inds.numpy()] = bbox_iou + pred_corners = pos_bbox_pred.reshape([-1, self.reg_max + 1]) + target_corners = bbox2distance(pos_grid_cell_centers, + pos_decode_bbox_targets, + self.reg_max).reshape([-1]) + # regression loss + loss_bbox = paddle.sum( + self.loss_bbox(pos_decode_bbox_pred, + pos_decode_bbox_targets) * + weight_targets.mean(axis=-1)) + + # dfl loss + loss_dfl = self.loss_dfl( + pred_corners, + target_corners, + weight=weight_targets.unsqueeze(-1).expand([-1, 4]).reshape( + [-1]), + avg_factor=4.0) + else: + loss_bbox = bbox_pred.sum() * 0 + loss_dfl = bbox_pred.sum() * 0 + weight_targets = paddle.to_tensor([0]) + + # qfl loss + score = paddle.to_tensor(score) + loss_qfl = self.loss_qfl( + cls_score, (labels, score), + weight=label_weights, + avg_factor=num_total_pos) + loss_bbox_list.append(loss_bbox) + loss_dfl_list.append(loss_dfl) + loss_qfl_list.append(loss_qfl) + avg_factor.append(weight_targets.sum()) + + avg_factor = sum(avg_factor) + if avg_factor <= 0: + loss_qfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False) + loss_bbox = paddle.to_tensor( + 0, dtype='float32', stop_gradient=False) + loss_dfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False) + else: + losses_bbox = list(map(lambda x: x / avg_factor, loss_bbox_list)) + losses_dfl = list(map(lambda x: x / avg_factor, loss_dfl_list)) + loss_qfl = sum(loss_qfl_list) + loss_bbox = sum(losses_bbox) + loss_dfl = sum(losses_dfl) + + loss_states = dict( + loss_qfl=loss_qfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl) + + return loss_states + + def get_single_level_center_point(self, featmap_size, stride, + cell_offset=0): + """ + Generate pixel centers of a single stage feature map. + Args: + featmap_size: height and width of the feature map + stride: down sample stride of the feature map + Returns: + y and x of the center points + """ + h, w = featmap_size + x_range = (paddle.arange(w, dtype='float32') + cell_offset) * stride + y_range = (paddle.arange(h, dtype='float32') + cell_offset) * stride + y, x = paddle.meshgrid(y_range, x_range) + y = y.flatten() + x = x.flatten() + return y, x + + def get_bboxes_single(self, + cls_scores, + bbox_preds, + img_shape, + scale_factor, + rescale=True, + cell_offset=0): + assert len(cls_scores) == len(bbox_preds) + mlvl_bboxes = [] + mlvl_scores = [] + for stride, cls_score, bbox_pred in zip(self.fpn_stride, cls_scores, + bbox_preds): + featmap_size = cls_score.shape[-2:] + y, x = self.get_single_level_center_point( + featmap_size, stride, cell_offset=cell_offset) + center_points = paddle.stack([x, y], axis=-1) + scores = F.sigmoid( + cls_score.transpose([1, 2, 0]).reshape( + [-1, self.cls_out_channels])) + bbox_pred = bbox_pred.transpose([1, 2, 0]) + bbox_pred = self.distribution_project(bbox_pred) * stride + + if scores.shape[0] > self.nms_pre: + max_scores = scores.max(axis=1) + _, topk_inds = max_scores.topk(self.nms_pre) + center_points = center_points.gather(topk_inds) + bbox_pred = bbox_pred.gather(topk_inds) + scores = scores.gather(topk_inds) + + bboxes = distance2bbox( + center_points, bbox_pred, max_shape=img_shape) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_bboxes = paddle.concat(mlvl_bboxes) + if rescale: + # [h_scale, w_scale] to [w_scale, h_scale, w_scale, h_scale] + im_scale = paddle.concat([scale_factor[::-1], scale_factor[::-1]]) + mlvl_bboxes /= im_scale + mlvl_scores = paddle.concat(mlvl_scores) + if self.use_sigmoid: + # add a dummy background class to the backend when use_sigmoid + padding = paddle.zeros([mlvl_scores.shape[0], 1]) + mlvl_scores = paddle.concat([mlvl_scores, padding], axis=1) + mlvl_scores = mlvl_scores.transpose([1, 0]) + return mlvl_bboxes, mlvl_scores + + def decode(self, cls_scores, bbox_preds, im_shape, scale_factor, + cell_offset): + batch_bboxes = [] + batch_scores = [] + for img_id in range(cls_scores[0].shape[0]): + num_levels = len(cls_scores) + cls_score_list = [cls_scores[i][img_id] for i in range(num_levels)] + bbox_pred_list = [bbox_preds[i][img_id] for i in range(num_levels)] + bboxes, scores = self.get_bboxes_single( + cls_score_list, + bbox_pred_list, + im_shape[img_id], + scale_factor[img_id], + cell_offset=cell_offset) + batch_bboxes.append(bboxes) + batch_scores.append(scores) + batch_bboxes = paddle.stack(batch_bboxes, axis=0) + batch_scores = paddle.stack(batch_scores, axis=0) + + return batch_bboxes, batch_scores + + def post_process(self, gfl_head_outs, im_shape, scale_factor): + cls_scores, bboxes_reg = gfl_head_outs + bboxes, score = self.decode(cls_scores, bboxes_reg, im_shape, + scale_factor, self.cell_offset) + bbox_pred, bbox_num, _ = self.nms(bboxes, score) + return bbox_pred, bbox_num diff --git a/ppdet/modeling/heads/pico_head.py b/ppdet/modeling/heads/pico_head.py new file mode 100644 index 000000000..db409beed --- /dev/null +++ b/ppdet/modeling/heads/pico_head.py @@ -0,0 +1,328 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn.initializer import Normal, Constant + +from ppdet.core.workspace import register +from ppdet.modeling.layers import ConvNormLayer +from ppdet.modeling.bbox_utils import distance2bbox, bbox2distance +from ppdet.data.transform.atss_assigner import bbox_overlaps +from .gfl_head import GFLHead + + +@register +class PicoFeat(nn.Layer): + """ + PicoFeat of PicoDet + + Args: + feat_in (int): The channel number of input Tensor. + feat_out (int): The channel number of output Tensor. + num_convs (int): The convolution number of the LiteGFLFeat. + norm_type (str): Normalization type, 'bn'/'sync_bn'/'gn'. + """ + + def __init__(self, + feat_in=256, + feat_out=96, + num_fpn_stride=3, + num_convs=2, + norm_type='bn', + share_cls_reg=False): + super(PicoFeat, self).__init__() + self.num_convs = num_convs + self.norm_type = norm_type + self.share_cls_reg = share_cls_reg + self.cls_convs = [] + self.reg_convs = [] + for stage_idx in range(num_fpn_stride): + cls_subnet_convs = [] + reg_subnet_convs = [] + for i in range(self.num_convs): + in_c = feat_in if i == 0 else feat_out + cls_conv_dw = self.add_sublayer( + 'cls_conv_dw{}.{}'.format(stage_idx, i), + ConvNormLayer( + ch_in=in_c, + ch_out=feat_out, + filter_size=3, + stride=1, + groups=feat_out, + norm_type=norm_type, + bias_on=False, + lr_scale=2.)) + cls_subnet_convs.append(cls_conv_dw) + cls_conv_pw = self.add_sublayer( + 'cls_conv_pw{}.{}'.format(stage_idx, i), + ConvNormLayer( + ch_in=in_c, + ch_out=feat_out, + filter_size=1, + stride=1, + norm_type=norm_type, + bias_on=False, + lr_scale=2.)) + cls_subnet_convs.append(cls_conv_pw) + + if not self.share_cls_reg: + reg_conv_dw = self.add_sublayer( + 'reg_conv_dw{}.{}'.format(stage_idx, i), + ConvNormLayer( + ch_in=in_c, + ch_out=feat_out, + filter_size=3, + stride=1, + groups=feat_out, + norm_type=norm_type, + bias_on=False, + lr_scale=2.)) + reg_subnet_convs.append(reg_conv_dw) + reg_conv_pw = self.add_sublayer( + 'reg_conv_pw{}.{}'.format(stage_idx, i), + ConvNormLayer( + ch_in=in_c, + ch_out=feat_out, + filter_size=1, + stride=1, + norm_type=norm_type, + bias_on=False, + lr_scale=2.)) + reg_subnet_convs.append(reg_conv_pw) + self.cls_convs.append(cls_subnet_convs) + self.reg_convs.append(reg_subnet_convs) + + def forward(self, fpn_feat, stage_idx): + assert stage_idx < len(self.cls_convs) + cls_feat = fpn_feat + reg_feat = fpn_feat + for i in range(len(self.cls_convs[stage_idx])): + cls_feat = F.leaky_relu(self.cls_convs[stage_idx][i](cls_feat), 0.1) + if not self.share_cls_reg: + reg_feat = F.leaky_relu(self.reg_convs[stage_idx][i](reg_feat), + 0.1) + return cls_feat, reg_feat + + +@register +class PicoHead(GFLHead): + """ + PicoHead + Args: + conv_feat (object): Instance of 'LiteGFLFeat' + num_classes (int): Number of classes + fpn_stride (list): The stride of each FPN Layer + prior_prob (float): Used to set the bias init for the class prediction layer + loss_qfl (object): + loss_dfl (object): + loss_bbox (object): + reg_max: Max value of integral set :math: `{0, ..., reg_max}` + n QFL setting. Default: 16. + """ + __inject__ = [ + 'conv_feat', 'dgqp_module', 'loss_qfl', 'loss_dfl', 'loss_bbox', 'nms' + ] + __shared__ = ['num_classes'] + + def __init__(self, + conv_feat='PicoFeat', + dgqp_module=None, + num_classes=80, + fpn_stride=[8, 16, 32], + prior_prob=0.01, + loss_qfl='QualityFocalLoss', + loss_dfl='DistributionFocalLoss', + loss_bbox='GIoULoss', + reg_max=16, + feat_in_chan=96, + nms=None, + nms_pre=1000, + cell_offset=0): + super(PicoHead, self).__init__( + conv_feat=conv_feat, + dgqp_module=dgqp_module, + num_classes=num_classes, + fpn_stride=fpn_stride, + prior_prob=prior_prob, + loss_qfl=loss_qfl, + loss_dfl=loss_dfl, + loss_bbox=loss_bbox, + reg_max=reg_max, + feat_in_chan=feat_in_chan, + nms=nms, + nms_pre=nms_pre, + cell_offset=cell_offset) + self.conv_feat = conv_feat + self.num_classes = num_classes + self.fpn_stride = fpn_stride + self.prior_prob = prior_prob + self.loss_qfl = loss_qfl + self.loss_dfl = loss_dfl + self.loss_bbox = loss_bbox + self.reg_max = reg_max + self.feat_in_chan = feat_in_chan + self.nms = nms + self.nms_pre = nms_pre + self.cell_offset = cell_offset + self.use_sigmoid = self.loss_qfl.use_sigmoid + if self.use_sigmoid: + self.cls_out_channels = self.num_classes + else: + self.cls_out_channels = self.num_classes + 1 + bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob) + # Clear the super class initialization + self.gfl_head_cls = None + self.gfl_head_reg = None + self.scales_regs = None + + self.head_cls_list = [] + self.head_reg_list = [] + for i in range(len(fpn_stride)): + head_cls = self.add_sublayer( + "head_cls" + str(i), + nn.Conv2D( + in_channels=self.feat_in_chan, + out_channels=self.cls_out_channels + 4 * (self.reg_max + 1) + if self.conv_feat.share_cls_reg else self.cls_out_channels, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.01)), + bias_attr=ParamAttr( + initializer=Constant(value=bias_init_value)))) + self.head_cls_list.append(head_cls) + if not self.conv_feat.share_cls_reg: + head_reg = self.add_sublayer( + "head_reg" + str(i), + nn.Conv2D( + in_channels=self.feat_in_chan, + out_channels=4 * (self.reg_max + 1), + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.01)), + bias_attr=ParamAttr(initializer=Constant(value=0)))) + self.head_reg_list.append(head_reg) + + def forward(self, fpn_feats): + assert len(fpn_feats) == len( + self.fpn_stride + ), "The size of fpn_feats is not equal to size of fpn_stride" + cls_logits_list = [] + bboxes_reg_list = [] + for i, fpn_feat in enumerate(fpn_feats): + conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat, i) + if self.conv_feat.share_cls_reg: + cls_logits = self.head_cls_list[i](conv_cls_feat) + cls_score, bbox_pred = paddle.split( + cls_logits, + [self.cls_out_channels, 4 * (self.reg_max + 1)], + axis=1) + else: + cls_score = self.head_cls_list[i](conv_cls_feat) + bbox_pred = self.head_reg_list[i](conv_reg_feat) + if self.dgqp_module: + quality_score = self.dgqp_module(bbox_pred) + cls_score = F.sigmoid(cls_score) * quality_score + + if not self.training: + cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1])) + bbox_pred = self.distribution_project( + bbox_pred.transpose([0, 2, 3, 1])) * self.fpn_stride[i] + + cls_logits_list.append(cls_score) + bboxes_reg_list.append(bbox_pred) + + return (cls_logits_list, bboxes_reg_list) + + def get_bboxes_single(self, + cls_scores, + bbox_preds, + img_shape, + scale_factor, + rescale=True, + cell_offset=0): + assert len(cls_scores) == len(bbox_preds) + mlvl_bboxes = [] + mlvl_scores = [] + for stride, cls_score, bbox_pred in zip(self.fpn_stride, cls_scores, + bbox_preds): + featmap_size = cls_score.shape[0:2] + y, x = self.get_single_level_center_point( + featmap_size, stride, cell_offset=cell_offset) + center_points = paddle.stack([x, y], axis=-1) + scores = cls_score.reshape([-1, self.cls_out_channels]) + + if scores.shape[0] > self.nms_pre: + max_scores = scores.max(axis=1) + _, topk_inds = max_scores.topk(self.nms_pre) + center_points = center_points.gather(topk_inds) + bbox_pred = bbox_pred.gather(topk_inds) + scores = scores.gather(topk_inds) + + bboxes = distance2bbox( + center_points, bbox_pred, max_shape=img_shape) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_bboxes = paddle.concat(mlvl_bboxes) + if rescale: + # [h_scale, w_scale] to [w_scale, h_scale, w_scale, h_scale] + im_scale = paddle.concat([scale_factor[::-1], scale_factor[::-1]]) + mlvl_bboxes /= im_scale + mlvl_scores = paddle.concat(mlvl_scores) + mlvl_scores = mlvl_scores.transpose([1, 0]) + return mlvl_bboxes, mlvl_scores + + def decode(self, cls_scores, bbox_preds, im_shape, scale_factor, + cell_offset): + batch_bboxes = [] + batch_scores = [] + batch_size = cls_scores[0].shape[0] + for img_id in range(batch_size): + num_levels = len(cls_scores) + cls_score_list = [cls_scores[i][img_id] for i in range(num_levels)] + bbox_pred_list = [ + bbox_preds[i].reshape([batch_size, -1, 4])[img_id] + for i in range(num_levels) + ] + bboxes, scores = self.get_bboxes_single( + cls_score_list, + bbox_pred_list, + im_shape[img_id], + scale_factor[img_id], + cell_offset=cell_offset) + batch_bboxes.append(bboxes) + batch_scores.append(scores) + batch_bboxes = paddle.stack(batch_bboxes, axis=0) + batch_scores = paddle.stack(batch_scores, axis=0) + + return batch_bboxes, batch_scores + + def post_process(self, gfl_head_outs, im_shape, scale_factor): + cls_scores, bboxes_reg = gfl_head_outs + bboxes, score = self.decode(cls_scores, bboxes_reg, im_shape, + scale_factor, self.cell_offset) + bbox_pred, bbox_num, _ = self.nms(bboxes, score) + return bbox_pred, bbox_num diff --git a/ppdet/modeling/losses/__init__.py b/ppdet/modeling/losses/__init__.py index 8c424da35..83389c08e 100644 --- a/ppdet/modeling/losses/__init__.py +++ b/ppdet/modeling/losses/__init__.py @@ -22,6 +22,7 @@ from . import ctfocal_loss from . import keypoint_loss from . import jde_loss from . import fairmot_loss +from . import gfocal_loss from . import detr_loss from . import sparsercnn_loss @@ -35,5 +36,6 @@ from .ctfocal_loss import * from .keypoint_loss import * from .jde_loss import * from .fairmot_loss import * +from .gfocal_loss import * from .detr_loss import * from .sparsercnn_loss import * diff --git a/ppdet/modeling/losses/gfocal_loss.py b/ppdet/modeling/losses/gfocal_loss.py new file mode 100644 index 000000000..149d30bf8 --- /dev/null +++ b/ppdet/modeling/losses/gfocal_loss.py @@ -0,0 +1,214 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.core.workspace import register, serializable +from ppdet.modeling import ops + +__all__ = ['QualityFocalLoss', 'DistributionFocalLoss'] + + +def quality_focal_loss(pred, target, beta=2.0, use_sigmoid=True): + """ + Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning + Qualified and Distributed Bounding Boxes for Dense Object Detection + `_. + Args: + pred (Tensor): Predicted joint representation of classification + and quality (IoU) estimation with shape (N, C), C is the number of + classes. + target (tuple([Tensor])): Target category label with shape (N,) + and target quality label with shape (N,). + beta (float): The beta parameter for calculating the modulating factor. + Defaults to 2.0. + Returns: + Tensor: Loss tensor with shape (N,). + """ + assert len(target) == 2, """target for QFL must be a tuple of two elements, + including category label and quality label, respectively""" + # label denotes the category id, score denotes the quality score + label, score = target + if use_sigmoid: + func = F.binary_cross_entropy_with_logits + else: + func = F.binary_cross_entropy + + # negatives are supervised by 0 quality score + pred_sigmoid = F.sigmoid(pred) if use_sigmoid else pred + scale_factor = pred_sigmoid + zerolabel = paddle.zeros(pred.shape, dtype='float32') + loss = func(pred, zerolabel, reduction='none') * scale_factor.pow(beta) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = pred.shape[1] + pos = paddle.logical_and((label >= 0), + (label < bg_class_ind)).nonzero().squeeze(1) + if pos.shape[0] == 0: + return loss.sum(axis=1) + pos_label = paddle.gather(label, pos, axis=0) + pos_mask = np.zeros(pred.shape, dtype=np.int32) + pos_mask[pos.numpy(), pos_label.numpy()] = 1 + pos_mask = paddle.to_tensor(pos_mask, dtype='bool') + score = score.unsqueeze(-1).expand([-1, pred.shape[1]]).cast('float32') + # positives are supervised by bbox quality (IoU) score + scale_factor_new = score - pred_sigmoid + + loss_pos = func( + pred, score, reduction='none') * scale_factor_new.abs().pow(beta) + loss = loss * paddle.logical_not(pos_mask) + loss_pos * pos_mask + loss = loss.sum(axis=1) + return loss + + +def distribution_focal_loss(pred, label): + """Distribution Focal Loss (DFL) is from `Generalized Focal Loss: Learning + Qualified and Distributed Bounding Boxes for Dense Object Detection + `_. + Args: + pred (Tensor): Predicted general distribution of bounding boxes + (before softmax) with shape (N, n+1), n is the max value of the + integral set `{0, ..., n}` in paper. + label (Tensor): Target distance label for bounding boxes with + shape (N,). + Returns: + Tensor: Loss tensor with shape (N,). + """ + dis_left = label.cast('int64') + dis_right = dis_left + 1 + weight_left = dis_right.cast('float32') - label + weight_right = label - dis_left.cast('float32') + loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left \ + + F.cross_entropy(pred, dis_right, reduction='none') * weight_right + return loss + + +@register +@serializable +class QualityFocalLoss(nn.Layer): + r"""Quality Focal Loss (QFL) is a variant of `Generalized Focal Loss: + Learning Qualified and Distributed Bounding Boxes for Dense Object + Detection `_. + Args: + use_sigmoid (bool): Whether sigmoid operation is conducted in QFL. + Defaults to True. + beta (float): The beta parameter for calculating the modulating factor. + Defaults to 2.0. + reduction (str): Options are "none", "mean" and "sum". + loss_weight (float): Loss weight of current loss. + """ + + def __init__(self, + use_sigmoid=True, + beta=2.0, + reduction='mean', + loss_weight=1.0): + super(QualityFocalLoss, self).__init__() + self.use_sigmoid = use_sigmoid + self.beta = beta + assert reduction in ('none', 'mean', 'sum') + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, pred, target, weight=None, avg_factor=None): + """Forward function. + Args: + pred (Tensor): Predicted joint representation of + classification and quality (IoU) estimation with shape (N, C), + C is the number of classes. + target (tuple([Tensor])): Target category label with shape + (N,) and target quality label with shape (N,). + weight (Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + + loss = self.loss_weight * quality_focal_loss( + pred, target, beta=self.beta, use_sigmoid=self.use_sigmoid) + + if weight is not None: + loss = loss * weight + if avg_factor is None: + if self.reduction == 'none': + return loss + elif self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: + # if reduction is mean, then average the loss by avg_factor + if self.reduction == 'mean': + loss = loss.sum() / avg_factor + # if reduction is 'none', then do nothing, otherwise raise an error + elif self.reduction != 'none': + raise ValueError( + 'avg_factor can not be used with reduction="sum"') + return loss + + +@register +@serializable +class DistributionFocalLoss(nn.Layer): + """Distribution Focal Loss (DFL) is a variant of `Generalized Focal Loss: + Learning Qualified and Distributed Bounding Boxes for Dense Object + Detection `_. + Args: + reduction (str): Options are `'none'`, `'mean'` and `'sum'`. + loss_weight (float): Loss weight of current loss. + """ + + def __init__(self, reduction='mean', loss_weight=1.0): + super(DistributionFocalLoss, self).__init__() + assert reduction in ('none', 'mean', 'sum') + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, pred, target, weight=None, avg_factor=None): + """Forward function. + Args: + pred (Tensor): Predicted general distribution of bounding + boxes (before softmax) with shape (N, n+1), n is the max value + of the integral set `{0, ..., n}` in paper. + target (Tensor): Target distance label for bounding boxes + with shape (N,). + weight (Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + loss = self.loss_weight * distribution_focal_loss(pred, target) + if weight is not None: + loss = loss * weight + if avg_factor is None: + if self.reduction == 'none': + return loss + elif self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: + # if reduction is mean, then average the loss by avg_factor + if self.reduction == 'mean': + loss = loss.sum() / avg_factor + # if reduction is 'none', then do nothing, otherwise raise an error + elif self.reduction != 'none': + raise ValueError( + 'avg_factor can not be used with reduction="sum"') + return loss diff --git a/ppdet/modeling/necks/__init__.py b/ppdet/modeling/necks/__init__.py index 7a7e3af40..742dff293 100644 --- a/ppdet/modeling/necks/__init__.py +++ b/ppdet/modeling/necks/__init__.py @@ -17,6 +17,7 @@ from . import yolo_fpn from . import hrfpn from . import ttf_fpn from . import centernet_fpn +from . import pan from .fpn import * from .yolo_fpn import * @@ -24,3 +25,4 @@ from .hrfpn import * from .ttf_fpn import * from .centernet_fpn import * from .blazeface_fpn import * +from .pan import * diff --git a/ppdet/modeling/necks/pan.py b/ppdet/modeling/necks/pan.py new file mode 100644 index 000000000..c5f5b90df --- /dev/null +++ b/ppdet/modeling/necks/pan.py @@ -0,0 +1,135 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn.initializer import XavierUniform +from paddle.regularizer import L2Decay +from ppdet.core.workspace import register, serializable +from ppdet.modeling.layers import ConvNormLayer +from ..shape_spec import ShapeSpec + +__all__ = ['PAN'] + + +@register +@serializable +class PAN(nn.Layer): + """ + Path Aggregation Network, see https://arxiv.org/abs/1803.01534 + + Args: + in_channels (list[int]): input channels of each level which can be + derived from the output shape of backbone by from_config + out_channel (list[int]): output channel of each level + spatial_scales (list[float]): the spatial scales between input feature + maps and original input image which can be derived from the output + shape of backbone by from_config + has_extra_convs (bool): whether to add extra conv to the last level. + default False + extra_stage (int): the number of extra stages added to the last level. + default 1 + use_c5 (bool): Whether to use c5 as the input of extra stage, + otherwise p5 is used. default True + norm_type (string|None): The normalization type in FPN module. If + norm_type is None, norm will not be used after conv and if + norm_type is string, bn, gn, sync_bn are available. default None + norm_decay (float): weight decay for normalization layer weights. + default 0. + freeze_norm (bool): whether to freeze normalization layer. + default False + relu_before_extra_convs (bool): whether to add relu before extra convs. + default False + """ + + def __init__(self, + in_channels, + out_channel, + spatial_scales=[0.125, 0.0625, 0.03125], + start_level=0, + end_level=-1, + norm_type=None): + super(PAN, self).__init__() + self.out_channel = out_channel + self.num_ins = len(in_channels) + self.spatial_scales = spatial_scales + if end_level == -1: + self.end_level = self.num_ins + else: + # if end_level < inputs, no extra level is allowed + self.end_level = end_level + assert end_level <= len(in_channels) + self.start_level = start_level + self.norm_type = norm_type + self.lateral_convs = [] + + for i in range(self.start_level, self.end_level): + in_c = in_channels[i - self.start_level] + if self.norm_type is not None: + lateral = self.add_sublayer( + 'pan_lateral' + str(i), + ConvNormLayer( + ch_in=in_c, + ch_out=self.out_channel, + filter_size=1, + stride=1, + norm_type=self.norm_type, + norm_decay=self.norm_decay, + freeze_norm=self.freeze_norm, + initializer=XavierUniform(fan_out=in_c))) + else: + lateral = self.add_sublayer( + 'pan_lateral' + str(i), + nn.Conv2D( + in_channels=in_c, + out_channels=self.out_channel, + kernel_size=1, + weight_attr=ParamAttr( + initializer=XavierUniform(fan_out=in_c)))) + self.lateral_convs.append(lateral) + + @classmethod + def from_config(cls, cfg, input_shape): + return {'in_channels': [i.channels for i in input_shape], } + + def forward(self, body_feats): + laterals = [] + for i, lateral_conv in enumerate(self.lateral_convs): + laterals.append(lateral_conv(body_feats[i + self.start_level])) + num_levels = len(laterals) + for i in range(1, num_levels): + lvl = num_levels - i + upsample = F.interpolate( + laterals[lvl], + scale_factor=2., + mode='bilinear', ) + laterals[lvl - 1] += upsample + + outs = [laterals[i] for i in range(num_levels)] + for i in range(0, num_levels - 1): + outs[i + 1] += F.interpolate( + outs[i], scale_factor=0.5, mode='bilinear') + + return outs + + @property + def out_shape(self): + return [ + ShapeSpec( + channels=self.out_channel, stride=1. / s) + for s in self.spatial_scales + ] diff --git a/ppdet/modeling/tests/test_architectures.py b/ppdet/modeling/tests/test_architectures.py index 95cb21203..693a27e1a 100644 --- a/ppdet/modeling/tests/test_architectures.py +++ b/ppdet/modeling/tests/test_architectures.py @@ -55,5 +55,15 @@ class TestSSD(TestFasterRCNN): self.cfg_file = 'configs/ssd/ssd_vgg16_300_240e_voc.yml' +class TestGFL(TestFasterRCNN): + def set_config(self): + self.cfg_file = 'configs/gfl/gfl_r50_fpn_1x_coco.yml' + + +class TestPicoDet(TestFasterRCNN): + def set_config(self): + self.cfg_file = 'configs/picodet/picodet_s_shufflenetv2_320_coco.yml' + + if __name__ == '__main__': unittest.main() -- GitLab