提交 7ac4dbc2 编写于 作者: M Megvii Engine Team

fix(mgb/trace): finalize when exception raise

GitOrigin-RevId: b8ffd00a7ea29add26a2a3d6275ea9f29d877908
上级 2bd84d67
...@@ -125,6 +125,9 @@ class trace: ...@@ -125,6 +125,9 @@ class trace:
self._graph_opt_level = opt_level self._graph_opt_level = opt_level
self._tensor_shape = tensor_shape self._tensor_shape = tensor_shape
self._reset()
def _reset(self):
self._untraced = True self._untraced = True
self._tinfo = [] # handle -> TensorInfo self._tinfo = [] # handle -> TensorInfo
self._seq = [] self._seq = []
...@@ -257,77 +260,117 @@ class trace: ...@@ -257,77 +260,117 @@ class trace:
def _record_const(self, op, outputs): def _record_const(self, op, outputs):
pass pass
@contextlib.contextmanager def _set_active(self, active: bool):
def _setup(self):
global active_trace global active_trace
if active_trace: if active:
raise NotImplementedError("sorry, not implemented: nested trace") if active_trace:
active_trace = self raise NotImplementedError("sorry, not implemented: nested trace")
active_trace = self
if self._untraced:
apply.enable(apply_with_tracing)
apply.enable(apply_const_with_tracing)
if self._symbolic:
apply.enable(apply_symbolic_mode)
apply.enable(apply_const_symbolic_mode)
self._lazy_eval_graph = G.Graph()
else: else:
apply.enable(apply_compiled_mode) assert active_trace is self
if self._graph is None: active_trace = None
self._compile()
self._graph.execute() def _init_trace(self, symbolic: bool):
apply.enable(apply_with_tracing)
yield apply.enable(apply_const_with_tracing)
if symbolic:
apply.enable(apply_symbolic_mode)
apply.enable(apply_const_symbolic_mode)
self._lazy_eval_graph = G.Graph()
def _take_escaped_tensors(self):
escaped_tensors = tuple(self._active_tensors) escaped_tensors = tuple(self._active_tensors)
self._active_tensors.clear() self._active_tensors.clear()
return escaped_tensors
if self._untraced: def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors):
for x in escaped_tensors: active_lazy_eval_tensors = []
info = self._tinfo[x._TraceMixin__handle] visited = set()
info.data_read = True readers = []
x._TraceMixin__restore() for x in lazy_eval_tensors:
if self._inputs_to_restore: x = x()
for x in self._inputs_to_restore: if x is None or x in visited:
continue
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
readers.append(reader)
active_lazy_eval_tensors.append(x)
visited.add(x)
self._apply_graph_options(lazy_eval_graph)
lazy_eval_graph.compile(*readers)
lazy_eval_graph()
for r, x in zip(readers, active_lazy_eval_tensors):
assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
@contextlib.contextmanager
def _setup(self):
interrupted = False
def do_enter():
self._set_active(True)
if self._untraced:
self._init_trace(self._symbolic)
else:
apply.enable(apply_compiled_mode)
if self._graph is None:
self._compile()
self._graph.execute()
def do_finalize():
escaped_tensors = self._take_escaped_tensors()
if self._untraced:
for x in escaped_tensors:
info = self._tinfo[x._TraceMixin__handle]
info.data_read = True
x._TraceMixin__restore() x._TraceMixin__restore()
if self._symbolic: if self._inputs_to_restore:
# eval lazy eval tensors for x in self._inputs_to_restore:
if self._lazy_eval_tensors: x._TraceMixin__restore()
lazy_eval_tensors = [] if self._symbolic and self._lazy_eval_tensors:
visited = set() # eval lazy eval tensors
readers = [] self._lazy_eval(self._lazy_eval_graph, self._lazy_eval_tensors)
for x in self._lazy_eval_tensors:
x = x()
if x is None or x in visited:
continue
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
readers.append(reader)
lazy_eval_tensors.append(x)
visited.add(x)
self._apply_graph_options(self._lazy_eval_graph)
self._lazy_eval_graph.compile(*readers)
self._lazy_eval_graph()
for r, x in zip(readers, lazy_eval_tensors):
assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
self._lazy_eval_graph = None self._lazy_eval_graph = None
self._lazy_eval_tensors = None self._lazy_eval_tensors = None
self._untraced = False self._untraced = False
else: else:
if self._pc != len(self._seq): # compiled_tensor leaks
raise TraceMismatchError("premature end") if self._pc == len(self._seq):
for x in escaped_tensors: for x in escaped_tensors:
assign_raw_tensor(x, as_raw_tensor(x._dev_tensor())) try:
self._graph.wait() assign_raw_tensor(x, as_raw_tensor(x._dev_tensor()))
self._reset_exec_env() except TraceMismatchError:
# TraceMismatchError thrown in do_exit
pass
self._graph.wait()
self._reset_exec_env()
# reset status
self._pc = 0 self._pc = 0
self._tensor_remaps = None
self._tensor_remaps = None apply.disable(apply_with_tracing)
apply.disable(apply_with_tracing) apply.disable(apply_const_with_tracing)
apply.disable(apply_const_with_tracing) apply.disable(apply_symbolic_mode)
apply.disable(apply_symbolic_mode) apply.disable(apply_const_symbolic_mode)
apply.disable(apply_const_symbolic_mode) apply.disable(apply_compiled_mode)
apply.disable(apply_compiled_mode) self._set_active(False)
active_trace = None
def do_exit():
if not self._untraced and self._pc != len(self._seq):
raise TraceMismatchError("premature end")
if not self._symbolic or not self._untraced:
for x in self._active_tensors:
x._dev_tensor()
try:
do_enter()
yield
do_exit()
except:
interrupted = True
raise
finally:
do_finalize()
if interrupted:
self._reset()
def _begin_excluded_region(self): def _begin_excluded_region(self):
if self._capture_as_const: if self._capture_as_const:
......
...@@ -307,3 +307,36 @@ def test_trace_warp_perspective(): ...@@ -307,3 +307,36 @@ def test_trace_warp_perspective():
for i in range(1): for i in range(1):
f(x, M) f(x, M)
def test_raise_on_trace():
step_count = 0
catch_count = 0
bad_step = 10
class CatchMe(Exception):
pass
a = tensor([1, 2, 3, 4])
b = tensor([5, 6, 7, 8])
c = tensor([9, 0, 1, 2])
@trace
def add_abc(a, b, c):
print("Hello")
ps = a + b
result = ps + c
if step_count == bad_step:
raise CatchMe("catch me")
return result
for i in range(100):
try:
d = add_abc(a, b, c)
except CatchMe as e:
catch_count += 1
else:
np.testing.assert_equal(d.numpy(), (a + b + c).numpy())
step_count += 1
assert catch_count == 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册