diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index fc352f11aa84e51adb9cfdd2b6a4a4614b15c17c..cf3c988388713f6f59b2d404d3bdbf9be0a8e979 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -95,6 +95,27 @@ atexit.register(_close) del _set_fork_exec_path_for_timed_func +_exit_handlers = [] + + +def _run_exit_handlers(): + for handler in _exit_handlers: + handler() + _exit_handlers.clear() + + +atexit.register(_run_exit_handlers) + + +def _exit(code): + _run_exit_handlers() + sys.exit(code) + + +def _atexit(handler): + _exit_handlers.append(handler) + + # subpackages import megengine.autodiff import megengine.data diff --git a/imperative/python/megengine/distributed/launcher.py b/imperative/python/megengine/distributed/launcher.py index 1e2a3dff9d025413515b8d3148601deef5acf146..f88f9ab459b6da22c558a949b226a19de5a76cb1 100644 --- a/imperative/python/megengine/distributed/launcher.py +++ b/imperative/python/megengine/distributed/launcher.py @@ -11,6 +11,7 @@ import multiprocessing as mp import os import queue +from .. import _exit from ..core._imperative_rt.core2 import sync from ..logger import get_logger from .group import group_barrier, init_process_group @@ -53,6 +54,7 @@ def _run_wrapped( sync() if is_multimachine: group_barrier() + _exit(0) class launcher: