From 05c739b84630bbdefb443887f6e04b75e4ee9ee8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 31 Dec 2020 15:24:47 +0800 Subject: [PATCH] refactor(mge/data): rename `MapDataset` to `Dataset` GitOrigin-RevId: 6262561355995679ff51a80c0187d204405d376b --- .../python/megengine/data/dataloader.py | 30 +++++++++---------- .../python/megengine/data/dataset/__init__.py | 2 +- .../megengine/data/dataset/meta_dataset.py | 20 +++++-------- .../data/dataset/vision/meta_vision.py | 4 +-- .../python/test/unit/data/test_dataset.py | 4 +-- 5 files changed, 26 insertions(+), 34 deletions(-) diff --git a/imperative/python/megengine/data/dataloader.py b/imperative/python/megengine/data/dataloader.py index 31cce8da4..c1fb54320 100644 --- a/imperative/python/megengine/data/dataloader.py +++ b/imperative/python/megengine/data/dataloader.py @@ -19,7 +19,7 @@ import numpy as np from ..logger import get_logger from ..random.rng import _random_seed_generator from .collator import Collator -from .dataset import Dataset, MapDataset, StreamDataset +from .dataset import Dataset, StreamDataset from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler from .transform import PseudoTransform, Transform @@ -88,7 +88,15 @@ class DataLoader: self.divide = divide - if isinstance(dataset, MapDataset): + if isinstance(dataset, StreamDataset): + self.sampler = sampler if sampler else StreamSampler(batch_size=1) + assert isinstance( + self.sampler, StreamSampler + ), "types of dataset and sampler do not match" + else: + assert isinstance( + dataset, Dataset + ), "Can not recognize this kind of dataset: %s" % type(dataset) self.sampler = ( sampler if sampler @@ -97,15 +105,6 @@ class DataLoader: assert isinstance( self.sampler, MapSampler ), "types of dataset and sampler do not match" - elif isinstance(dataset, StreamDataset): - self.sampler = sampler if sampler else StreamSampler(batch_size=1) - assert isinstance( - self.sampler, StreamSampler - ), "types of dataset and sampler do not match" - else: - raise TypeError( - "can not recognize this kind of dataset: %s" % type(dataset) - ) if divide: if self.sampler.batch_size <= self.num_workers: @@ -140,15 +139,14 @@ class DataLoader: return _SerialStreamDataLoaderIter(self) else: return _ParallelStreamDataLoaderIter(self) - elif isinstance(self.dataset, MapDataset): + else: + assert isinstance( + self.dataset, Dataset + ), "Can not recognize this kind of dataset: %s" % type(self.dataset) if not self.num_workers: return _SerialMapDataLoaderIter(self) else: return _ParallelMapDataLoaderIter(self) - else: - raise TypeError( - "can not recognize this kind of dataset: %s" % type(self.dataset) - ) def __len__(self): return len(self.sampler) diff --git a/imperative/python/megengine/data/dataset/__init__.py b/imperative/python/megengine/data/dataset/__init__.py index 8b70d2211..2d6fe2f1a 100644 --- a/imperative/python/megengine/data/dataset/__init__.py +++ b/imperative/python/megengine/data/dataset/__init__.py @@ -6,5 +6,5 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from .meta_dataset import ArrayDataset, Dataset, MapDataset, StreamDataset +from .meta_dataset import ArrayDataset, Dataset, StreamDataset from .vision import * diff --git a/imperative/python/megengine/data/dataset/meta_dataset.py b/imperative/python/megengine/data/dataset/meta_dataset.py index 8b2a304d8..dd1f01c95 100644 --- a/imperative/python/megengine/data/dataset/meta_dataset.py +++ b/imperative/python/megengine/data/dataset/meta_dataset.py @@ -12,17 +12,7 @@ from typing import Tuple class Dataset(ABC): r""" - An abstract class for all Datasets. - """ - - @abstractmethod - def __init__(self): - pass - - -class MapDataset(Dataset): - r""" - An abstract class for map data. + An abstract class for all datasets. __getitem__ and __len__ method are aditionally needed. """ @@ -53,8 +43,14 @@ class StreamDataset(Dataset): def __iter__(self): pass + def __getitem__(self): + raise AssertionError("can not get item from StreamDataset by index") + + def __len__(self): + raise AssertionError("StreamDataset does not have length") + -class ArrayDataset(MapDataset): +class ArrayDataset(Dataset): def __init__(self, *arrays): r""" ArrayDataset is a dataset for numpy array data, one or more numpy arrays diff --git a/imperative/python/megengine/data/dataset/vision/meta_vision.py b/imperative/python/megengine/data/dataset/vision/meta_vision.py index 6d03d3eda..f89df7407 100644 --- a/imperative/python/megengine/data/dataset/vision/meta_vision.py +++ b/imperative/python/megengine/data/dataset/vision/meta_vision.py @@ -9,10 +9,10 @@ import collections.abc import os -from ..meta_dataset import MapDataset +from ..meta_dataset import Dataset -class VisionDataset(MapDataset): +class VisionDataset(Dataset): _repr_indent = 4 def __init__(self, root, *, order=None, supported_order=None): diff --git a/imperative/python/test/unit/data/test_dataset.py b/imperative/python/test/unit/data/test_dataset.py index d68d37846..4df34921e 100644 --- a/imperative/python/test/unit/data/test_dataset.py +++ b/imperative/python/test/unit/data/test_dataset.py @@ -12,14 +12,12 @@ import sys import numpy as np import pytest -from megengine.data.dataset import ArrayDataset, Dataset, MapDataset, StreamDataset +from megengine.data.dataset import ArrayDataset, Dataset, StreamDataset def test_abstract_cls(): with pytest.raises(TypeError): Dataset() - with pytest.raises(TypeError): - MapDataset() with pytest.raises(TypeError): StreamDataset() -- GitLab