From c45481d7be71e90c194c615b61db411463d43349 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Wed, 19 Aug 2020 11:40:49 +0800 Subject: [PATCH] add Sampler/SequenceSampler/RandomSampler (#26375) * add Sampler/SequenceSampler/RandomSampler. test=develop --- python/paddle/fluid/dataloader/__init__.py | 6 +- .../paddle/fluid/dataloader/batch_sampler.py | 79 +++--- python/paddle/fluid/dataloader/sampler.py | 232 ++++++++++++++++++ .../tests/unittests/test_batch_sampler.py | 78 +++++- python/paddle/io/__init__.py | 6 +- 5 files changed, 349 insertions(+), 52 deletions(-) create mode 100644 python/paddle/fluid/dataloader/sampler.py diff --git a/python/paddle/fluid/dataloader/__init__.py b/python/paddle/fluid/dataloader/__init__.py index 2f15811e4f3..597f1f21748 100644 --- a/python/paddle/fluid/dataloader/__init__.py +++ b/python/paddle/fluid/dataloader/__init__.py @@ -23,6 +23,10 @@ from .batch_sampler import * from . import dataloader_iter from .dataloader_iter import * +from . import sampler +from .sampler import * + __all__ = dataset.__all__ \ + batch_sampler.__all__ \ - + dataloader_iter.__all__ + + dataloader_iter.__all__ \ + + sampler.__all__ diff --git a/python/paddle/fluid/dataloader/batch_sampler.py b/python/paddle/fluid/dataloader/batch_sampler.py index 811468c523b..8043237c0d9 100644 --- a/python/paddle/fluid/dataloader/batch_sampler.py +++ b/python/paddle/fluid/dataloader/batch_sampler.py @@ -16,12 +16,13 @@ from __future__ import print_function from __future__ import division import numpy as np +from .sampler import Sampler, SequenceSampler from .dataset import Dataset, IterableDataset __all__ = ["BatchSampler"] -class BatchSampler(object): +class BatchSampler(Sampler): """ A base implement of batch sampler used by `paddle.io.DataLoader` which yield mini-batch indices(a list/tuple with length as @@ -41,10 +42,11 @@ class BatchSampler(object): implement or other python object which implemented :code:`__len__` for BatchSampler to get indices as the range of :attr:`dataset` length. Default None. - indices (list|tuple): a substitution parameter for - :attr:`dataset` either :attr:`dataset` or - :attr:`indices` should be set, give the whole - indices to sampler from directly. Default None. + sampler (Sampler): this could be a :code:`paddle.io.Dataset` + instance which implemented :code:`__iter__` to yield + sample indices. :attr:`sampler` and :attr:`dataset` + can not be set in the same time. If :attr:`sampler` + is set, :attr:`shuffle` should not be set. Default None. shuffle(bool): whether to shuffle indices order before genrating batch indices. Default False. batch_size(int): sample indice number in a mini-batch indices. @@ -58,16 +60,7 @@ class BatchSampler(object): .. code-block:: python - from paddle.io import BatchSampler, Dataset - - # init with indices - bs = BatchSampler(indices=list(range(100)), - shuffle=True, - batch_size=8, - drop_last=True) - - for batch_indices in bs: - print(batch_indices) + from paddle.io import RandomSampler, BatchSampler, Dataset # init with dataset class RandomDataset(Dataset): @@ -90,34 +83,42 @@ class BatchSampler(object): for batch_indices in bs: print(batch_indices) + # init with sampler + sampler = RandomSampler(RandomDataset(100)) + bs = BatchSampler(sampler=sampler, + shuffle=True, + batch_size=8, + drop_last=True) + + for batch_indices in bs: + print(batch_indices) + + see `paddle.io.DataLoader` """ def __init__(self, dataset=None, - indices=None, + sampler=None, shuffle=False, batch_size=1, drop_last=False): if dataset is None: - assert indices is not None, \ - "either dataset or indices should be set" - assert isinstance(indices, list) or isinstance(indices, tuple), \ - "indices should be a list or tuple, but got {}".format(type(indices)) - self.indices = indices - self.sampler_iter = None + assert sampler is not None, \ + "either dataset or sampler should be set" + assert isinstance(sampler, Sampler), \ + "sampler should be a paddle.io.Sampler, but got {}".format(type(sampler)) + assert not shuffle, "shuffle should be False when sampler is set" + self.sampler = sampler else: - if isinstance(dataset, IterableDataset): - self.sampler_iter = iter( - _InfiniteIterableSampler(dataset, batch_size)) - else: - self.sampler_iter = None - assert isinstance(dataset, Dataset), \ - "dataset should be an instance of paddle.io.Dataset" - assert indices is None, \ - "should not set both dataset and indices" - self.indices = list(range(len(dataset))) + assert isinstance(dataset, Dataset), \ + "dataset should be a paddle.io.Dataset" + assert not isinstance(dataset, IterableDataset), \ + "dataset should not be a paddle.io.IterableDataset" + assert sampler is None, \ + "should not set both dataset and sampler" + self.sampler = SequenceSampler(dataset) assert isinstance(batch_size, int) and batch_size > 0, \ "batch_size should be a positive integer, but got {}".format(batch_size) @@ -130,15 +131,8 @@ class BatchSampler(object): self.drop_last = drop_last def __iter__(self): - if self.sampler_iter: - yield next(self.sampler_iter) - - if self.shuffle: - np.random.shuffle(self.indices) - _iter = iter(self.indices) - batch_indices = [] - for idx in _iter: + for idx in self.sampler: batch_indices.append(idx) if len(batch_indices) == self.batch_size: yield batch_indices @@ -147,10 +141,7 @@ class BatchSampler(object): yield batch_indices def __len__(self): - if self.sampler_iter: - raise RuntimeError("'{}' should not be called for IterableDataset". - format('__len__')) - num_samples = len(self.indices) + num_samples = len(self.sampler) num_samples += int(not self.drop_last) * (self.batch_size - 1) return num_samples // self.batch_size diff --git a/python/paddle/fluid/dataloader/sampler.py b/python/paddle/fluid/dataloader/sampler.py new file mode 100644 index 00000000000..d2f3231cc6b --- /dev/null +++ b/python/paddle/fluid/dataloader/sampler.py @@ -0,0 +1,232 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from __future__ import division + +import numpy as np + +__all__ = ["Sampler", "SequenceSampler", "RandomSampler"] + + +class Sampler(object): + """ + An abstract class to encapsulate methods and behaviors of samplers. + + All sampler used by :code:`paddle.io.BatchSampler` should be a subclass + of :code:`paddle.io.Sampler`, BatchSampler subclasses should + implement following methods: + + :code:`__iter__`: return sample index iterably, which iterate over indices + of dataset elements + + :code:`__len__`: the number of sample in :attr:`data_source` + + + Args: + data_source(Dataset, optional): this could be an instance of + :code:`paddle.io.Dataset` other Python object which + implemented :code:`__len__` for Sampler to get indices + as the range of :attr:`dataset` length. Default None. + + Returns: + Sampler: an iterable object for sample indices iterating + + Examples: + + .. code-block:: python + + from paddle.io import Dataset, Sampler + + class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([784]).astype('float32') + label = np.random.randint(0, 9, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + class MySampler(Sampler): + def __init__(self, data_source): + self.data_source = data_source + + def __iter__(self): + return iter(range(len(self.data_source))) + + def __len__(self): + return len(self.data_source) + + sampler = MySampler(data_source=RandomDataset(100)) + + for index in sampler: + print(index) + + see `paddle.io.BatchSampler` + see `paddle.io.DataLoader` + + """ + + def __init__(self, data_source=None): + self.data_source = data_source + + def __iter__(self): + raise NotImplementedError + + # Not define __len__ method in this base class here for __len__ + # is not needed in same sence, e.g. paddle.io.IterableDataset + + +class SequenceSampler(Sampler): + """ + Iterate samples sequentially, yield :code:`0, 1, 2, ..., len(data_source) -1` + generally, + + Args: + data_source(Dataset): dataset to sample, this could be an + instance of :code:`paddle.io.Dataset` other Python + object which implemented :code:`__len__`. + + Returns: + Sampler: a Sampler yield sample index sequentially + + Examples: + + .. code-block:: python + + from paddle.io import Dataset, SequenceSampler + + class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([784]).astype('float32') + label = np.random.randint(0, 9, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + sampler = SequenceSampler(data_source=RandomDataset(100)) + + for index in sampler: + print(index) + + see `paddle.io.Sampler` + """ + + def __init__(self, data_source): + self.data_source = data_source + + def __iter__(self): + return iter(range(len(self.data_source))) + + def __len__(self): + return len(self.data_source) + + +class RandomSampler(Sampler): + """ + Iterate samples randomly, yield shuffled indices, if :attr:`replacement=False`, + yield shuffled indices of the whole data souce, if :attr:`replacement=True`, + :attr:`num_samples` can set to specify the sample number to draw. + + Args: + data_source(Dataset): dataset to sample, this could be an + instance of :code:`paddle.io.Dataset` other Python + object which implemented :code:`__len__`. + replacement(bool): If False, sample the whole dataset, If False, + set :attr:`num_samples` for how many sample to draw. Default False. + num_samples(int): set sample number to draw if :attr:`replacement` + is True. Default None. + generator(Generator): specify a generator to sample the data source. Default None + + Returns: + Sampler: a Sampler yield sample index randomly + + Examples: + + .. code-block:: python + + from paddle.io import Dataset, RandomSampler + + class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([784]).astype('float32') + label = np.random.randint(0, 9, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + sampler = RandomSampler(data_souce=RandomDataset(100)) + + for index in sampler: + print(index) + + see `paddle.io.Sampler` + """ + + def __init__(self, + data_source, + replacement=False, + num_samples=None, + generator=None): + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.generator = generator + + if not isinstance(self.replacement, bool): + raise TypeError("expect boolean value for replacement, but got " + "replacement={}".format(self.replacement)) + + if self._num_samples is not None and not replacement: + raise ValueError( + "num_samples should not be specified while replacement is False") + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError("num_samples should be a positive integer, " + "but got num_samples={}".format(self.num_samples)) + + @property + def num_samples(self): + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self): + n = len(self.data_source) + if self.generator: + for index in self.generator: + yield index + else: + if self.replacement: + for index in np.random.choice( + np.arange(n), self.num_samples, replace=True).tolist(): + yield index + else: + for index in np.random.choice( + np.arange(n), n, replace=False).tolist(): + yield index + + def __len__(self): + return self.num_samples diff --git a/python/paddle/fluid/tests/unittests/test_batch_sampler.py b/python/paddle/fluid/tests/unittests/test_batch_sampler.py index 7d90bbd0357..2e2a6144fd0 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_sampler.py +++ b/python/paddle/fluid/tests/unittests/test_batch_sampler.py @@ -17,7 +17,7 @@ from __future__ import division import unittest import paddle.fluid as fluid -from paddle.io import BatchSampler, Dataset +from paddle.io import BatchSampler, Dataset, Sampler, SequenceSampler, RandomSampler class RandomDataset(Dataset): @@ -35,6 +35,60 @@ class RandomDataset(Dataset): return self.sample_num +class TestSampler(unittest.TestCase): + def test_main(self): + dataset = RandomDataset(100, 10) + sampler = Sampler(dataset) + try: + iter(sampler) + self.assertTrue(False) + except NotImplementedError: + pass + + +class TestSequenceSampler(unittest.TestCase): + def test_main(self): + dataset = RandomDataset(100, 10) + sampler = SequenceSampler(dataset) + assert len(sampler) == 100 + + for i, index in enumerate(iter(sampler)): + assert i == index + + +class TestRandomSampler(unittest.TestCase): + def test_main(self): + dataset = RandomDataset(100, 10) + sampler = RandomSampler(dataset) + assert len(sampler) == 100 + + rets = [] + for i in iter(sampler): + rets.append(i) + assert tuple(sorted(rets)) == tuple(range(0, 100)) + + def test_with_num_samples(self): + dataset = RandomDataset(100, 10) + sampler = RandomSampler(dataset, num_samples=50, replacement=True) + assert len(sampler) == 50 + + rets = [] + for i in iter(sampler): + rets.append(i) + assert i >= 0 and i < 100 + + def test_with_generator(self): + dataset = RandomDataset(100, 10) + generator = iter(range(0, 60)) + sampler = RandomSampler(dataset, generator=generator) + assert len(sampler) == 100 + + rets = [] + for i in iter(sampler): + rets.append(i) + assert tuple(sorted(rets)) == tuple(range(0, 60)) + + class TestBatchSampler(unittest.TestCase): def setUp(self): self.num_samples = 1000 @@ -86,16 +140,18 @@ class TestBatchSamplerShuffle(TestBatchSampler): self.drop_last = True -class TestBatchSamplerWithIndices(TestBatchSampler): +class TestBatchSamplerWithSampler(TestBatchSampler): def init_batch_sampler(self): + dataset = RandomDataset(1000, 10) + sampler = SequenceSampler(dataset) bs = BatchSampler( - indices=list(range(self.num_samples)), + sampler=sampler, batch_size=self.batch_size, drop_last=self.drop_last) return bs -class TestBatchSamplerWithIndicesAndDataSource(unittest.TestCase): +class TestBatchSamplerWithSamplerDropLast(unittest.TestCase): def setUp(self): self.num_samples = 1000 self.num_classes = 10 @@ -103,12 +159,22 @@ class TestBatchSamplerWithIndicesAndDataSource(unittest.TestCase): self.shuffle = False self.drop_last = True + +class TestBatchSamplerWithSamplerShuffle(unittest.TestCase): + def setUp(self): + self.num_samples = 1000 + self.num_classes = 10 + self.batch_size = 32 + self.shuffle = True + self.drop_last = True + def test_main(self): try: dataset = RandomDataset(self.num_samples, self.num_classes) + sampler = RandomSampler(dataset) bs = BatchSampler( - dataset=dataset, - indices=list(range(self.num_samples)), + sampler=sampler, + shuffle=self.shuffle, batch_size=self.batch_size, drop_last=self.drop_last) self.assertTrue(False) diff --git a/python/paddle/io/__init__.py b/python/paddle/io/__init__.py index 875f3ff2e91..89bbd591657 100644 --- a/python/paddle/io/__init__.py +++ b/python/paddle/io/__init__.py @@ -20,6 +20,9 @@ __all__ = [ # 'Transform', 'DataLoader', 'get_worker_info', + 'Sampler', + 'SequenceSampler', + 'RandomSampler', 'load', 'save', 'load_program_state', @@ -38,7 +41,8 @@ __all__ = [ ] from ..fluid.io import DataLoader -from ..fluid.dataloader import Dataset, IterableDataset, BatchSampler, get_worker_info +from ..fluid.dataloader import Dataset, IterableDataset, BatchSampler, get_worker_info, \ + Sampler, SequenceSampler, RandomSampler from ..fluid.io import load, save, load_program_state, set_program_state, \ load_inference_model, save_inference_model, batch from ..reader import shuffle, buffered, cache, chain, firstn, compose, map_readers, xmap_readers -- GitLab