Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2c2caf33
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2c2caf33
编写于
3月 27, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge/data/dataloader): refactor the implementation of parallel dataloader
GitOrigin-RevId: 0554ee8427c7d892557422c1ee57597b7c88756b
上级
364dafcc
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
228 addition
and
274 deletion
+228
-274
python_module/megengine/data/dataloader.py
python_module/megengine/data/dataloader.py
+228
-274
未找到文件。
python_module/megengine/data/dataloader.py
浏览文件 @
2c2caf33
...
...
@@ -15,9 +15,8 @@ import time
import
numpy
as
np
import
megengine
as
mge
from
..logger
import
get_logger
from
..random.rng
import
_random_seed_generator
from
.collator
import
Collator
from
.dataset
import
Dataset
from
.sampler
import
Sampler
,
SequentialSampler
...
...
@@ -87,8 +86,6 @@ class DataLoader:
self
.
divide
=
divide
self
.
rng
=
np
.
random
.
RandomState
()
if
sampler
is
None
:
self
.
sampler
=
SequentialSampler
(
dataset
,
batch_size
=
1
,
drop_last
=
False
)
else
:
...
...
@@ -130,7 +127,7 @@ class _BaseDataLoaderIter:
def
__init__
(
self
,
loader
):
self
.
dataset
=
loader
.
dataset
self
.
sampler
=
loader
.
sampler
self
.
seed
=
loader
.
rng
.
randint
(
1e9
)
self
.
seed
=
_random_seed_generator
().
__next__
(
)
self
.
transform
=
loader
.
transform
self
.
collator
=
loader
.
collator
self
.
num_workers
=
loader
.
num_workers
...
...
@@ -173,10 +170,6 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
def
__init__
(
self
,
loader
):
super
(
_ParallelDataLoaderIter
,
self
).
__init__
(
loader
)
# if any worker died, all workers will be shutdown.
self
.
strict
=
True
# TODO: put `strict` into DataLoader args or not?
self
.
task_queues
=
[
multiprocessing
.
Queue
(
maxsize
=
2
)
for
_
in
range
(
self
.
num_workers
)
]
...
...
@@ -185,7 +178,7 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
self
.
target_batch_idx
=
multiprocessing
.
Value
(
"i"
,
0
)
self
.
shutdown_flag
=
multiprocessing
.
Value
(
"i"
,
0
)
self
.
batch_part
_queues
=
[
self
.
trans_data
_queues
=
[
multiprocessing
.
Queue
(
maxsize
=
1
)
for
_
in
range
(
self
.
num_workers
)
]
...
...
@@ -195,8 +188,15 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
self
.
batch_queue
=
PlasmaShmQueue
(
maxsize
=
2
)
self
.
task_feeding_worker
=
multiprocessing
.
Process
(
target
=
self
.
_task_feeding_loop
,
args
=
(
iter
(
self
.
sampler
),
self
.
divide
),
target
=
_task_feeding_loop
,
args
=
(
iter
(
self
.
sampler
),
self
.
task_queues
,
self
.
num_workers
,
self
.
divide
,
self
.
shutdown_flag
,
self
.
feed_batch_idx
,
),
daemon
=
True
,
)
self
.
task_feeding_worker
.
start
()
...
...
@@ -204,13 +204,14 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
self
.
workers
=
[]
for
worker_id
in
range
(
self
.
num_workers
):
worker
=
multiprocessing
.
Process
(
target
=
self
.
_worker_loop
,
target
=
_worker_loop
,
args
=
(
self
.
dataset
,
self
.
task_queues
[
worker_id
],
self
.
batch_part
_queues
[
worker_id
],
self
.
trans_data
_queues
[
worker_id
],
self
.
transform
,
self
.
collator
,
self
.
seed
+
worker_id
+
1
,
self
.
shutdown_flag
,
),
daemon
=
True
,
)
...
...
@@ -219,191 +220,257 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
if
self
.
divide
:
self
.
data_collecting_worker
=
multiprocessing
.
Process
(
target
=
self
.
_data_gathering_loop
,
args
=
(
self
.
batch_part_queues
,
self
.
batch_queue
,),
target
=
_data_gathering_loop
,
args
=
(
self
.
trans_data_queues
,
self
.
batch_queue
,
self
.
collator
,
len
(
self
),
self
.
num_workers
,
self
.
shutdown_flag
,
self
.
target_batch_idx
,
),
daemon
=
True
,
)
else
:
self
.
data_collecting_worker
=
multiprocessing
.
Process
(
target
=
self
.
_data_selecting_loop
,
args
=
(
self
.
batch_part_queues
,
self
.
batch_queue
,),
target
=
_data_selecting_loop
,
args
=
(
self
.
trans_data_queues
,
self
.
batch_queue
,
self
.
collator
,
len
(
self
),
self
.
num_workers
,
self
.
shutdown_flag
,
self
.
target_batch_idx
,
),
daemon
=
True
,
)
self
.
data_collecting_worker
.
start
()
self
.
__initialized
=
True
def
_task_feeding_loop
(
self
,
indices_iter
,
divide
):
def
_check_workers
(
self
):
# Check the status of each worker.
if
not
self
.
data_collecting_worker
.
is_alive
():
exitcode
=
self
.
task_feeding_worker
.
exitcode
if
exitcode
!=
0
:
raise
RuntimeError
(
"data collecting worker died. {}"
.
format
(
exitcode
))
if
not
self
.
task_feeding_worker
.
is_alive
():
exitcode
=
self
.
task_feeding_worker
.
exitcode
if
exitcode
!=
0
:
raise
RuntimeError
(
"task feeding worker died. {}"
.
format
(
exitcode
))
for
worker_id
,
worker
in
enumerate
(
self
.
workers
):
if
not
worker
.
is_alive
():
exitcode
=
worker
.
exitcode
if
exitcode
!=
0
:
raise
RuntimeError
(
"worker:{} died. {}"
.
format
(
worker_id
,
exitcode
))
logger
.
debug
(
"all workers are alive."
)
def
_try_get_next_batch
(
self
):
start_time
=
time
.
time
()
while
True
:
self
.
_check_workers
()
try
:
return
self
.
batch_queue
.
get
(
timeout
=
1
)
except
queue
.
Empty
:
logger
.
debug
(
"batch queue empty!"
)
waited_time
=
time
.
time
()
-
start_time
if
self
.
timeout
>
0
:
if
waited_time
>
self
.
timeout
:
raise
RuntimeError
(
"get_next_batch timeout!"
)
def
_get_next_batch
(
self
):
batch_data
=
self
.
_try_get_next_batch
()
return
batch_data
def
_shutdown
(
self
):
with
self
.
shutdown_flag
.
get_lock
():
self
.
shutdown_flag
.
value
=
1
if
self
.
task_feeding_worker
.
is_alive
():
self
.
task_feeding_worker
.
terminate
()
self
.
task_feeding_worker
.
join
()
if
self
.
data_collecting_worker
.
is_alive
():
self
.
data_collecting_worker
.
terminate
()
self
.
data_collecting_worker
.
join
()
for
worker
in
self
.
workers
:
if
worker
.
is_alive
():
worker
.
terminate
()
worker
.
join
()
for
q
in
self
.
trans_data_queues
:
q
.
cancel_join_thread
()
q
.
close
()
for
q
in
self
.
task_queues
:
q
.
cancel_join_thread
()
q
.
close
()
self
.
batch_queue
.
cancel_join_thread
()
self
.
batch_queue
.
close
()
def
__del__
(
self
):
if
self
.
__initialized
:
self
.
_shutdown
()
def
_task_feeding_loop
(
indices_iter
,
task_queues
,
num_workers
,
divide
,
shutdown_flag
,
feed_batch_idx
):
# Feed the indices into the task queues
while
True
:
if
self
.
shutdown_flag
.
value
==
1
:
if
shutdown_flag
.
value
==
1
:
break
batch_idx
=
self
.
feed_batch_idx
.
value
batch_idx
=
feed_batch_idx
.
value
try
:
indices
=
next
(
indices_iter
)
except
StopIteration
:
break
if
divide
:
# make sure all task_queues is ready for put
while
any
([
q
.
full
()
for
q
in
self
.
task_queues
]):
if
self
.
shutdown_flag
.
value
==
1
:
while
any
([
q
.
full
()
for
q
in
task_queues
]):
if
shutdown_flag
.
value
==
1
:
return
# divide into small pieces, feed to different workers.
sub_num
=
math
.
ceil
(
len
(
indices
)
/
self
.
num_workers
)
for
worker_id
in
range
(
self
.
num_workers
):
sub_indices
=
indices
[
worker_id
*
sub_num
:
(
worker_id
+
1
)
*
sub_num
]
self
.
task_queues
[
worker_id
].
put
((
batch_idx
,
sub_indices
))
sub_num
=
math
.
ceil
(
len
(
indices
)
/
num_workers
)
for
worker_id
in
range
(
num_workers
):
sub_indices
=
indices
[
worker_id
*
sub_num
:
(
worker_id
+
1
)
*
sub_num
]
task_queues
[
worker_id
].
put
((
batch_idx
,
sub_indices
))
else
:
# distribute tasks to different workers uniformly.
target_id
=
batch_idx
%
self
.
num_workers
while
self
.
task_queues
[
target_id
].
full
():
if
self
.
shutdown_flag
.
value
==
1
:
target_id
=
batch_idx
%
num_workers
while
task_queues
[
target_id
].
full
():
if
shutdown_flag
.
value
==
1
:
return
self
.
task_queues
[
target_id
].
put
((
batch_idx
,
indices
))
with
self
.
feed_batch_idx
.
get_lock
():
self
.
feed_batch_idx
.
value
+=
1
task_queues
[
target_id
].
put
((
batch_idx
,
indices
))
with
feed_batch_idx
.
get_lock
():
feed_batch_idx
.
value
+=
1
def
_worker_loop
(
self
,
task_queue
,
data_queue
,
transform
,
collator
,
seed
):
def
_worker_loop
(
dataset
,
task_queue
,
trans_data_queue
,
transform
,
seed
,
shutdown_flag
):
# Get dataset items and do the transform
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
while
True
:
if
self
.
shutdown_flag
.
value
==
1
:
if
shutdown_flag
.
value
==
1
:
break
try
:
batch_idx
,
indices
=
task_queue
.
get
(
timeout
=
MP_QUEUE_GET_TIMEOUT
)
except
queue
.
Empty
:
continue
if
len
(
indices
)
>
0
:
items
=
[
self
.
dataset
[
idx
]
for
idx
in
indices
]
items
=
[
dataset
[
idx
]
for
idx
in
indices
]
trans_items
=
transform
.
apply_batch
(
items
)
batch_data
=
collator
.
apply
(
trans_items
)
else
:
# in case of incomplete last batch
batch_data
=
()
trans_items
=
()
while
True
:
try
:
data_queue
.
put
((
np
.
array
([
batch_idx
]),
batch_data
),
timeout
=
1
)
trans_data_queue
.
put
((
batch_idx
,
trans_items
),
timeout
=
1
)
break
except
queue
.
Full
:
if
self
.
shutdown_flag
.
value
==
1
:
if
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"batch part queue is full!"
)
continue
def
_data_gathering_loop
(
self
,
batch_part_queues
,
batch_queue
):
r
"""Gathering the small pieces of batch data into full batch data."""
gathered_data
=
collections
.
defaultdict
(
dict
)
def
_data_gathering_loop
(
trans_data_queues
,
batch_queue
,
collator
,
length
,
num_workers
,
shutdown_flag
,
target_idx
,
):
# Gathering the small pieces of batch data into full batch data
while
True
:
if
self
.
shutdown_flag
.
value
==
1
:
if
shutdown_flag
.
value
==
1
:
break
target_batch_idx
=
self
.
target_batch
_idx
.
value
target_batch_idx
=
target
_idx
.
value
if
target_batch_idx
>=
len
(
self
)
:
if
target_batch_idx
>=
length
:
break
for
worker_id
in
range
(
self
.
num_workers
):
if
worker_id
in
gathered_data
[
target_batch_idx
]:
continue
full_trans_items
=
[]
for
worker_id
in
range
(
num_workers
):
while
True
:
try
:
(
batch_idx
,),
batch_part
=
batch_part
_queues
[
worker_id
].
get
(
batch_idx
,
trans_items
=
trans_data
_queues
[
worker_id
].
get
(
timeout
=
MP_QUEUE_GET_TIMEOUT
)
break
except
queue
.
Empty
:
if
self
.
shutdown_flag
.
value
==
1
:
if
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"worker:{} data queue get timeout! target batch idx:{}"
.
format
(
worker_id
,
target_batch_idx
)
)
if
batch_idx
<
target_batch_idx
:
if
batch_idx
!=
target_batch_idx
:
raise
RuntimeError
(
"Unexperted batch_idx in data gathering loop. worker_id:{}."
.
format
(
worker_id
)
)
else
:
gathered_data
[
batch_idx
][
worker_id
]
=
batch_part
if
len
(
gathered_data
[
target_batch_idx
])
<
self
.
num_workers
:
length
=
len
(
gathered_data
[
target_batch_idx
])
if
self
.
strict
:
raise
RuntimeError
(
"Parts missing in data gathering loop."
)
logger
.
warning
(
"target_batch_idx:{}, {} part(s) missing."
.
format
(
target_batch_idx
,
self
.
num_workers
-
length
)
)
del
gathered_data
[
target_batch_idx
]
with
self
.
target_batch_idx
.
get_lock
():
self
.
target_batch_idx
.
value
+=
1
continue
full_trans_items
.
extend
(
trans_items
)
# Merge different parts.
full_batch
=
[[]
for
_
in
range
(
len
(
gathered_data
[
target_batch_idx
][
0
]))]
for
idx
in
range
(
self
.
num_workers
):
for
i
,
field
in
enumerate
(
gathered_data
[
target_batch_idx
][
idx
]):
full_batch
[
i
].
append
(
field
)
full_batch
=
tuple
([
np
.
concatenate
(
field
,
axis
=
0
)
for
field
in
full_batch
])
# Merge different parts into a batch.
full_batch
=
collator
.
apply
(
full_trans_items
)
while
True
:
try
:
batch_queue
.
put
(
full_batch
,
timeout
=
1
)
break
except
queue
.
Full
:
if
self
.
shutdown_flag
.
value
==
1
:
if
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"batch queue is full!"
)
continue
del
gathered_data
[
target_batch_idx
]
with
self
.
target_batch
_idx
.
get_lock
():
self
.
target_batch
_idx
.
value
+=
1
with
target
_idx
.
get_lock
():
target
_idx
.
value
+=
1
batch_queue
.
disconnect_client
()
def
_data_selecting_loop
(
self
,
batch_part_queues
,
batch_queue
):
r
"""Make sure that batch is generated exactly with the same order as generated indices."""
buffer_batches
=
{}
def
_data_selecting_loop
(
trans_data_queues
,
batch_queue
,
collator
,
length
,
num_workers
,
shutdown_flag
,
target_idx
,
):
# Make sure that batch is generated exactly with the same order as generated indices
while
True
:
if
self
.
shutdown_flag
.
value
==
1
:
if
shutdown_flag
.
value
==
1
:
break
target_batch_idx
=
self
.
target_batch
_idx
.
value
target_batch_idx
=
target
_idx
.
value
if
target_batch_idx
>=
len
(
self
)
:
if
target_batch_idx
>=
length
:
break
if
target_batch_idx
in
buffer_batches
:
target_worker_id
=
target_batch_idx
%
num_workers
while
True
:
try
:
batch_queue
.
put
(
buffer_batches
[
target_batch_idx
],
timeout
=
1
,
)
break
except
queue
.
Full
:
if
self
.
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"batch queue is full!"
)
with
self
.
target_batch_idx
.
get_lock
():
self
.
target_batch_idx
.
value
+=
1
del
buffer_batches
[
target_batch_idx
]
continue
target_worker_id
=
target_batch_idx
%
self
.
num_workers
while
True
:
try
:
(
batch_idx
,),
batch_data
=
batch_part_queues
[
target_worker_id
].
get
(
batch_idx
,
trans_items
=
trans_data_queues
[
target_worker_id
].
get
(
timeout
=
MP_QUEUE_GET_TIMEOUT
)
batch_data
=
collator
.
apply
(
trans_items
)
break
except
queue
.
Empty
:
if
self
.
shutdown_flag
.
value
==
1
:
if
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"worker:{} data queue get timeout! target batch idx:{}"
.
format
(
...
...
@@ -411,136 +478,23 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
)
)
if
batch_idx
<
target_batch_idx
:
raise
RuntimeError
(
"batch_idx smaller than target_batch_idx"
)
elif
batch_idx
>
target_batch_idx
:
if
self
.
strict
:
raise
RuntimeError
(
"batch_idx larger than target_batch_idx"
)
logger
.
warning
(
"missing target batch idx:{}, batch idx:{}"
.
format
(
target_batch_idx
,
batch_idx
)
)
buffer_batches
[
batch_idx
]
=
batch_data
else
:
try
:
batch_queue
.
put
(
batch_data
,
timeout
=
1
)
except
queue
.
Full
:
buffer_batches
[
batch_idx
]
=
batch_data
continue
with
self
.
target_batch_idx
.
get_lock
():
self
.
target_batch_idx
.
value
+=
1
batch_queue
.
disconnect_client
()
def
_check_workers
(
self
):
"""Check the status of each worker and restart if necessary."""
if
not
self
.
data_collecting_worker
.
is_alive
():
exitcode
=
self
.
task_feeding_worker
.
exitcode
if
exitcode
!=
0
:
raise
RuntimeError
(
"data collecting worker died. {}"
.
format
(
exitcode
))
if
self
.
strict
:
if
not
self
.
task_feeding_worker
.
is_alive
():
exitcode
=
self
.
task_feeding_worker
.
exitcode
if
exitcode
!=
0
:
raise
RuntimeError
(
"task feeding worker died. {}"
.
format
(
exitcode
))
for
worker_id
,
worker
in
enumerate
(
self
.
workers
):
if
not
worker
.
is_alive
():
exitcode
=
worker
.
exitcode
if
exitcode
!=
0
:
if
batch_idx
!=
target_batch_idx
:
raise
RuntimeError
(
"worker:{} died. {}"
.
format
(
worker_id
,
exitcode
)
"batch_idx {} mismatch the target_batch_idx {}"
.
format
(
batch_idx
,
target_batch_idx
)
else
:
if
not
self
.
task_feeding_worker
.
is_alive
():
exitcode
=
self
.
task_feeding_worker
.
exitcode
if
exitcode
!=
0
:
logger
.
error
(
"task feeding worker died {}. Restarting"
.
format
(
exitcode
)
)
self
.
task_feeding_worker
.
join
()
self
.
task_feeding_worker
=
multiprocessing
.
Process
(
target
=
self
.
_task_feeding_loop
,
args
=
(
iter
(
self
.
sampler
),
self
.
divide
),
daemon
=
True
,
)
self
.
task_feeding_worker
.
start
()
failed_num
=
0
for
worker_id
in
range
(
self
.
num_workers
):
if
self
.
workers
[
worker_id
].
is_alive
():
continue
exitcode
=
worker
.
exitcode
if
exitcode
==
0
:
continue
logger
.
error
(
"worker {} died. Restarting"
.
format
(
worker_id
))
failed_num
+=
1
self
.
workers
[
worker_id
].
join
()
worker
=
multiprocessing
.
Process
(
target
=
self
.
_worker_loop
,
args
=
(
self
.
task_queues
[
worker_id
],
self
.
batch_part_queues
[
worker_id
],
self
.
transform
,
self
.
collator
,
self
.
seed
+
worker_id
+
1
,
),
daemon
=
True
,
)
worker
.
start
()
self
.
workers
[
worker_id
]
=
worker
if
failed_num
>
0
:
logger
.
error
(
"{} worker had exited"
.
format
(
failed_num
))
else
:
logger
.
debug
(
"all workers are alive."
)
def
_try_get_next_batch
(
self
):
start_time
=
time
.
time
()
while
True
:
self
.
_check_workers
()
try
:
return
self
.
batch_queue
.
get
(
timeout
=
1
)
except
queue
.
Empty
:
logger
.
debug
(
"batch queue empty!"
)
waited_time
=
time
.
time
()
-
start_time
if
self
.
timeout
>
0
:
if
waited_time
>
self
.
timeout
:
raise
RuntimeError
(
"get_next_batch timeout!"
)
def
_get_next_batch
(
self
):
batch_data
=
self
.
_try_get_next_batch
()
return
batch_data
def
_shutdown
(
self
):
with
self
.
shutdown_flag
.
get_lock
():
self
.
shutdown_flag
.
value
=
1
if
self
.
task_feeding_worker
.
is_alive
():
self
.
task_feeding_worker
.
terminate
()
self
.
task_feeding_worker
.
join
()
if
self
.
data_collecting_worker
.
is_alive
():
self
.
data_collecting_worker
.
terminate
()
self
.
data_collecting_worker
.
join
()
for
worker
in
self
.
workers
:
if
worker
.
is_alive
():
worker
.
terminate
()
worker
.
join
()
for
q
in
self
.
batch_part_queues
:
q
.
cancel_join_thread
()
q
.
close
()
for
q
in
self
.
task_queues
:
q
.
cancel_join_thread
()
q
.
close
()
batch_queue
.
put
(
batch_data
,
timeout
=
1
)
break
except
queue
.
Full
:
if
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"batch queue is full!"
)
self
.
batch_queue
.
cancel_join_thread
()
self
.
batch_queue
.
close
()
with
target_idx
.
get_lock
():
target_idx
.
value
+=
1
def
__del__
(
self
):
if
self
.
__initialized
:
self
.
_shutdown
()
batch_queue
.
disconnect_client
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录