# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import six import numpy as np from tqdm import tqdm class DatasetMixin(object): """Standard indexing interface for dataset. Inherit this class to get the indexing interface. Since it is a mixin class which does not have an `__init__` class, the subclass not need to call `super().__init__()`. """ def __getitem__(self, index): """Standard indexing interface for dataset. Args: index (slice, list[int], np.array or int): the index. if can be int, slice, list of integers, or ndarray of integers. It calls `get_example` to pick an example. Returns: Example, or List[Example]: If `index` is an interger, it returns an example. If `index` is a slice, a list of intergers or an array of intergers, it returns a list of examples. """ if isinstance(index, slice): start, stop, step = index.indices(len(self)) return [ self.get_example(i) for i in six.moves.range(start, stop, step) ] elif isinstance(index, (list, np.ndarray)): return [self.get_example(i) for i in index] else: # assumes it an integer return self.get_example(index) def get_example(self, i): """Get an example from the dataset. Custom datasets should have this method implemented. Args: i (int): example index. """ raise NotImplementedError def __len__(self): raise NotImplementedError def __iter__(self): for i in range(len(self)): yield self.get_example(i) class TransformDataset(DatasetMixin): def __init__(self, dataset, transform): """Dataset which is transformed from another with a transform. Args: dataset (DatasetMixin): the base dataset. transform (callable): the transform which takes an example of the base dataset as parameter and return a new example. """ self._dataset = dataset self._transform = transform def __len__(self): return len(self._dataset) def get_example(self, i): in_data = self._dataset[i] return self._transform(in_data) class CacheDataset(DatasetMixin): def __init__(self, dataset): """A lazy cache of the base dataset. Args: dataset (DatasetMixin): the base dataset to cache. """ self._dataset = dataset self._cache = dict() def __len__(self): return len(self._dataset) def get_example(self, i): if not i in self._cache: self._cache[i] = self._dataset[i] return self._cache[i] class TupleDataset(object): def __init__(self, *datasets): """A compound dataset made from several datasets of the same length. An example of the `TupleDataset` is a tuple of examples from the constituent datasets. Args: datasets: tuple[DatasetMixin], the constituent datasets. """ if not datasets: raise ValueError("no datasets are given") length = len(datasets[0]) for i, dataset in enumerate(datasets): if len(datasets) != length: raise ValueError( "all the datasets should have the same length." "dataset {} has a different length".format(i)) self._datasets = datasets self._length = length def __getitem__(self, index): # SOA batches = [dataset[index] for dataset in self._datasets] if isinstance(index, slice): length = len(batches[0]) # AOS return [ tuple([batch[i] for batch in batches]) for i in six.moves.range(length) ] else: return tuple(batches) def __len__(self): return self._length class DictDataset(object): def __init__(self, **datasets): """A compound dataset made from several datasets of the same length. An example of the `DictDataset` is a dict of examples from the constituent datasets. Args: datasets: Dict[DatasetMixin], the constituent datasets. """ if not datasets: raise ValueError("no datasets are given") length = None for key, dataset in six.iteritems(datasets): if length is None: length = len(dataset) elif len(datasets) != length: raise ValueError( "all the datasets should have the same length." "dataset {} has a different length".format(key)) self._datasets = datasets self._length = length def __getitem__(self, index): batches = { key: dataset[index] for key, dataset in six.iteritems(self._datasets) } if isinstance(index, slice): length = len(six.next(six.itervalues(batches))) return [{key: batch[i] for key, batch in six.iteritems(batches)} for i in six.moves.range(length)] else: return batches class SliceDataset(DatasetMixin): def __init__(self, dataset, start, finish, order=None): """A Dataset which is a slice of the base dataset. Args: dataset (DatasetMixin): the base dataset. start (int): the start of the slice. finish (int): the end of the slice, not inclusive. order (List[int], optional): the order, it is a permutation of the valid example ids of the base dataset. If `order` is provided, the slice is taken in `order`. Defaults to None. """ if start < 0 or finish > len(dataset): raise ValueError("subset overruns the dataset.") self._dataset = dataset self._start = start self._finish = finish self._size = finish - start if order is not None and len(order) != len(dataset): raise ValueError( "order should have the same length as the dataset" "len(order) = {} which does not euqals len(dataset) = {} ". format(len(order), len(dataset))) self._order = order def __len__(self): return self._size def get_example(self, i): if i >= 0: if i >= self._size: raise IndexError('dataset index out of range') index = self._start + i else: if i < -self._size: raise IndexError('dataset index out of range') index = self._finish + i if self._order is not None: index = self._order[index] return self._dataset[index] class SubsetDataset(DatasetMixin): def __init__(self, dataset, indices): """A Dataset which is a subset of the base dataset. Args: dataset (DatasetMixin): the base dataset. indices (Iterable[int]): the indices of the examples to pick. """ self._dataset = dataset if len(indices) > len(dataset): raise ValueError("subset's size larger that dataset's size!") self._indices = indices self._size = len(indices) def __len__(self): return self._size def get_example(self, i): index = self._indices[i] return self._dataset[index] class FilterDataset(DatasetMixin): def __init__(self, dataset, filter_fn): """A filtered dataset. Args: dataset (DatasetMixin): the base dataset. filter_fn (callable): a callable which takes an example of the base dataset and return a boolean. """ self._dataset = dataset self._indices = [ i for i in range(len(dataset)) if filter_fn(dataset[i]) ] self._size = len(self._indices) def __len__(self): return self._size def get_example(self, i): index = self._indices[i] return self._dataset[index] class ChainDataset(DatasetMixin): def __init__(self, *datasets): """A concatenation of the several datasets which the same structure. Args: datasets (Iterable[DatasetMixin]): datasets to concat. """ self._datasets = datasets def __len__(self): return sum(len(dataset) for dataset in self._datasets) def get_example(self, i): if i < 0: raise IndexError("ChainDataset doesnot support negative indexing.") for dataset in self._datasets: if i < len(dataset): return dataset[i] i -= len(dataset) raise IndexError("dataset index out of range")