diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index 97e0fcc4fe13fbaf3f3f0c4c4a9a2176aa1bf1bf..45ce932d08454dfd0bc41d35c53c267702a8e5bf 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -315,14 +315,25 @@ public: }; apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { - auto outputs = apply(ctx); + // copy inputs first, or trace will make InputNodes for each usage + SmallVector> inputs_copy; + SmallVector inputs_copy_weak; + for (size_t i = 0; i < ctx.nargs; ++i) { + inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]); + inputs_copy_weak.push_back(inputs_copy.back().get()); + inputs_copy.back()->m_grad_info = ctx.args[i]->m_grad_info; + } + ApplyContext ctx_dup = ctx; + ctx_dup.args = inputs_copy_weak.data(); + + auto outputs = apply(ctx_dup); - auto backward_graph = make_backward_graph(ctx, outputs); + auto backward_graph = make_backward_graph(ctx_dup, outputs); if (!backward_graph) { return outputs; } - ret_grad_fn.emplace(std::move(backward_graph), ctx, outputs); + ret_grad_fn.emplace(std::move(backward_graph), ctx_dup, outputs); return outputs; }