提交 c70a49ed 编写于 作者: M Megvii Engine Team

fix(mge): correct trace outputs when grad does copy

GitOrigin-RevId: 65c8956a7df80bea78aa1ff9baa26c31490e2492
上级 d4ada69d
......@@ -163,9 +163,9 @@ class trace:
self._graph = None
self._need_reset_nodes = None
self._lazy_eval_graph = None
self._lazy_eval_tensors = set()
self._lazy_eval_tensors = {}
self._lazy_eval_links = None
self._active_tensors = set()
self._active_tensors = {}
self._tensor_remaps = None
self._inputs_to_restore = None
self._arg_bindings = None
......@@ -249,8 +249,8 @@ class trace:
y._compiled_info = CompiledTensorProxy(h)
y.mixin_handle = h
outputs += [y]
self._active_tensors[h] = TensorWeakRef(y)
self._output_handles.update(ohandles)
self._active_tensors.update([TensorWeakRef(o) for o in outputs])
return outputs
def _apply_const(self, value, dtype, device):
......@@ -303,9 +303,11 @@ class trace:
x.mixin_handle = h
x.recording = True
x._trace_mixin_info = info
self._active_tensors[h] = TensorWeakRef(x)
if self._symbolic:
self._lazy_eval_tensors[h] = TensorWeakRef(x)
self._seq.append((op, tuple(ihandles), tuple(ohandles)))
self._active_tensors.update([TensorWeakRef(o) for o in outputs])
def _record_const(self, outputs):
if skip_tracing:
......@@ -327,6 +329,8 @@ class trace:
x.mixin_handle = h
x.recording = True
x._trace_mixin_info = info
if self._symbolic:
self._lazy_eval_tensors[h] = TensorWeakRef(x)
self._seq.append(("Const", tuple(), tuple(ohandles)))
def _set_active(self, active: bool):
......@@ -346,12 +350,12 @@ class trace:
self._lazy_eval_links = ()
def _take_escaped_tensors(self):
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors))
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors.values()))
self._active_tensors.clear()
return escaped_tensors
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors))
lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors.values()))
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors]
self._apply_graph_options(lazy_eval_graph)
# FIXME
......@@ -401,7 +405,7 @@ class trace:
# eval lazy eval tensors
self._lazy_eval(
self._lazy_eval_graph,
tuple(self._lazy_eval_tensors),
self._lazy_eval_tensors,
self._lazy_eval_links,
)
self._lazy_eval_graph = None
......@@ -433,9 +437,10 @@ class trace:
if not self._untraced and self._pc != len(self._seq):
raise TraceMismatchError("premature end")
if not self._symbolic or not self._untraced:
for x in self._active_tensors:
for x in self._active_tensors.values():
if x() is not None:
x()._dev_tensor()
x()._reset_varnode()
x().mixin_handle = -1
x().recording = False
......@@ -459,7 +464,7 @@ class trace:
if self._untraced:
# conditionally reading a compiled tensor in excluded region
# is permitted, so we have to assume every tensor might be read
for x in self._active_tensors:
for x in self._active_tensors.values():
info = self._tinfo[x().mixin_handle]
info.exported = True
info.data_read = True
......@@ -626,8 +631,20 @@ class trace:
if self._capture_as_const:
self._process_inputs(*args, **kwargs)
outputs = self.__wrapped__(*args, **kwargs)
transform = False
if outputs is not None:
if not isinstance(outputs, collections.abc.Sequence):
transform = True
outputs = (outputs,)
for o in outputs:
if o._copied:
self._active_tensors[o.mixin_handle] = TensorWeakRef(o)
if self._untraced and self._symbolic:
self._lazy_eval_tensors[o.mixin_handle] = TensorWeakRef(o)
if self._capture_as_const:
self._process_outputs(outputs)
if transform:
outputs = outputs[0]
return outputs
def dump(
......@@ -1031,7 +1048,6 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
if require_links:
active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),)
active_trace._lazy_eval_tensors.update([TensorWeakRef(o) for o in outputs])
return outputs
......@@ -1042,7 +1058,6 @@ def apply_const_symbolic_mode(value, dtype, device):
ret = RawTensor(graph.make_const(value, dtype=dtype, device=device))
if np.array(value).ndim == 0:
setscalar(ret)
active_trace._lazy_eval_tensors.add(TensorWeakRef(ret))
return (ret,)
......
......@@ -284,6 +284,11 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording)
#undef REGISTE_TENSORWRAPPER_FUNC
PyObject* TensorWrapper::copied() {
return py::cast(m_tensor->m_trace_info.copied).release().ptr();
}
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
PyObject* TensorWrapper::member() { \
return m_tensor->m_trace_info.member; \
......@@ -740,6 +745,7 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::_drop>("_drop")
.def<&TensorWrapper::reset_varnode>("_reset_varnode")
.def_getset<&TensorWrapper::varnode>("_varnode")
.def_getset<&TensorWrapper::copied>("_copied")
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("mixin_handle")
.def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("recording")
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle")
......
......@@ -161,6 +161,7 @@ struct TensorWrapper {
PyObject* mixin_handle();
PyObject* recording();
PyObject* copied();
void set_mixin_handle(PyObject*);
void set_recording(PyObject*);
......
......@@ -17,6 +17,7 @@ namespace mgb::imperative::python {
struct TraceInfo {
int64_t mixin_handle = -1;
bool recording = false;
bool copied = false;
PyObject* compiled_info = nullptr;
PyObject* trace_mixin_info = nullptr;
......@@ -32,6 +33,7 @@ struct TraceInfo {
trace_mixin_info = that.trace_mixin_info;
Py_XINCREF(trace_mixin_info);
copied = true;
return *this;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册