diff --git a/configs/ppyoloe/_base_/optimizer_300e.yml b/configs/ppyoloe/_base_/optimizer_300e.yml new file mode 100644 index 0000000000000000000000000000000000000000..f38a34b2316cbaf738e8b40e2db01aef7e0a81d6 --- /dev/null +++ b/configs/ppyoloe/_base_/optimizer_300e.yml @@ -0,0 +1,18 @@ +epoch: 300 + +LearningRate: + base_lr: 0.03 + schedulers: + - !CosineDecay + max_epochs: 360 + - !LinearWarmup + start_factor: 0.001 + steps: 3000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 diff --git a/configs/ppyoloe/_base_/ppyoloe_crn.yml b/configs/ppyoloe/_base_/ppyoloe_crn.yml new file mode 100644 index 0000000000000000000000000000000000000000..2ad9a11a8e98fe0f415606fb0b1119263d3b5aa4 --- /dev/null +++ b/configs/ppyoloe/_base_/ppyoloe_crn.yml @@ -0,0 +1,46 @@ +architecture: YOLOv3 +norm_type: sync_bn +use_ema: true +ema_decay: 0.9998 + +YOLOv3: + backbone: CSPResNet + neck: CustomCSPPAN + yolo_head: PPYOLOEHead + post_process: ~ + +CSPResNet: + layers: [3, 6, 6, 3] + channels: [64, 128, 256, 512, 1024] + return_idx: [1, 2, 3] + use_large_stem: True + +CustomCSPPAN: + out_channels: [768, 384, 192] + stage_num: 1 + block_num: 3 + act: 'swish' + spp: true + +PPYOLOEHead: + fpn_strides: [32, 16, 8] + grid_cell_scale: 5.0 + grid_cell_offset: 0.5 + static_assigner_epoch: 100 + use_varifocal_loss: True + eval_input_size: [640, 640] + loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5} + static_assigner: + name: ATSSAssigner + topk: 9 + assigner: + name: TaskAlignedAssigner + topk: 13 + alpha: 1.0 + beta: 6.0 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.01 + nms_threshold: 0.6 diff --git a/configs/ppyoloe/_base_/ppyoloe_reader.yml b/configs/ppyoloe/_base_/ppyoloe_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..0774db5a83dd2e4349957e435292dbe290d11a96 --- /dev/null +++ b/configs/ppyoloe/_base_/ppyoloe_reader.yml @@ -0,0 +1,36 @@ +worker_num: 8 +TrainReader: + sample_transforms: + - Decode: {} + - RandomDistort: {} + - RandomExpand: {fill_value: [123.675, 116.28, 103.53]} + - RandomCrop: {} + - RandomFlip: {} + batch_transforms: + - BatchRandomResize: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - Permute: {} + - PadGT: {} + batch_size: 24 + shuffle: true + drop_last: true + use_shared_memory: true + collate_batch: true + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - Permute: {} + batch_size: 4 + +TestReader: + inputs_def: + image_shape: [3, 640, 640] + sample_transforms: + - Decode: {} + - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - Permute: {} + batch_size: 1 diff --git a/configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml b/configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..0b2a25b9041d7da3376a67342e5308af899d3fb6 --- /dev/null +++ b/configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml @@ -0,0 +1,16 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + './_base_/optimizer_300e.yml', + './_base_/ppyoloe_crn.yml', + './_base_/ppyoloe_reader.yml', +] + +log_iter: 100 +snapshot_epoch: 10 +weights: output/ppyoloe_crn_l_300e_coco/model_final +find_unused_parameters: True + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_l_pretrained.pdparams +depth_mult: 1.0 +width_mult: 1.0 diff --git a/configs/ppyoloe/ppyoloe_crn_m_300e_coco.yml b/configs/ppyoloe/ppyoloe_crn_m_300e_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..b7e78d72ffa4cc10a96385c75fcd58b379718562 --- /dev/null +++ b/configs/ppyoloe/ppyoloe_crn_m_300e_coco.yml @@ -0,0 +1,28 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + './_base_/optimizer_300e.yml', + './_base_/ppyoloe_crn.yml', + './_base_/ppyoloe_reader.yml', +] + +log_iter: 100 +snapshot_epoch: 10 +weights: output/ppyoloe_crn_m_300e_coco/model_final +find_unused_parameters: True + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_m_pretrained.pdparams +depth_mult: 0.67 +width_mult: 0.75 + +TrainReader: + batch_size: 32 + +LearningRate: + base_lr: 0.04 + schedulers: + - !CosineDecay + max_epochs: 360 + - !LinearWarmup + start_factor: 0.001 + steps: 2300 diff --git a/configs/ppyoloe/ppyoloe_crn_s_300e_coco.yml b/configs/ppyoloe/ppyoloe_crn_s_300e_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..7feffed06ab6f9f94a2cd3a8ff18783fa798e9a7 --- /dev/null +++ b/configs/ppyoloe/ppyoloe_crn_s_300e_coco.yml @@ -0,0 +1,28 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + './_base_/optimizer_300e.yml', + './_base_/ppyoloe_crn.yml', + './_base_/ppyoloe_reader.yml', +] + +log_iter: 100 +snapshot_epoch: 10 +weights: output/ppyoloe_crn_s_300e_coco/model_final +find_unused_parameters: True + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_s_pretrained.pdparams +depth_mult: 0.33 +width_mult: 0.50 + +TrainReader: + batch_size: 32 + +LearningRate: + base_lr: 0.04 + schedulers: + - !CosineDecay + max_epochs: 360 + - !LinearWarmup + start_factor: 0.001 + steps: 2300 diff --git a/configs/ppyoloe/ppyoloe_crn_x_300e_coco.yml b/configs/ppyoloe/ppyoloe_crn_x_300e_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..4e72b19c71f4348f04c292c2db2db992c679b829 --- /dev/null +++ b/configs/ppyoloe/ppyoloe_crn_x_300e_coco.yml @@ -0,0 +1,28 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + './_base_/optimizer_300e.yml', + './_base_/ppyoloe_crn.yml', + './_base_/ppyoloe_reader.yml', +] + +log_iter: 100 +snapshot_epoch: 10 +weights: output/ppyoloe_crn_x_300e_coco/model_final +find_unused_parameters: True + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_x_pretrained.pdparams +depth_mult: 1.33 +width_mult: 1.25 + +TrainReader: + batch_size: 16 + +LearningRate: + base_lr: 0.02 + schedulers: + - !CosineDecay + max_epochs: 360 + - !LinearWarmup + start_factor: 0.001 + steps: 4600 diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index 41e3fe0a56f3e16949fac0a80f3908e7a4ff9b11..e831933381c88dbd14f588b045d619333cc537c3 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -32,7 +32,7 @@ from . import detr_head from . import sparsercnn_head from . import tood_head from . import retina_head -from . import ppyolo_head +from . import ppyoloe_head from .bbox_head import * from .mask_head import * @@ -54,4 +54,4 @@ from .detr_head import * from .sparsercnn_head import * from .tood_head import * from .retina_head import * -from .ppyolo_head import * +from .ppyoloe_head import * diff --git a/ppdet/modeling/heads/ppyolo_head.py b/ppdet/modeling/heads/ppyoloe_head.py similarity index 96% rename from ppdet/modeling/heads/ppyolo_head.py rename to ppdet/modeling/heads/ppyoloe_head.py index 38c6b3368c1034a683ecc36f05d160679d6e495c..920bb2298909e5275c9bc04f3c73cce3f4c8ff36 100644 --- a/ppdet/modeling/heads/ppyolo_head.py +++ b/ppdet/modeling/heads/ppyoloe_head.py @@ -24,7 +24,7 @@ from ..assigners.utils import generate_anchors_for_grid_cell from ppdet.modeling.backbones.cspresnet import ConvBNLayer from ppdet.modeling.ops import get_static_shape, paddle_distributed_is_initialized, get_act_fn -__all__ = ['PPYOLOHead'] +__all__ = ['PPYOLOEHead'] class ESEAttn(nn.Layer): @@ -44,8 +44,8 @@ class ESEAttn(nn.Layer): @register -class PPYOLOHead(nn.Layer): - __shared__ = ['num_classes', 'trt'] +class PPYOLOEHead(nn.Layer): + __shared__ = ['num_classes', 'trt', 'exclude_nms'] __inject__ = ['static_assigner', 'assigner', 'nms'] def __init__(self, @@ -67,8 +67,9 @@ class PPYOLOHead(nn.Layer): 'iou': 2.5, 'dfl': 0.5, }, - trt=False): - super(PPYOLOHead, self).__init__() + trt=False, + exclude_nms=False): + super(PPYOLOEHead, self).__init__() assert len(in_channels) > 0, "len(in_channels) should > 0" self.in_channels = in_channels self.num_classes = num_classes @@ -85,6 +86,7 @@ class PPYOLOHead(nn.Layer): self.static_assigner = static_assigner self.assigner = assigner self.nms = nms + self.exclude_nms = exclude_nms # stem self.stem_cls = nn.LayerList() self.stem_reg = nn.LayerList() @@ -333,8 +335,7 @@ class PPYOLOHead(nn.Layer): loss_cls = self._varifocal_loss(pred_scores, assigned_scores, one_hot_label) else: - loss_cls = self._focal_loss( - pred_scores, assigned_scores, alpha=alpha_l) + loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha_l) assigned_scores_sum = assigned_scores.sum() if paddle_distributed_is_initialized(): @@ -370,5 +371,9 @@ class PPYOLOHead(nn.Layer): scale_factor = paddle.concat( [scale_x, scale_y, scale_x, scale_y], axis=-1).reshape([-1, 1, 4]) pred_bboxes /= scale_factor - bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores) - return bbox_pred, bbox_num + if self.exclude_nms: + # `exclude_nms=True` just use in benchmark + return pred_bboxes.sum(), pred_scores.sum() + else: + bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores) + return bbox_pred, bbox_num