diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index 372249e01f66bc39045db3f36238da1500c7738e..2884fa0ce5e3d037fe2e929218da0aa52c1c0d8e 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -34,7 +34,8 @@ __all__ = [ 'cluster_files_reader', ] -DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') +HOME = os.path.expanduser('~') +DATA_HOME = os.path.join(HOME, '.cache', 'paddle', 'dataset') # When running unit tests, there could be multiple processes that diff --git a/python/paddle/tests/test_download.py b/python/paddle/tests/test_download.py index 6fb53573c21a1589e474e337d058294c09f65f38..b8af7f6a80e72148a4f793a4de2188d3cc7a8b69 100644 --- a/python/paddle/tests/test_download.py +++ b/python/paddle/tests/test_download.py @@ -15,6 +15,7 @@ import unittest from paddle.utils.download import get_weights_path_from_url +from paddle.utils.download import get_path_from_url class TestDownload(unittest.TestCase): @@ -57,6 +58,18 @@ class TestDownload(unittest.TestCase): for url in urls: self.download(url, None) + def test_get_path_from_url(self): + urls = [ + "https://paddle-hapi.bj.bcebos.com/unittest/files.tar", + "https://paddle-hapi.bj.bcebos.com/unittest/files.zip", + "https://paddle-hapi.bj.bcebos.com/unittest/single_dir.tar", + "https://paddle-hapi.bj.bcebos.com/unittest/single_dir.zip", + "https://paddle-hapi.bj.bcebos.com/unittest/single_file.tar", + "https://paddle-hapi.bj.bcebos.com/unittest/single_file.zip", + ] + for url in urls: + get_path_from_url(url, root_dir='./test') + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/utils/download.py b/python/paddle/utils/download.py index 7ba208574353fa32d1e6f1a36ca1c24956909cf1..c5c7de678edee4dc510383a09bfec4c8fbf80950 100644 --- a/python/paddle/utils/download.py +++ b/python/paddle/utils/download.py @@ -335,8 +335,16 @@ def _is_a_single_file(file_list): def _is_a_single_dir(file_list): - file_name = file_list[0].split(os.sep)[0] - for i in range(1, len(file_list)): - if file_name != file_list[i].split(os.sep)[0]: + new_file_list = [] + for file_path in file_list: + if '/' in file_path: + file_path = file_path.replace('/', os.sep) + elif '\\' in file_path: + file_path = file_path.replace('\\', os.sep) + new_file_list.append(file_path) + + file_name = new_file_list[0].split(os.sep)[0] + for i in range(1, len(new_file_list)): + if file_name != new_file_list[i].split(os.sep)[0]: return False return True