未验证 提交 1dbc2855 编写于 作者: S smallv0221 提交者: GitHub

upgrade MapDatasetWrapper (#5132)

* upgrade MapDatasetWrapper

* minor fix

* refine code
上级 61e6247b
...@@ -89,14 +89,21 @@ class MapDatasetWrapper(Dataset): ...@@ -89,14 +89,21 @@ class MapDatasetWrapper(Dataset):
def __init__(self, data): def __init__(self, data):
self.data = data self.data = data
self._transform_func = None self._transform_pipline = []
self.new_data = self.data
def _transform(self, data, pipline):
for fn in reversed(pipline):
data = fn(data)
return data
def __getitem__(self, idx): def __getitem__(self, idx):
return self._transform_func(self.data[ return self._transform(
idx]) if self._transform_func else self.data[idx] self.new_data[idx], self._transform_pipline
) if self._transform_pipline else self.new_data[idx]
def __len__(self): def __len__(self):
return len(self.data) return len(self.new_data)
def filter(self, fn): def filter(self, fn):
""" """
...@@ -108,11 +115,12 @@ class MapDatasetWrapper(Dataset): ...@@ -108,11 +115,12 @@ class MapDatasetWrapper(Dataset):
Returns: Returns:
MapDatasetWrapper: The filtered dataset MapDatasetWrapper: The filtered dataset
""" """
filted_data = [
self.data[idx] for idx in range(len(self.data)) self.new_data = [
if fn(self.data[idx]) self.new_data[idx] for idx in range(len(self.new_data))
if fn(self.new_data[idx])
] ]
return type(self)(filted_data) return self
def shard(self, num_shards=None, index=None): def shard(self, num_shards=None, index=None):
""" """
...@@ -132,16 +140,18 @@ class MapDatasetWrapper(Dataset): ...@@ -132,16 +140,18 @@ class MapDatasetWrapper(Dataset):
num_shards = dist.get_world_size() num_shards = dist.get_world_size()
if index is None: if index is None:
index = dist.get_rank() index = dist.get_rank()
num_samples = int(math.ceil(len(self.data) * 1.0 / num_shards))
num_samples = int(math.ceil(len(self.new_data) * 1.0 / num_shards))
total_size = num_samples * num_shards total_size = num_samples * num_shards
# add extra samples to make it evenly divisible # add extra samples to make it evenly divisible
sharded_data = [ self.new_data = [
self.data[idx] for idx in range(len(self.data)) self.new_data[idx] for idx in range(len(self.new_data))
if idx % num_shards == index if idx % num_shards == index
] ]
if len(sharded_data) < num_samples: if len(self.new_data) < num_samples:
sharded_data.append(self.data[index + 1 - num_shards]) self.new_data.append(self.new_data[index + 1 - num_shards])
return type(self)(sharded_data)
return self
def apply(self, fn, lazy=False): def apply(self, fn, lazy=False):
""" """
...@@ -159,10 +169,11 @@ class MapDatasetWrapper(Dataset): ...@@ -159,10 +169,11 @@ class MapDatasetWrapper(Dataset):
otherwise bind `fn` as a property to transform on demand. otherwise bind `fn` as a property to transform on demand.
""" """
if lazy: if lazy:
self._transform_func = fn self._transform_pipline.append(fn)
else: else:
applied_data = [fn(self.data[idx]) for idx in range(len(self.data))] self.new_data = [
return type(self)(applied_data) fn(self.new_data[idx]) for idx in range(len(self.new_data))
]
return self return self
def __getattr__(self, name): def __getattr__(self, name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册