dataset.py 16.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

K
Kaipeng Deng 已提交
15
import paddle
16
from .. import framework
17

18 19
__all__ = [
    "Dataset", "IterableDataset", "TensorDataset", "ComposeDataset",
20
    "ChainDataset", "random_split", "Subset"
21
]
22 23 24 25


class Dataset(object):
    """
26
    An abstract class to encapsulate methods and behaviors of datasets.
27 28 29 30 31 32 33 34 35 36 37 38 39 40

    All datasets in map-style(dataset samples can be get by a given key)
    should be a subclass of `paddle.io.Dataset`. All subclasses should
    implement following methods:

    :code:`__getitem__`: get sample from dataset with a given index. This
    method is required by reading dataset sample in :code:`paddle.io.DataLoader`.

    :code:`__len__`: return dataset sample number. This method is required
    by some implements of :code:`paddle.io.BatchSampler`

    see :code:`paddle.io.DataLoader`.

    Examples:
41

42 43 44 45
        .. code-block:: python

            import numpy as np
            from paddle.io import Dataset
46

47 48 49 50
            # define a random dataset
            class RandomDataset(Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples
51

52 53 54 55
                def __getitem__(self, idx):
                    image = np.random.random([784]).astype('float32')
                    label = np.random.randint(0, 9, (1, )).astype('int64')
                    return image, label
56

57 58
                def __len__(self):
                    return self.num_samples
59

60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
            dataset = RandomDataset(10)
            for i in range(len(dataset)):
                print(dataset[i])

    """

    def __init__(self):
        pass

    def __getitem__(self, idx):
        raise NotImplementedError("'{}' not implement in class "\
                "{}".format('__getitem__', self.__class__.__name__))

    def __len__(self):
        raise NotImplementedError("'{}' not implement in class "\
                "{}".format('__len__', self.__class__.__name__))
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93


class IterableDataset(Dataset):
    """
    An abstract class to encapsulate methods and behaviors of iterable datasets.

    All datasets in iterable-style (can only get sample one by one sequentially, like
    a Python iterator) should be a subclass of `paddle.io.IterableDataset`. All subclasses should
    implement following methods:

    :code:`__iter__`: yield sample sequentially. This method is required by reading dataset sample in :code:`paddle.io.DataLoader`.

    .. note::
        do not implement :code:`__getitem__` and :code:`__len__` in IterableDataset, should not be called either.

    see :code:`paddle.io.DataLoader`.

    Examples:
94

95 96 97
        .. code-block:: python

            import numpy as np
98
            from paddle.io import IterableDataset
99

100
            # define a random dataset
101
            class RandomDataset(IterableDataset):
102 103
                def __init__(self, num_samples):
                    self.num_samples = num_samples
104

105 106 107 108 109
                def __iter__(self):
                    for i in range(self.num_samples):
                        image = np.random.random([784]).astype('float32')
                        label = np.random.randint(0, 9, (1, )).astype('int64')
                        yield image, label
110

111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
            dataset = RandomDataset(10)
            for img, lbl in dataset:
                print(img, lbl)

    When :attr:`num_workers > 0`, each worker has a different copy of the dataset object and
    will yield whole dataset samples, which means samples in dataset will be repeated in
    :attr:`num_workers` times. If it is required for each sample to yield only once, there
    are two methods to configure different copy in each worker process to avoid duplicate data
    among workers as follows. In both the methods, worker information that can be getted in
    a worker process by `paddle.io.get_worker_info` will be needed.

    Example 1: splitting data copy in each worker in :code:`__iter__`

        .. code-block:: python

            import math
127
            import paddle
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
            import numpy as np
            from paddle.io import IterableDataset, DataLoader, get_worker_info

            class SplitedIterableDataset(IterableDataset):
                def __init__(self, start, end):
                    self.start = start
                    self.end = end

                def __iter__(self):
                    worker_info = get_worker_info()
                    if worker_info is None:
                        iter_start = self.start
                        iter_end = self.end
                    else:
                        per_worker = int(
                            math.ceil((self.end - self.start) / float(
                                worker_info.num_workers)))
                        worker_id = worker_info.id
                        iter_start = self.start + worker_id * per_worker
                        iter_end = min(iter_start + per_worker, self.end)

                    for i in range(iter_start, iter_end):
                        yield np.array([i])

152 153 154 155 156 157 158 159 160
            dataset = SplitedIterableDataset(start=2, end=9)
            dataloader = DataLoader(
                dataset,
                num_workers=2,
                batch_size=1,
                drop_last=True)

            for data in dataloader:
                print(data)
161 162 163 164 165 166 167
                # outputs: [2, 5, 3, 6, 4, 7]

    Example 2: splitting data copy in each worker by :code:`worker_init_fn`

        .. code-block:: python

            import math
168
            import paddle
169 170 171 172 173 174 175 176 177 178 179 180
            import numpy as np
            from paddle.io import IterableDataset, DataLoader, get_worker_info

            class RangeIterableDataset(IterableDataset):
                def __init__(self, start, end):
                    self.start = start
                    self.end = end

                def __iter__(self):
                    for i in range(self.start, self.end):
                        yield np.array([i])

181
            dataset = RangeIterableDataset(start=2, end=9)
182

183 184
            def worker_init_fn(worker_id):
                worker_info = get_worker_info()
185

186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
                dataset = worker_info.dataset
                start = dataset.start
                end = dataset.end
                num_per_worker = int(
                    math.ceil((end - start) / float(worker_info.num_workers)))

                worker_id = worker_info.id
                dataset.start = start + worker_id * num_per_worker
                dataset.end = min(dataset.start + num_per_worker, end)

            dataloader = DataLoader(
                dataset,
                num_workers=2,
                batch_size=1,
                drop_last=True,
                worker_init_fn=worker_init_fn)

            for data in dataloader:
204
                print(data)
205
            # outputs: [2, 5, 3, 6, 4, 7]
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222

    """

    def __init__(self):
        pass

    def __iter__(self):
        raise NotImplementedError("'{}' not implement in class "\
                "{}".format('__iter__', self.__class__.__name__))

    def __getitem__(self, idx):
        raise RuntimeError("'{}' should not be called for IterableDataset" \
                "{}".format('__getitem__', self.__class__.__name__))

    def __len__(self):
        raise RuntimeError("'{}' should not be called for IterableDataset" \
                "{}".format('__len__', self.__class__.__name__))
223 224 225 226 227 228 229 230 231 232 233


class TensorDataset(Dataset):
    """
    Dataset defined by a list of tensors.

    Each tensor should be in shape of [N, ...], while N is the sample number,
    and ecah tensor contains a field of sample, :code:`TensorDataset` retrieve
    each sample by indexing tensors in the 1st dimension.

    Args:
234
        tensors(list|tuple): A list/tuple of tensors with same shape in the 1st dimension.
235 236 237 238 239 240 241

    Returns:
        Dataset: a Dataset instance wrapping tensors.

    Examples:

        .. code-block:: python
242

243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
            import numpy as np
            import paddle
            from paddle.io import TensorDataset


            input_np = np.random.random([2, 3, 4]).astype('float32')
            input = paddle.to_tensor(input_np)
            label_np = np.random.random([2, 1]).astype('int32')
            label = paddle.to_tensor(label_np)

            dataset = TensorDataset([input, label])

            for i in range(len(dataset)):
                input, label = dataset[i]
                print(input, label)

    """

    def __init__(self, tensors):
J
Jiabin Yang 已提交
262
        if not framework._non_static_mode():
263 264 265 266 267 268 269 270 271 272 273
            raise RuntimeError(
                "TensorDataset con only be used in imperative mode")
        assert all([tensor.shape[0] == tensors[0].shape[0] for tensor in tensors]), \
                "tensors not have same shape of the 1st dimension"
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].shape[0]
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299


def to_list(value):
    if value is None:
        return value
    if isinstance(value, (list, tuple)):
        return list(value)
    return [value]


class ComposeDataset(Dataset):
    """
    A Dataset which composes fields of multiple datasets.

    This dataset is used for composing fileds of multiple map-style
    datasets of same length.

    Args:
        datasets(list of Dataset): List of datasets to be composed.

    Returns:
        Dataset: A Dataset which composes fields of multiple datasets.

    Examples:

        .. code-block:: python
300

301 302 303 304 305 306 307 308 309 310 311 312 313 314
            import numpy as np
            import paddle
            from paddle.io import Dataset, ComposeDataset


            # define a random dataset
            class RandomDataset(Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples

                def __getitem__(self, idx):
                    image = np.random.random([32]).astype('float32')
                    label = np.random.randint(0, 9, (1, )).astype('int64')
                    return image, label
315

316 317 318 319 320 321 322 323 324 325
                def __len__(self):
                    return self.num_samples

            dataset = ComposeDataset([RandomDataset(10), RandomDataset(10)])
            for i in range(len(dataset)):
                image1, label1, image2, label2 = dataset[i]
                print(image1)
                print(label1)
                print(image2)
                print(label2)
326

327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
    """

    def __init__(self, datasets):
        self.datasets = list(datasets)
        assert len(self.datasets) > 0, "input datasets shoule not be empty"
        for i, dataset in enumerate(self.datasets):
            assert isinstance(dataset, Dataset), \
                    "each input dataset should be paddle.io.Dataset"
            assert not isinstance(dataset, IterableDataset), \
                    "paddle.io.IterableDataset not supported"
            if i > 0:
                assert len(dataset) == len(self.datasets[i-1]), \
                        "lengths of datasets should be same"

    def __len__(self):
        return len(self.datasets[0])

    def __getitem__(self, idx):
        sample = []
        for dataset in self.datasets:
            sample.extend(to_list(dataset[idx]))
        return tuple(sample)


class ChainDataset(IterableDataset):
    """
    A Dataset which chains multiple iterable-tyle datasets.

    This dataset is used for assembling multiple datasets which should
    be :code:`paddle.io.IterableDataset`.

    Args:
        datasets(list of Dataset): List of datasets to be chainned.

    Returns:
        Dataset: A Dataset which chains fields of multiple datasets.

    Examples:

        .. code-block:: python
367

368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
            import numpy as np
            import paddle
            from paddle.io import IterableDataset, ChainDataset


            # define a random dataset
            class RandomDataset(IterableDataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples

                def __iter__(self):
                    for i in range(10):
                        image = np.random.random([32]).astype('float32')
                        label = np.random.randint(0, 9, (1, )).astype('int64')
                        yield image, label
383

384 385 386
            dataset = ChainDataset([RandomDataset(10), RandomDataset(10)])
            for image, label in iter(dataset):
                print(image, label)
387

388 389 390 391 392 393 394 395 396 397 398 399 400
    """

    def __init__(self, datasets):
        self.datasets = list(datasets)
        assert len(self.datasets) > 0, "input datasets shoule not be empty"
        for i, dataset in enumerate(self.datasets):
            assert isinstance(dataset, IterableDataset), \
                    "ChainDataset only support paddle.io.IterableDataset"

    def __iter__(self):
        for dataset in self.datasets:
            for sample in dataset:
                yield sample
401 402 403 404 405


class Subset(Dataset):
    """
    Subset of a dataset at specified indices.
406

407 408 409 410 411
    Args:
        dataset (Dataset): The whole Dataset.
        indices (sequence): Indices in the whole set selected for subset.

    Returns:
412
        List[Dataset]: A Dataset which is the subset of the original dataset.
413

414
    Examples:
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452

        .. code-block:: python

            import paddle
            from paddle.io import Subset

            # example 1:
            a = paddle.io.Subset(dataset=range(1, 4), indices=[0, 2])
            print(list(a))
            # [1, 3]

            # example 2:
            b = paddle.io.Subset(dataset=range(1, 4), indices=[1, 1])
            print(list(b))
            # [2, 2]
    """

    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)


def random_split(dataset, lengths, generator=None):
    """
    Randomly split a dataset into non-overlapping new datasets of given lengths.
    Optionally fix the generator for reproducible results, e.g.:

    Args:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
        generator (Generator, optional): Generator used for the random permutation. Default is None then the DefaultGenerator is used in manual_seed().

453
    Returns:
454 455
        Datasets: A list of subset Datasets, which are the non-overlapping subsets of the original Dataset.

456
    Examples:
457 458 459 460 461 462 463

        .. code-block:: python

            import paddle
            from paddle.io import random_split

            a_list = paddle.io.random_split(range(10), [3, 7])
464
            print(len(a_list))
465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
            # 2

            for idx, v in enumerate(a_list[0]):
                print(idx, v)

            # output of the first subset
            # 0 1
            # 1 3
            # 2 9

            for idx, v in enumerate(a_list[1]):
                print(idx, v)
            # output of the second subset
            # 0 5
            # 1 7
            # 2 8
            # 3 6
            # 4 0
            # 5 2
            # 6 4
    """
    # Cannot verify that dataset is Sized
    if sum(lengths) != len(dataset):  # type: ignore
        raise ValueError(
            "Sum of input lengths does not equal the length of the input dataset!"
        )
    # TODO(@Joejiong): support Variable or Tensor type with .tolist class member function.
    # For example var.item() and var.tolist()
    indices = paddle.randperm(sum(lengths)).numpy().tolist()
    return [
        Subset(dataset, indices[offset - length:offset])
        for offset, length in zip(_accumulate(lengths), lengths)
    ]


def _accumulate(iterable, fn=lambda x, y: x + y):
    """
    Return running totals
503

504 505 506 507 508 509 510 511 512
    Args:
        iterable: any iterable object for example dataset.
        y (x): one element in the iterable object.
        fn (x, y): Defaults to lambdax.

    Yields:
        yields total from beginning iterator to current iterator.

    Example code:
513

514
        .. code-block:: python
515

516 517 518 519 520 521 522 523 524 525 526 527 528
            _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
            _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
    """

    it = iter(iterable)
    try:
        total = next(it)
    except StopIteration:
        return
    yield total
    for element in it:
        total = fn(total, element)
        yield total