提交 7af49c98 编写于 作者: M Megvii Engine Team

fix(imperative): fix pyarrow serialization warning

GitOrigin-RevId: 3e61ee70d7b8b57c403a80ba7f7f0064aa22da8b
上级 4365f158
...@@ -4,6 +4,7 @@ import gc ...@@ -4,6 +4,7 @@ import gc
import itertools import itertools
import multiprocessing import multiprocessing
import os import os
import pickle
import platform import platform
import queue import queue
import random import random
...@@ -37,7 +38,7 @@ except: ...@@ -37,7 +38,7 @@ except:
if platform.system() != "Windows": if platform.system() != "Windows":
import pyarrow import pyarrow
from .tools._queue import _ExceptionWrapper, context from .tools._queue import _ExceptionWrapper
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -330,9 +331,9 @@ class _ParallelDataLoaderIter: ...@@ -330,9 +331,9 @@ class _ParallelDataLoaderIter:
def _process_data(self, data): def _process_data(self, data):
self._rcvd_idx += 1 self._rcvd_idx += 1
self._try_put_index() self._try_put_index()
if isinstance(data, pyarrow.lib.Buffer): if isinstance(data, bytes):
exception = pyarrow.deserialize(data, context=context) data = pickle.loads(data)
exception.reraise() data.reraise()
return data return data
def _get_data(self): def _get_data(self):
...@@ -369,8 +370,8 @@ class _ParallelDataLoaderIter: ...@@ -369,8 +370,8 @@ class _ParallelDataLoaderIter:
_get_data = self._get_data() _get_data = self._get_data()
if len(_get_data) == 1: if len(_get_data) == 1:
assert isinstance(_get_data[0], pyarrow.lib.Buffer) assert isinstance(_get_data[0], bytes)
exception = pyarrow.deserialize(_get_data[0], context=context) exception = pickle.loads(_get_data[0])
exception.reraise() exception.reraise()
self._try_put_index() self._try_put_index()
continue continue
...@@ -725,7 +726,7 @@ def _worker_loop( ...@@ -725,7 +726,7 @@ def _worker_loop(
where = "in DataLoader worker process {}".format(worker_id) where = "in DataLoader worker process {}".format(worker_id)
exc_msg = "".join(traceback.format_exception(*exc_info)) exc_msg = "".join(traceback.format_exception(*exc_info))
data = _ExceptionWrapper(exc_info[0].__name__, exc_msg, where) data = _ExceptionWrapper(exc_info[0].__name__, exc_msg, where)
data = pyarrow.serialize(data, context=context).to_buffer() data = pickle.dumps(data)
data_queue.put((idx, data)) data_queue.put((idx, data))
del data, idx, place_holder, r del data, idx, place_holder, r
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import binascii import binascii
import os import os
import pickle
import queue import queue
import subprocess import subprocess
from multiprocessing import Queue from multiprocessing import Queue
...@@ -60,15 +61,6 @@ class _ExceptionWrapper: ...@@ -60,15 +61,6 @@ class _ExceptionWrapper:
return _ExceptionWrapper(data["exc_type"], data["exc_msg"], data["where"]) return _ExceptionWrapper(data["exc_type"], data["exc_msg"], data["where"])
context = pyarrow.SerializationContext()
context.register_type(
_ExceptionWrapper,
"_ExceptionWrapper",
custom_serializer=_ExceptionWrapper._serialize_Exception,
custom_deserializer=_ExceptionWrapper._deserialize_Exception,
)
class _PlasmaStoreManager: class _PlasmaStoreManager:
__initialized = False __initialized = False
...@@ -137,7 +129,7 @@ class PlasmaShmQueue: ...@@ -137,7 +129,7 @@ class PlasmaShmQueue:
def get_error(self, exc_type, where="in background"): def get_error(self, exc_type, where="in background"):
data = _ExceptionWrapper(exc_type=exc_type, where=where) data = _ExceptionWrapper(exc_type=exc_type, where=where)
data_buffer = pyarrow.serialize(data, context=context).to_buffer() data_buffer = pickle.dumps(data)
return data_buffer return data_buffer
def put(self, data, block=True, timeout=None): def put(self, data, block=True, timeout=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册