pretrain_weights.py 6.1 KB
Newer Older
J
jiangjiajun 已提交
1
import paddlex
J
jiangjiajun 已提交
2
import paddlehub as hub
J
jiangjiajun 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
import os
import os.path as osp

image_pretrain = {
    'ResNet18':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_pretrained.tar',
    'ResNet34':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_pretrained.tar',
    'ResNet50':
    'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar',
    'ResNet101':
    'http://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_pretrained.tar',
    'ResNet50_vd':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar',
    'ResNet101_vd':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar',
S
sunyanfang01 已提交
19 20 21 22
    'ResNet50_vd_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar',
    'ResNet101_vd_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_ssld_pretrained.tar',
J
jiangjiajun 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
    'MobileNetV1':
    'http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar',
    'MobileNetV2_x1.0':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_pretrained.tar',
    'MobileNetV2_x0.5':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_5_pretrained.tar',
    'MobileNetV2_x2.0':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x2_0_pretrained.tar',
    'MobileNetV2_x0.25':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x0_25_pretrained.tar',
    'MobileNetV2_x1.5':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV2_x1_5_pretrained.tar',
    'MobileNetV3_small':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_pretrained.tar',
    'MobileNetV3_large':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_pretrained.tar',
S
sunyanfang01 已提交
39 40 41 42
    'MobileNetV3_small_x1_0_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_ssld_pretrained.tar',
    'MobileNetV3_large_x1_0_ssld':
    'https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar',
J
jiangjiajun 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    'DarkNet53':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_ImageNet1k_pretrained.tar',
    'DenseNet121':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet121_pretrained.tar',
    'DenseNet161':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet161_pretrained.tar',
    'DenseNet201':
    'https://paddle-imagenet-models-name.bj.bcebos.com/DenseNet201_pretrained.tar',
    'DetResNet50':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar',
    'SegXception41':
    'https://paddle-imagenet-models-name.bj.bcebos.com/Xception41_deeplab_pretrained.tar',
    'SegXception65':
    'https://paddle-imagenet-models-name.bj.bcebos.com/Xception65_deeplab_pretrained.tar',
    'ShuffleNetV2':
    'https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_pretrained.tar',
}

coco_pretrain = {
    'UNet': 'https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz'
}


def get_pretrain_weights(flag, model_type, backbone, save_dir):
    if flag is None:
        return None
    elif osp.isdir(flag):
        return flag
    elif flag == 'IMAGENET':
        new_save_dir = save_dir
        if hasattr(paddlex, 'pretrain_dir'):
            new_save_dir = paddlex.pretrain_dir
        if backbone.startswith('Xception'):
            backbone = 'Seg{}'.format(backbone)
        elif backbone == 'MobileNetV2':
            backbone = 'MobileNetV2_x1.0'
S
sunyanfang01 已提交
79 80 81 82
        elif backbone == 'MobileNetV3_small_ssld':
            backbone = 'MobileNetV3_small_x1_0_ssld'
        elif backbone == 'MobileNetV3_large_ssld':
            backbone = 'MobileNetV3_large_x1_0_ssld'
J
jiangjiajun 已提交
83 84 85 86 87
        if model_type == 'detector':
            if backbone == 'ResNet50':
                backbone = 'DetResNet50'
        assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format(
            backbone)
J
jiangjiajun 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
        #        url = image_pretrain[backbone]
        #        fname = osp.split(url)[-1].split('.')[0]
        #        paddlex.utils.download_and_decompress(url, path=new_save_dir)
        #        return osp.join(new_save_dir, fname)
        try:
            hub.download(backbone, save_path=new_save_dir)
        except Exception as e:
            if isinstance(e, hub.ResourceNotFoundError):
                raise Exception("Resource for backbone {} not found".format(
                    backbone))
            elif isinstance(e, hub.ServerConnectionError):
                raise Exception(
                    "Cannot get reource for backbone {}, please check your internet connecgtion"
                    .format(backbone))
            else:
                raise Exception(
                    "Unexpected error, please make sure paddlehub >= 1.6.2")
        return osp.join(new_save_dir, backbone)
J
jiangjiajun 已提交
106 107 108 109
    elif flag == 'COCO':
        new_save_dir = save_dir
        if hasattr(paddlex, 'pretrain_dir'):
            new_save_dir = paddlex.pretrain_dir
J
jiangjiajun 已提交
110 111
        url = coco_pretrain[backbone]
        fname = osp.split(url)[-1].split('.')[0]
J
jiangjiajun 已提交
112 113
        #        paddlex.utils.download_and_decompress(url, path=new_save_dir)
        #        return osp.join(new_save_dir, fname)
J
jiangjiajun 已提交
114

J
jiangjiajun 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
        assert backbone in coco_pretrain, "There is not COCO pretrain weights for {}, you may try ImageNet.".format(
            backbone)
        try:
            hub.download(backbone, save_path=new_save_dir)
        except Exception as e:
            if isinstance(hub.ResourceNotFoundError):
                raise Exception("Resource for backbone {} not found".format(
                    backbone))
            elif isinstance(hub.ServerConnectionError):
                raise Exception(
                    "Cannot get reource for backbone {}, please check your internet connecgtion"
                    .format(backbone))
            else:
                raise Exception(
                    "Unexpected error, please make sure paddlehub >= 1.6.2")
        return osp.join(new_save_dir, backbone)
J
jiangjiajun 已提交
131 132 133 134
    else:
        raise Exception(
            "pretrain_weights need to be defined as directory path or `IMAGENET` or 'COCO' (download pretrain weights automatically)."
        )