Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5308b081
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5308b081
编写于
6月 20, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(data): fix pyarrow.plasma import error in pyarrow1.12
GitOrigin-RevId: b5e1cd3be59cc80a3cc5bf6a83855ede2a2cd38a
上级
6486428f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
37 addition
and
6 deletion
+37
-6
imperative/python/megengine/data/dataloader.py
imperative/python/megengine/data/dataloader.py
+2
-5
imperative/python/megengine/data/tools/_queue.py
imperative/python/megengine/data/tools/_queue.py
+7
-1
imperative/python/test/unit/data/test_dataloader.py
imperative/python/test/unit/data/test_dataloader.py
+16
-0
imperative/python/test/unit/data/test_pre_dataloader.py
imperative/python/test/unit/data/test_pre_dataloader.py
+12
-0
未找到文件。
imperative/python/megengine/data/dataloader.py
浏览文件 @
5308b081
...
...
@@ -36,11 +36,6 @@ try:
except
:
import
_thread
as
thread
if
platform
.
system
()
!=
"Windows"
:
import
pyarrow
from
.tools._queue
import
_ExceptionWrapper
logger
=
get_logger
(
__name__
)
...
...
@@ -722,6 +717,8 @@ def _worker_loop(
data
=
worker_id
iteration_end
=
True
else
:
from
.tools._queue
import
_ExceptionWrapper
exc_info
=
sys
.
exc_info
()
where
=
"in DataLoader worker process {}"
.
format
(
worker_id
)
exc_msg
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
...
...
imperative/python/megengine/data/tools/_queue.py
浏览文件 @
5308b081
...
...
@@ -7,12 +7,18 @@ import subprocess
from
multiprocessing
import
Queue
import
pyarrow
import
pyarrow.plasma
as
plasma
from
...logger
import
get_logger
logger
=
get_logger
(
__name__
)
try
:
import
pyarrow.plasma
as
plasma
except
ModuleNotFoundError
:
raise
RuntimeError
(
"pyarrow remove plasma in version 12.0.0, please use pyarrow vserion < 12.0.0"
)
# Each process only need to start one plasma store, so we set it as a global variable.
# TODO: how to share between different processes?
MGE_PLASMA_STORE_MANAGER
=
None
...
...
imperative/python/test/unit/data/test_dataloader.py
浏览文件 @
5308b081
...
...
@@ -143,6 +143,10 @@ def test_dataloader_worker_baseerror(transform):
batch_data
=
next
(
data_iter
)
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader
(
num_workers
):
dataset
=
MyStream
(
100
)
...
...
@@ -186,6 +190,10 @@ def test_dataloader_serial():
assert
label
.
shape
==
(
4
,)
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
def
test_dataloader_parallel
():
# set max shared memory to 100M
os
.
environ
[
"MGE_PLASMA_MEMORY"
]
=
"100000000"
...
...
@@ -286,6 +294,10 @@ def _multi_instances_parallel_dataloader_worker():
assert
val_label
.
shape
==
(
10
,)
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
def
test_dataloader_parallel_multi_instances
():
# set max shared memory to 100M
os
.
environ
[
"MGE_PLASMA_MEMORY"
]
=
"100000000"
...
...
@@ -337,6 +349,10 @@ class MyPreStream(StreamDataset):
raise
StopIteration
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"dataloader do not support parallel on windows"
,
...
...
imperative/python/test/unit/data/test_pre_dataloader.py
浏览文件 @
5308b081
...
...
@@ -78,6 +78,10 @@ class MyStream(StreamDataset):
raise
StopIteration
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader
(
num_workers
):
dataset
=
MyStream
(
100
)
...
...
@@ -127,6 +131,10 @@ def test_dataloader_serial():
assert
label
.
_tuple_shape
==
(
4
,)
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
def
test_dataloader_parallel
():
# set max shared memory to 100M
os
.
environ
[
"MGE_PLASMA_MEMORY"
]
=
"100000000"
...
...
@@ -232,6 +240,10 @@ def _multi_instances_parallel_dataloader_worker():
assert
val_label
.
_tuple_shape
==
(
10
,)
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
def
test_dataloader_parallel_multi_instances
():
# set max shared memory to 100M
os
.
environ
[
"MGE_PLASMA_MEMORY"
]
=
"100000000"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录