Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
edc92ccf
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
edc92ccf
编写于
8月 22, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(imperative/data): improve dataloader preformance
GitOrigin-RevId: 7d8d52aaeb47e7ec6c3efa282ff9014a4b7d1f01
上级
896b0193
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
524 addition
and
732 deletion
+524
-732
imperative/python/megengine/data/dataloader.py
imperative/python/megengine/data/dataloader.py
+363
-563
imperative/python/megengine/data/sampler.py
imperative/python/megengine/data/sampler.py
+20
-9
imperative/python/test/unit/data/test_dataloader.py
imperative/python/test/unit/data/test_dataloader.py
+106
-73
imperative/python/test/unit/data/test_pre_dataloader.py
imperative/python/test/unit/data/test_pre_dataloader.py
+35
-87
未找到文件。
imperative/python/megengine/data/dataloader.py
浏览文件 @
edc92ccf
# -*- coding: utf-8 -*-
import
collections
import
gc
import
math
import
itertools
import
multiprocessing
import
os
import
platform
...
...
@@ -9,7 +9,6 @@ import queue
import
random
import
threading
import
time
from
typing
import
Callable
,
Union
import
numpy
as
np
...
...
@@ -67,12 +66,6 @@ class DataLoader:
the batch. ``0`` means using single-process. Default: 0
timeout: if positive, means the timeout value(second) for collecting a
batch from workers. Default: 0
timeout_event: callback function triggered by timeout, default to raise
runtime error.
divide: define the paralleling strategy in multi-processing mode.
``True`` means one batch is divided into :attr:`num_workers` pieces, and
the workers will process these pieces parallelly. ``False`` means
different sub-process will process different batch. Default: False
preload: whether to enable the preloading strategy of the dataloader.
When enabling, the dataloader will preload one batch to the device memory to speed up the whole training process.
...
...
@@ -85,7 +78,6 @@ class DataLoader:
which will improve the training speed at the cost of **higher device memory usage** (due to one more batch data on device memory).
This feature saves more time when your NN training time is short or your machine's host PCIe bandwidth for each device is low.
"""
__initialized
=
False
def
__init__
(
self
,
...
...
@@ -95,9 +87,8 @@ class DataLoader:
collator
:
Collator
=
None
,
num_workers
:
int
=
0
,
timeout
:
int
=
0
,
timeout_event
:
Callable
=
_raise_timeout_error
,
divide
:
bool
=
False
,
preload
:
bool
=
False
,
parallel_stream
:
bool
=
False
,
):
if
num_workers
<
0
:
raise
ValueError
(
"num_workers should not be negative"
)
...
...
@@ -105,23 +96,22 @@ class DataLoader:
if
timeout
<
0
:
raise
ValueError
(
"timeout should not be negative"
)
if
divide
and
num_workers
<=
1
:
raise
ValueError
(
"divide should not be set to True when num_workers <= 1"
)
self
.
dataset
=
dataset
self
.
num_workers
=
num_workers
self
.
timeout
=
timeout
self
.
timeout_event
=
timeout_event
self
.
divide
=
divide
self
.
preload
=
preload
self
.
parallel_stream
=
parallel_stream
if
isinstance
(
dataset
,
StreamDataset
):
self
.
sampler
=
sampler
if
sampler
else
StreamSampler
(
batch_size
=
1
)
assert
isinstance
(
self
.
sampler
,
StreamSampler
),
"types of dataset and sampler do not match"
if
parallel_stream
is
False
and
self
.
num_workers
>
1
:
logger
.
warning
(
"Data time will be affected by getting origin-data, please set parallel_stream in order to speed up dataloader!"
)
self
.
datakind
=
"stream"
else
:
assert
isinstance
(
dataset
,
Dataset
...
...
@@ -134,16 +124,7 @@ class DataLoader:
assert
isinstance
(
self
.
sampler
,
MapSampler
),
"types of dataset and sampler do not match"
if
divide
:
if
self
.
sampler
.
batch_size
<=
self
.
num_workers
:
raise
ValueError
(
"batch size must not smaller than num_workers in divide mode."
)
elif
self
.
sampler
.
batch_size
%
self
.
num_workers
:
logger
.
warning
(
"batch size is not divisible by num_workers, may lose performance in divide mode."
)
self
.
datakind
=
"map"
if
transform
is
None
:
self
.
transform
=
PseudoTransform
()
...
...
@@ -155,7 +136,8 @@ class DataLoader:
else
:
self
.
collator
=
collator
self
.
__initialized
=
True
if
platform
.
system
()
==
"Linux"
and
self
.
num_workers
>
0
:
self
.
check_memory_rationality
()
def
__iter__
(
self
):
if
platform
.
system
()
==
"Windows"
and
self
.
num_workers
>
0
:
...
...
@@ -187,15 +169,50 @@ class DataLoader:
def
__len__
(
self
):
return
len
(
self
.
sampler
)
def
check_memory_rationality
(
self
):
import
psutil
main_memory
=
psutil
.
Process
(
os
.
getpid
()).
memory_info
().
rss
/
1024
/
1024
/
1024
total_memory
=
(
self
.
num_workers
+
1
)
*
main_memory
current_memory
=
(
int
(
os
.
popen
(
"cat /sys/fs/cgroup/memory/memory.limit_in_bytes"
).
read
())
/
1024
/
1024
/
1024
)
if
current_memory
<
total_memory
:
logger
.
warning
(
"Each worker need to read the shared meta-data, which will be increasing the reference count."
"Copy-On-Write propety will lead to 'memory leak', the memory usage will end up being "
+
total_memory
+
" GB"
"However the current requested memory is "
+
current_memory
+
" GB"
"Maybe you can request more memory or uesd np-array to save meta-data rather than List or Tuple"
)
class
_PreLoader
:
def
__init__
(
self
,
preload
):
def
__init__
(
self
,
loader
,
preload
):
self
.
dataset
=
loader
.
dataset
self
.
sampler
=
loader
.
sampler
self
.
seed
=
_random_seed_generator
().
__next__
()
self
.
transform
=
loader
.
transform
self
.
collator
=
loader
.
collator
self
.
num_workers
=
loader
.
num_workers
self
.
timeout
=
loader
.
timeout
self
.
num_processed
=
0
self
.
datakind
=
loader
.
datakind
self
.
parallel_stream
=
loader
.
parallel_stream
if
preload
:
self
.
default_device
=
get_default_device
()
self
.
pre_load_device
=
self
.
default_device
+
":"
+
str
(
_sh
.
get_next
())
self
.
pre_load_device_cache
=
None
self
.
preload
=
preload
def
__iter__
(
self
):
return
self
"""
strategy one: load from numpy data, and generate dtype tensor
"""
...
...
@@ -237,29 +254,176 @@ class _PreLoader:
return
out
class
_BaseMapDataLoaderIter
(
_PreLoader
):
def
__init__
(
self
,
loader
,
preload
):
super
().
__init__
(
preload
)
self
.
dataset
=
loader
.
dataset
self
.
sampler
=
loader
.
sampler
self
.
seed
=
_random_seed_generator
().
__next__
()
self
.
transform
=
loader
.
transform
self
.
collator
=
loader
.
collator
self
.
num_workers
=
loader
.
num_workers
self
.
timeout
=
loader
.
timeout
self
.
timeout_event
=
loader
.
timeout_event
self
.
divide
=
loader
.
divide
self
.
num_processed
=
0
class
_ParallelDataLoaderIter
:
def
__init__
(
self
):
self
.
_worker_queue_idx_cycle
=
itertools
.
cycle
(
range
(
self
.
num_workers
))
from
.tools._queue
import
PlasmaShmQueue
def
_get_next_batch
(
self
):
self
.
_worker_result_queue
=
PlasmaShmQueue
()
self
.
_shutdown
=
False
self
.
_workers_done_event
=
multiprocessing
.
Event
()
self
.
_index_queues
=
[]
self
.
_workers
=
[]
for
i
in
range
(
self
.
num_workers
):
index_queue
=
multiprocessing
.
Queue
()
index_queue
.
cancel_join_thread
()
w
=
multiprocessing
.
Process
(
target
=
_worker_loop
,
args
=
(
self
.
dataset
,
index_queue
,
self
.
_worker_result_queue
,
self
.
_workers_done_event
,
self
.
transform
,
self
.
collator
,
self
.
sampler
.
batch_size
,
self
.
seed
+
i
,
i
,
self
.
num_workers
,
self
.
datakind
,
self
.
parallel_stream
,
),
daemon
=
True
,
)
gc
.
collect
()
w
.
start
()
self
.
_index_queues
.
append
(
index_queue
)
self
.
_workers
.
append
(
w
)
self
.
_data_queue
=
self
.
_worker_result_queue
self
.
_reset
()
def
_try_put_index
(
self
):
raise
NotImplementedError
def
_reset
(
self
):
self
.
_sampler_iter
=
iter
(
self
.
sampler
)
self
.
_send_idx
=
0
self
.
_rcvd_idx
=
0
self
.
_task_info
=
{}
self
.
_workers_status
=
[
True
for
_
in
range
(
self
.
num_workers
)]
for
_
in
range
(
2
*
self
.
num_workers
):
self
.
_try_put_index
()
def
_process_data
(
self
,
data
):
self
.
_rcvd_idx
+=
1
self
.
_try_put_index
()
return
data
def
_get_data
(
self
):
if
self
.
timeout
>
0
:
success
,
data
=
self
.
_try_get_data
(
self
.
timeout
)
if
success
:
return
data
else
:
_raise_timeout_error
()
else
:
while
True
:
success
,
data
=
self
.
_try_get_data
()
if
success
:
return
data
def
_get_next_batch
(
self
):
while
True
:
while
self
.
_rcvd_idx
<
self
.
_send_idx
:
info
=
self
.
_task_info
[
self
.
_rcvd_idx
]
worker_id
=
info
[
0
]
if
(
len
(
info
)
==
2
or
self
.
_workers_status
[
worker_id
]
):
# has data or work is still active
break
del
self
.
_task_info
[
self
.
_rcvd_idx
]
self
.
_rcvd_idx
+=
1
else
:
self
.
_shutdown_workers
()
raise
StopIteration
if
len
(
self
.
_task_info
[
self
.
_rcvd_idx
])
==
2
:
data
=
self
.
_task_info
.
pop
(
self
.
_rcvd_idx
)[
1
]
return
self
.
_process_data
(
data
)
idx
,
data
=
self
.
_get_data
()
if
isinstance
(
data
,
int
):
# Check if StopIteration in StreamDataset
self
.
_mark_worker_as_unavailable
(
data
)
self
.
_try_put_index
()
continue
if
idx
!=
self
.
_rcvd_idx
:
self
.
_task_info
[
idx
]
+=
(
data
,)
else
:
del
self
.
_task_info
[
idx
]
return
self
.
_process_data
(
data
)
def
_try_get_data
(
self
,
timeout
=
GLOBAL_TIMEOUT
):
try
:
data
=
self
.
_data_queue
.
get
(
timeout
=
timeout
)
return
(
True
,
data
)
except
Exception
as
e
:
failed_workers
=
[]
for
worker_id
,
w
in
enumerate
(
self
.
_workers
):
if
self
.
_workers_status
[
worker_id
]
and
not
w
.
is_alive
():
failed_workers
.
append
((
worker_id
,
w
))
self
.
_mark_worker_as_unavailable
(
worker_id
)
if
w
.
exitcode
==
-
9
:
logger
.
debug
(
"Maybe memory is not enough, please request for more memory!"
)
if
len
(
failed_workers
)
>
0
:
pids_str
=
", "
.
join
(
str
(
w_info
[
1
].
pid
)
for
w_info
in
failed_workers
)
w_ids_str
=
", "
.
join
(
str
(
w_info
[
0
])
for
w_info
in
failed_workers
)
exitcode_str
=
", "
.
join
(
str
(
w_info
[
1
].
exitcode
)
for
w_info
in
failed_workers
)
raise
RuntimeError
(
"DataLoader worker (worker(s): {} , pid(s): {}) exited unexpectedly, exitcode(s): {}"
.
format
(
w_ids_str
,
pids_str
,
exitcode_str
)
)
if
isinstance
(
e
,
queue
.
Empty
):
return
(
False
,
None
)
def
_mark_worker_as_unavailable
(
self
,
worker_id
,
shutdown
=
False
):
q
=
self
.
_index_queues
[
worker_id
]
q
.
put
(
None
)
self
.
_workers_status
[
worker_id
]
=
False
assert
self
.
_workers_done_event
.
is_set
()
==
shutdown
def
_shutdown_workers
(
self
):
if
not
self
.
_shutdown
:
self
.
_shutdown
=
True
try
:
self
.
_workers_done_event
.
set
()
for
worker_id
in
range
(
len
(
self
.
_workers
)):
if
self
.
_workers_status
[
worker_id
]:
self
.
_mark_worker_as_unavailable
(
worker_id
,
shutdown
=
True
)
for
w
in
self
.
_workers
:
w
.
join
(
timeout
=
GLOBAL_TIMEOUT
)
for
q
in
self
.
_index_queues
:
q
.
cancel_join_thread
()
q
.
close
()
self
.
_data_queue
.
cancel_join_thread
()
self
.
_data_queue
.
close
()
finally
:
for
w
in
self
.
_workers
:
if
w
.
is_alive
():
w
.
terminate
()
def
__del__
(
self
):
self
.
_shutdown_workers
()
class
_BaseMapDataLoaderIter
(
_PreLoader
):
def
__init__
(
self
,
loader
,
preload
):
super
().
__init__
(
loader
,
preload
)
def
__len__
(
self
):
return
len
(
self
.
sampler
)
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
if
self
.
preload
:
cached
=
self
.
pre_load_device_cache
...
...
@@ -272,11 +436,8 @@ class _BaseMapDataLoaderIter(_PreLoader):
self
.
_try_load_tensor
()
return
out
else
:
if
self
.
num_processed
>=
len
(
self
):
raise
StopIteration
minibatch
=
self
.
_get_next_batch
()
self
.
num_processed
+=
1
return
minibatch
data
=
self
.
_get_next_batch
()
return
data
def
_try_load_tensor
(
self
,
cached
=
True
):
if
self
.
num_processed
>=
len
(
self
):
...
...
@@ -290,199 +451,69 @@ class _BaseMapDataLoaderIter(_PreLoader):
class
_SerialMapDataLoaderIter
(
_BaseMapDataLoaderIter
):
def
__init__
(
self
,
loader
,
preload
):
super
(
_SerialMapDataLoaderIter
,
self
).
__init__
(
loader
,
preload
)
self
.
indices
_iter
=
iter
(
self
.
sampler
)
self
.
_sampler
_iter
=
iter
(
self
.
sampler
)
def
_get_next_batch
(
self
):
indices
=
next
(
self
.
indices
_iter
)
indices
=
next
(
self
.
_sampler
_iter
)
items
=
[
self
.
dataset
[
idx
]
for
idx
in
indices
]
trans_items
=
self
.
transform
.
apply_batch
(
items
)
return
self
.
collator
.
apply
(
trans_items
)
class
_ParallelMapDataLoaderIter
(
_BaseMapDataLoaderIter
):
__initialized
=
False
class
_ParallelMapDataLoaderIter
(
_BaseMapDataLoaderIter
,
_ParallelDataLoaderIter
):
def
__init__
(
self
,
loader
,
preload
):
super
(
_ParallelMapDataLoaderIter
,
self
).
__init__
(
loader
,
preload
)
_BaseMapDataLoaderIter
.
__init__
(
self
,
loader
,
preload
)
_ParallelDataLoaderIter
.
__init__
(
self
)
self
.
task_queues
=
[
multiprocessing
.
Queue
(
maxsize
=
2
)
for
_
in
range
(
self
.
num_workers
)
]
self
.
feed_batch_idx
=
multiprocessing
.
Value
(
"i"
,
0
)
self
.
target_batch_idx
=
multiprocessing
.
Value
(
"i"
,
0
)
self
.
shutdown_flag
=
multiprocessing
.
Value
(
"i"
,
0
)
def
_try_put_index
(
self
):
try
:
index
=
next
(
self
.
_sampler_iter
)
except
StopIteration
:
return
for
_
in
range
(
self
.
num_workers
):
# find the next active worker, if any
worker_queue_idx
=
next
(
self
.
_worker_queue_idx_cycle
)
if
self
.
_workers_status
[
worker_queue_idx
]:
break
self
.
trans_data_queues
=
[
multiprocessing
.
Queue
(
maxsize
=
1
)
for
_
in
range
(
self
.
num_workers
)
]
self
.
_index_queues
[
worker_queue_idx
].
put
((
self
.
_send_idx
,
index
))
self
.
_task_info
[
self
.
_send_idx
]
=
(
worker_queue_idx
,
)
self
.
_send_idx
+=
1
# use shared-memory queue implemented by pyarrow plasma store.
from
.tools._queue
import
PlasmaShmQueue
self
.
batch_queue
=
PlasmaShmQueue
(
maxsize
=
2
)
_worker_info
=
None
self
.
task_feeding_worker
=
multiprocessing
.
Process
(
target
=
_task_feeding_loop
,
args
=
(
iter
(
self
.
sampler
),
self
.
task_queues
,
self
.
num_workers
,
self
.
divide
,
self
.
shutdown_flag
,
self
.
feed_batch_idx
,
),
daemon
=
True
,
)
gc
.
collect
()
self
.
task_feeding_worker
.
start
()
self
.
workers
=
[]
for
worker_id
in
range
(
self
.
num_workers
):
worker
=
multiprocessing
.
Process
(
target
=
_worker_loop
,
args
=
(
self
.
dataset
,
self
.
task_queues
[
worker_id
],
self
.
trans_data_queues
[
worker_id
],
self
.
transform
,
self
.
seed
+
worker_id
+
1
,
self
.
shutdown_flag
,
),
daemon
=
True
,
)
gc
.
collect
()
worker
.
start
()
self
.
workers
.
append
(
worker
)
if
self
.
divide
:
self
.
data_collecting_worker
=
multiprocessing
.
Process
(
target
=
_data_gathering_loop
,
args
=
(
self
.
trans_data_queues
,
self
.
batch_queue
,
self
.
collator
,
len
(
self
),
self
.
num_workers
,
self
.
shutdown_flag
,
self
.
target_batch_idx
,
),
daemon
=
True
,
)
else
:
self
.
data_collecting_worker
=
multiprocessing
.
Process
(
target
=
_data_selecting_loop
,
args
=
(
self
.
trans_data_queues
,
self
.
batch_queue
,
self
.
collator
,
len
(
self
),
self
.
num_workers
,
self
.
shutdown_flag
,
self
.
target_batch_idx
,
),
daemon
=
True
,
)
gc
.
collect
()
self
.
data_collecting_worker
.
start
()
class
WorkerInfo
(
object
):
__initialized
=
False
def
__init__
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
setattr
(
self
,
k
,
v
)
self
.
__keys
=
tuple
(
kwargs
.
keys
())
self
.
__initialized
=
True
def
_check_workers
(
self
):
# Check the status of each worker.
if
not
self
.
data_collecting_worker
.
is_alive
():
exitcode
=
self
.
data_collecting_worker
.
exitcode
if
exitcode
!=
0
:
raise
RuntimeError
(
"data collecting worker died. {}"
.
format
(
exitcode
))
if
not
self
.
task_feeding_worker
.
is_alive
():
exitcode
=
self
.
task_feeding_worker
.
exitcode
if
exitcode
!=
0
:
raise
RuntimeError
(
"task feeding worker died. {}"
.
format
(
exitcode
))
for
worker_id
,
worker
in
enumerate
(
self
.
workers
):
if
not
worker
.
is_alive
():
exitcode
=
worker
.
exitcode
if
exitcode
!=
0
:
raise
RuntimeError
(
"worker:{} died. {}"
.
format
(
worker_id
,
exitcode
))
logger
.
debug
(
"all workers are alive."
)
def
_get_next_batch
(
self
):
start_time
=
time
.
time
()
while
True
:
self
.
_check_workers
()
try
:
return
self
.
batch_queue
.
get
(
timeout
=
1
)
except
queue
.
Empty
:
logger
.
debug
(
"batch queue empty!"
)
waited_time
=
time
.
time
()
-
start_time
if
self
.
timeout
>
0
:
if
waited_time
>
self
.
timeout
:
raise
RuntimeError
(
"get_next_batch timeout!"
)
def
_shutdown
(
self
):
with
self
.
shutdown_flag
.
get_lock
():
self
.
shutdown_flag
.
value
=
1
if
self
.
task_feeding_worker
.
is_alive
():
self
.
task_feeding_worker
.
terminate
()
self
.
task_feeding_worker
.
join
()
if
self
.
data_collecting_worker
.
is_alive
():
self
.
data_collecting_worker
.
terminate
()
self
.
data_collecting_worker
.
join
()
def
__setattr__
(
self
,
key
,
val
):
if
self
.
__initialized
:
raise
RuntimeError
(
"Cannot assign attributes to {} objects"
.
format
(
self
.
__class__
.
__name__
)
)
return
super
(
WorkerInfo
,
self
).
__setattr__
(
key
,
val
)
for
worker
in
self
.
workers
:
if
worker
.
is_alive
():
worker
.
terminate
()
worker
.
join
()
def
__repr__
(
self
):
items
=
[]
for
k
in
self
.
__keys
:
items
.
append
(
"{}={}"
.
format
(
k
,
getattr
(
self
,
k
)))
return
"{}({})"
.
format
(
self
.
__class__
.
__name__
,
", "
.
join
(
items
))
for
q
in
self
.
trans_data_queues
:
q
.
cancel_join_thread
()
q
.
close
()
for
q
in
self
.
task_queues
:
q
.
cancel_join_thread
()
q
.
close
()
self
.
batch_queue
.
cancel_join_thread
()
self
.
batch_queue
.
close
()
def
__del__
(
self
):
if
self
.
__initialized
:
self
.
_shutdown
()
def
get_worker_info
():
return
_worker_info
class
_BaseStreamDataLoaderIter
(
_PreLoader
):
def
__init__
(
self
,
loader
,
preload
):
super
().
__init__
(
preload
)
self
.
dataset
=
loader
.
dataset
self
.
sampler
=
loader
.
sampler
self
.
transform
=
loader
.
transform
self
.
collator
=
loader
.
collator
self
.
num_workers
=
loader
.
num_workers
self
.
timeout
=
loader
.
timeout
self
.
timeout_event
=
loader
.
timeout_event
def
_get_next_batch
(
self
):
raise
NotImplementedError
def
_process_raw_data
(
self
,
raw_data
):
assert
len
(
raw_data
)
==
2
and
isinstance
(
raw_data
[
0
],
bool
),
"StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if
not
raw_data
[
0
]:
data
=
list
((
x
,)
for
x
in
raw_data
[
1
])
else
:
data
=
raw_data
[
1
]
ret
=
[]
for
idx
in
range
(
len
(
data
[
0
])):
ret
.
append
(
tuple
(
e
[
idx
]
for
e
in
data
))
return
ret
def
__iter__
(
self
):
return
self
super
().
__init__
(
loader
,
preload
)
self
.
dataset_iter
=
iter
(
self
.
dataset
)
def
__next__
(
self
):
if
self
.
preload
:
...
...
@@ -503,8 +534,6 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
def
__init__
(
self
,
loader
,
preload
):
super
().
__init__
(
loader
,
preload
)
self
.
dataset_iter
=
iter
(
self
.
dataset
)
self
.
idx
=
0
self
.
unused
=
[]
def
_try_get_raw_data
(
self
,
start_time
):
raw_data
=
None
...
...
@@ -516,382 +545,153 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
raw_data
=
next
(
self
.
dataset_iter
)
if
self
.
timeout
>
0
:
timer
.
cancel
()
except
KeyboardInterrupt
:
ra
w_data
=
self
.
timeout_event
()
except
AttributeError
as
error
:
ra
ise
error
except
:
if
self
.
timeout
>
0
:
timer
.
cancel
()
waited_time
=
time
.
time
()
-
start_time
if
waited_time
>
self
.
timeout
:
raw_data
=
self
.
timeout_event
()
_raise_timeout_error
()
return
raw_data
def
_get_next_batch
(
self
):
ret
=
[]
start_time
=
time
.
time
()
while
len
(
ret
)
<
self
.
sampler
.
batch_size
:
if
len
(
self
.
unused
)
!=
0
:
batch_data
=
self
.
unused
else
:
raw_data
=
self
.
_try_get_raw_data
(
start_time
)
batch_data
=
self
.
_process_raw_data
(
raw_data
)
while
len
(
batch_data
)
!=
0
and
len
(
ret
)
<
self
.
sampler
.
batch_size
:
data
=
batch_data
.
pop
()
ret
.
append
(
self
.
transform
.
apply
(
data
))
self
.
unused
=
batch_data
ret
.
append
(
self
.
transform
.
apply
(
raw_data
))
return
self
.
collator
.
apply
(
ret
)
class
_ParallelStreamDataLoaderIter
(
_BaseStreamDataLoaderIter
):
__initialized
=
False
class
_ParallelStreamDataLoaderIter
(
_BaseStreamDataLoaderIter
,
_ParallelDataLoaderIter
):
def
__init__
(
self
,
loader
,
preload
):
super
().
__init__
(
loader
,
preload
)
self
.
shutdown_flag
=
multiprocessing
.
Value
(
"i"
,
0
)
self
.
raw_data_queues
=
[
multiprocessing
.
Queue
(
maxsize
=
1
)
for
_
in
range
(
self
.
num_workers
)
]
self
.
trans_data_queues
=
[
multiprocessing
.
Queue
(
maxsize
=
1
)
for
_
in
range
(
self
.
num_workers
)
]
# shared-memory queue implemented by pyarrow plasma store
from
.tools._queue
import
PlasmaShmQueue
self
.
batch_queue
=
PlasmaShmQueue
(
maxsize
=
2
)
self
.
recieve_worker
=
multiprocessing
.
Process
(
target
=
self
.
_worker_to_raw_data_queues
,
daemon
=
True
)
gc
.
collect
()
self
.
recieve_worker
.
start
()
self
.
transform_workers
=
[]
for
worker_id
in
range
(
self
.
num_workers
):
worker
=
multiprocessing
.
Process
(
target
=
self
.
_worker_to_trans_data_queues
,
args
=
(
worker_id
,),
daemon
=
True
)
gc
.
collect
()
worker
.
start
()
self
.
transform_workers
.
append
(
worker
)
self
.
collect_worker
=
multiprocessing
.
Process
(
target
=
self
.
_worker_to_batch_queue
,
daemon
=
True
)
gc
.
collect
()
self
.
collect_worker
.
start
()
self
.
__initialized
=
True
def
_put_raw_data_queues
(
self
,
raw_data
,
qidx
):
batch_data
=
self
.
_process_raw_data
(
raw_data
)
for
data
in
batch_data
:
while
True
:
qidx
=
qidx
%
self
.
num_workers
try
:
self
.
raw_data_queues
[
qidx
].
put
(
data
)
break
except
queue
.
Full
:
if
self
.
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"raw data queue %d is full"
%
qidx
)
finally
:
qidx
+=
1
return
qidx
_BaseStreamDataLoaderIter
.
__init__
(
self
,
loader
,
preload
)
_ParallelDataLoaderIter
.
__init__
(
self
)
def
_worker_to_raw_data_queues
(
self
):
dataset_iter
=
iter
(
self
.
dataset
)
qidx
=
0
while
True
:
if
self
.
shutdown_flag
.
value
==
1
:
break
raw_data
=
next
(
dataset_iter
)
qidx
=
self
.
_put_raw_data_queues
(
raw_data
,
qidx
)
def
_get_remaind_data
(
self
,
place_holder
):
num
=
self
.
sampler
.
batch_size
for
_
in
range
(
num
-
1
):
place_holder
.
append
(
next
(
self
.
dataset_iter
))
return
place_holder
def
_worker_to_trans_data_queues
(
self
,
worker_id
):
while
True
:
if
self
.
shutdown_flag
.
value
==
1
:
break
try
:
data
=
self
.
raw_data_queues
[
worker_id
].
get
(
timeout
=
GLOBAL_TIMEOUT
)
except
queue
.
Empty
:
continue
trans_data
=
self
.
transform
.
apply
(
data
)
while
True
:
try
:
self
.
trans_data_queues
[
worker_id
].
put
(
trans_data
)
break
except
queue
.
Full
:
if
self
.
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"batch queue if full"
)
def
_worker_to_batch_queue
(
self
):
cnt
=
-
1
trans_items
=
[]
while
True
:
if
self
.
shutdown_flag
.
value
==
1
:
break
cnt
+=
1
queue_id
=
cnt
%
self
.
num_workers
try
:
trans_item
=
self
.
trans_data_queues
[
queue_id
].
get
(
timeout
=
GLOBAL_TIMEOUT
)
except
queue
.
Empty
:
continue
trans_items
.
append
(
trans_item
)
if
len
(
trans_items
)
==
self
.
sampler
.
batch_size
:
batch_data
=
self
.
collator
.
apply
(
trans_items
)
while
True
:
def
_try_put_index
(
self
):
try
:
self
.
batch_queue
.
put
(
batch_data
,
timeout
=
1
)
break
except
queue
.
Full
:
if
self
.
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"batch queue is full"
)
trans_items
=
[]
def
_check_workers
(
self
):
if
not
self
.
collect_worker
.
is_alive
():
exitcode
=
self
.
collect_worker
.
exitcode
if
exitcode
!=
0
:
raise
RuntimeError
(
"collator worker died. {}"
.
format
(
exitcode
))
for
worker_id
,
worker
in
enumerate
(
self
.
transform_workers
):
if
not
worker
.
is_alive
():
exitcode
=
worker
.
exitcode
if
exitcode
!=
0
:
raise
RuntimeError
(
"worker: {} died. {}"
.
format
(
worker_id
,
exitcode
)
)
def
_get_next_batch
(
self
):
if
self
.
parallel_stream
is
False
:
start_time
=
time
.
time
()
while
True
:
self
.
_check_workers
()
try
:
return
self
.
batch_queue
.
get
(
timeout
=
1
)
except
queue
.
Empty
:
logger
.
debug
(
"batch queue empty!"
)
place_holder
=
[
next
(
self
.
dataset_iter
)]
waited_time
=
time
.
time
()
-
start_time
if
self
.
timeout
>
0
and
waited_time
>
self
.
timeout
:
self
.
_put_raw_data_queues
(
self
.
timeout_event
(),
0
)
def
_shutdown
(
self
):
with
self
.
shutdown_flag
.
get_lock
():
self
.
shutdown_flag
.
value
=
1
if
self
.
recieve_worker
.
is_alive
():
self
.
recieve_worker
.
terminate
()
self
.
recieve_worker
.
join
()
if
self
.
collect_worker
.
is_alive
():
self
.
collect_worker
.
terminate
()
self
.
collect_worker
.
join
()
for
worker
in
self
.
transform_workers
:
if
worker
.
is_alive
():
worker
.
terminate
()
worker
.
join
()
for
q
in
self
.
raw_data_queues
:
q
.
cancel_join_thread
()
q
.
close
()
for
q
in
self
.
trans_data_queues
:
q
.
cancel_join_thread
()
q
.
close
()
self
.
batch_queue
.
cancel_join_thread
()
self
.
batch_queue
.
close
()
def
__del__
(
self
):
if
self
.
__initialized
:
self
.
_shutdown
()
def
_task_feeding_loop
(
indices_iter
,
task_queues
,
num_workers
,
divide
,
shutdown_flag
,
feed_batch_idx
):
# Feed the indices into the task queues
while
True
:
if
shutdown_flag
.
value
==
1
:
break
batch_idx
=
feed_batch_idx
.
value
try
:
indices
=
next
(
indices_iter
)
except
StopIteration
:
break
if
divide
:
# make sure all task_queues is ready for put
while
any
([
q
.
full
()
for
q
in
task_queues
]):
if
shutdown_flag
.
value
==
1
:
return
# divide into small pieces, feed to different workers.
sub_num
=
math
.
ceil
(
len
(
indices
)
/
num_workers
)
for
worker_id
in
range
(
num_workers
):
sub_indices
=
indices
[
worker_id
*
sub_num
:
(
worker_id
+
1
)
*
sub_num
]
task_queues
[
worker_id
].
put
((
batch_idx
,
sub_indices
))
_raise_timeout_error
()
place_holder
=
self
.
_get_remaind_data
(
place_holder
)
else
:
# distribute tasks to different workers uniformly.
target_id
=
batch_idx
%
num_workers
while
task_queues
[
target_id
].
full
():
if
shutdown_flag
.
value
==
1
:
place_holder
=
next
(
self
.
_sampler_iter
)
except
StopIteration
:
return
task_queues
[
target_id
].
put
((
batch_idx
,
indices
))
with
feed_batch_idx
.
get_lock
():
feed_batch_idx
.
value
+=
1
def
_worker_loop
(
dataset
,
task_queue
,
trans_data_queue
,
transform
,
seed
,
shutdown_flag
):
# Get dataset items and do the transform
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
while
True
:
if
shutdown_flag
.
value
==
1
:
for
_
in
range
(
self
.
num_workers
):
worker_queue_idx
=
next
(
self
.
_worker_queue_idx_cycle
)
if
self
.
_workers_status
[
worker_queue_idx
]:
break
try
:
batch_idx
,
indices
=
task_queue
.
get
(
timeout
=
GLOBAL_TIMEOUT
)
except
queue
.
Empty
:
continue
if
len
(
indices
)
>
0
:
items
=
[
dataset
[
idx
]
for
idx
in
indices
]
trans_items
=
transform
.
apply_batch
(
items
)
else
:
# in case of incomplete last batch
trans_items
=
()
while
True
:
try
:
trans_data_queue
.
put
((
batch_idx
,
trans_items
),
timeout
=
1
)
break
except
queue
.
Full
:
if
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"batch part queue is full!"
)
return
self
.
_index_queues
[
worker_queue_idx
].
put
((
self
.
_send_idx
,
place_holder
))
self
.
_task_info
[
self
.
_send_idx
]
=
(
worker_queue_idx
,)
self
.
_send_idx
+=
1
def
_data_gathering_loop
(
trans_data_queues
,
batch_queue
,
collator
,
length
,
num_workers
,
shutdown_flag
,
target_idx
,
):
# Gathering the small pieces of batch data into full batch data
while
True
:
if
shutdown_flag
.
value
==
1
:
break
target_batch_idx
=
target_idx
.
value
class
ManagerWatchdog
(
object
):
def
__init__
(
self
):
self
.
manager_pid
=
os
.
getppid
()
self
.
manager_dead
=
False
if
target_batch_idx
>=
length
:
break
def
is_alive
(
self
):
if
not
self
.
manager_dead
:
self
.
manager_dead
=
os
.
getppid
()
!=
self
.
manager_pid
return
not
self
.
manager_dead
full_trans_items
=
[]
for
worker_id
in
range
(
num_workers
):
while
True
:
def
stream_fetcher
(
dataset_iter
,
place_holder
,
transform
,
collate
,
parallel_stream
,
batch_size
):
data
=
[]
for
idx
in
place_holder
:
try
:
batch_idx
,
trans_items
=
trans_data_queues
[
worker_id
].
get
(
timeout
=
GLOBAL_TIMEOUT
)
break
except
queue
.
Empty
:
if
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"worker:{} data queue get timeout! target batch idx:{}"
.
format
(
worker_id
,
target_batch_idx
)
)
if
batch_idx
!=
target_batch_idx
:
raise
RuntimeError
(
"Unexperted batch_idx in data gathering loop. worker_id:{}."
.
format
(
worker_id
)
)
if
parallel_stream
is
False
:
raw_data
=
idx
else
:
full_trans_items
.
extend
(
trans_items
)
# Merge different parts into a batch.
full_batch
=
collator
.
apply
(
full_trans_items
)
raw_data
=
next
(
dataset_iter
)
trans_items
=
transform
.
apply
(
raw_data
)
data
.
append
(
trans_items
)
while
True
:
try
:
batch_queue
.
put
(
full_batch
,
timeout
=
1
)
break
except
queue
.
Full
:
if
shutdown_flag
.
value
==
1
:
except
StopIteration
:
break
logger
.
debug
(
"batch queue is full!"
)
with
target_idx
.
get_lock
():
target_idx
.
value
+=
1
if
len
(
data
)
==
0
:
raise
StopIteration
data
=
collate
.
apply
(
data
)
return
data
batch_queue
.
disconnect_client
()
def
map_fetcher
(
dataset
,
place_holder
,
transform
,
collate
,
parallel_stream
,
batch_size
):
items
=
[
dataset
[
idx
]
for
idx
in
place_holder
]
trans_items
=
transform
.
apply_batch
(
items
)
data
=
collate
.
apply
(
trans_items
)
return
data
def
_data_selecting_loop
(
trans_data_queues
,
batch_queue
,
collator
,
length
,
def
_worker_loop
(
dataset
,
index_queue
,
data_queue
,
done_event
,
transform
,
collate
,
batch_size
,
seed
,
worker_id
,
num_workers
,
shutdown_flag
,
target_idx
,
datakind
,
parallel_stream
,
):
# Make sure that batch is generated exactly with the same order as generated indices
while
True
:
if
shutdown_flag
.
value
==
1
:
break
target_batch_idx
=
target_idx
.
value
if
target_batch_idx
>=
length
:
break
target_worker_id
=
target_batch_idx
%
num_workers
while
True
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
watchdog
=
ManagerWatchdog
()
iteration_end
=
False
fetcher
=
map_fetcher
if
datakind
==
"stream"
:
global
_worker_info
_worker_info
=
WorkerInfo
(
idx
=
worker_id
,
worker
=
num_workers
,
seed
=
seed
)
dataset
=
iter
(
dataset
)
fetcher
=
stream_fetcher
while
watchdog
.
is_alive
()
:
try
:
batch_idx
,
trans_items
=
trans_data_queues
[
target_worker_id
].
get
(
timeout
=
GLOBAL_TIMEOUT
)
batch_data
=
collator
.
apply
(
trans_items
)
break
r
=
index_queue
.
get
(
timeout
=
GLOBAL_TIMEOUT
)
except
queue
.
Empty
:
if
shutdown_flag
.
value
==
1
:
continue
if
r
is
None
:
assert
done_event
.
is_set
()
or
iteration_end
break
logger
.
debug
(
"worker:{} data queue get timeout! target batch idx:{}"
.
format
(
target_worker_id
,
target_batch_idx
)
)
if
batch_idx
!=
target_batch_idx
:
raise
RuntimeError
(
"batch_idx {} mismatch the target_batch_idx {}"
.
format
(
batch_idx
,
target_batch_idx
)
)
elif
done_event
.
is_set
()
or
iteration_end
:
continue
while
True
:
idx
,
place_holder
=
r
try
:
batch_queue
.
put
(
batch_data
,
timeout
=
1
)
break
except
queue
.
Full
:
if
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"batch queue is full!"
)
with
target_idx
.
get_lock
():
target_idx
.
value
+=
1
data
=
fetcher
(
dataset
,
place_holder
,
transform
,
collate
,
parallel_stream
,
batch_size
)
except
Exception
as
e
:
if
isinstance
(
e
,
StopIteration
)
and
datakind
==
"stream"
:
data
=
worker_id
iteration_end
=
True
else
:
raise
e
data_queue
.
put
((
idx
,
data
))
del
data
,
idx
,
place_holder
,
r
batch_queue
.
disconnect_client
()
if
done_event
.
is_set
():
data_queue
.
disconnect_client
()
data_queue
.
close
()
imperative/python/megengine/data/sampler.py
浏览文件 @
edc92ccf
...
...
@@ -2,6 +2,7 @@
import
collections.abc
import
math
from
abc
import
ABC
,
abstractmethod
from
itertools
import
count
from
typing
import
Any
,
Generator
,
Iterator
,
List
,
Union
import
numpy
as
np
...
...
@@ -126,13 +127,15 @@ class MapSampler(Sampler):
if
self
.
world_size
>
1
:
indices
=
self
.
scatter
(
indices
)
step
,
length
=
self
.
batch_size
,
len
(
indices
)
batch_index
=
[
indices
[
i
:
i
+
step
]
for
i
in
range
(
0
,
length
,
step
)]
batch
=
[]
for
idx
in
indices
:
batch
.
append
(
idx
)
if
len
(
batch
)
==
self
.
batch_size
:
yield
batch
batch
=
[]
if
self
.
drop_last
and
len
(
batch_index
[
-
1
])
<
self
.
batch_size
:
batch_index
.
pop
()
return
iter
(
batch_index
)
if
len
(
batch
)
>
0
and
not
self
.
drop_last
:
yield
batch
class
StreamSampler
(
Sampler
):
...
...
@@ -151,10 +154,18 @@ class StreamSampler(Sampler):
self
.
batch_size
=
batch_size
def
__iter__
(
self
):
return
self
return
self
.
batch
()
def
__next__
(
self
):
return
iter
(
range
(
self
.
batch_size
))
def
batch
(
self
):
batch
=
[]
for
idx
in
self
.
sample
():
batch
.
append
(
idx
)
if
len
(
batch
)
==
self
.
batch_size
:
yield
batch
batch
=
[]
def
sample
(
self
):
return
count
(
start
=
0
)
class
SequentialSampler
(
MapSampler
):
...
...
imperative/python/test/unit/data/test_dataloader.py
浏览文件 @
edc92ccf
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
math
import
os
import
platform
import
time
...
...
@@ -7,7 +15,7 @@ import numpy as np
import
pytest
from
megengine.data.collator
import
Collator
from
megengine.data.dataloader
import
DataLoader
from
megengine.data.dataloader
import
DataLoader
,
get_worker_info
from
megengine.data.dataset
import
ArrayDataset
,
StreamDataset
from
megengine.data.sampler
import
RandomSampler
,
SequentialSampler
,
StreamSampler
from
megengine.data.transform
import
(
...
...
@@ -29,14 +37,10 @@ def init_dataset():
def
test_dataloader_init
():
dataset
=
init_dataset
()
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
num_workers
=
2
,
divide
=
True
)
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
num_workers
=-
1
)
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
timeout
=-
1
)
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
num_workers
=
0
,
divide
=
True
)
dataloader
=
DataLoader
(
dataset
)
assert
isinstance
(
dataloader
.
sampler
,
SequentialSampler
)
...
...
@@ -54,10 +58,8 @@ def test_dataloader_init():
class
MyStream
(
StreamDataset
):
def
__init__
(
self
,
number
,
b
atch
=
False
,
error_foramt
=
False
,
b
lock
=
False
):
def
__init__
(
self
,
number
,
block
=
False
):
self
.
number
=
number
self
.
batch
=
batch
self
.
error_format
=
error_foramt
self
.
block
=
block
def
__iter__
(
self
):
...
...
@@ -65,22 +67,14 @@ class MyStream(StreamDataset):
if
self
.
block
:
for
_
in
range
(
10
):
time
.
sleep
(
1
)
if
self
.
batch
:
data
=
np
.
random
.
randint
(
0
,
256
,
(
2
,
2
,
2
,
3
),
dtype
=
"uint8"
)
yield
(
True
,
(
data
,
[
cnt
,
cnt
-
self
.
number
]))
else
:
data
=
np
.
random
.
randint
(
0
,
256
,
(
2
,
2
,
3
),
dtype
=
"uint8"
)
if
self
.
error_format
:
yield
(
data
,
cnt
)
else
:
yield
(
False
,
(
data
,
cnt
))
raise
StopIteration
@
pytest
.
mark
.
parametrize
(
"batch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader
(
batch
,
num_workers
):
dataset
=
MyStream
(
100
,
batch
=
batch
)
def
test_stream_dataloader
(
num_workers
):
dataset
=
MyStream
(
100
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
...
...
@@ -90,7 +84,6 @@ def test_stream_dataloader(batch, num_workers):
)
check_set
=
set
()
for
step
,
data
in
enumerate
(
dataloader
):
if
step
==
10
:
break
...
...
@@ -101,18 +94,9 @@ def test_stream_dataloader(batch, num_workers):
check_set
.
add
(
i
)
def
test_stream_dataloader_error
():
dataset
=
MyStream
(
100
,
error_foramt
=
True
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
sampler
)
with
pytest
.
raises
(
AssertionError
,
match
=
r
".*tuple.*"
):
data_iter
=
iter
(
dataloader
)
next
(
data_iter
)
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader_timeout
(
num_workers
):
dataset
=
MyStream
(
100
,
False
,
block
=
True
)
dataset
=
MyStream
(
100
,
block
=
True
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
sampler
,
num_workers
=
num_workers
,
timeout
=
2
)
...
...
@@ -140,17 +124,6 @@ def test_dataloader_parallel():
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
False
,
)
for
(
data
,
label
)
in
dataloader
:
assert
data
.
shape
==
(
4
,
1
,
32
,
32
)
assert
label
.
shape
==
(
4
,)
dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
True
,
)
for
(
data
,
label
)
in
dataloader
:
assert
data
.
shape
==
(
4
,
1
,
32
,
32
)
...
...
@@ -205,7 +178,7 @@ def test_dataloader_parallel_worker_exception():
transform
=
FakeErrorTransform
(),
num_workers
=
2
,
)
with
pytest
.
raises
(
RuntimeError
,
match
=
r
"
worker.*died
"
):
with
pytest
.
raises
(
RuntimeError
,
match
=
r
"
exited unexpectedly
"
):
data_iter
=
iter
(
dataloader
)
batch_data
=
next
(
data_iter
)
...
...
@@ -213,18 +186,15 @@ def test_dataloader_parallel_worker_exception():
def
_multi_instances_parallel_dataloader_worker
():
dataset
=
init_dataset
()
for
divide_flag
in
[
True
,
False
]:
train_dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
divide_flag
,
)
val_dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
10
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
divide_flag
,
)
for
idx
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
assert
data
.
shape
==
(
4
,
1
,
32
,
32
)
...
...
@@ -261,18 +231,81 @@ def test_dataloader_parallel_multi_instances_multiprocessing():
assert
p
.
exitcode
==
0
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_timeout_event
(
num_workers
):
def
cb
():
return
(
True
,
(
np
.
zeros
(
shape
=
(
2
,
2
,
2
,
3
)),
np
.
ones
(
shape
=
(
2
,))))
def
partition
(
ls
,
size
):
return
[
ls
[
i
:
i
+
size
]
for
i
in
range
(
0
,
len
(
ls
),
size
)]
class
MyPreStream
(
StreamDataset
):
def
__init__
(
self
,
number
,
block
=
False
):
self
.
number
=
[
i
for
i
in
range
(
number
)]
self
.
block
=
block
self
.
data
=
[]
for
i
in
range
(
100
):
self
.
data
.
append
(
np
.
random
.
randint
(
0
,
256
,
(
2
,
2
,
3
),
dtype
=
"uint8"
))
def
__iter__
(
self
):
worker_info
=
get_worker_info
()
per_worker
=
int
(
math
.
ceil
((
len
(
self
.
data
))
/
float
(
worker_info
.
worker
)))
pre_data
=
iter
(
partition
(
self
.
data
,
per_worker
)[
worker_info
.
idx
])
pre_cnt
=
partition
(
self
.
number
,
per_worker
)[
worker_info
.
idx
]
for
cnt
in
pre_cnt
:
if
self
.
block
:
for
_
in
range
(
10
):
time
.
sleep
(
1
)
yield
(
next
(
pre_data
),
cnt
)
raise
StopIteration
dataset
=
MyStream
(
100
,
block
=
True
)
sampler
=
StreamSampler
(
batch_size
=
4
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"dataloader do not support parallel on windows"
,
)
def
test_prestream_dataloader_multiprocessing
():
dataset
=
MyPreStream
(
100
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
sampler
,
num_workers
=
num_workers
,
timeout
=
2
,
timeout_event
=
cb
dataset
,
sampler
,
Compose
([
Normalize
(
mean
=
(
103
,
116
,
123
),
std
=
(
57
,
57
,
58
)),
ToMode
(
"CHW"
)]),
num_workers
=
2
,
parallel_stream
=
True
,
)
for
_
,
data
in
enumerate
(
dataloader
):
np
.
testing
.
assert_equal
(
data
[
0
],
np
.
zeros
(
shape
=
(
4
,
2
,
2
,
3
)))
np
.
testing
.
assert_equal
(
data
[
1
],
np
.
ones
(
shape
=
(
4
,)))
check_set
=
set
()
for
step
,
data
in
enumerate
(
dataloader
):
if
step
==
10
:
break
assert
data
[
0
].
shape
==
(
4
,
3
,
2
,
2
)
assert
data
[
1
].
shape
==
(
4
,)
for
i
in
data
[
1
]:
assert
i
not
in
check_set
check_set
.
add
(
i
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"dataloader do not support parallel on windows"
,
)
def
test_predataloader_parallel_worker_exception
():
dataset
=
MyPreStream
(
100
)
class
FakeErrorTransform
(
Transform
):
def
__init__
(
self
):
pass
def
apply
(
self
,
input
):
raise
RuntimeError
(
"test raise error"
)
return
input
dataloader
=
DataLoader
(
dataset
,
sampler
=
StreamSampler
(
batch_size
=
4
),
transform
=
FakeErrorTransform
(),
num_workers
=
2
,
parallel_stream
=
True
,
)
with
pytest
.
raises
(
RuntimeError
,
match
=
r
"exited unexpectedly"
):
data_iter
=
iter
(
dataloader
)
batch_data
=
next
(
data_iter
)
print
(
batch_data
.
shape
)
imperative/python/test/unit/data/test_pre_dataloader.py
浏览文件 @
edc92ccf
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
gc
import
math
import
os
import
platform
import
time
...
...
@@ -8,7 +16,7 @@ import numpy as np
import
pytest
from
megengine.data.collator
import
Collator
from
megengine.data.dataloader
import
DataLoader
from
megengine.data.dataloader
import
DataLoader
,
get_worker_info
from
megengine.data.dataset
import
ArrayDataset
,
StreamDataset
from
megengine.data.sampler
import
RandomSampler
,
SequentialSampler
,
StreamSampler
from
megengine.data.transform
import
(
...
...
@@ -30,14 +38,10 @@ def init_dataset():
def
test_dataloader_init
():
dataset
=
init_dataset
()
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
num_workers
=
2
,
divide
=
True
)
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
num_workers
=-
1
)
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
timeout
=-
1
)
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
num_workers
=
0
,
divide
=
True
)
dataloader
=
DataLoader
(
dataset
,
preload
=
True
)
assert
isinstance
(
dataloader
.
sampler
,
SequentialSampler
)
...
...
@@ -59,10 +63,8 @@ def test_dataloader_init():
class
MyStream
(
StreamDataset
):
def
__init__
(
self
,
number
,
b
atch
=
False
,
error_foramt
=
False
,
b
lock
=
False
):
def
__init__
(
self
,
number
,
block
=
False
):
self
.
number
=
number
self
.
batch
=
batch
self
.
error_format
=
error_foramt
self
.
block
=
block
def
__iter__
(
self
):
...
...
@@ -70,22 +72,14 @@ class MyStream(StreamDataset):
if
self
.
block
:
for
_
in
range
(
10
):
time
.
sleep
(
1
)
if
self
.
batch
:
data
=
np
.
random
.
randint
(
0
,
256
,
(
2
,
2
,
2
,
3
),
dtype
=
"uint8"
)
yield
(
True
,
(
data
,
[
cnt
,
cnt
-
self
.
number
]))
else
:
data
=
np
.
random
.
randint
(
0
,
256
,
(
2
,
2
,
3
),
dtype
=
"uint8"
)
if
self
.
error_format
:
yield
(
data
,
cnt
)
else
:
yield
(
False
,
(
data
,
cnt
))
raise
StopIteration
@
pytest
.
mark
.
parametrize
(
"batch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader
(
batch
,
num_workers
):
dataset
=
MyStream
(
100
,
batch
=
batch
)
def
test_stream_dataloader
(
num_workers
):
dataset
=
MyStream
(
100
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
...
...
@@ -107,18 +101,9 @@ def test_stream_dataloader(batch, num_workers):
check_set
.
add
(
i
)
def
test_stream_dataloader_error
():
dataset
=
MyStream
(
100
,
error_foramt
=
True
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
sampler
,
preload
=
True
)
with
pytest
.
raises
(
AssertionError
,
match
=
r
".*tuple.*"
):
data_iter
=
iter
(
dataloader
)
next
(
data_iter
)
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader_timeout
(
num_workers
):
dataset
=
MyStream
(
100
,
False
,
block
=
True
)
dataset
=
MyStream
(
100
,
block
=
True
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
...
...
@@ -150,18 +135,6 @@ def test_dataloader_parallel():
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
False
,
preload
=
True
,
)
for
(
data
,
label
)
in
dataloader
:
assert
data
.
_tuple_shape
==
(
4
,
1
,
32
,
32
)
assert
label
.
_tuple_shape
==
(
4
,)
dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
True
,
preload
=
True
,
)
for
(
data
,
label
)
in
dataloader
:
...
...
@@ -219,7 +192,7 @@ def test_dataloader_parallel_worker_exception():
num_workers
=
2
,
preload
=
True
,
)
with
pytest
.
raises
(
RuntimeError
,
match
=
r
"
worker.*died
"
):
with
pytest
.
raises
(
RuntimeError
,
match
=
r
"
exited unexpectedly
"
):
data_iter
=
iter
(
dataloader
)
batch_data
=
next
(
data_iter
)
...
...
@@ -227,19 +200,16 @@ def test_dataloader_parallel_worker_exception():
def
_multi_instances_parallel_dataloader_worker
():
dataset
=
init_dataset
()
for
divide_flag
in
[
True
,
False
]:
train_dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
divide_flag
,
preload
=
True
,
)
val_dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
10
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
divide_flag
,
preload
=
True
,
)
for
idx
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
...
...
@@ -276,25 +246,3 @@ def test_dataloader_parallel_multi_instances_multiprocessing():
for
p
in
processes
:
p
.
join
()
assert
p
.
exitcode
==
0
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_timeout_event
(
num_workers
):
def
cb
():
return
(
True
,
(
np
.
zeros
(
shape
=
(
2
,
2
,
2
,
3
)),
np
.
ones
(
shape
=
(
2
,))))
dataset
=
MyStream
(
100
,
block
=
True
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
sampler
,
num_workers
=
num_workers
,
timeout
=
2
,
timeout_event
=
cb
,
preload
=
True
,
)
for
_
,
data
in
enumerate
(
dataloader
):
np
.
testing
.
assert_equal
(
data
[
0
],
np
.
zeros
(
shape
=
(
4
,
2
,
2
,
3
)))
np
.
testing
.
assert_equal
(
data
[
1
],
np
.
ones
(
shape
=
(
4
,)))
break
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录