Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
cc7b2f16
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
cc7b2f16
编写于
6月 20, 2023
作者:
M
Megvii Engine Team
提交者:
Wanwan1996
7月 25, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(data): fix pyarrow.plasma import error in pyarrow1.12
GitOrigin-RevId: b5e1cd3be59cc80a3cc5bf6a83855ede2a2cd38a
上级
d4fbffe3
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
112 addition
and
2 deletion
+112
-2
imperative/python/megengine/data/dataloader.py
imperative/python/megengine/data/dataloader.py
+8
-1
imperative/python/megengine/data/tools/_queue.py
imperative/python/megengine/data/tools/_queue.py
+7
-1
imperative/python/test/unit/data/test_dataloader.py
imperative/python/test/unit/data/test_dataloader.py
+85
-0
imperative/python/test/unit/data/test_pre_dataloader.py
imperative/python/test/unit/data/test_pre_dataloader.py
+12
-0
未找到文件。
imperative/python/megengine/data/dataloader.py
浏览文件 @
cc7b2f16
...
...
@@ -691,7 +691,14 @@ def _worker_loop(
data
=
worker_id
iteration_end
=
True
else
:
raise
e
from
.tools._queue
import
_ExceptionWrapper
exc_info
=
sys
.
exc_info
()
where
=
"in DataLoader worker process {}"
.
format
(
worker_id
)
exc_msg
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
data
=
_ExceptionWrapper
(
exc_info
[
0
].
__name__
,
exc_msg
,
where
)
data
=
pickle
.
dumps
(
data
)
data_queue
.
put
((
idx
,
data
))
del
data
,
idx
,
place_holder
,
r
...
...
imperative/python/megengine/data/tools/_queue.py
浏览文件 @
cc7b2f16
...
...
@@ -6,10 +6,16 @@ import subprocess
from
multiprocessing
import
Queue
import
pyarrow
import
pyarrow.plasma
as
plasma
MGE_PLASMA_MEMORY
=
int
(
os
.
environ
.
get
(
"MGE_PLASMA_MEMORY"
,
4000000000
))
# 4GB
try
:
import
pyarrow.plasma
as
plasma
except
ModuleNotFoundError
:
raise
RuntimeError
(
"pyarrow remove plasma in version 12.0.0, please use pyarrow vserion < 12.0.0"
)
# Each process only need to start one plasma store, so we set it as a global variable.
# TODO: how to share between different processes?
MGE_PLASMA_STORE_MANAGER
=
None
...
...
imperative/python/test/unit/data/test_dataloader.py
浏览文件 @
cc7b2f16
...
...
@@ -73,6 +73,79 @@ class MyStream(StreamDataset):
raise
StopIteration
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"dataloader do not support parallel on windows"
,
)
@
pytest
.
mark
.
skipif
(
multiprocessing
.
get_start_method
()
!=
"fork"
,
reason
=
"the runtime error is only raised when fork"
,
)
def
test_dataloader_worker_signal_exception
():
dataset
=
init_dataset
()
class
FakeErrorTransform
(
Transform
):
def
__init__
(
self
):
pass
def
apply
(
self
,
input
):
pid
=
os
.
getpid
()
subprocess
.
run
([
"kill"
,
"-11"
,
str
(
pid
)])
return
input
dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
transform
=
FakeErrorTransform
(),
num_workers
=
2
,
)
with
pytest
.
raises
(
RuntimeError
,
match
=
r
"DataLoader worker.* exited unexpectedly"
):
data_iter
=
iter
(
dataloader
)
batch_data
=
next
(
data_iter
)
class
IndexErrorTransform
(
Transform
):
def
__init__
(
self
):
self
.
array
=
[
0
,
1
,
2
]
def
apply
(
self
,
input
):
error_item
=
self
.
array
[
3
]
return
input
class
TypeErrorTransform
(
Transform
):
def
__init__
(
self
):
self
.
adda
=
1
self
.
addb
=
"2"
def
apply
(
self
,
input
):
error_item
=
self
.
adda
+
self
.
addb
return
input
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"dataloader do not support parallel on windows"
,
)
@
pytest
.
mark
.
parametrize
(
"transform"
,
[
IndexErrorTransform
(),
TypeErrorTransform
()])
def
test_dataloader_worker_baseerror
(
transform
):
dataset
=
init_dataset
()
dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
transform
=
transform
,
num_workers
=
2
,
)
with
pytest
.
raises
(
RuntimeError
,
match
=
r
"Caught .*Error in DataLoader worker"
):
data_iter
=
iter
(
dataloader
)
batch_data
=
next
(
data_iter
)
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader
(
num_workers
):
dataset
=
MyStream
(
100
)
...
...
@@ -116,6 +189,10 @@ def test_dataloader_serial():
assert
label
.
shape
==
(
4
,)
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
def
test_dataloader_parallel
():
# set max shared memory to 100M
os
.
environ
[
"MGE_PLASMA_MEMORY"
]
=
"100000000"
...
...
@@ -214,6 +291,10 @@ def _multi_instances_parallel_dataloader_worker():
assert
val_label
.
shape
==
(
10
,)
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
def
test_dataloader_parallel_multi_instances
():
# set max shared memory to 100M
os
.
environ
[
"MGE_PLASMA_MEMORY"
]
=
"100000000"
...
...
@@ -265,6 +346,10 @@ class MyPreStream(StreamDataset):
raise
StopIteration
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"dataloader do not support parallel on windows"
,
...
...
imperative/python/test/unit/data/test_pre_dataloader.py
浏览文件 @
cc7b2f16
...
...
@@ -78,6 +78,10 @@ class MyStream(StreamDataset):
raise
StopIteration
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader
(
num_workers
):
dataset
=
MyStream
(
100
)
...
...
@@ -127,6 +131,10 @@ def test_dataloader_serial():
assert
label
.
_tuple_shape
==
(
4
,)
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
def
test_dataloader_parallel
():
# set max shared memory to 100M
os
.
environ
[
"MGE_PLASMA_MEMORY"
]
=
"100000000"
...
...
@@ -230,6 +238,10 @@ def _multi_instances_parallel_dataloader_worker():
assert
val_label
.
_tuple_shape
==
(
10
,)
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
def
test_dataloader_parallel_multi_instances
():
# set max shared memory to 100M
os
.
environ
[
"MGE_PLASMA_MEMORY"
]
=
"100000000"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录