diff --git a/configs/pphuman/README.md b/configs/pphuman/README.md
index dcaa9543e1390b674e07859006ceeb3b810bed61..6a713b6ce390856805ff093f9b85e6b5a000218d 100644
--- a/configs/pphuman/README.md
+++ b/configs/pphuman/README.md
@@ -11,6 +11,8 @@ PaddleDetection团队提供了针对行人的基于PP-YOLOE的检测模型,用
|PP-YOLOE-l| CrowdHuman | 48.0 | 81.9 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_36e_crowdhuman.pdparams) | [配置文件](./ppyoloe_crn_l_36e_crowdhuman.yml) |
|PP-YOLOE-s| 业务数据集 | 53.2 | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_s_36e_pipeline.zip) | [配置文件](./ppyoloe_crn_s_36e_pphuman.yml) |
|PP-YOLOE-l| 业务数据集 | 57.8 | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_l_36e_pipeline.zip) | [配置文件](./ppyoloe_crn_l_36e_pphuman.yml) |
+|PP-YOLOE+_t-P2(320)| 业务数据集 | 49.8 | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_t_p2_60e_pipeline.zip) | [配置文件](./ppyoloe_plus_crn_t_p2_60e_pphuman.yml) |
+|PP-YOLOE+_t-P2(416)| 业务数据集 | 52.2 | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_t_p2_60e_pipeline.zip) | [配置文件](./ppyoloe_plus_crn_t_p2_60e_pphuman.yml) |
**注意:**
diff --git a/configs/pphuman/ppyoloe_plus_crn_t_p2_60e_pphuman.yml b/configs/pphuman/ppyoloe_plus_crn_t_p2_60e_pphuman.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b13f19c55645b7cb6423ad0a1e84508a3acb67c4
--- /dev/null
+++ b/configs/pphuman/ppyoloe_plus_crn_t_p2_60e_pphuman.yml
@@ -0,0 +1,60 @@
+_BASE_: [
+ '../datasets/coco_detection.yml',
+ '../runtime.yml',
+ '../ppyoloe/_base_/optimizer_300e.yml',
+ '../ppyoloe/_base_/ppyoloe_plus_crn_tiny_auxhead.yml',
+ '../ppyoloe/_base_/ppyoloe_plus_reader_tiny.yml',
+]
+
+log_iter: 100
+snapshot_epoch: 4
+weights: output/ppyoloe_plus_crn_tiny_60e_pphuman/model_final
+
+pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_tiny_auxhead_300e_coco.pdparams
+depth_mult: 0.33
+width_mult: 0.375
+
+
+num_classes: 1
+TrainDataset:
+ !COCODataSet
+ image_dir: ""
+ anno_path: annotations/train.json
+ dataset_dir: dataset/pphuman
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
+
+EvalDataset:
+ !COCODataSet
+ image_dir: ""
+ anno_path: annotations/val.json
+ dataset_dir: dataset/pphuman
+
+TestDataset:
+ !ImageFolder
+ anno_path: annotations/val.json
+ dataset_dir: dataset/pphuman
+
+
+TrainReader:
+ batch_size: 8
+
+
+epoch: 60
+LearningRate:
+ base_lr: 0.001
+ schedulers:
+ - !CosineDecay
+ max_epochs: 72
+ - !LinearWarmup
+ start_factor: 0.
+ epochs: 1
+
+
+PPYOLOEHead:
+ static_assigner_epoch: -1
+ nms:
+ name: MultiClassNMS
+ nms_top_k: 1000
+ keep_top_k: 300
+ score_threshold: 0.01
+ nms_threshold: 0.7
diff --git a/configs/ppvehicle/README.md b/configs/ppvehicle/README.md
index a0a6ee28c0714cc43ab795e68ac25d2236b6d0fd..9181162900048d6c7ffd82a2ee976996be046fc8 100644
--- a/configs/ppvehicle/README.md
+++ b/configs/ppvehicle/README.md
@@ -19,6 +19,9 @@ PaddleDetection团队提供了针对自动驾驶场景的基于PP-YOLOE的检测
|PP-YOLOE-s| PPVehicle9cls | 9 | 35.3 | [下载链接](https://paddledet.bj.bcebos.com/models/mot_ppyoloe_s_36e_ppvehicle9cls.pdparams) | [配置文件](./mot_ppyoloe_s_36e_ppvehicle9cls.yml) |
|PP-YOLOE-l| PPVehicle | 1 | 63.9 | [下载链接](https://paddledet.bj.bcebos.com/models/mot_ppyoloe_l_36e_ppvehicle.pdparams) | [配置文件](./mot_ppyoloe_l_36e_ppvehicle.yml) |
|PP-YOLOE-s| PPVehicle | 1 | 61.3 | [下载链接](https://paddledet.bj.bcebos.com/models/mot_ppyoloe_s_36e_ppvehicle.pdparams) | [配置文件](./mot_ppyoloe_s_36e_ppvehicle.yml) |
+|PP-YOLOE+_t-P2(320)| PPVehicle | 1 | 58.2 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_t_p2_60e_ppvehicle.zip) | [配置文件](./ppyoloe_plus_crn_t_p2_60e_ppvehicle.yml) |
+|PP-YOLOE+_t-P2(416)| PPVehicle | 1 | 60.5 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_t_p2_60e_ppvehicle.zip) | [配置文件](./ppyoloe_plus_crn_t_p2_60e_ppvehicle.yml) |
+
**注意:**
- PP-YOLOE模型训练过程中使用8 GPUs进行混合精度训练,如果**GPU卡数**或者**batch size**发生了改变,你需要按照公式 **lrnew = lrdefault * (batch_sizenew * GPU_numbernew) / (batch_sizedefault * GPU_numberdefault)** 调整学习率。
diff --git a/configs/ppvehicle/ppyoloe_plus_crn_t_p2_60e_ppvehicle.yml b/configs/ppvehicle/ppyoloe_plus_crn_t_p2_60e_ppvehicle.yml
new file mode 100644
index 0000000000000000000000000000000000000000..815935f92f825c359d79b250e44c89ea644abda5
--- /dev/null
+++ b/configs/ppvehicle/ppyoloe_plus_crn_t_p2_60e_ppvehicle.yml
@@ -0,0 +1,61 @@
+_BASE_: [
+ '../datasets/coco_detection.yml',
+ '../runtime.yml',
+ '../ppyoloe/_base_/optimizer_300e.yml',
+ '../ppyoloe/_base_/ppyoloe_plus_crn_tiny_auxhead.yml',
+ '../ppyoloe/_base_/ppyoloe_plus_reader_tiny.yml',
+]
+
+log_iter: 100
+snapshot_epoch: 4
+weights: output/ppyoloe_plus_crn_tiny_60e_ppvehicle/model_final
+
+pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_tiny_auxhead_300e_coco.pdparams
+depth_mult: 0.33
+width_mult: 0.375
+
+
+num_classes: 1
+TrainDataset:
+ !COCODataSet
+ image_dir: ""
+ anno_path: annotations/train_all.json
+ dataset_dir: dataset/ppvehicle
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
+ allow_empty: true
+
+EvalDataset:
+ !COCODataSet
+ image_dir: ""
+ anno_path: annotations/val_all.json
+ dataset_dir: dataset/ppvehicle
+
+TestDataset:
+ !ImageFolder
+ anno_path: annotations/val_all.json
+ dataset_dir: dataset/ppvehicle
+
+
+TrainReader:
+ batch_size: 8
+
+
+epoch: 60
+LearningRate:
+ base_lr: 0.001
+ schedulers:
+ - !CosineDecay
+ max_epochs: 72
+ - !LinearWarmup
+ start_factor: 0.
+ epochs: 1
+
+
+PPYOLOEHead:
+ static_assigner_epoch: -1
+ nms:
+ name: MultiClassNMS
+ nms_top_k: 1000
+ keep_top_k: 300
+ score_threshold: 0.01
+ nms_threshold: 0.7
diff --git a/configs/ppyoloe/README.md b/configs/ppyoloe/README.md
index ba5e9c0fcfe1c5ad6763114bd3b726d7716a8597..9b550b77e9e140acef259e597a072478454314d4 100644
--- a/configs/ppyoloe/README.md
+++ b/configs/ppyoloe/README.md
@@ -44,6 +44,16 @@ PP-YOLOE is composed of following methods:
| PP-YOLOE+_x | 80 | 8 | 8 | cspresnet-x | 640 | 54.7 | 54.9 | 98.42 | 206.59 | 45.0 | 95.2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_x_80e_coco.pdparams) | [config](./ppyoloe_plus_crn_x_80e_coco.yml) |
+#### Tiny model
+
+| Model | Epoch | GPU number | images/GPU | backbone | input shape | Box APval
0.5:0.95 | Box APtest
0.5:0.95 | Params(M) | FLOPs(G) | V100 FP32(FPS) | V100 TensorRT FP16(FPS) | download | config |
+|:--------------:|:-----:|:-------:|:----------:|:----------:| :-------:|:--------------------------:|:---------------------------:|:---------:|:--------:|:---------------:| :---------------------: |:------------------------------------------------------------------------------------:|:-------------------------------------------:|
+| PP-YOLOE-t-P2 | 300 | 8 | 8 | cspresnet-t | 320 | 34.7 | 50.0 | 6.82 | 4.78 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_t_p2_300e_coco.pdparams) | [config](./ppyoloe_crn_t_p2_300e_coco.yml) |
+| PP-YOLOE-t-P2 | 300 | 8 | 8 | cspresnet-t | 416 | 36.4 | 52.3 | 6.82 | 8.07 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_t_p2_300e_coco.pdparams) | [config](./ppyoloe_crn_t_p2_300e_coco.yml) |
+| PP-YOLOE+_t-P2(aux) | 300 | 8 | 8 | cspresnet-t | 320 | 36.3 | 51.7 | 6.00 | 15.46 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_t_p2_auxhead_300e_coco.pdparams) | [config](./ppyoloe_plus_crn_t_p2_auxhead_300e_coco.yml) |
+| PP-YOLOE+_t-P2(aux) | 300 | 8 | 8 | cspresnet-t | 416 | 39.0 | 55.1 | 6.00 | 26.13 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_t_p2_auxhead_300e_coco.pdparams) | [config](./ppyoloe_plus_crn_t_p2_auxhead_300e_coco.yml) |
+
+
### Comprehensive Metrics
| Model | Epoch | AP0.5:0.95 | AP0.5 | AP0.75 | APsmall | APmedium | APlarge | ARsmall | ARmedium | ARlarge |
|:------------------------:|:-----:|:---------------:|:----------:|:------------:|:------------:| :-----------: |:------------:|:------------:|:-------------:|:------------:|
diff --git a/configs/ppyoloe/README_cn.md b/configs/ppyoloe/README_cn.md
index d73bc8415024e68527a17a4f306a9a8374084b05..c07730ee927bd9d178d8912defae479656e04f9a 100644
--- a/configs/ppyoloe/README_cn.md
+++ b/configs/ppyoloe/README_cn.md
@@ -43,6 +43,15 @@ PP-YOLOE由以下方法组成
| PP-YOLOE+_l | 80 | 8 | 8 | cspresnet-l | 640 | 52.9 | 53.3 | 52.20 | 110.07 | 78.1 | 149.2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_l_80e_coco.pdparams) | [config](./ppyoloe_plus_crn_l_80e_coco.yml) |
| PP-YOLOE+_x | 80 | 8 | 8 | cspresnet-x | 640 | 54.7 | 54.9 | 98.42 | 206.59 | 45.0 | 95.2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_x_80e_coco.pdparams) | [config](./ppyoloe_plus_crn_x_80e_coco.yml) |
+#### Tiny模型
+
+| 模型 | Epoch | GPU个数 | 每GPU图片个数 | 骨干网络 | 输入尺寸 | Box APval
0.5:0.95 | Box APtest
0.5:0.95 | Params(M) | FLOPs(G) | V100 FP32(FPS) | V100 TensorRT FP16(FPS) | 模型下载 | 配置文件 |
+|:---------------:|:-----:|:---------:|:--------:|:----------:|:----------:|:--------------------------:|:---------------------------:|:---------:|:--------:|:---------------:| :---------------------: |:------------------------------------------------------------------------------------:|:-------------------------------------------:|
+| PP-YOLOE-t-P2 | 300 | 8 | 8 | cspresnet-t | 320 | 34.7 | 50.0 | 6.82 | 4.78 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_t_p2_300e_coco.pdparams) | [config](./ppyoloe_crn_t_p2_300e_coco.yml) |
+| PP-YOLOE-t-P2 | 300 | 8 | 8 | cspresnet-t | 416 | 36.4 | 52.3 | 6.82 | 8.07 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_t_p2_300e_coco.pdparams) | [config](./ppyoloe_crn_t_p2_300e_coco.yml) |
+| PP-YOLOE+_t-P2(aux) | 300 | 8 | 8 | cspresnet-t | 320 | 36.3 | 51.7 | 6.00 | 15.46 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_t_p2_auxhead_300e_coco.pdparams) | [config](./ppyoloe_plus_crn_t_p2_auxhead_300e_coco.yml) |
+| PP-YOLOE+_t-P2(aux) | 300 | 8 | 8 | cspresnet-t | 416 | 39.0 | 55.1 | 6.00 | 26.13 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_t_p2_auxhead_300e_coco.pdparams) | [config](./ppyoloe_plus_crn_t_p2_auxhead_300e_coco.yml) |
+
### 综合指标
| 模型 | Epoch | AP0.5:0.95 | AP0.5 | AP0.75 | APsmall | APmedium | APlarge | ARsmall | ARmedium | ARlarge |
diff --git a/configs/ppyoloe/_base_/ppyoloe_plus_crn_tiny_auxhead.yml b/configs/ppyoloe/_base_/ppyoloe_plus_crn_tiny_auxhead.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8aea82150dfaef11a9c7e7362642fdd8e5e951d9
--- /dev/null
+++ b/configs/ppyoloe/_base_/ppyoloe_plus_crn_tiny_auxhead.yml
@@ -0,0 +1,60 @@
+architecture: PPYOLOEWithAuxHead
+norm_type: sync_bn
+use_ema: true
+ema_decay: 0.9998
+ema_black_list: ['proj_conv.weight']
+custom_black_list: ['reduce_mean']
+
+PPYOLOEWithAuxHead:
+ backbone: CSPResNet
+ neck: CustomCSPPAN
+ yolo_head: PPYOLOEHead
+ aux_head: SimpleConvHead
+ post_process: ~
+
+CSPResNet:
+ layers: [3, 6, 6, 3]
+ channels: [64, 128, 256, 512, 1024]
+ return_idx: [1, 2, 3]
+ use_large_stem: True
+ use_alpha: True
+
+CustomCSPPAN:
+ out_channels: [384, 384, 384]
+ stage_num: 1
+ block_num: 3
+ act: 'swish'
+ spp: true
+
+SimpleConvHead:
+ feat_in: 288
+ feat_out: 288
+ num_convs: 1
+ fpn_strides: [32, 16, 8]
+ norm_type: 'gn'
+ act: 'LeakyReLU'
+ reg_max: 16
+
+PPYOLOEHead:
+ fpn_strides: [32, 16, 8]
+ grid_cell_scale: 5.0
+ grid_cell_offset: 0.5
+ static_assigner_epoch: 100
+ use_varifocal_loss: True
+ loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
+ attn_conv: 'repvgg' #
+ static_assigner:
+ name: ATSSAssigner
+ topk: 9
+ assigner:
+ name: TaskAlignedAssigner
+ topk: 13
+ alpha: 1.0
+ beta: 6.0
+ is_close_gt: True #
+ nms:
+ name: MultiClassNMS
+ nms_top_k: 1000
+ keep_top_k: 300
+ score_threshold: 0.01
+ nms_threshold: 0.7
diff --git a/configs/ppyoloe/_base_/ppyoloe_plus_tiny_reader.yml b/configs/ppyoloe/_base_/ppyoloe_plus_tiny_reader.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2b7be58daf8e208c4875cff6be9ea48dbf0073e5
--- /dev/null
+++ b/configs/ppyoloe/_base_/ppyoloe_plus_tiny_reader.yml
@@ -0,0 +1,40 @@
+worker_num: 4
+eval_height: &eval_height 320
+eval_width: &eval_width 320
+eval_size: &eval_size [*eval_height, *eval_width]
+
+TrainReader:
+ sample_transforms:
+ - Decode: {}
+ - RandomDistort: {}
+ - RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
+ - RandomCrop: {}
+ - RandomFlip: {}
+ batch_transforms:
+ - BatchRandomResize: {target_size: [224, 256, 288, 320, 352, 384, 416, 448, 480, 512, 544], random_size: True, random_interp: True, keep_ratio: False}
+ - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
+ - Permute: {}
+ - PadGT: {}
+ batch_size: 8
+ shuffle: true
+ drop_last: true
+ use_shared_memory: true
+ collate_batch: true
+
+EvalReader:
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
+ - Permute: {}
+ batch_size: 2
+
+TestReader:
+ inputs_def:
+ image_shape: [3, *eval_height, *eval_width]
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
+ - Permute: {}
+ batch_size: 1
diff --git a/configs/ppyoloe/ppyoloe_crn_t_p2_300e_coco.yml b/configs/ppyoloe/ppyoloe_crn_t_p2_300e_coco.yml
new file mode 100644
index 0000000000000000000000000000000000000000..02dc0ddc9f61ac99231029238453b9490d6df546
--- /dev/null
+++ b/configs/ppyoloe/ppyoloe_crn_t_p2_300e_coco.yml
@@ -0,0 +1,81 @@
+_BASE_: [
+ '../datasets/coco_detection.yml',
+ '../runtime.yml',
+ './_base_/optimizer_300e.yml',
+ './_base_/ppyoloe_crn.yml',
+ './_base_/ppyoloe_reader.yml',
+]
+
+log_iter: 100
+snapshot_epoch: 10
+weights: output/ppyoloe_crn_t_p2_300e_coco/model_final
+
+pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_t_pretrained.pdparams
+depth_mult: 0.33
+width_mult: 0.375
+
+
+CSPResNet:
+ return_idx: [0, 1, 2, 3]
+
+CustomCSPPAN:
+ out_channels: [768, 384, 192, 96]
+
+PPYOLOEHead:
+ fpn_strides: [32, 16, 8, 4]
+ attn_conv: 'repvgg' #
+ assigner:
+ name: TaskAlignedAssigner
+ topk: 13
+ alpha: 1.0
+ beta: 6.0
+ is_close_gt: True #
+ nms:
+ name: MultiClassNMS
+ nms_top_k: 1000
+ keep_top_k: 300
+ score_threshold: 0.01
+ nms_threshold: 0.7
+
+
+worker_num: 4
+eval_height: &eval_height 320
+eval_width: &eval_width 320
+eval_size: &eval_size [*eval_height, *eval_width]
+
+TrainReader:
+ sample_transforms:
+ - Decode: {}
+ - RandomDistort: {}
+ - RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
+ - RandomCrop: {}
+ - RandomFlip: {}
+ batch_transforms:
+ - BatchRandomResize: {target_size: [224, 256, 288, 320, 352, 384, 416, 448, 480, 512, 544], random_size: True, random_interp: True, keep_ratio: False}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ - PadGT: {}
+ batch_size: 8
+ shuffle: true
+ drop_last: true
+ use_shared_memory: true
+ collate_batch: true
+
+EvalReader:
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 2
+
+TestReader:
+ inputs_def:
+ image_shape: [3, *eval_height, *eval_width]
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+ fuse_normalize: True
diff --git a/configs/ppyoloe/ppyoloe_plus_crn_t_auxhead_300e_coco.yml b/configs/ppyoloe/ppyoloe_plus_crn_t_auxhead_300e_coco.yml
new file mode 100644
index 0000000000000000000000000000000000000000..61422ddcceeaf8bcdeab73e24f6fcbabfd5d34a3
--- /dev/null
+++ b/configs/ppyoloe/ppyoloe_plus_crn_t_auxhead_300e_coco.yml
@@ -0,0 +1,15 @@
+_BASE_: [
+ '../datasets/coco_detection.yml',
+ '../runtime.yml',
+ './_base_/optimizer_300e.yml',
+ './_base_/ppyoloe_plus_crn_tiny_auxhead.yml',
+ './_base_/ppyoloe_plus_tiny_reader.yml',
+]
+
+log_iter: 100
+snapshot_epoch: 10
+weights: output/ppyoloe_plus_crn_t_auxhead_300e_coco/model_final
+
+pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_t_pretrained.pdparams
+depth_mult: 0.33
+width_mult: 0.375
diff --git a/configs/ppyoloe/ppyoloe_plus_crn_t_p2_auxhead_300e_coco.yml b/configs/ppyoloe/ppyoloe_plus_crn_t_p2_auxhead_300e_coco.yml
new file mode 100644
index 0000000000000000000000000000000000000000..27f8f5220149f6d0424c6d3d19b03dae803c33d5
--- /dev/null
+++ b/configs/ppyoloe/ppyoloe_plus_crn_t_p2_auxhead_300e_coco.yml
@@ -0,0 +1,36 @@
+_BASE_: [
+ '../datasets/coco_detection.yml',
+ '../runtime.yml',
+ './_base_/optimizer_300e.yml',
+ './_base_/ppyoloe_plus_crn_tiny_auxhead.yml',
+ './_base_/ppyoloe_plus_tiny_reader.yml',
+]
+
+log_iter: 100
+snapshot_epoch: 10
+weights: output/ppyoloe_plus_crn_t_p2_auxhead_300e_coco/model_final
+
+pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_t_pretrained.pdparams
+depth_mult: 0.33
+width_mult: 0.375
+
+
+architecture: PPYOLOEWithAuxHead
+PPYOLOEWithAuxHead:
+ backbone: CSPResNet
+ neck: CustomCSPPAN
+ yolo_head: PPYOLOEHead
+ aux_head: SimpleConvHead
+ post_process: ~
+
+CSPResNet:
+ return_idx: [0, 1, 2, 3] # index 0 stands for P2
+
+CustomCSPPAN:
+ out_channels: [384, 384, 384, 384]
+
+SimpleConvHead:
+ fpn_strides: [32, 16, 8, 4]
+
+PPYOLOEHead:
+ fpn_strides: [32, 16, 8, 4]
diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py
index eb8e181690b507e0851a25134a2662b9a51e7d3c..d7d2e883d2dc38ee09fc4e3c4cdab3a1dad6d8ac 100644
--- a/ppdet/engine/export_utils.py
+++ b/ppdet/engine/export_utils.py
@@ -194,6 +194,9 @@ def _dump_infer_config(config, path, image_shape, model):
arch_state = True
break
+ if infer_arch == 'PPYOLOEWithAuxHead':
+ infer_arch = 'PPYOLOE'
+
if infer_arch in ['PPYOLOE', 'YOLOX', 'YOLOF']:
infer_cfg['arch'] = infer_arch
infer_cfg['min_subgraph_size'] = TRT_MIN_SUBGRAPH[infer_arch]
diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py
index 272cb32bdb59082abce638d7fb5ac96abc449fa0..362140ac2029cfb5e0b740a607c85340ac8e5a87 100644
--- a/ppdet/engine/trainer.py
+++ b/ppdet/engine/trainer.py
@@ -157,9 +157,10 @@ class Trainer(object):
if print_params:
params = sum([
p.numel() for n, p in self.model.named_parameters()
- if all([x not in n for x in ['_mean', '_variance']])
+ if all([x not in n for x in ['_mean', '_variance', 'aux_']])
]) # exclude BatchNorm running status
- logger.info('Params: ', params / 1e6)
+ logger.info('Model Params : {} M.'.format((params / 1e6).numpy()[
+ 0]))
# build optimizer in train mode
if self.mode == 'train':
@@ -1105,6 +1106,10 @@ class Trainer(object):
return static_model, pruned_input_spec
def export(self, output_dir='output_inference'):
+ if hasattr(self.model, 'aux_neck'):
+ self.model.__delattr__('aux_neck')
+ if hasattr(self.model, 'aux_head'):
+ self.model.__delattr__('aux_head')
self.model.eval()
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
@@ -1151,6 +1156,10 @@ class Trainer(object):
logger.info("Export Post-Quant model and saved in {}".format(save_dir))
def _flops(self, loader):
+ if hasattr(self.model, 'aux_neck'):
+ self.model.__delattr__('aux_neck')
+ if hasattr(self.model, 'aux_head'):
+ self.model.__delattr__('aux_head')
self.model.eval()
try:
import paddleslim
diff --git a/ppdet/modeling/architectures/ppyoloe.py b/ppdet/modeling/architectures/ppyoloe.py
index 0d0e926f49bd164382b9707efd5d44211e9beaec..7ff7c254da92546be6b4ee684f02dfdca730ebba 100644
--- a/ppdet/modeling/architectures/ppyoloe.py
+++ b/ppdet/modeling/architectures/ppyoloe.py
@@ -16,10 +16,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import paddle
+import copy
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
-__all__ = ['PPYOLOE']
+__all__ = ['PPYOLOE', 'PPYOLOEWithAuxHead']
# PP-YOLOE and PP-YOLOE+ are recommended to use this architecture
# PP-YOLOE and PP-YOLOE+ can also use the same architecture of YOLOv3 in yolo.py
@@ -97,3 +99,101 @@ class PPYOLOE(BaseArch):
def get_pred(self):
return self._forward()
+
+
+@register
+class PPYOLOEWithAuxHead(BaseArch):
+ __category__ = 'architecture'
+ __inject__ = ['post_process']
+
+ def __init__(self,
+ backbone='CSPResNet',
+ neck='CustomCSPPAN',
+ yolo_head='PPYOLOEHead',
+ aux_head='SimpleConvHead',
+ post_process='BBoxPostProcess',
+ for_mot=False,
+ detach_epoch=5):
+ """
+ PPYOLOE network, see https://arxiv.org/abs/2203.16250
+
+ Args:
+ backbone (nn.Layer): backbone instance
+ neck (nn.Layer): neck instance
+ yolo_head (nn.Layer): anchor_head instance
+ post_process (object): `BBoxPostProcess` instance
+ for_mot (bool): whether return other features for multi-object tracking
+ models, default False in pure object detection models.
+ """
+ super(PPYOLOEWithAuxHead, self).__init__()
+ self.backbone = backbone
+ self.neck = neck
+ self.aux_neck = copy.deepcopy(self.neck)
+
+ self.yolo_head = yolo_head
+ self.aux_head = aux_head
+ self.post_process = post_process
+ self.for_mot = for_mot
+ self.detach_epoch = detach_epoch
+
+ @classmethod
+ def from_config(cls, cfg, *args, **kwargs):
+ # backbone
+ backbone = create(cfg['backbone'])
+
+ # fpn
+ kwargs = {'input_shape': backbone.out_shape}
+ neck = create(cfg['neck'], **kwargs)
+ aux_neck = copy.deepcopy(neck)
+
+ # head
+ kwargs = {'input_shape': neck.out_shape}
+ yolo_head = create(cfg['yolo_head'], **kwargs)
+ aux_head = create(cfg['aux_head'], **kwargs)
+
+ return {
+ 'backbone': backbone,
+ 'neck': neck,
+ "yolo_head": yolo_head,
+ 'aux_head': aux_head,
+ }
+
+ def _forward(self):
+ body_feats = self.backbone(self.inputs)
+ neck_feats = self.neck(body_feats, self.for_mot)
+
+ if self.training:
+ if self.inputs['epoch_id'] >= self.detach_epoch:
+ aux_neck_feats = self.aux_neck([f.detach() for f in body_feats])
+ dual_neck_feats = (paddle.concat(
+ [f.detach(), aux_f], axis=1) for f, aux_f in
+ zip(neck_feats, aux_neck_feats))
+ else:
+ aux_neck_feats = self.aux_neck(body_feats)
+ dual_neck_feats = (paddle.concat(
+ [f, aux_f], axis=1) for f, aux_f in
+ zip(neck_feats, aux_neck_feats))
+ aux_cls_scores, aux_bbox_preds = self.aux_head(dual_neck_feats)
+ loss = self.yolo_head(
+ neck_feats,
+ self.inputs,
+ aux_pred=[aux_cls_scores, aux_bbox_preds])
+ return loss
+ else:
+ yolo_head_outs = self.yolo_head(neck_feats)
+ if self.post_process is not None:
+ bbox, bbox_num = self.post_process(
+ yolo_head_outs, self.yolo_head.mask_anchors,
+ self.inputs['im_shape'], self.inputs['scale_factor'])
+ else:
+ bbox, bbox_num = self.yolo_head.post_process(
+ yolo_head_outs, self.inputs['scale_factor'])
+ output = {'bbox': bbox, 'bbox_num': bbox_num}
+
+ return output
+
+ def get_loss(self):
+ return self._forward()
+
+ def get_pred(self):
+ return self._forward()
diff --git a/ppdet/modeling/assigners/task_aligned_assigner.py b/ppdet/modeling/assigners/task_aligned_assigner.py
index cb932c7886120258b0ca15824ca1d6040ee8c830..23af79439ae7074b1a0f7fd74c42c1866c4de6ce 100644
--- a/ppdet/modeling/assigners/task_aligned_assigner.py
+++ b/ppdet/modeling/assigners/task_aligned_assigner.py
@@ -28,17 +28,47 @@ from .utils import (gather_topk_anchors, check_points_inside_bboxes,
__all__ = ['TaskAlignedAssigner']
+def is_close_gt(anchor, gt, stride_lst, max_dist=2.0, alpha=2.):
+ """Calculate distance ratio of box1 and box2 in batch for larger stride
+ anchors dist/stride to promote the survive of large distance match
+ Args:
+ anchor (Tensor): box with the shape [L, 2]
+ gt (Tensor): box with the shape [N, M2, 4]
+ Return:
+ dist (Tensor): dist ratio between box1 and box2 with the shape [N, M1, M2]
+ """
+ center1 = anchor.unsqueeze(0)
+ center2 = (gt[..., :2] + gt[..., -2:]) / 2.
+ center1 = center1.unsqueeze(1) # [N, M1, 2] -> [N, 1, M1, 2]
+ center2 = center2.unsqueeze(2) # [N, M2, 2] -> [N, M2, 1, 2]
+
+ stride = paddle.concat([
+ paddle.full([x], 32 / pow(2, idx)) for idx, x in enumerate(stride_lst)
+ ]).unsqueeze(0).unsqueeze(0)
+ dist = paddle.linalg.norm(center1 - center2, p=2, axis=-1) / stride
+ dist_ratio = dist
+ dist_ratio[dist < max_dist] = 1.
+ dist_ratio[dist >= max_dist] = 0.
+ return dist_ratio
+
+
@register
class TaskAlignedAssigner(nn.Layer):
"""TOOD: Task-aligned One-stage Object Detection
"""
- def __init__(self, topk=13, alpha=1.0, beta=6.0, eps=1e-9):
+ def __init__(self,
+ topk=13,
+ alpha=1.0,
+ beta=6.0,
+ eps=1e-9,
+ is_close_gt=False):
super(TaskAlignedAssigner, self).__init__()
self.topk = topk
self.alpha = alpha
self.beta = beta
self.eps = eps
+ self.is_close_gt = is_close_gt
@paddle.no_grad()
def forward(self,
@@ -107,7 +137,10 @@ class TaskAlignedAssigner(nn.Layer):
self.beta)
# check the positive sample's center in gt, [B, n, L]
- is_in_gts = check_points_inside_bboxes(anchor_points, gt_bboxes)
+ if self.is_close_gt:
+ is_in_gts = is_close_gt(anchor_points, gt_bboxes, num_anchors_list)
+ else:
+ is_in_gts = check_points_inside_bboxes(anchor_points, gt_bboxes)
# select topk largest alignment metrics pred bbox as candidates
# for each gt, [B, n, L]
diff --git a/ppdet/modeling/heads/ppyoloe_head.py b/ppdet/modeling/heads/ppyoloe_head.py
index 699d601313e702d7b01e04b9ad461cf07047beca..d29e9ac73ac044ea2c65711dd895cf6e9cc2be8a 100644
--- a/ppdet/modeling/heads/ppyoloe_head.py
+++ b/ppdet/modeling/heads/ppyoloe_head.py
@@ -16,24 +16,29 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
+from paddle import ParamAttr
+from paddle.nn.initializer import KaimingNormal
+from paddle.nn.initializer import Normal, Constant
from ..bbox_utils import batch_distance2bbox
from ..losses import GIoULoss
from ..initializer import bias_init_with_prob, constant_, normal_
from ..assigners.utils import generate_anchors_for_grid_cell
-from ppdet.modeling.backbones.cspresnet import ConvBNLayer
+from ppdet.modeling.backbones.cspresnet import ConvBNLayer, RepVggBlock
from ppdet.modeling.ops import get_static_shape, get_act_fn
from ppdet.modeling.layers import MultiClassNMS
-__all__ = ['PPYOLOEHead']
+__all__ = ['PPYOLOEHead', 'SimpleConvHead']
class ESEAttn(nn.Layer):
- def __init__(self, feat_channels, act='swish'):
+ def __init__(self, feat_channels, act='swish', attn_conv='convbn'):
super(ESEAttn, self).__init__()
self.fc = nn.Conv2D(feat_channels, feat_channels, 1)
- self.conv = ConvBNLayer(feat_channels, feat_channels, 1, act=act)
-
+ if attn_conv == 'convbn':
+ self.conv = ConvBNLayer(feat_channels, feat_channels, 1, act=act)
+ else:
+ self.conv = RepVggBlock(feat_channels, feat_channels, act=act)
self._init_weights()
def _init_weights(self):
@@ -73,6 +78,7 @@ class PPYOLOEHead(nn.Layer):
'dfl': 0.5,
},
trt=False,
+ attn_conv='convbn',
exclude_nms=False,
exclude_post_process=False,
use_shared_conv=True):
@@ -112,8 +118,8 @@ class PPYOLOEHead(nn.Layer):
act, trt=trt) if act is None or isinstance(act,
(str, dict)) else act
for in_c in self.in_channels:
- self.stem_cls.append(ESEAttn(in_c, act=act))
- self.stem_reg.append(ESEAttn(in_c, act=act))
+ self.stem_cls.append(ESEAttn(in_c, act=act, attn_conv=attn_conv))
+ self.stem_reg.append(ESEAttn(in_c, act=act, attn_conv=attn_conv))
# pred head
self.pred_cls = nn.LayerList()
self.pred_reg = nn.LayerList()
@@ -151,7 +157,7 @@ class PPYOLOEHead(nn.Layer):
self.anchor_points = anchor_points
self.stride_tensor = stride_tensor
- def forward_train(self, feats, targets):
+ def forward_train(self, feats, targets, aux_pred=None):
anchors, anchor_points, num_anchors_list, stride_tensor = \
generate_anchors_for_grid_cell(
feats, self.fpn_strides, self.grid_cell_scale,
@@ -173,7 +179,7 @@ class PPYOLOEHead(nn.Layer):
return self.get_loss([
cls_score_list, reg_distri_list, anchors, anchor_points,
num_anchors_list, stride_tensor
- ], targets)
+ ], targets, aux_pred)
def _generate_anchors(self, feats=None, dtype='float32'):
# just use in eval time
@@ -231,12 +237,12 @@ class PPYOLOEHead(nn.Layer):
return cls_score_list, reg_dist_list, anchor_points, stride_tensor
- def forward(self, feats, targets=None):
+ def forward(self, feats, targets=None, aux_pred=None):
assert len(feats) == len(self.fpn_strides), \
"The size of feats is not equal to size of fpn_strides"
if self.training:
- return self.forward_train(feats, targets)
+ return self.forward_train(feats, targets, aux_pred)
else:
return self.forward_eval(feats)
@@ -321,13 +327,17 @@ class PPYOLOEHead(nn.Layer):
loss_dfl = pred_dist.sum() * 0.
return loss_l1, loss_iou, loss_dfl
- def get_loss(self, head_outs, gt_meta):
+ def get_loss(self, head_outs, gt_meta, aux_pred=None):
pred_scores, pred_distri, anchors,\
anchor_points, num_anchors_list, stride_tensor = head_outs
anchor_points_s = anchor_points / stride_tensor
pred_bboxes = self._bbox_decode(anchor_points_s, pred_distri)
+ if aux_pred is not None:
+ pred_scores_aux = aux_pred[0]
+ pred_bboxes_aux = self._bbox_decode(anchor_points_s, aux_pred[1])
+
gt_labels = gt_meta['gt_class']
gt_bboxes = gt_meta['gt_bbox']
pad_gt_mask = gt_meta['pad_gt_mask']
@@ -345,6 +355,7 @@ class PPYOLOEHead(nn.Layer):
alpha_l = 0.25
else:
if self.sm_use:
+ # only used in smalldet of PPYOLOE-SOD model
assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner(
pred_scores.detach(),
@@ -356,19 +367,51 @@ class PPYOLOEHead(nn.Layer):
pad_gt_mask,
bg_index=self.num_classes)
else:
- assigned_labels, assigned_bboxes, assigned_scores = \
- self.assigner(
- pred_scores.detach(),
- pred_bboxes.detach() * stride_tensor,
- anchor_points,
- num_anchors_list,
- gt_labels,
- gt_bboxes,
- pad_gt_mask,
- bg_index=self.num_classes)
+ if aux_pred is None:
+ assigned_labels, assigned_bboxes, assigned_scores = \
+ self.assigner(
+ pred_scores.detach(),
+ pred_bboxes.detach() * stride_tensor,
+ anchor_points,
+ num_anchors_list,
+ gt_labels,
+ gt_bboxes,
+ pad_gt_mask,
+ bg_index=self.num_classes)
+ else:
+ assigned_labels, assigned_bboxes, assigned_scores = \
+ self.assigner(
+ pred_scores_aux.detach(),
+ pred_bboxes_aux.detach() * stride_tensor,
+ anchor_points,
+ num_anchors_list,
+ gt_labels,
+ gt_bboxes,
+ pad_gt_mask,
+ bg_index=self.num_classes)
alpha_l = -1
# rescale bbox
assigned_bboxes /= stride_tensor
+
+ assign_out_dict = self.get_loss_from_assign(
+ pred_scores, pred_distri, pred_bboxes, anchor_points_s,
+ assigned_labels, assigned_bboxes, assigned_scores, alpha_l)
+
+ if aux_pred is not None:
+ assign_out_dict_aux = self.get_loss_from_assign(
+ aux_pred[0], aux_pred[1], pred_bboxes_aux, anchor_points_s,
+ assigned_labels, assigned_bboxes, assigned_scores, alpha_l)
+ loss = {}
+ for key in assign_out_dict.keys():
+ loss[key] = assign_out_dict[key] + assign_out_dict_aux[key]
+ else:
+ loss = assign_out_dict
+
+ return loss
+
+ def get_loss_from_assign(self, pred_scores, pred_distri, pred_bboxes,
+ anchor_points_s, assigned_labels, assigned_bboxes,
+ assigned_scores, alpha_l):
# cls loss
if self.use_varifocal_loss:
one_hot_label = F.one_hot(assigned_labels,
@@ -421,3 +464,169 @@ class PPYOLOEHead(nn.Layer):
else:
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num
+
+
+def get_activation(name="LeakyReLU"):
+ if name == "silu":
+ module = nn.Silu()
+ elif name == "relu":
+ module = nn.ReLU()
+ elif name in ["LeakyReLU", 'leakyrelu', 'lrelu']:
+ module = nn.LeakyReLU(0.1)
+ elif name is None:
+ module = nn.Identity()
+ else:
+ raise AttributeError("Unsupported act type: {}".format(name))
+ return module
+
+
+class ConvNormLayer(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ norm_type='gn',
+ activation="LeakyReLU"):
+ super(ConvNormLayer, self).__init__()
+ assert norm_type in ['bn', 'sync_bn', 'syncbn', 'gn', None]
+ self.conv = nn.Conv2D(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias_attr=False,
+ weight_attr=ParamAttr(initializer=KaimingNormal()))
+
+ if norm_type in ['bn', 'sync_bn', 'syncbn']:
+ self.norm = nn.BatchNorm2D(out_channels)
+ elif norm_type == 'gn':
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=out_channels)
+ else:
+ self.norm = None
+
+ self.act = get_activation(activation)
+
+ def forward(self, x):
+ y = self.conv(x)
+ if self.norm is not None:
+ y = self.norm(y)
+ y = self.act(y)
+ return y
+
+
+class ScaleReg(nn.Layer):
+ """
+ Parameter for scaling the regression outputs.
+ """
+
+ def __init__(self, scale=1.0):
+ super(ScaleReg, self).__init__()
+ scale = paddle.to_tensor(scale)
+ self.scale = self.create_parameter(
+ shape=[1],
+ dtype='float32',
+ default_initializer=nn.initializer.Assign(scale))
+
+ def forward(self, x):
+ return x * self.scale
+
+
+@register
+class SimpleConvHead(nn.Layer):
+ __shared__ = ['num_classes']
+
+ def __init__(self,
+ num_classes=80,
+ feat_in=288,
+ feat_out=288,
+ num_convs=1,
+ fpn_strides=[32, 16, 8, 4],
+ norm_type='gn',
+ act='LeakyReLU',
+ prior_prob=0.01,
+ reg_max=16):
+ super(SimpleConvHead, self).__init__()
+ self.num_classes = num_classes
+ self.feat_in = feat_in
+ self.feat_out = feat_out
+ self.num_convs = num_convs
+ self.fpn_strides = fpn_strides
+ self.reg_max = reg_max
+
+ self.cls_convs = nn.LayerList()
+ self.reg_convs = nn.LayerList()
+ for i in range(self.num_convs):
+ in_c = feat_in if i == 0 else feat_out
+ self.cls_convs.append(
+ ConvNormLayer(
+ in_c,
+ feat_out,
+ 3,
+ stride=1,
+ padding=1,
+ norm_type=norm_type,
+ activation=act))
+ self.reg_convs.append(
+ ConvNormLayer(
+ in_c,
+ feat_out,
+ 3,
+ stride=1,
+ padding=1,
+ norm_type=norm_type,
+ activation=act))
+
+ bias_cls = bias_init_with_prob(prior_prob)
+ self.gfl_cls = nn.Conv2D(
+ feat_out,
+ self.num_classes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ weight_attr=ParamAttr(initializer=Normal(
+ mean=0.0, std=0.01)),
+ bias_attr=ParamAttr(initializer=Constant(value=bias_cls)))
+ self.gfl_reg = nn.Conv2D(
+ feat_out,
+ 4 * (self.reg_max + 1),
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ weight_attr=ParamAttr(initializer=Normal(
+ mean=0.0, std=0.01)),
+ bias_attr=ParamAttr(initializer=Constant(value=0)))
+
+ self.scales = nn.LayerList()
+ for i in range(len(self.fpn_strides)):
+ self.scales.append(ScaleReg(1.0))
+
+ def forward(self, feats):
+ cls_scores = []
+ bbox_preds = []
+ for x, scale in zip(feats, self.scales):
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ reg_feat = reg_conv(reg_feat)
+
+ cls_score = self.gfl_cls(cls_feat)
+ cls_score = F.sigmoid(cls_score)
+ cls_score = cls_score.flatten(2).transpose([0, 2, 1])
+ cls_scores.append(cls_score)
+
+ bbox_pred = scale(self.gfl_reg(reg_feat))
+ bbox_pred = bbox_pred.flatten(2).transpose([0, 2, 1])
+ bbox_preds.append(bbox_pred)
+
+ cls_scores = paddle.concat(cls_scores, axis=1)
+ bbox_preds = paddle.concat(bbox_preds, axis=1)
+ return cls_scores, bbox_preds