From 1dbc2855b6615d156f7f4a57bb21091cf2a986a1 Mon Sep 17 00:00:00 2001 From: smallv0221 <33639025+smallv0221@users.noreply.github.com> Date: Fri, 25 Dec 2020 10:39:42 +0800 Subject: [PATCH] upgrade MapDatasetWrapper (#5132) * upgrade MapDatasetWrapper * minor fix * refine code --- PaddleNLP/paddlenlp/datasets/dataset.py | 45 +++++++++++++++---------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/PaddleNLP/paddlenlp/datasets/dataset.py b/PaddleNLP/paddlenlp/datasets/dataset.py index 27a2f84a..787e4ac4 100644 --- a/PaddleNLP/paddlenlp/datasets/dataset.py +++ b/PaddleNLP/paddlenlp/datasets/dataset.py @@ -89,14 +89,21 @@ class MapDatasetWrapper(Dataset): def __init__(self, data): self.data = data - self._transform_func = None + self._transform_pipline = [] + self.new_data = self.data + + def _transform(self, data, pipline): + for fn in reversed(pipline): + data = fn(data) + return data def __getitem__(self, idx): - return self._transform_func(self.data[ - idx]) if self._transform_func else self.data[idx] + return self._transform( + self.new_data[idx], self._transform_pipline + ) if self._transform_pipline else self.new_data[idx] def __len__(self): - return len(self.data) + return len(self.new_data) def filter(self, fn): """ @@ -108,11 +115,12 @@ class MapDatasetWrapper(Dataset): Returns: MapDatasetWrapper: The filtered dataset """ - filted_data = [ - self.data[idx] for idx in range(len(self.data)) - if fn(self.data[idx]) + + self.new_data = [ + self.new_data[idx] for idx in range(len(self.new_data)) + if fn(self.new_data[idx]) ] - return type(self)(filted_data) + return self def shard(self, num_shards=None, index=None): """ @@ -132,16 +140,18 @@ class MapDatasetWrapper(Dataset): num_shards = dist.get_world_size() if index is None: index = dist.get_rank() - num_samples = int(math.ceil(len(self.data) * 1.0 / num_shards)) + + num_samples = int(math.ceil(len(self.new_data) * 1.0 / num_shards)) total_size = num_samples * num_shards # add extra samples to make it evenly divisible - sharded_data = [ - self.data[idx] for idx in range(len(self.data)) + self.new_data = [ + self.new_data[idx] for idx in range(len(self.new_data)) if idx % num_shards == index ] - if len(sharded_data) < num_samples: - sharded_data.append(self.data[index + 1 - num_shards]) - return type(self)(sharded_data) + if len(self.new_data) < num_samples: + self.new_data.append(self.new_data[index + 1 - num_shards]) + + return self def apply(self, fn, lazy=False): """ @@ -159,10 +169,11 @@ class MapDatasetWrapper(Dataset): otherwise bind `fn` as a property to transform on demand. """ if lazy: - self._transform_func = fn + self._transform_pipline.append(fn) else: - applied_data = [fn(self.data[idx]) for idx in range(len(self.data))] - return type(self)(applied_data) + self.new_data = [ + fn(self.new_data[idx]) for idx in range(len(self.new_data)) + ] return self def __getattr__(self, name): -- GitLab