提交 1b568517 编写于 作者: M Megvii Engine Team

fix(mge/trace): fix graph option in trace

GitOrigin-RevId: 7bec84f56d61cea4ad50f66224ffa11255458d0b
上级 f2f5f9ac
......@@ -284,6 +284,7 @@ class trace:
apply.enable(apply_symbolic_mode)
apply.enable(apply_const_symbolic_mode)
self._lazy_eval_graph = G.Graph()
self._apply_graph_options(self._lazy_eval_graph)
def _take_escaped_tensors(self):
escaped_tensors = tuple(self._active_tensors)
......@@ -302,7 +303,6 @@ class trace:
readers.append(reader)
active_lazy_eval_tensors.append(x)
visited.add(x)
self._apply_graph_options(lazy_eval_graph)
lazy_eval_graph.compile(*readers)
lazy_eval_graph()
for r, x in zip(readers, active_lazy_eval_tensors):
......@@ -599,6 +599,8 @@ class trace:
h2v = {}
graph = G.Graph()
# only graph_opt_level takes effect in dump
self._apply_graph_options(graph)
for i, h in enumerate(self._arg_bindings):
info = self._tinfo[h]
......
......@@ -174,31 +174,24 @@ def test_trace_profiler():
assert out.get("profiler")
@pytest.mark.skip(reason="could not disable opt_level")
def test_goptions_log_exp():
def test_goptions():
@trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x):
return log(exp(x))
# directly return x / x will not trigger gopt
# since there's no way to tell the two x are the same
y = 2.0 * x
return y / y
@trace(symbolic=True, opt_level=1, capture_as_const=True)
def g(x):
return log(exp(x))
f(tensor(1.0))
_, out = mkstemp()
f.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
oprs_1 = cgtools.get_oprs_seq(outputs)
g(tensor(1.0))
g.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
oprs_2 = cgtools.get_oprs_seq(outputs)
y = 2.0 * x
return y / y
assert len(oprs_1) - len(oprs_2) == 2
d = tensor(0.0)
assert not np.isfinite(f(d).numpy())
np.testing.assert_equal(g(d).numpy().item(), 1.0)
@pytest.mark.skip(reason="could not disable opt_level")
def test_goptions_log_sum_exp():
@trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x, y):
......@@ -208,13 +201,30 @@ def test_goptions_log_sum_exp():
def g(x, y):
return log(exp(x) + exp(y))
f(tensor(1.0), tensor(2.0))
val = 1.0e4
d = tensor(val)
o = tensor(0.0)
assert not np.isfinite(f(d, o).numpy())
np.testing.assert_almost_equal(g(d, o), val)
@pytest.mark.skip(reason="could not use opt_level=0 with dump")
def test_goptions_log_exp():
@trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x):
return log(exp(x))
@trace(symbolic=True, opt_level=1, capture_as_const=True)
def g(x):
return log(exp(x))
f(tensor(1.0))
_, out = mkstemp()
f.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
oprs_1 = cgtools.get_oprs_seq(outputs)
g(tensor(1.0), tensor(2.0))
g(tensor(1.0))
g.dump(out, optimize_for_inference=False)
*_, outputs = G.load_graph(out)
oprs_2 = cgtools.get_oprs_seq(outputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册