From 31766e3a9d16071dd9231b6620dc7c302153e42a Mon Sep 17 00:00:00 2001 From: LielinJiang Date: Wed, 29 Apr 2020 12:29:21 +0000 Subject: [PATCH] refine datasets --- hapi/datasets/folder.py | 12 +++--------- hapi/datasets/utils.py | 4 ++-- hapi/tests/test_datasets.py | 34 +++++++++++++++++++++++++++++++--- 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/hapi/datasets/folder.py b/hapi/datasets/folder.py index 1d8c2a3..c0b7c08 100644 --- a/hapi/datasets/folder.py +++ b/hapi/datasets/folder.py @@ -34,13 +34,10 @@ def has_valid_extension(filename, extensions): return filename.lower().endswith(extensions) -def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None): +def make_dataset(dir, class_to_idx, extensions, is_valid_file=None): images = [] dir = os.path.expanduser(dir) - if not ((extensions is None) ^ (is_valid_file is None)): - raise ValueError( - "Both extensions and is_valid_file cannot be None or not None at the same time" - ) + if extensions is not None: def is_valid_file(x): @@ -200,10 +197,7 @@ class ImageFolder(Dataset): samples = [] path = os.path.expanduser(root) - if not ((extensions is None) ^ (is_valid_file is None)): - raise ValueError( - "Both extensions and is_valid_file cannot be None or not None at the same time" - ) + if extensions is not None: def is_valid_file(x): diff --git a/hapi/datasets/utils.py b/hapi/datasets/utils.py index b580dd2..171f794 100644 --- a/hapi/datasets/utils.py +++ b/hapi/datasets/utils.py @@ -25,5 +25,5 @@ def _check_exists_and_download(path, url, md5, module_name, download=True): if download: return paddle.dataset.common.download(url, module_name, md5) else: - raise FileNotFoundError( - '{} not exists and auto download disabled'.format(path)) + raise ValueError('{} not exists and auto download disabled'.format( + path)) diff --git a/hapi/tests/test_datasets.py b/hapi/tests/test_datasets.py index 857d037..cec6f1e 100644 --- a/hapi/tests/test_datasets.py +++ b/hapi/tests/test_datasets.py @@ -20,11 +20,14 @@ import shutil import cv2 from hapi.datasets import * +from hapi.datasets.utils import _check_exists_and_download +from hapi.vision.transforms import Compose class TestFolderDatasets(unittest.TestCase): - def makedata(self): + def setUp(self): self.data_dir = tempfile.mkdtemp() + self.empty_dir = tempfile.mkdtemp() for i in range(2): sub_dir = os.path.join(self.data_dir, 'class_' + str(i)) if not os.path.exists(sub_dir): @@ -34,8 +37,10 @@ class TestFolderDatasets(unittest.TestCase): (32, 32, 3)) * 255).astype('uint8') cv2.imwrite(os.path.join(sub_dir, str(j) + '.jpg'), fake_img) + def tearDown(self): + shutil.rmtree(self.data_dir) + def test_dataset(self): - self.makedata() dataset_folder = DatasetFolder(self.data_dir) for _ in dataset_folder: @@ -44,7 +49,30 @@ class TestFolderDatasets(unittest.TestCase): assert len(dataset_folder) == 4 assert len(dataset_folder.classes) == 2 - shutil.rmtree(self.data_dir) + transform = Compose([]) + dataset_folder = DatasetFolder(self.data_dir, transform=transform) + for _ in dataset_folder: + pass + + def test_folder(self): + loader = ImageFolder(self.data_dir) + + for _ in loader: + pass + + transform = Compose([]) + loader = ImageFolder(self.data_dir, transform=transform) + for _ in loader: + pass + + def test_errors(self): + with self.assertRaises(RuntimeError): + ImageFolder(self.empty_dir) + with self.assertRaises(RuntimeError): + DatasetFolder(self.empty_dir) + + with self.assertRaises(ValueError): + _check_exists_and_download('temp_paddle', None, None, None, False) class TestMNISTTest(unittest.TestCase): -- GitLab