提交 b7e596b4 编写于 作者: M Megvii Engine Team 提交者: huangxinda

perf(autograd): copy inputs before capture in backward_graph_grad_rule

GitOrigin-RevId: 8b9c067b2d4ad52d8a0ba876d8fb5e2e688f291b
上级 2ac3c9dc
......@@ -315,14 +315,25 @@ public:
};
apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
auto outputs = apply(ctx);
// copy inputs first, or trace will make InputNodes for each usage
SmallVector<std::shared_ptr<Tensor>> inputs_copy;
SmallVector<Tensor*> inputs_copy_weak;
for (size_t i = 0; i < ctx.nargs; ++i) {
inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]);
inputs_copy_weak.push_back(inputs_copy.back().get());
inputs_copy.back()->m_grad_info = ctx.args[i]->m_grad_info;
}
ApplyContext ctx_dup = ctx;
ctx_dup.args = inputs_copy_weak.data();
auto outputs = apply(ctx_dup);
auto backward_graph = make_backward_graph(ctx, outputs);
auto backward_graph = make_backward_graph(ctx_dup, outputs);
if (!backward_graph) {
return outputs;
}
ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx, outputs);
ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx_dup, outputs);
return outputs;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册