diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 52ab83698592ab65e820b4bdf1f717667ded12c9..1f928bfc8a6890ac2ba1bcc597538d41c5bfc0b2 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -289,10 +289,14 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): # if user exit python program when dataloader is still # iterating, resource may no release safely, so we - # add __del__ function to to CleanupFuncRegistrar - # to make sure __del__ is always called when program + # add _shutdown_on_exit function to to CleanupFuncRegistrar + # to make sure _try_shutdown_all is always called when program # exit for resoure releasing safely - CleanupFuncRegistrar.register(self.__del__) + # worker join may hang for in _try_shutdown_all call in atexit + # for main process is in atexit state in some OS, so we add + # timeout=1 for shutdown function call in atexit, for shutdown + # function call in __del__, we keep it as it is + CleanupFuncRegistrar.register(self._shutdown_on_exit) def _init_workers(self): # multiprocess worker and indice queue list initial as empty @@ -363,7 +367,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): self._indices_queues[worker_id].put(None) self._worker_status[worker_id] = False - def _try_shutdown_all(self): + def _try_shutdown_all(self, timeout=None): if not self._shutdown: try: self._exit_thread_expectedly() @@ -376,11 +380,12 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): for i in range(self._num_workers): self._shutdown_worker(i) - for w in self._workers: - w.join() - for q in self._indices_queues: - q.cancel_join_thread() - q.close() + if not self._shutdown: + for w in self._workers: + w.join(timeout) + for q in self._indices_queues: + q.cancel_join_thread() + q.close() finally: core._erase_process_pids(id(self)) self._shutdown = True @@ -560,6 +565,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): def __del__(self): self._try_shutdown_all() + def _shutdown_on_exit(self): + self._try_shutdown_all(1) + def __next__(self): try: # _batches_outstanding here record the total batch data number