提交 c5ebd13a 编写于 作者: C chenguowei01

Merge branch 'develop' of https://github.com/wuyefeilin/PaddleSeg into develop

...@@ -40,11 +40,21 @@ train_parameters = { ...@@ -40,11 +40,21 @@ train_parameters = {
class ResNet(): class ResNet():
def __init__(self, layers=50, scale=1.0, stem=None): def __init__(self,
layers=50,
scale=1.0,
stem=None,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]):
self.params = train_parameters self.params = train_parameters
self.layers = layers self.layers = layers
self.scale = scale self.scale = scale
self.stem = stem self.stem = stem
self.lr_mult_list = lr_mult_list
assert len(
self.lr_mult_list
) == 5, "lr_mult_list length in ResNet must be 5 but got {}!!".format(
len(self.lr_mult_list))
self.curr_stage = 0
def net(self, def net(self,
input, input,
...@@ -86,38 +96,33 @@ class ResNet(): ...@@ -86,38 +96,33 @@ class ResNet():
num_filters = [64, 128, 256, 512] num_filters = [64, 128, 256, 512]
if self.stem == 'icnet' or self.stem == 'pspnet' or self.stem == 'deeplab': if self.stem == 'icnet' or self.stem == 'pspnet' or self.stem == 'deeplab':
conv = self.conv_bn_layer( conv = self.conv_bn_layer(input=input,
input=input,
num_filters=int(32 * self.scale), num_filters=int(32 * self.scale),
filter_size=3, filter_size=3,
stride=2, stride=2,
act='relu', act='relu',
name="conv1_1") name="conv1_1")
conv = self.conv_bn_layer( conv = self.conv_bn_layer(input=conv,
input=conv,
num_filters=int(32 * self.scale), num_filters=int(32 * self.scale),
filter_size=3, filter_size=3,
stride=1, stride=1,
act='relu', act='relu',
name="conv1_2") name="conv1_2")
conv = self.conv_bn_layer( conv = self.conv_bn_layer(input=conv,
input=conv,
num_filters=int(64 * self.scale), num_filters=int(64 * self.scale),
filter_size=3, filter_size=3,
stride=1, stride=1,
act='relu', act='relu',
name="conv1_3") name="conv1_3")
else: else:
conv = self.conv_bn_layer( conv = self.conv_bn_layer(input=input,
input=input,
num_filters=int(64 * self.scale), num_filters=int(64 * self.scale),
filter_size=7, filter_size=7,
stride=2, stride=2,
act='relu', act='relu',
name="conv1") name="conv1")
conv = fluid.layers.pool2d( conv = fluid.layers.pool2d(input=conv,
input=conv,
pool_size=3, pool_size=3,
pool_stride=2, pool_stride=2,
pool_padding=1, pool_padding=1,
...@@ -132,6 +137,7 @@ class ResNet(): ...@@ -132,6 +137,7 @@ class ResNet():
if layers >= 50: if layers >= 50:
for block in range(len(depth)): for block in range(len(depth)):
self.curr_stage += 1
for i in range(depth[block]): for i in range(depth[block]):
if layers in [101, 152] and block == 2: if layers in [101, 152] and block == 2:
if i == 0: if i == 0:
...@@ -164,8 +170,10 @@ class ResNet(): ...@@ -164,8 +170,10 @@ class ResNet():
np.ceil( np.ceil(
np.array(conv.shape[2:]).astype('int32') / 2)) np.array(conv.shape[2:]).astype('int32') / 2))
pool = fluid.layers.pool2d( pool = fluid.layers.pool2d(input=conv,
input=conv, pool_size=7, pool_type='avg', global_pooling=True) pool_size=7,
pool_type='avg',
global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc( out = fluid.layers.fc(
input=pool, input=pool,
...@@ -174,6 +182,7 @@ class ResNet(): ...@@ -174,6 +182,7 @@ class ResNet():
initializer=fluid.initializer.Uniform(-stdv, stdv))) initializer=fluid.initializer.Uniform(-stdv, stdv)))
else: else:
for block in range(len(depth)): for block in range(len(depth)):
self.curr_stage += 1
for i in range(depth[block]): for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i) conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.basic_block( conv = self.basic_block(
...@@ -189,8 +198,10 @@ class ResNet(): ...@@ -189,8 +198,10 @@ class ResNet():
if check_points(layer_count, end_points): if check_points(layer_count, end_points):
return conv, decode_ends return conv, decode_ends
pool = fluid.layers.pool2d( pool = fluid.layers.pool2d(input=conv,
input=conv, pool_size=7, pool_type='avg', global_pooling=True) pool_size=7,
pool_type='avg',
global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc( out = fluid.layers.fc(
input=pool, input=pool,
...@@ -217,21 +228,23 @@ class ResNet(): ...@@ -217,21 +228,23 @@ class ResNet():
act=None, act=None,
name=None): name=None):
lr_mult = self.lr_mult_list[self.curr_stage]
if self.stem == 'pspnet': if self.stem == 'pspnet':
bias_attr = ParamAttr(name=name + "_biases") bias_attr = ParamAttr(name=name + "_biases")
else: else:
bias_attr = False bias_attr = False
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(input=input,
input=input,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,
padding=(filter_size - 1) // 2 if dilation == 1 else 0, padding=(filter_size - 1) //
2 if dilation == 1 else 0,
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr(name=name + "_weights"), param_attr=ParamAttr(name=name + "_weights",
learning_rate=lr_mult),
bias_attr=bias_attr, bias_attr=bias_attr,
name=name + '.conv2d.output.1') name=name + '.conv2d.output.1')
...@@ -243,8 +256,9 @@ class ResNet(): ...@@ -243,8 +256,9 @@ class ResNet():
input=conv, input=conv,
act=act, act=act,
name=bn_name + '.output.1', name=bn_name + '.output.1',
param_attr=ParamAttr(name=bn_name + '_scale'), param_attr=ParamAttr(name=bn_name + '_scale',
bias_attr=ParamAttr(bn_name + '_offset'), learning_rate=lr_mult),
bias_attr=ParamAttr(bn_name + '_offset', learning_rate=lr_mult),
moving_mean_name=bn_name + '_mean', moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', moving_variance_name=bn_name + '_variance',
) )
...@@ -257,23 +271,23 @@ class ResNet(): ...@@ -257,23 +271,23 @@ class ResNet():
groups=1, groups=1,
act=None, act=None,
name=None): name=None):
pool = fluid.layers.pool2d( lr_mult = self.lr_mult_list[self.curr_stage]
input=input, pool = fluid.layers.pool2d(input=input,
pool_size=2, pool_size=2,
pool_stride=2, pool_stride=2,
pool_padding=0, pool_padding=0,
pool_type='avg', pool_type='avg',
ceil_mode=True) ceil_mode=True)
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(input=pool,
input=pool,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=1, stride=1,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr(name=name + "_weights"), param_attr=ParamAttr(name=name + "_weights",
learning_rate=lr_mult),
bias_attr=False) bias_attr=False)
if name == "conv1": if name == "conv1":
bn_name = "bn_" + name bn_name = "bn_" + name
...@@ -282,8 +296,9 @@ class ResNet(): ...@@ -282,8 +296,9 @@ class ResNet():
return fluid.layers.batch_norm( return fluid.layers.batch_norm(
input=conv, input=conv,
act=act, act=act,
param_attr=ParamAttr(name=bn_name + '_scale'), param_attr=ParamAttr(name=bn_name + '_scale',
bias_attr=ParamAttr(bn_name + '_offset'), learning_rate=lr_mult),
bias_attr=ParamAttr(bn_name + '_offset', learning_rate=lr_mult),
moving_mean_name=bn_name + '_mean', moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance') moving_variance_name=bn_name + '_variance')
...@@ -294,8 +309,11 @@ class ResNet(): ...@@ -294,8 +309,11 @@ class ResNet():
if is_first or stride == 1: if is_first or stride == 1:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name) return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else: else:
return self.conv_bn_layer_new( return self.conv_bn_layer_new(input,
input, ch_out, 1, stride, name=name) ch_out,
1,
stride,
name=name)
elif is_first: elif is_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name) return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else: else:
...@@ -308,8 +326,7 @@ class ResNet(): ...@@ -308,8 +326,7 @@ class ResNet():
name, name,
is_first=False, is_first=False,
dilation=1): dilation=1):
conv0 = self.conv_bn_layer( conv0 = self.conv_bn_layer(input=input,
input=input,
num_filters=num_filters, num_filters=num_filters,
filter_size=1, filter_size=1,
dilation=1, dilation=1,
...@@ -318,24 +335,21 @@ class ResNet(): ...@@ -318,24 +335,21 @@ class ResNet():
name=name + "_branch2a") name=name + "_branch2a")
if dilation > 1: if dilation > 1:
conv0 = self.zero_padding(conv0, dilation) conv0 = self.zero_padding(conv0, dilation)
conv1 = self.conv_bn_layer( conv1 = self.conv_bn_layer(input=conv0,
input=conv0,
num_filters=num_filters, num_filters=num_filters,
filter_size=3, filter_size=3,
dilation=dilation, dilation=dilation,
stride=stride, stride=stride,
act='relu', act='relu',
name=name + "_branch2b") name=name + "_branch2b")
conv2 = self.conv_bn_layer( conv2 = self.conv_bn_layer(input=conv1,
input=conv1,
num_filters=num_filters * 4, num_filters=num_filters * 4,
dilation=1, dilation=1,
filter_size=1, filter_size=1,
act=None, act=None,
name=name + "_branch2c") name=name + "_branch2c")
short = self.shortcut( short = self.shortcut(input,
input,
num_filters * 4, num_filters * 4,
stride, stride,
is_first=is_first, is_first=is_first,
...@@ -343,25 +357,28 @@ class ResNet(): ...@@ -343,25 +357,28 @@ class ResNet():
print(input.shape, short.shape, conv2.shape) print(input.shape, short.shape, conv2.shape)
print(stride) print(stride)
return fluid.layers.elementwise_add( return fluid.layers.elementwise_add(x=short,
x=short, y=conv2, act='relu', name=name + ".add.output.5") y=conv2,
act='relu',
name=name + ".add.output.5")
def basic_block(self, input, num_filters, stride, is_first, name): def basic_block(self, input, num_filters, stride, is_first, name):
conv0 = self.conv_bn_layer( conv0 = self.conv_bn_layer(input=input,
input=input,
num_filters=num_filters, num_filters=num_filters,
filter_size=3, filter_size=3,
act='relu', act='relu',
stride=stride, stride=stride,
name=name + "_branch2a") name=name + "_branch2a")
conv1 = self.conv_bn_layer( conv1 = self.conv_bn_layer(input=conv0,
input=conv0,
num_filters=num_filters, num_filters=num_filters,
filter_size=3, filter_size=3,
act=None, act=None,
name=name + "_branch2b") name=name + "_branch2b")
short = self.shortcut( short = self.shortcut(input,
input, num_filters, stride, is_first, name=name + "_branch1") num_filters,
stride,
is_first,
name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu') return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
......
...@@ -280,7 +280,9 @@ def resnet_vd(input): ...@@ -280,7 +280,9 @@ def resnet_vd(input):
dilation_dict = {3: 2} dilation_dict = {3: 2}
else: else:
raise Exception("deeplab only support stride 8 or 16") raise Exception("deeplab only support stride 8 or 16")
model = resnet_vd_backbone(layers, stem='deeplab') lr_mult_list = cfg.MODEL.DEEPLAB.RESNET_LR_MULT_LIST
model = resnet_vd_backbone(
layers, stem='deeplab', lr_mult_list=lr_mult_list)
data, decode_shortcuts = model.net( data, decode_shortcuts = model.net(
input, input,
end_points=end_points, end_points=end_points,
......
...@@ -206,6 +206,8 @@ cfg.MODEL.DEEPLAB.ENABLE_DECODER = True ...@@ -206,6 +206,8 @@ cfg.MODEL.DEEPLAB.ENABLE_DECODER = True
cfg.MODEL.DEEPLAB.ASPP_WITH_SEP_CONV = True cfg.MODEL.DEEPLAB.ASPP_WITH_SEP_CONV = True
# 解码器是否使用可分离卷积 # 解码器是否使用可分离卷积
cfg.MODEL.DEEPLAB.DECODER_USE_SEP_CONV = True cfg.MODEL.DEEPLAB.DECODER_USE_SEP_CONV = True
# resnet_vd分阶段学习率
cfg.MODEL.DEEPLAB.RESNET_LR_MULT_LIST = [1.0, 1.0, 1.0, 1.0, 1.0]
########################## UNET模型配置 ####################################### ########################## UNET模型配置 #######################################
# 上采样方式, 默认为双线性插值 # 上采样方式, 默认为双线性插值
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册