diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 1eab184fd8c9592c146488a2bab3644e004c49ac..1fe93c715e58aab7b757a9aecf4400cb241fbcb2 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -125,6 +125,9 @@ class trace: self._graph_opt_level = opt_level self._tensor_shape = tensor_shape + self._reset() + + def _reset(self): self._untraced = True self._tinfo = [] # handle -> TensorInfo self._seq = [] @@ -257,77 +260,117 @@ class trace: def _record_const(self, op, outputs): pass - @contextlib.contextmanager - def _setup(self): + def _set_active(self, active: bool): global active_trace - if active_trace: - raise NotImplementedError("sorry, not implemented: nested trace") - active_trace = self - - if self._untraced: - apply.enable(apply_with_tracing) - apply.enable(apply_const_with_tracing) - if self._symbolic: - apply.enable(apply_symbolic_mode) - apply.enable(apply_const_symbolic_mode) - self._lazy_eval_graph = G.Graph() + if active: + if active_trace: + raise NotImplementedError("sorry, not implemented: nested trace") + active_trace = self else: - apply.enable(apply_compiled_mode) - if self._graph is None: - self._compile() - self._graph.execute() - - yield - + assert active_trace is self + active_trace = None + + def _init_trace(self, symbolic: bool): + apply.enable(apply_with_tracing) + apply.enable(apply_const_with_tracing) + if symbolic: + apply.enable(apply_symbolic_mode) + apply.enable(apply_const_symbolic_mode) + self._lazy_eval_graph = G.Graph() + + def _take_escaped_tensors(self): escaped_tensors = tuple(self._active_tensors) self._active_tensors.clear() + return escaped_tensors - if self._untraced: - for x in escaped_tensors: - info = self._tinfo[x._TraceMixin__handle] - info.data_read = True - x._TraceMixin__restore() - if self._inputs_to_restore: - for x in self._inputs_to_restore: + def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors): + active_lazy_eval_tensors = [] + visited = set() + readers = [] + for x in lazy_eval_tensors: + x = x() + if x is None or x in visited: + continue + reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] + readers.append(reader) + active_lazy_eval_tensors.append(x) + visited.add(x) + self._apply_graph_options(lazy_eval_graph) + lazy_eval_graph.compile(*readers) + lazy_eval_graph() + for r, x in zip(readers, active_lazy_eval_tensors): + assign_raw_tensor(x, as_raw_tensor(r.op.get_value())) + + @contextlib.contextmanager + def _setup(self): + interrupted = False + + def do_enter(): + self._set_active(True) + if self._untraced: + self._init_trace(self._symbolic) + else: + apply.enable(apply_compiled_mode) + if self._graph is None: + self._compile() + self._graph.execute() + + def do_finalize(): + escaped_tensors = self._take_escaped_tensors() + if self._untraced: + for x in escaped_tensors: + info = self._tinfo[x._TraceMixin__handle] + info.data_read = True x._TraceMixin__restore() - if self._symbolic: - # eval lazy eval tensors - if self._lazy_eval_tensors: - lazy_eval_tensors = [] - visited = set() - readers = [] - for x in self._lazy_eval_tensors: - x = x() - if x is None or x in visited: - continue - reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] - readers.append(reader) - lazy_eval_tensors.append(x) - visited.add(x) - self._apply_graph_options(self._lazy_eval_graph) - self._lazy_eval_graph.compile(*readers) - self._lazy_eval_graph() - for r, x in zip(readers, lazy_eval_tensors): - assign_raw_tensor(x, as_raw_tensor(r.op.get_value())) + if self._inputs_to_restore: + for x in self._inputs_to_restore: + x._TraceMixin__restore() + if self._symbolic and self._lazy_eval_tensors: + # eval lazy eval tensors + self._lazy_eval(self._lazy_eval_graph, self._lazy_eval_tensors) self._lazy_eval_graph = None self._lazy_eval_tensors = None - self._untraced = False - else: - if self._pc != len(self._seq): - raise TraceMismatchError("premature end") - for x in escaped_tensors: - assign_raw_tensor(x, as_raw_tensor(x._dev_tensor())) - self._graph.wait() - self._reset_exec_env() + self._untraced = False + else: + # compiled_tensor leaks + if self._pc == len(self._seq): + for x in escaped_tensors: + try: + assign_raw_tensor(x, as_raw_tensor(x._dev_tensor())) + except TraceMismatchError: + # TraceMismatchError thrown in do_exit + pass + self._graph.wait() + self._reset_exec_env() + + # reset status self._pc = 0 - - self._tensor_remaps = None - apply.disable(apply_with_tracing) - apply.disable(apply_const_with_tracing) - apply.disable(apply_symbolic_mode) - apply.disable(apply_const_symbolic_mode) - apply.disable(apply_compiled_mode) - active_trace = None + self._tensor_remaps = None + apply.disable(apply_with_tracing) + apply.disable(apply_const_with_tracing) + apply.disable(apply_symbolic_mode) + apply.disable(apply_const_symbolic_mode) + apply.disable(apply_compiled_mode) + self._set_active(False) + + def do_exit(): + if not self._untraced and self._pc != len(self._seq): + raise TraceMismatchError("premature end") + if not self._symbolic or not self._untraced: + for x in self._active_tensors: + x._dev_tensor() + + try: + do_enter() + yield + do_exit() + except: + interrupted = True + raise + finally: + do_finalize() + if interrupted: + self._reset() def _begin_excluded_region(self): if self._capture_as_const: diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 443222c4ca16de0d4e5b12877b0082c6e417e636..a022b9eb7047a942ad9314f45e7d5eccceb496e2 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -307,3 +307,36 @@ def test_trace_warp_perspective(): for i in range(1): f(x, M) + + +def test_raise_on_trace(): + step_count = 0 + catch_count = 0 + bad_step = 10 + + class CatchMe(Exception): + pass + + a = tensor([1, 2, 3, 4]) + b = tensor([5, 6, 7, 8]) + c = tensor([9, 0, 1, 2]) + + @trace + def add_abc(a, b, c): + print("Hello") + ps = a + b + result = ps + c + if step_count == bad_step: + raise CatchMe("catch me") + return result + + for i in range(100): + try: + d = add_abc(a, b, c) + except CatchMe as e: + catch_count += 1 + else: + np.testing.assert_equal(d.numpy(), (a + b + c).numpy()) + step_count += 1 + + assert catch_count == 1