// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/pten/kernels/norm_kernel.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/kernels/funcs/common_shape.h" #include "paddle/pten/kernels/funcs/eigen/eigen_function.h" namespace pten { template void NormKernel(const Context& ctx, const DenseTensor& x, int axis, float epsilon, bool is_test, DenseTensor* out, DenseTensor* norm) { auto xdim = x.dims(); T eps = epsilon; if (axis < 0) axis = xdim.size() + axis; int pre, n, post; funcs::GetPrePostNumel(xdim, axis, &pre, &n, &post); DenseTensor* out_norm; DenseTensor out_norm_tmp; if (is_test) { auto out_dim = x.dims(); out_dim[axis] = 1; out_norm = &out_norm_tmp; out_norm->Resize(out_dim); } else { out_norm = norm; } ctx.template Alloc(out); ctx.template Alloc(out_norm); auto* place = ctx.eigen_device(); Eigen::DSizes shape(pre, n, post); Eigen::DSizes norm_shape(pre, post); auto x_e = paddle::framework::EigenVector::Flatten(x); auto y_e = paddle::framework::EigenVector::Flatten(*out); auto norm_e = paddle::framework::EigenVector::Flatten(*out_norm); auto x_r = x_e.reshape(shape); auto y = y_e.reshape(shape); auto norm_reshape = norm_e.reshape(norm_shape); Eigen::DSizes rdim(1); // y = x / sqrt((sum(x * x) + epsilon)) // norm = sqrt(sum(x * x) + epsilon) auto x2 = x_r * x_r; auto sum = x2.sum(rdim) + eps; norm_reshape.device(*place) = sum.sqrt(); // y = x / norm Eigen::DSizes rshape(pre, 1, post); Eigen::DSizes bcast(1, n, 1); y.device(*place) = x_r / norm_reshape.reshape(rshape).broadcast(bcast); } } // namespace pten PT_REGISTER_KERNEL(norm, CPU, ALL_LAYOUT, pten::NormKernel, float, double) {}