From d137720b94caccc7d418985d5336cb1210794d79 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Mon, 18 May 2020 02:50:52 +0000 Subject: [PATCH] fix download to support pdparams --- ppcls/utils/model_zoo.py | 10 +++++----- ppcls/utils/pretrained.list | 1 + tools/download.py | 4 +++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/ppcls/utils/model_zoo.py b/ppcls/utils/model_zoo.py index dd65e921..d023f4d1 100644 --- a/ppcls/utils/model_zoo.py +++ b/ppcls/utils/model_zoo.py @@ -58,9 +58,9 @@ class RetryError(Exception): super(RetryError, self).__init__(message) -def _get_url(architecture): +def _get_url(architecture, postfix="tar"): prefix = "https://paddle-imagenet-models-name.bj.bcebos.com/" - fname = architecture + "_pretrained.tar" + fname = architecture + "_pretrained." + postfix return prefix + fname @@ -193,13 +193,13 @@ def list_models(): return -def get(architecture, path, decompress=True): +def get(architecture, path, decompress=True, postfix="tar"): """ Get the pretrained model. """ _check_pretrained_name(architecture) - url = _get_url(architecture) + url = _get_url(architecture, postfix=postfix) fname = _download(url, path) - if decompress: + if postfix == "tar" and decompress: _decompress(fname) logger.info("download {} finished ".format(fname)) diff --git a/ppcls/utils/pretrained.list b/ppcls/utils/pretrained.list index 633cafd9..91ae4409 100644 --- a/ppcls/utils/pretrained.list +++ b/ppcls/utils/pretrained.list @@ -116,3 +116,4 @@ VGG16 VGG19 DarkNet53_ImageNet1k ResNet50_ACNet_deploy +CSPResNet50_leaky diff --git a/tools/download.py b/tools/download.py index d9fe1a8e..35cf77a7 100644 --- a/tools/download.py +++ b/tools/download.py @@ -24,6 +24,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('-a', '--architecture', type=str, default='ResNet50') parser.add_argument('-p', '--path', type=str, default='./pretrained/') + parser.add_argument('--postfix', type=str, default="tar") parser.add_argument('-d', '--decompress', type=str2bool, default=True) parser.add_argument('-l', '--list', type=str2bool, default=False) @@ -36,7 +37,8 @@ def main(): if args.list: model_zoo.list_models() else: - model_zoo.get(args.architecture, args.path, args.decompress) + model_zoo.get(args.architecture, args.path, args.decompress, + args.postfix) if __name__ == '__main__': -- GitLab