extradatasets.py 3.9 KB
Newer Older
H
huangyuxin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#


"""Train PyTorch models directly from POSIX tar archive.

Code works locally or over HTTP connections.
"""

import itertools as itt
import os
import random
import sys

import braceexpand

from . import utils
from .paddle_utils import IterableDataset
from .utils import PipelineStage


class MockDataset(IterableDataset):
    """MockDataset.

    A mock dataset for performance testing and unit testing.
    """

    def __init__(self, sample, length):
        """Create a mock dataset instance.

        :param sample: the sample to be returned repeatedly
        :param length: the length of the mock dataset
        """
        self.sample = sample
        self.length = length

    def __iter__(self):
        """Return an iterator over this mock dataset."""
        for i in range(self.length):
            yield self.sample


class repeatedly(IterableDataset, PipelineStage):
    """Repeatedly yield samples from a dataset."""

    def __init__(self, source, nepochs=None, nbatches=None, length=None):
        """Create an instance of Repeatedly.

        :param nepochs: repeat for a maximum of nepochs
        :param nbatches: repeat for a maximum of nbatches
        """
        self.source = source
        self.length = length
        self.nbatches = nbatches

    def invoke(self, source):
        """Return an iterator that iterates repeatedly over a source."""
        return utils.repeatedly(
            source,
            nepochs=self.nepochs,
            nbatches=self.nbatches,
        )


class with_epoch(IterableDataset):
    """Change the actual and nominal length of an IterableDataset.

    This will continuously iterate through the original dataset, but
    impose new epoch boundaries at the given length/nominal.
    This exists mainly as a workaround for the odd logic in DataLoader.
    It is also useful for choosing smaller nominal epoch sizes with
    very large datasets.

    """

    def __init__(self, dataset, length):
        """Chop the dataset to the given length.

        :param dataset: IterableDataset
        :param length: declared length of the dataset
        :param nominal: nominal length of dataset (if different from declared)
        """
        super().__init__()
        self.length = length
        self.source = None

    def __getstate__(self):
        """Return the pickled state of the dataset.

        This resets the dataset iterator, since that can't be pickled.
        """
        result = dict(self.__dict__)
        result["source"] = None
        return result

    def invoke(self, dataset):
        """Return an iterator over the dataset.

        This iterator returns as many samples as given by the `length`
        parameter.
        """
        if self.source is None:
            self.source = iter(dataset)
        for i in range(self.length):
            try:
                sample = next(self.source)
            except StopIteration:
                self.source = iter(dataset)
                try:
                    sample = next(self.source)
                except StopIteration:
                    return
            yield sample
        self.source = None


class with_length(IterableDataset, PipelineStage):
    """Repeatedly yield samples from a dataset."""

    def __init__(self, dataset, length):
        """Create an instance of Repeatedly.

        :param dataset: source dataset
        :param length: stated length
        """
        super().__init__()
        self.dataset = dataset
        self.length = length

    def invoke(self, dataset):
        """Return an iterator that iterates repeatedly over a source."""
        return iter(dataset)

    def __len__(self):
        """Return the user specified length."""
        return self.length