Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
161bb240
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
161bb240
编写于
4年前
作者:
M
mindspore-ci-bot
提交者:
Gitee
4年前
浏览文件
操作
浏览文件
下载
差异文件
!512 [Dataset] Multiprocessing support for GeneratorDataset
Merge pull request !512 from JunhanHu/multiprocess_generator
上级
fb18671b
78001ac9
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
245 addition
and
2 deletion
+245
-2
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+146
-2
tests/ut/python/dataset/test_generator.py
tests/ut/python/dataset/test_generator.py
+99
-0
未找到文件。
mindspore/dataset/engine/datasets.py
浏览文件 @
161bb240
...
...
@@ -25,6 +25,7 @@ import os
import
random
import
uuid
import
multiprocessing
import
queue
from
enum
import
Enum
from
importlib
import
import_module
...
...
@@ -2124,6 +2125,142 @@ def _cpp_sampler_fn(sampler, dataset):
yield
tuple
([
np
.
array
(
x
)
for
x
in
val
])
def
_cpp_sampler_fn_mp
(
sampler
,
dataset
,
num_worker
):
"""
Multiprocessing generator function wrapper for mappable dataset with cpp sampler
"""
indices
=
sampler
.
get_indices
()
return
_sampler_fn_mp
(
indices
,
dataset
,
num_worker
)
def
_py_sampler_fn_mp
(
sampler
,
num_samples
,
dataset
,
num_worker
):
"""
Multiprocessing generator function wrapper for mappable dataset with python sampler
"""
indices
=
_fetch_py_sampler_indices
(
sampler
,
num_samples
)
return
_sampler_fn_mp
(
indices
,
dataset
,
num_worker
)
def
_fetch_py_sampler_indices
(
sampler
,
num_samples
):
"""
Indices fetcher for python sampler
"""
if
num_samples
is
not
None
:
sampler_iter
=
iter
(
sampler
)
ret
=
[]
for
_
in
range
(
num_samples
):
try
:
val
=
next
(
sampler_iter
)
ret
.
append
(
val
)
except
StopIteration
:
break
return
ret
return
[
i
for
i
in
sampler
]
def
_fill_worker_indices
(
workers
,
indices
,
idx
):
"""
Worker index queue filler, fill worker index queue in round robin order
"""
num_worker
=
len
(
workers
)
while
idx
<
len
(
indices
):
try
:
workers
[
idx
%
num_worker
].
put
(
indices
[
idx
])
idx
+=
1
except
queue
.
Full
:
break
return
idx
def
_sampler_fn_mp
(
indices
,
dataset
,
num_worker
):
"""
Multiprocessing generator function wrapper master process
"""
workers
=
[]
# Event for end of epoch
eoe
=
multiprocessing
.
Event
()
# Create workers
for
_
in
range
(
num_worker
):
worker
=
_GeneratorWorker
(
dataset
,
eoe
)
worker
.
daemon
=
True
workers
.
append
(
worker
)
# Fill initial index queues
idx_cursor
=
0
idx_cursor
=
_fill_worker_indices
(
workers
,
indices
,
idx_cursor
)
# Start all workers
for
w
in
workers
:
w
.
start
()
# Fetch results
for
i
in
range
(
len
(
indices
)):
# Fetch result and put index
try
:
result
=
workers
[
i
%
num_worker
].
get
()
except
queue
.
Empty
:
raise
Exception
(
"Generator worker process timeout"
)
except
KeyboardInterrupt
:
for
w
in
workers
:
w
.
terminate
()
w
.
join
()
raise
Exception
(
"Generator worker receives KeyboardInterrupt"
)
if
idx_cursor
<
len
(
indices
):
idx_cursor
=
_fill_worker_indices
(
workers
,
indices
,
idx_cursor
)
# Set eoe event once all indices are sent
if
idx_cursor
==
len
(
indices
)
and
not
eoe
.
is_set
():
eoe
.
set
()
yield
tuple
([
np
.
array
(
x
)
for
x
in
result
])
def
_generator_worker_loop
(
dataset
,
idx_queue
,
result_queue
,
eoe
):
"""
Multiprocessing generator worker process loop
"""
while
True
:
# Fetch index, block
try
:
idx
=
idx_queue
.
get
()
except
KeyboardInterrupt
:
raise
Exception
(
"Generator worker receives KeyboardInterrupt"
)
if
idx
is
None
:
# When the queue is out of scope from master process, a None item can be fetched from the queue.
# Upon receiving None, worker process should check if EOE is set.
assert
eoe
.
is_set
(),
""
return
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
result
=
dataset
[
idx
]
# Send data, block
try
:
result_queue
.
put
(
result
)
except
KeyboardInterrupt
:
raise
Exception
(
"Generator worker receives KeyboardInterrupt"
)
del
result
,
idx
class
_GeneratorWorker
(
multiprocessing
.
Process
):
"""
Worker process for multiprocess Generator
"""
def
__init__
(
self
,
dataset
,
eoe
):
self
.
idx_queue
=
multiprocessing
.
Queue
(
16
)
self
.
res_queue
=
multiprocessing
.
Queue
(
16
)
super
().
__init__
(
target
=
_generator_worker_loop
,
args
=
(
dataset
,
self
.
idx_queue
,
self
.
res_queue
,
eoe
))
def
put
(
self
,
item
):
"""
Put function for worker index queue. Never block. Raise queue.Full on failure.
"""
self
.
idx_queue
.
put_nowait
(
item
)
def
get
(
self
):
"""
Get function for worker result queue. Block with timeout.
"""
return
self
.
res_queue
.
get
(
timeout
=
5
)
class
GeneratorDataset
(
SourceDataset
):
"""
A source dataset that generate data from python by invoking python data source each epoch.
...
...
@@ -2171,6 +2308,7 @@ class GeneratorDataset(SourceDataset):
If the schema is not provided, the meta data from column_names and column_types is considered the schema.
num_samples (int, optional): The number of samples to be included in the dataset
(default=None, all images).
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
(default=None, expected order behavior shown in the table).
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
...
...
@@ -2229,7 +2367,13 @@ class GeneratorDataset(SourceDataset):
sampler_instance
.
set_num_rows
(
len
(
source
))
sampler_instance
.
set_num_samples
(
num_samples
)
sampler_instance
.
initialize
()
if
num_parallel_workers
>
1
:
self
.
source
=
(
lambda
:
_cpp_sampler_fn_mp
(
sampler_instance
,
source
,
num_parallel_workers
))
else
:
self
.
source
=
(
lambda
:
_cpp_sampler_fn
(
sampler_instance
,
source
))
else
:
if
num_parallel_workers
>
1
:
self
.
source
=
(
lambda
:
_py_sampler_fn_mp
(
self
.
sampler
,
num_samples
,
source
,
num_parallel_workers
))
else
:
self
.
source
=
(
lambda
:
_py_sampler_fn
(
self
.
sampler
,
num_samples
,
source
))
else
:
...
...
This diff is collapsed.
Click to expand it.
tests/ut/python/dataset/test_generator.py
浏览文件 @
161bb240
...
...
@@ -391,6 +391,80 @@ def test_case_13():
i
=
i
+
1
def
test_case_14
():
"""
Test 1D Generator MP + CPP sampler
"""
logger
.
info
(
"Test 1D Generator MP : 0 - 63"
)
source
=
[(
np
.
array
([
x
]),)
for
x
in
range
(
256
)]
ds1
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
sampler
=
ds
.
SequentialSampler
(),
num_parallel_workers
=
4
).
repeat
(
2
)
i
=
0
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
data
[
"data"
],
golden
)
i
=
i
+
1
if
i
==
256
:
i
=
0
def
test_case_15
():
"""
Test 1D Generator MP + Python sampler
"""
logger
.
info
(
"Test 1D Generator MP : 0 - 63"
)
sampler
=
[
x
for
x
in
range
(
256
)]
source
=
[(
np
.
array
([
x
]),)
for
x
in
range
(
256
)]
ds1
=
ds
.
GeneratorDataset
(
source
,
[
"data"
],
sampler
=
sampler
,
num_parallel_workers
=
4
).
repeat
(
2
)
i
=
0
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
data
[
"data"
],
golden
)
i
=
i
+
1
if
i
==
256
:
i
=
0
def
test_case_16
():
"""
Test multi column generator Mp + CPP sampler
"""
logger
.
info
(
"Test multi column generator"
)
source
=
[(
np
.
array
([
x
]),
np
.
array
([
x
+
1
]))
for
x
in
range
(
256
)]
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
source
,
[
"col0"
,
"col1"
],
sampler
=
ds
.
SequentialSampler
())
i
=
0
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
"col0"
],
golden
)
golden
=
np
.
array
([
i
+
1
])
assert
np
.
array_equal
(
item
[
"col1"
],
golden
)
i
=
i
+
1
def
test_case_17
():
"""
Test multi column generator Mp + Python sampler
"""
logger
.
info
(
"Test multi column generator"
)
sampler
=
[
x
for
x
in
range
(
256
)]
source
=
[(
np
.
array
([
x
]),
np
.
array
([
x
+
1
]))
for
x
in
range
(
256
)]
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
source
,
[
"col0"
,
"col1"
],
sampler
=
sampler
)
i
=
0
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
"col0"
],
golden
)
golden
=
np
.
array
([
i
+
1
])
assert
np
.
array_equal
(
item
[
"col1"
],
golden
)
i
=
i
+
1
def
test_case_error_1
():
def
generator_np
():
for
i
in
range
(
64
):
...
...
@@ -506,6 +580,25 @@ def test_num_samples_underflow():
count
=
count
+
1
assert
count
==
64
def
manual_test_keyborad_interrupt
():
"""
Test keyborad_interrupt
"""
logger
.
info
(
"Test 1D Generator MP : 0 - 63"
)
class
MyDS
():
def
__getitem__
(
self
,
item
):
while
True
:
pass
def
__len__
(
self
):
return
1024
ds1
=
ds
.
GeneratorDataset
(
MyDS
(),
[
"data"
],
num_parallel_workers
=
4
).
repeat
(
2
)
i
=
0
for
data
in
ds1
.
create_dict_iterator
():
# each data is a dictionary
pass
if
__name__
==
"__main__"
:
test_case_0
()
...
...
@@ -522,6 +615,10 @@ if __name__ == "__main__":
test_case_11
()
test_case_12
()
test_case_13
()
test_case_14
()
test_case_15
()
test_case_16
()
test_case_17
()
test_case_error_1
()
test_case_error_2
()
test_case_error_3
()
...
...
@@ -529,3 +626,5 @@ if __name__ == "__main__":
test_sequential_sampler
()
test_distributed_sampler
()
test_random_sampler
()
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
新手
引导
客服
返回
顶部