未验证 提交 56f15c43 编写于 作者: W wanghuancoder 提交者: GitHub

refine reduce_all (#48133)

* refine reduce_all
上级 208f625b
...@@ -336,4 +336,15 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -336,4 +336,15 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
}; };
}; };
inline bool recompute_reduce_all(const DenseTensor& x,
const IntArray& dims,
bool reduce_all = false) {
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size() ||
reduce_all) {
return true;
} else {
return false;
}
}
} // namespace phi } // namespace phi
...@@ -28,6 +28,7 @@ void ProdRawKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void ProdRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::ProdFunctor>( phi::Reduce<CPUContext, T, phi::funcs::ProdFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
......
...@@ -30,6 +30,7 @@ void Reduce(const DeviceContext& dev_ctx, ...@@ -30,6 +30,7 @@ void Reduce(const DeviceContext& dev_ctx,
bool keep_dim, bool keep_dim,
DataType out_dtype, DataType out_dtype,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
// If the dims has full dim, set the reduce_all is True // If the dims has full dim, set the reduce_all is True
const int& input_dim_size = x.dims().size(); const int& input_dim_size = x.dims().size();
std::set<int> dims_set(dims.begin(), dims.end()); std::set<int> dims_set(dims.begin(), dims.end());
...@@ -71,6 +72,7 @@ void BoolReduceKernel(const DeviceContext& dev_ctx, ...@@ -71,6 +72,7 @@ void BoolReduceKernel(const DeviceContext& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
phi::DenseTensor* output) { phi::DenseTensor* output) {
reduce_all = recompute_reduce_all(input, dims, reduce_all);
dev_ctx.template Alloc<OutT>(output); dev_ctx.template Alloc<OutT>(output);
// The dims has full dim, set the reduce_all is True // The dims has full dim, set the reduce_all is True
......
...@@ -28,6 +28,7 @@ void AllRawKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void AllRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
phi::BoolReduceKernel<CPUContext, T, phi::funcs::AllFunctor>( phi::BoolReduceKernel<CPUContext, T, phi::funcs::AllFunctor>(
dev_ctx, x, dims, keep_dim, reduce_all, out); dev_ctx, x, dims, keep_dim, reduce_all, out);
} }
......
...@@ -28,6 +28,7 @@ void ReduceAMaxGradKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void ReduceAMaxGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceGradKernel<Context, T, funcs::AMaxOrAMinGradFunctor>( ReduceGradKernel<Context, T, funcs::AMaxOrAMinGradFunctor>(
dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
} }
......
...@@ -28,6 +28,7 @@ void AMaxRawKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void AMaxRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::MaxFunctor>( phi::Reduce<CPUContext, T, phi::funcs::MaxFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
......
...@@ -28,6 +28,7 @@ void ReduceAMinGradKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void ReduceAMinGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceGradKernel<Context, T, funcs::AMaxOrAMinGradFunctor>( ReduceGradKernel<Context, T, funcs::AMaxOrAMinGradFunctor>(
dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
} }
......
...@@ -28,6 +28,7 @@ void AMinRawKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void AMinRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::MinFunctor>( phi::Reduce<CPUContext, T, phi::funcs::MinFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
......
...@@ -28,6 +28,7 @@ void AnyRawKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void AnyRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
phi::BoolReduceKernel<CPUContext, T, phi::funcs::AnyFunctor>( phi::BoolReduceKernel<CPUContext, T, phi::funcs::AnyFunctor>(
dev_ctx, x, dims, keep_dim, reduce_all, out); dev_ctx, x, dims, keep_dim, reduce_all, out);
} }
......
...@@ -28,6 +28,7 @@ void MaxRawKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void MaxRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::MaxFunctor>( phi::Reduce<CPUContext, T, phi::funcs::MaxFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
......
...@@ -28,6 +28,7 @@ void ReduceMeanGradKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceGradKernel<Context, T, funcs::MeanGradFunctor, true>(dev_ctx, ReduceGradKernel<Context, T, funcs::MeanGradFunctor, true>(dev_ctx,
x, x,
paddle::none, paddle::none,
......
...@@ -28,6 +28,7 @@ void MeanRawKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void MeanRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::MeanFunctor>( phi::Reduce<CPUContext, T, phi::funcs::MeanFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
......
...@@ -28,6 +28,7 @@ void MinRawKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void MinRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::MinFunctor>( phi::Reduce<CPUContext, T, phi::funcs::MinFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
......
...@@ -77,6 +77,7 @@ void ReduceSumGradKernel(const Context& dev_ctx, ...@@ -77,6 +77,7 @@ void ReduceSumGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
if (dims.size() == 1) { if (dims.size() == 1) {
if (out_grad.dtype() != x.dtype()) { if (out_grad.dtype() != x.dtype()) {
DenseTensorMeta x_grad_meta( DenseTensorMeta x_grad_meta(
......
...@@ -58,6 +58,7 @@ using dim3 = phi::kps::dim3; ...@@ -58,6 +58,7 @@ using dim3 = phi::kps::dim3;
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_utils.h"
#include "paddle/phi/core/utils/array.h" #include "paddle/phi/core/utils/array.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
......
...@@ -26,6 +26,7 @@ void FrobeniusNormKernel(const Context& dev_ctx, ...@@ -26,6 +26,7 @@ void FrobeniusNormKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<T, kps::AddFunctor, kps::SquareFunctor>( phi::Reduce<T, kps::AddFunctor, kps::SquareFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
......
...@@ -36,6 +36,7 @@ void Reduce(const KPDevice& dev_ctx, ...@@ -36,6 +36,7 @@ void Reduce(const KPDevice& dev_ctx,
DataType out_dtype, DataType out_dtype,
DenseTensor* out, DenseTensor* out,
bool is_mean = false) { bool is_mean = false) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
std::vector<int> reduce_dims = std::vector<int> reduce_dims =
phi::funcs::details::GetReduceDim(dims, x.dims().size(), reduce_all); phi::funcs::details::GetReduceDim(dims, x.dims().size(), reduce_all);
......
...@@ -28,6 +28,7 @@ void ReduceAMaxGradKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void ReduceAMaxGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceCudaAMaxAMinGrad<T, Context>( ReduceCudaAMaxAMinGrad<T, Context>(
dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
} }
......
...@@ -32,15 +32,13 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx, ...@@ -32,15 +32,13 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto* in_x = &x; auto* in_x = &x;
auto* out_y = &out; auto* out_y = &out;
auto* d_out = &out_grad; auto* d_out = &out_grad;
auto* d_x = x_grad; auto* d_x = x_grad;
// get reduce_dim and reduce_num for reduce_mean_grad // get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size(); int dim_size = in_x->dims().size();
if (dims.size() == 0) {
reduce_all = true;
}
auto reduce_dims = funcs::details::GetReduceDim(dims, dim_size, reduce_all); auto reduce_dims = funcs::details::GetReduceDim(dims, dim_size, reduce_all);
auto update_dims = vectorize(d_x->dims()); auto update_dims = vectorize(d_x->dims());
int reduce_num = 1; int reduce_num = 1;
......
...@@ -29,6 +29,7 @@ void ReduceAMinGradKernel(const Context& dev_ctx, ...@@ -29,6 +29,7 @@ void ReduceAMinGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceCudaAMaxAMinGrad<T, Context>( ReduceCudaAMaxAMinGrad<T, Context>(
dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
} }
......
...@@ -52,6 +52,7 @@ void ReduceGradKernel(const Context& dev_ctx, ...@@ -52,6 +52,7 @@ void ReduceGradKernel(const Context& dev_ctx,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad, DenseTensor* x_grad,
Functor functor) { Functor functor) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto* in_x = &x; auto* in_x = &x;
auto* d_out = &out_grad; auto* d_out = &out_grad;
auto* d_x = x_grad; auto* d_x = x_grad;
......
...@@ -29,11 +29,9 @@ void ReduceMeanGradKernel(const Context& dev_ctx, ...@@ -29,11 +29,9 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
// get reduce_dim and reduce_num for reduce_mean_grad // get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = x.dims().size(); int dim_size = x.dims().size();
if (dims.size() == 0) {
reduce_all = true;
}
std::vector<int> reduce_dims = std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all); funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all);
......
...@@ -29,11 +29,9 @@ void ReduceSumGradKernel(const Context& dev_ctx, ...@@ -29,11 +29,9 @@ void ReduceSumGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
// get reduce_dim for reduce_mean_grad // get reduce_dim for reduce_mean_grad
int dim_size = x.dims().size(); int dim_size = x.dims().size();
if (dims.size() == 0) {
reduce_all = true;
}
std::vector<int> reduce_dims = std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all); funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all);
......
...@@ -29,6 +29,7 @@ void FrobeniusNormGradKernel(const Context& ctx, ...@@ -29,6 +29,7 @@ void FrobeniusNormGradKernel(const Context& ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* dx) { DenseTensor* dx) {
reduce_all = recompute_reduce_all(x, axis, reduce_all);
ReduceGradKernel<Context, T, funcs::FrobeniusNormGradFunctor>( ReduceGradKernel<Context, T, funcs::FrobeniusNormGradFunctor>(
ctx, x, out, dout, axis, keep_dim, reduce_all, dx); ctx, x, out, dout, axis, keep_dim, reduce_all, dx);
} }
......
...@@ -27,6 +27,7 @@ void FrobeniusNormKernel(const Context& ctx, ...@@ -27,6 +27,7 @@ void FrobeniusNormKernel(const Context& ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, axis, reduce_all);
Reduce<Context, T, funcs::FrobeniusNormFunctor>( Reduce<Context, T, funcs::FrobeniusNormFunctor>(
ctx, x, reduce_all, axis, keep_dim, x.dtype(), out); ctx, x, reduce_all, axis, keep_dim, x.dtype(), out);
} }
......
...@@ -60,9 +60,7 @@ void LogsumexpGradKernel(const Context& dev_ctx, ...@@ -60,9 +60,7 @@ void LogsumexpGradKernel(const Context& dev_ctx,
DenseTensor* in_grad) { DenseTensor* in_grad) {
dev_ctx.template Alloc<T>(in_grad); dev_ctx.template Alloc<T>(in_grad);
if (axis.size() == 0 || static_cast<int>(axis.size()) == in.dims().size()) { reduce_all = recompute_reduce_all(in, axis, reduce_all);
reduce_all = true;
}
if (reduce_all) { if (reduce_all) {
auto x = phi::EigenVector<T>::Flatten(in); auto x = phi::EigenVector<T>::Flatten(in);
......
...@@ -69,9 +69,7 @@ void LogsumexpKernel(const Context& dev_ctx, ...@@ -69,9 +69,7 @@ void LogsumexpKernel(const Context& dev_ctx,
DenseTensor* out) { DenseTensor* out) {
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
if (axis.size() == 0 || static_cast<int>(axis.size()) == x.dims().size()) { reduce_all = recompute_reduce_all(x, axis, reduce_all);
reduce_all = true;
}
if (reduce_all) { if (reduce_all) {
// Flatten and reduce 1-D tensor // Flatten and reduce 1-D tensor
......
...@@ -30,6 +30,7 @@ void ProdGradKernel(const Context& dev_ctx, ...@@ -30,6 +30,7 @@ void ProdGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceGradKernel<Context, T, funcs::ProdGradFunctor>( ReduceGradKernel<Context, T, funcs::ProdGradFunctor>(
dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad); dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad);
} }
......
...@@ -34,6 +34,7 @@ void ComputeFromInput(const Context& dev_ctx, ...@@ -34,6 +34,7 @@ void ComputeFromInput(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto* input0 = &x; auto* input0 = &x;
auto* input1 = out.get_ptr(); auto* input1 = out.get_ptr();
auto* output = x_grad; auto* output = x_grad;
...@@ -91,9 +92,8 @@ void ReduceGradKernel(const Context& dev_ctx, ...@@ -91,9 +92,8 @@ void ReduceGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
if (dims.size() == 0) { reduce_all = recompute_reduce_all(x, dims, reduce_all);
reduce_all = true;
}
if (x.dtype() != out_grad.dtype()) { if (x.dtype() != out_grad.dtype()) {
DenseTensorMeta x_grad_meta( DenseTensorMeta x_grad_meta(
out_grad.dtype(), x_grad->dims(), x_grad->layout()); out_grad.dtype(), x_grad->dims(), x_grad->layout());
......
...@@ -29,6 +29,7 @@ void ReduceMaxGradKernel(const Context& dev_ctx, ...@@ -29,6 +29,7 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>( ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(
dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad); dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad);
} }
......
...@@ -29,6 +29,7 @@ void ReduceMinGradKernel(const Context& dev_ctx, ...@@ -29,6 +29,7 @@ void ReduceMinGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>( ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(
dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad); dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad);
} }
......
...@@ -25,6 +25,7 @@ void ProdRawKernel(const Context& dev_ctx, ...@@ -25,6 +25,7 @@ void ProdRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<T, kps::MulFunctor, kps::IdentityFunctor>( phi::Reduce<T, kps::MulFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/reduce_all_kernel.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/reduce.h" #include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/reduce_all_kernel.h"
namespace phi { namespace phi {
...@@ -25,6 +25,7 @@ void AllRawKernel(const Context& dev_ctx, ...@@ -25,6 +25,7 @@ void AllRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<T, kps::LogicalAndFunctor, kps::IdentityFunctor>( phi::Reduce<T, kps::LogicalAndFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
......
...@@ -25,6 +25,7 @@ void AMaxRawKernel(const Context& dev_ctx, ...@@ -25,6 +25,7 @@ void AMaxRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<T, kps::MaxFunctor, kps::IdentityFunctor>( phi::Reduce<T, kps::MaxFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
......
...@@ -25,6 +25,7 @@ void AMinRawKernel(const Context& dev_ctx, ...@@ -25,6 +25,7 @@ void AMinRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<T, kps::MinFunctor, kps::IdentityFunctor>( phi::Reduce<T, kps::MinFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/reduce_any_kernel.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/reduce.h" #include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/reduce_any_kernel.h"
namespace phi { namespace phi {
...@@ -25,6 +25,7 @@ void AnyRawKernel(const Context& dev_ctx, ...@@ -25,6 +25,7 @@ void AnyRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<T, kps::LogicalOrFunctor, kps::IdentityFunctor>( phi::Reduce<T, kps::LogicalOrFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
......
...@@ -25,6 +25,7 @@ void MaxRawKernel(const Context& dev_ctx, ...@@ -25,6 +25,7 @@ void MaxRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<T, kps::MaxFunctor, kps::IdentityFunctor>( phi::Reduce<T, kps::MaxFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
......
...@@ -25,6 +25,7 @@ void MeanRawKernel(const Context& dev_ctx, ...@@ -25,6 +25,7 @@ void MeanRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<T, kps::AddFunctor, kps::IdentityFunctor>( phi::Reduce<T, kps::AddFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out, true); dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out, true);
......
...@@ -25,6 +25,7 @@ void MinRawKernel(const Context& dev_ctx, ...@@ -25,6 +25,7 @@ void MinRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype(); auto out_dtype = x.dtype();
phi::Reduce<T, kps::MinFunctor, kps::IdentityFunctor>( phi::Reduce<T, kps::MinFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
......
...@@ -35,6 +35,7 @@ void ReduceSumEigen(const KPDevice& dev_ctx, ...@@ -35,6 +35,7 @@ void ReduceSumEigen(const KPDevice& dev_ctx,
DataType out_dtype, DataType out_dtype,
DenseTensor* out, DenseTensor* out,
std::vector<int>* reduce_dims) { std::vector<int>* reduce_dims) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
// Resize Input Tensor // Resize Input Tensor
auto new_x = x; auto new_x = x;
int added_dims = EigenDimSize - x.dims().size(); int added_dims = EigenDimSize - x.dims().size();
...@@ -79,6 +80,7 @@ void SumRawKernel(const Context& dev_ctx, ...@@ -79,6 +80,7 @@ void SumRawKernel(const Context& dev_ctx,
bool reduce_all, bool reduce_all,
DataType out_dtype, DataType out_dtype,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) { if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) {
out_dtype = out->dtype(); out_dtype = out->dtype();
} }
......
...@@ -46,6 +46,7 @@ void ReduceKernel(const Context& dev_ctx, ...@@ -46,6 +46,7 @@ void ReduceKernel(const Context& dev_ctx,
bool reduce_all, bool reduce_all,
DenseTensor* out, DenseTensor* out,
dnnl::algorithm reduction_type) { dnnl::algorithm reduction_type) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
const auto& onednn_engine = dev_ctx.GetEngine(); const auto& onednn_engine = dev_ctx.GetEngine();
auto x_tz = vectorize(x.dims()); auto x_tz = vectorize(x.dims());
auto out_tz = auto out_tz =
...@@ -116,6 +117,7 @@ void ReduceGradKernel(const Context& dev_ctx, ...@@ -116,6 +117,7 @@ void ReduceGradKernel(const Context& dev_ctx,
dnnl::algorithm reduction_type, dnnl::algorithm reduction_type,
float scale_x, float scale_x,
float scale_y) { float scale_y) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
const auto& onednn_engine = dev_ctx.GetEngine(); const auto& onednn_engine = dev_ctx.GetEngine();
auto out_grad_tz = CalculateReducedDims( auto out_grad_tz = CalculateReducedDims(
x_grad, &out_grad, dims.GetData(), reduce_all, keep_dim); x_grad, &out_grad, dims.GetData(), reduce_all, keep_dim);
......
...@@ -24,6 +24,7 @@ void MaxRawKernel(const Context& dev_ctx, ...@@ -24,6 +24,7 @@ void MaxRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceKernel<T, Context>(dev_ctx, ReduceKernel<T, Context>(dev_ctx,
x, x,
dims, dims,
......
...@@ -25,6 +25,7 @@ void MeanGradKernel(const Context& dev_ctx, ...@@ -25,6 +25,7 @@ void MeanGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto input_dims = phi::vectorize(x.dims()); auto input_dims = phi::vectorize(x.dims());
std::vector<int64_t> reduce_dims = dims.GetData(); std::vector<int64_t> reduce_dims = dims.GetData();
int number_of_elements = 1; int number_of_elements = 1;
......
...@@ -24,6 +24,7 @@ void MeanRawKernel(const Context& dev_ctx, ...@@ -24,6 +24,7 @@ void MeanRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceKernel<T, Context>(dev_ctx, ReduceKernel<T, Context>(dev_ctx,
x, x,
dims, dims,
......
...@@ -24,6 +24,7 @@ void MinRawKernel(const Context& dev_ctx, ...@@ -24,6 +24,7 @@ void MinRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceKernel<T, Context>(dev_ctx, ReduceKernel<T, Context>(dev_ctx,
x, x,
dims, dims,
......
...@@ -25,6 +25,7 @@ void SumGradKernel(const Context& dev_ctx, ...@@ -25,6 +25,7 @@ void SumGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceGradKernel<T, Context>(dev_ctx, ReduceGradKernel<T, Context>(dev_ctx,
x, x,
out_grad, out_grad,
......
...@@ -25,6 +25,7 @@ void SumRawKernel(const Context& dev_ctx, ...@@ -25,6 +25,7 @@ void SumRawKernel(const Context& dev_ctx,
bool reduce_all, bool reduce_all,
DataType out_dtype, DataType out_dtype,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceKernel<T, Context>(dev_ctx, ReduceKernel<T, Context>(dev_ctx,
x, x,
dims, dims,
......
...@@ -25,7 +25,7 @@ void ProdKernel(const Context& dev_ctx, ...@@ -25,7 +25,7 @@ void ProdKernel(const Context& dev_ctx,
const IntArray& dims, const IntArray& dims,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = false; bool reduce_all = false; // recompute_reduce_all(x, dims);
ProdRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out); ProdRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
} }
......
...@@ -25,10 +25,7 @@ void AllKernel(const Context& dev_ctx, ...@@ -25,10 +25,7 @@ void AllKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = false; bool reduce_all = recompute_reduce_all(x, dims);
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size()) {
reduce_all = true;
}
AllRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out); AllRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
} }
......
...@@ -25,10 +25,7 @@ void AMaxKernel(const Context& dev_ctx, ...@@ -25,10 +25,7 @@ void AMaxKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = false; bool reduce_all = recompute_reduce_all(x, dims);
if (dims.size() == 0) {
reduce_all = true;
}
AMaxRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out); AMaxRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
} }
......
...@@ -25,10 +25,7 @@ void AMinKernel(const Context& dev_ctx, ...@@ -25,10 +25,7 @@ void AMinKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = false; bool reduce_all = recompute_reduce_all(x, dims);
if (dims.size() == 0) {
reduce_all = true;
}
AMinRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out); AMinRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
} }
......
...@@ -25,10 +25,7 @@ void AnyKernel(const Context& dev_ctx, ...@@ -25,10 +25,7 @@ void AnyKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims, const std::vector<int64_t>& dims,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = false; bool reduce_all = recompute_reduce_all(x, dims);
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size()) {
reduce_all = true;
}
AnyRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out); AnyRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
} }
......
...@@ -25,10 +25,7 @@ void MaxKernel(const Context& dev_ctx, ...@@ -25,10 +25,7 @@ void MaxKernel(const Context& dev_ctx,
const IntArray& dims, const IntArray& dims,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = false; bool reduce_all = recompute_reduce_all(x, dims);
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size()) {
reduce_all = true;
}
MaxRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out); MaxRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
} }
......
...@@ -25,10 +25,7 @@ void MeanKernel(const Context& dev_ctx, ...@@ -25,10 +25,7 @@ void MeanKernel(const Context& dev_ctx,
const IntArray& dims, const IntArray& dims,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = false; bool reduce_all = recompute_reduce_all(x, dims);
if (dims.size() == 0) {
reduce_all = true;
}
MeanRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out); MeanRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
} }
......
...@@ -25,10 +25,7 @@ void MinKernel(const Context& dev_ctx, ...@@ -25,10 +25,7 @@ void MinKernel(const Context& dev_ctx,
const IntArray& dims, const IntArray& dims,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = false; bool reduce_all = recompute_reduce_all(x, dims);
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size()) {
reduce_all = true;
}
MinRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out); MinRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
} }
......
...@@ -26,10 +26,7 @@ void SumKernel(const Context& dev_ctx, ...@@ -26,10 +26,7 @@ void SumKernel(const Context& dev_ctx,
DataType out_dtype, DataType out_dtype,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = false; bool reduce_all = recompute_reduce_all(x, dims);
if (dims.size() == 0) {
reduce_all = true;
}
SumRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out); SumRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out);
} }
......
...@@ -28,6 +28,7 @@ void ProdRawKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void ProdRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
int r = XPUReduce<Context, T>(dev_ctx, int r = XPUReduce<Context, T>(dev_ctx,
x, x,
dims.GetData(), dims.GetData(),
......
...@@ -33,6 +33,7 @@ int XPUReduce(const Context& dev_ctx, ...@@ -33,6 +33,7 @@ int XPUReduce(const Context& dev_ctx,
T*, T*,
const std::vector<int>&, const std::vector<int>&,
const std::vector<int>&)> func) { const std::vector<int>&)> func) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
const auto* x_data = x.data<T>(); const auto* x_data = x.data<T>();
......
...@@ -31,6 +31,7 @@ void ReduceMaxGradKernel(const Context& dev_ctx, ...@@ -31,6 +31,7 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims_arr, reduce_all);
auto dims = dims_arr.GetData(); auto dims = dims_arr.GetData();
dev_ctx.template Alloc<T>(x_grad); dev_ctx.template Alloc<T>(x_grad);
......
...@@ -28,6 +28,7 @@ void MaxRawKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void MaxRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
int r = XPUReduce<Context, T>(dev_ctx, int r = XPUReduce<Context, T>(dev_ctx,
x, x,
dims.GetData(), dims.GetData(),
......
...@@ -31,6 +31,7 @@ void ReduceMeanGradKernel(const Context& dev_ctx, ...@@ -31,6 +31,7 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
reduce_all = recompute_reduce_all(x, dims_arr, reduce_all);
dev_ctx.template Alloc<T>(x_grad); dev_ctx.template Alloc<T>(x_grad);
const XPUType* dy_data = reinterpret_cast<const XPUType*>(out_grad.data<T>()); const XPUType* dy_data = reinterpret_cast<const XPUType*>(out_grad.data<T>());
......
...@@ -28,6 +28,7 @@ void MeanRawKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void MeanRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
int r = XPUReduce<Context, T>(dev_ctx, int r = XPUReduce<Context, T>(dev_ctx,
x, x,
dims.GetData(), dims.GetData(),
......
...@@ -28,6 +28,7 @@ void MinRawKernel(const Context& dev_ctx, ...@@ -28,6 +28,7 @@ void MinRawKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
int r = XPUReduce<Context, T>(dev_ctx, int r = XPUReduce<Context, T>(dev_ctx,
x, x,
dims.GetData(), dims.GetData(),
......
...@@ -28,13 +28,11 @@ void ReduceSumGradKernel(const Context& dev_ctx, ...@@ -28,13 +28,11 @@ void ReduceSumGradKernel(const Context& dev_ctx,
bool reduce_all, bool reduce_all,
DenseTensor* x_grad) { DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
reduce_all = recompute_reduce_all(x, dims_arr, reduce_all);
auto dims = dims_arr.GetData(); auto dims = dims_arr.GetData();
dev_ctx.template Alloc<XPUType>(x_grad); dev_ctx.template Alloc<XPUType>(x_grad);
const auto* out_data = out_grad.data<XPUType>(); const auto* out_data = out_grad.data<XPUType>();
auto* x_grad_data = x_grad->data<XPUType>(); auto* x_grad_data = x_grad->data<XPUType>();
if (dims_arr.size() == 0) {
reduce_all = true;
}
const auto& input_dim_size = x.dims().size(); const auto& input_dim_size = x.dims().size();
std::vector<int> true_dims; std::vector<int> true_dims;
for (size_t i = 0; i < dims.size(); ++i) { for (size_t i = 0; i < dims.size(); ++i) {
......
...@@ -29,6 +29,7 @@ void SumRawKernel(const Context& dev_ctx, ...@@ -29,6 +29,7 @@ void SumRawKernel(const Context& dev_ctx,
bool reduce_all, bool reduce_all,
DataType out_dtype, DataType out_dtype,
DenseTensor* out) { DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
int r = XPUReduce<Context, T>(dev_ctx, int r = XPUReduce<Context, T>(dev_ctx,
x, x,
dims.GetData(), dims.GetData(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册