未验证 提交 d8508359 编写于 作者: S shangliang Xu 提交者: GitHub

Add PP-YOLOE yaml (#5318)

* Add PP-YOLOE yaml

* rename PP-YOLOE

* Add exclude_nms in ppyoloe_head

* rename pretrain_weights
上级 64865333
epoch: 300
LearningRate:
base_lr: 0.03
schedulers:
- !CosineDecay
max_epochs: 360
- !LinearWarmup
start_factor: 0.001
steps: 3000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
architecture: YOLOv3
norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
CSPResNet:
layers: [3, 6, 6, 3]
channels: [64, 128, 256, 512, 1024]
return_idx: [1, 2, 3]
use_large_stem: True
CustomCSPPAN:
out_channels: [768, 384, 192]
stage_num: 1
block_num: 3
act: 'swish'
spp: true
PPYOLOEHead:
fpn_strides: [32, 16, 8]
grid_cell_scale: 5.0
grid_cell_offset: 0.5
static_assigner_epoch: 100
use_varifocal_loss: True
eval_input_size: [640, 640]
loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 100
score_threshold: 0.01
nms_threshold: 0.6
worker_num: 8
TrainReader:
sample_transforms:
- Decode: {}
- RandomDistort: {}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomCrop: {}
- RandomFlip: {}
batch_transforms:
- BatchRandomResize: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
- PadGT: {}
batch_size: 24
shuffle: true
drop_last: true
use_shared_memory: true
collate_batch: true
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 4
TestReader:
inputs_def:
image_shape: [3, 640, 640]
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 1
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/ppyoloe_crn.yml',
'./_base_/ppyoloe_reader.yml',
]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_crn_l_300e_coco/model_final
find_unused_parameters: True
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_l_pretrained.pdparams
depth_mult: 1.0
width_mult: 1.0
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/ppyoloe_crn.yml',
'./_base_/ppyoloe_reader.yml',
]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_crn_m_300e_coco/model_final
find_unused_parameters: True
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_m_pretrained.pdparams
depth_mult: 0.67
width_mult: 0.75
TrainReader:
batch_size: 32
LearningRate:
base_lr: 0.04
schedulers:
- !CosineDecay
max_epochs: 360
- !LinearWarmup
start_factor: 0.001
steps: 2300
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/ppyoloe_crn.yml',
'./_base_/ppyoloe_reader.yml',
]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_crn_s_300e_coco/model_final
find_unused_parameters: True
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_s_pretrained.pdparams
depth_mult: 0.33
width_mult: 0.50
TrainReader:
batch_size: 32
LearningRate:
base_lr: 0.04
schedulers:
- !CosineDecay
max_epochs: 360
- !LinearWarmup
start_factor: 0.001
steps: 2300
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/ppyoloe_crn.yml',
'./_base_/ppyoloe_reader.yml',
]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_crn_x_300e_coco/model_final
find_unused_parameters: True
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_x_pretrained.pdparams
depth_mult: 1.33
width_mult: 1.25
TrainReader:
batch_size: 16
LearningRate:
base_lr: 0.02
schedulers:
- !CosineDecay
max_epochs: 360
- !LinearWarmup
start_factor: 0.001
steps: 4600
......@@ -32,7 +32,7 @@ from . import detr_head
from . import sparsercnn_head
from . import tood_head
from . import retina_head
from . import ppyolo_head
from . import ppyoloe_head
from .bbox_head import *
from .mask_head import *
......@@ -54,4 +54,4 @@ from .detr_head import *
from .sparsercnn_head import *
from .tood_head import *
from .retina_head import *
from .ppyolo_head import *
from .ppyoloe_head import *
......@@ -24,7 +24,7 @@ from ..assigners.utils import generate_anchors_for_grid_cell
from ppdet.modeling.backbones.cspresnet import ConvBNLayer
from ppdet.modeling.ops import get_static_shape, paddle_distributed_is_initialized, get_act_fn
__all__ = ['PPYOLOHead']
__all__ = ['PPYOLOEHead']
class ESEAttn(nn.Layer):
......@@ -44,8 +44,8 @@ class ESEAttn(nn.Layer):
@register
class PPYOLOHead(nn.Layer):
__shared__ = ['num_classes', 'trt']
class PPYOLOEHead(nn.Layer):
__shared__ = ['num_classes', 'trt', 'exclude_nms']
__inject__ = ['static_assigner', 'assigner', 'nms']
def __init__(self,
......@@ -67,8 +67,9 @@ class PPYOLOHead(nn.Layer):
'iou': 2.5,
'dfl': 0.5,
},
trt=False):
super(PPYOLOHead, self).__init__()
trt=False,
exclude_nms=False):
super(PPYOLOEHead, self).__init__()
assert len(in_channels) > 0, "len(in_channels) should > 0"
self.in_channels = in_channels
self.num_classes = num_classes
......@@ -85,6 +86,7 @@ class PPYOLOHead(nn.Layer):
self.static_assigner = static_assigner
self.assigner = assigner
self.nms = nms
self.exclude_nms = exclude_nms
# stem
self.stem_cls = nn.LayerList()
self.stem_reg = nn.LayerList()
......@@ -333,8 +335,7 @@ class PPYOLOHead(nn.Layer):
loss_cls = self._varifocal_loss(pred_scores, assigned_scores,
one_hot_label)
else:
loss_cls = self._focal_loss(
pred_scores, assigned_scores, alpha=alpha_l)
loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha_l)
assigned_scores_sum = assigned_scores.sum()
if paddle_distributed_is_initialized():
......@@ -370,5 +371,9 @@ class PPYOLOHead(nn.Layer):
scale_factor = paddle.concat(
[scale_x, scale_y, scale_x, scale_y], axis=-1).reshape([-1, 1, 4])
pred_bboxes /= scale_factor
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num
if self.exclude_nms:
# `exclude_nms=True` just use in benchmark
return pred_bboxes.sum(), pred_scores.sum()
else:
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册