From 9cec01615ef559c134d6c3004b563a2ae79bc8ea Mon Sep 17 00:00:00 2001 From: Steffy-zxf <48793257+Steffy-zxf@users.noreply.github.com> Date: Thu, 3 Dec 2020 14:39:50 +0800 Subject: [PATCH] fix DATA_HOME path in win (#29222) (#29318) * fix DATA_HOME path in win --- python/paddle/dataset/common.py | 3 ++- python/paddle/tests/test_download.py | 13 +++++++++++++ python/paddle/utils/download.py | 14 +++++++++++--- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/python/paddle/dataset/common.py b/python/paddle/dataset/common.py index 372249e01f..2884fa0ce5 100644 --- a/python/paddle/dataset/common.py +++ b/python/paddle/dataset/common.py @@ -34,7 +34,8 @@ __all__ = [ 'cluster_files_reader', ] -DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') +HOME = os.path.expanduser('~') +DATA_HOME = os.path.join(HOME, '.cache', 'paddle', 'dataset') # When running unit tests, there could be multiple processes that diff --git a/python/paddle/tests/test_download.py b/python/paddle/tests/test_download.py index 6fb53573c2..b8af7f6a80 100644 --- a/python/paddle/tests/test_download.py +++ b/python/paddle/tests/test_download.py @@ -15,6 +15,7 @@ import unittest from paddle.utils.download import get_weights_path_from_url +from paddle.utils.download import get_path_from_url class TestDownload(unittest.TestCase): @@ -57,6 +58,18 @@ class TestDownload(unittest.TestCase): for url in urls: self.download(url, None) + def test_get_path_from_url(self): + urls = [ + "https://paddle-hapi.bj.bcebos.com/unittest/files.tar", + "https://paddle-hapi.bj.bcebos.com/unittest/files.zip", + "https://paddle-hapi.bj.bcebos.com/unittest/single_dir.tar", + "https://paddle-hapi.bj.bcebos.com/unittest/single_dir.zip", + "https://paddle-hapi.bj.bcebos.com/unittest/single_file.tar", + "https://paddle-hapi.bj.bcebos.com/unittest/single_file.zip", + ] + for url in urls: + get_path_from_url(url, root_dir='./test') + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/utils/download.py b/python/paddle/utils/download.py index 7ba2085743..c5c7de678e 100644 --- a/python/paddle/utils/download.py +++ b/python/paddle/utils/download.py @@ -335,8 +335,16 @@ def _is_a_single_file(file_list): def _is_a_single_dir(file_list): - file_name = file_list[0].split(os.sep)[0] - for i in range(1, len(file_list)): - if file_name != file_list[i].split(os.sep)[0]: + new_file_list = [] + for file_path in file_list: + if '/' in file_path: + file_path = file_path.replace('/', os.sep) + elif '\\' in file_path: + file_path = file_path.replace('\\', os.sep) + new_file_list.append(file_path) + + file_name = new_file_list[0].split(os.sep)[0] + for i in range(1, len(new_file_list)): + if file_name != new_file_list[i].split(os.sep)[0]: return False return True -- GitLab