diff --git a/paddlex/cv/models/utils/pretrain_weights.py b/paddlex/cv/models/utils/pretrain_weights.py index 3f41838b7d3e1529558ced1db23e84292bdd5270..81790a20144d8c255601b8a778eebf02c409c55d 100644 --- a/paddlex/cv/models/utils/pretrain_weights.py +++ b/paddlex/cv/models/utils/pretrain_weights.py @@ -1,5 +1,5 @@ import paddlex -#import paddlehub as hub +import paddlehub as hub import os import os.path as osp @@ -85,53 +85,49 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir): backbone = 'DetResNet50' assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format( backbone) - url = image_pretrain[backbone] - fname = osp.split(url)[-1].split('.')[0] - paddlex.utils.download_and_decompress(url, path=new_save_dir) - return osp.join(new_save_dir, fname) -# try: -# hub.download(backbone, save_path=new_save_dir) -# except Exception as e: -# if isinstance(e, hub.ResourceNotFoundError): -# raise Exception( -# "Resource for backbone {} not found".format(backbone)) -# elif isinstance(e, hub.ServerConnectionError): -# raise Exception( -# "Cannot get reource for backbone {}, please check your internet connecgtion" -# .format(backbone)) -# else: -# raise Exception( -# "Unexpected error, please make sure paddlehub >= 1.6.2") -# return osp.join(new_save_dir, backbone) + # url = image_pretrain[backbone] + # fname = osp.split(url)[-1].split('.')[0] + # paddlex.utils.download_and_decompress(url, path=new_save_dir) + # return osp.join(new_save_dir, fname) + try: + hub.download(backbone, save_path=new_save_dir) + except Exception as e: + if isinstance(e, hub.ResourceNotFoundError): + raise Exception("Resource for backbone {} not found".format( + backbone)) + elif isinstance(e, hub.ServerConnectionError): + raise Exception( + "Cannot get reource for backbone {}, please check your internet connecgtion" + .format(backbone)) + else: + raise Exception( + "Unexpected error, please make sure paddlehub >= 1.6.2") + return osp.join(new_save_dir, backbone) elif flag == 'COCO': new_save_dir = save_dir if hasattr(paddlex, 'pretrain_dir'): new_save_dir = paddlex.pretrain_dir url = coco_pretrain[backbone] fname = osp.split(url)[-1].split('.')[0] - paddlex.utils.download_and_decompress(url, path=new_save_dir) - return osp.join(new_save_dir, fname) - + # paddlex.utils.download_and_decompress(url, path=new_save_dir) + # return osp.join(new_save_dir, fname) -# new_save_dir = save_dir -# if hasattr(paddlex, 'pretrain_dir'): -# new_save_dir = paddlex.pretrain_dir -# assert backbone in coco_pretrain, "There is not COCO pretrain weights for {}, you may try ImageNet.".format( -# backbone) -# try: -# hub.download(backbone, save_path=new_save_dir) -# except Exception as e: -# if isinstance(hub.ResourceNotFoundError): -# raise Exception( -# "Resource for backbone {} not found".format(backbone)) -# elif isinstance(hub.ServerConnectionError): -# raise Exception( -# "Cannot get reource for backbone {}, please check your internet connecgtion" -# .format(backbone)) -# else: -# raise Exception( -# "Unexpected error, please make sure paddlehub >= 1.6.2") -# return osp.join(new_save_dir, backbone) + assert backbone in coco_pretrain, "There is not COCO pretrain weights for {}, you may try ImageNet.".format( + backbone) + try: + hub.download(backbone, save_path=new_save_dir) + except Exception as e: + if isinstance(hub.ResourceNotFoundError): + raise Exception("Resource for backbone {} not found".format( + backbone)) + elif isinstance(hub.ServerConnectionError): + raise Exception( + "Cannot get reource for backbone {}, please check your internet connecgtion" + .format(backbone)) + else: + raise Exception( + "Unexpected error, please make sure paddlehub >= 1.6.2") + return osp.join(new_save_dir, backbone) else: raise Exception( "pretrain_weights need to be defined as directory path or `IMAGENET` or 'COCO' (download pretrain weights automatically)."