diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 2a78896061853f67c5341ddd9c9a522f41715e0e..dd789fdfe7eaa6a1f2bf9e7bce0d9e4be623b4f7 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -186,6 +186,9 @@ class trace: self._seq.append((op, tuple(ihandles), tuple(ohandles))) self._active_tensors.update(outputs) + def _record_const(self, op, outputs): + pass + @contextlib.contextmanager def _setup(self): global active_trace @@ -195,8 +198,10 @@ class trace: if self._untraced: apply.enable(apply_with_tracing) + apply.enable(apply_const_with_tracing) if self._symbolic: apply.enable(apply_symbolic_mode) + apply.enable(apply_const_symbolic_mode) self._lazy_eval_graph = G.Graph() else: apply.enable(apply_compiled_mode) @@ -239,7 +244,9 @@ class trace: self._pc = 0 apply.disable(apply_with_tracing) + apply.disable(apply_const_with_tracing) apply.disable(apply_symbolic_mode) + apply.disable(apply_const_symbolic_mode) apply.disable(apply_compiled_mode) active_trace = None @@ -477,6 +484,16 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): apply.disable(apply_symbolic_mode) +@apply.register() +def apply_const_symbolic_mode(op: Const, *args: RawTensor): + graph = active_trace._lazy_eval_graph + ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) + return (ret,) + + +apply.disable(apply_const_symbolic_mode) + + @apply.register() def apply_compiled_mode(op: OpDef, *args: RawTensor): if skip_tracing: @@ -502,9 +519,14 @@ def apply_with_tracing(op: OpDef, *args: RawTensor): apply.disable(apply_with_tracing) -# @apply.register() -# def _(op: Const, *args: RawTensor): -# return active_trace._apply_const(op, args) +@apply.register() +def apply_const_with_tracing(op: Const, *args: RawTensor): + outputs = apply.super(op, *args) + active_trace._record_const(op, outputs) + return outputs + + +apply.disable(apply_const_with_tracing) class BrokenRawTensor(RawTensor):