sampler.py 11.9 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15 16
"""
At most cases, we have non-stream dataset, which means we can random access it with __getitem__, and we can get the length of the dataset with __len__.

17
This suffices for a sampler. We implemente sampler as iterable of valid indices. By valid, we mean 0 <= index < N, where N is the length of the dataset. We then collect several indices within a batch and use them to collect examples from the dataset with __getitem__. Then transform these examples into a batch.
18 19 20 21 22 23 24

So the sampler is only responsible for generating valid indices.
"""

import numpy as np
import random

L
lifuchen 已提交
25

26 27 28 29 30 31 32 33 34
class Sampler(object):
    def __iter__(self):
        # return a iterator of indices
        # or a iterator of list[int], for BatchSampler
        raise NotImplementedError


class SequentialSampler(Sampler):
    def __init__(self, data_source):
35 36 37 38 39
        """Sequential sampler, the simplest sampler that samples indices from 0 to N - 1, where N is the dataset is length.

        Args:
            data_source (DatasetMixin): the dataset. This is used to get the dataset's length.
        """
40
        self.data_source = data_source
L
lifuchen 已提交
41

42 43 44 45 46 47 48 49 50
    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)


class RandomSampler(Sampler):
    def __init__(self, data_source, replacement=False, num_samples=None):
51 52 53 54 55 56 57
        """Random sampler.

        Args:
            data_source (DatasetMixin): the dataset. This is used to get the dataset's length.
            replacement (bool, optional): whether replacement is enabled in sampling. When `replacement` is True, `num_samples` must be provided. Defaults to False.
            num_samples (int, optional): numbers of indices to draw. This option should only be provided when replacement is True. Defaults to None.
        """
58 59 60 61 62 63 64 65 66
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples

        if not isinstance(self.replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(self.replacement))

        if self._num_samples is not None and not replacement:
L
lifuchen 已提交
67 68 69
            raise ValueError(
                "With replacement=False, num_samples should not be specified, "
                "since a random permutation will be performed.")
70 71 72

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
L
lifuchen 已提交
73 74
                             "value, but got num_samples={}".format(
                                 self.num_samples))
75 76 77 78 79 80 81 82 83 84

    @property
    def num_samples(self):
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self):
        n = len(self.data_source)
        if self.replacement:
L
lifuchen 已提交
85 86 87
            return iter(
                np.random.randint(
                    0, n, size=(self.num_samples, ), dtype=np.int64).tolist())
88 89 90
        return iter(np.random.permutation(n).tolist())

    def __len__(self):
C
chenfeiyu 已提交
91
        return self.num_samples
92 93 94


class SubsetRandomSampler(Sampler):
95
    """Samples elements randomly from a given list of indices, without replacement.
96 97 98 99 100
    Arguments:
        indices (sequence): a sequence of indices
    """

    def __init__(self, indices):
101 102 103 104
        """
        Args:
            indices (List[int]): indices to sample from.
        """
105 106 107
        self.indices = indices

    def __iter__(self):
L
lifuchen 已提交
108 109
        return (self.indices[i]
                for i in np.random.permutation(len(self.indices)))
110 111 112 113 114 115 116 117 118 119 120 121

    def __len__(self):
        return len(self.indices)


class PartialyRandomizedSimilarTimeLengthSampler(Sampler):
    """Partially randmoized sampler, implemented as a example sampler
    1. Sort by lengths
    2. Pick a small patch and randomize it
    3. Permutate mini-batchs
    """

L
lifuchen 已提交
122 123 124 125
    def __init__(self,
                 lengths,
                 batch_size=4,
                 batch_group_size=None,
126
                 permutate=True):
127 128 129 130 131 132 133 134
        """[summary]

        Args:
            lengths (List[int]): The length of the examples of the dataset. This is the key to be considered as 'time length'.
            batch_size (int, optional): batch size. Defaults to 4.
            batch_group_size (int, optional): the size of a small batch. Random shuffling is applied within such patches. If `batch_group_size` is not provided, it is set to min(batch_size * 32, len(self.lengths)). Batch_group_size should be perfectly divided by batch_size. Defaults to None.
            permutate (bool, optional): permutate batches. Defaults to True.
        """
L
lifuchen 已提交
135 136 137
        _lengths = np.array(
            lengths,
            dtype=np.int64)  # maybe better implement length as a sort key
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
        self.lengths = np.sort(_lengths)
        self.sorted_indices = np.argsort(_lengths)

        self.batch_size = batch_size
        if batch_group_size is None:
            batch_group_size = min(batch_size * 32, len(self.lengths))
            if batch_group_size % batch_size != 0:
                batch_group_size -= batch_group_size % batch_size

        self.batch_group_size = batch_group_size
        assert batch_group_size % batch_size == 0
        self.permutate = permutate

    def __iter__(self):
        indices = np.copy(self.sorted_indices)
        batch_group_size = self.batch_group_size
        s, e = 0, 0
        for i in range(len(indices) // batch_group_size):
            s = i * batch_group_size
            e = s + batch_group_size
L
lifuchen 已提交
158
            random.shuffle(indices[s:e])  # inplace
159 160 161 162 163

        # Permutate batches
        if self.permutate:
            perm = np.arange(len(indices[:e]) // self.batch_size)
            random.shuffle(perm)
L
lifuchen 已提交
164 165
            indices[:e] = indices[:e].reshape(
                -1, self.batch_size)[perm, :].reshape(-1)
166 167 168 169 170 171

        # Handle last elements
        s += batch_group_size
        #print(indices)
        if s < len(indices):
            random.shuffle(indices[s:])
L
lifuchen 已提交
172

173 174 175 176 177 178 179
        return iter(indices)

    def __len__(self):
        return len(self.sorted_indices)


class WeightedRandomSampler(Sampler):
180
    """Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
181
    Args:
182 183 184
        weights (List[float]): a sequence of weights, not necessary summing up to 1.
        num_samples (int): number of samples to draw.
        replacement (bool): whether samples are drawn with replacement. When replacement is False, num_samples should not be larger than len(weights).
185 186 187 188 189 190 191 192 193 194
    Example:
        >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
        [0, 0, 0, 1, 0]
        >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
        [0, 1, 4, 3, 2]
    """

    def __init__(self, weights, num_samples, replacement):
        if not isinstance(num_samples, int) or num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
L
lifuchen 已提交
195 196
                             "value, but got num_samples={}".format(
                                 num_samples))
197 198 199
        self.weights = np.array(weights, dtype=np.float64)
        self.num_samples = num_samples
        self.replacement = replacement
200 201 202 203
        if replacement is False and num_samples > len(weights):
            raise ValueError(
                "when replacement is False, num_samples should not be"
                "larger that length of weight.")
204 205

    def __iter__(self):
L
lifuchen 已提交
206 207 208 209 210 211
        return iter(
            np.random.choice(
                len(self.weights),
                size=(self.num_samples, ),
                replace=self.replacement,
                p=self.weights).tolist())
212 213 214 215 216

    def __len__(self):
        return self.num_samples


K
Kexin Zhao 已提交
217 218
class DistributedSampler(Sampler):
    def __init__(self, dataset_size, num_trainers, rank, shuffle=True):
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
        """Sampler used for data parallel training. Indices are divided into num_trainers parts. Each trainer gets a subset and iter that subset. If the dataset has 16 examples, and there are 4 trainers. 

        Trainer 0 gets [0, 4, 8, 12];
        Trainer 1 gets [1, 5, 9, 13];
        Trainer 2 gets [2, 6, 10, 14];
        trainer 3 gets [3, 7, 11, 15].

        It ensures that trainer get different parts of the dataset. If dataset's length cannot be perfectly devidef by num_trainers, some examples appended to the dataset, to ensures that every trainer gets the same amounts of examples.

        Args:
            dataset_size (int): the length of the dataset.
            num_trainers (int): number of trainers(training processes).
            rank (int): local rank of the trainer.
            shuffle (bool, optional): whether to shuffle the indices before iteration. Defaults to True.
        """
K
Kexin Zhao 已提交
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
        self.dataset_size = dataset_size
        self.num_trainers = num_trainers
        self.rank = rank
        self.num_samples = int(np.ceil(dataset_size / num_trainers))
        self.total_size = self.num_samples * num_trainers
        assert self.total_size >= self.dataset_size
        self.shuffle = shuffle

    def __iter__(self):
        indices = list(range(self.dataset_size))
        if self.shuffle:
            random.shuffle(indices)

        # Append extra samples to make it evenly distributed on all trainers.
        indices += indices[:(self.total_size - self.dataset_size)]
        assert len(indices) == self.total_size

        # Subset samples for each trainer.
        indices = indices[self.rank:self.total_size:self.num_trainers]
L
lifuchen 已提交
253
        assert len(indices) == self.num_samples
K
Kexin Zhao 已提交
254 255 256 257 258 259 260

        return iter(indices)

    def __len__(self):
        return self.num_samples


261
class BatchSampler(Sampler):
262
    """Wraps another sampler to yield a mini-batch of indices."""
263 264

    def __init__(self, sampler, batch_size, drop_last):
265 266 267 268 269 270 271 272 273 274 275
        """
        Args:
            sampler (Sampler): Base sampler.
            batch_size (int): Size of mini-batch.
            drop_last (bool): If True, the sampler will drop the last batch if its size is less than batch_size.
        Example:
            >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
            [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
            >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
            [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
        """
276 277
        if not isinstance(sampler, Sampler):
            raise ValueError("sampler should be an instance of "
L
lifuchen 已提交
278
                             "Sampler, but got sampler={}".format(sampler))
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
        if not isinstance(batch_size, int) or batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
K
Kexin Zhao 已提交
303
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size