diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index 8019137a618dcd613d79b352d325a473a7ee304d..ba340ab324a70b4e5f23eff9641b38fb7cb29952 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -282,6 +282,10 @@ struct GradFn : std::enable_shared_from_this { } }; +GradSlotPtr::operator bool() const { + return bool(grad_fn); +} + GradSlot* GradSlotPtr::operator->() { return &grad_fn->slots[idx]; } @@ -537,7 +541,10 @@ void GradKey::backward(std::vector tensors, std::vectordsts[i]->grad, std::forward(g)); + auto& dst = grad_fn->dsts[i]; + if (dst) { + accum_grad(dst->grad, std::forward(g)); + } }; std::visit([&](auto&& backward) { using T = std::decay_t; diff --git a/imperative/python/src/grad_info.h b/imperative/python/src/grad_info.h index 676b598bf57147f74d18d834e64b05a8497970cb..9a1c9d939d76b78c4ac47e215ee9e725faae5887 100644 --- a/imperative/python/src/grad_info.h +++ b/imperative/python/src/grad_info.h @@ -22,6 +22,7 @@ struct GradSlotPtr { std::shared_ptr grad_fn; size_t idx; + operator bool() const; GradSlot* operator->(); };