未验证 提交 6e018b8e 编写于 作者: F Feng Ni 提交者: GitHub

[Dygraph] add FCOS-DCN (#2066)

* add fcos dcn mstrain doc and config

* add dcn on resnet and head

* clean code
上级 e333d629
......@@ -13,6 +13,8 @@ FCOS (Fully Convolutional One-Stage Object Detection) is a fast anchor-free obje
| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 |
| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
| ResNet50-FPN | FCOS | 2 | 1x | ---- | 39.6 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/fcos_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/fcos/fcos_r50_fpn_1x_coco.yml) |
| ResNet50-FPN | FCOS+DCN | 2 | 1x | ---- | 44.3 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/fcos_dcn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/fcos/fcos_dcn_r50_fpn_1x_coco.yml) |
| ResNet50-FPN | FCOS+multiscale_train | 2 | 2x | ---- | 42.0 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/fcos_r50_fpn_multiscale_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/fcos/fcos_r50_fpn_multiscale_2x_coco.yml) |
**Notes:**
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/fcos_r50_fpn.yml',
'_base_/optimizer_1x.yml',
'_base_/fcos_reader.yml',
]
weights: output/fcos_dcn_r50_fpn_1x_coco/model_final
ResNet:
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [1,2,3]
num_stages: 4
dcn_v2_stages: [1,2,3]
FCOSHead:
fcos_feat:
name: FCOSFeat
feat_in: 256
feat_out: 256
num_convs: 4
norm_type: "gn"
use_dcn: true
num_classes: 80
fpn_stride: [8, 16, 32, 64, 128]
prior_prob: 0.01
fcos_loss: FCOSLoss
norm_reg_targets: true
centerness_on_reg: true
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/fcos_r50_fpn.yml',
'_base_/optimizer_1x.yml',
'_base_/fcos_reader.yml',
]
weights: output/fcos_r50_fpn_multiscale_2x_coco/model_final
TrainReader:
sample_transforms:
- DecodeOp: {}
- RandomFlipOp: {prob: 0.5}
- NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeImage:
target_size: [640, 672, 704, 736, 768, 800]
max_size: 1333
interp: 1
use_cv2: true
- PermuteOp: {}
batch_transforms:
- PadBatchOp: {pad_to_stride: 128}
- Gt2FCOSTarget:
object_sizes_boundary: [64, 128, 256, 512]
center_sampling_radius: 1.5
downsample_ratios: [8, 16, 32, 64, 128]
norm_reg_targets: True
batch_size: 2
shuffle: true
drop_last: true
epoch: 24
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [16, 22]
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
......@@ -25,6 +25,7 @@ from ppdet.core.workspace import register, serializable
from paddle.regularizer import L2Decay
from .name_adapter import NameAdapter
from numbers import Integral
from ppdet.modeling.layers import DeformableConvV2
__all__ = ['ResNet', 'Res5Head']
......@@ -41,22 +42,36 @@ class ConvNormLayer(nn.Layer):
norm_decay=0.,
freeze_norm=True,
lr=1.0,
dcn_v2=False,
name=None):
super(ConvNormLayer, self).__init__()
assert norm_type in ['bn', 'sync_bn']
self.norm_type = norm_type
self.act = act
self.conv = Conv2D(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=1,
weight_attr=ParamAttr(
learning_rate=lr, name=name + "_weights"),
bias_attr=False)
if not dcn_v2:
self.conv = Conv2D(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=1,
weight_attr=ParamAttr(
learning_rate=lr, name=name + "_weights"),
bias_attr=False)
else:
self.conv = DeformableConvV2(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=1,
weight_attr=ParamAttr(
learning_rate=lr, name=name + '_weights'),
bias_attr=False,
name=name)
bn_name = name_adapter.fix_conv_norm_name(name)
norm_lr = 0. if freeze_norm else lr
......@@ -105,7 +120,8 @@ class BottleNeck(nn.Layer):
lr=1.0,
norm_type='bn',
norm_decay=0.,
freeze_norm=True):
freeze_norm=True,
dcn_v2=False):
super(BottleNeck, self).__init__()
if variant == 'a':
stride1, stride2 = stride, 1
......@@ -153,6 +169,7 @@ class BottleNeck(nn.Layer):
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr,
dcn_v2=dcn_v2,
name=conv_name2)
self.branch2c = ConvNormLayer(
......@@ -193,7 +210,8 @@ class Blocks(nn.Layer):
lr=1.0,
norm_type='bn',
norm_decay=0.,
freeze_norm=True):
freeze_norm=True,
dcn_v2=False):
super(Blocks, self).__init__()
self.blocks = []
......@@ -213,7 +231,8 @@ class Blocks(nn.Layer):
lr=lr,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm))
freeze_norm=freeze_norm,
dcn_v2=dcn_v2))
self.blocks.append(block)
def forward(self, inputs):
......@@ -238,6 +257,7 @@ class ResNet(nn.Layer):
freeze_norm=True,
freeze_at=0,
return_idx=[0, 1, 2, 3],
dcn_v2_stages=[-1],
num_stages=4):
super(ResNet, self).__init__()
self.depth = depth
......@@ -255,6 +275,11 @@ class ResNet(nn.Layer):
self.return_idx = return_idx
self.num_stages = num_stages
if isinstance(dcn_v2_stages, Integral):
dcn_v2_stages = [dcn_v2_stages]
assert max(dcn_v2_stages) < num_stages
self.dcn_v2_stages = dcn_v2_stages
block_nums = ResNet_cfg[depth]
na = NameAdapter(self)
......@@ -304,7 +329,8 @@ class ResNet(nn.Layer):
lr=lr_mult,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm))
freeze_norm=freeze_norm,
dcn_v2=(i in self.dcn_v2_stages)))
self.res_layers.append(res_layer)
def forward(self, inputs):
......
......@@ -30,6 +30,7 @@ from ppdet.core.workspace import register, serializable
from ppdet.py_op.target import generate_rpn_anchor_target, generate_proposal_target, generate_mask_target
from ppdet.py_op.post_process import bbox_post_process
from . import ops
from paddle.vision.ops import DeformConv2D
def _to_list(l):
......@@ -38,6 +39,77 @@ def _to_list(l):
return [l]
class DeformableConvV2(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
weight_attr=None,
bias_attr=None,
lr_scale=1,
regularizer=None,
name=None):
super(DeformableConvV2, self).__init__()
self.offset_channel = 2 * kernel_size**2
self.mask_channel = kernel_size**2
if lr_scale == 1 and regularizer is None:
offset_bias_attr = ParamAttr(
initializer=Constant(0.),
name='{}._conv_offset.bias'.format(name))
else:
offset_bias_attr = ParamAttr(
initializer=Constant(0.),
learning_rate=lr_scale,
regularizer=regularizer,
name='{}._conv_offset.bias'.format(name))
self.conv_offset = nn.Conv2D(
in_channels,
3 * kernel_size**2,
kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
weight_attr=ParamAttr(
initializer=Constant(0.0),
name='{}._conv_offset.weight'.format(name)),
bias_attr=offset_bias_attr)
if bias_attr:
# in FCOS-DCN head, specifically need learning_rate and regularizer
dcn_bias_attr = ParamAttr(
name=name + "_bias",
initializer=Constant(value=0),
regularizer=L2Decay(0.),
learning_rate=2.)
else:
# in ResNet backbone, do not need bias
dcn_bias_attr = False
self.conv_dcn = DeformConv2D(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2 * dilation,
dilation=dilation,
groups=groups,
weight_attr=weight_attr,
bias_attr=dcn_bias_attr)
def forward(self, x):
offset_mask = self.conv_offset(x)
offset, mask = paddle.split(
offset_mask,
num_or_sections=[self.offset_channel, self.mask_channel],
axis=1)
mask = F.sigmoid(mask)
y = self.conv_dcn(x, offset, mask=mask)
return y
class ConvNormLayer(nn.Layer):
def __init__(self,
ch_in,
......@@ -62,19 +134,38 @@ class ConvNormLayer(nn.Layer):
else:
bias_attr = False
self.conv = nn.Conv2D(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=1,
weight_attr=ParamAttr(
name=name + "_weight",
initializer=Normal(
mean=0., std=0.01),
learning_rate=1.),
bias_attr=bias_attr)
if not use_dcn:
self.conv = nn.Conv2D(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=1,
weight_attr=ParamAttr(
name=name + "_weight",
initializer=Normal(
mean=0., std=0.01),
learning_rate=1.),
bias_attr=bias_attr)
else:
# in FCOS-DCN head, specifically need learning_rate and regularizer
self.conv = DeformableConvV2(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=1,
weight_attr=ParamAttr(
name=name + "_weight",
initializer=Normal(
mean=0., std=0.01),
learning_rate=1.),
bias_attr=True,
lr_scale=2.,
regularizer=L2Decay(0.),
name=name)
param_attr = ParamAttr(
name=norm_name + "_scale",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册