提交 3b71bd0d 编写于 作者: X xulei2020

rename input to children, output to parent

上级 5dac9c4c
...@@ -134,8 +134,8 @@ class Dataset: ...@@ -134,8 +134,8 @@ class Dataset:
""" """
def __init__(self, num_parallel_workers=None): def __init__(self, num_parallel_workers=None):
self.input = [] self.children = []
self.output = [] self.parent = []
self.num_parallel_workers = num_parallel_workers self.num_parallel_workers = num_parallel_workers
self._device_iter = 0 self._device_iter = 0
self._input_indexs = () self._input_indexs = ()
...@@ -1006,9 +1006,9 @@ class Dataset: ...@@ -1006,9 +1006,9 @@ class Dataset:
dev_id = output_dataset.shard_id dev_id = output_dataset.shard_id
return "", dev_id return "", dev_id
if not output_dataset.input: if not output_dataset.children:
raise RuntimeError("Unknown output_dataset: {}".format(type(output_dataset))) raise RuntimeError("Unknown output_dataset: {}".format(type(output_dataset)))
input_dataset = output_dataset.input[0] input_dataset = output_dataset.children[0]
return get_distribution(input_dataset) return get_distribution(input_dataset)
distribution_path, device_id = get_distribution(self) distribution_path, device_id = get_distribution(self)
...@@ -1129,8 +1129,8 @@ class Dataset: ...@@ -1129,8 +1129,8 @@ class Dataset:
Return: Return:
Number, number of batches. Number, number of batches.
""" """
if self.input: if self.children:
return self.input[0].get_dataset_size() return self.children[0].get_dataset_size()
return None return None
def num_classes(self): def num_classes(self):
...@@ -1140,23 +1140,23 @@ class Dataset: ...@@ -1140,23 +1140,23 @@ class Dataset:
Return: Return:
Number, number of classes. Number, number of classes.
""" """
if self.input: if self.children:
return self.input[0].num_classes() return self.children[0].num_classes()
return None return None
def get_sync_notifiers(self): def get_sync_notifiers(self):
if self.input: if self.children:
return self.input[0].get_sync_notifiers() return self.children[0].get_sync_notifiers()
return {} return {}
def disable_sync(self): def disable_sync(self):
if self.input: if self.children:
return self.input[0].disable_sync() return self.children[0].disable_sync()
return {} return {}
def is_sync(self): def is_sync(self):
if self.input: if self.children:
return self.input[0].is_sync() return self.children[0].is_sync()
return False return False
def sync_update(self, condition_name, num_batch=None, data=None): def sync_update(self, condition_name, num_batch=None, data=None):
...@@ -1190,8 +1190,8 @@ class Dataset: ...@@ -1190,8 +1190,8 @@ class Dataset:
Return: Return:
Number, the number of data in a batch. Number, the number of data in a batch.
""" """
if self.input: if self.children:
return self.input[0].get_batch_size() return self.children[0].get_batch_size()
return 1 return 1
def get_repeat_count(self): def get_repeat_count(self):
...@@ -1201,8 +1201,8 @@ class Dataset: ...@@ -1201,8 +1201,8 @@ class Dataset:
Return: Return:
Number, the count of repeat. Number, the count of repeat.
""" """
if self.input: if self.children:
return self.input[0].get_repeat_count() return self.children[0].get_repeat_count()
return 1 return 1
def get_class_indexing(self): def get_class_indexing(self):
...@@ -1212,22 +1212,22 @@ class Dataset: ...@@ -1212,22 +1212,22 @@ class Dataset:
Return: Return:
Dict, A str-to-int mapping from label name to index. Dict, A str-to-int mapping from label name to index.
""" """
if self.input: if self.children:
return self.input[0].get_class_indexing() return self.children[0].get_class_indexing()
raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self))) raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self)))
def reset(self): def reset(self):
"""Reset the dataset for next epoch.""" """Reset the dataset for next epoch."""
def is_shuffled(self): def is_shuffled(self):
for input_dataset in self.input: for input_dataset in self.children:
if input_dataset.is_shuffled(): if input_dataset.is_shuffled():
return True return True
return False return False
def is_sharded(self): def is_sharded(self):
for input_dataset in self.input: for input_dataset in self.children:
if input_dataset.is_sharded(): if input_dataset.is_sharded():
return True return True
...@@ -1466,8 +1466,8 @@ class BucketBatchByLengthDataset(DatasetOp): ...@@ -1466,8 +1466,8 @@ class BucketBatchByLengthDataset(DatasetOp):
self.pad_to_bucket_boundary = pad_to_bucket_boundary self.pad_to_bucket_boundary = pad_to_bucket_boundary
self.drop_remainder = drop_remainder self.drop_remainder = drop_remainder
self.input.append(input_dataset) self.children.append(input_dataset)
input_dataset.output.append(self) input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs self._input_indexs = input_dataset.input_indexs
def get_args(self): def get_args(self):
...@@ -1529,8 +1529,8 @@ class BatchDataset(DatasetOp): ...@@ -1529,8 +1529,8 @@ class BatchDataset(DatasetOp):
self.per_batch_map = per_batch_map self.per_batch_map = per_batch_map
self.input_columns = input_columns self.input_columns = input_columns
self.pad_info = pad_info self.pad_info = pad_info
self.input.append(input_dataset) self.children.append(input_dataset)
input_dataset.output.append(self) input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs self._input_indexs = input_dataset.input_indexs
def get_args(self): def get_args(self):
...@@ -1549,7 +1549,7 @@ class BatchDataset(DatasetOp): ...@@ -1549,7 +1549,7 @@ class BatchDataset(DatasetOp):
Return: Return:
Number, number of batches. Number, number of batches.
""" """
child_size = self.input[0].get_dataset_size() child_size = self.children[0].get_dataset_size()
if child_size is not None: if child_size is not None:
if self.drop_remainder: if self.drop_remainder:
return math.floor(child_size / self.batch_size) return math.floor(child_size / self.batch_size)
...@@ -1578,7 +1578,7 @@ class BatchDataset(DatasetOp): ...@@ -1578,7 +1578,7 @@ class BatchDataset(DatasetOp):
if isinstance(dataset, RepeatDataset): if isinstance(dataset, RepeatDataset):
return True return True
flag = False flag = False
for input_dataset in dataset.input: for input_dataset in dataset.children:
flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset) flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset)
return flag return flag
...@@ -1593,7 +1593,7 @@ class BatchDataset(DatasetOp): ...@@ -1593,7 +1593,7 @@ class BatchDataset(DatasetOp):
""" """
if isinstance(dataset, SyncWaitDataset): if isinstance(dataset, SyncWaitDataset):
dataset.update_sync_batch_size(batch_size) dataset.update_sync_batch_size(batch_size)
for input_dataset in dataset.input: for input_dataset in dataset.children:
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
...@@ -1699,21 +1699,21 @@ class SyncWaitDataset(DatasetOp): ...@@ -1699,21 +1699,21 @@ class SyncWaitDataset(DatasetOp):
def __init__(self, input_dataset, condition_name, num_batch, callback=None): def __init__(self, input_dataset, condition_name, num_batch, callback=None):
super().__init__() super().__init__()
self.input.append(input_dataset) self.children.append(input_dataset)
input_dataset.output.append(self) input_dataset.parent.append(self)
# set to the default value, waiting for the batch to update it # set to the default value, waiting for the batch to update it
self._condition_name = condition_name self._condition_name = condition_name
if isinstance(num_batch, int) and num_batch <= 0: if isinstance(num_batch, int) and num_batch <= 0:
raise ValueError("num_batch need to be greater than 0.") raise ValueError("num_batch need to be greater than 0.")
self._pair = BlockReleasePair(num_batch, callback) self._pair = BlockReleasePair(num_batch, callback)
if self._condition_name in self.input[0].get_sync_notifiers(): if self._condition_name in self.children[0].get_sync_notifiers():
raise RuntimeError("Condition name is already in use") raise RuntimeError("Condition name is already in use")
logger.warning("Please remember to add dataset.sync_update(condition=%s), otherwise will result in hanging", logger.warning("Please remember to add dataset.sync_update(condition=%s), otherwise will result in hanging",
condition_name) condition_name)
def get_sync_notifiers(self): def get_sync_notifiers(self):
return {**self.input[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}} return {**self.children[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}}
def is_sync(self): def is_sync(self):
return True return True
...@@ -1746,7 +1746,7 @@ class SyncWaitDataset(DatasetOp): ...@@ -1746,7 +1746,7 @@ class SyncWaitDataset(DatasetOp):
if isinstance(dataset, BatchDataset): if isinstance(dataset, BatchDataset):
return True return True
flag = False flag = False
for input_dataset in dataset.input: for input_dataset in dataset.children:
flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset) flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset)
return flag return flag
...@@ -1766,9 +1766,9 @@ class ShuffleDataset(DatasetOp): ...@@ -1766,9 +1766,9 @@ class ShuffleDataset(DatasetOp):
def __init__(self, input_dataset, buffer_size): def __init__(self, input_dataset, buffer_size):
super().__init__() super().__init__()
self.buffer_size = buffer_size self.buffer_size = buffer_size
self.input.append(input_dataset) self.children.append(input_dataset)
self.reshuffle_each_epoch = None self.reshuffle_each_epoch = None
input_dataset.output.append(self) input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs self._input_indexs = input_dataset.input_indexs
if self.is_sync(): if self.is_sync():
raise RuntimeError("No shuffle after sync operators") raise RuntimeError("No shuffle after sync operators")
...@@ -1864,7 +1864,7 @@ class MapDataset(DatasetOp): ...@@ -1864,7 +1864,7 @@ class MapDataset(DatasetOp):
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None, def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None, python_multiprocessing=False): num_parallel_workers=None, python_multiprocessing=False):
super().__init__(num_parallel_workers) super().__init__(num_parallel_workers)
self.input.append(input_dataset) self.children.append(input_dataset)
if input_columns is not None and not isinstance(input_columns, list): if input_columns is not None and not isinstance(input_columns, list):
input_columns = [input_columns] input_columns = [input_columns]
self.input_columns = input_columns self.input_columns = input_columns
...@@ -1881,7 +1881,7 @@ class MapDataset(DatasetOp): ...@@ -1881,7 +1881,7 @@ class MapDataset(DatasetOp):
and self.columns_order is None: and self.columns_order is None:
raise ValueError("When (len(input_columns) != len(output_columns)), columns_order must be specified.") raise ValueError("When (len(input_columns) != len(output_columns)), columns_order must be specified.")
input_dataset.output.append(self) input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs self._input_indexs = input_dataset.input_indexs
self.python_multiprocessing = python_multiprocessing self.python_multiprocessing = python_multiprocessing
self.process_pool = None self.process_pool = None
...@@ -1901,7 +1901,7 @@ class MapDataset(DatasetOp): ...@@ -1901,7 +1901,7 @@ class MapDataset(DatasetOp):
Return: Return:
Number, number of batches. Number, number of batches.
""" """
return self.input[0].get_dataset_size() return self.children[0].get_dataset_size()
def __deepcopy__(self, memodict): def __deepcopy__(self, memodict):
if id(self) in memodict: if id(self) in memodict:
...@@ -1909,12 +1909,12 @@ class MapDataset(DatasetOp): ...@@ -1909,12 +1909,12 @@ class MapDataset(DatasetOp):
cls = self.__class__ cls = self.__class__
new_op = cls.__new__(cls) new_op = cls.__new__(cls)
memodict[id(self)] = new_op memodict[id(self)] = new_op
new_op.input = copy.deepcopy(self.input, memodict) new_op.children = copy.deepcopy(self.children, memodict)
new_op.input_columns = copy.deepcopy(self.input_columns, memodict) new_op.input_columns = copy.deepcopy(self.input_columns, memodict)
new_op.output_columns = copy.deepcopy(self.output_columns, memodict) new_op.output_columns = copy.deepcopy(self.output_columns, memodict)
new_op.columns_order = copy.deepcopy(self.columns_order, memodict) new_op.columns_order = copy.deepcopy(self.columns_order, memodict)
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
new_op.output = copy.deepcopy(self.output, memodict) new_op.parent = copy.deepcopy(self.parent, memodict)
new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict)
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
new_op.operations = self.operations new_op.operations = self.operations
...@@ -1975,8 +1975,8 @@ class FilterDataset(DatasetOp): ...@@ -1975,8 +1975,8 @@ class FilterDataset(DatasetOp):
def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None): def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None):
super().__init__(num_parallel_workers) super().__init__(num_parallel_workers)
self.predicate = lambda *args: bool(predicate(*args)) self.predicate = lambda *args: bool(predicate(*args))
self.input.append(input_dataset) self.children.append(input_dataset)
input_dataset.output.append(self) input_dataset.parent.append(self)
if input_columns is not None and not isinstance(input_columns, list): if input_columns is not None and not isinstance(input_columns, list):
input_columns = [input_columns] input_columns = [input_columns]
self.input_columns = input_columns self.input_columns = input_columns
...@@ -2012,8 +2012,8 @@ class RepeatDataset(DatasetOp): ...@@ -2012,8 +2012,8 @@ class RepeatDataset(DatasetOp):
self.count = -1 self.count = -1
else: else:
self.count = count self.count = count
self.input.append(input_dataset) self.children.append(input_dataset)
input_dataset.output.append(self) input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs self._input_indexs = input_dataset.input_indexs
def get_args(self): def get_args(self):
...@@ -2028,7 +2028,7 @@ class RepeatDataset(DatasetOp): ...@@ -2028,7 +2028,7 @@ class RepeatDataset(DatasetOp):
Return: Return:
Number, number of batches. Number, number of batches.
""" """
child_size = self.input[0].get_dataset_size() child_size = self.children[0].get_dataset_size()
if child_size is not None: if child_size is not None:
return child_size return child_size
return None return None
...@@ -2055,8 +2055,8 @@ class SkipDataset(DatasetOp): ...@@ -2055,8 +2055,8 @@ class SkipDataset(DatasetOp):
def __init__(self, input_dataset, count): def __init__(self, input_dataset, count):
super().__init__() super().__init__()
self.count = count self.count = count
self.input.append(input_dataset) self.children.append(input_dataset)
input_dataset.output.append(self) input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs self._input_indexs = input_dataset.input_indexs
def get_args(self): def get_args(self):
...@@ -2071,7 +2071,7 @@ class SkipDataset(DatasetOp): ...@@ -2071,7 +2071,7 @@ class SkipDataset(DatasetOp):
Return: Return:
Number, number of batches. Number, number of batches.
""" """
child_size = self.input[0].get_dataset_size() child_size = self.children[0].get_dataset_size()
output_size = 0 output_size = 0
if self.count >= 0 and self.count < child_size: if self.count >= 0 and self.count < child_size:
output_size = child_size - self.count output_size = child_size - self.count
...@@ -2090,8 +2090,8 @@ class TakeDataset(DatasetOp): ...@@ -2090,8 +2090,8 @@ class TakeDataset(DatasetOp):
def __init__(self, input_dataset, count): def __init__(self, input_dataset, count):
super().__init__() super().__init__()
self.count = count self.count = count
self.input.append(input_dataset) self.children.append(input_dataset)
input_dataset.output.append(self) input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs self._input_indexs = input_dataset.input_indexs
def get_args(self): def get_args(self):
...@@ -2106,7 +2106,7 @@ class TakeDataset(DatasetOp): ...@@ -2106,7 +2106,7 @@ class TakeDataset(DatasetOp):
Return: Return:
Number, number of batches. Number, number of batches.
""" """
child_size = self.input[0].get_dataset_size() child_size = self.children[0].get_dataset_size()
if child_size < self.count: if child_size < self.count:
return child_size return child_size
return self.count return self.count
...@@ -2130,8 +2130,8 @@ class ZipDataset(DatasetOp): ...@@ -2130,8 +2130,8 @@ class ZipDataset(DatasetOp):
raise TypeError("The parameter %s of zip has type error!" % (dataset)) raise TypeError("The parameter %s of zip has type error!" % (dataset))
self.datasets = datasets self.datasets = datasets
for data in datasets: for data in datasets:
self.input.append(data) self.children.append(data)
data.output.append(self) data.parent.append(self)
def get_dataset_size(self): def get_dataset_size(self):
""" """
...@@ -2140,7 +2140,7 @@ class ZipDataset(DatasetOp): ...@@ -2140,7 +2140,7 @@ class ZipDataset(DatasetOp):
Return: Return:
Number, number of batches. Number, number of batches.
""" """
children_sizes = [c.get_dataset_size() for c in self.input] children_sizes = [c.get_dataset_size() for c in self.children]
if all(c is not None for c in children_sizes): if all(c is not None for c in children_sizes):
return min(children_sizes) return min(children_sizes)
return None return None
...@@ -2155,7 +2155,7 @@ class ZipDataset(DatasetOp): ...@@ -2155,7 +2155,7 @@ class ZipDataset(DatasetOp):
return None return None
def is_sync(self): def is_sync(self):
return any([c.is_sync() for c in self.input]) return any([c.is_sync() for c in self.children])
def get_args(self): def get_args(self):
args = super().get_args() args = super().get_args()
...@@ -2180,8 +2180,8 @@ class ConcatDataset(DatasetOp): ...@@ -2180,8 +2180,8 @@ class ConcatDataset(DatasetOp):
raise TypeError("The parameter %s of concat has type error!" % (dataset)) raise TypeError("The parameter %s of concat has type error!" % (dataset))
self.datasets = datasets self.datasets = datasets
for data in datasets: for data in datasets:
self.input.append(data) self.children.append(data)
data.output.append(self) data.parent.append(self)
def get_dataset_size(self): def get_dataset_size(self):
""" """
...@@ -2190,7 +2190,7 @@ class ConcatDataset(DatasetOp): ...@@ -2190,7 +2190,7 @@ class ConcatDataset(DatasetOp):
Return: Return:
Number, number of batches. Number, number of batches.
""" """
children_sizes = [c.get_dataset_size() for c in self.input] children_sizes = [c.get_dataset_size() for c in self.children]
dataset_size = sum(children_sizes) dataset_size = sum(children_sizes)
return dataset_size return dataset_size
...@@ -2213,8 +2213,8 @@ class RenameDataset(DatasetOp): ...@@ -2213,8 +2213,8 @@ class RenameDataset(DatasetOp):
output_columns = [output_columns] output_columns = [output_columns]
self.input_column_names = input_columns self.input_column_names = input_columns
self.output_column_names = output_columns self.output_column_names = output_columns
self.input.append(input_dataset) self.children.append(input_dataset)
input_dataset.output.append(self) input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs self._input_indexs = input_dataset.input_indexs
def get_args(self): def get_args(self):
...@@ -2240,10 +2240,10 @@ class ProjectDataset(DatasetOp): ...@@ -2240,10 +2240,10 @@ class ProjectDataset(DatasetOp):
if not isinstance(columns, list): if not isinstance(columns, list):
columns = [columns] columns = [columns]
self.columns = columns self.columns = columns
self.input.append(input_dataset) self.children.append(input_dataset)
self.prefetch_size = prefetch_size self.prefetch_size = prefetch_size
input_dataset.output.append(self) input_dataset.parent.append(self)
self._input_indexs = input_dataset.input_indexs self._input_indexs = input_dataset.input_indexs
def get_args(self): def get_args(self):
...@@ -2267,8 +2267,8 @@ class TransferDataset(DatasetOp): ...@@ -2267,8 +2267,8 @@ class TransferDataset(DatasetOp):
def __init__(self, input_dataset, queue_name, device_id, device_type, num_batch=None): def __init__(self, input_dataset, queue_name, device_id, device_type, num_batch=None):
super().__init__() super().__init__()
self.input.append(input_dataset) self.children.append(input_dataset)
input_dataset.output.append(self) input_dataset.parent.append(self)
self.queue_name = queue_name self.queue_name = queue_name
self._input_indexs = input_dataset.input_indexs self._input_indexs = input_dataset.input_indexs
self._device_type = device_type self._device_type = device_type
...@@ -3170,8 +3170,8 @@ class GeneratorDataset(MappableDataset): ...@@ -3170,8 +3170,8 @@ class GeneratorDataset(MappableDataset):
cls = self.__class__ cls = self.__class__
new_op = cls.__new__(cls) new_op = cls.__new__(cls)
memodict[id(self)] = new_op memodict[id(self)] = new_op
new_op.input = copy.deepcopy(self.input, memodict) new_op.children = copy.deepcopy(self.children, memodict)
new_op.output = copy.deepcopy(self.output, memodict) new_op.parent = copy.deepcopy(self.parent, memodict)
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
new_op.column_types = copy.deepcopy(self.column_types, memodict) new_op.column_types = copy.deepcopy(self.column_types, memodict)
new_op.column_names = copy.deepcopy(self.column_names, memodict) new_op.column_names = copy.deepcopy(self.column_names, memodict)
...@@ -4879,14 +4879,14 @@ class BuildVocabDataset(DatasetOp): ...@@ -4879,14 +4879,14 @@ class BuildVocabDataset(DatasetOp):
prefetch_size=None): prefetch_size=None):
super().__init__() super().__init__()
self.columns = columns self.columns = columns
self.input.append(input_dataset) self.children.append(input_dataset)
self.prefetch_size = prefetch_size self.prefetch_size = prefetch_size
self.vocab = vocab self.vocab = vocab
self.freq_range = freq_range self.freq_range = freq_range
self.top_k = top_k self.top_k = top_k
self.special_tokens = special_tokens self.special_tokens = special_tokens
self.special_first = special_first self.special_first = special_first
input_dataset.output.append(self) input_dataset.parent.append(self)
def get_args(self): def get_args(self):
args = super().get_args() args = super().get_args()
...@@ -4905,11 +4905,11 @@ class BuildVocabDataset(DatasetOp): ...@@ -4905,11 +4905,11 @@ class BuildVocabDataset(DatasetOp):
cls = self.__class__ cls = self.__class__
new_op = cls.__new__(cls) new_op = cls.__new__(cls)
memodict[id(self)] = new_op memodict[id(self)] = new_op
new_op.input = copy.deepcopy(self.input, memodict) new_op.children = copy.deepcopy(self.children, memodict)
new_op.columns = copy.deepcopy(self.columns, memodict) new_op.columns = copy.deepcopy(self.columns, memodict)
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
new_op.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) new_op.prefetch_size = copy.deepcopy(self.prefetch_size, memodict)
new_op.output = copy.deepcopy(self.output, memodict) new_op.parent = copy.deepcopy(self.parent, memodict)
new_op.freq_range = copy.deepcopy(self.freq_range, memodict) new_op.freq_range = copy.deepcopy(self.freq_range, memodict)
new_op.top_k = copy.deepcopy(self.top_k, memodict) new_op.top_k = copy.deepcopy(self.top_k, memodict)
new_op.vocab = self.vocab new_op.vocab = self.vocab
......
...@@ -38,13 +38,13 @@ def _cleanup(): ...@@ -38,13 +38,13 @@ def _cleanup():
def alter_tree(node): def alter_tree(node):
"""Traversing the python Dataset tree/graph to perform some alteration to some specific nodes.""" """Traversing the python Dataset tree/graph to perform some alteration to some specific nodes."""
if not node.input: if not node.children:
return _alter_node(node) return _alter_node(node)
converted_children = [] converted_children = []
for input_op in node.input: for input_op in node.children:
converted_children.append(alter_tree(input_op)) converted_children.append(alter_tree(input_op))
node.input = converted_children node.children = converted_children
return _alter_node(node) return _alter_node(node)
...@@ -86,14 +86,14 @@ class Iterator: ...@@ -86,14 +86,14 @@ class Iterator:
def __is_tree_node(self, node): def __is_tree_node(self, node):
"""Check if a node is tree node.""" """Check if a node is tree node."""
if not node.input: if not node.children:
if len(node.output) > 1: if len(node.parent) > 1:
return False return False
if len(node.output) > 1: if len(node.parent) > 1:
return False return False
for input_node in node.input: for input_node in node.children:
cls = self.__is_tree_node(input_node) cls = self.__is_tree_node(input_node)
if not cls: if not cls:
return False return False
...@@ -174,7 +174,7 @@ class Iterator: ...@@ -174,7 +174,7 @@ class Iterator:
op_type = self.__get_dataset_type(node) op_type = self.__get_dataset_type(node)
c_node = self.depipeline.AddNodeToTree(op_type, node.get_args()) c_node = self.depipeline.AddNodeToTree(op_type, node.get_args())
for py_child in node.input: for py_child in node.children:
c_child = self.__convert_node_postorder(py_child) c_child = self.__convert_node_postorder(py_child)
self.depipeline.AddChildToParentNode(c_child, c_node) self.depipeline.AddChildToParentNode(c_child, c_node)
...@@ -184,7 +184,7 @@ class Iterator: ...@@ -184,7 +184,7 @@ class Iterator:
"""Recursively get batch node in the dataset tree.""" """Recursively get batch node in the dataset tree."""
if isinstance(dataset, de.BatchDataset): if isinstance(dataset, de.BatchDataset):
return return
for input_op in dataset.input: for input_op in dataset.children:
self.__batch_node(input_op, level + 1) self.__batch_node(input_op, level + 1)
@staticmethod @staticmethod
...@@ -194,11 +194,11 @@ class Iterator: ...@@ -194,11 +194,11 @@ class Iterator:
ptr = hex(id(dataset)) ptr = hex(id(dataset))
for _ in range(level): for _ in range(level):
logger.info("\t", end='') logger.info("\t", end='')
if not dataset.input: if not dataset.children:
logger.info("-%s (%s)", name, ptr) logger.info("-%s (%s)", name, ptr)
else: else:
logger.info("+%s (%s)", name, ptr) logger.info("+%s (%s)", name, ptr)
for input_op in dataset.input: for input_op in dataset.children:
Iterator.__print_local(input_op, level + 1) Iterator.__print_local(input_op, level + 1)
def print(self): def print(self):
......
...@@ -182,11 +182,11 @@ def traverse(node): ...@@ -182,11 +182,11 @@ def traverse(node):
node_repr['shard_id'] = None node_repr['shard_id'] = None
# Leaf node doesn't have input attribute. # Leaf node doesn't have input attribute.
if not node.input: if not node.children:
return node_repr return node_repr
# Recursively traverse the child and assign it to the current node_repr['children']. # Recursively traverse the child and assign it to the current node_repr['children'].
for child in node.input: for child in node.children:
node_repr["children"].append(traverse(child)) node_repr["children"].append(traverse(child))
return node_repr return node_repr
...@@ -226,11 +226,11 @@ def construct_pipeline(node): ...@@ -226,11 +226,11 @@ def construct_pipeline(node):
# Instantiate python Dataset object based on the current dictionary element # Instantiate python Dataset object based on the current dictionary element
dataset = create_node(node) dataset = create_node(node)
# Initially it is not connected to any other object. # Initially it is not connected to any other object.
dataset.input = [] dataset.children = []
# Construct the children too and add edge between the children and parent. # Construct the children too and add edge between the children and parent.
for child in node['children']: for child in node['children']:
dataset.input.append(construct_pipeline(child)) dataset.children.append(construct_pipeline(child))
return dataset return dataset
......
...@@ -103,7 +103,7 @@ def test_tree_copy(): ...@@ -103,7 +103,7 @@ def test_tree_copy():
itr = data1.create_tuple_iterator() itr = data1.create_tuple_iterator()
assert id(data1) != id(itr.dataset) assert id(data1) != id(itr.dataset)
assert id(data) != id(itr.dataset.input[0]) assert id(data) != id(itr.dataset.children[0])
assert id(data1.operations[0]) == id(itr.dataset.operations[0]) assert id(data1.operations[0]) == id(itr.dataset.operations[0])
itr.release() itr.release()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册