提交 47c31c15 编写于 作者: S sunyanfang01

add ssld

上级 c5a1dc18
...@@ -21,11 +21,15 @@ ResNet50 = cv.models.ResNet50 ...@@ -21,11 +21,15 @@ ResNet50 = cv.models.ResNet50
ResNet101 = cv.models.ResNet101 ResNet101 = cv.models.ResNet101
ResNet50_vd = cv.models.ResNet50_vd ResNet50_vd = cv.models.ResNet50_vd
ResNet101_vd = cv.models.ResNet101_vd ResNet101_vd = cv.models.ResNet101_vd
ResNet50_vd_ssld = cv.models.ResNet50_vd_ssld
ResNet101_vd_ssld = cv.models.ResNet101_vd_ssld
DarkNet53 = cv.models.DarkNet53 DarkNet53 = cv.models.DarkNet53
MobileNetV1 = cv.models.MobileNetV1 MobileNetV1 = cv.models.MobileNetV1
MobileNetV2 = cv.models.MobileNetV2 MobileNetV2 = cv.models.MobileNetV2
MobileNetV3_small = cv.models.MobileNetV3_small MobileNetV3_small = cv.models.MobileNetV3_small
MobileNetV3_large = cv.models.MobileNetV3_large MobileNetV3_large = cv.models.MobileNetV3_large
MobileNetV3_small_ssld = cv.models.MobileNetV3_small_ssld
MobileNetV3_large_ssld = cv.models.MobileNetV3_large_ssld
Xception41 = cv.models.Xception41 Xception41 = cv.models.Xception41
Xception65 = cv.models.Xception65 Xception65 = cv.models.Xception65
DenseNet121 = cv.models.DenseNet121 DenseNet121 = cv.models.DenseNet121
......
...@@ -19,11 +19,15 @@ from .classifier import ResNet50 ...@@ -19,11 +19,15 @@ from .classifier import ResNet50
from .classifier import ResNet101 from .classifier import ResNet101
from .classifier import ResNet50_vd from .classifier import ResNet50_vd
from .classifier import ResNet101_vd from .classifier import ResNet101_vd
from .classifier import ResNet50_vd_ssld
from .classifier import ResNet101_vd_ssld
from .classifier import DarkNet53 from .classifier import DarkNet53
from .classifier import MobileNetV1 from .classifier import MobileNetV1
from .classifier import MobileNetV2 from .classifier import MobileNetV2
from .classifier import MobileNetV3_small from .classifier import MobileNetV3_small
from .classifier import MobileNetV3_large from .classifier import MobileNetV3_large
from .classifier import MobileNetV3_small_ssld
from .classifier import MobileNetV3_large_ssld
from .classifier import Xception41 from .classifier import Xception41
from .classifier import Xception65 from .classifier import Xception65
from .classifier import DenseNet121 from .classifier import DenseNet121
......
...@@ -302,6 +302,17 @@ class ResNet101_vd(BaseClassifier): ...@@ -302,6 +302,17 @@ class ResNet101_vd(BaseClassifier):
model_name='ResNet101_vd', num_classes=num_classes) model_name='ResNet101_vd', num_classes=num_classes)
class ResNet50_vd_ssld(BaseClassifier):
def __init__(self, num_classes=1000):
super(ResNet50_vd_ssld, self).__init__(model_name='ResNet50_vd_ssld',
num_classes=num_classes)
class ResNet101_vd_ssld(BaseClassifier):
def __init__(self, num_classes=1000):
super(ResNet101_vd_ssld, self).__init__(model_name='ResNet101_vd_ssld',
num_classes=num_classes)
class DarkNet53(BaseClassifier): class DarkNet53(BaseClassifier):
def __init__(self, num_classes=1000): def __init__(self, num_classes=1000):
super(DarkNet53, self).__init__( super(DarkNet53, self).__init__(
...@@ -332,6 +343,19 @@ class MobileNetV3_large(BaseClassifier): ...@@ -332,6 +343,19 @@ class MobileNetV3_large(BaseClassifier):
model_name='MobileNetV3_large', num_classes=num_classes) model_name='MobileNetV3_large', num_classes=num_classes)
class MobileNetV3_small_ssld(BaseClassifier):
def __init__(self, num_classes=1000):
super(MobileNetV3_small_ssld, self).__init__(model_name='MobileNetV3_small_ssld',
num_classes=num_classes)
class MobileNetV3_large_ssld(BaseClassifier):
def __init__(self, num_classes=1000):
super(MobileNetV3_large_ssld, self).__init__(model_name='MobileNetV3_large_ssld',
num_classes=num_classes)
class Xception65(BaseClassifier): class Xception65(BaseClassifier):
def __init__(self, num_classes=1000): def __init__(self, num_classes=1000):
super(Xception65, self).__init__( super(Xception65, self).__init__(
......
...@@ -16,6 +16,10 @@ image_pretrain = { ...@@ -16,6 +16,10 @@ image_pretrain = {
'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar', 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar',
'ResNet101_vd': 'ResNet101_vd':
'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar', 'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar',
'ResNet50_vd_ssld':
'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar',
'ResNet101_vd_ssld':
'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_ssld_pretrained.tar',
'MobileNetV1': 'MobileNetV1':
'http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar', 'http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar',
'MobileNetV2_x1.0': 'MobileNetV2_x1.0':
...@@ -32,6 +36,10 @@ image_pretrain = { ...@@ -32,6 +36,10 @@ image_pretrain = {
'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_pretrained.tar', 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_pretrained.tar',
'MobileNetV3_large': 'MobileNetV3_large':
'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_pretrained.tar', 'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_pretrained.tar',
'MobileNetV3_small_x1_0_ssld':
'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_ssld_pretrained.tar',
'MobileNetV3_large_x1_0_ssld':
'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar',
'DarkNet53': 'DarkNet53':
'https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_ImageNet1k_pretrained.tar', 'https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_ImageNet1k_pretrained.tar',
'DenseNet121': 'DenseNet121':
...@@ -68,6 +76,10 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir): ...@@ -68,6 +76,10 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
backbone = 'Seg{}'.format(backbone) backbone = 'Seg{}'.format(backbone)
elif backbone == 'MobileNetV2': elif backbone == 'MobileNetV2':
backbone = 'MobileNetV2_x1.0' backbone = 'MobileNetV2_x1.0'
elif backbone == 'MobileNetV3_small_ssld':
backbone = 'MobileNetV3_small_x1_0_ssld'
elif backbone == 'MobileNetV3_large_ssld':
backbone = 'MobileNetV3_large_x1_0_ssld'
if model_type == 'detector': if model_type == 'detector':
if backbone == 'ResNet50': if backbone == 'ResNet50':
backbone = 'DetResNet50' backbone = 'DetResNet50'
......
...@@ -50,6 +50,18 @@ def resnet50_vd(input, num_classes=1000): ...@@ -50,6 +50,18 @@ def resnet50_vd(input, num_classes=1000):
return model(input) return model(input)
def resnet50_vd_ssld(input, num_classes=1000):
model = ResNet(layers=50, num_classes=num_classes,
variant='d', lr_mult_list=[1.0, 0.1, 0.2, 0.2, 0.3])
return model(input)
def resnet101_vd_ssld(input, num_classes=1000):
model = ResNet(layers=101, num_classes=num_classes,
variant='d', lr_mult_list=[1.0, 0.1, 0.2, 0.2, 0.3])
return model(input)
def resnet101_vd(input, num_classes=1000): def resnet101_vd(input, num_classes=1000):
model = ResNet(layers=101, num_classes=num_classes, variant='d') model = ResNet(layers=101, num_classes=num_classes, variant='d')
return model(input) return model(input)
...@@ -80,6 +92,18 @@ def mobilenetv3_large(input, num_classes=1000): ...@@ -80,6 +92,18 @@ def mobilenetv3_large(input, num_classes=1000):
return model(input) return model(input)
def mobilenetv3_small_ssld(input, num_classes=1000):
model = MobileNetV3(num_classes=num_classes, model_name='small',
lr_mult_list=[0.25, 0.25, 0.5, 0.5, 0.75])
return model(input)
def mobilenetv3_large_ssld(input, num_classes=1000):
model = MobileNetV3(num_classes=num_classes, model_name='large',
lr_mult_list=[0.25, 0.25, 0.5, 0.5, 0.75])
return model(input)
def xception65(input, num_classes=1000): def xception65(input, num_classes=1000):
model = Xception(layers=65, num_classes=num_classes) model = Xception(layers=65, num_classes=num_classes)
return model(input) return model(input)
...@@ -109,7 +133,6 @@ def densenet201(input, num_classes=1000): ...@@ -109,7 +133,6 @@ def densenet201(input, num_classes=1000):
model = DenseNet(layers=201, num_classes=num_classes) model = DenseNet(layers=201, num_classes=num_classes)
return model(input) return model(input)
def shufflenetv2(input, num_classes=1000): def shufflenetv2(input, num_classes=1000):
model = ShuffleNetV2(num_classes=num_classes) model = ShuffleNetV2(num_classes=num_classes)
return model(input) return model(input)
...@@ -31,7 +31,6 @@ class MobileNetV3(): ...@@ -31,7 +31,6 @@ class MobileNetV3():
with_extra_blocks (bool): if extra blocks should be added. with_extra_blocks (bool): if extra blocks should be added.
extra_block_filters (list): number of filter for each extra block. extra_block_filters (list): number of filter for each extra block.
""" """
def __init__(self, def __init__(self,
scale=1.0, scale=1.0,
model_name='small', model_name='small',
...@@ -41,7 +40,11 @@ class MobileNetV3(): ...@@ -41,7 +40,11 @@ class MobileNetV3():
norm_decay=0.0, norm_decay=0.0,
extra_block_filters=[[256, 512], [128, 256], [128, 256], extra_block_filters=[[256, 512], [128, 256], [128, 256],
[64, 128]], [64, 128]],
num_classes=None): num_classes=None,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]):
assert len(lr_mult_list) == 5, \
"lr_mult_list length in MobileNetV3 must be 5 but got {}!!".format(
len(lr_mult_list))
self.scale = scale self.scale = scale
self.with_extra_blocks = with_extra_blocks self.with_extra_blocks = with_extra_blocks
self.extra_block_filters = extra_block_filters self.extra_block_filters = extra_block_filters
...@@ -51,6 +54,8 @@ class MobileNetV3(): ...@@ -51,6 +54,8 @@ class MobileNetV3():
self.end_points = [] self.end_points = []
self.block_stride = 1 self.block_stride = 1
self.num_classes = num_classes self.num_classes = num_classes
self.lr_mult_list = lr_mult_list
self.curr_stage = 0
if model_name == "large": if model_name == "large":
self.cfg = [ self.cfg = [
# kernel_size, expand, channel, se_block, act_mode, stride # kernel_size, expand, channel, se_block, act_mode, stride
...@@ -72,6 +77,7 @@ class MobileNetV3(): ...@@ -72,6 +77,7 @@ class MobileNetV3():
] ]
self.cls_ch_squeeze = 960 self.cls_ch_squeeze = 960
self.cls_ch_expand = 1280 self.cls_ch_expand = 1280
self.lr_interval = 3
elif model_name == "small": elif model_name == "small":
self.cfg = [ self.cfg = [
# kernel_size, expand, channel, se_block, act_mode, stride # kernel_size, expand, channel, se_block, act_mode, stride
...@@ -89,6 +95,7 @@ class MobileNetV3(): ...@@ -89,6 +95,7 @@ class MobileNetV3():
] ]
self.cls_ch_squeeze = 576 self.cls_ch_squeeze = 576
self.cls_ch_expand = 1280 self.cls_ch_expand = 1280
self.lr_interval = 2
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -103,10 +110,13 @@ class MobileNetV3(): ...@@ -103,10 +110,13 @@ class MobileNetV3():
act=None, act=None,
name=None, name=None,
use_cudnn=True): use_cudnn=True):
conv_param_attr = ParamAttr( lr_idx = self.curr_stage // self.lr_interval
name=name + '_weights', regularizer=L2Decay(self.conv_decay)) lr_idx = min(lr_idx, len(self.lr_mult_list) - 1)
conv = fluid.layers.conv2d( lr_mult = self.lr_mult_list[lr_idx]
input=input, conv_param_attr = ParamAttr(name=name + '_weights',
learning_rate=lr_mult,
regularizer=L2Decay(self.conv_decay))
conv = fluid.layers.conv2d(input=input,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,
...@@ -117,12 +127,11 @@ class MobileNetV3(): ...@@ -117,12 +127,11 @@ class MobileNetV3():
param_attr=conv_param_attr, param_attr=conv_param_attr,
bias_attr=False) bias_attr=False)
bn_name = name + '_bn' bn_name = name + '_bn'
bn_param_attr = ParamAttr( bn_param_attr = ParamAttr(name=bn_name + "_scale",
name=bn_name + "_scale", regularizer=L2Decay(self.norm_decay)) regularizer=L2Decay(self.norm_decay))
bn_bias_attr = ParamAttr( bn_bias_attr = ParamAttr(name=bn_name + "_offset",
name=bn_name + "_offset", regularizer=L2Decay(self.norm_decay)) regularizer=L2Decay(self.norm_decay))
bn = fluid.layers.batch_norm( bn = fluid.layers.batch_norm(input=conv,
input=conv,
param_attr=bn_param_attr, param_attr=bn_param_attr,
bias_attr=bn_bias_attr, bias_attr=bn_bias_attr,
moving_mean_name=bn_name + '_mean', moving_mean_name=bn_name + '_mean',
...@@ -140,23 +149,33 @@ class MobileNetV3(): ...@@ -140,23 +149,33 @@ class MobileNetV3():
return x * fluid.layers.relu6(x + 3) / 6. return x * fluid.layers.relu6(x + 3) / 6.
def _se_block(self, input, num_out_filter, ratio=4, name=None): def _se_block(self, input, num_out_filter, ratio=4, name=None):
lr_idx = self.curr_stage // self.lr_interval
lr_idx = min(lr_idx, len(self.lr_mult_list) - 1)
lr_mult = self.lr_mult_list[lr_idx]
num_mid_filter = int(num_out_filter // ratio) num_mid_filter = int(num_out_filter // ratio)
pool = fluid.layers.pool2d( pool = fluid.layers.pool2d(input=input,
input=input, pool_type='avg', global_pooling=True, use_cudnn=False) pool_type='avg',
global_pooling=True,
use_cudnn=False)
conv1 = fluid.layers.conv2d( conv1 = fluid.layers.conv2d(
input=pool, input=pool,
filter_size=1, filter_size=1,
num_filters=num_mid_filter, num_filters=num_mid_filter,
act='relu', act='relu',
param_attr=ParamAttr(name=name + '_1_weights'), param_attr=ParamAttr(
bias_attr=ParamAttr(name=name + '_1_offset')) name=name + '_1_weights', learning_rate=lr_mult),
bias_attr=ParamAttr(
name=name + '_1_offset', learning_rate=lr_mult))
conv2 = fluid.layers.conv2d( conv2 = fluid.layers.conv2d(
input=conv1, input=conv1,
filter_size=1, filter_size=1,
num_filters=num_out_filter, num_filters=num_out_filter,
act='hard_sigmoid', act='hard_sigmoid',
param_attr=ParamAttr(name=name + '_2_weights'), param_attr=ParamAttr(
bias_attr=ParamAttr(name=name + '_2_offset')) name=name + '_2_weights', learning_rate=lr_mult),
bias_attr=ParamAttr(
name=name + '_2_offset', learning_rate=lr_mult))
scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0) scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0)
return scale return scale
...@@ -172,8 +191,7 @@ class MobileNetV3(): ...@@ -172,8 +191,7 @@ class MobileNetV3():
use_se=False, use_se=False,
name=None): name=None):
input_data = input input_data = input
conv0 = self._conv_bn_layer( conv0 = self._conv_bn_layer(input=input,
input=input,
filter_size=1, filter_size=1,
num_filters=num_mid_filter, num_filters=num_mid_filter,
stride=1, stride=1,
...@@ -183,8 +201,7 @@ class MobileNetV3(): ...@@ -183,8 +201,7 @@ class MobileNetV3():
name=name + '_expand') name=name + '_expand')
if self.block_stride == 16 and stride == 2: if self.block_stride == 16 and stride == 2:
self.end_points.append(conv0) self.end_points.append(conv0)
conv1 = self._conv_bn_layer( conv1 = self._conv_bn_layer(input=conv0,
input=conv0,
filter_size=filter_size, filter_size=filter_size,
num_filters=num_mid_filter, num_filters=num_mid_filter,
stride=stride, stride=stride,
...@@ -196,11 +213,11 @@ class MobileNetV3(): ...@@ -196,11 +213,11 @@ class MobileNetV3():
name=name + '_depthwise') name=name + '_depthwise')
if use_se: if use_se:
conv1 = self._se_block( conv1 = self._se_block(input=conv1,
input=conv1, num_out_filter=num_mid_filter, name=name + '_se') num_out_filter=num_mid_filter,
name=name + '_se')
conv2 = self._conv_bn_layer( conv2 = self._conv_bn_layer(input=conv1,
input=conv1,
filter_size=1, filter_size=1,
num_filters=num_out_filter, num_filters=num_out_filter,
stride=1, stride=1,
...@@ -210,8 +227,7 @@ class MobileNetV3(): ...@@ -210,8 +227,7 @@ class MobileNetV3():
if num_in_filter != num_out_filter or stride != 1: if num_in_filter != num_out_filter or stride != 1:
return conv2 return conv2
else: else:
return fluid.layers.elementwise_add( return fluid.layers.elementwise_add(x=input_data, y=conv2, act=None)
x=input_data, y=conv2, act=None)
def _extra_block_dw(self, def _extra_block_dw(self,
input, input,
...@@ -219,16 +235,14 @@ class MobileNetV3(): ...@@ -219,16 +235,14 @@ class MobileNetV3():
num_filters2, num_filters2,
stride, stride,
name=None): name=None):
pointwise_conv = self._conv_bn_layer( pointwise_conv = self._conv_bn_layer(input=input,
input=input,
filter_size=1, filter_size=1,
num_filters=int(num_filters1), num_filters=int(num_filters1),
stride=1, stride=1,
padding="SAME", padding="SAME",
act='relu6', act='relu6',
name=name + "_extra1") name=name + "_extra1")
depthwise_conv = self._conv_bn_layer( depthwise_conv = self._conv_bn_layer(input=pointwise_conv,
input=pointwise_conv,
filter_size=3, filter_size=3,
num_filters=int(num_filters2), num_filters=int(num_filters2),
stride=stride, stride=stride,
...@@ -237,8 +251,7 @@ class MobileNetV3(): ...@@ -237,8 +251,7 @@ class MobileNetV3():
act='relu6', act='relu6',
use_cudnn=False, use_cudnn=False,
name=name + "_extra2_dw") name=name + "_extra2_dw")
normal_conv = self._conv_bn_layer( normal_conv = self._conv_bn_layer(input=depthwise_conv,
input=depthwise_conv,
filter_size=1, filter_size=1,
num_filters=int(num_filters2), num_filters=int(num_filters2),
stride=1, stride=1,
...@@ -269,8 +282,7 @@ class MobileNetV3(): ...@@ -269,8 +282,7 @@ class MobileNetV3():
self.block_stride *= layer_cfg[5] self.block_stride *= layer_cfg[5]
if layer_cfg[5] == 2: if layer_cfg[5] == 2:
blocks.append(conv) blocks.append(conv)
conv = self._residual_unit( conv = self._residual_unit(input=conv,
input=conv,
num_in_filter=inplanes, num_in_filter=inplanes,
num_mid_filter=int(scale * layer_cfg[1]), num_mid_filter=int(scale * layer_cfg[1]),
num_out_filter=int(scale * layer_cfg[2]), num_out_filter=int(scale * layer_cfg[2]),
...@@ -282,11 +294,11 @@ class MobileNetV3(): ...@@ -282,11 +294,11 @@ class MobileNetV3():
inplanes = int(scale * layer_cfg[2]) inplanes = int(scale * layer_cfg[2])
i += 1 i += 1
self.curr_stage = i
blocks.append(conv) blocks.append(conv)
if self.num_classes: if self.num_classes:
conv = self._conv_bn_layer( conv = self._conv_bn_layer(input=conv,
input=conv,
filter_size=1, filter_size=1,
num_filters=int(scale * self.cls_ch_squeeze), num_filters=int(scale * self.cls_ch_squeeze),
stride=1, stride=1,
...@@ -296,8 +308,7 @@ class MobileNetV3(): ...@@ -296,8 +308,7 @@ class MobileNetV3():
act='hard_swish', act='hard_swish',
name='conv_last') name='conv_last')
conv = fluid.layers.pool2d( conv = fluid.layers.pool2d(input=conv,
input=conv,
pool_type='avg', pool_type='avg',
global_pooling=True, global_pooling=True,
use_cudnn=False) use_cudnn=False)
...@@ -312,8 +323,7 @@ class MobileNetV3(): ...@@ -312,8 +323,7 @@ class MobileNetV3():
bias_attr=False) bias_attr=False)
conv = self._hard_swish(conv) conv = self._hard_swish(conv)
drop = fluid.layers.dropout(x=conv, dropout_prob=0.2) drop = fluid.layers.dropout(x=conv, dropout_prob=0.2)
out = fluid.layers.fc( out = fluid.layers.fc(input=drop,
input=drop,
size=self.num_classes, size=self.num_classes,
param_attr=ParamAttr(name='fc_weights'), param_attr=ParamAttr(name='fc_weights'),
bias_attr=ParamAttr(name='fc_offset')) bias_attr=ParamAttr(name='fc_offset'))
...@@ -323,8 +333,7 @@ class MobileNetV3(): ...@@ -323,8 +333,7 @@ class MobileNetV3():
return blocks return blocks
# extra block # extra block
conv_extra = self._conv_bn_layer( conv_extra = self._conv_bn_layer(conv,
conv,
filter_size=1, filter_size=1,
num_filters=int(scale * cfg[-1][1]), num_filters=int(scale * cfg[-1][1]),
stride=1, stride=1,
......
...@@ -65,7 +65,8 @@ class ResNet(object): ...@@ -65,7 +65,8 @@ class ResNet(object):
nonlocal_stages=[], nonlocal_stages=[],
gcb_stages=[], gcb_stages=[],
gcb_params=dict(), gcb_params=dict(),
num_classes=None): num_classes=None,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]):
super(ResNet, self).__init__() super(ResNet, self).__init__()
if isinstance(feature_maps, Integral): if isinstance(feature_maps, Integral):
...@@ -79,6 +80,10 @@ class ResNet(object): ...@@ -79,6 +80,10 @@ class ResNet(object):
assert norm_type in ['bn', 'sync_bn', 'affine_channel'] assert norm_type in ['bn', 'sync_bn', 'affine_channel']
assert not (len(nonlocal_stages)>0 and layers<50), \ assert not (len(nonlocal_stages)>0 and layers<50), \
"non-local is not supported for resnet18 or resnet34" "non-local is not supported for resnet18 or resnet34"
assert len(
lr_mult_list
) == 5, "lr_mult_list length in ResNet must be 5 but got {}!!".format(
len(lr_mult_list))
self.layers = layers self.layers = layers
self.freeze_at = freeze_at self.freeze_at = freeze_at
...@@ -113,6 +118,8 @@ class ResNet(object): ...@@ -113,6 +118,8 @@ class ResNet(object):
self.gcb_stages = gcb_stages self.gcb_stages = gcb_stages
self.gcb_params = gcb_params self.gcb_params = gcb_params
self.num_classes = num_classes self.num_classes = num_classes
self.lr_mult_list = lr_mult_list
self.curr_stage = 0
def _conv_offset(self, def _conv_offset(self,
input, input,
...@@ -128,8 +135,7 @@ class ResNet(object): ...@@ -128,8 +135,7 @@ class ResNet(object):
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,
padding=padding, padding=padding,
param_attr=ParamAttr( param_attr=ParamAttr(initializer=Constant(0.0), name=name + ".w_0"),
initializer=Constant(0.0), name=name + ".w_0"),
bias_attr=ParamAttr(initializer=Constant(0.0), name=name + ".b_0"), bias_attr=ParamAttr(initializer=Constant(0.0), name=name + ".b_0"),
act=act, act=act,
name=name) name=name)
...@@ -143,7 +149,9 @@ class ResNet(object): ...@@ -143,7 +149,9 @@ class ResNet(object):
groups=1, groups=1,
act=None, act=None,
name=None, name=None,
dcn_v2=False): dcn_v2=False,
use_lr_mult_list=False):
lr_mult = self.lr_mult_list[self.curr_stage] if use_lr_mult_list else 1.0
_name = self.prefix_name + name if self.prefix_name != '' else name _name = self.prefix_name + name if self.prefix_name != '' else name
if not dcn_v2: if not dcn_v2:
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(
...@@ -154,7 +162,8 @@ class ResNet(object): ...@@ -154,7 +162,8 @@ class ResNet(object):
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr(name=_name + "_weights"), param_attr=ParamAttr(name=_name + "_weights",
learning_rate=lr_mult),
bias_attr=False, bias_attr=False,
name=_name + '.conv2d.output.1') name=_name + '.conv2d.output.1')
else: else:
...@@ -191,7 +200,7 @@ class ResNet(object): ...@@ -191,7 +200,7 @@ class ResNet(object):
bn_name = self.na.fix_conv_norm_name(name) bn_name = self.na.fix_conv_norm_name(name)
bn_name = self.prefix_name + bn_name if self.prefix_name != '' else bn_name bn_name = self.prefix_name + bn_name if self.prefix_name != '' else bn_name
norm_lr = 0. if self.freeze_norm else 1. norm_lr = 0. if self.freeze_norm else lr_mult
norm_decay = self.norm_decay norm_decay = self.norm_decay
pattr = ParamAttr( pattr = ParamAttr(
name=bn_name + '_scale', name=bn_name + '_scale',
...@@ -253,7 +262,8 @@ class ResNet(object): ...@@ -253,7 +262,8 @@ class ResNet(object):
pool_padding=0, pool_padding=0,
ceil_mode=True, ceil_mode=True,
pool_type='avg') pool_type='avg')
return self._conv_norm(input, ch_out, 1, 1, name=name) return self._conv_norm(input, ch_out, 1, 1, name=name,
use_lr_mult_list=True)
return self._conv_norm(input, ch_out, 1, stride, name=name) return self._conv_norm(input, ch_out, 1, stride, name=name)
else: else:
return input return input
...@@ -448,6 +458,7 @@ class ResNet(object): ...@@ -448,6 +458,7 @@ class ResNet(object):
feature_maps = range(2, max(self.feature_maps) + 1) feature_maps = range(2, max(self.feature_maps) + 1)
for i in feature_maps: for i in feature_maps:
self.curr_stage += 1
res = self.layer_warp(res, i) res = self.layer_warp(res, i)
if i in self.feature_maps: if i in self.feature_maps:
res_endpoints.append(res) res_endpoints.append(res)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册