From 5308b081699298f6049763c33fc1dad553596a80 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 20 Jun 2023 17:08:39 +0800 Subject: [PATCH] fix(data): fix pyarrow.plasma import error in pyarrow1.12 GitOrigin-RevId: b5e1cd3be59cc80a3cc5bf6a83855ede2a2cd38a --- imperative/python/megengine/data/dataloader.py | 7 ++----- imperative/python/megengine/data/tools/_queue.py | 8 +++++++- .../python/test/unit/data/test_dataloader.py | 16 ++++++++++++++++ .../python/test/unit/data/test_pre_dataloader.py | 12 ++++++++++++ 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/imperative/python/megengine/data/dataloader.py b/imperative/python/megengine/data/dataloader.py index 256fa0c8c..853d556b0 100644 --- a/imperative/python/megengine/data/dataloader.py +++ b/imperative/python/megengine/data/dataloader.py @@ -36,11 +36,6 @@ try: except: import _thread as thread -if platform.system() != "Windows": - import pyarrow - from .tools._queue import _ExceptionWrapper - - logger = get_logger(__name__) @@ -722,6 +717,8 @@ def _worker_loop( data = worker_id iteration_end = True else: + from .tools._queue import _ExceptionWrapper + exc_info = sys.exc_info() where = "in DataLoader worker process {}".format(worker_id) exc_msg = "".join(traceback.format_exception(*exc_info)) diff --git a/imperative/python/megengine/data/tools/_queue.py b/imperative/python/megengine/data/tools/_queue.py index eb34685d0..2fab0a6fc 100644 --- a/imperative/python/megengine/data/tools/_queue.py +++ b/imperative/python/megengine/data/tools/_queue.py @@ -7,12 +7,18 @@ import subprocess from multiprocessing import Queue import pyarrow -import pyarrow.plasma as plasma from ...logger import get_logger logger = get_logger(__name__) +try: + import pyarrow.plasma as plasma +except ModuleNotFoundError: + raise RuntimeError( + "pyarrow remove plasma in version 12.0.0, please use pyarrow vserion < 12.0.0" + ) + # Each process only need to start one plasma store, so we set it as a global variable. # TODO: how to share between different processes? MGE_PLASMA_STORE_MANAGER = None diff --git a/imperative/python/test/unit/data/test_dataloader.py b/imperative/python/test/unit/data/test_dataloader.py index 43498f840..91f8b4654 100644 --- a/imperative/python/test/unit/data/test_dataloader.py +++ b/imperative/python/test/unit/data/test_dataloader.py @@ -143,6 +143,10 @@ def test_dataloader_worker_baseerror(transform): batch_data = next(data_iter) +@pytest.mark.skipif( + np.__version__ >= "1.20.0", + reason="pyarrow is incompatible with numpy vserion 1.20.0", +) @pytest.mark.parametrize("num_workers", [0, 2]) def test_stream_dataloader(num_workers): dataset = MyStream(100) @@ -186,6 +190,10 @@ def test_dataloader_serial(): assert label.shape == (4,) +@pytest.mark.skipif( + np.__version__ >= "1.20.0", + reason="pyarrow is incompatible with numpy vserion 1.20.0", +) def test_dataloader_parallel(): # set max shared memory to 100M os.environ["MGE_PLASMA_MEMORY"] = "100000000" @@ -286,6 +294,10 @@ def _multi_instances_parallel_dataloader_worker(): assert val_label.shape == (10,) +@pytest.mark.skipif( + np.__version__ >= "1.20.0", + reason="pyarrow is incompatible with numpy vserion 1.20.0", +) def test_dataloader_parallel_multi_instances(): # set max shared memory to 100M os.environ["MGE_PLASMA_MEMORY"] = "100000000" @@ -337,6 +349,10 @@ class MyPreStream(StreamDataset): raise StopIteration +@pytest.mark.skipif( + np.__version__ >= "1.20.0", + reason="pyarrow is incompatible with numpy vserion 1.20.0", +) @pytest.mark.skipif( platform.system() == "Windows", reason="dataloader do not support parallel on windows", diff --git a/imperative/python/test/unit/data/test_pre_dataloader.py b/imperative/python/test/unit/data/test_pre_dataloader.py index f1f18530e..eb0fd5a46 100644 --- a/imperative/python/test/unit/data/test_pre_dataloader.py +++ b/imperative/python/test/unit/data/test_pre_dataloader.py @@ -78,6 +78,10 @@ class MyStream(StreamDataset): raise StopIteration +@pytest.mark.skipif( + np.__version__ >= "1.20.0", + reason="pyarrow is incompatible with numpy vserion 1.20.0", +) @pytest.mark.parametrize("num_workers", [0, 2]) def test_stream_dataloader(num_workers): dataset = MyStream(100) @@ -127,6 +131,10 @@ def test_dataloader_serial(): assert label._tuple_shape == (4,) +@pytest.mark.skipif( + np.__version__ >= "1.20.0", + reason="pyarrow is incompatible with numpy vserion 1.20.0", +) def test_dataloader_parallel(): # set max shared memory to 100M os.environ["MGE_PLASMA_MEMORY"] = "100000000" @@ -232,6 +240,10 @@ def _multi_instances_parallel_dataloader_worker(): assert val_label._tuple_shape == (10,) +@pytest.mark.skipif( + np.__version__ >= "1.20.0", + reason="pyarrow is incompatible with numpy vserion 1.20.0", +) def test_dataloader_parallel_multi_instances(): # set max shared memory to 100M os.environ["MGE_PLASMA_MEMORY"] = "100000000" -- GitLab