diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 05ae3d9618ea6072522275ae55347d33fe0d0f6c..4f0ec56d052c5e70816e0afccdbf3d457aad66b8 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -163,9 +163,9 @@ class trace: self._graph = None self._need_reset_nodes = None self._lazy_eval_graph = None - self._lazy_eval_tensors = set() + self._lazy_eval_tensors = {} self._lazy_eval_links = None - self._active_tensors = set() + self._active_tensors = {} self._tensor_remaps = None self._inputs_to_restore = None self._arg_bindings = None @@ -249,8 +249,8 @@ class trace: y._compiled_info = CompiledTensorProxy(h) y.mixin_handle = h outputs += [y] + self._active_tensors[h] = TensorWeakRef(y) self._output_handles.update(ohandles) - self._active_tensors.update([TensorWeakRef(o) for o in outputs]) return outputs def _apply_const(self, value, dtype, device): @@ -303,9 +303,11 @@ class trace: x.mixin_handle = h x.recording = True x._trace_mixin_info = info + self._active_tensors[h] = TensorWeakRef(x) + if self._symbolic: + self._lazy_eval_tensors[h] = TensorWeakRef(x) self._seq.append((op, tuple(ihandles), tuple(ohandles))) - self._active_tensors.update([TensorWeakRef(o) for o in outputs]) def _record_const(self, outputs): if skip_tracing: @@ -327,6 +329,8 @@ class trace: x.mixin_handle = h x.recording = True x._trace_mixin_info = info + if self._symbolic: + self._lazy_eval_tensors[h] = TensorWeakRef(x) self._seq.append(("Const", tuple(), tuple(ohandles))) def _set_active(self, active: bool): @@ -346,12 +350,12 @@ class trace: self._lazy_eval_links = () def _take_escaped_tensors(self): - escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors)) + escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors.values())) self._active_tensors.clear() return escaped_tensors def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): - lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors)) + lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors.values())) readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] self._apply_graph_options(lazy_eval_graph) # FIXME @@ -401,7 +405,7 @@ class trace: # eval lazy eval tensors self._lazy_eval( self._lazy_eval_graph, - tuple(self._lazy_eval_tensors), + self._lazy_eval_tensors, self._lazy_eval_links, ) self._lazy_eval_graph = None @@ -433,9 +437,10 @@ class trace: if not self._untraced and self._pc != len(self._seq): raise TraceMismatchError("premature end") if not self._symbolic or not self._untraced: - for x in self._active_tensors: + for x in self._active_tensors.values(): if x() is not None: x()._dev_tensor() + x()._reset_varnode() x().mixin_handle = -1 x().recording = False @@ -459,7 +464,7 @@ class trace: if self._untraced: # conditionally reading a compiled tensor in excluded region # is permitted, so we have to assume every tensor might be read - for x in self._active_tensors: + for x in self._active_tensors.values(): info = self._tinfo[x().mixin_handle] info.exported = True info.data_read = True @@ -626,8 +631,20 @@ class trace: if self._capture_as_const: self._process_inputs(*args, **kwargs) outputs = self.__wrapped__(*args, **kwargs) + transform = False + if outputs is not None: + if not isinstance(outputs, collections.abc.Sequence): + transform = True + outputs = (outputs,) + for o in outputs: + if o._copied: + self._active_tensors[o.mixin_handle] = TensorWeakRef(o) + if self._untraced and self._symbolic: + self._lazy_eval_tensors[o.mixin_handle] = TensorWeakRef(o) if self._capture_as_const: self._process_outputs(outputs) + if transform: + outputs = outputs[0] return outputs def dump( @@ -1031,7 +1048,6 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): if require_links: active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),) - active_trace._lazy_eval_tensors.update([TensorWeakRef(o) for o in outputs]) return outputs @@ -1042,7 +1058,6 @@ def apply_const_symbolic_mode(value, dtype, device): ret = RawTensor(graph.make_const(value, dtype=dtype, device=device)) if np.array(value).ndim == 0: setscalar(ret) - active_trace._lazy_eval_tensors.add(TensorWeakRef(ret)) return (ret,) diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 6faa030af021668068ceeb788d6f3bbf44406818..40746e335bc36a492b15d2c30bf14a484b56bfd9 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -284,6 +284,11 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording) #undef REGISTE_TENSORWRAPPER_FUNC +PyObject* TensorWrapper::copied() { + return py::cast(m_tensor->m_trace_info.copied).release().ptr(); +} + + #define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \ PyObject* TensorWrapper::member() { \ return m_tensor->m_trace_info.member; \ @@ -740,6 +745,7 @@ void init_tensor(py::module m) { .def<&TensorWrapper::_drop>("_drop") .def<&TensorWrapper::reset_varnode>("_reset_varnode") .def_getset<&TensorWrapper::varnode>("_varnode") + .def_getset<&TensorWrapper::copied>("_copied") .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("mixin_handle") .def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("recording") .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 6b0c22e240211f83047d1e5a3c15b3516eba87fb..3d78bb9006d16176730e41ab4f8d4228f4ebbcda 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -161,6 +161,7 @@ struct TensorWrapper { PyObject* mixin_handle(); PyObject* recording(); + PyObject* copied(); void set_mixin_handle(PyObject*); void set_recording(PyObject*); diff --git a/imperative/python/src/trace_info.h b/imperative/python/src/trace_info.h index 3ab057fc395107c002b28e7dac6c3556cdf90b3b..b7c9a0f6a0080b237afde9ed626463f39abeb466 100644 --- a/imperative/python/src/trace_info.h +++ b/imperative/python/src/trace_info.h @@ -17,6 +17,7 @@ namespace mgb::imperative::python { struct TraceInfo { int64_t mixin_handle = -1; bool recording = false; + bool copied = false; PyObject* compiled_info = nullptr; PyObject* trace_mixin_info = nullptr; @@ -32,6 +33,7 @@ struct TraceInfo { trace_mixin_info = that.trace_mixin_info; Py_XINCREF(trace_mixin_info); + copied = true; return *this; }