提交 3faba54f 编写于 作者: M Megvii Engine Team

fix(mge): fix segfault with Function returning unused grads

GitOrigin-RevId: 0cce84592337909db3de0d73f566deaf1c229525
上级 75129cf0
......@@ -282,6 +282,10 @@ struct GradFn : std::enable_shared_from_this<GradFn> {
}
};
GradSlotPtr::operator bool() const {
return bool(grad_fn);
}
GradSlot* GradSlotPtr::operator->() {
return &grad_fn->slots[idx];
}
......@@ -537,7 +541,10 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
if (!grad_fn) continue;
auto grad_receiver = [&](size_t i, auto&& g) {
accum_grad(grad_fn->dsts[i]->grad, std::forward<decltype(g)>(g));
auto& dst = grad_fn->dsts[i];
if (dst) {
accum_grad(dst->grad, std::forward<decltype(g)>(g));
}
};
std::visit([&](auto&& backward) {
using T = std::decay_t<decltype(backward)>;
......
......@@ -22,6 +22,7 @@ struct GradSlotPtr {
std::shared_ptr<GradFn> grad_fn;
size_t idx;
operator bool() const;
GradSlot* operator->();
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册