提交 ce96e567 编写于 作者: Y Yuan Gao 提交者: wangguanzhong

Add group norm (#3140)

* add a global norm_type flag

* add group norm on fpn

* add gn on box head

* add faster rcnn fpn 50 gn config

* add mask branch norm

* update configs
上级 bdd4bc8a
......@@ -86,7 +86,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.02
......
......@@ -85,7 +85,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.02
......
......@@ -86,7 +86,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.02
......
......@@ -88,7 +88,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -71,7 +71,7 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
conv_dim: 256
num_convs: 4
resolution: 28
......@@ -94,7 +94,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -70,7 +70,7 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
conv_dim: 256
num_convs: 4
resolution: 28
......@@ -93,7 +93,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -71,7 +71,7 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
conv_dim: 256
num_convs: 4
resolution: 28
......@@ -94,7 +94,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -73,7 +73,7 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
conv_dim: 256
num_convs: 4
resolution: 28
......@@ -96,7 +96,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -83,7 +83,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -83,7 +83,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -84,7 +84,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -84,7 +84,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -84,7 +84,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.02
......
......@@ -84,7 +84,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.02
......
......@@ -84,7 +84,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.02
......
......@@ -86,7 +86,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -86,7 +86,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -86,7 +86,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
architecture: FasterRCNN
train_feed: FasterRCNNTrainFeed
eval_feed: FasterRCNNEvalFeed
test_feed: FasterRCNNTestFeed
max_iters: 180000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/faster_rcnn_r50_fpn_gn/model_final
num_classes: 81
FasterRCNN:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: BBoxAssigner
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: affine_channel
FPN:
min_level: 2
max_level: 6
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
norm_type: gn
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
min_level: 2
max_level: 6
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_positive_overlap: 0.7
rpn_negative_overlap: 0.3
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 2000
post_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 1000
post_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
min_level: 2
max_level: 5
box_resolution: 7
sampling_ratio: 2
BBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_lo: 0.0
bg_thresh_hi: 0.5
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: XConvNormHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
XConvNormHead:
norm_type: gn
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [120000, 160000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
FasterRCNNTrainFeed:
batch_size: 2
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_train2017.json
image_dir: train2017
batch_transforms:
- !PadBatch
pad_to_stride: 32
drop_last: false
num_workers: 16
FasterRCNNEvalFeed:
batch_size: 1
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
batch_transforms:
- !PadBatch
pad_to_stride: 32
FasterRCNNTestFeed:
batch_size: 1
dataset:
annotation: annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
drop_last: false
num_workers: 2
architecture: MaskRCNN
train_feed: MaskRCNNTrainFeed
eval_feed: MaskRCNNEvalFeed
test_feed: MaskRCNNTestFeed
max_iters: 360000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
weights: output/mask_rcnn_r50_fpn_gn_2x/model_final/
metric: COCO
num_classes: 81
MaskRCNN:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: BBoxAssigner
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: affine_channel
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
norm_type: gn
FPNRPNHead:
anchor_generator:
aspect_ratios: [0.5, 1.0, 2.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 2000
post_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 1000
post_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
sampling_ratio: 2
box_resolution: 7
mask_resolution: 14
MaskHead:
dilation: 1
conv_dim: 256
num_convs: 4
resolution: 28
norm_type: gn
BBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
MaskAssigner:
resolution: 28
BBoxHead:
head: XConvNormHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
XConvNormHead:
norm_type: gn
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [240000, 320000]
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
MaskRCNNTrainFeed:
batch_size: 1
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_train2017.json
image_dir: train2017
batch_transforms:
- !PadBatch
pad_to_stride: 32
num_workers: 2
MaskRCNNEvalFeed:
batch_size: 1
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
batch_transforms:
- !PadBatch
pad_to_stride: 32
num_workers: 2
MaskRCNNTestFeed:
batch_size: 1
dataset:
annotation: dataset/coco/annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
num_workers: 2
......@@ -68,7 +68,7 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
conv_dim: 256
num_convs: 4
resolution: 28
......@@ -91,7 +91,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -69,7 +69,7 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
conv_dim: 256
num_convs: 4
resolution: 28
......@@ -92,7 +92,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -70,7 +70,7 @@ BBoxHead:
MaskHead:
dilation: 1
num_chan_reduced: 256
conv_dim: 256
resolution: 14
BBoxAssigner:
......
......@@ -71,7 +71,7 @@ BBoxHead:
MaskHead:
dilation: 1
num_chan_reduced: 256
conv_dim: 256
resolution: 14
BBoxAssigner:
......
......@@ -68,7 +68,7 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
conv_dim: 256
num_convs: 4
resolution: 28
......@@ -91,7 +91,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -68,7 +68,7 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
conv_dim: 256
num_convs: 4
resolution: 28
......@@ -91,7 +91,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -69,7 +69,7 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
conv_dim: 256
num_convs: 4
resolution: 28
......@@ -92,7 +92,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -71,7 +71,7 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
conv_dim: 256
num_convs: 4
resolution: 28
......@@ -94,7 +94,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -71,7 +71,7 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
conv_dim: 256
num_convs: 4
resolution: 28
......@@ -94,7 +94,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -71,7 +71,7 @@ FPNRoIAlign:
MaskHead:
dilation: 1
num_chan_reduced: 256
conv_dim: 256
num_convs: 4
resolution: 28
......@@ -94,7 +94,7 @@ BBoxHead:
score_threshold: 0.05
TwoFCHead:
num_chan: 1024
mlp_dim: 1024
LearningRate:
base_lr: 0.01
......
......@@ -67,8 +67,8 @@ def load(anno_path, sample_num=-1, with_background=True):
for img_id in img_ids:
img_anno = coco.loadImgs(img_id)[0]
im_fname = img_anno['file_name']
im_w = img_anno['width']
im_h = img_anno['height']
im_w = float(img_anno['width'])
im_h = float(img_anno['height'])
ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
instances = coco.loadAnns(ins_anno_ids)
......@@ -85,8 +85,8 @@ def load(anno_path, sample_num=-1, with_background=True):
bboxes.append(inst)
else:
logger.warn(
'Found an invalid bbox in annotations: im_id: {}, area: {} x: {}, y: {}, h: {}, w: {}.'.
format(img_id, float(inst['area']), x, y, box_w, box_h))
'Found an invalid bbox in annotations: im_id: {}, area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.
format(img_id, float(inst['area']), x1, y1, x2, y2))
num_bbox = len(bboxes)
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
......
......@@ -36,6 +36,7 @@ class DarkNet(object):
norm_type (str): normalization type, 'bn' and 'sync_bn' are supported
norm_decay (float): weight decay for normalization layer weights
"""
__shared__ = ['norm_type']
def __init__(self, depth=53, norm_type='bn', norm_decay=0.):
assert depth in [53], "unsupported depth value"
......
......@@ -42,6 +42,7 @@ class FPN(object):
has_extra_convs (bool): whether has extral convolutions in higher levels
norm_type (str|None): normalization type, 'bn'/'sync_bn'/'affine_channel'
"""
__shared__ = ['norm_type', 'freeze_norm']
def __init__(self,
num_chan=256,
......@@ -49,7 +50,9 @@ class FPN(object):
max_level=6,
spatial_scale=[1. / 32., 1. / 16., 1. / 8., 1. / 4.],
has_extra_convs=False,
norm_type=None):
norm_type=None,
freeze_norm=False):
self.freeze_norm = freeze_norm
self.num_chan = num_chan
self.min_level = min_level
self.max_level = max_level
......@@ -69,8 +72,9 @@ class FPN(object):
1,
initializer=initializer,
norm_type=self.norm_type,
freeze_norm=self.freeze_norm,
name=lateral_name,
bn_name=lateral_name)
norm_name=lateral_name)
else:
lateral = fluid.layers.conv2d(
body_input,
......@@ -120,8 +124,9 @@ class FPN(object):
1,
initializer=initializer,
norm_type=self.norm_type,
freeze_norm=self.freeze_norm,
name=fpn_inner_name,
bn_name=fpn_inner_name)
norm_name=fpn_inner_name)
else:
self.fpn_inner_output[0] = fluid.layers.conv2d(
body_input,
......@@ -155,8 +160,9 @@ class FPN(object):
3,
initializer=initializer,
norm_type=self.norm_type,
freeze_norm=self.freeze_norm,
name=fpn_name,
bn_name=fpn_name)
norm_name=fpn_name)
else:
fpn_output = fluid.layers.conv2d(
self.fpn_inner_output[i],
......
......@@ -37,6 +37,7 @@ class MobileNet(object):
with_extra_blocks (bool): if extra blocks should be added
extra_block_filters (list): number of filter for each extra block
"""
__shared__ = ['norm_type']
def __init__(self,
norm_type='bn',
......
......@@ -47,6 +47,7 @@ class ResNet(object):
feature_maps (list): index of stages whose feature maps are returned
dcn_v2_stages (list): index of stages who select deformable conv v2
"""
__shared__ = ['norm_type', 'freeze_norm']
def __init__(self,
depth=50,
......
......@@ -16,7 +16,6 @@ from numbers import Integral
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import MSRA
from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register, serializable
......@@ -34,9 +33,11 @@ def ConvNorm(input,
groups=1,
norm_decay=0.,
norm_type='affine_channel',
norm_groups=32,
dilation=1,
freeze_norm=False,
act=None,
bn_name=None,
norm_name=None,
initializer=None,
name=None):
fan = num_filters
......@@ -45,7 +46,8 @@ def ConvNorm(input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
padding=((filter_size - 1) // 2) * dilation,
dilation=dilation,
groups=groups,
act=None,
param_attr=ParamAttr(
......@@ -55,11 +57,11 @@ def ConvNorm(input,
norm_lr = 0. if freeze_norm else 1.
pattr = ParamAttr(
name=bn_name + '_scale',
name=norm_name + '_scale',
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
battr = ParamAttr(
name=bn_name + '_offset',
name=norm_name + '_offset',
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
......@@ -68,14 +70,24 @@ def ConvNorm(input,
out = fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '.output.1',
name=norm_name + '.output.1',
param_attr=pattr,
bias_attr=battr,
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
moving_mean_name=norm_name + '_mean',
moving_variance_name=norm_name + '_variance',
use_global_stats=global_stats)
scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name)
elif norm_type == 'gn':
out = fluid.layers.group_norm(
input=conv,
act=act,
name=norm_name + '.output.1',
groups=norm_groups,
param_attr=pattr,
bias_attr=battr)
scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name)
elif norm_type == 'affine_channel':
scale = fluid.layers.create_parameter(
shape=[conv.shape[1]],
......
......@@ -22,11 +22,13 @@ from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal, Xavier
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.initializer import MSRA
from ppdet.modeling.ops import MultiClassNMS
from ppdet.modeling.ops import ConvNorm
from ppdet.core.workspace import register, serializable
__all__ = ['BBoxHead', 'TwoFCHead']
__all__ = ['BBoxHead', 'TwoFCHead', 'XConvNormHead']
@register
......@@ -47,23 +49,79 @@ class BoxCoder(object):
self.axis = axis
@register
class XConvNormHead(object):
"""
RCNN head with serveral convolution layers
Args:
conv_num (int): num of convolution layers for the rcnn head
conv_dim (int): num of filters for the conv layers
mlp_dim (int): num of filters for the fc layers
"""
__shared__ = ['norm_type', 'freeze_norm']
def __init__(self,
num_conv=4,
conv_dim=256,
mlp_dim=1024,
norm_type=None,
freeze_norm=False):
super(XConvNormHead, self).__init__()
self.conv_dim = conv_dim
self.mlp_dim = mlp_dim
self.num_conv = num_conv
self.norm_type = norm_type
self.freeze_norm = freeze_norm
def __call__(self, roi_feat):
conv = roi_feat
fan = self.conv_dim * 3 * 3
initializer = MSRA(uniform=False, fan_in=fan)
for i in range(self.num_conv):
name = 'bbox_head_conv' + str(i)
conv = ConvNorm(
conv,
self.conv_dim,
3,
act='relu',
initializer=initializer,
norm_type=self.norm_type,
freeze_norm=self.freeze_norm,
name=name,
norm_name=name)
fan = conv.shape[1] * conv.shape[2] * conv.shape[3]
head_heat = fluid.layers.fc(input=conv,
size=self.mlp_dim,
act='relu',
name='fc6' + name,
param_attr=ParamAttr(
name='fc6%s_w' % name,
initializer=Xavier(fan_out=fan)),
bias_attr=ParamAttr(
name='fc6%s_b' % name,
learning_rate=2,
regularizer=L2Decay(0.)))
return head_heat
@register
class TwoFCHead(object):
"""
RCNN head with two Fully Connected layers
Args:
num_chan (int): num of filters for the fc layers
mlp_dim (int): num of filters for the fc layers
"""
def __init__(self, num_chan=1024):
def __init__(self, mlp_dim=1024):
super(TwoFCHead, self).__init__()
self.num_chan = num_chan
self.mlp_dim = mlp_dim
def __call__(self, roi_feat):
fan = roi_feat.shape[1] * roi_feat.shape[2] * roi_feat.shape[3]
fc6 = fluid.layers.fc(input=roi_feat,
size=self.num_chan,
size=self.mlp_dim,
act='relu',
name='fc6',
param_attr=ParamAttr(
......@@ -74,7 +132,7 @@ class TwoFCHead(object):
learning_rate=2.,
regularizer=L2Decay(0.)))
head_feat = fluid.layers.fc(input=fc6,
size=self.num_chan,
size=self.mlp_dim,
act='relu',
name='fc7',
param_attr=ParamAttr(
......@@ -143,7 +201,8 @@ class BBoxHead(object):
"""
head_feat = self.get_head_feat(roi_feat)
# when ResNetC5 output a single feature map
if not isinstance(self.head, TwoFCHead):
if not isinstance(self.head, TwoFCHead) and not isinstance(
self.head, XConvNormHead):
head_feat = fluid.layers.pool2d(
head_feat, pool_type='avg', global_pooling=True)
cls_score = fluid.layers.fc(input=head_feat,
......
......@@ -22,6 +22,7 @@ from paddle.fluid.initializer import MSRA
from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register
from ppdet.modeling.ops import ConvNorm
__all__ = ['MaskHead']
......@@ -31,8 +32,8 @@ class MaskHead(object):
"""
RCNN mask head
Args:
num_convs (int): num of convolutions, 4 for FPN, 0 otherwise
num_chan_reduced (int): num of channels after first convolution
num_convs (int): num of convolutions, 4 for FPN, 1 otherwise
conv_dim (int): num of channels after first convolution
resolution (int): size of the output mask
dilation (int): dilation rate
num_classes (int): number of output classes
......@@ -42,24 +43,43 @@ class MaskHead(object):
def __init__(self,
num_convs=0,
num_chan_reduced=256,
conv_dim=256,
resolution=14,
dilation=1,
num_classes=81):
num_classes=81,
norm_type=None):
super(MaskHead, self).__init__()
self.num_convs = num_convs
self.num_chan_reduced = num_chan_reduced
self.conv_dim = conv_dim
self.resolution = resolution
self.dilation = dilation
self.num_classes = num_classes
self.norm_type = norm_type
def _mask_conv_head(self, roi_feat, num_convs):
def _mask_conv_head(self, roi_feat, num_convs, norm_type):
if norm_type == 'gn':
for i in range(num_convs):
layer_name = "mask_inter_feat_" + str(i + 1)
fan = self.num_chan_reduced * 3 * 3
fan = self.conv_dim * 3 * 3
initializer = MSRA(uniform=False, fan_in=fan)
roi_feat = ConvNorm(
roi_feat,
self.conv_dim,
3,
act='relu',
dilation=self.dilation,
initializer=initializer,
norm_type=self.norm_type,
name=layer_name,
norm_name=layer_name)
else:
for i in range(num_convs):
layer_name = "mask_inter_feat_" + str(i + 1)
fan = self.conv_dim * 3 * 3
initializer = MSRA(uniform=False, fan_in=fan)
roi_feat = fluid.layers.conv2d(
input=roi_feat,
num_filters=self.num_chan_reduced,
num_filters=self.conv_dim,
filter_size=3,
padding=1 * self.dilation,
act='relu',
......@@ -67,9 +87,7 @@ class MaskHead(object):
dilation=self.dilation,
name=layer_name,
param_attr=ParamAttr(
name=layer_name + '_w',
initializer=MSRA(
uniform=False, fan_in=fan)),
name=layer_name + '_w', initializer=initializer),
bias_attr=ParamAttr(
name=layer_name + '_b',
learning_rate=2.,
......@@ -77,7 +95,7 @@ class MaskHead(object):
fan = roi_feat.shape[1] * 2 * 2
feat = fluid.layers.conv2d_transpose(
input=roi_feat,
num_filters=self.num_chan_reduced,
num_filters=self.conv_dim,
filter_size=2,
stride=2,
act='relu',
......@@ -92,7 +110,8 @@ class MaskHead(object):
def _get_output(self, roi_feat):
class_num = self.num_classes
# configure the conv number for FPN if necessary
head_feat = self._mask_conv_head(roi_feat, self.num_convs)
head_feat = self._mask_conv_head(roi_feat, self.num_convs,
self.norm_type)
fan = class_num
mask_logits = fluid.layers.conv2d(
input=head_feat,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册