未验证 提交 1644926a 编写于 作者: C Chen Weihang 提交者: GitHub

Polish detail implement of dygraph data loader (#22878)

* polish detail implement of data loader, test=develop

* solve coverage ci problem, test=develop
上级 1daa6655
......@@ -129,9 +129,22 @@ void ThrowErrorIfLoadProcessFailed() {
process_pid, infop.si_status));
} else if (infop.si_code == CLD_KILLED ||
infop.si_code == CLD_DUMPED) { // killed by signal
PADDLE_THROW(platform::errors::Fatal(
"DataLoader process (pid %ld) exited is killed by signal: %s.",
process_pid, strsignal(infop.si_status)));
if (infop.si_status == SIGBUS) {
PADDLE_THROW(platform::errors::Fatal(
"DataLoader process (pid %ld) exited is killed by signal: %s.\n"
" It may be caused by insufficient shared storage space. This "
"problem usually occurs when using docker as a development "
"environment.\n Please use command `df -h` to check the storage "
"space of `/dev/shm`. Shared storage space needs to be greater "
"than (DataLoader Num * DataLoader queue capacity * 1 batch data "
"size).\n You can solve this problem by increasing the shared "
"storage space or reducing the queue capacity appropriately.",
process_pid, strsignal(infop.si_status)));
} else {
PADDLE_THROW(platform::errors::Fatal(
"DataLoader process (pid %ld) exited is killed by signal: %s.",
process_pid, strsignal(infop.si_status)));
}
}
}
}
......
......@@ -544,11 +544,12 @@ class DygraphGeneratorLoader(DataLoaderBase):
# Set reader_thread
self._thread_done_event = threading.Event()
self._thread = threading.Thread(
target=self._reader_thread_loop_with_process)
target=self._reader_thread_loop_for_multiprocess)
self._thread.daemon = True
self._thread.start()
else:
self._thread = threading.Thread(target=self._reader_thread_loop)
self._thread = threading.Thread(
target=self._reader_thread_loop_for_singleprocess)
self._thread.daemon = True
self._thread.start()
......@@ -621,7 +622,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
except:
six.reraise(*sys.exc_info())
def _reader_thread_loop_with_process(self):
def _reader_thread_loop_for_multiprocess(self):
while not self._thread_done_event.is_set():
try:
# NOTE: [ avoid hanging ] Even with carefully designed data dependencies
......@@ -636,11 +637,15 @@ class DygraphGeneratorLoader(DataLoaderBase):
# start trying to get data from queue. At this time, the child thread needs
# to wait slightly longer
tensor_list = self._data_queue.get(timeout=QUEUE_GET_TIMEOUT)
except queue.Empty:
except:
# NOTE [ avoid handing ] After adding the shared memory mechanism, not only
# the queue.Empty exception will occur here, but other exceptions will also
# occur, such as mmap failure. If it is not handled here, it will hang.
self._exit_thread_unexpectedly()
raise RuntimeError(
"DataLoader reader thread has not read data for a long time (60s)."
logging.error(
"DataLoader reader thread failed to read data from the multiprocessing.Queue."
)
six.reraise(*sys.exc_info())
if not self._thread_done_event.is_set():
if tensor_list is not None:
......@@ -656,7 +661,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
else:
self._exit_thread_expectedly()
def _reader_thread_loop(self):
def _reader_thread_loop_for_singleprocess(self):
try:
for sample in self._batch_reader():
array = core.LoDTensorArray()
......
......@@ -197,7 +197,7 @@ list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass)
if (APPLE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_dataset)
list(REMOVE_ITEM TEST_OPS test_dataset_dataloader)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_base)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_exception)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_process)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_fds_clear)
......@@ -366,5 +366,7 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu
test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST")
if(NOT WIN32 AND NOT APPLE)
set_tests_properties(test_imperative_data_loader PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE)
set_tests_properties(test_imperative_data_loader_base PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE)
set_tests_properties(test_imperative_data_loader_exception PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE)
set_tests_properties(test_imperative_data_loader_fds_clear PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE)
endif()
......@@ -34,16 +34,6 @@ def sample_generator_creator(batch_size, batch_num):
return __reader__
def batch_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num):
batch_image, batch_label = get_random_images_and_labels(
[batch_size, 784], [batch_size, 1])
yield batch_image, batch_label
return __reader__
class TestDygraphDataLoader(unittest.TestCase):
def setUp(self):
self.batch_size = 8
......@@ -51,7 +41,7 @@ class TestDygraphDataLoader(unittest.TestCase):
self.epoch_num = 1
self.capacity = 5
def test_single_process_reader(self):
def test_single_process_loader(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, iterable=False, use_multiprocess=False)
......@@ -66,7 +56,7 @@ class TestDygraphDataLoader(unittest.TestCase):
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
def test_sample_genarator(self):
def test_multi_process_loader(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
......@@ -81,20 +71,6 @@ class TestDygraphDataLoader(unittest.TestCase):
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
def test_batch_genarator(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_batch_generator(
batch_generator_creator(self.batch_size, self.batch_num),
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
if __name__ == '__main__':
unittest.main()
......@@ -38,7 +38,7 @@ def set_child_signal_handler(self, child_pid):
class TestDygraphDataLoaderSingalHandler(unittest.TestCase):
def test_child_process_exit_will_error(self):
def test_child_process_exit_with_error(self):
def __test_process__():
core._set_process_signal_handler()
sys.exit(1)
......@@ -69,7 +69,25 @@ class TestDygraphDataLoaderSingalHandler(unittest.TestCase):
set_child_signal_handler(id(self), test_process.pid)
time.sleep(3)
except core.EnforceNotMet as ex:
self.assertIn("FatalError", cpt.get_exception_message(ex))
self.assertIn("Segmentation fault", cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
def test_child_process_killed_by_sigbus(self):
def __test_process__():
core._set_process_signal_handler()
os.kill(os.getpid(), signal.SIGBUS)
exception = None
try:
test_process = multiprocessing.Process(target=__test_process__)
test_process.start()
set_child_signal_handler(id(self), test_process.pid)
time.sleep(3)
except core.EnforceNotMet as ex:
self.assertIn("Bus error", cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册