提交 05c739b8 编写于 作者: M Megvii Engine Team

refactor(mge/data): rename `MapDataset` to `Dataset`

GitOrigin-RevId: 6262561355995679ff51a80c0187d204405d376b
上级 a892e5d0
......@@ -19,7 +19,7 @@ import numpy as np
from ..logger import get_logger
from ..random.rng import _random_seed_generator
from .collator import Collator
from .dataset import Dataset, MapDataset, StreamDataset
from .dataset import Dataset, StreamDataset
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform
......@@ -88,7 +88,15 @@ class DataLoader:
self.divide = divide
if isinstance(dataset, MapDataset):
if isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
assert isinstance(
self.sampler, StreamSampler
), "types of dataset and sampler do not match"
else:
assert isinstance(
dataset, Dataset
), "Can not recognize this kind of dataset: %s" % type(dataset)
self.sampler = (
sampler
if sampler
......@@ -97,15 +105,6 @@ class DataLoader:
assert isinstance(
self.sampler, MapSampler
), "types of dataset and sampler do not match"
elif isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
assert isinstance(
self.sampler, StreamSampler
), "types of dataset and sampler do not match"
else:
raise TypeError(
"can not recognize this kind of dataset: %s" % type(dataset)
)
if divide:
if self.sampler.batch_size <= self.num_workers:
......@@ -140,15 +139,14 @@ class DataLoader:
return _SerialStreamDataLoaderIter(self)
else:
return _ParallelStreamDataLoaderIter(self)
elif isinstance(self.dataset, MapDataset):
else:
assert isinstance(
self.dataset, Dataset
), "Can not recognize this kind of dataset: %s" % type(self.dataset)
if not self.num_workers:
return _SerialMapDataLoaderIter(self)
else:
return _ParallelMapDataLoaderIter(self)
else:
raise TypeError(
"can not recognize this kind of dataset: %s" % type(self.dataset)
)
def __len__(self):
return len(self.sampler)
......
......@@ -6,5 +6,5 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .meta_dataset import ArrayDataset, Dataset, MapDataset, StreamDataset
from .meta_dataset import ArrayDataset, Dataset, StreamDataset
from .vision import *
......@@ -12,17 +12,7 @@ from typing import Tuple
class Dataset(ABC):
r"""
An abstract class for all Datasets.
"""
@abstractmethod
def __init__(self):
pass
class MapDataset(Dataset):
r"""
An abstract class for map data.
An abstract class for all datasets.
__getitem__ and __len__ method are aditionally needed.
"""
......@@ -53,8 +43,14 @@ class StreamDataset(Dataset):
def __iter__(self):
pass
def __getitem__(self):
raise AssertionError("can not get item from StreamDataset by index")
def __len__(self):
raise AssertionError("StreamDataset does not have length")
class ArrayDataset(MapDataset):
class ArrayDataset(Dataset):
def __init__(self, *arrays):
r"""
ArrayDataset is a dataset for numpy array data, one or more numpy arrays
......
......@@ -9,10 +9,10 @@
import collections.abc
import os
from ..meta_dataset import MapDataset
from ..meta_dataset import Dataset
class VisionDataset(MapDataset):
class VisionDataset(Dataset):
_repr_indent = 4
def __init__(self, root, *, order=None, supported_order=None):
......
......@@ -12,14 +12,12 @@ import sys
import numpy as np
import pytest
from megengine.data.dataset import ArrayDataset, Dataset, MapDataset, StreamDataset
from megengine.data.dataset import ArrayDataset, Dataset, StreamDataset
def test_abstract_cls():
with pytest.raises(TypeError):
Dataset()
with pytest.raises(TypeError):
MapDataset()
with pytest.raises(TypeError):
StreamDataset()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册