diff --git a/ppdet/data/parallel_map.py b/ppdet/data/parallel_map.py index 7517157cdbeaf597fc7c5cb4c57db4153c90fbf5..c9d74880e35ada3c887d8fe2c17bb16556c1aa67 100644 --- a/ppdet/data/parallel_map.py +++ b/ppdet/data/parallel_map.py @@ -35,6 +35,8 @@ import traceback logger = logging.getLogger(__name__) +worker_set = set() + class EndSignal(object): """ signal used to notify worker to exit @@ -120,6 +122,7 @@ class ParallelMap(object): self._consumers = [] self._consumer_endsig = {} + global worker_set for i in range(consumer_num): consumer_id = 'consumer-' + id + '-' + str(i) p = Worker( @@ -128,6 +131,7 @@ class ParallelMap(object): self._consumers.append(p) p.daemon = True setattr(p, 'id', consumer_id) + worker_set.add(p) self._epoch = -1 self._feeding_ev = Event() @@ -279,16 +283,17 @@ class ParallelMap(object): self._feeding_ev.set() -# FIXME(dengkaipeng): fix me if you have better impliment +# FIXME: fix me if you have better impliment # handle terminate reader process, do not print stack frame signal.signal(signal.SIGTERM, lambda signum, frame: sys.exit()) -def _term_group(sig_num, frame): - pid = os.getpid() - pg = os.getpgid(os.getpid()) - logger.info("main proc {} exit, kill process group " "{}".format(pid, pg)) - os.killpg(pg, signal.SIGKILL) +def _term_workers(sig_num, frame): + global worker_set + logger.info("main proc {} exit, kill subprocess {}".format( + pid, [w.pid for w in worker_set])) + for w in worker_set: + os.kill(w, signal.SIGKILL) -signal.signal(signal.SIGINT, _term_group) +signal.signal(signal.SIGINT, _term_workers)