提交 7191c4bd 编写于 作者: M Megvii Engine Team

fix(mge/data): support timeout for serial stream dataloader

GitOrigin-RevId: 1ae5a8cfdace417c470dd7191ddd3d17ba03dc70
上级 fe6af7cb
......@@ -12,6 +12,7 @@ import multiprocessing
import platform
import queue
import random
import threading
import time
import numpy as np
......@@ -23,10 +24,16 @@ from .dataset import Dataset, StreamDataset
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform
try:
import thread
except:
import _thread as thread
logger = get_logger(__name__)
MP_QUEUE_GET_TIMEOUT = 5
GLOBAL_TIMEOUT = 5
class DataLoader:
......@@ -39,7 +46,7 @@ class DataLoader:
transform: Transform = None,
collator: Collator = None,
num_workers: int = 0,
timeout: int = 0,
timeout: int = GLOBAL_TIMEOUT,
divide: bool = False,
):
r"""
......@@ -377,21 +384,23 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
def _get_next_batch(self):
ret = []
start_time = time.time()
while len(ret) != self.sampler.batch_size:
waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
if self.idx != 0:
data = self.data
else:
try:
timer = threading.Timer(self.timeout, thread.interrupt_main)
timer.start()
raw_data = next(self.dataset_iter)
timer.cancel()
except KeyboardInterrupt:
raise RuntimeError("get_next_batch timeout!")
except:
timer.cancel()
continue
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "raw_data must be a tuple"
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
......@@ -456,7 +465,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
raw_data = next(dataset_iter)
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "raw_data must be a tuple"
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
......@@ -478,7 +487,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
if self.shutdown_flag.value == 1:
break
try:
data = self.raw_data_queues[worker_id].get(timeout=MP_QUEUE_GET_TIMEOUT)
data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)
except queue.Empty:
continue
trans_data = self.transform.apply(data)
......@@ -501,7 +510,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
queue_id = cnt % self.num_workers
try:
trans_item = self.trans_data_queues[queue_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
timeout=GLOBAL_TIMEOUT
)
except queue.Empty:
continue
......@@ -622,7 +631,7 @@ def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdow
if shutdown_flag.value == 1:
break
try:
batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT)
batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT)
except queue.Empty:
continue
if len(indices) > 0:
......@@ -665,7 +674,7 @@ def _data_gathering_loop(
while True:
try:
batch_idx, trans_items = trans_data_queues[worker_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
timeout=GLOBAL_TIMEOUT
)
break
except queue.Empty:
......@@ -726,7 +735,7 @@ def _data_selecting_loop(
while True:
try:
batch_idx, trans_items = trans_data_queues[target_worker_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
timeout=GLOBAL_TIMEOUT
)
batch_data = collator.apply(trans_items)
break
......
......@@ -61,13 +61,17 @@ def test_dataloader_init():
class MyStream(StreamDataset):
def __init__(self, number, batch=False, error=False):
def __init__(self, number, batch=False, error=False, block=False):
self.number = number
self.batch = batch
self.error = error
self.block = block
def __iter__(self):
for cnt in range(self.number):
if self.block:
for _ in range(10):
time.sleep(1)
if self.batch:
data = np.random.randint(0, 256, (2, 32, 32, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number]))
......@@ -115,20 +119,10 @@ def test_stream_dataloader_error():
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False)
dataset = MyStream(100, False, block=True)
sampler = StreamSampler(batch_size=4)
class TimeoutTransform(Transform):
def __init__(self):
pass
def apply(self, input):
time.sleep(10)
return input
dataloader = DataLoader(
dataset, sampler, TimeoutTransform(), num_workers=num_workers, timeout=5
)
dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=5)
with pytest.raises(RuntimeError, match=r".*timeout.*"):
data_iter = iter(dataloader)
next(data_iter)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册