From 7191c4bd9f87fb8f340ebbc41ed2840510fc80c8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 12 Jan 2021 11:06:08 +0800 Subject: [PATCH] fix(mge/data): support timeout for serial stream dataloader GitOrigin-RevId: 1ae5a8cfdace417c470dd7191ddd3d17ba03dc70 --- .../python/megengine/data/dataloader.py | 35 ++++++++++++------- .../python/test/unit/data/test_dataloader.py | 20 ++++------- 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/imperative/python/megengine/data/dataloader.py b/imperative/python/megengine/data/dataloader.py index c1fb54320..ae7113262 100644 --- a/imperative/python/megengine/data/dataloader.py +++ b/imperative/python/megengine/data/dataloader.py @@ -12,6 +12,7 @@ import multiprocessing import platform import queue import random +import threading import time import numpy as np @@ -23,10 +24,16 @@ from .dataset import Dataset, StreamDataset from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler from .transform import PseudoTransform, Transform +try: + import thread +except: + import _thread as thread + + logger = get_logger(__name__) -MP_QUEUE_GET_TIMEOUT = 5 +GLOBAL_TIMEOUT = 5 class DataLoader: @@ -39,7 +46,7 @@ class DataLoader: transform: Transform = None, collator: Collator = None, num_workers: int = 0, - timeout: int = 0, + timeout: int = GLOBAL_TIMEOUT, divide: bool = False, ): r""" @@ -377,21 +384,23 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): def _get_next_batch(self): ret = [] - start_time = time.time() while len(ret) != self.sampler.batch_size: - waited_time = time.time() - start_time - if self.timeout > 0 and waited_time > self.timeout: - raise RuntimeError("get_next_batch timeout!") if self.idx != 0: data = self.data else: try: + timer = threading.Timer(self.timeout, thread.interrupt_main) + timer.start() raw_data = next(self.dataset_iter) + timer.cancel() + except KeyboardInterrupt: + raise RuntimeError("get_next_batch timeout!") except: + timer.cancel() continue assert len(raw_data) == 2 and isinstance( raw_data[0], bool - ), "raw_data must be a tuple" + ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched." if not raw_data[0]: data = list((x,) for x in raw_data[1]) else: @@ -456,7 +465,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): raw_data = next(dataset_iter) assert len(raw_data) == 2 and isinstance( raw_data[0], bool - ), "raw_data must be a tuple" + ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched." if not raw_data[0]: data = list((x,) for x in raw_data[1]) else: @@ -478,7 +487,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): if self.shutdown_flag.value == 1: break try: - data = self.raw_data_queues[worker_id].get(timeout=MP_QUEUE_GET_TIMEOUT) + data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT) except queue.Empty: continue trans_data = self.transform.apply(data) @@ -501,7 +510,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): queue_id = cnt % self.num_workers try: trans_item = self.trans_data_queues[queue_id].get( - timeout=MP_QUEUE_GET_TIMEOUT + timeout=GLOBAL_TIMEOUT ) except queue.Empty: continue @@ -622,7 +631,7 @@ def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdow if shutdown_flag.value == 1: break try: - batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT) + batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT) except queue.Empty: continue if len(indices) > 0: @@ -665,7 +674,7 @@ def _data_gathering_loop( while True: try: batch_idx, trans_items = trans_data_queues[worker_id].get( - timeout=MP_QUEUE_GET_TIMEOUT + timeout=GLOBAL_TIMEOUT ) break except queue.Empty: @@ -726,7 +735,7 @@ def _data_selecting_loop( while True: try: batch_idx, trans_items = trans_data_queues[target_worker_id].get( - timeout=MP_QUEUE_GET_TIMEOUT + timeout=GLOBAL_TIMEOUT ) batch_data = collator.apply(trans_items) break diff --git a/imperative/python/test/unit/data/test_dataloader.py b/imperative/python/test/unit/data/test_dataloader.py index a8f152d37..8dd25b68f 100644 --- a/imperative/python/test/unit/data/test_dataloader.py +++ b/imperative/python/test/unit/data/test_dataloader.py @@ -61,13 +61,17 @@ def test_dataloader_init(): class MyStream(StreamDataset): - def __init__(self, number, batch=False, error=False): + def __init__(self, number, batch=False, error=False, block=False): self.number = number self.batch = batch self.error = error + self.block = block def __iter__(self): for cnt in range(self.number): + if self.block: + for _ in range(10): + time.sleep(1) if self.batch: data = np.random.randint(0, 256, (2, 32, 32, 3), dtype="uint8") yield (True, (data, [cnt, cnt - self.number])) @@ -115,20 +119,10 @@ def test_stream_dataloader_error(): @pytest.mark.parametrize("num_workers", [0, 2]) def test_stream_dataloader_timeout(num_workers): - dataset = MyStream(100, False) + dataset = MyStream(100, False, block=True) sampler = StreamSampler(batch_size=4) - class TimeoutTransform(Transform): - def __init__(self): - pass - - def apply(self, input): - time.sleep(10) - return input - - dataloader = DataLoader( - dataset, sampler, TimeoutTransform(), num_workers=num_workers, timeout=5 - ) + dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=5) with pytest.raises(RuntimeError, match=r".*timeout.*"): data_iter = iter(dataloader) next(data_iter) -- GitLab