提交 d4bad711 编写于 作者: M Megvii Engine Team

feat(mge): add jit.trace

GitOrigin-RevId: ec647324c0e207b6185efe118b61a094c959ce7f
上级 0b88ec3c
......@@ -17,15 +17,31 @@ from ..ops.builtin import OpDef
from .core import OpBase, TensorBase, apply
class CompiledFunction:
def __init__(self, graph, function):
self._graph = graph
self._function = function
class Graph(_imperative_rt.ComputingGraph):
def __init__(self):
self._var_cache = weakref.WeakKeyDictionary()
self._op_cache = weakref.WeakKeyDictionary()
self._executor = ThreadPoolExecutor(1)
self._function = None
self._future = None
def _wrap(self, obj):
if type(obj) is _imperative_rt.VarNode:
wrapper, cache = VarNode, self._var_cache
elif type(obj) is _imperative_rt.OperatorNode:
wrapper, cache = OpNode, self._op_cache
if obj not in cache:
cache[obj] = wrapper(obj)
return cache[obj]
def compile(self, *args):
self._function = super().compile(_unwrap(args))
return self
def execute(self, *args):
assert self._future is None
self._future = self._graph._executor.submit(self._function.execute, *args)
self._future = self._executor.submit(self._function.execute, *args)
def wait(self):
assert self._future is not None
......@@ -40,30 +56,23 @@ class CompiledFunction:
return self.wait()
def make_const(self, data, dtype=None, device=None):
if isinstance(data, _imperative_rt.DeviceTensorND):
assert dtype is None and device is None
return self._wrap(_imperative_rt.make_shared(self, data))
device = as_device(device).to_c()
return self._wrap(_imperative_rt.make_const(self, data, device, dtype))
class Graph(_imperative_rt.ComputingGraph):
def __init__(self):
self._var_cache = weakref.WeakKeyDictionary()
self._op_cache = weakref.WeakKeyDictionary()
self._executor = ThreadPoolExecutor(1)
def _wrap(self, obj):
if type(obj) is _imperative_rt.VarNode:
wrapper, cache = VarNode, self._var_cache
elif type(obj) is _imperative_rt.OperatorNode:
wrapper, cache = OpNode, self._op_cache
if obj not in cache:
cache[obj] = wrapper(obj)
return cache[obj]
def compile(self, *args):
return CompiledFunction(self, super().compile(_unwrap(args)))
def make_input(self, *args: "VarNode", device=None, dtype=None, shape=None):
opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self)
return opnode.outputs[0]
class VarNode(TensorBase):
def __init__(self, node: _imperative_rt.VarNode):
self._node = node
self.graph._var_cache[node] = self
def graph(self) -> Graph:
......@@ -81,10 +90,15 @@ class VarNode(TensorBase):
def device(self):
return as_device(self._node.comp_node)
def shape(self):
return self._node.shape
class OpNode:
def __init__(self, node: _imperative_rt.OperatorNode):
self._node = node
self.graph._op_cache[node] = self
def graph(self) -> Graph:
......@@ -117,21 +131,21 @@ def _(op: OpDef, *args: VarNode):
return _wrap(outputs)
def input_callback(callback, *args, device=None, dtype=None, graph=None):
def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None):
outputs = _imperative_rt.input_callback(
callback, as_device(device).to_c(), dtype, _unwrap(args), graph=graph
callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph
value, dummy = _wrap(outputs)
return value, dummy
class InputNode(OpNode):
def __init__(self, *args: VarNode, device=None, dtype=None, graph=None):
def __init__(self, *args: VarNode, device=None, dtype=None, shape=None, graph=None):
r = _imperative_rt.DeviceTensorNDRendezvous()
if device is not None:
device = as_device(device).to_c()
outputs = _imperative_rt.input_callback(
r, device, dtype, _unwrap(args), graph=graph
r, device, dtype, shape, _unwrap(args), graph=graph
self._rendezvous = r
......@@ -169,6 +183,29 @@ class OutputNode(OpNode):
def get_value(self):
return self._rendezvous.get()
def drop_value(self):
def reset(self):
class ValueOutputNode(OpNode):
def __init__(self, var, *args):
args = (var,) + args
r = _imperative_rt.HostTensorNDRendezvous()
dummy = _imperative_rt.value_output_callback(r, _unwrap(args))
self._rendezvous = r
def get_value(self):
hostnd, event = self._rendezvous.get()
return hostnd.numpy()
def drop_value(self):
def reset(self):
......@@ -192,5 +229,8 @@ class AttrOutputNode(OpNode):
attr = self._rendezvous.get()
return TensorAttr(attr.shape, attr.dtype, as_device(attr.comp_node))
def drop_value(self):
def reset(self):
......@@ -31,11 +31,13 @@ class RawTensor(TensorBase):
_init_cb = None
_del_cb = None
_handle = None
def __init__(self, handle):
def __init__(self, handle=None):
self._handle = handle
if self._init_cb:
if handle is not None:
if self._init_cb:
def dtype(self):
......@@ -61,9 +63,10 @@ class RawTensor(TensorBase):
def __del__(self):
if self._del_cb:
if self._handle is not None:
if self._del_cb:
......@@ -89,6 +92,11 @@ def as_raw_tensor(obj, dtype=None, device=None):
return as_raw_tensor(obj, device=device)
def _(data: DeviceTensorND):
return RawTensor(put(data))
def _(array: np.ndarray, dtype=None, device=None):
device = None if device is None else as_device(device).to_c()
from .tracing import exclude_from_trace, trace
import contextlib
import functools
import typing
import weakref
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.core import OpBase, apply
from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor
class TraceMismatchError(RuntimeError):
active_trace = None
skip_tracing = False
def exclude_from_trace():
global skip_tracing
if skip_tracing:
skip_tracing = True
if active_trace is not None:
skip_tracing = False
class TensorInfo:
__slots__ = (
# collected attributes
# resources for execution
def __init__(self):
self.exported = None
self.data_read = None
self.shape_read = None
self.value_read = None
self.bound_data = None
self.data_setter = None
self.shape_reader = None
self.value_reader = None
self.data_reader = None
class trace:
def __new__(cls, *args, **kwargs):
if not args:
return functools.partial(cls, **kwargs)
self = super().__new__(cls)
self.__init__(*args, **kwargs)
return self
def __init__(self, function, symbolic=False, capture_as_const=False):
self.__wrapped__ = function
self._symbolic = symbolic
self._capture_as_const = capture_as_const
self._capture_static_shape = False
self._untraced = True
self._tinfo = [] # handle -> TensorInfo
self._seq = []
self._pc = 0
self._graph = None
self._need_reset_nodes = None
self._lazy_eval_graph = None
self._lazy_eval_tensors = weakref.WeakSet()
self._active_tensors = weakref.WeakSet()
def _new_handle(self):
handle = len(self._tinfo)
info = TensorInfo()
return handle, info
def _apply_op(self, op, args):
assert not self._untraced
# check against trace
if self._pc >= len(self._seq):
raise TraceMismatchError("trace should end here, but more op observed")
record = self._seq[self._pc]
op_, ihandles, ohandles = record
if op != op_:
raise TraceMismatchError("op different from last time")
if len(ihandles) != len(args):
raise TraceMismatchError("op input size different from last time")
for h, x in zip(ihandles, args):
info = self._tinfo[h]
if info.external:
if (
x.__class__ is CompiledTensorProxy
and not self._tinfo[x._CompiledTensorProxy__handle].exported
raise TraceMismatchError(
"failed to capture: input was an external tensor "
"last time, got an internal tensor this time"
if info.bound_data:
if x.__class__ is CompiledTensorProxy:
raise TraceMismatchError(
"const capture violated: was an external tensor "
"last time, got an internal tensor this time"
if x._handle != info.bound_data._handle:
raise TraceMismatchError(
"const capture violated: got "
"a different tensor this time"
if info.dtype != x.dtype:
raise TraceMismatchError(
"failed to capture: different dtype from last time"
if info.device != x.device:
raise TraceMismatchError(
"failed to capture: different device from last time"
if x.__class__ is not CompiledTensorProxy:
raise TraceMismatchError(
"unexpected capture: trying to use an external tensor as input, "
"but that input was an internal tensor last time"
if x._CompiledTensorProxy__handle != h:
raise TraceMismatchError(
"mis-wiring: input edge to an data flow "
"graph node is different from last time"
self._pc += 1
outputs = tuple([CompiledTensorProxy(h) for h in ohandles])
return outputs
def _record_op(self, op, inputs, outputs):
if skip_tracing:
for x in inputs:
h = getattr(x, "_TraceMixin__handle", None)
if h is not None:
self._tinfo[h].data_read = True
ihandles = []
for x in inputs:
h = getattr(x, "_TraceMixin__handle", None)
if h is None or (not self._capture_as_const and self._tinfo[h].exported):
h, info = self._new_handle()
info.external = True
info.device = x.device
info.dtype = x.dtype
if self._capture_as_const:
info.bound_data = x
ohandles = []
for x in outputs:
h, info = self._new_handle()
info.external = False
TraceMixin._TraceMixin__inject(x, h)
self._seq.append((op, tuple(ihandles), tuple(ohandles)))
def _setup(self):
global active_trace
if active_trace:
raise NotImplementedError("sorry, not implemented: nested trace")
active_trace = self
if self._untraced:
if self._symbolic:
self._lazy_eval_graph = G.Graph()
if self._graph is None:
escaped_tensors = tuple(self._active_tensors)
if self._untraced:
for x in escaped_tensors:
info = self._tinfo[x._TraceMixin__handle]
info.data_read = True
if self._symbolic:
# eval lazy eval tensors
lazy_eval_tensors = tuple(self._lazy_eval_tensors)
if lazy_eval_tensors:
readers = [
for x in lazy_eval_tensors
for r, x in zip(readers, lazy_eval_tensors):
assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
self._lazy_eval_graph = None
self._lazy_eval_tensors = None
self._untraced = False
if self._pc != len(self._seq):
raise TraceMismatchError("premature end")
for x in escaped_tensors:
assign_raw_tensor(x, as_raw_tensor(x._dev_tensor()))
self._pc = 0
active_trace = None
def _begin_excluded_region(self):
if self._untraced:
# conditionally reading a compiled tensor in excluded region
# is permitted, so we have to assume every tensor might be read
for x in self._active_tensors:
info = self._tinfo[x._TraceMixin__handle]
info.exported = True
info.data_read = True
def _compile(self):
graph = self._graph = G.Graph()
# graph.options.graph_opt_level = 0
need_reset_nodes = self._need_reset_nodes = []
# links enforce ordering of I/O nodes
links = ()
for op, ihandles, ohandles in self._seq:
ivars = []
readers = []
for h in ihandles:
info = self._tinfo[h]
if not hasattr(info, "varnode"):
assert info.external
if info.bound_data:
info.varnode = graph.make_const(info.bound_data._dev_tensor())
opnode = info.data_setter = G.InputNode(
*links, device=info.device, dtype=info.dtype, graph=graph
info.varnode, *links = opnode.outputs
ovars = apply(op, *ivars)
assert len(ovars) == len(ohandles)
for h, v in zip(ohandles, ovars):
info = self._tinfo[h]
info.varnode = v
def add_reader(opnode):
nonlocal links
links = opnode.outputs
if info.data_read:
# Shape can be obtained from data so doesn't need its own
# output node. On the other hand, value is read separately
# to leverage eager h2d copy
info.shape_read = False
opnode = info.data_reader = G.OutputNode(v, *links)
if info.value_read:
opnode = info.value_reader = G.ValueOutputNode(v, *links)
if info.shape_read:
opnode = info.shape_reader = G.AttrOutputNode(v, *links)
def _reset_exec_env(self):
for opnode in self._need_reset_nodes:
def _require_shape(self, handle):
info = self._tinfo[handle]
info.shape_read = True
def _require_value(self, handle):
info = self._tinfo[handle]
info.value_read = True
def _require_data(self, handle):
info = self._tinfo[handle]
info.data_read = True
def __call__(self, *args, **kwargs):
with self._setup():
return self.__wrapped__(*args, **kwargs)
class CompiledTensorProxy(RawTensor):
Duck-typed RawTensor
def __init__(self, handle):
self.__handle = handle
self.__info = active_trace._tinfo[handle]
self.__shape = None
self.__data = None
self.__value = None
def dtype(self):
return self.__info.varnode.dtype
def device(self):
return self.__info.varnode.device
def shape(self):
if self.__shape is None:
if self.__info.shape_read:
self.__shape = self.__info.shape_reader.get_value().shape
elif self.__info.data_read:
self.__shape = self._dev_tensor().shape
raise TraceMismatchError("shape of this tensor is not read in trace")
return self.__shape
def numpy(self):
if self.__value is None:
if self.__info.value_read:
self.__value = self.__info.value_reader.get_value()
elif self.__info.data_read:
self.__value = self._dev_tensor().numpy()
raise TraceMismatchError("value of this tensor is not read in trace")
return self.__value
def _dev_tensor(self):
if self.__data is None:
if not self.__info.data_read:
raise TraceMismatchError("raw data of this tensor is not read in trace")
self.__data = self.__info.data_reader.get_value()
return self.__data
def __del__(self):
if self.__info.shape_read and self.__shape is not None:
if self.__info.value_read and self.__value is not None:
if self.__info.data_read and self.__data is not None:
class LazyEvalTensor(RawTensor):
def __init__(self, varnode):
self.__varnode = varnode
def dtype(self):
return self.__varnode.dtype
def device(self):
return self.__varnode.device
def shape(self):
return self.__varnode.shape
def numpy(self):
raise RuntimeError("cannot read value during symbolic tracing")
def _dev_tensor(self):
raise RuntimeError("cannot access data during symbolic tracing")
class TraceMixin:
__subclass_cache = {}
def __inject(self, handle):
cache = __class__.__subclass_cache
cls = self.__class__
subcls = cache.get(cls)
if subcls is None:
subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {})
self.__class__ = subcls
self.__handle = handle
self.__cls = cls
return self
def __restore(self):
cls = self.__cls
del self.__handle
del self.__cls
self.__class__ = cls
return self
def shape(self):
if not skip_tracing:
return super().shape
def numpy(self):
if not skip_tracing:
return super().numpy()
def _dev_tensor(self):
if not skip_tracing:
return super()._dev_tensor()
class TracedRawTensor(TraceMixin, RawTensor):
class TracedLazyTensor(TraceMixin, LazyEvalTensor):
def assign_raw_tensor(lhs, rhs):
handle = rhs._handle
lhs.__class__ = RawTensor
# this hook turns RawTensor into LazyEvalTensor
def apply_symbolic_mode(op: OpDef, *args: RawTensor):
graph = active_trace._lazy_eval_graph
ivars = [
getattr(x, "_LazyEvalTensor__varnode", None)
or graph.make_const(x._dev_tensor())
for x in args
ovars = apply(op, *ivars)
outputs = [LazyEvalTensor(v) for v in ovars]
return outputs
def apply_compiled_mode(op: OpDef, *args: RawTensor):
if skip_tracing:
args = [
as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
for x in args
return apply.super(op, *args)
return active_trace._apply_op(op, args)
# this hook injects TraceMixin
def apply_with_tracing(op: OpDef, *args: RawTensor):
outputs = apply.super(op, *args)
active_trace._record_op(op, args, outputs)
return outputs
# @apply.register()
# def _(op: Const, *args: RawTensor):
# return active_trace._apply_const(op, args)
class BrokenRawTensor(RawTensor):
def __getattribute__(self, _):
raise RuntimeError("broken due to misuse of tracing")
def __setattr__(self, *_):
raise RuntimeError("broken due to misuse of tracing")
......@@ -23,10 +23,29 @@ namespace py = pybind11;
using namespace mgb;
using namespace imperative;
namespace {
template<typename XTensorND>
auto def_TensorND(py::object parent, const char* name) {
return py::class_<XTensorND>(parent, name)
.def_property_readonly("shape", py::overload_cast<>(&XTensorND::shape, py::const_))
.def_property_readonly("dtype", py::overload_cast<>(&XTensorND::dtype, py::const_))
.def_property_readonly("comp_node", py::overload_cast<>(&XTensorND::comp_node, py::const_))
.def("copy_from", &XTensorND::template copy_from<DeviceTensorStorage>)
.def("copy_from", &XTensorND::template copy_from<HostTensorStorage>)
.def("copy_from_fixlayout", py::overload_cast<const DeviceTensorND&>(
&XTensorND::template copy_from_fixlayout<DeviceTensorStorage>))
.def("copy_from_fixlayout", py::overload_cast<const HostTensorND&>(
&XTensorND::template copy_from_fixlayout<HostTensorStorage>));
} // namespace
void init_common(py::module m) {
py::class_<CompNode>(m, "CompNode")
auto&& PyCompNode = py::class_<CompNode>(m, "CompNode")
.def(py::init(py::overload_cast<const std::string&>(&CompNode::load)))
.def("create_event", &CompNode::create_event, py::arg("flags") = 0ul)
.def("__str__", &CompNode::to_string_logical)
.def_static("_sync_all", &CompNode::sync_all)
.def(py::self == py::self)
......@@ -40,19 +59,30 @@ void init_common(py::module m) {
return CompNode::load(cn);
py::class_<CompNode::Event, std::shared_ptr<CompNode::Event>>(PyCompNode, "Event")
.def("record", &CompNode::Event::record)
.def("wait", &CompNode::Event::host_wait);
py::implicitly_convertible<std::string, CompNode>();
py::class_<DeviceTensorND>(m, "DeviceTensorND")
.def_property_readonly("shape", py::overload_cast<>(&DeviceTensorND::shape, py::const_))
.def_property_readonly("dtype", py::overload_cast<>(&DeviceTensorND::dtype, py::const_))
.def_property_readonly("comp_node", py::overload_cast<>(&DeviceTensorND::comp_node, py::const_))
def_TensorND<DeviceTensorND>(m, "DeviceTensorND")
.def("numpy", [](const DeviceTensorND& self) {
HostTensorND hv;
return py::handle(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
def_TensorND<HostTensorND>(m, "HostTensorND")
.def(py::init([](py::array data, CompNode cn, DType dtype) {
if (!cn.valid()) {
throw py::type_error("device must not be None");
return npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype);
.def("numpy", [](const HostTensorND& self) {
return py::reinterpret_steal<py::object>(npy::ndarray_from_tensor(self, npy::ShareType::TRY_SHARE));
py::class_<cg::OperatorNodeConfig>(m, "OperatorNodeConfig")
......@@ -12,6 +12,7 @@
#include "./graph_rt.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/imperative.h"
#include "./helper.h"
......@@ -29,29 +30,44 @@ auto def_rendezvous(py::object m, const char* name) {
.def(py::init([](){return std::make_shared<Rendezvous<T>>();}))
.def("set", [](Rendezvous<T>& r, T v) {r.set(std::move(v));})
.def("get", [](Rendezvous<T>& r) {return r.get();}, py::call_guard<py::gil_scoped_release>())
.def("drop", &Rendezvous<T>::drop)
.def("reset", &Rendezvous<T>::reset);
using TensorAttr = LogicalTensorDesc;
using HostNDWithEvent = std::pair<HostTensorND, std::shared_ptr<CompNode::Event>>;
void init_graph_rt(py::module m) {
def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous");
def_rendezvous<HostNDWithEvent>(m, "HostTensorNDRendezvous");
def_rendezvous<TensorAttr>(m, "TensorAttrRendezvous");
py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode")
.def_property_readonly("owner", [](cg::VarNode* v) {return v->owner_opr();})
.def_property_readonly("graph", [](cg::VarNode* v) {return v->owner_graph();})
.def_property_readonly("name", py::overload_cast<>(&VarNode::name, py::const_))
.def_property_readonly("dtype", [](cg::VarNode* v) {return v->dtype();})
.def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();});
.def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();})
.def_property_readonly("shape", [](cg::VarNode* v) -> const TensorShape* {
auto&& mgr = v->owner_graph()->static_infer_manager();
auto&& type = mgr.get_infer_type(v);
using InferType = cg::static_infer::InferType;
if (!(type.shape & (InferType::CONST | InferType::RT_STATIC))) {
return nullptr;
return mgr.infer_shape_fallible(v);
py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode")
.def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();})
.def_property_readonly("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_))
.def_property_readonly("inputs", [](cg::OperatorNodeBase* opr) {
return to_tuple(opr->input());
.def_property_readonly("outputs", [](cg::OperatorNodeBase* opr) {
return to_tuple(opr->output());
return to_tuple(opr->usable_output());
py::class_<cg::AsyncExecutable>(m, "AsyncExecutable")
......@@ -117,7 +133,7 @@ void init_graph_rt(py::module m) {
common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) {
cg::VarNodeArray vinputs(inputs.begin(), inputs.end());
auto opr = OpDef::apply_on_var_node(def, vinputs);
auto outputs = opr->output();
auto outputs = opr->usable_output();
return to_tuple(outputs);
py::arg(), py::arg(), py::arg("graph") = py::none());
......@@ -125,6 +141,7 @@ void init_graph_rt(py::module m) {
auto input_callback = [](auto callback,
const CompNode& comp_node,
const DType& dtype,
const TensorShape& shape,
const std::vector<cg::VarNode*>& inputs,
cg::ComputingGraph* graph) {
if (!graph) {
......@@ -135,7 +152,7 @@ void init_graph_rt(py::module m) {
auto soutputs = opr::InputCallback::make(*graph, std::move(callback), comp_node, dtype, sinputs);
auto soutputs = opr::InputCallback::make(*graph, std::move(callback), comp_node, dtype, shape, sinputs);
std::vector<VarNode*> outputs;
for (auto i : soutputs) {
......@@ -144,26 +161,40 @@ void init_graph_rt(py::module m) {
return outputs;
m.def("make_shared", [](cg::ComputingGraph* graph, const DeviceTensorND& data) {
return opr::SharedDeviceTensor::make(*graph, std::make_shared<DeviceTensorND>(data)).node();
m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) {
if (!cn.valid()) {
throw py::type_error("device must not be None");
auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype);
opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node();
m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback,
const CompNode& comp_node,
const DType& dtype,
const TensorShape& shape,
const std::vector<cg::VarNode*>& inputs,
cg::ComputingGraph* graph) {
return input_callback([f=std::move(callback)](){py::gil_scoped_acquire _; return f();}, comp_node, dtype, inputs, graph);
return input_callback([f=std::move(callback)](){py::gil_scoped_acquire _; return f();}, comp_node, dtype, shape, inputs, graph);
py::arg(), py::arg(), py::arg(), py::arg() = py::tuple(), py::arg("graph") = py::none());
py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none());
m.def("input_callback", [input_callback](std::shared_ptr<Rendezvous<DeviceTensorND>> p,
const CompNode& comp_node,
const DType& dtype,
const TensorShape& shape,
const std::vector<cg::VarNode*>& inputs,
cg::ComputingGraph* graph) {
auto f = [p]() -> DeviceTensorND {
return p->get();
return input_callback(std::move(f), comp_node, dtype, inputs, graph);
return input_callback(std::move(f), comp_node, dtype, shape, inputs, graph);
py::arg(), py::arg(), py::arg(), py::arg() = py::tuple(), py::arg("graph") = py::none());
py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none());
auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, bool borrow = false) {
SymbolVarArray sinputs;
......@@ -193,6 +224,17 @@ void init_graph_rt(py::module m) {
return output_callback(std::move(f), std::move(inputs));
m.def("value_output_callback", [output_callback](std::shared_ptr<Rendezvous<HostNDWithEvent>> p, std::vector<cg::VarNode*> inputs) {
auto f = [p](DeviceTensorND dv) {
HostNDWithEvent hv_with_event;
hv_with_event.second = dv.comp_node().create_event();
return output_callback(std::move(f), std::move(inputs), true);
m.def("attr_output_callback", [output_callback](std::shared_ptr<Rendezvous<TensorAttr>> p, std::vector<cg::VarNode*> inputs) {
auto f = [p](DeviceTensorND dv) {
p->set(TensorAttr{TensorLayout{dv.shape(), dv.dtype()}, dv.comp_node()});
......@@ -39,6 +39,7 @@ template<typename R>
class Rendezvous {
std::mutex m_lock;
int m_read_ahead = 0;
bool m_drop_next = false;
std::promise<R> m_promise;
Rendezvous() = default;
......@@ -47,6 +48,7 @@ public:
Rendezvous& operator=(const Rendezvous& rhs) = delete;
Rendezvous& operator=(Rendezvous&& rhs) {
m_drop_next = rhs.m_drop_next;
m_read_ahead = rhs.m_read_ahead;
m_promise = std::move(rhs.m_promise);
return *this;
......@@ -67,12 +69,28 @@ public:
return f.get();
void drop() {
mgb_assert(m_read_ahead <= 0);
mgb_assert(m_read_ahead >= -1);
if (m_read_ahead == -1) {
m_promise = {};
} else {
m_drop_next = true;
template<typename T>
void set(T&& value) {
mgb_assert(m_read_ahead >= 0);
mgb_assert(m_read_ahead <= 1);
if (m_drop_next) {
m_drop_next = false;
} else {
if (m_read_ahead == 1) {
m_promise = {};
......@@ -83,6 +101,7 @@ public:
m_promise = {};
m_read_ahead = 0;
m_drop_next = false;
......@@ -280,9 +280,12 @@ namespace detail {
bool load(handle src, bool convert) {
auto obj = reinterpret_steal<object>(src);
if (!isinstance<tuple>(obj)) {
if (!convert && !isinstance<tuple>(obj)) {
return false;
if (obj.is_none()) {
return true;
value.ndim = len(obj);
mgb_assert(value.ndim <= mgb::TensorShape::MAX_NDIM);
size_t i = 0;
......@@ -63,6 +63,7 @@ void init_imperative_rt(py::module m) {
return self.put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype));
}, py::arg(), py::arg("dtype") = py::none(), py::arg("device") = py::none())
.def("put", py::overload_cast<const DeviceTensorND&>(&Interpreter::Channel::put))
.def("delete", [](Interpreter::Channel& self, Interpreter::Handle handle) {
return self.del(handle);
......@@ -24,6 +24,12 @@ constexpr bool has_fastcall = true;
constexpr bool has_fastcall = false;
constexpr bool has_vectorcall = true;
constexpr bool has_vectorcall = false;
template<typename... Args>
struct invocable_with {
template<typename T>
......@@ -55,6 +61,9 @@ private:
std::aligned_storage_t<sizeof(T), alignof(T)> storage;
PyObject* vectorcall_slot;
inline T* inst() {
return reinterpret_cast<T*>(&storage);
......@@ -155,6 +164,51 @@ private:
// polyfills
struct tp_vectorcall {
static constexpr bool valid = HAS_MEMBER(T, tp_vectorcall);
static constexpr bool haskw = [](){if constexpr (valid)
if constexpr (std::is_invocable_v<T::tp_vectorcall, T, PyObject*const*, size_t, PyObject*>)
return true;
return false;}();
template<typename = void>
static PyObject* impl(PyObject* self, PyObject*const* args, size_t nargsf, PyObject *kwnames) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
if constexpr (haskw) {
CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf, kwnames));
} else {
if (kwnames && PyTuple_GET_SIZE(kwnames)) {
PyErr_SetString(PyExc_TypeError, "expect no keyword argument");
return nullptr;
CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf));
static constexpr Py_ssize_t offset = []() {if constexpr (valid) return offsetof(wrap_t, vectorcall_slot);
else return 0;}();
struct tp_call {
static constexpr bool provided = HAS_MEMBER(T, tp_call);
static constexpr bool static_form = invocable_with<T, PyObject*, PyObject*, PyObject*>{}(
[](auto&& t, auto... args) -> decltype(std::decay_t<decltype(t)>::tp_call(args...)) {});
static constexpr bool valid = provided || tp_vectorcall::valid;
template<typename = void>
static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ(inst->tp_call(args, kwargs));
static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call;
else if constexpr (provided) return impl<>;
else if constexpr (valid) return PyVectorcall_Call;
else return nullptr;}();
struct tp_new {
static constexpr bool provided = HAS_MEMBER(T, tp_new);
static constexpr bool varkw = std::is_constructible_v<T, PyObject*, PyObject*>;
......@@ -163,11 +217,14 @@ private:
template<typename = void>
static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
auto* self = type->tp_alloc(type, 0);
auto* ptr = reinterpret_cast<wrap_t*>(self)->inst();
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
if constexpr (has_vectorcall && tp_vectorcall::valid) {
reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>;
if constexpr (varkw) {
new(ptr) T(args, kwargs);
new(inst) T(args, kwargs);
} else {
new(ptr) T();
new(inst) T();
return self;
......@@ -190,22 +247,6 @@ private:
else return impl<>;}();
struct tp_call {
static constexpr bool valid = HAS_MEMBER(T, tp_call);
static constexpr bool static_form = invocable_with<T, PyObject*, PyObject*, PyObject*>{}(
[](auto&& t, auto... args) -> decltype(std::decay_t<decltype(t)>::tp_call(args...)) {});
template<typename = void>
static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ(inst->tp_call(args, kwargs));
static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call;
else if constexpr (valid) return impl<>;
else return nullptr;}();
class TypeBuilder {
std::vector<PyMethodDef> m_methods;
......@@ -228,9 +269,17 @@ public:
m_type.tp_name = T::tp_name;
m_type.tp_dealloc = tp_dealloc::value;
m_type.tp_vectorcall_offset = tp_vectorcall::offset;
m_type.tp_call = tp_call::value;
m_type.tp_basicsize = sizeof(wrap_t);
if constexpr (tp_vectorcall::valid) {
m_type.tp_flags |= _Py_TPFLAGS_HAVE_VECTORCALL;
m_type.tp_new = tp_new::value;
import numpy as np
from megengine.core.ops import builtin as ops
from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.jit import exclude_from_trace, trace
def test_trace():
for symbolic in [False, True]:
def f(x):
op = ops.Elemwise(mode="negate")
(y,) = apply(op, x)
return y
x = as_raw_tensor([1]).numpy()
y = f.__wrapped__(as_raw_tensor(x)).numpy()
for i in range(3):
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
def test_exclude_from_trace():
for symbolic in [False, True]:
def f(x):
neg = ops.Elemwise(mode="negate")
(x,) = apply(neg, x)
with exclude_from_trace():
if i % 2:
(x,) = apply(neg, x)
(x,) = apply(neg, x)
return x
x = as_raw_tensor([1]).numpy()
for i in range(3):
y = f.__wrapped__(as_raw_tensor(x)).numpy()
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
def test_print_in_trace():
for symbolic in [False]: # cannot read value in symbolic mode
def f(x):
nonlocal buf
neg = ops.Elemwise(mode="negate")
(x,) = apply(neg, x)
buf = x.numpy()
(x,) = apply(neg, x)
return x
buf = None
x = as_raw_tensor([1]).numpy()
for i in range(3):
y = f.__wrapped__(as_raw_tensor(x)).numpy()
z = buf
buf = None
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
np.testing.assert_equal(z, buf)
......@@ -37,6 +37,15 @@ void* ChannelImpl::put(const HostTensorND& value) {
return info;
void* ChannelImpl::put(const DeviceTensorND& data) {
auto info = alloc();
info->desc.layout = data.layout();
info->desc.comp_node = data.comp_node();
info->ptr = Tensor::make(data);
return info;
void ChannelImpl::del(void* handle) {
mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle);
......@@ -55,6 +55,7 @@ struct ChannelImpl : Interpreter::Channel {
~ChannelImpl() override;
Handle put(const HostTensorND& value) override;
Handle put(const DeviceTensorND& value) override;
void del(Handle) override;
......@@ -31,9 +31,10 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(InputCallback);
InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback,
const VarNodeArray& inputs,
const TensorShape& output_shape,
const OperatorNodeConfig& config)
: Super(&graph, config, "input_callback", inputs),
m_callback(callback) {
m_output_shape(output_shape), m_callback(callback) {
for (VarNode* i : inputs) {
......@@ -48,7 +49,8 @@ InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback,
SymbolVarArray InputCallback::make(cg::ComputingGraph& graph,
callback_t callback, CompNode comp_node,
DType dtype, const SymbolVarArray& inputs) {
DType dtype, const TensorShape& shape,
const SymbolVarArray& inputs) {
OperatorNodeConfig config;
......@@ -56,11 +58,22 @@ SymbolVarArray InputCallback::make(cg::ComputingGraph& graph,
auto vinputs = to_var_node_array(inputs);
auto opr = graph.insert_opr(
std::make_unique<InputCallback>(graph, callback, vinputs, config));
std::make_unique<InputCallback>(graph, callback, vinputs, shape, config));
return to_symbol_var_array(opr->output());
void InputCallback::init_output_static_infer_desc() {}
void InputCallback::init_output_static_infer_desc() {
if (m_output_shape.ndim) {
using namespace cg::static_infer;
auto &&mgr = owner_graph()->static_infer_manager();
auto infer_shape = [this](TensorShape &dest, const InpVal &) {
dest = m_output_shape;
return true;
{SourceType::CONSTANT, {}, infer_shape});
cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const {
NodeProp* prop = Super::do_make_node_prop();
......@@ -73,9 +86,23 @@ cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const {
void InputCallback::scn_do_execute() {
auto dev_tensor = m_callback();
if (m_output_shape.ndim) {
cg::OperatorNodeBase* InputCallback::shallow_copy(
const serialization::OprShallowCopyContext &ctx,
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
const OperatorNodeConfig &config) {
auto &&opr = opr_.cast_final_safe<InputCallback>();
auto* graph = ctx.owner_graph(opr, inputs);
return graph->insert_opr(std::make_unique<InputCallback>(*graph, opr.m_callback, inputs, opr.m_output_shape, config));
MGB_REG_OPR_SHALLOW_COPY(InputCallback, InputCallback::shallow_copy);
/* ================ OutputCallback ================== */
......@@ -122,6 +149,17 @@ void OutputCallback::scn_do_execute() {
cg::OperatorNodeBase* OutputCallback::shallow_copy(
const serialization::OprShallowCopyContext &ctx,
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
const OperatorNodeConfig &config) {
auto &&opr = opr_.cast_final_safe<OutputCallback>();
auto* graph = ctx.owner_graph(opr, inputs);
return graph->insert_opr(std::make_unique<OutputCallback>(opr.m_param, inputs, config));
MGB_REG_OPR_SHALLOW_COPY(OutputCallback, OutputCallback::shallow_copy);
/* ================ NopCallback ================== */
......@@ -22,6 +22,7 @@ struct Interpreter {
virtual ~Channel() = default;
virtual Handle put(const HostTensorND& value) = 0;
virtual Handle put(const DeviceTensorND& value) = 0;
virtual void del(Handle) = 0;
......@@ -17,6 +17,7 @@
#include "megbrain/opr/internal/param_tag_defs.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/param_defs.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/oprs/utils.h"
......@@ -33,17 +34,24 @@ public:
InputCallback(cg::ComputingGraph& graph,
callback_t callback,
const VarNodeArray& inputs,
const TensorShape& output_shape,
const OperatorNodeConfig &config);
static SymbolVarArray make(cg::ComputingGraph& graph,
callback_t callback,
CompNode comp_node,
DType dtype,
const TensorShape& shape,
const SymbolVarArray& inputs = {});
static cg::OperatorNodeBase* shallow_copy(
const serialization::OprShallowCopyContext &ctx,
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
const OperatorNodeConfig &config);
void scn_do_execute() override;
void init_output_static_infer_desc() override;
NodeProp* do_make_node_prop() const override;
TensorShape m_output_shape;
callback_t m_callback;
......@@ -63,6 +71,10 @@ public:
SymbolVar input) {
return make(std::move(param), SymbolVarArray{input});
static cg::OperatorNodeBase* shallow_copy(
const serialization::OprShallowCopyContext &ctx,
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
const OperatorNodeConfig &config);
void scn_do_execute() override;
void init_output_static_infer_desc() override;
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
想要评论请 注册