Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e082e277
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
e082e277
编写于
11月 05, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge/data): Refactor megeninge.data.dataset
GitOrigin-RevId: 1d9c61ce70059de9e7e0f804a67f48e54d3891a6
上级
f04e0d77
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
202 addition
and
48 deletion
+202
-48
imperative/python/megengine/data/__init__.py
imperative/python/megengine/data/__init__.py
+1
-0
imperative/python/megengine/data/dataloader.py
imperative/python/megengine/data/dataloader.py
+113
-43
imperative/python/megengine/data/sampler.py
imperative/python/megengine/data/sampler.py
+5
-2
imperative/python/test/unit/data/test_dataloader.py
imperative/python/test/unit/data/test_dataloader.py
+83
-3
未找到文件。
imperative/python/megengine/data/__init__.py
浏览文件 @
e082e277
...
@@ -10,6 +10,7 @@ from .collator import Collator
...
@@ -10,6 +10,7 @@ from .collator import Collator
from
.dataloader
import
DataLoader
from
.dataloader
import
DataLoader
from
.sampler
import
(
from
.sampler
import
(
Infinite
,
Infinite
,
MapSampler
,
RandomSampler
,
RandomSampler
,
ReplacementSampler
,
ReplacementSampler
,
Sampler
,
Sampler
,
...
...
imperative/python/megengine/data/dataloader.py
浏览文件 @
e082e277
...
@@ -20,7 +20,7 @@ from ..logger import get_logger
...
@@ -20,7 +20,7 @@ from ..logger import get_logger
from
..random.rng
import
_random_seed_generator
from
..random.rng
import
_random_seed_generator
from
.collator
import
Collator
from
.collator
import
Collator
from
.dataset
import
Dataset
,
MapDataset
,
StreamDataset
from
.dataset
import
Dataset
,
MapDataset
,
StreamDataset
from
.sampler
import
Sampler
,
SequentialSampler
,
StreamSampler
from
.sampler
import
MapSampler
,
Sampler
,
SequentialSampler
,
StreamSampler
from
.transform
import
PseudoTransform
,
Transform
from
.transform
import
PseudoTransform
,
Transform
logger
=
get_logger
(
__name__
)
logger
=
get_logger
(
__name__
)
...
@@ -88,17 +88,24 @@ class DataLoader:
...
@@ -88,17 +88,24 @@ class DataLoader:
self
.
divide
=
divide
self
.
divide
=
divide
if
sampler
is
None
:
if
isinstance
(
dataset
,
MapDataset
):
if
isinstance
(
dataset
,
MapDataset
):
self
.
sampler
=
(
self
.
sampler
=
SequentialSampler
(
dataset
,
batch_size
=
1
,
drop_last
=
False
)
sampler
elif
isinstance
(
dataset
,
StreamDataset
):
if
sampler
self
.
sampler
=
StreamSampler
(
batch_size
=
1
)
else
SequentialSampler
(
dataset
,
batch_size
=
1
,
drop_last
=
False
)
else
:
)
raise
TypeError
(
assert
isinstance
(
"can not recognize this kind of dataset: %s"
%
type
(
dataset
)
self
.
sampler
,
MapSampler
)
),
"types of dataset and sampler do not match"
elif
isinstance
(
dataset
,
StreamDataset
):
self
.
sampler
=
sampler
if
sampler
else
StreamSampler
(
batch_size
=
1
)
assert
isinstance
(
self
.
sampler
,
StreamSampler
),
"types of dataset and sampler do not match"
else
:
else
:
self
.
sampler
=
sampler
raise
TypeError
(
"can not recognize this kind of dataset: %s"
%
type
(
dataset
)
)
if
divide
:
if
divide
:
if
self
.
sampler
.
batch_size
<=
self
.
num_workers
:
if
self
.
sampler
.
batch_size
<=
self
.
num_workers
:
...
@@ -352,7 +359,6 @@ class _BaseStreamDataLoaderIter:
...
@@ -352,7 +359,6 @@ class _BaseStreamDataLoaderIter:
self
.
collator
=
loader
.
collator
self
.
collator
=
loader
.
collator
self
.
num_workers
=
loader
.
num_workers
self
.
num_workers
=
loader
.
num_workers
self
.
timeout
=
loader
.
timeout
self
.
timeout
=
loader
.
timeout
self
.
post_process
=
self
.
dataset
.
post_process
def
_get_next_batch
(
self
):
def
_get_next_batch
(
self
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -361,13 +367,15 @@ class _BaseStreamDataLoaderIter:
...
@@ -361,13 +367,15 @@ class _BaseStreamDataLoaderIter:
return
self
return
self
def
__next__
(
self
):
def
__next__
(
self
):
return
self
.
post_process
(
self
.
_get_next_batch
()
)
return
self
.
_get_next_batch
(
)
class
_SerialStreamDataLoaderIter
(
_BaseStreamDataLoaderIter
):
class
_SerialStreamDataLoaderIter
(
_BaseStreamDataLoaderIter
):
def
__init__
(
self
,
loader
):
def
__init__
(
self
,
loader
):
super
().
__init__
(
loader
)
super
().
__init__
(
loader
)
self
.
dataset_iter
=
iter
(
self
.
dataset
)
self
.
dataset_iter
=
iter
(
self
.
dataset
)
self
.
idx
=
0
self
.
data
=
None
def
_get_next_batch
(
self
):
def
_get_next_batch
(
self
):
ret
=
[]
ret
=
[]
...
@@ -376,11 +384,30 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
...
@@ -376,11 +384,30 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
waited_time
=
time
.
time
()
-
start_time
waited_time
=
time
.
time
()
-
start_time
if
self
.
timeout
>
0
and
waited_time
>
self
.
timeout
:
if
self
.
timeout
>
0
and
waited_time
>
self
.
timeout
:
raise
RuntimeError
(
"get_next_batch timeout!"
)
raise
RuntimeError
(
"get_next_batch timeout!"
)
item
=
next
(
self
.
dataset_iter
)
if
self
.
idx
!=
0
:
for
idx
in
range
(
len
(
item
[
0
])):
data
=
self
.
data
trans_item
=
self
.
transform
.
apply
(
tuple
(
e
[
idx
]
for
e
in
item
))
else
:
ret
.
append
(
trans_item
)
try
:
raw_data
=
next
(
self
.
dataset_iter
)
except
:
continue
assert
len
(
raw_data
)
==
2
and
isinstance
(
raw_data
[
0
],
bool
),
"raw_data must be a tuple"
if
not
raw_data
[
0
]:
data
=
list
((
x
,)
for
x
in
raw_data
[
1
])
else
:
data
=
raw_data
[
1
]
for
idx
in
range
(
self
.
idx
,
len
(
data
[
0
])):
trans_data
=
self
.
transform
.
apply
(
tuple
(
e
[
idx
]
for
e
in
data
))
ret
.
append
(
trans_data
)
if
len
(
ret
)
==
self
.
sampler
.
batch_size
:
if
len
(
ret
)
==
self
.
sampler
.
batch_size
:
if
idx
+
1
==
len
(
data
[
0
]):
self
.
idx
=
0
self
.
data
=
None
else
:
self
.
idx
=
idx
self
.
data
=
data
break
break
return
self
.
collator
.
apply
(
ret
)
return
self
.
collator
.
apply
(
ret
)
...
@@ -393,45 +420,80 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
...
@@ -393,45 +420,80 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
self
.
shutdown_flag
=
multiprocessing
.
Value
(
"i"
,
0
)
self
.
shutdown_flag
=
multiprocessing
.
Value
(
"i"
,
0
)
self
.
raw_data_queues
=
[
multiprocessing
.
Queue
(
maxsize
=
1
)
for
_
in
range
(
self
.
num_workers
)
]
self
.
trans_data_queues
=
[
multiprocessing
.
Queue
(
maxsize
=
1
)
for
_
in
range
(
self
.
num_workers
)
]
# shared-memory queue implemented by pyarrow plasma store
# shared-memory queue implemented by pyarrow plasma store
from
._queue
import
PlasmaShmQueue
from
._queue
import
PlasmaShmQueue
self
.
batch_queue
=
PlasmaShmQueue
(
maxsize
=
2
)
self
.
batch_queue
=
PlasmaShmQueue
(
maxsize
=
2
)
self
.
workers
=
[]
self
.
worker_queues
=
[
self
.
recieve_worker
=
multiprocessing
.
Process
(
target
=
self
.
_recieve
,
daemon
=
True
)
multiprocessing
.
Queue
(
maxsize
=
1
)
for
_
in
range
(
self
.
num_workers
)
self
.
recieve_worker
.
start
()
]
self
.
transform_workers
=
[]
for
worker_id
in
range
(
self
.
num_workers
):
for
worker_id
in
range
(
self
.
num_workers
):
worker
=
multiprocessing
.
Process
(
worker
=
multiprocessing
.
Process
(
target
=
self
.
_
gen_data
,
args
=
(
worker_id
,),
daemon
=
True
target
=
self
.
_
transform
,
args
=
(
worker_id
,),
daemon
=
True
)
)
worker
.
start
()
worker
.
start
()
self
.
workers
.
append
(
worker
)
self
.
transform_workers
.
append
(
worker
)
self
.
collator_worker
=
multiprocessing
.
Process
(
target
=
self
.
_gen_batch
,
daemon
=
True
self
.
collect_worker
=
multiprocessing
.
Process
(
target
=
self
.
_collect
,
daemon
=
True
)
)
self
.
collect_worker
.
start
()
self
.
collator_worker
.
start
()
self
.
__initialized
=
True
self
.
__initialized
=
True
def
_
gen_data
(
self
,
worker_id
):
def
_
recieve
(
self
):
dataset_iter
=
iter
(
self
.
dataset
)
dataset_iter
=
iter
(
self
.
dataset
)
cnt
=
-
1
while
True
:
while
True
:
if
self
.
shutdown_flag
.
value
==
1
:
if
self
.
shutdown_flag
.
value
==
1
:
break
break
item
=
next
(
dataset_iter
)
raw_data
=
next
(
dataset_iter
)
for
idx
in
range
(
len
(
item
[
0
])):
assert
len
(
raw_data
)
==
2
and
isinstance
(
trans_item
=
self
.
transform
.
apply
(
tuple
(
e
[
idx
]
for
e
in
item
))
raw_data
[
0
],
bool
),
"raw_data must be a tuple"
if
not
raw_data
[
0
]:
data
=
list
((
x
,)
for
x
in
raw_data
[
1
])
else
:
data
=
raw_data
[
1
]
for
idx
in
range
(
len
(
data
[
0
])):
while
True
:
while
True
:
cnt
+=
1
qid
=
cnt
%
self
.
num_workers
try
:
try
:
self
.
worker_queues
[
worker_id
].
put
(
trans_item
)
self
.
raw_data_queues
[
qid
].
put
(
tuple
(
e
[
idx
]
for
e
in
data
)
)
break
break
except
queue
.
Full
:
except
queue
.
Full
:
if
self
.
shutdown_flag
.
value
==
1
:
if
self
.
shutdown_flag
.
value
==
1
:
break
break
logger
.
debug
(
"
batch part
queue is full"
)
logger
.
debug
(
"
raw data
queue is full"
)
def
_gen_batch
(
self
):
def
_transform
(
self
,
worker_id
):
while
True
:
if
self
.
shutdown_flag
.
value
==
1
:
break
try
:
data
=
self
.
raw_data_queues
[
worker_id
].
get
(
timeout
=
MP_QUEUE_GET_TIMEOUT
)
except
queue
.
Empty
:
continue
trans_data
=
self
.
transform
.
apply
(
data
)
while
True
:
try
:
self
.
trans_data_queues
[
worker_id
].
put
(
trans_data
)
break
except
queue
.
Full
:
if
self
.
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"batch queue if full"
)
def
_collect
(
self
):
cnt
=
-
1
cnt
=
-
1
trans_items
=
[]
trans_items
=
[]
while
True
:
while
True
:
...
@@ -440,7 +502,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
...
@@ -440,7 +502,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
cnt
+=
1
cnt
+=
1
queue_id
=
cnt
%
self
.
num_workers
queue_id
=
cnt
%
self
.
num_workers
try
:
try
:
trans_item
=
self
.
worker
_queues
[
queue_id
].
get
(
trans_item
=
self
.
trans_data
_queues
[
queue_id
].
get
(
timeout
=
MP_QUEUE_GET_TIMEOUT
timeout
=
MP_QUEUE_GET_TIMEOUT
)
)
except
queue
.
Empty
:
except
queue
.
Empty
:
...
@@ -459,12 +521,12 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
...
@@ -459,12 +521,12 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
trans_items
=
[]
trans_items
=
[]
def
_check_workers
(
self
):
def
_check_workers
(
self
):
if
not
self
.
coll
ator
_worker
.
is_alive
():
if
not
self
.
coll
ect
_worker
.
is_alive
():
exitcode
=
self
.
coll
ator
_worker
.
exitcode
exitcode
=
self
.
coll
ect
_worker
.
exitcode
if
exitcode
!=
0
:
if
exitcode
!=
0
:
raise
RuntimeError
(
"collator worker died. {}"
.
format
(
exitcode
))
raise
RuntimeError
(
"collator worker died. {}"
.
format
(
exitcode
))
for
worker_id
,
worker
in
enumerate
(
self
.
workers
):
for
worker_id
,
worker
in
enumerate
(
self
.
transform_
workers
):
if
not
worker
.
is_alive
():
if
not
worker
.
is_alive
():
exitcode
=
worker
.
exitcode
exitcode
=
worker
.
exitcode
if
exitcode
!=
0
:
if
exitcode
!=
0
:
...
@@ -492,16 +554,24 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
...
@@ -492,16 +554,24 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
with
self
.
shutdown_flag
.
get_lock
():
with
self
.
shutdown_flag
.
get_lock
():
self
.
shutdown_flag
.
value
=
1
self
.
shutdown_flag
.
value
=
1
if
self
.
collator
_worker
.
is_alive
():
if
self
.
recieve
_worker
.
is_alive
():
self
.
collator
_worker
.
terminate
()
self
.
recieve
_worker
.
terminate
()
self
.
collator
_worker
.
join
()
self
.
recieve
_worker
.
join
()
for
worker
in
self
.
workers
:
if
self
.
collect_worker
.
is_alive
():
self
.
collect_worker
.
terminate
()
self
.
collect_worker
.
join
()
for
worker
in
self
.
transform_workers
:
if
worker
.
is_alive
():
if
worker
.
is_alive
():
worker
.
terminate
()
worker
.
terminate
()
worker
.
join
()
worker
.
join
()
for
q
in
self
.
worker_queues
:
for
q
in
self
.
raw_data_queues
:
q
.
cancel_join_thread
()
q
.
close
()
for
q
in
self
.
trans_data_queues
:
q
.
cancel_join_thread
()
q
.
cancel_join_thread
()
q
.
close
()
q
.
close
()
...
...
imperative/python/megengine/data/sampler.py
浏览文件 @
e082e277
...
@@ -161,10 +161,13 @@ class StreamSampler(Sampler):
...
@@ -161,10 +161,13 @@ class StreamSampler(Sampler):
.. warning::
.. warning::
In the case of multiple
worker
s, sampler should ensure that each worker gets
In the case of multiple
machine
s, sampler should ensure that each worker gets
different data. But this class cannot do it yet, please build your own
different data. But this class cannot do it yet, please build your own
dataset and sampler to achieve this goal.
dataset and sampler to achieve this goal.
Usually, meth::`~.StreamDataset.__iter__` can return different iterator by
``rank = dist.get_rank()``. So that they will get different data.
"""
"""
def
__init__
(
self
,
batch_size
=
1
):
def
__init__
(
self
,
batch_size
=
1
):
...
@@ -174,7 +177,7 @@ class StreamSampler(Sampler):
...
@@ -174,7 +177,7 @@ class StreamSampler(Sampler):
return
self
return
self
def
__next__
(
self
):
def
__next__
(
self
):
return
range
(
self
.
batch_size
)
return
iter
(
range
(
self
.
batch_size
)
)
class
SequentialSampler
(
MapSampler
):
class
SequentialSampler
(
MapSampler
):
...
...
imperative/python/test/unit/data/test_dataloader.py
浏览文件 @
e082e277
...
@@ -15,9 +15,15 @@ import pytest
...
@@ -15,9 +15,15 @@ import pytest
from
megengine.data.collator
import
Collator
from
megengine.data.collator
import
Collator
from
megengine.data.dataloader
import
DataLoader
from
megengine.data.dataloader
import
DataLoader
from
megengine.data.dataset
import
ArrayDataset
from
megengine.data.dataset
import
ArrayDataset
,
StreamDataset
from
megengine.data.sampler
import
RandomSampler
,
SequentialSampler
from
megengine.data.sampler
import
RandomSampler
,
SequentialSampler
,
StreamSampler
from
megengine.data.transform
import
PseudoTransform
,
Transform
from
megengine.data.transform
import
(
Compose
,
Normalize
,
PseudoTransform
,
ToMode
,
Transform
,
)
def
init_dataset
():
def
init_dataset
():
...
@@ -54,6 +60,80 @@ def test_dataloader_init():
...
@@ -54,6 +60,80 @@ def test_dataloader_init():
assert
len
(
dataloader
)
==
16
assert
len
(
dataloader
)
==
16
class
MyStream
(
StreamDataset
):
def
__init__
(
self
,
number
,
batch
=
False
,
error
=
False
):
self
.
number
=
number
self
.
batch
=
batch
self
.
error
=
error
def
__iter__
(
self
):
for
cnt
in
range
(
self
.
number
):
if
self
.
batch
:
data
=
np
.
random
.
randint
(
0
,
256
,
(
2
,
32
,
32
,
3
),
dtype
=
"uint8"
)
yield
(
True
,
(
data
,
[
cnt
,
cnt
-
self
.
number
]))
else
:
data
=
np
.
random
.
randint
(
0
,
256
,
(
32
,
32
,
3
),
dtype
=
"uint8"
)
if
self
.
error
:
yield
(
data
,
cnt
)
else
:
yield
(
False
,
(
data
,
cnt
))
raise
StopIteration
@
pytest
.
mark
.
parametrize
(
"batch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader
(
batch
,
num_workers
):
dataset
=
MyStream
(
100
,
batch
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
sampler
,
Compose
([
Normalize
(
mean
=
(
103
,
116
,
123
),
std
=
(
57
,
57
,
58
)),
ToMode
(
"CHW"
)]),
num_workers
=
num_workers
,
)
check_set
=
set
()
for
step
,
data
in
enumerate
(
dataloader
):
if
step
==
10
:
break
assert
data
[
0
].
shape
==
(
4
,
3
,
32
,
32
)
assert
data
[
1
].
shape
==
(
4
,)
for
i
in
data
[
1
]:
assert
i
not
in
check_set
check_set
.
add
(
i
)
def
test_stream_dataloader_error
():
dataset
=
MyStream
(
100
,
error
=
True
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
sampler
)
with
pytest
.
raises
(
AssertionError
,
match
=
r
".*tuple.*"
):
data_iter
=
iter
(
dataloader
)
next
(
data_iter
)
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader_timeout
(
num_workers
):
dataset
=
MyStream
(
100
,
False
)
sampler
=
StreamSampler
(
batch_size
=
4
)
class
TimeoutTransform
(
Transform
):
def
__init__
(
self
):
pass
def
apply
(
self
,
input
):
time
.
sleep
(
10
)
return
input
dataloader
=
DataLoader
(
dataset
,
sampler
,
TimeoutTransform
(),
num_workers
=
num_workers
,
timeout
=
5
)
with
pytest
.
raises
(
RuntimeError
,
match
=
r
".*timeout.*"
):
data_iter
=
iter
(
dataloader
)
next
(
data_iter
)
def
test_dataloader_serial
():
def
test_dataloader_serial
():
dataset
=
init_dataset
()
dataset
=
init_dataset
()
dataloader
=
DataLoader
(
dataloader
=
DataLoader
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录