“b878027c9ab2fc1db938a64fabc25885df0ed0a1”上不存在“parakeet/git@gitcode.net:paddlepaddle/DeepSpeech.git”
提交 beaa62a7 编写于 作者: L longxiang

update yolov3

上级 a66dfe9c
architecture: YOLOv3
use_gpu: true
max_iters: 500000
log_smooth_window: 100
log_iter: 100
save_dir: output
snapshot_iter: 10000
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
weights: output/ppyolo/model_final
num_classes: 80
use_fine_grained_loss: true
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: ResNet
yolo_head: YOLOv3Head
use_fine_grained_loss: true
ResNet:
norm_type: sync_bn
freeze_at: 0
freeze_norm: false
norm_decay: 0.
depth: 50
feature_maps: [3, 4, 5]
variant: d
dcn_v2_stages: [5]
YOLOv3Head:
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
coord_conv: true
iou_aware: true
iou_aware_factor: 0.4
scale_x_y: 1.05
spp: true
yolo_loss: YOLOv3Loss
nms:
background_label: -1
keep_top_k: 100
# nms_threshold: 0.45
# nms_top_k: 1000
normalized: false
score_threshold: 0.01
drop_block: true
YOLOv3Loss:
batch_size: 24
ignore_thresh: 0.7
scale_x_y: 1.05
label_smooth: false
use_fine_grained_loss: true
iou_loss: IouLoss
iou_aware_loss: IouAwareLoss
IouLoss:
loss_weight: 2.5
max_height: 608
max_width: 608
IouAwareLoss:
loss_weight: 1.0
max_height: 608
max_width: 608
LearningRate:
base_lr: 0.00333
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 400000
- 450000
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_READER_: 'ppyolo_reader.yml'
architecture: YOLOv3
use_gpu: true
max_iters: 250000
log_smooth_window: 100
log_iter: 100
save_dir: output
snapshot_iter: 10000
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
weights: output/ppyolo_lb/model_final
num_classes: 80
use_fine_grained_loss: true
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: ResNet
yolo_head: YOLOv3Head
use_fine_grained_loss: true
ResNet:
norm_type: sync_bn
freeze_at: 0
freeze_norm: false
norm_decay: 0.
depth: 50
feature_maps: [3, 4, 5]
variant: d
dcn_v2_stages: [5]
YOLOv3Head:
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
coord_conv: true
iou_aware: true
iou_aware_factor: 0.4
scale_x_y: 1.05
spp: true
yolo_loss: YOLOv3Loss
nms:
background_label: -1
keep_top_k: 100
# nms_threshold: 0.45
# nms_top_k: 1000
normalized: false
score_threshold: 0.01
drop_block: true
YOLOv3Loss:
batch_size: 24
ignore_thresh: 0.7
scale_x_y: 1.05
label_smooth: false
use_fine_grained_loss: true
iou_loss: IouLoss
iou_aware_loss: IouAwareLoss
IouLoss:
loss_weight: 2.5
max_height: 608
max_width: 608
IouAwareLoss:
loss_weight: 1.0
max_height: 608
max_width: 608
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 150000
- 200000
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_READER_: 'ppyolo_reader.yml'
TrainReader:
inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: True
- !MixupImage
alpha: 1.5
beta: 1.5
- !ColorDistort {}
- !RandomExpand
fill_value: [123.675, 116.28, 103.53]
- !RandomCrop {}
- !RandomFlipImage
is_normalized: false
- !NormalizeBox {}
- !PadBox
num_max_boxes: 50
- !BboxXYXY2XYWH {}
batch_transforms:
- !RandomShape
sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
random_inter: True
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
- !Gt2YoloTarget
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
downsample_ratios: [32, 16, 8]
batch_size: 24
shuffle: true
# mixup_epoch: 250
mixup_epoch: 25000
drop_last: true
worker_num: 8
bufsize: 4
use_process: true
EvalReader:
inputs_def:
fields: ['image', 'im_size', 'im_id']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 608
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !PadBox
num_max_boxes: 50
- !Permute
to_bgr: false
channel_first: True
batch_size: 8
drop_empty: false
worker_num: 8
bufsize: 4
TestReader:
inputs_def:
image_shape: [3, 608, 608]
fields: ['image', 'im_size', 'im_id']
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 608
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
batch_size: 1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle import fluid from paddle import fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay from paddle.fluid.regularizer import L2Decay
from ppdet.modeling.ops import MultiClassNMS, MultiClassSoftNMS from ppdet.modeling.ops import MultiClassNMS, MultiClassSoftNMS
from ppdet.modeling.losses.yolo_loss import YOLOv3Loss from ppdet.modeling.ops import MultiClassMatrixNMS
from ppdet.core.workspace import register from ppdet.modeling.losses.yolo_loss import YOLOv3Loss
from ppdet.modeling.ops import DropBlock from ppdet.core.workspace import register
from .iou_aware import get_iou_aware_score from ppdet.modeling.ops import DropBlock
try: from .iou_aware import get_iou_aware_score
from collections.abc import Sequence try:
except Exception: from collections.abc import Sequence
from collections import Sequence except Exception:
from ppdet.utils.check import check_version from collections import Sequence
from ppdet.utils.check import check_version
__all__ = ['YOLOv3Head', 'YOLOv4Head']
__all__ = ['YOLOv3Head', 'YOLOv4Head']
@register
class YOLOv3Head(object): @register
""" class YOLOv3Head(object):
Head block for YOLOv3 network """
Head block for YOLOv3 network
Args:
norm_decay (float): weight decay for normalization layer weights Args:
num_classes (int): number of output classes norm_decay (float): weight decay for normalization layer weights
anchors (list): anchors num_classes (int): number of output classes
anchor_masks (list): anchor masks anchors (list): anchors
nms (object): an instance of `MultiClassNMS` anchor_masks (list): anchor masks
""" nms (object): an instance of `MultiClassNMS`
__inject__ = ['yolo_loss', 'nms'] """
__shared__ = ['num_classes', 'weight_prefix_name'] __inject__ = ['yolo_loss', 'nms']
__shared__ = ['num_classes', 'weight_prefix_name']
def __init__(self,
norm_decay=0., def __init__(self,
num_classes=80, norm_decay=0.,
anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], num_classes=80,
[59, 119], [116, 90], [156, 198], [373, 326]], anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]], [59, 119], [116, 90], [156, 198], [373, 326]],
drop_block=False, anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
iou_aware=False, drop_block=False,
iou_aware_factor=0.4, coord_conv=False,
block_size=3, iou_aware=False,
keep_prob=0.9, iou_aware_factor=0.4,
yolo_loss="YOLOv3Loss", block_size=3,
nms=MultiClassNMS( keep_prob=0.9,
score_threshold=0.01, yolo_loss="YOLOv3Loss",
nms_top_k=1000, spp=False,
keep_top_k=100, nms=MultiClassNMS(
nms_threshold=0.45, score_threshold=0.01,
background_label=-1).__dict__, nms_top_k=1000,
weight_prefix_name='', keep_top_k=100,
downsample=[32, 16, 8], nms_threshold=0.45,
scale_x_y=1.0, background_label=-1).__dict__,
clip_bbox=True): weight_prefix_name='',
check_version('2.0.0') downsample=[32, 16, 8],
self.norm_decay = norm_decay scale_x_y=1.0,
self.num_classes = num_classes clip_bbox=True):
self.anchor_masks = anchor_masks check_version('2.0.0')
self._parse_anchors(anchors) self.norm_decay = norm_decay
self.yolo_loss = yolo_loss self.num_classes = num_classes
self.nms = nms self.anchor_masks = anchor_masks
self.prefix_name = weight_prefix_name self._parse_anchors(anchors)
self.drop_block = drop_block self.yolo_loss = yolo_loss
self.iou_aware = iou_aware self.nms = nms
self.iou_aware_factor = iou_aware_factor self.prefix_name = weight_prefix_name
self.block_size = block_size self.drop_block = drop_block
self.keep_prob = keep_prob self.iou_aware = iou_aware
if isinstance(nms, dict): self.coord_conv = coord_conv
self.nms = MultiClassNMS(**nms) self.iou_aware_factor = iou_aware_factor
self.downsample = downsample self.block_size = block_size
self.scale_x_y = scale_x_y self.keep_prob = keep_prob
self.clip_bbox = clip_bbox self.use_spp = spp
if isinstance(nms, dict):
def _conv_bn(self, self.nms = MultiClassMatrixNMS(**nms)
input, self.downsample = downsample
ch_out, self.scale_x_y = scale_x_y
filter_size, self.clip_bbox = clip_bbox
stride,
padding, def _add_coord(self, input):
act='leaky', input_shape = fluid.layers.shape(input)
is_test=True, b = input_shape[0]
name=None): h = input_shape[2]
conv = fluid.layers.conv2d( w = input_shape[3]
input=input,
num_filters=ch_out, x_range = fluid.layers.range(0, w, 1, 'float32') / (w - 1.)
filter_size=filter_size, x_range = x_range * 2. - 1.
stride=stride, x_range = fluid.layers.unsqueeze(x_range, [0, 1, 2])
padding=padding, x_range = fluid.layers.expand(x_range, [b, 1, h, 1])
act=None, x_range.stop_gradient = True
param_attr=ParamAttr(name=name + ".conv.weights"), y_range = fluid.layers.transpose(x_range, [0, 1, 3, 2])
bias_attr=False) y_range.stop_gradient = True
bn_name = name + ".bn" return fluid.layers.concat([input, x_range, y_range], axis=1)
bn_param_attr = ParamAttr(
regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale') def _conv_bn(self,
bn_bias_attr = ParamAttr( input,
regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset') ch_out,
out = fluid.layers.batch_norm( filter_size,
input=conv, stride,
act=None, padding,
param_attr=bn_param_attr, coord_conv=False,
bias_attr=bn_bias_attr, act='leaky',
moving_mean_name=bn_name + '.mean', is_test=True,
moving_variance_name=bn_name + '.var') name=None):
if coord_conv:
if act == 'leaky': input = self._add_coord(input)
out = fluid.layers.leaky_relu(x=out, alpha=0.1) conv = fluid.layers.conv2d(
return out input=input,
num_filters=ch_out,
def _detection_block(self, input, channel, is_test=True, name=None): filter_size=filter_size,
assert channel % 2 == 0, \ stride=stride,
"channel {} cannot be divided by 2 in detection block {}" \ padding=padding,
.format(channel, name) act=None,
param_attr=ParamAttr(name=name + ".conv.weights"),
conv = input bias_attr=False)
for j in range(2):
conv = self._conv_bn( bn_name = name + ".bn"
conv, bn_param_attr = ParamAttr(
channel, regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale')
filter_size=1, bn_bias_attr = ParamAttr(
stride=1, regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset')
padding=0, out = fluid.layers.batch_norm(
is_test=is_test, input=conv,
name='{}.{}.0'.format(name, j)) act=None,
conv = self._conv_bn( is_test=is_test,
conv, param_attr=bn_param_attr,
channel * 2, bias_attr=bn_bias_attr,
filter_size=3, moving_mean_name=bn_name + '.mean',
stride=1, moving_variance_name=bn_name + '.var')
padding=1,
is_test=is_test, if act == 'leaky':
name='{}.{}.1'.format(name, j)) out = fluid.layers.leaky_relu(x=out, alpha=0.1)
if self.drop_block and j == 0 and channel != 512: return out
conv = DropBlock(
conv, def _spp_module(self, input, is_test=True, name=""):
block_size=self.block_size, output1 = input
keep_prob=self.keep_prob, output2 = fluid.layers.pool2d(
is_test=is_test) input=output1,
pool_size=5,
if self.drop_block and channel == 512: pool_stride=1,
conv = DropBlock( pool_padding=2,
conv, ceil_mode=False,
block_size=self.block_size, pool_type='max')
keep_prob=self.keep_prob, output3 = fluid.layers.pool2d(
is_test=is_test) input=output1,
route = self._conv_bn( pool_size=9,
conv, pool_stride=1,
channel, pool_padding=4,
filter_size=1, ceil_mode=False,
stride=1, pool_type='max')
padding=0, output4 = fluid.layers.pool2d(
is_test=is_test, input=output1,
name='{}.2'.format(name)) pool_size=13,
tip = self._conv_bn( pool_stride=1,
route, pool_padding=6,
channel * 2, ceil_mode=False,
filter_size=3, pool_type='max')
stride=1, output = fluid.layers.concat(input=[output1, output2, output3, output4], axis=1)
padding=1, return output
is_test=is_test,
name='{}.tip'.format(name)) def _detection_block(self, input, channel, is_test=True, name=None):
return route, tip assert channel % 2 == 0, \
"channel {} cannot be divided by 2 in detection block {}" \
def _upsample(self, input, scale=2, name=None): .format(channel, name)
out = fluid.layers.resize_nearest(
input=input, scale=float(scale), name=name) conv = input
return out for j in range(2):
conv = self._conv_bn(
def _parse_anchors(self, anchors): conv,
""" channel,
Check ANCHORS/ANCHOR_MASKS in config and parse mask_anchors filter_size=1,
stride=1,
""" padding=0,
self.anchors = [] coord_conv=True,
self.mask_anchors = [] is_test=is_test,
name='{}.{}.0'.format(name, j))
assert len(anchors) > 0, "ANCHORS not set." if self.use_spp and channel == 512 and j == 1:
assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set." conv = self._spp_module(conv, is_test=is_test, name="spp")
conv = self._conv_bn(
for anchor in anchors: conv,
assert len(anchor) == 2, "anchor {} len should be 2".format(anchor) 512,
self.anchors.extend(anchor) filter_size=1,
stride=1,
anchor_num = len(anchors) padding=0,
for masks in self.anchor_masks: is_test=is_test,
self.mask_anchors.append([]) name='{}.{}.spp.conv'.format(name, j))
for mask in masks: conv = self._conv_bn(
assert mask < anchor_num, "anchor mask index overflow" conv,
self.mask_anchors[-1].extend(anchors[mask]) channel * 2,
filter_size=3,
def _get_outputs(self, input, is_train=True): stride=1,
""" padding=1,
Get YOLOv3 head output is_test=is_test,
name='{}.{}.1'.format(name, j))
Args: if self.drop_block and j == 0 and channel != 512:
input (list): List of Variables, output of backbone stages conv = DropBlock(
is_train (bool): whether in train or test mode conv,
block_size=self.block_size,
Returns: keep_prob=self.keep_prob,
outputs (list): Variables of each output layer is_test=is_test)
"""
if self.drop_block and channel == 512:
outputs = [] conv = DropBlock(
conv,
# get last out_layer_num blocks in reverse order block_size=self.block_size,
out_layer_num = len(self.anchor_masks) keep_prob=self.keep_prob,
blocks = input[-1:-out_layer_num - 1:-1] is_test=is_test)
route = self._conv_bn(
route = None conv,
for i, block in enumerate(blocks): channel,
if i > 0: # perform concat in first 2 detection_block filter_size=1,
block = fluid.layers.concat(input=[route, block], axis=1) stride=1,
route, tip = self._detection_block( padding=0,
block, coord_conv=True,
channel=512 // (2**i), is_test=is_test,
is_test=(not is_train), name='{}.2'.format(name))
name=self.prefix_name + "yolo_block.{}".format(i)) tip = self._conv_bn(
route,
# out channel number = mask_num * (5 + class_num) channel * 2,
if self.iou_aware: filter_size=3,
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 6) stride=1,
else: padding=1,
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5) coord_conv=True,
with fluid.name_scope('yolo_output'): is_test=is_test,
block_out = fluid.layers.conv2d( name='{}.tip'.format(name))
input=tip, return route, tip
num_filters=num_filters,
filter_size=1, def _upsample(self, input, scale=2, name=None):
stride=1, out = fluid.layers.resize_nearest(
padding=0, input=input, scale=float(scale), name=name)
act=None, return out
param_attr=ParamAttr(
name=self.prefix_name + def _parse_anchors(self, anchors):
"yolo_output.{}.conv.weights".format(i)), """
bias_attr=ParamAttr( Check ANCHORS/ANCHOR_MASKS in config and parse mask_anchors
regularizer=L2Decay(0.),
name=self.prefix_name + """
"yolo_output.{}.conv.bias".format(i))) self.anchors = []
outputs.append(block_out) self.mask_anchors = []
if i < len(blocks) - 1: assert len(anchors) > 0, "ANCHORS not set."
# do not perform upsample in the last detection_block assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set."
route = self._conv_bn(
input=route, for anchor in anchors:
ch_out=256 // (2**i), assert len(anchor) == 2, "anchor {} len should be 2".format(anchor)
filter_size=1, self.anchors.extend(anchor)
stride=1,
padding=0, anchor_num = len(anchors)
is_test=(not is_train), for masks in self.anchor_masks:
name=self.prefix_name + "yolo_transition.{}".format(i)) self.mask_anchors.append([])
# upsample for mask in masks:
route = self._upsample(route) assert mask < anchor_num, "anchor mask index overflow"
self.mask_anchors[-1].extend(anchors[mask])
return outputs
def _get_outputs(self, input, is_train=True):
def get_loss(self, input, gt_box, gt_label, gt_score, targets): """
""" Get YOLOv3 head output
Get final loss of network of YOLOv3.
Args:
Args: input (list): List of Variables, output of backbone stages
input (list): List of Variables, output of backbone stages is_train (bool): whether in train or test mode
gt_box (Variable): The ground-truth boudding boxes.
gt_label (Variable): The ground-truth class labels. Returns:
gt_score (Variable): The ground-truth boudding boxes mixup scores. outputs (list): Variables of each output layer
targets ([Variables]): List of Variables, the targets for yolo """
loss calculatation.
outputs = []
Returns:
loss (Variable): The loss Variable of YOLOv3 network. # get last out_layer_num blocks in reverse order
out_layer_num = len(self.anchor_masks)
""" blocks = input[-1:-out_layer_num - 1:-1]
outputs = self._get_outputs(input, is_train=True)
route = None
return self.yolo_loss(outputs, gt_box, gt_label, gt_score, targets, for i, block in enumerate(blocks):
self.anchors, self.anchor_masks, if i > 0: # perform concat in first 2 detection_block
self.mask_anchors, self.num_classes, block = fluid.layers.concat(input=[route, block], axis=1)
self.prefix_name) route, tip = self._detection_block(
block,
def get_prediction(self, input, im_size): channel=512 // (2**i),
""" is_test=(not is_train),
Get prediction result of YOLOv3 network name=self.prefix_name + "yolo_block.{}".format(i))
Args: # out channel number = mask_num * (5 + class_num)
input (list): List of Variables, output of backbone stages if self.iou_aware:
im_size (Variable): Variable of size([h, w]) of each image num_filters = len(self.anchor_masks[i]) * (self.num_classes + 6)
else:
Returns: num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
pred (Variable): The prediction result after non-max suppress. with fluid.name_scope('yolo_output'):
block_out = fluid.layers.conv2d(
""" input=tip,
num_filters=num_filters,
outputs = self._get_outputs(input, is_train=False) filter_size=1,
stride=1,
boxes = [] padding=0,
scores = [] act=None,
for i, output in enumerate(outputs): param_attr=ParamAttr(
if self.iou_aware: name=self.prefix_name +
output = get_iou_aware_score(output, "yolo_output.{}.conv.weights".format(i)),
len(self.anchor_masks[i]), bias_attr=ParamAttr(
self.num_classes, regularizer=L2Decay(0.),
self.iou_aware_factor) name=self.prefix_name +
scale_x_y = self.scale_x_y if not isinstance( "yolo_output.{}.conv.bias".format(i)))
self.scale_x_y, Sequence) else self.scale_x_y[i] outputs.append(block_out)
box, score = fluid.layers.yolo_box(
x=output, if i < len(blocks) - 1:
img_size=im_size, # do not perform upsample in the last detection_block
anchors=self.mask_anchors[i], route = self._conv_bn(
class_num=self.num_classes, input=route,
conf_thresh=self.nms.score_threshold, ch_out=256 // (2**i),
downsample_ratio=self.downsample[i], filter_size=1,
name=self.prefix_name + "yolo_box" + str(i), stride=1,
clip_bbox=self.clip_bbox, padding=0,
scale_x_y=scale_x_y) is_test=(not is_train),
boxes.append(box) name=self.prefix_name + "yolo_transition.{}".format(i))
scores.append(fluid.layers.transpose(score, perm=[0, 2, 1])) # upsample
route = self._upsample(route)
yolo_boxes = fluid.layers.concat(boxes, axis=1)
yolo_scores = fluid.layers.concat(scores, axis=2) return outputs
if type(self.nms) is MultiClassSoftNMS:
yolo_scores = fluid.layers.transpose(yolo_scores, perm=[0, 2, 1]) def get_loss(self, input, gt_box, gt_label, gt_score, targets):
pred = self.nms(bboxes=yolo_boxes, scores=yolo_scores) """
return {'bbox': pred} Get final loss of network of YOLOv3.
Args:
@register input (list): List of Variables, output of backbone stages
class YOLOv4Head(YOLOv3Head): gt_box (Variable): The ground-truth boudding boxes.
""" gt_label (Variable): The ground-truth class labels.
Head block for YOLOv4 network gt_score (Variable): The ground-truth boudding boxes mixup scores.
targets ([Variables]): List of Variables, the targets for yolo
Args: loss calculatation.
anchors (list): anchors
anchor_masks (list): anchor masks Returns:
nms (object): an instance of `MultiClassNMS` loss (Variable): The loss Variable of YOLOv3 network.
spp_stage (int): apply spp on which stage.
num_classes (int): number of output classes """
downsample (list): downsample ratio for each yolo_head outputs = self._get_outputs(input, is_train=True)
scale_x_y (list): scale the center point of bbox at each stage
""" return self.yolo_loss(outputs, gt_box, gt_label, gt_score, targets,
__inject__ = ['nms', 'yolo_loss'] self.anchors, self.anchor_masks,
__shared__ = ['num_classes', 'weight_prefix_name'] self.mask_anchors, self.num_classes,
self.prefix_name)
def __init__(self,
anchors=[[12, 16], [19, 36], [40, 28], [36, 75], [76, 55], def get_prediction(self, input, im_size):
[72, 146], [142, 110], [192, 243], [459, 401]], """
anchor_masks=[[0, 1, 2], [3, 4, 5], [6, 7, 8]], Get prediction result of YOLOv3 network
nms=MultiClassNMS(
score_threshold=0.01, Args:
nms_top_k=-1, input (list): List of Variables, output of backbone stages
keep_top_k=-1, im_size (Variable): Variable of size([h, w]) of each image
nms_threshold=0.45,
background_label=-1).__dict__, Returns:
spp_stage=5, pred (Variable): The prediction result after non-max suppress.
num_classes=80,
weight_prefix_name='', """
downsample=[8, 16, 32],
scale_x_y=1.0, outputs = self._get_outputs(input, is_train=False)
yolo_loss="YOLOv3Loss",
iou_aware=False, boxes = []
iou_aware_factor=0.4, scores = []
clip_bbox=False): for i, output in enumerate(outputs):
super(YOLOv4Head, self).__init__( if self.iou_aware:
anchors=anchors, output = get_iou_aware_score(output,
anchor_masks=anchor_masks, len(self.anchor_masks[i]),
nms=nms, self.num_classes,
num_classes=num_classes, self.iou_aware_factor)
weight_prefix_name=weight_prefix_name, scale_x_y = self.scale_x_y if not isinstance(
downsample=downsample, self.scale_x_y, Sequence) else self.scale_x_y[i]
scale_x_y=scale_x_y, box, score = fluid.layers.yolo_box(
yolo_loss=yolo_loss, x=output,
iou_aware=iou_aware, img_size=im_size,
iou_aware_factor=iou_aware_factor, anchors=self.mask_anchors[i],
clip_bbox=clip_bbox) class_num=self.num_classes,
self.spp_stage = spp_stage conf_thresh=self.nms.score_threshold,
downsample_ratio=self.downsample[i],
def _upsample(self, input, scale=2, name=None): name=self.prefix_name + "yolo_box" + str(i),
out = fluid.layers.resize_nearest( clip_bbox=self.clip_bbox,
input=input, scale=float(scale), name=name) scale_x_y=scale_x_y)
return out boxes.append(box)
scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
def max_pool(self, input, size):
pad = [(size - 1) // 2] * 2 yolo_boxes = fluid.layers.concat(boxes, axis=1)
return fluid.layers.pool2d(input, size, 'max', pool_padding=pad) yolo_scores = fluid.layers.concat(scores, axis=2)
if type(self.nms) is MultiClassSoftNMS:
def spp(self, input): yolo_scores = fluid.layers.transpose(yolo_scores, perm=[0, 2, 1])
branch_a = self.max_pool(input, 13) pred = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
branch_b = self.max_pool(input, 9) return {'bbox': pred}
branch_c = self.max_pool(input, 5)
out = fluid.layers.concat([branch_a, branch_b, branch_c, input], axis=1)
return out @register
class YOLOv4Head(YOLOv3Head):
def stack_conv(self, """
input, Head block for YOLOv4 network
ch_list=[512, 1024, 512],
filter_list=[1, 3, 1], Args:
stride=1, anchors (list): anchors
name=None): anchor_masks (list): anchor masks
conv = input nms (object): an instance of `MultiClassNMS`
for i, (ch_out, f_size) in enumerate(zip(ch_list, filter_list)): spp_stage (int): apply spp on which stage.
padding = 1 if f_size == 3 else 0 num_classes (int): number of output classes
conv = self._conv_bn( downsample (list): downsample ratio for each yolo_head
conv, scale_x_y (list): scale the center point of bbox at each stage
ch_out=ch_out, """
filter_size=f_size, __inject__ = ['nms', 'yolo_loss']
stride=stride, __shared__ = ['num_classes', 'weight_prefix_name']
padding=padding,
name='{}.{}'.format(name, i)) def __init__(self,
return conv anchors=[[12, 16], [19, 36], [40, 28], [36, 75], [76, 55],
[72, 146], [142, 110], [192, 243], [459, 401]],
def spp_module(self, input, name=None): anchor_masks=[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
conv = self.stack_conv(input, name=name + '.stack_conv.0') nms=MultiClassNMS(
spp_out = self.spp(conv) score_threshold=0.01,
conv = self.stack_conv(spp_out, name=name + '.stack_conv.1') nms_top_k=-1,
return conv keep_top_k=-1,
nms_threshold=0.45,
def pan_module(self, input, filter_list, name=None): background_label=-1).__dict__,
for i in range(1, len(input)): spp_stage=5,
ch_out = input[i].shape[1] // 2 num_classes=80,
conv_left = self._conv_bn( weight_prefix_name='',
input[i], downsample=[8, 16, 32],
ch_out=ch_out, scale_x_y=1.0,
filter_size=1, yolo_loss="YOLOv3Loss",
stride=1, iou_aware=False,
padding=0, iou_aware_factor=0.4,
name=name + '.{}.left'.format(i)) clip_bbox=False):
ch_out = input[i - 1].shape[1] // 2 super(YOLOv4Head, self).__init__(
conv_right = self._conv_bn( anchors=anchors,
input[i - 1], anchor_masks=anchor_masks,
ch_out=ch_out, nms=nms,
filter_size=1, num_classes=num_classes,
stride=1, weight_prefix_name=weight_prefix_name,
padding=0, downsample=downsample,
name=name + '.{}.right'.format(i)) scale_x_y=scale_x_y,
conv_right = self._upsample(conv_right) yolo_loss=yolo_loss,
pan_out = fluid.layers.concat([conv_left, conv_right], axis=1) iou_aware=iou_aware,
ch_list = [pan_out.shape[1] // 2 * k for k in [1, 2, 1, 2, 1]] iou_aware_factor=iou_aware_factor,
input[i] = self.stack_conv( clip_bbox=clip_bbox)
pan_out, self.spp_stage = spp_stage
ch_list=ch_list,
filter_list=filter_list, def _upsample(self, input, scale=2, name=None):
name=name + '.stack_conv.{}'.format(i)) out = fluid.layers.resize_nearest(
return input input=input, scale=float(scale), name=name)
return out
def _get_outputs(self, input, is_train=True):
outputs = [] def max_pool(self, input, size):
filter_list = [1, 3, 1, 3, 1] pad = [(size - 1) // 2] * 2
spp_stage = len(input) - self.spp_stage return fluid.layers.pool2d(input, size, 'max', pool_padding=pad)
# get last out_layer_num blocks in reverse order
out_layer_num = len(self.anchor_masks) def spp(self, input):
blocks = input[-1:-out_layer_num - 1:-1] branch_a = self.max_pool(input, 13)
blocks[spp_stage] = self.spp_module( branch_b = self.max_pool(input, 9)
blocks[spp_stage], name=self.prefix_name + "spp_module") branch_c = self.max_pool(input, 5)
blocks = self.pan_module( out = fluid.layers.concat([branch_a, branch_b, branch_c, input], axis=1)
blocks, return out
filter_list=filter_list,
name=self.prefix_name + 'pan_module') def stack_conv(self,
input,
# reverse order back to input ch_list=[512, 1024, 512],
blocks = blocks[::-1] filter_list=[1, 3, 1],
stride=1,
route = None name=None):
for i, block in enumerate(blocks): conv = input
if i > 0: # perform concat in first 2 detection_block for i, (ch_out, f_size) in enumerate(zip(ch_list, filter_list)):
route = self._conv_bn( padding = 1 if f_size == 3 else 0
route, conv = self._conv_bn(
ch_out=route.shape[1] * 2, conv,
filter_size=3, ch_out=ch_out,
stride=2, filter_size=f_size,
padding=1, stride=stride,
name=self.prefix_name + 'yolo_block.route.{}'.format(i)) padding=padding,
block = fluid.layers.concat(input=[route, block], axis=1) name='{}.{}'.format(name, i))
ch_list = [block.shape[1] // 2 * k for k in [1, 2, 1, 2, 1]] return conv
block = self.stack_conv(
block, def spp_module(self, input, name=None):
ch_list=ch_list, conv = self.stack_conv(input, name=name + '.stack_conv.0')
filter_list=filter_list, spp_out = self.spp(conv)
name=self.prefix_name + conv = self.stack_conv(spp_out, name=name + '.stack_conv.1')
'yolo_block.stack_conv.{}'.format(i)) return conv
route = block
def pan_module(self, input, filter_list, name=None):
block_out = self._conv_bn( for i in range(1, len(input)):
block, ch_out = input[i].shape[1] // 2
ch_out=block.shape[1] * 2, conv_left = self._conv_bn(
filter_size=3, input[i],
stride=1, ch_out=ch_out,
padding=1, filter_size=1,
name=self.prefix_name + 'yolo_output.{}.conv.0'.format(i)) stride=1,
padding=0,
if self.iou_aware: name=name + '.{}.left'.format(i))
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 6) ch_out = input[i - 1].shape[1] // 2
else: conv_right = self._conv_bn(
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5) input[i - 1],
block_out = fluid.layers.conv2d( ch_out=ch_out,
input=block_out, filter_size=1,
num_filters=num_filters, stride=1,
filter_size=1, padding=0,
stride=1, name=name + '.{}.right'.format(i))
padding=0, conv_right = self._upsample(conv_right)
act=None, pan_out = fluid.layers.concat([conv_left, conv_right], axis=1)
param_attr=ParamAttr(name=self.prefix_name + ch_list = [pan_out.shape[1] // 2 * k for k in [1, 2, 1, 2, 1]]
"yolo_output.{}.conv.1.weights".format(i)), input[i] = self.stack_conv(
bias_attr=ParamAttr( pan_out,
regularizer=L2Decay(0.), ch_list=ch_list,
name=self.prefix_name + filter_list=filter_list,
"yolo_output.{}.conv.1.bias".format(i))) name=name + '.stack_conv.{}'.format(i))
outputs.append(block_out) return input
return outputs def _get_outputs(self, input, is_train=True):
outputs = []
filter_list = [1, 3, 1, 3, 1]
spp_stage = len(input) - self.spp_stage
# get last out_layer_num blocks in reverse order
out_layer_num = len(self.anchor_masks)
blocks = input[-1:-out_layer_num - 1:-1]
blocks[spp_stage] = self.spp_module(
blocks[spp_stage], name=self.prefix_name + "spp_module")
blocks = self.pan_module(
blocks,
filter_list=filter_list,
name=self.prefix_name + 'pan_module')
# reverse order back to input
blocks = blocks[::-1]
route = None
for i, block in enumerate(blocks):
if i > 0: # perform concat in first 2 detection_block
route = self._conv_bn(
route,
ch_out=route.shape[1] * 2,
filter_size=3,
stride=2,
padding=1,
name=self.prefix_name + 'yolo_block.route.{}'.format(i))
block = fluid.layers.concat(input=[route, block], axis=1)
ch_list = [block.shape[1] // 2 * k for k in [1, 2, 1, 2, 1]]
block = self.stack_conv(
block,
ch_list=ch_list,
filter_list=filter_list,
name=self.prefix_name +
'yolo_block.stack_conv.{}'.format(i))
route = block
block_out = self._conv_bn(
block,
ch_out=block.shape[1] * 2,
filter_size=3,
stride=1,
padding=1,
name=self.prefix_name + 'yolo_output.{}.conv.0'.format(i))
if self.iou_aware:
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 6)
else:
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
block_out = fluid.layers.conv2d(
input=block_out,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(name=self.prefix_name +
"yolo_output.{}.conv.1.weights".format(i)),
bias_attr=ParamAttr(
regularizer=L2Decay(0.),
name=self.prefix_name +
"yolo_output.{}.conv.1.bias".format(i)))
outputs.append(block_out)
return outputs
...@@ -30,9 +30,33 @@ __all__ = [ ...@@ -30,9 +30,33 @@ __all__ = [
'GenerateProposals', 'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner', 'GenerateProposals', 'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner',
'RoIAlign', 'RoIPool', 'MultiBoxHead', 'SSDLiteMultiBoxHead', 'RoIAlign', 'RoIPool', 'MultiBoxHead', 'SSDLiteMultiBoxHead',
'SSDOutputDecoder', 'RetinaTargetAssign', 'RetinaOutputDecoder', 'ConvNorm', 'SSDOutputDecoder', 'RetinaTargetAssign', 'RetinaOutputDecoder', 'ConvNorm',
'DeformConvNorm', 'MultiClassSoftNMS', 'LibraBBoxAssigner' 'DeformConvNorm', 'MultiClassSoftNMS', 'LibraBBoxAssigner', 'MultiClassMatrixNMS'
] ]
@register
@serializable
class MultiClassMatrixNMS(object):
__op__ = fluid.layers.matrix_nms
__append_doc__ = True
def __init__(self,
score_threshold=.05,
post_threshold=.01,
nms_top_k=-1,
keep_top_k=100,
use_gaussian=False,
gaussian_sigma=2.0,
normalized=False,
background_label=0):
super(MultiClassMatrixNMS, self).__init__()
self.score_threshold = score_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
self.score_threshold = score_threshold
self.post_threshold = post_threshold
self.use_gaussian = use_gaussian
self.normalized = normalized
self.background_label = background_label
def _conv_offset(input, filter_size, stride, padding, act=None, name=None): def _conv_offset(input, filter_size, stride, padding, act=None, name=None):
out_channel = filter_size * filter_size * 3 out_channel = filter_size * filter_size * 3
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册