// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/layer_norm_kernel.h" #include "paddle/phi/kernels/cpu/elementwise.h" #include "paddle/phi/kernels/funcs/layer_norm_util.h" #if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \ !defined(__OSX__) #include "paddle/fluid/operators/jit/kernels.h" #endif #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { template void LayerNormKernel(const Context& dev_ctx, const DenseTensor& x, const paddle::optional& scale_opt, const paddle::optional& bias_opt, float epsilon, int begin_norm_axis, bool is_test, DenseTensor* y, DenseTensor* mean, DenseTensor* var) { const auto x_dims = x.dims(); auto* scale = scale_opt.get_ptr(); auto* bias = bias_opt.get_ptr(); dev_ctx.template Alloc(y); dev_ctx.template Alloc(mean); dev_ctx.template Alloc(var); auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); int left = static_cast(matrix_dim[0]); int right = static_cast(matrix_dim[1]); DDim matrix_shape({left, right}); auto x_tmp = x; x_tmp.Resize(matrix_shape); DenseTensor out; out.ShareDataWith(*y); out.Resize(matrix_shape); #if defined(PADDLE_WITH_CUDA) || defined(_WIN32) || defined(__APPLE__) || \ defined(__OSX__) funcs::RowwiseMean2D row_mean(left, right, dev_ctx); // get mean row_mean(dev_ctx, x_tmp, mean); // get variance phi::funcs::ElementwiseCompute, T, T>( dev_ctx, x_tmp, *mean, 0, funcs::SubAndSquareFunctor(), &out); row_mean(dev_ctx, out, var); // get x_norm phi::funcs::ElementwiseCompute, T, T>( dev_ctx, x_tmp, *mean, 0, funcs::SubtractFunctor(), &out); phi::funcs::ElementwiseCompute, T, T>( dev_ctx, out, *var, 0, funcs::DivAndSqrtFunctor(static_cast(epsilon)), &out); if (scale) { phi::funcs::ElementwiseCompute, T, T>( dev_ctx, out, *scale, 1, funcs::MultiplyFunctor(), &out); } if (bias) { phi::funcs::ElementwiseCompute, T, T>( dev_ctx, out, *bias, 1, funcs::AddFunctor(), &out); } #else PADDLE_ENFORCE_EQ(mean->numel(), left, phi::errors::InvalidArgument( "mean's length (%d) is not equal with expected (%d).", mean->numel(), left)); PADDLE_ENFORCE_EQ(var->numel(), left, phi::errors::InvalidArgument( "var's length (%d) is not equal with expected (%d).", var->numel(), left)); if (scale) { PADDLE_ENFORCE_EQ( scale->numel(), right, phi::errors::InvalidArgument( "scale's length (%d) is not equal with expected (%d).", scale->numel(), right)); } if (bias) { PADDLE_ENFORCE_EQ(bias->numel(), right, phi::errors::InvalidArgument( "bias's length (%d) is not equal with expected (%d).", bias->numel(), right)); } auto ker = paddle::operators::jit::KernelFuncs< paddle::operators::jit::LayerNormTuple, phi::CPUPlace>::Cache() .At(right); ker(x_tmp.data(), out.data(), mean->data(), var->data(), scale ? scale->data() : nullptr, bias ? bias->data() : nullptr, static_cast(left), static_cast(epsilon), right); #endif } } // namespace phi PD_REGISTER_KERNEL( layer_norm, CPU, ALL_LAYOUT, phi::LayerNormKernel, float, double) {}