diff --git a/ppcls/utils/model_zoo.py b/ppcls/utils/model_zoo.py index dd65e921f3e258b75c911599ee234359f3fd7c51..d023f4d1fbc310d50ca56fa389632b41485a8174 100644 --- a/ppcls/utils/model_zoo.py +++ b/ppcls/utils/model_zoo.py @@ -58,9 +58,9 @@ class RetryError(Exception): super(RetryError, self).__init__(message) -def _get_url(architecture): +def _get_url(architecture, postfix="tar"): prefix = "https://paddle-imagenet-models-name.bj.bcebos.com/" - fname = architecture + "_pretrained.tar" + fname = architecture + "_pretrained." + postfix return prefix + fname @@ -193,13 +193,13 @@ def list_models(): return -def get(architecture, path, decompress=True): +def get(architecture, path, decompress=True, postfix="tar"): """ Get the pretrained model. """ _check_pretrained_name(architecture) - url = _get_url(architecture) + url = _get_url(architecture, postfix=postfix) fname = _download(url, path) - if decompress: + if postfix == "tar" and decompress: _decompress(fname) logger.info("download {} finished ".format(fname)) diff --git a/ppcls/utils/pretrained.list b/ppcls/utils/pretrained.list index 633cafd921d7390f434b2b5f82dad70129349658..91ae4409f9289b0634b4c6fa95ae3e1d75cc42aa 100644 --- a/ppcls/utils/pretrained.list +++ b/ppcls/utils/pretrained.list @@ -116,3 +116,4 @@ VGG16 VGG19 DarkNet53_ImageNet1k ResNet50_ACNet_deploy +CSPResNet50_leaky diff --git a/tools/download.py b/tools/download.py index d9fe1a8ee04a14b31cf4917e60f9348ae51b8d20..35cf77a725a9790c3cd2804ffd4e6ce1509b39de 100644 --- a/tools/download.py +++ b/tools/download.py @@ -24,6 +24,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('-a', '--architecture', type=str, default='ResNet50') parser.add_argument('-p', '--path', type=str, default='./pretrained/') + parser.add_argument('--postfix', type=str, default="tar") parser.add_argument('-d', '--decompress', type=str2bool, default=True) parser.add_argument('-l', '--list', type=str2bool, default=False) @@ -36,7 +37,8 @@ def main(): if args.list: model_zoo.list_models() else: - model_zoo.get(args.architecture, args.path, args.decompress) + model_zoo.get(args.architecture, args.path, args.decompress, + args.postfix) if __name__ == '__main__':