From 7d625608a174afabf88466b3d31d0313ec847582 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Mon, 18 Oct 2021 18:42:06 +0800 Subject: [PATCH] update PicoDet Architecture and config (#4323) * update PicoDet Architecture and config --- configs/gfl/_base_/gfl_r50_fpn.yml | 2 +- configs/gfl/_base_/gflv2_r50_fpn.yml | 2 +- configs/picodet/README.md | 38 +- configs/picodet/_base_/picodet_320_reader.yml | 9 +- configs/picodet/_base_/picodet_416_reader.yml | 7 +- configs/picodet/_base_/picodet_640_reader.yml | 41 ++ ...odet_mobilenetv3.yml => picodet_esnet.yml} | 40 +- .../_base_/picodet_shufflenetv2_1x.yml | 49 -- .../more_config/picodet_lcnet_416_coco.yml | 37 ++ .../picodet_mobilenetv3_416_coco.yml | 38 ++ .../more_config/picodet_r18_640_coco.yml | 39 ++ .../picodet_shufflenetv2_416_coco.yml | 38 ++ ...18_320_coco.yml => picodet_l_320_coco.yml} | 36 +- configs/picodet/picodet_l_416_coco.yml | 43 ++ configs/picodet/picodet_l_640_coco.yml | 43 ++ ...v2_320_coco.yml => picodet_m_320_coco.yml} | 4 +- ...v2_416_coco.yml => picodet_m_416_coco.yml} | 4 +- configs/picodet/picodet_m_mbv3_320_coco.yml | 16 - configs/picodet/picodet_m_mbv3_416_coco.yml | 25 - .../picodet_m_shufflenetv2_416_coco.yml | 38 -- ...v2_320_coco.yml => picodet_s_320_coco.yml} | 29 +- configs/picodet/picodet_s_416_coco.yml | 34 ++ configs/picodet/picodet_s_lcnet_320_coco.yml | 23 - configs/picodet/picodet_s_lcnet_416_coco.yml | 23 - configs/picodet/picodet_xs_lcnet_320_coco.yml | 23 - ppdet/engine/trainer.py | 10 +- ppdet/modeling/assigners/__init__.py | 2 + ppdet/modeling/assigners/simota_assigner.py | 272 ++++++++++ ppdet/modeling/backbones/__init__.py | 2 + ppdet/modeling/backbones/esnet.py | 290 ++++++++++ ppdet/modeling/bbox_utils.py | 94 ++++ ppdet/modeling/heads/__init__.py | 2 + ppdet/modeling/heads/gfl_head.py | 12 +- ppdet/modeling/heads/pico_head.py | 48 +- ppdet/modeling/heads/simota_head.py | 513 ++++++++++++++++++ ppdet/modeling/losses/varifocal_loss.py | 158 ++++++ ppdet/modeling/necks/__init__.py | 2 + ppdet/modeling/necks/csp_pan.py | 361 ++++++++++++ 38 files changed, 2125 insertions(+), 322 deletions(-) create mode 100644 configs/picodet/_base_/picodet_640_reader.yml rename configs/picodet/_base_/{picodet_mobilenetv3.yml => picodet_esnet.yml} (53%) delete mode 100644 configs/picodet/_base_/picodet_shufflenetv2_1x.yml create mode 100644 configs/picodet/more_config/picodet_lcnet_416_coco.yml create mode 100644 configs/picodet/more_config/picodet_mobilenetv3_416_coco.yml create mode 100644 configs/picodet/more_config/picodet_r18_640_coco.yml create mode 100644 configs/picodet/more_config/picodet_shufflenetv2_416_coco.yml rename configs/picodet/{picodet_l_r18_320_coco.yml => picodet_l_320_coco.yml} (54%) create mode 100644 configs/picodet/picodet_l_416_coco.yml create mode 100644 configs/picodet/picodet_l_640_coco.yml rename configs/picodet/{picodet_s_shufflenetv2_320_coco.yml => picodet_m_320_coco.yml} (68%) rename configs/picodet/{picodet_s_shufflenetv2_416_coco.yml => picodet_m_416_coco.yml} (68%) delete mode 100644 configs/picodet/picodet_m_mbv3_320_coco.yml delete mode 100644 configs/picodet/picodet_m_mbv3_416_coco.yml delete mode 100644 configs/picodet/picodet_m_shufflenetv2_416_coco.yml rename configs/picodet/{picodet_m_shufflenetv2_320_coco.yml => picodet_s_320_coco.yml} (50%) create mode 100644 configs/picodet/picodet_s_416_coco.yml delete mode 100644 configs/picodet/picodet_s_lcnet_320_coco.yml delete mode 100644 configs/picodet/picodet_s_lcnet_416_coco.yml delete mode 100644 configs/picodet/picodet_xs_lcnet_320_coco.yml create mode 100644 ppdet/modeling/assigners/simota_assigner.py create mode 100644 ppdet/modeling/backbones/esnet.py create mode 100644 ppdet/modeling/heads/simota_head.py create mode 100644 ppdet/modeling/losses/varifocal_loss.py create mode 100644 ppdet/modeling/necks/csp_pan.py diff --git a/configs/gfl/_base_/gfl_r50_fpn.yml b/configs/gfl/_base_/gfl_r50_fpn.yml index 8130b5ca8..488bec61e 100644 --- a/configs/gfl/_base_/gfl_r50_fpn.yml +++ b/configs/gfl/_base_/gfl_r50_fpn.yml @@ -32,7 +32,7 @@ GFLHead: fpn_stride: [8, 16, 32, 64, 128] prior_prob: 0.01 reg_max: 16 - loss_qfl: + loss_class: name: QualityFocalLoss use_sigmoid: True beta: 2.0 diff --git a/configs/gfl/_base_/gflv2_r50_fpn.yml b/configs/gfl/_base_/gflv2_r50_fpn.yml index 691dde035..e9708d86a 100644 --- a/configs/gfl/_base_/gflv2_r50_fpn.yml +++ b/configs/gfl/_base_/gflv2_r50_fpn.yml @@ -37,7 +37,7 @@ GFLHead: reg_topk: 4 reg_channels: 64 add_mean: True - loss_qfl: + loss_class: name: QualityFocalLoss use_sigmoid: False beta: 2.0 diff --git a/configs/picodet/README.md b/configs/picodet/README.md index 85cd197a0..b2c75a517 100644 --- a/configs/picodet/README.md +++ b/configs/picodet/README.md @@ -5,41 +5,29 @@ We developed a series of lightweight models, which named `PicoDet`. Because of its excellent performance, it is very suitable for deployment on mobile or CPU. -Optimizing method of we use: -- [ATSS](https://arxiv.org/abs/1912.02424) -- [Generalized Focal Loss](https://arxiv.org/abs/2006.04388) -- Lr Cosine Decay and cycle-EMA -- lightweight head +- 🌟 Higher mAP: The **first** model which within 1M parameter with mAP reaching 30+. +- 🚀 Faster latency: 114FPS on mobile ARM CPU. +- 😊 Deploy friendly: support PaddleLite/MNN/NCNN/OpenVINO and provide C++/Python/Android implementation. +- 😍 Advanced algorithm: use the most advanced algorithms and innovate, such as ESNet, CSP-PAN, SimOTA with VFL, etc. + ## Requirements -- PaddlePaddle == 2.1.2 +- PaddlePaddle >= 2.1.2 - PaddleSlim >= 2.1.1 ## Comming soon -- [ ] Benchmark of PicoDet. -- [ ] deploy for most platforms, such as PaddleLite、MNN、ncnn、openvino etc. -- [ ] PicoDet-XS and PicoDet-L series of model. -- [ ] Slim for PicoDet. +- [ ] More series of model, such as Smaller or larger model. +- [ ] Pretrained models for more scenarios. - [ ] More features in need. ## Model Zoo -### PicoDet-S - -| Backbone | Input size | lr schedule | Box AP(0.5:0.95) | Box AP(0.5) | FLOPS | Model Size | Inference Time | download | config | +| Model | Input size | lr schedule | Box AP(0.5:0.95) | Box AP(0.5) | FLOPS | Model Size | Inference Time | download | config | | :------------------------ | :-------: | :-------: | :------: | :---: | :---: | :---: | :------------: | :-------------------------------------------------: | :-----: | -| ShuffleNetv2-1x | 320*320 | 280e | 22.8 | 37.7 | -- | 3.8M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_shufflenetv2_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_shufflenetv2_320_coco.yml) | -| ShuffleNetv2-1x | 416*416 | 280e | 25.3 | 41.1 | -- | 3.8M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_shufflenetv2_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_shufflenetv2_416_coco.yml) | - - -### PicoDet-M - -| Backbone | Input size | lr schedule | Box AP(0.5:0.95) | Box AP(0.5) | FLOPS | Model Size | Inference Time | download | config | -| :------------------------ | :-------: | :-------: | :-----------: | :---: | :---: | :---: | :-----: | :-------------------------------------------------: | :-----: | -| ShuffleNetv2-1.5x | 320*320 | 280e | 25.3 | 41.2 | -- | 8.1M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_shufflenetv2_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_shufflenetv2_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_shufflenetv2_320_coco.yml) | -| MobileNetv3-large-1x | 320*320 | 280e | 26.7 | 44.1 | -- | 11.6M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_mbv3_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_mbv3_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_mbv3_320_coco.yml) | -| ShuffleNetv2-1.5x | 416*416 | 280e | 28.0 | 44.3 | -- | 8.1M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_shufflenetv2_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_shufflenetv2_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_shufflenetv2_416_coco.yml) | -| MobileNetv3-large-1x | 416*416 | 280e | 29.3 | 47.2 | -- | 11.6M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_mbv3_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_mbv3_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_mbv3_416_coco.yml) | +| PicoDet-S | 320*320 | 300e | 27.1 | 41.4 | -- | 3.9M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_320_coco.yml) | +| PicoDet-S | 416*416 | 300e | 30.6 | 45.5 | -- | 3.9M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_416_coco.yml) | +| PicoDet-M | 320*320 | 300e | - | 41.2 | -- | 8.4M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_320_coco.yml) | +| PicoDet-M | 416*416 | 300e | 34.3 | 49.8 | -- | 8.4M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_416_coco.yml) | **Notes:** diff --git a/configs/picodet/_base_/picodet_320_reader.yml b/configs/picodet/_base_/picodet_320_reader.yml index 469184529..2ce5bca66 100644 --- a/configs/picodet/_base_/picodet_320_reader.yml +++ b/configs/picodet/_base_/picodet_320_reader.yml @@ -1,4 +1,4 @@ -worker_num: 8 +worker_num: 6 TrainReader: sample_transforms: - Decode: {} @@ -9,13 +9,10 @@ TrainReader: - BatchRandomResize: {target_size: [256, 288, 320, 352, 384], random_size: True, random_interp: True, keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} - - Gt2GFLTarget: - downsample_ratios: [8, 16, 32] - grid_cell_scale: 5 - cell_offset: 0.5 batch_size: 128 shuffle: true drop_last: true + collate_batch: false EvalReader: @@ -32,7 +29,7 @@ EvalReader: TestReader: inputs_def: - image_shape: [3, 320, 320] + image_shape: [1, 3, 320, 320] sample_transforms: - Decode: {} - Resize: {interp: 2, target_size: [320, 320], keep_ratio: False} diff --git a/configs/picodet/_base_/picodet_416_reader.yml b/configs/picodet/_base_/picodet_416_reader.yml index 58b6607dc..12070a4be 100644 --- a/configs/picodet/_base_/picodet_416_reader.yml +++ b/configs/picodet/_base_/picodet_416_reader.yml @@ -9,13 +9,10 @@ TrainReader: - BatchRandomResize: {target_size: [352, 384, 416, 448, 480], random_size: True, random_interp: True, keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} - - Gt2GFLTarget: - downsample_ratios: [8, 16, 32] - grid_cell_scale: 5 - cell_offset: 0.5 batch_size: 80 shuffle: true drop_last: true + collate_batch: false EvalReader: @@ -32,7 +29,7 @@ EvalReader: TestReader: inputs_def: - image_shape: [3, 416, 416] + image_shape: [1, 3, 416, 416] sample_transforms: - Decode: {} - Resize: {interp: 2, target_size: [416, 416], keep_ratio: False} diff --git a/configs/picodet/_base_/picodet_640_reader.yml b/configs/picodet/_base_/picodet_640_reader.yml new file mode 100644 index 000000000..a931f2a76 --- /dev/null +++ b/configs/picodet/_base_/picodet_640_reader.yml @@ -0,0 +1,41 @@ +worker_num: 6 +TrainReader: + sample_transforms: + - Decode: {} + - RandomCrop: {} + - RandomFlip: {prob: 0.5} + - RandomDistort: {} + batch_transforms: + - BatchRandomResize: {target_size: [576, 608, 640, 672, 704], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_size: 56 + shuffle: true + drop_last: true + collate_batch: false + + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [640, 640], keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 8 + shuffle: false + + +TestReader: + inputs_def: + image_shape: [1, 3, 640, 640] + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [640, 640], keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false diff --git a/configs/picodet/_base_/picodet_mobilenetv3.yml b/configs/picodet/_base_/picodet_esnet.yml similarity index 53% rename from configs/picodet/_base_/picodet_mobilenetv3.yml rename to configs/picodet/_base_/picodet_esnet.yml index 934fe0a5c..aa099fca1 100644 --- a/configs/picodet/_base_/picodet_mobilenetv3.yml +++ b/configs/picodet/_base_/picodet_esnet.yml @@ -1,41 +1,41 @@ architecture: PicoDet -pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/MobileNetV3_large_x1_0_ssld_pretrained.pdparams +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_0_pretrained.pdparams PicoDet: - backbone: MobileNetV3 - neck: PAN + backbone: ESNet + neck: CSPPAN head: PicoHead -MobileNetV3: - model_name: large +ESNet: scale: 1.0 - with_extra_blocks: false - extra_block_filters: [] - feature_maps: [7, 13, 16] + feature_maps: [4, 11, 14] + act: hard_swish + channel_ratio: [0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5, 0.625, 1.0, 0.625, 0.75] -PAN: - out_channel: 128 - start_level: 0 - end_level: 3 - spatial_scales: [0.125, 0.0625, 0.03125] +CSPPAN: + out_channels: 128 + use_depthwise: True + num_csp_blocks: 1 + num_features: 4 PicoHead: conv_feat: name: PicoFeat feat_in: 128 feat_out: 128 - num_convs: 2 + num_convs: 4 + num_fpn_stride: 4 norm_type: bn share_cls_reg: True - fpn_stride: [8, 16, 32] + fpn_stride: [8, 16, 32, 64] feat_in_chan: 128 prior_prob: 0.01 reg_max: 7 cell_offset: 0.5 - loss_qfl: - name: QualityFocalLoss + loss_class: + name: VarifocalLoss use_sigmoid: True - beta: 2.0 + iou_weighted: True loss_weight: 1.0 loss_dfl: name: DistributionFocalLoss @@ -43,6 +43,10 @@ PicoHead: loss_bbox: name: GIoULoss loss_weight: 2.0 + assigner: + name: SimOTAAssigner + candidate_topk: 10 + iou_weight: 6 nms: name: MultiClassNMS nms_top_k: 1000 diff --git a/configs/picodet/_base_/picodet_shufflenetv2_1x.yml b/configs/picodet/_base_/picodet_shufflenetv2_1x.yml deleted file mode 100644 index 25517f9fb..000000000 --- a/configs/picodet/_base_/picodet_shufflenetv2_1x.yml +++ /dev/null @@ -1,49 +0,0 @@ -architecture: PicoDet -pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ShuffleNetV2_x1_0_pretrained.pdparams - -PicoDet: - backbone: ShuffleNetV2 - neck: PAN - head: PicoHead - -ShuffleNetV2: - scale: 1.0 - feature_maps: [5, 13, 17] - act: leaky_relu - -PAN: - out_channel: 96 - start_level: 0 - end_level: 3 - spatial_scales: [0.125, 0.0625, 0.03125] - -PicoHead: - conv_feat: - name: PicoFeat - feat_in: 96 - feat_out: 96 - num_convs: 2 - norm_type: bn - share_cls_reg: True - fpn_stride: [8, 16, 32] - feat_in_chan: 96 - prior_prob: 0.01 - reg_max: 7 - cell_offset: 0.5 - loss_qfl: - name: QualityFocalLoss - use_sigmoid: True - beta: 2.0 - loss_weight: 1.0 - loss_dfl: - name: DistributionFocalLoss - loss_weight: 0.25 - loss_bbox: - name: GIoULoss - loss_weight: 2.0 - nms: - name: MultiClassNMS - nms_top_k: 1000 - keep_top_k: 100 - score_threshold: 0.025 - nms_threshold: 0.6 diff --git a/configs/picodet/more_config/picodet_lcnet_416_coco.yml b/configs/picodet/more_config/picodet_lcnet_416_coco.yml new file mode 100644 index 000000000..cdfc59919 --- /dev/null +++ b/configs/picodet/more_config/picodet_lcnet_416_coco.yml @@ -0,0 +1,37 @@ +_BASE_: [ + '../../datasets/coco_detection.yml', + '../../runtime.yml', + '../_base_/picodet_esnet.yml', + '../_base_/optimizer_300e.yml', + '../_base_/picodet_416_reader.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/LCNet_x1_0_pretrained.pdparams +weights: output/picodet_lcnet_416_coco/model_final +find_unused_parameters: True +use_ema: true +cycle_epoch: 40 +snapshot_epoch: 10 + +PicoDet: + backbone: LCNet + neck: CSPPAN + head: PicoHead + +LCNet: + scale: 1.0 + feature_maps: [3, 4, 5] + +CSPPAN: + out_channels: 96 + +PicoHead: + conv_feat: + name: PicoFeat + feat_in: 96 + feat_out: 96 + num_convs: 2 + num_fpn_stride: 4 + norm_type: bn + share_cls_reg: True + feat_in_chan: 96 diff --git a/configs/picodet/more_config/picodet_mobilenetv3_416_coco.yml b/configs/picodet/more_config/picodet_mobilenetv3_416_coco.yml new file mode 100644 index 000000000..fb4cc96ee --- /dev/null +++ b/configs/picodet/more_config/picodet_mobilenetv3_416_coco.yml @@ -0,0 +1,38 @@ +_BASE_: [ + '../../datasets/coco_detection.yml', + '../../runtime.yml', + '../_base_/picodet_esnet.yml', + '../_base_/optimizer_300e.yml', + '../_base_/picodet_416_reader.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/MobileNetV3_large_x1_0_ssld_pretrained.pdparams +weights: output/picodet_mobilenetv3_416_coco/model_final +find_unused_parameters: True +use_ema: true +cycle_epoch: 40 +snapshot_epoch: 10 + +PicoDet: + backbone: MobileNetV3 + neck: CSPPAN + head: PicoHead + +MobileNetV3: + model_name: large + scale: 1.0 + with_extra_blocks: false + extra_block_filters: [] + feature_maps: [7, 13, 16] + +TrainReader: + batch_size: 56 + +LearningRate: + base_lr: 0.3 + schedulers: + - !CosineDecay + max_epochs: 300 + - !LinearWarmup + start_factor: 0.1 + steps: 300 diff --git a/configs/picodet/more_config/picodet_r18_640_coco.yml b/configs/picodet/more_config/picodet_r18_640_coco.yml new file mode 100644 index 000000000..276a933d9 --- /dev/null +++ b/configs/picodet/more_config/picodet_r18_640_coco.yml @@ -0,0 +1,39 @@ +_BASE_: [ + '../../datasets/coco_detection.yml', + '../../runtime.yml', + '../_base_/picodet_esnet.yml', + '../_base_/optimizer_300e.yml', + '../_base_/picodet_640_reader.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet18_vd_pretrained.pdparams +weights: output/picodet_r18_640_coco/model_final +find_unused_parameters: True +use_ema: true +cycle_epoch: 40 +snapshot_epoch: 10 + +PicoDet: + backbone: ResNet + neck: CSPPAN + head: PicoHead + +ResNet: + depth: 18 + variant: d + return_idx: [1, 2, 3] + freeze_at: -1 + freeze_norm: false + norm_decay: 0. + +TrainReader: + batch_size: 56 + +LearningRate: + base_lr: 0.3 + schedulers: + - !CosineDecay + max_epochs: 300 + - !LinearWarmup + start_factor: 0.1 + steps: 300 diff --git a/configs/picodet/more_config/picodet_shufflenetv2_416_coco.yml b/configs/picodet/more_config/picodet_shufflenetv2_416_coco.yml new file mode 100644 index 000000000..cefcf18ef --- /dev/null +++ b/configs/picodet/more_config/picodet_shufflenetv2_416_coco.yml @@ -0,0 +1,38 @@ +_BASE_: [ + '../../datasets/coco_detection.yml', + '../../runtime.yml', + '../_base_/picodet_esnet.yml', + '../_base_/optimizer_300e.yml', + '../_base_/picodet_416_reader.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ShuffleNetV2_x1_0_pretrained.pdparams +weights: output/picodet_shufflenetv2_416_coco/model_final +find_unused_parameters: True +use_ema: true +cycle_epoch: 40 +snapshot_epoch: 10 + +PicoDet: + backbone: ShuffleNetV2 + neck: CSPPAN + head: PicoHead + +ShuffleNetV2: + scale: 1.0 + feature_maps: [5, 13, 17] + act: leaky_relu + +CSPPAN: + out_channels: 96 + +PicoHead: + conv_feat: + name: PicoFeat + feat_in: 96 + feat_out: 96 + num_convs: 2 + num_fpn_stride: 4 + norm_type: bn + share_cls_reg: True + feat_in_chan: 96 diff --git a/configs/picodet/picodet_l_r18_320_coco.yml b/configs/picodet/picodet_l_320_coco.yml similarity index 54% rename from configs/picodet/picodet_l_r18_320_coco.yml rename to configs/picodet/picodet_l_320_coco.yml index d627a975b..23320cd36 100644 --- a/configs/picodet/picodet_l_r18_320_coco.yml +++ b/configs/picodet/picodet_l_320_coco.yml @@ -1,45 +1,33 @@ _BASE_: [ '../datasets/coco_detection.yml', '../runtime.yml', - '_base_/picodet_mbv3_0_5x.yml', + '_base_/picodet_esnet.yml', '_base_/optimizer_300e.yml', '_base_/picodet_320_reader.yml', ] -weights: output/picodet_l_r18_320_coco/model_final -pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet18_vd_pretrained.pdparams +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_25_pretrained.pdparams +weights: output/picodet_l_320_coco/model_final find_unused_parameters: True use_ema: true cycle_epoch: 40 snapshot_epoch: 10 -PicoDet: - backbone: ResNet - neck: PAN - head: PicoHead - -ResNet: - depth: 18 - variant: d - return_idx: [1, 2, 3] - freeze_at: -1 - freeze_norm: false - norm_decay: 0. - -PAN: - out_channel: 128 - start_level: 0 - end_level: 3 - spatial_scales: [0.125, 0.0625, 0.03125] +ESNet: + scale: 1.25 + feature_maps: [4, 11, 14] + act: hard_swish + channel_ratio: [0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5, 0.625, 1.0, 0.625, 0.75] PicoHead: conv_feat: name: PicoFeat feat_in: 128 feat_out: 128 - num_convs: 2 + num_convs: 4 + num_fpn_stride: 4 norm_type: bn - share_cls_reg: True + share_cls_reg: False feat_in_chan: 128 TrainReader: @@ -49,7 +37,7 @@ LearningRate: base_lr: 0.3 schedulers: - !CosineDecay - max_epochs: 280 + max_epochs: 300 - !LinearWarmup start_factor: 0.1 steps: 300 diff --git a/configs/picodet/picodet_l_416_coco.yml b/configs/picodet/picodet_l_416_coco.yml new file mode 100644 index 000000000..720125179 --- /dev/null +++ b/configs/picodet/picodet_l_416_coco.yml @@ -0,0 +1,43 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_esnet.yml', + '_base_/optimizer_300e.yml', + '_base_/picodet_416_reader.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_25_pretrained.pdparams +weights: output/picodet_l_416_coco/model_final +find_unused_parameters: True +use_ema: true +cycle_epoch: 40 +snapshot_epoch: 10 + +ESNet: + scale: 1.25 + feature_maps: [4, 11, 14] + act: hard_swish + channel_ratio: [0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5, 0.625, 1.0, 0.625, 0.75] + +PicoHead: + conv_feat: + name: PicoFeat + feat_in: 128 + feat_out: 128 + num_convs: 4 + num_fpn_stride: 4 + norm_type: bn + share_cls_reg: False + feat_in_chan: 128 + +TrainReader: + batch_size: 48 + +LearningRate: + base_lr: 0.3 + schedulers: + - !CosineDecay + max_epochs: 300 + - !LinearWarmup + start_factor: 0.1 + steps: 300 diff --git a/configs/picodet/picodet_l_640_coco.yml b/configs/picodet/picodet_l_640_coco.yml new file mode 100644 index 000000000..94cab2548 --- /dev/null +++ b/configs/picodet/picodet_l_640_coco.yml @@ -0,0 +1,43 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_esnet.yml', + '_base_/optimizer_300e.yml', + '_base_/picodet_640_reader.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_25_pretrained.pdparams +weights: output/picodet_l_640_coco/model_final +find_unused_parameters: True +use_ema: true +cycle_epoch: 40 +snapshot_epoch: 10 + +ESNet: + scale: 1.25 + feature_maps: [4, 11, 14] + act: hard_swish + channel_ratio: [0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5, 0.625, 1.0, 0.625, 0.75] + +PicoHead: + conv_feat: + name: PicoFeat + feat_in: 128 + feat_out: 128 + num_convs: 4 + num_fpn_stride: 4 + norm_type: bn + share_cls_reg: False + feat_in_chan: 128 + +TrainReader: + batch_size: 48 + +LearningRate: + base_lr: 0.3 + schedulers: + - !CosineDecay + max_epochs: 300 + - !LinearWarmup + start_factor: 0.1 + steps: 300 diff --git a/configs/picodet/picodet_s_shufflenetv2_320_coco.yml b/configs/picodet/picodet_m_320_coco.yml similarity index 68% rename from configs/picodet/picodet_s_shufflenetv2_320_coco.yml rename to configs/picodet/picodet_m_320_coco.yml index 009e994aa..54fc6e605 100644 --- a/configs/picodet/picodet_s_shufflenetv2_320_coco.yml +++ b/configs/picodet/picodet_m_320_coco.yml @@ -1,12 +1,12 @@ _BASE_: [ '../datasets/coco_detection.yml', '../runtime.yml', - '_base_/picodet_shufflenetv2_1x.yml', + '_base_/picodet_esnet.yml', '_base_/optimizer_300e.yml', '_base_/picodet_320_reader.yml', ] -weights: output/picodet_s_shufflenetv2_320_coco/model_final +weights: output/picodet_m_320_coco/model_final find_unused_parameters: True use_ema: true cycle_epoch: 40 diff --git a/configs/picodet/picodet_s_shufflenetv2_416_coco.yml b/configs/picodet/picodet_m_416_coco.yml similarity index 68% rename from configs/picodet/picodet_s_shufflenetv2_416_coco.yml rename to configs/picodet/picodet_m_416_coco.yml index 9c551f7a6..c53b6dcc5 100644 --- a/configs/picodet/picodet_s_shufflenetv2_416_coco.yml +++ b/configs/picodet/picodet_m_416_coco.yml @@ -1,12 +1,12 @@ _BASE_: [ '../datasets/coco_detection.yml', '../runtime.yml', - '_base_/picodet_shufflenetv2_1x.yml', + '_base_/picodet_esnet.yml', '_base_/optimizer_300e.yml', '_base_/picodet_416_reader.yml', ] -weights: output/picodet_s_shufflenetv2_416_coco/model_final +weights: output/picodet_m_416_coco/model_final find_unused_parameters: True use_ema: true cycle_epoch: 40 diff --git a/configs/picodet/picodet_m_mbv3_320_coco.yml b/configs/picodet/picodet_m_mbv3_320_coco.yml deleted file mode 100644 index 9e4055b8a..000000000 --- a/configs/picodet/picodet_m_mbv3_320_coco.yml +++ /dev/null @@ -1,16 +0,0 @@ -_BASE_: [ - '../datasets/coco_detection.yml', - '../runtime.yml', - '_base_/picodet_mobilenetv3.yml', - '_base_/optimizer_300e.yml', - '_base_/picodet_320_reader.yml', -] - -weights: output/picodet_m_mbv3_320_coco/model_final -find_unused_parameters: True -use_ema: true -cycle_epoch: 40 -snapshot_epoch: 10 - -TrainReader: - batch_size: 88 diff --git a/configs/picodet/picodet_m_mbv3_416_coco.yml b/configs/picodet/picodet_m_mbv3_416_coco.yml deleted file mode 100644 index f2e9653dd..000000000 --- a/configs/picodet/picodet_m_mbv3_416_coco.yml +++ /dev/null @@ -1,25 +0,0 @@ -_BASE_: [ - '../datasets/coco_detection.yml', - '../runtime.yml', - '_base_/picodet_mobilenetv3.yml', - '_base_/optimizer_300e.yml', - '_base_/picodet_416_reader.yml', -] - -weights: output/picodet_m_mbv3_416_coco/model_final -find_unused_parameters: True -use_ema: true -cycle_epoch: 40 -snapshot_epoch: 10 - -TrainReader: - batch_size: 56 - -LearningRate: - base_lr: 0.3 - schedulers: - - !CosineDecay - max_epochs: 280 - - !LinearWarmup - start_factor: 0.1 - steps: 300 diff --git a/configs/picodet/picodet_m_shufflenetv2_416_coco.yml b/configs/picodet/picodet_m_shufflenetv2_416_coco.yml deleted file mode 100644 index 0726ab8e2..000000000 --- a/configs/picodet/picodet_m_shufflenetv2_416_coco.yml +++ /dev/null @@ -1,38 +0,0 @@ -_BASE_: [ - '../datasets/coco_detection.yml', - '../runtime.yml', - '_base_/picodet_shufflenetv2_1x.yml', - '_base_/optimizer_300e.yml', - '_base_/picodet_416_reader.yml', -] - -weights: output/picodet_m_shufflenetv2_416_coco/model_final -pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ShuffleNetV2_x1_5_pretrained.pdparams -find_unused_parameters: True -use_ema: true -cycle_epoch: 40 -snapshot_epoch: 10 - -ShuffleNetV2: - scale: 1.5 - feature_maps: [5, 13, 17] - act: leaky_relu - -PAN: - out_channel: 128 - start_level: 0 - end_level: 3 - spatial_scales: [0.125, 0.0625, 0.03125] - -PicoHead: - conv_feat: - name: PicoFeat - feat_in: 128 - feat_out: 128 - num_convs: 2 - norm_type: bn - share_cls_reg: True - feat_in_chan: 128 - -TrainReader: - batch_size: 88 diff --git a/configs/picodet/picodet_m_shufflenetv2_320_coco.yml b/configs/picodet/picodet_s_320_coco.yml similarity index 50% rename from configs/picodet/picodet_m_shufflenetv2_320_coco.yml rename to configs/picodet/picodet_s_320_coco.yml index 168b36dfc..c54ab14df 100644 --- a/configs/picodet/picodet_m_shufflenetv2_320_coco.yml +++ b/configs/picodet/picodet_s_320_coco.yml @@ -1,35 +1,34 @@ _BASE_: [ '../datasets/coco_detection.yml', '../runtime.yml', - '_base_/picodet_shufflenetv2_1x.yml', + '_base_/picodet_esnet.yml', '_base_/optimizer_300e.yml', '_base_/picodet_320_reader.yml', ] -weights: output/picodet_m_shufflenetv2_320_coco/model_final -pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ShuffleNetV2_x1_5_pretrained.pdparams +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x0_75_pretrained.pdparams +weights: output/picodet_s_320_coco/model_final find_unused_parameters: True use_ema: true cycle_epoch: 40 snapshot_epoch: 10 -ShuffleNetV2: - scale: 1.5 - feature_maps: [5, 13, 17] - act: leaky_relu +ESNet: + scale: 0.75 + feature_maps: [4, 11, 14] + act: hard_swish + channel_ratio: [0.875, 0.5, 0.5, 0.5, 0.625, 0.5, 0.625, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5] -PAN: - out_channel: 128 - start_level: 0 - end_level: 3 - spatial_scales: [0.125, 0.0625, 0.03125] +CSPPAN: + out_channels: 96 PicoHead: conv_feat: name: PicoFeat - feat_in: 128 - feat_out: 128 + feat_in: 96 + feat_out: 96 num_convs: 2 + num_fpn_stride: 4 norm_type: bn share_cls_reg: True - feat_in_chan: 128 + feat_in_chan: 96 diff --git a/configs/picodet/picodet_s_416_coco.yml b/configs/picodet/picodet_s_416_coco.yml new file mode 100644 index 000000000..f28e166cc --- /dev/null +++ b/configs/picodet/picodet_s_416_coco.yml @@ -0,0 +1,34 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_esnet.yml', + '_base_/optimizer_300e.yml', + '_base_/picodet_416_reader.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x0_75_pretrained.pdparams +weights: output/picodet_s_416_coco/model_final +find_unused_parameters: True +use_ema: true +cycle_epoch: 40 +snapshot_epoch: 10 + +ESNet: + scale: 0.75 + feature_maps: [4, 11, 14] + act: hard_swish + channel_ratio: [0.875, 0.5, 0.5, 0.5, 0.625, 0.5, 0.625, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5] + +CSPPAN: + out_channels: 96 + +PicoHead: + conv_feat: + name: PicoFeat + feat_in: 96 + feat_out: 96 + num_convs: 2 + num_fpn_stride: 4 + norm_type: bn + share_cls_reg: True + feat_in_chan: 96 diff --git a/configs/picodet/picodet_s_lcnet_320_coco.yml b/configs/picodet/picodet_s_lcnet_320_coco.yml deleted file mode 100644 index 762ae1d90..000000000 --- a/configs/picodet/picodet_s_lcnet_320_coco.yml +++ /dev/null @@ -1,23 +0,0 @@ -_BASE_: [ - '../datasets/coco_detection.yml', - '../runtime.yml', - '_base_/picodet_shufflenetv2_1x.yml', - '_base_/optimizer_300e.yml', - '_base_/picodet_320_reader.yml', -] - -pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/LCNet_x1_0_pretrained.pdparams -weights: output/picodet_s_lcnet_320_coco/model_final -find_unused_parameters: True -use_ema: true -cycle_epoch: 40 -snapshot_epoch: 10 - -PicoDet: - backbone: LCNet - neck: PAN - head: PicoHead - -LCNet: - scale: 1.0 - feature_maps: [3, 4, 5] diff --git a/configs/picodet/picodet_s_lcnet_416_coco.yml b/configs/picodet/picodet_s_lcnet_416_coco.yml deleted file mode 100644 index f638b2a48..000000000 --- a/configs/picodet/picodet_s_lcnet_416_coco.yml +++ /dev/null @@ -1,23 +0,0 @@ -_BASE_: [ - '../datasets/coco_detection.yml', - '../runtime.yml', - '_base_/picodet_shufflenetv2_1x.yml', - '_base_/optimizer_300e.yml', - '_base_/picodet_416_reader.yml', -] - -pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/LCNet_x1_0_pretrained.pdparams -weights: output/picodet_s_lcnet_416_coco/model_final -find_unused_parameters: True -use_ema: true -cycle_epoch: 40 -snapshot_epoch: 10 - -PicoDet: - backbone: LCNet - neck: PAN - head: PicoHead - -LCNet: - scale: 1.0 - feature_maps: [3, 4, 5] diff --git a/configs/picodet/picodet_xs_lcnet_320_coco.yml b/configs/picodet/picodet_xs_lcnet_320_coco.yml deleted file mode 100644 index ab286963d..000000000 --- a/configs/picodet/picodet_xs_lcnet_320_coco.yml +++ /dev/null @@ -1,23 +0,0 @@ -_BASE_: [ - '../datasets/coco_detection.yml', - '../runtime.yml', - '_base_/picodet_shufflenetv2_1x.yml', - '_base_/optimizer_280e.yml', - '_base_/picodet_320_reader.yml', -] - -pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/LCNet_x0_25_pretrained.pdparams -weights: output/picodet_s_shufflenetv2_320_coco/model_final -find_unused_parameters: True -use_ema: true -cycle_epoch: 40 -snapshot_epoch: 10 - -PicoDet: - backbone: LCNet - neck: PAN - head: PicoHead - -LCNet: - scale: 0.25 - feature_maps: [3, 4, 5] diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index a97d0f16a..436c22f76 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -543,6 +543,8 @@ class Trainer(object): def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True): image_shape = None + im_shape = [None, 2] + scale_factor = [None, 2] if self.cfg.architecture in MOT_ARCH: test_reader_name = 'TestMOTReader' else: @@ -553,8 +555,12 @@ class Trainer(object): # set image_shape=[None, 3, -1, -1] as default if image_shape is None: image_shape = [None, 3, -1, -1] + if len(image_shape) == 3: image_shape = [None] + image_shape + else: + im_shape = [image_shape[0], 2] + scale_factor = [image_shape[0], 2] if hasattr(self.model, 'deploy'): self.model.deploy = True @@ -571,9 +577,9 @@ class Trainer(object): "image": InputSpec( shape=image_shape, name='image'), "im_shape": InputSpec( - shape=[None, 2], name='im_shape'), + shape=im_shape, name='im_shape'), "scale_factor": InputSpec( - shape=[None, 2], name='scale_factor') + shape=scale_factor, name='scale_factor') }] if self.cfg.architecture == 'DeepSORT': input_spec[0].update({ diff --git a/ppdet/modeling/assigners/__init__.py b/ppdet/modeling/assigners/__init__.py index 108ac7ecd..be5bb04d3 100644 --- a/ppdet/modeling/assigners/__init__.py +++ b/ppdet/modeling/assigners/__init__.py @@ -15,7 +15,9 @@ from . import utils from . import task_aligned_assigner from . import atss_assigner +from . import simota_assigner from .utils import * from .task_aligned_assigner import * from .atss_assigner import * +from .simota_assigner import * diff --git a/ppdet/modeling/assigners/simota_assigner.py b/ppdet/modeling/assigners/simota_assigner.py new file mode 100644 index 000000000..de4e89c8c --- /dev/null +++ b/ppdet/modeling/assigners/simota_assigner.py @@ -0,0 +1,272 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import paddle.nn.functional as F +import paddle.nn as nn + +from ppdet.modeling.losses.varifocal_loss import varifocal_loss +from ppdet.modeling.bbox_utils import batch_bbox_overlaps +from ppdet.core.workspace import register + + +@register +class SimOTAAssigner(object): + """Computes matching between predictions and ground truth. + + Args: + center_radius (int | float, optional): Ground truth center size + to judge whether a prior is in center. Default 2.5. + candidate_topk (int, optional): The candidate top-k which used to + get top-k ious to calculate dynamic-k. Default 10. + iou_weight (int | float, optional): The scale factor for regression + iou cost. Default 3.0. + cls_weight (int | float, optional): The scale factor for classification + cost. Default 1.0. + num_classes (int): The num_classes of dataset. + use_vfl (int): Whether to use varifocal_loss when calculating the cost matrix. + """ + __shared__ = ['num_classes'] + + def __init__(self, + center_radius=2.5, + candidate_topk=10, + iou_weight=3.0, + cls_weight=1.0, + num_classes=80, + use_vfl=True): + self.center_radius = center_radius + self.candidate_topk = candidate_topk + self.iou_weight = iou_weight + self.cls_weight = cls_weight + self.num_classes = num_classes + self.use_vfl = use_vfl + + def get_in_gt_and_in_center_info(self, priors, gt_bboxes): + num_gt = gt_bboxes.shape[0] + + repeated_x = priors[:, 0].unsqueeze(1).tile([1, num_gt]) + repeated_y = priors[:, 1].unsqueeze(1).tile([1, num_gt]) + repeated_stride_x = priors[:, 2].unsqueeze(1).tile([1, num_gt]) + repeated_stride_y = priors[:, 3].unsqueeze(1).tile([1, num_gt]) + + # is prior centers in gt bboxes, shape: [n_prior, n_gt] + l_ = repeated_x - gt_bboxes[:, 0] + t_ = repeated_y - gt_bboxes[:, 1] + r_ = gt_bboxes[:, 2] - repeated_x + b_ = gt_bboxes[:, 3] - repeated_y + + deltas = paddle.stack([l_, t_, r_, b_], axis=1) + is_in_gts = deltas.min(axis=1) > 0 + is_in_gts_all = is_in_gts.sum(axis=1) > 0 + + # is prior centers in gt centers + gt_cxs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0 + gt_cys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0 + ct_box_l = gt_cxs - self.center_radius * repeated_stride_x + ct_box_t = gt_cys - self.center_radius * repeated_stride_y + ct_box_r = gt_cxs + self.center_radius * repeated_stride_x + ct_box_b = gt_cys + self.center_radius * repeated_stride_y + + cl_ = repeated_x - ct_box_l + ct_ = repeated_y - ct_box_t + cr_ = ct_box_r - repeated_x + cb_ = ct_box_b - repeated_y + + ct_deltas = paddle.stack([cl_, ct_, cr_, cb_], axis=1) + is_in_cts = ct_deltas.min(axis=1) > 0 + is_in_cts_all = is_in_cts.sum(axis=1) > 0 + + # in boxes or in centers, shape: [num_priors] + is_in_gts_or_centers = paddle.logical_or(is_in_gts_all, is_in_cts_all) + + is_in_gts_or_centers_inds = paddle.nonzero( + is_in_gts_or_centers).squeeze(1) + + # both in boxes and centers, shape: [num_fg, num_gt] + is_in_boxes_and_centers = paddle.logical_and( + paddle.gather( + is_in_gts.cast('int'), is_in_gts_or_centers_inds, + axis=0).cast('bool'), + paddle.gather( + is_in_cts.cast('int'), is_in_gts_or_centers_inds, + axis=0).cast('bool')) + return is_in_gts_or_centers, is_in_boxes_and_centers + + def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask): + matching_matrix = np.zeros_like(cost.numpy()) + # select candidate topk ious for dynamic-k calculation + topk_ious, _ = paddle.topk(pairwise_ious, self.candidate_topk, axis=0) + # calculate dynamic k for each gt + dynamic_ks = paddle.clip(topk_ious.sum(0).cast('int'), min=1) + for gt_idx in range(num_gt): + _, pos_idx = paddle.topk( + cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False) + matching_matrix[:, gt_idx][pos_idx.numpy()] = 1.0 + + del topk_ious, dynamic_ks, pos_idx + + prior_match_gt_mask = matching_matrix.sum(1) > 1 + if prior_match_gt_mask.sum() > 0: + cost = cost.numpy() + cost_argmin = np.argmin(cost[prior_match_gt_mask, :], axis=1) + matching_matrix[prior_match_gt_mask, :] *= 0.0 + matching_matrix[prior_match_gt_mask, cost_argmin] = 1.0 + # get foreground mask inside box and center prior + fg_mask_inboxes = matching_matrix.sum(1) > 0.0 + valid_mask[valid_mask.copy()] = fg_mask_inboxes + + matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1) + matched_pred_ious = (matching_matrix * + pairwise_ious.numpy()).sum(1)[fg_mask_inboxes] + + matched_pred_ious = paddle.to_tensor( + matched_pred_ious, place=pairwise_ious.place) + matched_gt_inds = paddle.to_tensor( + matched_gt_inds, place=pairwise_ious.place) + + return matched_pred_ious, matched_gt_inds, valid_mask + + def get_sample(self, assign_gt_inds, gt_bboxes): + pos_inds = np.unique(np.nonzero(assign_gt_inds > 0)[0]) + neg_inds = np.unique(np.nonzero(assign_gt_inds == 0)[0]) + pos_assigned_gt_inds = assign_gt_inds[pos_inds] - 1 + + if gt_bboxes.size == 0: + # hack for index error case + assert pos_assigned_gt_inds.size == 0 + pos_gt_bboxes = np.empty_like(gt_bboxes).reshape(-1, 4) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.resize(-1, 4) + pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :] + return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds + + def __call__(self, + pred_scores, + priors, + decoded_bboxes, + gt_bboxes, + gt_labels, + eps=1e-7): + """Assign gt to priors using SimOTA. + TODO: add comment. + Returns: + assign_result: The assigned result. + """ + + INF = 100000000 + num_gt = gt_bboxes.shape[0] + num_bboxes = decoded_bboxes.shape[0] + + # assign 0 by default + assigned_gt_inds = paddle.full( + (num_bboxes, ), 0, dtype=paddle.int64).numpy() + if num_gt == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + max_overlaps = decoded_bboxes.new_zeros((num_bboxes, )) + if num_gt == 0: + # No truth, assign everything to background + assigned_gt_inds[:] = 0 + if gt_labels is None: + assigned_labels = None + else: + assigned_labels = paddle.full( + (num_bboxes, ), -1, dtype=paddle.int64) + return + + valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info( + priors, gt_bboxes) + + valid_mask_inds = paddle.nonzero(valid_mask).squeeze(1) + valid_decoded_bbox = decoded_bboxes[valid_mask_inds] + valid_pred_scores = pred_scores[valid_mask_inds] + num_valid = valid_decoded_bbox.shape[0] + + pairwise_ious = batch_bbox_overlaps(valid_decoded_bbox, gt_bboxes) + iou_cost = -paddle.log(pairwise_ious + eps) + if self.use_vfl: + gt_vfl_labels = gt_labels.squeeze(-1).unsqueeze(0).tile( + [num_valid, 1]).reshape([-1]) + valid_pred_scores = valid_pred_scores.unsqueeze(1).tile( + [1, num_gt, 1]).reshape([-1, self.num_classes]) + vfl_score = np.zeros(valid_pred_scores.shape) + vfl_score[np.arange(0, vfl_score.shape[0]), gt_vfl_labels.numpy( + )] = pairwise_ious.reshape([-1]) + vfl_score = paddle.to_tensor(vfl_score) + losses_vfl = varifocal_loss( + valid_pred_scores, vfl_score, + use_sigmoid=False).reshape([num_valid, num_gt]) + losses_giou = batch_bbox_overlaps( + valid_decoded_bbox, gt_bboxes, mode='giou') + cost_matrix = ( + losses_vfl * self.cls_weight + losses_giou * self.iou_weight + + paddle.logical_not(is_in_boxes_and_center).cast('float32') * INF + ) + else: + gt_onehot_label = (F.one_hot( + gt_labels.squeeze(-1).cast(paddle.int64), + pred_scores.shape[-1]).cast('float32').unsqueeze(0).tile( + [num_valid, 1, 1])) + + valid_pred_scores = valid_pred_scores.unsqueeze(1).tile( + [1, num_gt, 1]) + cls_cost = F.binary_cross_entropy( + valid_pred_scores, gt_onehot_label, reduction='none').sum(-1) + + cost_matrix = ( + cls_cost * self.cls_weight + iou_cost * self.iou_weight + + paddle.logical_not(is_in_boxes_and_center).cast('float32') * INF + ) + + matched_pred_ious, matched_gt_inds, valid_mask = \ + self.dynamic_k_matching( + cost_matrix, pairwise_ious, num_gt, valid_mask.numpy()) + + # assign results + gt_labels = gt_labels.numpy() + priors = priors.numpy() + matched_gt_inds = matched_gt_inds.numpy() + gt_bboxes = gt_bboxes.numpy() + + assigned_gt_inds[valid_mask] = matched_gt_inds + 1 + assigned_labels = np.full((num_bboxes, ), self.num_classes) + assigned_labels[valid_mask] = gt_labels.squeeze(-1)[matched_gt_inds] + + pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds \ + = self.get_sample(assigned_gt_inds, gt_bboxes) + + num_cells = priors.shape[0] + bbox_targets = np.zeros_like(priors) + bbox_weights = np.zeros_like(priors) + labels = np.ones([num_cells], dtype=np.int64) * self.num_classes + label_weights = np.zeros([num_cells], dtype=np.float32) + + if len(pos_inds) > 0: + pos_bbox_targets = pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + if not np.any(gt_labels): + labels[pos_inds] = 0 + else: + labels[pos_inds] = gt_labels.squeeze(-1)[pos_assigned_gt_inds] + + label_weights[pos_inds] = 1.0 + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + pos_num = max(pos_inds.size, 1) + + return priors, labels, label_weights, bbox_targets, pos_num diff --git a/ppdet/modeling/backbones/__init__.py b/ppdet/modeling/backbones/__init__.py index 138b64935..3f415e6a5 100644 --- a/ppdet/modeling/backbones/__init__.py +++ b/ppdet/modeling/backbones/__init__.py @@ -28,6 +28,7 @@ from . import shufflenet_v2 from . import swin_transformer from . import lcnet from . import hardnet +from . import esnet from .vgg import * from .resnet import * @@ -45,3 +46,4 @@ from .shufflenet_v2 import * from .swin_transformer import * from .lcnet import * from .hardnet import * +from .esnet import * diff --git a/ppdet/modeling/backbones/esnet.py b/ppdet/modeling/backbones/esnet.py new file mode 100644 index 000000000..2b3f3c54a --- /dev/null +++ b/ppdet/modeling/backbones/esnet.py @@ -0,0 +1,290 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn import Conv2D, MaxPool2D, AdaptiveAvgPool2D, BatchNorm +from paddle.nn.initializer import KaimingNormal +from paddle.regularizer import L2Decay + +from ppdet.core.workspace import register, serializable +from numbers import Integral +from ..shape_spec import ShapeSpec +from ppdet.modeling.ops import channel_shuffle +from ppdet.modeling.backbones.shufflenet_v2 import ConvBNLayer + +__all__ = ['ESNet'] + + +def make_divisible(v, divisor=16, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class SEModule(nn.Layer): + def __init__(self, channel, reduction=4): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2D(1) + self.conv1 = Conv2D( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(), + bias_attr=ParamAttr()) + self.conv2 = Conv2D( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(), + bias_attr=ParamAttr()) + + def forward(self, inputs): + outputs = self.avg_pool(inputs) + outputs = self.conv1(outputs) + outputs = F.relu(outputs) + outputs = self.conv2(outputs) + outputs = F.hardsigmoid(outputs) + return paddle.multiply(x=inputs, y=outputs) + + +class InvertedResidual(nn.Layer): + def __init__(self, + in_channels, + mid_channels, + out_channels, + stride, + act="relu"): + super(InvertedResidual, self).__init__() + self._conv_pw = ConvBNLayer( + in_channels=in_channels // 2, + out_channels=mid_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act) + self._conv_dw = ConvBNLayer( + in_channels=mid_channels // 2, + out_channels=mid_channels // 2, + kernel_size=3, + stride=stride, + padding=1, + groups=mid_channels // 2, + act=None) + self._se = SEModule(mid_channels) + + self._conv_linear = ConvBNLayer( + in_channels=mid_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act) + + def forward(self, inputs): + x1, x2 = paddle.split( + inputs, + num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2], + axis=1) + x2 = self._conv_pw(x2) + x3 = self._conv_dw(x2) + x3 = paddle.concat([x2, x3], axis=1) + x3 = self._se(x3) + x3 = self._conv_linear(x3) + out = paddle.concat([x1, x3], axis=1) + return channel_shuffle(out, 2) + + +class InvertedResidualDS(nn.Layer): + def __init__(self, + in_channels, + mid_channels, + out_channels, + stride, + act="relu"): + super(InvertedResidualDS, self).__init__() + + # branch1 + self._conv_dw_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=stride, + padding=1, + groups=in_channels, + act=None) + self._conv_linear_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act) + # branch2 + self._conv_pw_2 = ConvBNLayer( + in_channels=in_channels, + out_channels=mid_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act) + self._conv_dw_2 = ConvBNLayer( + in_channels=mid_channels // 2, + out_channels=mid_channels // 2, + kernel_size=3, + stride=stride, + padding=1, + groups=mid_channels // 2, + act=None) + self._se = SEModule(mid_channels // 2) + self._conv_linear_2 = ConvBNLayer( + in_channels=mid_channels // 2, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act=act) + self._conv_dw_mv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1, + groups=out_channels, + act="hard_swish") + self._conv_pw_mv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + act="hard_swish") + + def forward(self, inputs): + x1 = self._conv_dw_1(inputs) + x1 = self._conv_linear_1(x1) + x2 = self._conv_pw_2(inputs) + x2 = self._conv_dw_2(x2) + x2 = self._se(x2) + x2 = self._conv_linear_2(x2) + out = paddle.concat([x1, x2], axis=1) + out = self._conv_dw_mv1(out) + out = self._conv_pw_mv1(out) + + return out + + +@register +@serializable +class ESNet(nn.Layer): + def __init__(self, + scale=1.0, + act="hard_swish", + feature_maps=[4, 11, 14], + channel_ratio=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]): + super(ESNet, self).__init__() + self.scale = scale + if isinstance(feature_maps, Integral): + feature_maps = [feature_maps] + self.feature_maps = feature_maps + stage_repeats = [3, 7, 3] + + stage_out_channels = [ + -1, 24, make_divisible(128 * scale), make_divisible(256 * scale), + make_divisible(512 * scale), 1024 + ] + + self._out_channels = [] + self._feature_idx = 0 + # 1. conv1 + self._conv1 = ConvBNLayer( + in_channels=3, + out_channels=stage_out_channels[1], + kernel_size=3, + stride=2, + padding=1, + act=act) + self._max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1) + self._feature_idx += 1 + + # 2. bottleneck sequences + self._block_list = [] + arch_idx = 0 + for stage_id, num_repeat in enumerate(stage_repeats): + for i in range(num_repeat): + channels_scales = channel_ratio[arch_idx] + mid_c = make_divisible( + int(stage_out_channels[stage_id + 2] * channels_scales), + divisor=8) + if i == 0: + block = self.add_sublayer( + name=str(stage_id + 2) + '_' + str(i + 1), + sublayer=InvertedResidualDS( + in_channels=stage_out_channels[stage_id + 1], + mid_channels=mid_c, + out_channels=stage_out_channels[stage_id + 2], + stride=2, + act=act)) + else: + block = self.add_sublayer( + name=str(stage_id + 2) + '_' + str(i + 1), + sublayer=InvertedResidual( + in_channels=stage_out_channels[stage_id + 2], + mid_channels=mid_c, + out_channels=stage_out_channels[stage_id + 2], + stride=1, + act=act)) + self._block_list.append(block) + arch_idx += 1 + self._feature_idx += 1 + self._update_out_channels(stage_out_channels[stage_id + 2], + self._feature_idx, self.feature_maps) + + def _update_out_channels(self, channel, feature_idx, feature_maps): + if feature_idx in feature_maps: + self._out_channels.append(channel) + + def forward(self, inputs): + y = self._conv1(inputs['image']) + y = self._max_pool(y) + outs = [] + for i, inv in enumerate(self._block_list): + y = inv(y) + if i + 2 in self.feature_maps: + outs.append(y) + + return outs + + @property + def out_shape(self): + return [ShapeSpec(channels=c) for c in self._out_channels] diff --git a/ppdet/modeling/bbox_utils.py b/ppdet/modeling/bbox_utils.py index 05d7cef2f..789d06467 100644 --- a/ppdet/modeling/bbox_utils.py +++ b/ppdet/modeling/bbox_utils.py @@ -143,6 +143,100 @@ def bbox_overlaps(boxes1, boxes2): return overlaps +def batch_bbox_overlaps(bboxes1, + bboxes2, + mode='iou', + is_aligned=False, + eps=1e-6): + """Calculate overlap between two set of bboxes. + If ``is_aligned `` is ``False``, then calculate the overlaps between each + bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned + pair of bboxes1 and bboxes2. + Args: + bboxes1 (Tensor): shape (B, m, 4) in format or empty. + bboxes2 (Tensor): shape (B, n, 4) in format or empty. + B indicates the batch dim, in shape (B1, B2, ..., Bn). + If ``is_aligned `` is ``True``, then m and n must be equal. + mode (str): "iou" (intersection over union) or "iof" (intersection over + foreground). + is_aligned (bool, optional): If True, then m and n must be equal. + Default False. + eps (float, optional): A value added to the denominator for numerical + stability. Default 1e-6. + Returns: + Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,) + """ + assert mode in ['iou', 'iof', 'giou'], 'Unsupported mode {}'.format(mode) + # Either the boxes are empty or the length of boxes's last dimenstion is 4 + assert (bboxes1.shape[-1] == 4 or bboxes1.shape[0] == 0) + assert (bboxes2.shape[-1] == 4 or bboxes2.shape[0] == 0) + + # Batch dim must be the same + # Batch dim: (B1, B2, ... Bn) + assert bboxes1.shape[:-2] == bboxes2.shape[:-2] + batch_shape = bboxes1.shape[:-2] + + rows = bboxes1.shape[-2] if bboxes1.shape[0] > 0 else 0 + cols = bboxes2.shape[-2] if bboxes2.shape[0] > 0 else 0 + if is_aligned: + assert rows == cols + + if rows * cols == 0: + if is_aligned: + return paddle.full(batch_shape + (rows, ), 1) + else: + return paddle.full(batch_shape + (rows, cols), 1) + + area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1]) + area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1]) + + if is_aligned: + lt = paddle.maximum(bboxes1[:, :2], bboxes2[:, :2]) # [B, rows, 2] + rb = paddle.minimum(bboxes1[:, 2:], bboxes2[:, 2:]) # [B, rows, 2] + + wh = (rb - lt).clip(min=0) # [B, rows, 2] + overlap = wh[:, 0] * wh[:, 1] + + if mode in ['iou', 'giou']: + union = area1 + area2 - overlap + else: + union = area1 + if mode == 'giou': + enclosed_lt = paddle.minimum(bboxes1[:, :2], bboxes2[:, :2]) + enclosed_rb = paddle.maximum(bboxes1[:, 2:], bboxes2[:, 2:]) + else: + lt = paddle.maximum(bboxes1[:, :2].reshape([rows, 1, 2]), + bboxes2[:, :2]) # [B, rows, cols, 2] + rb = paddle.minimum(bboxes1[:, 2:].reshape([rows, 1, 2]), + bboxes2[:, 2:]) # [B, rows, cols, 2] + + wh = (rb - lt).clip(min=0) # [B, rows, cols, 2] + overlap = wh[:, :, 0] * wh[:, :, 1] + + if mode in ['iou', 'giou']: + union = area1.reshape([rows,1]) \ + + area2.reshape([1,cols]) - overlap + else: + union = area1[:, None] + if mode == 'giou': + enclosed_lt = paddle.minimum(bboxes1[:, :2].reshape([rows, 1, 2]), + bboxes2[:, :2]) + enclosed_rb = paddle.maximum(bboxes1[:, 2:].reshape([rows, 1, 2]), + bboxes2[:, 2:]) + + eps = paddle.to_tensor([eps]) + union = paddle.maximum(union, eps) + ious = overlap / union + if mode in ['iou', 'iof']: + return ious + # calculate gious + enclose_wh = (enclosed_rb - enclosed_lt).clip(min=0) + enclose_area = enclose_wh[:, :, 0] * enclose_wh[:, :, 1] + enclose_area = paddle.maximum(enclose_area, eps) + gious = ious - (enclose_area - union) / enclose_area + return 1 - gious + + def xywh2xyxy(box): x, y, w, h = box x1 = x - w * 0.5 diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index 55b9d907d..b6b928608 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -26,6 +26,7 @@ from . import s2anet_head from . import keypoint_hrhrnet_head from . import centernet_head from . import gfl_head +from . import simota_head from . import pico_head from . import detr_head from . import sparsercnn_head @@ -45,6 +46,7 @@ from .s2anet_head import * from .keypoint_hrhrnet_head import * from .centernet_head import * from .gfl_head import * +from .simota_head import * from .pico_head import * from .detr_head import * from .sparsercnn_head import * diff --git a/ppdet/modeling/heads/gfl_head.py b/ppdet/modeling/heads/gfl_head.py index 54ae68e27..e5b0377a6 100644 --- a/ppdet/modeling/heads/gfl_head.py +++ b/ppdet/modeling/heads/gfl_head.py @@ -150,14 +150,14 @@ class GFLHead(nn.Layer): num_classes (int): Number of classes fpn_stride (list): The stride of each FPN Layer prior_prob (float): Used to set the bias init for the class prediction layer - loss_qfl (object): - loss_dfl (object): - loss_bbox (object): + loss_class (object): Instance of QualityFocalLoss. + loss_dfl (object): Instance of DistributionFocalLoss. + loss_bbox (object): Instance of bbox loss. reg_max: Max value of integral set :math: `{0, ..., reg_max}` n QFL setting. Default: 16. """ __inject__ = [ - 'conv_feat', 'dgqp_module', 'loss_qfl', 'loss_dfl', 'loss_bbox', 'nms' + 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox', 'nms' ] __shared__ = ['num_classes'] @@ -167,7 +167,7 @@ class GFLHead(nn.Layer): num_classes=80, fpn_stride=[8, 16, 32, 64, 128], prior_prob=0.01, - loss_qfl='QualityFocalLoss', + loss_class='QualityFocalLoss', loss_dfl='DistributionFocalLoss', loss_bbox='GIoULoss', reg_max=16, @@ -181,7 +181,7 @@ class GFLHead(nn.Layer): self.num_classes = num_classes self.fpn_stride = fpn_stride self.prior_prob = prior_prob - self.loss_qfl = loss_qfl + self.loss_qfl = loss_class self.loss_dfl = loss_dfl self.loss_bbox = loss_bbox self.reg_max = reg_max diff --git a/ppdet/modeling/heads/pico_head.py b/ppdet/modeling/heads/pico_head.py index b51dfe941..7cfd24c3c 100644 --- a/ppdet/modeling/heads/pico_head.py +++ b/ppdet/modeling/heads/pico_head.py @@ -26,9 +26,7 @@ from paddle.nn.initializer import Normal, Constant from ppdet.core.workspace import register from ppdet.modeling.layers import ConvNormLayer -from ppdet.modeling.bbox_utils import distance2bbox, bbox2distance -from ppdet.data.transform.atss_assigner import bbox_overlaps -from .gfl_head import GFLHead +from .simota_head import OTAVFLHead @register @@ -49,11 +47,13 @@ class PicoFeat(nn.Layer): num_fpn_stride=3, num_convs=2, norm_type='bn', - share_cls_reg=False): + share_cls_reg=False, + act='hard_swish'): super(PicoFeat, self).__init__() self.num_convs = num_convs self.norm_type = norm_type self.share_cls_reg = share_cls_reg + self.act = act self.cls_convs = [] self.reg_convs = [] for stage_idx in range(num_fpn_stride): @@ -112,35 +112,43 @@ class PicoFeat(nn.Layer): self.cls_convs.append(cls_subnet_convs) self.reg_convs.append(reg_subnet_convs) + def act_func(self, x): + if self.act == "leaky_relu": + x = F.leaky_relu(x) + elif self.act == "hard_swish": + x = F.hardswish(x) + return x + def forward(self, fpn_feat, stage_idx): assert stage_idx < len(self.cls_convs) cls_feat = fpn_feat reg_feat = fpn_feat for i in range(len(self.cls_convs[stage_idx])): - cls_feat = F.leaky_relu(self.cls_convs[stage_idx][i](cls_feat), 0.1) + cls_feat = self.act_func(self.cls_convs[stage_idx][i](cls_feat)) if not self.share_cls_reg: - reg_feat = F.leaky_relu(self.reg_convs[stage_idx][i](reg_feat), - 0.1) + reg_feat = self.act_func(self.reg_convs[stage_idx][i](reg_feat)) return cls_feat, reg_feat @register -class PicoHead(GFLHead): +class PicoHead(OTAVFLHead): """ PicoHead Args: - conv_feat (object): Instance of 'LiteGFLFeat' + conv_feat (object): Instance of 'PicoFeat' num_classes (int): Number of classes fpn_stride (list): The stride of each FPN Layer prior_prob (float): Used to set the bias init for the class prediction layer - loss_qfl (object): - loss_dfl (object): - loss_bbox (object): + loss_class (object): Instance of VariFocalLoss. + loss_dfl (object): Instance of DistributionFocalLoss. + loss_bbox (object): Instance of bbox loss. + assigner (object): Instance of label assigner. reg_max: Max value of integral set :math: `{0, ..., reg_max}` - n QFL setting. Default: 16. + n QFL setting. Default: 7. """ __inject__ = [ - 'conv_feat', 'dgqp_module', 'loss_qfl', 'loss_dfl', 'loss_bbox', 'nms' + 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox', + 'assigner', 'nms' ] __shared__ = ['num_classes'] @@ -150,9 +158,10 @@ class PicoHead(GFLHead): num_classes=80, fpn_stride=[8, 16, 32], prior_prob=0.01, - loss_qfl='QualityFocalLoss', + loss_class='VariFocalLoss', loss_dfl='DistributionFocalLoss', loss_bbox='GIoULoss', + assigner='SimOTAAssigner', reg_max=16, feat_in_chan=96, nms=None, @@ -164,9 +173,10 @@ class PicoHead(GFLHead): num_classes=num_classes, fpn_stride=fpn_stride, prior_prob=prior_prob, - loss_qfl=loss_qfl, + loss_class=loss_class, loss_dfl=loss_dfl, loss_bbox=loss_bbox, + assigner=assigner, reg_max=reg_max, feat_in_chan=feat_in_chan, nms=nms, @@ -176,15 +186,17 @@ class PicoHead(GFLHead): self.num_classes = num_classes self.fpn_stride = fpn_stride self.prior_prob = prior_prob - self.loss_qfl = loss_qfl + self.loss_vfl = loss_class self.loss_dfl = loss_dfl self.loss_bbox = loss_bbox + self.assigner = assigner self.reg_max = reg_max self.feat_in_chan = feat_in_chan self.nms = nms self.nms_pre = nms_pre self.cell_offset = cell_offset - self.use_sigmoid = self.loss_qfl.use_sigmoid + + self.use_sigmoid = self.loss_vfl.use_sigmoid if self.use_sigmoid: self.cls_out_channels = self.num_classes else: diff --git a/ppdet/modeling/heads/simota_head.py b/ppdet/modeling/heads/simota_head.py new file mode 100644 index 000000000..e312726da --- /dev/null +++ b/ppdet/modeling/heads/simota_head.py @@ -0,0 +1,513 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +from functools import partial +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn.initializer import Normal, Constant + +from ppdet.core.workspace import register + +from ppdet.modeling.bbox_utils import distance2bbox, bbox2distance +from ppdet.data.transform.atss_assigner import bbox_overlaps + +from .gfl_head import GFLHead, ScaleReg, Integral + + +def multi_apply(func, *args, **kwargs): + """Apply function to a list of arguments. + + Note: + This function applies the ``func`` to multiple inputs and + map the multiple outputs of the ``func`` into different + list. Each list contains the same type of outputs corresponding + to different inputs. + + Args: + func (Function): A function that will be applied to a list of + arguments + + Returns: + tuple(list): A tuple containing multiple list, each list contains \ + a kind of returned results by the function + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +@register +class OTAHead(GFLHead): + """ + OTAHead + Args: + conv_feat (object): Instance of 'FCOSFeat' + num_classes (int): Number of classes + fpn_stride (list): The stride of each FPN Layer + prior_prob (float): Used to set the bias init for the class prediction layer + loss_qfl (object): Instance of QualityFocalLoss. + loss_dfl (object): Instance of DistributionFocalLoss. + loss_bbox (object): Instance of bbox loss. + assigner (object): Instance of label assigner. + reg_max: Max value of integral set :math: `{0, ..., reg_max}` + n QFL setting. Default: 16. + """ + __inject__ = [ + 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox', + 'assigner', 'nms' + ] + __shared__ = ['num_classes'] + + def __init__(self, + conv_feat='FCOSFeat', + dgqp_module=None, + num_classes=80, + fpn_stride=[8, 16, 32, 64, 128], + prior_prob=0.01, + loss_class='QualityFocalLoss', + loss_dfl='DistributionFocalLoss', + loss_bbox='GIoULoss', + assigner='SimOTAAssigner', + reg_max=16, + feat_in_chan=256, + nms=None, + nms_pre=1000, + cell_offset=0): + super(OTAHead, self).__init__( + conv_feat=conv_feat, + dgqp_module=dgqp_module, + num_classes=num_classes, + fpn_stride=fpn_stride, + prior_prob=prior_prob, + loss_class=loss_class, + loss_dfl=loss_dfl, + loss_bbox=loss_bbox, + reg_max=reg_max, + feat_in_chan=feat_in_chan, + nms=nms, + nms_pre=nms_pre, + cell_offset=cell_offset) + self.conv_feat = conv_feat + self.dgqp_module = dgqp_module + self.num_classes = num_classes + self.fpn_stride = fpn_stride + self.prior_prob = prior_prob + self.loss_qfl = loss_class + self.loss_dfl = loss_dfl + self.loss_bbox = loss_bbox + self.reg_max = reg_max + self.feat_in_chan = feat_in_chan + self.nms = nms + self.nms_pre = nms_pre + self.cell_offset = cell_offset + self.use_sigmoid = self.loss_qfl.use_sigmoid + + self.assigner = assigner + + def _get_target_single(self, cls_preds, centors, decoded_bboxes, gt_bboxes, + gt_labels): + """Compute targets for priors in a single image. + """ + num_gts = gt_labels.shape[0] + # No target + if num_gts == 0: + pass + + centors, labels, label_weights, bbox_targets, pos_num = self.assigner( + F.sigmoid(cls_preds), centors, decoded_bboxes, gt_bboxes, gt_labels) + + return (centors, labels, label_weights, bbox_targets, pos_num) + + def get_loss(self, head_outs, gt_meta): + cls_scores, bbox_preds = head_outs + num_level_anchors = [ + featmap.shape[-2] * featmap.shape[-1] for featmap in cls_scores + ] + num_imgs = gt_meta['im_id'].shape[0] + featmap_sizes = [[featmap.shape[-2], featmap.shape[-1]] + for featmap in cls_scores] + + decode_bbox_preds = [] + mlvl_centors = [] + with_stride = True + for featmap_size, stride, bbox_pred in zip(featmap_sizes, + self.fpn_stride, bbox_preds): + + yy, xx = self.get_single_level_center_point(featmap_size, stride, + self.cell_offset) + if with_stride: + stride_w = paddle.full((len(xx), ), stride) + stride_h = paddle.full((len(yy), ), stride) + centers = paddle.stack([xx, yy, stride_w, stride_h], -1).tile( + [num_imgs, 1, 1]) + mlvl_centors.append(centers) + centers_in_feature = centers.reshape([-1, 4])[:, :-2] / stride + bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape( + [num_imgs, -1, 4 * (self.reg_max + 1)]) + pred_corners = self.distribution_project(bbox_pred) + decode_bbox_pred = distance2bbox( + centers_in_feature, pred_corners).reshape([num_imgs, -1, 4]) + decode_bbox_preds.append(decode_bbox_pred * stride) + + flatten_cls_preds = [ + cls_pred.transpose([0, 2, 3, 1]).reshape( + [num_imgs, -1, self.cls_out_channels]) + for cls_pred in cls_scores + ] + flatten_cls_preds = paddle.concat(flatten_cls_preds, axis=1) + flatten_bboxes = paddle.concat(decode_bbox_preds, axis=1) + flatten_centors = paddle.concat(mlvl_centors, axis=1) + + gt_box, gt_labels = gt_meta['gt_bbox'], gt_meta['gt_class'] + (centors, labels, label_weights, bbox_targets, pos_num) = multi_apply( + self._get_target_single, + flatten_cls_preds.detach(), + flatten_centors.detach(), + flatten_bboxes.detach(), gt_box, gt_labels) + + centors = paddle.to_tensor(np.stack(centors, axis=0)) + labels = paddle.to_tensor(np.stack(labels, axis=0)) + label_weights = paddle.to_tensor(np.stack(label_weights, axis=0)) + bbox_targets = paddle.to_tensor(np.stack(bbox_targets, axis=0)) + + centors_list = self._images_to_levels(centors, num_level_anchors) + labels_list = self._images_to_levels(labels, num_level_anchors) + label_weights_list = self._images_to_levels(label_weights, + num_level_anchors) + bbox_targets_list = self._images_to_levels(bbox_targets, + num_level_anchors) + num_total_pos = sum(pos_num) + try: + num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone( + )) / paddle.distributed.get_world_size() + except: + num_total_pos = max(num_total_pos, 1) + + loss_bbox_list, loss_dfl_list, loss_qfl_list, avg_factor = [], [], [], [] + for cls_score, bbox_pred, grid_cells, labels, label_weights, bbox_targets, stride in zip( + cls_scores, bbox_preds, centors_list, labels_list, + label_weights_list, bbox_targets_list, self.fpn_stride): + grid_cells = grid_cells.reshape([-1, 4]) + cls_score = cls_score.transpose([0, 2, 3, 1]).reshape( + [-1, self.cls_out_channels]) + bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape( + [-1, 4 * (self.reg_max + 1)]) + bbox_targets = bbox_targets.reshape([-1, 4]) + labels = labels.reshape([-1]) + label_weights = label_weights.reshape([-1]) + + bg_class_ind = self.num_classes + pos_inds = paddle.nonzero( + paddle.logical_and((labels >= 0), (labels < bg_class_ind)), + as_tuple=False).squeeze(1) + score = np.zeros(labels.shape) + + if len(pos_inds) > 0: + pos_bbox_targets = paddle.gather(bbox_targets, pos_inds, axis=0) + pos_bbox_pred = paddle.gather(bbox_pred, pos_inds, axis=0) + pos_grid_cell_centers = paddle.gather( + grid_cells[:, :-2], pos_inds, axis=0) / stride + + weight_targets = F.sigmoid(cls_score.detach()) + weight_targets = paddle.gather( + weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0) + pos_bbox_pred_corners = self.distribution_project(pos_bbox_pred) + pos_decode_bbox_pred = distance2bbox(pos_grid_cell_centers, + pos_bbox_pred_corners) + pos_decode_bbox_targets = pos_bbox_targets / stride + bbox_iou = bbox_overlaps( + pos_decode_bbox_pred.detach().numpy(), + pos_decode_bbox_targets.detach().numpy(), + is_aligned=True) + score[pos_inds.numpy()] = bbox_iou + + pred_corners = pos_bbox_pred.reshape([-1, self.reg_max + 1]) + target_corners = bbox2distance(pos_grid_cell_centers, + pos_decode_bbox_targets, + self.reg_max).reshape([-1]) + # regression loss + loss_bbox = paddle.sum( + self.loss_bbox(pos_decode_bbox_pred, + pos_decode_bbox_targets) * weight_targets) + + # dfl loss + loss_dfl = self.loss_dfl( + pred_corners, + target_corners, + weight=weight_targets.expand([-1, 4]).reshape([-1]), + avg_factor=4.0) + else: + loss_bbox = bbox_pred.sum() * 0 + loss_dfl = bbox_pred.sum() * 0 + weight_targets = paddle.to_tensor([0], dtype='float32') + + # qfl loss + score = paddle.to_tensor(score) + loss_qfl = self.loss_qfl( + cls_score, (labels, score), + weight=label_weights, + avg_factor=num_total_pos) + loss_bbox_list.append(loss_bbox) + loss_dfl_list.append(loss_dfl) + loss_qfl_list.append(loss_qfl) + avg_factor.append(weight_targets.sum()) + + avg_factor = sum(avg_factor) + try: + avg_factor = paddle.distributed.all_reduce(avg_factor.clone()) + avg_factor = paddle.clip( + avg_factor / paddle.distributed.get_world_size(), min=1) + except: + avg_factor = max(avg_factor.item(), 1) + if avg_factor <= 0: + loss_qfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False) + loss_bbox = paddle.to_tensor( + 0, dtype='float32', stop_gradient=False) + loss_dfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False) + else: + losses_bbox = list(map(lambda x: x / avg_factor, loss_bbox_list)) + losses_dfl = list(map(lambda x: x / avg_factor, loss_dfl_list)) + loss_qfl = sum(loss_qfl_list) + loss_bbox = sum(losses_bbox) + loss_dfl = sum(losses_dfl) + + loss_states = dict( + loss_qfl=loss_qfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl) + + return loss_states + + +@register +class OTAVFLHead(OTAHead): + __inject__ = [ + 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox', + 'assigner', 'nms' + ] + __shared__ = ['num_classes'] + + def __init__(self, + conv_feat='FCOSFeat', + dgqp_module=None, + num_classes=80, + fpn_stride=[8, 16, 32, 64, 128], + prior_prob=0.01, + loss_class='VarifocalLoss', + loss_dfl='DistributionFocalLoss', + loss_bbox='GIoULoss', + assigner='SimOTAAssigner', + reg_max=16, + feat_in_chan=256, + nms=None, + nms_pre=1000, + cell_offset=0): + super(OTAVFLHead, self).__init__( + conv_feat=conv_feat, + dgqp_module=dgqp_module, + num_classes=num_classes, + fpn_stride=fpn_stride, + prior_prob=prior_prob, + loss_class=loss_class, + loss_dfl=loss_dfl, + loss_bbox=loss_bbox, + reg_max=reg_max, + feat_in_chan=feat_in_chan, + nms=nms, + nms_pre=nms_pre, + cell_offset=cell_offset) + self.conv_feat = conv_feat + self.dgqp_module = dgqp_module + self.num_classes = num_classes + self.fpn_stride = fpn_stride + self.prior_prob = prior_prob + self.loss_vfl = loss_class + self.loss_dfl = loss_dfl + self.loss_bbox = loss_bbox + self.reg_max = reg_max + self.feat_in_chan = feat_in_chan + self.nms = nms + self.nms_pre = nms_pre + self.cell_offset = cell_offset + self.use_sigmoid = self.loss_vfl.use_sigmoid + + self.assigner = assigner + + def get_loss(self, head_outs, gt_meta): + cls_scores, bbox_preds = head_outs + num_level_anchors = [ + featmap.shape[-2] * featmap.shape[-1] for featmap in cls_scores + ] + num_imgs = gt_meta['im_id'].shape[0] + featmap_sizes = [[featmap.shape[-2], featmap.shape[-1]] + for featmap in cls_scores] + + decode_bbox_preds = [] + mlvl_centors = [] + with_stride = True + for featmap_size, stride, bbox_pred in zip(featmap_sizes, + self.fpn_stride, bbox_preds): + + yy, xx = self.get_single_level_center_point(featmap_size, stride, + self.cell_offset) + if with_stride: + stride_w = paddle.full((len(xx), ), stride) + stride_h = paddle.full((len(yy), ), stride) + centers = paddle.stack([xx, yy, stride_w, stride_h], -1).tile( + [num_imgs, 1, 1]) + mlvl_centors.append(centers) + centers_in_feature = centers.reshape([-1, 4])[:, :-2] / stride + bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape( + [num_imgs, -1, 4 * (self.reg_max + 1)]) + pred_corners = self.distribution_project(bbox_pred) + decode_bbox_pred = distance2bbox( + centers_in_feature, pred_corners).reshape([num_imgs, -1, 4]) + decode_bbox_preds.append(decode_bbox_pred * stride) + + flatten_cls_preds = [ + cls_pred.transpose([0, 2, 3, 1]).reshape( + [num_imgs, -1, self.cls_out_channels]) + for cls_pred in cls_scores + ] + flatten_cls_preds = paddle.concat(flatten_cls_preds, axis=1) + flatten_bboxes = paddle.concat(decode_bbox_preds, axis=1) + flatten_centors = paddle.concat(mlvl_centors, axis=1) + + gt_box, gt_labels = gt_meta['gt_bbox'], gt_meta['gt_class'] + (centors, labels, label_weights, bbox_targets, pos_num) = multi_apply( + self._get_target_single, + flatten_cls_preds.detach(), + flatten_centors.detach(), + flatten_bboxes.detach(), gt_box, gt_labels) + + centors = paddle.to_tensor(np.stack(centors, axis=0)) + labels = paddle.to_tensor(np.stack(labels, axis=0)) + label_weights = paddle.to_tensor(np.stack(label_weights, axis=0)) + bbox_targets = paddle.to_tensor(np.stack(bbox_targets, axis=0)) + + centors_list = self._images_to_levels(centors, num_level_anchors) + labels_list = self._images_to_levels(labels, num_level_anchors) + label_weights_list = self._images_to_levels(label_weights, + num_level_anchors) + bbox_targets_list = self._images_to_levels(bbox_targets, + num_level_anchors) + num_total_pos = sum(pos_num) + try: + num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone( + )) / paddle.distributed.get_world_size() + except: + num_total_pos = max(num_total_pos, 1) + + loss_bbox_list, loss_dfl_list, loss_vfl_list, avg_factor = [], [], [], [] + for cls_score, bbox_pred, grid_cells, labels, label_weights, bbox_targets, stride in zip( + cls_scores, bbox_preds, centors_list, labels_list, + label_weights_list, bbox_targets_list, self.fpn_stride): + grid_cells = grid_cells.reshape([-1, 4]) + cls_score = cls_score.transpose([0, 2, 3, 1]).reshape( + [-1, self.cls_out_channels]) + bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape( + [-1, 4 * (self.reg_max + 1)]) + bbox_targets = bbox_targets.reshape([-1, 4]) + labels = labels.reshape([-1]) + label_weights = label_weights.reshape([-1]) + + bg_class_ind = self.num_classes + pos_inds = paddle.nonzero( + paddle.logical_and((labels >= 0), (labels < bg_class_ind)), + as_tuple=False).squeeze(1) + # vfl + vfl_score = np.zeros(cls_score.shape) + + if len(pos_inds) > 0: + pos_bbox_targets = paddle.gather(bbox_targets, pos_inds, axis=0) + pos_bbox_pred = paddle.gather(bbox_pred, pos_inds, axis=0) + pos_grid_cell_centers = paddle.gather( + grid_cells[:, :-2], pos_inds, axis=0) / stride + + weight_targets = F.sigmoid(cls_score.detach()) + weight_targets = paddle.gather( + weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0) + pos_bbox_pred_corners = self.distribution_project(pos_bbox_pred) + pos_decode_bbox_pred = distance2bbox(pos_grid_cell_centers, + pos_bbox_pred_corners) + pos_decode_bbox_targets = pos_bbox_targets / stride + bbox_iou = bbox_overlaps( + pos_decode_bbox_pred.detach().numpy(), + pos_decode_bbox_targets.detach().numpy(), + is_aligned=True) + + # vfl + pos_labels = paddle.gather(labels, pos_inds, axis=0) + vfl_score[pos_inds.numpy(), pos_labels] = bbox_iou + + pred_corners = pos_bbox_pred.reshape([-1, self.reg_max + 1]) + target_corners = bbox2distance(pos_grid_cell_centers, + pos_decode_bbox_targets, + self.reg_max).reshape([-1]) + # regression loss + loss_bbox = paddle.sum( + self.loss_bbox(pos_decode_bbox_pred, + pos_decode_bbox_targets) * weight_targets) + + # dfl loss + loss_dfl = self.loss_dfl( + pred_corners, + target_corners, + weight=weight_targets.expand([-1, 4]).reshape([-1]), + avg_factor=4.0) + else: + loss_bbox = bbox_pred.sum() * 0 + loss_dfl = bbox_pred.sum() * 0 + weight_targets = paddle.to_tensor([0], dtype='float32') + + # vfl loss + num_pos_avg_per_gpu = num_total_pos + vfl_score = paddle.to_tensor(vfl_score) + loss_vfl = self.loss_vfl( + cls_score, vfl_score, avg_factor=num_pos_avg_per_gpu) + + loss_bbox_list.append(loss_bbox) + loss_dfl_list.append(loss_dfl) + loss_vfl_list.append(loss_vfl) + avg_factor.append(weight_targets.sum()) + + avg_factor = sum(avg_factor) + try: + avg_factor = paddle.distributed.all_reduce(avg_factor.clone()) + avg_factor = paddle.clip( + avg_factor / paddle.distributed.get_world_size(), min=1) + except: + avg_factor = max(avg_factor.item(), 1) + if avg_factor <= 0: + loss_vfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False) + loss_bbox = paddle.to_tensor( + 0, dtype='float32', stop_gradient=False) + loss_dfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False) + else: + losses_bbox = list(map(lambda x: x / avg_factor, loss_bbox_list)) + losses_dfl = list(map(lambda x: x / avg_factor, loss_dfl_list)) + loss_vfl = sum(loss_vfl_list) + loss_bbox = sum(losses_bbox) + loss_dfl = sum(losses_dfl) + + loss_states = dict( + loss_vfl=loss_vfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl) + + return loss_states diff --git a/ppdet/modeling/losses/varifocal_loss.py b/ppdet/modeling/losses/varifocal_loss.py new file mode 100644 index 000000000..220c3b072 --- /dev/null +++ b/ppdet/modeling/losses/varifocal_loss.py @@ -0,0 +1,158 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.core.workspace import register, serializable +from ppdet.modeling import ops + +__all__ = ['VarifocalLoss'] + + +def varifocal_loss(pred, + target, + alpha=0.75, + gamma=2.0, + iou_weighted=True, + use_sigmoid=True): + """`Varifocal Loss `_ + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the + number of classes + target (torch.Tensor): The learning target of the iou-aware + classification score with shape (N, C), C is the number of classes. + alpha (float, optional): A balance factor for the negative part of + Varifocal Loss, which is different from the alpha of Focal Loss. + Defaults to 0.75. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + iou_weighted (bool, optional): Whether to weight the loss of the + positive example with the iou target. Defaults to True. + """ + # pred and target should be of the same size + assert pred.shape == target.shape + if use_sigmoid: + pred_new = F.sigmoid(pred) + else: + pred_new = pred + target = target.cast(pred.dtype) + if iou_weighted: + focal_weight = target * (target > 0.0).cast('float32') + \ + alpha * (pred_new - target).abs().pow(gamma) * \ + (target <= 0.0).cast('float32') + else: + focal_weight = (target > 0.0).cast('float32') + \ + alpha * (pred_new - target).abs().pow(gamma) * \ + (target <= 0.0).cast('float32') + + if use_sigmoid: + loss = F.binary_cross_entropy_with_logits( + pred, target, reduction='none') * focal_weight + else: + loss = F.binary_cross_entropy( + pred, target, reduction='none') * focal_weight + loss = loss.sum(axis=1) + return loss + + +@register +@serializable +class VarifocalLoss(nn.Layer): + def __init__(self, + use_sigmoid=True, + alpha=0.75, + gamma=2.0, + iou_weighted=True, + reduction='mean', + loss_weight=1.0): + """`Varifocal Loss `_ + + Args: + use_sigmoid (bool, optional): Whether the prediction is + used for sigmoid or softmax. Defaults to True. + alpha (float, optional): A balance factor for the negative part of + Varifocal Loss, which is different from the alpha of Focal + Loss. Defaults to 0.75. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + iou_weighted (bool, optional): Whether to weight the loss of the + positive examples with the iou target. Defaults to True. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and + "sum". + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + """ + super(VarifocalLoss, self).__init__() + assert alpha >= 0.0 + self.use_sigmoid = use_sigmoid + self.alpha = alpha + self.gamma = gamma + self.iou_weighted = iou_weighted + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + """Forward function. + + Args: + pred (torch.Tensor): The prediction. + target (torch.Tensor): The learning target of the prediction. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + loss = self.loss_weight * varifocal_loss( + pred, + target, + alpha=self.alpha, + gamma=self.gamma, + iou_weighted=self.iou_weighted, + use_sigmoid=self.use_sigmoid) + + if weight is not None: + loss = loss * weight + if avg_factor is None: + if self.reduction == 'none': + return loss + elif self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: + # if reduction is mean, then average the loss by avg_factor + if self.reduction == 'mean': + loss = loss.sum() / avg_factor + # if reduction is 'none', then do nothing, otherwise raise an error + elif self.reduction != 'none': + raise ValueError( + 'avg_factor can not be used with reduction="sum"') + return loss diff --git a/ppdet/modeling/necks/__init__.py b/ppdet/modeling/necks/__init__.py index 4745d8091..79deaf94c 100644 --- a/ppdet/modeling/necks/__init__.py +++ b/ppdet/modeling/necks/__init__.py @@ -19,6 +19,7 @@ from . import ttf_fpn from . import centernet_fpn from . import pan from . import bifpn +from . import csp_pan from .fpn import * from .yolo_fpn import * @@ -28,3 +29,4 @@ from .centernet_fpn import * from .blazeface_fpn import * from .pan import * from .bifpn import * +from .csp_pan import * diff --git a/ppdet/modeling/necks/csp_pan.py b/ppdet/modeling/necks/csp_pan.py new file mode 100644 index 000000000..fccd25b75 --- /dev/null +++ b/ppdet/modeling/necks/csp_pan.py @@ -0,0 +1,361 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.regularizer import L2Decay +from ppdet.core.workspace import register, serializable +from ..shape_spec import ShapeSpec + +__all__ = ['CSPPAN'] + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channel=96, + out_channel=96, + kernel_size=3, + stride=1, + groups=1, + act='leaky_relu'): + super(ConvBNLayer, self).__init__() + initializer = nn.initializer.KaimingUniform() + self.act = act + assert self.act in ['leaky_relu', "hard_swish"] + self.conv = nn.Conv2D( + in_channels=in_channel, + out_channels=out_channel, + kernel_size=kernel_size, + groups=groups, + padding=(kernel_size - 1) // 2, + stride=stride, + weight_attr=ParamAttr(initializer=initializer), + bias_attr=False) + self.bn = nn.BatchNorm2D(out_channel) + + def forward(self, x): + x = self.bn(self.conv(x)) + if self.act == "leaky_relu": + x = F.leaky_relu(x) + elif self.act == "hard_swish": + x = F.hardswish(x) + return x + + +class DPModule(nn.Layer): + """ + Depth-wise and point-wise module. + Args: + in_channel (int): The input channels of this Module. + out_channel (int): The output channels of this Module. + kernel_size (int): The conv2d kernel size of this Module. + stride (int): The conv2d's stride of this Module. + act (str): The activation function of this Module, + Now support `leaky_relu` and `hard_swish`. + """ + + def __init__(self, + in_channel=96, + out_channel=96, + kernel_size=3, + stride=1, + act='leaky_relu'): + super(DPModule, self).__init__() + initializer = nn.initializer.KaimingUniform() + self.act = act + self.dwconv = nn.Conv2D( + in_channels=in_channel, + out_channels=out_channel, + kernel_size=kernel_size, + groups=out_channel, + padding=(kernel_size - 1) // 2, + stride=stride, + weight_attr=ParamAttr(initializer=initializer), + bias_attr=False) + self.bn1 = nn.BatchNorm2D(out_channel) + self.pwconv = nn.Conv2D( + in_channels=out_channel, + out_channels=out_channel, + kernel_size=1, + groups=1, + padding=0, + weight_attr=ParamAttr(initializer=initializer), + bias_attr=False) + self.bn2 = nn.BatchNorm2D(out_channel) + + def act_func(self, x): + if self.act == "leaky_relu": + x = F.leaky_relu(x) + elif self.act == "hard_swish": + x = F.hardswish(x) + return x + + def forward(self, x): + x = self.act_func(self.bn1(self.dwconv(x))) + x = self.act_func(self.bn2(self.pwconv(x))) + return x + + +class DarknetBottleneck(nn.Layer): + """The basic bottleneck block used in Darknet. + + Each Block consists of two ConvModules and the input is added to the + final output. Each ConvModule is composed of Conv, BN, and act. + The first convLayer has filter size of 1x1 and the second one has the + filter size of 3x3. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + expansion (int): The kernel size of the convolution. Default: 0.5 + add_identity (bool): Whether to add identity to the out. + Default: True + use_depthwise (bool): Whether to use depthwise separable convolution. + Default: False + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + expansion=0.5, + add_identity=True, + use_depthwise=False, + act="leaky_relu"): + super(DarknetBottleneck, self).__init__() + hidden_channels = int(out_channels * expansion) + conv_func = DPModule if use_depthwise else ConvBNLayer + self.conv1 = ConvBNLayer( + in_channel=in_channels, + out_channel=hidden_channels, + kernel_size=1, + act=act) + self.conv2 = conv_func( + in_channel=hidden_channels, + out_channel=out_channels, + kernel_size=kernel_size, + stride=1, + act=act) + self.add_identity = \ + add_identity and in_channels == out_channels + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.conv2(out) + + if self.add_identity: + return out + identity + else: + return out + + +class CSPLayer(nn.Layer): + """Cross Stage Partial Layer. + + Args: + in_channels (int): The input channels of the CSP layer. + out_channels (int): The output channels of the CSP layer. + expand_ratio (float): Ratio to adjust the number of channels of the + hidden layer. Default: 0.5 + num_blocks (int): Number of blocks. Default: 1 + add_identity (bool): Whether to add identity in blocks. + Default: True + use_depthwise (bool): Whether to depthwise separable convolution in + blocks. Default: False + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + expand_ratio=0.5, + num_blocks=1, + add_identity=True, + use_depthwise=False, + act="leaky_relu"): + super().__init__() + mid_channels = int(out_channels * expand_ratio) + self.main_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act) + self.short_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act) + self.final_conv = ConvBNLayer( + 2 * mid_channels, out_channels, 1, act=act) + + self.blocks = nn.Sequential(* [ + DarknetBottleneck( + mid_channels, + mid_channels, + kernel_size, + 1.0, + add_identity, + use_depthwise, + act=act) for _ in range(num_blocks) + ]) + + def forward(self, x): + x_short = self.short_conv(x) + + x_main = self.main_conv(x) + x_main = self.blocks(x_main) + + x_final = paddle.concat((x_main, x_short), axis=1) + return self.final_conv(x_final) + + +class Channel_T(nn.Layer): + def __init__(self, + in_channels=[116, 232, 464], + out_channels=96, + act="leaky_relu"): + super(Channel_T, self).__init__() + self.convs = nn.LayerList() + for i in range(len(in_channels)): + self.convs.append( + ConvBNLayer( + in_channels[i], out_channels, 1, act=act)) + + def forward(self, x): + outs = [self.convs[i](x[i]) for i in range(len(x))] + return outs + + +@register +@serializable +class CSPPAN(nn.Layer): + """Path Aggregation Network with CSP module. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + kernel_size (int): The conv2d kernel size of this Module. + num_features (int): Number of output features of CSPPAN module. + num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 1 + use_depthwise (bool): Whether to depthwise separable convolution in + blocks. Default: True + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=5, + num_features=3, + num_csp_blocks=1, + use_depthwise=True, + act='hard_swish', + spatial_scales=[0.125, 0.0625, 0.03125]): + super(CSPPAN, self).__init__() + self.conv_t = Channel_T(in_channels, out_channels, act=act) + in_channels = [out_channels] * len(spatial_scales) + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial_scales = spatial_scales + self.num_features = num_features + conv_func = DPModule if use_depthwise else ConvBNLayer + + if self.num_features == 4: + self.first_top_conv = conv_func( + in_channels[0], in_channels[0], kernel_size, stride=2, act=act) + self.second_top_conv = conv_func( + in_channels[0], in_channels[0], kernel_size, stride=2, act=act) + self.spatial_scales.append(self.spatial_scales[-1] / 2) + + # build top-down blocks + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + self.top_down_blocks = nn.LayerList() + for idx in range(len(in_channels) - 1, 0, -1): + self.top_down_blocks.append( + CSPLayer( + in_channels[idx - 1] * 2, + in_channels[idx - 1], + kernel_size=kernel_size, + num_blocks=num_csp_blocks, + add_identity=False, + use_depthwise=use_depthwise, + act=act)) + + # build bottom-up blocks + self.downsamples = nn.LayerList() + self.bottom_up_blocks = nn.LayerList() + for idx in range(len(in_channels) - 1): + self.downsamples.append( + conv_func( + in_channels[idx], + in_channels[idx], + kernel_size=kernel_size, + stride=2, + act=act)) + self.bottom_up_blocks.append( + CSPLayer( + in_channels[idx] * 2, + in_channels[idx + 1], + kernel_size=kernel_size, + num_blocks=num_csp_blocks, + add_identity=False, + use_depthwise=use_depthwise, + act=act)) + + def forward(self, inputs): + """ + Args: + inputs (tuple[Tensor]): input features. + + Returns: + tuple[Tensor]: CSPPAN features. + """ + assert len(inputs) == len(self.in_channels) + inputs = self.conv_t(inputs) + + # top-down path + inner_outs = [inputs[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_heigh = inner_outs[0] + feat_low = inputs[idx - 1] + + upsample_feat = self.upsample(feat_heigh) + + inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx]( + paddle.concat([upsample_feat, feat_low], 1)) + inner_outs.insert(0, inner_out) + + # bottom-up path + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_height = inner_outs[idx + 1] + downsample_feat = self.downsamples[idx](feat_low) + out = self.bottom_up_blocks[idx](paddle.concat( + [downsample_feat, feat_height], 1)) + outs.append(out) + + top_features = None + if self.num_features == 4: + top_features = self.first_top_conv(inputs[-1]) + top_features = top_features + self.second_top_conv(outs[-1]) + outs.append(top_features) + + return tuple(outs) + + @property + def out_shape(self): + return [ + ShapeSpec( + channels=self.out_channels, stride=1. / s) + for s in self.spatial_scales + ] + + @classmethod + def from_config(cls, cfg, input_shape): + return {'in_channels': [i.channels for i in input_shape], } -- GitLab