提交 5308b081 编写于 作者: M Megvii Engine Team

fix(data): fix pyarrow.plasma import error in pyarrow1.12

GitOrigin-RevId: b5e1cd3be59cc80a3cc5bf6a83855ede2a2cd38a
上级 6486428f
......@@ -36,11 +36,6 @@ try:
except:
import _thread as thread
if platform.system() != "Windows":
import pyarrow
from .tools._queue import _ExceptionWrapper
logger = get_logger(__name__)
......@@ -722,6 +717,8 @@ def _worker_loop(
data = worker_id
iteration_end = True
else:
from .tools._queue import _ExceptionWrapper
exc_info = sys.exc_info()
where = "in DataLoader worker process {}".format(worker_id)
exc_msg = "".join(traceback.format_exception(*exc_info))
......
......@@ -7,12 +7,18 @@ import subprocess
from multiprocessing import Queue
import pyarrow
import pyarrow.plasma as plasma
from ...logger import get_logger
logger = get_logger(__name__)
try:
import pyarrow.plasma as plasma
except ModuleNotFoundError:
raise RuntimeError(
"pyarrow remove plasma in version 12.0.0, please use pyarrow vserion < 12.0.0"
)
# Each process only need to start one plasma store, so we set it as a global variable.
# TODO: how to share between different processes?
MGE_PLASMA_STORE_MANAGER = None
......
......@@ -143,6 +143,10 @@ def test_dataloader_worker_baseerror(transform):
batch_data = next(data_iter)
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(num_workers):
dataset = MyStream(100)
......@@ -186,6 +190,10 @@ def test_dataloader_serial():
assert label.shape == (4,)
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
def test_dataloader_parallel():
# set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000"
......@@ -286,6 +294,10 @@ def _multi_instances_parallel_dataloader_worker():
assert val_label.shape == (10,)
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
def test_dataloader_parallel_multi_instances():
# set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000"
......@@ -337,6 +349,10 @@ class MyPreStream(StreamDataset):
raise StopIteration
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
......
......@@ -78,6 +78,10 @@ class MyStream(StreamDataset):
raise StopIteration
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(num_workers):
dataset = MyStream(100)
......@@ -127,6 +131,10 @@ def test_dataloader_serial():
assert label._tuple_shape == (4,)
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
def test_dataloader_parallel():
# set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000"
......@@ -232,6 +240,10 @@ def _multi_instances_parallel_dataloader_worker():
assert val_label._tuple_shape == (10,)
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
def test_dataloader_parallel_multi_instances():
# set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册