From e027dcbf2c5478c7b3e4a591c68d0babc0d42d91 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 1 Sep 2020 17:18:38 +0800 Subject: [PATCH] chore(mge): improve symbolic tracing value/shape inference GitOrigin-RevId: d1a6baac741726604c799752b19d2ed90e399639 --- imperative/python/megengine/jit/tracing.py | 28 +++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 2a7889606..dd789fdfe 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -186,6 +186,9 @@ class trace: self._seq.append((op, tuple(ihandles), tuple(ohandles))) self._active_tensors.update(outputs) + def _record_const(self, op, outputs): + pass + @contextlib.contextmanager def _setup(self): global active_trace @@ -195,8 +198,10 @@ class trace: if self._untraced: apply.enable(apply_with_tracing) + apply.enable(apply_const_with_tracing) if self._symbolic: apply.enable(apply_symbolic_mode) + apply.enable(apply_const_symbolic_mode) self._lazy_eval_graph = G.Graph() else: apply.enable(apply_compiled_mode) @@ -239,7 +244,9 @@ class trace: self._pc = 0 apply.disable(apply_with_tracing) + apply.disable(apply_const_with_tracing) apply.disable(apply_symbolic_mode) + apply.disable(apply_const_symbolic_mode) apply.disable(apply_compiled_mode) active_trace = None @@ -477,6 +484,16 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): apply.disable(apply_symbolic_mode) +@apply.register() +def apply_const_symbolic_mode(op: Const, *args: RawTensor): + graph = active_trace._lazy_eval_graph + ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) + return (ret,) + + +apply.disable(apply_const_symbolic_mode) + + @apply.register() def apply_compiled_mode(op: OpDef, *args: RawTensor): if skip_tracing: @@ -502,9 +519,14 @@ def apply_with_tracing(op: OpDef, *args: RawTensor): apply.disable(apply_with_tracing) -# @apply.register() -# def _(op: Const, *args: RawTensor): -# return active_trace._apply_const(op, args) +@apply.register() +def apply_const_with_tracing(op: Const, *args: RawTensor): + outputs = apply.super(op, *args) + active_trace._record_const(op, outputs) + return outputs + + +apply.disable(apply_const_with_tracing) class BrokenRawTensor(RawTensor): -- GitLab