提交 b85792ac 编写于 作者: M Megvii Engine Team

feat(imperative/data): add concat dataset

GitOrigin-RevId: a82b720998c797c45de8a396a0d80a5db68925ef
上级 55cbab7a
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from .meta_dataset import ArrayDataset, Dataset, StreamDataset from .meta_dataset import ArrayDataset, ConcatDataset, Dataset, StreamDataset
from .vision import * from .vision import *
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import bisect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Tuple from typing import Tuple
...@@ -143,3 +144,73 @@ class ArrayDataset(Dataset): ...@@ -143,3 +144,73 @@ class ArrayDataset(Dataset):
def __len__(self) -> int: def __len__(self) -> int:
return len(self.arrays[0]) return len(self.arrays[0])
class ConcatDataset(Dataset):
r"""ConcatDataset is a concatenation of multiple datasets.
This dataset is used for assembleing multiple map-style
datasets.
Examples:
.. code-block:: python
from megengine.data.dataset import ArrayDataset, ConcatDataset
data1 = np.random.randint(0, 255, size=(2, 1, 32, 32), dtype=np.uint8)
data2 = np.random.randint(0, 255, size=(2, 1, 32, 32), dtype=np.uint8)
label = np.random.randint(0, 10, size=(2,), dtype=int)
labe2 = np.random.randint(0, 10, size=(2,), dtype=int)
dataset1 = ArrayDataset(data1, label1)
dataset2 = ArrayDataset(data2, label2)
dataset = ConcatDataset([dataset1, dataset2])
seque_sampler = SequentialSampler(dataset, batch_size=2)
dataloader = DataLoader(
dataset,
sampler = seque_sampler,
num_workers=3,
)
for step, data in enumerate(dataloader):
print(data)
"""
def __init__(self, datasets):
super(ConcatDataset, self).__init__()
self.datasets = datasets
def cumsum(datasets):
r, s = [], 0
for e in datasets:
l = len(e)
r.append(l + s)
s += l
return r
assert len(self.datasets) > 0, "datasets should not be an empty iterable"
for d in self.datasets:
assert not isinstance(
d, StreamDataset
), "ConcatDataset does not support StreamDataset"
self.datasets = list(datasets)
self.cumulative_sizes = cumsum(self.datasets)
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError(
"absolute value of index should not exceed dataset length"
)
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
def __len__(self):
return self.cumulative_sizes[-1]
...@@ -5,7 +5,7 @@ import sys ...@@ -5,7 +5,7 @@ import sys
import numpy as np import numpy as np
import pytest import pytest
from megengine.data.dataset import ArrayDataset, Dataset, StreamDataset from megengine.data.dataset import ArrayDataset, ConcatDataset, Dataset, StreamDataset
def test_abstract_cls(): def test_abstract_cls():
...@@ -32,3 +32,21 @@ def test_array_dataset_dim_error(): ...@@ -32,3 +32,21 @@ def test_array_dataset_dim_error():
label = np.random.randint(0, 9, (1,)) label = np.random.randint(0, 9, (1,))
with pytest.raises(ValueError): with pytest.raises(ValueError):
ArrayDataset(data, label) ArrayDataset(data, label)
def test_concat_dataset():
size1 = (10,)
size2 = (20,)
data_shape1 = (3, 256, 256)
data_shape2 = (2, 128, 128)
label_shape1 = (1,)
label_shape2 = (2,)
data1 = np.random.randint(0, 255, size1 + data_shape1)
data2 = np.random.randint(0, 255, size2 + data_shape2)
label1 = np.random.randint(0, 9, size1 + label_shape1)
label2 = np.random.randint(0, 9, size2 + label_shape2)
dataset1 = ArrayDataset(data1, label1)
dataset2 = ArrayDataset(data2, label2)
dataset = ConcatDataset([dataset1, dataset2])
assert dataset[15][0].shape == data_shape2
assert dataset[15][1].shape == label_shape2
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册