未验证 提交 5d96b6e0 编写于 作者: C Chen Weihang 提交者: GitHub

Add Queue.get delay for multiprocess data loader (#22604) (#22640)

上级 750c6f42
...@@ -34,8 +34,9 @@ if sys.version_info[0] == 2: ...@@ -34,8 +34,9 @@ if sys.version_info[0] == 2:
import Queue as queue import Queue as queue
else: else:
import queue import queue
# NOTE: [ avoid hanging ] This value is used in getting data from another process # NOTE: [ avoid hanging ] These value is used in getting data from another process
MP_CHECK_TIMEOUT = 10 QUEUE_GET_TIMEOUT = 5
MAX_GET_FAILED_TIME = 12
__all__ = ['PyReader', 'DataLoader'] __all__ = ['PyReader', 'DataLoader']
...@@ -485,6 +486,17 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -485,6 +486,17 @@ class DygraphGeneratorLoader(DataLoaderBase):
signal.signal(signal.SIGCHLD, __handler__) signal.signal(signal.SIGCHLD, __handler__)
def _exit_thread_expectedly(self):
self._thread_done_event.set()
self._blocking_queue.close()
self._data_queue.close()
def _exit_thread_unexpectedly(self):
self._thread_done_event.set()
self._blocking_queue.kill()
self._data_queue.close()
logging.error("DataLoader reader thread raised an exception!")
def _reader_process_loop(self): def _reader_process_loop(self):
try: try:
# set signal handler # set signal handler
...@@ -506,6 +518,7 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -506,6 +518,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
def _reader_thread_loop_with_process(self): def _reader_thread_loop_with_process(self):
get_sample_try_time = 0
while not self._thread_done_event.is_set(): while not self._thread_done_event.is_set():
try: try:
# NOTE: [ avoid hanging ] Even with carefully designed data dependencies # NOTE: [ avoid hanging ] Even with carefully designed data dependencies
...@@ -513,10 +526,21 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -513,10 +526,21 @@ class DygraphGeneratorLoader(DataLoaderBase):
# still happen when data in queue is corrupted (e.g., due to # still happen when data in queue is corrupted (e.g., due to
# Queue.cancel_join_thread or unexpected exit). So we set a timeout whenever # Queue.cancel_join_thread or unexpected exit). So we set a timeout whenever
# we try to get data from `data_queue` # we try to get data from `data_queue`
sample = self._data_queue.get(timeout=MP_CHECK_TIMEOUT) sample = self._data_queue.get(timeout=QUEUE_GET_TIMEOUT)
get_sample_try_time = 0
except queue.Empty: except queue.Empty:
self._thread_done_event.set() get_sample_try_time += 1
logging.error("The reader has not read data for a long time.") if get_sample_try_time > MAX_GET_FAILED_TIME:
self._exit_thread_unexpectedly()
raise RuntimeError(
"DataLoader reader thread has not read data for a long time (60s)."
)
else:
# NOTE: [ avoid failed quickly ] Sometimes if the reader child process has a heavy burden,
# the child process has no enough time to put the data in the queue when the main process
# start trying to get data from queue. At this time, failure to read data should not be
# counted as a fatal error, there should be a certain number of attempts.
continue
if not self._thread_done_event.is_set(): if not self._thread_done_event.is_set():
if sample is not None: if sample is not None:
...@@ -532,20 +556,10 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -532,20 +556,10 @@ class DygraphGeneratorLoader(DataLoaderBase):
if not self._blocking_queue.push(array): if not self._blocking_queue.push(array):
self._blocking_queue.close() self._blocking_queue.close()
except: except:
self._thread_done_event.set() self._exit_thread_unexpectedly()
self._blocking_queue.kill()
self._data_queue.close()
logging.warning(
"DygraphDataLoader reader thread raised an exception."
)
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
else: else:
self._thread_done_event.set() self._exit_thread_expectedly()
self._blocking_queue.close()
self._data_queue.close()
else:
self._blocking_queue.kill()
self._data_queue.close()
def _reader_thread_loop(self): def _reader_thread_loop(self):
try: try:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import sys import sys
import time
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -20,10 +21,18 @@ from paddle.fluid import core ...@@ -20,10 +21,18 @@ from paddle.fluid import core
import paddle.compat as cpt import paddle.compat as cpt
def get_random_images_and_labels(image_shape, label_shape):
image = np.random.random(size=image_shape).astype('float32')
label = np.random.random(size=label_shape).astype('int64')
return image, label
class TestDygraphhDataLoaderWithException(unittest.TestCase): class TestDygraphhDataLoaderWithException(unittest.TestCase):
def setUp(self): def setUp(self):
self.batch_size = 8
self.batch_num = 4 self.batch_num = 4
self.capacity = 2 self.epoch_num = 1
self.capacity = 5
def test_not_capacity(self): def test_not_capacity(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -77,6 +86,34 @@ class TestDygraphhDataLoaderWithException(unittest.TestCase): ...@@ -77,6 +86,34 @@ class TestDygraphhDataLoaderWithException(unittest.TestCase):
exception = ex exception = ex
self.assertIsNotNone(exception) self.assertIsNotNone(exception)
def test_multi_process_with_get_timeout(self):
def slow_batch_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num):
time.sleep(80)
batch_image, batch_label = get_random_images_and_labels(
[batch_size, 784], [batch_size, 1])
yield batch_image, batch_label
return __reader__
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_batch_generator(
slow_batch_generator_creator(self.batch_size, self.batch_num),
places=fluid.CPUPlace())
exception = None
try:
for _ in range(self.epoch_num):
for image, _ in loader():
fluid.layers.relu(image)
except core.EnforceNotMet as ex:
self.assertIn("Blocking queue is killed",
cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册