未验证 提交 467099f0 编写于 作者: C CtfGo 提交者: GitHub

Speedup download uncompress function (#37311)

`paddle.utils.download` :change to call `extractall` on tar/zip compressd file  to speed up the uncompress process when they includes many files

--- result of decompression speed comparison ---
1. dataset:https://paddlenlp.bj.bcebos.com/datasets/cnn_dailymail/cnn_stories.tgz, decompression time
:5m50s vs 20s
2. dataset:https://paddlenlp.bj.bcebos.com/datasets/cnn_dailymail/dailymail_stories.tgz, decompression time:33m20s vs 47s
上级 2dfcdf21
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import unittest import unittest
from paddle.utils.download import get_weights_path_from_url from paddle.utils.download import get_weights_path_from_url
...@@ -70,6 +71,36 @@ class TestDownload(unittest.TestCase): ...@@ -70,6 +71,36 @@ class TestDownload(unittest.TestCase):
for url in urls: for url in urls:
get_path_from_url(url, root_dir='./test') get_path_from_url(url, root_dir='./test')
def test_uncompress_result(self):
results = [
[
"files/single_dir/file1", "files/single_dir/file2",
"files/single_file.pdparams"
],
["single_dir/file1", "single_dir/file2"],
["single_file.pdparams"],
]
tar_urls = [
"https://paddle-hapi.bj.bcebos.com/unittest/files.tar",
"https://paddle-hapi.bj.bcebos.com/unittest/single_dir.tar",
"https://paddle-hapi.bj.bcebos.com/unittest/single_file.tar",
]
for url, uncompressd_res in zip(tar_urls, results):
uncompressed_path = get_path_from_url(url, root_dir='./test_tar')
self.assertTrue(all([os.path.exists(os.path.join("./test_tar", filepath)) \
for filepath in uncompressd_res]))
zip_urls = [
"https://paddle-hapi.bj.bcebos.com/unittest/files.zip",
"https://paddle-hapi.bj.bcebos.com/unittest/single_dir.zip",
"https://paddle-hapi.bj.bcebos.com/unittest/single_file.zip",
]
for url, uncompressd_res in zip(zip_urls, results):
uncompressed_path = get_path_from_url(url, root_dir='./test_zip')
self.assertTrue(all([os.path.exists(os.path.join("./test_zip", filepath)) \
for filepath in uncompressd_res]))
def test_retry_exception(self, ): def test_retry_exception(self, ):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
from paddle.utils.download import _download from paddle.utils.download import _download
......
...@@ -302,7 +302,7 @@ def _decompress(fname): ...@@ -302,7 +302,7 @@ def _decompress(fname):
def _uncompress_file_zip(filepath): def _uncompress_file_zip(filepath):
files = zipfile.ZipFile(filepath, 'r') with zipfile.ZipFile(filepath, 'r') as files:
file_list = files.namelist() file_list = files.namelist()
file_dir = os.path.dirname(filepath) file_dir = os.path.dirname(filepath)
...@@ -310,32 +310,27 @@ def _uncompress_file_zip(filepath): ...@@ -310,32 +310,27 @@ def _uncompress_file_zip(filepath):
if _is_a_single_file(file_list): if _is_a_single_file(file_list):
rootpath = file_list[0] rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath) uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list): elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] # `strip(os.sep)` to remove `os.sep` in the tail of path
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath) uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list: files.extractall(file_dir)
files.extract(item, file_dir)
else: else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath) uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path): if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path) os.makedirs(uncompressed_path)
for item in file_list: files.extractall(os.path.join(file_dir, rootpath))
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path return uncompressed_path
def _uncompress_file_tar(filepath, mode="r:*"): def _uncompress_file_tar(filepath, mode="r:*"):
files = tarfile.open(filepath, mode) with tarfile.open(filepath, mode) as files:
file_list = files.getnames() file_list = files.getnames()
file_dir = os.path.dirname(filepath) file_dir = os.path.dirname(filepath)
...@@ -343,29 +338,25 @@ def _uncompress_file_tar(filepath, mode="r:*"): ...@@ -343,29 +338,25 @@ def _uncompress_file_tar(filepath, mode="r:*"):
if _is_a_single_file(file_list): if _is_a_single_file(file_list):
rootpath = file_list[0] rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath) uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list: files.extractall(file_dir)
files.extract(item, file_dir)
elif _is_a_single_dir(file_list): elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath) uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list: files.extractall(file_dir)
files.extract(item, file_dir)
else: else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath) uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path): if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path) os.makedirs(uncompressed_path)
for item in file_list: files.extractall(os.path.join(file_dir, rootpath))
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path return uncompressed_path
def _is_a_single_file(file_list): def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < -1: if len(file_list) == 1 and file_list[0].find(os.sep) < 0:
return True return True
return False return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册