未验证 提交 55cc99b1 编写于 作者: F Feng Ni 提交者: GitHub

Add PPYOLOE tiny and PPYOLOE+ tiny with aux_head (#7649)

* add ppyoloe plus tiny with aux_head

* fix ppyoloe tiny training

* fix ppyoloe+ tiny p2 model

* add ppyoloe t p2 model
上级 da166b0e
......@@ -11,6 +11,8 @@ PaddleDetection团队提供了针对行人的基于PP-YOLOE的检测模型,用
|PP-YOLOE-l| CrowdHuman | 48.0 | 81.9 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_36e_crowdhuman.pdparams) | [配置文件](./ppyoloe_crn_l_36e_crowdhuman.yml) |
|PP-YOLOE-s| 业务数据集 | 53.2 | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_s_36e_pipeline.zip) | [配置文件](./ppyoloe_crn_s_36e_pphuman.yml) |
|PP-YOLOE-l| 业务数据集 | 57.8 | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_l_36e_pipeline.zip) | [配置文件](./ppyoloe_crn_l_36e_pphuman.yml) |
|PP-YOLOE+_t-P2(320)| 业务数据集 | 49.8 | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_t_p2_60e_pipeline.zip) | [配置文件](./ppyoloe_plus_crn_t_p2_60e_pphuman.yml) |
|PP-YOLOE+_t-P2(416)| 业务数据集 | 52.2 | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_t_p2_60e_pipeline.zip) | [配置文件](./ppyoloe_plus_crn_t_p2_60e_pphuman.yml) |
**注意:**
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'../ppyoloe/_base_/optimizer_300e.yml',
'../ppyoloe/_base_/ppyoloe_plus_crn_tiny_auxhead.yml',
'../ppyoloe/_base_/ppyoloe_plus_reader_tiny.yml',
]
log_iter: 100
snapshot_epoch: 4
weights: output/ppyoloe_plus_crn_tiny_60e_pphuman/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_tiny_auxhead_300e_coco.pdparams
depth_mult: 0.33
width_mult: 0.375
num_classes: 1
TrainDataset:
!COCODataSet
image_dir: ""
anno_path: annotations/train.json
dataset_dir: dataset/pphuman
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
EvalDataset:
!COCODataSet
image_dir: ""
anno_path: annotations/val.json
dataset_dir: dataset/pphuman
TestDataset:
!ImageFolder
anno_path: annotations/val.json
dataset_dir: dataset/pphuman
TrainReader:
batch_size: 8
epoch: 60
LearningRate:
base_lr: 0.001
schedulers:
- !CosineDecay
max_epochs: 72
- !LinearWarmup
start_factor: 0.
epochs: 1
PPYOLOEHead:
static_assigner_epoch: -1
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 300
score_threshold: 0.01
nms_threshold: 0.7
......@@ -19,6 +19,9 @@ PaddleDetection团队提供了针对自动驾驶场景的基于PP-YOLOE的检测
|PP-YOLOE-s| PPVehicle9cls | 9 | 35.3 | [下载链接](https://paddledet.bj.bcebos.com/models/mot_ppyoloe_s_36e_ppvehicle9cls.pdparams) | [配置文件](./mot_ppyoloe_s_36e_ppvehicle9cls.yml) |
|PP-YOLOE-l| PPVehicle | 1 | 63.9 | [下载链接](https://paddledet.bj.bcebos.com/models/mot_ppyoloe_l_36e_ppvehicle.pdparams) | [配置文件](./mot_ppyoloe_l_36e_ppvehicle.yml) |
|PP-YOLOE-s| PPVehicle | 1 | 61.3 | [下载链接](https://paddledet.bj.bcebos.com/models/mot_ppyoloe_s_36e_ppvehicle.pdparams) | [配置文件](./mot_ppyoloe_s_36e_ppvehicle.yml) |
|PP-YOLOE+_t-P2(320)| PPVehicle | 1 | 58.2 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_t_p2_60e_ppvehicle.zip) | [配置文件](./ppyoloe_plus_crn_t_p2_60e_ppvehicle.yml) |
|PP-YOLOE+_t-P2(416)| PPVehicle | 1 | 60.5 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_t_p2_60e_ppvehicle.zip) | [配置文件](./ppyoloe_plus_crn_t_p2_60e_ppvehicle.yml) |
**注意:**
- PP-YOLOE模型训练过程中使用8 GPUs进行混合精度训练,如果**GPU卡数**或者**batch size**发生了改变,你需要按照公式 **lr<sub>new</sub> = lr<sub>default</sub> * (batch_size<sub>new</sub> * GPU_number<sub>new</sub>) / (batch_size<sub>default</sub> * GPU_number<sub>default</sub>)** 调整学习率。
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'../ppyoloe/_base_/optimizer_300e.yml',
'../ppyoloe/_base_/ppyoloe_plus_crn_tiny_auxhead.yml',
'../ppyoloe/_base_/ppyoloe_plus_reader_tiny.yml',
]
log_iter: 100
snapshot_epoch: 4
weights: output/ppyoloe_plus_crn_tiny_60e_ppvehicle/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_tiny_auxhead_300e_coco.pdparams
depth_mult: 0.33
width_mult: 0.375
num_classes: 1
TrainDataset:
!COCODataSet
image_dir: ""
anno_path: annotations/train_all.json
dataset_dir: dataset/ppvehicle
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
allow_empty: true
EvalDataset:
!COCODataSet
image_dir: ""
anno_path: annotations/val_all.json
dataset_dir: dataset/ppvehicle
TestDataset:
!ImageFolder
anno_path: annotations/val_all.json
dataset_dir: dataset/ppvehicle
TrainReader:
batch_size: 8
epoch: 60
LearningRate:
base_lr: 0.001
schedulers:
- !CosineDecay
max_epochs: 72
- !LinearWarmup
start_factor: 0.
epochs: 1
PPYOLOEHead:
static_assigner_epoch: -1
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 300
score_threshold: 0.01
nms_threshold: 0.7
......@@ -44,6 +44,16 @@ PP-YOLOE is composed of following methods:
| PP-YOLOE+_x | 80 | 8 | 8 | cspresnet-x | 640 | 54.7 | 54.9 | 98.42 | 206.59 | 45.0 | 95.2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_x_80e_coco.pdparams) | [config](./ppyoloe_plus_crn_x_80e_coco.yml) |
#### Tiny model
| Model | Epoch | GPU number | images/GPU | backbone | input shape | Box AP<sup>val<br>0.5:0.95 | Box AP<sup>test<br>0.5:0.95 | Params(M) | FLOPs(G) | V100 FP32(FPS) | V100 TensorRT FP16(FPS) | download | config |
|:--------------:|:-----:|:-------:|:----------:|:----------:| :-------:|:--------------------------:|:---------------------------:|:---------:|:--------:|:---------------:| :---------------------: |:------------------------------------------------------------------------------------:|:-------------------------------------------:|
| PP-YOLOE-t-P2 | 300 | 8 | 8 | cspresnet-t | 320 | 34.7 | 50.0 | 6.82 | 4.78 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_t_p2_300e_coco.pdparams) | [config](./ppyoloe_crn_t_p2_300e_coco.yml) |
| PP-YOLOE-t-P2 | 300 | 8 | 8 | cspresnet-t | 416 | 36.4 | 52.3 | 6.82 | 8.07 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_t_p2_300e_coco.pdparams) | [config](./ppyoloe_crn_t_p2_300e_coco.yml) |
| PP-YOLOE+_t-P2(aux) | 300 | 8 | 8 | cspresnet-t | 320 | 36.3 | 51.7 | 6.00 | 15.46 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_t_p2_auxhead_300e_coco.pdparams) | [config](./ppyoloe_plus_crn_t_p2_auxhead_300e_coco.yml) |
| PP-YOLOE+_t-P2(aux) | 300 | 8 | 8 | cspresnet-t | 416 | 39.0 | 55.1 | 6.00 | 26.13 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_t_p2_auxhead_300e_coco.pdparams) | [config](./ppyoloe_plus_crn_t_p2_auxhead_300e_coco.yml) |
### Comprehensive Metrics
| Model | Epoch | AP<sup>0.5:0.95 | AP<sup>0.5 | AP<sup>0.75 | AP<sup>small | AP<sup>medium | AP<sup>large | AR<sup>small | AR<sup>medium | AR<sup>large |
|:------------------------:|:-----:|:---------------:|:----------:|:------------:|:------------:| :-----------: |:------------:|:------------:|:-------------:|:------------:|
......
......@@ -43,6 +43,15 @@ PP-YOLOE由以下方法组成
| PP-YOLOE+_l | 80 | 8 | 8 | cspresnet-l | 640 | 52.9 | 53.3 | 52.20 | 110.07 | 78.1 | 149.2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_l_80e_coco.pdparams) | [config](./ppyoloe_plus_crn_l_80e_coco.yml) |
| PP-YOLOE+_x | 80 | 8 | 8 | cspresnet-x | 640 | 54.7 | 54.9 | 98.42 | 206.59 | 45.0 | 95.2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_x_80e_coco.pdparams) | [config](./ppyoloe_plus_crn_x_80e_coco.yml) |
#### Tiny模型
| 模型 | Epoch | GPU个数 | 每GPU图片个数 | 骨干网络 | 输入尺寸 | Box AP<sup>val<br>0.5:0.95 | Box AP<sup>test<br>0.5:0.95 | Params(M) | FLOPs(G) | V100 FP32(FPS) | V100 TensorRT FP16(FPS) | 模型下载 | 配置文件 |
|:---------------:|:-----:|:---------:|:--------:|:----------:|:----------:|:--------------------------:|:---------------------------:|:---------:|:--------:|:---------------:| :---------------------: |:------------------------------------------------------------------------------------:|:-------------------------------------------:|
| PP-YOLOE-t-P2 | 300 | 8 | 8 | cspresnet-t | 320 | 34.7 | 50.0 | 6.82 | 4.78 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_t_p2_300e_coco.pdparams) | [config](./ppyoloe_crn_t_p2_300e_coco.yml) |
| PP-YOLOE-t-P2 | 300 | 8 | 8 | cspresnet-t | 416 | 36.4 | 52.3 | 6.82 | 8.07 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_t_p2_300e_coco.pdparams) | [config](./ppyoloe_crn_t_p2_300e_coco.yml) |
| PP-YOLOE+_t-P2(aux) | 300 | 8 | 8 | cspresnet-t | 320 | 36.3 | 51.7 | 6.00 | 15.46 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_t_p2_auxhead_300e_coco.pdparams) | [config](./ppyoloe_plus_crn_t_p2_auxhead_300e_coco.yml) |
| PP-YOLOE+_t-P2(aux) | 300 | 8 | 8 | cspresnet-t | 416 | 39.0 | 55.1 | 6.00 | 26.13 | - | - | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_t_p2_auxhead_300e_coco.pdparams) | [config](./ppyoloe_plus_crn_t_p2_auxhead_300e_coco.yml) |
### 综合指标
| 模型 | Epoch | AP<sup>0.5:0.95 | AP<sup>0.5 | AP<sup>0.75 | AP<sup>small | AP<sup>medium | AP<sup>large | AR<sup>small | AR<sup>medium | AR<sup>large |
......
architecture: PPYOLOEWithAuxHead
norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
ema_black_list: ['proj_conv.weight']
custom_black_list: ['reduce_mean']
PPYOLOEWithAuxHead:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
aux_head: SimpleConvHead
post_process: ~
CSPResNet:
layers: [3, 6, 6, 3]
channels: [64, 128, 256, 512, 1024]
return_idx: [1, 2, 3]
use_large_stem: True
use_alpha: True
CustomCSPPAN:
out_channels: [384, 384, 384]
stage_num: 1
block_num: 3
act: 'swish'
spp: true
SimpleConvHead:
feat_in: 288
feat_out: 288
num_convs: 1
fpn_strides: [32, 16, 8]
norm_type: 'gn'
act: 'LeakyReLU'
reg_max: 16
PPYOLOEHead:
fpn_strides: [32, 16, 8]
grid_cell_scale: 5.0
grid_cell_offset: 0.5
static_assigner_epoch: 100
use_varifocal_loss: True
loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
attn_conv: 'repvgg' #
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
is_close_gt: True #
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 300
score_threshold: 0.01
nms_threshold: 0.7
worker_num: 4
eval_height: &eval_height 320
eval_width: &eval_width 320
eval_size: &eval_size [*eval_height, *eval_width]
TrainReader:
sample_transforms:
- Decode: {}
- RandomDistort: {}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomCrop: {}
- RandomFlip: {}
batch_transforms:
- BatchRandomResize: {target_size: [224, 256, 288, 320, 352, 384, 416, 448, 480, 512, 544], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
- PadGT: {}
batch_size: 8
shuffle: true
drop_last: true
use_shared_memory: true
collate_batch: true
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 2
TestReader:
inputs_def:
image_shape: [3, *eval_height, *eval_width]
sample_transforms:
- Decode: {}
- Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 1
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/ppyoloe_crn.yml',
'./_base_/ppyoloe_reader.yml',
]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_crn_t_p2_300e_coco/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_t_pretrained.pdparams
depth_mult: 0.33
width_mult: 0.375
CSPResNet:
return_idx: [0, 1, 2, 3]
CustomCSPPAN:
out_channels: [768, 384, 192, 96]
PPYOLOEHead:
fpn_strides: [32, 16, 8, 4]
attn_conv: 'repvgg' #
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
is_close_gt: True #
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 300
score_threshold: 0.01
nms_threshold: 0.7
worker_num: 4
eval_height: &eval_height 320
eval_width: &eval_width 320
eval_size: &eval_size [*eval_height, *eval_width]
TrainReader:
sample_transforms:
- Decode: {}
- RandomDistort: {}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomCrop: {}
- RandomFlip: {}
batch_transforms:
- BatchRandomResize: {target_size: [224, 256, 288, 320, 352, 384, 416, 448, 480, 512, 544], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
- PadGT: {}
batch_size: 8
shuffle: true
drop_last: true
use_shared_memory: true
collate_batch: true
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 2
TestReader:
inputs_def:
image_shape: [3, *eval_height, *eval_width]
sample_transforms:
- Decode: {}
- Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 1
fuse_normalize: True
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/ppyoloe_plus_crn_tiny_auxhead.yml',
'./_base_/ppyoloe_plus_tiny_reader.yml',
]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_plus_crn_t_auxhead_300e_coco/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_t_pretrained.pdparams
depth_mult: 0.33
width_mult: 0.375
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/ppyoloe_plus_crn_tiny_auxhead.yml',
'./_base_/ppyoloe_plus_tiny_reader.yml',
]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_plus_crn_t_p2_auxhead_300e_coco/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_t_pretrained.pdparams
depth_mult: 0.33
width_mult: 0.375
architecture: PPYOLOEWithAuxHead
PPYOLOEWithAuxHead:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
aux_head: SimpleConvHead
post_process: ~
CSPResNet:
return_idx: [0, 1, 2, 3] # index 0 stands for P2
CustomCSPPAN:
out_channels: [384, 384, 384, 384]
SimpleConvHead:
fpn_strides: [32, 16, 8, 4]
PPYOLOEHead:
fpn_strides: [32, 16, 8, 4]
......@@ -194,6 +194,9 @@ def _dump_infer_config(config, path, image_shape, model):
arch_state = True
break
if infer_arch == 'PPYOLOEWithAuxHead':
infer_arch = 'PPYOLOE'
if infer_arch in ['PPYOLOE', 'YOLOX', 'YOLOF']:
infer_cfg['arch'] = infer_arch
infer_cfg['min_subgraph_size'] = TRT_MIN_SUBGRAPH[infer_arch]
......
......@@ -157,9 +157,10 @@ class Trainer(object):
if print_params:
params = sum([
p.numel() for n, p in self.model.named_parameters()
if all([x not in n for x in ['_mean', '_variance']])
if all([x not in n for x in ['_mean', '_variance', 'aux_']])
]) # exclude BatchNorm running status
logger.info('Params: ', params / 1e6)
logger.info('Model Params : {} M.'.format((params / 1e6).numpy()[
0]))
# build optimizer in train mode
if self.mode == 'train':
......@@ -1105,6 +1106,10 @@ class Trainer(object):
return static_model, pruned_input_spec
def export(self, output_dir='output_inference'):
if hasattr(self.model, 'aux_neck'):
self.model.__delattr__('aux_neck')
if hasattr(self.model, 'aux_head'):
self.model.__delattr__('aux_head')
self.model.eval()
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
......@@ -1151,6 +1156,10 @@ class Trainer(object):
logger.info("Export Post-Quant model and saved in {}".format(save_dir))
def _flops(self, loader):
if hasattr(self.model, 'aux_neck'):
self.model.__delattr__('aux_neck')
if hasattr(self.model, 'aux_head'):
self.model.__delattr__('aux_head')
self.model.eval()
try:
import paddleslim
......
......@@ -16,10 +16,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import copy
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
__all__ = ['PPYOLOE']
__all__ = ['PPYOLOE', 'PPYOLOEWithAuxHead']
# PP-YOLOE and PP-YOLOE+ are recommended to use this architecture
# PP-YOLOE and PP-YOLOE+ can also use the same architecture of YOLOv3 in yolo.py
......@@ -97,3 +99,101 @@ class PPYOLOE(BaseArch):
def get_pred(self):
return self._forward()
@register
class PPYOLOEWithAuxHead(BaseArch):
__category__ = 'architecture'
__inject__ = ['post_process']
def __init__(self,
backbone='CSPResNet',
neck='CustomCSPPAN',
yolo_head='PPYOLOEHead',
aux_head='SimpleConvHead',
post_process='BBoxPostProcess',
for_mot=False,
detach_epoch=5):
"""
PPYOLOE network, see https://arxiv.org/abs/2203.16250
Args:
backbone (nn.Layer): backbone instance
neck (nn.Layer): neck instance
yolo_head (nn.Layer): anchor_head instance
post_process (object): `BBoxPostProcess` instance
for_mot (bool): whether return other features for multi-object tracking
models, default False in pure object detection models.
"""
super(PPYOLOEWithAuxHead, self).__init__()
self.backbone = backbone
self.neck = neck
self.aux_neck = copy.deepcopy(self.neck)
self.yolo_head = yolo_head
self.aux_head = aux_head
self.post_process = post_process
self.for_mot = for_mot
self.detach_epoch = detach_epoch
@classmethod
def from_config(cls, cfg, *args, **kwargs):
# backbone
backbone = create(cfg['backbone'])
# fpn
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)
aux_neck = copy.deepcopy(neck)
# head
kwargs = {'input_shape': neck.out_shape}
yolo_head = create(cfg['yolo_head'], **kwargs)
aux_head = create(cfg['aux_head'], **kwargs)
return {
'backbone': backbone,
'neck': neck,
"yolo_head": yolo_head,
'aux_head': aux_head,
}
def _forward(self):
body_feats = self.backbone(self.inputs)
neck_feats = self.neck(body_feats, self.for_mot)
if self.training:
if self.inputs['epoch_id'] >= self.detach_epoch:
aux_neck_feats = self.aux_neck([f.detach() for f in body_feats])
dual_neck_feats = (paddle.concat(
[f.detach(), aux_f], axis=1) for f, aux_f in
zip(neck_feats, aux_neck_feats))
else:
aux_neck_feats = self.aux_neck(body_feats)
dual_neck_feats = (paddle.concat(
[f, aux_f], axis=1) for f, aux_f in
zip(neck_feats, aux_neck_feats))
aux_cls_scores, aux_bbox_preds = self.aux_head(dual_neck_feats)
loss = self.yolo_head(
neck_feats,
self.inputs,
aux_pred=[aux_cls_scores, aux_bbox_preds])
return loss
else:
yolo_head_outs = self.yolo_head(neck_feats)
if self.post_process is not None:
bbox, bbox_num = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
else:
bbox, bbox_num = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor'])
output = {'bbox': bbox, 'bbox_num': bbox_num}
return output
def get_loss(self):
return self._forward()
def get_pred(self):
return self._forward()
......@@ -28,17 +28,47 @@ from .utils import (gather_topk_anchors, check_points_inside_bboxes,
__all__ = ['TaskAlignedAssigner']
def is_close_gt(anchor, gt, stride_lst, max_dist=2.0, alpha=2.):
"""Calculate distance ratio of box1 and box2 in batch for larger stride
anchors dist/stride to promote the survive of large distance match
Args:
anchor (Tensor): box with the shape [L, 2]
gt (Tensor): box with the shape [N, M2, 4]
Return:
dist (Tensor): dist ratio between box1 and box2 with the shape [N, M1, M2]
"""
center1 = anchor.unsqueeze(0)
center2 = (gt[..., :2] + gt[..., -2:]) / 2.
center1 = center1.unsqueeze(1) # [N, M1, 2] -> [N, 1, M1, 2]
center2 = center2.unsqueeze(2) # [N, M2, 2] -> [N, M2, 1, 2]
stride = paddle.concat([
paddle.full([x], 32 / pow(2, idx)) for idx, x in enumerate(stride_lst)
]).unsqueeze(0).unsqueeze(0)
dist = paddle.linalg.norm(center1 - center2, p=2, axis=-1) / stride
dist_ratio = dist
dist_ratio[dist < max_dist] = 1.
dist_ratio[dist >= max_dist] = 0.
return dist_ratio
@register
class TaskAlignedAssigner(nn.Layer):
"""TOOD: Task-aligned One-stage Object Detection
"""
def __init__(self, topk=13, alpha=1.0, beta=6.0, eps=1e-9):
def __init__(self,
topk=13,
alpha=1.0,
beta=6.0,
eps=1e-9,
is_close_gt=False):
super(TaskAlignedAssigner, self).__init__()
self.topk = topk
self.alpha = alpha
self.beta = beta
self.eps = eps
self.is_close_gt = is_close_gt
@paddle.no_grad()
def forward(self,
......@@ -107,7 +137,10 @@ class TaskAlignedAssigner(nn.Layer):
self.beta)
# check the positive sample's center in gt, [B, n, L]
is_in_gts = check_points_inside_bboxes(anchor_points, gt_bboxes)
if self.is_close_gt:
is_in_gts = is_close_gt(anchor_points, gt_bboxes, num_anchors_list)
else:
is_in_gts = check_points_inside_bboxes(anchor_points, gt_bboxes)
# select topk largest alignment metrics pred bbox as candidates
# for each gt, [B, n, L]
......
......@@ -16,24 +16,29 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from paddle import ParamAttr
from paddle.nn.initializer import KaimingNormal
from paddle.nn.initializer import Normal, Constant
from ..bbox_utils import batch_distance2bbox
from ..losses import GIoULoss
from ..initializer import bias_init_with_prob, constant_, normal_
from ..assigners.utils import generate_anchors_for_grid_cell
from ppdet.modeling.backbones.cspresnet import ConvBNLayer
from ppdet.modeling.backbones.cspresnet import ConvBNLayer, RepVggBlock
from ppdet.modeling.ops import get_static_shape, get_act_fn
from ppdet.modeling.layers import MultiClassNMS
__all__ = ['PPYOLOEHead']
__all__ = ['PPYOLOEHead', 'SimpleConvHead']
class ESEAttn(nn.Layer):
def __init__(self, feat_channels, act='swish'):
def __init__(self, feat_channels, act='swish', attn_conv='convbn'):
super(ESEAttn, self).__init__()
self.fc = nn.Conv2D(feat_channels, feat_channels, 1)
self.conv = ConvBNLayer(feat_channels, feat_channels, 1, act=act)
if attn_conv == 'convbn':
self.conv = ConvBNLayer(feat_channels, feat_channels, 1, act=act)
else:
self.conv = RepVggBlock(feat_channels, feat_channels, act=act)
self._init_weights()
def _init_weights(self):
......@@ -73,6 +78,7 @@ class PPYOLOEHead(nn.Layer):
'dfl': 0.5,
},
trt=False,
attn_conv='convbn',
exclude_nms=False,
exclude_post_process=False,
use_shared_conv=True):
......@@ -112,8 +118,8 @@ class PPYOLOEHead(nn.Layer):
act, trt=trt) if act is None or isinstance(act,
(str, dict)) else act
for in_c in self.in_channels:
self.stem_cls.append(ESEAttn(in_c, act=act))
self.stem_reg.append(ESEAttn(in_c, act=act))
self.stem_cls.append(ESEAttn(in_c, act=act, attn_conv=attn_conv))
self.stem_reg.append(ESEAttn(in_c, act=act, attn_conv=attn_conv))
# pred head
self.pred_cls = nn.LayerList()
self.pred_reg = nn.LayerList()
......@@ -151,7 +157,7 @@ class PPYOLOEHead(nn.Layer):
self.anchor_points = anchor_points
self.stride_tensor = stride_tensor
def forward_train(self, feats, targets):
def forward_train(self, feats, targets, aux_pred=None):
anchors, anchor_points, num_anchors_list, stride_tensor = \
generate_anchors_for_grid_cell(
feats, self.fpn_strides, self.grid_cell_scale,
......@@ -173,7 +179,7 @@ class PPYOLOEHead(nn.Layer):
return self.get_loss([
cls_score_list, reg_distri_list, anchors, anchor_points,
num_anchors_list, stride_tensor
], targets)
], targets, aux_pred)
def _generate_anchors(self, feats=None, dtype='float32'):
# just use in eval time
......@@ -231,12 +237,12 @@ class PPYOLOEHead(nn.Layer):
return cls_score_list, reg_dist_list, anchor_points, stride_tensor
def forward(self, feats, targets=None):
def forward(self, feats, targets=None, aux_pred=None):
assert len(feats) == len(self.fpn_strides), \
"The size of feats is not equal to size of fpn_strides"
if self.training:
return self.forward_train(feats, targets)
return self.forward_train(feats, targets, aux_pred)
else:
return self.forward_eval(feats)
......@@ -321,13 +327,17 @@ class PPYOLOEHead(nn.Layer):
loss_dfl = pred_dist.sum() * 0.
return loss_l1, loss_iou, loss_dfl
def get_loss(self, head_outs, gt_meta):
def get_loss(self, head_outs, gt_meta, aux_pred=None):
pred_scores, pred_distri, anchors,\
anchor_points, num_anchors_list, stride_tensor = head_outs
anchor_points_s = anchor_points / stride_tensor
pred_bboxes = self._bbox_decode(anchor_points_s, pred_distri)
if aux_pred is not None:
pred_scores_aux = aux_pred[0]
pred_bboxes_aux = self._bbox_decode(anchor_points_s, aux_pred[1])
gt_labels = gt_meta['gt_class']
gt_bboxes = gt_meta['gt_bbox']
pad_gt_mask = gt_meta['pad_gt_mask']
......@@ -345,6 +355,7 @@ class PPYOLOEHead(nn.Layer):
alpha_l = 0.25
else:
if self.sm_use:
# only used in smalldet of PPYOLOE-SOD model
assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner(
pred_scores.detach(),
......@@ -356,19 +367,51 @@ class PPYOLOEHead(nn.Layer):
pad_gt_mask,
bg_index=self.num_classes)
else:
assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor,
anchor_points,
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes)
if aux_pred is None:
assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor,
anchor_points,
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes)
else:
assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner(
pred_scores_aux.detach(),
pred_bboxes_aux.detach() * stride_tensor,
anchor_points,
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes)
alpha_l = -1
# rescale bbox
assigned_bboxes /= stride_tensor
assign_out_dict = self.get_loss_from_assign(
pred_scores, pred_distri, pred_bboxes, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores, alpha_l)
if aux_pred is not None:
assign_out_dict_aux = self.get_loss_from_assign(
aux_pred[0], aux_pred[1], pred_bboxes_aux, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores, alpha_l)
loss = {}
for key in assign_out_dict.keys():
loss[key] = assign_out_dict[key] + assign_out_dict_aux[key]
else:
loss = assign_out_dict
return loss
def get_loss_from_assign(self, pred_scores, pred_distri, pred_bboxes,
anchor_points_s, assigned_labels, assigned_bboxes,
assigned_scores, alpha_l):
# cls loss
if self.use_varifocal_loss:
one_hot_label = F.one_hot(assigned_labels,
......@@ -421,3 +464,169 @@ class PPYOLOEHead(nn.Layer):
else:
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num
def get_activation(name="LeakyReLU"):
if name == "silu":
module = nn.Silu()
elif name == "relu":
module = nn.ReLU()
elif name in ["LeakyReLU", 'leakyrelu', 'lrelu']:
module = nn.LeakyReLU(0.1)
elif name is None:
module = nn.Identity()
else:
raise AttributeError("Unsupported act type: {}".format(name))
return module
class ConvNormLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
norm_type='gn',
activation="LeakyReLU"):
super(ConvNormLayer, self).__init__()
assert norm_type in ['bn', 'sync_bn', 'syncbn', 'gn', None]
self.conv = nn.Conv2D(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias_attr=False,
weight_attr=ParamAttr(initializer=KaimingNormal()))
if norm_type in ['bn', 'sync_bn', 'syncbn']:
self.norm = nn.BatchNorm2D(out_channels)
elif norm_type == 'gn':
self.norm = nn.GroupNorm(num_groups=32, num_channels=out_channels)
else:
self.norm = None
self.act = get_activation(activation)
def forward(self, x):
y = self.conv(x)
if self.norm is not None:
y = self.norm(y)
y = self.act(y)
return y
class ScaleReg(nn.Layer):
"""
Parameter for scaling the regression outputs.
"""
def __init__(self, scale=1.0):
super(ScaleReg, self).__init__()
scale = paddle.to_tensor(scale)
self.scale = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=nn.initializer.Assign(scale))
def forward(self, x):
return x * self.scale
@register
class SimpleConvHead(nn.Layer):
__shared__ = ['num_classes']
def __init__(self,
num_classes=80,
feat_in=288,
feat_out=288,
num_convs=1,
fpn_strides=[32, 16, 8, 4],
norm_type='gn',
act='LeakyReLU',
prior_prob=0.01,
reg_max=16):
super(SimpleConvHead, self).__init__()
self.num_classes = num_classes
self.feat_in = feat_in
self.feat_out = feat_out
self.num_convs = num_convs
self.fpn_strides = fpn_strides
self.reg_max = reg_max
self.cls_convs = nn.LayerList()
self.reg_convs = nn.LayerList()
for i in range(self.num_convs):
in_c = feat_in if i == 0 else feat_out
self.cls_convs.append(
ConvNormLayer(
in_c,
feat_out,
3,
stride=1,
padding=1,
norm_type=norm_type,
activation=act))
self.reg_convs.append(
ConvNormLayer(
in_c,
feat_out,
3,
stride=1,
padding=1,
norm_type=norm_type,
activation=act))
bias_cls = bias_init_with_prob(prior_prob)
self.gfl_cls = nn.Conv2D(
feat_out,
self.num_classes,
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(initializer=Normal(
mean=0.0, std=0.01)),
bias_attr=ParamAttr(initializer=Constant(value=bias_cls)))
self.gfl_reg = nn.Conv2D(
feat_out,
4 * (self.reg_max + 1),
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(initializer=Normal(
mean=0.0, std=0.01)),
bias_attr=ParamAttr(initializer=Constant(value=0)))
self.scales = nn.LayerList()
for i in range(len(self.fpn_strides)):
self.scales.append(ScaleReg(1.0))
def forward(self, feats):
cls_scores = []
bbox_preds = []
for x, scale in zip(feats, self.scales):
cls_feat = x
reg_feat = x
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs:
reg_feat = reg_conv(reg_feat)
cls_score = self.gfl_cls(cls_feat)
cls_score = F.sigmoid(cls_score)
cls_score = cls_score.flatten(2).transpose([0, 2, 1])
cls_scores.append(cls_score)
bbox_pred = scale(self.gfl_reg(reg_feat))
bbox_pred = bbox_pred.flatten(2).transpose([0, 2, 1])
bbox_preds.append(bbox_pred)
cls_scores = paddle.concat(cls_scores, axis=1)
bbox_preds = paddle.concat(bbox_preds, axis=1)
return cls_scores, bbox_preds
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册