diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 855e4609bbd310b23c826cbbed5ebf812ce294b8..3225ebc806af61da2ed1d20f40d0bc748555e2e1 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -24,6 +24,7 @@ import math import os import random import uuid +import multiprocessing from enum import Enum from importlib import import_module @@ -231,7 +232,7 @@ class Dataset: @check_map def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, - num_parallel_workers=None): + num_parallel_workers=None, python_multiprocessing=False): """ Applies each operation in operations to this dataset. @@ -270,6 +271,8 @@ class Dataset: same). num_parallel_workers (int, optional): Number of threads used to process the dataset in parallel (default=None, the value from the config will be used). + python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This + option could be beneficial if the python operation is computational heavy (default=False). Returns: MapDataset, dataset after mapping operation. @@ -383,7 +386,8 @@ class Dataset: >>> columns_order = ["mod7", "mod3", "col1"] >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order) """ - return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers) + return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers, + python_multiprocessing) @check_filter def filter(self, predicate, input_columns=None, num_parallel_workers=1): @@ -1076,6 +1080,55 @@ class ShuffleDataset(DatasetOp): return args +# Pyfunc collection for multiprocess pyfunc +# This global variable will only be used within subprocesses +_GLOBAL_PYFUNC_LIST = [] + + +# Pyfunc worker init function +# Python multiprocessing library forbid sending lambda function through pipe. +# This init function allow us to add all python function to a global collection and then fork afterwards. +def _pyfunc_worker_init(pyfunc_list): + global _GLOBAL_PYFUNC_LIST + _GLOBAL_PYFUNC_LIST = pyfunc_list + + +# Pyfunc worker execution function +# All exceptions will be raised to main processes +def _pyfunc_worker_exec(index, *args): + try: + return _GLOBAL_PYFUNC_LIST[index](*args) + except KeyboardInterrupt: + raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") + + +# PythonCallable wrapper for multiprocess pyfunc +class _PythonCallable: + """ + Internal python function wrapper for multiprocessing pyfunc + """ + def __init__(self, py_callable, idx, pool=None): + # Original python callable from user. + self.py_callable = py_callable + # Process pool created for current iterator. + self.pool = pool + # Python callable index for subprocess _GLOBAL_PYFUNC_LIST + self.idx = idx + + def __call__(self, *args): + if self.pool is not None: + try: + # This call will send the tensors along with Python callable index to the process pool. + # Block, yield GIL. Current thread will reacquire GIL once result is returned. + return self.pool.apply(_pyfunc_worker_exec, [self.idx, *args]) + except KeyboardInterrupt: + self.pool.terminate() + self.pool.join() + raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") + # Invoke original python callable in master process in case the pool is gone. + return self.py_callable(*args) + + class MapDataset(DatasetOp): """ The result of applying Map operator to the input Dataset. @@ -1095,13 +1148,15 @@ class MapDataset(DatasetOp): The argument is mandatory if len(input_columns) != len(output_columns). num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel (default=None). + python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This + option could be beneficial if the python operation is computational heavy (default=False). Raises: ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified. """ def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None, - num_parallel_workers=None): + num_parallel_workers=None, python_multiprocessing=False): super().__init__(num_parallel_workers) self.input.append(input_dataset) if input_columns is not None and not isinstance(input_columns, list): @@ -1122,6 +1177,8 @@ class MapDataset(DatasetOp): input_dataset.output.append(self) self._input_indexs = input_dataset.input_indexs + self.python_multiprocessing = python_multiprocessing + self.process_pool = None def get_args(self): args = super().get_args() @@ -1139,6 +1196,40 @@ class MapDataset(DatasetOp): """ return self.input[0].get_dataset_size() + # Iterator bootstrap will be called on iterator construction. + # A deep copy of Dataset object is created prior of iterator_bootstrap. + # This method will create per iterator process pool and bind pyfunc execution to the pool. + def iterator_bootstrap(self): + """ + Per iterator bootstrap callback. + """ + if self.python_multiprocessing: + iter_specific_operations = [] + callable_list = [] + + # Pass #1, look for python callables and build list + for op in self.operations: + if callable(op): + callable_list.append(op) + + if callable_list: + # Construct pool with the callable list + # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses + self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, + initializer=_pyfunc_worker_init, + initargs=(callable_list,)) + # Pass #2 + idx = 0 + for op in self.operations: + if callable(op): + # Wrap python callable into _PythonCallable + iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool)) + idx += 1 + else: + # CPP ops remain the same + iter_specific_operations.append(op) + self.operations = iter_specific_operations + class FilterDataset(DatasetOp): """ diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 6af6c7dba8e5080a216ff1993b05883080a7f235..d562110c798e9f1323879b86979c0a23c424882d 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -63,6 +63,10 @@ def _alter_node(node): return new_shuffle if isinstance(node, de.MapDataset): + if node.python_multiprocessing: + # Bootstrap can only be performed on a copy of the original dataset node. + # Bootstrap on original dataset node will make all iterators share the same process pool + node.iterator_bootstrap() if node.columns_order is not None: # Remove the connection between the parent's node to the current node because we are inserting a node. if node.output: diff --git a/tests/ut/python/dataset/test_pyfunc.py b/tests/ut/python/dataset/test_pyfunc.py index 4b0672a1f2ad90fddc652ba526c6c5e3bcbb7a2b..e7bdc4863992782b2e7eeed7287bd96dc5258ec2 100644 --- a/tests/ut/python/dataset/test_pyfunc.py +++ b/tests/ut/python/dataset/test_pyfunc.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== import numpy as np +import pytest import mindspore.dataset as ds from mindspore import log as logger @@ -181,6 +182,106 @@ def test_case_6(): i = i + 4 +def test_case_7(): + """ + Test PyFunc + """ + logger.info("Test 1-1 PyFunc Multiprocess: lambda x : x + x") + + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + + data1 = data1.map(input_columns="col0", output_columns="out", operations=(lambda x: x + x), + num_parallel_workers=4, python_multiprocessing = True) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + # In this test, the dataset is 2x2 sequential tensors + golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]]) + assert np.array_equal(item["out"], golden) + i = i + 4 + + +def test_case_8(): + """ + Test PyFunc + """ + logger.info("Test Multiprocess n-m PyFunc : lambda x, y : (x , x + 1, x + y)") + + col = ["col0", "col1"] + + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + + data1 = data1.map(input_columns=col, output_columns=["out0", "out1", "out2"], num_parallel_workers=4, + operations=(lambda x, y: (x, x + y, x + y + 1)), columns_order=["out0", "out1", "out2"], + python_multiprocessing=True) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + # In this test, the dataset is 2x2 sequential tensors + golden = np.array([[i, i + 1], [i + 2, i + 3]]) + assert np.array_equal(item["out0"], golden) + golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]]) + assert np.array_equal(item["out1"], golden) + golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]]) + assert np.array_equal(item["out2"], golden) + i = i + 4 + + +def test_case_9(): + """ + Test PyFunc + """ + logger.info("Test multiple 1-1 PyFunc Multiprocess: lambda x : x + x") + + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + + data1 = data1.map(input_columns="col0", output_columns="out", operations=[(lambda x: x + x), (lambda x: x + 1), + (lambda x: x + 2)], + num_parallel_workers=4, python_multiprocessing=True) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + # In this test, the dataset is 2x2 sequential tensors + golden = np.array([[i * 2 + 3, (i + 1) * 2 + 3], [(i + 2) * 2 + 3, (i + 3) * 2 + 3]]) + assert np.array_equal(item["out"], golden) + i = i + 4 + + +def test_pyfunc_execption(): + logger.info("Test PyFunc Execption Throw: lambda x : raise Execption()") + + def pyfunc(x): + raise Exception("Pyfunc Throw") + + with pytest.raises(RuntimeError) as info: + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc, + num_parallel_workers=4) + for _ in data1: + pass + assert "Pyfunc Throw" in str(info.value) + + +def test_pyfunc_execption_multiprocess(): + logger.info("Test Multiprocess PyFunc Execption Throw: lambda x : raise Execption()") + + def pyfunc(x): + raise Exception("MP Pyfunc Throw") + + with pytest.raises(RuntimeError) as info: + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc, + num_parallel_workers=4, python_multiprocessing = True) + for _ in data1: + pass + assert "MP Pyfunc Throw" in str(info.value) + + if __name__ == "__main__": test_case_0() test_case_1() @@ -189,3 +290,8 @@ if __name__ == "__main__": test_case_4() test_case_5() test_case_6() + test_case_7() + test_case_8() + test_case_9() + test_pyfunc_execption() + test_pyfunc_execption_multiprocess()