"""Dataset which is transformed from another with a transform.
Args:
dataset (Dataset): the base dataset.
transform (callable): the transform which takes an example of the base dataset as parameter and return a new example.
"""
self._dataset=dataset
self._transform=transform
def__len__(self):
returnlen(self._dataset)
def__getitem__(self,i):
in_data=self._dataset[i]
returnself._transform(in_data)
classCacheDataset(Dataset):
def__init__(self,dataset):
"""A lazy cache of the base dataset.
Args:
dataset (Dataset): the base dataset to cache.
"""
self._dataset=dataset
self._cache=dict()
def__len__(self):
returnlen(self._dataset)
def__getitem__(self,i):
ifinotinself._cache:
self._cache[i]=self._dataset[i]
returnself._cache[i]
classTupleDataset(Dataset):
def__init__(self,*datasets):
"""A compound dataset made from several datasets of the same length. An example of the `TupleDataset` is a tuple of examples from the constituent datasets.
Args:
datasets: tuple[Dataset], the constituent datasets.
"""
ifnotdatasets:
raiseValueError("no datasets are given")
length=len(datasets[0])
fori,datasetinenumerate(datasets):
iflen(dataset)!=length:
raiseValueError("all the datasets should have the same length."
"""A Dataset which is a slice of the base dataset.
Args:
dataset (Dataset): the base dataset.
start (int): the start of the slice.
finish (int): the end of the slice, not inclusive.
order (List[int], optional): the order, it is a permutation of the valid example ids of the base dataset. If `order` is provided, the slice is taken in `order`. Defaults to None.
"""
ifstart<0orfinish>len(dataset):
raiseValueError("subset overruns the dataset.")
self._dataset=dataset
self._start=start
self._finish=finish
self._size=finish-start
iforderisnotNoneandlen(order)!=len(dataset):
raiseValueError(
"order should have the same length as the dataset"
"len(order) = {} which does not euqals len(dataset) = {} ".
format(len(order),len(dataset)))
self._order=order
def__len__(self):
returnself._size
def__getitem__(self,i):
ifi>=0:
ifi>=self._size:
raiseIndexError('dataset index out of range')
index=self._start+i
else:
ifi<-self._size:
raiseIndexError('dataset index out of range')
index=self._finish+i
ifself._orderisnotNone:
index=self._order[index]
returnself._dataset[index]
classSubsetDataset(Dataset):
def__init__(self,dataset,indices):
"""A Dataset which is a subset of the base dataset.
Args:
dataset (Dataset): the base dataset.
indices (Iterable[int]): the indices of the examples to pick.
"""
self._dataset=dataset
iflen(indices)>len(dataset):
raiseValueError("subset's size larger that dataset's size!")
self._indices=indices
self._size=len(indices)
def__len__(self):
returnself._size
def__getitem__(self,i):
index=self._indices[i]
returnself._dataset[index]
classFilterDataset(Dataset):
def__init__(self,dataset,filter_fn):
"""A filtered dataset.
Args:
dataset (Dataset): the base dataset.
filter_fn (callable): a callable which takes an example of the base dataset and return a boolean.
"""
self._dataset=dataset
self._indices=[
iforiinrange(len(dataset))iffilter_fn(dataset[i])
]
self._size=len(self._indices)
def__len__(self):
returnself._size
def__getitem__(self,i):
index=self._indices[i]
returnself._dataset[index]
classChainDataset(Dataset):
def__init__(self,*datasets):
"""A concatenation of the several datasets which the same structure.
Args:
datasets (Iterable[Dataset]): datasets to concat.
"""
self._datasets=datasets
def__len__(self):
returnsum(len(dataset)fordatasetinself._datasets)
def__getitem__(self,i):
ifi<0:
raiseIndexError("ChainDataset doesnot support negative indexing.")