From 56f15c439781b135257aab648739d2d80b6ae009 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Mon, 21 Nov 2022 16:18:11 +0800 Subject: [PATCH] refine reduce_all (#48133) * refine reduce_all --- paddle/phi/core/kernel_utils.h | 11 +++++++++++ paddle/phi/kernels/cpu/prod_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce.h | 2 ++ paddle/phi/kernels/cpu/reduce_all_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_amax_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_amin_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_any_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_max_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_mean_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_min_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc | 1 + paddle/phi/kernels/funcs/reduce_function.h | 1 + paddle/phi/kernels/gpu/frobenius_norm_kernel.cu | 1 + paddle/phi/kernels/gpu/reduce.h | 1 + paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu | 1 + paddle/phi/kernels/gpu/reduce_amin_amax_common.h | 4 +--- paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu | 1 + paddle/phi/kernels/gpu/reduce_grad.h | 1 + paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu | 4 +--- paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu | 4 +--- .../kernels/impl/frobenius_norm_grad_kernel_impl.h | 1 + paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h | 1 + paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h | 4 +--- paddle/phi/kernels/impl/logsumexp_kernel_impl.h | 4 +--- paddle/phi/kernels/impl/prod_grad_kernel_impl.h | 1 + paddle/phi/kernels/impl/reduce_grad.h | 6 +++--- paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h | 1 + paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h | 1 + paddle/phi/kernels/kps/prod_kernel.cu | 1 + paddle/phi/kernels/kps/reduce_all_kernel.cu | 3 ++- paddle/phi/kernels/kps/reduce_amax_kernel.cu | 1 + paddle/phi/kernels/kps/reduce_amin_kernel.cu | 1 + paddle/phi/kernels/kps/reduce_any_kernel.cu | 3 ++- paddle/phi/kernels/kps/reduce_max_kernel.cu | 1 + paddle/phi/kernels/kps/reduce_mean_kernel.cu | 1 + paddle/phi/kernels/kps/reduce_min_kernel.cu | 1 + paddle/phi/kernels/kps/reduce_sum_kernel.cu | 2 ++ paddle/phi/kernels/onednn/reduce_kernel_impl.h | 2 ++ paddle/phi/kernels/onednn/reduce_max_kernel.cc | 1 + paddle/phi/kernels/onednn/reduce_mean_grad_kernel.cc | 1 + paddle/phi/kernels/onednn/reduce_mean_kernel.cc | 1 + paddle/phi/kernels/onednn/reduce_min_kernel.cc | 1 + paddle/phi/kernels/onednn/reduce_sum_grad_kernel.cc | 1 + paddle/phi/kernels/onednn/reduce_sum_kernel.cc | 1 + paddle/phi/kernels/prod_kernel.cc | 2 +- paddle/phi/kernels/reduce_all_kernel.cc | 5 +---- paddle/phi/kernels/reduce_amax_kernel.cc | 5 +---- paddle/phi/kernels/reduce_amin_kernel.cc | 5 +---- paddle/phi/kernels/reduce_any_kernel.cc | 5 +---- paddle/phi/kernels/reduce_max_kernel.cc | 5 +---- paddle/phi/kernels/reduce_mean_kernel.cc | 5 +---- paddle/phi/kernels/reduce_min_kernel.cc | 5 +---- paddle/phi/kernels/reduce_sum_kernel.cc | 5 +---- paddle/phi/kernels/xpu/prod_kernel.cc | 1 + paddle/phi/kernels/xpu/reduce.h | 1 + paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc | 1 + paddle/phi/kernels/xpu/reduce_max_kernel.cc | 1 + paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc | 1 + paddle/phi/kernels/xpu/reduce_mean_kernel.cc | 1 + paddle/phi/kernels/xpu/reduce_min_kernel.cc | 1 + paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc | 4 +--- paddle/phi/kernels/xpu/reduce_sum_kernel.cc | 1 + 65 files changed, 82 insertions(+), 56 deletions(-) diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index 55ea3a31eb..05d8e259cf 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -336,4 +336,15 @@ struct KernelImpl { }; }; +inline bool recompute_reduce_all(const DenseTensor& x, + const IntArray& dims, + bool reduce_all = false) { + if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size() || + reduce_all) { + return true; + } else { + return false; + } +} + } // namespace phi diff --git a/paddle/phi/kernels/cpu/prod_kernel.cc b/paddle/phi/kernels/cpu/prod_kernel.cc index af5ea5cb95..d5a07c0057 100644 --- a/paddle/phi/kernels/cpu/prod_kernel.cc +++ b/paddle/phi/kernels/cpu/prod_kernel.cc @@ -28,6 +28,7 @@ void ProdRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce.h b/paddle/phi/kernels/cpu/reduce.h index e5f610b955..bfcbe0eee1 100644 --- a/paddle/phi/kernels/cpu/reduce.h +++ b/paddle/phi/kernels/cpu/reduce.h @@ -30,6 +30,7 @@ void Reduce(const DeviceContext& dev_ctx, bool keep_dim, DataType out_dtype, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); // If the dims has full dim, set the reduce_all is True const int& input_dim_size = x.dims().size(); std::set dims_set(dims.begin(), dims.end()); @@ -71,6 +72,7 @@ void BoolReduceKernel(const DeviceContext& dev_ctx, bool keep_dim, bool reduce_all, phi::DenseTensor* output) { + reduce_all = recompute_reduce_all(input, dims, reduce_all); dev_ctx.template Alloc(output); // The dims has full dim, set the reduce_all is True diff --git a/paddle/phi/kernels/cpu/reduce_all_kernel.cc b/paddle/phi/kernels/cpu/reduce_all_kernel.cc index 3e8e38ee44..60094d1345 100644 --- a/paddle/phi/kernels/cpu/reduce_all_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_all_kernel.cc @@ -28,6 +28,7 @@ void AllRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); phi::BoolReduceKernel( dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc index ffe9133d6d..731ee34636 100644 --- a/paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc @@ -28,6 +28,7 @@ void ReduceAMaxGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel( dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/cpu/reduce_amax_kernel.cc b/paddle/phi/kernels/cpu/reduce_amax_kernel.cc index ac3b5ce762..72ac780e40 100644 --- a/paddle/phi/kernels/cpu/reduce_amax_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_amax_kernel.cc @@ -28,6 +28,7 @@ void AMaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc index 6bb0e5061c..1165e4c754 100644 --- a/paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc @@ -28,6 +28,7 @@ void ReduceAMinGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel( dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/cpu/reduce_amin_kernel.cc b/paddle/phi/kernels/cpu/reduce_amin_kernel.cc index d8f090f93f..47aa5210f3 100644 --- a/paddle/phi/kernels/cpu/reduce_amin_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_amin_kernel.cc @@ -28,6 +28,7 @@ void AMinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce_any_kernel.cc b/paddle/phi/kernels/cpu/reduce_any_kernel.cc index 4fd71f1d0b..553393e7db 100644 --- a/paddle/phi/kernels/cpu/reduce_any_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_any_kernel.cc @@ -28,6 +28,7 @@ void AnyRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); phi::BoolReduceKernel( dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/cpu/reduce_max_kernel.cc b/paddle/phi/kernels/cpu/reduce_max_kernel.cc index b15a555a2c..d71476a0f9 100644 --- a/paddle/phi/kernels/cpu/reduce_max_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_max_kernel.cc @@ -28,6 +28,7 @@ void MaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc index 3ab8a40a85..b19f6ebdad 100644 --- a/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc @@ -28,6 +28,7 @@ void ReduceMeanGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel(dev_ctx, x, paddle::none, diff --git a/paddle/phi/kernels/cpu/reduce_mean_kernel.cc b/paddle/phi/kernels/cpu/reduce_mean_kernel.cc index 7164ec8b2b..2ab1b3e5a4 100644 --- a/paddle/phi/kernels/cpu/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_mean_kernel.cc @@ -28,6 +28,7 @@ void MeanRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce_min_kernel.cc b/paddle/phi/kernels/cpu/reduce_min_kernel.cc index a11de5ea81..286951f672 100644 --- a/paddle/phi/kernels/cpu/reduce_min_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_min_kernel.cc @@ -28,6 +28,7 @@ void MinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc index 87e3df717b..e7d73611cf 100644 --- a/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc @@ -77,6 +77,7 @@ void ReduceSumGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); if (dims.size() == 1) { if (out_grad.dtype() != x.dtype()) { DenseTensorMeta x_grad_meta( diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 1b1a55b25c..be64e3c7db 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -58,6 +58,7 @@ using dim3 = phi::kps::dim3; #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_utils.h" #include "paddle/phi/core/utils/array.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" diff --git a/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu b/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu index a439711f5d..9878aa6ee2 100644 --- a/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu @@ -26,6 +26,7 @@ void FrobeniusNormKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/gpu/reduce.h b/paddle/phi/kernels/gpu/reduce.h index bb914defbe..0d6edd13ac 100644 --- a/paddle/phi/kernels/gpu/reduce.h +++ b/paddle/phi/kernels/gpu/reduce.h @@ -36,6 +36,7 @@ void Reduce(const KPDevice& dev_ctx, DataType out_dtype, DenseTensor* out, bool is_mean = false) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); std::vector reduce_dims = phi::funcs::details::GetReduceDim(dims, x.dims().size(), reduce_all); diff --git a/paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu index a75ef42889..db6cb2274c 100644 --- a/paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu @@ -28,6 +28,7 @@ void ReduceAMaxGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceCudaAMaxAMinGrad( dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h index 5d90433ad2..ed6e0ef515 100644 --- a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h +++ b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h @@ -32,15 +32,13 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto* in_x = &x; auto* out_y = &out; auto* d_out = &out_grad; auto* d_x = x_grad; // get reduce_dim and reduce_num for reduce_mean_grad int dim_size = in_x->dims().size(); - if (dims.size() == 0) { - reduce_all = true; - } auto reduce_dims = funcs::details::GetReduceDim(dims, dim_size, reduce_all); auto update_dims = vectorize(d_x->dims()); int reduce_num = 1; diff --git a/paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu index 152ef494b4..58598cae56 100644 --- a/paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu @@ -29,6 +29,7 @@ void ReduceAMinGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceCudaAMaxAMinGrad( dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/gpu/reduce_grad.h b/paddle/phi/kernels/gpu/reduce_grad.h index ed6cc0c3c2..01f9192464 100644 --- a/paddle/phi/kernels/gpu/reduce_grad.h +++ b/paddle/phi/kernels/gpu/reduce_grad.h @@ -52,6 +52,7 @@ void ReduceGradKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* x_grad, Functor functor) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto* in_x = &x; auto* d_out = &out_grad; auto* d_x = x_grad; diff --git a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu index 40c317e126..d7b3adfcd6 100644 --- a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu @@ -29,11 +29,9 @@ void ReduceMeanGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); // get reduce_dim and reduce_num for reduce_mean_grad int dim_size = x.dims().size(); - if (dims.size() == 0) { - reduce_all = true; - } std::vector reduce_dims = funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all); diff --git a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu index 74209afe37..04b3253178 100644 --- a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu @@ -29,11 +29,9 @@ void ReduceSumGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); // get reduce_dim for reduce_mean_grad int dim_size = x.dims().size(); - if (dims.size() == 0) { - reduce_all = true; - } std::vector reduce_dims = funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all); diff --git a/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h b/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h index 96cf08af96..385ea68e6e 100644 --- a/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h @@ -29,6 +29,7 @@ void FrobeniusNormGradKernel(const Context& ctx, bool keep_dim, bool reduce_all, DenseTensor* dx) { + reduce_all = recompute_reduce_all(x, axis, reduce_all); ReduceGradKernel( ctx, x, out, dout, axis, keep_dim, reduce_all, dx); } diff --git a/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h b/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h index d1de47e128..7dbc3ab3af 100644 --- a/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h +++ b/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h @@ -27,6 +27,7 @@ void FrobeniusNormKernel(const Context& ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, axis, reduce_all); Reduce( ctx, x, reduce_all, axis, keep_dim, x.dtype(), out); } diff --git a/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h b/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h index 098503f82c..0db6c12d4a 100644 --- a/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h @@ -60,9 +60,7 @@ void LogsumexpGradKernel(const Context& dev_ctx, DenseTensor* in_grad) { dev_ctx.template Alloc(in_grad); - if (axis.size() == 0 || static_cast(axis.size()) == in.dims().size()) { - reduce_all = true; - } + reduce_all = recompute_reduce_all(in, axis, reduce_all); if (reduce_all) { auto x = phi::EigenVector::Flatten(in); diff --git a/paddle/phi/kernels/impl/logsumexp_kernel_impl.h b/paddle/phi/kernels/impl/logsumexp_kernel_impl.h index 0d16dc7baf..cc50573962 100644 --- a/paddle/phi/kernels/impl/logsumexp_kernel_impl.h +++ b/paddle/phi/kernels/impl/logsumexp_kernel_impl.h @@ -69,9 +69,7 @@ void LogsumexpKernel(const Context& dev_ctx, DenseTensor* out) { dev_ctx.template Alloc(out); - if (axis.size() == 0 || static_cast(axis.size()) == x.dims().size()) { - reduce_all = true; - } + reduce_all = recompute_reduce_all(x, axis, reduce_all); if (reduce_all) { // Flatten and reduce 1-D tensor diff --git a/paddle/phi/kernels/impl/prod_grad_kernel_impl.h b/paddle/phi/kernels/impl/prod_grad_kernel_impl.h index 13f517c072..208e3362de 100644 --- a/paddle/phi/kernels/impl/prod_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/prod_grad_kernel_impl.h @@ -30,6 +30,7 @@ void ProdGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel( dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/impl/reduce_grad.h b/paddle/phi/kernels/impl/reduce_grad.h index 40b62cc83f..e9d1aec0f0 100644 --- a/paddle/phi/kernels/impl/reduce_grad.h +++ b/paddle/phi/kernels/impl/reduce_grad.h @@ -34,6 +34,7 @@ void ComputeFromInput(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto* input0 = &x; auto* input1 = out.get_ptr(); auto* output = x_grad; @@ -91,9 +92,8 @@ void ReduceGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { - if (dims.size() == 0) { - reduce_all = true; - } + reduce_all = recompute_reduce_all(x, dims, reduce_all); + if (x.dtype() != out_grad.dtype()) { DenseTensorMeta x_grad_meta( out_grad.dtype(), x_grad->dims(), x_grad->layout()); diff --git a/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h b/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h index 33730a3717..1d73b582ea 100644 --- a/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h @@ -29,6 +29,7 @@ void ReduceMaxGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel( dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h b/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h index 93afa07ff0..1f27ed1039 100644 --- a/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h @@ -29,6 +29,7 @@ void ReduceMinGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel( dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/kps/prod_kernel.cu b/paddle/phi/kernels/kps/prod_kernel.cu index 326a351f6d..79dc76f81c 100644 --- a/paddle/phi/kernels/kps/prod_kernel.cu +++ b/paddle/phi/kernels/kps/prod_kernel.cu @@ -25,6 +25,7 @@ void ProdRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_all_kernel.cu b/paddle/phi/kernels/kps/reduce_all_kernel.cu index 0459acd982..d4d4596917 100644 --- a/paddle/phi/kernels/kps/reduce_all_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_all_kernel.cu @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/reduce_all_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/gpu/reduce.h" -#include "paddle/phi/kernels/reduce_all_kernel.h" namespace phi { @@ -25,6 +25,7 @@ void AllRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_amax_kernel.cu b/paddle/phi/kernels/kps/reduce_amax_kernel.cu index 57197fd9d5..f762a30638 100644 --- a/paddle/phi/kernels/kps/reduce_amax_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_amax_kernel.cu @@ -25,6 +25,7 @@ void AMaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_amin_kernel.cu b/paddle/phi/kernels/kps/reduce_amin_kernel.cu index 230adcc829..e5d15b337f 100644 --- a/paddle/phi/kernels/kps/reduce_amin_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_amin_kernel.cu @@ -25,6 +25,7 @@ void AMinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_any_kernel.cu b/paddle/phi/kernels/kps/reduce_any_kernel.cu index 480268936f..3210f23c3b 100644 --- a/paddle/phi/kernels/kps/reduce_any_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_any_kernel.cu @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/reduce_any_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/gpu/reduce.h" -#include "paddle/phi/kernels/reduce_any_kernel.h" namespace phi { @@ -25,6 +25,7 @@ void AnyRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_max_kernel.cu b/paddle/phi/kernels/kps/reduce_max_kernel.cu index fb47b64f6e..9c0fdb52c4 100644 --- a/paddle/phi/kernels/kps/reduce_max_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_max_kernel.cu @@ -25,6 +25,7 @@ void MaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_mean_kernel.cu b/paddle/phi/kernels/kps/reduce_mean_kernel.cu index 7f7946e030..8fc63b2256 100644 --- a/paddle/phi/kernels/kps/reduce_mean_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_mean_kernel.cu @@ -25,6 +25,7 @@ void MeanRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out, true); diff --git a/paddle/phi/kernels/kps/reduce_min_kernel.cu b/paddle/phi/kernels/kps/reduce_min_kernel.cu index 9c3e61d3c0..450fee16b4 100644 --- a/paddle/phi/kernels/kps/reduce_min_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_min_kernel.cu @@ -25,6 +25,7 @@ void MinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_sum_kernel.cu b/paddle/phi/kernels/kps/reduce_sum_kernel.cu index c5a30a6a63..e6030db8aa 100644 --- a/paddle/phi/kernels/kps/reduce_sum_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_sum_kernel.cu @@ -35,6 +35,7 @@ void ReduceSumEigen(const KPDevice& dev_ctx, DataType out_dtype, DenseTensor* out, std::vector* reduce_dims) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); // Resize Input Tensor auto new_x = x; int added_dims = EigenDimSize - x.dims().size(); @@ -79,6 +80,7 @@ void SumRawKernel(const Context& dev_ctx, bool reduce_all, DataType out_dtype, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) { out_dtype = out->dtype(); } diff --git a/paddle/phi/kernels/onednn/reduce_kernel_impl.h b/paddle/phi/kernels/onednn/reduce_kernel_impl.h index 4665876469..7a2f66ec98 100644 --- a/paddle/phi/kernels/onednn/reduce_kernel_impl.h +++ b/paddle/phi/kernels/onednn/reduce_kernel_impl.h @@ -46,6 +46,7 @@ void ReduceKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* out, dnnl::algorithm reduction_type) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); const auto& onednn_engine = dev_ctx.GetEngine(); auto x_tz = vectorize(x.dims()); auto out_tz = @@ -116,6 +117,7 @@ void ReduceGradKernel(const Context& dev_ctx, dnnl::algorithm reduction_type, float scale_x, float scale_y) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); const auto& onednn_engine = dev_ctx.GetEngine(); auto out_grad_tz = CalculateReducedDims( x_grad, &out_grad, dims.GetData(), reduce_all, keep_dim); diff --git a/paddle/phi/kernels/onednn/reduce_max_kernel.cc b/paddle/phi/kernels/onednn/reduce_max_kernel.cc index 9e3932d7f0..3ece763675 100644 --- a/paddle/phi/kernels/onednn/reduce_max_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_max_kernel.cc @@ -24,6 +24,7 @@ void MaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceKernel(dev_ctx, x, dims, diff --git a/paddle/phi/kernels/onednn/reduce_mean_grad_kernel.cc b/paddle/phi/kernels/onednn/reduce_mean_grad_kernel.cc index 4395126821..fd566782b1 100644 --- a/paddle/phi/kernels/onednn/reduce_mean_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_mean_grad_kernel.cc @@ -25,6 +25,7 @@ void MeanGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto input_dims = phi::vectorize(x.dims()); std::vector reduce_dims = dims.GetData(); int number_of_elements = 1; diff --git a/paddle/phi/kernels/onednn/reduce_mean_kernel.cc b/paddle/phi/kernels/onednn/reduce_mean_kernel.cc index 22e6b3f87b..a6d72c03e7 100644 --- a/paddle/phi/kernels/onednn/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_mean_kernel.cc @@ -24,6 +24,7 @@ void MeanRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceKernel(dev_ctx, x, dims, diff --git a/paddle/phi/kernels/onednn/reduce_min_kernel.cc b/paddle/phi/kernels/onednn/reduce_min_kernel.cc index 177e588d38..d5985efcba 100644 --- a/paddle/phi/kernels/onednn/reduce_min_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_min_kernel.cc @@ -24,6 +24,7 @@ void MinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceKernel(dev_ctx, x, dims, diff --git a/paddle/phi/kernels/onednn/reduce_sum_grad_kernel.cc b/paddle/phi/kernels/onednn/reduce_sum_grad_kernel.cc index cd21d36cba..10b914a200 100644 --- a/paddle/phi/kernels/onednn/reduce_sum_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_sum_grad_kernel.cc @@ -25,6 +25,7 @@ void SumGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel(dev_ctx, x, out_grad, diff --git a/paddle/phi/kernels/onednn/reduce_sum_kernel.cc b/paddle/phi/kernels/onednn/reduce_sum_kernel.cc index e5b1d8b6fb..81e77546b4 100644 --- a/paddle/phi/kernels/onednn/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_sum_kernel.cc @@ -25,6 +25,7 @@ void SumRawKernel(const Context& dev_ctx, bool reduce_all, DataType out_dtype, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceKernel(dev_ctx, x, dims, diff --git a/paddle/phi/kernels/prod_kernel.cc b/paddle/phi/kernels/prod_kernel.cc index 532b6fdaa1..1fce5167da 100644 --- a/paddle/phi/kernels/prod_kernel.cc +++ b/paddle/phi/kernels/prod_kernel.cc @@ -25,7 +25,7 @@ void ProdKernel(const Context& dev_ctx, const IntArray& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; + bool reduce_all = false; // recompute_reduce_all(x, dims); ProdRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_all_kernel.cc b/paddle/phi/kernels/reduce_all_kernel.cc index 5b8d2cbecc..e1651f12c1 100644 --- a/paddle/phi/kernels/reduce_all_kernel.cc +++ b/paddle/phi/kernels/reduce_all_kernel.cc @@ -25,10 +25,7 @@ void AllKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size()) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); AllRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_amax_kernel.cc b/paddle/phi/kernels/reduce_amax_kernel.cc index 47b5e97467..87e432c5c2 100644 --- a/paddle/phi/kernels/reduce_amax_kernel.cc +++ b/paddle/phi/kernels/reduce_amax_kernel.cc @@ -25,10 +25,7 @@ void AMaxKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); AMaxRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_amin_kernel.cc b/paddle/phi/kernels/reduce_amin_kernel.cc index 8da4f3afd9..a355da6423 100644 --- a/paddle/phi/kernels/reduce_amin_kernel.cc +++ b/paddle/phi/kernels/reduce_amin_kernel.cc @@ -25,10 +25,7 @@ void AMinKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); AMinRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_any_kernel.cc b/paddle/phi/kernels/reduce_any_kernel.cc index cc70e39680..2baa1edb09 100644 --- a/paddle/phi/kernels/reduce_any_kernel.cc +++ b/paddle/phi/kernels/reduce_any_kernel.cc @@ -25,10 +25,7 @@ void AnyKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size()) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); AnyRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_max_kernel.cc b/paddle/phi/kernels/reduce_max_kernel.cc index 64079cb2ae..23da5bd4cd 100644 --- a/paddle/phi/kernels/reduce_max_kernel.cc +++ b/paddle/phi/kernels/reduce_max_kernel.cc @@ -25,10 +25,7 @@ void MaxKernel(const Context& dev_ctx, const IntArray& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size()) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); MaxRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_mean_kernel.cc b/paddle/phi/kernels/reduce_mean_kernel.cc index aa615a6bb1..83906fdfc0 100644 --- a/paddle/phi/kernels/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/reduce_mean_kernel.cc @@ -25,10 +25,7 @@ void MeanKernel(const Context& dev_ctx, const IntArray& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); MeanRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_min_kernel.cc b/paddle/phi/kernels/reduce_min_kernel.cc index 7a14d106c3..660d3b753e 100644 --- a/paddle/phi/kernels/reduce_min_kernel.cc +++ b/paddle/phi/kernels/reduce_min_kernel.cc @@ -25,10 +25,7 @@ void MinKernel(const Context& dev_ctx, const IntArray& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size()) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); MinRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_sum_kernel.cc b/paddle/phi/kernels/reduce_sum_kernel.cc index 70c88c2358..c6cfe42566 100644 --- a/paddle/phi/kernels/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/reduce_sum_kernel.cc @@ -26,10 +26,7 @@ void SumKernel(const Context& dev_ctx, DataType out_dtype, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); SumRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out); } diff --git a/paddle/phi/kernels/xpu/prod_kernel.cc b/paddle/phi/kernels/xpu/prod_kernel.cc index 7be48a8bab..cf237afb22 100644 --- a/paddle/phi/kernels/xpu/prod_kernel.cc +++ b/paddle/phi/kernels/xpu/prod_kernel.cc @@ -28,6 +28,7 @@ void ProdRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); int r = XPUReduce(dev_ctx, x, dims.GetData(), diff --git a/paddle/phi/kernels/xpu/reduce.h b/paddle/phi/kernels/xpu/reduce.h index 81fe362a61..49c9eb5ea6 100644 --- a/paddle/phi/kernels/xpu/reduce.h +++ b/paddle/phi/kernels/xpu/reduce.h @@ -33,6 +33,7 @@ int XPUReduce(const Context& dev_ctx, T*, const std::vector&, const std::vector&)> func) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); dev_ctx.template Alloc(out); const auto* x_data = x.data(); diff --git a/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc b/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc index 1bfc5ae5f8..b1561233ea 100644 --- a/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc @@ -31,6 +31,7 @@ void ReduceMaxGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims_arr, reduce_all); auto dims = dims_arr.GetData(); dev_ctx.template Alloc(x_grad); diff --git a/paddle/phi/kernels/xpu/reduce_max_kernel.cc b/paddle/phi/kernels/xpu/reduce_max_kernel.cc index d0994f580c..8db710a24a 100644 --- a/paddle/phi/kernels/xpu/reduce_max_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_max_kernel.cc @@ -28,6 +28,7 @@ void MaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); int r = XPUReduce(dev_ctx, x, dims.GetData(), diff --git a/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc b/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc index 0c2fe9a9d9..afe84e43d9 100644 --- a/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc @@ -31,6 +31,7 @@ void ReduceMeanGradKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* x_grad) { using XPUType = typename XPUTypeTrait::Type; + reduce_all = recompute_reduce_all(x, dims_arr, reduce_all); dev_ctx.template Alloc(x_grad); const XPUType* dy_data = reinterpret_cast(out_grad.data()); diff --git a/paddle/phi/kernels/xpu/reduce_mean_kernel.cc b/paddle/phi/kernels/xpu/reduce_mean_kernel.cc index 4af1ba2da2..d29db35517 100644 --- a/paddle/phi/kernels/xpu/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_mean_kernel.cc @@ -28,6 +28,7 @@ void MeanRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); int r = XPUReduce(dev_ctx, x, dims.GetData(), diff --git a/paddle/phi/kernels/xpu/reduce_min_kernel.cc b/paddle/phi/kernels/xpu/reduce_min_kernel.cc index c54aca1830..e330e30bec 100644 --- a/paddle/phi/kernels/xpu/reduce_min_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_min_kernel.cc @@ -28,6 +28,7 @@ void MinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); int r = XPUReduce(dev_ctx, x, dims.GetData(), diff --git a/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc b/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc index b6e4d1021e..0ba67f68bc 100644 --- a/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc @@ -28,13 +28,11 @@ void ReduceSumGradKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* x_grad) { using XPUType = typename XPUTypeTrait::Type; + reduce_all = recompute_reduce_all(x, dims_arr, reduce_all); auto dims = dims_arr.GetData(); dev_ctx.template Alloc(x_grad); const auto* out_data = out_grad.data(); auto* x_grad_data = x_grad->data(); - if (dims_arr.size() == 0) { - reduce_all = true; - } const auto& input_dim_size = x.dims().size(); std::vector true_dims; for (size_t i = 0; i < dims.size(); ++i) { diff --git a/paddle/phi/kernels/xpu/reduce_sum_kernel.cc b/paddle/phi/kernels/xpu/reduce_sum_kernel.cc index 74c50304b1..952ed101cd 100644 --- a/paddle/phi/kernels/xpu/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_sum_kernel.cc @@ -29,6 +29,7 @@ void SumRawKernel(const Context& dev_ctx, bool reduce_all, DataType out_dtype, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); int r = XPUReduce(dev_ctx, x, dims.GetData(), -- GitLab