提交 243a05b4 编写于 作者: M Megvii Engine Team

fix(mge): fix cpp trace function release

GitOrigin-RevId: 73f96428216bfb4cb0c0573ff6e85c40fcf392b3
上级 9fb5581f
...@@ -71,7 +71,7 @@ if sys.platform == "win32": ...@@ -71,7 +71,7 @@ if sys.platform == "win32":
kernel32.SetErrorMode(old_error_mode) kernel32.SetErrorMode(old_error_mode)
from .core._imperative_rt.core2 import sync from .core._imperative_rt.core2 import sync, release_trace_apply_func
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .device import * from .device import *
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
...@@ -90,7 +90,9 @@ _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer() ...@@ -90,7 +90,9 @@ _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
_persistent_cache_impl_ins.reg() _persistent_cache_impl_ins.reg()
atexit.register(sync) atexit.register(sync)
atexit.register(release_trace_apply_func)
del sync del sync
del release_trace_apply_func
del _set_fork_exec_path_for_timed_func del _set_fork_exec_path_for_timed_func
del _persistent_cache_impl_ins del _persistent_cache_impl_ins
...@@ -39,6 +39,14 @@ py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing, ...@@ -39,6 +39,14 @@ py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing,
py::object cpp_apply_backward_varnode; py::object cpp_apply_backward_varnode;
void release_trace_apply_func(){
cpp_apply_with_tracing.release();
cpp_apply_const_with_tracing.release();
cpp_apply_compiled_mode.release();
cpp_apply_const_compiled_mode.release();
cpp_apply_backward_varnode.release();
}
#define REGISTE_APPLY_FUNC(mode) \ #define REGISTE_APPLY_FUNC(mode) \
void set_##mode(py::object pyf) { \ void set_##mode(py::object pyf) { \
mode = pybind11::reinterpret_steal<py::object>(pyf); \ mode = pybind11::reinterpret_steal<py::object>(pyf); \
...@@ -720,6 +728,8 @@ void init_tensor(py::module m) { ...@@ -720,6 +728,8 @@ void init_tensor(py::module m) {
py_task_q.wait_all_task_finish(); py_task_q.wait_all_task_finish();
}, },
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("release_trace_apply_func", &release_trace_apply_func);
py::handle grad_key_type = GradKeyWrapper::wrap_t::type() py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach") .def<&GradKeyWrapper::attach>("attach")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册