提交 05c739b8 编写于 作者: M Megvii Engine Team

refactor(mge/data): rename `MapDataset` to `Dataset`

GitOrigin-RevId: 6262561355995679ff51a80c0187d204405d376b
上级 a892e5d0
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
from ..logger import get_logger from ..logger import get_logger
from ..random.rng import _random_seed_generator from ..random.rng import _random_seed_generator
from .collator import Collator from .collator import Collator
from .dataset import Dataset, MapDataset, StreamDataset from .dataset import Dataset, StreamDataset
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform from .transform import PseudoTransform, Transform
...@@ -88,7 +88,15 @@ class DataLoader: ...@@ -88,7 +88,15 @@ class DataLoader:
self.divide = divide self.divide = divide
if isinstance(dataset, MapDataset): if isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
assert isinstance(
self.sampler, StreamSampler
), "types of dataset and sampler do not match"
else:
assert isinstance(
dataset, Dataset
), "Can not recognize this kind of dataset: %s" % type(dataset)
self.sampler = ( self.sampler = (
sampler sampler
if sampler if sampler
...@@ -97,15 +105,6 @@ class DataLoader: ...@@ -97,15 +105,6 @@ class DataLoader:
assert isinstance( assert isinstance(
self.sampler, MapSampler self.sampler, MapSampler
), "types of dataset and sampler do not match" ), "types of dataset and sampler do not match"
elif isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
assert isinstance(
self.sampler, StreamSampler
), "types of dataset and sampler do not match"
else:
raise TypeError(
"can not recognize this kind of dataset: %s" % type(dataset)
)
if divide: if divide:
if self.sampler.batch_size <= self.num_workers: if self.sampler.batch_size <= self.num_workers:
...@@ -140,15 +139,14 @@ class DataLoader: ...@@ -140,15 +139,14 @@ class DataLoader:
return _SerialStreamDataLoaderIter(self) return _SerialStreamDataLoaderIter(self)
else: else:
return _ParallelStreamDataLoaderIter(self) return _ParallelStreamDataLoaderIter(self)
elif isinstance(self.dataset, MapDataset): else:
assert isinstance(
self.dataset, Dataset
), "Can not recognize this kind of dataset: %s" % type(self.dataset)
if not self.num_workers: if not self.num_workers:
return _SerialMapDataLoaderIter(self) return _SerialMapDataLoaderIter(self)
else: else:
return _ParallelMapDataLoaderIter(self) return _ParallelMapDataLoaderIter(self)
else:
raise TypeError(
"can not recognize this kind of dataset: %s" % type(self.dataset)
)
def __len__(self): def __len__(self):
return len(self.sampler) return len(self.sampler)
......
...@@ -6,5 +6,5 @@ ...@@ -6,5 +6,5 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .meta_dataset import ArrayDataset, Dataset, MapDataset, StreamDataset from .meta_dataset import ArrayDataset, Dataset, StreamDataset
from .vision import * from .vision import *
...@@ -12,17 +12,7 @@ from typing import Tuple ...@@ -12,17 +12,7 @@ from typing import Tuple
class Dataset(ABC): class Dataset(ABC):
r""" r"""
An abstract class for all Datasets. An abstract class for all datasets.
"""
@abstractmethod
def __init__(self):
pass
class MapDataset(Dataset):
r"""
An abstract class for map data.
__getitem__ and __len__ method are aditionally needed. __getitem__ and __len__ method are aditionally needed.
""" """
...@@ -53,8 +43,14 @@ class StreamDataset(Dataset): ...@@ -53,8 +43,14 @@ class StreamDataset(Dataset):
def __iter__(self): def __iter__(self):
pass pass
def __getitem__(self):
raise AssertionError("can not get item from StreamDataset by index")
def __len__(self):
raise AssertionError("StreamDataset does not have length")
class ArrayDataset(MapDataset): class ArrayDataset(Dataset):
def __init__(self, *arrays): def __init__(self, *arrays):
r""" r"""
ArrayDataset is a dataset for numpy array data, one or more numpy arrays ArrayDataset is a dataset for numpy array data, one or more numpy arrays
......
...@@ -9,10 +9,10 @@ ...@@ -9,10 +9,10 @@
import collections.abc import collections.abc
import os import os
from ..meta_dataset import MapDataset from ..meta_dataset import Dataset
class VisionDataset(MapDataset): class VisionDataset(Dataset):
_repr_indent = 4 _repr_indent = 4
def __init__(self, root, *, order=None, supported_order=None): def __init__(self, root, *, order=None, supported_order=None):
......
...@@ -12,14 +12,12 @@ import sys ...@@ -12,14 +12,12 @@ import sys
import numpy as np import numpy as np
import pytest import pytest
from megengine.data.dataset import ArrayDataset, Dataset, MapDataset, StreamDataset from megengine.data.dataset import ArrayDataset, Dataset, StreamDataset
def test_abstract_cls(): def test_abstract_cls():
with pytest.raises(TypeError): with pytest.raises(TypeError):
Dataset() Dataset()
with pytest.raises(TypeError):
MapDataset()
with pytest.raises(TypeError): with pytest.raises(TypeError):
StreamDataset() StreamDataset()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册