提交 894bd217 编写于 作者: S smallv0221

refine code

上级 923a3aa7
......@@ -90,25 +90,20 @@ class MapDatasetWrapper(Dataset):
def __init__(self, data):
self.data = data
self._transform_pipline = []
self.new_data = []
self.new_data = self.data
def transform(self, data, pipline):
def _transform(self, data, pipline):
for fn in reversed(pipline):
data = fn(data)
return data
def __getitem__(self, idx):
if not self.new_data:
return self.transform(
self.data[idx], self._transform_pipline
) if self._transform_pipline else self.data[idx]
else:
return self.transform(
self.new_data[idx], self._transform_pipline
) if self._transform_pipline else self.new_data[idx]
return self._transform(
self.new_data[idx], self._transform_pipline
) if self._transform_pipline else self.new_data[idx]
def __len__(self):
return len(self.data) if not self.new_data else len(self.new_data)
return len(self.new_data)
def filter(self, fn):
"""
......@@ -120,16 +115,11 @@ class MapDatasetWrapper(Dataset):
Returns:
MapDatasetWrapper: The filtered dataset
"""
if not self.new_data:
self.new_data = [
self.data[idx] for idx in range(len(self.data))
if fn(self.data[idx])
]
else:
self.new_data = [
self.data[idx] for idx in range(len(self.new_data))
if fn(self.new_data[idx])
]
self.new_data = [
self.new_data[idx] for idx in range(len(self.new_data))
if fn(self.new_data[idx])
]
return self
def shard(self, num_shards=None, index=None):
......@@ -150,23 +140,17 @@ class MapDatasetWrapper(Dataset):
num_shards = dist.get_world_size()
if index is None:
index = dist.get_rank()
num_samples = int(math.ceil(len(self.data) * 1.0 / num_shards))
num_samples = int(math.ceil(len(self.new_data) * 1.0 / num_shards))
total_size = num_samples * num_shards
# add extra samples to make it evenly divisible
if not self.new_data:
self.new_data = [
self.data[idx] for idx in range(len(self.data))
if idx % num_shards == index
]
if len(self.new_data) < num_samples:
self.new_data.append(self.data[index + 1 - num_shards])
else:
self.new_data = [
self.data[idx] for idx in range(len(self.new_data))
if idx % num_shards == index
]
if len(self.new_data) < num_samples:
self.new_data.append(self.new_data[index + 1 - num_shards])
self.new_data = [
self.new_data[idx] for idx in range(len(self.new_data))
if idx % num_shards == index
]
if len(self.new_data) < num_samples:
self.new_data.append(self.new_data[index + 1 - num_shards])
return self
def apply(self, fn, lazy=False):
......@@ -187,15 +171,9 @@ class MapDatasetWrapper(Dataset):
if lazy:
self._transform_pipline.append(fn)
else:
if not self.new_data:
self.new_data = [
fn(self.data[idx]) for idx in range(len(self.data))
]
else:
self.new_data = [
fn(self.new_data[idx])
for idx in range(len(self.new_data))
]
self.new_data = [
fn(self.new_data[idx]) for idx in range(len(self.new_data))
]
return self
def __getattr__(self, name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册