提交 d98eef2e 编写于 作者: D dengkaipeng

add CARAFE upsample.

上级 8a95c4b2
architecture: YOLOv3
use_gpu: true
max_iters: 500000
log_smooth_window: 20
save_dir: output
snapshot_iter: 10000
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar
weights: output/yolov3_darknet/model_final
num_classes: 80
use_fine_grained_loss: false
YOLOv3:
backbone: DarkNet
yolo_head: YOLOv3Head
DarkNet:
norm_type: sync_bn
norm_decay: 0.
depth: 53
YOLOv3Head:
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
norm_decay: 0.
yolo_loss: YOLOv3Loss
upsample:
type: carafe
scale: 2
kernel_size: 3
group_size: 5
nms:
background_label: -1
keep_top_k: 100
nms_threshold: 0.45
nms_top_k: 1000
normalized: false
score_threshold: 0.01
YOLOv3Loss:
# batch_size here is only used for fine grained loss, not used
# for training batch_size setting, training batch_size setting
# is in configs/yolov3_reader.yml TrainReader.batch_size, batch
# size here should be set as same value as TrainReader.batch_size
batch_size: 8
ignore_thresh: 0.7
label_smooth: true
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 400000
- 450000
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_READER_: 'yolov3_reader.yml'
......@@ -20,7 +20,7 @@ from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from ppdet.modeling.ops import MultiClassNMS
from ppdet.modeling.ops import MultiClassNMS, CARAFEUpsample
from ppdet.modeling.losses.yolo_loss import YOLOv3Loss
from ppdet.core.workspace import register
from ppdet.modeling.ops import DropBlock
......@@ -55,6 +55,7 @@ class YOLOv3Head(object):
anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
upsample='nearest',
drop_block=False,
iou_aware=False,
iou_aware_factor=0.4,
......@@ -75,6 +76,7 @@ class YOLOv3Head(object):
self.norm_decay = norm_decay
self.num_classes = num_classes
self.anchor_masks = anchor_masks
self.upsample = upsample
self._parse_anchors(anchors)
self.yolo_loss = yolo_loss
self.nms = nms
......@@ -180,9 +182,27 @@ class YOLOv3Head(object):
name='{}.tip'.format(name))
return route, tip
def _upsample(self, input, scale=2, name=None):
out = fluid.layers.resize_nearest(
input=input, scale=float(scale), name=name)
def _upsample(self, input, scale=2, upsample='nearest', name=None):
upsample = upsample.copy()
if upsample == 'nearest':
out = fluid.layers.resize_nearest(
input=input, scale=float(scale), name=name)
else:
print("upsample", upsample)
import sys
sys.stdout.flush()
assert isinstance(
upsample, dict), "Unknown upsample method: {}".format(upsample)
assert upsample['type'] in [
'carafe'
], 'Unknown upsample type {}'.format(upsample['type'])
upsample_type = upsample.pop('type')
upsample['name'] = name
if upsample_type.lower() == 'carafe':
up = CARAFEUpsample(**upsample)
out = up(input)
return out
def _parse_anchors(self, anchors):
......@@ -268,7 +288,10 @@ class YOLOv3Head(object):
is_test=(not is_train),
name=self.prefix_name + "yolo_transition.{}".format(i))
# upsample
route = self._upsample(route)
route = self._upsample(
route,
upsample=self.upsample,
name="yolo_upsample.{}".format(i))
return outputs
......
......@@ -26,11 +26,26 @@ from ppdet.core.workspace import register, serializable
from ppdet.utils.bbox_utils import bbox_overlaps, box_to_delta
__all__ = [
'AnchorGenerator', 'AnchorGrid', 'DropBlock', 'RPNTargetAssign',
'GenerateProposals', 'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner',
'RoIAlign', 'RoIPool', 'MultiBoxHead', 'SSDLiteMultiBoxHead',
'SSDOutputDecoder', 'RetinaTargetAssign', 'RetinaOutputDecoder', 'ConvNorm',
'DeformConvNorm', 'MultiClassSoftNMS', 'LibraBBoxAssigner'
'AnchorGenerator',
'AnchorGrid',
'DropBlock',
'RPNTargetAssign',
'GenerateProposals',
'MultiClassNMS',
'BBoxAssigner',
'MaskAssigner',
'RoIAlign',
'RoIPool',
'MultiBoxHead',
'SSDLiteMultiBoxHead',
'SSDOutputDecoder',
'RetinaTargetAssign',
'RetinaOutputDecoder',
'ConvNorm',
'DeformConvNorm',
'MultiClassSoftNMS',
'LibraBBoxAssigner',
'CARAFEUpsample',
]
......@@ -53,6 +68,31 @@ def _conv_offset(input, filter_size, stride, padding, act=None, name=None):
return out
def _conv_bn(input, ch_out, filter_size, stride, padding, act=None, name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(name=name + ".conv.weights"),
bias_attr=False)
bn_name = name + ".bn"
bn_param_attr = ParamAttr(regularizer=L2Decay(0.), name=bn_name + '.scale')
bn_bias_attr = ParamAttr(regularizer=L2Decay(0.), name=bn_name + '.offset')
out = fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=bn_param_attr,
bias_attr=bn_bias_attr,
moving_mean_name=bn_name + '.mean',
moving_variance_name=bn_name + '.var')
return out
def DeformConvNorm(input,
num_filters,
filter_size,
......@@ -1497,3 +1537,57 @@ class RetinaOutputDecoder(object):
self.nms_top_k = pre_nms_top_n
self.keep_top_k = detections_per_im
self.nms_eta = nms_eta
@register
@serializable
class CARAFEUpsample(object):
def __init__(self,
scale=2,
mid_channels=64,
kernel_size=3,
group_size=3,
name=None):
super(CARAFEUpsample, self).__init__()
self.scale = scale
self.mid_channels = mid_channels
self.kernel_size = kernel_size
self.group_size = group_size
self.name = name
def __call__(self, input):
weight = _conv_bn(
input,
self.mid_channels,
filter_size=1,
stride=1,
padding=0,
act='relu',
name=self.name + '.compresser')
weight = _conv_bn(
weight, (self.scale * self.group_size)**2,
filter_size=self.kernel_size,
stride=1,
padding=(self.kernel_size - 1) // 2,
act=None,
name=self.name + '.encoder')
weight = fluid.layers.pixel_shuffle(weight, upscale_factor=self.scale)
weight = fluid.layers.softmax(weight, axis=1)
weight = fluid.layers.unsqueeze(weight, axes=[1])
out = fluid.layers.resize_nearest(input, scale=float(self.scale))
out = fluid.layers.unfold(
out,
self.group_size,
dilations=self.scale,
paddings=(self.group_size - 1) // 2 * self.scale)
input_shape = fluid.layers.shape(input)
b = input_shape[0]
h = input_shape[2]
w = input_shape[3]
out = fluid.layers.reshape(
out, [b, int(input.shape[1]), -1, h * self.scale, w * self.scale])
out = fluid.layers.reduce_sum(weight * out, dim=2)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册