diff --git a/imperative/python/megengine/data/dataloader.py b/imperative/python/megengine/data/dataloader.py index 2a818a29781e1f5c579d09e563eccbf22d72fac9..4372b2b753afaf0a7612d8ec794c6755314137bb 100644 --- a/imperative/python/megengine/data/dataloader.py +++ b/imperative/python/megengine/data/dataloader.py @@ -9,6 +9,7 @@ import collections import math import multiprocessing +import platform import queue import random import time @@ -113,6 +114,11 @@ class DataLoader: self.__initialized = True def __iter__(self): + if platform.system() == "Windows": + print( + "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero" + ) + self.num_workers = 0 if self.num_workers == 0: return _SerialDataLoaderIter(self) else: diff --git a/imperative/python/test/unit/data/test_dataloader.py b/imperative/python/test/unit/data/test_dataloader.py index 6bb0f3e32f592ea50b4514b6a3c616a4b0f8c117..446dce07c4917ebc05955992e57fa6e71221cb6f 100644 --- a/imperative/python/test/unit/data/test_dataloader.py +++ b/imperative/python/test/unit/data/test_dataloader.py @@ -7,6 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import os +import platform import time import numpy as np @@ -89,6 +90,10 @@ def test_dataloader_parallel(): assert label.shape == (4,) +@pytest.mark.skipif( + platform.system() == "Windows", + reason="dataloader do not support parallel on windows", +) def test_dataloader_parallel_timeout(): dataset = init_dataset() @@ -112,6 +117,10 @@ def test_dataloader_parallel_timeout(): batch_data = next(data_iter) +@pytest.mark.skipif( + platform.system() == "Windows", + reason="dataloader do not support parallel on windows", +) def test_dataloader_parallel_worker_exception(): dataset = init_dataset()