pretrain_weights.py 12.5 KB
Newer Older
J
jiangjiajun 已提交
1
import paddlex
2
import paddlex.utils.logging as logging
J
jiangjiajun 已提交
3
import paddlehub as hub
J
jiangjiajun 已提交
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
import os
import os.path as osp

image_pretrain = {
    'ResNet18':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_pretrained.tar',
    'ResNet34':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_pretrained.tar',
    'ResNet50':
    'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar',
    'ResNet101':
    'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar',
    'ResNet50_vd':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar',
    'ResNet101_vd':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar',
S
sunyanfang01 已提交
20 21 22 23
    'ResNet50_vd_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar',
    'ResNet101_vd_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_ssld_pretrained.tar',
J
jiangjiajun 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
    'MobileNetV1':
    'http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar',
    'MobileNetV2_x1.0':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.tar',
    'MobileNetV2_x0.5':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_5_pretrained.tar',
    'MobileNetV2_x2.0':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x2_0_pretrained.tar',
    'MobileNetV2_x0.25':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_25_pretrained.tar',
    'MobileNetV2_x1.5':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x1_5_pretrained.tar',
    'MobileNetV3_small':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_pretrained.tar',
    'MobileNetV3_large':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_pretrained.tar',
S
sunyanfang01 已提交
40 41 42 43
    'MobileNetV3_small_x1_0_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_ssld_pretrained.tar',
    'MobileNetV3_large_x1_0_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar',
J
jiangjiajun 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    'DarkNet53':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_ImageNet1k_pretrained.tar',
    'DenseNet121':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet121_pretrained.tar',
    'DenseNet161':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet161_pretrained.tar',
    'DenseNet201':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet201_pretrained.tar',
    'DetResNet50':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar',
    'SegXception41':
    'https://paddle-imagenet-models-name.bj.bcebos.com/Xception41_deeplab_pretrained.tar',
    'SegXception65':
    'https://paddle-imagenet-models-name.bj.bcebos.com/Xception65_deeplab_pretrained.tar',
    'ShuffleNetV2':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_pretrained.tar',
60 61
    'HRNet_W18':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar',
F
FlyingQianMM 已提交
62 63 64 65 66 67
    'HRNet_W30':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W30_C_pretrained.tar',
    'HRNet_W32':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W32_C_pretrained.tar',
    'HRNet_W40':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W40_C_pretrained.tar',
68 69
    'HRNet_W44':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W44_C_pretrained.tar',
F
FlyingQianMM 已提交
70 71 72 73 74 75
    'HRNet_W48':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W48_C_pretrained.tar',
    'HRNet_W60':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W60_C_pretrained.tar',
    'HRNet_W64':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W64_C_pretrained.tar',
J
jiangjiajun 已提交
76 77
    'AlexNet':
    'http://paddle-imagenet-models-name.bj.bcebos.com/AlexNet_pretrained.tar'
J
jiangjiajun 已提交
78 79 80
}

coco_pretrain = {
81
    'YOLOv3_DarkNet53_COCO':
82
    'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_darknet.tar',
83
    'YOLOv3_MobileNetV1_COCO':
84
    'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar',
85
    'YOLOv3_MobileNetV3_large_COCO':
F
FlyingQianMM 已提交
86
    'https://bj.bcebos.com/paddlex/models/yolov3_mobilenet_v3.tar',
87
    'YOLOv3_ResNet34_COCO':
88
    'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar',
89
    'YOLOv3_ResNet50_vd_COCO':
90
    'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn.tar',
91 92
    'FasterRCNN_ResNet18_COCO':
    'https://bj.bcebos.com/paddlex/pretrained_weights/faster_rcnn_r18_fpn_1x.tar',
93
    'FasterRCNN_ResNet50_COCO':
94
    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_fpn_2x.tar',
95
    'FasterRCNN_ResNet50_vd_COCO':
96
    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_2x.tar',
97
    'FasterRCNN_ResNet101_COCO':
98
    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar',
99
    'FasterRCNN_ResNet101_vd_COCO':
100
    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_2x.tar',
101
    'FasterRCNN_HRNet_W18_COCO':
102
    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_hrnetv2p_w18_2x.tar',
103 104
    'MaskRCNN_ResNet18_COCO':
    'https://bj.bcebos.com/paddlex/pretrained_weights/mask_rcnn_r18_fpn_1x.tar',
105
    'MaskRCNN_ResNet50_COCO':
106
    'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_fpn_2x.tar',
107
    'MaskRCNN_ResNet50_vd_COCO':
108
    'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_vd_fpn_2x.tar',
109
    'MaskRCNN_ResNet101_COCO':
110
    'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_fpn_1x.tar',
111
    'MaskRCNN_ResNet101_vd_COCO':
112
    'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_vd_fpn_1x.tar',
F
FlyingQianMM 已提交
113 114
    'MaskRCNN_HRNet_W18_COCO':
    'https://bj.bcebos.com/paddlex/pretrained_weights/mask_rcnn_hrnetv2p_w18_2x.tar',
115 116
    'UNet_COCO': 'https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz',
    'DeepLabv3p_MobileNetV2_x1.0_COCO':
117
    'https://bj.bcebos.com/v1/paddleseg/deeplab_mobilenet_x1_0_coco.tgz',
118
    'DeepLabv3p_Xception65_COCO':
119 120 121 122
    'https://paddleseg.bj.bcebos.com/models/xception65_coco.tgz'
}

cityscapes_pretrain = {
123
    'DeepLabv3p_MobileNetV2_x1.0_CITYSCAPES':
124
    'https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz',
125
    'DeepLabv3p_Xception65_CITYSCAPES':
126
    'https://paddleseg.bj.bcebos.com/models/xception65_bn_cityscapes.tgz',
127
    'HRNet_W18_CITYSCAPES':
F
FlyingQianMM 已提交
128 129 130
    'https://paddleseg.bj.bcebos.com/models/hrnet_w18_bn_cityscapes.tgz',
    'FastSCNN_CITYSCAPES':
    'https://paddleseg.bj.bcebos.com/models/fast_scnn_cityscape.tar'
J
jiangjiajun 已提交
131 132 133
}


134
def get_pretrain_weights(flag, class_name, backbone, save_dir):
J
jiangjiajun 已提交
135 136 137 138
    if flag is None:
        return None
    elif osp.isdir(flag):
        return flag
139 140 141
    elif osp.isfile(flag):
        return flag
    warning_info = "{} does not support to be finetuned with weights pretrained on the {} dataset, so pretrain_weights is forced to be set to {}"
142
    if flag == 'COCO':
143 144 145 146
        if class_name == 'DeepLabv3p' and backbone in [
                'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',
                'MobileNetV2_x1.5', 'MobileNetV2_x2.0'
        ]:
147
            model_name = '{}_{}'.format(class_name, backbone)
148
            logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))
149 150
            flag = 'IMAGENET'
        elif class_name == 'HRNet':
151
            logging.warning(warning_info.format(class_name, flag, 'IMAGENET'))
152
            flag = 'IMAGENET'
F
FlyingQianMM 已提交
153
        elif class_name == 'FastSCNN':
F
FlyingQianMM 已提交
154 155
            logging.warning(
                warning_info.format(class_name, flag, 'CITYSCAPES'))
F
FlyingQianMM 已提交
156
            flag = 'CITYSCAPES'
157
    elif flag == 'CITYSCAPES':
158
        model_name = '{}_{}'.format(class_name, backbone)
159
        if class_name == 'UNet':
160
            logging.warning(warning_info.format(class_name, flag, 'COCO'))
161 162 163 164 165 166 167 168 169 170 171 172
            flag = 'COCO'
        if class_name == 'HRNet' and backbone.split('_')[
                -1] in ['W30', 'W32', 'W40', 'W48', 'W60', 'W64']:
            logging.warning(warning_info.format(backbone, flag, 'IMAGENET'))
            flag = 'IMAGENET'
        if class_name == 'DeepLabv3p' and backbone in [
                'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',
                'MobileNetV2_x1.5', 'MobileNetV2_x2.0'
        ]:
            model_name = '{}_{}'.format(class_name, backbone)
            logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))
            flag = 'IMAGENET'
F
FlyingQianMM 已提交
173 174 175 176 177
    elif flag == 'IMAGENET':
        if class_name == 'UNet':
            logging.warning(warning_info.format(class_name, flag, 'COCO'))
            flag = 'COCO'
        elif class_name == 'FastSCNN':
F
FlyingQianMM 已提交
178 179
            logging.warning(
                warning_info.format(class_name, flag, 'CITYSCAPES'))
F
FlyingQianMM 已提交
180
            flag = 'CITYSCAPES'
181 182

    if flag == 'IMAGENET':
J
jiangjiajun 已提交
183 184 185 186 187 188 189
        new_save_dir = save_dir
        if hasattr(paddlex, 'pretrain_dir'):
            new_save_dir = paddlex.pretrain_dir
        if backbone.startswith('Xception'):
            backbone = 'Seg{}'.format(backbone)
        elif backbone == 'MobileNetV2':
            backbone = 'MobileNetV2_x1.0'
S
sunyanfang01 已提交
190 191 192 193
        elif backbone == 'MobileNetV3_small_ssld':
            backbone = 'MobileNetV3_small_x1_0_ssld'
        elif backbone == 'MobileNetV3_large_ssld':
            backbone = 'MobileNetV3_large_x1_0_ssld'
194
        if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN']:
J
jiangjiajun 已提交
195 196 197 198
            if backbone == 'ResNet50':
                backbone = 'DetResNet50'
        assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
            backbone)
J
jiangjiajun 已提交
199

J
jiangjiajun 已提交
200 201 202 203 204
        #        if backbone == 'AlexNet':
        #            url = image_pretrain[backbone]
        #            fname = osp.split(url)[-1].split('.')[0]
        #            paddlex.utils.download_and_decompress(url, path=new_save_dir)
        #            return osp.join(new_save_dir, fname)
J
jiangjiajun 已提交
205
        try:
J
jiangjiajun 已提交
206 207
            logging.info(
                "Connecting PaddleHub server to get pretrain weights...")
J
jiangjiajun 已提交
208 209
            hub.download(backbone, save_path=new_save_dir)
        except Exception as e:
J
jiangjiajun 已提交
210 211 212 213
            logging.error(
                "Couldn't download pretrain weight, you can download it manualy from {} (decompress the file if it is a compressed file), and set pretrain weights by your self".
                format(image_pretrain[backbone]),
                exit=False)
J
jiangjiajun 已提交
214 215 216 217 218
            if isinstance(e, hub.ResourceNotFoundError):
                raise Exception("Resource for backbone {} not found".format(
                    backbone))
            elif isinstance(e, hub.ServerConnectionError):
                raise Exception(
219
                    "Cannot get reource for backbone {}, please check your internet connection"
J
jiangjiajun 已提交
220 221 222 223 224
                    .format(backbone))
            else:
                raise Exception(
                    "Unexpected error, please make sure paddlehub >= 1.6.2")
        return osp.join(new_save_dir, backbone)
225
    elif flag in ['COCO', 'CITYSCAPES']:
J
jiangjiajun 已提交
226 227 228
        new_save_dir = save_dir
        if hasattr(paddlex, 'pretrain_dir'):
            new_save_dir = paddlex.pretrain_dir
229
        if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN', 'DeepLabv3p']:
230
            backbone = '{}_{}'.format(class_name, backbone)
231
        backbone = "{}_{}".format(backbone, flag)
232 233 234 235
        if flag == 'COCO':
            url = coco_pretrain[backbone]
        elif flag == 'CITYSCAPES':
            url = cityscapes_pretrain[backbone]
J
jiangjiajun 已提交
236
        fname = osp.split(url)[-1].split('.')[0]
J
jiangjiajun 已提交
237 238 239
        #        paddlex.utils.download_and_decompress(url, path=new_save_dir)
        #        return osp.join(new_save_dir, fname)
        try:
J
jiangjiajun 已提交
240 241
            logging.info(
                "Connecting PaddleHub server to get pretrain weights...")
J
jiangjiajun 已提交
242 243
            hub.download(backbone, save_path=new_save_dir)
        except Exception as e:
J
jiangjiajun 已提交
244 245 246 247
            logging.error(
                "Couldn't download pretrain weight, you can download it manualy from {} (decompress the file if it is a compressed file), and set pretrain weights by your self".
                format(url),
                exit=False)
J
jiangjiajun 已提交
248 249 250 251 252
            if isinstance(hub.ResourceNotFoundError):
                raise Exception("Resource for backbone {} not found".format(
                    backbone))
            elif isinstance(hub.ServerConnectionError):
                raise Exception(
253
                    "Cannot get reource for backbone {}, please check your internet connection"
J
jiangjiajun 已提交
254 255 256 257 258
                    .format(backbone))
            else:
                raise Exception(
                    "Unexpected error, please make sure paddlehub >= 1.6.2")
        return osp.join(new_save_dir, backbone)
J
jiangjiajun 已提交
259
    else:
J
jiangjiajun 已提交
260 261
        logging.error("Path of retrain weights '{}' is not exists!".format(
            flag))