diff --git a/paddle/phi/kernels/funcs/reduce_functor.h b/paddle/phi/kernels/funcs/reduce_functor.h index 34032e153c0496ca64cfc6ab86cfe5fe64bc37e4..e0e7ec3d403f1d0a1b0617bb254e328c50914c5f 100644 --- a/paddle/phi/kernels/funcs/reduce_functor.h +++ b/paddle/phi/kernels/funcs/reduce_functor.h @@ -177,7 +177,8 @@ struct MaxOrMinGradFunctor { auto zeros = dx->constant(0); // If there are multiple minimum or maximum elements, the subgradient of // each is the set [0, 1], and we pass gradient to all of them here. - dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros); + dx->device(place) = dy->broadcast(dim).reshape(x->dimensions()) * + equals.select(ones, zeros); } }; @@ -259,7 +260,8 @@ struct AMaxOrAMinGradFunctor { auto equal_number = mask.sum() .reshape(Eigen::array({1})) .broadcast(Eigen::array({size})); - dx->device(place) = dy->broadcast(dim) * mask / equal_number; + dx->device(place) = + dy->broadcast(dim).reshape(x->dimensions()) * mask / equal_number; return; }