提交 3b71bd0d 编写于 作者: X xulei2020

rename input to children, output to parent

上级 5dac9c4c
......@@ -134,8 +134,8 @@ class Dataset:
"""
def __init__(self, num_parallel_workers=None):
self.input = []
self.output = []
self.children = []
self.parent = []
self.num_parallel_workers = num_parallel_workers
self._device_iter = 0
self._input_indexs = ()
......@@ -1006,9 +1006,9 @@ class Dataset:
dev_id = output_dataset.shard_id
return "", dev_id
if not output_dataset.input:
if not output_dataset.children:
raise RuntimeError("Unknown output_dataset: {}".format(type(output_dataset)))
input_dataset = output_dataset.input[0]
input_dataset = output_dataset.children[0]
return get_distribution(input_dataset)
distribution_path, device_id = get_distribution(self)
......@@ -1129,8 +1129,8 @@ class Dataset:
Return:
Number, number of batches.
"""
if self.input:
return self.input[0].get_dataset_size()
if self.children:
return self.children[0].get_dataset_size()
return None
def num_classes(self):
......@@ -1140,23 +1140,23 @@ class Dataset:
Return:
Number, number of classes.
"""
if self.input:
return self.input[0].num_classes()
if self.children:
return self.children[0].num_classes()
return None
def get_sync_notifiers(self):
if self.input:
return self.input[0].get_sync_notifiers()
if self.children:
return self.children[0].get_sync_notifiers()
return {}
def disable_sync(self):
if self.input:
return self.input[0].disable_sync()
if self.children:
return self.children[0].disable_sync()
return {}
def is_sync(self):
if self.input:
return self.input[0].is_sync()
if self.children:
return self.children[0].is_sync()
return False
def sync_update(self, condition_name, num_batch=None, data=None):
......@@ -1190,8 +1190,8 @@ class Dataset:
Return:
Number, the number of data in a batch.
"""
if self.input:
return self.input[0].get_batch_size()
if self.children:
return self.children[0].get_batch_size()
return 1
def get_repeat_count(self):
......@@ -1201,8 +1201,8 @@ class Dataset:
Return:
Number, the count of repeat.
"""
if self.input:
return self.input[0].get_repeat_count()
if self.children:
return self.children[0].get_repeat_count()
return 1
def get_class_indexing(self):
......@@ -1212,22 +1212,22 @@ class Dataset:
Return:
Dict, A str-to-int mapping from label name to index.
"""
if self.input:
return self.input[0].get_class_indexing()
if self.children:
return self.children[0].get_class_indexing()
raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self)))
def reset(self):
"""Reset the dataset for next epoch."""
def is_shuffled(self):
for input_dataset in self.input:
for input_dataset in self.children:
if input_dataset.is_shuffled():
return True
return False
def is_sharded(self):
for input_dataset in self.input:
for input_dataset in self.children:
if input_dataset.is_sharded():
return True
......@@ -1466,8 +1466,8 @@ class BucketBatchByLengthDataset(DatasetOp):
self.pad_to_bucket_boundary = pad_to_bucket_boundary
self.drop_remainder = drop_remainder
self.input.append(input_dataset)
input_dataset.output.append(self)
self.children.append(input_dataset)
input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs
def get_args(self):
......@@ -1529,8 +1529,8 @@ class BatchDataset(DatasetOp):
self.per_batch_map = per_batch_map
self.input_columns = input_columns
self.pad_info = pad_info
self.input.append(input_dataset)
input_dataset.output.append(self)
self.children.append(input_dataset)
input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs
def get_args(self):
......@@ -1549,7 +1549,7 @@ class BatchDataset(DatasetOp):
Return:
Number, number of batches.
"""
child_size = self.input[0].get_dataset_size()
child_size = self.children[0].get_dataset_size()
if child_size is not None:
if self.drop_remainder:
return math.floor(child_size / self.batch_size)
......@@ -1578,7 +1578,7 @@ class BatchDataset(DatasetOp):
if isinstance(dataset, RepeatDataset):
return True
flag = False
for input_dataset in dataset.input:
for input_dataset in dataset.children:
flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset)
return flag
......@@ -1593,7 +1593,7 @@ class BatchDataset(DatasetOp):
"""
if isinstance(dataset, SyncWaitDataset):
dataset.update_sync_batch_size(batch_size)
for input_dataset in dataset.input:
for input_dataset in dataset.children:
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
......@@ -1699,21 +1699,21 @@ class SyncWaitDataset(DatasetOp):
def __init__(self, input_dataset, condition_name, num_batch, callback=None):
super().__init__()
self.input.append(input_dataset)
input_dataset.output.append(self)
self.children.append(input_dataset)
input_dataset.parent.append(self)
# set to the default value, waiting for the batch to update it
self._condition_name = condition_name
if isinstance(num_batch, int) and num_batch <= 0:
raise ValueError("num_batch need to be greater than 0.")
self._pair = BlockReleasePair(num_batch, callback)
if self._condition_name in self.input[0].get_sync_notifiers():
if self._condition_name in self.children[0].get_sync_notifiers():
raise RuntimeError("Condition name is already in use")
logger.warning("Please remember to add dataset.sync_update(condition=%s), otherwise will result in hanging",
condition_name)
def get_sync_notifiers(self):
return {**self.input[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}}
return {**self.children[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}}
def is_sync(self):
return True
......@@ -1746,7 +1746,7 @@ class SyncWaitDataset(DatasetOp):
if isinstance(dataset, BatchDataset):
return True
flag = False
for input_dataset in dataset.input:
for input_dataset in dataset.children:
flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset)
return flag
......@@ -1766,9 +1766,9 @@ class ShuffleDataset(DatasetOp):
def __init__(self, input_dataset, buffer_size):
super().__init__()
self.buffer_size = buffer_size
self.input.append(input_dataset)
self.children.append(input_dataset)
self.reshuffle_each_epoch = None
input_dataset.output.append(self)
input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs
if self.is_sync():
raise RuntimeError("No shuffle after sync operators")
......@@ -1864,7 +1864,7 @@ class MapDataset(DatasetOp):
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None, python_multiprocessing=False):
super().__init__(num_parallel_workers)
self.input.append(input_dataset)
self.children.append(input_dataset)
if input_columns is not None and not isinstance(input_columns, list):
input_columns = [input_columns]
self.input_columns = input_columns
......@@ -1881,7 +1881,7 @@ class MapDataset(DatasetOp):
and self.columns_order is None:
raise ValueError("When (len(input_columns) != len(output_columns)), columns_order must be specified.")
input_dataset.output.append(self)
input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs
self.python_multiprocessing = python_multiprocessing
self.process_pool = None
......@@ -1901,7 +1901,7 @@ class MapDataset(DatasetOp):
Return:
Number, number of batches.
"""
return self.input[0].get_dataset_size()
return self.children[0].get_dataset_size()
def __deepcopy__(self, memodict):
if id(self) in memodict:
......@@ -1909,12 +1909,12 @@ class MapDataset(DatasetOp):
cls = self.__class__
new_op = cls.__new__(cls)
memodict[id(self)] = new_op
new_op.input = copy.deepcopy(self.input, memodict)
new_op.children = copy.deepcopy(self.children, memodict)
new_op.input_columns = copy.deepcopy(self.input_columns, memodict)
new_op.output_columns = copy.deepcopy(self.output_columns, memodict)
new_op.columns_order = copy.deepcopy(self.columns_order, memodict)
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
new_op.output = copy.deepcopy(self.output, memodict)
new_op.parent = copy.deepcopy(self.parent, memodict)
new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict)
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
new_op.operations = self.operations
......@@ -1975,8 +1975,8 @@ class FilterDataset(DatasetOp):
def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None):
super().__init__(num_parallel_workers)
self.predicate = lambda *args: bool(predicate(*args))
self.input.append(input_dataset)
input_dataset.output.append(self)
self.children.append(input_dataset)
input_dataset.parent.append(self)
if input_columns is not None and not isinstance(input_columns, list):
input_columns = [input_columns]
self.input_columns = input_columns
......@@ -2012,8 +2012,8 @@ class RepeatDataset(DatasetOp):
self.count = -1
else:
self.count = count
self.input.append(input_dataset)
input_dataset.output.append(self)
self.children.append(input_dataset)
input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs
def get_args(self):
......@@ -2028,7 +2028,7 @@ class RepeatDataset(DatasetOp):
Return:
Number, number of batches.
"""
child_size = self.input[0].get_dataset_size()
child_size = self.children[0].get_dataset_size()
if child_size is not None:
return child_size
return None
......@@ -2055,8 +2055,8 @@ class SkipDataset(DatasetOp):
def __init__(self, input_dataset, count):
super().__init__()
self.count = count
self.input.append(input_dataset)
input_dataset.output.append(self)
self.children.append(input_dataset)
input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs
def get_args(self):
......@@ -2071,7 +2071,7 @@ class SkipDataset(DatasetOp):
Return:
Number, number of batches.
"""
child_size = self.input[0].get_dataset_size()
child_size = self.children[0].get_dataset_size()
output_size = 0
if self.count >= 0 and self.count < child_size:
output_size = child_size - self.count
......@@ -2090,8 +2090,8 @@ class TakeDataset(DatasetOp):
def __init__(self, input_dataset, count):
super().__init__()
self.count = count
self.input.append(input_dataset)
input_dataset.output.append(self)
self.children.append(input_dataset)
input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs
def get_args(self):
......@@ -2106,7 +2106,7 @@ class TakeDataset(DatasetOp):
Return:
Number, number of batches.
"""
child_size = self.input[0].get_dataset_size()
child_size = self.children[0].get_dataset_size()
if child_size < self.count:
return child_size
return self.count
......@@ -2130,8 +2130,8 @@ class ZipDataset(DatasetOp):
raise TypeError("The parameter %s of zip has type error!" % (dataset))
self.datasets = datasets
for data in datasets:
self.input.append(data)
data.output.append(self)
self.children.append(data)
data.parent.append(self)
def get_dataset_size(self):
"""
......@@ -2140,7 +2140,7 @@ class ZipDataset(DatasetOp):
Return:
Number, number of batches.
"""
children_sizes = [c.get_dataset_size() for c in self.input]
children_sizes = [c.get_dataset_size() for c in self.children]
if all(c is not None for c in children_sizes):
return min(children_sizes)
return None
......@@ -2155,7 +2155,7 @@ class ZipDataset(DatasetOp):
return None
def is_sync(self):
return any([c.is_sync() for c in self.input])
return any([c.is_sync() for c in self.children])
def get_args(self):
args = super().get_args()
......@@ -2180,8 +2180,8 @@ class ConcatDataset(DatasetOp):
raise TypeError("The parameter %s of concat has type error!" % (dataset))
self.datasets = datasets
for data in datasets:
self.input.append(data)
data.output.append(self)
self.children.append(data)
data.parent.append(self)
def get_dataset_size(self):
"""
......@@ -2190,7 +2190,7 @@ class ConcatDataset(DatasetOp):
Return:
Number, number of batches.
"""
children_sizes = [c.get_dataset_size() for c in self.input]
children_sizes = [c.get_dataset_size() for c in self.children]
dataset_size = sum(children_sizes)
return dataset_size
......@@ -2213,8 +2213,8 @@ class RenameDataset(DatasetOp):
output_columns = [output_columns]
self.input_column_names = input_columns
self.output_column_names = output_columns
self.input.append(input_dataset)
input_dataset.output.append(self)
self.children.append(input_dataset)
input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs
def get_args(self):
......@@ -2240,10 +2240,10 @@ class ProjectDataset(DatasetOp):
if not isinstance(columns, list):
columns = [columns]
self.columns = columns
self.input.append(input_dataset)
self.children.append(input_dataset)
self.prefetch_size = prefetch_size
input_dataset.output.append(self)
input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs
def get_args(self):
......@@ -2267,8 +2267,8 @@ class TransferDataset(DatasetOp):
def __init__(self, input_dataset, queue_name, device_id, device_type, num_batch=None):
super().__init__()
self.input.append(input_dataset)
input_dataset.output.append(self)
self.children.append(input_dataset)
input_dataset.parent.append(self)
self.queue_name = queue_name
self._input_indexs = input_dataset.input_indexs
self._device_type = device_type
......@@ -3170,8 +3170,8 @@ class GeneratorDataset(MappableDataset):
cls = self.__class__
new_op = cls.__new__(cls)
memodict[id(self)] = new_op
new_op.input = copy.deepcopy(self.input, memodict)
new_op.output = copy.deepcopy(self.output, memodict)
new_op.children = copy.deepcopy(self.children, memodict)
new_op.parent = copy.deepcopy(self.parent, memodict)
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
new_op.column_types = copy.deepcopy(self.column_types, memodict)
new_op.column_names = copy.deepcopy(self.column_names, memodict)
......@@ -4879,14 +4879,14 @@ class BuildVocabDataset(DatasetOp):
prefetch_size=None):
super().__init__()
self.columns = columns
self.input.append(input_dataset)
self.children.append(input_dataset)
self.prefetch_size = prefetch_size
self.vocab = vocab
self.freq_range = freq_range
self.top_k = top_k
self.special_tokens = special_tokens
self.special_first = special_first
input_dataset.output.append(self)
input_dataset.parent.append(self)
def get_args(self):
args = super().get_args()
......@@ -4905,11 +4905,11 @@ class BuildVocabDataset(DatasetOp):
cls = self.__class__
new_op = cls.__new__(cls)
memodict[id(self)] = new_op
new_op.input = copy.deepcopy(self.input, memodict)
new_op.children = copy.deepcopy(self.children, memodict)
new_op.columns = copy.deepcopy(self.columns, memodict)
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
new_op.prefetch_size = copy.deepcopy(self.prefetch_size, memodict)
new_op.output = copy.deepcopy(self.output, memodict)
new_op.parent = copy.deepcopy(self.parent, memodict)
new_op.freq_range = copy.deepcopy(self.freq_range, memodict)
new_op.top_k = copy.deepcopy(self.top_k, memodict)
new_op.vocab = self.vocab
......
......@@ -38,13 +38,13 @@ def _cleanup():
def alter_tree(node):
"""Traversing the python Dataset tree/graph to perform some alteration to some specific nodes."""
if not node.input:
if not node.children:
return _alter_node(node)
converted_children = []
for input_op in node.input:
for input_op in node.children:
converted_children.append(alter_tree(input_op))
node.input = converted_children
node.children = converted_children
return _alter_node(node)
......@@ -86,14 +86,14 @@ class Iterator:
def __is_tree_node(self, node):
"""Check if a node is tree node."""
if not node.input:
if len(node.output) > 1:
if not node.children:
if len(node.parent) > 1:
return False
if len(node.output) > 1:
if len(node.parent) > 1:
return False
for input_node in node.input:
for input_node in node.children:
cls = self.__is_tree_node(input_node)
if not cls:
return False
......@@ -174,7 +174,7 @@ class Iterator:
op_type = self.__get_dataset_type(node)
c_node = self.depipeline.AddNodeToTree(op_type, node.get_args())
for py_child in node.input:
for py_child in node.children:
c_child = self.__convert_node_postorder(py_child)
self.depipeline.AddChildToParentNode(c_child, c_node)
......@@ -184,7 +184,7 @@ class Iterator:
"""Recursively get batch node in the dataset tree."""
if isinstance(dataset, de.BatchDataset):
return
for input_op in dataset.input:
for input_op in dataset.children:
self.__batch_node(input_op, level + 1)
@staticmethod
......@@ -194,11 +194,11 @@ class Iterator:
ptr = hex(id(dataset))
for _ in range(level):
logger.info("\t", end='')
if not dataset.input:
if not dataset.children:
logger.info("-%s (%s)", name, ptr)
else:
logger.info("+%s (%s)", name, ptr)
for input_op in dataset.input:
for input_op in dataset.children:
Iterator.__print_local(input_op, level + 1)
def print(self):
......
......@@ -182,11 +182,11 @@ def traverse(node):
node_repr['shard_id'] = None
# Leaf node doesn't have input attribute.
if not node.input:
if not node.children:
return node_repr
# Recursively traverse the child and assign it to the current node_repr['children'].
for child in node.input:
for child in node.children:
node_repr["children"].append(traverse(child))
return node_repr
......@@ -226,11 +226,11 @@ def construct_pipeline(node):
# Instantiate python Dataset object based on the current dictionary element
dataset = create_node(node)
# Initially it is not connected to any other object.
dataset.input = []
dataset.children = []
# Construct the children too and add edge between the children and parent.
for child in node['children']:
dataset.input.append(construct_pipeline(child))
dataset.children.append(construct_pipeline(child))
return dataset
......
......@@ -103,7 +103,7 @@ def test_tree_copy():
itr = data1.create_tuple_iterator()
assert id(data1) != id(itr.dataset)
assert id(data) != id(itr.dataset.input[0])
assert id(data) != id(itr.dataset.children[0])
assert id(data1.operations[0]) == id(itr.dataset.operations[0])
itr.release()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册