From edc92ccfd64cd283b66dbd9edcc24bd7b79899ad Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 22 Aug 2022 20:26:46 +0800 Subject: [PATCH] perf(imperative/data): improve dataloader preformance GitOrigin-RevId: 7d8d52aaeb47e7ec6c3efa282ff9014a4b7d1f01 --- .../python/megengine/data/dataloader.py | 926 +++++++----------- imperative/python/megengine/data/sampler.py | 29 +- .../python/test/unit/data/test_dataloader.py | 179 ++-- .../test/unit/data/test_pre_dataloader.py | 122 +-- 4 files changed, 524 insertions(+), 732 deletions(-) diff --git a/imperative/python/megengine/data/dataloader.py b/imperative/python/megengine/data/dataloader.py index e04c7324f..3d671c1e9 100644 --- a/imperative/python/megengine/data/dataloader.py +++ b/imperative/python/megengine/data/dataloader.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import collections import gc -import math +import itertools import multiprocessing import os import platform @@ -9,7 +9,6 @@ import queue import random import threading import time -from typing import Callable, Union import numpy as np @@ -67,12 +66,6 @@ class DataLoader: the batch. ``0`` means using single-process. Default: 0 timeout: if positive, means the timeout value(second) for collecting a batch from workers. Default: 0 - timeout_event: callback function triggered by timeout, default to raise - runtime error. - divide: define the paralleling strategy in multi-processing mode. - ``True`` means one batch is divided into :attr:`num_workers` pieces, and - the workers will process these pieces parallelly. ``False`` means - different sub-process will process different batch. Default: False preload: whether to enable the preloading strategy of the dataloader. When enabling, the dataloader will preload one batch to the device memory to speed up the whole training process. @@ -85,7 +78,6 @@ class DataLoader: which will improve the training speed at the cost of **higher device memory usage** (due to one more batch data on device memory). This feature saves more time when your NN training time is short or your machine's host PCIe bandwidth for each device is low. """ - __initialized = False def __init__( self, @@ -95,9 +87,8 @@ class DataLoader: collator: Collator = None, num_workers: int = 0, timeout: int = 0, - timeout_event: Callable = _raise_timeout_error, - divide: bool = False, preload: bool = False, + parallel_stream: bool = False, ): if num_workers < 0: raise ValueError("num_workers should not be negative") @@ -105,23 +96,22 @@ class DataLoader: if timeout < 0: raise ValueError("timeout should not be negative") - if divide and num_workers <= 1: - raise ValueError("divide should not be set to True when num_workers <= 1") - self.dataset = dataset - self.num_workers = num_workers self.timeout = timeout - self.timeout_event = timeout_event - - self.divide = divide self.preload = preload + self.parallel_stream = parallel_stream if isinstance(dataset, StreamDataset): self.sampler = sampler if sampler else StreamSampler(batch_size=1) assert isinstance( self.sampler, StreamSampler ), "types of dataset and sampler do not match" + if parallel_stream is False and self.num_workers > 1: + logger.warning( + "Data time will be affected by getting origin-data, please set parallel_stream in order to speed up dataloader!" + ) + self.datakind = "stream" else: assert isinstance( dataset, Dataset @@ -134,16 +124,7 @@ class DataLoader: assert isinstance( self.sampler, MapSampler ), "types of dataset and sampler do not match" - - if divide: - if self.sampler.batch_size <= self.num_workers: - raise ValueError( - "batch size must not smaller than num_workers in divide mode." - ) - elif self.sampler.batch_size % self.num_workers: - logger.warning( - "batch size is not divisible by num_workers, may lose performance in divide mode." - ) + self.datakind = "map" if transform is None: self.transform = PseudoTransform() @@ -155,7 +136,8 @@ class DataLoader: else: self.collator = collator - self.__initialized = True + if platform.system() == "Linux" and self.num_workers > 0: + self.check_memory_rationality() def __iter__(self): if platform.system() == "Windows" and self.num_workers > 0: @@ -187,15 +169,50 @@ class DataLoader: def __len__(self): return len(self.sampler) + def check_memory_rationality(self): + import psutil + + main_memory = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024 + total_memory = (self.num_workers + 1) * main_memory + current_memory = ( + int(os.popen("cat /sys/fs/cgroup/memory/memory.limit_in_bytes").read()) + / 1024 + / 1024 + / 1024 + ) + if current_memory < total_memory: + logger.warning( + "Each worker need to read the shared meta-data, which will be increasing the reference count." + "Copy-On-Write propety will lead to 'memory leak', the memory usage will end up being " + + total_memory + + " GB" + "However the current requested memory is " + current_memory + " GB" + "Maybe you can request more memory or uesd np-array to save meta-data rather than List or Tuple" + ) + class _PreLoader: - def __init__(self, preload): + def __init__(self, loader, preload): + self.dataset = loader.dataset + self.sampler = loader.sampler + self.seed = _random_seed_generator().__next__() + self.transform = loader.transform + self.collator = loader.collator + self.num_workers = loader.num_workers + self.timeout = loader.timeout + self.num_processed = 0 + self.datakind = loader.datakind + self.parallel_stream = loader.parallel_stream + if preload: self.default_device = get_default_device() self.pre_load_device = self.default_device + ":" + str(_sh.get_next()) self.pre_load_device_cache = None self.preload = preload + def __iter__(self): + return self + """ strategy one: load from numpy data, and generate dtype tensor """ @@ -237,29 +254,176 @@ class _PreLoader: return out -class _BaseMapDataLoaderIter(_PreLoader): - def __init__(self, loader, preload): - super().__init__(preload) - self.dataset = loader.dataset - self.sampler = loader.sampler - self.seed = _random_seed_generator().__next__() - self.transform = loader.transform - self.collator = loader.collator - self.num_workers = loader.num_workers - self.timeout = loader.timeout - self.timeout_event = loader.timeout_event - self.divide = loader.divide - self.num_processed = 0 +class _ParallelDataLoaderIter: + def __init__(self): + self._worker_queue_idx_cycle = itertools.cycle(range(self.num_workers)) + from .tools._queue import PlasmaShmQueue - def _get_next_batch(self): + self._worker_result_queue = PlasmaShmQueue() + self._shutdown = False + self._workers_done_event = multiprocessing.Event() + self._index_queues = [] + self._workers = [] + + for i in range(self.num_workers): + index_queue = multiprocessing.Queue() + index_queue.cancel_join_thread() + w = multiprocessing.Process( + target=_worker_loop, + args=( + self.dataset, + index_queue, + self._worker_result_queue, + self._workers_done_event, + self.transform, + self.collator, + self.sampler.batch_size, + self.seed + i, + i, + self.num_workers, + self.datakind, + self.parallel_stream, + ), + daemon=True, + ) + gc.collect() + w.start() + self._index_queues.append(index_queue) + self._workers.append(w) + + self._data_queue = self._worker_result_queue + self._reset() + + def _try_put_index(self): raise NotImplementedError + def _reset(self): + self._sampler_iter = iter(self.sampler) + + self._send_idx = 0 + self._rcvd_idx = 0 + + self._task_info = {} + self._workers_status = [True for _ in range(self.num_workers)] + for _ in range(2 * self.num_workers): + self._try_put_index() + + def _process_data(self, data): + self._rcvd_idx += 1 + self._try_put_index() + return data + + def _get_data(self): + if self.timeout > 0: + success, data = self._try_get_data(self.timeout) + if success: + return data + else: + _raise_timeout_error() + else: + while True: + success, data = self._try_get_data() + if success: + return data + + def _get_next_batch(self): + while True: + while self._rcvd_idx < self._send_idx: + info = self._task_info[self._rcvd_idx] + worker_id = info[0] + if ( + len(info) == 2 or self._workers_status[worker_id] + ): # has data or work is still active + break + del self._task_info[self._rcvd_idx] + self._rcvd_idx += 1 + else: + self._shutdown_workers() + raise StopIteration + + if len(self._task_info[self._rcvd_idx]) == 2: + data = self._task_info.pop(self._rcvd_idx)[1] + return self._process_data(data) + + idx, data = self._get_data() + + if isinstance(data, int): # Check if StopIteration in StreamDataset + self._mark_worker_as_unavailable(data) + self._try_put_index() + continue + + if idx != self._rcvd_idx: + self._task_info[idx] += (data,) + else: + del self._task_info[idx] + return self._process_data(data) + + def _try_get_data(self, timeout=GLOBAL_TIMEOUT): + try: + data = self._data_queue.get(timeout=timeout) + return (True, data) + except Exception as e: + failed_workers = [] + for worker_id, w in enumerate(self._workers): + if self._workers_status[worker_id] and not w.is_alive(): + failed_workers.append((worker_id, w)) + self._mark_worker_as_unavailable(worker_id) + if w.exitcode == -9: + logger.debug( + "Maybe memory is not enough, please request for more memory!" + ) + if len(failed_workers) > 0: + pids_str = ", ".join(str(w_info[1].pid) for w_info in failed_workers) + w_ids_str = ", ".join(str(w_info[0]) for w_info in failed_workers) + exitcode_str = ", ".join( + str(w_info[1].exitcode) for w_info in failed_workers + ) + raise RuntimeError( + "DataLoader worker (worker(s): {} , pid(s): {}) exited unexpectedly, exitcode(s): {}".format( + w_ids_str, pids_str, exitcode_str + ) + ) + + if isinstance(e, queue.Empty): + return (False, None) + + def _mark_worker_as_unavailable(self, worker_id, shutdown=False): + q = self._index_queues[worker_id] + q.put(None) + self._workers_status[worker_id] = False + assert self._workers_done_event.is_set() == shutdown + + def _shutdown_workers(self): + if not self._shutdown: + self._shutdown = True + try: + self._workers_done_event.set() + for worker_id in range(len(self._workers)): + if self._workers_status[worker_id]: + self._mark_worker_as_unavailable(worker_id, shutdown=True) + for w in self._workers: + w.join(timeout=GLOBAL_TIMEOUT) + for q in self._index_queues: + q.cancel_join_thread() + q.close() + self._data_queue.cancel_join_thread() + self._data_queue.close() + finally: + for w in self._workers: + if w.is_alive(): + w.terminate() + + def __del__(self): + self._shutdown_workers() + + +class _BaseMapDataLoaderIter(_PreLoader): + def __init__(self, loader, preload): + super().__init__(loader, preload) + def __len__(self): return len(self.sampler) - def __iter__(self): - return self - def __next__(self): if self.preload: cached = self.pre_load_device_cache @@ -272,11 +436,8 @@ class _BaseMapDataLoaderIter(_PreLoader): self._try_load_tensor() return out else: - if self.num_processed >= len(self): - raise StopIteration - minibatch = self._get_next_batch() - self.num_processed += 1 - return minibatch + data = self._get_next_batch() + return data def _try_load_tensor(self, cached=True): if self.num_processed >= len(self): @@ -290,199 +451,69 @@ class _BaseMapDataLoaderIter(_PreLoader): class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): def __init__(self, loader, preload): super(_SerialMapDataLoaderIter, self).__init__(loader, preload) - self.indices_iter = iter(self.sampler) + self._sampler_iter = iter(self.sampler) def _get_next_batch(self): - indices = next(self.indices_iter) + indices = next(self._sampler_iter) items = [self.dataset[idx] for idx in indices] trans_items = self.transform.apply_batch(items) return self.collator.apply(trans_items) -class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): - __initialized = False - +class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter, _ParallelDataLoaderIter): def __init__(self, loader, preload): - super(_ParallelMapDataLoaderIter, self).__init__(loader, preload) - - self.task_queues = [ - multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) - ] + _BaseMapDataLoaderIter.__init__(self, loader, preload) + _ParallelDataLoaderIter.__init__(self) - self.feed_batch_idx = multiprocessing.Value("i", 0) - self.target_batch_idx = multiprocessing.Value("i", 0) - self.shutdown_flag = multiprocessing.Value("i", 0) + def _try_put_index(self): + try: + index = next(self._sampler_iter) + except StopIteration: + return + for _ in range(self.num_workers): # find the next active worker, if any + worker_queue_idx = next(self._worker_queue_idx_cycle) + if self._workers_status[worker_queue_idx]: + break - self.trans_data_queues = [ - multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) - ] + self._index_queues[worker_queue_idx].put((self._send_idx, index)) + self._task_info[self._send_idx] = (worker_queue_idx,) + self._send_idx += 1 - # use shared-memory queue implemented by pyarrow plasma store. - from .tools._queue import PlasmaShmQueue - self.batch_queue = PlasmaShmQueue(maxsize=2) - - self.task_feeding_worker = multiprocessing.Process( - target=_task_feeding_loop, - args=( - iter(self.sampler), - self.task_queues, - self.num_workers, - self.divide, - self.shutdown_flag, - self.feed_batch_idx, - ), - daemon=True, - ) - gc.collect() - self.task_feeding_worker.start() +_worker_info = None - self.workers = [] - for worker_id in range(self.num_workers): - worker = multiprocessing.Process( - target=_worker_loop, - args=( - self.dataset, - self.task_queues[worker_id], - self.trans_data_queues[worker_id], - self.transform, - self.seed + worker_id + 1, - self.shutdown_flag, - ), - daemon=True, - ) - gc.collect() - worker.start() - self.workers.append(worker) - if self.divide: - self.data_collecting_worker = multiprocessing.Process( - target=_data_gathering_loop, - args=( - self.trans_data_queues, - self.batch_queue, - self.collator, - len(self), - self.num_workers, - self.shutdown_flag, - self.target_batch_idx, - ), - daemon=True, - ) - else: - self.data_collecting_worker = multiprocessing.Process( - target=_data_selecting_loop, - args=( - self.trans_data_queues, - self.batch_queue, - self.collator, - len(self), - self.num_workers, - self.shutdown_flag, - self.target_batch_idx, - ), - daemon=True, - ) - gc.collect() - self.data_collecting_worker.start() +class WorkerInfo(object): + __initialized = False + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + self.__keys = tuple(kwargs.keys()) self.__initialized = True - def _check_workers(self): - # Check the status of each worker. - if not self.data_collecting_worker.is_alive(): - exitcode = self.data_collecting_worker.exitcode - if exitcode != 0: - raise RuntimeError("data collecting worker died. {}".format(exitcode)) + def __setattr__(self, key, val): + if self.__initialized: + raise RuntimeError( + "Cannot assign attributes to {} objects".format(self.__class__.__name__) + ) + return super(WorkerInfo, self).__setattr__(key, val) - if not self.task_feeding_worker.is_alive(): - exitcode = self.task_feeding_worker.exitcode - if exitcode != 0: - raise RuntimeError("task feeding worker died. {}".format(exitcode)) + def __repr__(self): + items = [] + for k in self.__keys: + items.append("{}={}".format(k, getattr(self, k))) + return "{}({})".format(self.__class__.__name__, ", ".join(items)) - for worker_id, worker in enumerate(self.workers): - if not worker.is_alive(): - exitcode = worker.exitcode - if exitcode != 0: - raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode)) - logger.debug("all workers are alive.") - - def _get_next_batch(self): - start_time = time.time() - while True: - self._check_workers() - try: - return self.batch_queue.get(timeout=1) - except queue.Empty: - logger.debug("batch queue empty!") - waited_time = time.time() - start_time - if self.timeout > 0: - if waited_time > self.timeout: - raise RuntimeError("get_next_batch timeout!") - - def _shutdown(self): - with self.shutdown_flag.get_lock(): - self.shutdown_flag.value = 1 - - if self.task_feeding_worker.is_alive(): - self.task_feeding_worker.terminate() - self.task_feeding_worker.join() - - if self.data_collecting_worker.is_alive(): - self.data_collecting_worker.terminate() - self.data_collecting_worker.join() - - for worker in self.workers: - if worker.is_alive(): - worker.terminate() - worker.join() - - for q in self.trans_data_queues: - q.cancel_join_thread() - q.close() - - for q in self.task_queues: - q.cancel_join_thread() - q.close() - - self.batch_queue.cancel_join_thread() - self.batch_queue.close() - - def __del__(self): - if self.__initialized: - self._shutdown() +def get_worker_info(): + return _worker_info class _BaseStreamDataLoaderIter(_PreLoader): def __init__(self, loader, preload): - super().__init__(preload) - self.dataset = loader.dataset - self.sampler = loader.sampler - self.transform = loader.transform - self.collator = loader.collator - self.num_workers = loader.num_workers - self.timeout = loader.timeout - self.timeout_event = loader.timeout_event - - def _get_next_batch(self): - raise NotImplementedError - - def _process_raw_data(self, raw_data): - assert len(raw_data) == 2 and isinstance( - raw_data[0], bool - ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched." - if not raw_data[0]: - data = list((x,) for x in raw_data[1]) - else: - data = raw_data[1] - ret = [] - for idx in range(len(data[0])): - ret.append(tuple(e[idx] for e in data)) - return ret - - def __iter__(self): - return self + super().__init__(loader, preload) + self.dataset_iter = iter(self.dataset) def __next__(self): if self.preload: @@ -503,8 +534,6 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): def __init__(self, loader, preload): super().__init__(loader, preload) self.dataset_iter = iter(self.dataset) - self.idx = 0 - self.unused = [] def _try_get_raw_data(self, start_time): raw_data = None @@ -516,382 +545,153 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): raw_data = next(self.dataset_iter) if self.timeout > 0: timer.cancel() - except KeyboardInterrupt: - raw_data = self.timeout_event() + except AttributeError as error: + raise error except: if self.timeout > 0: timer.cancel() waited_time = time.time() - start_time if waited_time > self.timeout: - raw_data = self.timeout_event() + _raise_timeout_error() return raw_data def _get_next_batch(self): ret = [] start_time = time.time() while len(ret) < self.sampler.batch_size: - if len(self.unused) != 0: - batch_data = self.unused - else: - raw_data = self._try_get_raw_data(start_time) - batch_data = self._process_raw_data(raw_data) - - while len(batch_data) != 0 and len(ret) < self.sampler.batch_size: - data = batch_data.pop() - ret.append(self.transform.apply(data)) - self.unused = batch_data + raw_data = self._try_get_raw_data(start_time) + ret.append(self.transform.apply(raw_data)) return self.collator.apply(ret) -class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): - __initialized = False - +class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter, _ParallelDataLoaderIter): def __init__(self, loader, preload): - super().__init__(loader, preload) - - self.shutdown_flag = multiprocessing.Value("i", 0) - - self.raw_data_queues = [ - multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) - ] - - self.trans_data_queues = [ - multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) - ] - - # shared-memory queue implemented by pyarrow plasma store - from .tools._queue import PlasmaShmQueue - - self.batch_queue = PlasmaShmQueue(maxsize=2) - - self.recieve_worker = multiprocessing.Process( - target=self._worker_to_raw_data_queues, daemon=True - ) - gc.collect() - self.recieve_worker.start() - - self.transform_workers = [] - for worker_id in range(self.num_workers): - worker = multiprocessing.Process( - target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True - ) - gc.collect() - worker.start() - self.transform_workers.append(worker) + _BaseStreamDataLoaderIter.__init__(self, loader, preload) + _ParallelDataLoaderIter.__init__(self) - self.collect_worker = multiprocessing.Process( - target=self._worker_to_batch_queue, daemon=True - ) - gc.collect() - self.collect_worker.start() - - self.__initialized = True + def _get_remaind_data(self, place_holder): + num = self.sampler.batch_size + for _ in range(num - 1): + place_holder.append(next(self.dataset_iter)) + return place_holder - def _put_raw_data_queues(self, raw_data, qidx): - batch_data = self._process_raw_data(raw_data) - for data in batch_data: - while True: - qidx = qidx % self.num_workers - try: - self.raw_data_queues[qidx].put(data) - break - except queue.Full: - if self.shutdown_flag.value == 1: - break - logger.debug("raw data queue %d is full" % qidx) - finally: - qidx += 1 - return qidx - - def _worker_to_raw_data_queues(self): - dataset_iter = iter(self.dataset) - qidx = 0 - while True: - if self.shutdown_flag.value == 1: - break - raw_data = next(dataset_iter) - qidx = self._put_raw_data_queues(raw_data, qidx) + def _try_put_index(self): + try: + if self.parallel_stream is False: + start_time = time.time() + place_holder = [next(self.dataset_iter)] + waited_time = time.time() - start_time + if self.timeout > 0 and waited_time > self.timeout: + _raise_timeout_error() + place_holder = self._get_remaind_data(place_holder) + else: + place_holder = next(self._sampler_iter) + except StopIteration: + return - def _worker_to_trans_data_queues(self, worker_id): - while True: - if self.shutdown_flag.value == 1: - break - try: - data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT) - except queue.Empty: - continue - trans_data = self.transform.apply(data) - while True: - try: - self.trans_data_queues[worker_id].put(trans_data) - break - except queue.Full: - if self.shutdown_flag.value == 1: - break - logger.debug("batch queue if full") - - def _worker_to_batch_queue(self): - cnt = -1 - trans_items = [] - while True: - if self.shutdown_flag.value == 1: + for _ in range(self.num_workers): + worker_queue_idx = next(self._worker_queue_idx_cycle) + if self._workers_status[worker_queue_idx]: break - cnt += 1 - queue_id = cnt % self.num_workers - try: - trans_item = self.trans_data_queues[queue_id].get( - timeout=GLOBAL_TIMEOUT - ) - except queue.Empty: - continue - trans_items.append(trans_item) - if len(trans_items) == self.sampler.batch_size: - batch_data = self.collator.apply(trans_items) - while True: - try: - self.batch_queue.put(batch_data, timeout=1) - break - except queue.Full: - if self.shutdown_flag.value == 1: - break - logger.debug("batch queue is full") - trans_items = [] - - def _check_workers(self): - if not self.collect_worker.is_alive(): - exitcode = self.collect_worker.exitcode - if exitcode != 0: - raise RuntimeError("collator worker died. {}".format(exitcode)) - - for worker_id, worker in enumerate(self.transform_workers): - if not worker.is_alive(): - exitcode = worker.exitcode - if exitcode != 0: - raise RuntimeError( - "worker: {} died. {}".format(worker_id, exitcode) - ) - - def _get_next_batch(self): - start_time = time.time() - while True: - self._check_workers() - try: - return self.batch_queue.get(timeout=1) - except queue.Empty: - logger.debug("batch queue empty!") - waited_time = time.time() - start_time - if self.timeout > 0 and waited_time > self.timeout: - self._put_raw_data_queues(self.timeout_event(), 0) - - def _shutdown(self): - with self.shutdown_flag.get_lock(): - self.shutdown_flag.value = 1 - - if self.recieve_worker.is_alive(): - self.recieve_worker.terminate() - self.recieve_worker.join() - - if self.collect_worker.is_alive(): - self.collect_worker.terminate() - self.collect_worker.join() - - for worker in self.transform_workers: - if worker.is_alive(): - worker.terminate() - worker.join() + else: + return - for q in self.raw_data_queues: - q.cancel_join_thread() - q.close() + self._index_queues[worker_queue_idx].put((self._send_idx, place_holder)) + self._task_info[self._send_idx] = (worker_queue_idx,) + self._send_idx += 1 - for q in self.trans_data_queues: - q.cancel_join_thread() - q.close() - self.batch_queue.cancel_join_thread() - self.batch_queue.close() +class ManagerWatchdog(object): + def __init__(self): + self.manager_pid = os.getppid() + self.manager_dead = False - def __del__(self): - if self.__initialized: - self._shutdown() + def is_alive(self): + if not self.manager_dead: + self.manager_dead = os.getppid() != self.manager_pid + return not self.manager_dead -def _task_feeding_loop( - indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx +def stream_fetcher( + dataset_iter, place_holder, transform, collate, parallel_stream, batch_size ): - # Feed the indices into the task queues - while True: - if shutdown_flag.value == 1: - break - batch_idx = feed_batch_idx.value + data = [] + for idx in place_holder: try: - indices = next(indices_iter) + if parallel_stream is False: + raw_data = idx + else: + raw_data = next(dataset_iter) + trans_items = transform.apply(raw_data) + data.append(trans_items) + except StopIteration: break - if divide: - # make sure all task_queues is ready for put - while any([q.full() for q in task_queues]): - if shutdown_flag.value == 1: - return - # divide into small pieces, feed to different workers. - sub_num = math.ceil(len(indices) / num_workers) - for worker_id in range(num_workers): - sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num] - task_queues[worker_id].put((batch_idx, sub_indices)) - else: - # distribute tasks to different workers uniformly. - target_id = batch_idx % num_workers - while task_queues[target_id].full(): - if shutdown_flag.value == 1: - return - task_queues[target_id].put((batch_idx, indices)) - with feed_batch_idx.get_lock(): - feed_batch_idx.value += 1 - - -def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag): - # Get dataset items and do the transform + + if len(data) == 0: + raise StopIteration + data = collate.apply(data) + return data + + +def map_fetcher(dataset, place_holder, transform, collate, parallel_stream, batch_size): + items = [dataset[idx] for idx in place_holder] + trans_items = transform.apply_batch(items) + data = collate.apply(trans_items) + return data + + +def _worker_loop( + dataset, + index_queue, + data_queue, + done_event, + transform, + collate, + batch_size, + seed, + worker_id, + num_workers, + datakind, + parallel_stream, +): random.seed(seed) np.random.seed(seed) - while True: - if shutdown_flag.value == 1: - break + watchdog = ManagerWatchdog() + iteration_end = False + fetcher = map_fetcher + if datakind == "stream": + global _worker_info + _worker_info = WorkerInfo(idx=worker_id, worker=num_workers, seed=seed) + dataset = iter(dataset) + fetcher = stream_fetcher + + while watchdog.is_alive(): try: - batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT) + r = index_queue.get(timeout=GLOBAL_TIMEOUT) except queue.Empty: continue - if len(indices) > 0: - items = [dataset[idx] for idx in indices] - trans_items = transform.apply_batch(items) - else: - # in case of incomplete last batch - trans_items = () - while True: - try: - trans_data_queue.put((batch_idx, trans_items), timeout=1) - break - except queue.Full: - if shutdown_flag.value == 1: - break - logger.debug("batch part queue is full!") - - -def _data_gathering_loop( - trans_data_queues, - batch_queue, - collator, - length, - num_workers, - shutdown_flag, - target_idx, -): - # Gathering the small pieces of batch data into full batch data - while True: - if shutdown_flag.value == 1: - break - - target_batch_idx = target_idx.value - - if target_batch_idx >= length: - break - - full_trans_items = [] - for worker_id in range(num_workers): - while True: - try: - batch_idx, trans_items = trans_data_queues[worker_id].get( - timeout=GLOBAL_TIMEOUT - ) - break - except queue.Empty: - if shutdown_flag.value == 1: - break - logger.debug( - "worker:{} data queue get timeout! target batch idx:{}".format( - worker_id, target_batch_idx - ) - ) - if batch_idx != target_batch_idx: - raise RuntimeError( - "Unexperted batch_idx in data gathering loop. worker_id:{}.".format( - worker_id - ) - ) - else: - full_trans_items.extend(trans_items) - - # Merge different parts into a batch. - full_batch = collator.apply(full_trans_items) - - while True: - try: - batch_queue.put(full_batch, timeout=1) - break - except queue.Full: - if shutdown_flag.value == 1: - break - logger.debug("batch queue is full!") - - with target_idx.get_lock(): - target_idx.value += 1 - - batch_queue.disconnect_client() - - -def _data_selecting_loop( - trans_data_queues, - batch_queue, - collator, - length, - num_workers, - shutdown_flag, - target_idx, -): - # Make sure that batch is generated exactly with the same order as generated indices - while True: - if shutdown_flag.value == 1: - break - - target_batch_idx = target_idx.value - - if target_batch_idx >= length: + if r is None: + assert done_event.is_set() or iteration_end break + elif done_event.is_set() or iteration_end: + continue - target_worker_id = target_batch_idx % num_workers - while True: - try: - batch_idx, trans_items = trans_data_queues[target_worker_id].get( - timeout=GLOBAL_TIMEOUT - ) - batch_data = collator.apply(trans_items) - break - except queue.Empty: - if shutdown_flag.value == 1: - break - logger.debug( - "worker:{} data queue get timeout! target batch idx:{}".format( - target_worker_id, target_batch_idx - ) - ) - - if batch_idx != target_batch_idx: - raise RuntimeError( - "batch_idx {} mismatch the target_batch_idx {}".format( - batch_idx, target_batch_idx - ) + idx, place_holder = r + try: + data = fetcher( + dataset, place_holder, transform, collate, parallel_stream, batch_size ) + except Exception as e: + if isinstance(e, StopIteration) and datakind == "stream": + data = worker_id + iteration_end = True + else: + raise e + data_queue.put((idx, data)) + del data, idx, place_holder, r - while True: - try: - batch_queue.put(batch_data, timeout=1) - break - except queue.Full: - if shutdown_flag.value == 1: - break - logger.debug("batch queue is full!") - - with target_idx.get_lock(): - target_idx.value += 1 - - batch_queue.disconnect_client() + if done_event.is_set(): + data_queue.disconnect_client() + data_queue.close() diff --git a/imperative/python/megengine/data/sampler.py b/imperative/python/megengine/data/sampler.py index 1e7bcc75f..7105f16f7 100644 --- a/imperative/python/megengine/data/sampler.py +++ b/imperative/python/megengine/data/sampler.py @@ -2,6 +2,7 @@ import collections.abc import math from abc import ABC, abstractmethod +from itertools import count from typing import Any, Generator, Iterator, List, Union import numpy as np @@ -126,13 +127,15 @@ class MapSampler(Sampler): if self.world_size > 1: indices = self.scatter(indices) - step, length = self.batch_size, len(indices) - batch_index = [indices[i : i + step] for i in range(0, length, step)] + batch = [] + for idx in indices: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] - if self.drop_last and len(batch_index[-1]) < self.batch_size: - batch_index.pop() - - return iter(batch_index) + if len(batch) > 0 and not self.drop_last: + yield batch class StreamSampler(Sampler): @@ -151,10 +154,18 @@ class StreamSampler(Sampler): self.batch_size = batch_size def __iter__(self): - return self + return self.batch() - def __next__(self): - return iter(range(self.batch_size)) + def batch(self): + batch = [] + for idx in self.sample(): + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + + def sample(self): + return count(start=0) class SequentialSampler(MapSampler): diff --git a/imperative/python/test/unit/data/test_dataloader.py b/imperative/python/test/unit/data/test_dataloader.py index c765feb60..a3af43164 100644 --- a/imperative/python/test/unit/data/test_dataloader.py +++ b/imperative/python/test/unit/data/test_dataloader.py @@ -1,4 +1,12 @@ # -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import math import os import platform import time @@ -7,7 +15,7 @@ import numpy as np import pytest from megengine.data.collator import Collator -from megengine.data.dataloader import DataLoader +from megengine.data.dataloader import DataLoader, get_worker_info from megengine.data.dataset import ArrayDataset, StreamDataset from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler from megengine.data.transform import ( @@ -29,14 +37,10 @@ def init_dataset(): def test_dataloader_init(): dataset = init_dataset() - with pytest.raises(ValueError): - dataloader = DataLoader(dataset, num_workers=2, divide=True) with pytest.raises(ValueError): dataloader = DataLoader(dataset, num_workers=-1) with pytest.raises(ValueError): dataloader = DataLoader(dataset, timeout=-1) - with pytest.raises(ValueError): - dataloader = DataLoader(dataset, num_workers=0, divide=True) dataloader = DataLoader(dataset) assert isinstance(dataloader.sampler, SequentialSampler) @@ -54,10 +58,8 @@ def test_dataloader_init(): class MyStream(StreamDataset): - def __init__(self, number, batch=False, error_foramt=False, block=False): + def __init__(self, number, block=False): self.number = number - self.batch = batch - self.error_format = error_foramt self.block = block def __iter__(self): @@ -65,22 +67,14 @@ class MyStream(StreamDataset): if self.block: for _ in range(10): time.sleep(1) - if self.batch: - data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8") - yield (True, (data, [cnt, cnt - self.number])) - else: - data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8") - if self.error_format: - yield (data, cnt) - else: - yield (False, (data, cnt)) + data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8") + yield (data, cnt) raise StopIteration -@pytest.mark.parametrize("batch", [True, False]) @pytest.mark.parametrize("num_workers", [0, 2]) -def test_stream_dataloader(batch, num_workers): - dataset = MyStream(100, batch=batch) +def test_stream_dataloader(num_workers): + dataset = MyStream(100) sampler = StreamSampler(batch_size=4) dataloader = DataLoader( dataset, @@ -90,7 +84,6 @@ def test_stream_dataloader(batch, num_workers): ) check_set = set() - for step, data in enumerate(dataloader): if step == 10: break @@ -101,18 +94,9 @@ def test_stream_dataloader(batch, num_workers): check_set.add(i) -def test_stream_dataloader_error(): - dataset = MyStream(100, error_foramt=True) - sampler = StreamSampler(batch_size=4) - dataloader = DataLoader(dataset, sampler) - with pytest.raises(AssertionError, match=r".*tuple.*"): - data_iter = iter(dataloader) - next(data_iter) - - @pytest.mark.parametrize("num_workers", [0, 2]) def test_stream_dataloader_timeout(num_workers): - dataset = MyStream(100, False, block=True) + dataset = MyStream(100, block=True) sampler = StreamSampler(batch_size=4) dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=2) @@ -140,17 +124,6 @@ def test_dataloader_parallel(): dataset, sampler=RandomSampler(dataset, batch_size=4, drop_last=False), num_workers=2, - divide=False, - ) - for (data, label) in dataloader: - assert data.shape == (4, 1, 32, 32) - assert label.shape == (4,) - - dataloader = DataLoader( - dataset, - sampler=RandomSampler(dataset, batch_size=4, drop_last=False), - num_workers=2, - divide=True, ) for (data, label) in dataloader: assert data.shape == (4, 1, 32, 32) @@ -205,7 +178,7 @@ def test_dataloader_parallel_worker_exception(): transform=FakeErrorTransform(), num_workers=2, ) - with pytest.raises(RuntimeError, match=r"worker.*died"): + with pytest.raises(RuntimeError, match=r"exited unexpectedly"): data_iter = iter(dataloader) batch_data = next(data_iter) @@ -213,26 +186,23 @@ def test_dataloader_parallel_worker_exception(): def _multi_instances_parallel_dataloader_worker(): dataset = init_dataset() - for divide_flag in [True, False]: - train_dataloader = DataLoader( - dataset, - sampler=RandomSampler(dataset, batch_size=4, drop_last=False), - num_workers=2, - divide=divide_flag, - ) - val_dataloader = DataLoader( - dataset, - sampler=RandomSampler(dataset, batch_size=10, drop_last=False), - num_workers=2, - divide=divide_flag, - ) - for idx, (data, label) in enumerate(train_dataloader): - assert data.shape == (4, 1, 32, 32) - assert label.shape == (4,) - if idx % 5 == 0: - for val_data, val_label in val_dataloader: - assert val_data.shape == (10, 1, 32, 32) - assert val_label.shape == (10,) + train_dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=4, drop_last=False), + num_workers=2, + ) + val_dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=10, drop_last=False), + num_workers=2, + ) + for idx, (data, label) in enumerate(train_dataloader): + assert data.shape == (4, 1, 32, 32) + assert label.shape == (4,) + if idx % 5 == 0: + for val_data, val_label in val_dataloader: + assert val_data.shape == (10, 1, 32, 32) + assert val_label.shape == (10,) def test_dataloader_parallel_multi_instances(): @@ -261,18 +231,81 @@ def test_dataloader_parallel_multi_instances_multiprocessing(): assert p.exitcode == 0 -@pytest.mark.parametrize("num_workers", [0, 2]) -def test_timeout_event(num_workers): - def cb(): - return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,)))) +def partition(ls, size): + return [ls[i : i + size] for i in range(0, len(ls), size)] - dataset = MyStream(100, block=True) + +class MyPreStream(StreamDataset): + def __init__(self, number, block=False): + self.number = [i for i in range(number)] + self.block = block + self.data = [] + for i in range(100): + self.data.append(np.random.randint(0, 256, (2, 2, 3), dtype="uint8")) + + def __iter__(self): + worker_info = get_worker_info() + per_worker = int(math.ceil((len(self.data)) / float(worker_info.worker))) + pre_data = iter(partition(self.data, per_worker)[worker_info.idx]) + pre_cnt = partition(self.number, per_worker)[worker_info.idx] + for cnt in pre_cnt: + if self.block: + for _ in range(10): + time.sleep(1) + yield (next(pre_data), cnt) + raise StopIteration + + +@pytest.mark.skipif( + platform.system() == "Windows", + reason="dataloader do not support parallel on windows", +) +def test_prestream_dataloader_multiprocessing(): + dataset = MyPreStream(100) sampler = StreamSampler(batch_size=4) + dataloader = DataLoader( + dataset, + sampler, + Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]), + num_workers=2, + parallel_stream=True, + ) + + check_set = set() + + for step, data in enumerate(dataloader): + if step == 10: + break + assert data[0].shape == (4, 3, 2, 2) + assert data[1].shape == (4,) + for i in data[1]: + assert i not in check_set + check_set.add(i) + + +@pytest.mark.skipif( + platform.system() == "Windows", + reason="dataloader do not support parallel on windows", +) +def test_predataloader_parallel_worker_exception(): + dataset = MyPreStream(100) + + class FakeErrorTransform(Transform): + def __init__(self): + pass + + def apply(self, input): + raise RuntimeError("test raise error") + return input dataloader = DataLoader( - dataset, sampler, num_workers=num_workers, timeout=2, timeout_event=cb + dataset, + sampler=StreamSampler(batch_size=4), + transform=FakeErrorTransform(), + num_workers=2, + parallel_stream=True, ) - for _, data in enumerate(dataloader): - np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3))) - np.testing.assert_equal(data[1], np.ones(shape=(4,))) - break + with pytest.raises(RuntimeError, match=r"exited unexpectedly"): + data_iter = iter(dataloader) + batch_data = next(data_iter) + print(batch_data.shape) diff --git a/imperative/python/test/unit/data/test_pre_dataloader.py b/imperative/python/test/unit/data/test_pre_dataloader.py index 228b93a40..489cb7265 100644 --- a/imperative/python/test/unit/data/test_pre_dataloader.py +++ b/imperative/python/test/unit/data/test_pre_dataloader.py @@ -1,5 +1,13 @@ # -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import gc +import math import os import platform import time @@ -8,7 +16,7 @@ import numpy as np import pytest from megengine.data.collator import Collator -from megengine.data.dataloader import DataLoader +from megengine.data.dataloader import DataLoader, get_worker_info from megengine.data.dataset import ArrayDataset, StreamDataset from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler from megengine.data.transform import ( @@ -30,14 +38,10 @@ def init_dataset(): def test_dataloader_init(): dataset = init_dataset() - with pytest.raises(ValueError): - dataloader = DataLoader(dataset, num_workers=2, divide=True) with pytest.raises(ValueError): dataloader = DataLoader(dataset, num_workers=-1) with pytest.raises(ValueError): dataloader = DataLoader(dataset, timeout=-1) - with pytest.raises(ValueError): - dataloader = DataLoader(dataset, num_workers=0, divide=True) dataloader = DataLoader(dataset, preload=True) assert isinstance(dataloader.sampler, SequentialSampler) @@ -59,10 +63,8 @@ def test_dataloader_init(): class MyStream(StreamDataset): - def __init__(self, number, batch=False, error_foramt=False, block=False): + def __init__(self, number, block=False): self.number = number - self.batch = batch - self.error_format = error_foramt self.block = block def __iter__(self): @@ -70,22 +72,14 @@ class MyStream(StreamDataset): if self.block: for _ in range(10): time.sleep(1) - if self.batch: - data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8") - yield (True, (data, [cnt, cnt - self.number])) - else: - data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8") - if self.error_format: - yield (data, cnt) - else: - yield (False, (data, cnt)) + data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8") + yield (data, cnt) raise StopIteration -@pytest.mark.parametrize("batch", [True, False]) @pytest.mark.parametrize("num_workers", [0, 2]) -def test_stream_dataloader(batch, num_workers): - dataset = MyStream(100, batch=batch) +def test_stream_dataloader(num_workers): + dataset = MyStream(100) sampler = StreamSampler(batch_size=4) dataloader = DataLoader( dataset, @@ -107,18 +101,9 @@ def test_stream_dataloader(batch, num_workers): check_set.add(i) -def test_stream_dataloader_error(): - dataset = MyStream(100, error_foramt=True) - sampler = StreamSampler(batch_size=4) - dataloader = DataLoader(dataset, sampler, preload=True) - with pytest.raises(AssertionError, match=r".*tuple.*"): - data_iter = iter(dataloader) - next(data_iter) - - @pytest.mark.parametrize("num_workers", [0, 2]) def test_stream_dataloader_timeout(num_workers): - dataset = MyStream(100, False, block=True) + dataset = MyStream(100, block=True) sampler = StreamSampler(batch_size=4) dataloader = DataLoader( @@ -150,18 +135,6 @@ def test_dataloader_parallel(): dataset, sampler=RandomSampler(dataset, batch_size=4, drop_last=False), num_workers=2, - divide=False, - preload=True, - ) - for (data, label) in dataloader: - assert data._tuple_shape == (4, 1, 32, 32) - assert label._tuple_shape == (4,) - - dataloader = DataLoader( - dataset, - sampler=RandomSampler(dataset, batch_size=4, drop_last=False), - num_workers=2, - divide=True, preload=True, ) for (data, label) in dataloader: @@ -219,7 +192,7 @@ def test_dataloader_parallel_worker_exception(): num_workers=2, preload=True, ) - with pytest.raises(RuntimeError, match=r"worker.*died"): + with pytest.raises(RuntimeError, match=r"exited unexpectedly"): data_iter = iter(dataloader) batch_data = next(data_iter) @@ -227,28 +200,25 @@ def test_dataloader_parallel_worker_exception(): def _multi_instances_parallel_dataloader_worker(): dataset = init_dataset() - for divide_flag in [True, False]: - train_dataloader = DataLoader( - dataset, - sampler=RandomSampler(dataset, batch_size=4, drop_last=False), - num_workers=2, - divide=divide_flag, - preload=True, - ) - val_dataloader = DataLoader( - dataset, - sampler=RandomSampler(dataset, batch_size=10, drop_last=False), - num_workers=2, - divide=divide_flag, - preload=True, - ) - for idx, (data, label) in enumerate(train_dataloader): - assert data._tuple_shape == (4, 1, 32, 32) - assert label._tuple_shape == (4,) - if idx % 5 == 0: - for val_data, val_label in val_dataloader: - assert val_data._tuple_shape == (10, 1, 32, 32) - assert val_label._tuple_shape == (10,) + train_dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=4, drop_last=False), + num_workers=2, + preload=True, + ) + val_dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=10, drop_last=False), + num_workers=2, + preload=True, + ) + for idx, (data, label) in enumerate(train_dataloader): + assert data._tuple_shape == (4, 1, 32, 32) + assert label._tuple_shape == (4,) + if idx % 5 == 0: + for val_data, val_label in val_dataloader: + assert val_data._tuple_shape == (10, 1, 32, 32) + assert val_label._tuple_shape == (10,) def test_dataloader_parallel_multi_instances(): @@ -276,25 +246,3 @@ def test_dataloader_parallel_multi_instances_multiprocessing(): for p in processes: p.join() assert p.exitcode == 0 - - -@pytest.mark.parametrize("num_workers", [0, 2]) -def test_timeout_event(num_workers): - def cb(): - return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,)))) - - dataset = MyStream(100, block=True) - sampler = StreamSampler(batch_size=4) - - dataloader = DataLoader( - dataset, - sampler, - num_workers=num_workers, - timeout=2, - timeout_event=cb, - preload=True, - ) - for _, data in enumerate(dataloader): - np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3))) - np.testing.assert_equal(data[1], np.ones(shape=(4,))) - break -- GitLab