未验证 提交 7d625608 编写于 作者: G Guanghua Yu 提交者: GitHub

update PicoDet Architecture and config (#4323)

* update PicoDet Architecture and config
上级 04d837cf
...@@ -32,7 +32,7 @@ GFLHead: ...@@ -32,7 +32,7 @@ GFLHead:
fpn_stride: [8, 16, 32, 64, 128] fpn_stride: [8, 16, 32, 64, 128]
prior_prob: 0.01 prior_prob: 0.01
reg_max: 16 reg_max: 16
loss_qfl: loss_class:
name: QualityFocalLoss name: QualityFocalLoss
use_sigmoid: True use_sigmoid: True
beta: 2.0 beta: 2.0
......
...@@ -37,7 +37,7 @@ GFLHead: ...@@ -37,7 +37,7 @@ GFLHead:
reg_topk: 4 reg_topk: 4
reg_channels: 64 reg_channels: 64
add_mean: True add_mean: True
loss_qfl: loss_class:
name: QualityFocalLoss name: QualityFocalLoss
use_sigmoid: False use_sigmoid: False
beta: 2.0 beta: 2.0
......
...@@ -5,41 +5,29 @@ ...@@ -5,41 +5,29 @@
We developed a series of lightweight models, which named `PicoDet`. Because of its excellent performance, it is very suitable for deployment on mobile or CPU. We developed a series of lightweight models, which named `PicoDet`. Because of its excellent performance, it is very suitable for deployment on mobile or CPU.
Optimizing method of we use: - 🌟 Higher mAP: The **first** model which within 1M parameter with mAP reaching 30+.
- [ATSS](https://arxiv.org/abs/1912.02424) - 🚀 Faster latency: 114FPS on mobile ARM CPU.
- [Generalized Focal Loss](https://arxiv.org/abs/2006.04388) - 😊 Deploy friendly: support PaddleLite/MNN/NCNN/OpenVINO and provide C++/Python/Android implementation.
- Lr Cosine Decay and cycle-EMA - 😍 Advanced algorithm: use the most advanced algorithms and innovate, such as ESNet, CSP-PAN, SimOTA with VFL, etc.
- lightweight head
## Requirements ## Requirements
- PaddlePaddle == 2.1.2 - PaddlePaddle >= 2.1.2
- PaddleSlim >= 2.1.1 - PaddleSlim >= 2.1.1
## Comming soon ## Comming soon
- [ ] Benchmark of PicoDet. - [ ] More series of model, such as Smaller or larger model.
- [ ] deploy for most platforms, such as PaddleLite、MNN、ncnn、openvino etc. - [ ] Pretrained models for more scenarios.
- [ ] PicoDet-XS and PicoDet-L series of model.
- [ ] Slim for PicoDet.
- [ ] More features in need. - [ ] More features in need.
## Model Zoo ## Model Zoo
### PicoDet-S | Model | Input size | lr schedule | Box AP(0.5:0.95) | Box AP(0.5) | FLOPS | Model Size | Inference Time | download | config |
| Backbone | Input size | lr schedule | Box AP(0.5:0.95) | Box AP(0.5) | FLOPS | Model Size | Inference Time | download | config |
| :------------------------ | :-------: | :-------: | :------: | :---: | :---: | :---: | :------------: | :-------------------------------------------------: | :-----: | | :------------------------ | :-------: | :-------: | :------: | :---: | :---: | :---: | :------------: | :-------------------------------------------------: | :-----: |
| ShuffleNetv2-1x | 320*320 | 280e | 22.8 | 37.7 | -- | 3.8M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_shufflenetv2_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_shufflenetv2_320_coco.yml) | | PicoDet-S | 320*320 | 300e | 27.1 | 41.4 | -- | 3.9M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_320_coco.yml) |
| ShuffleNetv2-1x | 416*416 | 280e | 25.3 | 41.1 | -- | 3.8M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_shufflenetv2_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_shufflenetv2_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_shufflenetv2_416_coco.yml) | | PicoDet-S | 416*416 | 300e | 30.6 | 45.5 | -- | 3.9M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_s_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_416_coco.yml) |
| PicoDet-M | 320*320 | 300e | - | 41.2 | -- | 8.4M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_320_coco.yml) |
| PicoDet-M | 416*416 | 300e | 34.3 | 49.8 | -- | 8.4M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_416_coco.yml) |
### PicoDet-M
| Backbone | Input size | lr schedule | Box AP(0.5:0.95) | Box AP(0.5) | FLOPS | Model Size | Inference Time | download | config |
| :------------------------ | :-------: | :-------: | :-----------: | :---: | :---: | :---: | :-----: | :-------------------------------------------------: | :-----: |
| ShuffleNetv2-1.5x | 320*320 | 280e | 25.3 | 41.2 | -- | 8.1M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_shufflenetv2_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_shufflenetv2_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_shufflenetv2_320_coco.yml) |
| MobileNetv3-large-1x | 320*320 | 280e | 26.7 | 44.1 | -- | 11.6M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_mbv3_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_mbv3_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_mbv3_320_coco.yml) |
| ShuffleNetv2-1.5x | 416*416 | 280e | 28.0 | 44.3 | -- | 8.1M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_shufflenetv2_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_shufflenetv2_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_shufflenetv2_416_coco.yml) |
| MobileNetv3-large-1x | 416*416 | 280e | 29.3 | 47.2 | -- | 11.6M | -- | [model](https://paddledet.bj.bcebos.com/models/picodet_m_mbv3_416_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_mbv3_416_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_m_mbv3_416_coco.yml) |
**Notes:** **Notes:**
......
worker_num: 8 worker_num: 6
TrainReader: TrainReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
...@@ -9,13 +9,10 @@ TrainReader: ...@@ -9,13 +9,10 @@ TrainReader:
- BatchRandomResize: {target_size: [256, 288, 320, 352, 384], random_size: True, random_interp: True, keep_ratio: False} - BatchRandomResize: {target_size: [256, 288, 320, 352, 384], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
- Gt2GFLTarget:
downsample_ratios: [8, 16, 32]
grid_cell_scale: 5
cell_offset: 0.5
batch_size: 128 batch_size: 128
shuffle: true shuffle: true
drop_last: true drop_last: true
collate_batch: false
EvalReader: EvalReader:
...@@ -32,7 +29,7 @@ EvalReader: ...@@ -32,7 +29,7 @@ EvalReader:
TestReader: TestReader:
inputs_def: inputs_def:
image_shape: [3, 320, 320] image_shape: [1, 3, 320, 320]
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [320, 320], keep_ratio: False} - Resize: {interp: 2, target_size: [320, 320], keep_ratio: False}
......
...@@ -9,13 +9,10 @@ TrainReader: ...@@ -9,13 +9,10 @@ TrainReader:
- BatchRandomResize: {target_size: [352, 384, 416, 448, 480], random_size: True, random_interp: True, keep_ratio: False} - BatchRandomResize: {target_size: [352, 384, 416, 448, 480], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {} - Permute: {}
- Gt2GFLTarget:
downsample_ratios: [8, 16, 32]
grid_cell_scale: 5
cell_offset: 0.5
batch_size: 80 batch_size: 80
shuffle: true shuffle: true
drop_last: true drop_last: true
collate_batch: false
EvalReader: EvalReader:
...@@ -32,7 +29,7 @@ EvalReader: ...@@ -32,7 +29,7 @@ EvalReader:
TestReader: TestReader:
inputs_def: inputs_def:
image_shape: [3, 416, 416] image_shape: [1, 3, 416, 416]
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {interp: 2, target_size: [416, 416], keep_ratio: False} - Resize: {interp: 2, target_size: [416, 416], keep_ratio: False}
......
worker_num: 6
TrainReader:
sample_transforms:
- Decode: {}
- RandomCrop: {}
- RandomFlip: {prob: 0.5}
- RandomDistort: {}
batch_transforms:
- BatchRandomResize: {target_size: [576, 608, 640, 672, 704], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_size: 56
shuffle: true
drop_last: true
collate_batch: false
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: [640, 640], keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 8
shuffle: false
TestReader:
inputs_def:
image_shape: [1, 3, 640, 640]
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: [640, 640], keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1
shuffle: false
architecture: PicoDet architecture: PicoDet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/MobileNetV3_large_x1_0_ssld_pretrained.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_0_pretrained.pdparams
PicoDet: PicoDet:
backbone: MobileNetV3 backbone: ESNet
neck: PAN neck: CSPPAN
head: PicoHead head: PicoHead
MobileNetV3: ESNet:
model_name: large
scale: 1.0 scale: 1.0
with_extra_blocks: false feature_maps: [4, 11, 14]
extra_block_filters: [] act: hard_swish
feature_maps: [7, 13, 16] channel_ratio: [0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5, 0.625, 1.0, 0.625, 0.75]
PAN: CSPPAN:
out_channel: 128 out_channels: 128
start_level: 0 use_depthwise: True
end_level: 3 num_csp_blocks: 1
spatial_scales: [0.125, 0.0625, 0.03125] num_features: 4
PicoHead: PicoHead:
conv_feat: conv_feat:
name: PicoFeat name: PicoFeat
feat_in: 128 feat_in: 128
feat_out: 128 feat_out: 128
num_convs: 2 num_convs: 4
num_fpn_stride: 4
norm_type: bn norm_type: bn
share_cls_reg: True share_cls_reg: True
fpn_stride: [8, 16, 32] fpn_stride: [8, 16, 32, 64]
feat_in_chan: 128 feat_in_chan: 128
prior_prob: 0.01 prior_prob: 0.01
reg_max: 7 reg_max: 7
cell_offset: 0.5 cell_offset: 0.5
loss_qfl: loss_class:
name: QualityFocalLoss name: VarifocalLoss
use_sigmoid: True use_sigmoid: True
beta: 2.0 iou_weighted: True
loss_weight: 1.0 loss_weight: 1.0
loss_dfl: loss_dfl:
name: DistributionFocalLoss name: DistributionFocalLoss
...@@ -43,6 +43,10 @@ PicoHead: ...@@ -43,6 +43,10 @@ PicoHead:
loss_bbox: loss_bbox:
name: GIoULoss name: GIoULoss
loss_weight: 2.0 loss_weight: 2.0
assigner:
name: SimOTAAssigner
candidate_topk: 10
iou_weight: 6
nms: nms:
name: MultiClassNMS name: MultiClassNMS
nms_top_k: 1000 nms_top_k: 1000
......
_BASE_: [ _BASE_: [
'../datasets/coco_detection.yml', '../../datasets/coco_detection.yml',
'../runtime.yml', '../../runtime.yml',
'_base_/picodet_shufflenetv2_1x.yml', '../_base_/picodet_esnet.yml',
'_base_/optimizer_300e.yml', '../_base_/optimizer_300e.yml',
'_base_/picodet_320_reader.yml', '../_base_/picodet_416_reader.yml',
] ]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/LCNet_x1_0_pretrained.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/LCNet_x1_0_pretrained.pdparams
weights: output/picodet_s_lcnet_320_coco/model_final weights: output/picodet_lcnet_416_coco/model_final
find_unused_parameters: True find_unused_parameters: True
use_ema: true use_ema: true
cycle_epoch: 40 cycle_epoch: 40
...@@ -15,9 +15,23 @@ snapshot_epoch: 10 ...@@ -15,9 +15,23 @@ snapshot_epoch: 10
PicoDet: PicoDet:
backbone: LCNet backbone: LCNet
neck: PAN neck: CSPPAN
head: PicoHead head: PicoHead
LCNet: LCNet:
scale: 1.0 scale: 1.0
feature_maps: [3, 4, 5] feature_maps: [3, 4, 5]
CSPPAN:
out_channels: 96
PicoHead:
conv_feat:
name: PicoFeat
feat_in: 96
feat_out: 96
num_convs: 2
num_fpn_stride: 4
norm_type: bn
share_cls_reg: True
feat_in_chan: 96
_BASE_: [
'../../datasets/coco_detection.yml',
'../../runtime.yml',
'../_base_/picodet_esnet.yml',
'../_base_/optimizer_300e.yml',
'../_base_/picodet_416_reader.yml',
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/MobileNetV3_large_x1_0_ssld_pretrained.pdparams
weights: output/picodet_mobilenetv3_416_coco/model_final
find_unused_parameters: True
use_ema: true
cycle_epoch: 40
snapshot_epoch: 10
PicoDet:
backbone: MobileNetV3
neck: CSPPAN
head: PicoHead
MobileNetV3:
model_name: large
scale: 1.0
with_extra_blocks: false
extra_block_filters: []
feature_maps: [7, 13, 16]
TrainReader:
batch_size: 56
LearningRate:
base_lr: 0.3
schedulers:
- !CosineDecay
max_epochs: 300
- !LinearWarmup
start_factor: 0.1
steps: 300
_BASE_: [
'../../datasets/coco_detection.yml',
'../../runtime.yml',
'../_base_/picodet_esnet.yml',
'../_base_/optimizer_300e.yml',
'../_base_/picodet_640_reader.yml',
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet18_vd_pretrained.pdparams
weights: output/picodet_r18_640_coco/model_final
find_unused_parameters: True
use_ema: true
cycle_epoch: 40
snapshot_epoch: 10
PicoDet:
backbone: ResNet
neck: CSPPAN
head: PicoHead
ResNet:
depth: 18
variant: d
return_idx: [1, 2, 3]
freeze_at: -1
freeze_norm: false
norm_decay: 0.
TrainReader:
batch_size: 56
LearningRate:
base_lr: 0.3
schedulers:
- !CosineDecay
max_epochs: 300
- !LinearWarmup
start_factor: 0.1
steps: 300
architecture: PicoDet _BASE_: [
'../../datasets/coco_detection.yml',
'../../runtime.yml',
'../_base_/picodet_esnet.yml',
'../_base_/optimizer_300e.yml',
'../_base_/picodet_416_reader.yml',
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ShuffleNetV2_x1_0_pretrained.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ShuffleNetV2_x1_0_pretrained.pdparams
weights: output/picodet_shufflenetv2_416_coco/model_final
find_unused_parameters: True
use_ema: true
cycle_epoch: 40
snapshot_epoch: 10
PicoDet: PicoDet:
backbone: ShuffleNetV2 backbone: ShuffleNetV2
neck: PAN neck: CSPPAN
head: PicoHead head: PicoHead
ShuffleNetV2: ShuffleNetV2:
...@@ -11,11 +23,8 @@ ShuffleNetV2: ...@@ -11,11 +23,8 @@ ShuffleNetV2:
feature_maps: [5, 13, 17] feature_maps: [5, 13, 17]
act: leaky_relu act: leaky_relu
PAN: CSPPAN:
out_channel: 96 out_channels: 96
start_level: 0
end_level: 3
spatial_scales: [0.125, 0.0625, 0.03125]
PicoHead: PicoHead:
conv_feat: conv_feat:
...@@ -23,27 +32,7 @@ PicoHead: ...@@ -23,27 +32,7 @@ PicoHead:
feat_in: 96 feat_in: 96
feat_out: 96 feat_out: 96
num_convs: 2 num_convs: 2
num_fpn_stride: 4
norm_type: bn norm_type: bn
share_cls_reg: True share_cls_reg: True
fpn_stride: [8, 16, 32]
feat_in_chan: 96 feat_in_chan: 96
prior_prob: 0.01
reg_max: 7
cell_offset: 0.5
loss_qfl:
name: QualityFocalLoss
use_sigmoid: True
beta: 2.0
loss_weight: 1.0
loss_dfl:
name: DistributionFocalLoss
loss_weight: 0.25
loss_bbox:
name: GIoULoss
loss_weight: 2.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 100
score_threshold: 0.025
nms_threshold: 0.6
_BASE_: [ _BASE_: [
'../datasets/coco_detection.yml', '../datasets/coco_detection.yml',
'../runtime.yml', '../runtime.yml',
'_base_/picodet_mbv3_0_5x.yml', '_base_/picodet_esnet.yml',
'_base_/optimizer_300e.yml', '_base_/optimizer_300e.yml',
'_base_/picodet_320_reader.yml', '_base_/picodet_320_reader.yml',
] ]
weights: output/picodet_l_r18_320_coco/model_final pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_25_pretrained.pdparams
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet18_vd_pretrained.pdparams weights: output/picodet_l_320_coco/model_final
find_unused_parameters: True find_unused_parameters: True
use_ema: true use_ema: true
cycle_epoch: 40 cycle_epoch: 40
snapshot_epoch: 10 snapshot_epoch: 10
PicoDet: ESNet:
backbone: ResNet scale: 1.25
neck: PAN feature_maps: [4, 11, 14]
head: PicoHead act: hard_swish
channel_ratio: [0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5, 0.625, 1.0, 0.625, 0.75]
ResNet:
depth: 18
variant: d
return_idx: [1, 2, 3]
freeze_at: -1
freeze_norm: false
norm_decay: 0.
PAN:
out_channel: 128
start_level: 0
end_level: 3
spatial_scales: [0.125, 0.0625, 0.03125]
PicoHead: PicoHead:
conv_feat: conv_feat:
name: PicoFeat name: PicoFeat
feat_in: 128 feat_in: 128
feat_out: 128 feat_out: 128
num_convs: 2 num_convs: 4
num_fpn_stride: 4
norm_type: bn norm_type: bn
share_cls_reg: True share_cls_reg: False
feat_in_chan: 128 feat_in_chan: 128
TrainReader: TrainReader:
...@@ -49,7 +37,7 @@ LearningRate: ...@@ -49,7 +37,7 @@ LearningRate:
base_lr: 0.3 base_lr: 0.3
schedulers: schedulers:
- !CosineDecay - !CosineDecay
max_epochs: 280 max_epochs: 300
- !LinearWarmup - !LinearWarmup
start_factor: 0.1 start_factor: 0.1
steps: 300 steps: 300
_BASE_: [ _BASE_: [
'../datasets/coco_detection.yml', '../datasets/coco_detection.yml',
'../runtime.yml', '../runtime.yml',
'_base_/picodet_mobilenetv3.yml', '_base_/picodet_esnet.yml',
'_base_/optimizer_300e.yml', '_base_/optimizer_300e.yml',
'_base_/picodet_416_reader.yml', '_base_/picodet_416_reader.yml',
] ]
weights: output/picodet_m_mbv3_416_coco/model_final pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_25_pretrained.pdparams
weights: output/picodet_l_416_coco/model_final
find_unused_parameters: True find_unused_parameters: True
use_ema: true use_ema: true
cycle_epoch: 40 cycle_epoch: 40
snapshot_epoch: 10 snapshot_epoch: 10
ESNet:
scale: 1.25
feature_maps: [4, 11, 14]
act: hard_swish
channel_ratio: [0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5, 0.625, 1.0, 0.625, 0.75]
PicoHead:
conv_feat:
name: PicoFeat
feat_in: 128
feat_out: 128
num_convs: 4
num_fpn_stride: 4
norm_type: bn
share_cls_reg: False
feat_in_chan: 128
TrainReader: TrainReader:
batch_size: 56 batch_size: 48
LearningRate: LearningRate:
base_lr: 0.3 base_lr: 0.3
schedulers: schedulers:
- !CosineDecay - !CosineDecay
max_epochs: 280 max_epochs: 300
- !LinearWarmup - !LinearWarmup
start_factor: 0.1 start_factor: 0.1
steps: 300 steps: 300
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/picodet_esnet.yml',
'_base_/optimizer_300e.yml',
'_base_/picodet_640_reader.yml',
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_25_pretrained.pdparams
weights: output/picodet_l_640_coco/model_final
find_unused_parameters: True
use_ema: true
cycle_epoch: 40
snapshot_epoch: 10
ESNet:
scale: 1.25
feature_maps: [4, 11, 14]
act: hard_swish
channel_ratio: [0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5, 0.625, 1.0, 0.625, 0.75]
PicoHead:
conv_feat:
name: PicoFeat
feat_in: 128
feat_out: 128
num_convs: 4
num_fpn_stride: 4
norm_type: bn
share_cls_reg: False
feat_in_chan: 128
TrainReader:
batch_size: 48
LearningRate:
base_lr: 0.3
schedulers:
- !CosineDecay
max_epochs: 300
- !LinearWarmup
start_factor: 0.1
steps: 300
_BASE_: [ _BASE_: [
'../datasets/coco_detection.yml', '../datasets/coco_detection.yml',
'../runtime.yml', '../runtime.yml',
'_base_/picodet_shufflenetv2_1x.yml', '_base_/picodet_esnet.yml',
'_base_/optimizer_300e.yml', '_base_/optimizer_300e.yml',
'_base_/picodet_320_reader.yml', '_base_/picodet_320_reader.yml',
] ]
weights: output/picodet_s_shufflenetv2_320_coco/model_final weights: output/picodet_m_320_coco/model_final
find_unused_parameters: True find_unused_parameters: True
use_ema: true use_ema: true
cycle_epoch: 40 cycle_epoch: 40
......
_BASE_: [ _BASE_: [
'../datasets/coco_detection.yml', '../datasets/coco_detection.yml',
'../runtime.yml', '../runtime.yml',
'_base_/picodet_shufflenetv2_1x.yml', '_base_/picodet_esnet.yml',
'_base_/optimizer_300e.yml', '_base_/optimizer_300e.yml',
'_base_/picodet_416_reader.yml', '_base_/picodet_416_reader.yml',
] ]
weights: output/picodet_s_shufflenetv2_416_coco/model_final weights: output/picodet_m_416_coco/model_final
find_unused_parameters: True find_unused_parameters: True
use_ema: true use_ema: true
cycle_epoch: 40 cycle_epoch: 40
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/picodet_mobilenetv3.yml',
'_base_/optimizer_300e.yml',
'_base_/picodet_320_reader.yml',
]
weights: output/picodet_m_mbv3_320_coco/model_final
find_unused_parameters: True
use_ema: true
cycle_epoch: 40
snapshot_epoch: 10
TrainReader:
batch_size: 88
_BASE_: [ _BASE_: [
'../datasets/coco_detection.yml', '../datasets/coco_detection.yml',
'../runtime.yml', '../runtime.yml',
'_base_/picodet_shufflenetv2_1x.yml', '_base_/picodet_esnet.yml',
'_base_/optimizer_300e.yml', '_base_/optimizer_300e.yml',
'_base_/picodet_320_reader.yml', '_base_/picodet_320_reader.yml',
] ]
weights: output/picodet_m_shufflenetv2_320_coco/model_final pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x0_75_pretrained.pdparams
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ShuffleNetV2_x1_5_pretrained.pdparams weights: output/picodet_s_320_coco/model_final
find_unused_parameters: True find_unused_parameters: True
use_ema: true use_ema: true
cycle_epoch: 40 cycle_epoch: 40
snapshot_epoch: 10 snapshot_epoch: 10
ShuffleNetV2: ESNet:
scale: 1.5 scale: 0.75
feature_maps: [5, 13, 17] feature_maps: [4, 11, 14]
act: leaky_relu act: hard_swish
channel_ratio: [0.875, 0.5, 0.5, 0.5, 0.625, 0.5, 0.625, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
PAN: CSPPAN:
out_channel: 128 out_channels: 96
start_level: 0
end_level: 3
spatial_scales: [0.125, 0.0625, 0.03125]
PicoHead: PicoHead:
conv_feat: conv_feat:
name: PicoFeat name: PicoFeat
feat_in: 128 feat_in: 96
feat_out: 128 feat_out: 96
num_convs: 2 num_convs: 2
num_fpn_stride: 4
norm_type: bn norm_type: bn
share_cls_reg: True share_cls_reg: True
feat_in_chan: 128 feat_in_chan: 96
_BASE_: [ _BASE_: [
'../datasets/coco_detection.yml', '../datasets/coco_detection.yml',
'../runtime.yml', '../runtime.yml',
'_base_/picodet_shufflenetv2_1x.yml', '_base_/picodet_esnet.yml',
'_base_/optimizer_300e.yml', '_base_/optimizer_300e.yml',
'_base_/picodet_416_reader.yml', '_base_/picodet_416_reader.yml',
] ]
weights: output/picodet_m_shufflenetv2_416_coco/model_final pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x0_75_pretrained.pdparams
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ShuffleNetV2_x1_5_pretrained.pdparams weights: output/picodet_s_416_coco/model_final
find_unused_parameters: True find_unused_parameters: True
use_ema: true use_ema: true
cycle_epoch: 40 cycle_epoch: 40
snapshot_epoch: 10 snapshot_epoch: 10
ShuffleNetV2: ESNet:
scale: 1.5 scale: 0.75
feature_maps: [5, 13, 17] feature_maps: [4, 11, 14]
act: leaky_relu act: hard_swish
channel_ratio: [0.875, 0.5, 0.5, 0.5, 0.625, 0.5, 0.625, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
PAN: CSPPAN:
out_channel: 128 out_channels: 96
start_level: 0
end_level: 3
spatial_scales: [0.125, 0.0625, 0.03125]
PicoHead: PicoHead:
conv_feat: conv_feat:
name: PicoFeat name: PicoFeat
feat_in: 128 feat_in: 96
feat_out: 128 feat_out: 96
num_convs: 2 num_convs: 2
num_fpn_stride: 4
norm_type: bn norm_type: bn
share_cls_reg: True share_cls_reg: True
feat_in_chan: 128 feat_in_chan: 96
TrainReader:
batch_size: 88
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/picodet_shufflenetv2_1x.yml',
'_base_/optimizer_300e.yml',
'_base_/picodet_416_reader.yml',
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/LCNet_x1_0_pretrained.pdparams
weights: output/picodet_s_lcnet_416_coco/model_final
find_unused_parameters: True
use_ema: true
cycle_epoch: 40
snapshot_epoch: 10
PicoDet:
backbone: LCNet
neck: PAN
head: PicoHead
LCNet:
scale: 1.0
feature_maps: [3, 4, 5]
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/picodet_shufflenetv2_1x.yml',
'_base_/optimizer_280e.yml',
'_base_/picodet_320_reader.yml',
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/LCNet_x0_25_pretrained.pdparams
weights: output/picodet_s_shufflenetv2_320_coco/model_final
find_unused_parameters: True
use_ema: true
cycle_epoch: 40
snapshot_epoch: 10
PicoDet:
backbone: LCNet
neck: PAN
head: PicoHead
LCNet:
scale: 0.25
feature_maps: [3, 4, 5]
...@@ -543,6 +543,8 @@ class Trainer(object): ...@@ -543,6 +543,8 @@ class Trainer(object):
def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True): def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True):
image_shape = None image_shape = None
im_shape = [None, 2]
scale_factor = [None, 2]
if self.cfg.architecture in MOT_ARCH: if self.cfg.architecture in MOT_ARCH:
test_reader_name = 'TestMOTReader' test_reader_name = 'TestMOTReader'
else: else:
...@@ -553,8 +555,12 @@ class Trainer(object): ...@@ -553,8 +555,12 @@ class Trainer(object):
# set image_shape=[None, 3, -1, -1] as default # set image_shape=[None, 3, -1, -1] as default
if image_shape is None: if image_shape is None:
image_shape = [None, 3, -1, -1] image_shape = [None, 3, -1, -1]
if len(image_shape) == 3: if len(image_shape) == 3:
image_shape = [None] + image_shape image_shape = [None] + image_shape
else:
im_shape = [image_shape[0], 2]
scale_factor = [image_shape[0], 2]
if hasattr(self.model, 'deploy'): if hasattr(self.model, 'deploy'):
self.model.deploy = True self.model.deploy = True
...@@ -571,9 +577,9 @@ class Trainer(object): ...@@ -571,9 +577,9 @@ class Trainer(object):
"image": InputSpec( "image": InputSpec(
shape=image_shape, name='image'), shape=image_shape, name='image'),
"im_shape": InputSpec( "im_shape": InputSpec(
shape=[None, 2], name='im_shape'), shape=im_shape, name='im_shape'),
"scale_factor": InputSpec( "scale_factor": InputSpec(
shape=[None, 2], name='scale_factor') shape=scale_factor, name='scale_factor')
}] }]
if self.cfg.architecture == 'DeepSORT': if self.cfg.architecture == 'DeepSORT':
input_spec[0].update({ input_spec[0].update({
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
from . import utils from . import utils
from . import task_aligned_assigner from . import task_aligned_assigner
from . import atss_assigner from . import atss_assigner
from . import simota_assigner
from .utils import * from .utils import *
from .task_aligned_assigner import * from .task_aligned_assigner import *
from .atss_assigner import * from .atss_assigner import *
from .simota_assigner import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import paddle.nn.functional as F
import paddle.nn as nn
from ppdet.modeling.losses.varifocal_loss import varifocal_loss
from ppdet.modeling.bbox_utils import batch_bbox_overlaps
from ppdet.core.workspace import register
@register
class SimOTAAssigner(object):
"""Computes matching between predictions and ground truth.
Args:
center_radius (int | float, optional): Ground truth center size
to judge whether a prior is in center. Default 2.5.
candidate_topk (int, optional): The candidate top-k which used to
get top-k ious to calculate dynamic-k. Default 10.
iou_weight (int | float, optional): The scale factor for regression
iou cost. Default 3.0.
cls_weight (int | float, optional): The scale factor for classification
cost. Default 1.0.
num_classes (int): The num_classes of dataset.
use_vfl (int): Whether to use varifocal_loss when calculating the cost matrix.
"""
__shared__ = ['num_classes']
def __init__(self,
center_radius=2.5,
candidate_topk=10,
iou_weight=3.0,
cls_weight=1.0,
num_classes=80,
use_vfl=True):
self.center_radius = center_radius
self.candidate_topk = candidate_topk
self.iou_weight = iou_weight
self.cls_weight = cls_weight
self.num_classes = num_classes
self.use_vfl = use_vfl
def get_in_gt_and_in_center_info(self, priors, gt_bboxes):
num_gt = gt_bboxes.shape[0]
repeated_x = priors[:, 0].unsqueeze(1).tile([1, num_gt])
repeated_y = priors[:, 1].unsqueeze(1).tile([1, num_gt])
repeated_stride_x = priors[:, 2].unsqueeze(1).tile([1, num_gt])
repeated_stride_y = priors[:, 3].unsqueeze(1).tile([1, num_gt])
# is prior centers in gt bboxes, shape: [n_prior, n_gt]
l_ = repeated_x - gt_bboxes[:, 0]
t_ = repeated_y - gt_bboxes[:, 1]
r_ = gt_bboxes[:, 2] - repeated_x
b_ = gt_bboxes[:, 3] - repeated_y
deltas = paddle.stack([l_, t_, r_, b_], axis=1)
is_in_gts = deltas.min(axis=1) > 0
is_in_gts_all = is_in_gts.sum(axis=1) > 0
# is prior centers in gt centers
gt_cxs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
gt_cys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
ct_box_l = gt_cxs - self.center_radius * repeated_stride_x
ct_box_t = gt_cys - self.center_radius * repeated_stride_y
ct_box_r = gt_cxs + self.center_radius * repeated_stride_x
ct_box_b = gt_cys + self.center_radius * repeated_stride_y
cl_ = repeated_x - ct_box_l
ct_ = repeated_y - ct_box_t
cr_ = ct_box_r - repeated_x
cb_ = ct_box_b - repeated_y
ct_deltas = paddle.stack([cl_, ct_, cr_, cb_], axis=1)
is_in_cts = ct_deltas.min(axis=1) > 0
is_in_cts_all = is_in_cts.sum(axis=1) > 0
# in boxes or in centers, shape: [num_priors]
is_in_gts_or_centers = paddle.logical_or(is_in_gts_all, is_in_cts_all)
is_in_gts_or_centers_inds = paddle.nonzero(
is_in_gts_or_centers).squeeze(1)
# both in boxes and centers, shape: [num_fg, num_gt]
is_in_boxes_and_centers = paddle.logical_and(
paddle.gather(
is_in_gts.cast('int'), is_in_gts_or_centers_inds,
axis=0).cast('bool'),
paddle.gather(
is_in_cts.cast('int'), is_in_gts_or_centers_inds,
axis=0).cast('bool'))
return is_in_gts_or_centers, is_in_boxes_and_centers
def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask):
matching_matrix = np.zeros_like(cost.numpy())
# select candidate topk ious for dynamic-k calculation
topk_ious, _ = paddle.topk(pairwise_ious, self.candidate_topk, axis=0)
# calculate dynamic k for each gt
dynamic_ks = paddle.clip(topk_ious.sum(0).cast('int'), min=1)
for gt_idx in range(num_gt):
_, pos_idx = paddle.topk(
cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False)
matching_matrix[:, gt_idx][pos_idx.numpy()] = 1.0
del topk_ious, dynamic_ks, pos_idx
prior_match_gt_mask = matching_matrix.sum(1) > 1
if prior_match_gt_mask.sum() > 0:
cost = cost.numpy()
cost_argmin = np.argmin(cost[prior_match_gt_mask, :], axis=1)
matching_matrix[prior_match_gt_mask, :] *= 0.0
matching_matrix[prior_match_gt_mask, cost_argmin] = 1.0
# get foreground mask inside box and center prior
fg_mask_inboxes = matching_matrix.sum(1) > 0.0
valid_mask[valid_mask.copy()] = fg_mask_inboxes
matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
matched_pred_ious = (matching_matrix *
pairwise_ious.numpy()).sum(1)[fg_mask_inboxes]
matched_pred_ious = paddle.to_tensor(
matched_pred_ious, place=pairwise_ious.place)
matched_gt_inds = paddle.to_tensor(
matched_gt_inds, place=pairwise_ious.place)
return matched_pred_ious, matched_gt_inds, valid_mask
def get_sample(self, assign_gt_inds, gt_bboxes):
pos_inds = np.unique(np.nonzero(assign_gt_inds > 0)[0])
neg_inds = np.unique(np.nonzero(assign_gt_inds == 0)[0])
pos_assigned_gt_inds = assign_gt_inds[pos_inds] - 1
if gt_bboxes.size == 0:
# hack for index error case
assert pos_assigned_gt_inds.size == 0
pos_gt_bboxes = np.empty_like(gt_bboxes).reshape(-1, 4)
else:
if len(gt_bboxes.shape) < 2:
gt_bboxes = gt_bboxes.resize(-1, 4)
pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds
def __call__(self,
pred_scores,
priors,
decoded_bboxes,
gt_bboxes,
gt_labels,
eps=1e-7):
"""Assign gt to priors using SimOTA.
TODO: add comment.
Returns:
assign_result: The assigned result.
"""
INF = 100000000
num_gt = gt_bboxes.shape[0]
num_bboxes = decoded_bboxes.shape[0]
# assign 0 by default
assigned_gt_inds = paddle.full(
(num_bboxes, ), 0, dtype=paddle.int64).numpy()
if num_gt == 0 or num_bboxes == 0:
# No ground truth or boxes, return empty assignment
max_overlaps = decoded_bboxes.new_zeros((num_bboxes, ))
if num_gt == 0:
# No truth, assign everything to background
assigned_gt_inds[:] = 0
if gt_labels is None:
assigned_labels = None
else:
assigned_labels = paddle.full(
(num_bboxes, ), -1, dtype=paddle.int64)
return
valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info(
priors, gt_bboxes)
valid_mask_inds = paddle.nonzero(valid_mask).squeeze(1)
valid_decoded_bbox = decoded_bboxes[valid_mask_inds]
valid_pred_scores = pred_scores[valid_mask_inds]
num_valid = valid_decoded_bbox.shape[0]
pairwise_ious = batch_bbox_overlaps(valid_decoded_bbox, gt_bboxes)
iou_cost = -paddle.log(pairwise_ious + eps)
if self.use_vfl:
gt_vfl_labels = gt_labels.squeeze(-1).unsqueeze(0).tile(
[num_valid, 1]).reshape([-1])
valid_pred_scores = valid_pred_scores.unsqueeze(1).tile(
[1, num_gt, 1]).reshape([-1, self.num_classes])
vfl_score = np.zeros(valid_pred_scores.shape)
vfl_score[np.arange(0, vfl_score.shape[0]), gt_vfl_labels.numpy(
)] = pairwise_ious.reshape([-1])
vfl_score = paddle.to_tensor(vfl_score)
losses_vfl = varifocal_loss(
valid_pred_scores, vfl_score,
use_sigmoid=False).reshape([num_valid, num_gt])
losses_giou = batch_bbox_overlaps(
valid_decoded_bbox, gt_bboxes, mode='giou')
cost_matrix = (
losses_vfl * self.cls_weight + losses_giou * self.iou_weight +
paddle.logical_not(is_in_boxes_and_center).cast('float32') * INF
)
else:
gt_onehot_label = (F.one_hot(
gt_labels.squeeze(-1).cast(paddle.int64),
pred_scores.shape[-1]).cast('float32').unsqueeze(0).tile(
[num_valid, 1, 1]))
valid_pred_scores = valid_pred_scores.unsqueeze(1).tile(
[1, num_gt, 1])
cls_cost = F.binary_cross_entropy(
valid_pred_scores, gt_onehot_label, reduction='none').sum(-1)
cost_matrix = (
cls_cost * self.cls_weight + iou_cost * self.iou_weight +
paddle.logical_not(is_in_boxes_and_center).cast('float32') * INF
)
matched_pred_ious, matched_gt_inds, valid_mask = \
self.dynamic_k_matching(
cost_matrix, pairwise_ious, num_gt, valid_mask.numpy())
# assign results
gt_labels = gt_labels.numpy()
priors = priors.numpy()
matched_gt_inds = matched_gt_inds.numpy()
gt_bboxes = gt_bboxes.numpy()
assigned_gt_inds[valid_mask] = matched_gt_inds + 1
assigned_labels = np.full((num_bboxes, ), self.num_classes)
assigned_labels[valid_mask] = gt_labels.squeeze(-1)[matched_gt_inds]
pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds \
= self.get_sample(assigned_gt_inds, gt_bboxes)
num_cells = priors.shape[0]
bbox_targets = np.zeros_like(priors)
bbox_weights = np.zeros_like(priors)
labels = np.ones([num_cells], dtype=np.int64) * self.num_classes
label_weights = np.zeros([num_cells], dtype=np.float32)
if len(pos_inds) > 0:
pos_bbox_targets = pos_gt_bboxes
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
if not np.any(gt_labels):
labels[pos_inds] = 0
else:
labels[pos_inds] = gt_labels.squeeze(-1)[pos_assigned_gt_inds]
label_weights[pos_inds] = 1.0
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
pos_num = max(pos_inds.size, 1)
return priors, labels, label_weights, bbox_targets, pos_num
...@@ -28,6 +28,7 @@ from . import shufflenet_v2 ...@@ -28,6 +28,7 @@ from . import shufflenet_v2
from . import swin_transformer from . import swin_transformer
from . import lcnet from . import lcnet
from . import hardnet from . import hardnet
from . import esnet
from .vgg import * from .vgg import *
from .resnet import * from .resnet import *
...@@ -45,3 +46,4 @@ from .shufflenet_v2 import * ...@@ -45,3 +46,4 @@ from .shufflenet_v2 import *
from .swin_transformer import * from .swin_transformer import *
from .lcnet import * from .lcnet import *
from .hardnet import * from .hardnet import *
from .esnet import *
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn import Conv2D, MaxPool2D, AdaptiveAvgPool2D, BatchNorm
from paddle.nn.initializer import KaimingNormal
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable
from numbers import Integral
from ..shape_spec import ShapeSpec
from ppdet.modeling.ops import channel_shuffle
from ppdet.modeling.backbones.shufflenet_v2 import ConvBNLayer
__all__ = ['ESNet']
def make_divisible(v, divisor=16, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class SEModule(nn.Layer):
def __init__(self, channel, reduction=4):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv1 = Conv2D(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(),
bias_attr=ParamAttr())
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(),
bias_attr=ParamAttr())
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = F.hardsigmoid(outputs)
return paddle.multiply(x=inputs, y=outputs)
class InvertedResidual(nn.Layer):
def __init__(self,
in_channels,
mid_channels,
out_channels,
stride,
act="relu"):
super(InvertedResidual, self).__init__()
self._conv_pw = ConvBNLayer(
in_channels=in_channels // 2,
out_channels=mid_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act)
self._conv_dw = ConvBNLayer(
in_channels=mid_channels // 2,
out_channels=mid_channels // 2,
kernel_size=3,
stride=stride,
padding=1,
groups=mid_channels // 2,
act=None)
self._se = SEModule(mid_channels)
self._conv_linear = ConvBNLayer(
in_channels=mid_channels,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act)
def forward(self, inputs):
x1, x2 = paddle.split(
inputs,
num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2],
axis=1)
x2 = self._conv_pw(x2)
x3 = self._conv_dw(x2)
x3 = paddle.concat([x2, x3], axis=1)
x3 = self._se(x3)
x3 = self._conv_linear(x3)
out = paddle.concat([x1, x3], axis=1)
return channel_shuffle(out, 2)
class InvertedResidualDS(nn.Layer):
def __init__(self,
in_channels,
mid_channels,
out_channels,
stride,
act="relu"):
super(InvertedResidualDS, self).__init__()
# branch1
self._conv_dw_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=stride,
padding=1,
groups=in_channels,
act=None)
self._conv_linear_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act)
# branch2
self._conv_pw_2 = ConvBNLayer(
in_channels=in_channels,
out_channels=mid_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act)
self._conv_dw_2 = ConvBNLayer(
in_channels=mid_channels // 2,
out_channels=mid_channels // 2,
kernel_size=3,
stride=stride,
padding=1,
groups=mid_channels // 2,
act=None)
self._se = SEModule(mid_channels // 2)
self._conv_linear_2 = ConvBNLayer(
in_channels=mid_channels // 2,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act)
self._conv_dw_mv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
groups=out_channels,
act="hard_swish")
self._conv_pw_mv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act="hard_swish")
def forward(self, inputs):
x1 = self._conv_dw_1(inputs)
x1 = self._conv_linear_1(x1)
x2 = self._conv_pw_2(inputs)
x2 = self._conv_dw_2(x2)
x2 = self._se(x2)
x2 = self._conv_linear_2(x2)
out = paddle.concat([x1, x2], axis=1)
out = self._conv_dw_mv1(out)
out = self._conv_pw_mv1(out)
return out
@register
@serializable
class ESNet(nn.Layer):
def __init__(self,
scale=1.0,
act="hard_swish",
feature_maps=[4, 11, 14],
channel_ratio=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]):
super(ESNet, self).__init__()
self.scale = scale
if isinstance(feature_maps, Integral):
feature_maps = [feature_maps]
self.feature_maps = feature_maps
stage_repeats = [3, 7, 3]
stage_out_channels = [
-1, 24, make_divisible(128 * scale), make_divisible(256 * scale),
make_divisible(512 * scale), 1024
]
self._out_channels = []
self._feature_idx = 0
# 1. conv1
self._conv1 = ConvBNLayer(
in_channels=3,
out_channels=stage_out_channels[1],
kernel_size=3,
stride=2,
padding=1,
act=act)
self._max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
self._feature_idx += 1
# 2. bottleneck sequences
self._block_list = []
arch_idx = 0
for stage_id, num_repeat in enumerate(stage_repeats):
for i in range(num_repeat):
channels_scales = channel_ratio[arch_idx]
mid_c = make_divisible(
int(stage_out_channels[stage_id + 2] * channels_scales),
divisor=8)
if i == 0:
block = self.add_sublayer(
name=str(stage_id + 2) + '_' + str(i + 1),
sublayer=InvertedResidualDS(
in_channels=stage_out_channels[stage_id + 1],
mid_channels=mid_c,
out_channels=stage_out_channels[stage_id + 2],
stride=2,
act=act))
else:
block = self.add_sublayer(
name=str(stage_id + 2) + '_' + str(i + 1),
sublayer=InvertedResidual(
in_channels=stage_out_channels[stage_id + 2],
mid_channels=mid_c,
out_channels=stage_out_channels[stage_id + 2],
stride=1,
act=act))
self._block_list.append(block)
arch_idx += 1
self._feature_idx += 1
self._update_out_channels(stage_out_channels[stage_id + 2],
self._feature_idx, self.feature_maps)
def _update_out_channels(self, channel, feature_idx, feature_maps):
if feature_idx in feature_maps:
self._out_channels.append(channel)
def forward(self, inputs):
y = self._conv1(inputs['image'])
y = self._max_pool(y)
outs = []
for i, inv in enumerate(self._block_list):
y = inv(y)
if i + 2 in self.feature_maps:
outs.append(y)
return outs
@property
def out_shape(self):
return [ShapeSpec(channels=c) for c in self._out_channels]
...@@ -143,6 +143,100 @@ def bbox_overlaps(boxes1, boxes2): ...@@ -143,6 +143,100 @@ def bbox_overlaps(boxes1, boxes2):
return overlaps return overlaps
def batch_bbox_overlaps(bboxes1,
bboxes2,
mode='iou',
is_aligned=False,
eps=1e-6):
"""Calculate overlap between two set of bboxes.
If ``is_aligned `` is ``False``, then calculate the overlaps between each
bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned
pair of bboxes1 and bboxes2.
Args:
bboxes1 (Tensor): shape (B, m, 4) in <x1, y1, x2, y2> format or empty.
bboxes2 (Tensor): shape (B, n, 4) in <x1, y1, x2, y2> format or empty.
B indicates the batch dim, in shape (B1, B2, ..., Bn).
If ``is_aligned `` is ``True``, then m and n must be equal.
mode (str): "iou" (intersection over union) or "iof" (intersection over
foreground).
is_aligned (bool, optional): If True, then m and n must be equal.
Default False.
eps (float, optional): A value added to the denominator for numerical
stability. Default 1e-6.
Returns:
Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,)
"""
assert mode in ['iou', 'iof', 'giou'], 'Unsupported mode {}'.format(mode)
# Either the boxes are empty or the length of boxes's last dimenstion is 4
assert (bboxes1.shape[-1] == 4 or bboxes1.shape[0] == 0)
assert (bboxes2.shape[-1] == 4 or bboxes2.shape[0] == 0)
# Batch dim must be the same
# Batch dim: (B1, B2, ... Bn)
assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
batch_shape = bboxes1.shape[:-2]
rows = bboxes1.shape[-2] if bboxes1.shape[0] > 0 else 0
cols = bboxes2.shape[-2] if bboxes2.shape[0] > 0 else 0
if is_aligned:
assert rows == cols
if rows * cols == 0:
if is_aligned:
return paddle.full(batch_shape + (rows, ), 1)
else:
return paddle.full(batch_shape + (rows, cols), 1)
area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1])
area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1])
if is_aligned:
lt = paddle.maximum(bboxes1[:, :2], bboxes2[:, :2]) # [B, rows, 2]
rb = paddle.minimum(bboxes1[:, 2:], bboxes2[:, 2:]) # [B, rows, 2]
wh = (rb - lt).clip(min=0) # [B, rows, 2]
overlap = wh[:, 0] * wh[:, 1]
if mode in ['iou', 'giou']:
union = area1 + area2 - overlap
else:
union = area1
if mode == 'giou':
enclosed_lt = paddle.minimum(bboxes1[:, :2], bboxes2[:, :2])
enclosed_rb = paddle.maximum(bboxes1[:, 2:], bboxes2[:, 2:])
else:
lt = paddle.maximum(bboxes1[:, :2].reshape([rows, 1, 2]),
bboxes2[:, :2]) # [B, rows, cols, 2]
rb = paddle.minimum(bboxes1[:, 2:].reshape([rows, 1, 2]),
bboxes2[:, 2:]) # [B, rows, cols, 2]
wh = (rb - lt).clip(min=0) # [B, rows, cols, 2]
overlap = wh[:, :, 0] * wh[:, :, 1]
if mode in ['iou', 'giou']:
union = area1.reshape([rows,1]) \
+ area2.reshape([1,cols]) - overlap
else:
union = area1[:, None]
if mode == 'giou':
enclosed_lt = paddle.minimum(bboxes1[:, :2].reshape([rows, 1, 2]),
bboxes2[:, :2])
enclosed_rb = paddle.maximum(bboxes1[:, 2:].reshape([rows, 1, 2]),
bboxes2[:, 2:])
eps = paddle.to_tensor([eps])
union = paddle.maximum(union, eps)
ious = overlap / union
if mode in ['iou', 'iof']:
return ious
# calculate gious
enclose_wh = (enclosed_rb - enclosed_lt).clip(min=0)
enclose_area = enclose_wh[:, :, 0] * enclose_wh[:, :, 1]
enclose_area = paddle.maximum(enclose_area, eps)
gious = ious - (enclose_area - union) / enclose_area
return 1 - gious
def xywh2xyxy(box): def xywh2xyxy(box):
x, y, w, h = box x, y, w, h = box
x1 = x - w * 0.5 x1 = x - w * 0.5
......
...@@ -26,6 +26,7 @@ from . import s2anet_head ...@@ -26,6 +26,7 @@ from . import s2anet_head
from . import keypoint_hrhrnet_head from . import keypoint_hrhrnet_head
from . import centernet_head from . import centernet_head
from . import gfl_head from . import gfl_head
from . import simota_head
from . import pico_head from . import pico_head
from . import detr_head from . import detr_head
from . import sparsercnn_head from . import sparsercnn_head
...@@ -45,6 +46,7 @@ from .s2anet_head import * ...@@ -45,6 +46,7 @@ from .s2anet_head import *
from .keypoint_hrhrnet_head import * from .keypoint_hrhrnet_head import *
from .centernet_head import * from .centernet_head import *
from .gfl_head import * from .gfl_head import *
from .simota_head import *
from .pico_head import * from .pico_head import *
from .detr_head import * from .detr_head import *
from .sparsercnn_head import * from .sparsercnn_head import *
......
...@@ -150,14 +150,14 @@ class GFLHead(nn.Layer): ...@@ -150,14 +150,14 @@ class GFLHead(nn.Layer):
num_classes (int): Number of classes num_classes (int): Number of classes
fpn_stride (list): The stride of each FPN Layer fpn_stride (list): The stride of each FPN Layer
prior_prob (float): Used to set the bias init for the class prediction layer prior_prob (float): Used to set the bias init for the class prediction layer
loss_qfl (object): loss_class (object): Instance of QualityFocalLoss.
loss_dfl (object): loss_dfl (object): Instance of DistributionFocalLoss.
loss_bbox (object): loss_bbox (object): Instance of bbox loss.
reg_max: Max value of integral set :math: `{0, ..., reg_max}` reg_max: Max value of integral set :math: `{0, ..., reg_max}`
n QFL setting. Default: 16. n QFL setting. Default: 16.
""" """
__inject__ = [ __inject__ = [
'conv_feat', 'dgqp_module', 'loss_qfl', 'loss_dfl', 'loss_bbox', 'nms' 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox', 'nms'
] ]
__shared__ = ['num_classes'] __shared__ = ['num_classes']
...@@ -167,7 +167,7 @@ class GFLHead(nn.Layer): ...@@ -167,7 +167,7 @@ class GFLHead(nn.Layer):
num_classes=80, num_classes=80,
fpn_stride=[8, 16, 32, 64, 128], fpn_stride=[8, 16, 32, 64, 128],
prior_prob=0.01, prior_prob=0.01,
loss_qfl='QualityFocalLoss', loss_class='QualityFocalLoss',
loss_dfl='DistributionFocalLoss', loss_dfl='DistributionFocalLoss',
loss_bbox='GIoULoss', loss_bbox='GIoULoss',
reg_max=16, reg_max=16,
...@@ -181,7 +181,7 @@ class GFLHead(nn.Layer): ...@@ -181,7 +181,7 @@ class GFLHead(nn.Layer):
self.num_classes = num_classes self.num_classes = num_classes
self.fpn_stride = fpn_stride self.fpn_stride = fpn_stride
self.prior_prob = prior_prob self.prior_prob = prior_prob
self.loss_qfl = loss_qfl self.loss_qfl = loss_class
self.loss_dfl = loss_dfl self.loss_dfl = loss_dfl
self.loss_bbox = loss_bbox self.loss_bbox = loss_bbox
self.reg_max = reg_max self.reg_max = reg_max
......
...@@ -26,9 +26,7 @@ from paddle.nn.initializer import Normal, Constant ...@@ -26,9 +26,7 @@ from paddle.nn.initializer import Normal, Constant
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ppdet.modeling.layers import ConvNormLayer from ppdet.modeling.layers import ConvNormLayer
from ppdet.modeling.bbox_utils import distance2bbox, bbox2distance from .simota_head import OTAVFLHead
from ppdet.data.transform.atss_assigner import bbox_overlaps
from .gfl_head import GFLHead
@register @register
...@@ -49,11 +47,13 @@ class PicoFeat(nn.Layer): ...@@ -49,11 +47,13 @@ class PicoFeat(nn.Layer):
num_fpn_stride=3, num_fpn_stride=3,
num_convs=2, num_convs=2,
norm_type='bn', norm_type='bn',
share_cls_reg=False): share_cls_reg=False,
act='hard_swish'):
super(PicoFeat, self).__init__() super(PicoFeat, self).__init__()
self.num_convs = num_convs self.num_convs = num_convs
self.norm_type = norm_type self.norm_type = norm_type
self.share_cls_reg = share_cls_reg self.share_cls_reg = share_cls_reg
self.act = act
self.cls_convs = [] self.cls_convs = []
self.reg_convs = [] self.reg_convs = []
for stage_idx in range(num_fpn_stride): for stage_idx in range(num_fpn_stride):
...@@ -112,35 +112,43 @@ class PicoFeat(nn.Layer): ...@@ -112,35 +112,43 @@ class PicoFeat(nn.Layer):
self.cls_convs.append(cls_subnet_convs) self.cls_convs.append(cls_subnet_convs)
self.reg_convs.append(reg_subnet_convs) self.reg_convs.append(reg_subnet_convs)
def act_func(self, x):
if self.act == "leaky_relu":
x = F.leaky_relu(x)
elif self.act == "hard_swish":
x = F.hardswish(x)
return x
def forward(self, fpn_feat, stage_idx): def forward(self, fpn_feat, stage_idx):
assert stage_idx < len(self.cls_convs) assert stage_idx < len(self.cls_convs)
cls_feat = fpn_feat cls_feat = fpn_feat
reg_feat = fpn_feat reg_feat = fpn_feat
for i in range(len(self.cls_convs[stage_idx])): for i in range(len(self.cls_convs[stage_idx])):
cls_feat = F.leaky_relu(self.cls_convs[stage_idx][i](cls_feat), 0.1) cls_feat = self.act_func(self.cls_convs[stage_idx][i](cls_feat))
if not self.share_cls_reg: if not self.share_cls_reg:
reg_feat = F.leaky_relu(self.reg_convs[stage_idx][i](reg_feat), reg_feat = self.act_func(self.reg_convs[stage_idx][i](reg_feat))
0.1)
return cls_feat, reg_feat return cls_feat, reg_feat
@register @register
class PicoHead(GFLHead): class PicoHead(OTAVFLHead):
""" """
PicoHead PicoHead
Args: Args:
conv_feat (object): Instance of 'LiteGFLFeat' conv_feat (object): Instance of 'PicoFeat'
num_classes (int): Number of classes num_classes (int): Number of classes
fpn_stride (list): The stride of each FPN Layer fpn_stride (list): The stride of each FPN Layer
prior_prob (float): Used to set the bias init for the class prediction layer prior_prob (float): Used to set the bias init for the class prediction layer
loss_qfl (object): loss_class (object): Instance of VariFocalLoss.
loss_dfl (object): loss_dfl (object): Instance of DistributionFocalLoss.
loss_bbox (object): loss_bbox (object): Instance of bbox loss.
assigner (object): Instance of label assigner.
reg_max: Max value of integral set :math: `{0, ..., reg_max}` reg_max: Max value of integral set :math: `{0, ..., reg_max}`
n QFL setting. Default: 16. n QFL setting. Default: 7.
""" """
__inject__ = [ __inject__ = [
'conv_feat', 'dgqp_module', 'loss_qfl', 'loss_dfl', 'loss_bbox', 'nms' 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox',
'assigner', 'nms'
] ]
__shared__ = ['num_classes'] __shared__ = ['num_classes']
...@@ -150,9 +158,10 @@ class PicoHead(GFLHead): ...@@ -150,9 +158,10 @@ class PicoHead(GFLHead):
num_classes=80, num_classes=80,
fpn_stride=[8, 16, 32], fpn_stride=[8, 16, 32],
prior_prob=0.01, prior_prob=0.01,
loss_qfl='QualityFocalLoss', loss_class='VariFocalLoss',
loss_dfl='DistributionFocalLoss', loss_dfl='DistributionFocalLoss',
loss_bbox='GIoULoss', loss_bbox='GIoULoss',
assigner='SimOTAAssigner',
reg_max=16, reg_max=16,
feat_in_chan=96, feat_in_chan=96,
nms=None, nms=None,
...@@ -164,9 +173,10 @@ class PicoHead(GFLHead): ...@@ -164,9 +173,10 @@ class PicoHead(GFLHead):
num_classes=num_classes, num_classes=num_classes,
fpn_stride=fpn_stride, fpn_stride=fpn_stride,
prior_prob=prior_prob, prior_prob=prior_prob,
loss_qfl=loss_qfl, loss_class=loss_class,
loss_dfl=loss_dfl, loss_dfl=loss_dfl,
loss_bbox=loss_bbox, loss_bbox=loss_bbox,
assigner=assigner,
reg_max=reg_max, reg_max=reg_max,
feat_in_chan=feat_in_chan, feat_in_chan=feat_in_chan,
nms=nms, nms=nms,
...@@ -176,15 +186,17 @@ class PicoHead(GFLHead): ...@@ -176,15 +186,17 @@ class PicoHead(GFLHead):
self.num_classes = num_classes self.num_classes = num_classes
self.fpn_stride = fpn_stride self.fpn_stride = fpn_stride
self.prior_prob = prior_prob self.prior_prob = prior_prob
self.loss_qfl = loss_qfl self.loss_vfl = loss_class
self.loss_dfl = loss_dfl self.loss_dfl = loss_dfl
self.loss_bbox = loss_bbox self.loss_bbox = loss_bbox
self.assigner = assigner
self.reg_max = reg_max self.reg_max = reg_max
self.feat_in_chan = feat_in_chan self.feat_in_chan = feat_in_chan
self.nms = nms self.nms = nms
self.nms_pre = nms_pre self.nms_pre = nms_pre
self.cell_offset = cell_offset self.cell_offset = cell_offset
self.use_sigmoid = self.loss_qfl.use_sigmoid
self.use_sigmoid = self.loss_vfl.use_sigmoid
if self.use_sigmoid: if self.use_sigmoid:
self.cls_out_channels = self.num_classes self.cls_out_channels = self.num_classes
else: else:
......
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from ppdet.modeling import ops
__all__ = ['VarifocalLoss']
def varifocal_loss(pred,
target,
alpha=0.75,
gamma=2.0,
iou_weighted=True,
use_sigmoid=True):
"""`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the
number of classes
target (torch.Tensor): The learning target of the iou-aware
classification score with shape (N, C), C is the number of classes.
alpha (float, optional): A balance factor for the negative part of
Varifocal Loss, which is different from the alpha of Focal Loss.
Defaults to 0.75.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
iou_weighted (bool, optional): Whether to weight the loss of the
positive example with the iou target. Defaults to True.
"""
# pred and target should be of the same size
assert pred.shape == target.shape
if use_sigmoid:
pred_new = F.sigmoid(pred)
else:
pred_new = pred
target = target.cast(pred.dtype)
if iou_weighted:
focal_weight = target * (target > 0.0).cast('float32') + \
alpha * (pred_new - target).abs().pow(gamma) * \
(target <= 0.0).cast('float32')
else:
focal_weight = (target > 0.0).cast('float32') + \
alpha * (pred_new - target).abs().pow(gamma) * \
(target <= 0.0).cast('float32')
if use_sigmoid:
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
else:
loss = F.binary_cross_entropy(
pred, target, reduction='none') * focal_weight
loss = loss.sum(axis=1)
return loss
@register
@serializable
class VarifocalLoss(nn.Layer):
def __init__(self,
use_sigmoid=True,
alpha=0.75,
gamma=2.0,
iou_weighted=True,
reduction='mean',
loss_weight=1.0):
"""`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
Args:
use_sigmoid (bool, optional): Whether the prediction is
used for sigmoid or softmax. Defaults to True.
alpha (float, optional): A balance factor for the negative part of
Varifocal Loss, which is different from the alpha of Focal
Loss. Defaults to 0.75.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
iou_weighted (bool, optional): Whether to weight the loss of the
positive examples with the iou target. Defaults to True.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and
"sum".
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
"""
super(VarifocalLoss, self).__init__()
assert alpha >= 0.0
self.use_sigmoid = use_sigmoid
self.alpha = alpha
self.gamma = gamma
self.iou_weighted = iou_weighted
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
loss = self.loss_weight * varifocal_loss(
pred,
target,
alpha=self.alpha,
gamma=self.gamma,
iou_weighted=self.iou_weighted,
use_sigmoid=self.use_sigmoid)
if weight is not None:
loss = loss * weight
if avg_factor is None:
if self.reduction == 'none':
return loss
elif self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
# if reduction is mean, then average the loss by avg_factor
if self.reduction == 'mean':
loss = loss.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif self.reduction != 'none':
raise ValueError(
'avg_factor can not be used with reduction="sum"')
return loss
...@@ -19,6 +19,7 @@ from . import ttf_fpn ...@@ -19,6 +19,7 @@ from . import ttf_fpn
from . import centernet_fpn from . import centernet_fpn
from . import pan from . import pan
from . import bifpn from . import bifpn
from . import csp_pan
from .fpn import * from .fpn import *
from .yolo_fpn import * from .yolo_fpn import *
...@@ -28,3 +29,4 @@ from .centernet_fpn import * ...@@ -28,3 +29,4 @@ from .centernet_fpn import *
from .blazeface_fpn import * from .blazeface_fpn import *
from .pan import * from .pan import *
from .bifpn import * from .bifpn import *
from .csp_pan import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable
from ..shape_spec import ShapeSpec
__all__ = ['CSPPAN']
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channel=96,
out_channel=96,
kernel_size=3,
stride=1,
groups=1,
act='leaky_relu'):
super(ConvBNLayer, self).__init__()
initializer = nn.initializer.KaimingUniform()
self.act = act
assert self.act in ['leaky_relu', "hard_swish"]
self.conv = nn.Conv2D(
in_channels=in_channel,
out_channels=out_channel,
kernel_size=kernel_size,
groups=groups,
padding=(kernel_size - 1) // 2,
stride=stride,
weight_attr=ParamAttr(initializer=initializer),
bias_attr=False)
self.bn = nn.BatchNorm2D(out_channel)
def forward(self, x):
x = self.bn(self.conv(x))
if self.act == "leaky_relu":
x = F.leaky_relu(x)
elif self.act == "hard_swish":
x = F.hardswish(x)
return x
class DPModule(nn.Layer):
"""
Depth-wise and point-wise module.
Args:
in_channel (int): The input channels of this Module.
out_channel (int): The output channels of this Module.
kernel_size (int): The conv2d kernel size of this Module.
stride (int): The conv2d's stride of this Module.
act (str): The activation function of this Module,
Now support `leaky_relu` and `hard_swish`.
"""
def __init__(self,
in_channel=96,
out_channel=96,
kernel_size=3,
stride=1,
act='leaky_relu'):
super(DPModule, self).__init__()
initializer = nn.initializer.KaimingUniform()
self.act = act
self.dwconv = nn.Conv2D(
in_channels=in_channel,
out_channels=out_channel,
kernel_size=kernel_size,
groups=out_channel,
padding=(kernel_size - 1) // 2,
stride=stride,
weight_attr=ParamAttr(initializer=initializer),
bias_attr=False)
self.bn1 = nn.BatchNorm2D(out_channel)
self.pwconv = nn.Conv2D(
in_channels=out_channel,
out_channels=out_channel,
kernel_size=1,
groups=1,
padding=0,
weight_attr=ParamAttr(initializer=initializer),
bias_attr=False)
self.bn2 = nn.BatchNorm2D(out_channel)
def act_func(self, x):
if self.act == "leaky_relu":
x = F.leaky_relu(x)
elif self.act == "hard_swish":
x = F.hardswish(x)
return x
def forward(self, x):
x = self.act_func(self.bn1(self.dwconv(x)))
x = self.act_func(self.bn2(self.pwconv(x)))
return x
class DarknetBottleneck(nn.Layer):
"""The basic bottleneck block used in Darknet.
Each Block consists of two ConvModules and the input is added to the
final output. Each ConvModule is composed of Conv, BN, and act.
The first convLayer has filter size of 1x1 and the second one has the
filter size of 3x3.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
expansion (int): The kernel size of the convolution. Default: 0.5
add_identity (bool): Whether to add identity to the out.
Default: True
use_depthwise (bool): Whether to use depthwise separable convolution.
Default: False
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
expansion=0.5,
add_identity=True,
use_depthwise=False,
act="leaky_relu"):
super(DarknetBottleneck, self).__init__()
hidden_channels = int(out_channels * expansion)
conv_func = DPModule if use_depthwise else ConvBNLayer
self.conv1 = ConvBNLayer(
in_channel=in_channels,
out_channel=hidden_channels,
kernel_size=1,
act=act)
self.conv2 = conv_func(
in_channel=hidden_channels,
out_channel=out_channels,
kernel_size=kernel_size,
stride=1,
act=act)
self.add_identity = \
add_identity and in_channels == out_channels
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
if self.add_identity:
return out + identity
else:
return out
class CSPLayer(nn.Layer):
"""Cross Stage Partial Layer.
Args:
in_channels (int): The input channels of the CSP layer.
out_channels (int): The output channels of the CSP layer.
expand_ratio (float): Ratio to adjust the number of channels of the
hidden layer. Default: 0.5
num_blocks (int): Number of blocks. Default: 1
add_identity (bool): Whether to add identity in blocks.
Default: True
use_depthwise (bool): Whether to depthwise separable convolution in
blocks. Default: False
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
expand_ratio=0.5,
num_blocks=1,
add_identity=True,
use_depthwise=False,
act="leaky_relu"):
super().__init__()
mid_channels = int(out_channels * expand_ratio)
self.main_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act)
self.short_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act)
self.final_conv = ConvBNLayer(
2 * mid_channels, out_channels, 1, act=act)
self.blocks = nn.Sequential(* [
DarknetBottleneck(
mid_channels,
mid_channels,
kernel_size,
1.0,
add_identity,
use_depthwise,
act=act) for _ in range(num_blocks)
])
def forward(self, x):
x_short = self.short_conv(x)
x_main = self.main_conv(x)
x_main = self.blocks(x_main)
x_final = paddle.concat((x_main, x_short), axis=1)
return self.final_conv(x_final)
class Channel_T(nn.Layer):
def __init__(self,
in_channels=[116, 232, 464],
out_channels=96,
act="leaky_relu"):
super(Channel_T, self).__init__()
self.convs = nn.LayerList()
for i in range(len(in_channels)):
self.convs.append(
ConvBNLayer(
in_channels[i], out_channels, 1, act=act))
def forward(self, x):
outs = [self.convs[i](x[i]) for i in range(len(x))]
return outs
@register
@serializable
class CSPPAN(nn.Layer):
"""Path Aggregation Network with CSP module.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
kernel_size (int): The conv2d kernel size of this Module.
num_features (int): Number of output features of CSPPAN module.
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 1
use_depthwise (bool): Whether to depthwise separable convolution in
blocks. Default: True
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=5,
num_features=3,
num_csp_blocks=1,
use_depthwise=True,
act='hard_swish',
spatial_scales=[0.125, 0.0625, 0.03125]):
super(CSPPAN, self).__init__()
self.conv_t = Channel_T(in_channels, out_channels, act=act)
in_channels = [out_channels] * len(spatial_scales)
self.in_channels = in_channels
self.out_channels = out_channels
self.spatial_scales = spatial_scales
self.num_features = num_features
conv_func = DPModule if use_depthwise else ConvBNLayer
if self.num_features == 4:
self.first_top_conv = conv_func(
in_channels[0], in_channels[0], kernel_size, stride=2, act=act)
self.second_top_conv = conv_func(
in_channels[0], in_channels[0], kernel_size, stride=2, act=act)
self.spatial_scales.append(self.spatial_scales[-1] / 2)
# build top-down blocks
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.top_down_blocks = nn.LayerList()
for idx in range(len(in_channels) - 1, 0, -1):
self.top_down_blocks.append(
CSPLayer(
in_channels[idx - 1] * 2,
in_channels[idx - 1],
kernel_size=kernel_size,
num_blocks=num_csp_blocks,
add_identity=False,
use_depthwise=use_depthwise,
act=act))
# build bottom-up blocks
self.downsamples = nn.LayerList()
self.bottom_up_blocks = nn.LayerList()
for idx in range(len(in_channels) - 1):
self.downsamples.append(
conv_func(
in_channels[idx],
in_channels[idx],
kernel_size=kernel_size,
stride=2,
act=act))
self.bottom_up_blocks.append(
CSPLayer(
in_channels[idx] * 2,
in_channels[idx + 1],
kernel_size=kernel_size,
num_blocks=num_csp_blocks,
add_identity=False,
use_depthwise=use_depthwise,
act=act))
def forward(self, inputs):
"""
Args:
inputs (tuple[Tensor]): input features.
Returns:
tuple[Tensor]: CSPPAN features.
"""
assert len(inputs) == len(self.in_channels)
inputs = self.conv_t(inputs)
# top-down path
inner_outs = [inputs[-1]]
for idx in range(len(self.in_channels) - 1, 0, -1):
feat_heigh = inner_outs[0]
feat_low = inputs[idx - 1]
upsample_feat = self.upsample(feat_heigh)
inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
paddle.concat([upsample_feat, feat_low], 1))
inner_outs.insert(0, inner_out)
# bottom-up path
outs = [inner_outs[0]]
for idx in range(len(self.in_channels) - 1):
feat_low = outs[-1]
feat_height = inner_outs[idx + 1]
downsample_feat = self.downsamples[idx](feat_low)
out = self.bottom_up_blocks[idx](paddle.concat(
[downsample_feat, feat_height], 1))
outs.append(out)
top_features = None
if self.num_features == 4:
top_features = self.first_top_conv(inputs[-1])
top_features = top_features + self.second_top_conv(outs[-1])
outs.append(top_features)
return tuple(outs)
@property
def out_shape(self):
return [
ShapeSpec(
channels=self.out_channels, stride=1. / s)
for s in self.spatial_scales
]
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册