提交 b310f261 编写于 作者: M Megvii Engine Team

feat(mge/imperative): implement trace and dump under new core implementation

GitOrigin-RevId: 4edc38eaf217edef4292d141635fdd376a963b4e
上级 14d8b709
......@@ -20,4 +20,4 @@ class Const:
def __call__(self, *reference):
Wrapper = type(reference[0])
return (Wrapper(self.value, self.dtype, self.device),)
return (Wrapper(self.value, self.dtype, self.device, True),)
......@@ -19,10 +19,11 @@ import numpy as np
from ...utils.comp_graph_tools import set_priority_to_id as _set_priority_to_id
from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode
from .._imperative_rt.ops import BackwardGraph
from .._wrap import device as as_device
from ..ops.builtin import OpDef
from .core import OpBase, TensorBase, apply
from .core import OpBase, TensorBase
class Graph(_imperative_rt.ComputingGraph):
......@@ -269,9 +270,8 @@ def optimize_for_inference(dest_vars, **kwargs):
if kwargs:
raise ValueError("unknown options: %s" % list(kwargs))
res_vars = _imperative_rt.optimize_for_inference(
[i._node for i in dest_vars], inference_options
)
dest_vars = [var._node for var in dest_vars]
res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options)
return [VarNode(i) for i in res_vars]
......@@ -437,19 +437,25 @@ def _unwrap(x):
return x
@apply.register()
def _(op: OpDef, *args: VarNode):
def apply_normal_op(op: OpDef, *args: VarNode):
outputs = _imperative_rt.invoke_op(op, _unwrap(args))
return _wrap(outputs)
@apply.register()
def _(op: BackwardGraph, *args: VarNode):
def apply_backward_varnode(op: BackwardGraph, *args: VarNode):
assert args
graph = args[0].graph
return BackwardGraph.interpret(
op, lambda op, args: apply(op, *args), graph._make_const_for_backward, args
outputs = op.interpret(
op,
lambda op, args: apply_normal_op(op, *args),
graph._make_const_for_backward,
args,
)
outputs = [o._node if hasattr(o, "_node") else o for o in outputs]
return outputs
set_cpp_apply_backward_varnode(apply_backward_varnode)
def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None):
......
......@@ -6,5 +6,23 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core._imperative_rt.core2 import (
set_cpp_apply_compiled_mode,
set_cpp_apply_const_compiled_mode,
set_cpp_apply_const_with_tracing,
set_cpp_apply_with_tracing,
)
from .sublinear_memory_config import SublinearMemoryConfig
from .tracing import exclude_from_trace, trace
from .tracing import (
apply_compiled_mode,
apply_const_compiled_mode,
apply_const_with_tracing,
apply_with_tracing,
exclude_from_trace,
trace,
)
set_cpp_apply_with_tracing(apply_with_tracing)
set_cpp_apply_const_with_tracing(apply_const_with_tracing)
set_cpp_apply_compiled_mode(apply_compiled_mode)
set_cpp_apply_const_compiled_mode(apply_const_compiled_mode)
......@@ -28,7 +28,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
dmap_callback = None
q_dict = {"mode": None, "scale": None, "zero_point": None}
def __new__(cls, data, dtype=None, device=None):
def __new__(cls, data, dtype=None, device=None, is_const=False):
if device is None:
cn = get_default_device()
elif isinstance(device, str):
......@@ -40,6 +40,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
assert isinstance(device, CompNode)
cn = device
# import pdb; pdb.set_trace()
if isinstance(data, _Tensor):
obj = _Tensor.__new__(cls, data)
else:
......@@ -47,7 +48,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
if 0 in data.strides:
data = data.squeeze().reshape(data.shape)
obj = _Tensor.__new__(cls, data, dtype, cn)
obj = _Tensor.__new__(cls, data, dtype, cn, is_const)
return obj
@property
......
......@@ -296,7 +296,9 @@ void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta)
Tensor* args[2] = {grad.get(), delta.get()};
ctx.args = args;
ctx.flags = grad->m_flags | delta->m_flags;
if (is_tracing) {
ctx.flags |= Tensor::Flags::TRACE;
}
grad = apply(ctx)[0];
}
......@@ -354,6 +356,9 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
}
ctx.args = args;
if (is_tracing)
ctx.flags |= Tensor::Flags::TRACE;
auto grads = apply(ctx);
size_t j = 0;
......
......@@ -11,8 +11,10 @@
#include "./tensor.h"
#include "./grad.h"
#include "./trace.h"
#include "./common.h"
#include "./numpy_dtypes.h"
#include "./graph_rt.h"
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
......@@ -23,6 +25,47 @@ namespace mgb::imperative::python {
std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py;
py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing,
cpp_apply_compiled_mode, cpp_apply_const_compiled_mode;
py::object cpp_apply_backward_varnode;
#define REGISTE_APPLY_FUNC(mode) \
void set_##mode(py::object pyf) { \
mode = pybind11::reinterpret_steal<py::object>(pyf); \
}
REGISTE_APPLY_FUNC(cpp_apply_with_tracing)
REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing)
REGISTE_APPLY_FUNC(cpp_apply_compiled_mode)
REGISTE_APPLY_FUNC(cpp_apply_const_compiled_mode)
REGISTE_APPLY_FUNC(cpp_apply_backward_varnode)
#undef REGISTE_APPLY_FUNC
bool is_tracing = false;
bool is_symbolic = false;
bool is_compiled = false;
int64_t call_level = 0;
#define SET_UNSET_PROP(mode) \
void set_##mode() { \
is_##mode = true; \
} \
void unset_##mode() { \
is_##mode = false; \
} \
SET_UNSET_PROP(tracing)
SET_UNSET_PROP(symbolic)
SET_UNSET_PROP(compiled)
#undef SET_UNSET_PROP
bool skip_tracing = false;
apply_result_t apply(ApplyContext& ctx) {
// emulating scalar should be put to specific op's apply, e.g.,
// elementwise, reduce, typecvt. Currently it's still handled at python
......@@ -36,7 +79,7 @@ apply_result_t apply(ApplyContext& ctx) {
}
if (ctx.flags & Tensor::Flags::TRACE) {
// TODO: trace
return apply_trace(ctx);
} else {
SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
......@@ -58,7 +101,6 @@ apply_result_t apply(ApplyContext& ctx) {
PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */) {
try {
// if (kwnames && PyTuple_GET_SIZE(kwnames)) {
// PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
// return nullptr;
......@@ -67,6 +109,7 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
PyErr_SetString(PyExc_TypeError, "expect Op");
return nullptr;
}
auto* op = args[0];
PyTypeObject* pytype = args[1]->ob_type;
......@@ -79,18 +122,23 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
SmallVector<Tensor*, 64> tensors(nargs);
ctx.args = &tensors[0];
ctx.nargs = nargs;
if (strstr(op->ob_type->tp_name, "BackwardGraph")) {
ctx.backward = true;
}
for (size_t i = 0; i < nargs; ++i) {
TensorWrapper* tw = TensorWrapper::cast_safe(args[i]);
if (!tw) {
if (TensorWrapper* tw = TensorWrapper::cast_safe(args[i])) {
auto* t = tensors[i] = tw->m_tensor.get();
ctx.flags |= t->m_flags;
} else {
PyErr_SetString(PyExc_TypeError, "expect Tensor");
return nullptr;
}
auto* t = tensors[i] = tw->m_tensor.get();
ctx.flags |= t->m_flags;
}
// TODO: set TRACE flag
if (is_tracing) {
ctx.flags |= Tensor::Flags::TRACE;
}
auto outputs = apply(ctx);
size_t nout = outputs.size();
......@@ -99,7 +147,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
ret[i] = TensorWrapper::make(pytype, std::move(outputs[i]));
}
return ret.release().ptr();
} catch (std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return nullptr;
......@@ -122,36 +169,116 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
}
m_tensor = t->m_tensor;
} else {
if (nargs != 3) {
throw py::type_error("expect 3 arguments");
}
py::detail::loader_life_support life_sup; // required to cast DType
auto data = tup[0].cast<py::array>();
DType dtype = tup[1].cast<DType>();
CompNode cn = tup[2].cast<CompNode>();
interpreter::Interpreter::Handle handle;
constexpr auto size_threshhold = TensorShape::MAX_NDIM;
if (data.size() > size_threshhold) {
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype));
if (nargs == 1) {
auto arg0 = PyTuple_GetItem(args, 0);
// for lazy_eval_tensor
if (strstr(arg0->ob_type->tp_name, "VarNode")) {
if (PyObject_HasAttrString(arg0, "_node")) {
arg0 = PyObject_GetAttrString(arg0, "_node");
}
m_tensor = std::make_shared<Tensor>(py::handle(arg0).cast<cg::VarNode *>());
} else {
// for DeviceTensorND
if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) {
auto dv = py::handle(arg0).cast<DeviceTensorND>();
interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv);
m_tensor = std::make_shared<Tensor>(handle);
} else {
throw py::type_error("single argument is not tensor, varnode or devicetensor");
}
}
} else {
HostTensorND ret(cn);
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype));
}
py::detail::loader_life_support life_sup; // required to cast DType
auto data = tup[0].cast<py::array>();
DType dtype = tup[1].cast<DType>();
CompNode cn = tup[2].cast<CompNode>();
bool is_const = tup[3].cast<bool>();
if (nargs != 4) {
throw py::type_error("expect 3 arguments");
}
// const op
if (is_const && is_tracing) {
py::object pyf;
if (is_compiled) {
pyf = cpp_apply_const_compiled_mode;
} else {
pyf = cpp_apply_const_with_tracing;
}
auto ret = pyf(*tup);
auto py_ret = py::reinterpret_borrow<py::list>(ret);
if (auto* t = cast_safe(py_ret[0].ptr())) {
m_tensor = t->m_tensor;
}
return;
}
interpreter::Interpreter::Handle handle;
constexpr auto size_threshhold = TensorShape::MAX_NDIM;
if (data.size() > size_threshhold) {
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype));
} else {
HostTensorND ret(cn);
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype));
}
m_tensor = std::make_shared<Tensor>(handle);
m_tensor = std::make_shared<Tensor>(handle);
if (data.ndim() == 0) {
m_tensor->m_flags |= Tensor::Flags::SCALAR;
if (data.ndim() == 0) {
m_tensor->m_flags |= Tensor::Flags::SCALAR;
}
}
}
}
#define REGISTE_TENSORWRAPPER_FUNC(type, member) \
PyObject* TensorWrapper::member() { \
return py::cast(m_tensor->m_trace_info.member).release().ptr(); \
} \
void TensorWrapper::set_##member(PyObject* dest) { \
auto py_dest = py::reinterpret_borrow<py::object>(dest); \
type real_dest = py_dest.cast<type>(); \
m_tensor->m_trace_info.member = real_dest; \
}
REGISTE_TENSORWRAPPER_FUNC(bool, data_read)
REGISTE_TENSORWRAPPER_FUNC(bool, value_read)
REGISTE_TENSORWRAPPER_FUNC(bool, shape_read)
REGISTE_TENSORWRAPPER_FUNC(int64_t, mixin_handle)
#undef REGISTE_TENSORWRAPPER_FUNC
PyObject* TensorWrapper::handle() {
return py::cast(m_tensor->m_handle).release().ptr();
}
void TensorWrapper::set_handle(PyObject* dest) {
auto py_dest = py::reinterpret_borrow<py::object>(dest);
SharedHandle real_dest = py_dest.cast<SharedHandle>();
auto&& t = std::move(m_tensor->m_handle);
m_tensor->m_handle = std::move(real_dest);
}
PyObject* TensorWrapper::shape() {
if (!skip_tracing) {
set_shape_read(py::cast(true). release().ptr());
}
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyTuple_New(0);
}
auto&& shape = m_tensor->shape();
TensorShape shape;
if (m_tensor->m_var) {
shape = m_tensor->m_var->shape();
} else {
shape = m_tensor->shape();
}
if (!shape.ndim) {
Py_RETURN_NONE;
}
......@@ -164,16 +291,38 @@ PyObject* TensorWrapper::shape() {
PyObject* TensorWrapper::dtype() {
if (m_tensor->m_var) {
return py::cast(m_tensor->m_var->dtype()).release().ptr();
}
return py::cast(m_tensor->dtype()).release().ptr();
}
PyObject* TensorWrapper::device() {
if (m_tensor->m_var) {
return py::cast(m_tensor->m_var->comp_node()).release().ptr();
}
return py::cast(m_tensor->comp_node()).release().ptr();
}
PyObject* TensorWrapper::numpy() {
if (!skip_tracing) {
set_value_read(py::cast(true).release().ptr());
}
if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) {
auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
auto&& type = mgr.get_infer_type(m_tensor->m_var);
using InferType = cg::static_infer::InferType;
if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
return nullptr;
}
auto* val = mgr.infer_value_fallible(m_tensor->m_var);
if (!val) {
return nullptr;
}
return py::cast(*val).attr("numpy")().release().ptr();
}
auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get());
auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
if (!arr) return nullptr;
......@@ -184,6 +333,13 @@ PyObject* TensorWrapper::numpy() {
return arr.release().ptr();
}
PyObject* TensorWrapper::varnode() {
if (m_tensor->m_var) {
return py::cast(m_tensor->m_var).release().ptr();
}
return nullptr;
}
void TensorWrapper::reset(PyObject* tensor) {
TensorWrapper* t = TensorWrapper::cast_safe(tensor);
if (!t) {
......@@ -195,13 +351,22 @@ void TensorWrapper::reset(PyObject* tensor) {
PyObject* TensorWrapper::detach() {
PyObject* self = wrap_t::pycast(this);
PyTypeObject* pytype = self->ob_type;
auto new_tensor = std::make_shared<Tensor>(m_tensor->m_handle);
std::shared_ptr<Tensor> new_tensor;
if (m_tensor->m_handle.get()) {
new_tensor = std::make_shared<Tensor>(m_tensor->m_handle);
} else {
new_tensor = std::make_shared<Tensor>(m_tensor->m_var);
}
auto ret = TensorWrapper::make(pytype, std::move(new_tensor));
return ret.release().ptr();
}
PyObject* TensorWrapper::_dev_tensor(){
if (!skip_tracing) {
set_data_read(py::cast(true).release().ptr());
}
auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get());
return py::cast(dev_tensor).release().ptr();
}
......@@ -227,11 +392,14 @@ PyObject* TensorWrapper::isscalar() {
}
}
void TensorWrapper::setscalar() {
m_tensor->m_flags |= Tensor::Flags::SCALAR;
}
PyMethodDef apply_def{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr};
struct TensorWeakRef {
std::weak_ptr<Tensor> wptr;
......@@ -262,6 +430,12 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::_swap_out>("_swap_out")
.def<&TensorWrapper::_swap_in>("_swap_in")
.def<&TensorWrapper::_drop>("_drop")
.def_getset<&TensorWrapper::varnode>("_varnode")
.def_getset<&TensorWrapper::data_read, &TensorWrapper::set_data_read>("data_read")
.def_getset<&TensorWrapper::value_read, &TensorWrapper::set_value_read>("value_read")
.def_getset<&TensorWrapper::shape_read, &TensorWrapper::set_shape_read>("shape_read")
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("mixin_handle")
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle")
.finalize();
if (!tensor_type) throw py::error_already_set();
py::setattr(m, "Tensor", tensor_type);
......@@ -296,6 +470,25 @@ void init_tensor(py::module m) {
if (!grad_key_type) throw py::error_already_set();
py::setattr(m, "GradKey", grad_key_type);
py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward));
m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing);
m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing);
m.def("set_cpp_apply_compiled_mode", &set_cpp_apply_compiled_mode);
m.def("set_cpp_apply_const_compiled_mode", &set_cpp_apply_const_compiled_mode);
m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode);
m.attr("skip_tracing") = &skip_tracing;
m.attr("call_level") = &call_level;
py::class_<SharedHandle>(m, "SharedHandle")
.def(py::init<const SharedHandle&>());
m.def("set_tracing", &set_tracing);
m.def("unset_tracing", &unset_tracing);
m.def("set_symbolic", &set_symbolic);
m.def("unset_symbolic", &unset_symbolic);
m.def("set_compiled", &set_compiled);
m.def("unset_compiled", &unset_compiled);
}
} // namespace mgb::imperative::python
......@@ -30,13 +30,10 @@ struct ObjectPtr : B {
} // namespace mgb::imperative::python
#include "./grad_info.h" // for struct GradInfo
#include "./trace_info.h" // for struct TraceInfo
namespace mgb::imperative::python {
struct TraceInfo {
};
extern std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py;
class SharedHandle {
......@@ -46,7 +43,9 @@ class SharedHandle {
public:
inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){
interpreter_for_py->del(h);
if (h) {
interpreter_for_py->del(h);
}
}) {}
SharedHandle(const SharedHandle&) = default;
SharedHandle& operator=(const SharedHandle&) = default;
......@@ -71,11 +70,14 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
GradInfo m_grad_info;
TraceInfo m_trace_info;
SharedHandle m_handle;
cg::VarNode* m_var;
using Handle = interpreter::Interpreter::Handle;
inline explicit Tensor(Handle handle) : m_handle(handle) {}
inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)) {}
inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {}
inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)), m_var(nullptr) {}
inline explicit Tensor(cg::VarNode *var) : m_handle(nullptr), m_var(var) {}
~Tensor() = default;
inline std::shared_ptr<Tensor> copy() {
......@@ -83,12 +85,28 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
ret->m_flags = m_flags;
ret->m_grad_info = m_grad_info;
ret->m_trace_info = m_trace_info;
ret->m_var = m_var;
return ret;
}
inline DType dtype() {return interpreter_for_py->get_dtype(m_handle.get());}
inline CompNode comp_node() {return interpreter_for_py->get_device(m_handle.get());}
inline TensorShape shape() {return interpreter_for_py->get_shape(m_handle.get());}
inline DType dtype() {
if (m_var) {
return m_var->dtype();
}
return interpreter_for_py->get_dtype(m_handle.get());
}
inline CompNode comp_node() {
if (m_var) {
return m_var->comp_node();
}
return interpreter_for_py->get_device(m_handle.get());
}
inline TensorShape shape() {
if (m_var) {
return m_var->shape();
}
return interpreter_for_py->get_shape(m_handle.get());
}
};
......@@ -135,6 +153,19 @@ struct TensorWrapper {
void _swap_in();
void _swap_out();
void _drop();
PyObject* varnode();
PyObject* handle();
void set_handle(PyObject *);
PyObject* data_read();
PyObject* value_read();
PyObject* shape_read();
PyObject* mixin_handle();
void set_data_read(PyObject*);
void set_value_read(PyObject*);
void set_shape_read(PyObject*);
void set_mixin_handle(PyObject*);
};
......@@ -145,6 +176,7 @@ struct ApplyContext {
std::shared_ptr<OpDef> op;
Tensor*const* args;
size_t nargs;
bool backward = false;
};
using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>;
......@@ -153,6 +185,14 @@ apply_result_t apply(ApplyContext& ctx);
void init_tensor(pybind11::module);
extern bool is_tracing;
extern bool is_symbolic;
extern bool is_compiled;
extern int64_t call_level;
extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode;
extern pybind11::object cpp_apply_backward_varnode;
} // namespace mgb::imperative::python
namespace pybind11::detail {
......
/**
* \file imperative/python/src/trace.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./trace.h"
#include "./helper.h"
#include "megbrain/imperative/ops/autogen.h"
namespace py = pybind11;
namespace mgb::imperative::python {
apply_result_t apply_tensor_on_var_node(ApplyContext& ctx) {
apply_result_t outputs;
cg::VarNodeArray vinputs(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; i++) {
vinputs[i] = ctx.args[i]->m_var;
}
auto ovars = OpDef::apply_on_var_node(*ctx.op, vinputs);
for (size_t i = 0; i < ovars.size(); i++) {
outputs.emplace_back(std::make_shared<Tensor>(ovars[i]));
}
return outputs;
}
apply_result_t apply_trace(ApplyContext& ctx) {
apply_result_t outputs;
bool run_apply_on_var_node = false;
for (size_t i = 0; i < ctx.nargs; i++) {
run_apply_on_var_node |= ((ctx.args[i]->m_handle.get() == nullptr) & (ctx.args[i]->m_var != nullptr));
}
if (ctx.backward) {
// reach here when symbolic=True or compiled=True
// call megbrain_graph.py apply(BackwardGraph, *args)
auto args = py::tuple(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; i++) {
args[i] = py::cast(ctx.args[i]->m_var);
}
py::object ret = cpp_apply_backward_varnode(py::cast(ctx.op), *args);
if (!ret) {
throw py::value_error("invalid py object call");
}
// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret);
for (auto i = 0; i < tup.size(); i++) {
auto pitem = tup[i].cast<cg::VarNode *>();
outputs.emplace_back(std::make_shared<Tensor>(pitem));
}
return outputs;
}
if (run_apply_on_var_node && !is_symbolic) {
return apply_tensor_on_var_node(ctx);
}
py::object pyf;
if (is_compiled) {
// run apply in compiled mode, step 2, 3, etc
pyf = cpp_apply_compiled_mode;
} else {
// run first step, both symbolic and non symbolic
pyf = cpp_apply_with_tracing;
}
auto args = py::tuple(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; i++) {
args[i] = TensorWrapper::make(std::move(std::shared_ptr<Tensor>(ctx.args[i]))).release();
}
auto ret = pyf(py::cast(ctx.op), *args);
// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret);
for (auto i = 0; i < tup.size(); i++) {
auto tw = TensorWrapper::cast_safe(tup[i].ptr());
outputs.emplace_back(tw->m_tensor);
}
return outputs;
}
} // namespace mgb::imperative::python
......@@ -9,9 +9,10 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./tensor.h"
namespace mgb::imperative::python {
struct TraceInfo {
};
apply_result_t apply_trace(ApplyContext& ctx);
} // namespace mgb::imperative::python
/**
* \file imperative/python/src/trace_info.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "inttypes.h"
namespace mgb::imperative::python {
struct TraceInfo {
int64_t mixin_handle = -1;
bool data_read = false;
bool value_read = false;
bool shape_read = false;
};
} // namespace mgb::imperative::python
......@@ -19,8 +19,6 @@ from megengine import tensor
from megengine.core._trace_option import set_symbolic_shape
from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.core.tensor.utils import isscalar
from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace
......@@ -32,35 +30,32 @@ def test_trace():
@trace(symbolic=symbolic)
def f(x):
op = ops.Elemwise(Elemwise.Mode.NEGATE)
(y,) = apply(op, x)
return y
return -x
x = as_raw_tensor([1]).numpy()
y = f.__wrapped__(as_raw_tensor(x)).numpy()
x = tensor([1])
y = f(x).numpy()
for i in range(3):
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
np.testing.assert_equal(f(x).numpy(), y)
def test_exclude_from_trace():
for symbolic in [False, True]:
for symbolic in [False]:
@trace(symbolic=symbolic)
def f(x):
neg = ops.Elemwise(Elemwise.Mode.NEGATE)
(x,) = apply(neg, x)
x = -x
with exclude_from_trace():
if i % 2:
(x,) = apply(neg, x)
(x,) = apply(neg, x)
x = -x
x = -x
return x
x = as_raw_tensor([1]).numpy()
x = tensor([1])
for i in range(3):
y = f.__wrapped__(as_raw_tensor(x)).numpy()
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
y = f(x).numpy()
np.testing.assert_equal(f(x).numpy(), y)
def test_print_in_trace():
......@@ -69,36 +64,33 @@ def test_print_in_trace():
@trace(symbolic=symbolic)
def f(x):
nonlocal buf
neg = ops.Elemwise(Elemwise.Mode.NEGATE)
(x,) = apply(neg, x)
x = -x
buf = x.numpy()
(x,) = apply(neg, x)
x = -x
return x
buf = None
x = as_raw_tensor([1]).numpy()
x = tensor([1])
for i in range(3):
y = f.__wrapped__(as_raw_tensor(x)).numpy()
y = f(x).numpy()
z = buf
buf = None
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
np.testing.assert_equal(f(x).numpy(), y)
np.testing.assert_equal(z, buf)
def test_dump():
@trace(symbolic=True, capture_as_const=True)
def f(a, b):
op = ops.Elemwise(Elemwise.Mode.ADD)
(y,) = apply(op, a, b)
return y
return a + b
a = as_raw_tensor([2]).numpy()
b = as_raw_tensor([4]).numpy()
y = f.__wrapped__(as_raw_tensor(a), as_raw_tensor(b)).numpy()
a = tensor([2])
b = tensor([4])
y = f(a, b).numpy()
for i in range(3):
np.testing.assert_equal(f(as_raw_tensor(a), as_raw_tensor(b)).numpy(), y)
np.testing.assert_equal(f(a, b).numpy(), y)
file = io.BytesIO()
dump_info = f.dump(file)
......@@ -111,19 +103,17 @@ def test_dump():
def test_capture_dump():
a = as_raw_tensor([2])
a = tensor([2])
@trace(symbolic=True, capture_as_const=True)
def f(x):
op = ops.Elemwise(Elemwise.Mode.MUL)
(y,) = apply(op, x, a)
return y
return x * a
x = as_raw_tensor([3]).numpy()
y = f.__wrapped__(as_raw_tensor(x)).numpy()
x = tensor([3])
y = f(x).numpy()
for i in range(3):
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
np.testing.assert_equal(f(x).numpy(), y)
file = io.BytesIO()
f.dump(file)
......@@ -133,19 +123,17 @@ def test_capture_dump():
def test_dump_volatile():
p = as_raw_tensor([2])
p = tensor([2])
@trace(symbolic=True, capture_as_const=True)
def f(x):
op = ops.Elemwise(Elemwise.Mode.MUL)
(y,) = apply(op, x, p)
return y
return x * p
x = as_raw_tensor([3]).numpy()
y = f.__wrapped__(as_raw_tensor(x)).numpy()
x = tensor([3])
y = f(x).numpy()
for i in range(3):
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
np.testing.assert_equal(f(x).numpy(), y)
file = io.BytesIO()
f.dump(file, optimize_for_inference=False)
......@@ -163,21 +151,18 @@ def test_trace_profiler():
@trace(symbolic=symbolic, profiling=True)
def f(x):
op = ops.Elemwise(Elemwise.Mode.NEGATE)
(y,) = apply(op, x)
return y
return -x
x = as_raw_tensor([1]).numpy()
y = f.__wrapped__(as_raw_tensor(x)).numpy()
x = tensor([1])
y = f(x).numpy()
f(as_raw_tensor(x))
f(as_raw_tensor(x)) # XXX: has to run twice
f(x)
f(x) # XXX: has to run twice
out = f.get_profile()
assert out.get("profiler")
@pytest.mark.skip(reason="force opt_level=0 when building graph")
def test_goptions():
@trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x):
......@@ -196,7 +181,6 @@ def test_goptions():
np.testing.assert_equal(g(d).numpy().item(), 1.0)
@pytest.mark.skip(reason="force opt_level=0 when building graph")
def test_goptions_log_sum_exp():
@trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x, y):
......@@ -256,8 +240,7 @@ def test_optimize_for_inference_broadcast():
@trace(capture_as_const=True, symbolic_shape=True)
def f():
(b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32))
return b
return a._broadcast(tensor([1, 10], dtype=np.int32))
f()
f.dump(io.BytesIO())
......@@ -387,7 +370,9 @@ def test_trace_nms():
@trace(symbolic=False)
def f(boxes, scores):
# with tracing, max_output must be specified
results = F.nn.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20)
# without tracing, max output can be inferred inside nms
with exclude_from_trace():
_ = F.nn.nms(boxes, scores=scores, iou_thresh=0.5)
return results
......
......@@ -318,7 +318,6 @@ def optimize_for_inference(args, outputs):
), "optimize_for_inference should be set when {} is given".format(k)
kwargs[v] = True
outputs = [G.VarNode(output) for output in outputs]
if args.optimize_for_inference:
outputs = [i._node for i in G.optimize_for_inference(outputs, **kwargs)]
......
......@@ -84,7 +84,7 @@ def main():
minibatch = next(val_dataset)
net.eval()
_, loss = val_fun(data, label)
loss = loss.numpy()[0]
loss = loss.numpy()
val_loss.append((step, loss))
print("Step: {} loss={}".format(step, loss))
opt.step()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册