Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f04e0d77
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
f04e0d77
编写于
10月 21, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/data): dpflow dataset, stream sampler and loader
GitOrigin-RevId: cbb4510a13625e7c2203cd1358a96208849029ca
上级
be511a56
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
236 addition
and
16 deletion
+236
-16
imperative/python/megengine/data/__init__.py
imperative/python/megengine/data/__init__.py
+1
-0
imperative/python/megengine/data/dataloader.py
imperative/python/megengine/data/dataloader.py
+198
-11
imperative/python/megengine/data/sampler.py
imperative/python/megengine/data/sampler.py
+37
-5
未找到文件。
imperative/python/megengine/data/__init__.py
浏览文件 @
f04e0d77
...
@@ -14,4 +14,5 @@ from .sampler import (
...
@@ -14,4 +14,5 @@ from .sampler import (
ReplacementSampler
,
ReplacementSampler
,
Sampler
,
Sampler
,
SequentialSampler
,
SequentialSampler
,
StreamSampler
,
)
)
imperative/python/megengine/data/dataloader.py
浏览文件 @
f04e0d77
...
@@ -19,8 +19,8 @@ import numpy as np
...
@@ -19,8 +19,8 @@ import numpy as np
from
..logger
import
get_logger
from
..logger
import
get_logger
from
..random.rng
import
_random_seed_generator
from
..random.rng
import
_random_seed_generator
from
.collator
import
Collator
from
.collator
import
Collator
from
.dataset
import
Dataset
from
.dataset
import
Dataset
,
MapDataset
,
StreamDataset
from
.sampler
import
Sampler
,
SequentialSampler
from
.sampler
import
Sampler
,
SequentialSampler
,
StreamSampler
from
.transform
import
PseudoTransform
,
Transform
from
.transform
import
PseudoTransform
,
Transform
logger
=
get_logger
(
__name__
)
logger
=
get_logger
(
__name__
)
...
@@ -82,13 +82,21 @@ class DataLoader:
...
@@ -82,13 +82,21 @@ class DataLoader:
raise
ValueError
(
"divide should not be set to True when num_workers <= 1"
)
raise
ValueError
(
"divide should not be set to True when num_workers <= 1"
)
self
.
dataset
=
dataset
self
.
dataset
=
dataset
self
.
num_workers
=
num_workers
self
.
num_workers
=
num_workers
self
.
timeout
=
timeout
self
.
timeout
=
timeout
self
.
divide
=
divide
self
.
divide
=
divide
if
sampler
is
None
:
if
sampler
is
None
:
if
isinstance
(
dataset
,
MapDataset
):
self
.
sampler
=
SequentialSampler
(
dataset
,
batch_size
=
1
,
drop_last
=
False
)
self
.
sampler
=
SequentialSampler
(
dataset
,
batch_size
=
1
,
drop_last
=
False
)
elif
isinstance
(
dataset
,
StreamDataset
):
self
.
sampler
=
StreamSampler
(
batch_size
=
1
)
else
:
raise
TypeError
(
"can not recognize this kind of dataset: %s"
%
type
(
dataset
)
)
else
:
else
:
self
.
sampler
=
sampler
self
.
sampler
=
sampler
...
@@ -120,16 +128,26 @@ class DataLoader:
...
@@ -120,16 +128,26 @@ class DataLoader:
"pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
"pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
)
)
self
.
num_workers
=
0
self
.
num_workers
=
0
if
self
.
num_workers
==
0
:
if
isinstance
(
self
.
dataset
,
StreamDataset
):
return
_SerialDataLoaderIter
(
self
)
if
not
self
.
num_workers
:
return
_SerialStreamDataLoaderIter
(
self
)
else
:
return
_ParallelStreamDataLoaderIter
(
self
)
elif
isinstance
(
self
.
dataset
,
MapDataset
):
if
not
self
.
num_workers
:
return
_SerialMapDataLoaderIter
(
self
)
else
:
else
:
return
_ParallelDataLoaderIter
(
self
)
return
_ParallelMapDataLoaderIter
(
self
)
else
:
raise
TypeError
(
"can not recognize this kind of dataset: %s"
%
type
(
self
.
dataset
)
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
sampler
)
return
len
(
self
.
sampler
)
class
_BaseDataLoaderIter
:
class
_Base
Map
DataLoaderIter
:
def
__init__
(
self
,
loader
):
def
__init__
(
self
,
loader
):
self
.
dataset
=
loader
.
dataset
self
.
dataset
=
loader
.
dataset
self
.
sampler
=
loader
.
sampler
self
.
sampler
=
loader
.
sampler
...
@@ -158,9 +176,9 @@ class _BaseDataLoaderIter:
...
@@ -158,9 +176,9 @@ class _BaseDataLoaderIter:
return
minibatch
return
minibatch
class
_Serial
DataLoaderIter
(
_Base
DataLoaderIter
):
class
_Serial
MapDataLoaderIter
(
_BaseMap
DataLoaderIter
):
def
__init__
(
self
,
loader
):
def
__init__
(
self
,
loader
):
super
(
_SerialDataLoaderIter
,
self
).
__init__
(
loader
)
super
(
_Serial
Map
DataLoaderIter
,
self
).
__init__
(
loader
)
self
.
indices_iter
=
iter
(
self
.
sampler
)
self
.
indices_iter
=
iter
(
self
.
sampler
)
def
_get_next_batch
(
self
):
def
_get_next_batch
(
self
):
...
@@ -170,11 +188,11 @@ class _SerialDataLoaderIter(_BaseDataLoaderIter):
...
@@ -170,11 +188,11 @@ class _SerialDataLoaderIter(_BaseDataLoaderIter):
return
self
.
collator
.
apply
(
trans_items
)
return
self
.
collator
.
apply
(
trans_items
)
class
_Parallel
DataLoaderIter
(
_Base
DataLoaderIter
):
class
_Parallel
MapDataLoaderIter
(
_BaseMap
DataLoaderIter
):
__initialized
=
False
__initialized
=
False
def
__init__
(
self
,
loader
):
def
__init__
(
self
,
loader
):
super
(
_ParallelDataLoaderIter
,
self
).
__init__
(
loader
)
super
(
_Parallel
Map
DataLoaderIter
,
self
).
__init__
(
loader
)
self
.
task_queues
=
[
self
.
task_queues
=
[
multiprocessing
.
Queue
(
maxsize
=
2
)
for
_
in
range
(
self
.
num_workers
)
multiprocessing
.
Queue
(
maxsize
=
2
)
for
_
in
range
(
self
.
num_workers
)
...
@@ -326,6 +344,175 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
...
@@ -326,6 +344,175 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
self
.
_shutdown
()
self
.
_shutdown
()
class
_BaseStreamDataLoaderIter
:
def
__init__
(
self
,
loader
):
self
.
dataset
=
loader
.
dataset
self
.
sampler
=
loader
.
sampler
self
.
transform
=
loader
.
transform
self
.
collator
=
loader
.
collator
self
.
num_workers
=
loader
.
num_workers
self
.
timeout
=
loader
.
timeout
self
.
post_process
=
self
.
dataset
.
post_process
def
_get_next_batch
(
self
):
raise
NotImplementedError
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
return
self
.
post_process
(
self
.
_get_next_batch
())
class
_SerialStreamDataLoaderIter
(
_BaseStreamDataLoaderIter
):
def
__init__
(
self
,
loader
):
super
().
__init__
(
loader
)
self
.
dataset_iter
=
iter
(
self
.
dataset
)
def
_get_next_batch
(
self
):
ret
=
[]
start_time
=
time
.
time
()
while
len
(
ret
)
!=
self
.
sampler
.
batch_size
:
waited_time
=
time
.
time
()
-
start_time
if
self
.
timeout
>
0
and
waited_time
>
self
.
timeout
:
raise
RuntimeError
(
"get_next_batch timeout!"
)
item
=
next
(
self
.
dataset_iter
)
for
idx
in
range
(
len
(
item
[
0
])):
trans_item
=
self
.
transform
.
apply
(
tuple
(
e
[
idx
]
for
e
in
item
))
ret
.
append
(
trans_item
)
if
len
(
ret
)
==
self
.
sampler
.
batch_size
:
break
return
self
.
collator
.
apply
(
ret
)
class
_ParallelStreamDataLoaderIter
(
_BaseStreamDataLoaderIter
):
__initialized
=
False
def
__init__
(
self
,
loader
):
super
().
__init__
(
loader
)
self
.
shutdown_flag
=
multiprocessing
.
Value
(
"i"
,
0
)
# shared-memory queue implemented by pyarrow plasma store
from
._queue
import
PlasmaShmQueue
self
.
batch_queue
=
PlasmaShmQueue
(
maxsize
=
2
)
self
.
workers
=
[]
self
.
worker_queues
=
[
multiprocessing
.
Queue
(
maxsize
=
1
)
for
_
in
range
(
self
.
num_workers
)
]
for
worker_id
in
range
(
self
.
num_workers
):
worker
=
multiprocessing
.
Process
(
target
=
self
.
_gen_data
,
args
=
(
worker_id
,),
daemon
=
True
)
worker
.
start
()
self
.
workers
.
append
(
worker
)
self
.
collator_worker
=
multiprocessing
.
Process
(
target
=
self
.
_gen_batch
,
daemon
=
True
)
self
.
collator_worker
.
start
()
self
.
__initialized
=
True
def
_gen_data
(
self
,
worker_id
):
dataset_iter
=
iter
(
self
.
dataset
)
while
True
:
if
self
.
shutdown_flag
.
value
==
1
:
break
item
=
next
(
dataset_iter
)
for
idx
in
range
(
len
(
item
[
0
])):
trans_item
=
self
.
transform
.
apply
(
tuple
(
e
[
idx
]
for
e
in
item
))
while
True
:
try
:
self
.
worker_queues
[
worker_id
].
put
(
trans_item
)
break
except
queue
.
Full
:
if
self
.
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"batch part queue is full"
)
def
_gen_batch
(
self
):
cnt
=
-
1
trans_items
=
[]
while
True
:
if
self
.
shutdown_flag
.
value
==
1
:
break
cnt
+=
1
queue_id
=
cnt
%
self
.
num_workers
try
:
trans_item
=
self
.
worker_queues
[
queue_id
].
get
(
timeout
=
MP_QUEUE_GET_TIMEOUT
)
except
queue
.
Empty
:
continue
trans_items
.
append
(
trans_item
)
if
len
(
trans_items
)
==
self
.
sampler
.
batch_size
:
batch_data
=
self
.
collator
.
apply
(
trans_items
)
while
True
:
try
:
self
.
batch_queue
.
put
(
batch_data
,
timeout
=
1
)
break
except
queue
.
Full
:
if
self
.
shutdown_flag
.
value
==
1
:
break
logger
.
debug
(
"batch queue is full"
)
trans_items
=
[]
def
_check_workers
(
self
):
if
not
self
.
collator_worker
.
is_alive
():
exitcode
=
self
.
collator_worker
.
exitcode
if
exitcode
!=
0
:
raise
RuntimeError
(
"collator worker died. {}"
.
format
(
exitcode
))
for
worker_id
,
worker
in
enumerate
(
self
.
workers
):
if
not
worker
.
is_alive
():
exitcode
=
worker
.
exitcode
if
exitcode
!=
0
:
raise
RuntimeError
(
"worker: {} died. {}"
.
format
(
worker_id
,
exitcode
)
)
def
_try_get_next_batch
(
self
):
start_time
=
time
.
time
()
while
True
:
self
.
_check_workers
()
try
:
return
self
.
batch_queue
.
get
(
timeout
=
1
)
except
queue
.
Empty
:
logger
.
debug
(
"batch queue empty!"
)
waited_time
=
time
.
time
()
-
start_time
if
self
.
timeout
>
0
and
waited_time
>
self
.
timeout
:
raise
RuntimeError
(
"get_next_batch timeout!"
)
def
_get_next_batch
(
self
):
batch_data
=
self
.
_try_get_next_batch
()
return
batch_data
def
_shutdown
(
self
):
with
self
.
shutdown_flag
.
get_lock
():
self
.
shutdown_flag
.
value
=
1
if
self
.
collator_worker
.
is_alive
():
self
.
collator_worker
.
terminate
()
self
.
collator_worker
.
join
()
for
worker
in
self
.
workers
:
if
worker
.
is_alive
():
worker
.
terminate
()
worker
.
join
()
for
q
in
self
.
worker_queues
:
q
.
cancel_join_thread
()
q
.
close
()
self
.
batch_queue
.
cancel_join_thread
()
self
.
batch_queue
.
close
()
def
__del__
(
self
):
if
self
.
__initialized
:
self
.
_shutdown
()
def
_task_feeding_loop
(
def
_task_feeding_loop
(
indices_iter
,
task_queues
,
num_workers
,
divide
,
shutdown_flag
,
feed_batch_idx
indices_iter
,
task_queues
,
num_workers
,
divide
,
shutdown_flag
,
feed_batch_idx
):
):
...
...
imperative/python/megengine/data/sampler.py
浏览文件 @
f04e0d77
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
collections.abc
import
collections.abc
import
math
import
math
from
abc
import
ABC
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Generator
,
Iterator
,
List
,
Union
from
typing
import
Any
,
Generator
,
Iterator
,
List
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -17,6 +17,16 @@ import megengine.distributed as dist
...
@@ -17,6 +17,16 @@ import megengine.distributed as dist
class
Sampler
(
ABC
):
class
Sampler
(
ABC
):
r
"""
An abstract class for all Sampler
"""
@
abstractmethod
def
__init__
(
self
):
pass
class
MapSampler
(
Sampler
):
def
__init__
(
def
__init__
(
self
,
self
,
dataset
,
dataset
,
...
@@ -145,7 +155,29 @@ class Sampler(ABC):
...
@@ -145,7 +155,29 @@ class Sampler(ABC):
return
iter
(
batch_index
)
return
iter
(
batch_index
)
class
SequentialSampler
(
Sampler
):
class
StreamSampler
(
Sampler
):
"""
Sampler for stream dataset.
.. warning::
In the case of multiple workers, sampler should ensure that each worker gets
different data. But this class cannot do it yet, please build your own
dataset and sampler to achieve this goal.
"""
def
__init__
(
self
,
batch_size
=
1
):
self
.
batch_size
=
batch_size
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
return
range
(
self
.
batch_size
)
class
SequentialSampler
(
MapSampler
):
def
__init__
(
def
__init__
(
self
,
self
,
dataset
,
dataset
,
...
@@ -176,7 +208,7 @@ class SequentialSampler(Sampler):
...
@@ -176,7 +208,7 @@ class SequentialSampler(Sampler):
return
self
.
indices
return
self
.
indices
class
RandomSampler
(
Sampler
):
class
RandomSampler
(
Map
Sampler
):
def
__init__
(
def
__init__
(
self
,
self
,
dataset
,
dataset
,
...
@@ -205,7 +237,7 @@ class RandomSampler(Sampler):
...
@@ -205,7 +237,7 @@ class RandomSampler(Sampler):
return
self
.
rng
.
permutation
(
self
.
indices
).
tolist
()
return
self
.
rng
.
permutation
(
self
.
indices
).
tolist
()
class
ReplacementSampler
(
Sampler
):
class
ReplacementSampler
(
Map
Sampler
):
def
__init__
(
def
__init__
(
self
,
self
,
dataset
,
dataset
,
...
@@ -249,7 +281,7 @@ class ReplacementSampler(Sampler):
...
@@ -249,7 +281,7 @@ class ReplacementSampler(Sampler):
return
self
.
rng
.
multinomial
(
n
,
self
.
weights
,
self
.
num_samples
).
tolist
()
return
self
.
rng
.
multinomial
(
n
,
self
.
weights
,
self
.
num_samples
).
tolist
()
class
Infinite
(
Sampler
):
class
Infinite
(
Map
Sampler
):
r
"""Infinite Sampler warper for basic sampler."""
r
"""Infinite Sampler warper for basic sampler."""
def
sample
(
self
):
def
sample
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录