diff --git a/python/paddle/tests/test_download.py b/python/paddle/tests/test_download.py index 986d84dd153b2f54624d3394fbb5a4b1b52b8953..49e76d9416e69439b06f2adb5eb2bef918d0b52b 100644 --- a/python/paddle/tests/test_download.py +++ b/python/paddle/tests/test_download.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest from paddle.utils.download import get_weights_path_from_url @@ -70,6 +71,36 @@ class TestDownload(unittest.TestCase): for url in urls: get_path_from_url(url, root_dir='./test') + def test_uncompress_result(self): + results = [ + [ + "files/single_dir/file1", "files/single_dir/file2", + "files/single_file.pdparams" + ], + ["single_dir/file1", "single_dir/file2"], + ["single_file.pdparams"], + ] + tar_urls = [ + "https://paddle-hapi.bj.bcebos.com/unittest/files.tar", + "https://paddle-hapi.bj.bcebos.com/unittest/single_dir.tar", + "https://paddle-hapi.bj.bcebos.com/unittest/single_file.tar", + ] + + for url, uncompressd_res in zip(tar_urls, results): + uncompressed_path = get_path_from_url(url, root_dir='./test_tar') + self.assertTrue(all([os.path.exists(os.path.join("./test_tar", filepath)) \ + for filepath in uncompressd_res])) + + zip_urls = [ + "https://paddle-hapi.bj.bcebos.com/unittest/files.zip", + "https://paddle-hapi.bj.bcebos.com/unittest/single_dir.zip", + "https://paddle-hapi.bj.bcebos.com/unittest/single_file.zip", + ] + for url, uncompressd_res in zip(zip_urls, results): + uncompressed_path = get_path_from_url(url, root_dir='./test_zip') + self.assertTrue(all([os.path.exists(os.path.join("./test_zip", filepath)) \ + for filepath in uncompressd_res])) + def test_retry_exception(self, ): with self.assertRaises(RuntimeError): from paddle.utils.download import _download diff --git a/python/paddle/utils/download.py b/python/paddle/utils/download.py index 29baddff05af22df4f11e8e0fcb38b6d66983a47..bf40ff9ab221c9f519b88b95266fcb0b01cc8487 100644 --- a/python/paddle/utils/download.py +++ b/python/paddle/utils/download.py @@ -79,7 +79,7 @@ def get_weights_path_from_url(url, md5sum=None): Args: url (str): download url md5sum (str): md5 sum of download package - + Returns: str: a local path to save downloaded weights. @@ -146,8 +146,8 @@ def get_path_from_url(url, assert is_url(url), "downloading from {} not a url".format(url) # parse path after download to decompress under root_dir fullpath = _map_path(url, root_dir) - # Mainly used to solve the problem of downloading data from different - # machines in the case of multiple machines. Different ips will download + # Mainly used to solve the problem of downloading data from different + # machines in the case of multiple machines. Different ips will download # data, and the same ip will only download data once. unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:]) if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): @@ -302,70 +302,61 @@ def _decompress(fname): def _uncompress_file_zip(filepath): - files = zipfile.ZipFile(filepath, 'r') - file_list = files.namelist() - - file_dir = os.path.dirname(filepath) - - if _is_a_single_file(file_list): - rootpath = file_list[0] - uncompressed_path = os.path.join(file_dir, rootpath) + with zipfile.ZipFile(filepath, 'r') as files: + file_list = files.namelist() - for item in file_list: - files.extract(item, file_dir) + file_dir = os.path.dirname(filepath) - elif _is_a_single_dir(file_list): - rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] - uncompressed_path = os.path.join(file_dir, rootpath) + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + files.extractall(file_dir) - for item in file_list: - files.extract(item, file_dir) + elif _is_a_single_dir(file_list): + # `strip(os.sep)` to remove `os.sep` in the tail of path + rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split( + os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) - else: - rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] - uncompressed_path = os.path.join(file_dir, rootpath) - if not os.path.exists(uncompressed_path): - os.makedirs(uncompressed_path) - for item in file_list: - files.extract(item, os.path.join(file_dir, rootpath)) + files.extractall(file_dir) + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + if not os.path.exists(uncompressed_path): + os.makedirs(uncompressed_path) + files.extractall(os.path.join(file_dir, rootpath)) - files.close() - - return uncompressed_path + return uncompressed_path def _uncompress_file_tar(filepath, mode="r:*"): - files = tarfile.open(filepath, mode) - file_list = files.getnames() - - file_dir = os.path.dirname(filepath) - - if _is_a_single_file(file_list): - rootpath = file_list[0] - uncompressed_path = os.path.join(file_dir, rootpath) - for item in file_list: - files.extract(item, file_dir) - elif _is_a_single_dir(file_list): - rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] - uncompressed_path = os.path.join(file_dir, rootpath) - for item in file_list: - files.extract(item, file_dir) - else: - rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] - uncompressed_path = os.path.join(file_dir, rootpath) - if not os.path.exists(uncompressed_path): - os.makedirs(uncompressed_path) - - for item in file_list: - files.extract(item, os.path.join(file_dir, rootpath)) + with tarfile.open(filepath, mode) as files: + file_list = files.getnames() + + file_dir = os.path.dirname(filepath) + + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + files.extractall(file_dir) + elif _is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split( + os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + files.extractall(file_dir) + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + if not os.path.exists(uncompressed_path): + os.makedirs(uncompressed_path) - files.close() + files.extractall(os.path.join(file_dir, rootpath)) - return uncompressed_path + return uncompressed_path def _is_a_single_file(file_list): - if len(file_list) == 1 and file_list[0].find(os.sep) < -1: + if len(file_list) == 1 and file_list[0].find(os.sep) < 0: return True return False