pretrain_weights.py 6.8 KB
Newer Older
J
jiangjiajun 已提交
1
import paddlex
J
jiangjiajun 已提交
2
import paddlehub as hub
J
jiangjiajun 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
import os
import os.path as osp

image_pretrain = {
    'ResNet18':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_pretrained.tar',
    'ResNet34':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_pretrained.tar',
    'ResNet50':
    'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar',
    'ResNet101':
    'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar',
    'ResNet50_vd':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar',
    'ResNet101_vd':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar',
S
sunyanfang01 已提交
19 20 21 22
    'ResNet50_vd_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar',
    'ResNet101_vd_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_ssld_pretrained.tar',
J
jiangjiajun 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
    'MobileNetV1':
    'http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar',
    'MobileNetV2_x1.0':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.tar',
    'MobileNetV2_x0.5':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_5_pretrained.tar',
    'MobileNetV2_x2.0':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x2_0_pretrained.tar',
    'MobileNetV2_x0.25':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_25_pretrained.tar',
    'MobileNetV2_x1.5':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x1_5_pretrained.tar',
    'MobileNetV3_small':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_pretrained.tar',
    'MobileNetV3_large':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_pretrained.tar',
S
sunyanfang01 已提交
39 40 41 42
    'MobileNetV3_small_x1_0_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_ssld_pretrained.tar',
    'MobileNetV3_large_x1_0_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar',
J
jiangjiajun 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
    'DarkNet53':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_ImageNet1k_pretrained.tar',
    'DenseNet121':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet121_pretrained.tar',
    'DenseNet161':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet161_pretrained.tar',
    'DenseNet201':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet201_pretrained.tar',
    'DetResNet50':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar',
    'SegXception41':
    'https://paddle-imagenet-models-name.bj.bcebos.com/Xception41_deeplab_pretrained.tar',
    'SegXception65':
    'https://paddle-imagenet-models-name.bj.bcebos.com/Xception65_deeplab_pretrained.tar',
    'ShuffleNetV2':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_pretrained.tar',
59 60
    'HRNet_W18':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar',
F
FlyingQianMM 已提交
61 62 63 64 65 66 67 68 69 70 71 72
    'HRNet_W30':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W30_C_pretrained.tar',
    'HRNet_W32':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W32_C_pretrained.tar',
    'HRNet_W40':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W40_C_pretrained.tar',
    'HRNet_W48':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W48_C_pretrained.tar',
    'HRNet_W60':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W60_C_pretrained.tar',
    'HRNet_W64':
    'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W64_C_pretrained.tar',
J
jiangjiajun 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
}

coco_pretrain = {
    'UNet': 'https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz'
}


def get_pretrain_weights(flag, model_type, backbone, save_dir):
    if flag is None:
        return None
    elif osp.isdir(flag):
        return flag
    elif flag == 'IMAGENET':
        new_save_dir = save_dir
        if hasattr(paddlex, 'pretrain_dir'):
            new_save_dir = paddlex.pretrain_dir
        if backbone.startswith('Xception'):
            backbone = 'Seg{}'.format(backbone)
        elif backbone == 'MobileNetV2':
            backbone = 'MobileNetV2_x1.0'
S
sunyanfang01 已提交
93 94 95 96
        elif backbone == 'MobileNetV3_small_ssld':
            backbone = 'MobileNetV3_small_x1_0_ssld'
        elif backbone == 'MobileNetV3_large_ssld':
            backbone = 'MobileNetV3_large_x1_0_ssld'
J
jiangjiajun 已提交
97 98 99 100 101
        if model_type == 'detector':
            if backbone == 'ResNet50':
                backbone = 'DetResNet50'
        assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
            backbone)
102 103 104 105
        #        url = image_pretrain[backbone]
        #        fname = osp.split(url)[-1].split('.')[0]
        #        paddlex.utils.download_and_decompress(url, path=new_save_dir)
        #        return osp.join(new_save_dir, fname)
J
jiangjiajun 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119
        try:
            hub.download(backbone, save_path=new_save_dir)
        except Exception as e:
            if isinstance(e, hub.ResourceNotFoundError):
                raise Exception("Resource for backbone {} not found".format(
                    backbone))
            elif isinstance(e, hub.ServerConnectionError):
                raise Exception(
                    "Cannot get reource for backbone {}, please check your internet connecgtion"
                    .format(backbone))
            else:
                raise Exception(
                    "Unexpected error, please make sure paddlehub >= 1.6.2")
        return osp.join(new_save_dir, backbone)
J
jiangjiajun 已提交
120 121 122 123
    elif flag == 'COCO':
        new_save_dir = save_dir
        if hasattr(paddlex, 'pretrain_dir'):
            new_save_dir = paddlex.pretrain_dir
J
jiangjiajun 已提交
124 125
        url = coco_pretrain[backbone]
        fname = osp.split(url)[-1].split('.')[0]
J
jiangjiajun 已提交
126 127
        #        paddlex.utils.download_and_decompress(url, path=new_save_dir)
        #        return osp.join(new_save_dir, fname)
J
jiangjiajun 已提交
128

J
jiangjiajun 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
        assert backbone in coco_pretrain, "There is not COCO pretrain weights for {}, you may try ImageNet.".format(
            backbone)
        try:
            hub.download(backbone, save_path=new_save_dir)
        except Exception as e:
            if isinstance(hub.ResourceNotFoundError):
                raise Exception("Resource for backbone {} not found".format(
                    backbone))
            elif isinstance(hub.ServerConnectionError):
                raise Exception(
                    "Cannot get reource for backbone {}, please check your internet connecgtion"
                    .format(backbone))
            else:
                raise Exception(
                    "Unexpected error, please make sure paddlehub >= 1.6.2")
        return osp.join(new_save_dir, backbone)
J
jiangjiajun 已提交
145 146 147 148
    else:
        raise Exception(
            "pretrain_weights need to be defined as directory path or `IMAGENET` or 'COCO' (download pretrain weights automatically)."
        )