提交 d3bfb0e9 编写于 作者: M Megvii Engine Team

fix(mge): fix trace exit code and reformat

GitOrigin-RevId: 145c06b7e7a7f98f40f0e7acc1b555c16f27e2ba
上级 23b9a98f
...@@ -292,8 +292,6 @@ def remote_recv( ...@@ -292,8 +292,6 @@ def remote_recv(
op = RemoteRecv() op = RemoteRecv()
op.key = key op.key = key
op.cn = device op.cn = device
if isinstance(shape, Tensor):
shape = shape.numpy()
op.shape = shape op.shape = shape
op.dtype = dtype op.dtype = dtype
op.addr, op.port = get_mm_server_addr() op.addr, op.port = get_mm_server_addr()
......
...@@ -191,19 +191,20 @@ class trace: ...@@ -191,19 +191,20 @@ class trace:
if len(ihandles) != len(args): if len(ihandles) != len(args):
raise TraceMismatchError("op input size different from last time") raise TraceMismatchError("op input size different from last time")
# check all inputs of crrent op
for h, x in zip(ihandles, args): for h, x in zip(ihandles, args):
info = self._tinfo[h] info = self._tinfo[h]
if info.external: if info.external:
if ( if (
x.__class__ is CompiledTensorProxy x._compiled_info is not None
and not self._tinfo[x._CompiledTensorProxy__handle].exported and not self._tinfo[x._mixin_handle].exported
): ):
raise TraceMismatchError( raise TraceMismatchError(
"failed to capture: input was an external tensor " "failed to capture: input was an external tensor "
"last time, got an internal tensor this time" "last time, got an internal tensor this time"
) )
if info.bound_data: if info.bound_data:
if x.__class__ is CompiledTensorProxy: if x._compiled_info is not None:
raise TraceMismatchError( raise TraceMismatchError(
"const capture violated: was an external tensor " "const capture violated: was an external tensor "
"last time, got an internal tensor this time" "last time, got an internal tensor this time"
...@@ -225,17 +226,17 @@ class trace: ...@@ -225,17 +226,17 @@ class trace:
) )
info.data_setter.set_value(x._dev_tensor()) info.data_setter.set_value(x._dev_tensor())
else: else:
if x.mixin_handle == -1: if x._mixin_handle == -1:
if x._handle not in self._tensor_remaps: if x._handle not in self._tensor_remaps:
raise TraceMismatchError( raise TraceMismatchError(
"unexpected capture: trying to use an external tensor as " "unexpected capture: trying to use an external tensor as "
"input, but that input was an internal tensor last time" "input, but that input was an internal tensor last time"
) )
else: else:
x.mixin_handle = self._tensor_remaps[ x._mixin_handle = self._tensor_remaps[
x._handle x._handle
]._CompiledTensorProxy__handle ]._CompiledTensorProxy__handle
if x.mixin_handle != h: if x._mixin_handle != h:
raise TraceMismatchError( raise TraceMismatchError(
"mis-wiring: input edge to an data flow " "mis-wiring: input edge to an data flow "
"graph node is different from last time" "graph node is different from last time"
...@@ -245,9 +246,10 @@ class trace: ...@@ -245,9 +246,10 @@ class trace:
outputs = [] outputs = []
for h in ohandles: for h in ohandles:
info = self._tinfo[h] info = self._tinfo[h]
# generate output tensor and create compied info
y = RawTensor(info.varnode) y = RawTensor(info.varnode)
y._compiled_info = CompiledTensorProxy(h) y._compiled_info = CompiledTensorProxy(h)
y.mixin_handle = h y._mixin_handle = h
outputs += [y] outputs += [y]
self._active_tensors[h] = TensorWeakRef(y) self._active_tensors[h] = TensorWeakRef(y)
self._output_handles.update(ohandles) self._output_handles.update(ohandles)
...@@ -260,6 +262,7 @@ class trace: ...@@ -260,6 +262,7 @@ class trace:
raise TraceMismatchError("trace should end here, but more op observed") raise TraceMismatchError("trace should end here, but more op observed")
record = self._seq[self._pc] record = self._seq[self._pc]
op_, ihandles, ohandles = record op_, ihandles, ohandles = record
# Const op is represented by a str
assert isinstance(op_, str) and op_ == "Const" assert isinstance(op_, str) and op_ == "Const"
eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy()) eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy())
...@@ -273,17 +276,18 @@ class trace: ...@@ -273,17 +276,18 @@ class trace:
outputs = [self._tinfo[h].bound_data] outputs = [self._tinfo[h].bound_data]
return outputs return outputs
# run in first step, record information for trace
def _record_op(self, op, inputs, outputs): def _record_op(self, op, inputs, outputs):
if skip_tracing: if skip_tracing:
for x in inputs: for x in inputs:
h = getattr(x, "mixin_handle", -1) h = getattr(x, "_mixin_handle", -1)
if h >= 0: if h >= 0:
self._tinfo[h].data = True self._tinfo[h].data = True
return return
ihandles = [] ihandles = []
for x in inputs: for x in inputs:
h = getattr(x, "mixin_handle", -1) h = getattr(x, "_mixin_handle", -1)
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
h, info = self._new_handle() h, info = self._new_handle()
info.external = True info.external = True
...@@ -300,8 +304,8 @@ class trace: ...@@ -300,8 +304,8 @@ class trace:
h, info = self._new_handle() h, info = self._new_handle()
ohandles.append(h) ohandles.append(h)
info.external = False info.external = False
x.mixin_handle = h x._mixin_handle = h
x.recording = True x._recording = True
x._trace_mixin_info = info x._trace_mixin_info = info
self._active_tensors[h] = TensorWeakRef(x) self._active_tensors[h] = TensorWeakRef(x)
if self._symbolic: if self._symbolic:
...@@ -312,7 +316,7 @@ class trace: ...@@ -312,7 +316,7 @@ class trace:
def _record_const(self, outputs): def _record_const(self, outputs):
if skip_tracing: if skip_tracing:
(x,) = outputs (x,) = outputs
h = getattr(x, "mixin_handle", -1) h = getattr(x, "_mixin_handle", -1)
if h >= 0: if h >= 0:
self._tinfo[h].data_read = True self._tinfo[h].data_read = True
return return
...@@ -326,8 +330,8 @@ class trace: ...@@ -326,8 +330,8 @@ class trace:
info.shape = x.shape info.shape = x.shape
info.bound_data = x info.bound_data = x
info.is_const = True info.is_const = True
x.mixin_handle = h x._mixin_handle = h
x.recording = True x._recording = True
x._trace_mixin_info = info x._trace_mixin_info = info
if self._symbolic: if self._symbolic:
self._lazy_eval_tensors[h] = TensorWeakRef(x) self._lazy_eval_tensors[h] = TensorWeakRef(x)
...@@ -371,6 +375,7 @@ class trace: ...@@ -371,6 +375,7 @@ class trace:
lazy_eval_graph.compile(*lazy_eval_links, *readers) lazy_eval_graph.compile(*lazy_eval_links, *readers)
lazy_eval_graph() lazy_eval_graph()
for r, x in zip(readers, lazy_eval_tensors): for r, x in zip(readers, lazy_eval_tensors):
# get values from lazy_eval_graph and assign to lazy_eval tensor
x()._handle = RawTensor(r.op.get_value())._handle x()._handle = RawTensor(r.op.get_value())._handle
x()._reset_varnode() x()._reset_varnode()
...@@ -395,14 +400,14 @@ class trace: ...@@ -395,14 +400,14 @@ class trace:
if self._untraced: if self._untraced:
for x in escaped_tensors: for x in escaped_tensors:
if x(): if x():
info = self._tinfo[x().mixin_handle] info = self._tinfo[x()._mixin_handle]
info.data_read = True info.data_read = True
x().mixin_handle = -1 x()._mixin_handle = -1
x().recording = False x()._recording = False
if self._inputs_to_restore: if self._inputs_to_restore:
for x in self._inputs_to_restore: for x in self._inputs_to_restore:
x.mixin_handle = -1 x._mixin_handle = -1
x.recording = False x._recording = False
if self._symbolic and ( if self._symbolic and (
self._lazy_eval_tensors or self._lazy_eval_links self._lazy_eval_tensors or self._lazy_eval_links
): ):
...@@ -441,12 +446,13 @@ class trace: ...@@ -441,12 +446,13 @@ class trace:
if not self._untraced and self._pc != len(self._seq): if not self._untraced and self._pc != len(self._seq):
raise TraceMismatchError("premature end") raise TraceMismatchError("premature end")
if not self._symbolic or not self._untraced: if not self._symbolic or not self._untraced:
# reset output tensors
for x in self._active_tensors.values(): for x in self._active_tensors.values():
if x() is not None: if x() is not None:
x()._dev_tensor() x()._dev_tensor()
x()._reset_varnode() x()._reset_varnode()
x().mixin_handle = -1 x()._mixin_handle = -1
x().recording = False x()._recording = False
x()._trace_mixin_info = None x()._trace_mixin_info = None
try: try:
...@@ -470,9 +476,13 @@ class trace: ...@@ -470,9 +476,13 @@ class trace:
# conditionally reading a compiled tensor in excluded region # conditionally reading a compiled tensor in excluded region
# is permitted, so we have to assume every tensor might be read # is permitted, so we have to assume every tensor might be read
for x in self._active_tensors.values(): for x in self._active_tensors.values():
info = self._tinfo[x().mixin_handle] if x():
info = self._tinfo[x()._mixin_handle]
info.exported = True info.exported = True
info.data_read = True info.data_read = True
else:
for x in self._active_tensors.values():
if x():
x()._dev_tensor() x()._dev_tensor()
def _apply_graph_options(self, graph): def _apply_graph_options(self, graph):
...@@ -528,7 +538,6 @@ class trace: ...@@ -528,7 +538,6 @@ class trace:
info.varnode = opnode.outputs[0] info.varnode = opnode.outputs[0]
in_out_links += opnode.outputs[1:] in_out_links += opnode.outputs[1:]
cnt_data, cnt_value, cnt_shape = 0, 0, 0
for op, ihandles, ohandles in self._seq: for op, ihandles, ohandles in self._seq:
if isinstance(op, str) and op == "Const": if isinstance(op, str) and op == "Const":
assert len(ihandles) == 0 assert len(ihandles) == 0
...@@ -604,16 +613,13 @@ class trace: ...@@ -604,16 +613,13 @@ class trace:
# Shape can be obtained from data so doesn't need its own # Shape can be obtained from data so doesn't need its own
# output node. On the other hand, value is read separately # output node. On the other hand, value is read separately
# to leverage eager h2d copy # to leverage eager h2d copy
cnt_data += 1
info.shape_read = False info.shape_read = False
opnode = info.data_reader = G.OutputNode(v, *in_out_links) opnode = info.data_reader = G.OutputNode(v, *in_out_links)
add_reader(opnode) add_reader(opnode)
if info.value_read: if info.value_read:
cnt_value += 1
opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links) opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
add_reader(opnode) add_reader(opnode)
if info.shape_read: if info.shape_read:
cnt_shape += 1
opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links) opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
add_reader(opnode) add_reader(opnode)
...@@ -637,15 +643,17 @@ class trace: ...@@ -637,15 +643,17 @@ class trace:
self._process_inputs(*args, **kwargs) self._process_inputs(*args, **kwargs)
outputs = self.__wrapped__(*args, **kwargs) outputs = self.__wrapped__(*args, **kwargs)
transform = False transform = False
# outputs can be None
if outputs is not None: if outputs is not None:
if not isinstance(outputs, collections.abc.Sequence): if not isinstance(outputs, collections.abc.Sequence):
transform = True transform = True
outputs = (outputs,) outputs = (outputs,)
for o in outputs: for o in outputs:
# if outputs are copied, then use the newest info in trace data structure
if o._copied: if o._copied:
self._active_tensors[o.mixin_handle] = TensorWeakRef(o) self._active_tensors[o._mixin_handle] = TensorWeakRef(o)
if self._untraced and self._symbolic: if self._untraced and self._symbolic:
self._lazy_eval_tensors[o.mixin_handle] = TensorWeakRef(o) self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o)
if self._capture_as_const: if self._capture_as_const:
self._process_outputs(outputs) self._process_outputs(outputs)
if transform: if transform:
...@@ -819,8 +827,8 @@ class trace: ...@@ -819,8 +827,8 @@ class trace:
info.device = x.device info.device = x.device
info.dtype = x.dtype info.dtype = x.dtype
info.shape = x.numpy().shape info.shape = x.numpy().shape
x.mixin_handle = h x._mixin_handle = h
x.recording = True x._recording = True
x._trace_mixin_info = info x._trace_mixin_info = info
self._inputs_to_restore.append(x) self._inputs_to_restore.append(x)
return h return h
...@@ -914,12 +922,12 @@ class trace: ...@@ -914,12 +922,12 @@ class trace:
if not isinstance(x, RawTensor): if not isinstance(x, RawTensor):
raise TypeError("every item of return value should be tensor") raise TypeError("every item of return value should be tensor")
if self._untraced: if self._untraced:
h = x.mixin_handle h = x._mixin_handle
if h < 0: if h < 0:
raise RuntimeError("output is not computed from inputs") raise RuntimeError("output is not computed from inputs")
self._output_bindings.append(h) self._output_bindings.append(h)
else: else:
h = x.mixin_handle h = x._mixin_handle
if h not in self._output_handles: if h not in self._output_handles:
raise RuntimeError("output is not computed from inputs") raise RuntimeError("output is not computed from inputs")
if h != self._output_bindings[i]: if h != self._output_bindings[i]:
...@@ -938,6 +946,11 @@ class trace: ...@@ -938,6 +946,11 @@ class trace:
raise RuntimeError("trace is not set with profiling=True") raise RuntimeError("trace is not set with profiling=True")
return json.loads(self._profiler.get()) return json.loads(self._profiler.get())
def __del__(self):
for x in self._tinfo:
if getattr(x, "bound_data", None):
x.bound_data = None
def trace(self, *args, **kwargs): def trace(self, *args, **kwargs):
raise NotImplementedError( raise NotImplementedError(
"trace is deemed unbeneficial with the new " "trace is deemed unbeneficial with the new "
......
...@@ -291,7 +291,11 @@ PyObject* TensorWrapper::copied() { ...@@ -291,7 +291,11 @@ PyObject* TensorWrapper::copied() {
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \ #define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
PyObject* TensorWrapper::member() { \ PyObject* TensorWrapper::member() { \
if (m_tensor->m_trace_info.member) { \
return m_tensor->m_trace_info.member; \ return m_tensor->m_trace_info.member; \
} else { \
Py_RETURN_NONE; \
} \
} \ } \
void TensorWrapper::set_##member(PyObject* dest) { \ void TensorWrapper::set_##member(PyObject* dest) { \
if (dest == Py_None) { \ if (dest == Py_None) { \
...@@ -322,6 +326,7 @@ void TensorWrapper::set_handle(PyObject* dest) { ...@@ -322,6 +326,7 @@ void TensorWrapper::set_handle(PyObject* dest) {
PyObject* TensorWrapper::shape() { PyObject* TensorWrapper::shape() {
// if it's tracing compiled mode, get value from compiled_info
if (m_tensor->m_trace_info.compiled_info != nullptr) { if (m_tensor->m_trace_info.compiled_info != nullptr) {
if (m_tensor->m_flags & Tensor::Flags::SCALAR) { if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyTuple_New(0); return PyTuple_New(0);
...@@ -332,15 +337,18 @@ PyObject* TensorWrapper::shape() { ...@@ -332,15 +337,18 @@ PyObject* TensorWrapper::shape() {
} }
return shp; return shp;
} }
// inside trace, if tensor shape is useful for other operations, set shape_read = true
if (m_tensor->m_trace_info.recording && !skip_tracing) { if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr()); PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr());
} }
if (m_tensor->m_flags & Tensor::Flags::SCALAR) { if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyTuple_New(0); return PyTuple_New(0);
} }
TensorShape shape; TensorShape shape;
if (m_tensor->m_var) { if (m_tensor->m_var) { // get shape from m_var
auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var); auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var);
if (!tshp) { if (!tshp) {
...@@ -389,9 +397,11 @@ PyObject* TensorWrapper::numpy() { ...@@ -389,9 +397,11 @@ PyObject* TensorWrapper::numpy() {
} }
return np_val; return np_val;
} }
if (m_tensor->m_trace_info.recording && !skip_tracing) { if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "value_read", py::cast(true).release().ptr()); PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "value_read", py::cast(true).release().ptr());
} }
if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) { if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) {
auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
auto&& type = mgr.get_infer_type(m_tensor->m_var); auto&& type = mgr.get_infer_type(m_tensor->m_var);
...@@ -411,12 +421,14 @@ PyObject* TensorWrapper::numpy() { ...@@ -411,12 +421,14 @@ PyObject* TensorWrapper::numpy() {
} }
return np_val.release().ptr(); return np_val.release().ptr();
} }
auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get()); auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get());
auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
if (!arr) { if (!arr) {
PyErr_SetString(PyExc_ValueError, "tensor invalid"); PyErr_SetString(PyExc_ValueError, "tensor invalid");
return nullptr; return nullptr;
} }
if (m_tensor->m_flags & Tensor::Flags::SCALAR) { if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
mgb_assert(PyArray_Check(arr.ptr())); mgb_assert(PyArray_Check(arr.ptr()));
return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr())); return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
...@@ -428,7 +440,7 @@ PyObject* TensorWrapper::varnode() { ...@@ -428,7 +440,7 @@ PyObject* TensorWrapper::varnode() {
if (m_tensor->m_var) { if (m_tensor->m_var) {
return py::cast(m_tensor->m_var).release().ptr(); return py::cast(m_tensor->m_var).release().ptr();
} }
return py::none().release().ptr(); Py_RETURN_NONE;
} }
void TensorWrapper::reset(PyObject* tensor) { void TensorWrapper::reset(PyObject* tensor) {
...@@ -465,9 +477,13 @@ PyObject* TensorWrapper::_dev_tensor(){ ...@@ -465,9 +477,13 @@ PyObject* TensorWrapper::_dev_tensor(){
if (dev_tensor == Py_None) { if (dev_tensor == Py_None) {
throw TraceReadError("raw data of this tensor is not read in trace"); throw TraceReadError("raw data of this tensor is not read in trace");
} }
// set m_handle to make it a real tensor
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor); auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>()); auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
m_tensor->m_handle = std::move(SharedHandle(sh)); m_tensor->m_handle = std::move(SharedHandle(sh));
// compiled info is useless after m_handle is set
Py_DECREF(m_tensor->m_trace_info.compiled_info); Py_DECREF(m_tensor->m_trace_info.compiled_info);
m_tensor->m_trace_info.compiled_info = nullptr; m_tensor->m_trace_info.compiled_info = nullptr;
...@@ -753,8 +769,8 @@ void init_tensor(py::module m) { ...@@ -753,8 +769,8 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::reset_varnode>("_reset_varnode") .def<&TensorWrapper::reset_varnode>("_reset_varnode")
.def_getset<&TensorWrapper::varnode>("_varnode") .def_getset<&TensorWrapper::varnode>("_varnode")
.def_getset<&TensorWrapper::copied>("_copied") .def_getset<&TensorWrapper::copied>("_copied")
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("mixin_handle") .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle")
.def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("recording") .def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("_recording")
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle")
.def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info") .def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info")
.def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info") .def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info")
......
...@@ -55,7 +55,6 @@ apply_result_t apply_trace(ApplyContext& ctx) { ...@@ -55,7 +55,6 @@ apply_result_t apply_trace(ApplyContext& ctx) {
auto args = py::tuple(ctx.nargs + 1); auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op); args[0] = py::cast(ctx.op);
py::tuple args(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; i++) { for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this());
} }
......
...@@ -19,7 +19,9 @@ struct TraceInfo { ...@@ -19,7 +19,9 @@ struct TraceInfo {
bool recording = false; bool recording = false;
bool copied = false; bool copied = false;
// refer to CompiledTensorProxy in tracing.py, works from second trace step
PyObject* compiled_info = nullptr; PyObject* compiled_info = nullptr;
// refer to TensorInfo in tracing.py, only works in first trace step
PyObject* trace_mixin_info = nullptr; PyObject* trace_mixin_info = nullptr;
TraceInfo() = default; TraceInfo() = default;
......
...@@ -14,14 +14,17 @@ import pytest ...@@ -14,14 +14,17 @@ import pytest
import megengine.core.tensor.megbrain_graph as G import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F import megengine.functional as F
import megengine.optimizer as optim
import megengine.utils.comp_graph_tools as cgtools import megengine.utils.comp_graph_tools as cgtools
from megengine import tensor from megengine import Parameter, tensor
from megengine.autodiff import GradManager
from megengine.core._trace_option import set_symbolic_shape from megengine.core._trace_option import set_symbolic_shape
from megengine.core.ops import builtin as ops from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.utils import isscalar from megengine.core.tensor.utils import isscalar
from megengine.functional import exp, log from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace from megengine.jit import exclude_from_trace, trace
from megengine.module import Module
from megengine.random import normal, uniform from megengine.random import normal, uniform
...@@ -39,8 +42,48 @@ def test_trace(): ...@@ -39,8 +42,48 @@ def test_trace():
np.testing.assert_equal(f(x).numpy(), y) np.testing.assert_equal(f(x).numpy(), y)
def test_output_copy_trace():
class Simple(Module):
def __init__(self):
super().__init__()
self.a = Parameter([1.0], dtype=np.float32)
def forward(self, x):
x = x * self.a
# will result into a copy of output in grad
x = F.exp(x)
return x
net = Simple()
gm = GradManager().attach(net.parameters())
opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
data = tensor(np.arange(4).reshape(2, 2), dtype="float32")
@trace(symbolic=False)
def train_f1(d):
with gm:
loss = net(d)
gm.backward(loss)
opt.step().clear_grad()
return loss
@trace(symbolic=True)
def train_f2(d):
with gm:
loss = net(d)
gm.backward(loss)
opt.step().clear_grad()
return loss
for i in range(2):
y1 = train_f1(data).numpy()
y2 = train_f2(data).numpy()
np.testing.assert_equal(y1, y2)
def test_exclude_from_trace(): def test_exclude_from_trace():
for symbolic in [False]: for symbolic in [False, True]:
@trace(symbolic=symbolic) @trace(symbolic=symbolic)
def f(x): def f(x):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册