提交 beaa62a7 编写于 作者: L longxiang

update yolov3

上级 a66dfe9c
architecture: YOLOv3
use_gpu: true
max_iters: 500000
log_smooth_window: 100
log_iter: 100
save_dir: output
snapshot_iter: 10000
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
weights: output/ppyolo/model_final
num_classes: 80
use_fine_grained_loss: true
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: ResNet
yolo_head: YOLOv3Head
use_fine_grained_loss: true
ResNet:
norm_type: sync_bn
freeze_at: 0
freeze_norm: false
norm_decay: 0.
depth: 50
feature_maps: [3, 4, 5]
variant: d
dcn_v2_stages: [5]
YOLOv3Head:
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
coord_conv: true
iou_aware: true
iou_aware_factor: 0.4
scale_x_y: 1.05
spp: true
yolo_loss: YOLOv3Loss
nms:
background_label: -1
keep_top_k: 100
# nms_threshold: 0.45
# nms_top_k: 1000
normalized: false
score_threshold: 0.01
drop_block: true
YOLOv3Loss:
batch_size: 24
ignore_thresh: 0.7
scale_x_y: 1.05
label_smooth: false
use_fine_grained_loss: true
iou_loss: IouLoss
iou_aware_loss: IouAwareLoss
IouLoss:
loss_weight: 2.5
max_height: 608
max_width: 608
IouAwareLoss:
loss_weight: 1.0
max_height: 608
max_width: 608
LearningRate:
base_lr: 0.00333
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 400000
- 450000
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_READER_: 'ppyolo_reader.yml'
architecture: YOLOv3
use_gpu: true
max_iters: 250000
log_smooth_window: 100
log_iter: 100
save_dir: output
snapshot_iter: 10000
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
weights: output/ppyolo_lb/model_final
num_classes: 80
use_fine_grained_loss: true
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: ResNet
yolo_head: YOLOv3Head
use_fine_grained_loss: true
ResNet:
norm_type: sync_bn
freeze_at: 0
freeze_norm: false
norm_decay: 0.
depth: 50
feature_maps: [3, 4, 5]
variant: d
dcn_v2_stages: [5]
YOLOv3Head:
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
coord_conv: true
iou_aware: true
iou_aware_factor: 0.4
scale_x_y: 1.05
spp: true
yolo_loss: YOLOv3Loss
nms:
background_label: -1
keep_top_k: 100
# nms_threshold: 0.45
# nms_top_k: 1000
normalized: false
score_threshold: 0.01
drop_block: true
YOLOv3Loss:
batch_size: 24
ignore_thresh: 0.7
scale_x_y: 1.05
label_smooth: false
use_fine_grained_loss: true
iou_loss: IouLoss
iou_aware_loss: IouAwareLoss
IouLoss:
loss_weight: 2.5
max_height: 608
max_width: 608
IouAwareLoss:
loss_weight: 1.0
max_height: 608
max_width: 608
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 150000
- 200000
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_READER_: 'ppyolo_reader.yml'
TrainReader:
inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: True
- !MixupImage
alpha: 1.5
beta: 1.5
- !ColorDistort {}
- !RandomExpand
fill_value: [123.675, 116.28, 103.53]
- !RandomCrop {}
- !RandomFlipImage
is_normalized: false
- !NormalizeBox {}
- !PadBox
num_max_boxes: 50
- !BboxXYXY2XYWH {}
batch_transforms:
- !RandomShape
sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
random_inter: True
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
- !Gt2YoloTarget
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
downsample_ratios: [32, 16, 8]
batch_size: 24
shuffle: true
# mixup_epoch: 250
mixup_epoch: 25000
drop_last: true
worker_num: 8
bufsize: 4
use_process: true
EvalReader:
inputs_def:
fields: ['image', 'im_size', 'im_id']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 608
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !PadBox
num_max_boxes: 50
- !Permute
to_bgr: false
channel_first: True
batch_size: 8
drop_empty: false
worker_num: 8
bufsize: 4
TestReader:
inputs_def:
image_shape: [3, 608, 608]
fields: ['image', 'im_size', 'im_id']
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 608
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
batch_size: 1
...@@ -21,6 +21,7 @@ from paddle.fluid.param_attr import ParamAttr ...@@ -21,6 +21,7 @@ from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay from paddle.fluid.regularizer import L2Decay
from ppdet.modeling.ops import MultiClassNMS, MultiClassSoftNMS from ppdet.modeling.ops import MultiClassNMS, MultiClassSoftNMS
from ppdet.modeling.ops import MultiClassMatrixNMS
from ppdet.modeling.losses.yolo_loss import YOLOv3Loss from ppdet.modeling.losses.yolo_loss import YOLOv3Loss
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ppdet.modeling.ops import DropBlock from ppdet.modeling.ops import DropBlock
...@@ -56,11 +57,13 @@ class YOLOv3Head(object): ...@@ -56,11 +57,13 @@ class YOLOv3Head(object):
[59, 119], [116, 90], [156, 198], [373, 326]], [59, 119], [116, 90], [156, 198], [373, 326]],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]], anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
drop_block=False, drop_block=False,
coord_conv=False,
iou_aware=False, iou_aware=False,
iou_aware_factor=0.4, iou_aware_factor=0.4,
block_size=3, block_size=3,
keep_prob=0.9, keep_prob=0.9,
yolo_loss="YOLOv3Loss", yolo_loss="YOLOv3Loss",
spp=False,
nms=MultiClassNMS( nms=MultiClassNMS(
score_threshold=0.01, score_threshold=0.01,
nms_top_k=1000, nms_top_k=1000,
...@@ -81,24 +84,45 @@ class YOLOv3Head(object): ...@@ -81,24 +84,45 @@ class YOLOv3Head(object):
self.prefix_name = weight_prefix_name self.prefix_name = weight_prefix_name
self.drop_block = drop_block self.drop_block = drop_block
self.iou_aware = iou_aware self.iou_aware = iou_aware
self.coord_conv = coord_conv
self.iou_aware_factor = iou_aware_factor self.iou_aware_factor = iou_aware_factor
self.block_size = block_size self.block_size = block_size
self.keep_prob = keep_prob self.keep_prob = keep_prob
self.use_spp = spp
if isinstance(nms, dict): if isinstance(nms, dict):
self.nms = MultiClassNMS(**nms) self.nms = MultiClassMatrixNMS(**nms)
self.downsample = downsample self.downsample = downsample
self.scale_x_y = scale_x_y self.scale_x_y = scale_x_y
self.clip_bbox = clip_bbox self.clip_bbox = clip_bbox
def _add_coord(self, input):
input_shape = fluid.layers.shape(input)
b = input_shape[0]
h = input_shape[2]
w = input_shape[3]
x_range = fluid.layers.range(0, w, 1, 'float32') / (w - 1.)
x_range = x_range * 2. - 1.
x_range = fluid.layers.unsqueeze(x_range, [0, 1, 2])
x_range = fluid.layers.expand(x_range, [b, 1, h, 1])
x_range.stop_gradient = True
y_range = fluid.layers.transpose(x_range, [0, 1, 3, 2])
y_range.stop_gradient = True
return fluid.layers.concat([input, x_range, y_range], axis=1)
def _conv_bn(self, def _conv_bn(self,
input, input,
ch_out, ch_out,
filter_size, filter_size,
stride, stride,
padding, padding,
coord_conv=False,
act='leaky', act='leaky',
is_test=True, is_test=True,
name=None): name=None):
if coord_conv:
input = self._add_coord(input)
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(
input=input, input=input,
num_filters=ch_out, num_filters=ch_out,
...@@ -117,6 +141,7 @@ class YOLOv3Head(object): ...@@ -117,6 +141,7 @@ class YOLOv3Head(object):
out = fluid.layers.batch_norm( out = fluid.layers.batch_norm(
input=conv, input=conv,
act=None, act=None,
is_test=is_test,
param_attr=bn_param_attr, param_attr=bn_param_attr,
bias_attr=bn_bias_attr, bias_attr=bn_bias_attr,
moving_mean_name=bn_name + '.mean', moving_mean_name=bn_name + '.mean',
...@@ -126,6 +151,32 @@ class YOLOv3Head(object): ...@@ -126,6 +151,32 @@ class YOLOv3Head(object):
out = fluid.layers.leaky_relu(x=out, alpha=0.1) out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out return out
def _spp_module(self, input, is_test=True, name=""):
output1 = input
output2 = fluid.layers.pool2d(
input=output1,
pool_size=5,
pool_stride=1,
pool_padding=2,
ceil_mode=False,
pool_type='max')
output3 = fluid.layers.pool2d(
input=output1,
pool_size=9,
pool_stride=1,
pool_padding=4,
ceil_mode=False,
pool_type='max')
output4 = fluid.layers.pool2d(
input=output1,
pool_size=13,
pool_stride=1,
pool_padding=6,
ceil_mode=False,
pool_type='max')
output = fluid.layers.concat(input=[output1, output2, output3, output4], axis=1)
return output
def _detection_block(self, input, channel, is_test=True, name=None): def _detection_block(self, input, channel, is_test=True, name=None):
assert channel % 2 == 0, \ assert channel % 2 == 0, \
"channel {} cannot be divided by 2 in detection block {}" \ "channel {} cannot be divided by 2 in detection block {}" \
...@@ -139,8 +190,19 @@ class YOLOv3Head(object): ...@@ -139,8 +190,19 @@ class YOLOv3Head(object):
filter_size=1, filter_size=1,
stride=1, stride=1,
padding=0, padding=0,
coord_conv=True,
is_test=is_test, is_test=is_test,
name='{}.{}.0'.format(name, j)) name='{}.{}.0'.format(name, j))
if self.use_spp and channel == 512 and j == 1:
conv = self._spp_module(conv, is_test=is_test, name="spp")
conv = self._conv_bn(
conv,
512,
filter_size=1,
stride=1,
padding=0,
is_test=is_test,
name='{}.{}.spp.conv'.format(name, j))
conv = self._conv_bn( conv = self._conv_bn(
conv, conv,
channel * 2, channel * 2,
...@@ -168,6 +230,7 @@ class YOLOv3Head(object): ...@@ -168,6 +230,7 @@ class YOLOv3Head(object):
filter_size=1, filter_size=1,
stride=1, stride=1,
padding=0, padding=0,
coord_conv=True,
is_test=is_test, is_test=is_test,
name='{}.2'.format(name)) name='{}.2'.format(name))
tip = self._conv_bn( tip = self._conv_bn(
...@@ -176,6 +239,7 @@ class YOLOv3Head(object): ...@@ -176,6 +239,7 @@ class YOLOv3Head(object):
filter_size=3, filter_size=3,
stride=1, stride=1,
padding=1, padding=1,
coord_conv=True,
is_test=is_test, is_test=is_test,
name='{}.tip'.format(name)) name='{}.tip'.format(name))
return route, tip return route, tip
......
...@@ -30,9 +30,33 @@ __all__ = [ ...@@ -30,9 +30,33 @@ __all__ = [
'GenerateProposals', 'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner', 'GenerateProposals', 'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner',
'RoIAlign', 'RoIPool', 'MultiBoxHead', 'SSDLiteMultiBoxHead', 'RoIAlign', 'RoIPool', 'MultiBoxHead', 'SSDLiteMultiBoxHead',
'SSDOutputDecoder', 'RetinaTargetAssign', 'RetinaOutputDecoder', 'ConvNorm', 'SSDOutputDecoder', 'RetinaTargetAssign', 'RetinaOutputDecoder', 'ConvNorm',
'DeformConvNorm', 'MultiClassSoftNMS', 'LibraBBoxAssigner' 'DeformConvNorm', 'MultiClassSoftNMS', 'LibraBBoxAssigner', 'MultiClassMatrixNMS'
] ]
@register
@serializable
class MultiClassMatrixNMS(object):
__op__ = fluid.layers.matrix_nms
__append_doc__ = True
def __init__(self,
score_threshold=.05,
post_threshold=.01,
nms_top_k=-1,
keep_top_k=100,
use_gaussian=False,
gaussian_sigma=2.0,
normalized=False,
background_label=0):
super(MultiClassMatrixNMS, self).__init__()
self.score_threshold = score_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
self.score_threshold = score_threshold
self.post_threshold = post_threshold
self.use_gaussian = use_gaussian
self.normalized = normalized
self.background_label = background_label
def _conv_offset(input, filter_size, stride, padding, act=None, name=None): def _conv_offset(input, filter_size, stride, padding, act=None, name=None):
out_channel = filter_size * filter_size * 3 out_channel = filter_size * filter_size * 3
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册