提交 b13e7bc3 编写于 作者: J Junhan Hu

Add python multiprocessing support for Mindspore.dataset

上级 822a3160
......@@ -24,6 +24,7 @@ import math
import os
import random
import uuid
import multiprocessing
from enum import Enum
from importlib import import_module
......@@ -231,7 +232,7 @@ class Dataset:
@check_map
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None):
num_parallel_workers=None, python_multiprocessing=False):
"""
Applies each operation in operations to this dataset.
......@@ -270,6 +271,8 @@ class Dataset:
same).
num_parallel_workers (int, optional): Number of threads used to process the dataset in
parallel (default=None, the value from the config will be used).
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=False).
Returns:
MapDataset, dataset after mapping operation.
......@@ -383,7 +386,8 @@ class Dataset:
>>> columns_order = ["mod7", "mod3", "col1"]
>>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
"""
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers)
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers,
python_multiprocessing)
@check_repeat
def repeat(self, count=None):
......@@ -1041,6 +1045,55 @@ class ShuffleDataset(DatasetOp):
return args
# Pyfunc collection for multiprocess pyfunc
# This global variable will only be used within subprocesses
_GLOBAL_PYFUNC_LIST = []
# Pyfunc worker init function
# Python multiprocessing library forbid sending lambda function through pipe.
# This init function allow us to add all python function to a global collection and then fork afterwards.
def _pyfunc_worker_init(pyfunc_list):
global _GLOBAL_PYFUNC_LIST
_GLOBAL_PYFUNC_LIST = pyfunc_list
# Pyfunc worker execution function
# All exceptions will be raised to main processes
def _pyfunc_worker_exec(index, *args):
try:
return _GLOBAL_PYFUNC_LIST[index](*args)
except KeyboardInterrupt:
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
# PythonCallable wrapper for multiprocess pyfunc
class _PythonCallable:
"""
Internal python function wrapper for multiprocessing pyfunc
"""
def __init__(self, py_callable, idx, pool=None):
# Original python callable from user.
self.py_callable = py_callable
# Process pool created for current iterator.
self.pool = pool
# Python callable index for subprocess _GLOBAL_PYFUNC_LIST
self.idx = idx
def __call__(self, *args):
if self.pool is not None:
try:
# This call will send the tensors along with Python callable index to the process pool.
# Block, yield GIL. Current thread will reacquire GIL once result is returned.
return self.pool.apply(_pyfunc_worker_exec, [self.idx, *args])
except KeyboardInterrupt:
self.pool.terminate()
self.pool.join()
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
# Invoke original python callable in master process in case the pool is gone.
return self.py_callable(*args)
class MapDataset(DatasetOp):
"""
The result of applying Map operator to the input Dataset.
......@@ -1060,13 +1113,15 @@ class MapDataset(DatasetOp):
The argument is mandatory if len(input_columns) != len(output_columns).
num_parallel_workers (int, optional): Number of workers to process the Dataset
in parallel (default=None).
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=False).
Raises:
ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
"""
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None):
num_parallel_workers=None, python_multiprocessing=False):
super().__init__(num_parallel_workers)
self.input.append(input_dataset)
if input_columns is not None and not isinstance(input_columns, list):
......@@ -1087,6 +1142,8 @@ class MapDataset(DatasetOp):
input_dataset.output.append(self)
self._input_indexs = input_dataset.input_indexs
self.python_multiprocessing = python_multiprocessing
self.process_pool = None
def get_args(self):
args = super().get_args()
......@@ -1104,6 +1161,40 @@ class MapDataset(DatasetOp):
"""
return self.input[0].get_dataset_size()
# Iterator bootstrap will be called on iterator construction.
# A deep copy of Dataset object is created prior of iterator_bootstrap.
# This method will create per iterator process pool and bind pyfunc execution to the pool.
def iterator_bootstrap(self):
"""
Per iterator bootstrap callback.
"""
if self.python_multiprocessing:
iter_specific_operations = []
callable_list = []
# Pass #1, look for python callables and build list
for op in self.operations:
if callable(op):
callable_list.append(op)
if callable_list:
# Construct pool with the callable list
# The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
initializer=_pyfunc_worker_init,
initargs=(callable_list,))
# Pass #2
idx = 0
for op in self.operations:
if callable(op):
# Wrap python callable into _PythonCallable
iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool))
idx += 1
else:
# CPP ops remain the same
iter_specific_operations.append(op)
self.operations = iter_specific_operations
class RepeatDataset(DatasetOp):
"""
......
......@@ -63,6 +63,10 @@ def _alter_node(node):
return new_shuffle
if isinstance(node, de.MapDataset):
if node.python_multiprocessing:
# Bootstrap can only be performed on a copy of the original dataset node.
# Bootstrap on original dataset node will make all iterators share the same process pool
node.iterator_bootstrap()
if node.columns_order is not None:
# Remove the connection between the parent's node to the current node because we are inserting a node.
if node.output:
......
......@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
......@@ -181,6 +182,106 @@ def test_case_6():
i = i + 4
def test_case_7():
"""
Test PyFunc
"""
logger.info("Test 1-1 PyFunc Multiprocess: lambda x : x + x")
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.map(input_columns="col0", output_columns="out", operations=(lambda x: x + x),
num_parallel_workers=4, python_multiprocessing = True)
i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# In this test, the dataset is 2x2 sequential tensors
golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
assert np.array_equal(item["out"], golden)
i = i + 4
def test_case_8():
"""
Test PyFunc
"""
logger.info("Test Multiprocess n-m PyFunc : lambda x, y : (x , x + 1, x + y)")
col = ["col0", "col1"]
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.map(input_columns=col, output_columns=["out0", "out1", "out2"], num_parallel_workers=4,
operations=(lambda x, y: (x, x + y, x + y + 1)), columns_order=["out0", "out1", "out2"],
python_multiprocessing=True)
i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# In this test, the dataset is 2x2 sequential tensors
golden = np.array([[i, i + 1], [i + 2, i + 3]])
assert np.array_equal(item["out0"], golden)
golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
assert np.array_equal(item["out1"], golden)
golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]])
assert np.array_equal(item["out2"], golden)
i = i + 4
def test_case_9():
"""
Test PyFunc
"""
logger.info("Test multiple 1-1 PyFunc Multiprocess: lambda x : x + x")
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.map(input_columns="col0", output_columns="out", operations=[(lambda x: x + x), (lambda x: x + 1),
(lambda x: x + 2)],
num_parallel_workers=4, python_multiprocessing=True)
i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# In this test, the dataset is 2x2 sequential tensors
golden = np.array([[i * 2 + 3, (i + 1) * 2 + 3], [(i + 2) * 2 + 3, (i + 3) * 2 + 3]])
assert np.array_equal(item["out"], golden)
i = i + 4
def test_pyfunc_execption():
logger.info("Test PyFunc Execption Throw: lambda x : raise Execption()")
def pyfunc(x):
raise Exception("Pyfunc Throw")
with pytest.raises(RuntimeError) as info:
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc,
num_parallel_workers=4)
for _ in data1:
pass
assert "Pyfunc Throw" in str(info.value)
def test_pyfunc_execption_multiprocess():
logger.info("Test Multiprocess PyFunc Execption Throw: lambda x : raise Execption()")
def pyfunc(x):
raise Exception("MP Pyfunc Throw")
with pytest.raises(RuntimeError) as info:
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc,
num_parallel_workers=4, python_multiprocessing = True)
for _ in data1:
pass
assert "MP Pyfunc Throw" in str(info.value)
if __name__ == "__main__":
test_case_0()
test_case_1()
......@@ -189,3 +290,8 @@ if __name__ == "__main__":
test_case_4()
test_case_5()
test_case_6()
test_case_7()
test_case_8()
test_case_9()
test_pyfunc_execption()
test_pyfunc_execption_multiprocess()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册