提交 7e172f88 编写于 作者: G Guanghua Yu 提交者: qingqing01

[PaddleDetection] Add VGG-SSD on VOC and COCO dataset. (#3037)

# Add VGG-SSD on VOC and COCO dataset.
# Add config and model zoo. 
# Refine bbox2out and draw_bbox:
    - Add bbox de-normalize in bbox2out function.
    - Remove bbox de-normalize in draw_bbox.
上级 85126e83
......@@ -38,16 +38,16 @@ multi-GPU training.
Supported Architectures:
| | ResNet | ResNet-vd <sup>[1](#vd)</sup> | ResNeXt-vd | SENet | MobileNet | DarkNet |
|--------------------|:------:|------------------------------:|:----------:|:-----:|:---------:|:-------:|
| Faster R-CNN | ✓ | ✓ | x | ✓ | ✗ | ✗ |
| Faster R-CNN + FPN | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
| Mask R-CNN | ✓ | ✓ | x | ✓ | ✗ | ✗ |
| Mask R-CNN + FPN | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
| Cascade R-CNN | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ |
| RetinaNet | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ |
| Yolov3 | ✓ | ✗ | ✗ | ✗ | ✓ | ✓ |
| SSD | ✗ | ✗ | ✗ | ✗ | ✓ | ✗ |
| | ResNet | ResNet-vd <sup>[1](#vd)</sup> | ResNeXt-vd | SENet | MobileNet | DarkNet | VGG |
|--------------------|:------:|------------------------------:|:----------:|:-----:|:---------:|:-------:|:---:|
| Faster R-CNN | ✓ | ✓ | x | ✓ | ✗ | ✗ | ✗ |
| Faster R-CNN + FPN | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✗ |
| Mask R-CNN | ✓ | ✓ | x | ✓ | ✗ | ✗ | ✗ |
| Mask R-CNN + FPN | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✗ |
| Cascade R-CNN | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
| RetinaNet | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
| Yolov3 | ✓ | ✗ | ✗ | ✗ | ✓ | ✓ | ✗ |
| SSD | ✗ | ✗ | ✗ | ✗ | ✓ | ✗ | ✓ |
<a name="vd">[1]</a> [ResNet-vd](https://arxiv.org/pdf/1812.01187) models offer much improved accuracy with negligible performance cost.
......
......@@ -27,16 +27,16 @@ PaddleDetection的目的是为工业界和学术界提供大量易使用的目
支持的模型结构:
| | ResNet | ResNet-vd <sup>[1](#vd)</sup> | ResNeXt-vd | SENet | MobileNet | DarkNet |
|--------------------|:------:|------------------------------:|:----------:|:-----:|:---------:|:-------:|
| Faster R-CNN | ✓ | ✓ | x | ✓ | ✗ | ✗ |
| Faster R-CNN + FPN | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
| Mask R-CNN | ✓ | ✓ | x | ✓ | ✗ | ✗ |
| Mask R-CNN + FPN | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
| Cascade R-CNN | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ |
| RetinaNet | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ |
| Yolov3 | ✓ | ✗ | ✗ | ✗ | ✓ | ✓ |
| SSD | ✗ | ✗ | ✗ | ✗ | ✓ | ✗ |
| | ResNet | ResNet-vd <sup>[1](#vd)</sup> | ResNeXt-vd | SENet | MobileNet | DarkNet | VGG |
|--------------------|:------:|------------------------------:|:----------:|:-----:|:---------:|:-------:|:---:|
| Faster R-CNN | ✓ | ✓ | x | ✓ | ✗ | ✗ | ✗ |
| Faster R-CNN + FPN | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✗ |
| Mask R-CNN | ✓ | ✓ | x | ✓ | ✗ | ✗ | ✗ |
| Mask R-CNN + FPN | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✗ |
| Cascade R-CNN | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
| RetinaNet | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
| Yolov3 | ✓ | ✗ | ✗ | ✗ | ✓ | ✓ | ✗ |
| SSD | ✗ | ✗ | ✗ | ✗ | ✓ | ✗ | ✓ |
<a name="vd">[1]</a> [ResNet-vd](https://arxiv.org/pdf/1812.01187) 模型提供了较大的精度提高和较少的性能损失。
......
architecture: SSD
train_feed: SSDTrainFeed
eval_feed: SSDEvalFeed
test_feed: SSDTestFeed
use_gpu: true
max_iters: 400000
snapshot_iter: 10000
log_smooth_window: 20
log_iter: 20
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/VGG16_caffe_pretrained.tar
save_dir: output
weights: output/ssd_vgg16_300/model_final
num_classes: 81
SSD:
backbone: VGG
multi_box_head: MultiBoxHead
metric:
ap_version: 11point
evaluate_difficult: false
overlap_threshold: 0.5
output_decoder:
background_label: 0
keep_top_k: 200
nms_eta: 1.0
nms_threshold: 0.45
nms_top_k: 400
score_threshold: 0.01
VGG:
depth: 16
with_extra_blocks: true
normalizations: [20., -1, -1, -1, -1, -1]
MultiBoxHead:
base_size: 300
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]]
min_ratio: 15
max_ratio: 90
min_sizes: [30.0, 60.0, 111.0, 162.0, 213.0, 264.0]
max_sizes: [60.0, 111.0, 162.0, 213.0, 264.0, 315.0]
steps: [8, 16, 32, 64, 100, 300]
offset: 0.5
flip: true
kernel_size: 3
pad: 1
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [280000, 360000]
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
SSDTrainFeed:
batch_size: 8
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_train2017.json
image_dir: train2017
image_shape: [3, 300, 300]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeBox {}
- !RandomDistort
brightness_lower: 0.875
brightness_upper: 1.125
is_order: true
- !ExpandImage
max_ratio: 4
mean: [104, 117, 123]
prob: 0.5
- !CropImage
avoid_no_bbox: true
batch_sampler:
- [1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]
satisfy_all: false
- !ResizeImage
interp: 1
target_size: 300
use_cv2: false
- !RandomFlipImage
is_normalized: true
- !Permute
to_bgr: false
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [1, 1, 1]
SSDEvalFeed:
batch_size: 16
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
drop_last: false
image_shape: [3, 300, 300]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !ResizeImage
interp: 1
target_size: 300
use_cv2: false
- !Permute
to_bgr: false
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [1, 1, 1]
SSDTestFeed:
batch_size: 1
dataset:
annotation: dataset/coco/annotations/instances_val2017.json
image_shape: [3, 300, 300]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !ResizeImage
interp: 1
max_size: 0
target_size: 300
use_cv2: false
- !Permute
to_bgr: false
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [1, 1, 1]
architecture: SSD
train_feed: SSDTrainFeed
eval_feed: SSDEvalFeed
test_feed: SSDTestFeed
use_gpu: true
max_iters: 120001
snapshot_iter: 10000
log_smooth_window: 20
log_iter: 20
metric: VOC
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/VGG16_caffe_pretrained.tar
save_dir: output
weights: output/ssd_vgg16_300_voc/model_final/
num_classes: 21
SSD:
backbone: VGG
multi_box_head: MultiBoxHead
metric:
ap_version: 11point
evaluate_difficult: false
overlap_threshold: 0.5
output_decoder:
background_label: 0
keep_top_k: 200
nms_eta: 1.0
nms_threshold: 0.45
nms_top_k: 400
score_threshold: 0.01
VGG:
depth: 16
with_extra_blocks: true
normalizations: [20., -1, -1, -1, -1, -1]
MultiBoxHead:
base_size: 300
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]]
min_ratio: 20
max_ratio: 90
min_sizes: [30.0, 60.0, 111.0, 162.0, 213.0, 264.0]
max_sizes: [60.0, 111.0, 162.0, 213.0, 264.0, 315.0]
steps: [8, 16, 32, 64, 100, 300]
offset: 0.5
flip: true
min_max_aspect_ratios_order: true
kernel_size: 3
pad: 1
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [80000, 100000]
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
SSDTrainFeed:
batch_size: 8
dataset:
dataset_dir: dataset/voc
annotation: VOCdevkit/VOC_all/ImageSets/Main/train.txt
image_dir: VOCdevkit/VOC_all/JPEGImages
use_default_label: true
image_shape: [3, 300, 300]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeBox {}
- !RandomDistort
brightness_lower: 0.875
brightness_upper: 1.125
is_order: true
- !ExpandImage
max_ratio: 4
mean: [104, 117, 123]
prob: 0.5
- !CropImage
avoid_no_bbox: true
batch_sampler:
- [1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]
satisfy_all: false
- !ResizeImage
interp: 1
target_size: 300
use_cv2: False
- !RandomFlipImage
is_normalized: true
- !Permute
to_bgr: false
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [1, 1, 1]
SSDEvalFeed:
batch_size: 32
dataset:
dataset_dir: dataset/voc
annotation: VOCdevkit/VOC_all/ImageSets/Main/val.txt
image_dir: VOCdevkit/VOC_all/JPEGImages
use_default_label: true
drop_last: false
image_shape: [3, 300, 300]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeBox {}
- !ResizeImage
interp: 1
target_size: 300
use_cv2: false
- !Permute
to_bgr: false
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [1, 1, 1]
SSDTestFeed:
batch_size: 1
dataset:
use_default_label: true
drop_last: false
image_shape: [3, 300, 300]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !ResizeImage
interp: 1
max_size: 0
target_size: 300
use_cv2: false
- !Permute
to_bgr: false
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [1, 1, 1]
architecture: SSD
train_feed: SSDTrainFeed
eval_feed: SSDEvalFeed
test_feed: SSDTestFeed
use_gpu: true
max_iters: 400000
snapshot_iter: 10000
log_smooth_window: 20
log_iter: 20
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/VGG16_caffe_pretrained.tar
save_dir: output
weights: output/ssd_vgg16_512/model_final
num_classes: 81
SSD:
backbone: VGG
multi_box_head: MultiBoxHead
metric:
ap_version: 11point
evaluate_difficult: false
overlap_threshold: 0.5
output_decoder:
background_label: 0
keep_top_k: 200
nms_eta: 1.0
nms_threshold: 0.45
nms_top_k: 400
score_threshold: 0.01
VGG:
depth: 16
with_extra_blocks: true
normalizations: [20., -1, -1, -1, -1, -1, -1]
extra_block_filters: [[256, 512, 1, 2, 3], [128, 256, 1, 2, 3], [128, 256, 1, 2, 3], [128, 256, 1, 2, 3], [128, 256, 1, 1, 4]]
MultiBoxHead:
base_size: 512
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]]
min_ratio: 15
max_ratio: 90
min_sizes: [20.0, 51.0, 133.0, 215.0, 296.0, 378.0, 460.0]
max_sizes: [51.0, 133.0, 215.0, 296.0, 378.0, 460.0, 542.0]
steps: [8, 16, 32, 64, 128, 256, 512]
offset: 0.5
flip: true
kernel_size: 3
pad: 1
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [280000, 360000]
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
SSDTrainFeed:
batch_size: 8
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_train2017.json
image_dir: train2017
image_shape: [3, 512, 512]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeBox {}
- !RandomDistort
brightness_lower: 0.875
brightness_upper: 1.125
is_order: true
- !ExpandImage
max_ratio: 4
mean: [104, 117, 123]
prob: 0.5
- !CropImage
avoid_no_bbox: true
batch_sampler:
- [1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]
satisfy_all: false
- !ResizeImage
interp: 1
target_size: 512
use_cv2: false
- !RandomFlipImage
is_normalized: true
- !Permute
to_bgr: false
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [1, 1, 1]
SSDEvalFeed:
batch_size: 8
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
drop_last: false
image_shape: [3, 512, 512]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !ResizeImage
interp: 1
target_size: 512
use_cv2: false
- !Permute
to_bgr: false
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [1, 1, 1]
SSDTestFeed:
batch_size: 1
dataset:
annotation: dataset/coco/annotations/instances_val2017.json
image_shape: [3, 512, 512]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !ResizeImage
interp: 1
max_size: 0
target_size: 512
use_cv2: false
- !Permute
to_bgr: false
- !NormalizeImage
is_scale: false
mean: [104, 117, 123]
std: [1, 1, 1]
architecture: SSD
train_feed: SSDTrainFeed
eval_feed: SSDEvalFeed
test_feed: SSDTestFeed
use_gpu: true
max_iters: 120000
snapshot_iter: 10000
log_smooth_window: 20
log_iter: 20
metric: VOC
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/VGG16_caffe_pretrained.tar
save_dir: output
weights: output/ssd_vgg16_512_voc/model_final/
num_classes: 21
SSD:
backbone: VGG
multi_box_head: MultiBoxHead
metric:
ap_version: 11point
evaluate_difficult: false
overlap_threshold: 0.5
output_decoder:
background_label: 0
keep_top_k: 200
nms_eta: 1.0
nms_threshold: 0.45
nms_top_k: 400
score_threshold: 0.01
VGG:
depth: 16
with_extra_blocks: true
normalizations: [20., -1, -1, -1, -1, -1, -1]
extra_block_filters: [[256, 512, 1, 2, 3], [128, 256, 1, 2, 3], [128, 256, 1, 2, 3], [128, 256, 1, 2, 3], [128, 256, 1, 1, 4]]
MultiBoxHead:
base_size: 512
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]]
min_ratio: 20
max_ratio: 90
min_sizes: [20.0, 51.0, 133.0, 215.0, 296.0, 378.0, 460.0]
max_sizes: [51.0, 133.0, 215.0, 296.0, 378.0, 460.0, 542.0]
steps: [8, 16, 32, 64, 128, 256, 512]
offset: 0.5
flip: true
kernel_size: 3
pad: 1
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [80000, 100000]
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
SSDTrainFeed:
batch_size: 8
dataset:
dataset_dir: dataset/voc
annotation: VOCdevkit/VOC_all/ImageSets/Main/train.txt
image_dir: VOCdevkit/VOC_all/JPEGImages
use_default_label: true
image_shape: [3, 512, 512]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeBox {}
- !RandomDistort
brightness_lower: 0.875
brightness_upper: 1.125
is_order: true
- !ExpandImage
max_ratio: 4
mean: [123, 117, 104]
prob: 0.5
- !CropImage
avoid_no_bbox: true
batch_sampler:
- [1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 0.0]
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]
satisfy_all: false
- !ResizeImage
interp: 1
target_size: 512
use_cv2: false
- !RandomFlipImage
is_normalized: true
- !Permute
to_bgr: false
- !NormalizeImage
is_scale: false
mean: [123, 117, 104]
std: [1, 1, 1]
SSDEvalFeed:
batch_size: 32
dataset:
dataset_dir: dataset/voc
annotation: VOCdevkit/VOC_all/ImageSets/Main/val.txt
image_dir: VOCdevkit/VOC_all/JPEGImages
use_default_label: true
drop_last: false
image_shape: [3, 512, 512]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeBox {}
- !ResizeImage
interp: 1
target_size: 512
use_cv2: false
- !Permute
to_bgr: false
- !NormalizeImage
is_scale: false
mean: [123, 117, 104]
std: [1, 1, 1]
SSDTestFeed:
batch_size: 1
dataset:
use_default_label: true
drop_last: false
image_shape: [3, 512, 512]
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !ResizeImage
interp: 1
max_size: 0
target_size: 512
use_cv2: false
- !Permute
to_bgr: false
- !NormalizeImage
is_scale: false
mean: [123, 117, 104]
std: [1, 1, 1]
......@@ -118,11 +118,21 @@ results of image size 608/416/320 above.
**Notes:** In RetinaNet, the base LR is changed to 0.01 for minibatch size 16.
### SSD
| Backbone | Size | Image/gpu | Lr schd | Box AP | Download |
| VGG16 | 300 | 8 | 40w | 25.1 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ssd_vgg16_300.tar) |
| VGG16 | 512 | 8 | 40w | 29.1 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ssd_vgg16_512.tar) |
**Notes:** VGG-SSD is trained in 4 GPU with total batch size as 32 and trained 400000 iters.
### SSD on Pascal VOC
| Backbone | Size | Image/gpu | Lr schd | Box AP | Download |
| :----------- | :--: | :-----: | :-----: | :----: | :-------: |
| MobileNet v1 | 300 | 32 | 120e | 73.13 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ssd_mobilenet_v1_voc.tar) |
| MobileNet v1 | 300 | 32 | 120e | 73.2 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ssd_mobilenet_v1_voc.tar) |
| VGG16 | 300 | 8 | 240e | 77.5 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ssd_vgg16_300_voc.tar) |
| VGG16 | 512 | 8 | 240e | 80.2 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ssd_vgg16_512_voc.tar) |
**Notes:** SSD is trained in 2 GPU with totoal batch size as 64 and trained 120 epoches. SSD training data augmentations: randomly color distortion,
**NOTE**: MobileNet-SSD is trained in 2 GPU with totoal batch size as 64 and trained 120 epoches. VGG-SSD is trained in 4 GPU with total batch size as 32 and trained 240 epoches. SSD training data augmentations: randomly color distortion,
randomly cropping, randomly expansion, randomly flipping.
......@@ -115,10 +115,20 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
**注意事项:** RetinaNet系列模型中,在总batch size为16下情况下,初始学习率改为0.01。
### SSD
| 骨架网络 | 输入尺寸 | 每张GPU图片个数 | 学习率策略 | Box AP | 下载 |
| VGG16 | 300 | 8 | 40万 | 25.1 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ssd_vgg16_300.tar) |
| VGG16 | 512 | 8 | 40万 | 29.1 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ssd_vgg16_512.tar) |
**注意事项:** VGG-SSD在总batch size为32下训练40万轮。
### SSD 基于Pascal VOC数据集
| 骨架网络 | 输入尺寸 | 每张GPU图片个数 | 学习率策略 | Box AP | 下载 |
| :----------- | :--: | :-----: | :-----: | :----: | :-------: |
| MobileNet v1 | 300 | 32 | 120e | 73.13 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ssd_mobilenet_v1_voc.tar) |
| MobileNet v1 | 300 | 32 | 120e | 73.2 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ssd_mobilenet_v1_voc.tar) |
| VGG16 | 300 | 8 | 240e | 77.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ssd_vgg16_300_voc.tar) |
| VGG16 | 512 | 8 | 240e | 80.2 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ssd_vgg16_512_voc.tar) |
**注意事项:** SSD在2卡,总batch size为64下训练120轮。数据增强包括:随机颜色失真,随机剪裁,随机扩张,随机翻转。
**注意事项:** MobileNet-SSD在2卡,总batch size为64下训练120周期。VGG-SSD在总batch size为32下训练240周期。数据增强包括:随机颜色失真,随机剪裁,随机扩张,随机翻转。
......@@ -28,9 +28,10 @@ from ppdet.data.transform.operators import (
DecodeImage, MixupImage, NormalizeBox, NormalizeImage, RandomDistort,
RandomFlipImage, RandomInterpImage, ResizeImage, ExpandImage, CropImage,
Permute)
from ppdet.data.transform.arrange_sample import (
ArrangeRCNN, ArrangeTestRCNN, ArrangeSSD, ArrangeTestSSD, ArrangeYOLO,
ArrangeEvalYOLO, ArrangeTestYOLO)
ArrangeRCNN, ArrangeTestRCNN, ArrangeSSD, ArrangeEvalSSD, ArrangeTestSSD,
ArrangeYOLO, ArrangeEvalYOLO, ArrangeTestYOLO)
__all__ = [
'PadBatch', 'MultiScale', 'RandomShape', 'DataSet', 'CocoDataSet',
......@@ -690,7 +691,7 @@ class SSDTrainFeed(DataFeed):
def __init__(self,
dataset=VocDataSet().__dict__,
fields=['image', 'gt_box', 'gt_label', 'is_difficult'],
fields=['image', 'gt_box', 'gt_label'],
image_shape=[3, 300, 300],
sample_transforms=[
DecodeImage(to_rgb=True, with_mixup=False),
......@@ -723,8 +724,6 @@ class SSDTrainFeed(DataFeed):
bufsize=10,
use_process=True):
sample_transforms.append(ArrangeSSD())
if isinstance(dataset, dict):
dataset = VocDataSet(**dataset)
super(SSDTrainFeed, self).__init__(
dataset,
fields,
......@@ -736,6 +735,7 @@ class SSDTrainFeed(DataFeed):
samples=samples,
drop_last=drop_last,
num_workers=num_workers,
bufsize=bufsize,
use_process=use_process)
self.mode = 'TRAIN'
......@@ -747,7 +747,8 @@ class SSDEvalFeed(DataFeed):
def __init__(
self,
dataset=VocDataSet(VOC_VAL_ANNOTATION).__dict__,
fields=['image', 'gt_box', 'gt_label', 'is_difficult'],
fields=['image', 'im_shape', 'im_id', 'gt_box',
'gt_label', 'is_difficult'],
image_shape=[3, 300, 300],
sample_transforms=[
DecodeImage(to_rgb=True, with_mixup=False),
......@@ -767,9 +768,7 @@ class SSDEvalFeed(DataFeed):
num_workers=8,
bufsize=10,
use_process=False):
sample_transforms.append(ArrangeSSD())
if isinstance(dataset, dict):
dataset = VocDataSet(**dataset)
sample_transforms.append(ArrangeEvalSSD())
super(SSDEvalFeed, self).__init__(
dataset,
fields,
......@@ -781,6 +780,7 @@ class SSDEvalFeed(DataFeed):
samples=samples,
drop_last=drop_last,
num_workers=num_workers,
bufsize=bufsize,
use_process=use_process)
self.mode = 'VAL'
......@@ -791,7 +791,7 @@ class SSDTestFeed(DataFeed):
def __init__(self,
dataset=SimpleDataSet(VOC_TEST_ANNOTATION).__dict__,
fields=['image', 'im_id'],
fields=['image', 'im_id', 'im_shape'],
image_shape=[3, 300, 300],
sample_transforms=[
DecodeImage(to_rgb=True),
......@@ -823,7 +823,9 @@ class SSDTestFeed(DataFeed):
shuffle=shuffle,
samples=samples,
drop_last=drop_last,
num_workers=num_workers)
num_workers=num_workers,
bufsize=bufsize,
use_process=use_process)
self.mode = 'TEST'
......
......@@ -131,15 +131,10 @@ class ArrangeTestRCNN(BaseOperator):
class ArrangeSSD(BaseOperator):
"""
Transform dict to tuple format needed for training.
Args:
is_mask (bool): whether to use include mask data
"""
def __init__(self, is_mask=False):
def __init__(self):
super(ArrangeSSD, self).__init__()
self.is_mask = is_mask
assert isinstance(self.is_mask, bool), "wrong type for is_mask"
def __call__(self, sample, context=None):
"""
......@@ -154,10 +149,40 @@ class ArrangeSSD(BaseOperator):
im = sample['image']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
difficult = sample['difficult']
outs = (im, gt_bbox, gt_class, difficult)
outs = (im, gt_bbox, gt_class)
return outs
@register_op
class ArrangeEvalSSD(BaseOperator):
"""
Transform dict to tuple format needed for training.
"""
def __init__(self):
super(ArrangeEvalSSD, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items: (image)
"""
im = sample['image']
if len(sample['gt_bbox']) != len(sample['gt_class']):
raise ValueError("gt num mismatch: bbox and class.")
im_id = sample['im_id']
h = sample['h']
w = sample['w']
im_shape = np.array((h, w))
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
difficult = sample['difficult']
outs = (im, im_shape, im_id, gt_bbox, gt_class, difficult)
return outs
@register_op
class ArrangeTestSSD(BaseOperator):
......@@ -168,10 +193,8 @@ class ArrangeTestSSD(BaseOperator):
is_mask (bool): whether to use include mask data
"""
def __init__(self, is_mask=False):
def __init__(self):
super(ArrangeTestSSD, self).__init__()
self.is_mask = is_mask
assert isinstance(self.is_mask, bool), "wrong type for is_mask"
def __call__(self, sample, context=None):
"""
......@@ -184,7 +207,10 @@ class ArrangeTestSSD(BaseOperator):
"""
im = sample['image']
im_id = sample['im_id']
outs = (im, im_id)
h = sample['h']
w = sample['w']
im_shape = np.array((h, w))
outs = (im, im_id, im_shape)
return outs
......
......@@ -63,7 +63,6 @@ class SSD(object):
if mode == 'train' or mode == 'eval':
gt_box = feed_vars['gt_box']
gt_label = feed_vars['gt_label']
difficult = feed_vars['is_difficult']
body_feats = self.backbone(im)
locs, confs, box, box_var = self.multi_box_head(
......@@ -76,17 +75,7 @@ class SSD(object):
return {'loss': loss}
else:
pred = self.output_decoder(locs, confs, box, box_var)
if mode == 'eval':
map_eval = self.metric(
pred,
gt_label,
gt_box,
difficult,
class_num=self.num_classes)
_, accum_map = map_eval.get_map_var()
return {'map': map_eval, 'accum_map': accum_map}
else:
return {'bbox': pred}
return {'bbox': pred}
def train(self, feed_vars):
return self.build(feed_vars, 'train')
......@@ -99,5 +88,5 @@ class SSD(object):
def is_bbox_normalized(self):
# SSD use output_decoder in output layers, bbox is normalized
# to range [0, 1], is_bbox_normalized is used in infer.py
# to range [0, 1], is_bbox_normalized is used in eval.py and infer.py
return True
......@@ -20,6 +20,7 @@ from . import darknet
from . import mobilenet
from . import senet
from . import fpn
from . import vgg
from .resnet import *
from .resnext import *
......@@ -27,3 +28,4 @@ from .darknet import *
from .mobilenet import *
from .senet import *
from .fpn import *
from .vgg import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from ppdet.core.workspace import register
__all__ = ['VGG']
@register
class VGG(object):
"""
VGG, see https://arxiv.org/abs/1409.1556
Args:
depth (int): the VGG net depth (16 or 19)
normalizations (list): params list of init scale in l2 norm, skip init
scale if param is -1.
with_extra_blocks (bool): whether or not extra blocks should be added
extra_block_filters (list): in each extra block, params:
[in_channel, out_channel, padding_size, stride_size, filter_size]
"""
def __init__(self,
depth=16,
with_extra_blocks=False,
normalizations=[20., -1, -1, -1, -1, -1],
extra_block_filters=[[256, 512, 1, 2, 3], [128, 256, 1, 2, 3],
[128, 256, 0, 1, 3], [128, 256, 0, 1, 3]]):
assert depth in [16, 19], \
"depth {} not in [16, 19]"
self.depth = depth
self.depth_cfg = {
16: [2, 2, 3, 3, 3],
19: [2, 2, 4, 4, 4]
}
self.with_extra_blocks = with_extra_blocks
self.normalizations = normalizations
self.extra_block_filters = extra_block_filters
def __call__(self, input):
layers = []
layers += self._vgg_block(input)
if not self.with_extra_blocks:
return layers[-1]
layers += self._add_extras_block(layers[-1])
norm_cfg = self.normalizations
for k, v in enumerate(layers):
if not norm_cfg[k] == -1:
layers[k] = self._l2_norm_scale(v, init_scale=norm_cfg[k])
return layers
def _vgg_block(self, input):
nums = self.depth_cfg[self.depth]
vgg_base = [64, 128, 256, 512, 512]
conv = input
layers = []
for k, v in enumerate(vgg_base):
conv = self._conv_block(conv, v, nums[k], name="conv{}_".format(k + 1))
layers.append(conv)
if k == 4:
conv = self._pooling_block(conv, 3, 1, pool_padding=1)
else:
conv = self._pooling_block(conv, 2, 2)
fc6 = self._conv_layer(conv, 1024, 3, 1, 6, dilation=6, name="fc6")
fc7 = self._conv_layer(fc6, 1024, 1, 1, 0, name="fc7")
return [layers[3], fc7]
def _add_extras_block(self, input):
cfg = self.extra_block_filters
conv = input
layers = []
for k, v in enumerate(cfg):
assert len(v) == 5, "extra_block_filters size not fix"
conv = self._extra_block(conv, v[0], v[1],
v[2], v[3], v[4], name="conv{}_".format(6 + k))
layers.append(conv)
return layers
def _conv_block(self, input, num_filter, groups, name=None):
conv = input
for i in range(groups):
conv = self._conv_layer(
input=conv,
num_filters=num_filter,
filter_size=3,
stride=1,
padding=1,
act='relu',
name=name + str(i + 1))
return conv
def _extra_block(self,
input,
num_filters1,
num_filters2,
padding_size,
stride_size,
filter_size,
name=None):
# 1x1 conv
conv_1 = self._conv_layer(
input=input,
num_filters=int(num_filters1),
filter_size=1,
stride=1,
act='relu',
padding=0,
name=name + "1")
# 3x3 conv
conv_2 = self._conv_layer(
input=conv_1,
num_filters=int(num_filters2),
filter_size=filter_size,
stride=stride_size,
act='relu',
padding=padding_size,
name=name + "2")
return conv_2
def _conv_layer(self,
input,
num_filters,
filter_size,
stride,
padding,
dilation=1,
act='relu',
use_cudnn=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
dilation=dilation,
act=act,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=ParamAttr(name=name + "_biases"),
name=name + '.conv2d.output.1')
return conv
def _pooling_block(self,
conv,
pool_size,
pool_stride,
pool_padding=0,
ceil_mode=True):
pool = fluid.layers.pool2d(
input=conv,
pool_size=pool_size,
pool_type='max',
pool_stride=pool_stride,
pool_padding=pool_padding,
ceil_mode=ceil_mode)
return pool
def _l2_norm_scale(self, input, init_scale=1.0, channel_shared=False):
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.initializer import Constant
helper = LayerHelper("Scale")
l2_norm = fluid.layers.l2_normalize(
input, axis=1) # l2 norm along channel
shape = [1] if channel_shared else [input.shape[1]]
scale = helper.create_parameter(
attr=helper.param_attr,
shape=shape,
dtype=input.dtype,
default_initializer=Constant(init_scale))
out = fluid.layers.elementwise_mul(
x=l2_norm, y=scale, axis=-1 if channel_shared else 1,
name="conv4_3_norm_scale")
return out
......@@ -255,22 +255,30 @@ class MultiBoxHead(object):
def __init__(self,
min_ratio=20,
max_ratio=90,
base_size=300,
min_sizes=[60.0, 105.0, 150.0, 195.0, 240.0, 285.0],
max_sizes=[[], 150.0, 195.0, 240.0, 285.0, 300.0],
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.],
[2., 3.]],
base_size=300,
steps=None,
offset=0.5,
flip=True):
flip=True,
min_max_aspect_ratios_order=False,
kernel_size=1,
pad=0):
super(MultiBoxHead, self).__init__()
self.min_ratio = min_ratio
self.max_ratio = max_ratio
self.base_size = base_size
self.min_sizes = min_sizes
self.max_sizes = max_sizes
self.aspect_ratios = aspect_ratios
self.base_size = base_size
self.steps = steps
self.offset = offset
self.flip = flip
self.min_max_aspect_ratios_order = min_max_aspect_ratios_order
self.kernel_size = kernel_size
self.pad = pad
@register
......
......@@ -66,7 +66,12 @@ def proposal_eval(results, anno_file, outfile, max_dets=(100, 300, 1000)):
# flush coco evaluation result
sys.stdout.flush()
def bbox_eval(results, anno_file, outfile, with_background=True):
def bbox_eval(results,
anno_file,
outfile,
with_background=True,
is_bbox_normalized=False):
assert 'bbox' in results[0]
assert outfile.endswith('.json')
......@@ -79,7 +84,9 @@ def bbox_eval(results, anno_file, outfile, with_background=True):
{i + int(with_background): catid
for i, catid in enumerate(cat_ids)})
xywh_results = bbox2out(results, clsid2catid)
xywh_results = bbox2out(
results, clsid2catid, is_bbox_normalized=is_bbox_normalized)
if len(xywh_results) == 0:
logger.warning("The number of valid bbox detected is zero.\n \
Please use reasonable model and check input data.\n \
......@@ -111,6 +118,7 @@ def mask_eval(results, anno_file, outfile, resolution, thresh_binarize=0.5):
cocoapi_eval(outfile, 'segm', coco_gt=coco_gt)
def cocoapi_eval(jsonfile,
style,
coco_gt=None,
......@@ -141,6 +149,7 @@ def cocoapi_eval(jsonfile,
coco_eval.summarize()
return coco_eval.stats
def proposal2out(results, is_bbox_normalized=False):
xywh_res = []
for t in results:
......@@ -180,6 +189,13 @@ def proposal2out(results, is_bbox_normalized=False):
def bbox2out(results, clsid2catid, is_bbox_normalized=False):
"""
Args:
results: request a dict, should include: `bbox`, `im_id`,
if is_bbox_normalized=True, also need `im_shape`.
clsid2catid: class id to category id map of COCO2017 dataset.
is_bbox_normalized: whether or not bbox is normalized.
"""
xywh_res = []
for t in results:
bboxes = t['bbox'][0]
......@@ -202,6 +218,11 @@ def bbox2out(results, clsid2catid, is_bbox_normalized=False):
clip_bbox([xmin, ymin, xmax, ymax])
w = xmax - xmin
h = ymax - ymin
im_height, im_width = t['im_shape'][0][i].tolist()
xmin *= im_width
ymin *= im_height
w *= im_width
h *= im_height
else:
w = xmax - xmin + 1
h = ymax - ymin + 1
......
......@@ -113,8 +113,11 @@ def eval_results(results,
output = 'bbox.json'
if output_directory:
output = os.path.join(output_directory, 'bbox.json')
box_ap_stats = bbox_eval(results, anno_file, output,
with_background)
with_background,
is_bbox_normalized=is_bbox_normalized)
if 'mask' in results[0]:
output = 'mask.json'
if output_directory:
......
......@@ -31,8 +31,7 @@ def visualize_results(image,
catid2name,
threshold=0.5,
bbox_results=None,
mask_results=None,
is_bbox_normalized=False):
mask_results=None):
"""
Visualize bbox and mask results
"""
......@@ -40,7 +39,7 @@ def visualize_results(image,
image = draw_mask(image, im_id, mask_results, threshold)
if bbox_results:
image = draw_bbox(image, im_id, catid2name, bbox_results,
threshold, is_bbox_normalized)
threshold)
return image
......@@ -69,8 +68,7 @@ def draw_mask(image, im_id, segms, threshold, alpha=0.7):
return Image.fromarray(img_array.astype('uint8'))
def draw_bbox(image, im_id, catid2name, bboxes, threshold,
is_bbox_normalized=False):
def draw_bbox(image, im_id, catid2name, bboxes, threshold):
"""
Draw bbox on image
"""
......@@ -86,12 +84,6 @@ def draw_bbox(image, im_id, catid2name, bboxes, threshold,
continue
xmin, ymin, w, h = bbox
if is_bbox_normalized:
im_width, im_height = image.size
xmin *= im_width
ymin *= im_height
w *= im_width
h *= im_height
xmax = xmin + w
ymax = ymin + h
......
......@@ -186,7 +186,7 @@ def main():
if cfg['metric'] == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape']
if cfg['metric'] == 'VOC':
extra_keys = ['im_id']
extra_keys = ['im_id', 'im_shape']
keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys)
# parse dataset category
......@@ -235,7 +235,7 @@ def main():
image = visualize_results(image,
int(im_id), catid2name,
FLAGS.draw_threshold, bbox_results,
mask_results, is_bbox_normalized)
mask_results)
save_name = get_save_image_name(FLAGS.output_dir, image_path)
logger.info("Detection bbox results save in {}".format(save_name))
image.save(save_name, quality=95)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册