From cbd52309cc2e8f705fb007d35f7dfda5b3266f64 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Thu, 5 Sep 2019 15:26:14 +0800 Subject: [PATCH] Fix download bug --- pretrained_model/download_model.py | 2 +- test/test_utils.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/pretrained_model/download_model.py b/pretrained_model/download_model.py index 7f94f218..b2bde566 100644 --- a/pretrained_model/download_model.py +++ b/pretrained_model/download_model.py @@ -46,7 +46,7 @@ model_urls = { "unet_bn_coco": "https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz", # Cityscapes pretrained - "deeplabv3plus_mobilenetv2-1-0_bn_cityscapes": + "deeplabv3p_mobilenetv2-1-0_bn_cityscapes": "https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz", "deeplabv3p_xception65_gn_cityscapes": "https://paddleseg.bj.bcebos.com/models/deeplabv3p_xception65_cityscapes.tgz", diff --git a/test/test_utils.py b/test/test_utils.py index 304003d8..ed1f8ed8 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -84,7 +84,7 @@ def _uncompress_file(filepath, extrapath, delete_file, print_progress): else: handler = functools.partial(_uncompress_file_tar, mode="r") - for total_num, index in handler(filepath, extrapath): + for total_num, index, rootpath in handler(filepath, extrapath): if print_progress: done = int(50 * float(index) / total_num) progress( @@ -95,27 +95,31 @@ def _uncompress_file(filepath, extrapath, delete_file, print_progress): if delete_file: os.remove(filepath) + return rootpath + def _uncompress_file_zip(filepath, extrapath): files = zipfile.ZipFile(filepath, 'r') filelist = files.namelist() + rootpath = filelist[0] total_num = len(filelist) for index, file in enumerate(filelist): files.extract(file, extrapath) - yield total_num, index + yield total_num, index, rootpath files.close() - yield total_num, index + yield total_num, index, rootpath def _uncompress_file_tar(filepath, extrapath, mode="r:gz"): files = tarfile.open(filepath, mode) filelist = files.getnames() total_num = len(filelist) + rootpath = filelist[0] for index, file in enumerate(filelist): files.extract(file, extrapath) - yield total_num, index + yield total_num, index, rootpath files.close() - yield total_num, index + yield total_num, index, rootpath def download_file_and_uncompress(url, @@ -150,7 +154,9 @@ def download_file_and_uncompress(url, if not os.path.exists(savename): if not os.path.exists(savepath): _download_file(url, savepath, print_progress) - _uncompress_file(savepath, extrapath, delete_file, print_progress) + savename = _uncompress_file(savepath, extrapath, delete_file, + print_progress) + savename = os.path.join(extrapath, savename) shutil.move(savename, extraname) -- GitLab