提交 7af49c98 编写于 作者: M Megvii Engine Team

fix(imperative): fix pyarrow serialization warning

GitOrigin-RevId: 3e61ee70d7b8b57c403a80ba7f7f0064aa22da8b
上级 4365f158
......@@ -4,6 +4,7 @@ import gc
import itertools
import multiprocessing
import os
import pickle
import platform
import queue
import random
......@@ -37,7 +38,7 @@ except:
if platform.system() != "Windows":
import pyarrow
from .tools._queue import _ExceptionWrapper, context
from .tools._queue import _ExceptionWrapper
logger = get_logger(__name__)
......@@ -330,9 +331,9 @@ class _ParallelDataLoaderIter:
def _process_data(self, data):
self._rcvd_idx += 1
self._try_put_index()
if isinstance(data, pyarrow.lib.Buffer):
exception = pyarrow.deserialize(data, context=context)
exception.reraise()
if isinstance(data, bytes):
data = pickle.loads(data)
data.reraise()
return data
def _get_data(self):
......@@ -369,8 +370,8 @@ class _ParallelDataLoaderIter:
_get_data = self._get_data()
if len(_get_data) == 1:
assert isinstance(_get_data[0], pyarrow.lib.Buffer)
exception = pyarrow.deserialize(_get_data[0], context=context)
assert isinstance(_get_data[0], bytes)
exception = pickle.loads(_get_data[0])
exception.reraise()
self._try_put_index()
continue
......@@ -725,7 +726,7 @@ def _worker_loop(
where = "in DataLoader worker process {}".format(worker_id)
exc_msg = "".join(traceback.format_exception(*exc_info))
data = _ExceptionWrapper(exc_info[0].__name__, exc_msg, where)
data = pyarrow.serialize(data, context=context).to_buffer()
data = pickle.dumps(data)
data_queue.put((idx, data))
del data, idx, place_holder, r
......
# -*- coding: utf-8 -*-
import binascii
import os
import pickle
import queue
import subprocess
from multiprocessing import Queue
......@@ -60,15 +61,6 @@ class _ExceptionWrapper:
return _ExceptionWrapper(data["exc_type"], data["exc_msg"], data["where"])
context = pyarrow.SerializationContext()
context.register_type(
_ExceptionWrapper,
"_ExceptionWrapper",
custom_serializer=_ExceptionWrapper._serialize_Exception,
custom_deserializer=_ExceptionWrapper._deserialize_Exception,
)
class _PlasmaStoreManager:
__initialized = False
......@@ -137,7 +129,7 @@ class PlasmaShmQueue:
def get_error(self, exc_type, where="in background"):
data = _ExceptionWrapper(exc_type=exc_type, where=where)
data_buffer = pyarrow.serialize(data, context=context).to_buffer()
data_buffer = pickle.dumps(data)
return data_buffer
def put(self, data, block=True, timeout=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册