未验证 提交 1644926a 编写于 作者: C Chen Weihang 提交者: GitHub

Polish detail implement of dygraph data loader (#22878)

* polish detail implement of data loader, test=develop

* solve coverage ci problem, test=develop
上级 1daa6655
...@@ -129,9 +129,22 @@ void ThrowErrorIfLoadProcessFailed() { ...@@ -129,9 +129,22 @@ void ThrowErrorIfLoadProcessFailed() {
process_pid, infop.si_status)); process_pid, infop.si_status));
} else if (infop.si_code == CLD_KILLED || } else if (infop.si_code == CLD_KILLED ||
infop.si_code == CLD_DUMPED) { // killed by signal infop.si_code == CLD_DUMPED) { // killed by signal
PADDLE_THROW(platform::errors::Fatal( if (infop.si_status == SIGBUS) {
"DataLoader process (pid %ld) exited is killed by signal: %s.", PADDLE_THROW(platform::errors::Fatal(
process_pid, strsignal(infop.si_status))); "DataLoader process (pid %ld) exited is killed by signal: %s.\n"
" It may be caused by insufficient shared storage space. This "
"problem usually occurs when using docker as a development "
"environment.\n Please use command `df -h` to check the storage "
"space of `/dev/shm`. Shared storage space needs to be greater "
"than (DataLoader Num * DataLoader queue capacity * 1 batch data "
"size).\n You can solve this problem by increasing the shared "
"storage space or reducing the queue capacity appropriately.",
process_pid, strsignal(infop.si_status)));
} else {
PADDLE_THROW(platform::errors::Fatal(
"DataLoader process (pid %ld) exited is killed by signal: %s.",
process_pid, strsignal(infop.si_status)));
}
} }
} }
} }
......
...@@ -544,11 +544,12 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -544,11 +544,12 @@ class DygraphGeneratorLoader(DataLoaderBase):
# Set reader_thread # Set reader_thread
self._thread_done_event = threading.Event() self._thread_done_event = threading.Event()
self._thread = threading.Thread( self._thread = threading.Thread(
target=self._reader_thread_loop_with_process) target=self._reader_thread_loop_for_multiprocess)
self._thread.daemon = True self._thread.daemon = True
self._thread.start() self._thread.start()
else: else:
self._thread = threading.Thread(target=self._reader_thread_loop) self._thread = threading.Thread(
target=self._reader_thread_loop_for_singleprocess)
self._thread.daemon = True self._thread.daemon = True
self._thread.start() self._thread.start()
...@@ -621,7 +622,7 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -621,7 +622,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
except: except:
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
def _reader_thread_loop_with_process(self): def _reader_thread_loop_for_multiprocess(self):
while not self._thread_done_event.is_set(): while not self._thread_done_event.is_set():
try: try:
# NOTE: [ avoid hanging ] Even with carefully designed data dependencies # NOTE: [ avoid hanging ] Even with carefully designed data dependencies
...@@ -636,11 +637,15 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -636,11 +637,15 @@ class DygraphGeneratorLoader(DataLoaderBase):
# start trying to get data from queue. At this time, the child thread needs # start trying to get data from queue. At this time, the child thread needs
# to wait slightly longer # to wait slightly longer
tensor_list = self._data_queue.get(timeout=QUEUE_GET_TIMEOUT) tensor_list = self._data_queue.get(timeout=QUEUE_GET_TIMEOUT)
except queue.Empty: except:
# NOTE [ avoid handing ] After adding the shared memory mechanism, not only
# the queue.Empty exception will occur here, but other exceptions will also
# occur, such as mmap failure. If it is not handled here, it will hang.
self._exit_thread_unexpectedly() self._exit_thread_unexpectedly()
raise RuntimeError( logging.error(
"DataLoader reader thread has not read data for a long time (60s)." "DataLoader reader thread failed to read data from the multiprocessing.Queue."
) )
six.reraise(*sys.exc_info())
if not self._thread_done_event.is_set(): if not self._thread_done_event.is_set():
if tensor_list is not None: if tensor_list is not None:
...@@ -656,7 +661,7 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -656,7 +661,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
else: else:
self._exit_thread_expectedly() self._exit_thread_expectedly()
def _reader_thread_loop(self): def _reader_thread_loop_for_singleprocess(self):
try: try:
for sample in self._batch_reader(): for sample in self._batch_reader():
array = core.LoDTensorArray() array = core.LoDTensorArray()
......
...@@ -197,7 +197,7 @@ list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass) ...@@ -197,7 +197,7 @@ list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass)
if (APPLE OR WIN32) if (APPLE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_dataset) list(REMOVE_ITEM TEST_OPS test_dataset)
list(REMOVE_ITEM TEST_OPS test_dataset_dataloader) list(REMOVE_ITEM TEST_OPS test_dataset_dataloader)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader) list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_base)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_exception) list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_exception)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_process) list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_process)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_fds_clear) list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_fds_clear)
...@@ -366,5 +366,7 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu ...@@ -366,5 +366,7 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu
test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST") test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST")
if(NOT WIN32 AND NOT APPLE) if(NOT WIN32 AND NOT APPLE)
set_tests_properties(test_imperative_data_loader PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE) set_tests_properties(test_imperative_data_loader_base PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE)
set_tests_properties(test_imperative_data_loader_exception PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE)
set_tests_properties(test_imperative_data_loader_fds_clear PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" RUN_SERIAL TRUE)
endif() endif()
...@@ -34,16 +34,6 @@ def sample_generator_creator(batch_size, batch_num): ...@@ -34,16 +34,6 @@ def sample_generator_creator(batch_size, batch_num):
return __reader__ return __reader__
def batch_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num):
batch_image, batch_label = get_random_images_and_labels(
[batch_size, 784], [batch_size, 1])
yield batch_image, batch_label
return __reader__
class TestDygraphDataLoader(unittest.TestCase): class TestDygraphDataLoader(unittest.TestCase):
def setUp(self): def setUp(self):
self.batch_size = 8 self.batch_size = 8
...@@ -51,7 +41,7 @@ class TestDygraphDataLoader(unittest.TestCase): ...@@ -51,7 +41,7 @@ class TestDygraphDataLoader(unittest.TestCase):
self.epoch_num = 1 self.epoch_num = 1
self.capacity = 5 self.capacity = 5
def test_single_process_reader(self): def test_single_process_loader(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator( loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, iterable=False, use_multiprocess=False) capacity=self.capacity, iterable=False, use_multiprocess=False)
...@@ -66,7 +56,7 @@ class TestDygraphDataLoader(unittest.TestCase): ...@@ -66,7 +56,7 @@ class TestDygraphDataLoader(unittest.TestCase):
self.assertEqual(label.shape, [self.batch_size, 1]) self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784]) self.assertEqual(relu.shape, [self.batch_size, 784])
def test_sample_genarator(self): def test_multi_process_loader(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator( loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True) capacity=self.capacity, use_multiprocess=True)
...@@ -81,20 +71,6 @@ class TestDygraphDataLoader(unittest.TestCase): ...@@ -81,20 +71,6 @@ class TestDygraphDataLoader(unittest.TestCase):
self.assertEqual(label.shape, [self.batch_size, 1]) self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784]) self.assertEqual(relu.shape, [self.batch_size, 784])
def test_batch_genarator(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_batch_generator(
batch_generator_creator(self.batch_size, self.batch_num),
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -38,7 +38,7 @@ def set_child_signal_handler(self, child_pid): ...@@ -38,7 +38,7 @@ def set_child_signal_handler(self, child_pid):
class TestDygraphDataLoaderSingalHandler(unittest.TestCase): class TestDygraphDataLoaderSingalHandler(unittest.TestCase):
def test_child_process_exit_will_error(self): def test_child_process_exit_with_error(self):
def __test_process__(): def __test_process__():
core._set_process_signal_handler() core._set_process_signal_handler()
sys.exit(1) sys.exit(1)
...@@ -69,7 +69,25 @@ class TestDygraphDataLoaderSingalHandler(unittest.TestCase): ...@@ -69,7 +69,25 @@ class TestDygraphDataLoaderSingalHandler(unittest.TestCase):
set_child_signal_handler(id(self), test_process.pid) set_child_signal_handler(id(self), test_process.pid)
time.sleep(3) time.sleep(3)
except core.EnforceNotMet as ex: except core.EnforceNotMet as ex:
self.assertIn("FatalError", cpt.get_exception_message(ex)) self.assertIn("Segmentation fault", cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
def test_child_process_killed_by_sigbus(self):
def __test_process__():
core._set_process_signal_handler()
os.kill(os.getpid(), signal.SIGBUS)
exception = None
try:
test_process = multiprocessing.Process(target=__test_process__)
test_process.start()
set_child_signal_handler(id(self), test_process.pid)
time.sleep(3)
except core.EnforceNotMet as ex:
self.assertIn("Bus error", cpt.get_exception_message(ex))
exception = ex exception = ex
self.assertIsNotNone(exception) self.assertIsNotNone(exception)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册