提交 87d6ff22 编写于 作者: M Megvii Engine Team

fix(mge): fix dumping backward graph

GitOrigin-RevId: 430f110053911dbb7719badb6463a8280376ed42
上级 f31752d5
...@@ -489,7 +489,7 @@ def apply_backward_varnode(op: BackwardGraph, *args: VarNode): ...@@ -489,7 +489,7 @@ def apply_backward_varnode(op: BackwardGraph, *args: VarNode):
graph._make_const_for_backward, graph._make_const_for_backward,
args, args,
) )
return _unwrap(outputs) return outputs
set_cpp_apply_backward_varnode(apply_backward_varnode) set_cpp_apply_backward_varnode(apply_backward_varnode)
......
...@@ -830,7 +830,10 @@ class trace: ...@@ -830,7 +830,10 @@ class trace:
name=info.name, name=info.name,
) )
ivars.append(h2v[h]) ivars.append(h2v[h])
ovars = G.apply_normal_varnode(op, *ivars) if isinstance(op, BackwardGraph):
ovars = G.apply_backward_varnode(op, *ivars)
else:
ovars = G.apply_normal_varnode(op, *ivars)
AutoNaming.record_opnode(ovars[0].op) AutoNaming.record_opnode(ovars[0].op)
......
...@@ -247,6 +247,34 @@ def test_dump_volatile(): ...@@ -247,6 +247,34 @@ def test_dump_volatile():
) )
def test_dump_backward_graph():
x0 = tensor(np.random.randn(3, 4))
x1 = tensor(np.random.randn(3, 4))
gm = GradManager().attach(x0)
@trace(symbolic=True, capture_as_const=True)
def f(x0, x1):
with gm:
y = x0 * x1
gm.backward(y, F.ones_like(y))
dx0 = x0.grad
return y, dx0
y, dx0 = f(x0, x1)
np.testing.assert_equal(dx0.numpy(), x1)
file = io.BytesIO()
f.dump(file, optimize_for_inference=False)
file.seek(0)
infer_cg = cgtools.GraphInference(file)
results = list((infer_cg.run(x0, x1)).values())
np.testing.assert_equal(results[0], y)
np.testing.assert_equal(results[1], dx0)
@pytest.mark.parametrize("trace_mode", [False, True]) @pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_profiler(trace_mode): def test_trace_profiler(trace_mode):
@trace(symbolic=trace_mode, profiling=True) @trace(symbolic=trace_mode, profiling=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册