提交 7b4b94fd 编写于 作者: M Megvii Engine Team

fix(imperative): fix the segmentfault when reduce backward

GitOrigin-RevId: 8a3e63d4f538ace0e06b3df6aaaa080633a3e525
上级 24c5c19b
......@@ -408,7 +408,7 @@ std::optional<ValueRefList> reduce_grad_rule(
[shapes = std::move(input_shapes), axis, keepdim](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
if (!keepdim) {
if (!keepdim && grad) {
auto&& grad_op = AddAxis::make(std::vector<int32_t>({axis}));
grad = imperative::apply(*grad_op, grad)[0];
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册