未验证 提交 35388eb7 编写于 作者: G Guanghua Yu 提交者: GitHub

Update PicoDet and GFl (#4044)

* Update PicoDet adn GFl
上级 3b564170
......@@ -2,21 +2,22 @@
## Introduction
[Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection](https://arxiv.org/abs/2006.04388) and [Generalized Focal Loss V2](https://arxiv.org/pdf/2011.12885.pdf)
We reproduce the object detection results in the paper [Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection](https://arxiv.org/abs/2006.04388) and [Generalized Focal Loss V2](https://arxiv.org/pdf/2011.12885.pdf). And We use a better performing pre-trained model and ResNet-vd structure to improve mAP.
## Model Zoo
| Backbone | Model | images/GPU | lr schedule |FPS | Box AP | download | config |
| Backbone | Model | batch-size/GPU | lr schedule |FPS | Box AP | download | config |
| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
| ResNet50-FPN | GFL | 2 | 1x | ---- | 40.1 | [download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r50_fpn_1x_coco.yml) |
| ResNet50-FPN | GFLv2 | 2 | 1x | ---- | 40.4 | [download](https://paddledet.bj.bcebos.com/models/gflv2_r50_fpn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gflv2_r50_fpn_1x_coco.yml) |
| ResNet50 | GFL | 2 | 1x | ---- | 41.0 | [model](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_1x_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_gfl_r50_fpn_1x_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r50_fpn_1x_coco.yml) |
| ResNet101-vd | GFL | 2 | 2x | ---- | 46.8 | [model](https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_2x_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_gfl_r101vd_fpn_mstrain_2x_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r101vd_fpn_mstrain_2x_coco.yml) |
| ResNet34-vd | GFL | 2 | 1x | ---- | 40.8 | [model](https://paddledet.bj.bcebos.com/models/gfl_r34vd_1x_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_gfl_r34vd_1x_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r34vd_1x_coco.yml) |
| ResNet18-vd | GFL | 2 | 1x | ---- | 36.6 | [model](https://paddledet.bj.bcebos.com/models/gfl_r18vd_1x_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_gfl_r18vd_1x_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gfl_r18vd_1x_coco.yml) |
| ResNet50 | GFLv2 | 2 | 1x | ---- | 41.2 | [model](https://paddledet.bj.bcebos.com/models/gflv2_r50_fpn_1x_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_gflv2_r50_fpn_1x_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gfl/gflv2_r50_fpn_1x_coco.yml) |
**Notes:**
- GFL is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`.
- GFL is trained on COCO train2017 dataset with 8 GPUs and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`.
## Citations
```
......
......@@ -3,8 +3,8 @@ TrainReader:
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
......@@ -19,8 +19,8 @@ TrainReader:
EvalReader:
sample_transforms:
- Decode: {}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
......@@ -31,8 +31,8 @@ EvalReader:
TestReader:
sample_transforms:
- Decode: {}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
......
......@@ -7,7 +7,7 @@ LearningRate:
gamma: 0.1
milestones: [8, 11]
- !LinearWarmup
start_factor: 0.1
start_factor: 0.001
steps: 500
OptimizerBuilder:
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/gfl_r50_fpn.yml',
'_base_/optimizer_1x.yml',
'_base_/gfl_reader.yml',
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_pretrained.pdparams
weights: output/gfl_r101vd_fpn_mstrain_2x_coco/model_final
find_unused_parameters: True
use_ema: true
ema_decay: 0.9998
ResNet:
depth: 101
variant: d
norm_type: bn
freeze_at: 0
return_idx: [1,2,3]
num_stages: 4
epoch: 24
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [16, 22]
- !LinearWarmup
start_factor: 0.001
steps: 500
TrainReader:
sample_transforms:
- Decode: {}
- RandomResize: {target_size: [[480, 1333], [512, 1333], [544, 1333], [576, 1333], [608, 1333], [640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], interp: 2, keep_ratio: True}
- RandomFlip: {prob: 0.5}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
- Gt2GFLTarget:
downsample_ratios: [8, 16, 32, 64, 128]
grid_cell_scale: 8
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/gfl_r50_fpn.yml',
'_base_/optimizer_1x.yml',
'_base_/gfl_reader.yml',
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet18_vd_pretrained.pdparams
weights: output/gfl_r18vd_1x_coco/model_final
find_unused_parameters: True
ResNet:
depth: 18
variant: d
norm_type: bn
freeze_at: 0
return_idx: [1,2,3]
num_stages: 4
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/gfl_r50_fpn.yml',
'_base_/optimizer_1x.yml',
'_base_/gfl_reader.yml',
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet34_vd_pretrained.pdparams
weights: output/gfl_r34vd_1x_coco/model_final
find_unused_parameters: True
ResNet:
depth: 34
variant: d
norm_type: bn
freeze_at: 0
return_idx: [1,2,3]
num_stages: 4
......@@ -6,5 +6,5 @@ _BASE_: [
'_base_/gfl_reader.yml',
]
weights: output/gfl_r50_fpn_1x_coco/model_final
weights: output/gflv2_r50_fpn_1x_coco/model_final
find_unused_parameters: True
# PicoDet
![](../../docs/images/picedet_demo.jpeg)
## Introduction
We developed a series of mobile models, which named `PicoDet`.
Optimizing method of we use:
- [Generalized Focal Loss V2](https://arxiv.org/pdf/2011.12885.pdf)
- Lr Cosine Decay
We developed a series of lightweight models, which named `PicoDet`. Because of its excellent performance, it is very suitable for deployment on mobile or CPU.
Optimizing method of we use:
- [ATSS](https://arxiv.org/abs/1912.02424)
- [Generalized Focal Loss](https://arxiv.org/abs/2006.04388)
- Lr Cosine Decay and cycle-EMA
- lightweight head
## Requirements
- PaddlePaddle == 2.1.2
- PaddleSlim >= 2.1.1
## Comming soon
- [ ] Benchmark of PicoDet.
- [ ] deploy for most platforms, such as PaddleLite、MNN、ncnn、openvino etc.
- [ ] PicoDet-XS and PicoDet-L series of model.
- [ ] Slim for PicoDet.
- [ ] More features in need.
## Model Zoo
### PicoDet-S
| Backbone | Input size | images/GPU | lr schedule |Box AP | FLOPS | Inference Time | download | config |
| :------------------------ | :-------: | :-------: | :-----------: | :---: | :-----: | :-----: | :-------------------------------------------------: | :-----: |
| ShuffleNetv2-1x | 320*320 | 128 | 280e | 21.9 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_320_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_shufflenetv2_320_coco.yml) |
| MobileNetv3-large-0.5x | 320*320 | 128 | 280e | 20.4 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_s_mbv3_320_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_mbv3_320_coco.yml) |
| ShuffleNetv2-1x | 416*416 | 96 | 280e | 24.0 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_416_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_shufflenetv2_416_coco.yml) |
| MobileNetv3-large-0.5x | 416*416 | 96 | 280e | 23.3 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_s_mbv3_416_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_mbv3_416_coco.yml) |
| Backbone | Input size | lr schedule | Box AP(0.5:0.95) | Box AP(0.5) | FLOPS | Model Size | Inference Time | download | config |
| :------------------------ | :-------: | :-------: | :------: | :---: | :---: | :---: | :------------: | :-------------------------------------------------: | :-----: |
| ShuffleNetv2-1x | 320*320 | 280e | 22.3 | 36.8 | -- | 3.8M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_shufflenetv2_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_shufflenetv2_320_coco.yml) |
| ShuffleNetv2-1x | 416*416 | 280e | 24.6 | 44.3 | -- | 3.8M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_shufflenetv2_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_shufflenetv2_416_coco.yml) |
### PicoDet-M
| Backbone | Input size | images/GPU | lr schedule |Box AP | FLOPS | Inference Time | download | config |
| :------------------------ | :-------: | :-------: | :-----------: | :---: | :-----: | :-----: | :-------------------------------------------------: | :-----: |
| ShuffleNetv2-1.5x | 320*320 | 128 | 280e | 24.9 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_m_shufflenetv2_320_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_shufflenetv2_320_coco.yml) |
| MobileNetv3-large-1x | 320*320 | 128 | 280e | 26.4 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_m_mbv3_320_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_mbv3_320_coco.yml) |
| ShuffleNetv2-1.5x | 416*416 | 128 | 280e | 27.4 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_m_shufflenetv2_416_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_shufflenetv2_416_coco.yml) |
| MobileNetv3-large-1x | 416*416 | 128 | 280e | 29.2 | -- | -- | [download](https://paddledet.bj.bcebos.com/models/picodet_m_mbv3_416_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_mbv3_416_coco.yml) |
| Backbone | Input size | lr schedule | Box AP(0.5:0.95) | Box AP(0.5) | FLOPS | Model Size | Inference Time | download | config |
| :------------------------ | :-------: | :-------: | :-----------: | :---: | :---: | :---: | :-----: | :-------------------------------------------------: | :-----: |
| ShuffleNetv2-1.5x | 320*320 | 280e | 25.3 | 41.2 | -- | 8.1M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_shufflenetv2_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_shufflenetv2_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_shufflenetv2_320_coco.yml) |
| MobileNetv3-large-1x | 320*320 | 280e | 26.7 | 44.1 | -- | 11.6M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_mbv3_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_mbv3_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_mbv3_320_coco.yml) |
| ShuffleNetv2-1.5x | 416*416 | 280e | 28.0 | 44.3 | -- | 8.1M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_shufflenetv2_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_shufflenetv2_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_shufflenetv2_416_coco.yml) |
| MobileNetv3-large-1x | 416*416 | 280e | 29.3 | 47.2 | -- | 11.6M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_mbv3_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_mbv3_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_mbv3_416_coco.yml) |
**Notes:**
- PicoDet inference speed is tested on Kirin 980 with 4 threads by arm8 and with FP16.
- PicoDet is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`.
- PicoDet is trained on COCO train2017 dataset and evaluated on val2017.
- PicoDet used 4 GPUs for training and mini-batch size as 128 or 96 on each GPU.
## Citations
```
@article{li2020gflv2,
title={Generalized Focal Loss V2: Learning Reliable Localization Quality Estimation for Dense Object Detection},
author={Li, Xiang and Wang, Wenhai and Hu, Xiaolin and Li, Jun and Tang, Jinhui and Yang, Jian},
journal={arXiv preprint arXiv:2011.12885},
@article{li2020generalized,
title={Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection},
author={Li, Xiang and Wang, Wenhai and Wu, Lijun and Chen, Shuo and Hu, Xiaolin and Li, Jun and Tang, Jinhui and Yang, Jian},
journal={arXiv preprint arXiv:2006.04388},
year={2020}
}
......
......@@ -2,11 +2,11 @@ worker_num: 6
TrainReader:
sample_transforms:
- Decode: {}
- RandomDistort: {}
- RandomCrop: {}
- RandomFlip: {prob: 0.5}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Resize: {target_size: [320, 320], keep_ratio: False, interp: 1}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- RandomDistort: {}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
......@@ -22,8 +22,8 @@ TrainReader:
EvalReader:
sample_transforms:
- Decode: {}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Resize: {interp: 1, target_size: [320, 320], keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
......@@ -36,8 +36,8 @@ TestReader:
image_shape: [3, 320, 320]
sample_transforms:
- Decode: {}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Resize: {interp: 1, target_size: [320, 320], keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
......
......@@ -2,11 +2,11 @@ worker_num: 6
TrainReader:
sample_transforms:
- Decode: {}
- RandomDistort: {}
- RandomCrop: {}
- RandomFlip: {prob: 0.5}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Resize: {target_size: [416, 416], keep_ratio: False, interp: 1}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- RandomDistort: {}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
......@@ -22,8 +22,8 @@ TrainReader:
EvalReader:
sample_transforms:
- Decode: {}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Resize: {interp: 1, target_size: [416, 416], keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
......@@ -36,8 +36,8 @@ TestReader:
image_shape: [3, 416, 416]
sample_transforms:
- Decode: {}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Resize: {interp: 1, target_size: [416, 416], keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
......
architecture: PicoDet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/MobileNetV3_large_x1_0_ssld_pretrained.pdparams
PicoDet:
backbone: MobileNetV3
......@@ -8,13 +8,13 @@ PicoDet:
MobileNetV3:
model_name: large
scale: 0.5
scale: 1.0
with_extra_blocks: false
extra_block_filters: []
feature_maps: [7, 13, 16]
PAN:
out_channel: 96
out_channel: 128
start_level: 0
end_level: 3
spatial_scales: [0.125, 0.0625, 0.03125]
......@@ -22,13 +22,13 @@ PAN:
PicoHead:
conv_feat:
name: PicoFeat
feat_in: 96
feat_out: 96
feat_in: 128
feat_out: 128
num_convs: 2
norm_type: bn
share_cls_reg: True
fpn_stride: [8, 16, 32]
feat_in_chan: 96
feat_in_chan: 128
prior_prob: 0.01
reg_max: 7
cell_offset: 0.5
......
......@@ -10,7 +10,7 @@ weights: output/picodet_m_r18_320_coco/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet18_vd_pretrained.pdparams
find_unused_parameters: True
use_ema: true
ema_decay: 0.9998
cycle_epoch: 40
snapshot_epoch: 10
PicoDet:
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/picodet_mbv3_0_5x.yml',
'_base_/picodet_mobilenetv3.yml',
'_base_/optimizer_280e.yml',
'_base_/picodet_320_reader.yml',
]
weights: output/picodet_m_mbv3_320_coco/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/MobileNetV3_large_x1_0_ssld_pretrained.pdparams
find_unused_parameters: True
use_ema: true
ema_decay: 0.9998
cycle_epoch: 40
snapshot_epoch: 10
MobileNetV3:
model_name: large
scale: 1.0
with_extra_blocks: false
extra_block_filters: []
feature_maps: [7, 13, 16]
PAN:
out_channel: 128
start_level: 0
end_level: 3
spatial_scales: [0.125, 0.0625, 0.03125]
PicoHead:
conv_feat:
name: PicoFeat
feat_in: 128
feat_out: 128
num_convs: 2
norm_type: bn
share_cls_reg: True
feat_in_chan: 128
TrainReader:
batch_size: 88
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/picodet_mbv3_0_5x.yml',
'_base_/picodet_mobilenetv3.yml',
'_base_/optimizer_280e.yml',
'_base_/picodet_416_reader.yml',
]
weights: output/picodet_m_mbv3_320_coco/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/MobileNetV3_large_x1_0_ssld_pretrained.pdparams
weights: output/picodet_m_mbv3_416_coco/model_final
find_unused_parameters: True
use_ema: true
ema_decay: 0.9998
cycle_epoch: 40
snapshot_epoch: 10
MobileNetV3:
model_name: large
scale: 1.0
with_extra_blocks: false
extra_block_filters: []
feature_maps: [7, 13, 16]
PAN:
out_channel: 128
start_level: 0
end_level: 3
spatial_scales: [0.125, 0.0625, 0.03125]
PicoHead:
conv_feat:
name: PicoFeat
feat_in: 128
feat_out: 128
num_convs: 2
norm_type: bn
share_cls_reg: True
feat_in_chan: 128
TrainReader:
batch_size: 56
......
......@@ -6,11 +6,11 @@ _BASE_: [
'_base_/picodet_320_reader.yml',
]
weights: output/picodet_s_shufflenetv2_320_coco/model_final
weights: output/picodet_m_shufflenetv2_320_coco/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ShuffleNetV2_x1_5_pretrained.pdparams
find_unused_parameters: True
use_ema: true
ema_decay: 0.9998
cycle_epoch: 40
snapshot_epoch: 10
ShuffleNetV2:
......
......@@ -6,11 +6,11 @@ _BASE_: [
'_base_/picodet_416_reader.yml',
]
weights: output/picodet_s_shufflenetv2_320_coco/model_final
weights: output/picodet_m_shufflenetv2_416_coco/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ShuffleNetV2_x1_5_pretrained.pdparams
find_unused_parameters: True
use_ema: true
ema_decay: 0.9998
cycle_epoch: 40
snapshot_epoch: 10
ShuffleNetV2:
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/picodet_mbv3_0_5x.yml',
'_base_/optimizer_280e.yml',
'_base_/picodet_320_reader.yml',
]
weights: output/picodet_s_mbv3_320_coco/model_final
find_unused_parameters: True
use_ema: true
ema_decay: 0.9998
snapshot_epoch: 10
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/picodet_mbv3_0_5x.yml',
'_base_/optimizer_280e.yml',
'_base_/picodet_416_reader.yml',
]
weights: output/picodet_s_mbv3_320_coco/model_final
find_unused_parameters: True
use_ema: true
ema_decay: 0.9998
snapshot_epoch: 10
......@@ -9,5 +9,5 @@ _BASE_: [
weights: output/picodet_s_shufflenetv2_320_coco/model_final
find_unused_parameters: True
use_ema: true
ema_decay: 0.9998
cycle_epoch: 40
snapshot_epoch: 10
......@@ -6,8 +6,8 @@ _BASE_: [
'_base_/picodet_416_reader.yml',
]
weights: output/picodet_s_shufflenetv2_320_coco/model_final
weights: output/picodet_s_shufflenetv2_416_coco/model_final
find_unused_parameters: True
use_ema: true
ema_decay: 0.9998
cycle_epoch: 40
snapshot_epoch: 10
pretrain_weights: https://paddledet.bj.bcebos.com/models/picodet_s_mbv3_320_coco.pdparams
pretrain_weights: https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_320_coco.pdparams
slim: QAT
QAT:
......@@ -9,15 +9,18 @@ QAT:
'quantizable_layer_type': ['Conv2D', 'Linear']}
print_model: False
epoch: 50
epoch: 80
LearningRate:
base_lr: 0.0001
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 35
- 45
- 60
- 70
- !LinearWarmup
start_factor: 0.
steps: 1000
steps: 100
TrainReader:
batch_size: 96
......@@ -90,8 +90,13 @@ class Trainer(object):
self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
if self.use_ema:
ema_decay = self.cfg.get('ema_decay', 0.9998)
cycle_epoch = self.cfg.get('cycle_epoch', -1)
self.ema = ModelEMA(
cfg['ema_decay'], self.model, use_thres_step=True)
self.model,
decay=ema_decay,
use_thres_step=True,
cycle_epoch=cycle_epoch)
# EvalDataset build with BatchSampler to evaluate in single device
# TODO: multi-device evaluate
......@@ -547,8 +552,9 @@ class Trainer(object):
if image_shape is None:
image_shape = [3, -1, -1]
self.model.eval()
if hasattr(self.model, 'deploy'): self.model.deploy = True
if hasattr(self.cfg, 'lite_deploy'):
self.model.lite_deploy = self.cfg.lite_deploy
# Save infer cfg
_dump_infer_config(self.cfg,
......
......@@ -41,7 +41,7 @@ class PicoDet(BaseArch):
self.backbone = backbone
self.neck = neck
self.head = head
self.deploy = False
self.lite_deploy = False
@classmethod
def from_config(cls, cfg, *args, **kwargs):
......@@ -63,7 +63,7 @@ class PicoDet(BaseArch):
body_feats = self.backbone(self.inputs)
fpn_feats = self.neck(body_feats)
head_outs = self.head(fpn_feats)
if self.training or self.deploy:
if self.training or self.lite_deploy:
return head_outs
else:
im_shape = self.inputs['im_shape']
......@@ -83,7 +83,7 @@ class PicoDet(BaseArch):
return loss
def get_pred(self):
if self.deploy:
if self.lite_deploy:
return {'picodet': self._forward()[0]}
else:
bbox_pred, bbox_num = self._forward()
......
......@@ -245,6 +245,9 @@ class GFLHead(nn.Layer):
if self.dgqp_module:
quality_score = self.dgqp_module(bbox_reg)
cls_logits = F.sigmoid(cls_logits) * quality_score
if not self.training:
cls_logits = F.sigmoid(cls_logits.transpose([0, 2, 3, 1]))
bbox_reg = bbox_reg.transpose([0, 2, 3, 1])
cls_logits_list.append(cls_logits)
bboxes_reg_list.append(bbox_reg)
......@@ -288,6 +291,11 @@ class GFLHead(nn.Layer):
bbox_targets_list = self._images_to_levels(gt_meta['bbox_targets'],
num_level_anchors)
num_total_pos = sum(gt_meta['pos_num'])
try:
num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone(
)) / paddle.distributed.get_world_size()
except:
num_total_pos = max(num_total_pos, 1)
loss_bbox_list, loss_dfl_list, loss_qfl_list, avg_factor = [], [], [], []
for cls_score, bbox_pred, grid_cells, labels, label_weights, bbox_targets, stride in zip(
......@@ -316,7 +324,7 @@ class GFLHead(nn.Layer):
weight_targets = F.sigmoid(cls_score.detach())
weight_targets = paddle.gather(
weight_targets.max(axis=1), pos_inds, axis=0)
weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0)
pos_bbox_pred_corners = self.distribution_project(pos_bbox_pred)
pos_decode_bbox_pred = distance2bbox(pos_grid_cell_centers,
pos_bbox_pred_corners)
......@@ -333,20 +341,18 @@ class GFLHead(nn.Layer):
# regression loss
loss_bbox = paddle.sum(
self.loss_bbox(pos_decode_bbox_pred,
pos_decode_bbox_targets) *
weight_targets.mean(axis=-1))
pos_decode_bbox_targets) * weight_targets)
# dfl loss
loss_dfl = self.loss_dfl(
pred_corners,
target_corners,
weight=weight_targets.unsqueeze(-1).expand([-1, 4]).reshape(
[-1]),
weight=weight_targets.expand([-1, 4]).reshape([-1]),
avg_factor=4.0)
else:
loss_bbox = bbox_pred.sum() * 0
loss_dfl = bbox_pred.sum() * 0
weight_targets = paddle.to_tensor([0])
weight_targets = paddle.to_tensor([0], dtype='float32')
# qfl loss
score = paddle.to_tensor(score)
......@@ -360,6 +366,12 @@ class GFLHead(nn.Layer):
avg_factor.append(weight_targets.sum())
avg_factor = sum(avg_factor)
try:
avg_factor = paddle.distributed.all_reduce(avg_factor.clone())
avg_factor = paddle.clip(
avg_factor / paddle.distributed.get_world_size(), min=1)
except:
avg_factor = max(avg_factor.item(), 1)
if avg_factor <= 0:
loss_qfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False)
loss_bbox = paddle.to_tensor(
......@@ -407,14 +419,11 @@ class GFLHead(nn.Layer):
mlvl_scores = []
for stride, cls_score, bbox_pred in zip(self.fpn_stride, cls_scores,
bbox_preds):
featmap_size = cls_score.shape[-2:]
featmap_size = cls_score.shape[:2]
y, x = self.get_single_level_center_point(
featmap_size, stride, cell_offset=cell_offset)
center_points = paddle.stack([x, y], axis=-1)
scores = F.sigmoid(
cls_score.transpose([1, 2, 0]).reshape(
[-1, self.cls_out_channels]))
bbox_pred = bbox_pred.transpose([1, 2, 0])
scores = cls_score.reshape([-1, self.cls_out_channels])
bbox_pred = self.distribution_project(bbox_pred) * stride
if scores.shape[0] > self.nms_pre:
......@@ -434,10 +443,6 @@ class GFLHead(nn.Layer):
im_scale = paddle.concat([scale_factor[::-1], scale_factor[::-1]])
mlvl_bboxes /= im_scale
mlvl_scores = paddle.concat(mlvl_scores)
if self.use_sigmoid:
# add a dummy background class to the backend when use_sigmoid
padding = paddle.zeros([mlvl_scores.shape[0], 1])
mlvl_scores = paddle.concat([mlvl_scores, padding], axis=1)
mlvl_scores = mlvl_scores.transpose([1, 0])
return mlvl_bboxes, mlvl_scores
......
......@@ -66,7 +66,7 @@ class PicoFeat(nn.Layer):
ConvNormLayer(
ch_in=in_c,
ch_out=feat_out,
filter_size=3,
filter_size=5,
stride=1,
groups=feat_out,
norm_type=norm_type,
......@@ -91,7 +91,7 @@ class PicoFeat(nn.Layer):
ConvNormLayer(
ch_in=in_c,
ch_out=feat_out,
filter_size=3,
filter_size=5,
stride=1,
groups=feat_out,
norm_type=norm_type,
......@@ -249,80 +249,9 @@ class PicoHead(GFLHead):
if not self.training:
cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1]))
bbox_pred = self.distribution_project(
bbox_pred.transpose([0, 2, 3, 1])) * self.fpn_stride[i]
bbox_pred = bbox_pred.transpose([0, 2, 3, 1])
cls_logits_list.append(cls_score)
bboxes_reg_list.append(bbox_pred)
return (cls_logits_list, bboxes_reg_list)
def get_bboxes_single(self,
cls_scores,
bbox_preds,
img_shape,
scale_factor,
rescale=True,
cell_offset=0):
assert len(cls_scores) == len(bbox_preds)
mlvl_bboxes = []
mlvl_scores = []
for stride, cls_score, bbox_pred in zip(self.fpn_stride, cls_scores,
bbox_preds):
featmap_size = cls_score.shape[0:2]
y, x = self.get_single_level_center_point(
featmap_size, stride, cell_offset=cell_offset)
center_points = paddle.stack([x, y], axis=-1)
scores = cls_score.reshape([-1, self.cls_out_channels])
if scores.shape[0] > self.nms_pre:
max_scores = scores.max(axis=1)
_, topk_inds = max_scores.topk(self.nms_pre)
center_points = center_points.gather(topk_inds)
bbox_pred = bbox_pred.gather(topk_inds)
scores = scores.gather(topk_inds)
bboxes = distance2bbox(
center_points, bbox_pred, max_shape=img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_bboxes = paddle.concat(mlvl_bboxes)
if rescale:
# [h_scale, w_scale] to [w_scale, h_scale, w_scale, h_scale]
im_scale = paddle.concat([scale_factor[::-1], scale_factor[::-1]])
mlvl_bboxes /= im_scale
mlvl_scores = paddle.concat(mlvl_scores)
mlvl_scores = mlvl_scores.transpose([1, 0])
return mlvl_bboxes, mlvl_scores
def decode(self, cls_scores, bbox_preds, im_shape, scale_factor,
cell_offset):
batch_bboxes = []
batch_scores = []
batch_size = cls_scores[0].shape[0]
for img_id in range(batch_size):
num_levels = len(cls_scores)
cls_score_list = [cls_scores[i][img_id] for i in range(num_levels)]
bbox_pred_list = [
bbox_preds[i].reshape([batch_size, -1, 4])[img_id]
for i in range(num_levels)
]
bboxes, scores = self.get_bboxes_single(
cls_score_list,
bbox_pred_list,
im_shape[img_id],
scale_factor[img_id],
cell_offset=cell_offset)
batch_bboxes.append(bboxes)
batch_scores.append(scores)
batch_bboxes = paddle.stack(batch_bboxes, axis=0)
batch_scores = paddle.stack(batch_scores, axis=0)
return batch_bboxes, batch_scores
def post_process(self, gfl_head_outs, im_shape, scale_factor):
cls_scores, bboxes_reg = gfl_head_outs
bboxes, score = self.decode(cls_scores, bboxes_reg, im_shape,
scale_factor, self.cell_offset)
bbox_pred, bbox_num, _ = self.nms(bboxes, score)
return bbox_pred, bbox_num
......@@ -39,21 +39,13 @@ class PAN(nn.Layer):
spatial_scales (list[float]): the spatial scales between input feature
maps and original input image which can be derived from the output
shape of backbone by from_config
has_extra_convs (bool): whether to add extra conv to the last level.
default False
extra_stage (int): the number of extra stages added to the last level.
default 1
use_c5 (bool): Whether to use c5 as the input of extra stage,
otherwise p5 is used. default True
start_level (int): Index of the start input backbone level used to
build the feature pyramid. Default: 0.
end_level (int): Index of the end input backbone level (exclusive) to
build the feature pyramid. Default: -1, which means the last level.
norm_type (string|None): The normalization type in FPN module. If
norm_type is None, norm will not be used after conv and if
norm_type is string, bn, gn, sync_bn are available. default None
norm_decay (float): weight decay for normalization layer weights.
default 0.
freeze_norm (bool): whether to freeze normalization layer.
default False
relu_before_extra_convs (bool): whether to add relu before extra convs.
default False
"""
def __init__(self,
......
......@@ -251,13 +251,40 @@ class OptimizerBuilder():
class ModelEMA(object):
def __init__(self, decay, model, use_thres_step=False):
"""
Exponential Weighted Average for Deep Neutal Networks
Args:
model (nn.Layer): Detector of model.
decay (int): The decay used for updating ema parameter.
Ema's parameter are updated with the formula:
`ema_param = decay * ema_param + (1 - decay) * cur_param`.
Defaults is 0.9998.
use_thres_step (bool): Whether set decay by thres_step or not
cycle_epoch (int): The epoch of interval to reset ema_param and
step. Defaults is -1, which means not reset. Its function is to
add a regular effect to ema, which is set according to experience
and is effective when the total training epoch is large.
"""
def __init__(self,
model,
decay=0.9998,
use_thres_step=False,
cycle_epoch=-1):
self.step = 0
self.epoch = 0
self.decay = decay
self.state_dict = dict()
for k, v in model.state_dict().items():
self.state_dict[k] = paddle.zeros_like(v)
self.use_thres_step = use_thres_step
self.cycle_epoch = cycle_epoch
def reset(self):
self.step = 0
self.epoch = 0
for k, v in self.state_dict.items():
self.state_dict[k] = paddle.zeros_like(v)
def update(self, model):
if self.use_thres_step:
......@@ -280,4 +307,8 @@ class ModelEMA(object):
v = v / (1 - self._decay**self.step)
v.stop_gradient = True
state_dict[k] = v
self.epoch += 1
if self.cycle_epoch > 0 and self.epoch == self.cycle_epoch:
self.reset()
return state_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册