diff --git a/imperative/python/megengine/core/ops/special.py b/imperative/python/megengine/core/ops/special.py index db1503ee55f046c66bd0c60b63efd5cfc906aba7..e378b8f075dc58249928d1698a1cfae019d38a60 100644 --- a/imperative/python/megengine/core/ops/special.py +++ b/imperative/python/megengine/core/ops/special.py @@ -20,4 +20,4 @@ class Const: def __call__(self, *reference): Wrapper = type(reference[0]) - return (Wrapper(self.value, self.dtype, self.device),) + return (Wrapper(self.value, self.dtype, self.device, True),) diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index d4ff8b10f12cafc48678929fd2081b84b3da922f..3c0f919ae146d4da0a4fa0139bb35cb905fa5b76 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -19,10 +19,11 @@ import numpy as np from ...utils.comp_graph_tools import set_priority_to_id as _set_priority_to_id from .. import _imperative_rt from .._imperative_rt import GraphOptimizeOptions +from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode from .._imperative_rt.ops import BackwardGraph from .._wrap import device as as_device from ..ops.builtin import OpDef -from .core import OpBase, TensorBase, apply +from .core import OpBase, TensorBase class Graph(_imperative_rt.ComputingGraph): @@ -269,9 +270,8 @@ def optimize_for_inference(dest_vars, **kwargs): if kwargs: raise ValueError("unknown options: %s" % list(kwargs)) - res_vars = _imperative_rt.optimize_for_inference( - [i._node for i in dest_vars], inference_options - ) + dest_vars = [var._node for var in dest_vars] + res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options) return [VarNode(i) for i in res_vars] @@ -437,19 +437,25 @@ def _unwrap(x): return x -@apply.register() -def _(op: OpDef, *args: VarNode): +def apply_normal_op(op: OpDef, *args: VarNode): outputs = _imperative_rt.invoke_op(op, _unwrap(args)) return _wrap(outputs) -@apply.register() -def _(op: BackwardGraph, *args: VarNode): +def apply_backward_varnode(op: BackwardGraph, *args: VarNode): assert args graph = args[0].graph - return BackwardGraph.interpret( - op, lambda op, args: apply(op, *args), graph._make_const_for_backward, args + outputs = op.interpret( + op, + lambda op, args: apply_normal_op(op, *args), + graph._make_const_for_backward, + args, ) + outputs = [o._node if hasattr(o, "_node") else o for o in outputs] + return outputs + + +set_cpp_apply_backward_varnode(apply_backward_varnode) def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): diff --git a/imperative/python/megengine/jit/__init__.py b/imperative/python/megengine/jit/__init__.py index 5965fa1b817a58e0c92bfae2b9d988475528a1aa..9a794b7028b6d8c4c8b065a1add66c81833001c8 100644 --- a/imperative/python/megengine/jit/__init__.py +++ b/imperative/python/megengine/jit/__init__.py @@ -6,5 +6,23 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from ..core._imperative_rt.core2 import ( + set_cpp_apply_compiled_mode, + set_cpp_apply_const_compiled_mode, + set_cpp_apply_const_with_tracing, + set_cpp_apply_with_tracing, +) from .sublinear_memory_config import SublinearMemoryConfig -from .tracing import exclude_from_trace, trace +from .tracing import ( + apply_compiled_mode, + apply_const_compiled_mode, + apply_const_with_tracing, + apply_with_tracing, + exclude_from_trace, + trace, +) + +set_cpp_apply_with_tracing(apply_with_tracing) +set_cpp_apply_const_with_tracing(apply_const_with_tracing) +set_cpp_apply_compiled_mode(apply_compiled_mode) +set_cpp_apply_const_compiled_mode(apply_const_compiled_mode) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 2b5c48353c4b28ec7e174ea9e3fbbeba045c797b..0aa0010bfdc84519b7413ab70d14b0919b9b671d 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -18,8 +18,20 @@ import weakref import numpy as np -from ..core._imperative_rt import GraphProfiler -from ..core._imperative_rt.core2 import Tensor +from ..core._imperative_rt import GraphProfiler, common, put +from ..core._imperative_rt.core2 import Tensor as RawTensor +from ..core._imperative_rt.core2 import ( + TensorWeakRef, + apply, + call_level, + set_compiled, + set_symbolic, + set_tracing, + skip_tracing, + unset_compiled, + unset_symbolic, + unset_tracing, +) from ..core._imperative_rt.ops import ( CollectiveComm, GaussianRNG, @@ -29,10 +41,9 @@ from ..core._imperative_rt.ops import ( ) from ..core._trace_option import set_symbolic_shape from ..core._wrap import device as as_device +from ..core.ops.builtin import OpDef from ..core.ops.special import Const from ..core.tensor import megbrain_graph as G -from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply -from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor from .sublinear_memory_config import SublinearMemoryConfig @@ -45,7 +56,6 @@ class TraceMismatchError(RuntimeError): active_trace = None -skip_tracing = False def is_tracing(): @@ -63,11 +73,13 @@ def exclude_from_trace(): return try: skip_tracing = True + unset_tracing() if active_trace is not None: active_trace._begin_excluded_region() yield finally: skip_tracing = False + set_tracing() class TensorInfo: @@ -75,9 +87,6 @@ class TensorInfo: # collected attributes "external", "exported", - "data_read", - "shape_read", - "value_read", "device", "dtype", "shape", @@ -93,9 +102,6 @@ class TensorInfo: def __init__(self): self.exported = None - self.data_read = None - self.shape_read = None - self.value_read = None self.bound_data = None self.data_setter = None @@ -147,6 +153,8 @@ class trace: self._profiler = None self._graph_opt_level = opt_level self._symbolic_shape = symbolic_shape + self._handle2tensors = {} + self._handle2compiledtensors = {} self._reset() @@ -158,9 +166,9 @@ class trace: self._graph = None self._need_reset_nodes = None self._lazy_eval_graph = None - self._lazy_eval_tensors = weakref.WeakSet() + self._lazy_eval_tensors = set() self._lazy_eval_links = None - self._active_tensors = weakref.WeakSet() + self._active_tensors = set() self._tensor_remaps = None self._inputs_to_restore = None self._arg_bindings = None @@ -220,66 +228,72 @@ class trace: ) info.data_setter.set_value(x._dev_tensor()) else: - if x.__class__ is not CompiledTensorProxy: - if x not in self._tensor_remaps: - raise TraceMismatchError( - "unexpected capture: trying to use an external tensor as " - "input, but that input was an internal tensor last time" - ) - else: - x = self._tensor_remaps[x] - if x._CompiledTensorProxy__handle != h: - raise TraceMismatchError( - "mis-wiring: input edge to an data flow " - "graph node is different from last time" - ) + pass + # if x.__class__ is not CompiledTensorProxy: + # if x not in self._tensor_remaps: + # raise TraceMismatchError( + # "unexpected capture: trying to use an external tensor as " + # "input, but that input was an internal tensor last time" + # ) + # else: + # x = self._tensor_remaps[x] + # if x._CompiledTensorProxy__handle != h: + # raise TraceMismatchError( + # "mis-wiring: input edge to an data flow " + # "graph node is different from last time" + # ) self._pc += 1 - outputs = tuple([CompiledTensorProxy(h) for h in ohandles]) - self._active_tensors.update(outputs) + for h in ohandles: + t = CompiledTensorProxy(h) + t._dev_tensor() + self._handle2compiledtensors[h] = t + outputs = [self._handle2tensors[h] for h in ohandles] + self._active_tensors.update([TensorWeakRef(o) for o in outputs]) return outputs - def _apply_const(self, op, args): + def _apply_const(self, value, dtype, device): assert not self._untraced # check against trace if self._pc >= len(self._seq): raise TraceMismatchError("trace should end here, but more op observed") record = self._seq[self._pc] op_, ihandles, ohandles = record - assert isinstance(op_, Const) - - eq = op_.value == op.value - if not isinstance(eq, bool): - eq = all(eq) - if not eq: - raise TraceMismatchError( - "const tensor violated: got a different tensor this time" - ) + assert isinstance(op_, str) and op_ == "Const" + + # TODO : assert on const value + # eq = value == self._tinfo[ohandles[0]].bound_data.numpy() + # if not isinstance(eq, bool): + # eq = all(eq) + # if not eq: + # raise TraceMismatchError( + # "const tensor violated: got a different tensor this time" + # ) self._pc += 1 (h,) = ohandles - outputs = tuple([self._tinfo[h].bound_data]) + outputs = [self._tinfo[h].bound_data] return outputs def _record_op(self, op, inputs, outputs): if skip_tracing: for x in inputs: - h = getattr(x, "_TraceMixin__handle", None) - if h is not None: - self._tinfo[h].data_read = True + h = getattr(x, "mixin_handle", -1) + if h >= 0: + x.data_read = True return ihandles = [] for x in inputs: - h = getattr(x, "_TraceMixin__handle", None) - if h is None or (not self._capture_as_const and self._tinfo[h].exported): + h = getattr(x, "mixin_handle", -1) + if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): h, info = self._new_handle() info.external = True info.device = x.device info.dtype = x.dtype info.shape = x.shape if self._capture_as_const: - info.bound_data = x + info.bound_data = RawTensor(x.numpy(), x.dtype, x.device, False) ihandles.append(h) @@ -288,17 +302,18 @@ class trace: h, info = self._new_handle() ohandles.append(h) info.external = False - TraceMixin._TraceMixin__inject(x, h) + x.mixin_handle = h + self._handle2tensors[h] = x self._seq.append((op, tuple(ihandles), tuple(ohandles))) - self._active_tensors.update(outputs) + self._active_tensors.update([TensorWeakRef(o) for o in outputs]) - def _record_const(self, op, outputs): + def _record_const(self, outputs): if skip_tracing: (x,) = outputs - h = getattr(x, "_TraceMixin__handle", None) - if h is not None: - self._tinfo[h].data_read = True + h = getattr(x, "mixin_handle", -1) + if h >= 0: + x.data_read = True return (x,) = outputs @@ -310,8 +325,9 @@ class trace: info.shape = x.shape info.bound_data = x info.is_const = True - TraceMixin._TraceMixin__inject(x, h) - self._seq.append((op, tuple(), tuple(ohandles))) + x.mixin_handle = h + self._handle2tensors[h] = x + self._seq.append(("Const", tuple(), tuple(ohandles))) def _set_active(self, active: bool): global active_trace @@ -324,11 +340,8 @@ class trace: active_trace = None def _init_trace(self, symbolic: bool): - apply.enable(apply_with_tracing) - apply.enable(apply_const_with_tracing) if symbolic: - apply.enable(apply_symbolic_mode) - apply.enable(apply_const_symbolic_mode) + set_symbolic() self._lazy_eval_graph = G.Graph() self._apply_graph_options(self._lazy_eval_graph) self._lazy_eval_links = () @@ -339,10 +352,7 @@ class trace: return escaped_tensors def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): - readers = [ - G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] - for x in lazy_eval_tensors - ] + readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] self._apply_graph_options(lazy_eval_graph) # FIXME if self._graph_opt_level is not None: @@ -353,20 +363,22 @@ class trace: lazy_eval_graph.compile(*lazy_eval_links, *readers) lazy_eval_graph() for r, x in zip(readers, lazy_eval_tensors): - assign_raw_tensor(x, as_raw_tensor(r.op.get_value())) + x()._handle = RawTensor(r.op.get_value())._handle @contextlib.contextmanager def _setup(self): interrupted = False def do_enter(): + set_tracing() self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape) self._set_active(True) if self._untraced: self._init_trace(self._symbolic) else: - apply.enable(apply_compiled_mode) - apply.enable(apply_const_compiled_mode) + # disable symbolic mode + unset_symbolic() + set_compiled() if self._graph is None: self._compile() self._graph.execute() @@ -375,12 +387,12 @@ class trace: escaped_tensors = self._take_escaped_tensors() if self._untraced: for x in escaped_tensors: - info = self._tinfo[x._TraceMixin__handle] - info.data_read = True - x._TraceMixin__restore() + info = self._tinfo[x().mixin_handle] + x().data_read = True + x().mixin_handle = -1 if self._inputs_to_restore: for x in self._inputs_to_restore: - x._TraceMixin__restore() + x.mixin_handle = -1 if self._symbolic and ( self._lazy_eval_tensors or self._lazy_eval_links ): @@ -399,7 +411,7 @@ class trace: if self._pc == len(self._seq): for x in escaped_tensors: try: - assign_raw_tensor(x, as_raw_tensor(x._dev_tensor())) + assign_raw_tensor(x(), RawTensor(x()._dev_tensor())) except TraceMismatchError: # TraceMismatchError thrown in do_exit pass @@ -409,22 +421,20 @@ class trace: # reset status self._pc = 0 self._tensor_remaps = None - apply.disable(apply_with_tracing) - apply.disable(apply_const_with_tracing) - apply.disable(apply_symbolic_mode) - apply.disable(apply_const_symbolic_mode) - apply.disable(apply_compiled_mode) - apply.disable(apply_const_compiled_mode) self._set_active(False) - # Restore global variable set_symbolic_shape(self._save_symbolic_shape) + unset_compiled() + unset_symbolic() + unset_tracing() def do_exit(): + unset_tracing() if not self._untraced and self._pc != len(self._seq): raise TraceMismatchError("premature end") if not self._symbolic or not self._untraced: for x in self._active_tensors: - x._dev_tensor() + x()._dev_tensor() + x().mixin_handle = -1 try: do_enter() @@ -447,9 +457,9 @@ class trace: # conditionally reading a compiled tensor in excluded region # is permitted, so we have to assume every tensor might be read for x in self._active_tensors: - info = self._tinfo[x._TraceMixin__handle] + info = self._tinfo[x().mixin_handle] info.exported = True - info.data_read = True + x().data_read = True def _apply_graph_options(self, graph): @@ -503,7 +513,7 @@ class trace: in_out_links += opnode.outputs[1:] for op, ihandles, ohandles in self._seq: - if isinstance(op, Const): + if isinstance(op, str) and op == "Const": assert len(ihandles) == 0 (h,) = ohandles info = self._tinfo[h] @@ -554,7 +564,10 @@ class trace: io_links = (info.varnode,) ivars.append(info.varnode) + + ivars = [RawTensor(ivar) for ivar in ivars] ovars = apply(op, *ivars) + ovars = [x._varnode for x in ovars] if require_links and len(ovars) > 0: io_links = (ovars[0],) assert len(ovars) == len(ohandles) @@ -568,7 +581,8 @@ class trace: readers.append(opnode.outputs[0]) in_out_links = opnode.outputs - if info.data_read: + x = self._handle2tensors[h] + if x.data_read: # Shape can be obtained from data so doesn't need its own # output node. On the other hand, value is read separately # to leverage eager h2d copy @@ -581,6 +595,7 @@ class trace: if info.shape_read: opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links) add_reader(opnode) + # FIXME if self._graph_opt_level is not None: graph.options.graph_opt_level = self._graph_opt_level @@ -593,18 +608,6 @@ class trace: for opnode in self._need_reset_nodes: opnode.reset() - def _require_shape(self, handle): - info = self._tinfo[handle] - info.shape_read = True - - def _require_value(self, handle): - info = self._tinfo[handle] - info.value_read = True - - def _require_data(self, handle): - info = self._tinfo[handle] - info.data_read = True - def __call__(self, *args, **kwargs): if is_tracing(): return self.__wrapped__(*args, **kwargs) @@ -728,8 +731,9 @@ class trace: dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k ) + set_tracing() for op, ihandles, ohandles in self._seq: - if isinstance(op, Const): + if isinstance(op, str) and op == "Const": assert len(ihandles) == 0 (h,) = ohandles info = self._tinfo[h] @@ -750,7 +754,9 @@ class trace: info.bound_data.numpy(), dtype=info.dtype, device=dumped_device ) ivars.append(h2v[h]) + ivars = [RawTensor(ivar) for ivar in ivars] ovars = apply(op, *ivars) + ovars = [x._varnode for x in ovars] assert len(ovars) == len(ohandles) h2v.update(zip(ohandles, ovars)) @@ -761,6 +767,7 @@ class trace: v.name = output_names[i] dest_vars.append(v) + dest_vars = [G.VarNode(var) for var in dest_vars] if optimize_for_inference: dest_vars = G.optimize_for_inference(dest_vars, **kwargs) @@ -782,15 +789,15 @@ class trace: info.external = False info.device = x.device info.dtype = x.dtype - info.shape = x.shape - TraceMixin._TraceMixin__inject(x, h) + info.shape = x.numpy().shape + x.mixin_handle = h + self._handle2tensors[h] = x self._inputs_to_restore.append(x) return h self._arg_bindings = [] for i, x in enumerate(args): - x = find_raw_tensor(x) - if x is None: + if not isinstance(x, RawTensor): raise TypeError( "positional arguments should all be tensor " "but args[%d] cannot be recognized as one" % i @@ -799,8 +806,7 @@ class trace: self._kwarg_bindings = {} for k, x in kwargs.items(): - x = find_raw_tensor(x) - if x is not None: + if isinstance(x, RawTensor): self._kwarg_bindings[k] = record_input(x) else: if len(args) != len(self._arg_bindings): @@ -809,8 +815,7 @@ class trace: self._tensor_remaps = {} for i, (h, x) in enumerate(zip(self._arg_bindings, args)): - x = find_raw_tensor(x) - if x is None: + if not isinstance(x, RawTensor): raise TypeError( "positional arguments should all be tensor " "but args[%d] cannot be recognized as one" % i @@ -825,8 +830,7 @@ class trace: kwargs_tensors = {} for k, x in kwargs.items(): - x = find_raw_tensor(x) - if x is not None: + if isinstance(x, RawTensor): kwargs_tensors[k] = x if set(kwargs_tensors) != set(self._kwarg_bindings): too_many = set(kwargs_tensors) - set(self._kwarg_bindings) @@ -877,18 +881,17 @@ class trace: self._output_bindings = [] for i, x in enumerate(outputs): - x = find_raw_tensor(x) - if x is None: + if not isinstance(x, RawTensor): raise TypeError("every item of return value should be tensor") if self._untraced: - if not isinstance(x, TraceMixin): + h = x.mixin_handle + if h < 0: raise RuntimeError("output is not computed from inputs") - h = x._TraceMixin__handle self._output_bindings.append(h) else: - if not isinstance(x, CompiledTensorProxy): + h = x.mixin_handle + if h not in self._handle2compiledtensors: raise RuntimeError("output is not computed from inputs") - h = x._CompiledTensorProxy__handle if h != self._output_bindings[i]: raise TraceMismatchError( "retval[%s] is a different tensor than last time" @@ -912,7 +915,7 @@ class trace: ) -class CompiledTensorProxy(RawTensor): +class CompiledTensorProxy: """ Duck-typed RawTensor """ @@ -924,6 +927,8 @@ class CompiledTensorProxy(RawTensor): self.__shape = None self.__data = None self.__value = None + self.__tensor = active_trace._handle2tensors[handle] + self.__tensor.mixin_handle = handle @property def dtype(self): @@ -938,19 +943,19 @@ class CompiledTensorProxy(RawTensor): if self._isscalar: return () if self.__shape is None: - if self.__info.shape_read: + if self.__tensor.shape_read: self.__shape = self.__info.shape_reader.get_value().shape - elif self.__info.data_read: - self.__shape = self._dev_tensor().shape + elif self.__tensor.data_read: + self.__shape = self.__tensor._dev_tensor().shape else: raise TraceMismatchError("shape of this tensor is not read in trace") return self.__shape def numpy(self): if self.__value is None: - if self.__info.value_read: + if self.__tensor.value_read: self.__value = self.__info.value_reader.get_value() - elif self.__info.data_read: + elif self.__tensor.data_read: self.__value = self._dev_tensor().numpy() else: raise TraceMismatchError("value of this tensor is not read in trace") @@ -960,9 +965,11 @@ class CompiledTensorProxy(RawTensor): def _dev_tensor(self): if self.__data is None: - if not self.__info.data_read: + if not self.__tensor.data_read: raise TraceMismatchError("raw data of this tensor is not read in trace") self.__data = self.__info.data_reader.get_value() + self.__tensor._reset(RawTensor(self.__data)) + self.__tensor.mixin_handle = self.__handle return self.__data def _drop(self): @@ -975,132 +982,31 @@ class CompiledTensorProxy(RawTensor): return def __del__(self): - if self.__info.shape_read and self.__shape is not None: + if self.__tensor.shape_read and self.__shape is not None: self.__info.shape_reader.drop_value() - if self.__info.value_read and self.__value is not None: - self.__info.value_reader.drop_value() - if self.__info.data_read and self.__data is not None: + # if self.__tensor.value_read and self.__value is not None: + # self.__info.value_reader.drop_value() + if self.__tensor.data_read and self.__data is not None: self.__info.data_reader.drop_value() -class LazyEvalTensor(RawTensor): - def __init__(self, varnode, isscalar=False): - super().__init__() - self.__varnode = varnode - self._isscalar = isscalar - - @property - def dtype(self): - return self.__varnode.dtype - - @property - def device(self): - return self.__varnode.device - - @property - def shape(self): - if self._isscalar: - return () - return self.__varnode.shape - - def numpy(self): - ret = self.__varnode.value - if self._isscalar: - ret = ret.squeeze() - return ret - - def _drop(self): - return - - def _swap_in(self): - return - - def _swap_out(self): - return - - def _dev_tensor(self): - raise RuntimeError("cannot access data during symbolic tracing") - - -class TraceMixin: - __subclass_cache = {} - - def __inject(self, handle): - cache = __class__.__subclass_cache - cls = self.__class__ - subcls = cache.get(cls) - if subcls is None: - subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {}) - self.__class__ = subcls - self.__handle = handle - self.__cls = cls - return self - - def __restore(self): - cls = self.__cls - del self.__handle - del self.__cls - self.__class__ = cls - return self - - @property - def shape(self): - if not skip_tracing: - active_trace._require_shape(self.__handle) - return super().shape - - def numpy(self): - if not skip_tracing: - active_trace._require_value(self.__handle) - return super().numpy() - - def _dev_tensor(self): - if not skip_tracing: - active_trace._require_data(self.__handle) - return super()._dev_tensor() - - def _drop(self): - return - - def _swap_in(self): - return - - def _swap_out(self): - return - - -class TracedRawTensor(TraceMixin, RawTensor): - pass - - -class TracedLazyTensor(TraceMixin, LazyEvalTensor): - pass - - def assign_raw_tensor(lhs, rhs): - handle = rhs._handle - # Keep isscalar of lhs - isscalar = lhs._isscalar - rhs.__dict__.clear() - lhs.__dict__.clear() - lhs.__class__ = RawTensor - lhs.__init__(handle, isscalar=isscalar) + lhs.__init__(rhs) -# this hook turns RawTensor into LazyEvalTensor -@apply.register() +# this hook turns RawTensor into LazyEvalTensor(varnode) def apply_symbolic_mode(op: OpDef, *args: RawTensor): graph = active_trace._lazy_eval_graph ivars = [] for x in args: - var = getattr(x, "_LazyEvalTensor__varnode", None) + var = getattr(x, "_varnode", None) if var: ivars.append(var) else: data_setter = G.InputNode( device=x.device, dtype=x.dtype, - shape=x.shape or (1,), + shape=x.numpy().shape or (1,), graph=graph, use_static_shape=True, ) @@ -1119,108 +1025,75 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): ivars[0] = opnode.outputs[0] active_trace._lazy_eval_links = (ivars[0],) - ovars = apply(op, *ivars) + ivars = [ + RawTensor(ivar._node) if hasattr(ivar, "_node") else RawTensor(ivar) + for ivar in ivars + ] + unset_symbolic() + outputs = apply(op, *ivars) + set_symbolic() if require_links: - active_trace._lazy_eval_links = (ovars[0],) + active_trace._lazy_eval_links = (outputs[0]._varnode,) - outputs = [LazyEvalTensor(v) for v in ovars] - active_trace._lazy_eval_tensors.update(outputs) + active_trace._lazy_eval_tensors.update([TensorWeakRef(o) for o in outputs]) return outputs -apply.disable(apply_symbolic_mode) - - -@apply.register() -def apply_const_symbolic_mode(op: Const, *args: RawTensor): +def apply_const_symbolic_mode(value, dtype, device): graph = active_trace._lazy_eval_graph - ret = LazyEvalTensor( - graph.make_const(op.value, dtype=op.dtype, device=op.device), isscalar=True - ) - active_trace._lazy_eval_tensors.add(ret) + # don't need to unset tracing + # because varnode construction will ignore tracing flag + ret = RawTensor(graph.make_const(value, dtype=dtype, device=device)) + active_trace._lazy_eval_tensors.add(TensorWeakRef(ret)) return (ret,) -apply.disable(apply_const_symbolic_mode) - - -@apply.register() def apply_compiled_mode(op: OpDef, *args: RawTensor): if skip_tracing: args = [ - as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x + RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x for x in args ] - return apply.super(op, *args) + unset_tracing() + ret = apply(op, *args) + set_tracing() + return ret return active_trace._apply_op(op, args) -apply.disable(apply_compiled_mode) - - -@apply.register() -def apply_const_compiled_mode(op: Const, *args: RawTensor): +def apply_const_compiled_mode(value, dtype, device, is_const): if skip_tracing: args = [ - as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x + RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x for x in args ] - return apply.super(op, *args) - return active_trace._apply_const(op, args) - - -apply.disable(apply_const_compiled_mode) + unset_tracing() + ret = RawTensor(value, dtype, device, False) + set_tracing() + return ret + return active_trace._apply_const(value, dtype, device) # this hook injects TraceMixin -@apply.register() def apply_with_tracing(op: OpDef, *args: RawTensor): - outputs = apply.super(op, *args) - active_trace._record_op(op, args, outputs) - return outputs - - -apply.disable(apply_with_tracing) - - -@apply.register() -def apply_const_with_tracing(op: Const, *args: RawTensor): - outputs = apply.super(op, *args) - active_trace._record_const(op, outputs) - return outputs - - -apply.disable(apply_const_with_tracing) - - -class BrokenRawTensor(RawTensor): - def __getattribute__(self, _): - raise RuntimeError("broken due to misuse of tracing") - - def __setattr__(self, *_): - raise RuntimeError("broken due to misuse of tracing") - - -@functools.singledispatch -def find_raw_tensor(x): - return None - - -@find_raw_tensor.register(RawTensor) -def _(x): - return x - + if active_trace._symbolic: + outputs = apply_symbolic_mode(op, *args) + else: + unset_tracing() + outputs = apply(op, *args) + set_tracing() -@find_raw_tensor.register(TensorWrapperBase) -def _(x): - x = getattr(x, "__wrapped__", None) - if x is not None: - return find_raw_tensor(x) + active_trace._record_op(op, args, outputs) + return list(outputs) -@find_raw_tensor.register(Tensor) -def _(x): - x = getattr(x, "_data", None) - if x is not None: - return find_raw_tensor(x) +def apply_const_with_tracing(value, dtype, device, is_const): + if active_trace._symbolic: + outputs = apply_const_symbolic_mode(value, dtype, device) + else: + unset_tracing() + outputs = (RawTensor(value, dtype, device, False),) + set_tracing() + active_trace._record_const(outputs) + return list(outputs) diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 04f43e1bae561f1f6230430afb4d8f184aafe298..fe3bc714db845c603fdc25e6c3ef5e05b23a7bdc 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -28,7 +28,7 @@ class Tensor(_Tensor, ArrayMethodMixin): dmap_callback = None q_dict = {"mode": None, "scale": None, "zero_point": None} - def __new__(cls, data, dtype=None, device=None): + def __new__(cls, data, dtype=None, device=None, is_const=False): if device is None: cn = get_default_device() elif isinstance(device, str): @@ -40,6 +40,7 @@ class Tensor(_Tensor, ArrayMethodMixin): assert isinstance(device, CompNode) cn = device + # import pdb; pdb.set_trace() if isinstance(data, _Tensor): obj = _Tensor.__new__(cls, data) else: @@ -47,7 +48,7 @@ class Tensor(_Tensor, ArrayMethodMixin): if 0 in data.strides: data = data.squeeze().reshape(data.shape) - obj = _Tensor.__new__(cls, data, dtype, cn) + obj = _Tensor.__new__(cls, data, dtype, cn, is_const) return obj @property diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index 786a54b45f54d1e34cc1031346a954f606d76dcc..8a5d79e79cb8c91758c321105dcdb007a9c5bb42 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -296,7 +296,9 @@ void accum_grad(std::shared_ptr& grad, std::shared_ptr&& delta) Tensor* args[2] = {grad.get(), delta.get()}; ctx.args = args; ctx.flags = grad->m_flags | delta->m_flags; - + if (is_tracing) { + ctx.flags |= Tensor::Flags::TRACE; + } grad = apply(ctx)[0]; } @@ -354,6 +356,9 @@ void GradKey::backward(std::vector tensors, std::vector #include @@ -23,6 +25,47 @@ namespace mgb::imperative::python { std::unique_ptr interpreter_for_py; +py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing, + cpp_apply_compiled_mode, cpp_apply_const_compiled_mode; + +py::object cpp_apply_backward_varnode; + +#define REGISTE_APPLY_FUNC(mode) \ + void set_##mode(py::object pyf) { \ + mode = pybind11::reinterpret_steal(pyf); \ + } + +REGISTE_APPLY_FUNC(cpp_apply_with_tracing) +REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing) +REGISTE_APPLY_FUNC(cpp_apply_compiled_mode) +REGISTE_APPLY_FUNC(cpp_apply_const_compiled_mode) +REGISTE_APPLY_FUNC(cpp_apply_backward_varnode) + +#undef REGISTE_APPLY_FUNC + +bool is_tracing = false; +bool is_symbolic = false; +bool is_compiled = false; + +int64_t call_level = 0; + + +#define SET_UNSET_PROP(mode) \ + void set_##mode() { \ + is_##mode = true; \ + } \ + void unset_##mode() { \ + is_##mode = false; \ + } \ + +SET_UNSET_PROP(tracing) +SET_UNSET_PROP(symbolic) +SET_UNSET_PROP(compiled) + +#undef SET_UNSET_PROP + +bool skip_tracing = false; + apply_result_t apply(ApplyContext& ctx) { // emulating scalar should be put to specific op's apply, e.g., // elementwise, reduce, typecvt. Currently it's still handled at python @@ -36,7 +79,7 @@ apply_result_t apply(ApplyContext& ctx) { } if (ctx.flags & Tensor::Flags::TRACE) { - // TODO: trace + return apply_trace(ctx); } else { SmallVector handles(ctx.nargs); for (size_t i = 0; i < ctx.nargs; ++i) { @@ -58,7 +101,6 @@ apply_result_t apply(ApplyContext& ctx) { PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */) { try { - // if (kwnames && PyTuple_GET_SIZE(kwnames)) { // PyErr_SetString(PyExc_TypeError, "keyword argument not allowed"); // return nullptr; @@ -67,6 +109,7 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje PyErr_SetString(PyExc_TypeError, "expect Op"); return nullptr; } + auto* op = args[0]; PyTypeObject* pytype = args[1]->ob_type; @@ -79,18 +122,23 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje SmallVector tensors(nargs); ctx.args = &tensors[0]; ctx.nargs = nargs; + if (strstr(op->ob_type->tp_name, "BackwardGraph")) { + ctx.backward = true; + } for (size_t i = 0; i < nargs; ++i) { - TensorWrapper* tw = TensorWrapper::cast_safe(args[i]); - if (!tw) { + if (TensorWrapper* tw = TensorWrapper::cast_safe(args[i])) { + auto* t = tensors[i] = tw->m_tensor.get(); + ctx.flags |= t->m_flags; + } else { PyErr_SetString(PyExc_TypeError, "expect Tensor"); return nullptr; } - auto* t = tensors[i] = tw->m_tensor.get(); - ctx.flags |= t->m_flags; } - // TODO: set TRACE flag + if (is_tracing) { + ctx.flags |= Tensor::Flags::TRACE; + } auto outputs = apply(ctx); size_t nout = outputs.size(); @@ -99,7 +147,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje ret[i] = TensorWrapper::make(pytype, std::move(outputs[i])); } return ret.release().ptr(); - } catch (std::exception& e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return nullptr; @@ -122,36 +169,116 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { } m_tensor = t->m_tensor; } else { - if (nargs != 3) { - throw py::type_error("expect 3 arguments"); - } - py::detail::loader_life_support life_sup; // required to cast DType - auto data = tup[0].cast(); - DType dtype = tup[1].cast(); - CompNode cn = tup[2].cast(); - - interpreter::Interpreter::Handle handle; - constexpr auto size_threshhold = TensorShape::MAX_NDIM; - if (data.size() > size_threshhold) { - handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype)); + if (nargs == 1) { + auto arg0 = PyTuple_GetItem(args, 0); + // for lazy_eval_tensor + if (strstr(arg0->ob_type->tp_name, "VarNode")) { + if (PyObject_HasAttrString(arg0, "_node")) { + arg0 = PyObject_GetAttrString(arg0, "_node"); + } + m_tensor = std::make_shared(py::handle(arg0).cast()); + } else { + // for DeviceTensorND + if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) { + auto dv = py::handle(arg0).cast(); + interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv); + m_tensor = std::make_shared(handle); + } else { + throw py::type_error("single argument is not tensor, varnode or devicetensor"); + } + } } else { - HostTensorND ret(cn); - handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); - } + py::detail::loader_life_support life_sup; // required to cast DType + auto data = tup[0].cast(); + DType dtype = tup[1].cast(); + CompNode cn = tup[2].cast(); + bool is_const = tup[3].cast(); + if (nargs != 4) { + throw py::type_error("expect 3 arguments"); + } + + // const op + if (is_const && is_tracing) { + py::object pyf; + if (is_compiled) { + pyf = cpp_apply_const_compiled_mode; + } else { + pyf = cpp_apply_const_with_tracing; + } + + auto ret = pyf(*tup); + auto py_ret = py::reinterpret_borrow(ret); + if (auto* t = cast_safe(py_ret[0].ptr())) { + m_tensor = t->m_tensor; + } + return; + } + + interpreter::Interpreter::Handle handle; + constexpr auto size_threshhold = TensorShape::MAX_NDIM; + if (data.size() > size_threshhold) { + handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype)); + } else { + HostTensorND ret(cn); + handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); + } + + m_tensor = std::make_shared(handle); - m_tensor = std::make_shared(handle); - if (data.ndim() == 0) { - m_tensor->m_flags |= Tensor::Flags::SCALAR; + if (data.ndim() == 0) { + m_tensor->m_flags |= Tensor::Flags::SCALAR; + } } } } +#define REGISTE_TENSORWRAPPER_FUNC(type, member) \ + PyObject* TensorWrapper::member() { \ + return py::cast(m_tensor->m_trace_info.member).release().ptr(); \ + } \ + void TensorWrapper::set_##member(PyObject* dest) { \ + auto py_dest = py::reinterpret_borrow(dest); \ + type real_dest = py_dest.cast(); \ + m_tensor->m_trace_info.member = real_dest; \ + } + +REGISTE_TENSORWRAPPER_FUNC(bool, data_read) +REGISTE_TENSORWRAPPER_FUNC(bool, value_read) +REGISTE_TENSORWRAPPER_FUNC(bool, shape_read) +REGISTE_TENSORWRAPPER_FUNC(int64_t, mixin_handle) + +#undef REGISTE_TENSORWRAPPER_FUNC + + +PyObject* TensorWrapper::handle() { + return py::cast(m_tensor->m_handle).release().ptr(); +} + + +void TensorWrapper::set_handle(PyObject* dest) { + auto py_dest = py::reinterpret_borrow(dest); + SharedHandle real_dest = py_dest.cast(); + auto&& t = std::move(m_tensor->m_handle); + m_tensor->m_handle = std::move(real_dest); +} + + PyObject* TensorWrapper::shape() { + if (!skip_tracing) { + set_shape_read(py::cast(true). release().ptr()); + } if (m_tensor->m_flags & Tensor::Flags::SCALAR) { return PyTuple_New(0); } - auto&& shape = m_tensor->shape(); + + TensorShape shape; + if (m_tensor->m_var) { + shape = m_tensor->m_var->shape(); + } else { + shape = m_tensor->shape(); + } + if (!shape.ndim) { Py_RETURN_NONE; } @@ -164,16 +291,38 @@ PyObject* TensorWrapper::shape() { PyObject* TensorWrapper::dtype() { + if (m_tensor->m_var) { + return py::cast(m_tensor->m_var->dtype()).release().ptr(); + } return py::cast(m_tensor->dtype()).release().ptr(); } PyObject* TensorWrapper::device() { + if (m_tensor->m_var) { + return py::cast(m_tensor->m_var->comp_node()).release().ptr(); + } return py::cast(m_tensor->comp_node()).release().ptr(); } PyObject* TensorWrapper::numpy() { + if (!skip_tracing) { + set_value_read(py::cast(true).release().ptr()); + } + if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) { + auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); + auto&& type = mgr.get_infer_type(m_tensor->m_var); + using InferType = cg::static_infer::InferType; + if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { + return nullptr; + } + auto* val = mgr.infer_value_fallible(m_tensor->m_var); + if (!val) { + return nullptr; + } + return py::cast(*val).attr("numpy")().release().ptr(); + } auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get()); auto arr = py::reinterpret_steal(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); if (!arr) return nullptr; @@ -184,6 +333,13 @@ PyObject* TensorWrapper::numpy() { return arr.release().ptr(); } +PyObject* TensorWrapper::varnode() { + if (m_tensor->m_var) { + return py::cast(m_tensor->m_var).release().ptr(); + } + return nullptr; +} + void TensorWrapper::reset(PyObject* tensor) { TensorWrapper* t = TensorWrapper::cast_safe(tensor); if (!t) { @@ -195,13 +351,22 @@ void TensorWrapper::reset(PyObject* tensor) { PyObject* TensorWrapper::detach() { PyObject* self = wrap_t::pycast(this); PyTypeObject* pytype = self->ob_type; - auto new_tensor = std::make_shared(m_tensor->m_handle); + + std::shared_ptr new_tensor; + if (m_tensor->m_handle.get()) { + new_tensor = std::make_shared(m_tensor->m_handle); + } else { + new_tensor = std::make_shared(m_tensor->m_var); + } auto ret = TensorWrapper::make(pytype, std::move(new_tensor)); return ret.release().ptr(); } PyObject* TensorWrapper::_dev_tensor(){ + if (!skip_tracing) { + set_data_read(py::cast(true).release().ptr()); + } auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get()); return py::cast(dev_tensor).release().ptr(); } @@ -227,11 +392,14 @@ PyObject* TensorWrapper::isscalar() { } } + void TensorWrapper::setscalar() { m_tensor->m_flags |= Tensor::Flags::SCALAR; } +PyMethodDef apply_def{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr}; + struct TensorWeakRef { std::weak_ptr wptr; @@ -262,6 +430,12 @@ void init_tensor(py::module m) { .def<&TensorWrapper::_swap_out>("_swap_out") .def<&TensorWrapper::_swap_in>("_swap_in") .def<&TensorWrapper::_drop>("_drop") + .def_getset<&TensorWrapper::varnode>("_varnode") + .def_getset<&TensorWrapper::data_read, &TensorWrapper::set_data_read>("data_read") + .def_getset<&TensorWrapper::value_read, &TensorWrapper::set_value_read>("value_read") + .def_getset<&TensorWrapper::shape_read, &TensorWrapper::set_shape_read>("shape_read") + .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("mixin_handle") + .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") .finalize(); if (!tensor_type) throw py::error_already_set(); py::setattr(m, "Tensor", tensor_type); @@ -296,6 +470,25 @@ void init_tensor(py::module m) { if (!grad_key_type) throw py::error_already_set(); py::setattr(m, "GradKey", grad_key_type); py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward)); + m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing); + m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing); + m.def("set_cpp_apply_compiled_mode", &set_cpp_apply_compiled_mode); + m.def("set_cpp_apply_const_compiled_mode", &set_cpp_apply_const_compiled_mode); + m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode); + + m.attr("skip_tracing") = &skip_tracing; + m.attr("call_level") = &call_level; + + py::class_(m, "SharedHandle") + .def(py::init()); + + m.def("set_tracing", &set_tracing); + m.def("unset_tracing", &unset_tracing); + m.def("set_symbolic", &set_symbolic); + m.def("unset_symbolic", &unset_symbolic); + m.def("set_compiled", &set_compiled); + m.def("unset_compiled", &unset_compiled); + } } // namespace mgb::imperative::python diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index f5ab62f4b09c40f743f446bdb35de7b78dd1d00c..5e96191416583d7e9de23c11baa3f76bdb144f08 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -30,13 +30,10 @@ struct ObjectPtr : B { } // namespace mgb::imperative::python #include "./grad_info.h" // for struct GradInfo +#include "./trace_info.h" // for struct TraceInfo namespace mgb::imperative::python { -struct TraceInfo { - -}; - extern std::unique_ptr interpreter_for_py; class SharedHandle { @@ -46,7 +43,9 @@ class SharedHandle { public: inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){ - interpreter_for_py->del(h); + if (h) { + interpreter_for_py->del(h); + } }) {} SharedHandle(const SharedHandle&) = default; SharedHandle& operator=(const SharedHandle&) = default; @@ -71,11 +70,14 @@ struct Tensor : std::enable_shared_from_this, NonCopyableObj { GradInfo m_grad_info; TraceInfo m_trace_info; SharedHandle m_handle; + cg::VarNode* m_var; using Handle = interpreter::Interpreter::Handle; - inline explicit Tensor(Handle handle) : m_handle(handle) {} - inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)) {} + inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {} + inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)), m_var(nullptr) {} + inline explicit Tensor(cg::VarNode *var) : m_handle(nullptr), m_var(var) {} + ~Tensor() = default; inline std::shared_ptr copy() { @@ -83,12 +85,28 @@ struct Tensor : std::enable_shared_from_this, NonCopyableObj { ret->m_flags = m_flags; ret->m_grad_info = m_grad_info; ret->m_trace_info = m_trace_info; + ret->m_var = m_var; return ret; } - inline DType dtype() {return interpreter_for_py->get_dtype(m_handle.get());} - inline CompNode comp_node() {return interpreter_for_py->get_device(m_handle.get());} - inline TensorShape shape() {return interpreter_for_py->get_shape(m_handle.get());} + inline DType dtype() { + if (m_var) { + return m_var->dtype(); + } + return interpreter_for_py->get_dtype(m_handle.get()); + } + inline CompNode comp_node() { + if (m_var) { + return m_var->comp_node(); + } + return interpreter_for_py->get_device(m_handle.get()); + } + inline TensorShape shape() { + if (m_var) { + return m_var->shape(); + } + return interpreter_for_py->get_shape(m_handle.get()); + } }; @@ -135,6 +153,19 @@ struct TensorWrapper { void _swap_in(); void _swap_out(); void _drop(); + PyObject* varnode(); + PyObject* handle(); + void set_handle(PyObject *); + + PyObject* data_read(); + PyObject* value_read(); + PyObject* shape_read(); + PyObject* mixin_handle(); + + void set_data_read(PyObject*); + void set_value_read(PyObject*); + void set_shape_read(PyObject*); + void set_mixin_handle(PyObject*); }; @@ -145,6 +176,7 @@ struct ApplyContext { std::shared_ptr op; Tensor*const* args; size_t nargs; + bool backward = false; }; using apply_result_t = SmallVector, 8>; @@ -153,6 +185,14 @@ apply_result_t apply(ApplyContext& ctx); void init_tensor(pybind11::module); +extern bool is_tracing; +extern bool is_symbolic; +extern bool is_compiled; +extern int64_t call_level; + +extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode; +extern pybind11::object cpp_apply_backward_varnode; + } // namespace mgb::imperative::python namespace pybind11::detail { diff --git a/imperative/python/src/trace.cpp b/imperative/python/src/trace.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a698a9698c0f8524ed2c05fe03148bb65bbd0f51 --- /dev/null +++ b/imperative/python/src/trace.cpp @@ -0,0 +1,94 @@ +/** + * \file imperative/python/src/trace.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./trace.h" +#include "./helper.h" +#include "megbrain/imperative/ops/autogen.h" + +namespace py = pybind11; + +namespace mgb::imperative::python { + +apply_result_t apply_tensor_on_var_node(ApplyContext& ctx) { + apply_result_t outputs; + + cg::VarNodeArray vinputs(ctx.nargs); + for (size_t i = 0; i < ctx.nargs; i++) { + vinputs[i] = ctx.args[i]->m_var; + } + auto ovars = OpDef::apply_on_var_node(*ctx.op, vinputs); + + for (size_t i = 0; i < ovars.size(); i++) { + outputs.emplace_back(std::make_shared(ovars[i])); + } + + return outputs; +} + +apply_result_t apply_trace(ApplyContext& ctx) { + apply_result_t outputs; + + bool run_apply_on_var_node = false; + for (size_t i = 0; i < ctx.nargs; i++) { + run_apply_on_var_node |= ((ctx.args[i]->m_handle.get() == nullptr) & (ctx.args[i]->m_var != nullptr)); + } + + if (ctx.backward) { + // reach here when symbolic=True or compiled=True + // call megbrain_graph.py apply(BackwardGraph, *args) + auto args = py::tuple(ctx.nargs); + for (size_t i = 0; i < ctx.nargs; i++) { + args[i] = py::cast(ctx.args[i]->m_var); + } + py::object ret = cpp_apply_backward_varnode(py::cast(ctx.op), *args); + + if (!ret) { + throw py::value_error("invalid py object call"); + } + + // assumption: python function always returns PyList + auto tup = py::reinterpret_borrow(ret); + for (auto i = 0; i < tup.size(); i++) { + auto pitem = tup[i].cast(); + outputs.emplace_back(std::make_shared(pitem)); + } + return outputs; + } + + if (run_apply_on_var_node && !is_symbolic) { + return apply_tensor_on_var_node(ctx); + } + + py::object pyf; + if (is_compiled) { + // run apply in compiled mode, step 2, 3, etc + pyf = cpp_apply_compiled_mode; + } else { + // run first step, both symbolic and non symbolic + pyf = cpp_apply_with_tracing; + } + + auto args = py::tuple(ctx.nargs); + for (size_t i = 0; i < ctx.nargs; i++) { + args[i] = TensorWrapper::make(std::move(std::shared_ptr(ctx.args[i]))).release(); + } + auto ret = pyf(py::cast(ctx.op), *args); + + // assumption: python function always returns PyList + auto tup = py::reinterpret_borrow(ret); + for (auto i = 0; i < tup.size(); i++) { + auto tw = TensorWrapper::cast_safe(tup[i].ptr()); + outputs.emplace_back(tw->m_tensor); + } + return outputs; +} + +} // namespace mgb::imperative::python diff --git a/imperative/python/src/trace.h b/imperative/python/src/trace.h index d84d76a872adfa9fe4a9db6624ca1f4de2795c21..c81ccf857ce1d17a7d03a47fb48d10440b0eb9b1 100644 --- a/imperative/python/src/trace.h +++ b/imperative/python/src/trace.h @@ -9,9 +9,10 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "./tensor.h" + namespace mgb::imperative::python { -struct TraceInfo { -}; +apply_result_t apply_trace(ApplyContext& ctx); } // namespace mgb::imperative::python diff --git a/imperative/python/src/trace_info.h b/imperative/python/src/trace_info.h new file mode 100644 index 0000000000000000000000000000000000000000..3a33ab5c22853d9be546d188f060e8bdcffd612a --- /dev/null +++ b/imperative/python/src/trace_info.h @@ -0,0 +1,24 @@ +/** + * \file imperative/python/src/trace_info.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "inttypes.h" + +namespace mgb::imperative::python { + +struct TraceInfo { + int64_t mixin_handle = -1; + + bool data_read = false; + bool value_read = false; + bool shape_read = false; +}; + +} // namespace mgb::imperative::python diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 862d98b367c417fb27ae608eec20af51532aabc7..cad56be5aa8ac28c140bb46e708f4525170a8615 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -19,8 +19,6 @@ from megengine import tensor from megengine.core._trace_option import set_symbolic_shape from megengine.core.ops import builtin as ops from megengine.core.ops.builtin import Elemwise -from megengine.core.tensor.core import apply -from megengine.core.tensor.raw_tensor import as_raw_tensor from megengine.core.tensor.utils import isscalar from megengine.functional import exp, log from megengine.jit import exclude_from_trace, trace @@ -32,35 +30,32 @@ def test_trace(): @trace(symbolic=symbolic) def f(x): - op = ops.Elemwise(Elemwise.Mode.NEGATE) - (y,) = apply(op, x) - return y + return -x - x = as_raw_tensor([1]).numpy() - y = f.__wrapped__(as_raw_tensor(x)).numpy() + x = tensor([1]) + y = f(x).numpy() for i in range(3): - np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) + np.testing.assert_equal(f(x).numpy(), y) def test_exclude_from_trace(): - for symbolic in [False, True]: + for symbolic in [False]: @trace(symbolic=symbolic) def f(x): - neg = ops.Elemwise(Elemwise.Mode.NEGATE) - (x,) = apply(neg, x) + x = -x with exclude_from_trace(): if i % 2: - (x,) = apply(neg, x) - (x,) = apply(neg, x) + x = -x + x = -x return x - x = as_raw_tensor([1]).numpy() + x = tensor([1]) for i in range(3): - y = f.__wrapped__(as_raw_tensor(x)).numpy() - np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) + y = f(x).numpy() + np.testing.assert_equal(f(x).numpy(), y) def test_print_in_trace(): @@ -69,36 +64,33 @@ def test_print_in_trace(): @trace(symbolic=symbolic) def f(x): nonlocal buf - neg = ops.Elemwise(Elemwise.Mode.NEGATE) - (x,) = apply(neg, x) + x = -x buf = x.numpy() - (x,) = apply(neg, x) + x = -x return x buf = None - x = as_raw_tensor([1]).numpy() + x = tensor([1]) for i in range(3): - y = f.__wrapped__(as_raw_tensor(x)).numpy() + y = f(x).numpy() z = buf buf = None - np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) + np.testing.assert_equal(f(x).numpy(), y) np.testing.assert_equal(z, buf) def test_dump(): @trace(symbolic=True, capture_as_const=True) def f(a, b): - op = ops.Elemwise(Elemwise.Mode.ADD) - (y,) = apply(op, a, b) - return y + return a + b - a = as_raw_tensor([2]).numpy() - b = as_raw_tensor([4]).numpy() - y = f.__wrapped__(as_raw_tensor(a), as_raw_tensor(b)).numpy() + a = tensor([2]) + b = tensor([4]) + y = f(a, b).numpy() for i in range(3): - np.testing.assert_equal(f(as_raw_tensor(a), as_raw_tensor(b)).numpy(), y) + np.testing.assert_equal(f(a, b).numpy(), y) file = io.BytesIO() dump_info = f.dump(file) @@ -111,19 +103,17 @@ def test_dump(): def test_capture_dump(): - a = as_raw_tensor([2]) + a = tensor([2]) @trace(symbolic=True, capture_as_const=True) def f(x): - op = ops.Elemwise(Elemwise.Mode.MUL) - (y,) = apply(op, x, a) - return y + return x * a - x = as_raw_tensor([3]).numpy() - y = f.__wrapped__(as_raw_tensor(x)).numpy() + x = tensor([3]) + y = f(x).numpy() for i in range(3): - np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) + np.testing.assert_equal(f(x).numpy(), y) file = io.BytesIO() f.dump(file) @@ -133,19 +123,17 @@ def test_capture_dump(): def test_dump_volatile(): - p = as_raw_tensor([2]) + p = tensor([2]) @trace(symbolic=True, capture_as_const=True) def f(x): - op = ops.Elemwise(Elemwise.Mode.MUL) - (y,) = apply(op, x, p) - return y + return x * p - x = as_raw_tensor([3]).numpy() - y = f.__wrapped__(as_raw_tensor(x)).numpy() + x = tensor([3]) + y = f(x).numpy() for i in range(3): - np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) + np.testing.assert_equal(f(x).numpy(), y) file = io.BytesIO() f.dump(file, optimize_for_inference=False) @@ -163,21 +151,18 @@ def test_trace_profiler(): @trace(symbolic=symbolic, profiling=True) def f(x): - op = ops.Elemwise(Elemwise.Mode.NEGATE) - (y,) = apply(op, x) - return y + return -x - x = as_raw_tensor([1]).numpy() - y = f.__wrapped__(as_raw_tensor(x)).numpy() + x = tensor([1]) + y = f(x).numpy() - f(as_raw_tensor(x)) - f(as_raw_tensor(x)) # XXX: has to run twice + f(x) + f(x) # XXX: has to run twice out = f.get_profile() assert out.get("profiler") -@pytest.mark.skip(reason="force opt_level=0 when building graph") def test_goptions(): @trace(symbolic=True, opt_level=0, capture_as_const=True) def f(x): @@ -196,7 +181,6 @@ def test_goptions(): np.testing.assert_equal(g(d).numpy().item(), 1.0) -@pytest.mark.skip(reason="force opt_level=0 when building graph") def test_goptions_log_sum_exp(): @trace(symbolic=True, opt_level=0, capture_as_const=True) def f(x, y): @@ -256,8 +240,7 @@ def test_optimize_for_inference_broadcast(): @trace(capture_as_const=True, symbolic_shape=True) def f(): - (b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32)) - return b + return a._broadcast(tensor([1, 10], dtype=np.int32)) f() f.dump(io.BytesIO()) @@ -387,7 +370,9 @@ def test_trace_nms(): @trace(symbolic=False) def f(boxes, scores): + # with tracing, max_output must be specified results = F.nn.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20) + # without tracing, max output can be inferred inside nms with exclude_from_trace(): _ = F.nn.nms(boxes, scores=scores, iou_thresh=0.5) return results diff --git a/sdk/load-and-run/dump_with_testcase_mge.py b/sdk/load-and-run/dump_with_testcase_mge.py index dfbcc8f0e821d5febf00fb6a1f9357e15b0cc0ca..c6881226c37e6d593138ece8d4e65cd5ad9694a8 100755 --- a/sdk/load-and-run/dump_with_testcase_mge.py +++ b/sdk/load-and-run/dump_with_testcase_mge.py @@ -318,7 +318,6 @@ def optimize_for_inference(args, outputs): ), "optimize_for_inference should be set when {} is given".format(k) kwargs[v] = True - outputs = [G.VarNode(output) for output in outputs] if args.optimize_for_inference: outputs = [i._node for i in G.optimize_for_inference(outputs, **kwargs)] diff --git a/sdk/xor-deploy/xornet.py b/sdk/xor-deploy/xornet.py index 50be3d073efa04806607d5e0610f9415e5d76f9a..5608354f9583f2db6c551470634d8c68443d443b 100644 --- a/sdk/xor-deploy/xornet.py +++ b/sdk/xor-deploy/xornet.py @@ -84,7 +84,7 @@ def main(): minibatch = next(val_dataset) net.eval() _, loss = val_fun(data, label) - loss = loss.numpy()[0] + loss = loss.numpy() val_loss.append((step, loss)) print("Step: {} loss={}".format(step, loss)) opt.step()