pretrain_weights.py 11.0 KB
Newer Older
J
jiangjiajun 已提交
1
import paddlex
2
import paddlex.utils.logging as logging
J
jiangjiajun 已提交
3
import paddlehub as hub
J
jiangjiajun 已提交
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
import os
import os.path as osp

image_pretrain = {
    'ResNet18':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_pretrained.tar',
    'ResNet34':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_pretrained.tar',
    'ResNet50':
    'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar',
    'ResNet101':
    'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar',
    'ResNet50_vd':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar',
    'ResNet101_vd':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar',
S
sunyanfang01 已提交
20 21 22 23
    'ResNet50_vd_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar',
    'ResNet101_vd_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_ssld_pretrained.tar',
J
jiangjiajun 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
    'MobileNetV1':
    'http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar',
    'MobileNetV2_x1.0':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.tar',
    'MobileNetV2_x0.5':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_5_pretrained.tar',
    'MobileNetV2_x2.0':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x2_0_pretrained.tar',
    'MobileNetV2_x0.25':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_25_pretrained.tar',
    'MobileNetV2_x1.5':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x1_5_pretrained.tar',
    'MobileNetV3_small':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_pretrained.tar',
    'MobileNetV3_large':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_pretrained.tar',
S
sunyanfang01 已提交
40 41 42 43
    'MobileNetV3_small_x1_0_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_ssld_pretrained.tar',
    'MobileNetV3_large_x1_0_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar',
J
jiangjiajun 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    'DarkNet53':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_ImageNet1k_pretrained.tar',
    'DenseNet121':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet121_pretrained.tar',
    'DenseNet161':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet161_pretrained.tar',
    'DenseNet201':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet201_pretrained.tar',
    'DetResNet50':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar',
    'SegXception41':
    'https://paddle-imagenet-models-name.bj.bcebos.com/Xception41_deeplab_pretrained.tar',
    'SegXception65':
    'https://paddle-imagenet-models-name.bj.bcebos.com/Xception65_deeplab_pretrained.tar',
    'ShuffleNetV2':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_pretrained.tar',
60 61
    'HRNet_W18':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar',
F
FlyingQianMM 已提交
62 63 64 65 66 67 68 69 70 71 72 73
    'HRNet_W30':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W30_C_pretrained.tar',
    'HRNet_W32':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W32_C_pretrained.tar',
    'HRNet_W40':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W40_C_pretrained.tar',
    'HRNet_W48':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W48_C_pretrained.tar',
    'HRNet_W60':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W60_C_pretrained.tar',
    'HRNet_W64':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W64_C_pretrained.tar',
J
jiangjiajun 已提交
74 75 76
}

coco_pretrain = {
77
    'YOLOv3_DarkNet53_COCO':
78
    'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_darknet.tar',
79
    'YOLOv3_MobileNetV1_COCO':
80
    'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar',
81
    'YOLOv3_MobileNetV3_large_COCO':
82
    'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v3.pdparams',
83
    'YOLOv3_ResNet34_COCO':
84
    'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar',
85
    'YOLOv3_ResNet50_vd_COCO':
86
    'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r50vd_dcn.tar',
87
    'FasterRCNN_ResNet50_COCO':
88
    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_fpn_2x.tar',
89
    'FasterRCNN_ResNet50_vd_COCO':
90
    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_2x.tar',
91
    'FasterRCNN_ResNet101_COCO':
92
    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar',
93
    'FasterRCNN_ResNet101_vd_COCO':
94
    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_2x.tar',
95
    'FasterRCNN_HRNet_W18_COCO':
96
    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_hrnetv2p_w18_2x.tar',
97
    'MaskRCNN_ResNet50_COCO':
98
    'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_fpn_2x.tar',
99
    'MaskRCNN_ResNet50_vd_COCO':
100
    'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_vd_fpn_2x.tar',
101
    'MaskRCNN_ResNet101_COCO':
102
    'https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar',
103
    'MaskRCNN_ResNet101_vd_COCO':
104
    'https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_vd_fpn_1x.tar',
105 106
    'UNet_COCO': 'https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz',
    'DeepLabv3p_MobileNetV2_x1.0_COCO':
107
    'https://bj.bcebos.com/v1/paddleseg/deeplab_mobilenet_x1_0_coco.tgz',
108
    'DeepLabv3p_Xception65_COCO':
109 110 111 112
    'https://paddleseg.bj.bcebos.com/models/xception65_coco.tgz'
}

cityscapes_pretrain = {
113
    'DeepLabv3p_MobileNetV2_x1.0_CITYSCAPES':
114
    'https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz',
115
    'DeepLabv3p_Xception65_CITYSCAPES':
116
    'https://paddleseg.bj.bcebos.com/models/xception65_bn_cityscapes.tgz',
117
    'HRNet_W18_CITYSCAPES':
118
    'https://paddleseg.bj.bcebos.com/models/hrnet_w18_bn_cityscapes.tgz'
J
jiangjiajun 已提交
119 120 121
}


122
def get_pretrain_weights(flag, class_name, backbone, save_dir):
J
jiangjiajun 已提交
123 124 125 126
    if flag is None:
        return None
    elif osp.isdir(flag):
        return flag
127 128 129
    elif osp.isfile(flag):
        return flag
    warning_info = "{} does not support to be finetuned with weights pretrained on the {} dataset, so pretrain_weights is forced to be set to {}"
130 131 132 133 134
    if flag == 'COCO':
        if class_name == "FasterRCNN" and backbone in ['ResNet18'] or \
            class_name == "MaskRCNN" and backbone in ['ResNet18', 'HRNet_W18'] or \
            class_name == 'DeepLabv3p' and backbone in ['Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0']:
            model_name = '{}_{}'.format(class_name, backbone)
135
            logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))
136 137
            flag = 'IMAGENET'
        elif class_name == 'HRNet':
138
            logging.warning(warning_info.format(class_name, flag, 'IMAGENET'))
139
            flag = 'IMAGENET'
140
    elif flag == 'CITYSCAPES':
141
        model_name = '{}_{}'.format(class_name, backbone)
142
        if class_name == 'UNet':
143
            logging.warning(warning_info.format(class_name, flag, 'COCO'))
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
            flag = 'COCO'
        if class_name == 'HRNet' and backbone.split('_')[
                -1] in ['W30', 'W32', 'W40', 'W48', 'W60', 'W64']:
            logging.warning(warning_info.format(backbone, flag, 'IMAGENET'))
            flag = 'IMAGENET'
        if class_name == 'DeepLabv3p' and backbone in [
                'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',
                'MobileNetV2_x1.5', 'MobileNetV2_x2.0'
        ]:
            model_name = '{}_{}'.format(class_name, backbone)
            logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))
            flag = 'IMAGENET'
    elif flag == 'IMAGENET' and class_name == 'UNet':
        logging.warning(warning_info.format(class_name, flag, 'COCO'))
        flag = 'COCO'
159 160

    if flag == 'IMAGENET':
J
jiangjiajun 已提交
161 162 163 164 165 166 167
        new_save_dir = save_dir
        if hasattr(paddlex, 'pretrain_dir'):
            new_save_dir = paddlex.pretrain_dir
        if backbone.startswith('Xception'):
            backbone = 'Seg{}'.format(backbone)
        elif backbone == 'MobileNetV2':
            backbone = 'MobileNetV2_x1.0'
S
sunyanfang01 已提交
168 169 170 171
        elif backbone == 'MobileNetV3_small_ssld':
            backbone = 'MobileNetV3_small_x1_0_ssld'
        elif backbone == 'MobileNetV3_large_ssld':
            backbone = 'MobileNetV3_large_x1_0_ssld'
172
        if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN']:
J
jiangjiajun 已提交
173 174 175 176
            if backbone == 'ResNet50':
                backbone = 'DetResNet50'
        assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
            backbone)
177 178 179 180
        #        url = image_pretrain[backbone]
        #        fname = osp.split(url)[-1].split('.')[0]
        #        paddlex.utils.download_and_decompress(url, path=new_save_dir)
        #        return osp.join(new_save_dir, fname)
J
jiangjiajun 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193 194
        try:
            hub.download(backbone, save_path=new_save_dir)
        except Exception as e:
            if isinstance(e, hub.ResourceNotFoundError):
                raise Exception("Resource for backbone {} not found".format(
                    backbone))
            elif isinstance(e, hub.ServerConnectionError):
                raise Exception(
                    "Cannot get reource for backbone {}, please check your internet connecgtion"
                    .format(backbone))
            else:
                raise Exception(
                    "Unexpected error, please make sure paddlehub >= 1.6.2")
        return osp.join(new_save_dir, backbone)
195
    elif flag in ['COCO', 'CITYSCAPES']:
J
jiangjiajun 已提交
196 197 198
        new_save_dir = save_dir
        if hasattr(paddlex, 'pretrain_dir'):
            new_save_dir = paddlex.pretrain_dir
199
        if class_name in ['YOLOv3', 'FasterRCNN', 'MaskRCNN', 'DeepLabv3p']:
200
            backbone = '{}_{}'.format(class_name, backbone)
201
        backbone = "{}_{}".format(backbone, flag)
202 203 204 205
        if flag == 'COCO':
            url = coco_pretrain[backbone]
        elif flag == 'CITYSCAPES':
            url = cityscapes_pretrain[backbone]
J
jiangjiajun 已提交
206
        fname = osp.split(url)[-1].split('.')[0]
J
jiangjiajun 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
        #        paddlex.utils.download_and_decompress(url, path=new_save_dir)
        #        return osp.join(new_save_dir, fname)
        try:
            hub.download(backbone, save_path=new_save_dir)
        except Exception as e:
            if isinstance(hub.ResourceNotFoundError):
                raise Exception("Resource for backbone {} not found".format(
                    backbone))
            elif isinstance(hub.ServerConnectionError):
                raise Exception(
                    "Cannot get reource for backbone {}, please check your internet connecgtion"
                    .format(backbone))
            else:
                raise Exception(
                    "Unexpected error, please make sure paddlehub >= 1.6.2")
        return osp.join(new_save_dir, backbone)
J
jiangjiajun 已提交
223 224
    else:
        raise Exception(
225
            "pretrain_weights need to be defined as directory path or 'IMAGENET' or 'COCO' or 'Cityscapes' (download pretrain weights automatically)."
J
jiangjiajun 已提交
226
        )