From 7af49c989435c545ab91a234b2874662583dbe0f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 13 Jun 2023 20:15:55 +0800 Subject: [PATCH] fix(imperative): fix pyarrow serialization warning GitOrigin-RevId: 3e61ee70d7b8b57c403a80ba7f7f0064aa22da8b --- imperative/python/megengine/data/dataloader.py | 15 ++++++++------- imperative/python/megengine/data/tools/_queue.py | 12 ++---------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/imperative/python/megengine/data/dataloader.py b/imperative/python/megengine/data/dataloader.py index d8e5929b4..256fa0c8c 100644 --- a/imperative/python/megengine/data/dataloader.py +++ b/imperative/python/megengine/data/dataloader.py @@ -4,6 +4,7 @@ import gc import itertools import multiprocessing import os +import pickle import platform import queue import random @@ -37,7 +38,7 @@ except: if platform.system() != "Windows": import pyarrow - from .tools._queue import _ExceptionWrapper, context + from .tools._queue import _ExceptionWrapper logger = get_logger(__name__) @@ -330,9 +331,9 @@ class _ParallelDataLoaderIter: def _process_data(self, data): self._rcvd_idx += 1 self._try_put_index() - if isinstance(data, pyarrow.lib.Buffer): - exception = pyarrow.deserialize(data, context=context) - exception.reraise() + if isinstance(data, bytes): + data = pickle.loads(data) + data.reraise() return data def _get_data(self): @@ -369,8 +370,8 @@ class _ParallelDataLoaderIter: _get_data = self._get_data() if len(_get_data) == 1: - assert isinstance(_get_data[0], pyarrow.lib.Buffer) - exception = pyarrow.deserialize(_get_data[0], context=context) + assert isinstance(_get_data[0], bytes) + exception = pickle.loads(_get_data[0]) exception.reraise() self._try_put_index() continue @@ -725,7 +726,7 @@ def _worker_loop( where = "in DataLoader worker process {}".format(worker_id) exc_msg = "".join(traceback.format_exception(*exc_info)) data = _ExceptionWrapper(exc_info[0].__name__, exc_msg, where) - data = pyarrow.serialize(data, context=context).to_buffer() + data = pickle.dumps(data) data_queue.put((idx, data)) del data, idx, place_holder, r diff --git a/imperative/python/megengine/data/tools/_queue.py b/imperative/python/megengine/data/tools/_queue.py index b4e174f89..eb34685d0 100644 --- a/imperative/python/megengine/data/tools/_queue.py +++ b/imperative/python/megengine/data/tools/_queue.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import binascii import os +import pickle import queue import subprocess from multiprocessing import Queue @@ -60,15 +61,6 @@ class _ExceptionWrapper: return _ExceptionWrapper(data["exc_type"], data["exc_msg"], data["where"]) -context = pyarrow.SerializationContext() -context.register_type( - _ExceptionWrapper, - "_ExceptionWrapper", - custom_serializer=_ExceptionWrapper._serialize_Exception, - custom_deserializer=_ExceptionWrapper._deserialize_Exception, -) - - class _PlasmaStoreManager: __initialized = False @@ -137,7 +129,7 @@ class PlasmaShmQueue: def get_error(self, exc_type, where="in background"): data = _ExceptionWrapper(exc_type=exc_type, where=where) - data_buffer = pyarrow.serialize(data, context=context).to_buffer() + data_buffer = pickle.dumps(data) return data_buffer def put(self, data, block=True, timeout=None): -- GitLab