batch_sampler.py 12.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import math

17
import numpy as np
18

19 20
from .dataset import IterableDataset
from .sampler import RandomSampler, Sampler, SequenceSampler
21 22


23
class BatchSampler(Sampler):
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
    """
    A base implement of batch sampler used by `paddle.io.DataLoader`
    which yield mini-batch indices(a list/tuple with length as
    mini-batch size and holds sample indices) iterably.

    Batch sampler used by :code:`paddle.io.DataLoader` should be a subclass
    of :code:`paddle.io.BatchSampler`, BatchSampler subclasses should
    implement following methods:

    :code:`__iter__`: return mini-batch indices iterably.

    :code:`__len__`: get mini-batch number in an epoch.


    Args:
1
1want2sleep 已提交
39 40
        dataset(Dataset, optional): this should be an instance of a subclass of :ref:`api_paddle_io_Dataset` or
                :ref:`api_paddle_io_IterableDataset` or other python object which implemented
41
                :code:`__len__` for BatchSampler to get indices as the
1
1want2sleep 已提交
42 43 44
                range of :attr:`dataset` length. Default None, disabled.
        sampler (Sampler, optional): this should be a :ref:`api_paddle_io_Sample`
                instance which implemented :code:`__iter__` to generate
45 46
                sample indices. :attr:`sampler` and :attr:`dataset`
                can not be set in the same time.  If :attr:`sampler`
1
1want2sleep 已提交
47 48 49 50 51 52
                is set, :attr:`dataset` should not be set. Default None, disabled.
        shuffle(bool, optional): whether to shuffle indices order before generating
                batch indices. Default False, don't shuffle indices before generating batch indices.
        batch_size(int, optional): sample indice number in a mini-batch indices. default 1, each mini-batch includes 1 sample.
        drop_last(bool, optional): whether drop the last incomplete (less than 1 mini-batch) batch dataset. Default False, keep it.
    see :ref:`api_paddle_io_DataLoader`
53 54 55 56 57

    Returns:
        BatchSampler: an iterable object for indices iterating

    Examples:
58

59
        .. code-block:: python
60

61
            from paddle.io import RandomSampler, BatchSampler, Dataset
62 63 64 65 66

            # init with dataset
            class RandomDataset(Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples
67

68 69 70 71
                def __getitem__(self, idx):
                    image = np.random.random([784]).astype('float32')
                    label = np.random.randint(0, 9, (1, )).astype('int64')
                    return image, label
72

73 74
                def __len__(self):
                    return self.num_samples
75

76 77 78 79 80 81 82 83
            bs = BatchSampler(dataset=RandomDataset(100),
                              shuffle=False,
                              batch_size=16,
                              drop_last=False)

            for batch_indices in bs:
                print(batch_indices)

84 85 86 87 88 89 90 91 92 93
            # init with sampler
            sampler = RandomSampler(RandomDataset(100))
            bs = BatchSampler(sampler=sampler,
                              batch_size=8,
                              drop_last=True)

            for batch_indices in bs:
                print(batch_indices)


94 95 96

    """

97 98 99 100 101 102 103 104
    def __init__(
        self,
        dataset=None,
        sampler=None,
        shuffle=False,
        batch_size=1,
        drop_last=False,
    ):
105
        if dataset is None:
106 107 108 109 110 111 112 113
            assert (
                sampler is not None
            ), "either dataset or sampler should be set"
            assert isinstance(
                sampler, Sampler
            ), "sampler should be a paddle.io.Sampler, but got {}".format(
                type(sampler)
            )
114 115
            assert not shuffle, "shuffle should be False when sampler is set"
            self.sampler = sampler
116
        else:
117 118 119 120 121 122 123 124 125
            assert not isinstance(
                dataset, IterableDataset
            ), "dataset should not be a paddle.io.IterableDataset"
            assert sampler is None, "should not set both dataset and sampler"
            assert isinstance(
                shuffle, bool
            ), "shuffle should be a boolean value, but got {}".format(
                type(shuffle)
            )
126 127 128 129
            if shuffle:
                self.sampler = RandomSampler(dataset)
            else:
                self.sampler = SequenceSampler(dataset)
130

131 132 133 134 135
        assert (
            isinstance(batch_size, int) and batch_size > 0
        ), "batch_size should be a positive integer, but got {}".format(
            batch_size
        )
136
        self.batch_size = batch_size
137 138 139 140 141
        assert isinstance(
            drop_last, bool
        ), "drop_last should be a boolean value, but got {}".format(
            type(drop_last)
        )
142 143 144 145
        self.drop_last = drop_last

    def __iter__(self):
        batch_indices = []
146
        for idx in self.sampler:
147 148 149 150 151 152 153 154
            batch_indices.append(idx)
            if len(batch_indices) == self.batch_size:
                yield batch_indices
                batch_indices = []
        if not self.drop_last and len(batch_indices) > 0:
            yield batch_indices

    def __len__(self):
155
        num_samples = len(self.sampler)
156 157
        num_samples += int(not self.drop_last) * (self.batch_size - 1)
        return num_samples // self.batch_size
158 159


160
class _InfiniteIterableSampler:
161 162 163 164 165 166 167 168 169 170
    def __init__(self, dataset, batch_size=1):
        assert isinstance(
            dataset, IterableDataset
        ), "dataset should be an instance of paddle.io.IterableDataset"
        self.dataset = dataset
        self.batch_size = batch_size

    def __iter__(self):
        while True:
            yield [None] * self.batch_size
171 172 173 174 175


class DistributedBatchSampler(BatchSampler):
    """Sampler that restricts data loading to a subset of the dataset.

176 177
    In such case, each process can pass a DistributedBatchSampler instance
    as a DataLoader sampler, and load a subset of the original dataset that
178 179 180 181
    is exclusive to it.

    .. note::
        Dataset is assumed to be of constant size.
182

183
    Args:
1
1want2sleep 已提交
184
        dataset(Dataset): this could be an instance of subclass of :ref:`api_paddle_io_Dataset`
185
                     or other python object which implemented
1
1want2sleep 已提交
186 187
                     `__len__` for BatchSampler to get indices of samples.
        batch_size(int): sample size of each mini-batch.
188 189
        num_replicas(int, optional): porcess number in distributed training.
            If :attr:`num_replicas` is None, :attr:`num_replicas` will be
1
1want2sleep 已提交
190
            retrieved from :ref:`api_paddle_distributed_ParallelEnv` .
191 192 193
            Default None.
        rank(int, optional): the rank of the current process among :attr:`num_replicas`
            processes. If :attr:`rank` is None, :attr:`rank` is retrieved from
1
1want2sleep 已提交
194 195
            :ref:`api_paddle_distributed_ParallelEnv`. Default None.
        shuffle(bool, optional): whther to shuffle indices order before genrating
196
            batch indices. Default False.
1
1want2sleep 已提交
197 198 199 200 201
        drop_last(bool, optional): whether drop the last incomplete(less than a mini-batch) batch dataset size.
            Default False.

    Returns:
        DistributedBatchSampler, return an iterable object for indices iterating.
202 203 204 205 206 207 208 209 210 211 212 213

    Examples:
        .. code-block:: python

            import numpy as np

            from paddle.io import Dataset, DistributedBatchSampler

            # init with dataset
            class RandomDataset(Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples
214

215 216 217 218
                def __getitem__(self, idx):
                    image = np.random.random([784]).astype('float32')
                    label = np.random.randint(0, 9, (1, )).astype('int64')
                    return image, label
219

220 221
                def __len__(self):
                    return self.num_samples
222

223 224 225 226 227 228 229 230
            dataset = RandomDataset(100)
            sampler = DistributedBatchSampler(dataset, batch_size=64)

            for data in sampler:
                # do something
                break
    """

231 232 233 234 235 236 237 238 239
    def __init__(
        self,
        dataset,
        batch_size,
        num_replicas=None,
        rank=None,
        shuffle=False,
        drop_last=False,
    ):
240 241
        self.dataset = dataset

242 243 244
        assert (
            isinstance(batch_size, int) and batch_size > 0
        ), "batch_size should be a positive integer"
245
        self.batch_size = batch_size
246
        assert isinstance(shuffle, bool), "shuffle should be a boolean value"
247
        self.shuffle = shuffle
248 249 250
        assert isinstance(
            drop_last, bool
        ), "drop_last should be a boolean number"
251

252
        from paddle.distributed import ParallelEnv
253 254

        if num_replicas is not None:
255 256 257
            assert (
                isinstance(num_replicas, int) and num_replicas > 0
            ), "num_replicas should be a positive integer"
258 259 260 261 262
            self.nranks = num_replicas
        else:
            self.nranks = ParallelEnv().nranks

        if rank is not None:
263 264 265
            assert (
                isinstance(rank, int) and rank >= 0
            ), "rank should be a non-negative integer"
266 267 268 269 270 271 272 273 274 275 276 277
            self.local_rank = rank
        else:
            self.local_rank = ParallelEnv().local_rank

        self.drop_last = drop_last
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
        self.total_size = self.num_samples * self.nranks

    def __iter__(self):
        num_samples = len(self.dataset)
        indices = np.arange(num_samples).tolist()
278
        indices += indices[: (self.total_size - len(indices))]
279 280 281 282 283 284 285 286 287 288 289 290
        assert len(indices) == self.total_size
        if self.shuffle:
            np.random.RandomState(self.epoch).shuffle(indices)
            self.epoch += 1

        # subsample
        def _get_indices_by_batch_size(indices):
            subsampled_indices = []
            last_batch_size = self.total_size % (self.batch_size * self.nranks)
            assert last_batch_size % self.nranks == 0
            last_local_batch_size = last_batch_size // self.nranks

291 292 293 294 295 296
            for i in range(
                self.local_rank * self.batch_size,
                len(indices) - last_batch_size,
                self.batch_size * self.nranks,
            ):
                subsampled_indices.extend(indices[i : i + self.batch_size])
297

298
            indices = indices[len(indices) - last_batch_size :]
299
            subsampled_indices.extend(
300 301 302 303 304 305
                indices[
                    self.local_rank
                    * last_local_batch_size : (self.local_rank + 1)
                    * last_local_batch_size
                ]
            )
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
            return subsampled_indices

        if self.nranks > 1:
            indices = _get_indices_by_batch_size(indices)

        assert len(indices) == self.num_samples
        _sample_iter = iter(indices)

        batch_indices = []
        for idx in _sample_iter:
            batch_indices.append(idx)
            if len(batch_indices) == self.batch_size:
                yield batch_indices
                batch_indices = []
        if not self.drop_last and len(batch_indices) > 0:
            yield batch_indices

    def __len__(self):
        num_samples = self.num_samples
        num_samples += int(not self.drop_last) * (self.batch_size - 1)
        return num_samples // self.batch_size

    def set_epoch(self, epoch):
        """
        Sets the epoch number. When :attr:`shuffle=True`, this number is used
        as seeds of random numbers. By default, users may not set this, all
        replicas (workers) use a different random ordering for each epoch.
        If set same number at each epoch, this sampler will yield the same
        ordering at all epoches.

        Arguments:
            epoch (int): Epoch number.

        Examples:
            .. code-block:: python
341

342
                import numpy as np
343

344
                from paddle.io import Dataset, DistributedBatchSampler
345

346 347 348 349
                # init with dataset
                class RandomDataset(Dataset):
                    def __init__(self, num_samples):
                        self.num_samples = num_samples
350

351 352 353 354
                    def __getitem__(self, idx):
                        image = np.random.random([784]).astype('float32')
                        label = np.random.randint(0, 9, (1, )).astype('int64')
                        return image, label
355

356 357
                    def __len__(self):
                        return self.num_samples
358

359 360
                dataset = RandomDataset(100)
                sampler = DistributedBatchSampler(dataset, batch_size=64)
361

362 363 364 365
                for epoch in range(10):
                    sampler.set_epoch(epoch)
        """
        self.epoch = epoch