提交 c418d3cd 编写于 作者: M Megvii Engine Team

fix(mge/data): add timeout event

GitOrigin-RevId: 43f2ba1456ce027e59ea2e09b9f9b795bb2e802f
上级 0f739c11
......@@ -14,6 +14,7 @@ import queue
import random
import threading
import time
from typing import Callable
import numpy as np
......@@ -36,6 +37,10 @@ logger = get_logger(__name__)
GLOBAL_TIMEOUT = 5
def raise_timeout_error():
raise RuntimeError("dataloader timeout")
class DataLoader:
__initialized = False
......@@ -46,7 +51,8 @@ class DataLoader:
transform: Transform = None,
collator: Collator = None,
num_workers: int = 0,
timeout: int = GLOBAL_TIMEOUT,
timeout: int = 0,
timeout_event: Callable = raise_timeout_error,
divide: bool = False,
):
r"""
......@@ -71,6 +77,9 @@ class DataLoader:
:type timeout: int
:param timeout: if positive, means the timeout value(second) for collecting a
batch from workers. Default: 0
:type timeout_event: Callable
:param timeout_event: callback function triggered by timeout, default to raise
runtime error.
:type divide: bool
:param divide: define the paralleling strategy in multi-processing mode.
``True`` means one batch is divided into :attr:`num_workers` pieces, and
......@@ -92,6 +101,7 @@ class DataLoader:
self.num_workers = num_workers
self.timeout = timeout
self.timeout_event = timeout_event
self.divide = divide
......@@ -168,6 +178,7 @@ class _BaseMapDataLoaderIter:
self.collator = loader.collator
self.num_workers = loader.num_workers
self.timeout = loader.timeout
self.timeout_event = loader.timeout_event
self.divide = loader.divide
self.num_processed = 0
......@@ -306,7 +317,7 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
logger.debug("all workers are alive.")
def _try_get_next_batch(self):
def _get_next_batch(self):
start_time = time.time()
while True:
self._check_workers()
......@@ -319,10 +330,6 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
if waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
def _get_next_batch(self):
batch_data = self._try_get_next_batch()
return batch_data
def _shutdown(self):
with self.shutdown_flag.get_lock():
self.shutdown_flag.value = 1
......@@ -364,10 +371,24 @@ class _BaseStreamDataLoaderIter:
self.collator = loader.collator
self.num_workers = loader.num_workers
self.timeout = loader.timeout
self.timeout_event = loader.timeout_event
def _get_next_batch(self):
raise NotImplementedError
def _process_raw_data(self, raw_data):
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
ret = []
for idx in range(len(data[0])):
ret.append(tuple(e[idx] for e in data))
return ret
def __iter__(self):
return self
......@@ -380,42 +401,43 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
super().__init__(loader)
self.dataset_iter = iter(self.dataset)
self.idx = 0
self.data = None
self.unused = []
def _get_next_batch(self):
ret = []
while len(ret) != self.sampler.batch_size:
if self.idx != 0:
data = self.data
else:
try:
def _try_get_raw_data(self, start_time):
raw_data = None
while not raw_data:
try:
if self.timeout > 0:
timer = threading.Timer(self.timeout, thread.interrupt_main)
timer.start()
raw_data = next(self.dataset_iter)
raw_data = next(self.dataset_iter)
if self.timeout > 0:
timer.cancel()
except KeyboardInterrupt:
raise RuntimeError("get_next_batch timeout!")
except:
except KeyboardInterrupt:
raw_data = self.timeout_event()
except:
if self.timeout > 0:
timer.cancel()
continue
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
for idx in range(self.idx, len(data[0])):
trans_data = self.transform.apply(tuple(e[idx] for e in data))
ret.append(trans_data)
if len(ret) == self.sampler.batch_size:
if idx + 1 == len(data[0]):
self.idx = 0
self.data = None
else:
self.idx = idx
self.data = data
break
waited_time = time.time() - start_time
if waited_time > self.timeout:
raw_data = self.timeout_event()
return raw_data
def _get_next_batch(self):
ret = []
start_time = time.time()
while len(ret) < self.sampler.batch_size:
if len(self.unused) != 0:
batch_data = self.unused
else:
raw_data = self._try_get_raw_data(start_time)
batch_data = self._process_raw_data(raw_data)
while len(batch_data) != 0 and len(ret) < self.sampler.batch_size:
data = batch_data.pop()
ret.append(self.transform.apply(data))
self.unused = batch_data
return self.collator.apply(ret)
......@@ -440,49 +462,52 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
self.batch_queue = PlasmaShmQueue(maxsize=2)
self.recieve_worker = multiprocessing.Process(target=self._recieve, daemon=True)
self.recieve_worker = multiprocessing.Process(
target=self._worker_to_raw_data_queues, daemon=True
)
self.recieve_worker.start()
self.transform_workers = []
for worker_id in range(self.num_workers):
worker = multiprocessing.Process(
target=self._transform, args=(worker_id,), daemon=True
target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True
)
worker.start()
self.transform_workers.append(worker)
self.collect_worker = multiprocessing.Process(target=self._collect, daemon=True)
self.collect_worker = multiprocessing.Process(
target=self._worker_to_batch_queue, daemon=True
)
self.collect_worker.start()
self.__initialized = True
def _recieve(self):
def _put_raw_data_queues(self, raw_data, qidx):
batch_data = self._process_raw_data(raw_data)
for data in batch_data:
while True:
qidx = qidx % self.num_workers
try:
self.raw_data_queues[qidx].put(data)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("raw data queue %d is full" % qidx)
finally:
qidx += 1
return qidx
def _worker_to_raw_data_queues(self):
dataset_iter = iter(self.dataset)
cnt = -1
qidx = 0
while True:
if self.shutdown_flag.value == 1:
break
raw_data = next(dataset_iter)
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
for idx in range(len(data[0])):
while True:
cnt += 1
qid = cnt % self.num_workers
try:
self.raw_data_queues[qid].put(tuple(e[idx] for e in data))
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("raw data queue is full")
qidx = self._put_raw_data_queues(raw_data, qidx)
def _transform(self, worker_id):
def _worker_to_trans_data_queues(self, worker_id):
while True:
if self.shutdown_flag.value == 1:
break
......@@ -500,7 +525,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
break
logger.debug("batch queue if full")
def _collect(self):
def _worker_to_batch_queue(self):
cnt = -1
trans_items = []
while True:
......@@ -541,7 +566,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
"worker: {} died. {}".format(worker_id, exitcode)
)
def _try_get_next_batch(self):
def _get_next_batch(self):
start_time = time.time()
while True:
self._check_workers()
......@@ -551,11 +576,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
logger.debug("batch queue empty!")
waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
def _get_next_batch(self):
batch_data = self._try_get_next_batch()
return batch_data
self._put_raw_data_queues(self.timeout_event(), 0)
def _shutdown(self):
with self.shutdown_flag.get_lock():
......
......@@ -43,7 +43,7 @@ class StreamDataset(Dataset):
def __iter__(self):
pass
def __getitem__(self):
def __getitem__(self, idx):
raise AssertionError("can not get item from StreamDataset by index")
def __len__(self):
......
......@@ -61,10 +61,10 @@ def test_dataloader_init():
class MyStream(StreamDataset):
def __init__(self, number, batch=False, error=False, block=False):
def __init__(self, number, batch=False, error_foramt=False, block=False):
self.number = number
self.batch = batch
self.error = error
self.error_format = error_foramt
self.block = block
def __iter__(self):
......@@ -73,11 +73,11 @@ class MyStream(StreamDataset):
for _ in range(10):
time.sleep(1)
if self.batch:
data = np.random.randint(0, 256, (2, 32, 32, 3), dtype="uint8")
data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number]))
else:
data = np.random.randint(0, 256, (32, 32, 3), dtype="uint8")
if self.error:
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
if self.error_format:
yield (data, cnt)
else:
yield (False, (data, cnt))
......@@ -87,7 +87,7 @@ class MyStream(StreamDataset):
@pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers):
dataset = MyStream(100, batch)
dataset = MyStream(100, batch=batch)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
......@@ -101,7 +101,7 @@ def test_stream_dataloader(batch, num_workers):
for step, data in enumerate(dataloader):
if step == 10:
break
assert data[0].shape == (4, 3, 32, 32)
assert data[0].shape == (4, 3, 2, 2)
assert data[1].shape == (4,)
for i in data[1]:
assert i not in check_set
......@@ -109,7 +109,7 @@ def test_stream_dataloader(batch, num_workers):
def test_stream_dataloader_error():
dataset = MyStream(100, error=True)
dataset = MyStream(100, error_foramt=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler)
with pytest.raises(AssertionError, match=r".*tuple.*"):
......@@ -122,7 +122,7 @@ def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False, block=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=5)
dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=2)
with pytest.raises(RuntimeError, match=r".*timeout.*"):
data_iter = iter(dataloader)
next(data_iter)
......@@ -264,3 +264,20 @@ def test_dataloader_parallel_multi_instances_multiprocessing():
for p in processes:
p.join()
@pytest.mark.parametrize("num_workers", [0, 2])
def test_timeout_event(num_workers):
def cb():
return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,))))
dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset, sampler, num_workers=num_workers, timeout=2, timeout_event=cb
)
for _, data in enumerate(dataloader):
np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3)))
np.testing.assert_equal(data[1], np.ones(shape=(4,)))
break
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册