提交 a689b8dd 编写于 作者: C chenguowei01

add mult grid and rgb change

上级 73fc0c03
...@@ -65,6 +65,7 @@ class ResNet(): ...@@ -65,6 +65,7 @@ class ResNet():
dilation_dict=None): dilation_dict=None):
layers = self.layers layers = self.layers
supported_layers = [18, 34, 50, 101, 152] supported_layers = [18, 34, 50, 101, 152]
mult_grid = [1, 2, 4]
assert layers in supported_layers, \ assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers) "supported layers are {} but input layer is {}".format(supported_layers, layers)
...@@ -96,37 +97,42 @@ class ResNet(): ...@@ -96,37 +97,42 @@ class ResNet():
num_filters = [64, 128, 256, 512] num_filters = [64, 128, 256, 512]
if self.stem == 'icnet' or self.stem == 'pspnet' or self.stem == 'deeplab': if self.stem == 'icnet' or self.stem == 'pspnet' or self.stem == 'deeplab':
conv = self.conv_bn_layer(input=input, conv = self.conv_bn_layer(
num_filters=int(32 * self.scale), input=input,
filter_size=3, num_filters=int(32 * self.scale),
stride=2, filter_size=3,
act='relu', stride=2,
name="conv1_1") act='relu',
conv = self.conv_bn_layer(input=conv, name="conv1_1")
num_filters=int(32 * self.scale), conv = self.conv_bn_layer(
filter_size=3, input=conv,
stride=1, num_filters=int(32 * self.scale),
act='relu', filter_size=3,
name="conv1_2") stride=1,
conv = self.conv_bn_layer(input=conv, act='relu',
num_filters=int(64 * self.scale), name="conv1_2")
filter_size=3, conv = self.conv_bn_layer(
stride=1, input=conv,
act='relu', num_filters=int(64 * self.scale),
name="conv1_3") filter_size=3,
stride=1,
act='relu',
name="conv1_3")
else: else:
conv = self.conv_bn_layer(input=input, conv = self.conv_bn_layer(
num_filters=int(64 * self.scale), input=input,
filter_size=7, num_filters=int(64 * self.scale),
stride=2, filter_size=7,
act='relu', stride=2,
name="conv1") act='relu',
name="conv1")
conv = fluid.layers.pool2d(input=conv,
pool_size=3, conv = fluid.layers.pool2d(
pool_stride=2, input=conv,
pool_padding=1, pool_size=3,
pool_type='max') pool_stride=2,
pool_padding=1,
pool_type='max')
layer_count = 1 layer_count = 1
if check_points(layer_count, decode_points): if check_points(layer_count, decode_points):
...@@ -147,6 +153,8 @@ class ResNet(): ...@@ -147,6 +153,8 @@ class ResNet():
else: else:
conv_name = "res" + str(block + 2) + chr(97 + i) conv_name = "res" + str(block + 2) + chr(97 + i)
dilation_rate = get_dilated_rate(dilation_dict, block) dilation_rate = get_dilated_rate(dilation_dict, block)
if block == 3:
dilation_rate = dilation_rate * mult_grid[i]
conv = self.bottleneck_block( conv = self.bottleneck_block(
input=conv, input=conv,
...@@ -170,10 +178,8 @@ class ResNet(): ...@@ -170,10 +178,8 @@ class ResNet():
np.ceil( np.ceil(
np.array(conv.shape[2:]).astype('int32') / 2)) np.array(conv.shape[2:]).astype('int32') / 2))
pool = fluid.layers.pool2d(input=conv, pool = fluid.layers.pool2d(
pool_size=7, input=conv, pool_size=7, pool_type='avg', global_pooling=True)
pool_type='avg',
global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc( out = fluid.layers.fc(
input=pool, input=pool,
...@@ -198,10 +204,8 @@ class ResNet(): ...@@ -198,10 +204,8 @@ class ResNet():
if check_points(layer_count, end_points): if check_points(layer_count, end_points):
return conv, decode_ends return conv, decode_ends
pool = fluid.layers.pool2d(input=conv, pool = fluid.layers.pool2d(
pool_size=7, input=conv, pool_size=7, pool_type='avg', global_pooling=True)
pool_type='avg',
global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc( out = fluid.layers.fc(
input=pool, input=pool,
...@@ -234,19 +238,18 @@ class ResNet(): ...@@ -234,19 +238,18 @@ class ResNet():
else: else:
bias_attr = False bias_attr = False
conv = fluid.layers.conv2d(input=input, conv = fluid.layers.conv2d(
num_filters=num_filters, input=input,
filter_size=filter_size, num_filters=num_filters,
stride=stride, filter_size=filter_size,
padding=(filter_size - 1) // stride=stride,
2 if dilation == 1 else 0, padding=(filter_size - 1) // 2 if dilation == 1 else 0,
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr(name=name + "_weights", param_attr=ParamAttr(name=name + "_weights", learning_rate=lr_mult),
learning_rate=lr_mult), bias_attr=bias_attr,
bias_attr=bias_attr, name=name + '.conv2d.output.1')
name=name + '.conv2d.output.1')
if name == "conv1": if name == "conv1":
bn_name = "bn_" + name bn_name = "bn_" + name
...@@ -256,8 +259,8 @@ class ResNet(): ...@@ -256,8 +259,8 @@ class ResNet():
input=conv, input=conv,
act=act, act=act,
name=bn_name + '.output.1', name=bn_name + '.output.1',
param_attr=ParamAttr(name=bn_name + '_scale', param_attr=ParamAttr(
learning_rate=lr_mult), name=bn_name + '_scale', learning_rate=lr_mult),
bias_attr=ParamAttr(bn_name + '_offset', learning_rate=lr_mult), bias_attr=ParamAttr(bn_name + '_offset', learning_rate=lr_mult),
moving_mean_name=bn_name + '_mean', moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', moving_variance_name=bn_name + '_variance',
...@@ -272,23 +275,24 @@ class ResNet(): ...@@ -272,23 +275,24 @@ class ResNet():
act=None, act=None,
name=None): name=None):
lr_mult = self.lr_mult_list[self.curr_stage] lr_mult = self.lr_mult_list[self.curr_stage]
pool = fluid.layers.pool2d(input=input, pool = fluid.layers.pool2d(
pool_size=2, input=input,
pool_stride=2, pool_size=2,
pool_padding=0, pool_stride=2,
pool_type='avg', pool_padding=0,
ceil_mode=True) pool_type='avg',
ceil_mode=True)
conv = fluid.layers.conv2d(input=pool,
num_filters=num_filters, conv = fluid.layers.conv2d(
filter_size=filter_size, input=pool,
stride=1, num_filters=num_filters,
padding=(filter_size - 1) // 2, filter_size=filter_size,
groups=groups, stride=1,
act=None, padding=(filter_size - 1) // 2,
param_attr=ParamAttr(name=name + "_weights", groups=groups,
learning_rate=lr_mult), act=None,
bias_attr=False) param_attr=ParamAttr(name=name + "_weights", learning_rate=lr_mult),
bias_attr=False)
if name == "conv1": if name == "conv1":
bn_name = "bn_" + name bn_name = "bn_" + name
else: else:
...@@ -296,24 +300,20 @@ class ResNet(): ...@@ -296,24 +300,20 @@ class ResNet():
return fluid.layers.batch_norm( return fluid.layers.batch_norm(
input=conv, input=conv,
act=act, act=act,
param_attr=ParamAttr(name=bn_name + '_scale', param_attr=ParamAttr(
learning_rate=lr_mult), name=bn_name + '_scale', learning_rate=lr_mult),
bias_attr=ParamAttr(bn_name + '_offset', learning_rate=lr_mult), bias_attr=ParamAttr(bn_name + '_offset', learning_rate=lr_mult),
moving_mean_name=bn_name + '_mean', moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance') moving_variance_name=bn_name + '_variance')
def shortcut(self, input, ch_out, stride, is_first, name): def shortcut(self, input, ch_out, stride, is_first, name):
ch_in = input.shape[1] ch_in = input.shape[1]
print('shortcut:', stride, is_first, ch_in, ch_out)
if ch_in != ch_out or stride != 1: if ch_in != ch_out or stride != 1:
if is_first or stride == 1: if is_first or stride == 1:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name) return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else: else:
return self.conv_bn_layer_new(input, return self.conv_bn_layer_new(
ch_out, input, ch_out, 1, stride, name=name)
1,
stride,
name=name)
elif is_first: elif is_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name) return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else: else:
...@@ -326,59 +326,58 @@ class ResNet(): ...@@ -326,59 +326,58 @@ class ResNet():
name, name,
is_first=False, is_first=False,
dilation=1): dilation=1):
conv0 = self.conv_bn_layer(input=input, conv0 = self.conv_bn_layer(
num_filters=num_filters, input=input,
filter_size=1, num_filters=num_filters,
dilation=1, filter_size=1,
stride=1, dilation=1,
act='relu', stride=1,
name=name + "_branch2a") act='relu',
name=name + "_branch2a")
if dilation > 1: if dilation > 1:
conv0 = self.zero_padding(conv0, dilation) conv0 = self.zero_padding(conv0, dilation)
conv1 = self.conv_bn_layer(input=conv0, conv1 = self.conv_bn_layer(
num_filters=num_filters, input=conv0,
filter_size=3, num_filters=num_filters,
dilation=dilation, filter_size=3,
stride=stride, dilation=dilation,
act='relu', stride=stride,
name=name + "_branch2b") act='relu',
conv2 = self.conv_bn_layer(input=conv1, name=name + "_branch2b")
num_filters=num_filters * 4, conv2 = self.conv_bn_layer(
dilation=1, input=conv1,
filter_size=1, num_filters=num_filters * 4,
act=None, dilation=1,
name=name + "_branch2c") filter_size=1,
act=None,
short = self.shortcut(input, name=name + "_branch2c")
num_filters * 4,
stride, short = self.shortcut(
is_first=is_first, input,
name=name + "_branch1") num_filters * 4,
print(input.shape, short.shape, conv2.shape) stride,
print(stride) is_first=is_first,
name=name + "_branch1")
return fluid.layers.elementwise_add(x=short,
y=conv2, return fluid.layers.elementwise_add(
act='relu', x=short, y=conv2, act='relu', name=name + ".add.output.5")
name=name + ".add.output.5")
def basic_block(self, input, num_filters, stride, is_first, name): def basic_block(self, input, num_filters, stride, is_first, name):
conv0 = self.conv_bn_layer(input=input, conv0 = self.conv_bn_layer(
num_filters=num_filters, input=input,
filter_size=3, num_filters=num_filters,
act='relu', filter_size=3,
stride=stride, act='relu',
name=name + "_branch2a") stride=stride,
conv1 = self.conv_bn_layer(input=conv0, name=name + "_branch2a")
num_filters=num_filters, conv1 = self.conv_bn_layer(
filter_size=3, input=conv0,
act=None, num_filters=num_filters,
name=name + "_branch2b") filter_size=3,
short = self.shortcut(input, act=None,
num_filters, name=name + "_branch2b")
stride, short = self.shortcut(
is_first, input, num_filters, stride, is_first, name=name + "_branch1")
name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu') return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
......
...@@ -318,6 +318,8 @@ class SegDataset(object): ...@@ -318,6 +318,8 @@ class SegDataset(object):
raise ValueError("Dataset mode={} Error!".format(mode)) raise ValueError("Dataset mode={} Error!".format(mode))
# Normalize image # Normalize image
if cfg.AUG.TO_RGB:
img = img[..., ::-1]
img = self.normalize_image(img) img = self.normalize_image(img)
if ModelPhase.is_train(mode) or ModelPhase.is_eval(mode): if ModelPhase.is_train(mode) or ModelPhase.is_eval(mode):
......
...@@ -117,6 +117,8 @@ cfg.AUG.RICH_CROP.CONTRAST_JITTER_RATIO = 0.5 ...@@ -117,6 +117,8 @@ cfg.AUG.RICH_CROP.CONTRAST_JITTER_RATIO = 0.5
cfg.AUG.RICH_CROP.BLUR = False cfg.AUG.RICH_CROP.BLUR = False
# 图像启动模糊百分比,0-1 # 图像启动模糊百分比,0-1
cfg.AUG.RICH_CROP.BLUR_RATIO = 0.1 cfg.AUG.RICH_CROP.BLUR_RATIO = 0.1
# 图像是否切换到rgb模式
cfg.AUG.TO_RGB = True
########################### 训练配置 ########################################## ########################### 训练配置 ##########################################
# 模型保存路径 # 模型保存路径
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册