From 3faba54f28615f69158a230c6430b33448ea842e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 31 Dec 2020 16:30:54 +0800 Subject: [PATCH] fix(mge): fix segfault with Function returning unused grads GitOrigin-RevId: 0cce84592337909db3de0d73f566deaf1c229525 --- imperative/python/src/grad.cpp | 9 ++++++++- imperative/python/src/grad_info.h | 1 + 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index 8019137a6..ba340ab32 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -282,6 +282,10 @@ struct GradFn : std::enable_shared_from_this { } }; +GradSlotPtr::operator bool() const { + return bool(grad_fn); +} + GradSlot* GradSlotPtr::operator->() { return &grad_fn->slots[idx]; } @@ -537,7 +541,10 @@ void GradKey::backward(std::vector tensors, std::vectordsts[i]->grad, std::forward(g)); + auto& dst = grad_fn->dsts[i]; + if (dst) { + accum_grad(dst->grad, std::forward(g)); + } }; std::visit([&](auto&& backward) { using T = std::decay_t; diff --git a/imperative/python/src/grad_info.h b/imperative/python/src/grad_info.h index 676b598bf..9a1c9d939 100644 --- a/imperative/python/src/grad_info.h +++ b/imperative/python/src/grad_info.h @@ -22,6 +22,7 @@ struct GradSlotPtr { std::shared_ptr grad_fn; size_t idx; + operator bool() const; GradSlot* operator->(); }; -- GitLab