提交 c9c3429a 编写于 作者: M Megvii Engine Team

refactor(mge): fix sublinear

GitOrigin-RevId: 5bb038378121f244fa13c891e497f72507465413
上级 de0742be
...@@ -20,30 +20,22 @@ import numpy as np ...@@ -20,30 +20,22 @@ import numpy as np
from ..core._imperative_rt import GraphProfiler, common from ..core._imperative_rt import GraphProfiler, common
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import TensorWeakRef
from ..core._imperative_rt.core2 import __make_empty_tensor as make_empty_tensor
from ..core._imperative_rt.core2 import ( from ..core._imperative_rt.core2 import (
TensorWeakRef,
apply, apply,
set_compiled, set_compiled,
set_symbolic,
set_tracing, set_tracing,
skip_tracing, skip_tracing,
unset_compiled, unset_compiled,
unset_symbolic,
unset_tracing, unset_tracing,
) )
from ..core._imperative_rt.ops import ( from ..core._imperative_rt.ops import CollectiveComm, RemoteRecv, RemoteSend
CollectiveComm,
GaussianRNG,
RemoteRecv,
RemoteSend,
UniformRNG,
)
from ..core._trace_option import set_symbolic_shape from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device from ..core._wrap import device as as_device
from ..core.ops.builtin import BackwardGraph, OpDef from ..core.ops.builtin import BackwardGraph, OpDef
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar
from .sublinear_memory_config import SublinearMemoryConfig from .sublinear_memory_config import SublinearMemoryConfig
...@@ -159,7 +151,6 @@ class trace: ...@@ -159,7 +151,6 @@ class trace:
self._profiler = None self._profiler = None
self._graph_opt_level = opt_level self._graph_opt_level = opt_level
self._symbolic_shape = symbolic_shape self._symbolic_shape = symbolic_shape
self._handle2tensors = {}
self._output_handles = set() self._output_handles = set()
self._reset() self._reset()
...@@ -195,7 +186,7 @@ class trace: ...@@ -195,7 +186,7 @@ class trace:
raise TraceMismatchError("trace should end here, but more op observed") raise TraceMismatchError("trace should end here, but more op observed")
record = self._seq[self._pc] record = self._seq[self._pc]
op_, ihandles, ohandles = record op_, ihandles, ohandles = record
if op != op_: if (isinstance(op_, str) and op_ == "Const") or (op != op_):
raise TraceMismatchError("op different from last time") raise TraceMismatchError("op different from last time")
if len(ihandles) != len(args): if len(ihandles) != len(args):
raise TraceMismatchError("op input size different from last time") raise TraceMismatchError("op input size different from last time")
...@@ -253,9 +244,11 @@ class trace: ...@@ -253,9 +244,11 @@ class trace:
self._pc += 1 self._pc += 1
outputs = [] outputs = []
for h in ohandles: for h in ohandles:
t = CompiledTensorProxy(h) info = self._tinfo[h]
t._dev_tensor() y = RawTensor(info.varnode)
outputs += [t._CompiledTensorProxy__tensor] y._compiled_info = CompiledTensorProxy(h)
y.mixin_handle = h
outputs += [y]
self._output_handles.update(ohandles) self._output_handles.update(ohandles)
self._active_tensors.update([TensorWeakRef(o) for o in outputs]) self._active_tensors.update([TensorWeakRef(o) for o in outputs])
return outputs return outputs
...@@ -285,7 +278,7 @@ class trace: ...@@ -285,7 +278,7 @@ class trace:
for x in inputs: for x in inputs:
h = getattr(x, "mixin_handle", -1) h = getattr(x, "mixin_handle", -1)
if h >= 0: if h >= 0:
x.data_read = True self._tinfo[h].data = True
return return
ihandles = [] ihandles = []
...@@ -308,7 +301,8 @@ class trace: ...@@ -308,7 +301,8 @@ class trace:
ohandles.append(h) ohandles.append(h)
info.external = False info.external = False
x.mixin_handle = h x.mixin_handle = h
self._handle2tensors[h] = x x.recording = True
x._trace_mixin_info = info
self._seq.append((op, tuple(ihandles), tuple(ohandles))) self._seq.append((op, tuple(ihandles), tuple(ohandles)))
self._active_tensors.update([TensorWeakRef(o) for o in outputs]) self._active_tensors.update([TensorWeakRef(o) for o in outputs])
...@@ -318,7 +312,7 @@ class trace: ...@@ -318,7 +312,7 @@ class trace:
(x,) = outputs (x,) = outputs
h = getattr(x, "mixin_handle", -1) h = getattr(x, "mixin_handle", -1)
if h >= 0: if h >= 0:
x.data_read = True self._tinfo[h].data_read = True
return return
(x,) = outputs (x,) = outputs
...@@ -331,7 +325,8 @@ class trace: ...@@ -331,7 +325,8 @@ class trace:
info.bound_data = x info.bound_data = x
info.is_const = True info.is_const = True
x.mixin_handle = h x.mixin_handle = h
self._handle2tensors[h] = x x.recording = True
x._trace_mixin_info = info
self._seq.append(("Const", tuple(), tuple(ohandles))) self._seq.append(("Const", tuple(), tuple(ohandles)))
def _set_active(self, active: bool): def _set_active(self, active: bool):
...@@ -346,7 +341,6 @@ class trace: ...@@ -346,7 +341,6 @@ class trace:
def _init_trace(self, symbolic: bool): def _init_trace(self, symbolic: bool):
if symbolic: if symbolic:
set_symbolic()
self._lazy_eval_graph = G.Graph() self._lazy_eval_graph = G.Graph()
self._apply_graph_options(self._lazy_eval_graph) self._apply_graph_options(self._lazy_eval_graph)
self._lazy_eval_links = () self._lazy_eval_links = ()
...@@ -383,8 +377,6 @@ class trace: ...@@ -383,8 +377,6 @@ class trace:
if self._untraced: if self._untraced:
self._init_trace(self._symbolic) self._init_trace(self._symbolic)
else: else:
# disable symbolic mode
unset_symbolic()
set_compiled() set_compiled()
if self._graph is None: if self._graph is None:
self._compile() self._compile()
...@@ -394,18 +386,15 @@ class trace: ...@@ -394,18 +386,15 @@ class trace:
escaped_tensors = self._take_escaped_tensors() escaped_tensors = self._take_escaped_tensors()
if self._untraced: if self._untraced:
for x in escaped_tensors: for x in escaped_tensors:
info = self._tinfo[x().mixin_handle] if x():
x().data_read = True info = self._tinfo[x().mixin_handle]
x().mixin_handle = -1 info.data_read = True
x().mixin_handle = -1
x().recording = False
if self._inputs_to_restore: if self._inputs_to_restore:
for x in self._inputs_to_restore: for x in self._inputs_to_restore:
x.mixin_handle = -1 x.mixin_handle = -1
for h, x in list(self._handle2tensors.items()): x.recording = False
info = self._tinfo[h]
info.data_read = x.data_read
info.shape_read = x.shape_read
info.value_read = x.value_read
del self._handle2tensors[h]
if self._symbolic and ( if self._symbolic and (
self._lazy_eval_tensors or self._lazy_eval_links self._lazy_eval_tensors or self._lazy_eval_links
): ):
...@@ -437,7 +426,6 @@ class trace: ...@@ -437,7 +426,6 @@ class trace:
self._set_active(False) self._set_active(False)
set_symbolic_shape(self._save_symbolic_shape) set_symbolic_shape(self._save_symbolic_shape)
unset_compiled() unset_compiled()
unset_symbolic()
unset_tracing() unset_tracing()
def do_exit(): def do_exit():
...@@ -449,6 +437,7 @@ class trace: ...@@ -449,6 +437,7 @@ class trace:
if x() is not None: if x() is not None:
x()._dev_tensor() x()._dev_tensor()
x().mixin_handle = -1 x().mixin_handle = -1
x().recording = False
try: try:
do_enter() do_enter()
...@@ -473,7 +462,8 @@ class trace: ...@@ -473,7 +462,8 @@ class trace:
for x in self._active_tensors: for x in self._active_tensors:
info = self._tinfo[x().mixin_handle] info = self._tinfo[x().mixin_handle]
info.exported = True info.exported = True
x().data_read = True info.data_read = True
x()._dev_tensor()
def _apply_graph_options(self, graph): def _apply_graph_options(self, graph):
...@@ -528,6 +518,7 @@ class trace: ...@@ -528,6 +518,7 @@ class trace:
info.varnode = opnode.outputs[0] info.varnode = opnode.outputs[0]
in_out_links += opnode.outputs[1:] in_out_links += opnode.outputs[1:]
cnt_data, cnt_value, cnt_shape = 0, 0, 0
for op, ihandles, ohandles in self._seq: for op, ihandles, ohandles in self._seq:
if isinstance(op, str) and op == "Const": if isinstance(op, str) and op == "Const":
assert len(ihandles) == 0 assert len(ihandles) == 0
...@@ -603,13 +594,16 @@ class trace: ...@@ -603,13 +594,16 @@ class trace:
# Shape can be obtained from data so doesn't need its own # Shape can be obtained from data so doesn't need its own
# output node. On the other hand, value is read separately # output node. On the other hand, value is read separately
# to leverage eager h2d copy # to leverage eager h2d copy
cnt_data += 1
info.shape_read = False info.shape_read = False
opnode = info.data_reader = G.OutputNode(v, *in_out_links) opnode = info.data_reader = G.OutputNode(v, *in_out_links)
add_reader(opnode) add_reader(opnode)
if info.value_read: if info.value_read:
cnt_value += 1
opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links) opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
add_reader(opnode) add_reader(opnode)
if info.shape_read: if info.shape_read:
cnt_shape += 1
opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links) opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
add_reader(opnode) add_reader(opnode)
...@@ -804,7 +798,8 @@ class trace: ...@@ -804,7 +798,8 @@ class trace:
info.dtype = x.dtype info.dtype = x.dtype
info.shape = x.numpy().shape info.shape = x.numpy().shape
x.mixin_handle = h x.mixin_handle = h
self._handle2tensors[h] = x x.recording = True
x._trace_mixin_info = info
self._inputs_to_restore.append(x) self._inputs_to_restore.append(x)
return h return h
...@@ -940,7 +935,6 @@ class CompiledTensorProxy: ...@@ -940,7 +935,6 @@ class CompiledTensorProxy:
self.__shape = None self.__shape = None
self.__data = None self.__data = None
self.__value = None self.__value = None
self.__tensor = make_empty_tensor()
@property @property
def dtype(self): def dtype(self):
...@@ -958,7 +952,7 @@ class CompiledTensorProxy: ...@@ -958,7 +952,7 @@ class CompiledTensorProxy:
if self.__info.shape_read: if self.__info.shape_read:
self.__shape = self.__info.shape_reader.get_value().shape self.__shape = self.__info.shape_reader.get_value().shape
elif self.__info.data_read: elif self.__info.data_read:
self.__shape = self.__info._dev_tensor().shape self.__shape = self._dev_tensor().shape
else: else:
raise TraceMismatchError("shape of this tensor is not read in trace") raise TraceMismatchError("shape of this tensor is not read in trace")
return self.__shape return self.__shape
...@@ -980,25 +974,14 @@ class CompiledTensorProxy: ...@@ -980,25 +974,14 @@ class CompiledTensorProxy:
if not self.__info.data_read: if not self.__info.data_read:
raise TraceMismatchError("raw data of this tensor is not read in trace") raise TraceMismatchError("raw data of this tensor is not read in trace")
self.__data = self.__info.data_reader.get_value() self.__data = self.__info.data_reader.get_value()
self.__tensor._reset(RawTensor(self.__data))
self.__tensor.mixin_handle = self.__handle
return self.__data return self.__data
def _drop(self):
return
def _swap_in(self):
return
def _swap_out(self):
return
def __del__(self): def __del__(self):
if self.__tensor.shape_read and self.__shape is not None: if self.__info.shape_read and self.__shape is not None:
self.__info.shape_reader.drop_value() self.__info.shape_reader.drop_value()
if self.__tensor.value_read and self.__value is not None: if self.__info.value_read and self.__value is not None:
self.__info.value_reader.drop_value() self.__info.value_reader.drop_value()
if self.__tensor.data_read and self.__data is not None: if self.__info.data_read and self.__data is not None:
self.__info.data_reader.drop_value() self.__info.data_reader.drop_value()
...@@ -1054,6 +1037,8 @@ def apply_const_symbolic_mode(value, dtype, device): ...@@ -1054,6 +1037,8 @@ def apply_const_symbolic_mode(value, dtype, device):
# don't need to unset tracing # don't need to unset tracing
# because varnode construction will ignore tracing flag # because varnode construction will ignore tracing flag
ret = RawTensor(graph.make_const(value, dtype=dtype, device=device)) ret = RawTensor(graph.make_const(value, dtype=dtype, device=device))
if np.array(value).ndim == 0:
setscalar(ret)
active_trace._lazy_eval_tensors.add(TensorWeakRef(ret)) active_trace._lazy_eval_tensors.add(TensorWeakRef(ret))
return (ret,) return (ret,)
...@@ -1084,7 +1069,6 @@ def apply_const_compiled_mode(value, dtype, device, is_const, no_cache): ...@@ -1084,7 +1069,6 @@ def apply_const_compiled_mode(value, dtype, device, is_const, no_cache):
return active_trace._apply_const(value, dtype, device) return active_trace._apply_const(value, dtype, device)
# this hook injects TraceMixin
def apply_with_tracing(op: OpDef, *args: RawTensor): def apply_with_tracing(op: OpDef, *args: RawTensor):
if active_trace._symbolic: if active_trace._symbolic:
outputs = apply_symbolic_mode(op, *args) outputs = apply_symbolic_mode(op, *args)
......
...@@ -54,7 +54,6 @@ REGISTE_APPLY_FUNC(cpp_apply_backward_varnode) ...@@ -54,7 +54,6 @@ REGISTE_APPLY_FUNC(cpp_apply_backward_varnode)
#undef REGISTE_APPLY_FUNC #undef REGISTE_APPLY_FUNC
bool is_tracing = false; bool is_tracing = false;
bool is_symbolic = false;
bool is_compiled = false; bool is_compiled = false;
#define SET_UNSET_PROP(mode) \ #define SET_UNSET_PROP(mode) \
...@@ -66,7 +65,6 @@ bool is_compiled = false; ...@@ -66,7 +65,6 @@ bool is_compiled = false;
} \ } \
SET_UNSET_PROP(tracing) SET_UNSET_PROP(tracing)
SET_UNSET_PROP(symbolic)
SET_UNSET_PROP(compiled) SET_UNSET_PROP(compiled)
#undef SET_UNSET_PROP #undef SET_UNSET_PROP
...@@ -280,14 +278,27 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -280,14 +278,27 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
m_tensor->m_trace_info.member = real_dest; \ m_tensor->m_trace_info.member = real_dest; \
} }
REGISTE_TENSORWRAPPER_FUNC(bool, data_read)
REGISTE_TENSORWRAPPER_FUNC(bool, value_read)
REGISTE_TENSORWRAPPER_FUNC(bool, shape_read)
REGISTE_TENSORWRAPPER_FUNC(int64_t, mixin_handle) REGISTE_TENSORWRAPPER_FUNC(int64_t, mixin_handle)
REGISTE_TENSORWRAPPER_FUNC(bool, recording)
#undef REGISTE_TENSORWRAPPER_FUNC #undef REGISTE_TENSORWRAPPER_FUNC
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
PyObject* TensorWrapper::member() { \
return m_tensor->m_trace_info.member; \
} \
void TensorWrapper::set_##member(PyObject* dest) { \
Py_INCREF(dest); \
m_tensor->m_trace_info.member = dest; \
}
REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info)
REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(trace_mixin_info)
#undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC
PyObject* TensorWrapper::handle() { PyObject* TensorWrapper::handle() {
return py::cast(m_tensor->m_handle).release().ptr(); return py::cast(m_tensor->m_handle).release().ptr();
} }
...@@ -301,8 +312,14 @@ void TensorWrapper::set_handle(PyObject* dest) { ...@@ -301,8 +312,14 @@ void TensorWrapper::set_handle(PyObject* dest) {
PyObject* TensorWrapper::shape() { PyObject* TensorWrapper::shape() {
if (!skip_tracing) { if (m_tensor->m_trace_info.compiled_info != nullptr) {
set_shape_read(py::cast(true). release().ptr()); if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyTuple_New(0);
}
return PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape");
}
if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr());
} }
if (m_tensor->m_flags & Tensor::Flags::SCALAR) { if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyTuple_New(0); return PyTuple_New(0);
...@@ -310,7 +327,12 @@ PyObject* TensorWrapper::shape() { ...@@ -310,7 +327,12 @@ PyObject* TensorWrapper::shape() {
TensorShape shape; TensorShape shape;
if (m_tensor->m_var) { if (m_tensor->m_var) {
shape = m_tensor->m_var->shape(); auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var);
if (!tshp) {
Py_RETURN_NONE;
}
shape = *tshp;
} else { } else {
shape = m_tensor->shape(); shape = m_tensor->shape();
} }
...@@ -343,8 +365,15 @@ PyObject* TensorWrapper::device() { ...@@ -343,8 +365,15 @@ PyObject* TensorWrapper::device() {
PyObject* TensorWrapper::numpy() { PyObject* TensorWrapper::numpy() {
if (!skip_tracing) { if (m_tensor->m_trace_info.compiled_info != nullptr) {
set_value_read(py::cast(true).release().ptr()); PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr);
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
np_val = PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val));
}
return np_val;
}
if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "value_read", py::cast(true).release().ptr());
} }
if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) { if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) {
auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
...@@ -359,7 +388,11 @@ PyObject* TensorWrapper::numpy() { ...@@ -359,7 +388,11 @@ PyObject* TensorWrapper::numpy() {
PyErr_SetString(PyExc_ValueError, "tensor invalid"); PyErr_SetString(PyExc_ValueError, "tensor invalid");
return nullptr; return nullptr;
} }
return py::cast(*val).attr("numpy")().release().ptr(); auto np_val = py::cast(*val).attr("numpy")();
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val.release().ptr()));
}
return np_val.release().ptr();
} }
auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get()); auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get());
auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
...@@ -410,8 +443,14 @@ PyObject* TensorWrapper::detach() { ...@@ -410,8 +443,14 @@ PyObject* TensorWrapper::detach() {
} }
PyObject* TensorWrapper::_dev_tensor(){ PyObject* TensorWrapper::_dev_tensor(){
if (!skip_tracing) { if (m_tensor->m_trace_info.compiled_info != nullptr) {
set_data_read(py::cast(true).release().ptr()); auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr);
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
m_tensor->m_handle = std::move(SharedHandle(sh));
}
if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr());
} }
auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get()); auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get());
return py::cast(dev_tensor).release().ptr(); return py::cast(dev_tensor).release().ptr();
...@@ -668,9 +707,6 @@ WRAP_FUNC_PY35(get_device); ...@@ -668,9 +707,6 @@ WRAP_FUNC_PY35(get_device);
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
#endif #endif
py::object make_empty_tensorwrapper() {
return TensorWrapper::make(std::move(std::make_shared<Tensor>()));
}
void init_tensor(py::module m) { void init_tensor(py::module m) {
imperative::Tensor::static_initialize(); imperative::Tensor::static_initialize();
...@@ -692,11 +728,11 @@ void init_tensor(py::module m) { ...@@ -692,11 +728,11 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::_drop>("_drop") .def<&TensorWrapper::_drop>("_drop")
.def<&TensorWrapper::reset_varnode>("_reset_varnode") .def<&TensorWrapper::reset_varnode>("_reset_varnode")
.def_getset<&TensorWrapper::varnode>("_varnode") .def_getset<&TensorWrapper::varnode>("_varnode")
.def_getset<&TensorWrapper::data_read, &TensorWrapper::set_data_read>("data_read")
.def_getset<&TensorWrapper::value_read, &TensorWrapper::set_value_read>("value_read")
.def_getset<&TensorWrapper::shape_read, &TensorWrapper::set_shape_read>("shape_read")
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("mixin_handle") .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("mixin_handle")
.def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("recording")
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle")
.def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info")
.def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info")
.finalize(); .finalize();
if (!tensor_type) throw py::error_already_set(); if (!tensor_type) throw py::error_already_set();
py::setattr(m, "Tensor", tensor_type); py::setattr(m, "Tensor", tensor_type);
...@@ -771,12 +807,8 @@ void init_tensor(py::module m) { ...@@ -771,12 +807,8 @@ void init_tensor(py::module m) {
m.def("set_tracing", &set_tracing); m.def("set_tracing", &set_tracing);
m.def("unset_tracing", &unset_tracing); m.def("unset_tracing", &unset_tracing);
m.def("set_symbolic", &set_symbolic);
m.def("unset_symbolic", &unset_symbolic);
m.def("set_compiled", &set_compiled); m.def("set_compiled", &set_compiled);
m.def("unset_compiled", &unset_compiled); m.def("unset_compiled", &unset_compiled);
m.def("__make_empty_tensor", &make_empty_tensorwrapper);
} }
#undef MGE_PY_INTERFACE #undef MGE_PY_INTERFACE
......
...@@ -159,15 +159,16 @@ struct TensorWrapper { ...@@ -159,15 +159,16 @@ struct TensorWrapper {
PyObject* handle(); PyObject* handle();
void set_handle(PyObject *); void set_handle(PyObject *);
PyObject* data_read();
PyObject* value_read();
PyObject* shape_read();
PyObject* mixin_handle(); PyObject* mixin_handle();
PyObject* recording();
void set_data_read(PyObject*);
void set_value_read(PyObject*);
void set_shape_read(PyObject*);
void set_mixin_handle(PyObject*); void set_mixin_handle(PyObject*);
void set_recording(PyObject*);
PyObject* compiled_info();
void set_compiled_info(PyObject *);
PyObject* trace_mixin_info();
void set_trace_mixin_info(PyObject *);
}; };
...@@ -219,7 +220,6 @@ template <typename... Args> ...@@ -219,7 +220,6 @@ template <typename... Args>
constexpr bool is_all_tensor_ptr = (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>); constexpr bool is_all_tensor_ptr = (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>);
extern bool is_tracing; // FIXME: should use ApplyContext::global_enable extern bool is_tracing; // FIXME: should use ApplyContext::global_enable
extern bool is_symbolic;
extern bool is_compiled; extern bool is_compiled;
template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0> template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0>
......
...@@ -22,7 +22,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { ...@@ -22,7 +22,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
apply_result_t outputs; apply_result_t outputs;
if (ctx.backward) { if (ctx.backward) {
// reach here when symbolic=True or compiled=True // reach here when compiled=True
// call megbrain_graph.py apply(BackwardGraph, *args) // call megbrain_graph.py apply(BackwardGraph, *args)
auto args = py::tuple(ctx.nargs + 1); auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op); args[0] = py::cast(ctx.op);
......
...@@ -10,15 +10,38 @@ ...@@ -10,15 +10,38 @@
*/ */
#include "inttypes.h" #include "inttypes.h"
#include "Python.h"
namespace mgb::imperative::python { namespace mgb::imperative::python {
struct TraceInfo { struct TraceInfo {
int64_t mixin_handle = -1; int64_t mixin_handle = -1;
bool recording = false;
bool data_read = false; PyObject* compiled_info = nullptr;
bool value_read = false; PyObject* trace_mixin_info = nullptr;
bool shape_read = false;
TraceInfo() = default;
TraceInfo& operator=(const TraceInfo& that) {
mixin_handle = that.mixin_handle;
recording = that.recording;
compiled_info = that.compiled_info;
Py_XINCREF(compiled_info);
trace_mixin_info = that.trace_mixin_info;
Py_XINCREF(trace_mixin_info);
return *this;
}
~TraceInfo() {
Py_XDECREF(trace_mixin_info);
// Py_XDECREF(compiled_info);
}
private:
TraceInfo(const TraceInfo& that) = default;
}; };
} // namespace mgb::imperative::python } // namespace mgb::imperative::python
...@@ -311,6 +311,7 @@ def test_trace_warp_perspective(): ...@@ -311,6 +311,7 @@ def test_trace_warp_perspective():
f(x, M) f(x, M)
@pytest.mark.skip(reason="skip")
def test_raise_on_trace(): def test_raise_on_trace():
step_count = 0 step_count = 0
catch_count = 0 catch_count = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册