未验证 提交 467099f0 编写于 作者: C CtfGo 提交者: GitHub

Speedup download uncompress function (#37311)

`paddle.utils.download` :change to call `extractall` on tar/zip compressd file  to speed up the uncompress process when they includes many files

--- result of decompression speed comparison ---
1. dataset:https://paddlenlp.bj.bcebos.com/datasets/cnn_dailymail/cnn_stories.tgz, decompression time
:5m50s vs 20s
2. dataset:https://paddlenlp.bj.bcebos.com/datasets/cnn_dailymail/dailymail_stories.tgz, decompression time:33m20s vs 47s
上级 2dfcdf21
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from paddle.utils.download import get_weights_path_from_url
......@@ -70,6 +71,36 @@ class TestDownload(unittest.TestCase):
for url in urls:
get_path_from_url(url, root_dir='./test')
def test_uncompress_result(self):
results = [
[
"files/single_dir/file1", "files/single_dir/file2",
"files/single_file.pdparams"
],
["single_dir/file1", "single_dir/file2"],
["single_file.pdparams"],
]
tar_urls = [
"https://paddle-hapi.bj.bcebos.com/unittest/files.tar",
"https://paddle-hapi.bj.bcebos.com/unittest/single_dir.tar",
"https://paddle-hapi.bj.bcebos.com/unittest/single_file.tar",
]
for url, uncompressd_res in zip(tar_urls, results):
uncompressed_path = get_path_from_url(url, root_dir='./test_tar')
self.assertTrue(all([os.path.exists(os.path.join("./test_tar", filepath)) \
for filepath in uncompressd_res]))
zip_urls = [
"https://paddle-hapi.bj.bcebos.com/unittest/files.zip",
"https://paddle-hapi.bj.bcebos.com/unittest/single_dir.zip",
"https://paddle-hapi.bj.bcebos.com/unittest/single_file.zip",
]
for url, uncompressd_res in zip(zip_urls, results):
uncompressed_path = get_path_from_url(url, root_dir='./test_zip')
self.assertTrue(all([os.path.exists(os.path.join("./test_zip", filepath)) \
for filepath in uncompressd_res]))
def test_retry_exception(self, ):
with self.assertRaises(RuntimeError):
from paddle.utils.download import _download
......
......@@ -79,7 +79,7 @@ def get_weights_path_from_url(url, md5sum=None):
Args:
url (str): download url
md5sum (str): md5 sum of download package
Returns:
str: a local path to save downloaded weights.
......@@ -146,8 +146,8 @@ def get_path_from_url(url,
assert is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir
fullpath = _map_path(url, root_dir)
# Mainly used to solve the problem of downloading data from different
# machines in the case of multiple machines. Different ips will download
# Mainly used to solve the problem of downloading data from different
# machines in the case of multiple machines. Different ips will download
# data, and the same ip will only download data once.
unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:])
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
......@@ -302,70 +302,61 @@ def _decompress(fname):
def _uncompress_file_zip(filepath):
files = zipfile.ZipFile(filepath, 'r')
file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
with zipfile.ZipFile(filepath, 'r') as files:
file_list = files.namelist()
for item in file_list:
files.extract(item, file_dir)
file_dir = os.path.dirname(filepath)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
# `strip(os.sep)` to remove `os.sep` in the tail of path
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.extractall(file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
files.extractall(os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
return uncompressed_path
def _uncompress_file_tar(filepath, mode="r:*"):
files = tarfile.open(filepath, mode)
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
with tarfile.open(filepath, mode) as files:
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
files.close()
files.extractall(os.path.join(file_dir, rootpath))
return uncompressed_path
return uncompressed_path
def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < -1:
if len(file_list) == 1 and file_list[0].find(os.sep) < 0:
return True
return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册