diff --git a/paddle/phi/kernels/impl/logsumexp_kernel_impl.h b/paddle/phi/kernels/impl/logsumexp_kernel_impl.h index cc5057396265c4f0cf3a1e9edf18b6b67285d4cc..9c4ee034b740958c57178b56e32702caecea358c 100644 --- a/paddle/phi/kernels/impl/logsumexp_kernel_impl.h +++ b/paddle/phi/kernels/impl/logsumexp_kernel_impl.h @@ -51,7 +51,7 @@ struct LogsumexpFunctor { auto x_mt = (*x).template cast(); auto y_dim = y->dimensions(); - auto x_max = x_mt.maximum(dim); + auto x_max = x_mt.maximum(dim).eval(); y->device(place) = (x_max + (x_mt - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log())