提交 fb18671b 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!506 [Dataset] Multiprocessing support for Pyfunc

Merge pull request !506 from JunhanHu/multiprocess_pyfunc
...@@ -24,6 +24,7 @@ import math ...@@ -24,6 +24,7 @@ import math
import os import os
import random import random
import uuid import uuid
import multiprocessing
from enum import Enum from enum import Enum
from importlib import import_module from importlib import import_module
...@@ -231,7 +232,7 @@ class Dataset: ...@@ -231,7 +232,7 @@ class Dataset:
@check_map @check_map
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None): num_parallel_workers=None, python_multiprocessing=False):
""" """
Applies each operation in operations to this dataset. Applies each operation in operations to this dataset.
...@@ -270,6 +271,8 @@ class Dataset: ...@@ -270,6 +271,8 @@ class Dataset:
same). same).
num_parallel_workers (int, optional): Number of threads used to process the dataset in num_parallel_workers (int, optional): Number of threads used to process the dataset in
parallel (default=None, the value from the config will be used). parallel (default=None, the value from the config will be used).
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=False).
Returns: Returns:
MapDataset, dataset after mapping operation. MapDataset, dataset after mapping operation.
...@@ -383,7 +386,8 @@ class Dataset: ...@@ -383,7 +386,8 @@ class Dataset:
>>> columns_order = ["mod7", "mod3", "col1"] >>> columns_order = ["mod7", "mod3", "col1"]
>>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order) >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
""" """
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers) return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers,
python_multiprocessing)
@check_filter @check_filter
def filter(self, predicate, input_columns=None, num_parallel_workers=1): def filter(self, predicate, input_columns=None, num_parallel_workers=1):
...@@ -1076,6 +1080,55 @@ class ShuffleDataset(DatasetOp): ...@@ -1076,6 +1080,55 @@ class ShuffleDataset(DatasetOp):
return args return args
# Pyfunc collection for multiprocess pyfunc
# This global variable will only be used within subprocesses
_GLOBAL_PYFUNC_LIST = []
# Pyfunc worker init function
# Python multiprocessing library forbid sending lambda function through pipe.
# This init function allow us to add all python function to a global collection and then fork afterwards.
def _pyfunc_worker_init(pyfunc_list):
global _GLOBAL_PYFUNC_LIST
_GLOBAL_PYFUNC_LIST = pyfunc_list
# Pyfunc worker execution function
# All exceptions will be raised to main processes
def _pyfunc_worker_exec(index, *args):
try:
return _GLOBAL_PYFUNC_LIST[index](*args)
except KeyboardInterrupt:
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
# PythonCallable wrapper for multiprocess pyfunc
class _PythonCallable:
"""
Internal python function wrapper for multiprocessing pyfunc
"""
def __init__(self, py_callable, idx, pool=None):
# Original python callable from user.
self.py_callable = py_callable
# Process pool created for current iterator.
self.pool = pool
# Python callable index for subprocess _GLOBAL_PYFUNC_LIST
self.idx = idx
def __call__(self, *args):
if self.pool is not None:
try:
# This call will send the tensors along with Python callable index to the process pool.
# Block, yield GIL. Current thread will reacquire GIL once result is returned.
return self.pool.apply(_pyfunc_worker_exec, [self.idx, *args])
except KeyboardInterrupt:
self.pool.terminate()
self.pool.join()
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
# Invoke original python callable in master process in case the pool is gone.
return self.py_callable(*args)
class MapDataset(DatasetOp): class MapDataset(DatasetOp):
""" """
The result of applying Map operator to the input Dataset. The result of applying Map operator to the input Dataset.
...@@ -1095,13 +1148,15 @@ class MapDataset(DatasetOp): ...@@ -1095,13 +1148,15 @@ class MapDataset(DatasetOp):
The argument is mandatory if len(input_columns) != len(output_columns). The argument is mandatory if len(input_columns) != len(output_columns).
num_parallel_workers (int, optional): Number of workers to process the Dataset num_parallel_workers (int, optional): Number of workers to process the Dataset
in parallel (default=None). in parallel (default=None).
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=False).
Raises: Raises:
ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified. ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
""" """
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None, def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None): num_parallel_workers=None, python_multiprocessing=False):
super().__init__(num_parallel_workers) super().__init__(num_parallel_workers)
self.input.append(input_dataset) self.input.append(input_dataset)
if input_columns is not None and not isinstance(input_columns, list): if input_columns is not None and not isinstance(input_columns, list):
...@@ -1122,6 +1177,8 @@ class MapDataset(DatasetOp): ...@@ -1122,6 +1177,8 @@ class MapDataset(DatasetOp):
input_dataset.output.append(self) input_dataset.output.append(self)
self._input_indexs = input_dataset.input_indexs self._input_indexs = input_dataset.input_indexs
self.python_multiprocessing = python_multiprocessing
self.process_pool = None
def get_args(self): def get_args(self):
args = super().get_args() args = super().get_args()
...@@ -1139,6 +1196,40 @@ class MapDataset(DatasetOp): ...@@ -1139,6 +1196,40 @@ class MapDataset(DatasetOp):
""" """
return self.input[0].get_dataset_size() return self.input[0].get_dataset_size()
# Iterator bootstrap will be called on iterator construction.
# A deep copy of Dataset object is created prior of iterator_bootstrap.
# This method will create per iterator process pool and bind pyfunc execution to the pool.
def iterator_bootstrap(self):
"""
Per iterator bootstrap callback.
"""
if self.python_multiprocessing:
iter_specific_operations = []
callable_list = []
# Pass #1, look for python callables and build list
for op in self.operations:
if callable(op):
callable_list.append(op)
if callable_list:
# Construct pool with the callable list
# The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
initializer=_pyfunc_worker_init,
initargs=(callable_list,))
# Pass #2
idx = 0
for op in self.operations:
if callable(op):
# Wrap python callable into _PythonCallable
iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool))
idx += 1
else:
# CPP ops remain the same
iter_specific_operations.append(op)
self.operations = iter_specific_operations
class FilterDataset(DatasetOp): class FilterDataset(DatasetOp):
""" """
......
...@@ -63,6 +63,10 @@ def _alter_node(node): ...@@ -63,6 +63,10 @@ def _alter_node(node):
return new_shuffle return new_shuffle
if isinstance(node, de.MapDataset): if isinstance(node, de.MapDataset):
if node.python_multiprocessing:
# Bootstrap can only be performed on a copy of the original dataset node.
# Bootstrap on original dataset node will make all iterators share the same process pool
node.iterator_bootstrap()
if node.columns_order is not None: if node.columns_order is not None:
# Remove the connection between the parent's node to the current node because we are inserting a node. # Remove the connection between the parent's node to the current node because we are inserting a node.
if node.output: if node.output:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import numpy as np import numpy as np
import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
...@@ -181,6 +182,106 @@ def test_case_6(): ...@@ -181,6 +182,106 @@ def test_case_6():
i = i + 4 i = i + 4
def test_case_7():
"""
Test PyFunc
"""
logger.info("Test 1-1 PyFunc Multiprocess: lambda x : x + x")
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.map(input_columns="col0", output_columns="out", operations=(lambda x: x + x),
num_parallel_workers=4, python_multiprocessing = True)
i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# In this test, the dataset is 2x2 sequential tensors
golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
assert np.array_equal(item["out"], golden)
i = i + 4
def test_case_8():
"""
Test PyFunc
"""
logger.info("Test Multiprocess n-m PyFunc : lambda x, y : (x , x + 1, x + y)")
col = ["col0", "col1"]
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.map(input_columns=col, output_columns=["out0", "out1", "out2"], num_parallel_workers=4,
operations=(lambda x, y: (x, x + y, x + y + 1)), columns_order=["out0", "out1", "out2"],
python_multiprocessing=True)
i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# In this test, the dataset is 2x2 sequential tensors
golden = np.array([[i, i + 1], [i + 2, i + 3]])
assert np.array_equal(item["out0"], golden)
golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
assert np.array_equal(item["out1"], golden)
golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]])
assert np.array_equal(item["out2"], golden)
i = i + 4
def test_case_9():
"""
Test PyFunc
"""
logger.info("Test multiple 1-1 PyFunc Multiprocess: lambda x : x + x")
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.map(input_columns="col0", output_columns="out", operations=[(lambda x: x + x), (lambda x: x + 1),
(lambda x: x + 2)],
num_parallel_workers=4, python_multiprocessing=True)
i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# In this test, the dataset is 2x2 sequential tensors
golden = np.array([[i * 2 + 3, (i + 1) * 2 + 3], [(i + 2) * 2 + 3, (i + 3) * 2 + 3]])
assert np.array_equal(item["out"], golden)
i = i + 4
def test_pyfunc_execption():
logger.info("Test PyFunc Execption Throw: lambda x : raise Execption()")
def pyfunc(x):
raise Exception("Pyfunc Throw")
with pytest.raises(RuntimeError) as info:
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc,
num_parallel_workers=4)
for _ in data1:
pass
assert "Pyfunc Throw" in str(info.value)
def test_pyfunc_execption_multiprocess():
logger.info("Test Multiprocess PyFunc Execption Throw: lambda x : raise Execption()")
def pyfunc(x):
raise Exception("MP Pyfunc Throw")
with pytest.raises(RuntimeError) as info:
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc,
num_parallel_workers=4, python_multiprocessing = True)
for _ in data1:
pass
assert "MP Pyfunc Throw" in str(info.value)
if __name__ == "__main__": if __name__ == "__main__":
test_case_0() test_case_0()
test_case_1() test_case_1()
...@@ -189,3 +290,8 @@ if __name__ == "__main__": ...@@ -189,3 +290,8 @@ if __name__ == "__main__":
test_case_4() test_case_4()
test_case_5() test_case_5()
test_case_6() test_case_6()
test_case_7()
test_case_8()
test_case_9()
test_pyfunc_execption()
test_pyfunc_execption_multiprocess()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册