提交 d3bfb0e9 编写于 作者: M Megvii Engine Team

fix(mge): fix trace exit code and reformat

GitOrigin-RevId: 145c06b7e7a7f98f40f0e7acc1b555c16f27e2ba
上级 23b9a98f
......@@ -292,8 +292,6 @@ def remote_recv(
op = RemoteRecv()
op.key = key
op.cn = device
if isinstance(shape, Tensor):
shape = shape.numpy()
op.shape = shape
op.dtype = dtype
op.addr, op.port = get_mm_server_addr()
......
......@@ -191,19 +191,20 @@ class trace:
if len(ihandles) != len(args):
raise TraceMismatchError("op input size different from last time")
# check all inputs of crrent op
for h, x in zip(ihandles, args):
info = self._tinfo[h]
if info.external:
if (
x.__class__ is CompiledTensorProxy
and not self._tinfo[x._CompiledTensorProxy__handle].exported
x._compiled_info is not None
and not self._tinfo[x._mixin_handle].exported
):
raise TraceMismatchError(
"failed to capture: input was an external tensor "
"last time, got an internal tensor this time"
)
if info.bound_data:
if x.__class__ is CompiledTensorProxy:
if x._compiled_info is not None:
raise TraceMismatchError(
"const capture violated: was an external tensor "
"last time, got an internal tensor this time"
......@@ -225,17 +226,17 @@ class trace:
)
info.data_setter.set_value(x._dev_tensor())
else:
if x.mixin_handle == -1:
if x._mixin_handle == -1:
if x._handle not in self._tensor_remaps:
raise TraceMismatchError(
"unexpected capture: trying to use an external tensor as "
"input, but that input was an internal tensor last time"
)
else:
x.mixin_handle = self._tensor_remaps[
x._mixin_handle = self._tensor_remaps[
x._handle
]._CompiledTensorProxy__handle
if x.mixin_handle != h:
if x._mixin_handle != h:
raise TraceMismatchError(
"mis-wiring: input edge to an data flow "
"graph node is different from last time"
......@@ -245,9 +246,10 @@ class trace:
outputs = []
for h in ohandles:
info = self._tinfo[h]
# generate output tensor and create compied info
y = RawTensor(info.varnode)
y._compiled_info = CompiledTensorProxy(h)
y.mixin_handle = h
y._mixin_handle = h
outputs += [y]
self._active_tensors[h] = TensorWeakRef(y)
self._output_handles.update(ohandles)
......@@ -260,6 +262,7 @@ class trace:
raise TraceMismatchError("trace should end here, but more op observed")
record = self._seq[self._pc]
op_, ihandles, ohandles = record
# Const op is represented by a str
assert isinstance(op_, str) and op_ == "Const"
eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy())
......@@ -273,17 +276,18 @@ class trace:
outputs = [self._tinfo[h].bound_data]
return outputs
# run in first step, record information for trace
def _record_op(self, op, inputs, outputs):
if skip_tracing:
for x in inputs:
h = getattr(x, "mixin_handle", -1)
h = getattr(x, "_mixin_handle", -1)
if h >= 0:
self._tinfo[h].data = True
return
ihandles = []
for x in inputs:
h = getattr(x, "mixin_handle", -1)
h = getattr(x, "_mixin_handle", -1)
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
h, info = self._new_handle()
info.external = True
......@@ -300,8 +304,8 @@ class trace:
h, info = self._new_handle()
ohandles.append(h)
info.external = False
x.mixin_handle = h
x.recording = True
x._mixin_handle = h
x._recording = True
x._trace_mixin_info = info
self._active_tensors[h] = TensorWeakRef(x)
if self._symbolic:
......@@ -312,7 +316,7 @@ class trace:
def _record_const(self, outputs):
if skip_tracing:
(x,) = outputs
h = getattr(x, "mixin_handle", -1)
h = getattr(x, "_mixin_handle", -1)
if h >= 0:
self._tinfo[h].data_read = True
return
......@@ -326,8 +330,8 @@ class trace:
info.shape = x.shape
info.bound_data = x
info.is_const = True
x.mixin_handle = h
x.recording = True
x._mixin_handle = h
x._recording = True
x._trace_mixin_info = info
if self._symbolic:
self._lazy_eval_tensors[h] = TensorWeakRef(x)
......@@ -371,6 +375,7 @@ class trace:
lazy_eval_graph.compile(*lazy_eval_links, *readers)
lazy_eval_graph()
for r, x in zip(readers, lazy_eval_tensors):
# get values from lazy_eval_graph and assign to lazy_eval tensor
x()._handle = RawTensor(r.op.get_value())._handle
x()._reset_varnode()
......@@ -395,14 +400,14 @@ class trace:
if self._untraced:
for x in escaped_tensors:
if x():
info = self._tinfo[x().mixin_handle]
info = self._tinfo[x()._mixin_handle]
info.data_read = True
x().mixin_handle = -1
x().recording = False
x()._mixin_handle = -1
x()._recording = False
if self._inputs_to_restore:
for x in self._inputs_to_restore:
x.mixin_handle = -1
x.recording = False
x._mixin_handle = -1
x._recording = False
if self._symbolic and (
self._lazy_eval_tensors or self._lazy_eval_links
):
......@@ -441,12 +446,13 @@ class trace:
if not self._untraced and self._pc != len(self._seq):
raise TraceMismatchError("premature end")
if not self._symbolic or not self._untraced:
# reset output tensors
for x in self._active_tensors.values():
if x() is not None:
x()._dev_tensor()
x()._reset_varnode()
x().mixin_handle = -1
x().recording = False
x()._mixin_handle = -1
x()._recording = False
x()._trace_mixin_info = None
try:
......@@ -470,10 +476,14 @@ class trace:
# conditionally reading a compiled tensor in excluded region
# is permitted, so we have to assume every tensor might be read
for x in self._active_tensors.values():
info = self._tinfo[x().mixin_handle]
info.exported = True
info.data_read = True
x()._dev_tensor()
if x():
info = self._tinfo[x()._mixin_handle]
info.exported = True
info.data_read = True
else:
for x in self._active_tensors.values():
if x():
x()._dev_tensor()
def _apply_graph_options(self, graph):
......@@ -528,7 +538,6 @@ class trace:
info.varnode = opnode.outputs[0]
in_out_links += opnode.outputs[1:]
cnt_data, cnt_value, cnt_shape = 0, 0, 0
for op, ihandles, ohandles in self._seq:
if isinstance(op, str) and op == "Const":
assert len(ihandles) == 0
......@@ -604,16 +613,13 @@ class trace:
# Shape can be obtained from data so doesn't need its own
# output node. On the other hand, value is read separately
# to leverage eager h2d copy
cnt_data += 1
info.shape_read = False
opnode = info.data_reader = G.OutputNode(v, *in_out_links)
add_reader(opnode)
if info.value_read:
cnt_value += 1
opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
add_reader(opnode)
if info.shape_read:
cnt_shape += 1
opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
add_reader(opnode)
......@@ -637,15 +643,17 @@ class trace:
self._process_inputs(*args, **kwargs)
outputs = self.__wrapped__(*args, **kwargs)
transform = False
# outputs can be None
if outputs is not None:
if not isinstance(outputs, collections.abc.Sequence):
transform = True
outputs = (outputs,)
for o in outputs:
# if outputs are copied, then use the newest info in trace data structure
if o._copied:
self._active_tensors[o.mixin_handle] = TensorWeakRef(o)
self._active_tensors[o._mixin_handle] = TensorWeakRef(o)
if self._untraced and self._symbolic:
self._lazy_eval_tensors[o.mixin_handle] = TensorWeakRef(o)
self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o)
if self._capture_as_const:
self._process_outputs(outputs)
if transform:
......@@ -819,8 +827,8 @@ class trace:
info.device = x.device
info.dtype = x.dtype
info.shape = x.numpy().shape
x.mixin_handle = h
x.recording = True
x._mixin_handle = h
x._recording = True
x._trace_mixin_info = info
self._inputs_to_restore.append(x)
return h
......@@ -914,12 +922,12 @@ class trace:
if not isinstance(x, RawTensor):
raise TypeError("every item of return value should be tensor")
if self._untraced:
h = x.mixin_handle
h = x._mixin_handle
if h < 0:
raise RuntimeError("output is not computed from inputs")
self._output_bindings.append(h)
else:
h = x.mixin_handle
h = x._mixin_handle
if h not in self._output_handles:
raise RuntimeError("output is not computed from inputs")
if h != self._output_bindings[i]:
......@@ -938,6 +946,11 @@ class trace:
raise RuntimeError("trace is not set with profiling=True")
return json.loads(self._profiler.get())
def __del__(self):
for x in self._tinfo:
if getattr(x, "bound_data", None):
x.bound_data = None
def trace(self, *args, **kwargs):
raise NotImplementedError(
"trace is deemed unbeneficial with the new "
......
......@@ -291,7 +291,11 @@ PyObject* TensorWrapper::copied() {
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
PyObject* TensorWrapper::member() { \
return m_tensor->m_trace_info.member; \
if (m_tensor->m_trace_info.member) { \
return m_tensor->m_trace_info.member; \
} else { \
Py_RETURN_NONE; \
} \
} \
void TensorWrapper::set_##member(PyObject* dest) { \
if (dest == Py_None) { \
......@@ -322,6 +326,7 @@ void TensorWrapper::set_handle(PyObject* dest) {
PyObject* TensorWrapper::shape() {
// if it's tracing compiled mode, get value from compiled_info
if (m_tensor->m_trace_info.compiled_info != nullptr) {
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyTuple_New(0);
......@@ -332,15 +337,18 @@ PyObject* TensorWrapper::shape() {
}
return shp;
}
// inside trace, if tensor shape is useful for other operations, set shape_read = true
if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr());
}
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyTuple_New(0);
}
TensorShape shape;
if (m_tensor->m_var) {
if (m_tensor->m_var) { // get shape from m_var
auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var);
if (!tshp) {
......@@ -389,9 +397,11 @@ PyObject* TensorWrapper::numpy() {
}
return np_val;
}
if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "value_read", py::cast(true).release().ptr());
}
if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) {
auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
auto&& type = mgr.get_infer_type(m_tensor->m_var);
......@@ -411,12 +421,14 @@ PyObject* TensorWrapper::numpy() {
}
return np_val.release().ptr();
}
auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get());
auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
if (!arr) {
PyErr_SetString(PyExc_ValueError, "tensor invalid");
return nullptr;
}
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
mgb_assert(PyArray_Check(arr.ptr()));
return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
......@@ -428,7 +440,7 @@ PyObject* TensorWrapper::varnode() {
if (m_tensor->m_var) {
return py::cast(m_tensor->m_var).release().ptr();
}
return py::none().release().ptr();
Py_RETURN_NONE;
}
void TensorWrapper::reset(PyObject* tensor) {
......@@ -465,9 +477,13 @@ PyObject* TensorWrapper::_dev_tensor(){
if (dev_tensor == Py_None) {
throw TraceReadError("raw data of this tensor is not read in trace");
}
// set m_handle to make it a real tensor
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
m_tensor->m_handle = std::move(SharedHandle(sh));
// compiled info is useless after m_handle is set
Py_DECREF(m_tensor->m_trace_info.compiled_info);
m_tensor->m_trace_info.compiled_info = nullptr;
......@@ -753,8 +769,8 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::reset_varnode>("_reset_varnode")
.def_getset<&TensorWrapper::varnode>("_varnode")
.def_getset<&TensorWrapper::copied>("_copied")
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("mixin_handle")
.def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("recording")
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle")
.def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("_recording")
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle")
.def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info")
.def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info")
......
......@@ -55,7 +55,6 @@ apply_result_t apply_trace(ApplyContext& ctx) {
auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op);
py::tuple args(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this());
}
......
......@@ -19,7 +19,9 @@ struct TraceInfo {
bool recording = false;
bool copied = false;
// refer to CompiledTensorProxy in tracing.py, works from second trace step
PyObject* compiled_info = nullptr;
// refer to TensorInfo in tracing.py, only works in first trace step
PyObject* trace_mixin_info = nullptr;
TraceInfo() = default;
......@@ -37,7 +39,7 @@ struct TraceInfo {
return *this;
}
~TraceInfo() {
~TraceInfo() {
Py_XDECREF(trace_mixin_info);
Py_XDECREF(compiled_info);
}
......
......@@ -14,14 +14,17 @@ import pytest
import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
import megengine.optimizer as optim
import megengine.utils.comp_graph_tools as cgtools
from megengine import tensor
from megengine import Parameter, tensor
from megengine.autodiff import GradManager
from megengine.core._trace_option import set_symbolic_shape
from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.utils import isscalar
from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace
from megengine.module import Module
from megengine.random import normal, uniform
......@@ -39,8 +42,48 @@ def test_trace():
np.testing.assert_equal(f(x).numpy(), y)
def test_output_copy_trace():
class Simple(Module):
def __init__(self):
super().__init__()
self.a = Parameter([1.0], dtype=np.float32)
def forward(self, x):
x = x * self.a
# will result into a copy of output in grad
x = F.exp(x)
return x
net = Simple()
gm = GradManager().attach(net.parameters())
opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
data = tensor(np.arange(4).reshape(2, 2), dtype="float32")
@trace(symbolic=False)
def train_f1(d):
with gm:
loss = net(d)
gm.backward(loss)
opt.step().clear_grad()
return loss
@trace(symbolic=True)
def train_f2(d):
with gm:
loss = net(d)
gm.backward(loss)
opt.step().clear_grad()
return loss
for i in range(2):
y1 = train_f1(data).numpy()
y2 = train_f2(data).numpy()
np.testing.assert_equal(y1, y2)
def test_exclude_from_trace():
for symbolic in [False]:
for symbolic in [False, True]:
@trace(symbolic=symbolic)
def f(x):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册