Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
比较版本
d4fbffe38a5bde21f3a5847f8d5a4df5dad04370...e6c9ddd362ab4fd1f8e81831101d041aea3d4ad6
MegEngine
项目概览
MegEngine 天元
/
MegEngine
8 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
源分支
e6c9ddd362ab4fd1f8e81831101d041aea3d4ad6
选择Git版本
...
目标分支
d4fbffe38a5bde21f3a5847f8d5a4df5dad04370
选择Git版本
比较
Commits (3)
https://gitcode.net/megvii/megengine/-/commit/cc7b2f169807146bcaf351843012c6637d86940b
fix(data): fix pyarrow.plasma import error in pyarrow1.12
2023-07-25T16:17:55+08:00
Megvii Engine Team
megengine@megvii.com
GitOrigin-RevId: b5e1cd3be59cc80a3cc5bf6a83855ede2a2cd38a
https://gitcode.net/megvii/megengine/-/commit/5fb2e8e158d3d6176b97e13d8d5ae6fed7ab91a0
fix(data): fix pyarrow.plasma import error in pyarrow1.12
2023-07-25T16:22:21+08:00
Megvii Engine Team
megengine@megvii.com
GitOrigin-RevId: b5e1cd3be59cc80a3cc5bf6a83855ede2a2cd38a
https://gitcode.net/megvii/megengine/-/commit/e6c9ddd362ab4fd1f8e81831101d041aea3d4ad6
fix(ci): fix pyarrow package version
2023-07-25T16:23:09+08:00
Megvii Engine Team
megengine@megvii.com
GitOrigin-RevId: 2ecdd7267c5d78d2eb0551115c5c5ed66d0b4b9c
隐藏空白更改
内联
并排
Showing
5 changed file
with
120 addition
and
4 deletion
+120
-4
imperative/python/megengine/data/dataloader.py
imperative/python/megengine/data/dataloader.py
+8
-2
imperative/python/megengine/data/tools/_queue.py
imperative/python/megengine/data/tools/_queue.py
+14
-1
imperative/python/requires.txt
imperative/python/requires.txt
+1
-1
imperative/python/test/unit/data/test_dataloader.py
imperative/python/test/unit/data/test_dataloader.py
+85
-0
imperative/python/test/unit/data/test_pre_dataloader.py
imperative/python/test/unit/data/test_pre_dataloader.py
+12
-0
未找到文件。
imperative/python/megengine/data/dataloader.py
浏览文件 @
e6c9ddd3
...
...
@@ -27,7 +27,6 @@ try:
except
:
import
_thread
as
thread
logger
=
get_logger
(
__name__
)
...
...
@@ -691,7 +690,14 @@ def _worker_loop(
data
=
worker_id
iteration_end
=
True
else
:
raise
e
from
.tools._queue
import
_ExceptionWrapper
exc_info
=
sys
.
exc_info
()
where
=
"in DataLoader worker process {}"
.
format
(
worker_id
)
exc_msg
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
data
=
_ExceptionWrapper
(
exc_info
[
0
].
__name__
,
exc_msg
,
where
)
data
=
pickle
.
dumps
(
data
)
data_queue
.
put
((
idx
,
data
))
del
data
,
idx
,
place_holder
,
r
...
...
imperative/python/megengine/data/tools/_queue.py
浏览文件 @
e6c9ddd3
...
...
@@ -6,10 +6,23 @@ import subprocess
from
multiprocessing
import
Queue
import
pyarrow
import
pyarrow.plasma
as
plasma
MGE_PLASMA_MEMORY
=
int
(
os
.
environ
.
get
(
"MGE_PLASMA_MEMORY"
,
4000000000
))
# 4GB
try
:
import
pyarrow.plasma
as
plasma
except
ModuleNotFoundError
:
raise
RuntimeError
(
"pyarrow remove plasma in version 12.0.0, please use pyarrow vserion < 12.0.0"
)
try
:
import
pyarrow.plasma
as
plasma
except
ModuleNotFoundError
:
raise
RuntimeError
(
"pyarrow remove plasma in version 12.0.0, please use pyarrow vserion < 12.0.0"
)
# Each process only need to start one plasma store, so we set it as a global variable.
# TODO: how to share between different processes?
MGE_PLASMA_STORE_MANAGER
=
None
...
...
imperative/python/requires.txt
浏览文件 @
e6c9ddd3
numpy>=1.18
opencv-python
pyarrow
pyarrow
<=11.0.0
requests
tabulate
tqdm
...
...
imperative/python/test/unit/data/test_dataloader.py
浏览文件 @
e6c9ddd3
...
...
@@ -73,6 +73,79 @@ class MyStream(StreamDataset):
raise
StopIteration
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"dataloader do not support parallel on windows"
,
)
@
pytest
.
mark
.
skipif
(
multiprocessing
.
get_start_method
()
!=
"fork"
,
reason
=
"the runtime error is only raised when fork"
,
)
def
test_dataloader_worker_signal_exception
():
dataset
=
init_dataset
()
class
FakeErrorTransform
(
Transform
):
def
__init__
(
self
):
pass
def
apply
(
self
,
input
):
pid
=
os
.
getpid
()
subprocess
.
run
([
"kill"
,
"-11"
,
str
(
pid
)])
return
input
dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
transform
=
FakeErrorTransform
(),
num_workers
=
2
,
)
with
pytest
.
raises
(
RuntimeError
,
match
=
r
"DataLoader worker.* exited unexpectedly"
):
data_iter
=
iter
(
dataloader
)
batch_data
=
next
(
data_iter
)
class
IndexErrorTransform
(
Transform
):
def
__init__
(
self
):
self
.
array
=
[
0
,
1
,
2
]
def
apply
(
self
,
input
):
error_item
=
self
.
array
[
3
]
return
input
class
TypeErrorTransform
(
Transform
):
def
__init__
(
self
):
self
.
adda
=
1
self
.
addb
=
"2"
def
apply
(
self
,
input
):
error_item
=
self
.
adda
+
self
.
addb
return
input
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"dataloader do not support parallel on windows"
,
)
@
pytest
.
mark
.
parametrize
(
"transform"
,
[
IndexErrorTransform
(),
TypeErrorTransform
()])
def
test_dataloader_worker_baseerror
(
transform
):
dataset
=
init_dataset
()
dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
transform
=
transform
,
num_workers
=
2
,
)
with
pytest
.
raises
(
RuntimeError
,
match
=
r
"Caught .*Error in DataLoader worker"
):
data_iter
=
iter
(
dataloader
)
batch_data
=
next
(
data_iter
)
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader
(
num_workers
):
dataset
=
MyStream
(
100
)
...
...
@@ -116,6 +189,10 @@ def test_dataloader_serial():
assert
label
.
shape
==
(
4
,)
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
def
test_dataloader_parallel
():
# set max shared memory to 100M
os
.
environ
[
"MGE_PLASMA_MEMORY"
]
=
"100000000"
...
...
@@ -214,6 +291,10 @@ def _multi_instances_parallel_dataloader_worker():
assert
val_label
.
shape
==
(
10
,)
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
def
test_dataloader_parallel_multi_instances
():
# set max shared memory to 100M
os
.
environ
[
"MGE_PLASMA_MEMORY"
]
=
"100000000"
...
...
@@ -265,6 +346,10 @@ class MyPreStream(StreamDataset):
raise
StopIteration
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"dataloader do not support parallel on windows"
,
...
...
imperative/python/test/unit/data/test_pre_dataloader.py
浏览文件 @
e6c9ddd3
...
...
@@ -78,6 +78,10 @@ class MyStream(StreamDataset):
raise
StopIteration
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader
(
num_workers
):
dataset
=
MyStream
(
100
)
...
...
@@ -127,6 +131,10 @@ def test_dataloader_serial():
assert
label
.
_tuple_shape
==
(
4
,)
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
def
test_dataloader_parallel
():
# set max shared memory to 100M
os
.
environ
[
"MGE_PLASMA_MEMORY"
]
=
"100000000"
...
...
@@ -230,6 +238,10 @@ def _multi_instances_parallel_dataloader_worker():
assert
val_label
.
_tuple_shape
==
(
10
,)
@
pytest
.
mark
.
skipif
(
np
.
__version__
>=
"1.20.0"
,
reason
=
"pyarrow is incompatible with numpy vserion 1.20.0"
,
)
def
test_dataloader_parallel_multi_instances
():
# set max shared memory to 100M
os
.
environ
[
"MGE_PLASMA_MEMORY"
]
=
"100000000"
...
...