From b7e596b4a1b4bf773a7cd278b111566bf0b9e6d1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 23 Jun 2021 13:54:04 +0800 Subject: [PATCH] perf(autograd): copy inputs before capture in backward_graph_grad_rule GitOrigin-RevId: 8b9c067b2d4ad52d8a0ba876d8fb5e2e688f291b --- imperative/python/src/grad.cpp | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index 97e0fcc4f..45ce932d0 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -315,14 +315,25 @@ public: }; apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { - auto outputs = apply(ctx); + // copy inputs first, or trace will make InputNodes for each usage + SmallVector> inputs_copy; + SmallVector inputs_copy_weak; + for (size_t i = 0; i < ctx.nargs; ++i) { + inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]); + inputs_copy_weak.push_back(inputs_copy.back().get()); + inputs_copy.back()->m_grad_info = ctx.args[i]->m_grad_info; + } + ApplyContext ctx_dup = ctx; + ctx_dup.args = inputs_copy_weak.data(); + + auto outputs = apply(ctx_dup); - auto backward_graph = make_backward_graph(ctx, outputs); + auto backward_graph = make_backward_graph(ctx_dup, outputs); if (!backward_graph) { return outputs; } - ret_grad_fn.emplace(std::move(backward_graph), ctx, outputs); + ret_grad_fn.emplace(std::move(backward_graph), ctx_dup, outputs); return outputs; } -- GitLab