batch_sampler.py 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
from __future__ import division

import numpy as np
19
from .sampler import Sampler, SequenceSampler, RandomSampler
20
from .dataset import Dataset, IterableDataset
21 22 23 24

__all__ = ["BatchSampler"]


25
class BatchSampler(Sampler):
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
    """
    A base implement of batch sampler used by `paddle.io.DataLoader`
    which yield mini-batch indices(a list/tuple with length as
    mini-batch size and holds sample indices) iterably.

    Batch sampler used by :code:`paddle.io.DataLoader` should be a subclass
    of :code:`paddle.io.BatchSampler`, BatchSampler subclasses should
    implement following methods:

    :code:`__iter__`: return mini-batch indices iterably.

    :code:`__len__`: get mini-batch number in an epoch.


    Args:
        dataset(Dataset): this could be a :code:`paddle.io.Dataset` 
                implement or other python object which implemented
                :code:`__len__` for BatchSampler to get indices as the
                range of :attr:`dataset` length. Default None.
45 46 47 48 49
        sampler (Sampler): this could be a :code:`paddle.io.Dataset`
                instance which implemented :code:`__iter__` to yield
                sample indices. :attr:`sampler` and :attr:`dataset`
                can not be set in the same time.  If :attr:`sampler`
                is set, :attr:`shuffle` should not be set. Default None.
50 51 52 53 54 55 56 57 58 59 60 61 62
        shuffle(bool): whether to shuffle indices order before genrating
                batch indices. Default False.
        batch_size(int): sample indice number in a mini-batch indices.
        drop_last(bool): whether drop the last incomplete batch dataset size
            is not divisible by the batch size. Default False

    Returns:
        BatchSampler: an iterable object for indices iterating

    Examples:
        
        .. code-block:: python
            
63
            from paddle.io import RandomSampler, BatchSampler, Dataset
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

            # init with dataset
            class RandomDataset(Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples
            
                def __getitem__(self, idx):
                    image = np.random.random([784]).astype('float32')
                    label = np.random.randint(0, 9, (1, )).astype('int64')
                    return image, label
                
                def __len__(self):
                    return self.num_samples
            
            bs = BatchSampler(dataset=RandomDataset(100),
                              shuffle=False,
                              batch_size=16,
                              drop_last=False)

            for batch_indices in bs:
                print(batch_indices)

86 87 88 89 90 91 92 93 94 95
            # init with sampler
            sampler = RandomSampler(RandomDataset(100))
            bs = BatchSampler(sampler=sampler,
                              batch_size=8,
                              drop_last=True)

            for batch_indices in bs:
                print(batch_indices)


96 97 98 99 100 101
    see `paddle.io.DataLoader`

    """

    def __init__(self,
                 dataset=None,
102
                 sampler=None,
103 104 105 106
                 shuffle=False,
                 batch_size=1,
                 drop_last=False):
        if dataset is None:
107 108 109 110 111 112
            assert sampler is not None, \
                "either dataset or sampler should be set"
            assert isinstance(sampler, Sampler), \
                "sampler should be a paddle.io.Sampler, but got {}".format(type(sampler))
            assert not shuffle, "shuffle should be False when sampler is set"
            self.sampler = sampler
113
        else:
114 115 116 117 118 119
            assert isinstance(dataset, Dataset), \
                "dataset should be a paddle.io.Dataset"
            assert not isinstance(dataset, IterableDataset), \
                "dataset should not be a paddle.io.IterableDataset"
            assert sampler is None, \
                "should not set both dataset and sampler"
120 121 122 123 124 125
            assert isinstance(shuffle, bool), \
                "shuffle should be a boolean value, but got {}".format(type(shuffle))
            if shuffle:
                self.sampler = RandomSampler(dataset)
            else:
                self.sampler = SequenceSampler(dataset)
126 127 128 129 130 131 132 133 134 135

        assert isinstance(batch_size, int) and batch_size > 0, \
            "batch_size should be a positive integer, but got {}".format(batch_size)
        self.batch_size = batch_size
        assert isinstance(drop_last, bool), \
            "drop_last should be a boolean value, but got {}".format(type(drop_last))
        self.drop_last = drop_last

    def __iter__(self):
        batch_indices = []
136
        for idx in self.sampler:
137 138 139 140 141 142 143 144
            batch_indices.append(idx)
            if len(batch_indices) == self.batch_size:
                yield batch_indices
                batch_indices = []
        if not self.drop_last and len(batch_indices) > 0:
            yield batch_indices

    def __len__(self):
145
        num_samples = len(self.sampler)
146 147
        num_samples += int(not self.drop_last) * (self.batch_size - 1)
        return num_samples // self.batch_size
148 149 150 151 152 153 154 155 156 157 158 159 160


class _InfiniteIterableSampler(object):
    def __init__(self, dataset, batch_size=1):
        assert isinstance(
            dataset, IterableDataset
        ), "dataset should be an instance of paddle.io.IterableDataset"
        self.dataset = dataset
        self.batch_size = batch_size

    def __iter__(self):
        while True:
            yield [None] * self.batch_size