From 56f15c439781b135257aab648739d2d80b6ae009 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Mon, 21 Nov 2022 16:18:11 +0800 Subject: [PATCH] refine reduce_all (#48133) * refine reduce_all --- paddle/phi/core/kernel_utils.h | 11 +++++++++++ paddle/phi/kernels/cpu/prod_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce.h | 2 ++ paddle/phi/kernels/cpu/reduce_all_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_amax_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_amin_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_any_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_max_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_mean_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_min_kernel.cc | 1 + paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc | 1 + paddle/phi/kernels/funcs/reduce_function.h | 1 + paddle/phi/kernels/gpu/frobenius_norm_kernel.cu | 1 + paddle/phi/kernels/gpu/reduce.h | 1 + paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu | 1 + paddle/phi/kernels/gpu/reduce_amin_amax_common.h | 4 +--- paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu | 1 + paddle/phi/kernels/gpu/reduce_grad.h | 1 + paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu | 4 +--- paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu | 4 +--- .../kernels/impl/frobenius_norm_grad_kernel_impl.h | 1 + paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h | 1 + paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h | 4 +--- paddle/phi/kernels/impl/logsumexp_kernel_impl.h | 4 +--- paddle/phi/kernels/impl/prod_grad_kernel_impl.h | 1 + paddle/phi/kernels/impl/reduce_grad.h | 6 +++--- paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h | 1 + paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h | 1 + paddle/phi/kernels/kps/prod_kernel.cu | 1 + paddle/phi/kernels/kps/reduce_all_kernel.cu | 3 ++- paddle/phi/kernels/kps/reduce_amax_kernel.cu | 1 + paddle/phi/kernels/kps/reduce_amin_kernel.cu | 1 + paddle/phi/kernels/kps/reduce_any_kernel.cu | 3 ++- paddle/phi/kernels/kps/reduce_max_kernel.cu | 1 + paddle/phi/kernels/kps/reduce_mean_kernel.cu | 1 + paddle/phi/kernels/kps/reduce_min_kernel.cu | 1 + paddle/phi/kernels/kps/reduce_sum_kernel.cu | 2 ++ paddle/phi/kernels/onednn/reduce_kernel_impl.h | 2 ++ paddle/phi/kernels/onednn/reduce_max_kernel.cc | 1 + paddle/phi/kernels/onednn/reduce_mean_grad_kernel.cc | 1 + paddle/phi/kernels/onednn/reduce_mean_kernel.cc | 1 + paddle/phi/kernels/onednn/reduce_min_kernel.cc | 1 + paddle/phi/kernels/onednn/reduce_sum_grad_kernel.cc | 1 + paddle/phi/kernels/onednn/reduce_sum_kernel.cc | 1 + paddle/phi/kernels/prod_kernel.cc | 2 +- paddle/phi/kernels/reduce_all_kernel.cc | 5 +---- paddle/phi/kernels/reduce_amax_kernel.cc | 5 +---- paddle/phi/kernels/reduce_amin_kernel.cc | 5 +---- paddle/phi/kernels/reduce_any_kernel.cc | 5 +---- paddle/phi/kernels/reduce_max_kernel.cc | 5 +---- paddle/phi/kernels/reduce_mean_kernel.cc | 5 +---- paddle/phi/kernels/reduce_min_kernel.cc | 5 +---- paddle/phi/kernels/reduce_sum_kernel.cc | 5 +---- paddle/phi/kernels/xpu/prod_kernel.cc | 1 + paddle/phi/kernels/xpu/reduce.h | 1 + paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc | 1 + paddle/phi/kernels/xpu/reduce_max_kernel.cc | 1 + paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc | 1 + paddle/phi/kernels/xpu/reduce_mean_kernel.cc | 1 + paddle/phi/kernels/xpu/reduce_min_kernel.cc | 1 + paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc | 4 +--- paddle/phi/kernels/xpu/reduce_sum_kernel.cc | 1 + 65 files changed, 82 insertions(+), 56 deletions(-) diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index 55ea3a31eb3..05d8e259cff 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -336,4 +336,15 @@ struct KernelImpl { }; }; +inline bool recompute_reduce_all(const DenseTensor& x, + const IntArray& dims, + bool reduce_all = false) { + if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size() || + reduce_all) { + return true; + } else { + return false; + } +} + } // namespace phi diff --git a/paddle/phi/kernels/cpu/prod_kernel.cc b/paddle/phi/kernels/cpu/prod_kernel.cc index af5ea5cb956..d5a07c0057d 100644 --- a/paddle/phi/kernels/cpu/prod_kernel.cc +++ b/paddle/phi/kernels/cpu/prod_kernel.cc @@ -28,6 +28,7 @@ void ProdRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce.h b/paddle/phi/kernels/cpu/reduce.h index e5f610b9554..bfcbe0eee1f 100644 --- a/paddle/phi/kernels/cpu/reduce.h +++ b/paddle/phi/kernels/cpu/reduce.h @@ -30,6 +30,7 @@ void Reduce(const DeviceContext& dev_ctx, bool keep_dim, DataType out_dtype, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); // If the dims has full dim, set the reduce_all is True const int& input_dim_size = x.dims().size(); std::set dims_set(dims.begin(), dims.end()); @@ -71,6 +72,7 @@ void BoolReduceKernel(const DeviceContext& dev_ctx, bool keep_dim, bool reduce_all, phi::DenseTensor* output) { + reduce_all = recompute_reduce_all(input, dims, reduce_all); dev_ctx.template Alloc(output); // The dims has full dim, set the reduce_all is True diff --git a/paddle/phi/kernels/cpu/reduce_all_kernel.cc b/paddle/phi/kernels/cpu/reduce_all_kernel.cc index 3e8e38ee444..60094d1345a 100644 --- a/paddle/phi/kernels/cpu/reduce_all_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_all_kernel.cc @@ -28,6 +28,7 @@ void AllRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); phi::BoolReduceKernel( dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc index ffe9133d6d9..731ee346365 100644 --- a/paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc @@ -28,6 +28,7 @@ void ReduceAMaxGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel( dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/cpu/reduce_amax_kernel.cc b/paddle/phi/kernels/cpu/reduce_amax_kernel.cc index ac3b5ce762e..72ac780e400 100644 --- a/paddle/phi/kernels/cpu/reduce_amax_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_amax_kernel.cc @@ -28,6 +28,7 @@ void AMaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc index 6bb0e5061cc..1165e4c7545 100644 --- a/paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc @@ -28,6 +28,7 @@ void ReduceAMinGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel( dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/cpu/reduce_amin_kernel.cc b/paddle/phi/kernels/cpu/reduce_amin_kernel.cc index d8f090f93ff..47aa5210f32 100644 --- a/paddle/phi/kernels/cpu/reduce_amin_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_amin_kernel.cc @@ -28,6 +28,7 @@ void AMinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce_any_kernel.cc b/paddle/phi/kernels/cpu/reduce_any_kernel.cc index 4fd71f1d0b1..553393e7dba 100644 --- a/paddle/phi/kernels/cpu/reduce_any_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_any_kernel.cc @@ -28,6 +28,7 @@ void AnyRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); phi::BoolReduceKernel( dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/cpu/reduce_max_kernel.cc b/paddle/phi/kernels/cpu/reduce_max_kernel.cc index b15a555a2cf..d71476a0f92 100644 --- a/paddle/phi/kernels/cpu/reduce_max_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_max_kernel.cc @@ -28,6 +28,7 @@ void MaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc index 3ab8a40a85e..b19f6ebdad8 100644 --- a/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc @@ -28,6 +28,7 @@ void ReduceMeanGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel(dev_ctx, x, paddle::none, diff --git a/paddle/phi/kernels/cpu/reduce_mean_kernel.cc b/paddle/phi/kernels/cpu/reduce_mean_kernel.cc index 7164ec8b2bf..2ab1b3e5a47 100644 --- a/paddle/phi/kernels/cpu/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_mean_kernel.cc @@ -28,6 +28,7 @@ void MeanRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce_min_kernel.cc b/paddle/phi/kernels/cpu/reduce_min_kernel.cc index a11de5ea81a..286951f6720 100644 --- a/paddle/phi/kernels/cpu/reduce_min_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_min_kernel.cc @@ -28,6 +28,7 @@ void MinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc index 87e3df717b2..e7d73611cf0 100644 --- a/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc @@ -77,6 +77,7 @@ void ReduceSumGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); if (dims.size() == 1) { if (out_grad.dtype() != x.dtype()) { DenseTensorMeta x_grad_meta( diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 1b1a55b25c5..be64e3c7db7 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -58,6 +58,7 @@ using dim3 = phi::kps::dim3; #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_utils.h" #include "paddle/phi/core/utils/array.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" diff --git a/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu b/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu index a439711f5d0..9878aa6ee23 100644 --- a/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu @@ -26,6 +26,7 @@ void FrobeniusNormKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/gpu/reduce.h b/paddle/phi/kernels/gpu/reduce.h index bb914defbe8..0d6edd13ac9 100644 --- a/paddle/phi/kernels/gpu/reduce.h +++ b/paddle/phi/kernels/gpu/reduce.h @@ -36,6 +36,7 @@ void Reduce(const KPDevice& dev_ctx, DataType out_dtype, DenseTensor* out, bool is_mean = false) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); std::vector reduce_dims = phi::funcs::details::GetReduceDim(dims, x.dims().size(), reduce_all); diff --git a/paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu index a75ef42889d..db6cb2274cd 100644 --- a/paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu @@ -28,6 +28,7 @@ void ReduceAMaxGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceCudaAMaxAMinGrad( dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h index 5d90433ad22..ed6e0ef5155 100644 --- a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h +++ b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h @@ -32,15 +32,13 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto* in_x = &x; auto* out_y = &out; auto* d_out = &out_grad; auto* d_x = x_grad; // get reduce_dim and reduce_num for reduce_mean_grad int dim_size = in_x->dims().size(); - if (dims.size() == 0) { - reduce_all = true; - } auto reduce_dims = funcs::details::GetReduceDim(dims, dim_size, reduce_all); auto update_dims = vectorize(d_x->dims()); int reduce_num = 1; diff --git a/paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu index 152ef494b4c..58598cae56a 100644 --- a/paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu @@ -29,6 +29,7 @@ void ReduceAMinGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceCudaAMaxAMinGrad( dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/gpu/reduce_grad.h b/paddle/phi/kernels/gpu/reduce_grad.h index ed6cc0c3c20..01f91924645 100644 --- a/paddle/phi/kernels/gpu/reduce_grad.h +++ b/paddle/phi/kernels/gpu/reduce_grad.h @@ -52,6 +52,7 @@ void ReduceGradKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* x_grad, Functor functor) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto* in_x = &x; auto* d_out = &out_grad; auto* d_x = x_grad; diff --git a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu index 40c317e1262..d7b3adfcd6f 100644 --- a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu @@ -29,11 +29,9 @@ void ReduceMeanGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); // get reduce_dim and reduce_num for reduce_mean_grad int dim_size = x.dims().size(); - if (dims.size() == 0) { - reduce_all = true; - } std::vector reduce_dims = funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all); diff --git a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu index 74209afe374..04b32531789 100644 --- a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu @@ -29,11 +29,9 @@ void ReduceSumGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); // get reduce_dim for reduce_mean_grad int dim_size = x.dims().size(); - if (dims.size() == 0) { - reduce_all = true; - } std::vector reduce_dims = funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all); diff --git a/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h b/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h index 96cf08af963..385ea68e6e7 100644 --- a/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h @@ -29,6 +29,7 @@ void FrobeniusNormGradKernel(const Context& ctx, bool keep_dim, bool reduce_all, DenseTensor* dx) { + reduce_all = recompute_reduce_all(x, axis, reduce_all); ReduceGradKernel( ctx, x, out, dout, axis, keep_dim, reduce_all, dx); } diff --git a/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h b/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h index d1de47e128e..7dbc3ab3af7 100644 --- a/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h +++ b/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h @@ -27,6 +27,7 @@ void FrobeniusNormKernel(const Context& ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, axis, reduce_all); Reduce( ctx, x, reduce_all, axis, keep_dim, x.dtype(), out); } diff --git a/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h b/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h index 098503f82cd..0db6c12d4a0 100644 --- a/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h @@ -60,9 +60,7 @@ void LogsumexpGradKernel(const Context& dev_ctx, DenseTensor* in_grad) { dev_ctx.template Alloc(in_grad); - if (axis.size() == 0 || static_cast(axis.size()) == in.dims().size()) { - reduce_all = true; - } + reduce_all = recompute_reduce_all(in, axis, reduce_all); if (reduce_all) { auto x = phi::EigenVector::Flatten(in); diff --git a/paddle/phi/kernels/impl/logsumexp_kernel_impl.h b/paddle/phi/kernels/impl/logsumexp_kernel_impl.h index 0d16dc7baf6..cc505739626 100644 --- a/paddle/phi/kernels/impl/logsumexp_kernel_impl.h +++ b/paddle/phi/kernels/impl/logsumexp_kernel_impl.h @@ -69,9 +69,7 @@ void LogsumexpKernel(const Context& dev_ctx, DenseTensor* out) { dev_ctx.template Alloc(out); - if (axis.size() == 0 || static_cast(axis.size()) == x.dims().size()) { - reduce_all = true; - } + reduce_all = recompute_reduce_all(x, axis, reduce_all); if (reduce_all) { // Flatten and reduce 1-D tensor diff --git a/paddle/phi/kernels/impl/prod_grad_kernel_impl.h b/paddle/phi/kernels/impl/prod_grad_kernel_impl.h index 13f517c072c..208e3362de4 100644 --- a/paddle/phi/kernels/impl/prod_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/prod_grad_kernel_impl.h @@ -30,6 +30,7 @@ void ProdGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel( dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/impl/reduce_grad.h b/paddle/phi/kernels/impl/reduce_grad.h index 40b62cc83fa..e9d1aec0f09 100644 --- a/paddle/phi/kernels/impl/reduce_grad.h +++ b/paddle/phi/kernels/impl/reduce_grad.h @@ -34,6 +34,7 @@ void ComputeFromInput(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto* input0 = &x; auto* input1 = out.get_ptr(); auto* output = x_grad; @@ -91,9 +92,8 @@ void ReduceGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { - if (dims.size() == 0) { - reduce_all = true; - } + reduce_all = recompute_reduce_all(x, dims, reduce_all); + if (x.dtype() != out_grad.dtype()) { DenseTensorMeta x_grad_meta( out_grad.dtype(), x_grad->dims(), x_grad->layout()); diff --git a/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h b/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h index 33730a37177..1d73b582ea0 100644 --- a/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h @@ -29,6 +29,7 @@ void ReduceMaxGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel( dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h b/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h index 93afa07ff01..1f27ed10392 100644 --- a/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h @@ -29,6 +29,7 @@ void ReduceMinGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel( dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/kps/prod_kernel.cu b/paddle/phi/kernels/kps/prod_kernel.cu index 326a351f6da..79dc76f81c0 100644 --- a/paddle/phi/kernels/kps/prod_kernel.cu +++ b/paddle/phi/kernels/kps/prod_kernel.cu @@ -25,6 +25,7 @@ void ProdRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_all_kernel.cu b/paddle/phi/kernels/kps/reduce_all_kernel.cu index 0459acd9822..d4d4596917b 100644 --- a/paddle/phi/kernels/kps/reduce_all_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_all_kernel.cu @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/reduce_all_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/gpu/reduce.h" -#include "paddle/phi/kernels/reduce_all_kernel.h" namespace phi { @@ -25,6 +25,7 @@ void AllRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_amax_kernel.cu b/paddle/phi/kernels/kps/reduce_amax_kernel.cu index 57197fd9d5b..f762a30638f 100644 --- a/paddle/phi/kernels/kps/reduce_amax_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_amax_kernel.cu @@ -25,6 +25,7 @@ void AMaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_amin_kernel.cu b/paddle/phi/kernels/kps/reduce_amin_kernel.cu index 230adcc8294..e5d15b337fa 100644 --- a/paddle/phi/kernels/kps/reduce_amin_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_amin_kernel.cu @@ -25,6 +25,7 @@ void AMinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_any_kernel.cu b/paddle/phi/kernels/kps/reduce_any_kernel.cu index 480268936f4..3210f23c3b2 100644 --- a/paddle/phi/kernels/kps/reduce_any_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_any_kernel.cu @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/reduce_any_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/gpu/reduce.h" -#include "paddle/phi/kernels/reduce_any_kernel.h" namespace phi { @@ -25,6 +25,7 @@ void AnyRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_max_kernel.cu b/paddle/phi/kernels/kps/reduce_max_kernel.cu index fb47b64f6ec..9c0fdb52c42 100644 --- a/paddle/phi/kernels/kps/reduce_max_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_max_kernel.cu @@ -25,6 +25,7 @@ void MaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_mean_kernel.cu b/paddle/phi/kernels/kps/reduce_mean_kernel.cu index 7f7946e0300..8fc63b2256d 100644 --- a/paddle/phi/kernels/kps/reduce_mean_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_mean_kernel.cu @@ -25,6 +25,7 @@ void MeanRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out, true); diff --git a/paddle/phi/kernels/kps/reduce_min_kernel.cu b/paddle/phi/kernels/kps/reduce_min_kernel.cu index 9c3e61d3c0b..450fee16b4c 100644 --- a/paddle/phi/kernels/kps/reduce_min_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_min_kernel.cu @@ -25,6 +25,7 @@ void MinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_sum_kernel.cu b/paddle/phi/kernels/kps/reduce_sum_kernel.cu index c5a30a6a634..e6030db8aa3 100644 --- a/paddle/phi/kernels/kps/reduce_sum_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_sum_kernel.cu @@ -35,6 +35,7 @@ void ReduceSumEigen(const KPDevice& dev_ctx, DataType out_dtype, DenseTensor* out, std::vector* reduce_dims) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); // Resize Input Tensor auto new_x = x; int added_dims = EigenDimSize - x.dims().size(); @@ -79,6 +80,7 @@ void SumRawKernel(const Context& dev_ctx, bool reduce_all, DataType out_dtype, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) { out_dtype = out->dtype(); } diff --git a/paddle/phi/kernels/onednn/reduce_kernel_impl.h b/paddle/phi/kernels/onednn/reduce_kernel_impl.h index 4665876469c..7a2f66ec984 100644 --- a/paddle/phi/kernels/onednn/reduce_kernel_impl.h +++ b/paddle/phi/kernels/onednn/reduce_kernel_impl.h @@ -46,6 +46,7 @@ void ReduceKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* out, dnnl::algorithm reduction_type) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); const auto& onednn_engine = dev_ctx.GetEngine(); auto x_tz = vectorize(x.dims()); auto out_tz = @@ -116,6 +117,7 @@ void ReduceGradKernel(const Context& dev_ctx, dnnl::algorithm reduction_type, float scale_x, float scale_y) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); const auto& onednn_engine = dev_ctx.GetEngine(); auto out_grad_tz = CalculateReducedDims( x_grad, &out_grad, dims.GetData(), reduce_all, keep_dim); diff --git a/paddle/phi/kernels/onednn/reduce_max_kernel.cc b/paddle/phi/kernels/onednn/reduce_max_kernel.cc index 9e3932d7f0b..3ece7636759 100644 --- a/paddle/phi/kernels/onednn/reduce_max_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_max_kernel.cc @@ -24,6 +24,7 @@ void MaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceKernel(dev_ctx, x, dims, diff --git a/paddle/phi/kernels/onednn/reduce_mean_grad_kernel.cc b/paddle/phi/kernels/onednn/reduce_mean_grad_kernel.cc index 4395126821b..fd566782b18 100644 --- a/paddle/phi/kernels/onednn/reduce_mean_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_mean_grad_kernel.cc @@ -25,6 +25,7 @@ void MeanGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto input_dims = phi::vectorize(x.dims()); std::vector reduce_dims = dims.GetData(); int number_of_elements = 1; diff --git a/paddle/phi/kernels/onednn/reduce_mean_kernel.cc b/paddle/phi/kernels/onednn/reduce_mean_kernel.cc index 22e6b3f87b1..a6d72c03e77 100644 --- a/paddle/phi/kernels/onednn/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_mean_kernel.cc @@ -24,6 +24,7 @@ void MeanRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceKernel(dev_ctx, x, dims, diff --git a/paddle/phi/kernels/onednn/reduce_min_kernel.cc b/paddle/phi/kernels/onednn/reduce_min_kernel.cc index 177e588d38e..d5985efcbaa 100644 --- a/paddle/phi/kernels/onednn/reduce_min_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_min_kernel.cc @@ -24,6 +24,7 @@ void MinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceKernel(dev_ctx, x, dims, diff --git a/paddle/phi/kernels/onednn/reduce_sum_grad_kernel.cc b/paddle/phi/kernels/onednn/reduce_sum_grad_kernel.cc index cd21d36cba7..10b914a2005 100644 --- a/paddle/phi/kernels/onednn/reduce_sum_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_sum_grad_kernel.cc @@ -25,6 +25,7 @@ void SumGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel(dev_ctx, x, out_grad, diff --git a/paddle/phi/kernels/onednn/reduce_sum_kernel.cc b/paddle/phi/kernels/onednn/reduce_sum_kernel.cc index e5b1d8b6fb4..81e77546b49 100644 --- a/paddle/phi/kernels/onednn/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_sum_kernel.cc @@ -25,6 +25,7 @@ void SumRawKernel(const Context& dev_ctx, bool reduce_all, DataType out_dtype, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceKernel(dev_ctx, x, dims, diff --git a/paddle/phi/kernels/prod_kernel.cc b/paddle/phi/kernels/prod_kernel.cc index 532b6fdaa14..1fce5167da9 100644 --- a/paddle/phi/kernels/prod_kernel.cc +++ b/paddle/phi/kernels/prod_kernel.cc @@ -25,7 +25,7 @@ void ProdKernel(const Context& dev_ctx, const IntArray& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; + bool reduce_all = false; // recompute_reduce_all(x, dims); ProdRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_all_kernel.cc b/paddle/phi/kernels/reduce_all_kernel.cc index 5b8d2cbecca..e1651f12c1c 100644 --- a/paddle/phi/kernels/reduce_all_kernel.cc +++ b/paddle/phi/kernels/reduce_all_kernel.cc @@ -25,10 +25,7 @@ void AllKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size()) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); AllRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_amax_kernel.cc b/paddle/phi/kernels/reduce_amax_kernel.cc index 47b5e97467f..87e432c5c20 100644 --- a/paddle/phi/kernels/reduce_amax_kernel.cc +++ b/paddle/phi/kernels/reduce_amax_kernel.cc @@ -25,10 +25,7 @@ void AMaxKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); AMaxRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_amin_kernel.cc b/paddle/phi/kernels/reduce_amin_kernel.cc index 8da4f3afd9f..a355da64230 100644 --- a/paddle/phi/kernels/reduce_amin_kernel.cc +++ b/paddle/phi/kernels/reduce_amin_kernel.cc @@ -25,10 +25,7 @@ void AMinKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); AMinRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_any_kernel.cc b/paddle/phi/kernels/reduce_any_kernel.cc index cc70e396806..2baa1edb094 100644 --- a/paddle/phi/kernels/reduce_any_kernel.cc +++ b/paddle/phi/kernels/reduce_any_kernel.cc @@ -25,10 +25,7 @@ void AnyKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size()) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); AnyRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_max_kernel.cc b/paddle/phi/kernels/reduce_max_kernel.cc index 64079cb2aef..23da5bd4cd5 100644 --- a/paddle/phi/kernels/reduce_max_kernel.cc +++ b/paddle/phi/kernels/reduce_max_kernel.cc @@ -25,10 +25,7 @@ void MaxKernel(const Context& dev_ctx, const IntArray& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size()) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); MaxRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_mean_kernel.cc b/paddle/phi/kernels/reduce_mean_kernel.cc index aa615a6bb1e..83906fdfc08 100644 --- a/paddle/phi/kernels/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/reduce_mean_kernel.cc @@ -25,10 +25,7 @@ void MeanKernel(const Context& dev_ctx, const IntArray& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); MeanRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_min_kernel.cc b/paddle/phi/kernels/reduce_min_kernel.cc index 7a14d106c3d..660d3b753e9 100644 --- a/paddle/phi/kernels/reduce_min_kernel.cc +++ b/paddle/phi/kernels/reduce_min_kernel.cc @@ -25,10 +25,7 @@ void MinKernel(const Context& dev_ctx, const IntArray& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size()) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); MinRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_sum_kernel.cc b/paddle/phi/kernels/reduce_sum_kernel.cc index 70c88c23585..c6cfe425663 100644 --- a/paddle/phi/kernels/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/reduce_sum_kernel.cc @@ -26,10 +26,7 @@ void SumKernel(const Context& dev_ctx, DataType out_dtype, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); SumRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out); } diff --git a/paddle/phi/kernels/xpu/prod_kernel.cc b/paddle/phi/kernels/xpu/prod_kernel.cc index 7be48a8bab7..cf237afb227 100644 --- a/paddle/phi/kernels/xpu/prod_kernel.cc +++ b/paddle/phi/kernels/xpu/prod_kernel.cc @@ -28,6 +28,7 @@ void ProdRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); int r = XPUReduce(dev_ctx, x, dims.GetData(), diff --git a/paddle/phi/kernels/xpu/reduce.h b/paddle/phi/kernels/xpu/reduce.h index 81fe362a61a..49c9eb5ea68 100644 --- a/paddle/phi/kernels/xpu/reduce.h +++ b/paddle/phi/kernels/xpu/reduce.h @@ -33,6 +33,7 @@ int XPUReduce(const Context& dev_ctx, T*, const std::vector&, const std::vector&)> func) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); dev_ctx.template Alloc(out); const auto* x_data = x.data(); diff --git a/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc b/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc index 1bfc5ae5f87..b1561233ea1 100644 --- a/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc @@ -31,6 +31,7 @@ void ReduceMaxGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims_arr, reduce_all); auto dims = dims_arr.GetData(); dev_ctx.template Alloc(x_grad); diff --git a/paddle/phi/kernels/xpu/reduce_max_kernel.cc b/paddle/phi/kernels/xpu/reduce_max_kernel.cc index d0994f580cf..8db710a24ad 100644 --- a/paddle/phi/kernels/xpu/reduce_max_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_max_kernel.cc @@ -28,6 +28,7 @@ void MaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); int r = XPUReduce(dev_ctx, x, dims.GetData(), diff --git a/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc b/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc index 0c2fe9a9d9e..afe84e43d99 100644 --- a/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc @@ -31,6 +31,7 @@ void ReduceMeanGradKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* x_grad) { using XPUType = typename XPUTypeTrait::Type; + reduce_all = recompute_reduce_all(x, dims_arr, reduce_all); dev_ctx.template Alloc(x_grad); const XPUType* dy_data = reinterpret_cast(out_grad.data()); diff --git a/paddle/phi/kernels/xpu/reduce_mean_kernel.cc b/paddle/phi/kernels/xpu/reduce_mean_kernel.cc index 4af1ba2da27..d29db35517f 100644 --- a/paddle/phi/kernels/xpu/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_mean_kernel.cc @@ -28,6 +28,7 @@ void MeanRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); int r = XPUReduce(dev_ctx, x, dims.GetData(), diff --git a/paddle/phi/kernels/xpu/reduce_min_kernel.cc b/paddle/phi/kernels/xpu/reduce_min_kernel.cc index c54aca1830b..e330e30becd 100644 --- a/paddle/phi/kernels/xpu/reduce_min_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_min_kernel.cc @@ -28,6 +28,7 @@ void MinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); int r = XPUReduce(dev_ctx, x, dims.GetData(), diff --git a/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc b/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc index b6e4d1021e4..0ba67f68bcc 100644 --- a/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc @@ -28,13 +28,11 @@ void ReduceSumGradKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* x_grad) { using XPUType = typename XPUTypeTrait::Type; + reduce_all = recompute_reduce_all(x, dims_arr, reduce_all); auto dims = dims_arr.GetData(); dev_ctx.template Alloc(x_grad); const auto* out_data = out_grad.data(); auto* x_grad_data = x_grad->data(); - if (dims_arr.size() == 0) { - reduce_all = true; - } const auto& input_dim_size = x.dims().size(); std::vector true_dims; for (size_t i = 0; i < dims.size(); ++i) { diff --git a/paddle/phi/kernels/xpu/reduce_sum_kernel.cc b/paddle/phi/kernels/xpu/reduce_sum_kernel.cc index 74c50304b14..952ed101cdc 100644 --- a/paddle/phi/kernels/xpu/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_sum_kernel.cc @@ -29,6 +29,7 @@ void SumRawKernel(const Context& dev_ctx, bool reduce_all, DataType out_dtype, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); int r = XPUReduce(dev_ctx, x, dims.GetData(), -- GitLab