diff --git a/PaddleNLP/paddlenlp/datasets/dataset.py b/PaddleNLP/paddlenlp/datasets/dataset.py index a44a84ddbd45f24c8d61a914e12b7d865a8fec6c..787e4ac4e3d8db01920a37a049cfd754cdcba299 100644 --- a/PaddleNLP/paddlenlp/datasets/dataset.py +++ b/PaddleNLP/paddlenlp/datasets/dataset.py @@ -90,25 +90,20 @@ class MapDatasetWrapper(Dataset): def __init__(self, data): self.data = data self._transform_pipline = [] - self.new_data = [] + self.new_data = self.data - def transform(self, data, pipline): + def _transform(self, data, pipline): for fn in reversed(pipline): data = fn(data) return data def __getitem__(self, idx): - if not self.new_data: - return self.transform( - self.data[idx], self._transform_pipline - ) if self._transform_pipline else self.data[idx] - else: - return self.transform( - self.new_data[idx], self._transform_pipline - ) if self._transform_pipline else self.new_data[idx] + return self._transform( + self.new_data[idx], self._transform_pipline + ) if self._transform_pipline else self.new_data[idx] def __len__(self): - return len(self.data) if not self.new_data else len(self.new_data) + return len(self.new_data) def filter(self, fn): """ @@ -120,16 +115,11 @@ class MapDatasetWrapper(Dataset): Returns: MapDatasetWrapper: The filtered dataset """ - if not self.new_data: - self.new_data = [ - self.data[idx] for idx in range(len(self.data)) - if fn(self.data[idx]) - ] - else: - self.new_data = [ - self.data[idx] for idx in range(len(self.new_data)) - if fn(self.new_data[idx]) - ] + + self.new_data = [ + self.new_data[idx] for idx in range(len(self.new_data)) + if fn(self.new_data[idx]) + ] return self def shard(self, num_shards=None, index=None): @@ -150,23 +140,17 @@ class MapDatasetWrapper(Dataset): num_shards = dist.get_world_size() if index is None: index = dist.get_rank() - num_samples = int(math.ceil(len(self.data) * 1.0 / num_shards)) + + num_samples = int(math.ceil(len(self.new_data) * 1.0 / num_shards)) total_size = num_samples * num_shards # add extra samples to make it evenly divisible - if not self.new_data: - self.new_data = [ - self.data[idx] for idx in range(len(self.data)) - if idx % num_shards == index - ] - if len(self.new_data) < num_samples: - self.new_data.append(self.data[index + 1 - num_shards]) - else: - self.new_data = [ - self.data[idx] for idx in range(len(self.new_data)) - if idx % num_shards == index - ] - if len(self.new_data) < num_samples: - self.new_data.append(self.new_data[index + 1 - num_shards]) + self.new_data = [ + self.new_data[idx] for idx in range(len(self.new_data)) + if idx % num_shards == index + ] + if len(self.new_data) < num_samples: + self.new_data.append(self.new_data[index + 1 - num_shards]) + return self def apply(self, fn, lazy=False): @@ -187,15 +171,9 @@ class MapDatasetWrapper(Dataset): if lazy: self._transform_pipline.append(fn) else: - if not self.new_data: - self.new_data = [ - fn(self.data[idx]) for idx in range(len(self.data)) - ] - else: - self.new_data = [ - fn(self.new_data[idx]) - for idx in range(len(self.new_data)) - ] + self.new_data = [ + fn(self.new_data[idx]) for idx in range(len(self.new_data)) + ] return self def __getattr__(self, name):