From 6f7a80c34ea3f5fb8e0c063c195cbc04bdcc531d Mon Sep 17 00:00:00 2001 From: Tao Luo Date: Wed, 2 Nov 2022 14:58:47 +0800 Subject: [PATCH] fix amax/amin/max/min write overflow (#47570) --- paddle/phi/kernels/funcs/reduce_functor.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/funcs/reduce_functor.h b/paddle/phi/kernels/funcs/reduce_functor.h index 34032e153c0..e0e7ec3d403 100644 --- a/paddle/phi/kernels/funcs/reduce_functor.h +++ b/paddle/phi/kernels/funcs/reduce_functor.h @@ -177,7 +177,8 @@ struct MaxOrMinGradFunctor { auto zeros = dx->constant(0); // If there are multiple minimum or maximum elements, the subgradient of // each is the set [0, 1], and we pass gradient to all of them here. - dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros); + dx->device(place) = dy->broadcast(dim).reshape(x->dimensions()) * + equals.select(ones, zeros); } }; @@ -259,7 +260,8 @@ struct AMaxOrAMinGradFunctor { auto equal_number = mask.sum() .reshape(Eigen::array({1})) .broadcast(Eigen::array({size})); - dx->device(place) = dy->broadcast(dim) * mask / equal_number; + dx->device(place) = + dy->broadcast(dim).reshape(x->dimensions()) * mask / equal_number; return; } -- GitLab