未验证 提交 4cd12915 编写于 作者: F Feng Ni 提交者: GitHub

[dygraph] Add gnfpn and gnhead (#2226)

* add gn head fpn, test=dygraph

* add gn for cascade

* update gn readme, test=dygraph
上级 e527466d
# Group Normalization
## Model Zoo
| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps)| Box AP | Mask AP | 下载 | 配置文件 |
| :------------- | :------------- | :-----------: | :------: | :--------: |:-----: | :-----: | :----: | :----: |
| ResNet50-FPN | Faster | 1 | 2x | - | 41.9 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/faster_rcnn_r50_fpn_gn_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/gn/faster_rcnn_r50_fpn_gn_2x_coco.yml) |
| ResNet50-FPN | Mask | 1 | 2x | - | - | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/mask_rcnn_r50_fpn_gn_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/gn/mask_rcnn_r50_fpn_gn_2x_coco.yml) |
**注意:** Faster R-CNN baseline仅使用 `2fc` head,而此处使用[`4conv1fc` head](https://arxiv.org/abs/1803.08494)(4层conv之间使用GN),并且FPN也使用GN,而对于Mask R-CNN是在mask head的4层conv之间也使用GN。
## Citations
```
@inproceedings{wu2018group,
title={Group Normalization},
author={Wu, Yuxin and He, Kaiming},
booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
year={2018}
}
```
_BASE_: [
'../datasets/coco_instance.yml',
'../runtime.yml',
'../cascade_rcnn/_base_/optimizer_1x.yml',
'../cascade_rcnn/_base_/cascade_mask_rcnn_r50_fpn.yml',
'../cascade_rcnn/_base_/cascade_mask_fpn_reader.yml',
]
weights: output/cascade_mask_rcnn_r50_fpn_gn_2x/model_final
CascadeRCNN:
backbone: ResNet
neck: FPN
rpn_head: RPNHead
bbox_head: CascadeHead
mask_head: MaskHead
# post process
bbox_post_process: BBoxPostProcess
mask_post_process: MaskPostProcess
FPN:
out_channel: 256
norm_type: gn
CascadeHead:
head: CascadeXConvNormHead
roi_extractor:
resolution: 7
sampling_ratio: 0
aligned: True
bbox_assigner: BBoxAssigner
CascadeXConvNormHead:
num_convs: 4
mlp_dim: 1024
norm_type: gn
MaskHead:
head: MaskFeat
roi_extractor:
resolution: 14
sampling_ratio: 0
aligned: True
mask_assigner: MaskAssigner
share_bbox_feat: False
MaskFeat:
num_convs: 4
out_channels: 256
norm_type: gn
epoch: 24
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [16, 22]
- !LinearWarmup
start_factor: 0.1
steps: 1000
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'../cascade_rcnn/_base_/optimizer_1x.yml',
'../cascade_rcnn/_base_/cascade_rcnn_r50_fpn.yml',
'../cascade_rcnn/_base_/cascade_fpn_reader.yml',
]
weights: output/cascade_rcnn_r50_fpn_gn_2x/model_final
FPN:
out_channel: 256
norm_type: gn
CascadeHead:
head: CascadeXConvNormHead
roi_extractor:
resolution: 7
sampling_ratio: 0
aligned: True
bbox_assigner: BBoxAssigner
CascadeXConvNormHead:
num_convs: 4
mlp_dim: 1024
norm_type: gn
epoch: 24
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [16, 22]
- !LinearWarmup
start_factor: 0.1
steps: 1000
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'../faster_rcnn/_base_/optimizer_1x.yml',
'../faster_rcnn/_base_/faster_rcnn_r50_fpn.yml',
'../faster_rcnn/_base_/faster_fpn_reader.yml',
]
weights: output/faster_rcnn_r50_fpn_gn_2x_coco/model_final
FasterRCNN:
backbone: ResNet
neck: FPN
rpn_head: RPNHead
bbox_head: BBoxHead
# post process
bbox_post_process: BBoxPostProcess
FPN:
out_channel: 256
norm_type: gn
BBoxHead:
head: XConvNormHead
roi_extractor:
resolution: 7
sampling_ratio: 0
aligned: True
bbox_assigner: BBoxAssigner
XConvNormHead:
num_convs: 4
mlp_dim: 1024
norm_type: gn
epoch: 24
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [16, 22]
- !LinearWarmup
start_factor: 0.1
steps: 1000
_BASE_: [
'../datasets/coco_instance.yml',
'../runtime.yml',
'../mask_rcnn/_base_/optimizer_1x.yml',
'../mask_rcnn/_base_/mask_rcnn_r50_fpn.yml',
'../mask_rcnn/_base_/mask_fpn_reader.yml',
]
weights: output/mask_rcnn_r50_fpn_gn_2x_coco/model_final
MaskRCNN:
backbone: ResNet
neck: FPN
rpn_head: RPNHead
bbox_head: BBoxHead
mask_head: MaskHead
# post process
bbox_post_process: BBoxPostProcess
mask_post_process: MaskPostProcess
FPN:
out_channel: 256
norm_type: gn
BBoxHead:
head: XConvNormHead
roi_extractor:
resolution: 7
sampling_ratio: 0
aligned: True
bbox_assigner: BBoxAssigner
XConvNormHead:
num_convs: 4
mlp_dim: 1024
norm_type: gn
MaskHead:
head: MaskFeat
roi_extractor:
resolution: 14
sampling_ratio: 0
aligned: True
mask_assigner: MaskAssigner
share_bbox_feat: False
MaskFeat:
num_convs: 4
out_channels: 256
norm_type: gn
epoch: 24
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [16, 22]
- !LinearWarmup
start_factor: 0.1
steps: 1000
......@@ -15,7 +15,7 @@
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal, XavierUniform
from paddle.nn.initializer import Normal, XavierUniform, KaimingNormal
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, create
......@@ -24,6 +24,9 @@ from ppdet.modeling import ops
from .roi_extractor import RoIAlign
from ..shape_spec import ShapeSpec
from ..bbox_utils import bbox2delta
from ppdet.modeling.layers import ConvNormLayer
__all__ = ['TwoFCHead', 'XConvNormHead', 'BBoxHead']
@register
......@@ -63,6 +66,86 @@ class TwoFCHead(nn.Layer):
return fc7
@register
class XConvNormHead(nn.Layer):
"""
RCNN bbox head with serveral convolution layers
Args:
in_dim(int): num of channels for the input rois_feat
num_convs(int): num of convolution layers for the rcnn bbox head
conv_dim(int): num of channels for the conv layers
mlp_dim(int): num of channels for the fc layers
resolution(int): resolution of the rois_feat
norm_type(str): norm type, 'gn' by defalut
freeze_norm(bool): whether to freeze the norm
stage_name(str): used in CascadeXConvNormHead, '' by default
"""
__shared__ = ['norm_type', 'freeze_norm']
def __init__(self,
in_dim=256,
num_convs=4,
conv_dim=256,
mlp_dim=1024,
resolution=7,
norm_type='gn',
freeze_norm=False,
stage_name=''):
super(XConvNormHead, self).__init__()
self.in_dim = in_dim
self.num_convs = num_convs
self.conv_dim = conv_dim
self.mlp_dim = mlp_dim
self.norm_type = norm_type
self.freeze_norm = freeze_norm
self.bbox_head_convs = []
fan = conv_dim * 3 * 3
initializer = KaimingNormal(fan_in=fan)
for i in range(self.num_convs):
in_c = in_dim if i == 0 else conv_dim
head_conv_name = stage_name + 'bbox_head_conv{}'.format(i)
head_conv = self.add_sublayer(
head_conv_name,
ConvNormLayer(
ch_in=in_c,
ch_out=conv_dim,
filter_size=3,
stride=1,
norm_type=self.norm_type,
norm_name=head_conv_name + '_norm',
freeze_norm=self.freeze_norm,
initializer=initializer,
name=head_conv_name))
self.bbox_head_convs.append(head_conv)
fan = conv_dim * resolution * resolution
self.fc6 = nn.Linear(
conv_dim * resolution * resolution,
mlp_dim,
weight_attr=paddle.ParamAttr(
initializer=XavierUniform(fan_out=fan)),
bias_attr=paddle.ParamAttr(
learning_rate=2., regularizer=L2Decay(0.)))
@classmethod
def from_config(cls, cfg, input_shape):
s = input_shape
s = s[0] if isinstance(s, (list, tuple)) else s
return {'in_dim': s.channels}
@property
def out_shape(self):
return [ShapeSpec(channels=self.mlp_dim, )]
def forward(self, rois_feat):
for i in range(self.num_convs):
rois_feat = F.relu(self.bbox_head_convs[i](rois_feat))
rois_feat = paddle.flatten(rois_feat, start_axis=1, stop_axis=-1)
fc6 = F.relu(self.fc6(rois_feat))
return fc6
@register
class BBoxHead(nn.Layer):
__shared__ = ['num_classes']
......
......@@ -21,11 +21,13 @@ from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, create
from ppdet.modeling import ops
from .bbox_head import BBoxHead, TwoFCHead
from .bbox_head import BBoxHead, TwoFCHead, XConvNormHead
from .roi_extractor import RoIAlign
from ..shape_spec import ShapeSpec
from ..bbox_utils import bbox2delta, delta2bbox, clip_bbox, nonempty_bbox
__all__ = ['CascadeTwoFCHead', 'CascadeXConvNormHead', 'CascadeHead']
@register
class CascadeTwoFCHead(nn.Layer):
......@@ -62,6 +64,53 @@ class CascadeTwoFCHead(nn.Layer):
return out
@register
class CascadeXConvNormHead(nn.Layer):
__shared__ = ['norm_type', 'freeze_norm', 'num_cascade_stage']
def __init__(self,
in_dim=256,
num_convs=4,
conv_dim=256,
mlp_dim=1024,
resolution=7,
norm_type='gn',
freeze_norm=False,
num_cascade_stage=3):
super(CascadeXConvNormHead, self).__init__()
self.in_dim = in_dim
self.mlp_dim = mlp_dim
self.head_list = []
for stage in range(num_cascade_stage):
head_per_stage = self.add_sublayer(
str(stage),
XConvNormHead(
in_dim,
num_convs,
conv_dim,
mlp_dim,
resolution,
norm_type,
freeze_norm,
stage_name='stage{}_'.format(stage)))
self.head_list.append(head_per_stage)
@classmethod
def from_config(cls, cfg, input_shape):
s = input_shape
s = s[0] if isinstance(s, (list, tuple)) else s
return {'in_dim': s.channels}
@property
def out_shape(self):
return [ShapeSpec(channels=self.mlp_dim, )]
def forward(self, rois_feat, stage=0):
out = self.head_list[stage](rois_feat)
return out
@register
class CascadeHead(BBoxHead):
__shared__ = ['num_classes', 'num_cascade_stages']
......
......@@ -20,33 +20,55 @@ from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, create
from ppdet.modeling import ops
from ppdet.modeling.layers import ConvNormLayer
from .roi_extractor import RoIAlign
@register
class MaskFeat(nn.Layer):
def __init__(self, num_convs=0, in_channels=2048, out_channels=256):
def __init__(self,
num_convs=4,
in_channels=256,
out_channels=256,
norm_type=None):
super(MaskFeat, self).__init__()
self.num_convs = num_convs
self.in_channels = in_channels
self.out_channels = out_channels
self.norm_type = norm_type
fan_conv = out_channels * 3 * 3
fan_deconv = out_channels * 2 * 2
mask_conv = nn.Sequential()
for i in range(self.num_convs):
conv_name = 'mask_inter_feat_{}'.format(i + 1)
mask_conv.add_sublayer(
conv_name,
nn.Conv2D(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
weight_attr=paddle.ParamAttr(
initializer=KaimingNormal(fan_in=fan_conv))))
mask_conv.add_sublayer(conv_name + 'act', nn.ReLU())
if norm_type == 'gn':
for i in range(self.num_convs):
conv_name = 'mask_inter_feat_{}'.format(i + 1)
mask_conv.add_sublayer(
conv_name,
ConvNormLayer(
ch_in=in_channels if i == 0 else out_channels,
ch_out=out_channels,
filter_size=3,
stride=1,
norm_type=self.norm_type,
norm_name=conv_name + '_norm',
initializer=KaimingNormal(fan_in=fan_conv),
name=conv_name))
mask_conv.add_sublayer(conv_name + 'act', nn.ReLU())
else:
for i in range(self.num_convs):
conv_name = 'mask_inter_feat_{}'.format(i + 1)
mask_conv.add_sublayer(
conv_name,
nn.Conv2D(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
weight_attr=paddle.ParamAttr(
initializer=KaimingNormal(fan_in=fan_conv))))
mask_conv.add_sublayer(conv_name + 'act', nn.ReLU())
mask_conv.add_sublayer(
'conv5_mask',
nn.Conv2DTranspose(
......
......@@ -117,11 +117,15 @@ class ConvNormLayer(nn.Layer):
filter_size,
stride,
norm_type='bn',
norm_decay=0.,
norm_groups=32,
use_dcn=False,
norm_name=None,
bias_on=False,
lr_scale=1.,
freeze_norm=False,
initializer=Normal(
mean=0., std=0.01),
name=None):
super(ConvNormLayer, self).__init__()
assert norm_type in ['bn', 'sync_bn', 'gn']
......@@ -144,8 +148,7 @@ class ConvNormLayer(nn.Layer):
groups=1,
weight_attr=ParamAttr(
name=name + "_weight",
initializer=Normal(
mean=0., std=0.01),
initializer=initializer,
learning_rate=1.),
bias_attr=bias_attr)
else:
......@@ -159,25 +162,28 @@ class ConvNormLayer(nn.Layer):
groups=1,
weight_attr=ParamAttr(
name=name + "_weight",
initializer=Normal(
mean=0., std=0.01),
initializer=initializer,
learning_rate=1.),
bias_attr=True,
lr_scale=2.,
regularizer=L2Decay(0.),
regularizer=L2Decay(norm_decay),
name=name)
norm_lr = 0. if freeze_norm else 1.
param_attr = ParamAttr(
name=norm_name + "_scale",
learning_rate=1.,
regularizer=L2Decay(0.))
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
bias_attr = ParamAttr(
name=norm_name + "_offset",
learning_rate=1.,
regularizer=L2Decay(0.))
if norm_type in ['bn', 'sync_bn']:
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
if norm_type == 'bn':
self.norm = nn.BatchNorm2D(
ch_out, weight_attr=param_attr, bias_attr=bias_attr)
elif norm_type == 'sync_bn':
self.norm = nn.SyncBatchNorm(
ch_out, weight_attr=param_attr, bias_attr=bias_attr)
elif norm_type == 'gn':
self.norm = nn.GroupNorm(
num_groups=norm_groups,
......
......@@ -14,13 +14,13 @@
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn import Layer
from paddle.nn import Conv2D
from paddle.nn.initializer import XavierUniform
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import ConvNormLayer
from ..shape_spec import ShapeSpec
__all__ = ['FPN']
......@@ -28,7 +28,7 @@ __all__ = ['FPN']
@register
@serializable
class FPN(Layer):
class FPN(nn.Layer):
def __init__(self,
in_channels,
out_channel,
......@@ -36,8 +36,10 @@ class FPN(Layer):
has_extra_convs=False,
extra_stage=1,
use_c5=True,
norm_type=None,
norm_decay=0.,
freeze_norm=False,
relu_before_extra_convs=True):
super(FPN, self).__init__()
self.out_channel = out_channel
for s in range(extra_stage):
......@@ -47,6 +49,9 @@ class FPN(Layer):
self.extra_stage = extra_stage
self.use_c5 = use_c5
self.relu_before_extra_convs = relu_before_extra_convs
self.norm_type = norm_type
self.norm_decay = norm_decay
self.freeze_norm = freeze_norm
self.lateral_convs = []
self.fpn_convs = []
......@@ -62,26 +67,56 @@ class FPN(Layer):
else:
lateral_name = 'fpn_inner_res{}_sum_lateral'.format(i + 2)
in_c = in_channels[i - st_stage]
lateral = self.add_sublayer(
lateral_name,
Conv2D(
in_channels=in_c,
out_channels=out_channel,
kernel_size=1,
weight_attr=ParamAttr(
initializer=XavierUniform(fan_out=in_c))))
if self.norm_type == 'gn':
lateral = self.add_sublayer(
lateral_name,
ConvNormLayer(
ch_in=in_c,
ch_out=out_channel,
filter_size=1,
stride=1,
norm_type=self.norm_type,
norm_decay=self.norm_decay,
norm_name=lateral_name + '_norm',
freeze_norm=self.freeze_norm,
initializer=XavierUniform(fan_out=in_c),
name=lateral_name))
else:
lateral = self.add_sublayer(
lateral_name,
nn.Conv2D(
in_channels=in_c,
out_channels=out_channel,
kernel_size=1,
weight_attr=ParamAttr(
initializer=XavierUniform(fan_out=in_c))))
self.lateral_convs.append(lateral)
fpn_name = 'fpn_res{}_sum'.format(i + 2)
fpn_conv = self.add_sublayer(
fpn_name,
Conv2D(
in_channels=out_channel,
out_channels=out_channel,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
initializer=XavierUniform(fan_out=fan))))
if self.norm_type == 'gn':
fpn_conv = self.add_sublayer(
fpn_name,
ConvNormLayer(
ch_in=out_channel,
ch_out=out_channel,
filter_size=3,
stride=1,
norm_type=self.norm_type,
norm_decay=self.norm_decay,
norm_name=fpn_name + '_norm',
freeze_norm=self.freeze_norm,
initializer=XavierUniform(fan_out=fan),
name=fpn_name))
else:
fpn_conv = self.add_sublayer(
fpn_name,
nn.Conv2D(
in_channels=out_channel,
out_channels=out_channel,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
initializer=XavierUniform(fan_out=fan))))
self.fpn_convs.append(fpn_conv)
# add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5)
......@@ -93,16 +128,31 @@ class FPN(Layer):
else:
in_c = out_channel
extra_fpn_name = 'fpn_{}'.format(lvl + 2)
extra_fpn_conv = self.add_sublayer(
extra_fpn_name,
Conv2D(
in_channels=in_c,
out_channels=out_channel,
kernel_size=3,
stride=2,
padding=1,
weight_attr=ParamAttr(
initializer=XavierUniform(fan_out=fan))))
if self.norm_type == 'gn':
extra_fpn_conv = self.add_sublayer(
extra_fpn_name,
ConvNormLayer(
ch_in=in_c,
ch_out=out_channel,
filter_size=3,
stride=2,
norm_type=self.norm_type,
norm_decay=self.norm_decay,
norm_name=extra_fpn_name + '_norm',
freeze_norm=self.freeze_norm,
initializer=XavierUniform(fan_out=fan),
name=extra_fpn_name))
else:
extra_fpn_conv = self.add_sublayer(
extra_fpn_name,
nn.Conv2D(
in_channels=in_c,
out_channels=out_channel,
kernel_size=3,
stride=2,
padding=1,
weight_attr=ParamAttr(
initializer=XavierUniform(fan_out=fan))))
self.fpn_convs.append(extra_fpn_conv)
@classmethod
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册