未验证 提交 e93c18a3 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix dataloader exit terminate error (#34501)

* fix DataLoader exit with SIGABRT/SIGSEGV. test=develop
上级 2df74aa6
...@@ -43,6 +43,36 @@ from .flat import _flatten_batch, _restore_batch ...@@ -43,6 +43,36 @@ from .flat import _flatten_batch, _restore_batch
__all__ = ['get_worker_info'] __all__ = ['get_worker_info']
# NOTE: fix `terminate called without an active exception`
# if for loop break and program exit immediately(with no model
# layers processing) after iterate **the first few data** in
# distributed lauch mode, distributed launch will call
# terminate() to kill main process on each devices, but thread
# is still iterating to fullfill blocking queue caches, which
# may cause thread error `terminate called without an active
# exception` for terminate is a strong singal and `__del__`
# of DataLoader may not be called, so we add a global link to
# the last DataLoader instance to call `__del__` to clean up
# resources
# NOTE: cannot simply as `__del__` to CleanupFuncRegistrar,
# for this will remain a link to each DataLoader instance in
# global, and will precludes GC to auto collect DataLoader
# instance and will cause memory leak
_loader = None
def _clear_loader():
global _loader
if _loader is not None:
try:
_loader.__del__()
del _loader
except:
pass
CleanupFuncRegistrar.register(_clear_loader)
class _DataLoaderIterBase(object): class _DataLoaderIterBase(object):
""" """
...@@ -100,6 +130,16 @@ class _DataLoaderIterBase(object): ...@@ -100,6 +130,16 @@ class _DataLoaderIterBase(object):
def __len__(self): def __len__(self):
return len(self._batch_sampler) return len(self._batch_sampler)
def _exit_thread_expectedly(self):
self._thread_done_event.set()
if self._blocking_queue:
self._blocking_queue.close()
def _exit_thread_unexpectedly(self):
self._thread_done_event.set()
if self._blocking_queue:
self._blocking_queue.kill()
class _DataLoaderIterSingleProcess(_DataLoaderIterBase): class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
""" """
...@@ -125,9 +165,13 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -125,9 +165,13 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
# NOTE: len(self._places) batch data compose as an output # NOTE: len(self._places) batch data compose as an output
# iteration, set blocking_queue can cache 2 iteration datas # iteration, set blocking_queue can cache 2 iteration datas
# at most here # at most here
self._blocking_queue_capacity = 2 * len(self._places) self._blocking_queue_capacity = 1 * len(self._places)
self._init_thread() self._init_thread()
self._shutdown = False
global _loader
_loader = self
def _init_thread(self): def _init_thread(self):
self._var_names = [v.name for v in self._feed_list] self._var_names = [v.name for v in self._feed_list]
...@@ -151,7 +195,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -151,7 +195,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self._thread.start() self._thread.start()
def _thread_loop(self, legacy_expected_place): def _thread_loop(self, legacy_expected_place):
try:
#NOTE(zhiqiu): Set the expected place for new thread as the same as father thread, #NOTE(zhiqiu): Set the expected place for new thread as the same as father thread,
# and it will call platform::SetDeviceId() in c++ internally. # and it will call platform::SetDeviceId() in c++ internally.
# If we do not set cudaDeviceId in new thread, the default cudaDeviceId will be 0, # If we do not set cudaDeviceId in new thread, the default cudaDeviceId will be 0,
...@@ -159,14 +202,28 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -159,14 +202,28 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
# APIs in this thread. # APIs in this thread.
_set_expected_place(legacy_expected_place) _set_expected_place(legacy_expected_place)
for indices in self._sampler_iter: while not self._thread_done_event.is_set():
try:
indices = next(self._sampler_iter)
# read data from dataset in mini-batch # read data from dataset in mini-batch
batch = self._dataset_fetcher.fetch(indices) # with paddle.fluid.dygraph.guard(place=paddle.CPUPlace()):
# read data from dataset in mini-batch
batch = self._dataset_fetcher.fetch(indices,
self._thread_done_event)
except StopIteration:
self._exit_thread_expectedly()
return
if batch is None or self._thread_done_event.is_set(): break
# flat batch and record structure infos # flat batch and record structure infos
batch, structure = _flatten_batch(batch) batch, structure = _flatten_batch(batch)
self._structure_infos.append(structure) self._structure_infos.append(structure)
if self._thread_done_event.is_set(): break
try:
# pack as LoDTensorArray # pack as LoDTensorArray
array = core.LoDTensorArray() array = core.LoDTensorArray()
for slot in batch: for slot in batch:
...@@ -179,22 +236,19 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -179,22 +236,19 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
array.append(slot) array.append(slot)
if not self._blocking_queue.push(array): if self._thread_done_event.is_set(): break
break
if self._thread_done_event.is_set(): try:
break self._blocking_queue.push(array)
except:
self._exit_thread_expectedly()
self._blocking_queue.close() except:
self._shutdown_thread() self._exit_thread_unexpectedly()
except StopIteration:
self._blocking_queue.close()
except Exception:
self._blocking_queue.kill()
self._shutdown_thread()
logging.warning("DataLoader reader thread raised an exception.")
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
self._exit_thread_expectedly()
def __next__(self): def __next__(self):
try: try:
if in_dygraph_mode(): if in_dygraph_mode():
...@@ -221,28 +275,46 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -221,28 +275,46 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
return data return data
except StopIteration: except StopIteration:
self._reader.shutdown() self._reader.shutdown()
self._try_shutdown_all()
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
def _shutdown_thread(self): def _shutdown_thread(self):
if self._thread: if self._thread:
self._thread_done_event.set() self._thread_done_event.set()
# NOTE: we wait for _thread exit for 3 seconds, if
# thread not exit normally, force kill it
for _ in range(3):
if self._thread.is_alive():
time.sleep(1)
else:
break
else:
if self._thread is not threading.current_thread(): if self._thread is not threading.current_thread():
self._thread.join() self._thread.join()
self._thread = None self._thread = None
# python2 compatibility # python2 compatibility
def next(self): def next(self):
return self.__next__() return self.__next__()
def __del__(self): def _try_shutdown_all(self):
# _blocking_queue in keep order mode holds sub-threads if not self._shutdown:
# need to release thread resources on unexpected exit try:
# # _blocking_queue in keep order mode holds sub-threads
# # need to release thread resources on unexpected exit
if self._blocking_queue: if self._blocking_queue:
self._blocking_queue.close() self._blocking_queue.close()
self._blocking_queue = None
# NOTE: blocking queue should be closed firstly for # NOTE: blocking queue should be closed firstly for
# blocking queue read may hang and _thread_done_event # blocking queue read may hang and _thread_done_event
# cannot be checked # cannot be checked
self._shutdown_thread() self._shutdown_thread()
finally:
self._shutdown = True
def __del__(self):
self._try_shutdown_all()
class _DataLoaderIterMultiProcess(_DataLoaderIterBase): class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
...@@ -421,15 +493,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -421,15 +493,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
core._erase_process_pids(id(self)) core._erase_process_pids(id(self))
self._shutdown = True self._shutdown = True
def _exit_thread_expectedly(self):
self._thread_done_event.set()
self._blocking_queue.close()
def _exit_thread_unexpectedly(self):
self._thread_done_event.set()
self._blocking_queue.kill()
logging.error("DataLoader reader thread raised an exception!")
def _thread_loop(self, legacy_expected_place): def _thread_loop(self, legacy_expected_place):
#NOTE(zhiqiu): Set the expected place for new thread as the same as father thread, #NOTE(zhiqiu): Set the expected place for new thread as the same as father thread,
# and it will call platform::SetDeviceId() in c++ internally. # and it will call platform::SetDeviceId() in c++ internally.
......
...@@ -26,7 +26,16 @@ class _DatasetFetcher(object): ...@@ -26,7 +26,16 @@ class _DatasetFetcher(object):
self.collate_fn = collate_fn self.collate_fn = collate_fn
self.drop_last = drop_last self.drop_last = drop_last
def fetch(self, batch_indices): # NOTE: fetch function here perform the whole pipeline of dataset
# reading and data trasforms of a batch in each calling, this
# may take a long time inside, if DataLoader is exit outside,
# fetch need to perceive exit situation, so we pass done_event
# here for fetch to check exit status
# NOTE: if DataLoadet exit by `break`, performing GPU tensor operations,
# e.g. to_tensor may cause SIGSEGV in thread, so we pass the
# done_event argument to check DataLoader exit status between
# ecah sample processing in the batch
def fetch(self, batch_indices, done_event=None):
raise NotImplementedError("'fetch' not implement for class {}".format( raise NotImplementedError("'fetch' not implement for class {}".format(
self.__class__.__name__)) self.__class__.__name__))
...@@ -69,15 +78,18 @@ class _IterableDatasetFetcher(_DatasetFetcher): ...@@ -69,15 +78,18 @@ class _IterableDatasetFetcher(_DatasetFetcher):
dataset, auto_collate_batch, collate_fn, drop_last) dataset, auto_collate_batch, collate_fn, drop_last)
self.dataset_iter = iter(dataset) self.dataset_iter = iter(dataset)
def fetch(self, batch_indices): def fetch(self, batch_indices, done_event=None):
if self.auto_collate_batch: if self.auto_collate_batch:
data = [] data = []
for _ in batch_indices: for _ in batch_indices:
if done_event is None or not done_event.is_set():
try: try:
data.append(next(self.dataset_iter)) data.append(next(self.dataset_iter))
except StopIteration: except StopIteration:
break break
else:
return None
if len(data) == 0 or (self.drop_last and if len(data) == 0 or (self.drop_last and
len(data) < len(batch_indices)): len(data) < len(batch_indices)):
...@@ -101,9 +113,14 @@ class _MapDatasetFetcher(_DatasetFetcher): ...@@ -101,9 +113,14 @@ class _MapDatasetFetcher(_DatasetFetcher):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch, super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch,
collate_fn, drop_last) collate_fn, drop_last)
def fetch(self, batch_indices): def fetch(self, batch_indices, done_event=None):
if self.auto_collate_batch: if self.auto_collate_batch:
data = [self.dataset[idx] for idx in batch_indices] data = []
for idx in batch_indices:
if done_event is None or not done_event.is_set():
data.append(self.dataset[idx])
else:
return None
global _WARNING_TO_LOG global _WARNING_TO_LOG
if not isinstance(data[0], (Sequence, Mapping)) \ if not isinstance(data[0], (Sequence, Mapping)) \
......
...@@ -43,14 +43,18 @@ class TestDatasetAbstract(unittest.TestCase): ...@@ -43,14 +43,18 @@ class TestDatasetAbstract(unittest.TestCase):
class TestDatasetWithDiffOutputPlace(unittest.TestCase): class TestDatasetWithDiffOutputPlace(unittest.TestCase):
def get_dataloader(self, num_workers): def get_dataloader(self, num_workers):
dataset = paddle.vision.datasets.MNIST( dataset = paddle.vision.datasets.MNIST(
mode='test', transform=transforms.ToTensor()) mode='test',
transform=transforms.Compose([
transforms.CenterCrop(20), transforms.RandomResizedCrop(14),
transforms.Normalize(), transforms.ToTensor()
]))
loader = paddle.io.DataLoader( loader = paddle.io.DataLoader(
dataset, batch_size=32, num_workers=num_workers, shuffle=True) dataset, batch_size=32, num_workers=num_workers, shuffle=True)
return loader return loader
def run_check_on_cpu(self): def run_check_on_cpu(self):
paddle.set_device('cpu') paddle.set_device('cpu')
loader = self.get_dataloader(0) loader = self.get_dataloader(1)
for image, label in loader: for image, label in loader:
self.assertTrue(image.place.is_cpu_place()) self.assertTrue(image.place.is_cpu_place())
self.assertTrue(label.place.is_cpu_place()) self.assertTrue(label.place.is_cpu_place())
...@@ -66,12 +70,7 @@ class TestDatasetWithDiffOutputPlace(unittest.TestCase): ...@@ -66,12 +70,7 @@ class TestDatasetWithDiffOutputPlace(unittest.TestCase):
for image, label in loader: for image, label in loader:
self.assertTrue(image.place.is_gpu_place()) self.assertTrue(image.place.is_gpu_place())
self.assertTrue(label.place.is_cuda_pinned_place()) self.assertTrue(label.place.is_cuda_pinned_place())
# FIXME(dkp): when input tensor is in GPU place and break
# iteration break in the median, it seems the GPU
# tensor put into blocking_queue cannot be safely
# released and may cause ABRT/SEGV, this should
# be fixed
# break
def test_multi_process(self): def test_multi_process(self):
# DataLoader with multi-process mode is not supported on MacOs and Windows currently # DataLoader with multi-process mode is not supported on MacOs and Windows currently
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册