未验证 提交 a24ae7e3 编写于 作者: K Kaipeng Deng 提交者: GitHub

add PP-YOLO tiny (#2594)

* add PP-YOLO tiny
上级 2d949d19
......@@ -71,6 +71,20 @@ PP-YOLO improved performance and speed of YOLOv3 with following methods:
- PP-YOLO_MobileNetV3 used 4 GPUs for training and mini-batch size as 32 on each GPU, if GPU number and mini-batch size is changed, learning rate and iteration times should be adjusted according [FAQ](https://github.com/PaddlePaddle/PaddleDetection/blob/master/docs/FAQ.md).
- PP-YOLO_MobileNetV3 inference speed is tested on Kirin 990 with 1 thread.
### PP-YOLO tiny
| Model | GPU number | images/GPU | Model Size | Post Quant Model Size | input shape | Box AP<sup>val</sup> | Kirin 990 4xCore(FPS) | download | config | post quant model |
|:----------------------------:|:-------:|:-------------:|:----------:| :-------------------: | :---------: | :------------------: | :-------------------: | :------: | :----: | :--------------: |
| PP-YOLO tiny | 8 | 32 | 4.2MB | **1.3M** | 320 | 20.6 | 92.3 | [model](https://paddledet.bj.bcebos.com/models/ppyolo_tiny_650e_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ppyolo/ppyolo_tiny_650e_coco.yml) | [inference model](https://paddledet.bj.bcebos.com/models/ppyolo_tiny_quant.tar) |
| PP-YOLO tiny | 8 | 32 | 4.2MB | **1.3M** | 416 | 22.7 | 65.4 | [model](https://paddledet.bj.bcebos.com/models/ppyolo_tiny_650e_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ppyolo/ppyolo_tiny_650e_coco.yml) | [inference model](https://paddledet.bj.bcebos.com/models/ppyolo_tiny_quant.tar) |
**Notes:**
- PP-YOLO-tiny is trained on COCO train2017 datast and evaluated on val2017 dataset,Box AP<sup>val</sup> is evaluation results of `mAP(IoU=0.5:0.95)`, Box AP<sup>val</sup> is evaluation results of `mAP(IoU=0.5)`.
- PP-YOLO-tiny used 8 GPUs for training and mini-batch size as 32 on each GPU, if GPU number and mini-batch size is changed, learning rate and iteration times should be adjusted according [FAQ](https://github.com/PaddlePaddle/PaddleDetection/blob/master/docs/FAQ.md).
- PP-YOLO-tiny inference speed is tested on Kirin 990 with 4 threads by arm8
- we alse provide PP-YOLO-tiny post quant inference model, which can compress model to **1.3MB** with nearly no inference on inference speed and performance
### PP-YOLO on Pascal VOC
PP-YOLO trained on Pascal VOC dataset as follows:
......
epoch: 650
LearningRate:
base_lr: 0.005
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 430
- 540
- 610
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
architecture: YOLOv3
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: MobileNetV3
neck: PPYOLOTinyFPN
yolo_head: YOLOv3Head
post_process: BBoxPostProcess
MobileNetV3:
model_name: large
scale: .5
with_extra_blocks: false
extra_block_filters: []
feature_maps: [7, 13, 16]
PPYOLOTinyFPN:
detection_block_channels: [160, 128, 96]
spp: true
drop_block: true
YOLOv3Head:
anchors: [[10, 15], [24, 36], [72, 42],
[35, 87], [102, 96], [60, 170],
[220, 125], [128, 222], [264, 266]]
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
loss: YOLOv3Loss
YOLOv3Loss:
ignore_thresh: 0.5
downsample: [32, 16, 8]
label_smooth: false
scale_x_y: 1.05
iou_loss: IouLoss
IouLoss:
loss_weight: 2.5
loss_square: true
BBoxPostProcess:
decode:
name: YOLOBox
conf_thresh: 0.005
downsample_ratio: 32
clip_bbox: true
scale_x_y: 1.05
nms:
name: MultiClassNMS
keep_top_k: 100
nms_threshold: 0.45
nms_top_k: 1000
score_threshold: 0.005
worker_num: 4
TrainReader:
inputs_def:
num_max_boxes: 100
sample_transforms:
- Decode: {}
- Mixup: {alpha: 1.5, beta: 1.5}
- RandomDistort: {}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomCrop: {}
- RandomFlip: {}
batch_transforms:
- BatchRandomResize: {target_size: [192, 224, 256, 288, 320, 352, 384, 416, 448, 480, 512], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeBox: {}
- PadBox: {num_max_boxes: 100}
- BboxXYXY2XYWH: {}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
- Gt2YoloTarget: {anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]], anchors: [[10, 15], [24, 36], [72, 42], [35, 87], [102, 96], [60, 170], [220, 125], [128, 222], [264, 266]], downsample_ratios: [32, 16, 8]}
batch_size: 32
shuffle: true
drop_last: true
mixup_epoch: 500
use_shared_memory: true
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [320, 320], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 8
drop_empty: false
TestReader:
inputs_def:
image_shape: [3, 320, 320]
sample_transforms:
- Decode: {}
- Resize: {target_size: [320, 320], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 1
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/ppyolo_tiny.yml',
'./_base_/optimizer_650e.yml',
'./_base_/ppyolo_tiny_reader.yml',
]
snapshot_epoch: 1
weights: output/ppyolo_tiny_650e_coco/model_final
......@@ -20,6 +20,7 @@ class YOLOv3Head(nn.Layer):
__inject__ = ['loss']
def __init__(self,
in_channels=[1024, 512, 256],
anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
......@@ -41,6 +42,8 @@ class YOLOv3Head(nn.Layer):
data_format (str): data format, NCHW or NHWC
"""
super(YOLOv3Head, self).__init__()
assert len(in_channels) > 0, "in_channels length should > 0"
self.in_channels = in_channels
self.num_classes = num_classes
self.loss = loss
......@@ -60,7 +63,7 @@ class YOLOv3Head(nn.Layer):
num_filters = len(self.anchors[i]) * (self.num_classes + 5)
name = 'yolo_output.{}'.format(i)
conv = nn.Conv2D(
in_channels=128 * (2**self.num_outputs) // (2**i),
in_channels=self.in_channels[i],
out_channels=num_filters,
kernel_size=1,
stride=1,
......@@ -116,3 +119,7 @@ class YOLOv3Head(nn.Layer):
return y
else:
return yolo_outputs
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
......@@ -262,6 +262,77 @@ class PPYOLODetBlock(nn.Layer):
return route, tip
class PPYOLOTinyDetBlock(nn.Layer):
def __init__(self,
ch_in,
ch_out,
name,
drop_block=False,
block_size=3,
keep_prob=0.9,
data_format='NCHW'):
"""
PPYOLO Tiny DetBlock layer
Args:
ch_in (list): input channel number
ch_out (list): output channel number
name (str): block name
drop_block: whether user DropBlock
block_size: drop block size
keep_prob: probability to keep block in DropBlock
data_format (str): data format, NCHW or NHWC
"""
super(PPYOLOTinyDetBlock, self).__init__()
self.drop_block_ = drop_block
self.conv_module = nn.Sequential()
cfgs = [
# name, in channels, out channels, filter_size,
# stride, padding, groups
['.0', ch_in, ch_out, 1, 1, 0, 1],
['.1', ch_out, ch_out, 5, 1, 2, ch_out],
['.2', ch_out, ch_out, 1, 1, 0, 1],
['.route', ch_out, ch_out, 5, 1, 2, ch_out],
]
for cfg in cfgs:
conv_name, conv_ch_in, conv_ch_out, filter_size, stride, padding, \
groups = cfg
self.conv_module.add_sublayer(
name + conv_name,
ConvBNLayer(
ch_in=conv_ch_in,
ch_out=conv_ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
name=name + conv_name))
self.tip = ConvBNLayer(
ch_in=ch_out,
ch_out=ch_out,
filter_size=1,
stride=1,
padding=0,
groups=1,
name=name + conv_name)
if self.drop_block_:
self.drop_block = DropBlock(
block_size=block_size,
keep_prob=keep_prob,
data_format=data_format,
name=name + '.dropblock')
def forward(self, inputs):
if self.drop_block_:
inputs = self.drop_block(inputs)
route = self.conv_module(inputs)
tip = self.tip(route)
return route, tip
@register
@serializable
class YOLOv3FPN(nn.Layer):
......@@ -497,3 +568,116 @@ class PPYOLOFPN(nn.Layer):
@property
def out_shape(self):
return [ShapeSpec(channels=c) for c in self._out_channels]
@register
@serializable
class PPYOLOTinyFPN(nn.Layer):
__shared__ = ['norm_type', 'data_format']
def __init__(self,
in_channels=[80, 56, 34],
detection_block_channels=[160, 128, 96],
norm_type='bn',
data_format='NCHW',
**kwargs):
"""
PPYOLO Tiny FPN layer
Args:
in_channels (list): input channels for fpn
detection_block_channels (list): channels in fpn
norm_type (str): batch norm type, default bn
data_format (str): data format, NCHW or NHWC
kwargs: extra key-value pairs, such as parameter of DropBlock and spp
"""
super(PPYOLOTinyFPN, self).__init__()
assert len(in_channels) > 0, "in_channels length should > 0"
self.in_channels = in_channels[::-1]
assert len(detection_block_channels
) > 0, "detection_block_channelslength should > 0"
self.detection_block_channels = detection_block_channels
self.data_format = data_format
self.num_blocks = len(in_channels)
# parse kwargs
self.drop_block = kwargs.get('drop_block', False)
self.block_size = kwargs.get('block_size', 3)
self.keep_prob = kwargs.get('keep_prob', 0.9)
self.spp_ = kwargs.get('spp', False)
if self.spp_:
self.spp = SPP(self.in_channels[0] * 4,
self.in_channels[0],
k=1,
pool_size=[5, 9, 13],
norm_type=norm_type,
name='spp')
self._out_channels = []
self.yolo_blocks = []
self.routes = []
for i, (
ch_in, ch_out
) in enumerate(zip(self.in_channels, self.detection_block_channels)):
name = 'yolo_block.{}'.format(i)
if i > 0:
ch_in += self.detection_block_channels[i - 1]
yolo_block = self.add_sublayer(
name,
PPYOLOTinyDetBlock(
ch_in,
ch_out,
name,
drop_block=self.drop_block,
block_size=self.block_size,
keep_prob=self.keep_prob))
self.yolo_blocks.append(yolo_block)
self._out_channels.append(ch_out)
if i < self.num_blocks - 1:
name = 'yolo_transition.{}'.format(i)
route = self.add_sublayer(
name,
ConvBNLayer(
ch_in=ch_out,
ch_out=ch_out,
filter_size=1,
stride=1,
padding=0,
norm_type=norm_type,
data_format=data_format,
name=name))
self.routes.append(route)
def forward(self, blocks):
assert len(blocks) == self.num_blocks
blocks = blocks[::-1]
yolo_feats = []
for i, block in enumerate(blocks):
if i == 0 and self.spp_:
block = self.spp(block)
if i > 0:
if self.data_format == 'NCHW':
block = paddle.concat([route, block], axis=1)
else:
block = paddle.concat([route, block], axis=-1)
route, tip = self.yolo_blocks[i](block)
yolo_feats.append(tip)
if i < self.num_blocks - 1:
route = self.routes[i](route)
route = F.interpolate(
route, scale_factor=2., data_format=self.data_format)
return yolo_feats
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
@property
def out_shape(self):
return [ShapeSpec(channels=c) for c in self._out_channels]
......@@ -88,6 +88,20 @@ PP-YOLO improved performance and speed of YOLOv3 with following methods:
- Pruning detectiom head of PP-YOLO model with ratio as 75%, while the arguments are `--pruned_params="yolo_block.0.2.conv.weights,yolo_block.0.tip.conv.weights,yolo_block.1.2.conv.weights,yolo_block.1.tip.conv.weights" --pruned_ratios="0.75,0.75,0.75,0.75"`
- For Slim PP-YOLO training, evaluation, inference and model exporting, please see [Distill pruned model](../../slim/extentions/distill_pruned_model/README.md)
### PP-YOLO tiny
| Model | GPU number | images/GPU | Model Size | Post Quant Model Size | input shape | Box AP<sup>val</sup> | Kirin 990 4xCore(FPS) | download | config | config | post quant model |
|:----------------------------:|:-------:|:-------------:|:----------:| :-------------------: | :----------:| :------------------: | :-------------------: | :------: | :----: | :----: | :--------------: |
| PP-YOLO tiny | 8 | 32 | 4.2MB | **1.3M** | 320 | 20.6 | 92.3 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_tiny.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/static/configs/ppyolo/ppyolo_tiny.yml) | [inference model](https://paddledet.bj.bcebos.com/models/ppyolo_tiny_quant.tar) |
| PP-YOLO tiny | 8 | 32 | 4.2MB | **1.3M** | 416 | 22.7 | 65.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_tiny.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/static/configs/ppyolo/ppyolo_tiny.yml) | [inference model](https://paddledet.bj.bcebos.com/models/ppyolo_tiny_quant.tar) |
**Notes:**
- PP-YOLO-tiny is trained on COCO train2017 datast and evaluated on val2017 dataset,Box AP<sup>val</sup> is evaluation results of `mAP(IoU=0.5:0.95)`, Box AP<sup>val</sup> is evaluation results of `mAP(IoU=0.5)`.
- PP-YOLO-tiny used 8 GPUs for training and mini-batch size as 32 on each GPU, if GPU number and mini-batch size is changed, learning rate and iteration times should be adjusted according [FAQ](https://github.com/PaddlePaddle/PaddleDetection/blob/master/docs/FAQ.md).
- PP-YOLO-tiny inference speed is tested on Kirin 990 with 4 threads by arm8
- we alse provide PP-YOLO-tiny post quant inference model, which can compress model to **1.3MB** with nearly no inference on inference speed and performance
### PP-YOLO on Pascal VOC
PP-YOLO trained on Pascal VOC dataset as follows:
......
......@@ -87,6 +87,18 @@ PP-YOLO从如下方面优化和提升YOLOv3模型的精度和速度:
- 卷积通道检测对Head部分剪裁掉75%的通道数,及剪裁参数为`--pruned_params="yolo_block.0.2.conv.weights,yolo_block.0.tip.conv.weights,yolo_block.1.2.conv.weights,yolo_block.1.tip.conv.weights" --pruned_ratios="0.75,0.75,0.75,0.75"`
- PP-YOLO 轻量级裁剪模型的训练、评估、预测及模型导出方法见[蒸馏通道剪裁模型](../../slim/extentions/distill_pruned_model/README.md)
### PP-YOLO tiny模型
| 模型 | GPU 个数 | 每GPU图片个数 | 模型体积 | 后量化模型体积 | 输入尺寸 | Box AP<sup>val</sup> | Kirin 990 1xCore (FPS) | 模型下载 | 配置文件 | 后量化模型 |
|:----------------------------:|:----------:|:-------------:| :--------: | :------------: | :----------:| :------------------: | :--------------------: | :------: | :------: | :--------: |
| PP-YOLO tiny | 8 | 32 | 4.2MB | **1.3M** | 320 | 20.6 | 92.3 | [model](https://paddledet.bj.bcebos.com/models/ppyolo_tiny_650e_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_tiny_650e_coco.yml) | [预测模型](https://paddledet.bj.bcebos.com/models/ppyolo_tiny_quant.tar) |
| PP-YOLO tiny | 8 | 32 | 4.2MB | **1.3M** | 416 | 22.7 | 65.4 | [model](https://paddledet.bj.bcebos.com/models/ppyolo_tiny_650e_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ppyolo/ppyolo_tiny_650e_coco.yml) | [预测模型](https://paddledet.bj.bcebos.com/models/ppyolo_tiny_quant.tar) |
- PP-YOLO-tiny 模型使用COCO数据集中train2017作为训练集,使用val2017作为测试集,Box AP<sup>val</sup>`mAP(IoU=0.5:0.95)`评估结果, Box AP50<sup>val</sup>`mAP(IoU=0.5)`评估结果。
- PP-YOLO-tiny 模型训练过程中使用8GPU,每GPU batch size为32进行训练,如训练GPU数和batch size不使用上述配置,须参考[FAQ](../../docs/FAQ.md)调整学习率和迭代次数。
- PP-YOLO-tiny 模型推理速度测试环境配置为麒麟990芯片4线程,arm8架构。
- 我们也提供的PP-YOLO-tiny的后量化压缩模型,将模型体积压缩到**1.3M**,对精度和预测速度基本无影响
### Pascal VOC数据集上的PP-YOLO
PP-YOLO在Pascal VOC数据集上训练模型如下:
......
architecture: YOLOv3
use_gpu: true
max_iters: 300000
log_smooth_window: 100
log_iter: 100
save_dir: output
snapshot_iter: 10000
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x0_5_pretrained.tar
weights: output/ppyolo_tiny/model_final
num_classes: 80
use_fine_grained_loss: true
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: MobileNetV3
yolo_head: PPYOLOTinyHead
use_fine_grained_loss: true
MobileNetV3:
norm_type: sync_bn
norm_decay: 0.
model_name: large
scale: .5
extra_block_filters: []
feature_maps: [1, 2, 3, 4, 6]
PPYOLOTinyHead:
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 15], [24, 36], [72, 42],
[35, 87], [102, 96], [60, 170],
[220, 125], [128, 222], [264, 266]]
detection_block_channels: [160, 128, 96]
norm_decay: 0.
scale_x_y: 1.05
yolo_loss: YOLOv3Loss
spp: true
drop_block: true
nms:
background_label: -1
keep_top_k: 100
nms_threshold: 0.45
nms_top_k: 1000
normalized: false
score_threshold: 0.01
YOLOv3Loss:
ignore_thresh: 0.5
scale_x_y: 1.05
label_smooth: false
use_fine_grained_loss: true
iou_loss: IouLoss
IouLoss:
loss_weight: 2.5
max_height: 512
max_width: 512
LearningRate:
base_lr: 0.005
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 200000
- 250000
- 280000
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.949
type: Momentum
regularizer:
factor: 0.0005
type: L2
TrainReader:
inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score']
num_max_boxes: 100
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: train_data/dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: True
- !MixupImage
alpha: 1.5
beta: 1.5
- !ColorDistort {}
- !RandomExpand
fill_value: [123.675, 116.28, 103.53]
ratio: 2
- !RandomCrop {}
- !RandomFlipImage
is_normalized: false
- !NormalizeBox {}
- !PadBox
num_max_boxes: 100
- !BboxXYXY2XYWH {}
batch_transforms:
- !RandomShape
sizes: [192, 224, 256, 288, 320, 352, 384, 416, 448, 480, 512]
random_inter: True
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
- !Gt2YoloTarget
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 15], [24, 36], [72, 42],
[35, 87], [102, 96], [60, 170],
[220, 125], [128, 222], [264, 266]]
downsample_ratios: [32, 16, 8]
iou_thresh: 0.25
num_classes: 80
batch_size: 32
shuffle: true
mixup_epoch: 200
drop_last: true
worker_num: 16
bufsize: 4
use_process: true
EvalReader:
inputs_def:
fields: ['image', 'im_size', 'im_id']
num_max_boxes: 100
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: train_data/dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 320
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !PadBox
num_max_boxes: 100
- !Permute
to_bgr: false
channel_first: True
batch_size: 1
drop_empty: false
worker_num: 2
bufsize: 4
TestReader:
inputs_def:
image_shape: [3, 320, 320]
fields: ['image', 'im_size', 'im_id']
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 320
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
batch_size: 1
......@@ -163,6 +163,7 @@ class YOLOv3Head(object):
filter_size,
stride,
padding,
groups=None,
act='leaky',
name=None):
conv = fluid.layers.conv2d(
......@@ -171,6 +172,7 @@ class YOLOv3Head(object):
filter_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + ".conv.weights"),
bias_attr=False)
......@@ -649,3 +651,190 @@ class YOLOv4Head(YOLOv3Head):
outputs.append(block_out)
return outputs
@register
class PPYOLOTinyHead(YOLOv3Head):
"""
Head block for YOLOv3 network
Args:
norm_decay (float): weight decay for normalization layer weights
num_classes (int): number of output classes
anchors (list): anchors
anchor_masks (list): anchor masks
nms (object): an instance of `MultiClassNMS`
detection_block_channels (list): the channel number of each
detection block.
"""
__inject__ = ['yolo_loss', 'nms']
__shared__ = ['num_classes', 'weight_prefix_name']
def __init__(self,
norm_decay=0.,
num_classes=80,
anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
detection_block_channels=[128, 96],
drop_block=False,
block_size=3,
keep_prob=0.9,
yolo_loss="YOLOv3Loss",
spp=False,
nms=MultiClassNMS(
score_threshold=0.01,
nms_top_k=1000,
keep_top_k=100,
nms_threshold=0.45,
background_label=-1).__dict__,
weight_prefix_name='',
downsample=[32, 16, 8],
scale_x_y=1.0,
clip_bbox=True):
super(PPYOLOTinyHead, self).__init__(
norm_decay=norm_decay,
num_classes=num_classes,
anchors=anchors,
anchor_masks=anchor_masks,
drop_block=drop_block,
block_size=block_size,
keep_prob=0.9,
spp=spp,
yolo_loss=yolo_loss,
nms=nms,
weight_prefix_name=weight_prefix_name,
downsample=downsample,
scale_x_y=scale_x_y,
clip_bbox=clip_bbox)
self.detection_block_channels = detection_block_channels
def _detection_block(self,
input,
channel,
is_first=False,
is_test=True,
name=None):
assert channel % 2 == 0, \
"channel {} cannot be divided by 2 in detection block {}" \
.format(channel, name)
conv = input
if self.use_spp and is_first:
c = conv.shape[1]
conv = self._spp_module(conv, name="spp")
conv = self._conv_bn(
conv,
c,
filter_size=1,
stride=1,
padding=0,
name='{}.spp.conv'.format(name))
if self.drop_block:
conv = DropBlock(
conv,
block_size=self.block_size,
keep_prob=self.keep_prob,
is_test=is_test)
conv = self._conv_bn(
conv,
ch_out=channel,
filter_size=1,
stride=1,
padding=0,
groups=1,
name='{}.0'.format(name))
conv = self._conv_bn(
conv,
channel,
filter_size=5,
stride=1,
padding=2,
groups=channel,
name='{}.1'.format(name))
conv = self._conv_bn(
conv,
channel,
filter_size=1,
stride=1,
padding=0,
name='{}.2'.format(name))
route = self._conv_bn(
conv,
channel,
filter_size=5,
stride=1,
padding=2,
groups=channel,
name='{}.route'.format(name))
tip = self._conv_bn(
route,
channel,
filter_size=1,
stride=1,
padding=0,
name='{}.tip'.format(name))
return route, tip
def _get_outputs(self, input, is_train=True):
"""
Get PP-YOLO tiny head output
Args:
input (list): List of Variables, output of backbone stages
is_train (bool): whether in train or test mode
Returns:
outputs (list): Variables of each output layer
"""
outputs = []
# get last out_layer_num blocks in reverse order
out_layer_num = len(self.anchor_masks)
blocks = input[-1:-out_layer_num - 1:-1]
route = None
for i, block in enumerate(blocks):
if i > 0: # perform concat in first 2 detection_block
block = fluid.layers.concat(input=[route, block], axis=1)
route, tip = self._detection_block(
block,
channel=self.detection_block_channels[i],
is_first=i == 0,
is_test=(not is_train),
name=self.prefix_name + "yolo_block.{}".format(i))
# out channel number = mask_num * (5 + class_num)
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
with fluid.name_scope('yolo_output'):
block_out = fluid.layers.conv2d(
input=tip,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(
name=self.prefix_name +
"yolo_output.{}.conv.weights".format(i)),
bias_attr=ParamAttr(
regularizer=L2Decay(0.),
name=self.prefix_name +
"yolo_output.{}.conv.bias".format(i)))
outputs.append(block_out)
if i < len(blocks) - 1:
# upsample
route = self._conv_bn(
input=route,
ch_out=self.detection_block_channels[i],
filter_size=1,
stride=1,
padding=0,
name=self.prefix_name + "yolo_transition.{}".format(i))
route = self._upsample(route)
return outputs
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册