提交 ce96e567 编写于 作者: Y Yuan Gao 提交者: wangguanzhong

Add group norm (#3140)

* add a global norm_type flag

* add group norm on fpn

* add gn on box head

* add faster rcnn fpn 50 gn config

* add mask branch norm

* update configs
上级 bdd4bc8a
...@@ -86,7 +86,7 @@ BBoxHead: ...@@ -86,7 +86,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.02 base_lr: 0.02
......
...@@ -85,7 +85,7 @@ BBoxHead: ...@@ -85,7 +85,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.02 base_lr: 0.02
......
...@@ -86,7 +86,7 @@ BBoxHead: ...@@ -86,7 +86,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.02 base_lr: 0.02
......
...@@ -88,7 +88,7 @@ BBoxHead: ...@@ -88,7 +88,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -71,7 +71,7 @@ FPNRoIAlign: ...@@ -71,7 +71,7 @@ FPNRoIAlign:
MaskHead: MaskHead:
dilation: 1 dilation: 1
num_chan_reduced: 256 conv_dim: 256
num_convs: 4 num_convs: 4
resolution: 28 resolution: 28
...@@ -94,7 +94,7 @@ BBoxHead: ...@@ -94,7 +94,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -70,7 +70,7 @@ FPNRoIAlign: ...@@ -70,7 +70,7 @@ FPNRoIAlign:
MaskHead: MaskHead:
dilation: 1 dilation: 1
num_chan_reduced: 256 conv_dim: 256
num_convs: 4 num_convs: 4
resolution: 28 resolution: 28
...@@ -93,7 +93,7 @@ BBoxHead: ...@@ -93,7 +93,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -71,7 +71,7 @@ FPNRoIAlign: ...@@ -71,7 +71,7 @@ FPNRoIAlign:
MaskHead: MaskHead:
dilation: 1 dilation: 1
num_chan_reduced: 256 conv_dim: 256
num_convs: 4 num_convs: 4
resolution: 28 resolution: 28
...@@ -94,7 +94,7 @@ BBoxHead: ...@@ -94,7 +94,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -73,7 +73,7 @@ FPNRoIAlign: ...@@ -73,7 +73,7 @@ FPNRoIAlign:
MaskHead: MaskHead:
dilation: 1 dilation: 1
num_chan_reduced: 256 conv_dim: 256
num_convs: 4 num_convs: 4
resolution: 28 resolution: 28
...@@ -96,7 +96,7 @@ BBoxHead: ...@@ -96,7 +96,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -83,7 +83,7 @@ BBoxHead: ...@@ -83,7 +83,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -83,7 +83,7 @@ BBoxHead: ...@@ -83,7 +83,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -84,7 +84,7 @@ BBoxHead: ...@@ -84,7 +84,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -84,7 +84,7 @@ BBoxHead: ...@@ -84,7 +84,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -84,7 +84,7 @@ BBoxHead: ...@@ -84,7 +84,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.02 base_lr: 0.02
......
...@@ -84,7 +84,7 @@ BBoxHead: ...@@ -84,7 +84,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.02 base_lr: 0.02
......
...@@ -84,7 +84,7 @@ BBoxHead: ...@@ -84,7 +84,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.02 base_lr: 0.02
......
...@@ -86,7 +86,7 @@ BBoxHead: ...@@ -86,7 +86,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -86,7 +86,7 @@ BBoxHead: ...@@ -86,7 +86,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -86,7 +86,7 @@ BBoxHead: ...@@ -86,7 +86,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
architecture: FasterRCNN
train_feed: FasterRCNNTrainFeed
eval_feed: FasterRCNNEvalFeed
test_feed: FasterRCNNTestFeed
max_iters: 180000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/faster_rcnn_r50_fpn_gn/model_final
num_classes: 81
FasterRCNN:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: BBoxAssigner
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: affine_channel
FPN:
min_level: 2
max_level: 6
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
norm_type: gn
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
min_level: 2
max_level: 6
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_positive_overlap: 0.7
rpn_negative_overlap: 0.3
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 2000
post_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 1000
post_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
min_level: 2
max_level: 5
box_resolution: 7
sampling_ratio: 2
BBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_lo: 0.0
bg_thresh_hi: 0.5
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: XConvNormHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
XConvNormHead:
norm_type: gn
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [120000, 160000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
FasterRCNNTrainFeed:
batch_size: 2
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_train2017.json
image_dir: train2017
batch_transforms:
- !PadBatch
pad_to_stride: 32
drop_last: false
num_workers: 16
FasterRCNNEvalFeed:
batch_size: 1
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
batch_transforms:
- !PadBatch
pad_to_stride: 32
FasterRCNNTestFeed:
batch_size: 1
dataset:
annotation: annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
drop_last: false
num_workers: 2
architecture: MaskRCNN
train_feed: MaskRCNNTrainFeed
eval_feed: MaskRCNNEvalFeed
test_feed: MaskRCNNTestFeed
max_iters: 360000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
weights: output/mask_rcnn_r50_fpn_gn_2x/model_final/
metric: COCO
num_classes: 81
MaskRCNN:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: BBoxAssigner
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: affine_channel
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
norm_type: gn
FPNRPNHead:
anchor_generator:
aspect_ratios: [0.5, 1.0, 2.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 2000
post_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 1000
post_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
sampling_ratio: 2
box_resolution: 7
mask_resolution: 14
MaskHead:
dilation: 1
conv_dim: 256
num_convs: 4
resolution: 28
norm_type: gn
BBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
MaskAssigner:
resolution: 28
BBoxHead:
head: XConvNormHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
XConvNormHead:
norm_type: gn
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [240000, 320000]
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
MaskRCNNTrainFeed:
batch_size: 1
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_train2017.json
image_dir: train2017
batch_transforms:
- !PadBatch
pad_to_stride: 32
num_workers: 2
MaskRCNNEvalFeed:
batch_size: 1
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
batch_transforms:
- !PadBatch
pad_to_stride: 32
num_workers: 2
MaskRCNNTestFeed:
batch_size: 1
dataset:
annotation: dataset/coco/annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
num_workers: 2
...@@ -68,7 +68,7 @@ FPNRoIAlign: ...@@ -68,7 +68,7 @@ FPNRoIAlign:
MaskHead: MaskHead:
dilation: 1 dilation: 1
num_chan_reduced: 256 conv_dim: 256
num_convs: 4 num_convs: 4
resolution: 28 resolution: 28
...@@ -91,7 +91,7 @@ BBoxHead: ...@@ -91,7 +91,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -69,7 +69,7 @@ FPNRoIAlign: ...@@ -69,7 +69,7 @@ FPNRoIAlign:
MaskHead: MaskHead:
dilation: 1 dilation: 1
num_chan_reduced: 256 conv_dim: 256
num_convs: 4 num_convs: 4
resolution: 28 resolution: 28
...@@ -92,7 +92,7 @@ BBoxHead: ...@@ -92,7 +92,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -70,7 +70,7 @@ BBoxHead: ...@@ -70,7 +70,7 @@ BBoxHead:
MaskHead: MaskHead:
dilation: 1 dilation: 1
num_chan_reduced: 256 conv_dim: 256
resolution: 14 resolution: 14
BBoxAssigner: BBoxAssigner:
......
...@@ -71,7 +71,7 @@ BBoxHead: ...@@ -71,7 +71,7 @@ BBoxHead:
MaskHead: MaskHead:
dilation: 1 dilation: 1
num_chan_reduced: 256 conv_dim: 256
resolution: 14 resolution: 14
BBoxAssigner: BBoxAssigner:
......
...@@ -68,7 +68,7 @@ FPNRoIAlign: ...@@ -68,7 +68,7 @@ FPNRoIAlign:
MaskHead: MaskHead:
dilation: 1 dilation: 1
num_chan_reduced: 256 conv_dim: 256
num_convs: 4 num_convs: 4
resolution: 28 resolution: 28
...@@ -91,7 +91,7 @@ BBoxHead: ...@@ -91,7 +91,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -68,7 +68,7 @@ FPNRoIAlign: ...@@ -68,7 +68,7 @@ FPNRoIAlign:
MaskHead: MaskHead:
dilation: 1 dilation: 1
num_chan_reduced: 256 conv_dim: 256
num_convs: 4 num_convs: 4
resolution: 28 resolution: 28
...@@ -91,7 +91,7 @@ BBoxHead: ...@@ -91,7 +91,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -69,7 +69,7 @@ FPNRoIAlign: ...@@ -69,7 +69,7 @@ FPNRoIAlign:
MaskHead: MaskHead:
dilation: 1 dilation: 1
num_chan_reduced: 256 conv_dim: 256
num_convs: 4 num_convs: 4
resolution: 28 resolution: 28
...@@ -92,7 +92,7 @@ BBoxHead: ...@@ -92,7 +92,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -71,7 +71,7 @@ FPNRoIAlign: ...@@ -71,7 +71,7 @@ FPNRoIAlign:
MaskHead: MaskHead:
dilation: 1 dilation: 1
num_chan_reduced: 256 conv_dim: 256
num_convs: 4 num_convs: 4
resolution: 28 resolution: 28
...@@ -94,7 +94,7 @@ BBoxHead: ...@@ -94,7 +94,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -71,7 +71,7 @@ FPNRoIAlign: ...@@ -71,7 +71,7 @@ FPNRoIAlign:
MaskHead: MaskHead:
dilation: 1 dilation: 1
num_chan_reduced: 256 conv_dim: 256
num_convs: 4 num_convs: 4
resolution: 28 resolution: 28
...@@ -94,7 +94,7 @@ BBoxHead: ...@@ -94,7 +94,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -71,7 +71,7 @@ FPNRoIAlign: ...@@ -71,7 +71,7 @@ FPNRoIAlign:
MaskHead: MaskHead:
dilation: 1 dilation: 1
num_chan_reduced: 256 conv_dim: 256
num_convs: 4 num_convs: 4
resolution: 28 resolution: 28
...@@ -94,7 +94,7 @@ BBoxHead: ...@@ -94,7 +94,7 @@ BBoxHead:
score_threshold: 0.05 score_threshold: 0.05
TwoFCHead: TwoFCHead:
num_chan: 1024 mlp_dim: 1024
LearningRate: LearningRate:
base_lr: 0.01 base_lr: 0.01
......
...@@ -67,8 +67,8 @@ def load(anno_path, sample_num=-1, with_background=True): ...@@ -67,8 +67,8 @@ def load(anno_path, sample_num=-1, with_background=True):
for img_id in img_ids: for img_id in img_ids:
img_anno = coco.loadImgs(img_id)[0] img_anno = coco.loadImgs(img_id)[0]
im_fname = img_anno['file_name'] im_fname = img_anno['file_name']
im_w = img_anno['width'] im_w = float(img_anno['width'])
im_h = img_anno['height'] im_h = float(img_anno['height'])
ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False) ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
instances = coco.loadAnns(ins_anno_ids) instances = coco.loadAnns(ins_anno_ids)
...@@ -85,8 +85,8 @@ def load(anno_path, sample_num=-1, with_background=True): ...@@ -85,8 +85,8 @@ def load(anno_path, sample_num=-1, with_background=True):
bboxes.append(inst) bboxes.append(inst)
else: else:
logger.warn( logger.warn(
'Found an invalid bbox in annotations: im_id: {}, area: {} x: {}, y: {}, h: {}, w: {}.'. 'Found an invalid bbox in annotations: im_id: {}, area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.
format(img_id, float(inst['area']), x, y, box_w, box_h)) format(img_id, float(inst['area']), x1, y1, x2, y2))
num_bbox = len(bboxes) num_bbox = len(bboxes)
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32) gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
......
...@@ -36,6 +36,7 @@ class DarkNet(object): ...@@ -36,6 +36,7 @@ class DarkNet(object):
norm_type (str): normalization type, 'bn' and 'sync_bn' are supported norm_type (str): normalization type, 'bn' and 'sync_bn' are supported
norm_decay (float): weight decay for normalization layer weights norm_decay (float): weight decay for normalization layer weights
""" """
__shared__ = ['norm_type']
def __init__(self, depth=53, norm_type='bn', norm_decay=0.): def __init__(self, depth=53, norm_type='bn', norm_decay=0.):
assert depth in [53], "unsupported depth value" assert depth in [53], "unsupported depth value"
......
...@@ -42,6 +42,7 @@ class FPN(object): ...@@ -42,6 +42,7 @@ class FPN(object):
has_extra_convs (bool): whether has extral convolutions in higher levels has_extra_convs (bool): whether has extral convolutions in higher levels
norm_type (str|None): normalization type, 'bn'/'sync_bn'/'affine_channel' norm_type (str|None): normalization type, 'bn'/'sync_bn'/'affine_channel'
""" """
__shared__ = ['norm_type', 'freeze_norm']
def __init__(self, def __init__(self,
num_chan=256, num_chan=256,
...@@ -49,7 +50,9 @@ class FPN(object): ...@@ -49,7 +50,9 @@ class FPN(object):
max_level=6, max_level=6,
spatial_scale=[1. / 32., 1. / 16., 1. / 8., 1. / 4.], spatial_scale=[1. / 32., 1. / 16., 1. / 8., 1. / 4.],
has_extra_convs=False, has_extra_convs=False,
norm_type=None): norm_type=None,
freeze_norm=False):
self.freeze_norm = freeze_norm
self.num_chan = num_chan self.num_chan = num_chan
self.min_level = min_level self.min_level = min_level
self.max_level = max_level self.max_level = max_level
...@@ -69,8 +72,9 @@ class FPN(object): ...@@ -69,8 +72,9 @@ class FPN(object):
1, 1,
initializer=initializer, initializer=initializer,
norm_type=self.norm_type, norm_type=self.norm_type,
freeze_norm=self.freeze_norm,
name=lateral_name, name=lateral_name,
bn_name=lateral_name) norm_name=lateral_name)
else: else:
lateral = fluid.layers.conv2d( lateral = fluid.layers.conv2d(
body_input, body_input,
...@@ -120,8 +124,9 @@ class FPN(object): ...@@ -120,8 +124,9 @@ class FPN(object):
1, 1,
initializer=initializer, initializer=initializer,
norm_type=self.norm_type, norm_type=self.norm_type,
freeze_norm=self.freeze_norm,
name=fpn_inner_name, name=fpn_inner_name,
bn_name=fpn_inner_name) norm_name=fpn_inner_name)
else: else:
self.fpn_inner_output[0] = fluid.layers.conv2d( self.fpn_inner_output[0] = fluid.layers.conv2d(
body_input, body_input,
...@@ -155,8 +160,9 @@ class FPN(object): ...@@ -155,8 +160,9 @@ class FPN(object):
3, 3,
initializer=initializer, initializer=initializer,
norm_type=self.norm_type, norm_type=self.norm_type,
freeze_norm=self.freeze_norm,
name=fpn_name, name=fpn_name,
bn_name=fpn_name) norm_name=fpn_name)
else: else:
fpn_output = fluid.layers.conv2d( fpn_output = fluid.layers.conv2d(
self.fpn_inner_output[i], self.fpn_inner_output[i],
......
...@@ -37,6 +37,7 @@ class MobileNet(object): ...@@ -37,6 +37,7 @@ class MobileNet(object):
with_extra_blocks (bool): if extra blocks should be added with_extra_blocks (bool): if extra blocks should be added
extra_block_filters (list): number of filter for each extra block extra_block_filters (list): number of filter for each extra block
""" """
__shared__ = ['norm_type']
def __init__(self, def __init__(self,
norm_type='bn', norm_type='bn',
......
...@@ -47,6 +47,7 @@ class ResNet(object): ...@@ -47,6 +47,7 @@ class ResNet(object):
feature_maps (list): index of stages whose feature maps are returned feature_maps (list): index of stages whose feature maps are returned
dcn_v2_stages (list): index of stages who select deformable conv v2 dcn_v2_stages (list): index of stages who select deformable conv v2
""" """
__shared__ = ['norm_type', 'freeze_norm']
def __init__(self, def __init__(self,
depth=50, depth=50,
......
...@@ -16,7 +16,6 @@ from numbers import Integral ...@@ -16,7 +16,6 @@ from numbers import Integral
from paddle import fluid from paddle import fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import MSRA
from paddle.fluid.regularizer import L2Decay from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
...@@ -34,9 +33,11 @@ def ConvNorm(input, ...@@ -34,9 +33,11 @@ def ConvNorm(input,
groups=1, groups=1,
norm_decay=0., norm_decay=0.,
norm_type='affine_channel', norm_type='affine_channel',
norm_groups=32,
dilation=1,
freeze_norm=False, freeze_norm=False,
act=None, act=None,
bn_name=None, norm_name=None,
initializer=None, initializer=None,
name=None): name=None):
fan = num_filters fan = num_filters
...@@ -45,7 +46,8 @@ def ConvNorm(input, ...@@ -45,7 +46,8 @@ def ConvNorm(input,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=((filter_size - 1) // 2) * dilation,
dilation=dilation,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr( param_attr=ParamAttr(
...@@ -55,11 +57,11 @@ def ConvNorm(input, ...@@ -55,11 +57,11 @@ def ConvNorm(input,
norm_lr = 0. if freeze_norm else 1. norm_lr = 0. if freeze_norm else 1.
pattr = ParamAttr( pattr = ParamAttr(
name=bn_name + '_scale', name=norm_name + '_scale',
learning_rate=norm_lr, learning_rate=norm_lr,
regularizer=L2Decay(norm_decay)) regularizer=L2Decay(norm_decay))
battr = ParamAttr( battr = ParamAttr(
name=bn_name + '_offset', name=norm_name + '_offset',
learning_rate=norm_lr, learning_rate=norm_lr,
regularizer=L2Decay(norm_decay)) regularizer=L2Decay(norm_decay))
...@@ -68,14 +70,24 @@ def ConvNorm(input, ...@@ -68,14 +70,24 @@ def ConvNorm(input,
out = fluid.layers.batch_norm( out = fluid.layers.batch_norm(
input=conv, input=conv,
act=act, act=act,
name=bn_name + '.output.1', name=norm_name + '.output.1',
param_attr=pattr, param_attr=pattr,
bias_attr=battr, bias_attr=battr,
moving_mean_name=bn_name + '_mean', moving_mean_name=norm_name + '_mean',
moving_variance_name=bn_name + '_variance', moving_variance_name=norm_name + '_variance',
use_global_stats=global_stats) use_global_stats=global_stats)
scale = fluid.framework._get_var(pattr.name) scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name) bias = fluid.framework._get_var(battr.name)
elif norm_type == 'gn':
out = fluid.layers.group_norm(
input=conv,
act=act,
name=norm_name + '.output.1',
groups=norm_groups,
param_attr=pattr,
bias_attr=battr)
scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name)
elif norm_type == 'affine_channel': elif norm_type == 'affine_channel':
scale = fluid.layers.create_parameter( scale = fluid.layers.create_parameter(
shape=[conv.shape[1]], shape=[conv.shape[1]],
......
...@@ -22,11 +22,13 @@ from paddle import fluid ...@@ -22,11 +22,13 @@ from paddle import fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal, Xavier from paddle.fluid.initializer import Normal, Xavier
from paddle.fluid.regularizer import L2Decay from paddle.fluid.regularizer import L2Decay
from paddle.fluid.initializer import MSRA
from ppdet.modeling.ops import MultiClassNMS from ppdet.modeling.ops import MultiClassNMS
from ppdet.modeling.ops import ConvNorm
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
__all__ = ['BBoxHead', 'TwoFCHead'] __all__ = ['BBoxHead', 'TwoFCHead', 'XConvNormHead']
@register @register
...@@ -47,23 +49,79 @@ class BoxCoder(object): ...@@ -47,23 +49,79 @@ class BoxCoder(object):
self.axis = axis self.axis = axis
@register
class XConvNormHead(object):
"""
RCNN head with serveral convolution layers
Args:
conv_num (int): num of convolution layers for the rcnn head
conv_dim (int): num of filters for the conv layers
mlp_dim (int): num of filters for the fc layers
"""
__shared__ = ['norm_type', 'freeze_norm']
def __init__(self,
num_conv=4,
conv_dim=256,
mlp_dim=1024,
norm_type=None,
freeze_norm=False):
super(XConvNormHead, self).__init__()
self.conv_dim = conv_dim
self.mlp_dim = mlp_dim
self.num_conv = num_conv
self.norm_type = norm_type
self.freeze_norm = freeze_norm
def __call__(self, roi_feat):
conv = roi_feat
fan = self.conv_dim * 3 * 3
initializer = MSRA(uniform=False, fan_in=fan)
for i in range(self.num_conv):
name = 'bbox_head_conv' + str(i)
conv = ConvNorm(
conv,
self.conv_dim,
3,
act='relu',
initializer=initializer,
norm_type=self.norm_type,
freeze_norm=self.freeze_norm,
name=name,
norm_name=name)
fan = conv.shape[1] * conv.shape[2] * conv.shape[3]
head_heat = fluid.layers.fc(input=conv,
size=self.mlp_dim,
act='relu',
name='fc6' + name,
param_attr=ParamAttr(
name='fc6%s_w' % name,
initializer=Xavier(fan_out=fan)),
bias_attr=ParamAttr(
name='fc6%s_b' % name,
learning_rate=2,
regularizer=L2Decay(0.)))
return head_heat
@register @register
class TwoFCHead(object): class TwoFCHead(object):
""" """
RCNN head with two Fully Connected layers RCNN head with two Fully Connected layers
Args: Args:
num_chan (int): num of filters for the fc layers mlp_dim (int): num of filters for the fc layers
""" """
def __init__(self, num_chan=1024): def __init__(self, mlp_dim=1024):
super(TwoFCHead, self).__init__() super(TwoFCHead, self).__init__()
self.num_chan = num_chan self.mlp_dim = mlp_dim
def __call__(self, roi_feat): def __call__(self, roi_feat):
fan = roi_feat.shape[1] * roi_feat.shape[2] * roi_feat.shape[3] fan = roi_feat.shape[1] * roi_feat.shape[2] * roi_feat.shape[3]
fc6 = fluid.layers.fc(input=roi_feat, fc6 = fluid.layers.fc(input=roi_feat,
size=self.num_chan, size=self.mlp_dim,
act='relu', act='relu',
name='fc6', name='fc6',
param_attr=ParamAttr( param_attr=ParamAttr(
...@@ -74,7 +132,7 @@ class TwoFCHead(object): ...@@ -74,7 +132,7 @@ class TwoFCHead(object):
learning_rate=2., learning_rate=2.,
regularizer=L2Decay(0.))) regularizer=L2Decay(0.)))
head_feat = fluid.layers.fc(input=fc6, head_feat = fluid.layers.fc(input=fc6,
size=self.num_chan, size=self.mlp_dim,
act='relu', act='relu',
name='fc7', name='fc7',
param_attr=ParamAttr( param_attr=ParamAttr(
...@@ -143,7 +201,8 @@ class BBoxHead(object): ...@@ -143,7 +201,8 @@ class BBoxHead(object):
""" """
head_feat = self.get_head_feat(roi_feat) head_feat = self.get_head_feat(roi_feat)
# when ResNetC5 output a single feature map # when ResNetC5 output a single feature map
if not isinstance(self.head, TwoFCHead): if not isinstance(self.head, TwoFCHead) and not isinstance(
self.head, XConvNormHead):
head_feat = fluid.layers.pool2d( head_feat = fluid.layers.pool2d(
head_feat, pool_type='avg', global_pooling=True) head_feat, pool_type='avg', global_pooling=True)
cls_score = fluid.layers.fc(input=head_feat, cls_score = fluid.layers.fc(input=head_feat,
......
...@@ -22,6 +22,7 @@ from paddle.fluid.initializer import MSRA ...@@ -22,6 +22,7 @@ from paddle.fluid.initializer import MSRA
from paddle.fluid.regularizer import L2Decay from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ppdet.modeling.ops import ConvNorm
__all__ = ['MaskHead'] __all__ = ['MaskHead']
...@@ -31,8 +32,8 @@ class MaskHead(object): ...@@ -31,8 +32,8 @@ class MaskHead(object):
""" """
RCNN mask head RCNN mask head
Args: Args:
num_convs (int): num of convolutions, 4 for FPN, 0 otherwise num_convs (int): num of convolutions, 4 for FPN, 1 otherwise
num_chan_reduced (int): num of channels after first convolution conv_dim (int): num of channels after first convolution
resolution (int): size of the output mask resolution (int): size of the output mask
dilation (int): dilation rate dilation (int): dilation rate
num_classes (int): number of output classes num_classes (int): number of output classes
...@@ -42,42 +43,59 @@ class MaskHead(object): ...@@ -42,42 +43,59 @@ class MaskHead(object):
def __init__(self, def __init__(self,
num_convs=0, num_convs=0,
num_chan_reduced=256, conv_dim=256,
resolution=14, resolution=14,
dilation=1, dilation=1,
num_classes=81): num_classes=81,
norm_type=None):
super(MaskHead, self).__init__() super(MaskHead, self).__init__()
self.num_convs = num_convs self.num_convs = num_convs
self.num_chan_reduced = num_chan_reduced self.conv_dim = conv_dim
self.resolution = resolution self.resolution = resolution
self.dilation = dilation self.dilation = dilation
self.num_classes = num_classes self.num_classes = num_classes
self.norm_type = norm_type
def _mask_conv_head(self, roi_feat, num_convs): def _mask_conv_head(self, roi_feat, num_convs, norm_type):
for i in range(num_convs): if norm_type == 'gn':
layer_name = "mask_inter_feat_" + str(i + 1) for i in range(num_convs):
fan = self.num_chan_reduced * 3 * 3 layer_name = "mask_inter_feat_" + str(i + 1)
roi_feat = fluid.layers.conv2d( fan = self.conv_dim * 3 * 3
input=roi_feat, initializer = MSRA(uniform=False, fan_in=fan)
num_filters=self.num_chan_reduced, roi_feat = ConvNorm(
filter_size=3, roi_feat,
padding=1 * self.dilation, self.conv_dim,
act='relu', 3,
stride=1, act='relu',
dilation=self.dilation, dilation=self.dilation,
name=layer_name, initializer=initializer,
param_attr=ParamAttr( norm_type=self.norm_type,
name=layer_name + '_w', name=layer_name,
initializer=MSRA( norm_name=layer_name)
uniform=False, fan_in=fan)), else:
bias_attr=ParamAttr( for i in range(num_convs):
name=layer_name + '_b', layer_name = "mask_inter_feat_" + str(i + 1)
learning_rate=2., fan = self.conv_dim * 3 * 3
regularizer=L2Decay(0.))) initializer = MSRA(uniform=False, fan_in=fan)
roi_feat = fluid.layers.conv2d(
input=roi_feat,
num_filters=self.conv_dim,
filter_size=3,
padding=1 * self.dilation,
act='relu',
stride=1,
dilation=self.dilation,
name=layer_name,
param_attr=ParamAttr(
name=layer_name + '_w', initializer=initializer),
bias_attr=ParamAttr(
name=layer_name + '_b',
learning_rate=2.,
regularizer=L2Decay(0.)))
fan = roi_feat.shape[1] * 2 * 2 fan = roi_feat.shape[1] * 2 * 2
feat = fluid.layers.conv2d_transpose( feat = fluid.layers.conv2d_transpose(
input=roi_feat, input=roi_feat,
num_filters=self.num_chan_reduced, num_filters=self.conv_dim,
filter_size=2, filter_size=2,
stride=2, stride=2,
act='relu', act='relu',
...@@ -92,7 +110,8 @@ class MaskHead(object): ...@@ -92,7 +110,8 @@ class MaskHead(object):
def _get_output(self, roi_feat): def _get_output(self, roi_feat):
class_num = self.num_classes class_num = self.num_classes
# configure the conv number for FPN if necessary # configure the conv number for FPN if necessary
head_feat = self._mask_conv_head(roi_feat, self.num_convs) head_feat = self._mask_conv_head(roi_feat, self.num_convs,
self.norm_type)
fan = class_num fan = class_num
mask_logits = fluid.layers.conv2d( mask_logits = fluid.layers.conv2d(
input=head_feat, input=head_feat,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册