未验证 提交 34d78329 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] add ppyoloe_plus configs and alter NormalizeImage (#6675)

* [dev] add ppyoloe_plus configs and alter NormalizeImage

* alter other NormalizeImage

* alter cpp NormalizeImage
上级 6eff0446
epoch: 80
LearningRate:
base_lr: 0.001
schedulers:
- !CosineDecay
max_epochs: 96
- !LinearWarmup
start_factor: 0.
epochs: 5
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
architecture: YOLOv3
norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
CSPResNet:
layers: [3, 6, 6, 3]
channels: [64, 128, 256, 512, 1024]
return_idx: [1, 2, 3]
use_large_stem: True
use_alpha: True
CustomCSPPAN:
out_channels: [768, 384, 192]
stage_num: 1
block_num: 3
act: 'swish'
spp: true
PPYOLOEHead:
fpn_strides: [32, 16, 8]
grid_cell_scale: 5.0
grid_cell_offset: 0.5
static_assigner_epoch: 30
use_varifocal_loss: True
loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 300
score_threshold: 0.01
nms_threshold: 0.7
worker_num: 4
eval_height: &eval_height 640
eval_width: &eval_width 640
eval_size: &eval_size [*eval_height, *eval_width]
TrainReader:
sample_transforms:
- Decode: {}
- RandomDistort: {}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomCrop: {}
- RandomFlip: {}
batch_transforms:
- BatchRandomResize: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
- PadGT: {}
batch_size: 8
shuffle: true
drop_last: true
use_shared_memory: true
collate_batch: true
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 2
TestReader:
inputs_def:
image_shape: [3, *eval_height, *eval_width]
sample_transforms:
- Decode: {}
- Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 1
# PP-YOLOE Legacy Model Zoo (2022.03)
## Legacy Model Zoo
| Model | Epoch | GPU number | images/GPU | backbone | input shape | Box AP<sup>val<br>0.5:0.95 | Box AP<sup>test<br>0.5:0.95 | Params(M) | FLOPs(G) | V100 FP32(FPS) | V100 TensorRT FP16(FPS) | download | config |
|:------------------------:|:-------:|:-------:|:--------:|:----------:| :-------:| :------------------: | :-------------------: |:---------:|:--------:|:---------------:| :---------------------: | :------: | :------: |
| PP-YOLOE-s | 400 | 8 | 32 | cspresnet-s | 640 | 43.4 | 43.6 | 7.93 | 17.36 | 208.3 | 333.3 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_400e_coco.pdparams) | [config](./ppyoloe_crn_s_400e_coco.yml) |
| PP-YOLOE-s | 300 | 8 | 32 | cspresnet-s | 640 | 43.0 | 43.2 | 7.93 | 17.36 | 208.3 | 333.3 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_300e_coco.pdparams) | [config](./ppyoloe_crn_s_300e_coco.yml) |
| PP-YOLOE-m | 300 | 8 | 28 | cspresnet-m | 640 | 49.0 | 49.1 | 23.43 | 49.91 | 123.4 | 208.3 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_m_300e_coco.pdparams) | [config](./ppyoloe_crn_m_300e_coco.yml) |
| PP-YOLOE-l | 300 | 8 | 20 | cspresnet-l | 640 | 51.4 | 51.6 | 52.20 | 110.07 | 78.1 | 149.2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams) | [config](./ppyoloe_crn_l_300e_coco.yml) |
| PP-YOLOE-x | 300 | 8 | 16 | cspresnet-x | 640 | 52.3 | 52.4 | 98.42 | 206.59 | 45.0 | 95.2 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_x_300e_coco.pdparams) | [config](./ppyoloe_crn_x_300e_coco.yml) |
### Comprehensive Metrics
| Model | Epoch | AP<sup>0.5:0.95 | AP<sup>0.5 | AP<sup>0.75 | AP<sup>small | AP<sup>medium | AP<sup>large | AR<sup>small | AR<sup>medium | AR<sup>large | download | config |
|:----------------------:|:-----:|:---------------:|:----------:|:-------------:| :------------:| :-----------: | :----------: |:------------:|:-------------:|:------------:| :-----: | :-----: |
| PP-YOLOE-s | 400 | 43.4 | 60.0 | 47.5 | 25.7 | 47.8 | 59.2 | 43.9 | 70.8 | 81.9 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_400e_coco.pdparams) | [config](./ppyoloe_crn_s_400e_coco.yml)|
| PP-YOLOE-s | 300 | 43.0 | 59.6 | 47.2 | 26.0 | 47.4 | 58.7 | 45.1 | 70.6 | 81.4 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_300e_coco.pdparams) | [config](./ppyoloe_crn_s_300e_coco.yml)|
| PP-YOLOE-m | 300 | 49.0 | 65.9 | 53.8 | 30.9 | 53.5 | 65.3 | 50.9 | 74.4 | 84.7 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_m_300e_coco.pdparams) | [config](./ppyoloe_crn_m_300e_coco.yml)|
| PP-YOLOE-l | 300 | 51.4 | 68.6 | 56.2 | 34.8 | 56.1 | 68.0 | 53.1 | 76.8 | 85.6 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams) | [config](./ppyoloe_crn_l_300e_coco.yml)|
| PP-YOLOE-x | 300 | 52.3 | 69.5 | 56.8 | 35.1 | 57.0 | 68.6 | 55.5 | 76.9 | 85.7 | [model](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_x_300e_coco.pdparams) | [config](./ppyoloe_crn_x_300e_coco.yml)|
**Notes:**
- PP-YOLOE is trained on COCO train2017 dataset and evaluated on val2017 & test-dev2017 dataset.
- The model weights in the table of Comprehensive Metrics are **the same as** that in the original Model Zoo, and evaluated on **val2017**.
- PP-YOLOE used 8 GPUs for training, if **GPU number** or **mini-batch size** is changed, **learning rate** should be adjusted according to the formula **lr<sub>new</sub> = lr<sub>default</sub> * (batch_size<sub>new</sub> * GPU_number<sub>new</sub>) / (batch_size<sub>default</sub> * GPU_number<sub>default</sub>)**.
- PP-YOLOE inference speed is tesed on single Tesla V100 with batch size as 1, **CUDA 10.2**, **CUDNN 7.6.5**, **TensorRT 6.0.1.8** in TensorRT mode.
## Appendix
Ablation experiments of PP-YOLOE.
| NO. | Model | Box AP<sup>val</sup> | Params(M) | FLOPs(G) | V100 FP32 FPS |
| :--: | :---------------------------: | :------------------: | :-------: | :------: | :-----------: |
| A | PP-YOLOv2 | 49.1 | 54.58 | 115.77 | 68.9 |
| B | A + Anchor-free | 48.8 | 54.27 | 114.78 | 69.8 |
| C | B + CSPRepResNet | 49.5 | 47.42 | 101.87 | 85.5 |
| D | C + TAL | 50.4 | 48.32 | 104.75 | 84.0 |
| E | D + ET-Head | 50.9 | 52.20 | 110.07 | 78.1 |
epoch: 300
LearningRate:
base_lr: 0.025
base_lr: 0.01
schedulers:
- !CosineDecay
max_epochs: 360
......
epoch: 400
LearningRate:
base_lr: 0.01
schedulers:
- !CosineDecay
max_epochs: 480
- !LinearWarmup
start_factor: 0.
epochs: 5
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
......@@ -15,7 +15,7 @@ TrainReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
- PadGT: {}
batch_size: 20
batch_size: 8
shuffle: true
drop_last: true
use_shared_memory: true
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'../../datasets/coco_detection.yml',
'../../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/ppyoloe_crn.yml',
'./_base_/ppyoloe_reader.yml',
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'../../datasets/coco_detection.yml',
'../../runtime.yml',
'./_base_/optimizer_36e_xpu.yml',
'./_base_/ppyoloe_reader.yml',
]
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'../../datasets/coco_detection.yml',
'../../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/ppyoloe_crn.yml',
'./_base_/ppyoloe_reader.yml',
......@@ -13,9 +13,3 @@ weights: output/ppyoloe_crn_m_300e_coco/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_m_pretrained.pdparams
depth_mult: 0.67
width_mult: 0.75
TrainReader:
batch_size: 28
LearningRate:
base_lr: 0.035
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'../../datasets/coco_detection.yml',
'../../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/ppyoloe_crn.yml',
'./_base_/ppyoloe_reader.yml',
......@@ -13,9 +13,3 @@ weights: output/ppyoloe_crn_s_300e_coco/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_s_pretrained.pdparams
depth_mult: 0.33
width_mult: 0.50
TrainReader:
batch_size: 32
LearningRate:
base_lr: 0.04
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_300e.yml',
'../../datasets/coco_detection.yml',
'../../runtime.yml',
'./_base_/optimizer_400e.yml',
'./_base_/ppyoloe_crn.yml',
'./_base_/ppyoloe_reader.yml',
]
......@@ -14,33 +14,5 @@ pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_s
depth_mult: 0.33
width_mult: 0.50
TrainReader:
batch_size: 32
epoch: 400
LearningRate:
base_lr: 0.04
schedulers:
- !CosineDecay
max_epochs: 480
- !LinearWarmup
start_factor: 0.
epochs: 5
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
PPYOLOEHead:
static_assigner_epoch: 133
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 300
score_threshold: 0.01
nms_threshold: 0.7
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'../../datasets/coco_detection.yml',
'../../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/ppyoloe_crn.yml',
'./_base_/ppyoloe_reader.yml',
......@@ -13,9 +13,3 @@ weights: output/ppyoloe_crn_x_300e_coco/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/CSPResNetb_x_pretrained.pdparams
depth_mult: 1.33
width_mult: 1.25
TrainReader:
batch_size: 16
LearningRate:
base_lr: 0.02
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_80e.yml',
'./_base_/ppyoloe_plus_crn.yml',
'./_base_/ppyoloe_plus_reader.yml',
]
log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_l_80e_coco/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_l_obj365_pretrained.pdparams
depth_mult: 1.0
width_mult: 1.0
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_80e.yml',
'./_base_/ppyoloe_plus_crn.yml',
'./_base_/ppyoloe_plus_reader.yml',
]
log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_m_80e_coco/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_m_obj365_pretrained.pdparams
depth_mult: 0.67
width_mult: 0.75
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_80e.yml',
'./_base_/ppyoloe_plus_crn.yml',
'./_base_/ppyoloe_plus_reader.yml',
]
log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_s_80e_coco/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_s_obj365_pretrained.pdparams
depth_mult: 0.33
width_mult: 0.50
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_80e.yml',
'./_base_/ppyoloe_plus_crn.yml',
'./_base_/ppyoloe_plus_reader.yml',
]
log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_x_80e_coco/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_x_obj365_pretrained.pdparams
depth_mult: 1.33
width_mult: 1.25
......@@ -65,7 +65,8 @@ class NormalizeImage : public PreprocessOp {
virtual void Init(const YAML::Node& item) {
mean_ = item["mean"].as<std::vector<float>>();
scale_ = item["std"].as<std::vector<float>>();
is_scale_ = item["is_scale"].as<bool>();
if (item["is_scale"]) is_scale_ = item["is_scale"].as<bool>();
if (item["norm_type"]) norm_type_ = item["norm_type"].as<std::string>();
}
virtual void Run(cv::Mat* im, ImageBlob* data);
......@@ -75,6 +76,7 @@ class NormalizeImage : public PreprocessOp {
std::vector<float> mean_;
std::vector<float> scale_;
bool is_scale_ = true;
std::string norm_type_ = "mean_std";
};
class Permute : public PreprocessOp {
......
......@@ -34,14 +34,16 @@ void NormalizeImage::Run(cv::Mat* im, ImageBlob* data) {
e /= 255.0;
}
(*im).convertTo(*im, CV_32FC3, e);
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean_[0]) / scale_[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean_[1]) / scale_[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean_[2]) / scale_[2];
if (norm_type_ == "mean_std"){
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean_[0]) / scale_[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean_[1]) / scale_[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean_[2]) / scale_[2];
}
}
}
}
......
......@@ -275,13 +275,14 @@ class NormalizeImage(object):
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
is_channel_first (bool): if True: image shape is CHW, else: HWC
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True):
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
......@@ -293,13 +294,15 @@ class NormalizeImage(object):
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.is_scale:
im = im / 255.0
im -= mean
im /= std
scale = 1.0 / 255.0
im *= scale
if self.norm_type == 'mean_std':
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
......
......@@ -87,13 +87,14 @@ class NormalizeImage(object):
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
is_channel_first (bool): if True: image shape is CHW, else: HWC
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True):
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
......@@ -105,13 +106,15 @@ class NormalizeImage(object):
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.is_scale:
im = im / 255.0
im -= mean
im /= std
scale = 1.0 / 255.0
im *= scale
if self.norm_type == 'mean_std':
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
......
......@@ -91,13 +91,14 @@ class NormalizeImage(object):
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
is_channel_first (bool): if True: image shape is CHW, else: HWC
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True):
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
......@@ -109,13 +110,15 @@ class NormalizeImage(object):
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.is_scale:
im = im / 255.0
im -= mean
im /= std
scale = 1.0 / 255.0
im *= scale
if self.norm_type == 'mean_std':
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
......
......@@ -359,19 +359,26 @@ class RandomErasingImage(BaseOperator):
@register_op
class NormalizeImage(BaseOperator):
def __init__(self, mean=[0.485, 0.456, 0.406], std=[1, 1, 1],
is_scale=True):
def __init__(self,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
norm_type='mean_std'):
"""
Args:
mean (list): the pixel mean
std (list): the pixel variance
is_scale (bool): scale the pixel to [0,1]
norm_type (str): type in ['mean_std', 'none']
"""
super(NormalizeImage, self).__init__()
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
if not (isinstance(self.mean, list) and isinstance(self.std, list) and
isinstance(self.is_scale, bool)):
isinstance(self.is_scale, bool) and
self.norm_type in ['mean_std', 'none']):
raise TypeError("{}: input type is invalid.".format(self))
from functools import reduce
if reduce(lambda x, y: x * y, self.std) == 0:
......@@ -380,20 +387,20 @@ class NormalizeImage(BaseOperator):
def apply(self, sample, context=None):
"""Normalize the image.
Operators:
1.(optional) Scale the image to [0,1]
2. Each pixel minus mean and is divided by std
1.(optional) Scale the pixel to [0,1]
2.(optional) Each pixel minus mean and is divided by std
"""
im = sample['image']
im = im.astype(np.float32, copy=False)
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.is_scale:
im = im / 255.0
im -= mean
im /= std
scale = 1.0 / 255.0
im *= scale
if self.norm_type == 'mean_std':
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
sample['image'] = im
return sample
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册