未验证 提交 8d531727 编写于 作者: C Chen Weihang 提交者: GitHub

move DataLoader._worker_loop to top level (#27247)

* move worker loop to top level

* move reader process loop to top level

* fix failed unittests
上级 aae41c6f
...@@ -347,6 +347,92 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -347,6 +347,92 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
return self.__next__() return self.__next__()
# NOTE(chenweihang): _worker_loop must be top level method to be pickled
def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
collate_fn, init_fn, worker_id, num_workers,
use_shared_memory):
try:
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
# some shared memory objects may have been applied for but have not yet
# been put into the inter-process Queue. This part of the object needs
# to be cleaned up when the process ends.
CleanupFuncRegistrar.register(_cleanup_mmap)
# set signal handler
core._set_process_signal_handler()
global _worker_info
_worker_info = WorkerInfo(
id=worker_id, num_workers=num_workers, dataset=dataset)
init_exception = None
try:
if init_fn is not None:
init_fn(worker_id)
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset,
collate_fn, True)
except:
init_exception = Exception("init_fn failed in worker {}: " \
"{}".format(worker_id, sys.exc_info()))
iterator_drained = False
parent_watch_dog = ParentWatchDog()
while parent_watch_dog.is_alive():
try:
data = indices_queue.get(MP_INDICES_CHECK_INTERVAL)
except queue.Empty:
continue
# None as poison piil, so worker event should be set
if data is None:
assert done_event.is_set() or iterator_drained, \
"get None when worker done_event set"
break
# If worker done event is set but get still get data in
# indices_queue, remaining data should be get and skipped.
if done_event.is_set() or iterator_drained:
continue
idx, indices = data
try:
if init_exception is not None:
batch = init_exception
init_exception = None
else:
batch = fetcher.fetch(indices)
except Exception as e:
if isinstance(
e, StopIteration) and dataset_kind == _DatasetKind.ITER:
out_queue.put(_IterableDatasetStopIteration(worker_id))
iterator_drained = True
else:
out_queue.put((idx, e))
else:
if use_shared_memory:
# FIXME(dkp): _convert_to_tensor_list only support np.array
# list now, should support paddle.Tensor list
if isinstance(batch[0][0], paddle.Tensor):
np_batch = []
for sample in batch:
np_batch.append([s.numpy() for s in sample])
batch = np_batch
tensor_list = core._convert_to_tensor_list(batch)
out_queue.put((idx, tensor_list))
core._remove_tensor_list_mmap_fds(tensor_list)
else:
out_queue.put((idx, batch))
except KeyboardInterrupt:
# NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
pass
except:
six.reraise(*sys.exc_info())
finally:
if use_shared_memory:
_cleanup_mmap()
class _DataLoaderIterMultiProcess(_DataLoaderIterBase): class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
def __init__(self, loader): def __init__(self, loader):
super(_DataLoaderIterMultiProcess, self).__init__(loader) super(_DataLoaderIterMultiProcess, self).__init__(loader)
...@@ -404,11 +490,11 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -404,11 +490,11 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
indices_queue = multiprocessing.Queue() indices_queue = multiprocessing.Queue()
self._indices_queues.append(indices_queue) self._indices_queues.append(indices_queue)
worker = multiprocessing.Process( worker = multiprocessing.Process(
target=self._worker_loop, target=_worker_loop,
args=(self._dataset, self._dataset_kind, indices_queue, args=(self._dataset, self._dataset_kind, indices_queue,
self._data_queue, self._workers_done_event, self._data_queue, self._workers_done_event,
self._collate_fn, self._worker_init_fn, i, self._collate_fn, self._worker_init_fn, i,
self._num_workers)) self._num_workers, self._use_shared_memory))
worker.daemon = True worker.daemon = True
worker.start() worker.start()
self._workers.append(worker) self._workers.append(worker)
...@@ -483,90 +569,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -483,90 +569,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._blocking_queue.kill() self._blocking_queue.kill()
logging.error("DataLoader reader thread raised an exception!") logging.error("DataLoader reader thread raised an exception!")
def _worker_loop(self, dataset, dataset_kind, indices_queue, out_queue,
done_event, collate_fn, init_fn, worker_id, num_workers):
try:
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
# some shared memory objects may have been applied for but have not yet
# been put into the inter-process Queue. This part of the object needs
# to be cleaned up when the process ends.
CleanupFuncRegistrar.register(_cleanup_mmap)
# set signal handler
core._set_process_signal_handler()
global _worker_info
_worker_info = WorkerInfo(
id=worker_id, num_workers=num_workers, dataset=dataset)
init_exception = None
try:
if init_fn is not None:
init_fn(worker_id)
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset,
collate_fn, True)
except:
init_exception = Exception("init_fn failed in worker {}: " \
"{}".format(worker_id, sys.exc_info()))
iterator_drained = False
parent_watch_dog = ParentWatchDog()
while parent_watch_dog.is_alive():
try:
data = indices_queue.get(MP_INDICES_CHECK_INTERVAL)
except queue.Empty:
continue
# None as poison piil, so worker event should be set
if data is None:
assert done_event.is_set() or iterator_drained, \
"get None when worker done_event set"
break
# If worker done event is set but get still get data in
# indices_queue, remaining data should be get and skipped.
if done_event.is_set() or iterator_drained:
continue
idx, indices = data
try:
if init_exception is not None:
batch = init_exception
init_exception = None
else:
batch = fetcher.fetch(indices)
except Exception as e:
if isinstance(
e,
StopIteration) and dataset_kind == _DatasetKind.ITER:
out_queue.put(_IterableDatasetStopIteration(worker_id))
iterator_drained = True
else:
out_queue.put((idx, e))
else:
if self._use_shared_memory:
# FIXME(dkp): _convert_to_tensor_list only support np.array
# list now, should support paddle.Tensor list
if isinstance(batch[0][0], paddle.Tensor):
np_batch = []
for sample in batch:
np_batch.append([s.numpy() for s in sample])
batch = np_batch
tensor_list = core._convert_to_tensor_list(batch)
out_queue.put((idx, tensor_list))
core._remove_tensor_list_mmap_fds(tensor_list)
else:
out_queue.put((idx, batch))
except KeyboardInterrupt:
# NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
pass
except:
six.reraise(*sys.exc_info())
finally:
if self._use_shared_memory:
_cleanup_mmap()
def _thread_loop(self): def _thread_loop(self):
while not self._thread_done_event.is_set(): while not self._thread_done_event.is_set():
batch = self._get_data() batch = self._get_data()
......
...@@ -85,6 +85,30 @@ def _convert_places(places): ...@@ -85,6 +85,30 @@ def _convert_places(places):
return ret return ret
# NOTE(chenweihang): _reader_process_loop must be top level method to be pickled
def _reader_process_loop(batch_reader, data_queue):
try:
# set signal handler
core._set_process_signal_handler()
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
# some shared memory objects may have been applied for but have not yet
# been put into the inter-process Queue. This part of the object needs
# to be cleaned up when the process ends.
CleanupFuncRegistrar.register(_cleanup_mmap)
for batch in batch_reader():
tensor_list = core._convert_to_tensor_list(batch)
data_queue.put(tensor_list)
core._remove_tensor_list_mmap_fds(tensor_list)
data_queue.put(None)
except KeyboardInterrupt:
# NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
pass
except:
six.reraise(*sys.exc_info())
class DataLoaderBase(object): class DataLoaderBase(object):
def __init__(self): def __init__(self):
self._places = None self._places = None
...@@ -811,7 +835,8 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -811,7 +835,8 @@ class DygraphGeneratorLoader(DataLoaderBase):
global multiprocess_queue_set global multiprocess_queue_set
multiprocess_queue_set.add(self._data_queue) multiprocess_queue_set.add(self._data_queue)
self._process = multiprocessing.Process( self._process = multiprocessing.Process(
target=self._reader_process_loop) target=_reader_process_loop,
args=(self._batch_reader, self._data_queue))
self._process.daemon = True self._process.daemon = True
self._process.start() self._process.start()
...@@ -867,28 +892,6 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -867,28 +892,6 @@ class DygraphGeneratorLoader(DataLoaderBase):
self._blocking_queue.kill() self._blocking_queue.kill()
logging.error("DataLoader reader thread raised an exception!") logging.error("DataLoader reader thread raised an exception!")
def _reader_process_loop(self):
try:
# set signal handler
core._set_process_signal_handler()
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
# some shared memory objects may have been applied for but have not yet
# been put into the inter-process Queue. This part of the object needs
# to be cleaned up when the process ends.
CleanupFuncRegistrar.register(_cleanup_mmap)
for batch in self._batch_reader():
tensor_list = core._convert_to_tensor_list(batch)
self._data_queue.put(tensor_list)
core._remove_tensor_list_mmap_fds(tensor_list)
self._data_queue.put(None)
except KeyboardInterrupt:
# NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
pass
except:
six.reraise(*sys.exc_info())
def _reader_thread_loop_for_multiprocess(self): def _reader_thread_loop_for_multiprocess(self):
while not self._thread_done_event.is_set(): while not self._thread_done_event.is_set():
try: try:
......
...@@ -18,6 +18,7 @@ import multiprocessing ...@@ -18,6 +18,7 @@ import multiprocessing
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.reader import _reader_process_loop
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import Queue as queue import Queue as queue
...@@ -66,7 +67,7 @@ class TestDygraphDataLoaderProcess(unittest.TestCase): ...@@ -66,7 +67,7 @@ class TestDygraphDataLoaderProcess(unittest.TestCase):
batch_generator_creator(self.batch_size, self.batch_num), batch_generator_creator(self.batch_size, self.batch_num),
places=fluid.CPUPlace()) places=fluid.CPUPlace())
loader._data_queue = queue.Queue(self.batch_num + 1) loader._data_queue = queue.Queue(self.batch_num + 1)
loader._reader_process_loop() _reader_process_loop(loader._batch_reader, loader._data_queue)
# For clean memory mapped files # For clean memory mapped files
util_queue = multiprocessing.Queue(self.batch_num + 1) util_queue = multiprocessing.Queue(self.batch_num + 1)
for _ in range(self.batch_num): for _ in range(self.batch_num):
...@@ -94,7 +95,7 @@ class TestDygraphDataLoaderProcess(unittest.TestCase): ...@@ -94,7 +95,7 @@ class TestDygraphDataLoaderProcess(unittest.TestCase):
loader._data_queue = queue.Queue(self.batch_num + 1) loader._data_queue = queue.Queue(self.batch_num + 1)
exception = None exception = None
try: try:
loader._reader_process_loop() _reader_process_loop(loader._batch_reader, loader._data_queue)
except core.EnforceNotMet as ex: except core.EnforceNotMet as ex:
exception = ex exception = ex
self.assertIsNotNone(exception) self.assertIsNotNone(exception)
......
...@@ -27,6 +27,7 @@ import paddle.fluid.core as core ...@@ -27,6 +27,7 @@ import paddle.fluid.core as core
from paddle.io import Dataset, IterableDataset, BatchSampler, DataLoader from paddle.io import Dataset, IterableDataset, BatchSampler, DataLoader
from paddle.fluid.dygraph.nn import Linear from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dataloader.dataloader_iter import _worker_loop
class RandomDataset(Dataset): class RandomDataset(Dataset):
...@@ -185,9 +186,10 @@ class TestDataLoaderWorkerLoop(unittest.TestCase): ...@@ -185,9 +186,10 @@ class TestDataLoaderWorkerLoop(unittest.TestCase):
for i in range(10): for i in range(10):
indices_queue.put([i, i + 10]) indices_queue.put([i, i + 10])
indices_queue.put(None) indices_queue.put(None)
loader._worker_loop( _worker_loop(loader._dataset, 0, indices_queue,
loader._dataset, 0, indices_queue, loader._data_queue, loader._data_queue, loader._workers_done_event,
loader._workers_done_event, _collate_fn, _init_fn, 0, 1) _collate_fn, _init_fn, 0, 1,
loader._use_shared_memory)
self.assertTrue(False) self.assertTrue(False)
except AssertionError: except AssertionError:
pass pass
...@@ -228,9 +230,10 @@ class TestDataLoaderWorkerLoop(unittest.TestCase): ...@@ -228,9 +230,10 @@ class TestDataLoaderWorkerLoop(unittest.TestCase):
indices_queue.put([i, i + 10]) indices_queue.put([i, i + 10])
indices_queue.put(None) indices_queue.put(None)
loader._workers_done_event.set() loader._workers_done_event.set()
loader._worker_loop( _worker_loop(loader._dataset, 0, indices_queue,
loader._dataset, 0, indices_queue, loader._data_queue, loader._data_queue, loader._workers_done_event,
loader._workers_done_event, _collate_fn, _init_fn, 0, 1) _collate_fn, _init_fn, 0, 1,
loader._use_shared_memory)
self.assertTrue(True) self.assertTrue(True)
except AssertionError: except AssertionError:
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册