diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index eaa3d6b2aa17563d00c067a7022cf8f36c1cc602..1bff175b254734107273b33380412a172b4b56f0 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -489,7 +489,7 @@ def apply_backward_varnode(op: BackwardGraph, *args: VarNode): graph._make_const_for_backward, args, ) - return _unwrap(outputs) + return outputs set_cpp_apply_backward_varnode(apply_backward_varnode) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 203efe0aab2a60f12a215c4052592f3e4e7f034b..c4b70d6cdefd988acd24dc5dbfa15825e58622f8 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -830,7 +830,10 @@ class trace: name=info.name, ) ivars.append(h2v[h]) - ovars = G.apply_normal_varnode(op, *ivars) + if isinstance(op, BackwardGraph): + ovars = G.apply_backward_varnode(op, *ivars) + else: + ovars = G.apply_normal_varnode(op, *ivars) AutoNaming.record_opnode(ovars[0].op) diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index 44ac4f044b28cdd1e8361f330ae041130849a4c5..480587a925843212da548b2bf7252ec2af5299bd 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -247,6 +247,34 @@ def test_dump_volatile(): ) +def test_dump_backward_graph(): + x0 = tensor(np.random.randn(3, 4)) + x1 = tensor(np.random.randn(3, 4)) + + gm = GradManager().attach(x0) + + @trace(symbolic=True, capture_as_const=True) + def f(x0, x1): + with gm: + y = x0 * x1 + gm.backward(y, F.ones_like(y)) + dx0 = x0.grad + return y, dx0 + + y, dx0 = f(x0, x1) + np.testing.assert_equal(dx0.numpy(), x1) + + file = io.BytesIO() + f.dump(file, optimize_for_inference=False) + file.seek(0) + + infer_cg = cgtools.GraphInference(file) + results = list((infer_cg.run(x0, x1)).values()) + + np.testing.assert_equal(results[0], y) + np.testing.assert_equal(results[1], dx0) + + @pytest.mark.parametrize("trace_mode", [False, True]) def test_trace_profiler(trace_mode): @trace(symbolic=trace_mode, profiling=True)