未验证 提交 56f15c43 编写于 作者: W wanghuancoder 提交者: GitHub

refine reduce_all (#48133)

* refine reduce_all
上级 208f625b
......@@ -336,4 +336,15 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
};
};
inline bool recompute_reduce_all(const DenseTensor& x,
const IntArray& dims,
bool reduce_all = false) {
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size() ||
reduce_all) {
return true;
} else {
return false;
}
}
} // namespace phi
......@@ -28,6 +28,7 @@ void ProdRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::ProdFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
......
......@@ -30,6 +30,7 @@ void Reduce(const DeviceContext& dev_ctx,
bool keep_dim,
DataType out_dtype,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
// If the dims has full dim, set the reduce_all is True
const int& input_dim_size = x.dims().size();
std::set<int> dims_set(dims.begin(), dims.end());
......@@ -71,6 +72,7 @@ void BoolReduceKernel(const DeviceContext& dev_ctx,
bool keep_dim,
bool reduce_all,
phi::DenseTensor* output) {
reduce_all = recompute_reduce_all(input, dims, reduce_all);
dev_ctx.template Alloc<OutT>(output);
// The dims has full dim, set the reduce_all is True
......
......@@ -28,6 +28,7 @@ void AllRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
phi::BoolReduceKernel<CPUContext, T, phi::funcs::AllFunctor>(
dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -28,6 +28,7 @@ void ReduceAMaxGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceGradKernel<Context, T, funcs::AMaxOrAMinGradFunctor>(
dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
}
......
......@@ -28,6 +28,7 @@ void AMaxRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::MaxFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
......
......@@ -28,6 +28,7 @@ void ReduceAMinGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceGradKernel<Context, T, funcs::AMaxOrAMinGradFunctor>(
dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
}
......
......@@ -28,6 +28,7 @@ void AMinRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::MinFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
......
......@@ -28,6 +28,7 @@ void AnyRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
phi::BoolReduceKernel<CPUContext, T, phi::funcs::AnyFunctor>(
dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -28,6 +28,7 @@ void MaxRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::MaxFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
......
......@@ -28,6 +28,7 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceGradKernel<Context, T, funcs::MeanGradFunctor, true>(dev_ctx,
x,
paddle::none,
......
......@@ -28,6 +28,7 @@ void MeanRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::MeanFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
......
......@@ -28,6 +28,7 @@ void MinRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::MinFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
......
......@@ -77,6 +77,7 @@ void ReduceSumGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
if (dims.size() == 1) {
if (out_grad.dtype() != x.dtype()) {
DenseTensorMeta x_grad_meta(
......
......@@ -58,6 +58,7 @@ using dim3 = phi::kps::dim3;
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_utils.h"
#include "paddle/phi/core/utils/array.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
......
......@@ -26,6 +26,7 @@ void FrobeniusNormKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<T, kps::AddFunctor, kps::SquareFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
......
......@@ -36,6 +36,7 @@ void Reduce(const KPDevice& dev_ctx,
DataType out_dtype,
DenseTensor* out,
bool is_mean = false) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
std::vector<int> reduce_dims =
phi::funcs::details::GetReduceDim(dims, x.dims().size(), reduce_all);
......
......@@ -28,6 +28,7 @@ void ReduceAMaxGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceCudaAMaxAMinGrad<T, Context>(
dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
}
......
......@@ -32,15 +32,13 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto* in_x = &x;
auto* out_y = &out;
auto* d_out = &out_grad;
auto* d_x = x_grad;
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
if (dims.size() == 0) {
reduce_all = true;
}
auto reduce_dims = funcs::details::GetReduceDim(dims, dim_size, reduce_all);
auto update_dims = vectorize(d_x->dims());
int reduce_num = 1;
......
......@@ -29,6 +29,7 @@ void ReduceAMinGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceCudaAMaxAMinGrad<T, Context>(
dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad);
}
......
......@@ -52,6 +52,7 @@ void ReduceGradKernel(const Context& dev_ctx,
bool reduce_all,
DenseTensor* x_grad,
Functor functor) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto* in_x = &x;
auto* d_out = &out_grad;
auto* d_x = x_grad;
......
......@@ -29,11 +29,9 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = x.dims().size();
if (dims.size() == 0) {
reduce_all = true;
}
std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all);
......
......@@ -29,11 +29,9 @@ void ReduceSumGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
// get reduce_dim for reduce_mean_grad
int dim_size = x.dims().size();
if (dims.size() == 0) {
reduce_all = true;
}
std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all);
......
......@@ -29,6 +29,7 @@ void FrobeniusNormGradKernel(const Context& ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* dx) {
reduce_all = recompute_reduce_all(x, axis, reduce_all);
ReduceGradKernel<Context, T, funcs::FrobeniusNormGradFunctor>(
ctx, x, out, dout, axis, keep_dim, reduce_all, dx);
}
......
......@@ -27,6 +27,7 @@ void FrobeniusNormKernel(const Context& ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, axis, reduce_all);
Reduce<Context, T, funcs::FrobeniusNormFunctor>(
ctx, x, reduce_all, axis, keep_dim, x.dtype(), out);
}
......
......@@ -60,9 +60,7 @@ void LogsumexpGradKernel(const Context& dev_ctx,
DenseTensor* in_grad) {
dev_ctx.template Alloc<T>(in_grad);
if (axis.size() == 0 || static_cast<int>(axis.size()) == in.dims().size()) {
reduce_all = true;
}
reduce_all = recompute_reduce_all(in, axis, reduce_all);
if (reduce_all) {
auto x = phi::EigenVector<T>::Flatten(in);
......
......@@ -69,9 +69,7 @@ void LogsumexpKernel(const Context& dev_ctx,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
if (axis.size() == 0 || static_cast<int>(axis.size()) == x.dims().size()) {
reduce_all = true;
}
reduce_all = recompute_reduce_all(x, axis, reduce_all);
if (reduce_all) {
// Flatten and reduce 1-D tensor
......
......@@ -30,6 +30,7 @@ void ProdGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceGradKernel<Context, T, funcs::ProdGradFunctor>(
dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad);
}
......
......@@ -34,6 +34,7 @@ void ComputeFromInput(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto* input0 = &x;
auto* input1 = out.get_ptr();
auto* output = x_grad;
......@@ -91,9 +92,8 @@ void ReduceGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
if (dims.size() == 0) {
reduce_all = true;
}
reduce_all = recompute_reduce_all(x, dims, reduce_all);
if (x.dtype() != out_grad.dtype()) {
DenseTensorMeta x_grad_meta(
out_grad.dtype(), x_grad->dims(), x_grad->layout());
......
......@@ -29,6 +29,7 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(
dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad);
}
......
......@@ -29,6 +29,7 @@ void ReduceMinGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceGradKernel<Context, T, funcs::MaxOrMinGradFunctor>(
dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad);
}
......
......@@ -25,6 +25,7 @@ void ProdRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<T, kps::MulFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
......
......@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_all_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/reduce_all_kernel.h"
namespace phi {
......@@ -25,6 +25,7 @@ void AllRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<T, kps::LogicalAndFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
......
......@@ -25,6 +25,7 @@ void AMaxRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<T, kps::MaxFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
......
......@@ -25,6 +25,7 @@ void AMinRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<T, kps::MinFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
......
......@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_any_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/reduce_any_kernel.h"
namespace phi {
......@@ -25,6 +25,7 @@ void AnyRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<T, kps::LogicalOrFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
......
......@@ -25,6 +25,7 @@ void MaxRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<T, kps::MaxFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
......
......@@ -25,6 +25,7 @@ void MeanRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<T, kps::AddFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out, true);
......
......@@ -25,6 +25,7 @@ void MinRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<T, kps::MinFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
......
......@@ -35,6 +35,7 @@ void ReduceSumEigen(const KPDevice& dev_ctx,
DataType out_dtype,
DenseTensor* out,
std::vector<int>* reduce_dims) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
// Resize Input Tensor
auto new_x = x;
int added_dims = EigenDimSize - x.dims().size();
......@@ -79,6 +80,7 @@ void SumRawKernel(const Context& dev_ctx,
bool reduce_all,
DataType out_dtype,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) {
out_dtype = out->dtype();
}
......
......@@ -46,6 +46,7 @@ void ReduceKernel(const Context& dev_ctx,
bool reduce_all,
DenseTensor* out,
dnnl::algorithm reduction_type) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
const auto& onednn_engine = dev_ctx.GetEngine();
auto x_tz = vectorize(x.dims());
auto out_tz =
......@@ -116,6 +117,7 @@ void ReduceGradKernel(const Context& dev_ctx,
dnnl::algorithm reduction_type,
float scale_x,
float scale_y) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
const auto& onednn_engine = dev_ctx.GetEngine();
auto out_grad_tz = CalculateReducedDims(
x_grad, &out_grad, dims.GetData(), reduce_all, keep_dim);
......
......@@ -24,6 +24,7 @@ void MaxRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceKernel<T, Context>(dev_ctx,
x,
dims,
......
......@@ -25,6 +25,7 @@ void MeanGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto input_dims = phi::vectorize(x.dims());
std::vector<int64_t> reduce_dims = dims.GetData();
int number_of_elements = 1;
......
......@@ -24,6 +24,7 @@ void MeanRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceKernel<T, Context>(dev_ctx,
x,
dims,
......
......@@ -24,6 +24,7 @@ void MinRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceKernel<T, Context>(dev_ctx,
x,
dims,
......
......@@ -25,6 +25,7 @@ void SumGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceGradKernel<T, Context>(dev_ctx,
x,
out_grad,
......
......@@ -25,6 +25,7 @@ void SumRawKernel(const Context& dev_ctx,
bool reduce_all,
DataType out_dtype,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
ReduceKernel<T, Context>(dev_ctx,
x,
dims,
......
......@@ -25,7 +25,7 @@ void ProdKernel(const Context& dev_ctx,
const IntArray& dims,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
bool reduce_all = false; // recompute_reduce_all(x, dims);
ProdRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -25,10 +25,7 @@ void AllKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size()) {
reduce_all = true;
}
bool reduce_all = recompute_reduce_all(x, dims);
AllRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -25,10 +25,7 @@ void AMaxKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0) {
reduce_all = true;
}
bool reduce_all = recompute_reduce_all(x, dims);
AMaxRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -25,10 +25,7 @@ void AMinKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0) {
reduce_all = true;
}
bool reduce_all = recompute_reduce_all(x, dims);
AMinRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -25,10 +25,7 @@ void AnyKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size()) {
reduce_all = true;
}
bool reduce_all = recompute_reduce_all(x, dims);
AnyRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -25,10 +25,7 @@ void MaxKernel(const Context& dev_ctx,
const IntArray& dims,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size()) {
reduce_all = true;
}
bool reduce_all = recompute_reduce_all(x, dims);
MaxRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -25,10 +25,7 @@ void MeanKernel(const Context& dev_ctx,
const IntArray& dims,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0) {
reduce_all = true;
}
bool reduce_all = recompute_reduce_all(x, dims);
MeanRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -25,10 +25,7 @@ void MinKernel(const Context& dev_ctx,
const IntArray& dims,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size()) {
reduce_all = true;
}
bool reduce_all = recompute_reduce_all(x, dims);
MinRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -26,10 +26,7 @@ void SumKernel(const Context& dev_ctx,
DataType out_dtype,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0) {
reduce_all = true;
}
bool reduce_all = recompute_reduce_all(x, dims);
SumRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out);
}
......
......@@ -28,6 +28,7 @@ void ProdRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
int r = XPUReduce<Context, T>(dev_ctx,
x,
dims.GetData(),
......
......@@ -33,6 +33,7 @@ int XPUReduce(const Context& dev_ctx,
T*,
const std::vector<int>&,
const std::vector<int>&)> func) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
dev_ctx.template Alloc<T>(out);
const auto* x_data = x.data<T>();
......
......@@ -31,6 +31,7 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims_arr, reduce_all);
auto dims = dims_arr.GetData();
dev_ctx.template Alloc<T>(x_grad);
......
......@@ -28,6 +28,7 @@ void MaxRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
int r = XPUReduce<Context, T>(dev_ctx,
x,
dims.GetData(),
......
......@@ -31,6 +31,7 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
bool reduce_all,
DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
reduce_all = recompute_reduce_all(x, dims_arr, reduce_all);
dev_ctx.template Alloc<T>(x_grad);
const XPUType* dy_data = reinterpret_cast<const XPUType*>(out_grad.data<T>());
......
......@@ -28,6 +28,7 @@ void MeanRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
int r = XPUReduce<Context, T>(dev_ctx,
x,
dims.GetData(),
......
......@@ -28,6 +28,7 @@ void MinRawKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
int r = XPUReduce<Context, T>(dev_ctx,
x,
dims.GetData(),
......
......@@ -28,13 +28,11 @@ void ReduceSumGradKernel(const Context& dev_ctx,
bool reduce_all,
DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
reduce_all = recompute_reduce_all(x, dims_arr, reduce_all);
auto dims = dims_arr.GetData();
dev_ctx.template Alloc<XPUType>(x_grad);
const auto* out_data = out_grad.data<XPUType>();
auto* x_grad_data = x_grad->data<XPUType>();
if (dims_arr.size() == 0) {
reduce_all = true;
}
const auto& input_dim_size = x.dims().size();
std::vector<int> true_dims;
for (size_t i = 0; i < dims.size(); ++i) {
......
......@@ -29,6 +29,7 @@ void SumRawKernel(const Context& dev_ctx,
bool reduce_all,
DataType out_dtype,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
int r = XPUReduce<Context, T>(dev_ctx,
x,
dims.GetData(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册