From cc7b2f169807146bcaf351843012c6637d86940b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 20 Jun 2023 17:08:39 +0800 Subject: [PATCH] fix(data): fix pyarrow.plasma import error in pyarrow1.12 GitOrigin-RevId: b5e1cd3be59cc80a3cc5bf6a83855ede2a2cd38a --- .../python/megengine/data/dataloader.py | 9 +- .../python/megengine/data/tools/_queue.py | 8 +- .../python/test/unit/data/test_dataloader.py | 85 +++++++++++++++++++ .../test/unit/data/test_pre_dataloader.py | 12 +++ 4 files changed, 112 insertions(+), 2 deletions(-) diff --git a/imperative/python/megengine/data/dataloader.py b/imperative/python/megengine/data/dataloader.py index 766342198..6cbe88dc0 100644 --- a/imperative/python/megengine/data/dataloader.py +++ b/imperative/python/megengine/data/dataloader.py @@ -691,7 +691,14 @@ def _worker_loop( data = worker_id iteration_end = True else: - raise e + from .tools._queue import _ExceptionWrapper + + exc_info = sys.exc_info() + where = "in DataLoader worker process {}".format(worker_id) + exc_msg = "".join(traceback.format_exception(*exc_info)) + data = _ExceptionWrapper(exc_info[0].__name__, exc_msg, where) + data = pickle.dumps(data) + data_queue.put((idx, data)) del data, idx, place_holder, r diff --git a/imperative/python/megengine/data/tools/_queue.py b/imperative/python/megengine/data/tools/_queue.py index f44641b07..bc3ad98bd 100644 --- a/imperative/python/megengine/data/tools/_queue.py +++ b/imperative/python/megengine/data/tools/_queue.py @@ -6,10 +6,16 @@ import subprocess from multiprocessing import Queue import pyarrow -import pyarrow.plasma as plasma MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB +try: + import pyarrow.plasma as plasma +except ModuleNotFoundError: + raise RuntimeError( + "pyarrow remove plasma in version 12.0.0, please use pyarrow vserion < 12.0.0" + ) + # Each process only need to start one plasma store, so we set it as a global variable. # TODO: how to share between different processes? MGE_PLASMA_STORE_MANAGER = None diff --git a/imperative/python/test/unit/data/test_dataloader.py b/imperative/python/test/unit/data/test_dataloader.py index e90679989..5f8f6c3c7 100644 --- a/imperative/python/test/unit/data/test_dataloader.py +++ b/imperative/python/test/unit/data/test_dataloader.py @@ -73,6 +73,79 @@ class MyStream(StreamDataset): raise StopIteration +@pytest.mark.skipif( + platform.system() == "Windows", + reason="dataloader do not support parallel on windows", +) +@pytest.mark.skipif( + multiprocessing.get_start_method() != "fork", + reason="the runtime error is only raised when fork", +) +def test_dataloader_worker_signal_exception(): + dataset = init_dataset() + + class FakeErrorTransform(Transform): + def __init__(self): + pass + + def apply(self, input): + pid = os.getpid() + subprocess.run(["kill", "-11", str(pid)]) + return input + + dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=4, drop_last=False), + transform=FakeErrorTransform(), + num_workers=2, + ) + with pytest.raises(RuntimeError, match=r"DataLoader worker.* exited unexpectedly"): + data_iter = iter(dataloader) + batch_data = next(data_iter) + + +class IndexErrorTransform(Transform): + def __init__(self): + self.array = [0, 1, 2] + + def apply(self, input): + error_item = self.array[3] + return input + + +class TypeErrorTransform(Transform): + def __init__(self): + self.adda = 1 + self.addb = "2" + + def apply(self, input): + error_item = self.adda + self.addb + return input + + +@pytest.mark.skipif( + platform.system() == "Windows", + reason="dataloader do not support parallel on windows", +) +@pytest.mark.parametrize("transform", [IndexErrorTransform(), TypeErrorTransform()]) +def test_dataloader_worker_baseerror(transform): + dataset = init_dataset() + + dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=4, drop_last=False), + transform=transform, + num_workers=2, + ) + with pytest.raises(RuntimeError, match=r"Caught .*Error in DataLoader worker"): + data_iter = iter(dataloader) + batch_data = next(data_iter) + + +@pytest.mark.skipif( + np.__version__ >= "1.20.0", + reason="pyarrow is incompatible with numpy vserion 1.20.0", +) @pytest.mark.parametrize("num_workers", [0, 2]) def test_stream_dataloader(num_workers): dataset = MyStream(100) @@ -116,6 +189,10 @@ def test_dataloader_serial(): assert label.shape == (4,) +@pytest.mark.skipif( + np.__version__ >= "1.20.0", + reason="pyarrow is incompatible with numpy vserion 1.20.0", +) def test_dataloader_parallel(): # set max shared memory to 100M os.environ["MGE_PLASMA_MEMORY"] = "100000000" @@ -214,6 +291,10 @@ def _multi_instances_parallel_dataloader_worker(): assert val_label.shape == (10,) +@pytest.mark.skipif( + np.__version__ >= "1.20.0", + reason="pyarrow is incompatible with numpy vserion 1.20.0", +) def test_dataloader_parallel_multi_instances(): # set max shared memory to 100M os.environ["MGE_PLASMA_MEMORY"] = "100000000" @@ -265,6 +346,10 @@ class MyPreStream(StreamDataset): raise StopIteration +@pytest.mark.skipif( + np.__version__ >= "1.20.0", + reason="pyarrow is incompatible with numpy vserion 1.20.0", +) @pytest.mark.skipif( platform.system() == "Windows", reason="dataloader do not support parallel on windows", diff --git a/imperative/python/test/unit/data/test_pre_dataloader.py b/imperative/python/test/unit/data/test_pre_dataloader.py index 390752870..c7db34d65 100644 --- a/imperative/python/test/unit/data/test_pre_dataloader.py +++ b/imperative/python/test/unit/data/test_pre_dataloader.py @@ -78,6 +78,10 @@ class MyStream(StreamDataset): raise StopIteration +@pytest.mark.skipif( + np.__version__ >= "1.20.0", + reason="pyarrow is incompatible with numpy vserion 1.20.0", +) @pytest.mark.parametrize("num_workers", [0, 2]) def test_stream_dataloader(num_workers): dataset = MyStream(100) @@ -127,6 +131,10 @@ def test_dataloader_serial(): assert label._tuple_shape == (4,) +@pytest.mark.skipif( + np.__version__ >= "1.20.0", + reason="pyarrow is incompatible with numpy vserion 1.20.0", +) def test_dataloader_parallel(): # set max shared memory to 100M os.environ["MGE_PLASMA_MEMORY"] = "100000000" @@ -230,6 +238,10 @@ def _multi_instances_parallel_dataloader_worker(): assert val_label._tuple_shape == (10,) +@pytest.mark.skipif( + np.__version__ >= "1.20.0", + reason="pyarrow is incompatible with numpy vserion 1.20.0", +) def test_dataloader_parallel_multi_instances(): # set max shared memory to 100M os.environ["MGE_PLASMA_MEMORY"] = "100000000" -- GitLab