未验证 提交 1e21400e 编写于 作者: F Feng Ni 提交者: GitHub

Add ppyoloe semi-det base codes (#7680)

* add ppyoloe semi-det base codes

* fix configs

* fix head distill loss

* add more semi_det configs and fix doc, test=document_fix

* add contrast_loss config, test=document_fix
上级 a496c2dd
......@@ -21,6 +21,9 @@
| PP-YOLOE+_s | 5% | 80 (7200) | 32.8 | [download](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_s_80e_coco_sup005.pdparams) | [config](ppyoloe_plus_crn_s_80e_coco_sup005.yml) |
| PP-YOLOE+_s | 10% | 80 (14480) | 35.3 | [download](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_s_80e_coco_sup010.pdparams) | [config](ppyoloe_plus_crn_s_80e_coco_sup010.yml) |
| PP-YOLOE+_s | full | 80 (146560) | 43.7 | [download](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_s_80e_coco.pdparams) | [config](../../ppyoloe/ppyoloe_plus_crn_s_80e_coco.yml) |
| PP-YOLOE+_l | 5% | 80 (7200) | 42.9 | [download](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_l_80e_coco_sup005.pdparams) | [config](ppyoloe_plus_crn_l_80e_coco_sup005.yml) |
| PP-YOLOE+_l | 10% | 80 (14480) | 45.7 | [download](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_l_80e_coco_sup010.pdparams) | [config](ppyoloe_plus_crn_l_80e_coco_sup010.yml) |
| PP-YOLOE+_l | full | 80 (146560) | 49.8 | [download](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_l_80e_coco.pdparams) | [config](../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml) |
**注意:**
- 以上模型训练默认使用8 GPUs,总batch_size默认为64,默认初始学习率为0.001。如果改动了总batch_size,请按线性比例相应地调整学习率。
......
_BASE_: [
'../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml',
]
log_iter: 50
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_l_80e_coco_sup005/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_l_obj365_pretrained.pdparams
depth_mult: 1.0
width_mult: 1.0
TrainDataset:
!COCODataSet
image_dir: train2017
anno_path: semi_annotations/instances_train2017.1@5.json
dataset_dir: dataset/coco
data_fields: ['image', 'gt_bbox', 'gt_class']
epoch: 80
LearningRate:
base_lr: 0.001
schedulers:
- !CosineDecay
max_epochs: 96
- !LinearWarmup
start_factor: 0.
epochs: 5
_BASE_: [
'../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml',
]
log_iter: 50
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_l_80e_coco_sup010/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_l_obj365_pretrained.pdparams
depth_mult: 1.0
width_mult: 1.0
TrainDataset:
!COCODataSet
image_dir: train2017
anno_path: semi_annotations/instances_train2017.1@10.json
dataset_dir: dataset/coco
data_fields: ['image', 'gt_bbox', 'gt_class']
epoch: 80
LearningRate:
base_lr: 0.001
schedulers:
- !CosineDecay
max_epochs: 96
- !LinearWarmup
start_factor: 0.
epochs: 5
......@@ -2,7 +2,7 @@
# Dense Teacher: Dense Pseudo-Labels for Semi-supervised Object Detection
## 模型库
## FCOS模型库
| 模型 | 监督数据比例 | Sup Baseline | Sup Epochs (Iters) | Sup mAP<sup>val<br>0.5:0.95 | Semi mAP<sup>val<br>0.5:0.95 | Semi Epochs (Iters) | 模型下载 | 配置文件 |
| :------------: | :---------: | :---------------------: | :---------------------: |:---------------------------: |:----------------------------: | :------------------: |:--------: |:----------: |
......@@ -34,6 +34,16 @@
```
## PPYOLOE+ 模型库
| 模型 | 监督数据比例 | Sup Baseline | Sup Epochs (Iters) | Sup mAP<sup>val<br>0.5:0.95 | Semi mAP<sup>val<br>0.5:0.95 | Semi Epochs (Iters) | 模型下载 | 配置文件 |
| :------------: | :---------: | :---------------------: | :---------------------: |:---------------------------: |:----------------------------: | :------------------: |:--------: |:----------: |
| DenseTeacher-PPYOLOE+_s | 5% | [sup_config](../baseline/ppyoloe_plus_crn_s_80e_coco_sup005.yml) | 80 (14480) | 32.8 | **34.0** | 200 (36200) | [download](https://paddledet.bj.bcebos.com/models/denseteacher_ppyoloe_plus_crn_s_coco_semi005.pdparams) | [config](./denseteacher_ppyoloe_plus_crn_s_coco_semi005.yml) |
| DenseTeacher-PPYOLOE+_s | 10% | [sup_config](../baseline/ppyoloe_plus_crn_s_80e_coco_sup010.yml) | 80 (14480) | 35.3 | **37.5** | 200 (36200) | [download](https://paddledet.bj.bcebos.com/models/denseteacher_ppyoloe_plus_crn_s_coco_semi010.pdparams) | [config](./denseteacher_ppyoloe_plus_crn_s_coco_semi010.yml) |
| DenseTeacher-PPYOLOE+_l | 5% | [sup_config](../baseline/ppyoloe_plus_crn_s_80e_coco_sup005.yml) | 80 (14480) | 42.9 | **45.4** | 200 (36200) | [download](https://paddledet.bj.bcebos.com/models/denseteacher_ppyoloe_plus_crn_l_coco_semi005.pdparams) | [config](./denseteacher_ppyoloe_plus_crn_l_coco_semi005.yml) |
| DenseTeacher-PPYOLOE+_l | 10% | [sup_config](../baseline/ppyoloe_plus_crn_l_80e_coco_sup010.yml) | 80 (14480) | 45.7 | **47.4** | 200 (36200) | [download](https://paddledet.bj.bcebos.com/models/denseteacher_ppyoloe_plus_crn_l_coco_semi010.pdparams) | [config](./denseteacher_ppyoloe_plus_crn_l_coco_semi010.yml) |
## 使用说明
仅训练时必须使用半监督检测的配置文件去训练,评估、预测、部署也可以按基础检测器的配置文件去执行。
......
_BASE_: [
'../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml',
'../_base_/coco_detection_percent_5.yml',
]
log_iter: 50
snapshot_epoch: 5
weights: output/denseteacher_ppyoloe_plus_crn_l_coco_semi005/model_final
epochs: &epochs 200
cosine_epochs: &cosine_epochs 240
### pretrain and warmup config, choose one and comment another
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/semi_det/ppyoloe_plus_crn_l_80e_coco_sup005.pdparams # mAP=42.9
semi_start_iters: 0
ema_start_iters: 0
use_warmup: &use_warmup False
# pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_l_obj365_pretrained.pdparams
# semi_start_iters: 5000
# ema_start_iters: 3000
# use_warmup: &use_warmup True
### global config
use_simple_ema: True
ema_decay: 0.9996
ssod_method: DenseTeacher
DenseTeacher:
train_cfg:
sup_weight: 1.0
unsup_weight: 1.0
loss_weight: {distill_loss_cls: 1.0, distill_loss_iou: 2.5, distill_loss_dfl: 0., distill_loss_contrast: 0.1}
contrast_loss:
temperature: 0.2
alpha: 0.9
smooth_iter: 100
concat_sup_data: True
suppress: linear
ratio: 0.01
test_cfg:
inference_on: teacher
### reader config
batch_size: &batch_size 8
worker_num: 2
SemiTrainReader:
sample_transforms:
- Decode: {}
- RandomDistort: {}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomFlip: {}
- RandomCrop: {} # unsup will be fake gt_boxes
weak_aug:
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], is_scale: true, norm_type: none}
strong_aug:
- StrongAugImage: {transforms: [
RandomColorJitter: {prob: 0.8, brightness: 0.4, contrast: 0.4, saturation: 0.4, hue: 0.1},
RandomErasingCrop: {},
RandomGaussianBlur: {prob: 0.5, sigma: [0.1, 2.0]},
RandomGrayscale: {prob: 0.2},
]}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], is_scale: true, norm_type: none}
sup_batch_transforms:
- BatchRandomResize: {target_size: [640], random_size: True, random_interp: True, keep_ratio: False}
- Permute: {}
- PadGT: {}
unsup_batch_transforms:
- BatchRandomResize: {target_size: [640], random_size: True, random_interp: True, keep_ratio: False}
- Permute: {}
sup_batch_size: *batch_size
unsup_batch_size: *batch_size
shuffle: True
drop_last: True
collate_batch: True
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 2
TestReader:
inputs_def:
image_shape: [3, 640, 640]
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 1
### model config
architecture: PPYOLOE
norm_type: sync_bn
ema_black_list: ['proj_conv.weight']
custom_black_list: ['reduce_mean']
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
eval_size: ~ # means None, but not str 'None'
PPYOLOEHead:
fpn_strides: [32, 16, 8]
grid_cell_scale: 5.0
grid_cell_offset: 0.5
static_assigner_epoch: -1 #
use_varifocal_loss: True
loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 300
score_threshold: 0.01
nms_threshold: 0.7
### other config
epoch: *epochs
LearningRate:
base_lr: 0.01
schedulers:
- !CosineDecay
max_epochs: *cosine_epochs
use_warmup: *use_warmup
- !LinearWarmup
start_factor: 0.001
epochs: 3
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005 # dt-fcos 0.0001
type: L2
clip_grad_by_norm: 1.0 # dt-fcos clip_grad_by_value
_BASE_: [
'../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml',
'../_base_/coco_detection_percent_10.yml',
]
log_iter: 50
snapshot_epoch: 5
weights: output/denseteacher_ppyoloe_plus_crn_l_coco_semi010/model_final
epochs: &epochs 200
cosine_epochs: &cosine_epochs 240
### pretrain and warmup config, choose one and comment another
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/semi_det/ppyoloe_plus_crn_l_80e_coco_sup010.pdparams # mAP=45.7
semi_start_iters: 0
ema_start_iters: 0
use_warmup: &use_warmup False
# pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_l_obj365_pretrained.pdparams
# semi_start_iters: 5000
# ema_start_iters: 3000
# use_warmup: &use_warmup True
### global config
use_simple_ema: True
ema_decay: 0.9996
ssod_method: DenseTeacher
DenseTeacher:
train_cfg:
sup_weight: 1.0
unsup_weight: 1.0
loss_weight: {distill_loss_cls: 1.0, distill_loss_iou: 2.5, distill_loss_dfl: 0., distill_loss_contrast: 0.1}
contrast_loss:
temperature: 0.2
alpha: 0.9
smooth_iter: 100
concat_sup_data: True
suppress: linear
ratio: 0.01
test_cfg:
inference_on: teacher
### reader config
batch_size: &batch_size 8
worker_num: 2
SemiTrainReader:
sample_transforms:
- Decode: {}
- RandomDistort: {}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomFlip: {}
- RandomCrop: {} # unsup will be fake gt_boxes
weak_aug:
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], is_scale: true, norm_type: none}
strong_aug:
- StrongAugImage: {transforms: [
RandomColorJitter: {prob: 0.8, brightness: 0.4, contrast: 0.4, saturation: 0.4, hue: 0.1},
RandomErasingCrop: {},
RandomGaussianBlur: {prob: 0.5, sigma: [0.1, 2.0]},
RandomGrayscale: {prob: 0.2},
]}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], is_scale: true, norm_type: none}
sup_batch_transforms:
- BatchRandomResize: {target_size: [640], random_size: True, random_interp: True, keep_ratio: False}
- Permute: {}
- PadGT: {}
unsup_batch_transforms:
- BatchRandomResize: {target_size: [640], random_size: True, random_interp: True, keep_ratio: False}
- Permute: {}
sup_batch_size: *batch_size
unsup_batch_size: *batch_size
shuffle: True
drop_last: True
collate_batch: True
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 2
TestReader:
inputs_def:
image_shape: [3, 640, 640]
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 1
### model config
architecture: PPYOLOE
norm_type: sync_bn
ema_black_list: ['proj_conv.weight']
custom_black_list: ['reduce_mean']
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
eval_size: ~ # means None, but not str 'None'
PPYOLOEHead:
fpn_strides: [32, 16, 8]
grid_cell_scale: 5.0
grid_cell_offset: 0.5
static_assigner_epoch: -1 #
use_varifocal_loss: True
loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 300
score_threshold: 0.01
nms_threshold: 0.7
### other config
epoch: *epochs
LearningRate:
base_lr: 0.01
schedulers:
- !CosineDecay
max_epochs: *cosine_epochs
use_warmup: *use_warmup
- !LinearWarmup
start_factor: 0.001
epochs: 3
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005 # dt-fcos 0.0001
type: L2
clip_grad_by_norm: 1.0 # dt-fcos clip_grad_by_value
_BASE_: [
'../../ppyoloe/ppyoloe_plus_crn_s_80e_coco.yml',
'../_base_/coco_detection_percent_5.yml',
]
log_iter: 50
snapshot_epoch: 5
weights: output/denseteacher_ppyoloe_plus_crn_s_coco_semi005/model_final
epochs: &epochs 200
cosine_epochs: &cosine_epochs 240
### pretrain and warmup config, choose one and comment another
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/semi_det/ppyoloe_plus_crn_s_80e_coco_sup005.pdparams # mAP=32.8
semi_start_iters: 0
ema_start_iters: 0
use_warmup: &use_warmup False
# pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_s_obj365_pretrained.pdparams
# semi_start_iters: 5000
# ema_start_iters: 3000
# use_warmup: &use_warmup True
### global config
use_simple_ema: True
ema_decay: 0.9996
ssod_method: DenseTeacher
DenseTeacher:
train_cfg:
sup_weight: 1.0
unsup_weight: 1.0
loss_weight: {distill_loss_cls: 1.0, distill_loss_iou: 2.5, distill_loss_dfl: 0., distill_loss_contrast: 0.1}
contrast_loss:
temperature: 0.2
alpha: 0.9
smooth_iter: 100
concat_sup_data: True
suppress: linear
ratio: 0.01
test_cfg:
inference_on: teacher
### reader config
batch_size: &batch_size 8
worker_num: 2
SemiTrainReader:
sample_transforms:
- Decode: {}
- RandomDistort: {}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomFlip: {}
- RandomCrop: {} # unsup will be fake gt_boxes
weak_aug:
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], is_scale: true, norm_type: none}
strong_aug:
- StrongAugImage: {transforms: [
RandomColorJitter: {prob: 0.8, brightness: 0.4, contrast: 0.4, saturation: 0.4, hue: 0.1},
RandomErasingCrop: {},
RandomGaussianBlur: {prob: 0.5, sigma: [0.1, 2.0]},
RandomGrayscale: {prob: 0.2},
]}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], is_scale: true, norm_type: none}
sup_batch_transforms:
- BatchRandomResize: {target_size: [640], random_size: True, random_interp: True, keep_ratio: False}
- Permute: {}
- PadGT: {}
unsup_batch_transforms:
- BatchRandomResize: {target_size: [640], random_size: True, random_interp: True, keep_ratio: False}
- Permute: {}
sup_batch_size: *batch_size
unsup_batch_size: *batch_size
shuffle: True
drop_last: True
collate_batch: True
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 2
TestReader:
inputs_def:
image_shape: [3, 640, 640]
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 1
### model config
architecture: PPYOLOE
norm_type: sync_bn
ema_black_list: ['proj_conv.weight']
custom_black_list: ['reduce_mean']
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
eval_size: ~ # means None, but not str 'None'
PPYOLOEHead:
fpn_strides: [32, 16, 8]
grid_cell_scale: 5.0
grid_cell_offset: 0.5
static_assigner_epoch: -1 #
use_varifocal_loss: True
loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 300
score_threshold: 0.01
nms_threshold: 0.7
### other config
epoch: *epochs
LearningRate:
base_lr: 0.01
schedulers:
- !CosineDecay
max_epochs: *cosine_epochs
use_warmup: *use_warmup
- !LinearWarmup
start_factor: 0.001
epochs: 3
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005 # dt-fcos 0.0001
type: L2
clip_grad_by_norm: 1.0 # dt-fcos clip_grad_by_value
_BASE_: [
'../../ppyoloe/ppyoloe_plus_crn_s_80e_coco.yml',
'../_base_/coco_detection_percent_10.yml',
]
log_iter: 50
snapshot_epoch: 5
weights: output/denseteacher_ppyoloe_plus_crn_s_coco_semi010/model_final
epochs: &epochs 200
cosine_epochs: &cosine_epochs 240
### pretrain and warmup config, choose one and comment another
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/semi_det/ppyoloe_plus_crn_s_80e_coco_sup010.pdparams # mAP=35.3
semi_start_iters: 0
ema_start_iters: 0
use_warmup: &use_warmup False
# pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_s_obj365_pretrained.pdparams
# semi_start_iters: 5000
# ema_start_iters: 3000
# use_warmup: &use_warmup True
### global config
use_simple_ema: True
ema_decay: 0.9996
ssod_method: DenseTeacher
DenseTeacher:
train_cfg:
sup_weight: 1.0
unsup_weight: 1.0
loss_weight: {distill_loss_cls: 1.0, distill_loss_iou: 2.5, distill_loss_dfl: 0., distill_loss_contrast: 0.1}
contrast_loss:
temperature: 0.2
alpha: 0.9
smooth_iter: 100
concat_sup_data: True
suppress: linear
ratio: 0.01
test_cfg:
inference_on: teacher
### reader config
batch_size: &batch_size 8
worker_num: 2
SemiTrainReader:
sample_transforms:
- Decode: {}
- RandomDistort: {}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomFlip: {}
- RandomCrop: {} # unsup will be fake gt_boxes
weak_aug:
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], is_scale: true, norm_type: none}
strong_aug:
- StrongAugImage: {transforms: [
RandomColorJitter: {prob: 0.8, brightness: 0.4, contrast: 0.4, saturation: 0.4, hue: 0.1},
RandomErasingCrop: {},
RandomGaussianBlur: {prob: 0.5, sigma: [0.1, 2.0]},
RandomGrayscale: {prob: 0.2},
]}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], is_scale: true, norm_type: none}
sup_batch_transforms:
- BatchRandomResize: {target_size: [640], random_size: True, random_interp: True, keep_ratio: False}
- Permute: {}
- PadGT: {}
unsup_batch_transforms:
- BatchRandomResize: {target_size: [640], random_size: True, random_interp: True, keep_ratio: False}
- Permute: {}
sup_batch_size: *batch_size
unsup_batch_size: *batch_size
shuffle: True
drop_last: True
collate_batch: True
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 2
TestReader:
inputs_def:
image_shape: [3, 640, 640]
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 1
### model config
architecture: PPYOLOE
norm_type: sync_bn
ema_black_list: ['proj_conv.weight']
custom_black_list: ['reduce_mean']
PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
eval_size: ~ # means None, but not str 'None'
PPYOLOEHead:
fpn_strides: [32, 16, 8]
grid_cell_scale: 5.0
grid_cell_offset: 0.5
static_assigner_epoch: -1 #
use_varifocal_loss: True
loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 300
score_threshold: 0.01
nms_threshold: 0.7
### other config
epoch: *epochs
LearningRate:
base_lr: 0.01
schedulers:
- !CosineDecay
max_epochs: *cosine_epochs
use_warmup: *use_warmup
- !LinearWarmup
start_factor: 0.001
epochs: 3
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005 # dt-fcos 0.0001
type: L2
clip_grad_by_norm: 1.0 # dt-fcos clip_grad_by_value
......@@ -1420,10 +1420,38 @@ class RandomCrop(BaseOperator):
crop_segms.append(_crop_rle(segm, crop, height, width))
return crop_segms
def set_fake_bboxes(self, sample):
sample['gt_bbox'] = np.array(
[
[32, 32, 128, 128],
[32, 32, 128, 256],
[32, 64, 128, 128],
[32, 64, 128, 256],
[64, 64, 128, 256],
[64, 64, 256, 256],
[64, 32, 128, 256],
[64, 32, 128, 256],
[96, 32, 128, 256],
[96, 32, 128, 256],
],
dtype=np.float32)
sample['gt_class'] = np.array(
[[1], [2], [3], [4], [5], [6], [7], [8], [9], [10]], np.int32)
return sample
def apply(self, sample, context=None):
if 'gt_bbox' not in sample:
# only used in semi-det as unsup data
sample = self.set_fake_bboxes(sample)
sample = self.random_crop(sample, fake_bboxes=True)
return sample
if 'gt_bbox' in sample and len(sample['gt_bbox']) == 0:
return sample
sample = self.random_crop(sample)
return sample
def random_crop(self, sample, fake_bboxes=False):
h, w = sample['image'].shape[:2]
gt_bbox = sample['gt_bbox']
......@@ -1515,6 +1543,9 @@ class RandomCrop(BaseOperator):
sample['gt_segm'], valid_ids, axis=0)
sample['image'] = self._crop_image(sample['image'], crop_box)
if fake_bboxes == True:
return sample
sample['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
sample['gt_class'] = np.take(
sample['gt_class'], valid_ids, axis=0)
......
......@@ -16,12 +16,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import copy
import time
import typing
import math
import numpy as np
import paddle
......@@ -317,16 +314,14 @@ class Trainer_DenseTeacher(Trainer):
data_unsup_w['is_teacher'] = True
teacher_preds = self.ema.model(data_unsup_w)
train_cfg['curr_iter'] = curr_iter
train_cfg['st_iter'] = st_iter
if self._nranks > 1:
loss_dict_unsup = self.model._layers.get_distill_loss(
student_preds,
teacher_preds,
ratio=train_cfg['ratio'])
loss_dict_unsup = self.model._layers.get_ssod_distill_loss(
student_preds, teacher_preds, train_cfg)
else:
loss_dict_unsup = self.model.get_distill_loss(
student_preds,
teacher_preds,
ratio=train_cfg['ratio'])
loss_dict_unsup = self.model.get_ssod_distill_loss(
student_preds, teacher_preds, train_cfg)
fg_num = loss_dict_unsup["fg_sum"]
del loss_dict_unsup["fg_sum"]
......
......@@ -85,12 +85,10 @@ class FCOS(BaseArch):
def get_loss_keys(self):
return ['loss_cls', 'loss_box', 'loss_quality']
def get_distill_loss(self,
fcos_head_outs,
teacher_fcos_head_outs,
ratio=0.01):
student_logits, student_deltas, student_quality = fcos_head_outs
teacher_logits, teacher_deltas, teacher_quality = teacher_fcos_head_outs
def get_ssod_distill_loss(self, student_head_outs, teacher_head_outs,
train_cfg):
student_logits, student_deltas, student_quality = student_head_outs
teacher_logits, teacher_deltas, teacher_quality = teacher_head_outs
nc = student_logits[0].shape[1]
student_logits = paddle.concat(
......@@ -132,6 +130,7 @@ class FCOS(BaseArch):
],
axis=0)
ratio = train_cfg.get('ratio', 0.01)
with paddle.no_grad():
# Region Selection
count_num = int(teacher_logits.shape[0] * ratio)
......
......@@ -16,10 +16,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import copy
import paddle
import paddle.nn.functional as F
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
from ..ssod_utils import QFLv2
from ..losses import GIoULoss
__all__ = ['PPYOLOE', 'PPYOLOEWithAuxHead']
# PP-YOLOE and PP-YOLOE+ are recommended to use this architecture, especially when use distillation or aux head
......@@ -57,6 +61,11 @@ class PPYOLOE(BaseArch):
self.yolo_head = yolo_head
self.post_process = post_process
self.for_mot = for_mot
# semi-det
self.is_teacher = False
# distill
self.for_distill = for_distill
self.feat_distill_place = feat_distill_place
if for_distill:
......@@ -85,7 +94,8 @@ class PPYOLOE(BaseArch):
body_feats = self.backbone(self.inputs)
neck_feats = self.neck(body_feats, self.for_mot)
if self.training:
self.is_teacher = self.inputs.get('is_teacher', False) # for semi-det
if self.training or self.is_teacher:
yolo_losses = self.yolo_head(neck_feats, self.inputs)
if self.for_distill:
......@@ -121,6 +131,110 @@ class PPYOLOE(BaseArch):
def get_pred(self):
return self._forward()
def get_loss_keys(self):
return ['loss_cls', 'loss_iou', 'loss_dfl', 'loss_contrast']
def get_ssod_distill_loss(self, student_head_outs, teacher_head_outs,
train_cfg):
# for semi-det distill
# student_probs: already sigmoid
student_probs, student_deltas, student_dfl = student_head_outs
teacher_probs, teacher_deltas, teacher_dfl = teacher_head_outs
bs, l, nc = student_probs.shape[:]
student_probs = student_probs.reshape([-1, nc])
teacher_probs = teacher_probs.reshape([-1, nc])
student_deltas = student_deltas.reshape([-1, 4])
teacher_deltas = teacher_deltas.reshape([-1, 4])
student_dfl = student_dfl.reshape([-1, 4, self.yolo_head.reg_channels])
teacher_dfl = teacher_dfl.reshape([-1, 4, self.yolo_head.reg_channels])
ratio = train_cfg.get('ratio', 0.01)
# for contrast loss
curr_iter = train_cfg['curr_iter']
st_iter = train_cfg['st_iter']
if curr_iter == st_iter + 1:
# start semi-det training
self.queue_ptr = 0
self.queue_size = int(bs * l * ratio)
self.queue_feats = paddle.zeros([self.queue_size, nc])
self.queue_probs = paddle.zeros([self.queue_size, nc])
contrast_loss_cfg = train_cfg['contrast_loss']
temperature = contrast_loss_cfg.get('temperature', 0.2)
alpha = contrast_loss_cfg.get('alpha', 0.9)
smooth_iter = contrast_loss_cfg.get('smooth_iter', 100) + st_iter
with paddle.no_grad():
# Region Selection
count_num = int(teacher_probs.shape[0] * ratio)
max_vals = paddle.max(teacher_probs, 1)
sorted_vals, sorted_inds = paddle.topk(max_vals,
teacher_probs.shape[0])
mask = paddle.zeros_like(max_vals)
mask[sorted_inds[:count_num]] = 1.
fg_num = sorted_vals[:count_num].sum()
b_mask = mask > 0.
# for contrast loss
probs = teacher_probs[b_mask].detach()
if curr_iter > smooth_iter: # memory-smoothing
A = paddle.exp(
paddle.mm(teacher_probs[b_mask], self.queue_probs.t()) /
temperature)
A = A / A.sum(1, keepdim=True)
probs = alpha * probs + (1 - alpha) * paddle.mm(
A, self.queue_probs)
n = student_probs[b_mask].shape[0]
# update memory bank
self.queue_feats[self.queue_ptr:self.queue_ptr +
n, :] = teacher_probs[b_mask].detach()
self.queue_probs[self.queue_ptr:self.queue_ptr +
n, :] = teacher_probs[b_mask].detach()
self.queue_ptr = (self.queue_ptr + n) % self.queue_size
# embedding similarity
sim = paddle.exp(
paddle.mm(student_probs[b_mask], teacher_probs[b_mask].t()) / 0.2)
sim_probs = sim / sim.sum(1, keepdim=True)
# pseudo-label graph with self-loop
Q = paddle.mm(probs, probs.t())
Q.fill_diagonal_(1)
pos_mask = (Q >= 0.5).astype('float32')
Q = Q * pos_mask
Q = Q / Q.sum(1, keepdim=True)
# contrastive loss
loss_contrast = -(paddle.log(sim_probs + 1e-7) * Q).sum(1)
loss_contrast = loss_contrast.mean()
# distill_loss_cls
loss_cls = QFLv2(
student_probs, teacher_probs, weight=mask, reduction="sum") / fg_num
# distill_loss_iou
inputs = paddle.concat(
(-student_deltas[b_mask][..., :2], student_deltas[b_mask][..., 2:]),
-1)
targets = paddle.concat(
(-teacher_deltas[b_mask][..., :2], teacher_deltas[b_mask][..., 2:]),
-1)
iou_loss = GIoULoss(reduction='mean')
loss_iou = iou_loss(inputs, targets)
# distill_loss_dfl
loss_dfl = F.cross_entropy(
student_dfl[b_mask].reshape([-1, self.yolo_head.reg_channels]),
teacher_dfl[b_mask].reshape([-1, self.yolo_head.reg_channels]),
soft_label=True,
reduction='mean')
return {
"distill_loss_cls": loss_cls,
"distill_loss_iou": loss_iou,
"distill_loss_dfl": loss_dfl,
"distill_loss_contrast": loss_contrast,
"fg_sum": fg_num,
}
@register
class PPYOLOEWithAuxHead(BaseArch):
......
......@@ -112,6 +112,7 @@ class PPYOLOEHead(nn.Layer):
self.exclude_post_process = exclude_post_process
self.use_shared_conv = use_shared_conv
self.for_distill = for_distill
self.is_teacher = False
# stem
self.stem_cls = nn.LayerList()
......@@ -181,6 +182,14 @@ class PPYOLOEHead(nn.Layer):
cls_score_list = paddle.concat(cls_score_list, axis=1)
reg_distri_list = paddle.concat(reg_distri_list, axis=1)
if targets.get('is_teacher', False):
pred_deltas, pred_dfls = self._bbox_decode_fake(reg_distri_list)
return cls_score_list, pred_deltas * stride_tensor, pred_dfls
if targets.get('get_data', False):
pred_deltas, pred_dfls = self._bbox_decode_fake(reg_distri_list)
return cls_score_list, pred_deltas * stride_tensor, pred_dfls
return self.get_loss([
cls_score_list, reg_distri_list, anchors, anchor_points,
num_anchors_list, stride_tensor
......@@ -249,6 +258,14 @@ class PPYOLOEHead(nn.Layer):
if self.training:
return self.forward_train(feats, targets, aux_pred)
else:
if targets is not None:
# only for semi-det
self.is_teacher = targets.get('is_teacher', False)
if self.is_teacher:
return self.forward_train(feats, targets, aux_pred=None)
else:
return self.forward_eval(feats)
return self.forward_eval(feats)
@staticmethod
......@@ -274,6 +291,14 @@ class PPYOLOEHead(nn.Layer):
pred_dist = self.proj_conv(pred_dist.transpose([0, 3, 1, 2])).squeeze(1)
return batch_distance2bbox(anchor_points, pred_dist)
def _bbox_decode_fake(self, pred_dist):
_, l, _ = get_static_shape(pred_dist)
pred_dist_dfl = F.softmax(
pred_dist.reshape([-1, l, 4, self.reg_channels]))
pred_dist = self.proj_conv(pred_dist_dfl.transpose([0, 3, 1, 2
])).squeeze(1)
return pred_dist, pred_dist_dfl
def _bbox2distance(self, points, bbox):
x1y1, x2y2 = paddle.split(bbox, 2, -1)
lt = points - x1y1
......@@ -388,11 +413,13 @@ class PPYOLOEHead(nn.Layer):
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes)
self.assigned_labels = assigned_labels
self.assigned_bboxes = assigned_bboxes
self.assigned_scores = assigned_scores
self.mask_positive = mask_positive
if self.for_distill:
self.assigned_labels = assigned_labels
self.assigned_bboxes = assigned_bboxes
self.assigned_scores = assigned_scores
self.mask_positive = mask_positive
else:
# only used in distill
assigned_labels = self.assigned_labels
assigned_bboxes = self.assigned_bboxes
assigned_scores = self.assigned_scores
......
......@@ -35,12 +35,12 @@ def align_weak_strong_shape(data_weak, data_strong):
mode='bilinear',
align_corners=False)
if 'gt_bbox' in data_strong:
gt_bboxes = data_strong['gt_bbox']
gt_bboxes = data_strong['gt_bbox'].numpy()
for i in range(len(gt_bboxes)):
if len(gt_bboxes[i]) > 0:
gt_bboxes[i][:, 0::2] = gt_bboxes[i][:, 0::2] * scale_x_s
gt_bboxes[i][:, 1::2] = gt_bboxes[i][:, 1::2] * scale_y_s
data_strong['gt_bbox'] = gt_bboxes
data_strong['gt_bbox'] = paddle.to_tensor(gt_bboxes)
if scale_x_w != 1 or scale_y_w != 1:
data_weak['image'] = F.interpolate(
......@@ -49,12 +49,12 @@ def align_weak_strong_shape(data_weak, data_strong):
mode='bilinear',
align_corners=False)
if 'gt_bbox' in data_weak:
gt_bboxes = data_weak['gt_bbox']
gt_bboxes = data_weak['gt_bbox'].numpy()
for i in range(len(gt_bboxes)):
if len(gt_bboxes[i]) > 0:
gt_bboxes[i][:, 0::2] = gt_bboxes[i][:, 0::2] * scale_x_w
gt_bboxes[i][:, 1::2] = gt_bboxes[i][:, 1::2] * scale_y_w
data_weak['gt_bbox'] = gt_bboxes
data_weak['gt_bbox'] = paddle.to_tensor(gt_bboxes)
return data_weak, data_strong
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册