diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index 55ea3a31eb3181dc2a12175badd28441339e02ed..05d8e259cff10c383b417c1f73bc7e8bc27150e8 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -336,4 +336,15 @@ struct KernelImpl { }; }; +inline bool recompute_reduce_all(const DenseTensor& x, + const IntArray& dims, + bool reduce_all = false) { + if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size() || + reduce_all) { + return true; + } else { + return false; + } +} + } // namespace phi diff --git a/paddle/phi/kernels/cpu/prod_kernel.cc b/paddle/phi/kernels/cpu/prod_kernel.cc index af5ea5cb9568d2b0daebdaac5aa090f8e9ff666b..d5a07c0057dd710fd1ab2ccbc7d7923c007b0b31 100644 --- a/paddle/phi/kernels/cpu/prod_kernel.cc +++ b/paddle/phi/kernels/cpu/prod_kernel.cc @@ -28,6 +28,7 @@ void ProdRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce.h b/paddle/phi/kernels/cpu/reduce.h index e5f610b9554097676afe26f94207b7fcd8cb06a2..bfcbe0eee1f60728c01663dd8121fe4f60ad0a09 100644 --- a/paddle/phi/kernels/cpu/reduce.h +++ b/paddle/phi/kernels/cpu/reduce.h @@ -30,6 +30,7 @@ void Reduce(const DeviceContext& dev_ctx, bool keep_dim, DataType out_dtype, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); // If the dims has full dim, set the reduce_all is True const int& input_dim_size = x.dims().size(); std::set dims_set(dims.begin(), dims.end()); @@ -71,6 +72,7 @@ void BoolReduceKernel(const DeviceContext& dev_ctx, bool keep_dim, bool reduce_all, phi::DenseTensor* output) { + reduce_all = recompute_reduce_all(input, dims, reduce_all); dev_ctx.template Alloc(output); // The dims has full dim, set the reduce_all is True diff --git a/paddle/phi/kernels/cpu/reduce_all_kernel.cc b/paddle/phi/kernels/cpu/reduce_all_kernel.cc index 3e8e38ee4447e67359e694700504c1041d0a15e7..60094d1345a77a9cd39833a50305b7b52292540e 100644 --- a/paddle/phi/kernels/cpu/reduce_all_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_all_kernel.cc @@ -28,6 +28,7 @@ void AllRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); phi::BoolReduceKernel( dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc index ffe9133d6d94c9cc284910038666f2bb1d37fb6c..731ee3463658072d0bfb05245cf4183f8a78df58 100644 --- a/paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_amax_grad_kernel.cc @@ -28,6 +28,7 @@ void ReduceAMaxGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel( dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/cpu/reduce_amax_kernel.cc b/paddle/phi/kernels/cpu/reduce_amax_kernel.cc index ac3b5ce762e293998788610df6df7ee658d4b4a7..72ac780e400715045206b542f971f9435737a642 100644 --- a/paddle/phi/kernels/cpu/reduce_amax_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_amax_kernel.cc @@ -28,6 +28,7 @@ void AMaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc index 6bb0e5061cc20a6a30622184436099730c2fb34a..1165e4c7545abef867894127531fdc061be06b19 100644 --- a/paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_amin_grad_kernel.cc @@ -28,6 +28,7 @@ void ReduceAMinGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel( dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/cpu/reduce_amin_kernel.cc b/paddle/phi/kernels/cpu/reduce_amin_kernel.cc index d8f090f93ffd3a8363e493f76b514107c6504a13..47aa5210f3297c86656efe8f65aa06facd95f7dd 100644 --- a/paddle/phi/kernels/cpu/reduce_amin_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_amin_kernel.cc @@ -28,6 +28,7 @@ void AMinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce_any_kernel.cc b/paddle/phi/kernels/cpu/reduce_any_kernel.cc index 4fd71f1d0b169866376664bdf2b0b89b13c120e1..553393e7dba35a825ab808562fd1b14fb83bad70 100644 --- a/paddle/phi/kernels/cpu/reduce_any_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_any_kernel.cc @@ -28,6 +28,7 @@ void AnyRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); phi::BoolReduceKernel( dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/cpu/reduce_max_kernel.cc b/paddle/phi/kernels/cpu/reduce_max_kernel.cc index b15a555a2cf4d21d4d98679463407dd28a2807a3..d71476a0f920d2424f06d024c085db1e08d333ae 100644 --- a/paddle/phi/kernels/cpu/reduce_max_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_max_kernel.cc @@ -28,6 +28,7 @@ void MaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc index 3ab8a40a85e55679d9ab681ce3e29248a0b63e1e..b19f6ebdad806664a44175ed74292df71b36759e 100644 --- a/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc @@ -28,6 +28,7 @@ void ReduceMeanGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel(dev_ctx, x, paddle::none, diff --git a/paddle/phi/kernels/cpu/reduce_mean_kernel.cc b/paddle/phi/kernels/cpu/reduce_mean_kernel.cc index 7164ec8b2bf9914e151f036942938214d2cc2b8d..2ab1b3e5a47392f26e3488025b2425fa14bc5fcd 100644 --- a/paddle/phi/kernels/cpu/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_mean_kernel.cc @@ -28,6 +28,7 @@ void MeanRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce_min_kernel.cc b/paddle/phi/kernels/cpu/reduce_min_kernel.cc index a11de5ea81ab65380a81f095745cda56e7f1d1f9..286951f6720d7bec39ecf689020fb8e1a46139a1 100644 --- a/paddle/phi/kernels/cpu/reduce_min_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_min_kernel.cc @@ -28,6 +28,7 @@ void MinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc index 87e3df717b2442f5c41aec6ccd0f05438bc9d02a..e7d73611cf041f88880da9a2f01426bf331c7321 100644 --- a/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc @@ -77,6 +77,7 @@ void ReduceSumGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); if (dims.size() == 1) { if (out_grad.dtype() != x.dtype()) { DenseTensorMeta x_grad_meta( diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 1b1a55b25c5eced5f3146ad9f7277e98a76cbf17..be64e3c7db7ddcb99d83f16b6f857e1a8a7072c8 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -58,6 +58,7 @@ using dim3 = phi::kps::dim3; #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_utils.h" #include "paddle/phi/core/utils/array.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" diff --git a/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu b/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu index a439711f5d04505c5dee34416e765779be9e1a99..9878aa6ee231d38cfa270df69b8952156fd9f471 100644 --- a/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/frobenius_norm_kernel.cu @@ -26,6 +26,7 @@ void FrobeniusNormKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/gpu/reduce.h b/paddle/phi/kernels/gpu/reduce.h index bb914defbe8925b87abf34211751e27a84bfd319..0d6edd13ac9d6165e5df40b5da5a3fae9de09ee7 100644 --- a/paddle/phi/kernels/gpu/reduce.h +++ b/paddle/phi/kernels/gpu/reduce.h @@ -36,6 +36,7 @@ void Reduce(const KPDevice& dev_ctx, DataType out_dtype, DenseTensor* out, bool is_mean = false) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); std::vector reduce_dims = phi::funcs::details::GetReduceDim(dims, x.dims().size(), reduce_all); diff --git a/paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu index a75ef42889da2ea90994fe5f41781c9207e26e6a..db6cb2274cdc68132cb398d04a8204dcde3fbaa1 100644 --- a/paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_amax_grad_kernel.cu @@ -28,6 +28,7 @@ void ReduceAMaxGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceCudaAMaxAMinGrad( dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h index 5d90433ad22e37f5eb3e3262cfb274a60958d632..ed6e0ef51558ac75ab95e6640d58604e018948c0 100644 --- a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h +++ b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h @@ -32,15 +32,13 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto* in_x = &x; auto* out_y = &out; auto* d_out = &out_grad; auto* d_x = x_grad; // get reduce_dim and reduce_num for reduce_mean_grad int dim_size = in_x->dims().size(); - if (dims.size() == 0) { - reduce_all = true; - } auto reduce_dims = funcs::details::GetReduceDim(dims, dim_size, reduce_all); auto update_dims = vectorize(d_x->dims()); int reduce_num = 1; diff --git a/paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu index 152ef494b4c13090f09dafc42d1da3f16229e541..58598cae56a35e3ca9c1f12e3ba3316afb5bac51 100644 --- a/paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_amin_grad_kernel.cu @@ -29,6 +29,7 @@ void ReduceAMinGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceCudaAMaxAMinGrad( dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/gpu/reduce_grad.h b/paddle/phi/kernels/gpu/reduce_grad.h index ed6cc0c3c2022c93f22b9f67ef3509456e156873..01f91924645fae9b8a1a4082ff3383457fa97ade 100644 --- a/paddle/phi/kernels/gpu/reduce_grad.h +++ b/paddle/phi/kernels/gpu/reduce_grad.h @@ -52,6 +52,7 @@ void ReduceGradKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* x_grad, Functor functor) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto* in_x = &x; auto* d_out = &out_grad; auto* d_x = x_grad; diff --git a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu index 40c317e1262c5b44f0e74ce53e38ae4faa3abd96..d7b3adfcd6f48d791aea6aaa864082e48fe493ec 100644 --- a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu @@ -29,11 +29,9 @@ void ReduceMeanGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); // get reduce_dim and reduce_num for reduce_mean_grad int dim_size = x.dims().size(); - if (dims.size() == 0) { - reduce_all = true; - } std::vector reduce_dims = funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all); diff --git a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu index 74209afe374673b90610361e4c361aeab0a7c760..04b3253178902f85462362a39f9485a6d0eadf11 100644 --- a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu @@ -29,11 +29,9 @@ void ReduceSumGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); // get reduce_dim for reduce_mean_grad int dim_size = x.dims().size(); - if (dims.size() == 0) { - reduce_all = true; - } std::vector reduce_dims = funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all); diff --git a/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h b/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h index 96cf08af9634fde2e5f3ab008105a158553bb488..385ea68e6e70751f7b57b7871cd2e8cb6fd69489 100644 --- a/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/frobenius_norm_grad_kernel_impl.h @@ -29,6 +29,7 @@ void FrobeniusNormGradKernel(const Context& ctx, bool keep_dim, bool reduce_all, DenseTensor* dx) { + reduce_all = recompute_reduce_all(x, axis, reduce_all); ReduceGradKernel( ctx, x, out, dout, axis, keep_dim, reduce_all, dx); } diff --git a/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h b/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h index d1de47e128e57db956b559ba064256fb19157902..7dbc3ab3af7ba68e464605041979d62c56334125 100644 --- a/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h +++ b/paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h @@ -27,6 +27,7 @@ void FrobeniusNormKernel(const Context& ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, axis, reduce_all); Reduce( ctx, x, reduce_all, axis, keep_dim, x.dtype(), out); } diff --git a/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h b/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h index 098503f82cd208aee3d9d86732c1cd1ab51e1c41..0db6c12d4a07c20be85762e4ec45fe34d4c67b5a 100644 --- a/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h @@ -60,9 +60,7 @@ void LogsumexpGradKernel(const Context& dev_ctx, DenseTensor* in_grad) { dev_ctx.template Alloc(in_grad); - if (axis.size() == 0 || static_cast(axis.size()) == in.dims().size()) { - reduce_all = true; - } + reduce_all = recompute_reduce_all(in, axis, reduce_all); if (reduce_all) { auto x = phi::EigenVector::Flatten(in); diff --git a/paddle/phi/kernels/impl/logsumexp_kernel_impl.h b/paddle/phi/kernels/impl/logsumexp_kernel_impl.h index 0d16dc7baf62157c3e842e9f6ffddfb55b112cfa..cc5057396265c4f0cf3a1e9edf18b6b67285d4cc 100644 --- a/paddle/phi/kernels/impl/logsumexp_kernel_impl.h +++ b/paddle/phi/kernels/impl/logsumexp_kernel_impl.h @@ -69,9 +69,7 @@ void LogsumexpKernel(const Context& dev_ctx, DenseTensor* out) { dev_ctx.template Alloc(out); - if (axis.size() == 0 || static_cast(axis.size()) == x.dims().size()) { - reduce_all = true; - } + reduce_all = recompute_reduce_all(x, axis, reduce_all); if (reduce_all) { // Flatten and reduce 1-D tensor diff --git a/paddle/phi/kernels/impl/prod_grad_kernel_impl.h b/paddle/phi/kernels/impl/prod_grad_kernel_impl.h index 13f517c072c15f2d3b0d87e3cbfc24c4f8f45ea2..208e3362de48a894a2337992b4a50ee3f1b110c7 100644 --- a/paddle/phi/kernels/impl/prod_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/prod_grad_kernel_impl.h @@ -30,6 +30,7 @@ void ProdGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel( dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/impl/reduce_grad.h b/paddle/phi/kernels/impl/reduce_grad.h index 40b62cc83fa73e96d01f83061587e104bb206e52..e9d1aec0f09c5e3c6420bb4f50e5947308ed9012 100644 --- a/paddle/phi/kernels/impl/reduce_grad.h +++ b/paddle/phi/kernels/impl/reduce_grad.h @@ -34,6 +34,7 @@ void ComputeFromInput(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto* input0 = &x; auto* input1 = out.get_ptr(); auto* output = x_grad; @@ -91,9 +92,8 @@ void ReduceGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { - if (dims.size() == 0) { - reduce_all = true; - } + reduce_all = recompute_reduce_all(x, dims, reduce_all); + if (x.dtype() != out_grad.dtype()) { DenseTensorMeta x_grad_meta( out_grad.dtype(), x_grad->dims(), x_grad->layout()); diff --git a/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h b/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h index 33730a3717781ea4ac988a26cd29dc5614e45aaf..1d73b582ea0f5018268b756d2c8a39260865521c 100644 --- a/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h @@ -29,6 +29,7 @@ void ReduceMaxGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel( dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h b/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h index 93afa07ff01af343c984b618f0bfece72928622f..1f27ed10392cc5f9cb1044ea90ce08bb7c075835 100644 --- a/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h @@ -29,6 +29,7 @@ void ReduceMinGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel( dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad); } diff --git a/paddle/phi/kernels/kps/prod_kernel.cu b/paddle/phi/kernels/kps/prod_kernel.cu index 326a351f6dabb20fc492cf2f382298fd655c2f99..79dc76f81c032315e42165ab005f61e6f8c21656 100644 --- a/paddle/phi/kernels/kps/prod_kernel.cu +++ b/paddle/phi/kernels/kps/prod_kernel.cu @@ -25,6 +25,7 @@ void ProdRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_all_kernel.cu b/paddle/phi/kernels/kps/reduce_all_kernel.cu index 0459acd9822694a5a2bfe8d649e21ca9139d1b8c..d4d4596917bf8563d76a8d6195bce0ed18ed1735 100644 --- a/paddle/phi/kernels/kps/reduce_all_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_all_kernel.cu @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/reduce_all_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/gpu/reduce.h" -#include "paddle/phi/kernels/reduce_all_kernel.h" namespace phi { @@ -25,6 +25,7 @@ void AllRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_amax_kernel.cu b/paddle/phi/kernels/kps/reduce_amax_kernel.cu index 57197fd9d5b8a24f87d8a41e7a18c4f8f3637656..f762a30638f05b39487c14bcc0fc7f6394355201 100644 --- a/paddle/phi/kernels/kps/reduce_amax_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_amax_kernel.cu @@ -25,6 +25,7 @@ void AMaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_amin_kernel.cu b/paddle/phi/kernels/kps/reduce_amin_kernel.cu index 230adcc829441824b5b9341da83504d884120b34..e5d15b337fa0408e4a191f7e3637b006d45d01d1 100644 --- a/paddle/phi/kernels/kps/reduce_amin_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_amin_kernel.cu @@ -25,6 +25,7 @@ void AMinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_any_kernel.cu b/paddle/phi/kernels/kps/reduce_any_kernel.cu index 480268936f49f17d218640ff49c115df3158c37f..3210f23c3b205951cdf2116a4229f253f3d0b801 100644 --- a/paddle/phi/kernels/kps/reduce_any_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_any_kernel.cu @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/reduce_any_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/gpu/reduce.h" -#include "paddle/phi/kernels/reduce_any_kernel.h" namespace phi { @@ -25,6 +25,7 @@ void AnyRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_max_kernel.cu b/paddle/phi/kernels/kps/reduce_max_kernel.cu index fb47b64f6ecec6e49592068b889aa0a3f7c3b824..9c0fdb52c4279026fdc4f25584c7aa9d1e437ea3 100644 --- a/paddle/phi/kernels/kps/reduce_max_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_max_kernel.cu @@ -25,6 +25,7 @@ void MaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_mean_kernel.cu b/paddle/phi/kernels/kps/reduce_mean_kernel.cu index 7f7946e0300631a8981aae15ba3dd1386552bba4..8fc63b2256db969e2ecd845f30e552652a0ad3a5 100644 --- a/paddle/phi/kernels/kps/reduce_mean_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_mean_kernel.cu @@ -25,6 +25,7 @@ void MeanRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out, true); diff --git a/paddle/phi/kernels/kps/reduce_min_kernel.cu b/paddle/phi/kernels/kps/reduce_min_kernel.cu index 9c3e61d3c0bc531b37b7467dc46ed576700a59d2..450fee16b4ca997babac02d7971a0e4982b565fb 100644 --- a/paddle/phi/kernels/kps/reduce_min_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_min_kernel.cu @@ -25,6 +25,7 @@ void MinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); diff --git a/paddle/phi/kernels/kps/reduce_sum_kernel.cu b/paddle/phi/kernels/kps/reduce_sum_kernel.cu index c5a30a6a634a8ea384c0dcb847b8cb77f16f197a..e6030db8aa325ee4a160a0810d0c0a083cb127ca 100644 --- a/paddle/phi/kernels/kps/reduce_sum_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_sum_kernel.cu @@ -35,6 +35,7 @@ void ReduceSumEigen(const KPDevice& dev_ctx, DataType out_dtype, DenseTensor* out, std::vector* reduce_dims) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); // Resize Input Tensor auto new_x = x; int added_dims = EigenDimSize - x.dims().size(); @@ -79,6 +80,7 @@ void SumRawKernel(const Context& dev_ctx, bool reduce_all, DataType out_dtype, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) { out_dtype = out->dtype(); } diff --git a/paddle/phi/kernels/onednn/reduce_kernel_impl.h b/paddle/phi/kernels/onednn/reduce_kernel_impl.h index 4665876469cd591f76eb3e4efe23bb770fdeed4e..7a2f66ec984e5821462e25226b54e838eda4d9f2 100644 --- a/paddle/phi/kernels/onednn/reduce_kernel_impl.h +++ b/paddle/phi/kernels/onednn/reduce_kernel_impl.h @@ -46,6 +46,7 @@ void ReduceKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* out, dnnl::algorithm reduction_type) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); const auto& onednn_engine = dev_ctx.GetEngine(); auto x_tz = vectorize(x.dims()); auto out_tz = @@ -116,6 +117,7 @@ void ReduceGradKernel(const Context& dev_ctx, dnnl::algorithm reduction_type, float scale_x, float scale_y) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); const auto& onednn_engine = dev_ctx.GetEngine(); auto out_grad_tz = CalculateReducedDims( x_grad, &out_grad, dims.GetData(), reduce_all, keep_dim); diff --git a/paddle/phi/kernels/onednn/reduce_max_kernel.cc b/paddle/phi/kernels/onednn/reduce_max_kernel.cc index 9e3932d7f0b5c0d79cb59cf89321b2ce4de19ae7..3ece76367598a68a01e47fad0a0df725972c9b31 100644 --- a/paddle/phi/kernels/onednn/reduce_max_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_max_kernel.cc @@ -24,6 +24,7 @@ void MaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceKernel(dev_ctx, x, dims, diff --git a/paddle/phi/kernels/onednn/reduce_mean_grad_kernel.cc b/paddle/phi/kernels/onednn/reduce_mean_grad_kernel.cc index 4395126821bad9244fc9dc057e21548ab7fbc5e2..fd566782b182e79b53e1a568c55d2154861d0574 100644 --- a/paddle/phi/kernels/onednn/reduce_mean_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_mean_grad_kernel.cc @@ -25,6 +25,7 @@ void MeanGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); auto input_dims = phi::vectorize(x.dims()); std::vector reduce_dims = dims.GetData(); int number_of_elements = 1; diff --git a/paddle/phi/kernels/onednn/reduce_mean_kernel.cc b/paddle/phi/kernels/onednn/reduce_mean_kernel.cc index 22e6b3f87b1f5c7bb5c31a474b6a9883208bf3bc..a6d72c03e776735e2a9a1425e2f9239a1e4ae904 100644 --- a/paddle/phi/kernels/onednn/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_mean_kernel.cc @@ -24,6 +24,7 @@ void MeanRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceKernel(dev_ctx, x, dims, diff --git a/paddle/phi/kernels/onednn/reduce_min_kernel.cc b/paddle/phi/kernels/onednn/reduce_min_kernel.cc index 177e588d38ef6d3a01c096e55e75bc87e9c4d913..d5985efcbaac3c00bdeedc01d13cfbd28a5772cb 100644 --- a/paddle/phi/kernels/onednn/reduce_min_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_min_kernel.cc @@ -24,6 +24,7 @@ void MinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceKernel(dev_ctx, x, dims, diff --git a/paddle/phi/kernels/onednn/reduce_sum_grad_kernel.cc b/paddle/phi/kernels/onednn/reduce_sum_grad_kernel.cc index cd21d36cba7b1880abe769de36c270abc814300e..10b914a2005cd836a540fb06dcda5bc0edd9addb 100644 --- a/paddle/phi/kernels/onednn/reduce_sum_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_sum_grad_kernel.cc @@ -25,6 +25,7 @@ void SumGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceGradKernel(dev_ctx, x, out_grad, diff --git a/paddle/phi/kernels/onednn/reduce_sum_kernel.cc b/paddle/phi/kernels/onednn/reduce_sum_kernel.cc index e5b1d8b6fb43299acd45ae0260750215e96d06f3..81e77546b490a2ff708f588121b131958bf67fcd 100644 --- a/paddle/phi/kernels/onednn/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/onednn/reduce_sum_kernel.cc @@ -25,6 +25,7 @@ void SumRawKernel(const Context& dev_ctx, bool reduce_all, DataType out_dtype, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); ReduceKernel(dev_ctx, x, dims, diff --git a/paddle/phi/kernels/prod_kernel.cc b/paddle/phi/kernels/prod_kernel.cc index 532b6fdaa141fbe3b2109ac32a8f9befeb3ba822..1fce5167da95830b8b969670232d226310d7ef3b 100644 --- a/paddle/phi/kernels/prod_kernel.cc +++ b/paddle/phi/kernels/prod_kernel.cc @@ -25,7 +25,7 @@ void ProdKernel(const Context& dev_ctx, const IntArray& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; + bool reduce_all = false; // recompute_reduce_all(x, dims); ProdRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_all_kernel.cc b/paddle/phi/kernels/reduce_all_kernel.cc index 5b8d2cbecca5f5c3b85de2f9bac96bc51a0e319b..e1651f12c1c847b51b17844484d92f6b939b4f71 100644 --- a/paddle/phi/kernels/reduce_all_kernel.cc +++ b/paddle/phi/kernels/reduce_all_kernel.cc @@ -25,10 +25,7 @@ void AllKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size()) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); AllRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_amax_kernel.cc b/paddle/phi/kernels/reduce_amax_kernel.cc index 47b5e97467fe747474630c2a80f583a89090d01d..87e432c5c20a7b0da0466042b9bc695f24f605fd 100644 --- a/paddle/phi/kernels/reduce_amax_kernel.cc +++ b/paddle/phi/kernels/reduce_amax_kernel.cc @@ -25,10 +25,7 @@ void AMaxKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); AMaxRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_amin_kernel.cc b/paddle/phi/kernels/reduce_amin_kernel.cc index 8da4f3afd9f43407798fa2dec523fa7c64586593..a355da64230dcb380471a5b0ed8edae1f80acbdd 100644 --- a/paddle/phi/kernels/reduce_amin_kernel.cc +++ b/paddle/phi/kernels/reduce_amin_kernel.cc @@ -25,10 +25,7 @@ void AMinKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); AMinRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_any_kernel.cc b/paddle/phi/kernels/reduce_any_kernel.cc index cc70e3968067c81d8abbaf571e6e21ac5a6b0976..2baa1edb094b9380fa77f5b1d7839ba1b52b411c 100644 --- a/paddle/phi/kernels/reduce_any_kernel.cc +++ b/paddle/phi/kernels/reduce_any_kernel.cc @@ -25,10 +25,7 @@ void AnyKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size()) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); AnyRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_max_kernel.cc b/paddle/phi/kernels/reduce_max_kernel.cc index 64079cb2aefad841b0e583147d475dc1b37da123..23da5bd4cd54edf0b53bb825a2a0ebbe7b03d9e3 100644 --- a/paddle/phi/kernels/reduce_max_kernel.cc +++ b/paddle/phi/kernels/reduce_max_kernel.cc @@ -25,10 +25,7 @@ void MaxKernel(const Context& dev_ctx, const IntArray& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size()) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); MaxRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_mean_kernel.cc b/paddle/phi/kernels/reduce_mean_kernel.cc index aa615a6bb1ef1ccaecf016155c0bd1a6c07a4feb..83906fdfc08536a1553a6623998d5e40a7921909 100644 --- a/paddle/phi/kernels/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/reduce_mean_kernel.cc @@ -25,10 +25,7 @@ void MeanKernel(const Context& dev_ctx, const IntArray& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); MeanRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_min_kernel.cc b/paddle/phi/kernels/reduce_min_kernel.cc index 7a14d106c3d74a3ecb22d4e62157a44cfb7334ec..660d3b753e97ea449c0ba8d8d2e07d4829c1f5cb 100644 --- a/paddle/phi/kernels/reduce_min_kernel.cc +++ b/paddle/phi/kernels/reduce_min_kernel.cc @@ -25,10 +25,7 @@ void MinKernel(const Context& dev_ctx, const IntArray& dims, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size()) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); MinRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } diff --git a/paddle/phi/kernels/reduce_sum_kernel.cc b/paddle/phi/kernels/reduce_sum_kernel.cc index 70c88c23585b3d84dca976efbe5d9c33a403027e..c6cfe42566372fe59303b30c09b84a15ab18cd39 100644 --- a/paddle/phi/kernels/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/reduce_sum_kernel.cc @@ -26,10 +26,7 @@ void SumKernel(const Context& dev_ctx, DataType out_dtype, bool keep_dim, DenseTensor* out) { - bool reduce_all = false; - if (dims.size() == 0) { - reduce_all = true; - } + bool reduce_all = recompute_reduce_all(x, dims); SumRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out); } diff --git a/paddle/phi/kernels/xpu/prod_kernel.cc b/paddle/phi/kernels/xpu/prod_kernel.cc index 7be48a8bab774fc4fe3815a39cf7f19790500cb1..cf237afb227975b7f7485262e739be9fbfac8d0c 100644 --- a/paddle/phi/kernels/xpu/prod_kernel.cc +++ b/paddle/phi/kernels/xpu/prod_kernel.cc @@ -28,6 +28,7 @@ void ProdRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); int r = XPUReduce(dev_ctx, x, dims.GetData(), diff --git a/paddle/phi/kernels/xpu/reduce.h b/paddle/phi/kernels/xpu/reduce.h index 81fe362a61a0691e26cbed6315698beefb2b4761..49c9eb5ea684fbd100fc05f84b5129b70c550d2c 100644 --- a/paddle/phi/kernels/xpu/reduce.h +++ b/paddle/phi/kernels/xpu/reduce.h @@ -33,6 +33,7 @@ int XPUReduce(const Context& dev_ctx, T*, const std::vector&, const std::vector&)> func) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); dev_ctx.template Alloc(out); const auto* x_data = x.data(); diff --git a/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc b/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc index 1bfc5ae5f877ec4b2ece64d8b097186bc29f0a46..b1561233ea1d49462cc87751b8c0b06e7816934f 100644 --- a/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc @@ -31,6 +31,7 @@ void ReduceMaxGradKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { + reduce_all = recompute_reduce_all(x, dims_arr, reduce_all); auto dims = dims_arr.GetData(); dev_ctx.template Alloc(x_grad); diff --git a/paddle/phi/kernels/xpu/reduce_max_kernel.cc b/paddle/phi/kernels/xpu/reduce_max_kernel.cc index d0994f580cfbf32b987450549d34a5bf90251ef5..8db710a24adce8b8055b4bb4fe2476e13b5fcb51 100644 --- a/paddle/phi/kernels/xpu/reduce_max_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_max_kernel.cc @@ -28,6 +28,7 @@ void MaxRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); int r = XPUReduce(dev_ctx, x, dims.GetData(), diff --git a/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc b/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc index 0c2fe9a9d9e644d01f99709a0c1239067ac0115b..afe84e43d99d140e618227863ac3532f480607b9 100644 --- a/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc @@ -31,6 +31,7 @@ void ReduceMeanGradKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* x_grad) { using XPUType = typename XPUTypeTrait::Type; + reduce_all = recompute_reduce_all(x, dims_arr, reduce_all); dev_ctx.template Alloc(x_grad); const XPUType* dy_data = reinterpret_cast(out_grad.data()); diff --git a/paddle/phi/kernels/xpu/reduce_mean_kernel.cc b/paddle/phi/kernels/xpu/reduce_mean_kernel.cc index 4af1ba2da2756f932a8e6bac59e19fb85d3bfc45..d29db35517f3725c879a6e04581744a49b7d79ec 100644 --- a/paddle/phi/kernels/xpu/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_mean_kernel.cc @@ -28,6 +28,7 @@ void MeanRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); int r = XPUReduce(dev_ctx, x, dims.GetData(), diff --git a/paddle/phi/kernels/xpu/reduce_min_kernel.cc b/paddle/phi/kernels/xpu/reduce_min_kernel.cc index c54aca1830b0a6e889c308b3fe1fe6419da13acb..e330e30becdcfee17c993ba2ee4051c029e0bad8 100644 --- a/paddle/phi/kernels/xpu/reduce_min_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_min_kernel.cc @@ -28,6 +28,7 @@ void MinRawKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); int r = XPUReduce(dev_ctx, x, dims.GetData(), diff --git a/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc b/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc index b6e4d1021e47dfda568701d19aa02de4a2e11e9a..0ba67f68bccf3c2d37d280b8266189a35e220563 100644 --- a/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc @@ -28,13 +28,11 @@ void ReduceSumGradKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* x_grad) { using XPUType = typename XPUTypeTrait::Type; + reduce_all = recompute_reduce_all(x, dims_arr, reduce_all); auto dims = dims_arr.GetData(); dev_ctx.template Alloc(x_grad); const auto* out_data = out_grad.data(); auto* x_grad_data = x_grad->data(); - if (dims_arr.size() == 0) { - reduce_all = true; - } const auto& input_dim_size = x.dims().size(); std::vector true_dims; for (size_t i = 0; i < dims.size(); ++i) { diff --git a/paddle/phi/kernels/xpu/reduce_sum_kernel.cc b/paddle/phi/kernels/xpu/reduce_sum_kernel.cc index 74c50304b1407b22eeb25b828bd0af03303148e3..952ed101cdcb8eb9616072a5ed2fc22d91cfc31d 100644 --- a/paddle/phi/kernels/xpu/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_sum_kernel.cc @@ -29,6 +29,7 @@ void SumRawKernel(const Context& dev_ctx, bool reduce_all, DataType out_dtype, DenseTensor* out) { + reduce_all = recompute_reduce_all(x, dims, reduce_all); int r = XPUReduce(dev_ctx, x, dims.GetData(),