提交 894bd217 编写于 作者: S smallv0221

refine code

上级 923a3aa7
...@@ -90,25 +90,20 @@ class MapDatasetWrapper(Dataset): ...@@ -90,25 +90,20 @@ class MapDatasetWrapper(Dataset):
def __init__(self, data): def __init__(self, data):
self.data = data self.data = data
self._transform_pipline = [] self._transform_pipline = []
self.new_data = [] self.new_data = self.data
def transform(self, data, pipline): def _transform(self, data, pipline):
for fn in reversed(pipline): for fn in reversed(pipline):
data = fn(data) data = fn(data)
return data return data
def __getitem__(self, idx): def __getitem__(self, idx):
if not self.new_data: return self._transform(
return self.transform(
self.data[idx], self._transform_pipline
) if self._transform_pipline else self.data[idx]
else:
return self.transform(
self.new_data[idx], self._transform_pipline self.new_data[idx], self._transform_pipline
) if self._transform_pipline else self.new_data[idx] ) if self._transform_pipline else self.new_data[idx]
def __len__(self): def __len__(self):
return len(self.data) if not self.new_data else len(self.new_data) return len(self.new_data)
def filter(self, fn): def filter(self, fn):
""" """
...@@ -120,14 +115,9 @@ class MapDatasetWrapper(Dataset): ...@@ -120,14 +115,9 @@ class MapDatasetWrapper(Dataset):
Returns: Returns:
MapDatasetWrapper: The filtered dataset MapDatasetWrapper: The filtered dataset
""" """
if not self.new_data:
self.new_data = [
self.data[idx] for idx in range(len(self.data))
if fn(self.data[idx])
]
else:
self.new_data = [ self.new_data = [
self.data[idx] for idx in range(len(self.new_data)) self.new_data[idx] for idx in range(len(self.new_data))
if fn(self.new_data[idx]) if fn(self.new_data[idx])
] ]
return self return self
...@@ -150,23 +140,17 @@ class MapDatasetWrapper(Dataset): ...@@ -150,23 +140,17 @@ class MapDatasetWrapper(Dataset):
num_shards = dist.get_world_size() num_shards = dist.get_world_size()
if index is None: if index is None:
index = dist.get_rank() index = dist.get_rank()
num_samples = int(math.ceil(len(self.data) * 1.0 / num_shards))
num_samples = int(math.ceil(len(self.new_data) * 1.0 / num_shards))
total_size = num_samples * num_shards total_size = num_samples * num_shards
# add extra samples to make it evenly divisible # add extra samples to make it evenly divisible
if not self.new_data:
self.new_data = [ self.new_data = [
self.data[idx] for idx in range(len(self.data)) self.new_data[idx] for idx in range(len(self.new_data))
if idx % num_shards == index
]
if len(self.new_data) < num_samples:
self.new_data.append(self.data[index + 1 - num_shards])
else:
self.new_data = [
self.data[idx] for idx in range(len(self.new_data))
if idx % num_shards == index if idx % num_shards == index
] ]
if len(self.new_data) < num_samples: if len(self.new_data) < num_samples:
self.new_data.append(self.new_data[index + 1 - num_shards]) self.new_data.append(self.new_data[index + 1 - num_shards])
return self return self
def apply(self, fn, lazy=False): def apply(self, fn, lazy=False):
...@@ -186,15 +170,9 @@ class MapDatasetWrapper(Dataset): ...@@ -186,15 +170,9 @@ class MapDatasetWrapper(Dataset):
""" """
if lazy: if lazy:
self._transform_pipline.append(fn) self._transform_pipline.append(fn)
else:
if not self.new_data:
self.new_data = [
fn(self.data[idx]) for idx in range(len(self.data))
]
else: else:
self.new_data = [ self.new_data = [
fn(self.new_data[idx]) fn(self.new_data[idx]) for idx in range(len(self.new_data))
for idx in range(len(self.new_data))
] ]
return self return self
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册