提交 2ddb6305 编写于 作者: S smallv0221

upgrade MapDatasetWrapper

上级 bec2933a
......@@ -89,14 +89,26 @@ class MapDatasetWrapper(Dataset):
def __init__(self, data):
self.data = data
self._transform_func = None
self._transform_pipline = []
self.new_data = []
def transform(self, data, pipline):
for fn in reversed(pipline):
data = fn(data)
return data
def __getitem__(self, idx):
return self._transform_func(self.data[
idx]) if self._transform_func else self.data[idx]
if not self.new_data:
return self.transform(
self.data[idx], self._transform_pipline
) if self._transform_pipline else self.data[idx]
else:
return self.transform(
self.new_data[idx], self._transform_pipline
) if self._transform_pipline else self.new_data[idx]
def __len__(self):
return len(self.data)
return len(self.data) if not self.new_data else len(self.new_data)
def filter(self, fn):
"""
......@@ -108,11 +120,17 @@ class MapDatasetWrapper(Dataset):
Returns:
MapDatasetWrapper: The filtered dataset
"""
filted_data = [
self.data[idx] for idx in range(len(self.data))
if fn(self.data[idx])
]
return type(self)(filted_data)
if not self.new_data:
self.new_data = [
self.data[idx] for idx in range(len(self.data))
if fn(self.data[idx])
]
else:
self.new_data = [
self.data[idx] for idx in range(len(self.new_data))
if fn(self.new_data[idx])
]
return self
def shard(self, num_shards=None, index=None):
"""
......@@ -135,13 +153,53 @@ class MapDatasetWrapper(Dataset):
num_samples = int(math.ceil(len(self.data) * 1.0 / num_shards))
total_size = num_samples * num_shards
# add extra samples to make it evenly divisible
sharded_data = [
self.data[idx] for idx in range(len(self.data))
if idx % num_shards == index
]
if len(sharded_data) < num_samples:
sharded_data.append(self.data[index + 1 - num_shards])
return type(self)(sharded_data)
if not self.new_data:
self.new_data = [
self.data[idx] for idx in range(len(self.data))
if idx % num_shards == index
]
if len(self.new_data) < num_samples:
self.new_data.append(self.data[index + 1 - num_shards])
else:
self.new_data = [
self.data[idx] for idx in range(len(self.new_data))
if idx % num_shards == index
]
if len(self.new_data) < num_samples:
self.new_data.append(self.new_data[index + 1 - num_shards])
return self
def apply(self, fn, lazy=False):
"""
Performs specific function on the dataset to transform every sample.
Args:
fn (callable): Transformations to be performed. It receives single
sample as argument rather than dataset.
lazy (bool, optional): If True, transformations would be delayed and
performed on demand. Otherwise, transforms all samples at once
and return a new MapDatasetWrapper instance. Note that if `fn` is
stochastic, `lazy` should be True or you will get the same
result on all epochs. Defalt: False.
Returns:
MapDatasetWrapper: A new MapDatasetWrapper instance if `lazy` is True, \
otherwise bind `fn` as a property to transform on demand.
"""
if lazy:
self._transform_pipline.append(fn)
else:
if not self.new_data:
self.new_data = [
fn(self.data[idx]) for idx in range(len(self.data))
]
else:
self.new_data = [
fn(self.new_data[idx])
for idx in range(len(self.new_data))
]
return self
def __getattr__(self, name):
return getattr(self.data, name)
def apply(self, fn, lazy=False):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册