提交 31766e3a 编写于 作者: L LielinJiang

refine datasets

上级 c42caaaa
......@@ -34,13 +34,10 @@ def has_valid_extension(filename, extensions):
return filename.lower().endswith(extensions)
def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
def make_dataset(dir, class_to_idx, extensions, is_valid_file=None):
images = []
dir = os.path.expanduser(dir)
if not ((extensions is None) ^ (is_valid_file is None)):
raise ValueError(
"Both extensions and is_valid_file cannot be None or not None at the same time"
)
if extensions is not None:
def is_valid_file(x):
......@@ -200,10 +197,7 @@ class ImageFolder(Dataset):
samples = []
path = os.path.expanduser(root)
if not ((extensions is None) ^ (is_valid_file is None)):
raise ValueError(
"Both extensions and is_valid_file cannot be None or not None at the same time"
)
if extensions is not None:
def is_valid_file(x):
......
......@@ -25,5 +25,5 @@ def _check_exists_and_download(path, url, md5, module_name, download=True):
if download:
return paddle.dataset.common.download(url, module_name, md5)
else:
raise FileNotFoundError(
'{} not exists and auto download disabled'.format(path))
raise ValueError('{} not exists and auto download disabled'.format(
path))
......@@ -20,11 +20,14 @@ import shutil
import cv2
from hapi.datasets import *
from hapi.datasets.utils import _check_exists_and_download
from hapi.vision.transforms import Compose
class TestFolderDatasets(unittest.TestCase):
def makedata(self):
def setUp(self):
self.data_dir = tempfile.mkdtemp()
self.empty_dir = tempfile.mkdtemp()
for i in range(2):
sub_dir = os.path.join(self.data_dir, 'class_' + str(i))
if not os.path.exists(sub_dir):
......@@ -34,8 +37,10 @@ class TestFolderDatasets(unittest.TestCase):
(32, 32, 3)) * 255).astype('uint8')
cv2.imwrite(os.path.join(sub_dir, str(j) + '.jpg'), fake_img)
def tearDown(self):
shutil.rmtree(self.data_dir)
def test_dataset(self):
self.makedata()
dataset_folder = DatasetFolder(self.data_dir)
for _ in dataset_folder:
......@@ -44,7 +49,30 @@ class TestFolderDatasets(unittest.TestCase):
assert len(dataset_folder) == 4
assert len(dataset_folder.classes) == 2
shutil.rmtree(self.data_dir)
transform = Compose([])
dataset_folder = DatasetFolder(self.data_dir, transform=transform)
for _ in dataset_folder:
pass
def test_folder(self):
loader = ImageFolder(self.data_dir)
for _ in loader:
pass
transform = Compose([])
loader = ImageFolder(self.data_dir, transform=transform)
for _ in loader:
pass
def test_errors(self):
with self.assertRaises(RuntimeError):
ImageFolder(self.empty_dir)
with self.assertRaises(RuntimeError):
DatasetFolder(self.empty_dir)
with self.assertRaises(ValueError):
_check_exists_and_download('temp_paddle', None, None, None, False)
class TestMNISTTest(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册