提交 8393d05b 编写于 作者: C chenguowei01

update

上级 ce7fb0d0
......@@ -85,7 +85,7 @@ class ResNet():
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
if self.stem == 'icnet' or self.stem == 'pspnet' or self.stem == 'deeplab':
if self.stem == 'icnet' or self.stem == 'pspnet':
conv = self.conv_bn_layer(
input=input,
num_filters=int(64 * self.scale),
......@@ -256,8 +256,6 @@ class ResNet():
return input
def bottleneck_block(self, input, num_filters, stride, name, dilation=1):
if self.stem == 'deeplab':
strides = [1, stride]
if self.stem == 'pspnet' and self.layers == 101:
strides = [1, stride]
else:
......
......@@ -229,37 +229,6 @@ def xception(input):
return data, decode_shortcut
def resnet(input):
# backbone: resnet, 可选resnet_50, resnet_101
# end_points: resnet终止层数
# dilation_dict: resnet block数及对应的膨胀卷积尺度
backbone = cfg.MODEL.DEEPLAB.BACKBONE
if '50' in backbone:
layers = 50
elif '101' in backbone:
layers = 101
else:
raise Exception("resnet backbone only support layers 50 or 101")
output_stride = cfg.MODEL.DEEPLAB.OUTPUT_STRIDE
end_points = layers - 1
decode_point = 10
if output_stride == 8:
dilation_dict = {2: 2, 3: 4}
elif output_stride == 16:
dilation_dict = {3: 2}
else:
raise Exception("deeplab only support stride 8 or 16")
model = resnet_backbone(layers, stem='deeplab')
data, decode_shortcuts = model.net(
input,
end_points=end_points,
decode_points=decode_point,
dilation_dict=dilation_dict)
decode_shortcut = decode_shortcuts[decode_point]
return data, decode_shortcut
def resnet_vd(input):
# backbone: resnet_vd, 可选resnet_vd_50, resnet_vd_101
# end_points: resnet终止层数
......@@ -299,10 +268,8 @@ def deeplabv3p(img, num_classes):
data, decode_shortcut = xception(img)
elif 'mobilenet' in cfg.MODEL.DEEPLAB.BACKBONE:
data, decode_shortcut = mobilenetv2(img)
elif 'resnet_vd' in cfg.MODEL.DEEPLAB.BACKBONE:
data, decode_shortcut = resnet_vd(img)
elif 'resnet' in cfg.MODEL.DEEPLAB.BACKBONE:
data, decode_shortcut = resnet(img)
data, decode_shortcut = resnet_vd(img)
else:
raise Exception(
"deeplab only support xception, mobilenet, resnet and resnet_vd backbone"
......
......@@ -194,7 +194,7 @@ cfg.MODEL.FP16 = False
cfg.MODEL.SCALE_LOSS = "DYNAMIC"
########################## DeepLab模型配置 ####################################
# DeepLab backbone 配置, 可选项xception_65, xception_41, xception_71, mobilenetv2, resnet_50, resnet_101, resnet_vd_50, resnet_vd_101
# DeepLab backbone 配置, 可选项xception_65, xception_41, xception_71, mobilenetv2, resnet50_vd, resnet101_vd
cfg.MODEL.DEEPLAB.BACKBONE = "xception_65"
# DeepLab output stride
cfg.MODEL.DEEPLAB.OUTPUT_STRIDE = 16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册