提交 cc7b2f16 编写于 作者: M Megvii Engine Team 提交者: Wanwan1996

fix(data): fix pyarrow.plasma import error in pyarrow1.12

GitOrigin-RevId: b5e1cd3be59cc80a3cc5bf6a83855ede2a2cd38a
上级 d4fbffe3
...@@ -691,7 +691,14 @@ def _worker_loop( ...@@ -691,7 +691,14 @@ def _worker_loop(
data = worker_id data = worker_id
iteration_end = True iteration_end = True
else: else:
raise e from .tools._queue import _ExceptionWrapper
exc_info = sys.exc_info()
where = "in DataLoader worker process {}".format(worker_id)
exc_msg = "".join(traceback.format_exception(*exc_info))
data = _ExceptionWrapper(exc_info[0].__name__, exc_msg, where)
data = pickle.dumps(data)
data_queue.put((idx, data)) data_queue.put((idx, data))
del data, idx, place_holder, r del data, idx, place_holder, r
......
...@@ -6,10 +6,16 @@ import subprocess ...@@ -6,10 +6,16 @@ import subprocess
from multiprocessing import Queue from multiprocessing import Queue
import pyarrow import pyarrow
import pyarrow.plasma as plasma
MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB
try:
import pyarrow.plasma as plasma
except ModuleNotFoundError:
raise RuntimeError(
"pyarrow remove plasma in version 12.0.0, please use pyarrow vserion < 12.0.0"
)
# Each process only need to start one plasma store, so we set it as a global variable. # Each process only need to start one plasma store, so we set it as a global variable.
# TODO: how to share between different processes? # TODO: how to share between different processes?
MGE_PLASMA_STORE_MANAGER = None MGE_PLASMA_STORE_MANAGER = None
......
...@@ -73,6 +73,79 @@ class MyStream(StreamDataset): ...@@ -73,6 +73,79 @@ class MyStream(StreamDataset):
raise StopIteration raise StopIteration
@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
@pytest.mark.skipif(
multiprocessing.get_start_method() != "fork",
reason="the runtime error is only raised when fork",
)
def test_dataloader_worker_signal_exception():
dataset = init_dataset()
class FakeErrorTransform(Transform):
def __init__(self):
pass
def apply(self, input):
pid = os.getpid()
subprocess.run(["kill", "-11", str(pid)])
return input
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
transform=FakeErrorTransform(),
num_workers=2,
)
with pytest.raises(RuntimeError, match=r"DataLoader worker.* exited unexpectedly"):
data_iter = iter(dataloader)
batch_data = next(data_iter)
class IndexErrorTransform(Transform):
def __init__(self):
self.array = [0, 1, 2]
def apply(self, input):
error_item = self.array[3]
return input
class TypeErrorTransform(Transform):
def __init__(self):
self.adda = 1
self.addb = "2"
def apply(self, input):
error_item = self.adda + self.addb
return input
@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
@pytest.mark.parametrize("transform", [IndexErrorTransform(), TypeErrorTransform()])
def test_dataloader_worker_baseerror(transform):
dataset = init_dataset()
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
transform=transform,
num_workers=2,
)
with pytest.raises(RuntimeError, match=r"Caught .*Error in DataLoader worker"):
data_iter = iter(dataloader)
batch_data = next(data_iter)
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
@pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(num_workers): def test_stream_dataloader(num_workers):
dataset = MyStream(100) dataset = MyStream(100)
...@@ -116,6 +189,10 @@ def test_dataloader_serial(): ...@@ -116,6 +189,10 @@ def test_dataloader_serial():
assert label.shape == (4,) assert label.shape == (4,)
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
def test_dataloader_parallel(): def test_dataloader_parallel():
# set max shared memory to 100M # set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000" os.environ["MGE_PLASMA_MEMORY"] = "100000000"
...@@ -214,6 +291,10 @@ def _multi_instances_parallel_dataloader_worker(): ...@@ -214,6 +291,10 @@ def _multi_instances_parallel_dataloader_worker():
assert val_label.shape == (10,) assert val_label.shape == (10,)
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
def test_dataloader_parallel_multi_instances(): def test_dataloader_parallel_multi_instances():
# set max shared memory to 100M # set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000" os.environ["MGE_PLASMA_MEMORY"] = "100000000"
...@@ -265,6 +346,10 @@ class MyPreStream(StreamDataset): ...@@ -265,6 +346,10 @@ class MyPreStream(StreamDataset):
raise StopIteration raise StopIteration
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", platform.system() == "Windows",
reason="dataloader do not support parallel on windows", reason="dataloader do not support parallel on windows",
......
...@@ -78,6 +78,10 @@ class MyStream(StreamDataset): ...@@ -78,6 +78,10 @@ class MyStream(StreamDataset):
raise StopIteration raise StopIteration
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
@pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(num_workers): def test_stream_dataloader(num_workers):
dataset = MyStream(100) dataset = MyStream(100)
...@@ -127,6 +131,10 @@ def test_dataloader_serial(): ...@@ -127,6 +131,10 @@ def test_dataloader_serial():
assert label._tuple_shape == (4,) assert label._tuple_shape == (4,)
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
def test_dataloader_parallel(): def test_dataloader_parallel():
# set max shared memory to 100M # set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000" os.environ["MGE_PLASMA_MEMORY"] = "100000000"
...@@ -230,6 +238,10 @@ def _multi_instances_parallel_dataloader_worker(): ...@@ -230,6 +238,10 @@ def _multi_instances_parallel_dataloader_worker():
assert val_label._tuple_shape == (10,) assert val_label._tuple_shape == (10,)
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
def test_dataloader_parallel_multi_instances(): def test_dataloader_parallel_multi_instances():
# set max shared memory to 100M # set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000" os.environ["MGE_PLASMA_MEMORY"] = "100000000"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册