未验证 提交 4ccd9a0a 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix dataloader exit hang when join re-enter (#32835)

* fix dataloader exit hang when join re-enter. test=develop

* double check _shutdown. test=develop
上级 02513207
...@@ -289,10 +289,14 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -289,10 +289,14 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# if user exit python program when dataloader is still # if user exit python program when dataloader is still
# iterating, resource may no release safely, so we # iterating, resource may no release safely, so we
# add __del__ function to to CleanupFuncRegistrar # add _shutdown_on_exit function to to CleanupFuncRegistrar
# to make sure __del__ is always called when program # to make sure _try_shutdown_all is always called when program
# exit for resoure releasing safely # exit for resoure releasing safely
CleanupFuncRegistrar.register(self.__del__) # worker join may hang for in _try_shutdown_all call in atexit
# for main process is in atexit state in some OS, so we add
# timeout=1 for shutdown function call in atexit, for shutdown
# function call in __del__, we keep it as it is
CleanupFuncRegistrar.register(self._shutdown_on_exit)
def _init_workers(self): def _init_workers(self):
# multiprocess worker and indice queue list initial as empty # multiprocess worker and indice queue list initial as empty
...@@ -363,7 +367,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -363,7 +367,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._indices_queues[worker_id].put(None) self._indices_queues[worker_id].put(None)
self._worker_status[worker_id] = False self._worker_status[worker_id] = False
def _try_shutdown_all(self): def _try_shutdown_all(self, timeout=None):
if not self._shutdown: if not self._shutdown:
try: try:
self._exit_thread_expectedly() self._exit_thread_expectedly()
...@@ -376,8 +380,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -376,8 +380,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
for i in range(self._num_workers): for i in range(self._num_workers):
self._shutdown_worker(i) self._shutdown_worker(i)
if not self._shutdown:
for w in self._workers: for w in self._workers:
w.join() w.join(timeout)
for q in self._indices_queues: for q in self._indices_queues:
q.cancel_join_thread() q.cancel_join_thread()
q.close() q.close()
...@@ -560,6 +565,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -560,6 +565,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
def __del__(self): def __del__(self):
self._try_shutdown_all() self._try_shutdown_all()
def _shutdown_on_exit(self):
self._try_shutdown_all(1)
def __next__(self): def __next__(self):
try: try:
# _batches_outstanding here record the total batch data number # _batches_outstanding here record the total batch data number
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册