提交 5308b081 编写于 作者: M Megvii Engine Team

fix(data): fix pyarrow.plasma import error in pyarrow1.12

GitOrigin-RevId: b5e1cd3be59cc80a3cc5bf6a83855ede2a2cd38a
上级 6486428f
...@@ -36,11 +36,6 @@ try: ...@@ -36,11 +36,6 @@ try:
except: except:
import _thread as thread import _thread as thread
if platform.system() != "Windows":
import pyarrow
from .tools._queue import _ExceptionWrapper
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -722,6 +717,8 @@ def _worker_loop( ...@@ -722,6 +717,8 @@ def _worker_loop(
data = worker_id data = worker_id
iteration_end = True iteration_end = True
else: else:
from .tools._queue import _ExceptionWrapper
exc_info = sys.exc_info() exc_info = sys.exc_info()
where = "in DataLoader worker process {}".format(worker_id) where = "in DataLoader worker process {}".format(worker_id)
exc_msg = "".join(traceback.format_exception(*exc_info)) exc_msg = "".join(traceback.format_exception(*exc_info))
......
...@@ -7,12 +7,18 @@ import subprocess ...@@ -7,12 +7,18 @@ import subprocess
from multiprocessing import Queue from multiprocessing import Queue
import pyarrow import pyarrow
import pyarrow.plasma as plasma
from ...logger import get_logger from ...logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
try:
import pyarrow.plasma as plasma
except ModuleNotFoundError:
raise RuntimeError(
"pyarrow remove plasma in version 12.0.0, please use pyarrow vserion < 12.0.0"
)
# Each process only need to start one plasma store, so we set it as a global variable. # Each process only need to start one plasma store, so we set it as a global variable.
# TODO: how to share between different processes? # TODO: how to share between different processes?
MGE_PLASMA_STORE_MANAGER = None MGE_PLASMA_STORE_MANAGER = None
......
...@@ -143,6 +143,10 @@ def test_dataloader_worker_baseerror(transform): ...@@ -143,6 +143,10 @@ def test_dataloader_worker_baseerror(transform):
batch_data = next(data_iter) batch_data = next(data_iter)
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
@pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(num_workers): def test_stream_dataloader(num_workers):
dataset = MyStream(100) dataset = MyStream(100)
...@@ -186,6 +190,10 @@ def test_dataloader_serial(): ...@@ -186,6 +190,10 @@ def test_dataloader_serial():
assert label.shape == (4,) assert label.shape == (4,)
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
def test_dataloader_parallel(): def test_dataloader_parallel():
# set max shared memory to 100M # set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000" os.environ["MGE_PLASMA_MEMORY"] = "100000000"
...@@ -286,6 +294,10 @@ def _multi_instances_parallel_dataloader_worker(): ...@@ -286,6 +294,10 @@ def _multi_instances_parallel_dataloader_worker():
assert val_label.shape == (10,) assert val_label.shape == (10,)
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
def test_dataloader_parallel_multi_instances(): def test_dataloader_parallel_multi_instances():
# set max shared memory to 100M # set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000" os.environ["MGE_PLASMA_MEMORY"] = "100000000"
...@@ -337,6 +349,10 @@ class MyPreStream(StreamDataset): ...@@ -337,6 +349,10 @@ class MyPreStream(StreamDataset):
raise StopIteration raise StopIteration
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", platform.system() == "Windows",
reason="dataloader do not support parallel on windows", reason="dataloader do not support parallel on windows",
......
...@@ -78,6 +78,10 @@ class MyStream(StreamDataset): ...@@ -78,6 +78,10 @@ class MyStream(StreamDataset):
raise StopIteration raise StopIteration
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
@pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(num_workers): def test_stream_dataloader(num_workers):
dataset = MyStream(100) dataset = MyStream(100)
...@@ -127,6 +131,10 @@ def test_dataloader_serial(): ...@@ -127,6 +131,10 @@ def test_dataloader_serial():
assert label._tuple_shape == (4,) assert label._tuple_shape == (4,)
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
def test_dataloader_parallel(): def test_dataloader_parallel():
# set max shared memory to 100M # set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000" os.environ["MGE_PLASMA_MEMORY"] = "100000000"
...@@ -232,6 +240,10 @@ def _multi_instances_parallel_dataloader_worker(): ...@@ -232,6 +240,10 @@ def _multi_instances_parallel_dataloader_worker():
assert val_label._tuple_shape == (10,) assert val_label._tuple_shape == (10,)
@pytest.mark.skipif(
np.__version__ >= "1.20.0",
reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
def test_dataloader_parallel_multi_instances(): def test_dataloader_parallel_multi_instances():
# set max shared memory to 100M # set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000" os.environ["MGE_PLASMA_MEMORY"] = "100000000"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册