提交 edc92ccf 编写于 作者: M Megvii Engine Team

perf(imperative/data): improve dataloader preformance

GitOrigin-RevId: 7d8d52aaeb47e7ec6c3efa282ff9014a4b7d1f01
上级 896b0193
# -*- coding: utf-8 -*-
import collections
import gc
import math
import itertools
import multiprocessing
import os
import platform
......@@ -9,7 +9,6 @@ import queue
import random
import threading
import time
from typing import Callable, Union
import numpy as np
......@@ -67,12 +66,6 @@ class DataLoader:
the batch. ``0`` means using single-process. Default: 0
timeout: if positive, means the timeout value(second) for collecting a
batch from workers. Default: 0
timeout_event: callback function triggered by timeout, default to raise
runtime error.
divide: define the paralleling strategy in multi-processing mode.
``True`` means one batch is divided into :attr:`num_workers` pieces, and
the workers will process these pieces parallelly. ``False`` means
different sub-process will process different batch. Default: False
preload: whether to enable the preloading strategy of the dataloader.
When enabling, the dataloader will preload one batch to the device memory to speed up the whole training process.
......@@ -85,7 +78,6 @@ class DataLoader:
which will improve the training speed at the cost of **higher device memory usage** (due to one more batch data on device memory).
This feature saves more time when your NN training time is short or your machine's host PCIe bandwidth for each device is low.
"""
__initialized = False
def __init__(
self,
......@@ -95,9 +87,8 @@ class DataLoader:
collator: Collator = None,
num_workers: int = 0,
timeout: int = 0,
timeout_event: Callable = _raise_timeout_error,
divide: bool = False,
preload: bool = False,
parallel_stream: bool = False,
):
if num_workers < 0:
raise ValueError("num_workers should not be negative")
......@@ -105,23 +96,22 @@ class DataLoader:
if timeout < 0:
raise ValueError("timeout should not be negative")
if divide and num_workers <= 1:
raise ValueError("divide should not be set to True when num_workers <= 1")
self.dataset = dataset
self.num_workers = num_workers
self.timeout = timeout
self.timeout_event = timeout_event
self.divide = divide
self.preload = preload
self.parallel_stream = parallel_stream
if isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
assert isinstance(
self.sampler, StreamSampler
), "types of dataset and sampler do not match"
if parallel_stream is False and self.num_workers > 1:
logger.warning(
"Data time will be affected by getting origin-data, please set parallel_stream in order to speed up dataloader!"
)
self.datakind = "stream"
else:
assert isinstance(
dataset, Dataset
......@@ -134,16 +124,7 @@ class DataLoader:
assert isinstance(
self.sampler, MapSampler
), "types of dataset and sampler do not match"
if divide:
if self.sampler.batch_size <= self.num_workers:
raise ValueError(
"batch size must not smaller than num_workers in divide mode."
)
elif self.sampler.batch_size % self.num_workers:
logger.warning(
"batch size is not divisible by num_workers, may lose performance in divide mode."
)
self.datakind = "map"
if transform is None:
self.transform = PseudoTransform()
......@@ -155,7 +136,8 @@ class DataLoader:
else:
self.collator = collator
self.__initialized = True
if platform.system() == "Linux" and self.num_workers > 0:
self.check_memory_rationality()
def __iter__(self):
if platform.system() == "Windows" and self.num_workers > 0:
......@@ -187,15 +169,50 @@ class DataLoader:
def __len__(self):
return len(self.sampler)
def check_memory_rationality(self):
import psutil
main_memory = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024
total_memory = (self.num_workers + 1) * main_memory
current_memory = (
int(os.popen("cat /sys/fs/cgroup/memory/memory.limit_in_bytes").read())
/ 1024
/ 1024
/ 1024
)
if current_memory < total_memory:
logger.warning(
"Each worker need to read the shared meta-data, which will be increasing the reference count."
"Copy-On-Write propety will lead to 'memory leak', the memory usage will end up being "
+ total_memory
+ " GB"
"However the current requested memory is " + current_memory + " GB"
"Maybe you can request more memory or uesd np-array to save meta-data rather than List or Tuple"
)
class _PreLoader:
def __init__(self, preload):
def __init__(self, loader, preload):
self.dataset = loader.dataset
self.sampler = loader.sampler
self.seed = _random_seed_generator().__next__()
self.transform = loader.transform
self.collator = loader.collator
self.num_workers = loader.num_workers
self.timeout = loader.timeout
self.num_processed = 0
self.datakind = loader.datakind
self.parallel_stream = loader.parallel_stream
if preload:
self.default_device = get_default_device()
self.pre_load_device = self.default_device + ":" + str(_sh.get_next())
self.pre_load_device_cache = None
self.preload = preload
def __iter__(self):
return self
"""
strategy one: load from numpy data, and generate dtype tensor
"""
......@@ -237,29 +254,176 @@ class _PreLoader:
return out
class _BaseMapDataLoaderIter(_PreLoader):
def __init__(self, loader, preload):
super().__init__(preload)
self.dataset = loader.dataset
self.sampler = loader.sampler
self.seed = _random_seed_generator().__next__()
self.transform = loader.transform
self.collator = loader.collator
self.num_workers = loader.num_workers
self.timeout = loader.timeout
self.timeout_event = loader.timeout_event
self.divide = loader.divide
self.num_processed = 0
class _ParallelDataLoaderIter:
def __init__(self):
self._worker_queue_idx_cycle = itertools.cycle(range(self.num_workers))
from .tools._queue import PlasmaShmQueue
def _get_next_batch(self):
self._worker_result_queue = PlasmaShmQueue()
self._shutdown = False
self._workers_done_event = multiprocessing.Event()
self._index_queues = []
self._workers = []
for i in range(self.num_workers):
index_queue = multiprocessing.Queue()
index_queue.cancel_join_thread()
w = multiprocessing.Process(
target=_worker_loop,
args=(
self.dataset,
index_queue,
self._worker_result_queue,
self._workers_done_event,
self.transform,
self.collator,
self.sampler.batch_size,
self.seed + i,
i,
self.num_workers,
self.datakind,
self.parallel_stream,
),
daemon=True,
)
gc.collect()
w.start()
self._index_queues.append(index_queue)
self._workers.append(w)
self._data_queue = self._worker_result_queue
self._reset()
def _try_put_index(self):
raise NotImplementedError
def _reset(self):
self._sampler_iter = iter(self.sampler)
self._send_idx = 0
self._rcvd_idx = 0
self._task_info = {}
self._workers_status = [True for _ in range(self.num_workers)]
for _ in range(2 * self.num_workers):
self._try_put_index()
def _process_data(self, data):
self._rcvd_idx += 1
self._try_put_index()
return data
def _get_data(self):
if self.timeout > 0:
success, data = self._try_get_data(self.timeout)
if success:
return data
else:
_raise_timeout_error()
else:
while True:
success, data = self._try_get_data()
if success:
return data
def _get_next_batch(self):
while True:
while self._rcvd_idx < self._send_idx:
info = self._task_info[self._rcvd_idx]
worker_id = info[0]
if (
len(info) == 2 or self._workers_status[worker_id]
): # has data or work is still active
break
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
else:
self._shutdown_workers()
raise StopIteration
if len(self._task_info[self._rcvd_idx]) == 2:
data = self._task_info.pop(self._rcvd_idx)[1]
return self._process_data(data)
idx, data = self._get_data()
if isinstance(data, int): # Check if StopIteration in StreamDataset
self._mark_worker_as_unavailable(data)
self._try_put_index()
continue
if idx != self._rcvd_idx:
self._task_info[idx] += (data,)
else:
del self._task_info[idx]
return self._process_data(data)
def _try_get_data(self, timeout=GLOBAL_TIMEOUT):
try:
data = self._data_queue.get(timeout=timeout)
return (True, data)
except Exception as e:
failed_workers = []
for worker_id, w in enumerate(self._workers):
if self._workers_status[worker_id] and not w.is_alive():
failed_workers.append((worker_id, w))
self._mark_worker_as_unavailable(worker_id)
if w.exitcode == -9:
logger.debug(
"Maybe memory is not enough, please request for more memory!"
)
if len(failed_workers) > 0:
pids_str = ", ".join(str(w_info[1].pid) for w_info in failed_workers)
w_ids_str = ", ".join(str(w_info[0]) for w_info in failed_workers)
exitcode_str = ", ".join(
str(w_info[1].exitcode) for w_info in failed_workers
)
raise RuntimeError(
"DataLoader worker (worker(s): {} , pid(s): {}) exited unexpectedly, exitcode(s): {}".format(
w_ids_str, pids_str, exitcode_str
)
)
if isinstance(e, queue.Empty):
return (False, None)
def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
q = self._index_queues[worker_id]
q.put(None)
self._workers_status[worker_id] = False
assert self._workers_done_event.is_set() == shutdown
def _shutdown_workers(self):
if not self._shutdown:
self._shutdown = True
try:
self._workers_done_event.set()
for worker_id in range(len(self._workers)):
if self._workers_status[worker_id]:
self._mark_worker_as_unavailable(worker_id, shutdown=True)
for w in self._workers:
w.join(timeout=GLOBAL_TIMEOUT)
for q in self._index_queues:
q.cancel_join_thread()
q.close()
self._data_queue.cancel_join_thread()
self._data_queue.close()
finally:
for w in self._workers:
if w.is_alive():
w.terminate()
def __del__(self):
self._shutdown_workers()
class _BaseMapDataLoaderIter(_PreLoader):
def __init__(self, loader, preload):
super().__init__(loader, preload)
def __len__(self):
return len(self.sampler)
def __iter__(self):
return self
def __next__(self):
if self.preload:
cached = self.pre_load_device_cache
......@@ -272,11 +436,8 @@ class _BaseMapDataLoaderIter(_PreLoader):
self._try_load_tensor()
return out
else:
if self.num_processed >= len(self):
raise StopIteration
minibatch = self._get_next_batch()
self.num_processed += 1
return minibatch
data = self._get_next_batch()
return data
def _try_load_tensor(self, cached=True):
if self.num_processed >= len(self):
......@@ -290,199 +451,69 @@ class _BaseMapDataLoaderIter(_PreLoader):
class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
def __init__(self, loader, preload):
super(_SerialMapDataLoaderIter, self).__init__(loader, preload)
self.indices_iter = iter(self.sampler)
self._sampler_iter = iter(self.sampler)
def _get_next_batch(self):
indices = next(self.indices_iter)
indices = next(self._sampler_iter)
items = [self.dataset[idx] for idx in indices]
trans_items = self.transform.apply_batch(items)
return self.collator.apply(trans_items)
class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
__initialized = False
class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter, _ParallelDataLoaderIter):
def __init__(self, loader, preload):
super(_ParallelMapDataLoaderIter, self).__init__(loader, preload)
_BaseMapDataLoaderIter.__init__(self, loader, preload)
_ParallelDataLoaderIter.__init__(self)
self.task_queues = [
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
]
self.feed_batch_idx = multiprocessing.Value("i", 0)
self.target_batch_idx = multiprocessing.Value("i", 0)
self.shutdown_flag = multiprocessing.Value("i", 0)
def _try_put_index(self):
try:
index = next(self._sampler_iter)
except StopIteration:
return
for _ in range(self.num_workers): # find the next active worker, if any
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]:
break
self.trans_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]
self._index_queues[worker_queue_idx].put((self._send_idx, index))
self._task_info[self._send_idx] = (worker_queue_idx,)
self._send_idx += 1
# use shared-memory queue implemented by pyarrow plasma store.
from .tools._queue import PlasmaShmQueue
self.batch_queue = PlasmaShmQueue(maxsize=2)
_worker_info = None
self.task_feeding_worker = multiprocessing.Process(
target=_task_feeding_loop,
args=(
iter(self.sampler),
self.task_queues,
self.num_workers,
self.divide,
self.shutdown_flag,
self.feed_batch_idx,
),
daemon=True,
)
gc.collect()
self.task_feeding_worker.start()
self.workers = []
for worker_id in range(self.num_workers):
worker = multiprocessing.Process(
target=_worker_loop,
args=(
self.dataset,
self.task_queues[worker_id],
self.trans_data_queues[worker_id],
self.transform,
self.seed + worker_id + 1,
self.shutdown_flag,
),
daemon=True,
)
gc.collect()
worker.start()
self.workers.append(worker)
if self.divide:
self.data_collecting_worker = multiprocessing.Process(
target=_data_gathering_loop,
args=(
self.trans_data_queues,
self.batch_queue,
self.collator,
len(self),
self.num_workers,
self.shutdown_flag,
self.target_batch_idx,
),
daemon=True,
)
else:
self.data_collecting_worker = multiprocessing.Process(
target=_data_selecting_loop,
args=(
self.trans_data_queues,
self.batch_queue,
self.collator,
len(self),
self.num_workers,
self.shutdown_flag,
self.target_batch_idx,
),
daemon=True,
)
gc.collect()
self.data_collecting_worker.start()
class WorkerInfo(object):
__initialized = False
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
self.__keys = tuple(kwargs.keys())
self.__initialized = True
def _check_workers(self):
# Check the status of each worker.
if not self.data_collecting_worker.is_alive():
exitcode = self.data_collecting_worker.exitcode
if exitcode != 0:
raise RuntimeError("data collecting worker died. {}".format(exitcode))
if not self.task_feeding_worker.is_alive():
exitcode = self.task_feeding_worker.exitcode
if exitcode != 0:
raise RuntimeError("task feeding worker died. {}".format(exitcode))
for worker_id, worker in enumerate(self.workers):
if not worker.is_alive():
exitcode = worker.exitcode
if exitcode != 0:
raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode))
logger.debug("all workers are alive.")
def _get_next_batch(self):
start_time = time.time()
while True:
self._check_workers()
try:
return self.batch_queue.get(timeout=1)
except queue.Empty:
logger.debug("batch queue empty!")
waited_time = time.time() - start_time
if self.timeout > 0:
if waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
def _shutdown(self):
with self.shutdown_flag.get_lock():
self.shutdown_flag.value = 1
if self.task_feeding_worker.is_alive():
self.task_feeding_worker.terminate()
self.task_feeding_worker.join()
if self.data_collecting_worker.is_alive():
self.data_collecting_worker.terminate()
self.data_collecting_worker.join()
def __setattr__(self, key, val):
if self.__initialized:
raise RuntimeError(
"Cannot assign attributes to {} objects".format(self.__class__.__name__)
)
return super(WorkerInfo, self).__setattr__(key, val)
for worker in self.workers:
if worker.is_alive():
worker.terminate()
worker.join()
def __repr__(self):
items = []
for k in self.__keys:
items.append("{}={}".format(k, getattr(self, k)))
return "{}({})".format(self.__class__.__name__, ", ".join(items))
for q in self.trans_data_queues:
q.cancel_join_thread()
q.close()
for q in self.task_queues:
q.cancel_join_thread()
q.close()
self.batch_queue.cancel_join_thread()
self.batch_queue.close()
def __del__(self):
if self.__initialized:
self._shutdown()
def get_worker_info():
return _worker_info
class _BaseStreamDataLoaderIter(_PreLoader):
def __init__(self, loader, preload):
super().__init__(preload)
self.dataset = loader.dataset
self.sampler = loader.sampler
self.transform = loader.transform
self.collator = loader.collator
self.num_workers = loader.num_workers
self.timeout = loader.timeout
self.timeout_event = loader.timeout_event
def _get_next_batch(self):
raise NotImplementedError
def _process_raw_data(self, raw_data):
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
ret = []
for idx in range(len(data[0])):
ret.append(tuple(e[idx] for e in data))
return ret
def __iter__(self):
return self
super().__init__(loader, preload)
self.dataset_iter = iter(self.dataset)
def __next__(self):
if self.preload:
......@@ -503,8 +534,6 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
def __init__(self, loader, preload):
super().__init__(loader, preload)
self.dataset_iter = iter(self.dataset)
self.idx = 0
self.unused = []
def _try_get_raw_data(self, start_time):
raw_data = None
......@@ -516,382 +545,153 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
raw_data = next(self.dataset_iter)
if self.timeout > 0:
timer.cancel()
except KeyboardInterrupt:
raw_data = self.timeout_event()
except AttributeError as error:
raise error
except:
if self.timeout > 0:
timer.cancel()
waited_time = time.time() - start_time
if waited_time > self.timeout:
raw_data = self.timeout_event()
_raise_timeout_error()
return raw_data
def _get_next_batch(self):
ret = []
start_time = time.time()
while len(ret) < self.sampler.batch_size:
if len(self.unused) != 0:
batch_data = self.unused
else:
raw_data = self._try_get_raw_data(start_time)
batch_data = self._process_raw_data(raw_data)
while len(batch_data) != 0 and len(ret) < self.sampler.batch_size:
data = batch_data.pop()
ret.append(self.transform.apply(data))
self.unused = batch_data
ret.append(self.transform.apply(raw_data))
return self.collator.apply(ret)
class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
__initialized = False
class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter, _ParallelDataLoaderIter):
def __init__(self, loader, preload):
super().__init__(loader, preload)
self.shutdown_flag = multiprocessing.Value("i", 0)
self.raw_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]
self.trans_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]
# shared-memory queue implemented by pyarrow plasma store
from .tools._queue import PlasmaShmQueue
self.batch_queue = PlasmaShmQueue(maxsize=2)
self.recieve_worker = multiprocessing.Process(
target=self._worker_to_raw_data_queues, daemon=True
)
gc.collect()
self.recieve_worker.start()
self.transform_workers = []
for worker_id in range(self.num_workers):
worker = multiprocessing.Process(
target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True
)
gc.collect()
worker.start()
self.transform_workers.append(worker)
self.collect_worker = multiprocessing.Process(
target=self._worker_to_batch_queue, daemon=True
)
gc.collect()
self.collect_worker.start()
self.__initialized = True
def _put_raw_data_queues(self, raw_data, qidx):
batch_data = self._process_raw_data(raw_data)
for data in batch_data:
while True:
qidx = qidx % self.num_workers
try:
self.raw_data_queues[qidx].put(data)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("raw data queue %d is full" % qidx)
finally:
qidx += 1
return qidx
_BaseStreamDataLoaderIter.__init__(self, loader, preload)
_ParallelDataLoaderIter.__init__(self)
def _worker_to_raw_data_queues(self):
dataset_iter = iter(self.dataset)
qidx = 0
while True:
if self.shutdown_flag.value == 1:
break
raw_data = next(dataset_iter)
qidx = self._put_raw_data_queues(raw_data, qidx)
def _get_remaind_data(self, place_holder):
num = self.sampler.batch_size
for _ in range(num - 1):
place_holder.append(next(self.dataset_iter))
return place_holder
def _worker_to_trans_data_queues(self, worker_id):
while True:
if self.shutdown_flag.value == 1:
break
try:
data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)
except queue.Empty:
continue
trans_data = self.transform.apply(data)
while True:
try:
self.trans_data_queues[worker_id].put(trans_data)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch queue if full")
def _worker_to_batch_queue(self):
cnt = -1
trans_items = []
while True:
if self.shutdown_flag.value == 1:
break
cnt += 1
queue_id = cnt % self.num_workers
try:
trans_item = self.trans_data_queues[queue_id].get(
timeout=GLOBAL_TIMEOUT
)
except queue.Empty:
continue
trans_items.append(trans_item)
if len(trans_items) == self.sampler.batch_size:
batch_data = self.collator.apply(trans_items)
while True:
def _try_put_index(self):
try:
self.batch_queue.put(batch_data, timeout=1)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch queue is full")
trans_items = []
def _check_workers(self):
if not self.collect_worker.is_alive():
exitcode = self.collect_worker.exitcode
if exitcode != 0:
raise RuntimeError("collator worker died. {}".format(exitcode))
for worker_id, worker in enumerate(self.transform_workers):
if not worker.is_alive():
exitcode = worker.exitcode
if exitcode != 0:
raise RuntimeError(
"worker: {} died. {}".format(worker_id, exitcode)
)
def _get_next_batch(self):
if self.parallel_stream is False:
start_time = time.time()
while True:
self._check_workers()
try:
return self.batch_queue.get(timeout=1)
except queue.Empty:
logger.debug("batch queue empty!")
place_holder = [next(self.dataset_iter)]
waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout:
self._put_raw_data_queues(self.timeout_event(), 0)
def _shutdown(self):
with self.shutdown_flag.get_lock():
self.shutdown_flag.value = 1
if self.recieve_worker.is_alive():
self.recieve_worker.terminate()
self.recieve_worker.join()
if self.collect_worker.is_alive():
self.collect_worker.terminate()
self.collect_worker.join()
for worker in self.transform_workers:
if worker.is_alive():
worker.terminate()
worker.join()
for q in self.raw_data_queues:
q.cancel_join_thread()
q.close()
for q in self.trans_data_queues:
q.cancel_join_thread()
q.close()
self.batch_queue.cancel_join_thread()
self.batch_queue.close()
def __del__(self):
if self.__initialized:
self._shutdown()
def _task_feeding_loop(
indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
):
# Feed the indices into the task queues
while True:
if shutdown_flag.value == 1:
break
batch_idx = feed_batch_idx.value
try:
indices = next(indices_iter)
except StopIteration:
break
if divide:
# make sure all task_queues is ready for put
while any([q.full() for q in task_queues]):
if shutdown_flag.value == 1:
return
# divide into small pieces, feed to different workers.
sub_num = math.ceil(len(indices) / num_workers)
for worker_id in range(num_workers):
sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num]
task_queues[worker_id].put((batch_idx, sub_indices))
_raise_timeout_error()
place_holder = self._get_remaind_data(place_holder)
else:
# distribute tasks to different workers uniformly.
target_id = batch_idx % num_workers
while task_queues[target_id].full():
if shutdown_flag.value == 1:
place_holder = next(self._sampler_iter)
except StopIteration:
return
task_queues[target_id].put((batch_idx, indices))
with feed_batch_idx.get_lock():
feed_batch_idx.value += 1
def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag):
# Get dataset items and do the transform
random.seed(seed)
np.random.seed(seed)
while True:
if shutdown_flag.value == 1:
for _ in range(self.num_workers):
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]:
break
try:
batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT)
except queue.Empty:
continue
if len(indices) > 0:
items = [dataset[idx] for idx in indices]
trans_items = transform.apply_batch(items)
else:
# in case of incomplete last batch
trans_items = ()
while True:
try:
trans_data_queue.put((batch_idx, trans_items), timeout=1)
break
except queue.Full:
if shutdown_flag.value == 1:
break
logger.debug("batch part queue is full!")
return
self._index_queues[worker_queue_idx].put((self._send_idx, place_holder))
self._task_info[self._send_idx] = (worker_queue_idx,)
self._send_idx += 1
def _data_gathering_loop(
trans_data_queues,
batch_queue,
collator,
length,
num_workers,
shutdown_flag,
target_idx,
):
# Gathering the small pieces of batch data into full batch data
while True:
if shutdown_flag.value == 1:
break
target_batch_idx = target_idx.value
class ManagerWatchdog(object):
def __init__(self):
self.manager_pid = os.getppid()
self.manager_dead = False
if target_batch_idx >= length:
break
def is_alive(self):
if not self.manager_dead:
self.manager_dead = os.getppid() != self.manager_pid
return not self.manager_dead
full_trans_items = []
for worker_id in range(num_workers):
while True:
def stream_fetcher(
dataset_iter, place_holder, transform, collate, parallel_stream, batch_size
):
data = []
for idx in place_holder:
try:
batch_idx, trans_items = trans_data_queues[worker_id].get(
timeout=GLOBAL_TIMEOUT
)
break
except queue.Empty:
if shutdown_flag.value == 1:
break
logger.debug(
"worker:{} data queue get timeout! target batch idx:{}".format(
worker_id, target_batch_idx
)
)
if batch_idx != target_batch_idx:
raise RuntimeError(
"Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
worker_id
)
)
if parallel_stream is False:
raw_data = idx
else:
full_trans_items.extend(trans_items)
# Merge different parts into a batch.
full_batch = collator.apply(full_trans_items)
raw_data = next(dataset_iter)
trans_items = transform.apply(raw_data)
data.append(trans_items)
while True:
try:
batch_queue.put(full_batch, timeout=1)
break
except queue.Full:
if shutdown_flag.value == 1:
except StopIteration:
break
logger.debug("batch queue is full!")
with target_idx.get_lock():
target_idx.value += 1
if len(data) == 0:
raise StopIteration
data = collate.apply(data)
return data
batch_queue.disconnect_client()
def map_fetcher(dataset, place_holder, transform, collate, parallel_stream, batch_size):
items = [dataset[idx] for idx in place_holder]
trans_items = transform.apply_batch(items)
data = collate.apply(trans_items)
return data
def _data_selecting_loop(
trans_data_queues,
batch_queue,
collator,
length,
def _worker_loop(
dataset,
index_queue,
data_queue,
done_event,
transform,
collate,
batch_size,
seed,
worker_id,
num_workers,
shutdown_flag,
target_idx,
datakind,
parallel_stream,
):
# Make sure that batch is generated exactly with the same order as generated indices
while True:
if shutdown_flag.value == 1:
break
target_batch_idx = target_idx.value
if target_batch_idx >= length:
break
target_worker_id = target_batch_idx % num_workers
while True:
random.seed(seed)
np.random.seed(seed)
watchdog = ManagerWatchdog()
iteration_end = False
fetcher = map_fetcher
if datakind == "stream":
global _worker_info
_worker_info = WorkerInfo(idx=worker_id, worker=num_workers, seed=seed)
dataset = iter(dataset)
fetcher = stream_fetcher
while watchdog.is_alive():
try:
batch_idx, trans_items = trans_data_queues[target_worker_id].get(
timeout=GLOBAL_TIMEOUT
)
batch_data = collator.apply(trans_items)
break
r = index_queue.get(timeout=GLOBAL_TIMEOUT)
except queue.Empty:
if shutdown_flag.value == 1:
continue
if r is None:
assert done_event.is_set() or iteration_end
break
logger.debug(
"worker:{} data queue get timeout! target batch idx:{}".format(
target_worker_id, target_batch_idx
)
)
if batch_idx != target_batch_idx:
raise RuntimeError(
"batch_idx {} mismatch the target_batch_idx {}".format(
batch_idx, target_batch_idx
)
)
elif done_event.is_set() or iteration_end:
continue
while True:
idx, place_holder = r
try:
batch_queue.put(batch_data, timeout=1)
break
except queue.Full:
if shutdown_flag.value == 1:
break
logger.debug("batch queue is full!")
with target_idx.get_lock():
target_idx.value += 1
data = fetcher(
dataset, place_holder, transform, collate, parallel_stream, batch_size
)
except Exception as e:
if isinstance(e, StopIteration) and datakind == "stream":
data = worker_id
iteration_end = True
else:
raise e
data_queue.put((idx, data))
del data, idx, place_holder, r
batch_queue.disconnect_client()
if done_event.is_set():
data_queue.disconnect_client()
data_queue.close()
......@@ -2,6 +2,7 @@
import collections.abc
import math
from abc import ABC, abstractmethod
from itertools import count
from typing import Any, Generator, Iterator, List, Union
import numpy as np
......@@ -126,13 +127,15 @@ class MapSampler(Sampler):
if self.world_size > 1:
indices = self.scatter(indices)
step, length = self.batch_size, len(indices)
batch_index = [indices[i : i + step] for i in range(0, length, step)]
batch = []
for idx in indices:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if self.drop_last and len(batch_index[-1]) < self.batch_size:
batch_index.pop()
return iter(batch_index)
if len(batch) > 0 and not self.drop_last:
yield batch
class StreamSampler(Sampler):
......@@ -151,10 +154,18 @@ class StreamSampler(Sampler):
self.batch_size = batch_size
def __iter__(self):
return self
return self.batch()
def __next__(self):
return iter(range(self.batch_size))
def batch(self):
batch = []
for idx in self.sample():
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
def sample(self):
return count(start=0)
class SequentialSampler(MapSampler):
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
import os
import platform
import time
......@@ -7,7 +15,7 @@ import numpy as np
import pytest
from megengine.data.collator import Collator
from megengine.data.dataloader import DataLoader
from megengine.data.dataloader import DataLoader, get_worker_info
from megengine.data.dataset import ArrayDataset, StreamDataset
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler
from megengine.data.transform import (
......@@ -29,14 +37,10 @@ def init_dataset():
def test_dataloader_init():
dataset = init_dataset()
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=2, divide=True)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, timeout=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=0, divide=True)
dataloader = DataLoader(dataset)
assert isinstance(dataloader.sampler, SequentialSampler)
......@@ -54,10 +58,8 @@ def test_dataloader_init():
class MyStream(StreamDataset):
def __init__(self, number, batch=False, error_foramt=False, block=False):
def __init__(self, number, block=False):
self.number = number
self.batch = batch
self.error_format = error_foramt
self.block = block
def __iter__(self):
......@@ -65,22 +67,14 @@ class MyStream(StreamDataset):
if self.block:
for _ in range(10):
time.sleep(1)
if self.batch:
data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number]))
else:
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
if self.error_format:
yield (data, cnt)
else:
yield (False, (data, cnt))
raise StopIteration
@pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers):
dataset = MyStream(100, batch=batch)
def test_stream_dataloader(num_workers):
dataset = MyStream(100)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
......@@ -90,7 +84,6 @@ def test_stream_dataloader(batch, num_workers):
)
check_set = set()
for step, data in enumerate(dataloader):
if step == 10:
break
......@@ -101,18 +94,9 @@ def test_stream_dataloader(batch, num_workers):
check_set.add(i)
def test_stream_dataloader_error():
dataset = MyStream(100, error_foramt=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler)
with pytest.raises(AssertionError, match=r".*tuple.*"):
data_iter = iter(dataloader)
next(data_iter)
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False, block=True)
dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=2)
......@@ -140,17 +124,6 @@ def test_dataloader_parallel():
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=False,
)
for (data, label) in dataloader:
assert data.shape == (4, 1, 32, 32)
assert label.shape == (4,)
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=True,
)
for (data, label) in dataloader:
assert data.shape == (4, 1, 32, 32)
......@@ -205,7 +178,7 @@ def test_dataloader_parallel_worker_exception():
transform=FakeErrorTransform(),
num_workers=2,
)
with pytest.raises(RuntimeError, match=r"worker.*died"):
with pytest.raises(RuntimeError, match=r"exited unexpectedly"):
data_iter = iter(dataloader)
batch_data = next(data_iter)
......@@ -213,18 +186,15 @@ def test_dataloader_parallel_worker_exception():
def _multi_instances_parallel_dataloader_worker():
dataset = init_dataset()
for divide_flag in [True, False]:
train_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=divide_flag,
)
val_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
num_workers=2,
divide=divide_flag,
)
for idx, (data, label) in enumerate(train_dataloader):
assert data.shape == (4, 1, 32, 32)
......@@ -261,18 +231,81 @@ def test_dataloader_parallel_multi_instances_multiprocessing():
assert p.exitcode == 0
@pytest.mark.parametrize("num_workers", [0, 2])
def test_timeout_event(num_workers):
def cb():
return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,))))
def partition(ls, size):
return [ls[i : i + size] for i in range(0, len(ls), size)]
class MyPreStream(StreamDataset):
def __init__(self, number, block=False):
self.number = [i for i in range(number)]
self.block = block
self.data = []
for i in range(100):
self.data.append(np.random.randint(0, 256, (2, 2, 3), dtype="uint8"))
def __iter__(self):
worker_info = get_worker_info()
per_worker = int(math.ceil((len(self.data)) / float(worker_info.worker)))
pre_data = iter(partition(self.data, per_worker)[worker_info.idx])
pre_cnt = partition(self.number, per_worker)[worker_info.idx]
for cnt in pre_cnt:
if self.block:
for _ in range(10):
time.sleep(1)
yield (next(pre_data), cnt)
raise StopIteration
dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4)
@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
def test_prestream_dataloader_multiprocessing():
dataset = MyPreStream(100)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset, sampler, num_workers=num_workers, timeout=2, timeout_event=cb
dataset,
sampler,
Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]),
num_workers=2,
parallel_stream=True,
)
for _, data in enumerate(dataloader):
np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3)))
np.testing.assert_equal(data[1], np.ones(shape=(4,)))
check_set = set()
for step, data in enumerate(dataloader):
if step == 10:
break
assert data[0].shape == (4, 3, 2, 2)
assert data[1].shape == (4,)
for i in data[1]:
assert i not in check_set
check_set.add(i)
@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
def test_predataloader_parallel_worker_exception():
dataset = MyPreStream(100)
class FakeErrorTransform(Transform):
def __init__(self):
pass
def apply(self, input):
raise RuntimeError("test raise error")
return input
dataloader = DataLoader(
dataset,
sampler=StreamSampler(batch_size=4),
transform=FakeErrorTransform(),
num_workers=2,
parallel_stream=True,
)
with pytest.raises(RuntimeError, match=r"exited unexpectedly"):
data_iter = iter(dataloader)
batch_data = next(data_iter)
print(batch_data.shape)
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import gc
import math
import os
import platform
import time
......@@ -8,7 +16,7 @@ import numpy as np
import pytest
from megengine.data.collator import Collator
from megengine.data.dataloader import DataLoader
from megengine.data.dataloader import DataLoader, get_worker_info
from megengine.data.dataset import ArrayDataset, StreamDataset
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler
from megengine.data.transform import (
......@@ -30,14 +38,10 @@ def init_dataset():
def test_dataloader_init():
dataset = init_dataset()
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=2, divide=True)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, timeout=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=0, divide=True)
dataloader = DataLoader(dataset, preload=True)
assert isinstance(dataloader.sampler, SequentialSampler)
......@@ -59,10 +63,8 @@ def test_dataloader_init():
class MyStream(StreamDataset):
def __init__(self, number, batch=False, error_foramt=False, block=False):
def __init__(self, number, block=False):
self.number = number
self.batch = batch
self.error_format = error_foramt
self.block = block
def __iter__(self):
......@@ -70,22 +72,14 @@ class MyStream(StreamDataset):
if self.block:
for _ in range(10):
time.sleep(1)
if self.batch:
data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number]))
else:
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
if self.error_format:
yield (data, cnt)
else:
yield (False, (data, cnt))
raise StopIteration
@pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers):
dataset = MyStream(100, batch=batch)
def test_stream_dataloader(num_workers):
dataset = MyStream(100)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
......@@ -107,18 +101,9 @@ def test_stream_dataloader(batch, num_workers):
check_set.add(i)
def test_stream_dataloader_error():
dataset = MyStream(100, error_foramt=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler, preload=True)
with pytest.raises(AssertionError, match=r".*tuple.*"):
data_iter = iter(dataloader)
next(data_iter)
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False, block=True)
dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
......@@ -150,18 +135,6 @@ def test_dataloader_parallel():
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=False,
preload=True,
)
for (data, label) in dataloader:
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=True,
preload=True,
)
for (data, label) in dataloader:
......@@ -219,7 +192,7 @@ def test_dataloader_parallel_worker_exception():
num_workers=2,
preload=True,
)
with pytest.raises(RuntimeError, match=r"worker.*died"):
with pytest.raises(RuntimeError, match=r"exited unexpectedly"):
data_iter = iter(dataloader)
batch_data = next(data_iter)
......@@ -227,19 +200,16 @@ def test_dataloader_parallel_worker_exception():
def _multi_instances_parallel_dataloader_worker():
dataset = init_dataset()
for divide_flag in [True, False]:
train_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=divide_flag,
preload=True,
)
val_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
num_workers=2,
divide=divide_flag,
preload=True,
)
for idx, (data, label) in enumerate(train_dataloader):
......@@ -276,25 +246,3 @@ def test_dataloader_parallel_multi_instances_multiprocessing():
for p in processes:
p.join()
assert p.exitcode == 0
@pytest.mark.parametrize("num_workers", [0, 2])
def test_timeout_event(num_workers):
def cb():
return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,))))
dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
sampler,
num_workers=num_workers,
timeout=2,
timeout_event=cb,
preload=True,
)
for _, data in enumerate(dataloader):
np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3)))
np.testing.assert_equal(data[1], np.ones(shape=(4,)))
break
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册