dataset.py 9.1 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import six
import numpy as np
17
from tqdm import tqdm
18 19 20


class DatasetMixin(object):
21 22 23 24 25
    """Standard indexing interface for dataset. Inherit this class to 
    get the indexing interface. Since it is a mixin class which does 
    not have an `__init__` class, the subclass not need to call  
    `super().__init__()`.
    """
26

27
    def __getitem__(self, index):
28 29 30 31 32 33 34 35 36 37
        """Standard indexing interface for dataset.

        Args:
            index (slice, list[int], np.array or int): the index. if can be int, slice, list of integers, or ndarray of integers. It calls `get_example` to pick an example. 

        Returns:
            Example, or List[Example]:  If `index` is an interger, it returns an 
                    example. If `index` is a slice, a list of intergers or an array of intergers,
                    it returns a list of examples.
        """
38 39 40
        if isinstance(index, slice):
            start, stop, step = index.indices(len(self))
            return [
L
lifuchen 已提交
41
                self.get_example(i) for i in six.moves.range(start, stop, step)
42 43 44 45 46 47 48 49
            ]
        elif isinstance(index, (list, np.ndarray)):
            return [self.get_example(i) for i in index]
        else:
            # assumes it an integer
            return self.get_example(index)

    def get_example(self, i):
50 51 52 53 54 55
        """Get an example from the dataset. Custom datasets should have 
        this method implemented.

        Args:
            i (int): example index.
        """
56
        raise NotImplementedError
57 58

    def __len__(self):
59 60
        raise NotImplementedError

61 62 63 64 65 66 67
    def __iter__(self):
        for i in range(len(self)):
            yield self.get_example(i)


class TransformDataset(DatasetMixin):
    def __init__(self, dataset, transform):
68 69 70 71 72 73
        """Dataset which is transformed from another with a transform.

        Args:
            dataset (DatasetMixin): the base dataset.
            transform (callable): the transform which takes an example of the base dataset as parameter and return a new example.
        """
74 75 76 77 78 79 80 81 82 83 84
        self._dataset = dataset
        self._transform = transform

    def __len__(self):
        return len(self._dataset)

    def get_example(self, i):
        in_data = self._dataset[i]
        return self._transform(in_data)


85 86
class CacheDataset(DatasetMixin):
    def __init__(self, dataset):
87 88 89 90 91
        """A lazy cache of the base dataset.

        Args:
            dataset (DatasetMixin): the base dataset to cache.
        """
92 93 94 95 96 97 98 99 100 101 102 103
        self._dataset = dataset
        self._cache = dict()

    def __len__(self):
        return len(self._dataset)

    def get_example(self, i):
        if not i in self._cache:
            self._cache[i] = self._dataset[i]
        return self._cache[i]


104 105
class TupleDataset(object):
    def __init__(self, *datasets):
106 107 108 109 110
        """A compound dataset made from several datasets of the same length. An example of the `TupleDataset` is a tuple of examples from the constituent datasets.

        Args:
            datasets: tuple[DatasetMixin], the constituent datasets.
        """
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
        if not datasets:
            raise ValueError("no datasets are given")
        length = len(datasets[0])
        for i, dataset in enumerate(datasets):
            if len(datasets) != length:
                raise ValueError(
                    "all the datasets should have the same length."
                    "dataset {} has a different length".format(i))
        self._datasets = datasets
        self._length = length

    def __getitem__(self, index):
        # SOA
        batches = [dataset[index] for dataset in self._datasets]
        if isinstance(index, slice):
            length = len(batches[0])
            # AOS
            return [
                tuple([batch[i] for batch in batches])
                for i in six.moves.range(length)
            ]
        else:
            return tuple(batches)

    def __len__(self):
        return self._length


class DictDataset(object):
    def __init__(self, **datasets):
141 142 143 144 145
        """A compound dataset made from several datasets of the same length. An example of the `DictDataset` is a dict of examples from the constituent datasets.

        Args:
            datasets: Dict[DatasetMixin], the constituent datasets.
        """
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
        if not datasets:
            raise ValueError("no datasets are given")
        length = None
        for key, dataset in six.iteritems(datasets):
            if length is None:
                length = len(dataset)
            elif len(datasets) != length:
                raise ValueError(
                    "all the datasets should have the same length."
                    "dataset {} has a different length".format(key))
        self._datasets = datasets
        self._length = length

    def __getitem__(self, index):
        batches = {
            key: dataset[index]
            for key, dataset in six.iteritems(self._datasets)
        }
        if isinstance(index, slice):
            length = len(six.next(six.itervalues(batches)))
            return [{key: batch[i]
                     for key, batch in six.iteritems(batches)}
                    for i in six.moves.range(length)]
        else:
            return batches


class SliceDataset(DatasetMixin):
    def __init__(self, dataset, start, finish, order=None):
175 176 177 178 179 180 181 182
        """A Dataset which is a slice of the base dataset.

        Args:
            dataset (DatasetMixin): the base dataset.
            start (int): the start of the slice.
            finish (int): the end of the slice, not inclusive.
            order (List[int], optional): the order, it is a permutation of the valid example ids of the base dataset. If `order` is provided, the slice is taken in `order`. Defaults to None.
        """
183 184 185 186 187 188 189 190 191 192 193 194 195 196
        if start < 0 or finish > len(dataset):
            raise ValueError("subset overruns the dataset.")
        self._dataset = dataset
        self._start = start
        self._finish = finish
        self._size = finish - start

        if order is not None and len(order) != len(dataset):
            raise ValueError(
                "order should have the same length as the dataset"
                "len(order) = {} which does not euqals len(dataset) = {} ".
                format(len(order), len(dataset)))
        self._order = order

C
chenfeiyu 已提交
197
    def __len__(self):
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
        return self._size

    def get_example(self, i):
        if i >= 0:
            if i >= self._size:
                raise IndexError('dataset index out of range')
            index = self._start + i
        else:
            if i < -self._size:
                raise IndexError('dataset index out of range')
            index = self._finish + i

        if self._order is not None:
            index = self._order[index]
        return self._dataset[index]


class SubsetDataset(DatasetMixin):
    def __init__(self, dataset, indices):
217 218 219 220 221 222
        """A Dataset which is a subset of the base dataset.

        Args:
            dataset (DatasetMixin): the base dataset.
            indices (Iterable[int]): the indices of the examples to pick.
        """
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
        self._dataset = dataset
        if len(indices) > len(dataset):
            raise ValueError("subset's size larger that dataset's size!")
        self._indices = indices
        self._size = len(indices)

    def __len__(self):
        return self._size

    def get_example(self, i):
        index = self._indices[i]
        return self._dataset[index]


class FilterDataset(DatasetMixin):
    def __init__(self, dataset, filter_fn):
239 240 241 242 243 244
        """A filtered dataset.

        Args:
            dataset (DatasetMixin): the base dataset.
            filter_fn (callable): a callable which takes an example of the base dataset and return a boolean.
        """
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
        self._dataset = dataset
        self._indices = [
            i for i in range(len(dataset)) if filter_fn(dataset[i])
        ]
        self._size = len(self._indices)

    def __len__(self):
        return self._size

    def get_example(self, i):
        index = self._indices[i]
        return self._dataset[index]


class ChainDataset(DatasetMixin):
    def __init__(self, *datasets):
261 262 263 264 265
        """A concatenation of the several datasets which the same structure.

        Args:
            datasets (Iterable[DatasetMixin]): datasets to concat.
        """
266 267 268 269 270 271 272
        self._datasets = datasets

    def __len__(self):
        return sum(len(dataset) for dataset in self._datasets)

    def get_example(self, i):
        if i < 0:
L
lifuchen 已提交
273
            raise IndexError("ChainDataset doesnot support negative indexing.")
274 275 276 277 278 279 280

        for dataset in self._datasets:
            if i < len(dataset):
                return dataset[i]
            i -= len(dataset)

        raise IndexError("dataset index out of range")