提交 ae47fd4e 编写于 作者: M Megvii Engine Team

fix(mge): fix none return value for attrs, add test_correctness

GitOrigin-RevId: 1bb96373f450c74ccf9e640cd4b2c73579f3c398
上级 97d12b3e
...@@ -20,10 +20,10 @@ import numpy as np ...@@ -20,10 +20,10 @@ import numpy as np
from ..core._imperative_rt import GraphProfiler, common, put from ..core._imperative_rt import GraphProfiler, common, put
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import TensorWeakRef
from ..core._imperative_rt.core2 import __make_empty_tensor as make_empty_tensor
from ..core._imperative_rt.core2 import ( from ..core._imperative_rt.core2 import (
TensorWeakRef,
apply, apply,
call_level,
set_compiled, set_compiled,
set_symbolic, set_symbolic,
set_tracing, set_tracing,
...@@ -86,6 +86,9 @@ class TensorInfo: ...@@ -86,6 +86,9 @@ class TensorInfo:
__slots__ = ( __slots__ = (
# collected attributes # collected attributes
"external", "external",
"data_read",
"shape_read",
"value_read",
"exported", "exported",
"device", "device",
"dtype", "dtype",
...@@ -102,6 +105,9 @@ class TensorInfo: ...@@ -102,6 +105,9 @@ class TensorInfo:
def __init__(self): def __init__(self):
self.exported = None self.exported = None
self.data_read = None
self.shape_read = None
self.value_read = None
self.bound_data = None self.bound_data = None
self.data_setter = None self.data_setter = None
...@@ -154,7 +160,7 @@ class trace: ...@@ -154,7 +160,7 @@ class trace:
self._graph_opt_level = opt_level self._graph_opt_level = opt_level
self._symbolic_shape = symbolic_shape self._symbolic_shape = symbolic_shape
self._handle2tensors = {} self._handle2tensors = {}
self._handle2compiledtensors = {} self._output_handles = set()
self._reset() self._reset()
...@@ -244,11 +250,12 @@ class trace: ...@@ -244,11 +250,12 @@ class trace:
# ) # )
self._pc += 1 self._pc += 1
outputs = []
for h in ohandles: for h in ohandles:
t = CompiledTensorProxy(h) t = CompiledTensorProxy(h)
t._dev_tensor() t._dev_tensor()
self._handle2compiledtensors[h] = t outputs += [t._CompiledTensorProxy__tensor]
outputs = [self._handle2tensors[h] for h in ohandles] self._output_handles.update(ohandles)
self._active_tensors.update([TensorWeakRef(o) for o in outputs]) self._active_tensors.update([TensorWeakRef(o) for o in outputs])
return outputs return outputs
...@@ -347,11 +354,12 @@ class trace: ...@@ -347,11 +354,12 @@ class trace:
self._lazy_eval_links = () self._lazy_eval_links = ()
def _take_escaped_tensors(self): def _take_escaped_tensors(self):
escaped_tensors = tuple(self._active_tensors) escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors))
self._active_tensors.clear() self._active_tensors.clear()
return escaped_tensors return escaped_tensors
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors))
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors]
self._apply_graph_options(lazy_eval_graph) self._apply_graph_options(lazy_eval_graph)
# FIXME # FIXME
...@@ -393,6 +401,12 @@ class trace: ...@@ -393,6 +401,12 @@ class trace:
if self._inputs_to_restore: if self._inputs_to_restore:
for x in self._inputs_to_restore: for x in self._inputs_to_restore:
x.mixin_handle = -1 x.mixin_handle = -1
for h, x in list(self._handle2tensors.items()):
info = self._tinfo[h]
info.data_read = x.data_read
info.shape_read = x.shape_read
info.value_read = x.value_read
del self._handle2tensors[h]
if self._symbolic and ( if self._symbolic and (
self._lazy_eval_tensors or self._lazy_eval_links self._lazy_eval_tensors or self._lazy_eval_links
): ):
...@@ -433,8 +447,9 @@ class trace: ...@@ -433,8 +447,9 @@ class trace:
raise TraceMismatchError("premature end") raise TraceMismatchError("premature end")
if not self._symbolic or not self._untraced: if not self._symbolic or not self._untraced:
for x in self._active_tensors: for x in self._active_tensors:
x()._dev_tensor() if x() is not None:
x().mixin_handle = -1 x()._dev_tensor()
x().mixin_handle = -1
try: try:
do_enter() do_enter()
...@@ -581,8 +596,7 @@ class trace: ...@@ -581,8 +596,7 @@ class trace:
readers.append(opnode.outputs[0]) readers.append(opnode.outputs[0])
in_out_links = opnode.outputs in_out_links = opnode.outputs
x = self._handle2tensors[h] if info.data_read:
if x.data_read:
# Shape can be obtained from data so doesn't need its own # Shape can be obtained from data so doesn't need its own
# output node. On the other hand, value is read separately # output node. On the other hand, value is read separately
# to leverage eager h2d copy # to leverage eager h2d copy
...@@ -890,7 +904,7 @@ class trace: ...@@ -890,7 +904,7 @@ class trace:
self._output_bindings.append(h) self._output_bindings.append(h)
else: else:
h = x.mixin_handle h = x.mixin_handle
if h not in self._handle2compiledtensors: if h not in self._output_handles:
raise RuntimeError("output is not computed from inputs") raise RuntimeError("output is not computed from inputs")
if h != self._output_bindings[i]: if h != self._output_bindings[i]:
raise TraceMismatchError( raise TraceMismatchError(
...@@ -927,8 +941,7 @@ class CompiledTensorProxy: ...@@ -927,8 +941,7 @@ class CompiledTensorProxy:
self.__shape = None self.__shape = None
self.__data = None self.__data = None
self.__value = None self.__value = None
self.__tensor = active_trace._handle2tensors[handle] self.__tensor = make_empty_tensor()
self.__tensor.mixin_handle = handle
@property @property
def dtype(self): def dtype(self):
...@@ -943,19 +956,19 @@ class CompiledTensorProxy: ...@@ -943,19 +956,19 @@ class CompiledTensorProxy:
if self._isscalar: if self._isscalar:
return () return ()
if self.__shape is None: if self.__shape is None:
if self.__tensor.shape_read: if self.__info.shape_read:
self.__shape = self.__info.shape_reader.get_value().shape self.__shape = self.__info.shape_reader.get_value().shape
elif self.__tensor.data_read: elif self.__info.data_read:
self.__shape = self.__tensor._dev_tensor().shape self.__shape = self.__info._dev_tensor().shape
else: else:
raise TraceMismatchError("shape of this tensor is not read in trace") raise TraceMismatchError("shape of this tensor is not read in trace")
return self.__shape return self.__shape
def numpy(self): def numpy(self):
if self.__value is None: if self.__value is None:
if self.__tensor.value_read: if self.__info.value_read:
self.__value = self.__info.value_reader.get_value() self.__value = self.__info.value_reader.get_value()
elif self.__tensor.data_read: elif self.__info.data_read:
self.__value = self._dev_tensor().numpy() self.__value = self._dev_tensor().numpy()
else: else:
raise TraceMismatchError("value of this tensor is not read in trace") raise TraceMismatchError("value of this tensor is not read in trace")
...@@ -965,7 +978,7 @@ class CompiledTensorProxy: ...@@ -965,7 +978,7 @@ class CompiledTensorProxy:
def _dev_tensor(self): def _dev_tensor(self):
if self.__data is None: if self.__data is None:
if not self.__tensor.data_read: if not self.__info.data_read:
raise TraceMismatchError("raw data of this tensor is not read in trace") raise TraceMismatchError("raw data of this tensor is not read in trace")
self.__data = self.__info.data_reader.get_value() self.__data = self.__info.data_reader.get_value()
self.__tensor._reset(RawTensor(self.__data)) self.__tensor._reset(RawTensor(self.__data))
......
...@@ -53,9 +53,6 @@ bool is_tracing = false; ...@@ -53,9 +53,6 @@ bool is_tracing = false;
bool is_symbolic = false; bool is_symbolic = false;
bool is_compiled = false; bool is_compiled = false;
int64_t call_level = 0;
#define SET_UNSET_PROP(mode) \ #define SET_UNSET_PROP(mode) \
void set_##mode() { \ void set_##mode() { \
is_##mode = true; \ is_##mode = true; \
...@@ -321,17 +318,22 @@ PyObject* TensorWrapper::numpy() { ...@@ -321,17 +318,22 @@ PyObject* TensorWrapper::numpy() {
auto&& type = mgr.get_infer_type(m_tensor->m_var); auto&& type = mgr.get_infer_type(m_tensor->m_var);
using InferType = cg::static_infer::InferType; using InferType = cg::static_infer::InferType;
if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
PyErr_SetString(PyExc_ValueError, "tensor invalid");
return nullptr; return nullptr;
} }
auto* val = mgr.infer_value_fallible(m_tensor->m_var); auto* val = mgr.infer_value_fallible(m_tensor->m_var);
if (!val) { if (!val) {
PyErr_SetString(PyExc_ValueError, "tensor invalid");
return nullptr; return nullptr;
} }
return py::cast(*val).attr("numpy")().release().ptr(); return py::cast(*val).attr("numpy")().release().ptr();
} }
auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get()); auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get());
auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
if (!arr) return nullptr; if (!arr) {
PyErr_SetString(PyExc_ValueError, "tensor invalid");
return nullptr;
}
if (m_tensor->m_flags & Tensor::Flags::SCALAR) { if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
mgb_assert(PyArray_Check(arr.ptr())); mgb_assert(PyArray_Check(arr.ptr()));
return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr())); return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
...@@ -343,7 +345,7 @@ PyObject* TensorWrapper::varnode() { ...@@ -343,7 +345,7 @@ PyObject* TensorWrapper::varnode() {
if (m_tensor->m_var) { if (m_tensor->m_var) {
return py::cast(m_tensor->m_var).release().ptr(); return py::cast(m_tensor->m_var).release().ptr();
} }
return nullptr; return py::none().release().ptr();
} }
void TensorWrapper::reset(PyObject* tensor) { void TensorWrapper::reset(PyObject* tensor) {
...@@ -364,6 +366,7 @@ PyObject* TensorWrapper::detach() { ...@@ -364,6 +366,7 @@ PyObject* TensorWrapper::detach() {
} else { } else {
new_tensor = std::make_shared<Tensor>(m_tensor->m_var); new_tensor = std::make_shared<Tensor>(m_tensor->m_var);
} }
new_tensor->m_trace_info = m_tensor->m_trace_info;
auto ret = TensorWrapper::make(pytype, std::move(new_tensor)); auto ret = TensorWrapper::make(pytype, std::move(new_tensor));
return ret.release().ptr(); return ret.release().ptr();
...@@ -628,6 +631,10 @@ WRAP_FUNC_PY35(get_device); ...@@ -628,6 +631,10 @@ WRAP_FUNC_PY35(get_device);
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
#endif #endif
py::object make_empty_tensorwrapper() {
return TensorWrapper::make(std::move(std::make_shared<Tensor>()));
}
void init_tensor(py::module m) { void init_tensor(py::module m) {
interpreter_for_py = interpreter::Interpreter::inst().create_channel(); interpreter_for_py = interpreter::Interpreter::inst().create_channel();
...@@ -699,7 +706,6 @@ void init_tensor(py::module m) { ...@@ -699,7 +706,6 @@ void init_tensor(py::module m) {
m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode); m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode);
m.attr("skip_tracing") = &skip_tracing; m.attr("skip_tracing") = &skip_tracing;
m.attr("call_level") = &call_level;
py::class_<SharedHandle>(m, "SharedHandle") py::class_<SharedHandle>(m, "SharedHandle")
.def(py::init<const SharedHandle&>()); .def(py::init<const SharedHandle&>());
...@@ -711,6 +717,7 @@ void init_tensor(py::module m) { ...@@ -711,6 +717,7 @@ void init_tensor(py::module m) {
m.def("set_compiled", &set_compiled); m.def("set_compiled", &set_compiled);
m.def("unset_compiled", &unset_compiled); m.def("unset_compiled", &unset_compiled);
m.def("__make_empty_tensor", &make_empty_tensorwrapper);
} }
#undef MGE_PY_INTERFACE #undef MGE_PY_INTERFACE
......
...@@ -74,6 +74,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { ...@@ -74,6 +74,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
using Handle = interpreter::Interpreter::Handle; using Handle = interpreter::Interpreter::Handle;
inline Tensor() : m_handle(nullptr), m_var(nullptr) {}
inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {} inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {}
inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)), m_var(nullptr) {} inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)), m_var(nullptr) {}
inline explicit Tensor(cg::VarNode *var) : m_handle(nullptr), m_var(var) {} inline explicit Tensor(cg::VarNode *var) : m_handle(nullptr), m_var(var) {}
...@@ -188,7 +189,6 @@ void init_tensor(pybind11::module); ...@@ -188,7 +189,6 @@ void init_tensor(pybind11::module);
extern bool is_tracing; extern bool is_tracing;
extern bool is_symbolic; extern bool is_symbolic;
extern bool is_compiled; extern bool is_compiled;
extern int64_t call_level;
extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode; extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode;
extern pybind11::object cpp_apply_backward_varnode; extern pybind11::object cpp_apply_backward_varnode;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册