未验证 提交 463075a8 编写于 作者: K Kaipeng Deng 提交者: GitHub

add paddle.io.ComposeDataset & paddle.io.ChainDataset (#28311)

* add paddle.io.ComposeDataset & paddle.io.ChainDataset. test=develop
上级 a4303496
......@@ -17,7 +17,10 @@ from __future__ import print_function
from .. import framework
import paddle.dataset.common
__all__ = ["Dataset", "IterableDataset", "TensorDataset"]
__all__ = [
"Dataset", "IterableDataset", "TensorDataset", "ComposeDataset",
"ChainDataset"
]
class Dataset(object):
......@@ -275,3 +278,130 @@ class TensorDataset(Dataset):
def __len__(self):
return self.tensors[0].shape[0]
def to_list(value):
if value is None:
return value
if isinstance(value, (list, tuple)):
return list(value)
return [value]
class ComposeDataset(Dataset):
"""
A Dataset which composes fields of multiple datasets.
This dataset is used for composing fileds of multiple map-style
datasets of same length.
Args:
datasets(list of Dataset): List of datasets to be composed.
Returns:
Dataset: A Dataset which composes fields of multiple datasets.
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.io import Dataset, ComposeDataset
# define a random dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([32]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
dataset = ComposeDataset([RandomDataset(10), RandomDataset(10)])
for i in range(len(dataset)):
image1, label1, image2, label2 = dataset[i]
print(image1)
print(label1)
print(image2)
print(label2)
"""
def __init__(self, datasets):
self.datasets = list(datasets)
assert len(self.datasets) > 0, "input datasets shoule not be empty"
for i, dataset in enumerate(self.datasets):
assert isinstance(dataset, Dataset), \
"each input dataset should be paddle.io.Dataset"
assert not isinstance(dataset, IterableDataset), \
"paddle.io.IterableDataset not supported"
if i > 0:
assert len(dataset) == len(self.datasets[i-1]), \
"lengths of datasets should be same"
def __len__(self):
return len(self.datasets[0])
def __getitem__(self, idx):
sample = []
for dataset in self.datasets:
sample.extend(to_list(dataset[idx]))
return tuple(sample)
class ChainDataset(IterableDataset):
"""
A Dataset which chains multiple iterable-tyle datasets.
This dataset is used for assembling multiple datasets which should
be :code:`paddle.io.IterableDataset`.
Args:
datasets(list of Dataset): List of datasets to be chainned.
Returns:
Dataset: A Dataset which chains fields of multiple datasets.
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.io import IterableDataset, ChainDataset
# define a random dataset
class RandomDataset(IterableDataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __iter__(self):
for i in range(10):
image = np.random.random([32]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
yield image, label
dataset = ChainDataset([RandomDataset(10), RandomDataset(10)])
for image, label in iter(dataset):
print(image, label)
"""
def __init__(self, datasets):
self.datasets = list(datasets)
assert len(self.datasets) > 0, "input datasets shoule not be empty"
for i, dataset in enumerate(self.datasets):
assert isinstance(dataset, IterableDataset), \
"ChainDataset only support paddle.io.IterableDataset"
def __iter__(self):
for dataset in self.datasets:
for sample in dataset:
yield sample
......@@ -19,9 +19,38 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.io import TensorDataset, DataLoader
from paddle.io import Dataset, IterableDataset, TensorDataset, \
ComposeDataset, ChainDataset, DataLoader
from paddle.fluid.dygraph.base import to_variable
IMAGE_SIZE = 32
class RandomDataset(Dataset):
def __init__(self, sample_num):
self.sample_num = sample_num
def __len__(self):
return self.sample_num
def __getitem__(self, idx):
np.random.seed(idx)
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
class RandomIterableDataset(IterableDataset):
def __init__(self, sample_num):
self.sample_num = sample_num
def __iter__(self):
for i in range(self.sample_num):
np.random.seed(i)
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
yield image, label
class TestTensorDataset(unittest.TestCase):
def run_main(self, num_workers, places):
......@@ -55,8 +84,56 @@ class TestTensorDataset(unittest.TestCase):
def test_main(self):
for p in [fluid.CPUPlace(), fluid.CUDAPlace(0)]:
for num_workers in [0, 2]:
ret = self.run_main(num_workers=num_workers, places=p)
self.run_main(num_workers=0, places=p)
class TestComposeDataset(unittest.TestCase):
def test_main(self):
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
dataset1 = RandomDataset(10)
dataset2 = RandomDataset(10)
dataset = ComposeDataset([dataset1, dataset2])
assert len(dataset) == 10
for i in range(len(dataset)):
input1, label1, input2, label2 = dataset[i]
input1_t, label1_t = dataset1[i]
input2_t, label2_t = dataset2[i]
assert np.allclose(input1, input1_t)
assert np.allclose(label1, label1_t)
assert np.allclose(input2, input2_t)
assert np.allclose(label2, label2_t)
class TestChainDataset(unittest.TestCase):
def run_main(self, num_workers, places):
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
dataset1 = RandomIterableDataset(10)
dataset2 = RandomIterableDataset(10)
dataset = ChainDataset([dataset1, dataset2])
samples = []
for data in iter(dataset):
samples.append(data)
assert len(samples) == 20
idx = 0
for image, label in iter(dataset1):
assert np.allclose(image, samples[idx][0])
assert np.allclose(label, samples[idx][1])
idx += 1
for image, label in iter(dataset2):
assert np.allclose(image, samples[idx][0])
assert np.allclose(label, samples[idx][1])
idx += 1
def test_main(self):
for p in [fluid.CPUPlace(), fluid.CUDAPlace(0)]:
self.run_main(num_workers=0, places=p)
if __name__ == '__main__':
......
......@@ -17,6 +17,8 @@ __all__ = [
'Dataset',
'IterableDataset',
'TensorDataset',
'ComposeDataset',
'ChainDataset',
'BatchSampler',
'DistributedBatchSampler',
# 'Transform',
......@@ -29,4 +31,5 @@ __all__ = [
from ..fluid.io import DataLoader
from ..fluid.dataloader import Dataset, IterableDataset, BatchSampler, get_worker_info, \
TensorDataset, Sampler, SequenceSampler, RandomSampler, DistributedBatchSampler
TensorDataset, Sampler, SequenceSampler, RandomSampler, DistributedBatchSampler, \
ComposeDataset, ChainDataset
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册