Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7af49c98
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
7af49c98
编写于
6月 13, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative): fix pyarrow serialization warning
GitOrigin-RevId: 3e61ee70d7b8b57c403a80ba7f7f0064aa22da8b
上级
4365f158
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
10 addition
and
17 deletion
+10
-17
imperative/python/megengine/data/dataloader.py
imperative/python/megengine/data/dataloader.py
+8
-7
imperative/python/megengine/data/tools/_queue.py
imperative/python/megengine/data/tools/_queue.py
+2
-10
未找到文件。
imperative/python/megengine/data/dataloader.py
浏览文件 @
7af49c98
...
@@ -4,6 +4,7 @@ import gc
...
@@ -4,6 +4,7 @@ import gc
import
itertools
import
itertools
import
multiprocessing
import
multiprocessing
import
os
import
os
import
pickle
import
platform
import
platform
import
queue
import
queue
import
random
import
random
...
@@ -37,7 +38,7 @@ except:
...
@@ -37,7 +38,7 @@ except:
if
platform
.
system
()
!=
"Windows"
:
if
platform
.
system
()
!=
"Windows"
:
import
pyarrow
import
pyarrow
from
.tools._queue
import
_ExceptionWrapper
,
context
from
.tools._queue
import
_ExceptionWrapper
logger
=
get_logger
(
__name__
)
logger
=
get_logger
(
__name__
)
...
@@ -330,9 +331,9 @@ class _ParallelDataLoaderIter:
...
@@ -330,9 +331,9 @@ class _ParallelDataLoaderIter:
def
_process_data
(
self
,
data
):
def
_process_data
(
self
,
data
):
self
.
_rcvd_idx
+=
1
self
.
_rcvd_idx
+=
1
self
.
_try_put_index
()
self
.
_try_put_index
()
if
isinstance
(
data
,
pyarrow
.
lib
.
Buffer
):
if
isinstance
(
data
,
bytes
):
exception
=
pyarrow
.
deserialize
(
data
,
context
=
context
)
data
=
pickle
.
loads
(
data
)
exception
.
reraise
()
data
.
reraise
()
return
data
return
data
def
_get_data
(
self
):
def
_get_data
(
self
):
...
@@ -369,8 +370,8 @@ class _ParallelDataLoaderIter:
...
@@ -369,8 +370,8 @@ class _ParallelDataLoaderIter:
_get_data
=
self
.
_get_data
()
_get_data
=
self
.
_get_data
()
if
len
(
_get_data
)
==
1
:
if
len
(
_get_data
)
==
1
:
assert
isinstance
(
_get_data
[
0
],
pyarrow
.
lib
.
Buffer
)
assert
isinstance
(
_get_data
[
0
],
bytes
)
exception
=
p
yarrow
.
deserialize
(
_get_data
[
0
],
context
=
context
)
exception
=
p
ickle
.
loads
(
_get_data
[
0
]
)
exception
.
reraise
()
exception
.
reraise
()
self
.
_try_put_index
()
self
.
_try_put_index
()
continue
continue
...
@@ -725,7 +726,7 @@ def _worker_loop(
...
@@ -725,7 +726,7 @@ def _worker_loop(
where
=
"in DataLoader worker process {}"
.
format
(
worker_id
)
where
=
"in DataLoader worker process {}"
.
format
(
worker_id
)
exc_msg
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
exc_msg
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
data
=
_ExceptionWrapper
(
exc_info
[
0
].
__name__
,
exc_msg
,
where
)
data
=
_ExceptionWrapper
(
exc_info
[
0
].
__name__
,
exc_msg
,
where
)
data
=
p
yarrow
.
serialize
(
data
,
context
=
context
).
to_buffer
(
)
data
=
p
ickle
.
dumps
(
data
)
data_queue
.
put
((
idx
,
data
))
data_queue
.
put
((
idx
,
data
))
del
data
,
idx
,
place_holder
,
r
del
data
,
idx
,
place_holder
,
r
...
...
imperative/python/megengine/data/tools/_queue.py
浏览文件 @
7af49c98
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import
binascii
import
binascii
import
os
import
os
import
pickle
import
queue
import
queue
import
subprocess
import
subprocess
from
multiprocessing
import
Queue
from
multiprocessing
import
Queue
...
@@ -60,15 +61,6 @@ class _ExceptionWrapper:
...
@@ -60,15 +61,6 @@ class _ExceptionWrapper:
return
_ExceptionWrapper
(
data
[
"exc_type"
],
data
[
"exc_msg"
],
data
[
"where"
])
return
_ExceptionWrapper
(
data
[
"exc_type"
],
data
[
"exc_msg"
],
data
[
"where"
])
context
=
pyarrow
.
SerializationContext
()
context
.
register_type
(
_ExceptionWrapper
,
"_ExceptionWrapper"
,
custom_serializer
=
_ExceptionWrapper
.
_serialize_Exception
,
custom_deserializer
=
_ExceptionWrapper
.
_deserialize_Exception
,
)
class
_PlasmaStoreManager
:
class
_PlasmaStoreManager
:
__initialized
=
False
__initialized
=
False
...
@@ -137,7 +129,7 @@ class PlasmaShmQueue:
...
@@ -137,7 +129,7 @@ class PlasmaShmQueue:
def
get_error
(
self
,
exc_type
,
where
=
"in background"
):
def
get_error
(
self
,
exc_type
,
where
=
"in background"
):
data
=
_ExceptionWrapper
(
exc_type
=
exc_type
,
where
=
where
)
data
=
_ExceptionWrapper
(
exc_type
=
exc_type
,
where
=
where
)
data_buffer
=
p
yarrow
.
serialize
(
data
,
context
=
context
).
to_buffer
(
)
data_buffer
=
p
ickle
.
dumps
(
data
)
return
data_buffer
return
data_buffer
def
put
(
self
,
data
,
block
=
True
,
timeout
=
None
):
def
put
(
self
,
data
,
block
=
True
,
timeout
=
None
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录