未验证 提交 a917efca 编写于 作者: F Feng Ni 提交者: GitHub

[Dygraph] add YOLOv3-MobileNet v1 v3 (#2022)

* add mbv1v3-yolov3

* add ssd mbv1v3 voc

* yolov3 mbv1v3 voc

* training log fix

* yolov3 voc fix

* yolov3 mbv1v3 training and eval, ssd fix

* add syncbn

* mbv1v3 clean code

* mbv1v3 clean codes

* remove ssd mbv1v3

* update modelzoo doc

* remove others

* fix ssd reader and prior_box

* ssd mbv1 config

* fix anchor steps and reader op

* ssdlite mbv1v3 infer and train

* update yolov3 config and modelzoo

* update yolov3_mbv3 modelzoo

* pre commit update

* update modelzoo

* fix mbv3 in yolov3

* fix precommit
Co-authored-by: Nnemonameless <nemonameless@qq.com>
上级 11a059c9
architecture: SSD
pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/ssd_mobilenet_v1_coco_pretrained.tar
weights: output/ssd_mobilenetv1/model_final
SSD:
backbone: MobileNet
ssd_head: SSDHead
post_process: BBoxPostProcess
MobileNet:
norm_decay: 0.
scale: 1
conv_learning_rate: 0.1
extra_block_filters: [[256, 512], [128, 256], [128, 256], [64, 128]]
with_extra_blocks: true
feature_maps: [11, 13, 14, 15, 16, 17]
SSDHead:
in_channels: [512, 1024, 512, 256, 256, 128]
anchor_generator: AnchorGeneratorSSD
kernel_size: 1
padding: 0
AnchorGeneratorSSD:
steps: [0, 0, 0, 0, 0, 0]
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2., 3.]]
min_ratio: 20
max_ratio: 90
base_size: 300
min_sizes: [60.0, 105.0, 150.0, 195.0, 240.0, 285.0]
max_sizes: [[], 150.0, 195.0, 240.0, 285.0, 300.0]
offset: 0.5
flip: true
min_max_aspect_ratios_order: false
BBoxPostProcess:
decode:
name: SSDBox
nms:
name: MultiClassNMS
keep_top_k: 200
score_threshold: 0.01
nms_threshold: 0.45
nms_top_k: 400
nms_eta: 1.0
architecture: SSD
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_ssld_pretrained.tar
weights: output/ssdlite_mobilenet_v1/model_final
SSD:
backbone: MobileNet
ssd_head: SSDHead
post_process: BBoxPostProcess
MobileNet:
conv_decay: 0.00004
scale: 1
extra_block_filters: [[256, 512], [128, 256], [128, 256], [64, 128]]
with_extra_blocks: true
feature_maps: [11, 13, 14, 15, 16, 17]
SSDHead:
in_channels: [512, 1024, 512, 256, 256, 128]
anchor_generator: AnchorGeneratorSSD
use_sepconv: True
conv_decay: 0.00004
AnchorGeneratorSSD:
steps: [16, 32, 64, 100, 150, 300]
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2., 3.]]
min_ratio: 20
max_ratio: 95
base_size: 300
min_sizes: []
max_sizes: []
offset: 0.5
flip: true
clip: true
min_max_aspect_ratios_order: False
BBoxPostProcess:
decode:
name: SSDBox
nms:
name: MultiClassNMS
keep_top_k: 200
score_threshold: 0.01
nms_threshold: 0.45
nms_top_k: 400
nms_eta: 1.0
architecture: SSD
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar
weights: output/ssdlite_mobilenet_v3_large/model_final
SSD:
backbone: MobileNetV3
ssd_head: SSDHead
post_process: BBoxPostProcess
MobileNetV3:
scale: 1.0
model_name: large
conv_decay: 0.00004
with_extra_blocks: true
extra_block_filters: [[256, 512], [128, 256], [128, 256], [64, 128]]
feature_maps: [14, 17, 18, 19, 20, 21]
lr_mult_list: [0.25, 0.25, 0.5, 0.5, 0.75]
multiplier: 0.5
SSDHead:
in_channels: [672, 480, 512, 256, 256, 128]
anchor_generator: AnchorGeneratorSSD
use_sepconv: True
conv_decay: 0.00004
AnchorGeneratorSSD:
steps: [16, 32, 64, 107, 160, 320]
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2., 3.]]
min_ratio: 20
max_ratio: 95
base_size: 320
min_sizes: []
max_sizes: []
offset: 0.5
flip: true
clip: true
min_max_aspect_ratios_order: false
BBoxPostProcess:
decode:
name: SSDBox
nms:
name: MultiClassNMS
keep_top_k: 200
score_threshold: 0.01
nms_threshold: 0.45
nms_top_k: 400
nms_eta: 1.0
architecture: SSD
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_ssld_pretrained.tar
weights: output/ssd_mobilenet_v3_small/model_final
SSD:
backbone: MobileNetV3
ssd_head: SSDHead
post_process: BBoxPostProcess
MobileNetV3:
scale: 1.0
model_name: small
conv_decay: 0.00004
with_extra_blocks: true
extra_block_filters: [[256, 512], [128, 256], [128, 256], [64, 128]]
feature_maps: [10, 13, 14, 15, 16, 17]
lr_mult_list: [0.25, 0.25, 0.5, 0.5, 0.75]
multiplier: 0.5
SSDHead:
in_channels: [288, 288, 512, 256, 256, 128]
anchor_generator: AnchorGeneratorSSD
use_sepconv: True
conv_decay: 0.00004
AnchorGeneratorSSD:
steps: [16, 32, 64, 107, 160, 320]
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2., 3.]]
min_ratio: 20
max_ratio: 95
base_size: 320
min_sizes: []
max_sizes: []
offset: 0.5
flip: true
clip: true
min_max_aspect_ratios_order: false
BBoxPostProcess:
decode:
name: SSDBox
nms:
name: MultiClassNMS
keep_top_k: 200
score_threshold: 0.01
nms_threshold: 0.45
nms_top_k: 400
nms_eta: 1.0
architecture: YOLOv3
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar
weights: output/yolov3_mobilenet_v1/model_final
load_static_weights: True
norm_type: sync_bn
YOLOv3:
backbone: MobileNet
neck: YOLOv3FPN
yolo_head: YOLOv3Head
post_process: BBoxPostProcess
MobileNet:
scale: 1
feature_maps: [4, 6, 13]
with_extra_blocks: false
extra_block_filters: []
YOLOv3FPN:
feat_channels: [1024, 768, 384]
YOLOv3Head:
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
loss: YOLOv3Loss
YOLOv3Loss:
ignore_thresh: 0.7
downsample: [32, 16, 8]
label_smooth: false
BBoxPostProcess:
decode:
name: YOLOBox
conf_thresh: 0.005
downsample_ratio: 32
clip_bbox: true
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.01
nms_threshold: 0.45
nms_top_k: 1000
normalized: false
background_label: -1
architecture: YOLOv3
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_pretrained.tar
weights: output/yolov3_mobilenet_v3_large/model_final
load_static_weights: True
norm_type: sync_bn
YOLOv3:
backbone: MobileNetV3
neck: YOLOv3FPN
yolo_head: YOLOv3Head
post_process: BBoxPostProcess
MobileNetV3:
model_name: large
scale: 1.
with_extra_blocks: false
extra_block_filters: []
feature_maps: [7, 13, 16]
YOLOv3FPN:
feat_channels: [160, 368, 168]
YOLOv3Head:
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
loss: YOLOv3Loss
YOLOv3Loss:
ignore_thresh: 0.7
downsample: [32, 16, 8]
label_smooth: false
BBoxPostProcess:
decode:
name: YOLOBox
conf_thresh: 0.005
downsample_ratio: 32
clip_bbox: true
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.01
nms_threshold: 0.45
nms_top_k: 1000
normalized: false
background_label: -1
architecture: YOLOv3
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_pretrained.tar
weights: output/yolov3_mobilenet_v3_small/model_final
load_static_weights: True
norm_type: sync_bn
YOLOv3:
backbone: MobileNetV3
neck: YOLOv3FPN
yolo_head: YOLOv3Head
post_process: BBoxPostProcess
MobileNetV3:
model_name: small
scale: 1.
with_extra_blocks: false
extra_block_filters: []
feature_maps: [4, 9, 12]
YOLOv3FPN:
feat_channels: [96, 304, 152]
YOLOv3Head:
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
loss: YOLOv3Loss
YOLOv3Loss:
ignore_thresh: 0.7
downsample: [32, 16, 8]
label_smooth: false
BBoxPostProcess:
decode:
name: YOLOBox
conf_thresh: 0.005
downsample_ratio: 32
clip_bbox: true
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.01
nms_threshold: 0.45
nms_top_k: 1000
normalized: false
background_label: -1
epoch: 120
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 80
- 100
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
epoch: 1746
LearningRate:
base_lr: 0.4
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 160
- 200
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 2000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
worker_num: 8
TrainReader:
inputs_def:
num_max_boxes: 90
sample_transforms:
- DecodeOp: {}
- RandomDistortOp: {brightness: [0.5, 1.125, 0.875], random_apply: False}
- RandomExpandOp: {fill_value: [127.5, 127.5, 127.5]}
- RandomCropOp: {allow_no_crop: Fasle}
- RandomFlipOp: {}
- NormalizeBoxOp: {}
- ResizeImage: {target_size: 300, interp: 1, use_cv2: false}
- PadBoxOp: {num_max_boxes: 90}
batch_transforms:
- NormalizeImageOp: {mean: [127.5, 127.5, 127.5], std: [127.502231, 127.502231, 127.502231], is_scale: false}
- Permute: {to_bgr: true}
batch_size: 32
shuffle: true
drop_last: true
EvalReader:
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [300, 300], keep_ratio: False, interp: 1}
- NormalizeImageOp: {mean: [127.5, 127.5, 127.5], std: [127.502231, 127.502231, 127.502231], is_scale: false}
- Permute: {to_bgr: true}
batch_size: 1
drop_empty: false
TestReader:
inputs_def:
image_shape: [3, 300, 300]
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [300, 300], keep_ratio: False, interp: 1}
- NormalizeImageOp: {mean: [127.5, 127.5, 127.5], std: [127.502231, 127.502231, 127.502231], is_scale: false}
- Permute: {to_bgr: true}
batch_size: 1
......@@ -9,8 +9,8 @@ TrainReader:
- RandomExpandOp: {fill_value: [104., 117., 123.]}
- RandomCropOp: {allow_no_crop: true}
- RandomFlipOp: {}
- NormalizeBoxOp: {}
- ResizeOp: {target_size: [300, 300], keep_ratio: False, interp: 1}
- NormalizeBoxOp: {}
- PadBoxOp: {num_max_boxes: 90}
batch_transforms:
......
worker_num: 8
TrainReader:
inputs_def:
num_max_boxes: 90
sample_transforms:
- DecodeOp: {}
- RandomDistortOp: {brightness: [0.5, 1.125, 0.875], random_apply: False}
- RandomExpandOp: {fill_value: [123.675, 116.28, 103.53]}
- RandomCropOp: {allow_no_crop: Fasle}
- RandomFlipOp: {}
- NormalizeBoxOp: {}
- ResizeImage: {target_size: 300, interp: 1, use_cv2: false}
- PadBoxOp: {num_max_boxes: 90}
batch_transforms:
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- PermuteOp: {}
batch_size: 64
shuffle: true
drop_last: true
EvalReader:
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [300, 300], keep_ratio: False, interp: 1}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- PermuteOp: {}
batch_size: 1
drop_empty: false
TestReader:
inputs_def:
image_shape: [3, 300, 300]
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [300, 300], keep_ratio: False, interp: 1}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- PermuteOp: {}
batch_size: 1
worker_num: 8
TrainReader:
inputs_def:
num_max_boxes: 90
sample_transforms:
- DecodeOp: {}
- RandomDistortOp: {brightness: [0.5, 1.125, 0.875], random_apply: False}
- RandomExpandOp: {fill_value: [123.675, 116.28, 103.53]}
- RandomCropOp: {allow_no_crop: Fasle}
- RandomFlipOp: {}
- NormalizeBoxOp: {}
- ResizeImage: {target_size: 320, interp: 1, use_cv2: false}
- PadBoxOp: {num_max_boxes: 90}
batch_transforms:
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- PermuteOp: {}
batch_size: 64
shuffle: true
drop_last: true
EvalReader:
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [320, 320], keep_ratio: False, interp: 1}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- PermuteOp: {}
batch_size: 1
drop_empty: false
TestReader:
inputs_def:
image_shape: [3, 320, 320]
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [320, 320], keep_ratio: False, interp: 1}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- PermuteOp: {}
batch_size: 1
_BASE_: [
'./_base_/models/ssd_mobilenet_v1_300.yml',
'./_base_/optimizers/ssd_mobilenet_120e.yml',
'./_base_/datasets/voc.yml',
'./_base_/readers/ssd_mobilenet_reader.yml',
'./_base_/runtime.yml',
]
_BASE_: [
'./_base_/models/ssdlite_mobilenet_v1_300.yml',
'./_base_/optimizers/ssdlite_1000e.yml',
'./_base_/datasets/coco_detection.yml',
'./_base_/readers/ssdlite300_reader.yml',
'./_base_/runtime.yml',
]
_BASE_: [
'./_base_/models/ssdlite_mobilenet_v3_large_320.yml',
'./_base_/optimizers/ssdlite_1000e.yml',
'./_base_/datasets/coco_detection.yml',
'./_base_/readers/ssdlite320_reader.yml',
'./_base_/runtime.yml',
]
_BASE_: [
'./_base_/models/ssdlite_mobilenet_v3_small_320.yml',
'./_base_/optimizers/ssdlite_1000e.yml',
'./_base_/datasets/coco_detection.yml',
'./_base_/readers/ssdlite320_reader.yml',
'./_base_/runtime.yml',
]
_BASE_: [
'./_base_/models/yolov3_mobilenet_v1.yml',
'./_base_/optimizers/yolov3_270e.yml',
'./_base_/datasets/coco_detection.yml',
'./_base_/readers/yolov3_reader.yml',
'./_base_/runtime.yml',
]
_BASE_: [
'./_base_/models/yolov3_mobilenet_v1.yml',
'./_base_/optimizers/yolov3_270e.yml',
'./_base_/datasets/voc.yml',
'./_base_/readers/yolov3_reader.yml',
'./_base_/runtime.yml',
]
TrainReader:
inputs_def:
num_max_boxes: 50
sample_transforms:
- DecodeOp: {}
- MixupOp: {alpha: 1.5, beta: 1.5}
- RandomDistortOp: {}
- RandomExpandOp: {fill_value: [123.675, 116.28, 103.53]}
- RandomCropOp: {}
- RandomFlipOp: {}
batch_transforms:
- BatchRandomResizeOp:
target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
random_size: True
random_interp: True
keep_ratio: False
- NormalizeBoxOp: {}
- PadBoxOp: {num_max_boxes: 50}
- BboxXYXY2XYWHOp: {}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
- Gt2YoloTargetOp:
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]]
downsample_ratios: [32, 16, 8]
num_classes: 20
batch_size: 8
shuffle: true
drop_last: true
mixup_epoch: 250
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 216
- 243
- !LinearWarmup
start_factor: 0.
steps: 1000
_BASE_: [
'./_base_/models/yolov3_mobilenet_v3_large.yml',
'./_base_/optimizers/yolov3_270e.yml',
'./_base_/datasets/coco_detection.yml',
'./_base_/readers/yolov3_reader.yml',
'./_base_/runtime.yml',
]
_BASE_: [
'./_base_/models/yolov3_mobilenet_v3_large.yml',
'./_base_/optimizers/yolov3_270e.yml',
'./_base_/datasets/voc.yml',
'./_base_/readers/yolov3_reader.yml',
'./_base_/runtime.yml',
]
TrainReader:
inputs_def:
num_max_boxes: 50
sample_transforms:
- DecodeOp: {}
- MixupOp: {alpha: 1.5, beta: 1.5}
- RandomDistortOp: {}
- RandomExpandOp: {fill_value: [123.675, 116.28, 103.53]}
- RandomCropOp: {}
- RandomFlipOp: {}
batch_transforms:
- BatchRandomResizeOp:
target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
random_size: True
random_interp: True
keep_ratio: False
- NormalizeBoxOp: {}
- PadBoxOp: {num_max_boxes: 50}
- BboxXYXY2XYWHOp: {}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
- Gt2YoloTargetOp:
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]]
downsample_ratios: [32, 16, 8]
num_classes: 20
batch_size: 8
shuffle: true
drop_last: true
mixup_epoch: 250
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 216
- 243
- !LinearWarmup
start_factor: 0.
steps: 1000
......@@ -40,7 +40,37 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
| ResNet50-FPN | Mask | 1 | 1x | ---- | 38.3 | 34.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/mask_rcnn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/mask_rcnn_r50_fpn_1x_coco.yml) |
| ResNet50-FPN | Cascade Faster | 1 | 1x | ---- | 41.1 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/cascade_rcnn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/cascade_faster_rcnn_r50_fpn_1x_coco.yml) |
| ResNet50-FPN | Cascade Mask | 1 | 1x | ---- | 41.6 | 35.3 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/cascade_mask_rcnn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/cascade_mask_rcnn_r50_fpn_1x_coco.yml) |
| DarkNet53 | YOLOv3 | 1 | 270e | ---- | 39.0 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_darknet53_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_darknet53_270e_coco.yml) |
### YOLOv3 on COCO
| 骨架网络 | 输入尺寸 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 |
| :------------------- | :------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
| DarkNet53(paper) | 608 | 8 | 270e | ---- | 33.0 | - | - |
| DarkNet53(paper) | 416 | 8 | 270e | ---- | 31.0 | - | - |
| DarkNet53(paper) | 320 | 8 | 270e | ---- | 28.2 | - | - |
| DarkNet53 | 608 | 8 | 270e | ---- | 39.0 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_darknet53_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_darknet53_270e_coco.yml) |
| DarkNet53 | 416 | 8 | 270e | ---- | 37.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_darknet53_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_darknet53_270e_coco.yml) |
| DarkNet53 | 320 | 8 | 270e | ---- | 34.6 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_darknet53_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_darknet53_270e_coco.yml) |
| MobileNet-V1 | 608 | 8 | 270e | ---- | 28.8 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_mobilenet_v1_270e_coco.yml) |
| MobileNet-V1 | 416 | 8 | 270e | ---- | 28.7 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_mobilenet_v1_270e_coco.yml) |
| MobileNet-V1 | 320 | 8 | 270e | ---- | 26.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_mobilenet_v1_270e_coco.yml) |
| MobileNet-V3 | 608 | 8 | 270e | ---- | 31.4 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v3_large_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_mobilenet_v3_large_270e_coco.yml) |
| MobileNet-V3 | 416 | 8 | 270e | ---- | 29.7 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v3_large_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_mobilenet_v3_large_270e_coco.yml) |
| MobileNet-V3 | 320 | 8 | 270e | ---- | 26.9 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v3_large_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_mobilenet_v3_large_270e_coco.yml) |
### YOLOv3 on Pasacl VOC
| 骨架网络 | 输入尺寸 | 每张GPU图片个数 | 学习率策略 |推理时间(fps)| Box AP | 下载 | 配置文件 |
| :----------- | :--: | :-----: | :-----: |:------------: |:----: | :-------: | :----: |
| MobileNet-V1 | 608 | 8 | 270e | - | 75.1 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_mobilenet_v1_270e_voc.yml) |
| MobileNet-V1 | 416 | 8 | 270e | - | 76.1 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_mobilenet_v1_270e_voc.yml) |
| MobileNet-V1 | 320 | 8 | 270e | - | 73.6 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v1_270e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_mobilenet_v1_270e_voc.yml) |
| MobileNet-V3 | 608 | 8 | 270e | - | 79.6 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v3_large_270e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_mobilenet_v3_large_270e_voc.yml) |
| MobileNet-V3 | 416 | 8 | 270e | - | 78.6 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v3_large_270e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_mobilenet_v3_large_270e_voc.yml) |
| MobileNet-V3 | 320 | 8 | 270e | - | 76.4 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_mobilenet_v3_large_270e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_mobilenet_v3_large_270e_voc.yml) |
**注意:** YOLOv3均使用8GPU训练,训练270个epoch
### SSD on Pascal VOC
......
......@@ -229,6 +229,9 @@ class Gt2YoloTargetOp(BaseOperator):
im = sample['image']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
if 'gt_score' not in sample:
sample['gt_score'] = np.ones(
(gt_bbox.shape[0], 1), dtype=np.float32)
gt_score = sample['gt_score']
for i, (
mask, downsample_ratio
......
......@@ -1582,6 +1582,12 @@ class MixupOp(BaseOperator):
is_crowd2 = sample[1]['is_crowd']
is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
result['is_crowd'] = is_crowd
if 'difficult' in sample[0]:
is_difficult1 = sample[0]['difficult']
is_difficult2 = sample[1]['difficult']
is_difficult = np.concatenate(
(is_difficult1, is_difficult2), axis=0)
result['difficult'] = is_difficult
return result
......
from . import vgg
from . import resnet
from . import darknet
from . import mobilenet_v1
from . import mobilenet_v3
from .vgg import *
from .resnet import *
from .darknet import *
from .mobilenet_v1 import *
from .mobilenet_v3 import *
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from paddle.nn.initializer import KaimingNormal
from ppdet.core.workspace import register, serializable
from numbers import Integral
__all__ = ['MobileNet']
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
num_groups=1,
act='relu',
conv_lr=1.,
conv_decay=0.,
norm_decay=0.,
norm_type='bn',
name=None):
super(ConvBNLayer, self).__init__()
self.act = act
self._conv = nn.Conv2D(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=num_groups,
weight_attr=ParamAttr(
learning_rate=conv_lr,
initializer=KaimingNormal(),
regularizer=L2Decay(conv_decay),
name=name + "_weights"),
bias_attr=False)
if norm_type == 'sync_bn':
batch_norm = nn.SyncBatchNorm
else:
batch_norm = nn.BatchNorm2D
self._batch_norm = batch_norm(
out_channels,
weight_attr=ParamAttr(
name=name + "_bn_scale", regularizer=L2Decay(norm_decay)),
bias_attr=ParamAttr(
name=name + "_bn_offset", regularizer=L2Decay(norm_decay)))
def forward(self, x):
x = self._conv(x)
x = self._batch_norm(x)
if self.act == "relu":
x = F.relu(x)
elif self.act == "relu6":
x = F.relu6(x)
return x
class DepthwiseSeparable(nn.Layer):
def __init__(self,
in_channels,
out_channels1,
out_channels2,
num_groups,
stride,
scale,
conv_lr=1.,
conv_decay=0.,
norm_decay=0.,
norm_type='bn',
name=None):
super(DepthwiseSeparable, self).__init__()
self._depthwise_conv = ConvBNLayer(
in_channels,
int(out_channels1 * scale),
kernel_size=3,
stride=stride,
padding=1,
num_groups=int(num_groups * scale),
conv_lr=conv_lr,
conv_decay=conv_decay,
norm_decay=norm_decay,
norm_type=norm_type,
name=name + "_dw")
self._pointwise_conv = ConvBNLayer(
int(out_channels1 * scale),
int(out_channels2 * scale),
kernel_size=1,
stride=1,
padding=0,
conv_lr=conv_lr,
conv_decay=conv_decay,
norm_decay=norm_decay,
norm_type=norm_type,
name=name + "_sep")
def forward(self, x):
x = self._depthwise_conv(x)
x = self._pointwise_conv(x)
return x
class ExtraBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels1,
out_channels2,
num_groups=1,
stride=2,
conv_lr=1.,
conv_decay=0.,
norm_decay=0.,
norm_type='bn',
name=None):
super(ExtraBlock, self).__init__()
self.pointwise_conv = ConvBNLayer(
in_channels,
int(out_channels1),
kernel_size=1,
stride=1,
padding=0,
num_groups=int(num_groups),
act='relu6',
conv_lr=conv_lr,
conv_decay=conv_decay,
norm_decay=norm_decay,
norm_type=norm_type,
name=name + "_extra1")
self.normal_conv = ConvBNLayer(
int(out_channels1),
int(out_channels2),
kernel_size=3,
stride=stride,
padding=1,
num_groups=int(num_groups),
act='relu6',
conv_lr=conv_lr,
conv_decay=conv_decay,
norm_decay=norm_decay,
norm_type=norm_type,
name=name + "_extra2")
def forward(self, x):
x = self.pointwise_conv(x)
x = self.normal_conv(x)
return x
@register
@serializable
class MobileNet(nn.Layer):
__shared__ = ['norm_type']
def __init__(self,
norm_type='bn',
norm_decay=0.,
conv_decay=0.,
scale=1,
conv_learning_rate=1.0,
feature_maps=[4, 6, 13],
with_extra_blocks=False,
extra_block_filters=[[256, 512], [128, 256], [128, 256],
[64, 128]]):
super(MobileNet, self).__init__()
if isinstance(feature_maps, Integral):
feature_maps = [feature_maps]
self.feature_maps = feature_maps
self.with_extra_blocks = with_extra_blocks
self.extra_block_filters = extra_block_filters
self.conv1 = ConvBNLayer(
in_channels=3,
out_channels=int(32 * scale),
kernel_size=3,
stride=2,
padding=1,
conv_lr=conv_learning_rate,
conv_decay=conv_decay,
norm_decay=norm_decay,
norm_type=norm_type,
name="conv1")
self.dwsl = []
dws21 = self.add_sublayer(
"conv2_1",
sublayer=DepthwiseSeparable(
in_channels=int(32 * scale),
out_channels1=32,
out_channels2=64,
num_groups=32,
stride=1,
scale=scale,
conv_lr=conv_learning_rate,
conv_decay=conv_decay,
norm_decay=norm_decay,
norm_type=norm_type,
name="conv2_1"))
self.dwsl.append(dws21)
dws22 = self.add_sublayer(
"conv2_2",
sublayer=DepthwiseSeparable(
in_channels=int(64 * scale),
out_channels1=64,
out_channels2=128,
num_groups=64,
stride=2,
scale=scale,
conv_lr=conv_learning_rate,
conv_decay=conv_decay,
norm_decay=norm_decay,
norm_type=norm_type,
name="conv2_2"))
self.dwsl.append(dws22)
# 1/4
dws31 = self.add_sublayer(
"conv3_1",
sublayer=DepthwiseSeparable(
in_channels=int(128 * scale),
out_channels1=128,
out_channels2=128,
num_groups=128,
stride=1,
scale=scale,
conv_lr=conv_learning_rate,
conv_decay=conv_decay,
norm_decay=norm_decay,
norm_type=norm_type,
name="conv3_1"))
self.dwsl.append(dws31)
dws32 = self.add_sublayer(
"conv3_2",
sublayer=DepthwiseSeparable(
in_channels=int(128 * scale),
out_channels1=128,
out_channels2=256,
num_groups=128,
stride=2,
scale=scale,
conv_lr=conv_learning_rate,
conv_decay=conv_decay,
norm_decay=norm_decay,
norm_type=norm_type,
name="conv3_2"))
self.dwsl.append(dws32)
# 1/8
dws41 = self.add_sublayer(
"conv4_1",
sublayer=DepthwiseSeparable(
in_channels=int(256 * scale),
out_channels1=256,
out_channels2=256,
num_groups=256,
stride=1,
scale=scale,
conv_lr=conv_learning_rate,
conv_decay=conv_decay,
norm_decay=norm_decay,
norm_type=norm_type,
name="conv4_1"))
self.dwsl.append(dws41)
dws42 = self.add_sublayer(
"conv4_2",
sublayer=DepthwiseSeparable(
in_channels=int(256 * scale),
out_channels1=256,
out_channels2=512,
num_groups=256,
stride=2,
scale=scale,
conv_lr=conv_learning_rate,
conv_decay=conv_decay,
norm_decay=norm_decay,
norm_type=norm_type,
name="conv4_2"))
self.dwsl.append(dws42)
# 1/16
for i in range(5):
tmp = self.add_sublayer(
"conv5_" + str(i + 1),
sublayer=DepthwiseSeparable(
in_channels=512,
out_channels1=512,
out_channels2=512,
num_groups=512,
stride=1,
scale=scale,
conv_lr=conv_learning_rate,
conv_decay=conv_decay,
norm_decay=norm_decay,
norm_type=norm_type,
name="conv5_" + str(i + 1)))
self.dwsl.append(tmp)
dws56 = self.add_sublayer(
"conv5_6",
sublayer=DepthwiseSeparable(
in_channels=int(512 * scale),
out_channels1=512,
out_channels2=1024,
num_groups=512,
stride=2,
scale=scale,
conv_lr=conv_learning_rate,
conv_decay=conv_decay,
norm_decay=norm_decay,
norm_type=norm_type,
name="conv5_6"))
self.dwsl.append(dws56)
# 1/32
dws6 = self.add_sublayer(
"conv6",
sublayer=DepthwiseSeparable(
in_channels=int(1024 * scale),
out_channels1=1024,
out_channels2=1024,
num_groups=1024,
stride=1,
scale=scale,
conv_lr=conv_learning_rate,
conv_decay=conv_decay,
norm_decay=norm_decay,
norm_type=norm_type,
name="conv6"))
self.dwsl.append(dws6)
if self.with_extra_blocks:
self.extra_blocks = []
for i, block_filter in enumerate(self.extra_block_filters):
in_c = 1024 if i == 0 else self.extra_block_filters[i - 1][1]
conv_extra = self.add_sublayer(
"conv7_" + str(i + 1),
sublayer=ExtraBlock(
in_c,
block_filter[0],
block_filter[1],
conv_lr=conv_learning_rate,
conv_decay=conv_decay,
norm_decay=norm_decay,
norm_type=norm_type,
name="conv7_" + str(i + 1)))
self.extra_blocks.append(conv_extra)
def forward(self, inputs):
outs = []
y = self.conv1(inputs['image'])
for i, block in enumerate(self.dwsl):
y = block(y)
if i + 1 in self.feature_maps:
outs.append(y)
if not self.with_extra_blocks:
return outs
y = outs[-1]
for i, block in enumerate(self.extra_blocks):
idx = i + len(self.dwsl)
y = block(y)
if idx + 1 in self.feature_maps:
outs.append(y)
return outs
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.functional.activation import hard_sigmoid
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable
from numbers import Integral
__all__ = ['MobileNetV3']
def make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNLayer(nn.Layer):
def __init__(self,
in_c,
out_c,
filter_size,
stride,
padding,
num_groups=1,
act=None,
lr_mult=1.,
conv_decay=0.,
norm_type='bn',
norm_decay=0.,
freeze_norm=False,
name=""):
super(ConvBNLayer, self).__init__()
self.act = act
self.conv = nn.Conv2D(
in_channels=in_c,
out_channels=out_c,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
weight_attr=ParamAttr(
learning_rate=lr_mult,
regularizer=L2Decay(conv_decay),
name=name + "_weights"),
bias_attr=False)
norm_lr = 0. if freeze_norm else lr_mult
if norm_type == 'sync_bn':
batch_norm = nn.SyncBatchNorm
else:
batch_norm = nn.BatchNorm2D
self.bn = batch_norm(
out_c,
weight_attr=ParamAttr(
learning_rate=norm_lr,
name=name + "_bn_scale",
regularizer=L2Decay(norm_decay)),
bias_attr=ParamAttr(
learning_rate=norm_lr,
name=name + "_bn_offset",
regularizer=L2Decay(norm_decay)))
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.act is not None:
if self.act == "relu":
x = F.relu(x)
elif self.act == "relu6":
x = F.relu6(x)
elif self.act == "hard_swish":
x = F.hardswish(x)
else:
raise NotImplementedError(
"The activation function is selected incorrectly.")
return x
class ResidualUnit(nn.Layer):
def __init__(self,
in_c,
mid_c,
out_c,
filter_size,
stride,
use_se,
lr_mult,
conv_decay=0.,
norm_type='bn',
norm_decay=0.,
freeze_norm=False,
act=None,
return_list=False,
name=''):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_c == out_c
self.use_se = use_se
self.return_list = return_list
self.expand_conv = ConvBNLayer(
in_c=in_c,
out_c=mid_c,
filter_size=1,
stride=1,
padding=0,
act=act,
lr_mult=lr_mult,
conv_decay=conv_decay,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name=name + "_expand")
self.bottleneck_conv = ConvBNLayer(
in_c=mid_c,
out_c=mid_c,
filter_size=filter_size,
stride=stride,
padding=int((filter_size - 1) // 2),
num_groups=mid_c,
act=act,
lr_mult=lr_mult,
conv_decay=conv_decay,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name=name + "_depthwise")
if self.use_se:
self.mid_se = SEModule(
mid_c, lr_mult, conv_decay, name=name + "_se")
self.linear_conv = ConvBNLayer(
in_c=mid_c,
out_c=out_c,
filter_size=1,
stride=1,
padding=0,
act=None,
lr_mult=lr_mult,
conv_decay=conv_decay,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name=name + "_linear")
def forward(self, inputs):
y = self.expand_conv(inputs)
x = self.bottleneck_conv(y)
if self.use_se:
x = self.mid_se(x)
x = self.linear_conv(x)
if self.if_shortcut:
x = paddle.add(inputs, x)
if self.return_list:
return [y, x]
else:
return x
class SEModule(nn.Layer):
def __init__(self, channel, lr_mult, conv_decay, reduction=4, name=""):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
mid_channels = int(channel // reduction)
self.conv1 = nn.Conv2D(
in_channels=channel,
out_channels=mid_channels,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(
learning_rate=lr_mult,
regularizer=L2Decay(conv_decay),
name=name + "_1_weights"),
bias_attr=ParamAttr(
learning_rate=lr_mult,
regularizer=L2Decay(conv_decay),
name=name + "_1_offset"))
self.conv2 = nn.Conv2D(
in_channels=mid_channels,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(
learning_rate=lr_mult,
regularizer=L2Decay(conv_decay),
name=name + "_2_weights"),
bias_attr=ParamAttr(
learning_rate=lr_mult,
regularizer=L2Decay(conv_decay),
name=name + "_2_offset"))
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = hard_sigmoid(outputs)
return paddle.multiply(x=inputs, y=outputs)
class ExtraBlockDW(nn.Layer):
def __init__(self,
in_c,
ch_1,
ch_2,
stride,
lr_mult,
conv_decay=0.,
norm_type='bn',
norm_decay=0.,
freeze_norm=False,
name=None):
super(ExtraBlockDW, self).__init__()
self.pointwise_conv = ConvBNLayer(
in_c=in_c,
out_c=ch_1,
filter_size=1,
stride=1,
padding='SAME',
act='relu6',
lr_mult=lr_mult,
conv_decay=conv_decay,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name=name + "_extra1")
self.depthwise_conv = ConvBNLayer(
in_c=ch_1,
out_c=ch_2,
filter_size=3,
stride=stride,
padding='SAME',
num_groups=int(ch_1),
act='relu6',
lr_mult=lr_mult,
conv_decay=conv_decay,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name=name + "_extra2_dw")
self.normal_conv = ConvBNLayer(
in_c=ch_2,
out_c=ch_2,
filter_size=1,
stride=1,
padding='SAME',
act='relu6',
lr_mult=lr_mult,
conv_decay=conv_decay,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name=name + "_extra2_sep")
def forward(self, inputs):
x = self.pointwise_conv(inputs)
x = self.depthwise_conv(x)
x = self.normal_conv(x)
return x
@register
@serializable
class MobileNetV3(nn.Layer):
__shared__ = ['norm_type']
def __init__(
self,
scale=1.0,
model_name="large",
feature_maps=[6, 12, 15],
with_extra_blocks=False,
extra_block_filters=[[256, 512], [128, 256], [128, 256], [64, 128]],
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
conv_decay=0.0,
multiplier=1.0,
norm_type='bn',
norm_decay=0.0,
freeze_norm=False):
super(MobileNetV3, self).__init__()
if isinstance(feature_maps, Integral):
feature_maps = [feature_maps]
if norm_type == 'sync_bn' and freeze_norm:
raise ValueError(
"The norm_type should not be sync_bn when freeze_norm is True")
self.feature_maps = feature_maps
self.with_extra_blocks = with_extra_blocks
self.extra_block_filters = extra_block_filters
inplanes = 16
if model_name == "large":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, "relu", 1],
[3, 64, 24, False, "relu", 2],
[3, 72, 24, False, "relu", 1],
[5, 72, 40, True, "relu", 2],
[5, 120, 40, True, "relu", 1],
[5, 120, 40, True, "relu", 1], # YOLOv3 output
[3, 240, 80, False, "hard_swish", 2],
[3, 200, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 480, 112, True, "hard_swish", 1],
[3, 672, 112, True, "hard_swish", 1], # YOLOv3 output
[5, 672, 160, True, "hard_swish", 2], # SSD/SSDLite output
[5, 960, 160, True, "hard_swish", 1],
[5, 960, 160, True, "hard_swish", 1], # YOLOv3 output
]
elif model_name == "small":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, "relu", 2],
[3, 72, 24, False, "relu", 2],
[3, 88, 24, False, "relu", 1], # YOLOv3 output
[5, 96, 40, True, "hard_swish", 2],
[5, 240, 40, True, "hard_swish", 1],
[5, 240, 40, True, "hard_swish", 1],
[5, 120, 48, True, "hard_swish", 1],
[5, 144, 48, True, "hard_swish", 1], # YOLOv3 output
[5, 288, 96, True, "hard_swish", 2], # SSD/SSDLite output
[5, 576, 96, True, "hard_swish", 1],
[5, 576, 96, True, "hard_swish", 1], # YOLOv3 output
]
else:
raise NotImplementedError(
"mode[{}_model] is not implemented!".format(model_name))
if multiplier != 1.0:
self.cfg[-3][2] = int(self.cfg[-3][2] * multiplier)
self.cfg[-2][1] = int(self.cfg[-2][1] * multiplier)
self.cfg[-2][2] = int(self.cfg[-2][2] * multiplier)
self.cfg[-1][1] = int(self.cfg[-1][1] * multiplier)
self.cfg[-1][2] = int(self.cfg[-1][2] * multiplier)
self.conv1 = ConvBNLayer(
in_c=3,
out_c=make_divisible(inplanes * scale),
filter_size=3,
stride=2,
padding=1,
num_groups=1,
act="hard_swish",
lr_mult=lr_mult_list[0],
conv_decay=conv_decay,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name="conv1")
self.block_list = []
i = 0
inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in self.cfg:
lr_idx = min(i // 3, len(lr_mult_list) - 1)
lr_mult = lr_mult_list[lr_idx]
# for SSD/SSDLite, first head input is after ResidualUnit expand_conv
return_list = self.with_extra_blocks and i + 2 in self.feature_maps
block = self.add_sublayer(
"conv" + str(i + 2),
sublayer=ResidualUnit(
in_c=inplanes,
mid_c=make_divisible(scale * exp),
out_c=make_divisible(scale * c),
filter_size=k,
stride=s,
use_se=se,
act=nl,
lr_mult=lr_mult,
conv_decay=conv_decay,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
return_list=return_list,
name="conv" + str(i + 2)))
self.block_list.append(block)
inplanes = make_divisible(scale * c)
i += 1
if self.with_extra_blocks:
self.extra_block_list = []
extra_out_c = make_divisible(scale * self.cfg[-1][1])
lr_idx = min(i // 3, len(lr_mult_list) - 1)
lr_mult = lr_mult_list[lr_idx]
conv_extra = self.add_sublayer(
"conv" + str(i + 2),
sublayer=ConvBNLayer(
in_c=inplanes,
out_c=extra_out_c,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
act="hard_swish",
lr_mult=lr_mult,
conv_decay=conv_decay,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name="conv" + str(i + 2)))
self.extra_block_list.append(conv_extra)
i += 1
for j, block_filter in enumerate(self.extra_block_filters):
in_c = extra_out_c if j == 0 else self.extra_block_filters[j -
1][1]
conv_extra = self.add_sublayer(
"conv" + str(i + 2),
sublayer=ExtraBlockDW(
in_c,
block_filter[0],
block_filter[1],
stride=2,
lr_mult=lr_mult,
conv_decay=conv_decay,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
name='conv' + str(i + 2)))
self.extra_block_list.append(conv_extra)
i += 1
def forward(self, inputs):
x = self.conv1(inputs['image'])
outs = []
for idx, block in enumerate(self.block_list):
x = block(x)
if idx + 2 in self.feature_maps:
if isinstance(x, list):
outs.append(x[0])
x = x[1]
else:
outs.append(x)
if not self.with_extra_blocks:
return outs
for i, block in enumerate(self.extra_block_list):
idx = i + len(self.block_list)
x = block(x)
if idx + 2 in self.feature_maps:
outs.append(x)
return outs
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
from ppdet.core.workspace import register
from paddle.regularizer import L2Decay
from paddle import ParamAttr
class SepConvLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
padding=1,
conv_decay=0,
name=None):
super(SepConvLayer, self).__init__()
self.dw_conv = nn.Conv2D(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
groups=in_channels,
weight_attr=ParamAttr(
name=name + "_dw_weights", regularizer=L2Decay(conv_decay)),
bias_attr=False)
self.bn = nn.BatchNorm2D(
in_channels,
weight_attr=ParamAttr(
name=name + "_bn_scale", regularizer=L2Decay(0.)),
bias_attr=ParamAttr(
name=name + "_bn_offset", regularizer=L2Decay(0.)))
self.pw_conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(
name=name + "_pw_weights", regularizer=L2Decay(conv_decay)),
bias_attr=False)
def forward(self, x):
x = self.dw_conv(x)
x = F.relu6(self.bn(x))
x = self.pw_conv(x)
return x
@register
......@@ -14,6 +59,10 @@ class SSDHead(nn.Layer):
num_classes=81,
in_channels=(512, 1024, 512, 256, 256, 256),
anchor_generator='AnchorGeneratorSSD',
kernel_size=3,
padding=1,
use_sepconv=False,
conv_decay=0.,
loss='SSDLoss'):
super(SSDHead, self).__init__()
self.num_classes = num_classes
......@@ -25,22 +74,47 @@ class SSDHead(nn.Layer):
self.box_convs = []
self.score_convs = []
for i, num_prior in enumerate(self.num_priors):
self.box_convs.append(
self.add_sublayer(
"boxes{}".format(i),
box_conv_name = "boxes{}".format(i)
if not use_sepconv:
box_conv = self.add_sublayer(
box_conv_name,
nn.Conv2D(
in_channels=in_channels[i],
out_channels=num_prior * 4,
kernel_size=3,
padding=1)))
self.score_convs.append(
self.add_sublayer(
"scores{}".format(i),
kernel_size=kernel_size,
padding=padding))
else:
box_conv = self.add_sublayer(
box_conv_name,
SepConvLayer(
in_channels=in_channels[i],
out_channels=num_prior * 4,
kernel_size=kernel_size,
padding=padding,
conv_decay=conv_decay,
name=box_conv_name))
self.box_convs.append(box_conv)
score_conv_name = "scores{}".format(i)
if not use_sepconv:
score_conv = self.add_sublayer(
score_conv_name,
nn.Conv2D(
in_channels=in_channels[i],
out_channels=num_prior * num_classes,
kernel_size=3,
padding=1)))
kernel_size=kernel_size,
padding=padding))
else:
score_conv = self.add_sublayer(
score_conv_name,
SepConvLayer(
in_channels=in_channels[i],
out_channels=num_prior * num_classes,
kernel_size=kernel_size,
padding=padding,
conv_decay=conv_decay,
name=score_conv_name))
self.score_convs.append(score_conv)
def forward(self, feats, image):
box_preds = []
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import math
import six
import numpy as np
from numbers import Integral
......@@ -118,6 +119,7 @@ class AnchorGeneratorSSD(object):
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
min_ratio=15,
max_ratio=90,
base_size=300,
min_sizes=[30.0, 60.0, 111.0, 162.0, 213.0, 264.0],
max_sizes=[60.0, 111.0, 162.0, 213.0, 264.0, 315.0],
offset=0.5,
......@@ -128,6 +130,7 @@ class AnchorGeneratorSSD(object):
self.aspect_ratios = aspect_ratios
self.min_ratio = min_ratio
self.max_ratio = max_ratio
self.base_size = base_size
self.min_sizes = min_sizes
self.max_sizes = max_sizes
self.offset = offset
......@@ -135,9 +138,21 @@ class AnchorGeneratorSSD(object):
self.clip = clip
self.min_max_aspect_ratios_order = min_max_aspect_ratios_order
if self.min_sizes == [] and self.max_sizes == []:
num_layer = len(aspect_ratios)
step = int(
math.floor(((self.max_ratio - self.min_ratio)) / (num_layer - 2
)))
for ratio in six.moves.range(self.min_ratio, self.max_ratio + 1,
step):
self.min_sizes.append(self.base_size * ratio / 100.)
self.max_sizes.append(self.base_size * (ratio + step) / 100.)
self.min_sizes = [self.base_size * .10] + self.min_sizes
self.max_sizes = [self.base_size * .20] + self.max_sizes
self.num_priors = []
for aspect_ratio, min_size, max_size in zip(aspect_ratios, min_sizes,
max_sizes):
for aspect_ratio, min_size, max_size in zip(
aspect_ratios, self.min_sizes, self.max_sizes):
self.num_priors.append((len(aspect_ratio) * 2 + 1) * len(
_to_list(min_size)) + len(_to_list(max_size)))
......
......@@ -806,12 +806,12 @@ def prior_box(input,
cur_max_sizes = max_sizes
if in_dygraph_mode():
assert cur_max_sizes is not None
attrs = ('min_sizes', min_sizes, 'max_sizes', cur_max_sizes,
'aspect_ratios', aspect_ratios, 'variances', variance, 'flip',
flip, 'clip', clip, 'step_w', steps[0], 'step_h', steps[1],
'offset', offset, 'min_max_aspect_ratios_order',
min_max_aspect_ratios_order)
attrs = ('min_sizes', min_sizes, 'aspect_ratios', aspect_ratios,
'variances', variance, 'flip', flip, 'clip', clip, 'step_w',
steps[0], 'step_h', steps[1], 'offset', offset,
'min_max_aspect_ratios_order', min_max_aspect_ratios_order)
if cur_max_sizes is not None:
attrs += ('max_sizes', cur_max_sizes)
box, var = core.ops.prior_box(input, image, *attrs)
return box, var
else:
......
......@@ -96,8 +96,9 @@ def list_modules(**kwargs):
print("")
max_len = max([len(mod.name) for mod in modules])
for mod in modules:
print(color_tty.green(mod.name.ljust(max_len)),
mod.doc.split('\n')[0])
print(
color_tty.green(mod.name.ljust(max_len)),
mod.doc.split('\n')[0])
print("")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册