diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 4f0ec56d052c5e70816e0afccdbf3d457aad66b8..85acc0dc197f3d7db0e5a83a3f068d0de78572e4 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -350,12 +350,16 @@ class trace: self._lazy_eval_links = () def _take_escaped_tensors(self): - escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors.values())) + escaped_tensors = tuple( + filter(lambda x: x() is not None, self._active_tensors.values()) + ) self._active_tensors.clear() return escaped_tensors def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): - lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors.values())) + lazy_eval_tensors = list( + filter(lambda x: x() is not None, lazy_eval_tensors.values()) + ) readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] self._apply_graph_options(lazy_eval_graph) # FIXME @@ -443,6 +447,7 @@ class trace: x()._reset_varnode() x().mixin_handle = -1 x().recording = False + x()._trace_mixin_info = None try: do_enter() diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 40746e335bc36a492b15d2c30bf14a484b56bfd9..558e6229a84e42b593abbec40520098f1dab0b9f 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -294,8 +294,13 @@ PyObject* TensorWrapper::copied() { return m_tensor->m_trace_info.member; \ } \ void TensorWrapper::set_##member(PyObject* dest) { \ - Py_INCREF(dest); \ - m_tensor->m_trace_info.member = dest; \ + if (dest == Py_None) { \ + Py_XDECREF(m_tensor->m_trace_info.member); \ + m_tensor->m_trace_info.member = nullptr; \ + } else { \ + Py_INCREF(dest); \ + m_tensor->m_trace_info.member = dest; \ + } \ } REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info) @@ -463,6 +468,8 @@ PyObject* TensorWrapper::_dev_tensor(){ auto py_dev_tensor = py::reinterpret_borrow(dev_tensor); auto sh = interpreter_for_py->put(py_dev_tensor.cast()); m_tensor->m_handle = std::move(SharedHandle(sh)); + Py_DECREF(m_tensor->m_trace_info.compiled_info); + m_tensor->m_trace_info.compiled_info = nullptr; return dev_tensor; } diff --git a/imperative/python/src/trace.cpp b/imperative/python/src/trace.cpp index 73de092bf985a668af107663c5bff00bc2cca77d..09d1c1753244a64e3fdda01074cdde6904ce6502 100644 --- a/imperative/python/src/trace.cpp +++ b/imperative/python/src/trace.cpp @@ -55,10 +55,9 @@ apply_result_t apply_trace(ApplyContext& ctx) { auto args = py::tuple(ctx.nargs + 1); args[0] = py::cast(ctx.op); + py::tuple args(ctx.nargs); for (size_t i = 0; i < ctx.nargs; i++) { - args[i + 1] = TensorWrapper::make( - std::move(std::shared_ptr(ctx.args[i]))) - .release(); + args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); } auto ret = py::reinterpret_steal( PyObject_Call(pyf, args.ptr(), nullptr)); diff --git a/imperative/python/src/trace_info.h b/imperative/python/src/trace_info.h index b7c9a0f6a0080b237afde9ed626463f39abeb466..d61d9024ca26a026eb20b67e6f357fbfe5f10a23 100644 --- a/imperative/python/src/trace_info.h +++ b/imperative/python/src/trace_info.h @@ -28,10 +28,10 @@ struct TraceInfo { mixin_handle = that.mixin_handle; recording = that.recording; - compiled_info = that.compiled_info; - Py_XINCREF(compiled_info); trace_mixin_info = that.trace_mixin_info; Py_XINCREF(trace_mixin_info); + compiled_info = that.compiled_info; + Py_XINCREF(compiled_info); copied = true; return *this; @@ -39,7 +39,7 @@ struct TraceInfo { ~TraceInfo() { Py_XDECREF(trace_mixin_info); - // Py_XDECREF(compiled_info); + Py_XDECREF(compiled_info); } private: