提交 2c2caf33 编写于 作者: M Megvii Engine Team

refactor(mge/data/dataloader): refactor the implementation of parallel dataloader

GitOrigin-RevId: 0554ee8427c7d892557422c1ee57597b7c88756b
上级 364dafcc
......@@ -15,9 +15,8 @@ import time
import numpy as np
import megengine as mge
from ..logger import get_logger
from ..random.rng import _random_seed_generator
from .collator import Collator
from .dataset import Dataset
from .sampler import Sampler, SequentialSampler
......@@ -87,8 +86,6 @@ class DataLoader:
self.divide = divide
self.rng = np.random.RandomState()
if sampler is None:
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False)
else:
......@@ -130,7 +127,7 @@ class _BaseDataLoaderIter:
def __init__(self, loader):
self.dataset = loader.dataset
self.sampler = loader.sampler
self.seed = loader.rng.randint(1e9)
self.seed = _random_seed_generator().__next__()
self.transform = loader.transform
self.collator = loader.collator
self.num_workers = loader.num_workers
......@@ -173,10 +170,6 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_ParallelDataLoaderIter, self).__init__(loader)
# if any worker died, all workers will be shutdown.
self.strict = True
# TODO: put `strict` into DataLoader args or not?
self.task_queues = [
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
]
......@@ -185,7 +178,7 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
self.target_batch_idx = multiprocessing.Value("i", 0)
self.shutdown_flag = multiprocessing.Value("i", 0)
self.batch_part_queues = [
self.trans_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]
......@@ -195,8 +188,15 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
self.batch_queue = PlasmaShmQueue(maxsize=2)
self.task_feeding_worker = multiprocessing.Process(
target=self._task_feeding_loop,
args=(iter(self.sampler), self.divide),
target=_task_feeding_loop,
args=(
iter(self.sampler),
self.task_queues,
self.num_workers,
self.divide,
self.shutdown_flag,
self.feed_batch_idx,
),
daemon=True,
)
self.task_feeding_worker.start()
......@@ -204,13 +204,14 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
self.workers = []
for worker_id in range(self.num_workers):
worker = multiprocessing.Process(
target=self._worker_loop,
target=_worker_loop,
args=(
self.dataset,
self.task_queues[worker_id],
self.batch_part_queues[worker_id],
self.trans_data_queues[worker_id],
self.transform,
self.collator,
self.seed + worker_id + 1,
self.shutdown_flag,
),
daemon=True,
)
......@@ -219,191 +220,257 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
if self.divide:
self.data_collecting_worker = multiprocessing.Process(
target=self._data_gathering_loop,
args=(self.batch_part_queues, self.batch_queue,),
target=_data_gathering_loop,
args=(
self.trans_data_queues,
self.batch_queue,
self.collator,
len(self),
self.num_workers,
self.shutdown_flag,
self.target_batch_idx,
),
daemon=True,
)
else:
self.data_collecting_worker = multiprocessing.Process(
target=self._data_selecting_loop,
args=(self.batch_part_queues, self.batch_queue,),
target=_data_selecting_loop,
args=(
self.trans_data_queues,
self.batch_queue,
self.collator,
len(self),
self.num_workers,
self.shutdown_flag,
self.target_batch_idx,
),
daemon=True,
)
self.data_collecting_worker.start()
self.__initialized = True
def _task_feeding_loop(self, indices_iter, divide):
def _check_workers(self):
# Check the status of each worker.
if not self.data_collecting_worker.is_alive():
exitcode = self.task_feeding_worker.exitcode
if exitcode != 0:
raise RuntimeError("data collecting worker died. {}".format(exitcode))
if not self.task_feeding_worker.is_alive():
exitcode = self.task_feeding_worker.exitcode
if exitcode != 0:
raise RuntimeError("task feeding worker died. {}".format(exitcode))
for worker_id, worker in enumerate(self.workers):
if not worker.is_alive():
exitcode = worker.exitcode
if exitcode != 0:
raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode))
logger.debug("all workers are alive.")
def _try_get_next_batch(self):
start_time = time.time()
while True:
self._check_workers()
try:
return self.batch_queue.get(timeout=1)
except queue.Empty:
logger.debug("batch queue empty!")
waited_time = time.time() - start_time
if self.timeout > 0:
if waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
def _get_next_batch(self):
batch_data = self._try_get_next_batch()
return batch_data
def _shutdown(self):
with self.shutdown_flag.get_lock():
self.shutdown_flag.value = 1
if self.task_feeding_worker.is_alive():
self.task_feeding_worker.terminate()
self.task_feeding_worker.join()
if self.data_collecting_worker.is_alive():
self.data_collecting_worker.terminate()
self.data_collecting_worker.join()
for worker in self.workers:
if worker.is_alive():
worker.terminate()
worker.join()
for q in self.trans_data_queues:
q.cancel_join_thread()
q.close()
for q in self.task_queues:
q.cancel_join_thread()
q.close()
self.batch_queue.cancel_join_thread()
self.batch_queue.close()
def __del__(self):
if self.__initialized:
self._shutdown()
def _task_feeding_loop(
indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
):
# Feed the indices into the task queues
while True:
if self.shutdown_flag.value == 1:
if shutdown_flag.value == 1:
break
batch_idx = self.feed_batch_idx.value
batch_idx = feed_batch_idx.value
try:
indices = next(indices_iter)
except StopIteration:
break
if divide:
# make sure all task_queues is ready for put
while any([q.full() for q in self.task_queues]):
if self.shutdown_flag.value == 1:
while any([q.full() for q in task_queues]):
if shutdown_flag.value == 1:
return
# divide into small pieces, feed to different workers.
sub_num = math.ceil(len(indices) / self.num_workers)
for worker_id in range(self.num_workers):
sub_indices = indices[
worker_id * sub_num : (worker_id + 1) * sub_num
]
self.task_queues[worker_id].put((batch_idx, sub_indices))
sub_num = math.ceil(len(indices) / num_workers)
for worker_id in range(num_workers):
sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num]
task_queues[worker_id].put((batch_idx, sub_indices))
else:
# distribute tasks to different workers uniformly.
target_id = batch_idx % self.num_workers
while self.task_queues[target_id].full():
if self.shutdown_flag.value == 1:
target_id = batch_idx % num_workers
while task_queues[target_id].full():
if shutdown_flag.value == 1:
return
self.task_queues[target_id].put((batch_idx, indices))
with self.feed_batch_idx.get_lock():
self.feed_batch_idx.value += 1
task_queues[target_id].put((batch_idx, indices))
with feed_batch_idx.get_lock():
feed_batch_idx.value += 1
def _worker_loop(self, task_queue, data_queue, transform, collator, seed):
def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag):
# Get dataset items and do the transform
random.seed(seed)
np.random.seed(seed)
while True:
if self.shutdown_flag.value == 1:
if shutdown_flag.value == 1:
break
try:
batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT)
except queue.Empty:
continue
if len(indices) > 0:
items = [self.dataset[idx] for idx in indices]
items = [dataset[idx] for idx in indices]
trans_items = transform.apply_batch(items)
batch_data = collator.apply(trans_items)
else:
# in case of incomplete last batch
batch_data = ()
trans_items = ()
while True:
try:
data_queue.put((np.array([batch_idx]), batch_data), timeout=1)
trans_data_queue.put((batch_idx, trans_items), timeout=1)
break
except queue.Full:
if self.shutdown_flag.value == 1:
if shutdown_flag.value == 1:
break
logger.debug("batch part queue is full!")
continue
def _data_gathering_loop(self, batch_part_queues, batch_queue):
r"""Gathering the small pieces of batch data into full batch data."""
gathered_data = collections.defaultdict(dict)
def _data_gathering_loop(
trans_data_queues,
batch_queue,
collator,
length,
num_workers,
shutdown_flag,
target_idx,
):
# Gathering the small pieces of batch data into full batch data
while True:
if self.shutdown_flag.value == 1:
if shutdown_flag.value == 1:
break
target_batch_idx = self.target_batch_idx.value
target_batch_idx = target_idx.value
if target_batch_idx >= len(self):
if target_batch_idx >= length:
break
for worker_id in range(self.num_workers):
if worker_id in gathered_data[target_batch_idx]:
continue
full_trans_items = []
for worker_id in range(num_workers):
while True:
try:
(batch_idx,), batch_part = batch_part_queues[worker_id].get(
batch_idx, trans_items = trans_data_queues[worker_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
)
break
except queue.Empty:
if self.shutdown_flag.value == 1:
if shutdown_flag.value == 1:
break
logger.debug(
"worker:{} data queue get timeout! target batch idx:{}".format(
worker_id, target_batch_idx
)
)
if batch_idx < target_batch_idx:
if batch_idx != target_batch_idx:
raise RuntimeError(
"Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
worker_id
)
)
else:
gathered_data[batch_idx][worker_id] = batch_part
if len(gathered_data[target_batch_idx]) < self.num_workers:
length = len(gathered_data[target_batch_idx])
if self.strict:
raise RuntimeError("Parts missing in data gathering loop.")
logger.warning(
"target_batch_idx:{}, {} part(s) missing.".format(
target_batch_idx, self.num_workers - length
)
)
del gathered_data[target_batch_idx]
with self.target_batch_idx.get_lock():
self.target_batch_idx.value += 1
continue
full_trans_items.extend(trans_items)
# Merge different parts.
full_batch = [[] for _ in range(len(gathered_data[target_batch_idx][0]))]
for idx in range(self.num_workers):
for i, field in enumerate(gathered_data[target_batch_idx][idx]):
full_batch[i].append(field)
full_batch = tuple([np.concatenate(field, axis=0) for field in full_batch])
# Merge different parts into a batch.
full_batch = collator.apply(full_trans_items)
while True:
try:
batch_queue.put(full_batch, timeout=1)
break
except queue.Full:
if self.shutdown_flag.value == 1:
if shutdown_flag.value == 1:
break
logger.debug("batch queue is full!")
continue
del gathered_data[target_batch_idx]
with self.target_batch_idx.get_lock():
self.target_batch_idx.value += 1
with target_idx.get_lock():
target_idx.value += 1
batch_queue.disconnect_client()
def _data_selecting_loop(self, batch_part_queues, batch_queue):
r"""Make sure that batch is generated exactly with the same order as generated indices."""
buffer_batches = {}
def _data_selecting_loop(
trans_data_queues,
batch_queue,
collator,
length,
num_workers,
shutdown_flag,
target_idx,
):
# Make sure that batch is generated exactly with the same order as generated indices
while True:
if self.shutdown_flag.value == 1:
if shutdown_flag.value == 1:
break
target_batch_idx = self.target_batch_idx.value
target_batch_idx = target_idx.value
if target_batch_idx >= len(self):
if target_batch_idx >= length:
break
if target_batch_idx in buffer_batches:
target_worker_id = target_batch_idx % num_workers
while True:
try:
batch_queue.put(
buffer_batches[target_batch_idx], timeout=1,
)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch queue is full!")
with self.target_batch_idx.get_lock():
self.target_batch_idx.value += 1
del buffer_batches[target_batch_idx]
continue
target_worker_id = target_batch_idx % self.num_workers
while True:
try:
(batch_idx,), batch_data = batch_part_queues[target_worker_id].get(
batch_idx, trans_items = trans_data_queues[target_worker_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
)
batch_data = collator.apply(trans_items)
break
except queue.Empty:
if self.shutdown_flag.value == 1:
if shutdown_flag.value == 1:
break
logger.debug(
"worker:{} data queue get timeout! target batch idx:{}".format(
......@@ -411,136 +478,23 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
)
)
if batch_idx < target_batch_idx:
raise RuntimeError("batch_idx smaller than target_batch_idx")
elif batch_idx > target_batch_idx:
if self.strict:
raise RuntimeError("batch_idx larger than target_batch_idx")
logger.warning(
"missing target batch idx:{}, batch idx:{}".format(
target_batch_idx, batch_idx
)
)
buffer_batches[batch_idx] = batch_data
else:
try:
batch_queue.put(batch_data, timeout=1)
except queue.Full:
buffer_batches[batch_idx] = batch_data
continue
with self.target_batch_idx.get_lock():
self.target_batch_idx.value += 1
batch_queue.disconnect_client()
def _check_workers(self):
"""Check the status of each worker and restart if necessary."""
if not self.data_collecting_worker.is_alive():
exitcode = self.task_feeding_worker.exitcode
if exitcode != 0:
raise RuntimeError("data collecting worker died. {}".format(exitcode))
if self.strict:
if not self.task_feeding_worker.is_alive():
exitcode = self.task_feeding_worker.exitcode
if exitcode != 0:
raise RuntimeError("task feeding worker died. {}".format(exitcode))
for worker_id, worker in enumerate(self.workers):
if not worker.is_alive():
exitcode = worker.exitcode
if exitcode != 0:
if batch_idx != target_batch_idx:
raise RuntimeError(
"worker:{} died. {}".format(worker_id, exitcode)
"batch_idx {} mismatch the target_batch_idx {}".format(
batch_idx, target_batch_idx
)
else:
if not self.task_feeding_worker.is_alive():
exitcode = self.task_feeding_worker.exitcode
if exitcode != 0:
logger.error(
"task feeding worker died {}. Restarting".format(exitcode)
)
self.task_feeding_worker.join()
self.task_feeding_worker = multiprocessing.Process(
target=self._task_feeding_loop,
args=(iter(self.sampler), self.divide),
daemon=True,
)
self.task_feeding_worker.start()
failed_num = 0
for worker_id in range(self.num_workers):
if self.workers[worker_id].is_alive():
continue
exitcode = worker.exitcode
if exitcode == 0:
continue
logger.error("worker {} died. Restarting".format(worker_id))
failed_num += 1
self.workers[worker_id].join()
worker = multiprocessing.Process(
target=self._worker_loop,
args=(
self.task_queues[worker_id],
self.batch_part_queues[worker_id],
self.transform,
self.collator,
self.seed + worker_id + 1,
),
daemon=True,
)
worker.start()
self.workers[worker_id] = worker
if failed_num > 0:
logger.error("{} worker had exited".format(failed_num))
else:
logger.debug("all workers are alive.")
def _try_get_next_batch(self):
start_time = time.time()
while True:
self._check_workers()
try:
return self.batch_queue.get(timeout=1)
except queue.Empty:
logger.debug("batch queue empty!")
waited_time = time.time() - start_time
if self.timeout > 0:
if waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
def _get_next_batch(self):
batch_data = self._try_get_next_batch()
return batch_data
def _shutdown(self):
with self.shutdown_flag.get_lock():
self.shutdown_flag.value = 1
if self.task_feeding_worker.is_alive():
self.task_feeding_worker.terminate()
self.task_feeding_worker.join()
if self.data_collecting_worker.is_alive():
self.data_collecting_worker.terminate()
self.data_collecting_worker.join()
for worker in self.workers:
if worker.is_alive():
worker.terminate()
worker.join()
for q in self.batch_part_queues:
q.cancel_join_thread()
q.close()
for q in self.task_queues:
q.cancel_join_thread()
q.close()
batch_queue.put(batch_data, timeout=1)
break
except queue.Full:
if shutdown_flag.value == 1:
break
logger.debug("batch queue is full!")
self.batch_queue.cancel_join_thread()
self.batch_queue.close()
with target_idx.get_lock():
target_idx.value += 1
def __del__(self):
if self.__initialized:
self._shutdown()
batch_queue.disconnect_client()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册