From 2ddb6305211706b604208eec2d5dfddd691f56ce Mon Sep 17 00:00:00 2001 From: smallv0221 <397551318@qq.com> Date: Tue, 22 Dec 2020 13:25:14 +0000 Subject: [PATCH] upgrade MapDatasetWrapper --- PaddleNLP/paddlenlp/datasets/dataset.py | 90 ++++++++++++++++++++----- 1 file changed, 74 insertions(+), 16 deletions(-) diff --git a/PaddleNLP/paddlenlp/datasets/dataset.py b/PaddleNLP/paddlenlp/datasets/dataset.py index 27a2f84a..6c5ae6d5 100644 --- a/PaddleNLP/paddlenlp/datasets/dataset.py +++ b/PaddleNLP/paddlenlp/datasets/dataset.py @@ -89,14 +89,26 @@ class MapDatasetWrapper(Dataset): def __init__(self, data): self.data = data - self._transform_func = None + self._transform_pipline = [] + self.new_data = [] + + def transform(self, data, pipline): + for fn in reversed(pipline): + data = fn(data) + return data def __getitem__(self, idx): - return self._transform_func(self.data[ - idx]) if self._transform_func else self.data[idx] + if not self.new_data: + return self.transform( + self.data[idx], self._transform_pipline + ) if self._transform_pipline else self.data[idx] + else: + return self.transform( + self.new_data[idx], self._transform_pipline + ) if self._transform_pipline else self.new_data[idx] def __len__(self): - return len(self.data) + return len(self.data) if not self.new_data else len(self.new_data) def filter(self, fn): """ @@ -108,11 +120,17 @@ class MapDatasetWrapper(Dataset): Returns: MapDatasetWrapper: The filtered dataset """ - filted_data = [ - self.data[idx] for idx in range(len(self.data)) - if fn(self.data[idx]) - ] - return type(self)(filted_data) + if not self.new_data: + self.new_data = [ + self.data[idx] for idx in range(len(self.data)) + if fn(self.data[idx]) + ] + else: + self.new_data = [ + self.data[idx] for idx in range(len(self.new_data)) + if fn(self.new_data[idx]) + ] + return self def shard(self, num_shards=None, index=None): """ @@ -135,13 +153,53 @@ class MapDatasetWrapper(Dataset): num_samples = int(math.ceil(len(self.data) * 1.0 / num_shards)) total_size = num_samples * num_shards # add extra samples to make it evenly divisible - sharded_data = [ - self.data[idx] for idx in range(len(self.data)) - if idx % num_shards == index - ] - if len(sharded_data) < num_samples: - sharded_data.append(self.data[index + 1 - num_shards]) - return type(self)(sharded_data) + if not self.new_data: + self.new_data = [ + self.data[idx] for idx in range(len(self.data)) + if idx % num_shards == index + ] + if len(self.new_data) < num_samples: + self.new_data.append(self.data[index + 1 - num_shards]) + else: + self.new_data = [ + self.data[idx] for idx in range(len(self.new_data)) + if idx % num_shards == index + ] + if len(self.new_data) < num_samples: + self.new_data.append(self.new_data[index + 1 - num_shards]) + return self + + def apply(self, fn, lazy=False): + """ + Performs specific function on the dataset to transform every sample. + Args: + fn (callable): Transformations to be performed. It receives single + sample as argument rather than dataset. + lazy (bool, optional): If True, transformations would be delayed and + performed on demand. Otherwise, transforms all samples at once + and return a new MapDatasetWrapper instance. Note that if `fn` is + stochastic, `lazy` should be True or you will get the same + result on all epochs. Defalt: False. + Returns: + MapDatasetWrapper: A new MapDatasetWrapper instance if `lazy` is True, \ + otherwise bind `fn` as a property to transform on demand. + """ + if lazy: + self._transform_pipline.append(fn) + else: + if not self.new_data: + self.new_data = [ + fn(self.data[idx]) for idx in range(len(self.data)) + ] + else: + self.new_data = [ + fn(self.new_data[idx]) + for idx in range(len(self.new_data)) + ] + return self + + def __getattr__(self, name): + return getattr(self.data, name) def apply(self, fn, lazy=False): """ -- GitLab