未验证 提交 6f7a80c3 编写于 作者: T Tao Luo 提交者: GitHub

fix amax/amin/max/min write overflow (#47570)

上级 2d058cce
......@@ -177,7 +177,8 @@ struct MaxOrMinGradFunctor {
auto zeros = dx->constant(0);
// If there are multiple minimum or maximum elements, the subgradient of
// each is the set [0, 1], and we pass gradient to all of them here.
dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros);
dx->device(place) = dy->broadcast(dim).reshape(x->dimensions()) *
equals.select(ones, zeros);
}
};
......@@ -259,7 +260,8 @@ struct AMaxOrAMinGradFunctor {
auto equal_number = mask.sum()
.reshape(Eigen::array<int, 1>({1}))
.broadcast(Eigen::array<int, 1>({size}));
dx->device(place) = dy->broadcast(dim) * mask / equal_number;
dx->device(place) =
dy->broadcast(dim).reshape(x->dimensions()) * mask / equal_number;
return;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册