diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 3d13e733b0a9467ad23ae74f7a6d9cf26e1ffbe6..c076b1a45b7c7bc753c530fea8f736f3ad4ec662 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -30,7 +30,9 @@ from enum import Enum from importlib import import_module import threading +import copy import numpy as np + from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ MindRecordOp, TextFileOp, CBatchInfo from mindspore._c_expression import typing @@ -1376,6 +1378,23 @@ class MapDataset(DatasetOp): """ return self.input[0].get_dataset_size() + def __deepcopy__(self, memodict): + if id(self) in memodict: + return memodict[id(self)] + cls = self.__class__ + new_op = cls.__new__(cls) + memodict[id(self)] = new_op + new_op.input = copy.deepcopy(self.input, memodict) + new_op.input_columns = copy.deepcopy(self.input_columns, memodict) + new_op.output_columns = copy.deepcopy(self.output_columns, memodict) + new_op.columns_order = copy.deepcopy(self.columns_order, memodict) + new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) + new_op.output = copy.deepcopy(self.output, memodict) + new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) + new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) + new_op.operations = self.operations + return new_op + # Iterator bootstrap will be called on iterator construction. # A deep copy of Dataset object is created prior of iterator_bootstrap. # This method will create per iterator process pool and bind pyfunc execution to the pool. @@ -2600,6 +2619,23 @@ class GeneratorDataset(SourceDataset): else: raise ValueError('set dataset_size with negative value {}'.format(value)) + def __deepcopy__(self, memodict): + if id(self) in memodict: + return memodict[id(self)] + cls = self.__class__ + new_op = cls.__new__(cls) + memodict[id(self)] = new_op + new_op.input = copy.deepcopy(self.input, memodict) + new_op.output = copy.deepcopy(self.output, memodict) + new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) + new_op.column_types = copy.deepcopy(self.column_types, memodict) + new_op.column_names = copy.deepcopy(self.column_names, memodict) + + new_op.source = self.source + new_op.sampler = self.sampler + + return new_op + class TFRecordDataset(SourceDataset): """ diff --git a/tests/ut/python/dataset/test_iterator.py b/tests/ut/python/dataset/test_iterator.py index 7c69adf5616d6f6319b2377e8db77e035cd2fb24..58beecbe16fbeb36c0c1e077df4198e5d1aa7b3c 100644 --- a/tests/ut/python/dataset/test_iterator.py +++ b/tests/ut/python/dataset/test_iterator.py @@ -14,7 +14,7 @@ # ============================================================================== import numpy as np import pytest - +import copy import mindspore.dataset as ds from mindspore.dataset.engine.iterators import ITERATORS_LIST, _cleanup @@ -81,3 +81,33 @@ def test_iterator_weak_ref(): assert sum(itr() is not None for itr in ITERATORS_LIST) == 2 _cleanup() + + +class MyDict(dict): + def __getattr__(self, key): + return self[key] + + def __setattr__(self, key, value): + self[key] = value + + def __call__(self, t): + return t + + +def test_tree_copy(): + # Testing copying the tree with a pyfunc that cannot be pickled + + data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS) + data1 = data.map(operations=[MyDict()]) + + itr = data1.create_tuple_iterator() + + assert id(data1) != id(itr.dataset) + assert id(data) != id(itr.dataset.input[0]) + assert id(data1.operations[0]) == id(itr.dataset.operations[0]) + + itr.release() + + +if __name__ == '__main__': + test_tree_copy() \ No newline at end of file