Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c418d3cd
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c418d3cd
编写于
1月 14, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/data): add timeout event
GitOrigin-RevId: 43f2ba1456ce027e59ea2e09b9f9b795bb2e802f
上级
0f739c11
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
116 addition
and
78 deletion
+116
-78
imperative/python/megengine/data/dataloader.py
imperative/python/megengine/data/dataloader.py
+89
-68
imperative/python/megengine/data/dataset/meta_dataset.py
imperative/python/megengine/data/dataset/meta_dataset.py
+1
-1
imperative/python/test/unit/data/test_dataloader.py
imperative/python/test/unit/data/test_dataloader.py
+26
-9
未找到文件。
imperative/python/megengine/data/dataloader.py
浏览文件 @
c418d3cd
...
...
@@ -14,6 +14,7 @@ import queue
import
random
import
threading
import
time
from
typing
import
Callable
import
numpy
as
np
...
...
@@ -36,6 +37,10 @@ logger = get_logger(__name__)
GLOBAL_TIMEOUT
=
5
def
raise_timeout_error
():
raise
RuntimeError
(
"dataloader timeout"
)
class
DataLoader
:
__initialized
=
False
...
...
@@ -46,7 +51,8 @@ class DataLoader:
transform
:
Transform
=
None
,
collator
:
Collator
=
None
,
num_workers
:
int
=
0
,
timeout
:
int
=
GLOBAL_TIMEOUT
,
timeout
:
int
=
0
,
timeout_event
:
Callable
=
raise_timeout_error
,
divide
:
bool
=
False
,
):
r
"""
...
...
@@ -71,6 +77,9 @@ class DataLoader:
:type timeout: int
:param timeout: if positive, means the timeout value(second) for collecting a
batch from workers. Default: 0
:type timeout_event: Callable
:param timeout_event: callback function triggered by timeout, default to raise
runtime error.
:type divide: bool
:param divide: define the paralleling strategy in multi-processing mode.
``True`` means one batch is divided into :attr:`num_workers` pieces, and
...
...
@@ -92,6 +101,7 @@ class DataLoader:
self
.
num_workers
=
num_workers
self
.
timeout
=
timeout
self
.
timeout_event
=
timeout_event
self
.
divide
=
divide
...
...
@@ -168,6 +178,7 @@ class _BaseMapDataLoaderIter:
self
.
collator
=
loader
.
collator
self
.
num_workers
=
loader
.
num_workers
self
.
timeout
=
loader
.
timeout
self
.
timeout_event
=
loader
.
timeout_event
self
.
divide
=
loader
.
divide
self
.
num_processed
=
0
...
...
@@ -306,7 +317,7 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
logger
.
debug
(
"all workers are alive."
)
def
_
try_
get_next_batch
(
self
):
def
_get_next_batch
(
self
):
start_time
=
time
.
time
()
while
True
:
self
.
_check_workers
()
...
...
@@ -319,10 +330,6 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
if
waited_time
>
self
.
timeout
:
raise
RuntimeError
(
"get_next_batch timeout!"
)
def
_get_next_batch
(
self
):
batch_data
=
self
.
_try_get_next_batch
()
return
batch_data
def
_shutdown
(
self
):
with
self
.
shutdown_flag
.
get_lock
():
self
.
shutdown_flag
.
value
=
1
...
...
@@ -364,10 +371,24 @@ class _BaseStreamDataLoaderIter:
self
.
collator
=
loader
.
collator
self
.
num_workers
=
loader
.
num_workers
self
.
timeout
=
loader
.
timeout
self
.
timeout_event
=
loader
.
timeout_event
def
_get_next_batch
(
self
):
raise
NotImplementedError
def
_process_raw_data
(
self
,
raw_data
):
assert
len
(
raw_data
)
==
2
and
isinstance
(
raw_data
[
0
],
bool
),
"StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if
not
raw_data
[
0
]:
data
=
list
((
x
,)
for
x
in
raw_data
[
1
])
else
:
data
=
raw_data
[
1
]
ret
=
[]
for
idx
in
range
(
len
(
data
[
0
])):
ret
.
append
(
tuple
(
e
[
idx
]
for
e
in
data
))
return
ret
def
__iter__
(
self
):
return
self
...
...
@@ -380,42 +401,43 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
super
().
__init__
(
loader
)
self
.
dataset_iter
=
iter
(
self
.
dataset
)
self
.
idx
=
0
self
.
data
=
None
self
.
unused
=
[]
def
_get_next_batch
(
self
):
ret
=
[]
while
len
(
ret
)
!=
self
.
sampler
.
batch_size
:
if
self
.
idx
!=
0
:
data
=
self
.
data
else
:
try
:
def
_try_get_raw_data
(
self
,
start_time
):
raw_data
=
None
while
not
raw_data
:
try
:
if
self
.
timeout
>
0
:
timer
=
threading
.
Timer
(
self
.
timeout
,
thread
.
interrupt_main
)
timer
.
start
()
raw_data
=
next
(
self
.
dataset_iter
)
raw_data
=
next
(
self
.
dataset_iter
)
if
self
.
timeout
>
0
:
timer
.
cancel
()
except
KeyboardInterrupt
:
raise
RuntimeError
(
"get_next_batch timeout!"
)
except
:
except
KeyboardInterrupt
:
raw_data
=
self
.
timeout_event
()
except
:
if
self
.
timeout
>
0
:
timer
.
cancel
()
continue
assert
len
(
raw_data
)
==
2
and
isinstance
(
raw_data
[
0
],
bool
),
"StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if
not
raw_data
[
0
]:
data
=
list
((
x
,)
for
x
in
raw_data
[
1
])
else
:
data
=
raw_data
[
1
]
for
idx
in
range
(
self
.
idx
,
len
(
data
[
0
])):
trans_data
=
self
.
transform
.
apply
(
tuple
(
e
[
idx
]
for
e
in
data
))
ret
.
append
(
trans_data
)
if
len
(
ret
)
==
self
.
sampler
.
batch_size
:
if
idx
+
1
==
len
(
data
[
0
]):
self
.
idx
=
0
self
.
data
=
None
else
:
self
.
idx
=
idx
self
.
data
=
data
break
waited_time
=
time
.
time
()
-
start_time
if
waited_time
>
self
.
timeout
:
raw_data
=
self
.
timeout_event
()
return
raw_data
def
_get_next_batch
(
self
):
ret
=
[]
start_time
=
time
.
time
()
while
len
(
ret
)
<
self
.
sampler
.
batch_size
:
if
len
(
self
.
unused
)
!=
0
:
batch_data
=
self
.
unused
else
:
raw_data
=
self
.
_try_get_raw_data
(
start_time
)
batch_data
=
self
.
_process_raw_data
(
raw_data
)
while
len
(
batch_data
)
!=
0
and
len
(
ret
)
<
self
.
sampler
.
batch_size
:
data
=
batch_data
.
pop
()
ret
.
append
(
self
.
transform
.
apply
(
data
))
self
.
unused
=
batch_data
return
self
.
collator
.
apply
(
ret
)
...
...
@@ -440,49 +462,52 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
self
.
batch_queue
=
PlasmaShmQueue
(
maxsize
=
2
)
self
.
recieve_worker
=
multiprocessing
.
Process
(
target
=
self
.
_recieve
,
daemon
=
True
)
self
.
recieve_worker
=
multiprocessing
.
Process
(
target
=
self
.
_worker_to_raw_data_queues
,
daemon
=
True
)
self
.
recieve_worker
.
start
()
self
.
transform_workers
=
[]
for
worker_id
in
range
(
self
.
num_workers
):
worker
=
multiprocessing
.
Process
(
target
=
self
.
_
transform
,
args
=
(
worker_id
,),
daemon
=
True
target
=
self
.
_
worker_to_trans_data_queues
,
args
=
(
worker_id
,),
daemon
=
True
)
worker
.
start
()
self
.
transform_workers
.
append
(
worker
)
self
.
collect_worker
=
multiprocessing
.
Process
(
target
=
self
.
_collect
,
daemon
=
True
)
self
.
collect_worker
=
multiprocessing
.
Process
(
target
=
self
.
_worker_to_batch_queue
,
daemon
=
True
)
self
.
collect_worker
.
start
()
self
.
__initialized
=
True
def
_recieve
(
self
):
def
_put_raw_data_queues
(
self
,
raw_data
,
qidx
):
batch_data
=
self
.
_process_raw_data
(
raw_data
)
for
data
in
batch_data
:
while
True
:
qidx
=
qidx
%
self
.
num_workers
try
:
self
.
raw_data_queues
[
qidx
].
put
(
data
)
break
except
queue
.
Full
:
if
self
.
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"raw data queue %d is full"
%
qidx
)
finally
:
qidx
+=
1
return
qidx
def
_worker_to_raw_data_queues
(
self
):
dataset_iter
=
iter
(
self
.
dataset
)
cnt
=
-
1
qidx
=
0
while
True
:
if
self
.
shutdown_flag
.
value
==
1
:
break
raw_data
=
next
(
dataset_iter
)
assert
len
(
raw_data
)
==
2
and
isinstance
(
raw_data
[
0
],
bool
),
"StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if
not
raw_data
[
0
]:
data
=
list
((
x
,)
for
x
in
raw_data
[
1
])
else
:
data
=
raw_data
[
1
]
for
idx
in
range
(
len
(
data
[
0
])):
while
True
:
cnt
+=
1
qid
=
cnt
%
self
.
num_workers
try
:
self
.
raw_data_queues
[
qid
].
put
(
tuple
(
e
[
idx
]
for
e
in
data
))
break
except
queue
.
Full
:
if
self
.
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"raw data queue is full"
)
qidx
=
self
.
_put_raw_data_queues
(
raw_data
,
qidx
)
def
_
transform
(
self
,
worker_id
):
def
_
worker_to_trans_data_queues
(
self
,
worker_id
):
while
True
:
if
self
.
shutdown_flag
.
value
==
1
:
break
...
...
@@ -500,7 +525,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
break
logger
.
debug
(
"batch queue if full"
)
def
_
collect
(
self
):
def
_
worker_to_batch_queue
(
self
):
cnt
=
-
1
trans_items
=
[]
while
True
:
...
...
@@ -541,7 +566,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
"worker: {} died. {}"
.
format
(
worker_id
,
exitcode
)
)
def
_
try_
get_next_batch
(
self
):
def
_get_next_batch
(
self
):
start_time
=
time
.
time
()
while
True
:
self
.
_check_workers
()
...
...
@@ -551,11 +576,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
logger
.
debug
(
"batch queue empty!"
)
waited_time
=
time
.
time
()
-
start_time
if
self
.
timeout
>
0
and
waited_time
>
self
.
timeout
:
raise
RuntimeError
(
"get_next_batch timeout!"
)
def
_get_next_batch
(
self
):
batch_data
=
self
.
_try_get_next_batch
()
return
batch_data
self
.
_put_raw_data_queues
(
self
.
timeout_event
(),
0
)
def
_shutdown
(
self
):
with
self
.
shutdown_flag
.
get_lock
():
...
...
imperative/python/megengine/data/dataset/meta_dataset.py
浏览文件 @
c418d3cd
...
...
@@ -43,7 +43,7 @@ class StreamDataset(Dataset):
def
__iter__
(
self
):
pass
def
__getitem__
(
self
):
def
__getitem__
(
self
,
idx
):
raise
AssertionError
(
"can not get item from StreamDataset by index"
)
def
__len__
(
self
):
...
...
imperative/python/test/unit/data/test_dataloader.py
浏览文件 @
c418d3cd
...
...
@@ -61,10 +61,10 @@ def test_dataloader_init():
class
MyStream
(
StreamDataset
):
def
__init__
(
self
,
number
,
batch
=
False
,
error
=
False
,
block
=
False
):
def
__init__
(
self
,
number
,
batch
=
False
,
error
_foramt
=
False
,
block
=
False
):
self
.
number
=
number
self
.
batch
=
batch
self
.
error
=
error
self
.
error
_format
=
error_foramt
self
.
block
=
block
def
__iter__
(
self
):
...
...
@@ -73,11 +73,11 @@ class MyStream(StreamDataset):
for
_
in
range
(
10
):
time
.
sleep
(
1
)
if
self
.
batch
:
data
=
np
.
random
.
randint
(
0
,
256
,
(
2
,
32
,
3
2
,
3
),
dtype
=
"uint8"
)
data
=
np
.
random
.
randint
(
0
,
256
,
(
2
,
2
,
2
,
3
),
dtype
=
"uint8"
)
yield
(
True
,
(
data
,
[
cnt
,
cnt
-
self
.
number
]))
else
:
data
=
np
.
random
.
randint
(
0
,
256
,
(
32
,
3
2
,
3
),
dtype
=
"uint8"
)
if
self
.
error
:
data
=
np
.
random
.
randint
(
0
,
256
,
(
2
,
2
,
3
),
dtype
=
"uint8"
)
if
self
.
error
_format
:
yield
(
data
,
cnt
)
else
:
yield
(
False
,
(
data
,
cnt
))
...
...
@@ -87,7 +87,7 @@ class MyStream(StreamDataset):
@
pytest
.
mark
.
parametrize
(
"batch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader
(
batch
,
num_workers
):
dataset
=
MyStream
(
100
,
batch
)
dataset
=
MyStream
(
100
,
batch
=
batch
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
...
...
@@ -101,7 +101,7 @@ def test_stream_dataloader(batch, num_workers):
for
step
,
data
in
enumerate
(
dataloader
):
if
step
==
10
:
break
assert
data
[
0
].
shape
==
(
4
,
3
,
32
,
3
2
)
assert
data
[
0
].
shape
==
(
4
,
3
,
2
,
2
)
assert
data
[
1
].
shape
==
(
4
,)
for
i
in
data
[
1
]:
assert
i
not
in
check_set
...
...
@@ -109,7 +109,7 @@ def test_stream_dataloader(batch, num_workers):
def
test_stream_dataloader_error
():
dataset
=
MyStream
(
100
,
error
=
True
)
dataset
=
MyStream
(
100
,
error
_foramt
=
True
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
sampler
)
with
pytest
.
raises
(
AssertionError
,
match
=
r
".*tuple.*"
):
...
...
@@ -122,7 +122,7 @@ def test_stream_dataloader_timeout(num_workers):
dataset
=
MyStream
(
100
,
False
,
block
=
True
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
sampler
,
num_workers
=
num_workers
,
timeout
=
5
)
dataloader
=
DataLoader
(
dataset
,
sampler
,
num_workers
=
num_workers
,
timeout
=
2
)
with
pytest
.
raises
(
RuntimeError
,
match
=
r
".*timeout.*"
):
data_iter
=
iter
(
dataloader
)
next
(
data_iter
)
...
...
@@ -264,3 +264,20 @@ def test_dataloader_parallel_multi_instances_multiprocessing():
for
p
in
processes
:
p
.
join
()
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_timeout_event
(
num_workers
):
def
cb
():
return
(
True
,
(
np
.
zeros
(
shape
=
(
2
,
2
,
2
,
3
)),
np
.
ones
(
shape
=
(
2
,))))
dataset
=
MyStream
(
100
,
block
=
True
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
sampler
,
num_workers
=
num_workers
,
timeout
=
2
,
timeout_event
=
cb
)
for
_
,
data
in
enumerate
(
dataloader
):
np
.
testing
.
assert_equal
(
data
[
0
],
np
.
zeros
(
shape
=
(
4
,
2
,
2
,
3
)))
np
.
testing
.
assert_equal
(
data
[
1
],
np
.
ones
(
shape
=
(
4
,)))
break
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录