提交 894bd217 编写于 作者: S smallv0221

refine code

上级 923a3aa7
...@@ -90,25 +90,20 @@ class MapDatasetWrapper(Dataset): ...@@ -90,25 +90,20 @@ class MapDatasetWrapper(Dataset):
def __init__(self, data): def __init__(self, data):
self.data = data self.data = data
self._transform_pipline = [] self._transform_pipline = []
self.new_data = [] self.new_data = self.data
def transform(self, data, pipline): def _transform(self, data, pipline):
for fn in reversed(pipline): for fn in reversed(pipline):
data = fn(data) data = fn(data)
return data return data
def __getitem__(self, idx): def __getitem__(self, idx):
if not self.new_data: return self._transform(
return self.transform( self.new_data[idx], self._transform_pipline
self.data[idx], self._transform_pipline ) if self._transform_pipline else self.new_data[idx]
) if self._transform_pipline else self.data[idx]
else:
return self.transform(
self.new_data[idx], self._transform_pipline
) if self._transform_pipline else self.new_data[idx]
def __len__(self): def __len__(self):
return len(self.data) if not self.new_data else len(self.new_data) return len(self.new_data)
def filter(self, fn): def filter(self, fn):
""" """
...@@ -120,16 +115,11 @@ class MapDatasetWrapper(Dataset): ...@@ -120,16 +115,11 @@ class MapDatasetWrapper(Dataset):
Returns: Returns:
MapDatasetWrapper: The filtered dataset MapDatasetWrapper: The filtered dataset
""" """
if not self.new_data:
self.new_data = [ self.new_data = [
self.data[idx] for idx in range(len(self.data)) self.new_data[idx] for idx in range(len(self.new_data))
if fn(self.data[idx]) if fn(self.new_data[idx])
] ]
else:
self.new_data = [
self.data[idx] for idx in range(len(self.new_data))
if fn(self.new_data[idx])
]
return self return self
def shard(self, num_shards=None, index=None): def shard(self, num_shards=None, index=None):
...@@ -150,23 +140,17 @@ class MapDatasetWrapper(Dataset): ...@@ -150,23 +140,17 @@ class MapDatasetWrapper(Dataset):
num_shards = dist.get_world_size() num_shards = dist.get_world_size()
if index is None: if index is None:
index = dist.get_rank() index = dist.get_rank()
num_samples = int(math.ceil(len(self.data) * 1.0 / num_shards))
num_samples = int(math.ceil(len(self.new_data) * 1.0 / num_shards))
total_size = num_samples * num_shards total_size = num_samples * num_shards
# add extra samples to make it evenly divisible # add extra samples to make it evenly divisible
if not self.new_data: self.new_data = [
self.new_data = [ self.new_data[idx] for idx in range(len(self.new_data))
self.data[idx] for idx in range(len(self.data)) if idx % num_shards == index
if idx % num_shards == index ]
] if len(self.new_data) < num_samples:
if len(self.new_data) < num_samples: self.new_data.append(self.new_data[index + 1 - num_shards])
self.new_data.append(self.data[index + 1 - num_shards])
else:
self.new_data = [
self.data[idx] for idx in range(len(self.new_data))
if idx % num_shards == index
]
if len(self.new_data) < num_samples:
self.new_data.append(self.new_data[index + 1 - num_shards])
return self return self
def apply(self, fn, lazy=False): def apply(self, fn, lazy=False):
...@@ -187,15 +171,9 @@ class MapDatasetWrapper(Dataset): ...@@ -187,15 +171,9 @@ class MapDatasetWrapper(Dataset):
if lazy: if lazy:
self._transform_pipline.append(fn) self._transform_pipline.append(fn)
else: else:
if not self.new_data: self.new_data = [
self.new_data = [ fn(self.new_data[idx]) for idx in range(len(self.new_data))
fn(self.data[idx]) for idx in range(len(self.data)) ]
]
else:
self.new_data = [
fn(self.new_data[idx])
for idx in range(len(self.new_data))
]
return self return self
def __getattr__(self, name): def __getattr__(self, name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册