diff --git a/imperative/python/megengine/data/dataloader.py b/imperative/python/megengine/data/dataloader.py index 31cce8da4f8c31f399f71a77bb2d9c451a00e06e..c1fb5432043b0557320d01e34757889d09f70d26 100644 --- a/imperative/python/megengine/data/dataloader.py +++ b/imperative/python/megengine/data/dataloader.py @@ -19,7 +19,7 @@ import numpy as np from ..logger import get_logger from ..random.rng import _random_seed_generator from .collator import Collator -from .dataset import Dataset, MapDataset, StreamDataset +from .dataset import Dataset, StreamDataset from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler from .transform import PseudoTransform, Transform @@ -88,7 +88,15 @@ class DataLoader: self.divide = divide - if isinstance(dataset, MapDataset): + if isinstance(dataset, StreamDataset): + self.sampler = sampler if sampler else StreamSampler(batch_size=1) + assert isinstance( + self.sampler, StreamSampler + ), "types of dataset and sampler do not match" + else: + assert isinstance( + dataset, Dataset + ), "Can not recognize this kind of dataset: %s" % type(dataset) self.sampler = ( sampler if sampler @@ -97,15 +105,6 @@ class DataLoader: assert isinstance( self.sampler, MapSampler ), "types of dataset and sampler do not match" - elif isinstance(dataset, StreamDataset): - self.sampler = sampler if sampler else StreamSampler(batch_size=1) - assert isinstance( - self.sampler, StreamSampler - ), "types of dataset and sampler do not match" - else: - raise TypeError( - "can not recognize this kind of dataset: %s" % type(dataset) - ) if divide: if self.sampler.batch_size <= self.num_workers: @@ -140,15 +139,14 @@ class DataLoader: return _SerialStreamDataLoaderIter(self) else: return _ParallelStreamDataLoaderIter(self) - elif isinstance(self.dataset, MapDataset): + else: + assert isinstance( + self.dataset, Dataset + ), "Can not recognize this kind of dataset: %s" % type(self.dataset) if not self.num_workers: return _SerialMapDataLoaderIter(self) else: return _ParallelMapDataLoaderIter(self) - else: - raise TypeError( - "can not recognize this kind of dataset: %s" % type(self.dataset) - ) def __len__(self): return len(self.sampler) diff --git a/imperative/python/megengine/data/dataset/__init__.py b/imperative/python/megengine/data/dataset/__init__.py index 8b70d22111ba33a749a8c90491b2db52a700ed44..2d6fe2f1af4e7d2ecab50a37ad41b4870cf93cf4 100644 --- a/imperative/python/megengine/data/dataset/__init__.py +++ b/imperative/python/megengine/data/dataset/__init__.py @@ -6,5 +6,5 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from .meta_dataset import ArrayDataset, Dataset, MapDataset, StreamDataset +from .meta_dataset import ArrayDataset, Dataset, StreamDataset from .vision import * diff --git a/imperative/python/megengine/data/dataset/meta_dataset.py b/imperative/python/megengine/data/dataset/meta_dataset.py index 8b2a304d8eb86057d038a794ed40ce51b7e7a7ef..dd1f01c95088712c5132d8e513b4e6d271d97f8b 100644 --- a/imperative/python/megengine/data/dataset/meta_dataset.py +++ b/imperative/python/megengine/data/dataset/meta_dataset.py @@ -12,17 +12,7 @@ from typing import Tuple class Dataset(ABC): r""" - An abstract class for all Datasets. - """ - - @abstractmethod - def __init__(self): - pass - - -class MapDataset(Dataset): - r""" - An abstract class for map data. + An abstract class for all datasets. __getitem__ and __len__ method are aditionally needed. """ @@ -53,8 +43,14 @@ class StreamDataset(Dataset): def __iter__(self): pass + def __getitem__(self): + raise AssertionError("can not get item from StreamDataset by index") + + def __len__(self): + raise AssertionError("StreamDataset does not have length") + -class ArrayDataset(MapDataset): +class ArrayDataset(Dataset): def __init__(self, *arrays): r""" ArrayDataset is a dataset for numpy array data, one or more numpy arrays diff --git a/imperative/python/megengine/data/dataset/vision/meta_vision.py b/imperative/python/megengine/data/dataset/vision/meta_vision.py index 6d03d3eda5451a05039f513034f32444004db218..f89df740788f0bb1f9a421e21053b33b8fc17d17 100644 --- a/imperative/python/megengine/data/dataset/vision/meta_vision.py +++ b/imperative/python/megengine/data/dataset/vision/meta_vision.py @@ -9,10 +9,10 @@ import collections.abc import os -from ..meta_dataset import MapDataset +from ..meta_dataset import Dataset -class VisionDataset(MapDataset): +class VisionDataset(Dataset): _repr_indent = 4 def __init__(self, root, *, order=None, supported_order=None): diff --git a/imperative/python/test/unit/data/test_dataset.py b/imperative/python/test/unit/data/test_dataset.py index d68d37846d4b69d784df3d29125fb4c23240e918..4df34921e20cab56b1af7fe9c5dddc1e9623ce66 100644 --- a/imperative/python/test/unit/data/test_dataset.py +++ b/imperative/python/test/unit/data/test_dataset.py @@ -12,14 +12,12 @@ import sys import numpy as np import pytest -from megengine.data.dataset import ArrayDataset, Dataset, MapDataset, StreamDataset +from megengine.data.dataset import ArrayDataset, Dataset, StreamDataset def test_abstract_cls(): with pytest.raises(TypeError): Dataset() - with pytest.raises(TypeError): - MapDataset() with pytest.raises(TypeError): StreamDataset()