未验证 提交 f239d898 编写于 作者: W wangxinxin08 提交者: GitHub

add senet (#2553)

* add senet

* add annotations according to review
上级 7b8c9eab
...@@ -8,7 +8,6 @@ ResNet: ...@@ -8,7 +8,6 @@ ResNet:
depth: 101 depth: 101
groups: 64 groups: 64
base_width: 4 base_width: 4
base_channels: 64
variant: d variant: d
norm_type: bn norm_type: bn
freeze_at: 0 freeze_at: 0
......
...@@ -9,7 +9,6 @@ ResNet: ...@@ -9,7 +9,6 @@ ResNet:
depth: 101 depth: 101
groups: 64 groups: 64
base_width: 4 base_width: 4
base_channels: 64
variant: d variant: d
norm_type: bn norm_type: bn
freeze_at: 0 freeze_at: 0
......
...@@ -10,7 +10,6 @@ ResNet: ...@@ -10,7 +10,6 @@ ResNet:
variant: d variant: d
groups: 64 groups: 64
base_width: 4 base_width: 4
base_channels: 64
norm_type: bn norm_type: bn
freeze_at: 0 freeze_at: 0
return_idx: [0,1,2,3] return_idx: [0,1,2,3]
......
...@@ -10,7 +10,6 @@ ResNet: ...@@ -10,7 +10,6 @@ ResNet:
depth: 101 depth: 101
groups: 64 groups: 64
base_width: 4 base_width: 4
base_channels: 64
variant: d variant: d
norm_type: bn norm_type: bn
freeze_at: 0 freeze_at: 0
......
...@@ -10,7 +10,6 @@ ResNet: ...@@ -10,7 +10,6 @@ ResNet:
depth: 101 depth: 101
groups: 64 groups: 64
base_width: 4 base_width: 4
base_channels: 64
variant: d variant: d
norm_type: bn norm_type: bn
freeze_at: 0 freeze_at: 0
......
...@@ -11,7 +11,6 @@ ResNet: ...@@ -11,7 +11,6 @@ ResNet:
variant: d variant: d
groups: 64 groups: 64
base_width: 4 base_width: 4
base_channels: 64
norm_type: bn norm_type: bn
freeze_at: 0 freeze_at: 0
return_idx: [0,1,2,3] return_idx: [0,1,2,3]
......
...@@ -11,7 +11,6 @@ ResNet: ...@@ -11,7 +11,6 @@ ResNet:
variant: d variant: d
groups: 64 groups: 64
base_width: 4 base_width: 4
base_channels: 64
norm_type: bn norm_type: bn
freeze_at: 0 freeze_at: 0
return_idx: [0,1,2,3] return_idx: [0,1,2,3]
......
...@@ -20,6 +20,7 @@ from . import mobilenet_v3 ...@@ -20,6 +20,7 @@ from . import mobilenet_v3
from . import hrnet from . import hrnet
from . import blazenet from . import blazenet
from . import ghostnet from . import ghostnet
from . import senet
from .vgg import * from .vgg import *
from .resnet import * from .resnet import *
...@@ -29,3 +30,4 @@ from .mobilenet_v3 import * ...@@ -29,3 +30,4 @@ from .mobilenet_v3 import *
from .hrnet import * from .hrnet import *
from .blazenet import * from .blazenet import *
from .ghostnet import * from .ghostnet import *
from .senet import *
...@@ -20,11 +20,12 @@ import paddle.nn as nn ...@@ -20,11 +20,12 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
from paddle.nn.initializer import Uniform
from ppdet.modeling.layers import DeformableConvV2 from ppdet.modeling.layers import DeformableConvV2
from .name_adapter import NameAdapter from .name_adapter import NameAdapter
from ..shape_spec import ShapeSpec from ..shape_spec import ShapeSpec
__all__ = ['ResNet', 'Res5Head'] __all__ = ['ResNet', 'Res5Head', 'Blocks', 'BasicBlock', 'BottleNeck']
ResNet_cfg = { ResNet_cfg = {
18: [2, 2, 2, 2], 18: [2, 2, 2, 2],
...@@ -41,15 +42,13 @@ class ConvNormLayer(nn.Layer): ...@@ -41,15 +42,13 @@ class ConvNormLayer(nn.Layer):
ch_out, ch_out,
filter_size, filter_size,
stride, stride,
name_adapter,
groups=1, groups=1,
act=None, act=None,
norm_type='bn', norm_type='bn',
norm_decay=0., norm_decay=0.,
freeze_norm=True, freeze_norm=True,
lr=1.0, lr=1.0,
dcn_v2=False, dcn_v2=False):
name=None):
super(ConvNormLayer, self).__init__() super(ConvNormLayer, self).__init__()
assert norm_type in ['bn', 'sync_bn'] assert norm_type in ['bn', 'sync_bn']
self.norm_type = norm_type self.norm_type = norm_type
...@@ -63,8 +62,7 @@ class ConvNormLayer(nn.Layer): ...@@ -63,8 +62,7 @@ class ConvNormLayer(nn.Layer):
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(learning_rate=lr),
learning_rate=lr, ),
bias_attr=False) bias_attr=False)
else: else:
self.conv = DeformableConvV2( self.conv = DeformableConvV2(
...@@ -74,12 +72,9 @@ class ConvNormLayer(nn.Layer): ...@@ -74,12 +72,9 @@ class ConvNormLayer(nn.Layer):
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(learning_rate=lr),
learning_rate=lr, ), bias_attr=False)
bias_attr=False,
name=name)
bn_name = name_adapter.fix_conv_norm_name(name)
norm_lr = 0. if freeze_norm else lr norm_lr = 0. if freeze_norm else lr
param_attr = paddle.ParamAttr( param_attr = paddle.ParamAttr(
learning_rate=norm_lr, learning_rate=norm_lr,
...@@ -116,24 +111,58 @@ class ConvNormLayer(nn.Layer): ...@@ -116,24 +111,58 @@ class ConvNormLayer(nn.Layer):
return out return out
class SELayer(nn.Layer):
def __init__(self, ch, reduction_ratio=16):
super(SELayer, self).__init__()
self.pool = nn.AdaptiveAvgPool2D(1)
stdv = 1.0 / math.sqrt(ch)
c_ = ch // reduction_ratio
self.squeeze = nn.Linear(
ch,
c_,
weight_attr=paddle.ParamAttr(initializer=Uniform(-stdv, stdv)),
bias_attr=True)
stdv = 1.0 / math.sqrt(c_)
self.extract = nn.Linear(
c_,
ch,
weight_attr=paddle.ParamAttr(initializer=Uniform(-stdv, stdv)),
bias_attr=True)
def forward(self, inputs):
out = self.pool(inputs)
out = paddle.squeeze(out, axis=[2, 3])
out = self.squeeze(out)
out = F.relu(out)
out = self.extract(out)
out = F.sigmoid(out)
out = paddle.unsqueeze(out, axis=[2, 3])
scale = out * inputs
return scale
class BasicBlock(nn.Layer): class BasicBlock(nn.Layer):
expansion = 1
def __init__(self, def __init__(self,
ch_in, ch_in,
ch_out, ch_out,
stride, stride,
shortcut, shortcut,
name_adapter,
name,
variant='b', variant='b',
groups=1,
base_width=64,
lr=1.0, lr=1.0,
norm_type='bn', norm_type='bn',
norm_decay=0., norm_decay=0.,
freeze_norm=True, freeze_norm=True,
dcn_v2=False): dcn_v2=False,
std_senet=False):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
assert dcn_v2 is False, "Not implemented yet." assert dcn_v2 is False, "Not implemented yet."
conv_name1, conv_name2, shortcut_name = name_adapter.fix_basicblock_name( assert groups == 1 and base_width == 64, 'BasicBlock only supports groups=1 and base_width=64'
name)
self.shortcut = shortcut self.shortcut = shortcut
if not shortcut: if not shortcut:
...@@ -150,54 +179,52 @@ class BasicBlock(nn.Layer): ...@@ -150,54 +179,52 @@ class BasicBlock(nn.Layer):
ch_out=ch_out, ch_out=ch_out,
filter_size=1, filter_size=1,
stride=1, stride=1,
name_adapter=name_adapter,
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay, norm_decay=norm_decay,
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
lr=lr, lr=lr))
name=shortcut_name))
else: else:
self.short = ConvNormLayer( self.short = ConvNormLayer(
ch_in=ch_in, ch_in=ch_in,
ch_out=ch_out, ch_out=ch_out,
filter_size=1, filter_size=1,
stride=stride, stride=stride,
name_adapter=name_adapter,
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay, norm_decay=norm_decay,
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
lr=lr, lr=lr)
name=shortcut_name)
self.branch2a = ConvNormLayer( self.branch2a = ConvNormLayer(
ch_in=ch_in, ch_in=ch_in,
ch_out=ch_out, ch_out=ch_out,
filter_size=3, filter_size=3,
stride=stride, stride=stride,
name_adapter=name_adapter,
act='relu', act='relu',
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay, norm_decay=norm_decay,
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
lr=lr, lr=lr)
name=conv_name1)
self.branch2b = ConvNormLayer( self.branch2b = ConvNormLayer(
ch_in=ch_out, ch_in=ch_out,
ch_out=ch_out, ch_out=ch_out,
filter_size=3, filter_size=3,
stride=1, stride=1,
name_adapter=name_adapter,
act=None, act=None,
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay, norm_decay=norm_decay,
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
lr=lr, lr=lr)
name=conv_name2)
self.std_senet = std_senet
if self.std_senet:
self.se = SELayer(ch_out)
def forward(self, inputs): def forward(self, inputs):
out = self.branch2a(inputs) out = self.branch2a(inputs)
out = self.branch2b(out) out = self.branch2b(out)
if self.std_senet:
out = self.se(out)
if self.shortcut: if self.shortcut:
short = inputs short = inputs
...@@ -211,22 +238,23 @@ class BasicBlock(nn.Layer): ...@@ -211,22 +238,23 @@ class BasicBlock(nn.Layer):
class BottleNeck(nn.Layer): class BottleNeck(nn.Layer):
expansion = 4
def __init__(self, def __init__(self,
ch_in, ch_in,
ch_out, ch_out,
stride, stride,
shortcut, shortcut,
name_adapter,
name,
variant='b', variant='b',
groups=1, groups=1,
base_width=4, base_width=4,
base_channels=64,
lr=1.0, lr=1.0,
norm_type='bn', norm_type='bn',
norm_decay=0., norm_decay=0.,
freeze_norm=True, freeze_norm=True,
dcn_v2=False): dcn_v2=False,
std_senet=False):
super(BottleNeck, self).__init__() super(BottleNeck, self).__init__()
if variant == 'a': if variant == 'a':
stride1, stride2 = stride, 1 stride1, stride2 = stride, 1
...@@ -234,15 +262,7 @@ class BottleNeck(nn.Layer): ...@@ -234,15 +262,7 @@ class BottleNeck(nn.Layer):
stride1, stride2 = 1, stride stride1, stride2 = 1, stride
# ResNeXt # ResNeXt
if groups == 1: width = int(ch_out * (base_width / 64.)) * groups
width = ch_out
else:
width = int(
math.floor(ch_out * (base_width * 1.0 / base_channels)) *
groups)
conv_name1, conv_name2, conv_name3, \
shortcut_name = name_adapter.fix_bottleneck_name(name)
self.shortcut = shortcut self.shortcut = shortcut
if not shortcut: if not shortcut:
...@@ -256,75 +276,73 @@ class BottleNeck(nn.Layer): ...@@ -256,75 +276,73 @@ class BottleNeck(nn.Layer):
'conv', 'conv',
ConvNormLayer( ConvNormLayer(
ch_in=ch_in, ch_in=ch_in,
ch_out=ch_out * 4, ch_out=ch_out * self.expansion,
filter_size=1, filter_size=1,
stride=1, stride=1,
name_adapter=name_adapter,
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay, norm_decay=norm_decay,
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
lr=lr, lr=lr))
name=shortcut_name))
else: else:
self.short = ConvNormLayer( self.short = ConvNormLayer(
ch_in=ch_in, ch_in=ch_in,
ch_out=ch_out * 4, ch_out=ch_out * self.expansion,
filter_size=1, filter_size=1,
stride=stride, stride=stride,
name_adapter=name_adapter,
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay, norm_decay=norm_decay,
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
lr=lr, lr=lr)
name=shortcut_name)
self.branch2a = ConvNormLayer( self.branch2a = ConvNormLayer(
ch_in=ch_in, ch_in=ch_in,
ch_out=width, ch_out=width,
filter_size=1, filter_size=1,
stride=stride1, stride=stride1,
name_adapter=name_adapter,
groups=1, groups=1,
act='relu', act='relu',
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay, norm_decay=norm_decay,
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
lr=lr, lr=lr)
name=conv_name1)
self.branch2b = ConvNormLayer( self.branch2b = ConvNormLayer(
ch_in=width, ch_in=width,
ch_out=width, ch_out=width,
filter_size=3, filter_size=3,
stride=stride2, stride=stride2,
name_adapter=name_adapter,
groups=groups, groups=groups,
act='relu', act='relu',
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay, norm_decay=norm_decay,
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
lr=lr, lr=lr,
dcn_v2=dcn_v2, dcn_v2=dcn_v2)
name=conv_name2)
self.branch2c = ConvNormLayer( self.branch2c = ConvNormLayer(
ch_in=width, ch_in=width,
ch_out=ch_out * 4, ch_out=ch_out * self.expansion,
filter_size=1, filter_size=1,
stride=1, stride=1,
name_adapter=name_adapter,
groups=1, groups=1,
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay, norm_decay=norm_decay,
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
lr=lr, lr=lr)
name=conv_name3)
self.std_senet = std_senet
if self.std_senet:
self.se = SELayer(ch_out * self.expansion)
def forward(self, inputs): def forward(self, inputs):
out = self.branch2a(inputs) out = self.branch2a(inputs)
out = self.branch2b(out) out = self.branch2b(out)
out = self.branch2c(out) out = self.branch2c(out)
if self.std_senet:
out = self.se(out)
if self.shortcut: if self.shortcut:
short = inputs short = inputs
else: else:
...@@ -338,7 +356,7 @@ class BottleNeck(nn.Layer): ...@@ -338,7 +356,7 @@ class BottleNeck(nn.Layer):
class Blocks(nn.Layer): class Blocks(nn.Layer):
def __init__(self, def __init__(self,
depth, block,
ch_in, ch_in,
ch_out, ch_out,
count, count,
...@@ -346,55 +364,37 @@ class Blocks(nn.Layer): ...@@ -346,55 +364,37 @@ class Blocks(nn.Layer):
stage_num, stage_num,
variant='b', variant='b',
groups=1, groups=1,
base_width=-1, base_width=64,
base_channels=-1,
lr=1.0, lr=1.0,
norm_type='bn', norm_type='bn',
norm_decay=0., norm_decay=0.,
freeze_norm=True, freeze_norm=True,
dcn_v2=False): dcn_v2=False,
std_senet=False):
super(Blocks, self).__init__() super(Blocks, self).__init__()
self.blocks = [] self.blocks = []
for i in range(count): for i in range(count):
conv_name = name_adapter.fix_layer_warp_name(stage_num, count, i) conv_name = name_adapter.fix_layer_warp_name(stage_num, count, i)
if depth >= 50: layer = self.add_sublayer(
block = self.add_sublayer( conv_name,
conv_name, block(
BottleNeck( ch_in=ch_in,
ch_in=ch_in if i == 0 else ch_out * 4, ch_out=ch_out,
ch_out=ch_out, stride=2 if i == 0 and stage_num != 2 else 1,
stride=2 if i == 0 and stage_num != 2 else 1, shortcut=False if i == 0 else True,
shortcut=False if i == 0 else True, variant=variant,
name_adapter=name_adapter, groups=groups,
name=conv_name, base_width=base_width,
variant=variant, lr=lr,
groups=groups, norm_type=norm_type,
base_width=base_width, norm_decay=norm_decay,
base_channels=base_channels, freeze_norm=freeze_norm,
lr=lr, dcn_v2=dcn_v2,
norm_type=norm_type, std_senet=std_senet))
norm_decay=norm_decay, self.blocks.append(layer)
freeze_norm=freeze_norm, if i == 0:
dcn_v2=dcn_v2)) ch_in = ch_out * block.expansion
else:
ch_in = ch_in // 4 if i > 0 else ch_in
block = self.add_sublayer(
conv_name,
BasicBlock(
ch_in=ch_in if i == 0 else ch_out,
ch_out=ch_out,
stride=2 if i == 0 and stage_num != 2 else 1,
shortcut=False if i == 0 else True,
name_adapter=name_adapter,
name=conv_name,
variant=variant,
lr=lr,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
dcn_v2=dcn_v2))
self.blocks.append(block)
def forward(self, inputs): def forward(self, inputs):
block_out = inputs block_out = inputs
...@@ -410,23 +410,47 @@ class ResNet(nn.Layer): ...@@ -410,23 +410,47 @@ class ResNet(nn.Layer):
def __init__(self, def __init__(self,
depth=50, depth=50,
ch_in=64,
variant='b', variant='b',
lr_mult_list=[1.0, 1.0, 1.0, 1.0], lr_mult_list=[1.0, 1.0, 1.0, 1.0],
groups=1, groups=1,
base_width=-1, base_width=64,
base_channels=-1,
norm_type='bn', norm_type='bn',
norm_decay=0, norm_decay=0,
freeze_norm=True, freeze_norm=True,
freeze_at=0, freeze_at=0,
return_idx=[0, 1, 2, 3], return_idx=[0, 1, 2, 3],
dcn_v2_stages=[-1], dcn_v2_stages=[-1],
num_stages=4): num_stages=4,
std_senet=False):
"""
Residual Network, see https://arxiv.org/abs/1512.03385
Args:
depth (int): ResNet depth, should be 18, 34, 50, 101, 152.
ch_in (int): output channel of first stage, default 64
variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
lr_mult_list (list): learning rate ratio of different resnet stages(2,3,4,5),
lower learning rate ratio is need for pretrained model
got using distillation(default as [1.0, 1.0, 1.0, 1.0]).
groups (int): group convolution cardinality
base_width (int): base width of each group convolution
norm_type (str): normalization type, 'bn', 'sync_bn' or 'affine_channel'
norm_decay (float): weight decay for normalization layer weights
freeze_norm (bool): freeze normalization layers
freeze_at (int): freeze the backbone at which stage
return_idx (list): index of the stages whose feature maps are returned
dcn_v2_stages (list): index of stages who select deformable conv v2
num_stages (int): total num of stages
std_senet (bool): whether use senet, default True
"""
super(ResNet, self).__init__() super(ResNet, self).__init__()
self._model_type = 'ResNet' if groups == 1 else 'ResNeXt' self._model_type = 'ResNet' if groups == 1 else 'ResNeXt'
assert num_stages >= 1 and num_stages <= 4 assert num_stages >= 1 and num_stages <= 4
self.depth = depth self.depth = depth
self.variant = variant self.variant = variant
self.groups = groups
self.base_width = base_width
self.norm_type = norm_type self.norm_type = norm_type
self.norm_decay = norm_decay self.norm_decay = norm_decay
self.freeze_norm = freeze_norm self.freeze_norm = freeze_norm
...@@ -456,12 +480,12 @@ class ResNet(nn.Layer): ...@@ -456,12 +480,12 @@ class ResNet(nn.Layer):
conv1_name = na.fix_c1_stage_name() conv1_name = na.fix_c1_stage_name()
if variant in ['c', 'd']: if variant in ['c', 'd']:
conv_def = [ conv_def = [
[3, 32, 3, 2, "conv1_1"], [3, ch_in // 2, 3, 2, "conv1_1"],
[32, 32, 3, 1, "conv1_2"], [ch_in // 2, ch_in // 2, 3, 1, "conv1_2"],
[32, 64, 3, 1, "conv1_3"], [ch_in // 2, ch_in, 3, 1, "conv1_3"],
] ]
else: else:
conv_def = [[3, 64, 7, 2, conv1_name]] conv_def = [[3, ch_in, 7, 2, conv1_name]]
self.conv1 = nn.Sequential() self.conv1 = nn.Sequential()
for (c_in, c_out, k, s, _name) in conv_def: for (c_in, c_out, k, s, _name) in conv_def:
self.conv1.add_sublayer( self.conv1.add_sublayer(
...@@ -471,20 +495,18 @@ class ResNet(nn.Layer): ...@@ -471,20 +495,18 @@ class ResNet(nn.Layer):
ch_out=c_out, ch_out=c_out,
filter_size=k, filter_size=k,
stride=s, stride=s,
name_adapter=na,
groups=1, groups=1,
act='relu', act='relu',
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay, norm_decay=norm_decay,
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
lr=1.0, lr=1.0))
name=_name))
ch_in_list = [64, 256, 512, 1024] self.ch_in = ch_in
ch_out_list = [64, 128, 256, 512] ch_out_list = [64, 128, 256, 512]
self.expansion = 4 if depth >= 50 else 1 block = BottleNeck if depth >= 50 else BasicBlock
self._out_channels = [self.expansion * v for v in ch_out_list] self._out_channels = [block.expansion * v for v in ch_out_list]
self._out_strides = [4, 8, 16, 32] self._out_strides = [4, 8, 16, 32]
self.res_layers = [] self.res_layers = []
...@@ -495,9 +517,8 @@ class ResNet(nn.Layer): ...@@ -495,9 +517,8 @@ class ResNet(nn.Layer):
res_layer = self.add_sublayer( res_layer = self.add_sublayer(
res_name, res_name,
Blocks( Blocks(
depth, block,
ch_in_list[i] // 4 self.ch_in,
if i > 0 and depth < 50 else ch_in_list[i],
ch_out_list[i], ch_out_list[i],
count=block_nums[i], count=block_nums[i],
name_adapter=na, name_adapter=na,
...@@ -505,13 +526,14 @@ class ResNet(nn.Layer): ...@@ -505,13 +526,14 @@ class ResNet(nn.Layer):
variant=variant, variant=variant,
groups=groups, groups=groups,
base_width=base_width, base_width=base_width,
base_channels=base_channels,
lr=lr_mult, lr=lr_mult,
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay, norm_decay=norm_decay,
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
dcn_v2=(i in self.dcn_v2_stages))) dcn_v2=(i in self.dcn_v2_stages),
std_senet=std_senet))
self.res_layers.append(res_layer) self.res_layers.append(res_layer)
self.ch_in = self._out_channels[i]
@property @property
def out_shape(self): def out_shape(self):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from .resnet import ResNet, Blocks, BasicBlock, BottleNeck
__all__ = ['SENet', 'SERes5Head']
@register
@serializable
class SENet(ResNet):
__shared__ = ['norm_type']
def __init__(self,
depth=50,
variant='b',
lr_mult_list=[1.0, 1.0, 1.0, 1.0],
groups=1,
base_width=64,
norm_type='bn',
norm_decay=0,
freeze_norm=True,
freeze_at=0,
return_idx=[0, 1, 2, 3],
dcn_v2_stages=[-1],
std_senet=True,
num_stages=4):
"""
Squeeze-and-Excitation Networks, see https://arxiv.org/abs/1709.01507
Args:
depth (int): SENet depth, should be 50, 101, 152
variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
lr_mult_list (list): learning rate ratio of different resnet stages(2,3,4,5),
lower learning rate ratio is need for pretrained model
got using distillation(default as [1.0, 1.0, 1.0, 1.0]).
groups (int): group convolution cardinality
base_width (int): base width of each group convolution
norm_type (str): normalization type, 'bn', 'sync_bn' or 'affine_channel'
norm_decay (float): weight decay for normalization layer weights
freeze_norm (bool): freeze normalization layers
freeze_at (int): freeze the backbone at which stage
return_idx (list): index of the stages whose feature maps are returned
dcn_v2_stages (list): index of stages who select deformable conv v2
std_senet (bool): whether use senet, default True
num_stages (int): total num of stages
"""
super(SENet, self).__init__(
depth=depth,
variant=variant,
lr_mult_list=lr_mult_list,
ch_in=128,
groups=groups,
base_width=base_width,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
freeze_at=freeze_at,
return_idx=return_idx,
dcn_v2_stages=dcn_v2_stages,
std_senet=std_senet,
num_stages=num_stages)
@register
class SERes5Head(nn.Layer):
def __init__(self,
depth=50,
variant='b',
lr_mult=1.0,
groups=1,
base_width=64,
norm_type='bn',
norm_decay=0,
dcn_v2=False,
freeze_norm=False,
std_senet=True):
"""
SERes5Head layer
Args:
depth (int): SENet depth, should be 50, 101, 152
variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
lr_mult (list): learning rate ratio of SERes5Head, default as 1.0.
groups (int): group convolution cardinality
base_width (int): base width of each group convolution
norm_type (str): normalization type, 'bn', 'sync_bn' or 'affine_channel'
norm_decay (float): weight decay for normalization layer weights
dcn_v2_stages (list): index of stages who select deformable conv v2
std_senet (bool): whether use senet, default True
"""
super(SERes5Head, self).__init__()
ch_out = 512
ch_in = 256 if depth < 50 else 1024
na = NameAdapter(self)
block = BottleNeck if depth >= 50 else BasicBlock
self.res5 = Blocks(
block,
ch_in,
ch_out,
count=3,
name_adapter=na,
stage_num=5,
variant=variant,
groups=groups,
base_width=base_width,
lr=lr_mult,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
dcn_v2=dcn_v2,
std_senet=std_senet)
self.ch_out = ch_out * block.expansion
@property
def out_shape(self):
return [ShapeSpec(
channels=self.ch_out,
stride=16, )]
def forward(self, roi_feat):
y = self.res5(roi_feat)
return y
...@@ -52,31 +52,25 @@ class DeformableConvV2(nn.Layer): ...@@ -52,31 +52,25 @@ class DeformableConvV2(nn.Layer):
bias_attr=None, bias_attr=None,
lr_scale=1, lr_scale=1,
regularizer=None, regularizer=None,
skip_quant=False, skip_quant=False):
name=None):
super(DeformableConvV2, self).__init__() super(DeformableConvV2, self).__init__()
self.offset_channel = 2 * kernel_size**2 self.offset_channel = 2 * kernel_size**2
self.mask_channel = kernel_size**2 self.mask_channel = kernel_size**2
if lr_scale == 1 and regularizer is None: if lr_scale == 1 and regularizer is None:
offset_bias_attr = ParamAttr( offset_bias_attr = ParamAttr(initializer=Constant(0.))
initializer=Constant(0.),
name='{}._conv_offset.bias'.format(name))
else: else:
offset_bias_attr = ParamAttr( offset_bias_attr = ParamAttr(
initializer=Constant(0.), initializer=Constant(0.),
learning_rate=lr_scale, learning_rate=lr_scale,
regularizer=regularizer, regularizer=regularizer)
name='{}._conv_offset.bias'.format(name))
self.conv_offset = nn.Conv2D( self.conv_offset = nn.Conv2D(
in_channels, in_channels,
3 * kernel_size**2, 3 * kernel_size**2,
kernel_size, kernel_size,
stride=stride, stride=stride,
padding=(kernel_size - 1) // 2, padding=(kernel_size - 1) // 2,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=Constant(0.0)),
initializer=Constant(0.0),
name='{}._conv_offset.weight'.format(name)),
bias_attr=offset_bias_attr) bias_attr=offset_bias_attr)
if skip_quant: if skip_quant:
self.conv_offset.skip_quant = True self.conv_offset.skip_quant = True
...@@ -84,7 +78,6 @@ class DeformableConvV2(nn.Layer): ...@@ -84,7 +78,6 @@ class DeformableConvV2(nn.Layer):
if bias_attr: if bias_attr:
# in FCOS-DCN head, specifically need learning_rate and regularizer # in FCOS-DCN head, specifically need learning_rate and regularizer
dcn_bias_attr = ParamAttr( dcn_bias_attr = ParamAttr(
name=name + "_bias",
initializer=Constant(value=0), initializer=Constant(value=0),
regularizer=L2Decay(0.), regularizer=L2Decay(0.),
learning_rate=2.) learning_rate=2.)
......
...@@ -43,8 +43,7 @@ class Upsample(nn.Layer): ...@@ -43,8 +43,7 @@ class Upsample(nn.Layer):
regularizer=L2Decay(0.), regularizer=L2Decay(0.),
learning_rate=2.), learning_rate=2.),
lr_scale=2., lr_scale=2.,
regularizer=L2Decay(0.), regularizer=L2Decay(0.))
name=name)
self.bn = batch_norm( self.bn = batch_norm(
ch_out, norm_type='bn', initializer=Constant(1.), name=name) ch_out, norm_type='bn', initializer=Constant(1.), name=name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册