提交 f961433c 编写于 作者: D dengkaipeng

add ppyolo mobilenetv3

上级 593b6b9e
......@@ -50,6 +50,9 @@ PP-YOLO improved performance and speed of YOLOv3 with following methods:
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 512 | 44.4 | 45.0 | 89.9 | 188.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_2x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_2x.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 416 | 42.7 | 43.2 | 109.1 | 215.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_2x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_2x.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 320 | 39.5 | 40.1 | 132.2 | 242.2 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_2x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_2x.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 512 | 29.3 | 29.5 | 357.1 | 657.9 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 416 | 28.6 | 28.9 | 409.8 | 719.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 320 | 26.2 | 26.4 | 480.7 | 763.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
**Notes:**
......@@ -63,14 +66,14 @@ PP-YOLO improved performance and speed of YOLOv3 with following methods:
### PP-YOLO for mobile
| Model | GPU number | images/GPU | backbone | input shape | Box AP50<sup>val</sup> | Box AP50<sup>test</sup> | V100 FP32(FPS) | V100 TensorRT FP16(FPS) | download | config |
|:------------------------:|:----------:|:----------:|:----------:| :----------:| :--------------------: | :---------------------: | :------------: | :---------------------: | :------: | :-----: |
| PP-YOLO_r18vd | 4 | 32 | ResNet18vd | 416 | 47.0 | 47.7 | 401.6 | 724.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| PP-YOLO_r18vd | 4 | 32 | ResNet18vd | 320 | 43.7 | 44.4 | 478.5 | 791.3 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| Model | GPU number | images/GPU | Model Size | input shape | Box AP<sup>val</sup> | Kirin 990(FPS) | download | config |
|:------------- -------------:|:----------:|:----------:| :--------: | :----------:| :------------------: | :------------: | :------: | :-----: |
| PP-YOLO_MobileNetV3_large | 4 | 32 | 18MB | 320 | 22.0 | 14.1 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_large.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_large.yml) |
| PP-YOLO_MobileNetV3_small | 4 | 32 | 11MB | 320 | 16.8 | 21.5 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_small.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_small.yml) |
- PP-YOLO_r18vd is trained on COCO train2017 datast and evaluated on val2017 & test-dev2017 dataset,Box AP50<sup>val</sup> is evaluation results of `mAP(IoU=0.5)`.
- PP-YOLO_r18vd used 4 GPUs for training and mini-batch size as 32 on each GPU, if GPU number and mini-batch size is changed, learning rate and iteration times should be adjusted according [FAQ](../../docs/FAQ.md).
- PP-YOLO_r18vd inference speeding testing environment and configuration is same as PP-YOLO above.
- PP-YOLO_MobileNetV3 is trained on COCO train2017 datast and evaluated on val2017 dataset,Box AP<sup>val</sup> is evaluation results of `mAP(IoU=0.5:0.95)`.
- PP-YOLO_MobileNetV3 used 4 GPUs for training and mini-batch size as 32 on each GPU, if GPU number and mini-batch size is changed, learning rate and iteration times should be adjusted according [FAQ](../../docs/FAQ.md).
- PP-YOLO_MobileNetV3 inference speed is tested on Kirin 990 with 1 thread.
## Getting Start
......
......@@ -50,6 +50,9 @@ PP-YOLO从如下方面优化和提升YOLOv3模型的精度和速度:
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 512 | 44.4 | 45.0 | 89.9 | 188.4 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 416 | 42.7 | 43.2 | 109.1 | 215.4 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 320 | 39.5 | 40.1 | 132.2 | 242.2 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 512 | 29.3 | 29.5 | 357.1 | 657.9 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 416 | 28.6 | 28.9 | 409.8 | 719.4 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 320 | 26.2 | 26.4 | 480.7 | 763.4 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
**注意:**
......@@ -64,14 +67,14 @@ PP-YOLO从如下方面优化和提升YOLOv3模型的精度和速度:
### PP-YOLO 移动端模型
| 模型 | GPU个数 | 每GPU图片个数 | 骨干网络 | 输入尺寸 | Box AP50<sup>val</sup> | Box AP50<sup>test</sup> | V100 FP32(FPS) | V100 TensorRT FP16(FPS) | 模型下载 | 配置文件 |
|:------------------------:|:-------:|:-------------:|:----------:| :-------:| :--------------------: | :---------------------: |------------: | :---------------------: | :------: | :------: |
| PP-YOLO_r18vd | 4 | 32 | ResNet18vd | 416 | 47.0 | 47.7 | 401.6 | 724.6 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| PP-YOLO_r18vd | 4 | 32 | ResNet18vd | 320 | 43.7 | 44.4 | 478.5 | 791.3 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| 模型 | GPU个数 | 每GPU图片个数 | 模型体积 | 输入尺寸 | Box AP<sup>val</sup> | Kirin 990 (FPS) | 模型下载 | 配置文件 |
|:----------------------------:|:-------:|:-------------:|:----------:| :-------:| :------------------: | :-------------: |------------: | :---------------------: | :------: | :------: |
| PP-YOLO_MobileNetV3_large | 4 | 32 | 18MB | 320 | 22.0 | 14.1 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_large.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_large.yml) |
| PP-YOLO_MobileNetV3_small | 4 | 32 | 11MB | 320 | 16.8 | 21.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_small.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_small.yml) |
- PP-YOLO_r18vd 模型使用COCO数据集中train2017作为训练集,使用val2017和test-dev2017作为测试集,Box AP50<sup>val</sup>`mAP(IoU=0.5)`评估结果。
- PP-YOLO_r18vd 模型训练过程中使用4GPU,每GPU batch size为32进行训练,如训练GPU数和batch size不使用上述配置,须参考[FAQ](../../docs/FAQ.md)调整学习率和迭代次数。
- PP-YOLO_r18vd 模型推理速度测试环境配置和测试方法与PP-YOLO模型一致
- PP-YOLO_MobileNetV3 模型使用COCO数据集中train2017作为训练集,使用val2017作为测试集,Box AP50<sup>val</sup>`mAP(IoU=0.5:0.95)`评估结果。
- PP-YOLO_MobileNetV3 模型训练过程中使用4GPU,每GPU batch size为32进行训练,如训练GPU数和batch size不使用上述配置,须参考[FAQ](../../docs/FAQ.md)调整学习率和迭代次数。
- PP-YOLO_MobileNetV3 模型推理速度测试环境配置为麒麟990芯片单线程
## 使用说明
......
architecture: YOLOv3
use_gpu: true
max_iters: 250000
log_smooth_window: 20
log_iter: 20
save_dir: output
snapshot_iter: 10000
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar
weights: output/ppyolo_tiny/model_final
num_classes: 80
use_fine_grained_loss: true
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: MobileNetV3
yolo_head: YOLOv3Head
use_fine_grained_loss: true
MobileNetV3:
norm_type: sync_bn
norm_decay: 0.
model_name: large
scale: 1.
extra_block_filters: []
feature_maps: [1, 2, 3, 4, 6]
YOLOv3Head:
anchor_masks: [[3, 4, 5], [0, 1, 2]]
anchors: [[10, 14], [23, 27], [37, 58],
[81, 82], [135, 169], [344, 319]]
norm_decay: 0.
conv_block_num: 0
scale_x_y: 1.05
yolo_loss: YOLOv3Loss
spp: true
nms:
background_label: -1
keep_top_k: 100
nms_threshold: 0.45
nms_top_k: 1000
normalized: false
score_threshold: 0.01
drop_block: true
YOLOv3Loss:
batch_size: 32
ignore_thresh: 0.7
scale_x_y: 1.05
label_smooth: false
use_fine_grained_loss: true
iou_loss: IouLoss
IouLoss:
loss_weight: 2.5
max_height: 608
max_width: 608
LearningRate:
base_lr: 0.00666
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 150000
- 200000
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_READER_: 'ppyolo_reader.yml'
TrainReader:
inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: train_data/dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: True
- !MixupImage
alpha: 1.5
beta: 1.5
- !ColorDistort {}
- !RandomExpand
fill_value: [123.675, 116.28, 103.53]
- !RandomCrop {}
- !RandomFlipImage
is_normalized: false
- !NormalizeBox {}
- !PadBox
num_max_boxes: 50
- !BboxXYXY2XYWH {}
batch_transforms:
- !RandomShape
sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
random_inter: True
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
- !Gt2YoloTarget
anchor_masks: [[3, 4, 5], [0, 1, 2]]
anchors: [[10, 14], [23, 27], [37, 58],
[81, 82], [135, 169], [344, 319]]
downsample_ratios: [32, 16]
batch_size: 32
shuffle: true
mixup_epoch: 500
drop_last: true
worker_num: 16
bufsize: 16
memsize: 8G
use_process: true
EvalReader:
inputs_def:
fields: ['image', 'im_size', 'im_id']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 320
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !PadBox
num_max_boxes: 50
- !Permute
to_bgr: false
channel_first: True
batch_size: 8
drop_empty: false
worker_num: 2
bufsize: 4
architecture: YOLOv3
use_gpu: true
max_iters: 250000
log_smooth_window: 20
log_iter: 20
save_dir: output
snapshot_iter: 10000
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_ssld_pretrained.tar
weights: output/ppyolo_tiny/model_final
num_classes: 80
use_fine_grained_loss: true
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: MobileNetV3
yolo_head: YOLOv3Head
use_fine_grained_loss: true
MobileNetV3:
norm_type: sync_bn
norm_decay: 0.
model_name: small
scale: 1.
extra_block_filters: []
feature_maps: [1, 2, 3, 4, 6]
YOLOv3Head:
anchor_masks: [[3, 4, 5], [0, 1, 2]]
anchors: [[10, 14], [23, 27], [37, 58],
[81, 82], [135, 169], [344, 319]]
norm_decay: 0.
conv_block_num: 0
scale_x_y: 1.05
yolo_loss: YOLOv3Loss
spp: true
nms:
background_label: -1
keep_top_k: 100
nms_threshold: 0.45
nms_top_k: 1000
normalized: false
score_threshold: 0.01
drop_block: true
YOLOv3Loss:
batch_size: 32
ignore_thresh: 0.7
scale_x_y: 1.05
label_smooth: false
use_fine_grained_loss: true
iou_loss: IouLoss
IouLoss:
loss_weight: 2.5
max_height: 608
max_width: 608
LearningRate:
base_lr: 0.00666
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 150000
- 200000
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_READER_: 'ppyolo_reader.yml'
TrainReader:
inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: train_data/dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: True
- !MixupImage
alpha: 1.5
beta: 1.5
- !ColorDistort {}
- !RandomExpand
fill_value: [123.675, 116.28, 103.53]
- !RandomCrop {}
- !RandomFlipImage
is_normalized: false
- !NormalizeBox {}
- !PadBox
num_max_boxes: 50
- !BboxXYXY2XYWH {}
batch_transforms:
- !RandomShape
sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
random_inter: True
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
- !Gt2YoloTarget
anchor_masks: [[3, 4, 5], [0, 1, 2]]
anchors: [[10, 14], [23, 27], [37, 58],
[81, 82], [135, 169], [344, 319]]
downsample_ratios: [32, 16]
batch_size: 32
shuffle: true
mixup_epoch: 500
drop_last: true
worker_num: 16
bufsize: 16
memsize: 8G
use_process: true
EvalReader:
inputs_def:
fields: ['image', 'im_size', 'im_id']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 320
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !PadBox
num_max_boxes: 50
- !Permute
to_bgr: false
channel_first: True
batch_size: 8
drop_empty: false
worker_num: 2
bufsize: 4
......@@ -263,7 +263,10 @@ class YOLOv3Head(object):
keep_prob=self.keep_prob,
is_test=is_test)
if self.drop_block and is_first:
if self.use_spp and conv_block_num == 0 and is_first:
conv = self._spp_module(conv, name="spp")
if self.drop_block and (is_first or conv_block_num == 0):
conv = DropBlock(
conv,
block_size=self.block_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册