Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
fb18671b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fb18671b
编写于
4月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!506 [Dataset] Multiprocessing support for Pyfunc
Merge pull request !506 from JunhanHu/multiprocess_pyfunc
上级
4a0b2b4a
b13e7bc3
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
204 addition
and
3 deletion
+204
-3
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+94
-3
mindspore/dataset/engine/iterators.py
mindspore/dataset/engine/iterators.py
+4
-0
tests/ut/python/dataset/test_pyfunc.py
tests/ut/python/dataset/test_pyfunc.py
+106
-0
未找到文件。
mindspore/dataset/engine/datasets.py
浏览文件 @
fb18671b
...
...
@@ -24,6 +24,7 @@ import math
import
os
import
random
import
uuid
import
multiprocessing
from
enum
import
Enum
from
importlib
import
import_module
...
...
@@ -231,7 +232,7 @@ class Dataset:
@
check_map
def
map
(
self
,
input_columns
=
None
,
operations
=
None
,
output_columns
=
None
,
columns_order
=
None
,
num_parallel_workers
=
None
):
num_parallel_workers
=
None
,
python_multiprocessing
=
False
):
"""
Applies each operation in operations to this dataset.
...
...
@@ -270,6 +271,8 @@ class Dataset:
same).
num_parallel_workers (int, optional): Number of threads used to process the dataset in
parallel (default=None, the value from the config will be used).
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=False).
Returns:
MapDataset, dataset after mapping operation.
...
...
@@ -383,7 +386,8 @@ class Dataset:
>>> columns_order = ["mod7", "mod3", "col1"]
>>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
"""
return
MapDataset
(
self
,
input_columns
,
operations
,
output_columns
,
columns_order
,
num_parallel_workers
)
return
MapDataset
(
self
,
input_columns
,
operations
,
output_columns
,
columns_order
,
num_parallel_workers
,
python_multiprocessing
)
@
check_filter
def
filter
(
self
,
predicate
,
input_columns
=
None
,
num_parallel_workers
=
1
):
...
...
@@ -1076,6 +1080,55 @@ class ShuffleDataset(DatasetOp):
return
args
# Pyfunc collection for multiprocess pyfunc
# This global variable will only be used within subprocesses
_GLOBAL_PYFUNC_LIST
=
[]
# Pyfunc worker init function
# Python multiprocessing library forbid sending lambda function through pipe.
# This init function allow us to add all python function to a global collection and then fork afterwards.
def
_pyfunc_worker_init
(
pyfunc_list
):
global
_GLOBAL_PYFUNC_LIST
_GLOBAL_PYFUNC_LIST
=
pyfunc_list
# Pyfunc worker execution function
# All exceptions will be raised to main processes
def
_pyfunc_worker_exec
(
index
,
*
args
):
try
:
return
_GLOBAL_PYFUNC_LIST
[
index
](
*
args
)
except
KeyboardInterrupt
:
raise
Exception
(
"Multiprocess MapOp worker receives KeyboardInterrupt"
)
# PythonCallable wrapper for multiprocess pyfunc
class
_PythonCallable
:
"""
Internal python function wrapper for multiprocessing pyfunc
"""
def
__init__
(
self
,
py_callable
,
idx
,
pool
=
None
):
# Original python callable from user.
self
.
py_callable
=
py_callable
# Process pool created for current iterator.
self
.
pool
=
pool
# Python callable index for subprocess _GLOBAL_PYFUNC_LIST
self
.
idx
=
idx
def
__call__
(
self
,
*
args
):
if
self
.
pool
is
not
None
:
try
:
# This call will send the tensors along with Python callable index to the process pool.
# Block, yield GIL. Current thread will reacquire GIL once result is returned.
return
self
.
pool
.
apply
(
_pyfunc_worker_exec
,
[
self
.
idx
,
*
args
])
except
KeyboardInterrupt
:
self
.
pool
.
terminate
()
self
.
pool
.
join
()
raise
Exception
(
"Multiprocess MapOp worker receives KeyboardInterrupt"
)
# Invoke original python callable in master process in case the pool is gone.
return
self
.
py_callable
(
*
args
)
class
MapDataset
(
DatasetOp
):
"""
The result of applying Map operator to the input Dataset.
...
...
@@ -1095,13 +1148,15 @@ class MapDataset(DatasetOp):
The argument is mandatory if len(input_columns) != len(output_columns).
num_parallel_workers (int, optional): Number of workers to process the Dataset
in parallel (default=None).
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=False).
Raises:
ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
"""
def
__init__
(
self
,
input_dataset
,
input_columns
=
None
,
operations
=
None
,
output_columns
=
None
,
columns_order
=
None
,
num_parallel_workers
=
None
):
num_parallel_workers
=
None
,
python_multiprocessing
=
False
):
super
().
__init__
(
num_parallel_workers
)
self
.
input
.
append
(
input_dataset
)
if
input_columns
is
not
None
and
not
isinstance
(
input_columns
,
list
):
...
...
@@ -1122,6 +1177,8 @@ class MapDataset(DatasetOp):
input_dataset
.
output
.
append
(
self
)
self
.
_input_indexs
=
input_dataset
.
input_indexs
self
.
python_multiprocessing
=
python_multiprocessing
self
.
process_pool
=
None
def
get_args
(
self
):
args
=
super
().
get_args
()
...
...
@@ -1139,6 +1196,40 @@ class MapDataset(DatasetOp):
"""
return
self
.
input
[
0
].
get_dataset_size
()
# Iterator bootstrap will be called on iterator construction.
# A deep copy of Dataset object is created prior of iterator_bootstrap.
# This method will create per iterator process pool and bind pyfunc execution to the pool.
def
iterator_bootstrap
(
self
):
"""
Per iterator bootstrap callback.
"""
if
self
.
python_multiprocessing
:
iter_specific_operations
=
[]
callable_list
=
[]
# Pass #1, look for python callables and build list
for
op
in
self
.
operations
:
if
callable
(
op
):
callable_list
.
append
(
op
)
if
callable_list
:
# Construct pool with the callable list
# The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
self
.
process_pool
=
multiprocessing
.
Pool
(
processes
=
self
.
num_parallel_workers
,
initializer
=
_pyfunc_worker_init
,
initargs
=
(
callable_list
,))
# Pass #2
idx
=
0
for
op
in
self
.
operations
:
if
callable
(
op
):
# Wrap python callable into _PythonCallable
iter_specific_operations
.
append
(
_PythonCallable
(
op
,
idx
,
self
.
process_pool
))
idx
+=
1
else
:
# CPP ops remain the same
iter_specific_operations
.
append
(
op
)
self
.
operations
=
iter_specific_operations
class
FilterDataset
(
DatasetOp
):
"""
...
...
mindspore/dataset/engine/iterators.py
浏览文件 @
fb18671b
...
...
@@ -63,6 +63,10 @@ def _alter_node(node):
return
new_shuffle
if
isinstance
(
node
,
de
.
MapDataset
):
if
node
.
python_multiprocessing
:
# Bootstrap can only be performed on a copy of the original dataset node.
# Bootstrap on original dataset node will make all iterators share the same process pool
node
.
iterator_bootstrap
()
if
node
.
columns_order
is
not
None
:
# Remove the connection between the parent's node to the current node because we are inserting a node.
if
node
.
output
:
...
...
tests/ut/python/dataset/test_pyfunc.py
浏览文件 @
fb18671b
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import
numpy
as
np
import
pytest
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
...
...
@@ -181,6 +182,106 @@ def test_case_6():
i
=
i
+
4
def
test_case_7
():
"""
Test PyFunc
"""
logger
.
info
(
"Test 1-1 PyFunc Multiprocess: lambda x : x + x"
)
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
data1
=
data1
.
map
(
input_columns
=
"col0"
,
output_columns
=
"out"
,
operations
=
(
lambda
x
:
x
+
x
),
num_parallel_workers
=
4
,
python_multiprocessing
=
True
)
i
=
0
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
# In this test, the dataset is 2x2 sequential tensors
golden
=
np
.
array
([[
i
*
2
,
(
i
+
1
)
*
2
],
[(
i
+
2
)
*
2
,
(
i
+
3
)
*
2
]])
assert
np
.
array_equal
(
item
[
"out"
],
golden
)
i
=
i
+
4
def
test_case_8
():
"""
Test PyFunc
"""
logger
.
info
(
"Test Multiprocess n-m PyFunc : lambda x, y : (x , x + 1, x + y)"
)
col
=
[
"col0"
,
"col1"
]
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
data1
=
data1
.
map
(
input_columns
=
col
,
output_columns
=
[
"out0"
,
"out1"
,
"out2"
],
num_parallel_workers
=
4
,
operations
=
(
lambda
x
,
y
:
(
x
,
x
+
y
,
x
+
y
+
1
)),
columns_order
=
[
"out0"
,
"out1"
,
"out2"
],
python_multiprocessing
=
True
)
i
=
0
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
# In this test, the dataset is 2x2 sequential tensors
golden
=
np
.
array
([[
i
,
i
+
1
],
[
i
+
2
,
i
+
3
]])
assert
np
.
array_equal
(
item
[
"out0"
],
golden
)
golden
=
np
.
array
([[
i
*
2
,
(
i
+
1
)
*
2
],
[(
i
+
2
)
*
2
,
(
i
+
3
)
*
2
]])
assert
np
.
array_equal
(
item
[
"out1"
],
golden
)
golden
=
np
.
array
([[
i
*
2
+
1
,
(
i
+
1
)
*
2
+
1
],
[(
i
+
2
)
*
2
+
1
,
(
i
+
3
)
*
2
+
1
]])
assert
np
.
array_equal
(
item
[
"out2"
],
golden
)
i
=
i
+
4
def
test_case_9
():
"""
Test PyFunc
"""
logger
.
info
(
"Test multiple 1-1 PyFunc Multiprocess: lambda x : x + x"
)
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
data1
=
data1
.
map
(
input_columns
=
"col0"
,
output_columns
=
"out"
,
operations
=
[(
lambda
x
:
x
+
x
),
(
lambda
x
:
x
+
1
),
(
lambda
x
:
x
+
2
)],
num_parallel_workers
=
4
,
python_multiprocessing
=
True
)
i
=
0
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
# In this test, the dataset is 2x2 sequential tensors
golden
=
np
.
array
([[
i
*
2
+
3
,
(
i
+
1
)
*
2
+
3
],
[(
i
+
2
)
*
2
+
3
,
(
i
+
3
)
*
2
+
3
]])
assert
np
.
array_equal
(
item
[
"out"
],
golden
)
i
=
i
+
4
def
test_pyfunc_execption
():
logger
.
info
(
"Test PyFunc Execption Throw: lambda x : raise Execption()"
)
def
pyfunc
(
x
):
raise
Exception
(
"Pyfunc Throw"
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
data1
=
data1
.
map
(
input_columns
=
"col0"
,
output_columns
=
"out"
,
operations
=
pyfunc
,
num_parallel_workers
=
4
)
for
_
in
data1
:
pass
assert
"Pyfunc Throw"
in
str
(
info
.
value
)
def
test_pyfunc_execption_multiprocess
():
logger
.
info
(
"Test Multiprocess PyFunc Execption Throw: lambda x : raise Execption()"
)
def
pyfunc
(
x
):
raise
Exception
(
"MP Pyfunc Throw"
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
shuffle
=
False
)
data1
=
data1
.
map
(
input_columns
=
"col0"
,
output_columns
=
"out"
,
operations
=
pyfunc
,
num_parallel_workers
=
4
,
python_multiprocessing
=
True
)
for
_
in
data1
:
pass
assert
"MP Pyfunc Throw"
in
str
(
info
.
value
)
if
__name__
==
"__main__"
:
test_case_0
()
test_case_1
()
...
...
@@ -189,3 +290,8 @@ if __name__ == "__main__":
test_case_4
()
test_case_5
()
test_case_6
()
test_case_7
()
test_case_8
()
test_case_9
()
test_pyfunc_execption
()
test_pyfunc_execption_multiprocess
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录