From 786399c5c96dd12c12e2f32b50348b30411909e8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 14 Oct 2020 12:11:47 +0800 Subject: [PATCH] fix(pytest/windows/impertive): fix impertive pytest failed on windows GitOrigin-RevId: 02f4c0a0be0be4a4e88ddc26363b19f3469bd4d5 --- imperative/python/megengine/data/dataloader.py | 6 ++++++ imperative/python/test/unit/data/test_dataloader.py | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/imperative/python/megengine/data/dataloader.py b/imperative/python/megengine/data/dataloader.py index 2a818a297..4372b2b75 100644 --- a/imperative/python/megengine/data/dataloader.py +++ b/imperative/python/megengine/data/dataloader.py @@ -9,6 +9,7 @@ import collections import math import multiprocessing +import platform import queue import random import time @@ -113,6 +114,11 @@ class DataLoader: self.__initialized = True def __iter__(self): + if platform.system() == "Windows": + print( + "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero" + ) + self.num_workers = 0 if self.num_workers == 0: return _SerialDataLoaderIter(self) else: diff --git a/imperative/python/test/unit/data/test_dataloader.py b/imperative/python/test/unit/data/test_dataloader.py index 6bb0f3e32..446dce07c 100644 --- a/imperative/python/test/unit/data/test_dataloader.py +++ b/imperative/python/test/unit/data/test_dataloader.py @@ -7,6 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import os +import platform import time import numpy as np @@ -89,6 +90,10 @@ def test_dataloader_parallel(): assert label.shape == (4,) +@pytest.mark.skipif( + platform.system() == "Windows", + reason="dataloader do not support parallel on windows", +) def test_dataloader_parallel_timeout(): dataset = init_dataset() @@ -112,6 +117,10 @@ def test_dataloader_parallel_timeout(): batch_data = next(data_iter) +@pytest.mark.skipif( + platform.system() == "Windows", + reason="dataloader do not support parallel on windows", +) def test_dataloader_parallel_worker_exception(): dataset = init_dataset() -- GitLab