提交 243a05b4 编写于 作者: M Megvii Engine Team

fix(mge): fix cpp trace function release

GitOrigin-RevId: 73f96428216bfb4cb0c0573ff6e85c40fcf392b3
上级 9fb5581f
......@@ -71,7 +71,7 @@ if sys.platform == "win32":
kernel32.SetErrorMode(old_error_mode)
from .core._imperative_rt.core2 import sync
from .core._imperative_rt.core2 import sync, release_trace_apply_func
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .device import *
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
......@@ -90,7 +90,9 @@ _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
_persistent_cache_impl_ins.reg()
atexit.register(sync)
atexit.register(release_trace_apply_func)
del sync
del release_trace_apply_func
del _set_fork_exec_path_for_timed_func
del _persistent_cache_impl_ins
......@@ -39,6 +39,14 @@ py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing,
py::object cpp_apply_backward_varnode;
void release_trace_apply_func(){
cpp_apply_with_tracing.release();
cpp_apply_const_with_tracing.release();
cpp_apply_compiled_mode.release();
cpp_apply_const_compiled_mode.release();
cpp_apply_backward_varnode.release();
}
#define REGISTE_APPLY_FUNC(mode) \
void set_##mode(py::object pyf) { \
mode = pybind11::reinterpret_steal<py::object>(pyf); \
......@@ -720,6 +728,8 @@ void init_tensor(py::module m) {
py_task_q.wait_all_task_finish();
},
py::call_guard<py::gil_scoped_release>());
m.def("release_trace_apply_func", &release_trace_apply_func);
py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册