diff --git a/pdseg/models/modeling/deeplab.py b/pdseg/models/modeling/deeplab.py index 685c7495b68934d34dd19672ec6941787ed2492e..454a5c4700e6b68a1be177b735c1711dc69564cc 100644 --- a/pdseg/models/modeling/deeplab.py +++ b/pdseg/models/modeling/deeplab.py @@ -248,7 +248,7 @@ def resnet_vd(input): dilation_dict = {3: 2} else: raise Exception("deeplab only support stride 8 or 16") - lr_mult_list = cfg.MODEL.DEEPLAB.RESNET_LR_MULT_LIST + lr_mult_list = cfg.MODEL.DEEPLAB.BACKBONE_LR_MULT_LIST model = resnet_vd_backbone( layers, stem='deeplab', lr_mult_list=lr_mult_list) data, decode_shortcuts = model.net( @@ -265,11 +265,16 @@ def deeplabv3p(img, num_classes): # Backbone设置:xception 或 mobilenetv2 if 'xception' in cfg.MODEL.DEEPLAB.BACKBONE: data, decode_shortcut = xception(img) - print('xception backbone do not support BACKBONE_LR_MULT_LIST setting') + if cfg.MODEL.DEEPLAB.BACKBONE_LR_MULT_LIST is not None: + print( + 'xception backbone do not support BACKBONE_LR_MULT_LIST setting' + ) elif 'mobilenet' in cfg.MODEL.DEEPLAB.BACKBONE: data, decode_shortcut = mobilenetv2(img) - print( - 'mobilenetv2 backbone do not support BACKBONE_LR_MULT_LIST setting') + if cfg.MODEL.DEEPLAB.BACKBONE_LR_MULT_LIST is not None: + print( + 'mobilenetv2 backbone do not support BACKBONE_LR_MULT_LIST setting' + ) elif 'resnet' in cfg.MODEL.DEEPLAB.BACKBONE: data, decode_shortcut = resnet_vd(img) else: diff --git a/pdseg/utils/config.py b/pdseg/utils/config.py index 9cce381f267831cc88b699c2d8b0e428171fbfef..4a2b9d852f2174607b5ec37f716beaedb252b33d 100644 --- a/pdseg/utils/config.py +++ b/pdseg/utils/config.py @@ -209,7 +209,7 @@ cfg.MODEL.DEEPLAB.ASPP_WITH_SEP_CONV = True # 解码器是否使用可分离卷积 cfg.MODEL.DEEPLAB.DECODER_USE_SEP_CONV = True # resnet_vd分阶段学习率 -cfg.MODEL.DEEPLAB.BACKBONE_LR_MULT_LIST = [1.0, 1.0, 1.0, 1.0, 1.0] +cfg.MODEL.DEEPLAB.BACKBONE_LR_MULT_LIST = None ########################## UNET模型配置 ####################################### # 上采样方式, 默认为双线性插值