From 8393d05b35a800ed0f91384d427bac94f28fa90f Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Mon, 29 Jun 2020 11:52:49 +0800 Subject: [PATCH] update --- pdseg/models/backbone/resnet.py | 4 +--- pdseg/models/modeling/deeplab.py | 35 +------------------------------- pdseg/utils/config.py | 2 +- 3 files changed, 3 insertions(+), 38 deletions(-) diff --git a/pdseg/models/backbone/resnet.py b/pdseg/models/backbone/resnet.py index 260a1806..60a7bc5d 100644 --- a/pdseg/models/backbone/resnet.py +++ b/pdseg/models/backbone/resnet.py @@ -85,7 +85,7 @@ class ResNet(): depth = [3, 8, 36, 3] num_filters = [64, 128, 256, 512] - if self.stem == 'icnet' or self.stem == 'pspnet' or self.stem == 'deeplab': + if self.stem == 'icnet' or self.stem == 'pspnet': conv = self.conv_bn_layer( input=input, num_filters=int(64 * self.scale), @@ -256,8 +256,6 @@ class ResNet(): return input def bottleneck_block(self, input, num_filters, stride, name, dilation=1): - if self.stem == 'deeplab': - strides = [1, stride] if self.stem == 'pspnet' and self.layers == 101: strides = [1, stride] else: diff --git a/pdseg/models/modeling/deeplab.py b/pdseg/models/modeling/deeplab.py index 0c338504..fae27653 100644 --- a/pdseg/models/modeling/deeplab.py +++ b/pdseg/models/modeling/deeplab.py @@ -229,37 +229,6 @@ def xception(input): return data, decode_shortcut -def resnet(input): - # backbone: resnet, 可选resnet_50, resnet_101 - # end_points: resnet终止层数 - # dilation_dict: resnet block数及对应的膨胀卷积尺度 - backbone = cfg.MODEL.DEEPLAB.BACKBONE - if '50' in backbone: - layers = 50 - elif '101' in backbone: - layers = 101 - else: - raise Exception("resnet backbone only support layers 50 or 101") - output_stride = cfg.MODEL.DEEPLAB.OUTPUT_STRIDE - end_points = layers - 1 - decode_point = 10 - if output_stride == 8: - dilation_dict = {2: 2, 3: 4} - elif output_stride == 16: - dilation_dict = {3: 2} - else: - raise Exception("deeplab only support stride 8 or 16") - model = resnet_backbone(layers, stem='deeplab') - data, decode_shortcuts = model.net( - input, - end_points=end_points, - decode_points=decode_point, - dilation_dict=dilation_dict) - decode_shortcut = decode_shortcuts[decode_point] - - return data, decode_shortcut - - def resnet_vd(input): # backbone: resnet_vd, 可选resnet_vd_50, resnet_vd_101 # end_points: resnet终止层数 @@ -299,10 +268,8 @@ def deeplabv3p(img, num_classes): data, decode_shortcut = xception(img) elif 'mobilenet' in cfg.MODEL.DEEPLAB.BACKBONE: data, decode_shortcut = mobilenetv2(img) - elif 'resnet_vd' in cfg.MODEL.DEEPLAB.BACKBONE: - data, decode_shortcut = resnet_vd(img) elif 'resnet' in cfg.MODEL.DEEPLAB.BACKBONE: - data, decode_shortcut = resnet(img) + data, decode_shortcut = resnet_vd(img) else: raise Exception( "deeplab only support xception, mobilenet, resnet and resnet_vd backbone" diff --git a/pdseg/utils/config.py b/pdseg/utils/config.py index a139c0a3..7058433c 100644 --- a/pdseg/utils/config.py +++ b/pdseg/utils/config.py @@ -194,7 +194,7 @@ cfg.MODEL.FP16 = False cfg.MODEL.SCALE_LOSS = "DYNAMIC" ########################## DeepLab模型配置 #################################### -# DeepLab backbone 配置, 可选项xception_65, xception_41, xception_71, mobilenetv2, resnet_50, resnet_101, resnet_vd_50, resnet_vd_101 +# DeepLab backbone 配置, 可选项xception_65, xception_41, xception_71, mobilenetv2, resnet50_vd, resnet101_vd cfg.MODEL.DEEPLAB.BACKBONE = "xception_65" # DeepLab output stride cfg.MODEL.DEEPLAB.OUTPUT_STRIDE = 16 -- GitLab