diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 6e81e2bf83e9b0782be83435031b77648264b6de..8b7b98b32fa502791d01b7e2f2d3aaf5f1af58bc 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -34,8 +34,9 @@ if sys.version_info[0] == 2: import Queue as queue else: import queue -# NOTE: [ avoid hanging ] This value is used in getting data from another process -MP_CHECK_TIMEOUT = 10 +# NOTE: [ avoid hanging ] These value is used in getting data from another process +QUEUE_GET_TIMEOUT = 5 +MAX_GET_FAILED_TIME = 12 __all__ = ['PyReader', 'DataLoader'] @@ -485,6 +486,17 @@ class DygraphGeneratorLoader(DataLoaderBase): signal.signal(signal.SIGCHLD, __handler__) + def _exit_thread_expectedly(self): + self._thread_done_event.set() + self._blocking_queue.close() + self._data_queue.close() + + def _exit_thread_unexpectedly(self): + self._thread_done_event.set() + self._blocking_queue.kill() + self._data_queue.close() + logging.error("DataLoader reader thread raised an exception!") + def _reader_process_loop(self): try: # set signal handler @@ -506,6 +518,7 @@ class DygraphGeneratorLoader(DataLoaderBase): six.reraise(*sys.exc_info()) def _reader_thread_loop_with_process(self): + get_sample_try_time = 0 while not self._thread_done_event.is_set(): try: # NOTE: [ avoid hanging ] Even with carefully designed data dependencies @@ -513,10 +526,21 @@ class DygraphGeneratorLoader(DataLoaderBase): # still happen when data in queue is corrupted (e.g., due to # Queue.cancel_join_thread or unexpected exit). So we set a timeout whenever # we try to get data from `data_queue` - sample = self._data_queue.get(timeout=MP_CHECK_TIMEOUT) + sample = self._data_queue.get(timeout=QUEUE_GET_TIMEOUT) + get_sample_try_time = 0 except queue.Empty: - self._thread_done_event.set() - logging.error("The reader has not read data for a long time.") + get_sample_try_time += 1 + if get_sample_try_time > MAX_GET_FAILED_TIME: + self._exit_thread_unexpectedly() + raise RuntimeError( + "DataLoader reader thread has not read data for a long time (60s)." + ) + else: + # NOTE: [ avoid failed quickly ] Sometimes if the reader child process has a heavy burden, + # the child process has no enough time to put the data in the queue when the main process + # start trying to get data from queue. At this time, failure to read data should not be + # counted as a fatal error, there should be a certain number of attempts. + continue if not self._thread_done_event.is_set(): if sample is not None: @@ -532,20 +556,10 @@ class DygraphGeneratorLoader(DataLoaderBase): if not self._blocking_queue.push(array): self._blocking_queue.close() except: - self._thread_done_event.set() - self._blocking_queue.kill() - self._data_queue.close() - logging.warning( - "DygraphDataLoader reader thread raised an exception." - ) + self._exit_thread_unexpectedly() six.reraise(*sys.exc_info()) else: - self._thread_done_event.set() - self._blocking_queue.close() - self._data_queue.close() - else: - self._blocking_queue.kill() - self._data_queue.close() + self._exit_thread_expectedly() def _reader_thread_loop(self): try: diff --git a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_exception.py b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_exception.py index 1f7c0c2f90e424fc18c5062d5a7583c8e02f915d..ca3995d602ead92e461f440ef83d9e8a6015f6d3 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_exception.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_exception.py @@ -13,6 +13,7 @@ # limitations under the License. import sys +import time import unittest import numpy as np import paddle.fluid as fluid @@ -20,10 +21,18 @@ from paddle.fluid import core import paddle.compat as cpt +def get_random_images_and_labels(image_shape, label_shape): + image = np.random.random(size=image_shape).astype('float32') + label = np.random.random(size=label_shape).astype('int64') + return image, label + + class TestDygraphhDataLoaderWithException(unittest.TestCase): def setUp(self): + self.batch_size = 8 self.batch_num = 4 - self.capacity = 2 + self.epoch_num = 1 + self.capacity = 5 def test_not_capacity(self): with fluid.dygraph.guard(): @@ -77,6 +86,34 @@ class TestDygraphhDataLoaderWithException(unittest.TestCase): exception = ex self.assertIsNotNone(exception) + def test_multi_process_with_get_timeout(self): + def slow_batch_generator_creator(batch_size, batch_num): + def __reader__(): + for _ in range(batch_num): + time.sleep(80) + batch_image, batch_label = get_random_images_and_labels( + [batch_size, 784], [batch_size, 1]) + yield batch_image, batch_label + + return __reader__ + + with fluid.dygraph.guard(): + loader = fluid.io.DataLoader.from_generator( + capacity=self.capacity, use_multiprocess=True) + loader.set_batch_generator( + slow_batch_generator_creator(self.batch_size, self.batch_num), + places=fluid.CPUPlace()) + exception = None + try: + for _ in range(self.epoch_num): + for image, _ in loader(): + fluid.layers.relu(image) + except core.EnforceNotMet as ex: + self.assertIn("Blocking queue is killed", + cpt.get_exception_message(ex)) + exception = ex + self.assertIsNotNone(exception) + if __name__ == '__main__': unittest.main()