提交 7ac4dbc2 编写于 作者: M Megvii Engine Team

fix(mgb/trace): finalize when exception raise

GitOrigin-RevId: b8ffd00a7ea29add26a2a3d6275ea9f29d877908
上级 2bd84d67
......@@ -125,6 +125,9 @@ class trace:
self._graph_opt_level = opt_level
self._tensor_shape = tensor_shape
self._reset()
def _reset(self):
self._untraced = True
self._tinfo = [] # handle -> TensorInfo
self._seq = []
......@@ -257,31 +260,63 @@ class trace:
def _record_const(self, op, outputs):
pass
@contextlib.contextmanager
def _setup(self):
def _set_active(self, active: bool):
global active_trace
if active:
if active_trace:
raise NotImplementedError("sorry, not implemented: nested trace")
active_trace = self
else:
assert active_trace is self
active_trace = None
if self._untraced:
def _init_trace(self, symbolic: bool):
apply.enable(apply_with_tracing)
apply.enable(apply_const_with_tracing)
if self._symbolic:
if symbolic:
apply.enable(apply_symbolic_mode)
apply.enable(apply_const_symbolic_mode)
self._lazy_eval_graph = G.Graph()
def _take_escaped_tensors(self):
escaped_tensors = tuple(self._active_tensors)
self._active_tensors.clear()
return escaped_tensors
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors):
active_lazy_eval_tensors = []
visited = set()
readers = []
for x in lazy_eval_tensors:
x = x()
if x is None or x in visited:
continue
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
readers.append(reader)
active_lazy_eval_tensors.append(x)
visited.add(x)
self._apply_graph_options(lazy_eval_graph)
lazy_eval_graph.compile(*readers)
lazy_eval_graph()
for r, x in zip(readers, active_lazy_eval_tensors):
assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
@contextlib.contextmanager
def _setup(self):
interrupted = False
def do_enter():
self._set_active(True)
if self._untraced:
self._init_trace(self._symbolic)
else:
apply.enable(apply_compiled_mode)
if self._graph is None:
self._compile()
self._graph.execute()
yield
escaped_tensors = tuple(self._active_tensors)
self._active_tensors.clear()
def do_finalize():
escaped_tensors = self._take_escaped_tensors()
if self._untraced:
for x in escaped_tensors:
info = self._tinfo[x._TraceMixin__handle]
......@@ -290,44 +325,52 @@ class trace:
if self._inputs_to_restore:
for x in self._inputs_to_restore:
x._TraceMixin__restore()
if self._symbolic:
if self._symbolic and self._lazy_eval_tensors:
# eval lazy eval tensors
if self._lazy_eval_tensors:
lazy_eval_tensors = []
visited = set()
readers = []
for x in self._lazy_eval_tensors:
x = x()
if x is None or x in visited:
continue
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
readers.append(reader)
lazy_eval_tensors.append(x)
visited.add(x)
self._apply_graph_options(self._lazy_eval_graph)
self._lazy_eval_graph.compile(*readers)
self._lazy_eval_graph()
for r, x in zip(readers, lazy_eval_tensors):
assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
self._lazy_eval(self._lazy_eval_graph, self._lazy_eval_tensors)
self._lazy_eval_graph = None
self._lazy_eval_tensors = None
self._untraced = False
else:
if self._pc != len(self._seq):
raise TraceMismatchError("premature end")
# compiled_tensor leaks
if self._pc == len(self._seq):
for x in escaped_tensors:
try:
assign_raw_tensor(x, as_raw_tensor(x._dev_tensor()))
except TraceMismatchError:
# TraceMismatchError thrown in do_exit
pass
self._graph.wait()
self._reset_exec_env()
self._pc = 0
# reset status
self._pc = 0
self._tensor_remaps = None
apply.disable(apply_with_tracing)
apply.disable(apply_const_with_tracing)
apply.disable(apply_symbolic_mode)
apply.disable(apply_const_symbolic_mode)
apply.disable(apply_compiled_mode)
active_trace = None
self._set_active(False)
def do_exit():
if not self._untraced and self._pc != len(self._seq):
raise TraceMismatchError("premature end")
if not self._symbolic or not self._untraced:
for x in self._active_tensors:
x._dev_tensor()
try:
do_enter()
yield
do_exit()
except:
interrupted = True
raise
finally:
do_finalize()
if interrupted:
self._reset()
def _begin_excluded_region(self):
if self._capture_as_const:
......
......@@ -307,3 +307,36 @@ def test_trace_warp_perspective():
for i in range(1):
f(x, M)
def test_raise_on_trace():
step_count = 0
catch_count = 0
bad_step = 10
class CatchMe(Exception):
pass
a = tensor([1, 2, 3, 4])
b = tensor([5, 6, 7, 8])
c = tensor([9, 0, 1, 2])
@trace
def add_abc(a, b, c):
print("Hello")
ps = a + b
result = ps + c
if step_count == bad_step:
raise CatchMe("catch me")
return result
for i in range(100):
try:
d = add_abc(a, b, c)
except CatchMe as e:
catch_count += 1
else:
np.testing.assert_equal(d.numpy(), (a + b + c).numpy())
step_count += 1
assert catch_count == 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册