提交 b28bc4db 编写于 作者: P pengmian

modify pspnet

上级 963b9031
...@@ -133,20 +133,21 @@ class ResNet(): ...@@ -133,20 +133,21 @@ class ResNet():
if layers >= 50: if layers >= 50:
for block in range(len(depth)): for block in range(len(depth)):
for i in range(depth[block]): for i in range(depth[block]):
conv_name = "conv" + str(block + 2) + '_' + str(1 + i) if layers in [101, 152] and block == 2:
dilation_rate = get_dilated_rate(dilation_dict, block) if i == 0:
conv_name = "res" + str(block + 2) + "a"
if self.stem == 'pspnet': else:
stride = 2 if i == 0 and block == 1 else 1 conv_name = "res" + str(block + 2) + "b" + str(i)
else: else:
stride= 2 if i == 0 and block != 0 and dilation_rate == 1 else 1 conv_name = "conv" + str(block + 2) + '_' + str(1 + i)
dilation_rate = get_dilated_rate(dilation_dict, block)
conv = self.bottleneck_block( conv = self.bottleneck_block(
input=conv, input=conv,
num_filters=int(num_filters[block] * self.scale), num_filters=int(num_filters[block] * self.scale),
stride=stride, stride=2
name=conv_name, if i == 0 and block != 0 and dilation_rate == 1 else 1,
dilation=dilation_rate) name=conv_name,
dilation=dilation_rate)
layer_count += 3 layer_count += 3
if check_points(layer_count, decode_points): if check_points(layer_count, decode_points):
...@@ -172,7 +173,7 @@ class ResNet(): ...@@ -172,7 +173,7 @@ class ResNet():
else: else:
for block in range(len(depth)): for block in range(len(depth)):
for i in range(depth[block]): for i in range(depth[block]):
conv_name = "conv" + str(block + 2) + chr(97 + i) conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.basic_block( conv = self.basic_block(
input=conv, input=conv,
num_filters=num_filters[block], num_filters=num_filters[block],
......
...@@ -12,6 +12,7 @@ from models.backbone.resnet import ResNet as resnet_backbone ...@@ -12,6 +12,7 @@ from models.backbone.resnet import ResNet as resnet_backbone
from utils.config import cfg from utils.config import cfg
def get_logit_interp(input, num_classes, out_shape, name="logit"): def get_logit_interp(input, num_classes, out_shape, name="logit"):
# 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
param_attr = fluid.ParamAttr( param_attr = fluid.ParamAttr(
name=name + 'weights', name=name + 'weights',
regularizer=fluid.regularizer.L2DecayRegularizer( regularizer=fluid.regularizer.L2DecayRegularizer(
...@@ -19,13 +20,12 @@ def get_logit_interp(input, num_classes, out_shape, name="logit"): ...@@ -19,13 +20,12 @@ def get_logit_interp(input, num_classes, out_shape, name="logit"):
initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.01)) initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.01))
with scope(name): with scope(name):
logit = conv( logit = conv(input,
input, num_classes,
num_classes, filter_size=1,
filter_size=1, param_attr=param_attr,
param_attr=param_attr, bias_attr=True,
bias_attr=True, name=name+'_conv')
name=name+'.conv2d.output.1')
logit_interp = fluid.layers.resize_bilinear( logit_interp = fluid.layers.resize_bilinear(
logit, logit,
out_shape=out_shape, out_shape=out_shape,
...@@ -34,53 +34,67 @@ def get_logit_interp(input, num_classes, out_shape, name="logit"): ...@@ -34,53 +34,67 @@ def get_logit_interp(input, num_classes, out_shape, name="logit"):
def psp_module(input, out_features): def psp_module(input, out_features):
# Pyramid Scene Parsing 金字塔池化模块
# 输入:backbone输出的特征
# 输出:对输入进行不同尺度pooling, 卷积操作后插值回原始尺寸,并concat
# 最后进行一个卷积及BN操作
cat_layers = [] cat_layers = []
sizes = (1,2,3,6) sizes = (1,2,3,6)
for size in sizes: for size in sizes:
psp_name = "psp_conv" + str(size) psp_name = "psp_conv" + str(size)
with scope(psp_name): with scope(psp_name):
pool = fluid.layers.adaptive_pool2d(input, pool = fluid.layers.adaptive_pool2d(input,
pool_size=[size, size], pool_size=[size, size],
pool_type='avg', pool_type='avg',
name=psp_name+'_adapool') name=psp_name+'_adapool')
data = conv(pool, out_features, filter_size=1, bias_attr=True, data = conv(pool, out_features,
name= psp_name + '.conv2d.output.1') filter_size=1,
bias_attr=True,
name= psp_name + '_conv')
data_bn = bn(data, act='relu') data_bn = bn(data, act='relu')
interp = fluid.layers.resize_bilinear(data_bn, interp = fluid.layers.resize_bilinear(data_bn,
out_shape=input.shape[2:], out_shape=input.shape[2:],
name=psp_name+'_interp') name=psp_name+'_interp')
cat_layers.append(interp) cat_layers.append(interp)
cat_layers = [input] + cat_layers[::-1] cat_layers = [input] + cat_layers[::-1]
cat = fluid.layers.concat(cat_layers, axis=1, name='psp_cat') cat = fluid.layers.concat(cat_layers, axis=1, name='psp_cat')
with scope("psp_conv_end"):
psp_end_name = "psp_conv_end"
with scope(psp_end_name):
data = conv(cat, data = conv(cat,
out_features, out_features,
filter_size=3, filter_size=3,
padding=1, padding=1,
bias_attr=True, bias_attr=True,
name='psp_conv_end.conv2d.output.1') name=psp_end_name)
out = bn(data, act='relu') out = bn(data, act='relu')
return out return out
def resnet(input): def resnet(input):
# PSPNET backbone: resnet, ĬÈresnet50 # PSPNET backbone: resnet, 默认resnet50
# end_points: resnetÖֹ²ã # end_points: resnet终止层数
# dilation_dict: resnet block数及对应的膨胀卷积尺度
scale = cfg.MODEL.ICNET.DEPTH_MULTIPLIER
scale = cfg.MODEL.PSPNET.DEPTH_MULTIPLIER scale = cfg.MODEL.PSPNET.DEPTH_MULTIPLIER
layers = cfg.MODEL.PSPNET.LAYERS layers = cfg.MODEL.PSPNET.LAYERS
end_points = layers - 1 end_points = layers - 1
dilation_dict = {2:2, 3:4} dilation_dict = {2:2, 3:4}
model = resnet_backbone(layers, scale, stem='pspnet') model = resnet_backbone(layers, scale, stem='pspnet')
data, _ = model.net(input, end_points=end_points, dilation_dict=dilation_dict) data, _ = model.net(input,
end_points=end_points,
dilation_dict=dilation_dict)
return data return data
def pspnet(input, num_classes): def pspnet(input, num_classes):
# Backbone: ResNet
res = resnet(input) res = resnet(input)
# PSP模块
psp = psp_module(res, 512) psp = psp_module(res, 512)
#dropout = fluid.layers.dropout(psp, dropout_prob=0.1, name="dropout") dropout = fluid.layers.dropout(psp, dropout_prob=0.1, name="dropout")
logit = get_logit_interp(psp, num_classes, input.shape[2:]) # 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
logit = get_logit_interp(dropout, num_classes, input.shape[2:])
return logit return logit
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册