From 243a05b4102c56bdd40f15656e9db5cacf49c982 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 29 Dec 2020 14:10:35 +0800 Subject: [PATCH] fix(mge): fix cpp trace function release GitOrigin-RevId: 73f96428216bfb4cb0c0573ff6e85c40fcf392b3 --- imperative/python/megengine/__init__.py | 4 +++- imperative/python/src/tensor.cpp | 10 ++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index 204dac0f8..f4e9f4855 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -71,7 +71,7 @@ if sys.platform == "win32": kernel32.SetErrorMode(old_error_mode) -from .core._imperative_rt.core2 import sync +from .core._imperative_rt.core2 import sync, release_trace_apply_func from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func from .device import * from .logger import enable_debug_log, get_logger, set_log_file, set_log_level @@ -90,7 +90,9 @@ _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer() _persistent_cache_impl_ins.reg() atexit.register(sync) +atexit.register(release_trace_apply_func) del sync +del release_trace_apply_func del _set_fork_exec_path_for_timed_func del _persistent_cache_impl_ins diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 6f68d4839..87e7deda0 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -39,6 +39,14 @@ py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing, py::object cpp_apply_backward_varnode; +void release_trace_apply_func(){ + cpp_apply_with_tracing.release(); + cpp_apply_const_with_tracing.release(); + cpp_apply_compiled_mode.release(); + cpp_apply_const_compiled_mode.release(); + cpp_apply_backward_varnode.release(); +} + #define REGISTE_APPLY_FUNC(mode) \ void set_##mode(py::object pyf) { \ mode = pybind11::reinterpret_steal(pyf); \ @@ -720,6 +728,8 @@ void init_tensor(py::module m) { py_task_q.wait_all_task_finish(); }, py::call_guard()); + + m.def("release_trace_apply_func", &release_trace_apply_func); py::handle grad_key_type = GradKeyWrapper::wrap_t::type() .def<&GradKeyWrapper::attach>("attach") -- GitLab