未验证 提交 467099f0 编写于 作者: C CtfGo 提交者: GitHub

Speedup download uncompress function (#37311)

`paddle.utils.download` :change to call `extractall` on tar/zip compressd file  to speed up the uncompress process when they includes many files

--- result of decompression speed comparison ---
1. dataset:https://paddlenlp.bj.bcebos.com/datasets/cnn_dailymail/cnn_stories.tgz, decompression time
:5m50s vs 20s
2. dataset:https://paddlenlp.bj.bcebos.com/datasets/cnn_dailymail/dailymail_stories.tgz, decompression time:33m20s vs 47s
上级 2dfcdf21
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import unittest import unittest
from paddle.utils.download import get_weights_path_from_url from paddle.utils.download import get_weights_path_from_url
...@@ -70,6 +71,36 @@ class TestDownload(unittest.TestCase): ...@@ -70,6 +71,36 @@ class TestDownload(unittest.TestCase):
for url in urls: for url in urls:
get_path_from_url(url, root_dir='./test') get_path_from_url(url, root_dir='./test')
def test_uncompress_result(self):
results = [
[
"files/single_dir/file1", "files/single_dir/file2",
"files/single_file.pdparams"
],
["single_dir/file1", "single_dir/file2"],
["single_file.pdparams"],
]
tar_urls = [
"https://paddle-hapi.bj.bcebos.com/unittest/files.tar",
"https://paddle-hapi.bj.bcebos.com/unittest/single_dir.tar",
"https://paddle-hapi.bj.bcebos.com/unittest/single_file.tar",
]
for url, uncompressd_res in zip(tar_urls, results):
uncompressed_path = get_path_from_url(url, root_dir='./test_tar')
self.assertTrue(all([os.path.exists(os.path.join("./test_tar", filepath)) \
for filepath in uncompressd_res]))
zip_urls = [
"https://paddle-hapi.bj.bcebos.com/unittest/files.zip",
"https://paddle-hapi.bj.bcebos.com/unittest/single_dir.zip",
"https://paddle-hapi.bj.bcebos.com/unittest/single_file.zip",
]
for url, uncompressd_res in zip(zip_urls, results):
uncompressed_path = get_path_from_url(url, root_dir='./test_zip')
self.assertTrue(all([os.path.exists(os.path.join("./test_zip", filepath)) \
for filepath in uncompressd_res]))
def test_retry_exception(self, ): def test_retry_exception(self, ):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
from paddle.utils.download import _download from paddle.utils.download import _download
......
...@@ -79,7 +79,7 @@ def get_weights_path_from_url(url, md5sum=None): ...@@ -79,7 +79,7 @@ def get_weights_path_from_url(url, md5sum=None):
Args: Args:
url (str): download url url (str): download url
md5sum (str): md5 sum of download package md5sum (str): md5 sum of download package
Returns: Returns:
str: a local path to save downloaded weights. str: a local path to save downloaded weights.
...@@ -146,8 +146,8 @@ def get_path_from_url(url, ...@@ -146,8 +146,8 @@ def get_path_from_url(url,
assert is_url(url), "downloading from {} not a url".format(url) assert is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir # parse path after download to decompress under root_dir
fullpath = _map_path(url, root_dir) fullpath = _map_path(url, root_dir)
# Mainly used to solve the problem of downloading data from different # Mainly used to solve the problem of downloading data from different
# machines in the case of multiple machines. Different ips will download # machines in the case of multiple machines. Different ips will download
# data, and the same ip will only download data once. # data, and the same ip will only download data once.
unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:]) unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:])
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
...@@ -302,70 +302,61 @@ def _decompress(fname): ...@@ -302,70 +302,61 @@ def _decompress(fname):
def _uncompress_file_zip(filepath): def _uncompress_file_zip(filepath):
files = zipfile.ZipFile(filepath, 'r') with zipfile.ZipFile(filepath, 'r') as files:
file_list = files.namelist() file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list: file_dir = os.path.dirname(filepath)
files.extract(item, file_dir)
elif _is_a_single_dir(file_list): if _is_a_single_file(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath) uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
for item in file_list: elif _is_a_single_dir(file_list):
files.extract(item, file_dir) # `strip(os.sep)` to remove `os.sep` in the tail of path
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
else: files.extractall(file_dir)
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] else:
uncompressed_path = os.path.join(file_dir, rootpath) rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
if not os.path.exists(uncompressed_path): uncompressed_path = os.path.join(file_dir, rootpath)
os.makedirs(uncompressed_path) if not os.path.exists(uncompressed_path):
for item in file_list: os.makedirs(uncompressed_path)
files.extract(item, os.path.join(file_dir, rootpath)) files.extractall(os.path.join(file_dir, rootpath))
files.close() return uncompressed_path
return uncompressed_path
def _uncompress_file_tar(filepath, mode="r:*"): def _uncompress_file_tar(filepath, mode="r:*"):
files = tarfile.open(filepath, mode) with tarfile.open(filepath, mode) as files:
file_list = files.getnames() file_list = files.getnames()
file_dir = os.path.dirname(filepath) file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list): if _is_a_single_file(file_list):
rootpath = file_list[0] rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath) uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list: files.extractall(file_dir)
files.extract(item, file_dir) elif _is_a_single_dir(file_list):
elif _is_a_single_dir(file_list): rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath) uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list: files.extractall(file_dir)
files.extract(item, file_dir) else:
else: rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] uncompressed_path = os.path.join(file_dir, rootpath)
uncompressed_path = os.path.join(file_dir, rootpath) if not os.path.exists(uncompressed_path):
if not os.path.exists(uncompressed_path): os.makedirs(uncompressed_path)
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close() files.extractall(os.path.join(file_dir, rootpath))
return uncompressed_path return uncompressed_path
def _is_a_single_file(file_list): def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < -1: if len(file_list) == 1 and file_list[0].find(os.sep) < 0:
return True return True
return False return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册