提交 cbd52309 编写于 作者: W wuzewu

Fix download bug

上级 e9b1393b
...@@ -46,7 +46,7 @@ model_urls = { ...@@ -46,7 +46,7 @@ model_urls = {
"unet_bn_coco": "https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz", "unet_bn_coco": "https://paddleseg.bj.bcebos.com/models/unet_coco_v3.tgz",
# Cityscapes pretrained # Cityscapes pretrained
"deeplabv3plus_mobilenetv2-1-0_bn_cityscapes": "deeplabv3p_mobilenetv2-1-0_bn_cityscapes":
"https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz", "https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz",
"deeplabv3p_xception65_gn_cityscapes": "deeplabv3p_xception65_gn_cityscapes":
"https://paddleseg.bj.bcebos.com/models/deeplabv3p_xception65_cityscapes.tgz", "https://paddleseg.bj.bcebos.com/models/deeplabv3p_xception65_cityscapes.tgz",
......
...@@ -84,7 +84,7 @@ def _uncompress_file(filepath, extrapath, delete_file, print_progress): ...@@ -84,7 +84,7 @@ def _uncompress_file(filepath, extrapath, delete_file, print_progress):
else: else:
handler = functools.partial(_uncompress_file_tar, mode="r") handler = functools.partial(_uncompress_file_tar, mode="r")
for total_num, index in handler(filepath, extrapath): for total_num, index, rootpath in handler(filepath, extrapath):
if print_progress: if print_progress:
done = int(50 * float(index) / total_num) done = int(50 * float(index) / total_num)
progress( progress(
...@@ -95,27 +95,31 @@ def _uncompress_file(filepath, extrapath, delete_file, print_progress): ...@@ -95,27 +95,31 @@ def _uncompress_file(filepath, extrapath, delete_file, print_progress):
if delete_file: if delete_file:
os.remove(filepath) os.remove(filepath)
return rootpath
def _uncompress_file_zip(filepath, extrapath): def _uncompress_file_zip(filepath, extrapath):
files = zipfile.ZipFile(filepath, 'r') files = zipfile.ZipFile(filepath, 'r')
filelist = files.namelist() filelist = files.namelist()
rootpath = filelist[0]
total_num = len(filelist) total_num = len(filelist)
for index, file in enumerate(filelist): for index, file in enumerate(filelist):
files.extract(file, extrapath) files.extract(file, extrapath)
yield total_num, index yield total_num, index, rootpath
files.close() files.close()
yield total_num, index yield total_num, index, rootpath
def _uncompress_file_tar(filepath, extrapath, mode="r:gz"): def _uncompress_file_tar(filepath, extrapath, mode="r:gz"):
files = tarfile.open(filepath, mode) files = tarfile.open(filepath, mode)
filelist = files.getnames() filelist = files.getnames()
total_num = len(filelist) total_num = len(filelist)
rootpath = filelist[0]
for index, file in enumerate(filelist): for index, file in enumerate(filelist):
files.extract(file, extrapath) files.extract(file, extrapath)
yield total_num, index yield total_num, index, rootpath
files.close() files.close()
yield total_num, index yield total_num, index, rootpath
def download_file_and_uncompress(url, def download_file_and_uncompress(url,
...@@ -150,7 +154,9 @@ def download_file_and_uncompress(url, ...@@ -150,7 +154,9 @@ def download_file_and_uncompress(url,
if not os.path.exists(savename): if not os.path.exists(savename):
if not os.path.exists(savepath): if not os.path.exists(savepath):
_download_file(url, savepath, print_progress) _download_file(url, savepath, print_progress)
_uncompress_file(savepath, extrapath, delete_file, print_progress) savename = _uncompress_file(savepath, extrapath, delete_file,
print_progress)
savename = os.path.join(extrapath, savename)
shutil.move(savename, extraname) shutil.move(savename, extraname)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册