From 8d531727943f44861b17c522fbdebf3e659a905a Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 14 Sep 2020 13:51:31 +0800 Subject: [PATCH] move DataLoader._worker_loop to top level (#27247) * move worker loop to top level * move reader process loop to top level * fix failed unittests --- .../fluid/dataloader/dataloader_iter.py | 174 +++++++++--------- python/paddle/fluid/reader.py | 49 ++--- .../test_imperative_data_loader_process.py | 5 +- .../test_multiprocess_dataloader_exception.py | 15 +- 4 files changed, 126 insertions(+), 117 deletions(-) diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 6a996493e4d..1ef0d494e07 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -347,6 +347,92 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): return self.__next__() +# NOTE(chenweihang): _worker_loop must be top level method to be pickled +def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, + collate_fn, init_fn, worker_id, num_workers, + use_shared_memory): + try: + # NOTE: [ mmap files clear ] When the child process exits unexpectedly, + # some shared memory objects may have been applied for but have not yet + # been put into the inter-process Queue. This part of the object needs + # to be cleaned up when the process ends. + CleanupFuncRegistrar.register(_cleanup_mmap) + + # set signal handler + core._set_process_signal_handler() + + global _worker_info + _worker_info = WorkerInfo( + id=worker_id, num_workers=num_workers, dataset=dataset) + + init_exception = None + try: + if init_fn is not None: + init_fn(worker_id) + fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, + collate_fn, True) + except: + init_exception = Exception("init_fn failed in worker {}: " \ + "{}".format(worker_id, sys.exc_info())) + + iterator_drained = False + parent_watch_dog = ParentWatchDog() + + while parent_watch_dog.is_alive(): + try: + data = indices_queue.get(MP_INDICES_CHECK_INTERVAL) + except queue.Empty: + continue + + # None as poison piil, so worker event should be set + if data is None: + assert done_event.is_set() or iterator_drained, \ + "get None when worker done_event set" + break + # If worker done event is set but get still get data in + # indices_queue, remaining data should be get and skipped. + if done_event.is_set() or iterator_drained: + continue + + idx, indices = data + try: + if init_exception is not None: + batch = init_exception + init_exception = None + else: + batch = fetcher.fetch(indices) + except Exception as e: + if isinstance( + e, StopIteration) and dataset_kind == _DatasetKind.ITER: + out_queue.put(_IterableDatasetStopIteration(worker_id)) + iterator_drained = True + else: + out_queue.put((idx, e)) + else: + if use_shared_memory: + # FIXME(dkp): _convert_to_tensor_list only support np.array + # list now, should support paddle.Tensor list + if isinstance(batch[0][0], paddle.Tensor): + np_batch = [] + for sample in batch: + np_batch.append([s.numpy() for s in sample]) + batch = np_batch + + tensor_list = core._convert_to_tensor_list(batch) + out_queue.put((idx, tensor_list)) + core._remove_tensor_list_mmap_fds(tensor_list) + else: + out_queue.put((idx, batch)) + except KeyboardInterrupt: + # NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process + pass + except: + six.reraise(*sys.exc_info()) + finally: + if use_shared_memory: + _cleanup_mmap() + + class _DataLoaderIterMultiProcess(_DataLoaderIterBase): def __init__(self, loader): super(_DataLoaderIterMultiProcess, self).__init__(loader) @@ -404,11 +490,11 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): indices_queue = multiprocessing.Queue() self._indices_queues.append(indices_queue) worker = multiprocessing.Process( - target=self._worker_loop, + target=_worker_loop, args=(self._dataset, self._dataset_kind, indices_queue, self._data_queue, self._workers_done_event, self._collate_fn, self._worker_init_fn, i, - self._num_workers)) + self._num_workers, self._use_shared_memory)) worker.daemon = True worker.start() self._workers.append(worker) @@ -483,90 +569,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): self._blocking_queue.kill() logging.error("DataLoader reader thread raised an exception!") - def _worker_loop(self, dataset, dataset_kind, indices_queue, out_queue, - done_event, collate_fn, init_fn, worker_id, num_workers): - try: - # NOTE: [ mmap files clear ] When the child process exits unexpectedly, - # some shared memory objects may have been applied for but have not yet - # been put into the inter-process Queue. This part of the object needs - # to be cleaned up when the process ends. - CleanupFuncRegistrar.register(_cleanup_mmap) - - # set signal handler - core._set_process_signal_handler() - - global _worker_info - _worker_info = WorkerInfo( - id=worker_id, num_workers=num_workers, dataset=dataset) - - init_exception = None - try: - if init_fn is not None: - init_fn(worker_id) - fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, - collate_fn, True) - except: - init_exception = Exception("init_fn failed in worker {}: " \ - "{}".format(worker_id, sys.exc_info())) - - iterator_drained = False - parent_watch_dog = ParentWatchDog() - - while parent_watch_dog.is_alive(): - try: - data = indices_queue.get(MP_INDICES_CHECK_INTERVAL) - except queue.Empty: - continue - - # None as poison piil, so worker event should be set - if data is None: - assert done_event.is_set() or iterator_drained, \ - "get None when worker done_event set" - break - # If worker done event is set but get still get data in - # indices_queue, remaining data should be get and skipped. - if done_event.is_set() or iterator_drained: - continue - - idx, indices = data - try: - if init_exception is not None: - batch = init_exception - init_exception = None - else: - batch = fetcher.fetch(indices) - except Exception as e: - if isinstance( - e, - StopIteration) and dataset_kind == _DatasetKind.ITER: - out_queue.put(_IterableDatasetStopIteration(worker_id)) - iterator_drained = True - else: - out_queue.put((idx, e)) - else: - if self._use_shared_memory: - # FIXME(dkp): _convert_to_tensor_list only support np.array - # list now, should support paddle.Tensor list - if isinstance(batch[0][0], paddle.Tensor): - np_batch = [] - for sample in batch: - np_batch.append([s.numpy() for s in sample]) - batch = np_batch - - tensor_list = core._convert_to_tensor_list(batch) - out_queue.put((idx, tensor_list)) - core._remove_tensor_list_mmap_fds(tensor_list) - else: - out_queue.put((idx, batch)) - except KeyboardInterrupt: - # NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process - pass - except: - six.reraise(*sys.exc_info()) - finally: - if self._use_shared_memory: - _cleanup_mmap() - def _thread_loop(self): while not self._thread_done_event.is_set(): batch = self._get_data() diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 76c95be75d6..f2bb567b95b 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -85,6 +85,30 @@ def _convert_places(places): return ret +# NOTE(chenweihang): _reader_process_loop must be top level method to be pickled +def _reader_process_loop(batch_reader, data_queue): + try: + # set signal handler + core._set_process_signal_handler() + + # NOTE: [ mmap files clear ] When the child process exits unexpectedly, + # some shared memory objects may have been applied for but have not yet + # been put into the inter-process Queue. This part of the object needs + # to be cleaned up when the process ends. + CleanupFuncRegistrar.register(_cleanup_mmap) + + for batch in batch_reader(): + tensor_list = core._convert_to_tensor_list(batch) + data_queue.put(tensor_list) + core._remove_tensor_list_mmap_fds(tensor_list) + data_queue.put(None) + except KeyboardInterrupt: + # NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process + pass + except: + six.reraise(*sys.exc_info()) + + class DataLoaderBase(object): def __init__(self): self._places = None @@ -811,7 +835,8 @@ class DygraphGeneratorLoader(DataLoaderBase): global multiprocess_queue_set multiprocess_queue_set.add(self._data_queue) self._process = multiprocessing.Process( - target=self._reader_process_loop) + target=_reader_process_loop, + args=(self._batch_reader, self._data_queue)) self._process.daemon = True self._process.start() @@ -867,28 +892,6 @@ class DygraphGeneratorLoader(DataLoaderBase): self._blocking_queue.kill() logging.error("DataLoader reader thread raised an exception!") - def _reader_process_loop(self): - try: - # set signal handler - core._set_process_signal_handler() - - # NOTE: [ mmap files clear ] When the child process exits unexpectedly, - # some shared memory objects may have been applied for but have not yet - # been put into the inter-process Queue. This part of the object needs - # to be cleaned up when the process ends. - CleanupFuncRegistrar.register(_cleanup_mmap) - - for batch in self._batch_reader(): - tensor_list = core._convert_to_tensor_list(batch) - self._data_queue.put(tensor_list) - core._remove_tensor_list_mmap_fds(tensor_list) - self._data_queue.put(None) - except KeyboardInterrupt: - # NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process - pass - except: - six.reraise(*sys.exc_info()) - def _reader_thread_loop_for_multiprocess(self): while not self._thread_done_event.is_set(): try: diff --git a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_process.py b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_process.py index 7fb2cb0090d..9b2d71c9f90 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_process.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_process.py @@ -18,6 +18,7 @@ import multiprocessing import numpy as np import paddle.fluid as fluid from paddle.fluid import core +from paddle.fluid.reader import _reader_process_loop if sys.version_info[0] == 2: import Queue as queue @@ -66,7 +67,7 @@ class TestDygraphDataLoaderProcess(unittest.TestCase): batch_generator_creator(self.batch_size, self.batch_num), places=fluid.CPUPlace()) loader._data_queue = queue.Queue(self.batch_num + 1) - loader._reader_process_loop() + _reader_process_loop(loader._batch_reader, loader._data_queue) # For clean memory mapped files util_queue = multiprocessing.Queue(self.batch_num + 1) for _ in range(self.batch_num): @@ -94,7 +95,7 @@ class TestDygraphDataLoaderProcess(unittest.TestCase): loader._data_queue = queue.Queue(self.batch_num + 1) exception = None try: - loader._reader_process_loop() + _reader_process_loop(loader._batch_reader, loader._data_queue) except core.EnforceNotMet as ex: exception = ex self.assertIsNotNone(exception) diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py index 3a8867f6bd2..6fd14b40bc9 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py @@ -27,6 +27,7 @@ import paddle.fluid.core as core from paddle.io import Dataset, IterableDataset, BatchSampler, DataLoader from paddle.fluid.dygraph.nn import Linear from paddle.fluid.dygraph.base import to_variable +from paddle.fluid.dataloader.dataloader_iter import _worker_loop class RandomDataset(Dataset): @@ -185,9 +186,10 @@ class TestDataLoaderWorkerLoop(unittest.TestCase): for i in range(10): indices_queue.put([i, i + 10]) indices_queue.put(None) - loader._worker_loop( - loader._dataset, 0, indices_queue, loader._data_queue, - loader._workers_done_event, _collate_fn, _init_fn, 0, 1) + _worker_loop(loader._dataset, 0, indices_queue, + loader._data_queue, loader._workers_done_event, + _collate_fn, _init_fn, 0, 1, + loader._use_shared_memory) self.assertTrue(False) except AssertionError: pass @@ -228,9 +230,10 @@ class TestDataLoaderWorkerLoop(unittest.TestCase): indices_queue.put([i, i + 10]) indices_queue.put(None) loader._workers_done_event.set() - loader._worker_loop( - loader._dataset, 0, indices_queue, loader._data_queue, - loader._workers_done_event, _collate_fn, _init_fn, 0, 1) + _worker_loop(loader._dataset, 0, indices_queue, + loader._data_queue, loader._workers_done_event, + _collate_fn, _init_fn, 0, 1, + loader._use_shared_memory) self.assertTrue(True) except AssertionError: pass -- GitLab