提交 d137720b 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix download to support pdparams

上级 e0f31392
...@@ -58,9 +58,9 @@ class RetryError(Exception): ...@@ -58,9 +58,9 @@ class RetryError(Exception):
super(RetryError, self).__init__(message) super(RetryError, self).__init__(message)
def _get_url(architecture): def _get_url(architecture, postfix="tar"):
prefix = "https://paddle-imagenet-models-name.bj.bcebos.com/" prefix = "https://paddle-imagenet-models-name.bj.bcebos.com/"
fname = architecture + "_pretrained.tar" fname = architecture + "_pretrained." + postfix
return prefix + fname return prefix + fname
...@@ -193,13 +193,13 @@ def list_models(): ...@@ -193,13 +193,13 @@ def list_models():
return return
def get(architecture, path, decompress=True): def get(architecture, path, decompress=True, postfix="tar"):
""" """
Get the pretrained model. Get the pretrained model.
""" """
_check_pretrained_name(architecture) _check_pretrained_name(architecture)
url = _get_url(architecture) url = _get_url(architecture, postfix=postfix)
fname = _download(url, path) fname = _download(url, path)
if decompress: if postfix == "tar" and decompress:
_decompress(fname) _decompress(fname)
logger.info("download {} finished ".format(fname)) logger.info("download {} finished ".format(fname))
...@@ -116,3 +116,4 @@ VGG16 ...@@ -116,3 +116,4 @@ VGG16
VGG19 VGG19
DarkNet53_ImageNet1k DarkNet53_ImageNet1k
ResNet50_ACNet_deploy ResNet50_ACNet_deploy
CSPResNet50_leaky
...@@ -24,6 +24,7 @@ def parse_args(): ...@@ -24,6 +24,7 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-a', '--architecture', type=str, default='ResNet50') parser.add_argument('-a', '--architecture', type=str, default='ResNet50')
parser.add_argument('-p', '--path', type=str, default='./pretrained/') parser.add_argument('-p', '--path', type=str, default='./pretrained/')
parser.add_argument('--postfix', type=str, default="tar")
parser.add_argument('-d', '--decompress', type=str2bool, default=True) parser.add_argument('-d', '--decompress', type=str2bool, default=True)
parser.add_argument('-l', '--list', type=str2bool, default=False) parser.add_argument('-l', '--list', type=str2bool, default=False)
...@@ -36,7 +37,8 @@ def main(): ...@@ -36,7 +37,8 @@ def main():
if args.list: if args.list:
model_zoo.list_models() model_zoo.list_models()
else: else:
model_zoo.get(args.architecture, args.path, args.decompress) model_zoo.get(args.architecture, args.path, args.decompress,
args.postfix)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册