diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index be6c95f94f86a63f8e76e0316edb85a4810e1268..c576aab25fa2a89bbc004e436924837ab9931838 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -408,7 +408,7 @@ std::optional reduce_grad_rule( [shapes = std::move(input_shapes), axis, keepdim](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - if (!keepdim) { + if (!keepdim && grad) { auto&& grad_op = AddAxis::make(std::vector({axis})); grad = imperative::apply(*grad_op, grad)[0]; }