/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \ !defined(__OSX__) #include "paddle/fluid/operators/jit/kernels.h" #endif #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { namespace platform { class CPUDeviceContext; class CUDADeviceContext; } // namespace platform } // namespace paddle namespace paddle { namespace operators { // Wrap RowwiseMean and ColwiseMean. // Reuse the cpu codes and replace the gpu codes with cublas_gemv, which is // significantly faster. Unlike the RowwiseMean and ColwiseMean, the // implementation only considers 2D. template struct RowwiseMean2D { RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx); void operator()(const platform::DeviceContext& context, const framework::Tensor& input, framework::Tensor* vec); }; #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template class RowwiseMean2D { public: RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx) : left_(left), right_(right) { framework::DDim ones_dim({right_}); divisor_.mutable_data(ones_dim, dev_ctx.GetPlace()); phi::funcs::set_constant(dev_ctx, &divisor_, 1.0 / right); } void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, framework::Tensor* out) { phi::funcs::GetBlas(context).GEMV( false, left_, right_, 1., input.data(), divisor_.data(), 0., out->data()); } private: int left_; int right_; framework::Tensor divisor_; }; #endif template class RowwiseMean2D { public: RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx) {} void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, framework::Tensor* out) { row_mean_(context, input, out); } private: phi::funcs::RowwiseMean row_mean_; }; template struct ColwiseSum2D { ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx); void operator()(const platform::DeviceContext& context, const framework::Tensor& input, framework::Tensor* vec); }; #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template class ColwiseSum2D { public: ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx) : left_(left), right_(right) { framework::DDim ones_dim({left_}); divisor_.mutable_data(ones_dim, dev_ctx.GetPlace()); phi::funcs::set_constant(dev_ctx, &divisor_, 1.0); } void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, framework::Tensor* out) { phi::funcs::GetBlas(context).GEMV( true, left_, right_, 1., input.data(), divisor_.data(), 0., out->data()); } private: int left_; int right_; framework::Tensor divisor_; }; #endif template class ColwiseSum2D { public: ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx) {} void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, framework::Tensor* out) { col_wise_(context, input, out); } private: phi::funcs::ColwiseSum col_wise_; }; template struct SubAndSquareFunctor { inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); } }; template struct DivAndSqrtFunctor { explicit DivAndSqrtFunctor(T epsilon) { epsilon_ = epsilon; } inline HOSTDEVICE T operator()(T a, T b) const { return a / (sqrt(b + epsilon_)); } private: T epsilon_; }; template struct MulInvVarFunctor { inline HOSTDEVICE T operator()(T a, T b) const { return a * std::sqrt(1.0 / b); } }; using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; using DataLayout = framework::DataLayout; #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template class LayerNormDirectCUDAFunctor { public: void operator()(gpuStream_t stream, const T* input, std::vector input_shape, const T* bias, const T* scale, T* output, T* mean, T* variance, int begin_norm_axis, float eps); }; #endif template class LayerNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { const float epsilon = ctx.Attr("epsilon"); auto* scale = ctx.Input("Scale"); auto* bias = ctx.Input("Bias"); auto x = *ctx.Input("X"); auto* y = ctx.Output("Y"); auto* mean = ctx.Output("Mean"); auto* var = ctx.Output("Variance"); const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); const auto x_dims = x.dims(); y->mutable_data(ctx.GetPlace()); mean->mutable_data(ctx.GetPlace()); var->mutable_data(ctx.GetPlace()); auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); int left = static_cast(matrix_dim[0]); int right = static_cast(matrix_dim[1]); framework::DDim matrix_shape({left, right}); x.Resize(matrix_shape); Tensor out; out.ShareDataWith(*y); out.Resize(matrix_shape); #if defined(PADDLE_WITH_CUDA) || defined(_WIN32) || defined(__APPLE__) || \ defined(__OSX__) auto& dev_ctx = ctx.template device_context(); RowwiseMean2D row_mean(left, right, ctx.device_context()); // get mean row_mean(dev_ctx, x, mean); // get variance ElementwiseComputeEx, DeviceContext, T>( ctx, &x, mean, /*axis*/ 0, SubAndSquareFunctor(), &out); row_mean(dev_ctx, out, var); // get x_norm ElementwiseComputeEx, DeviceContext, T>( ctx, &x, mean, /*axis*/ 0, SubFunctor(), &out); ElementwiseComputeEx, DeviceContext, T>( ctx, &out, var, /*axis*/ 0, DivAndSqrtFunctor(static_cast(epsilon)), &out); if (scale) { ElementwiseComputeEx, DeviceContext, T>( ctx, &out, scale, /*axis*/ 1, MulFunctor(), &out); } if (bias) { ElementwiseComputeEx, DeviceContext, T>( ctx, &out, bias, /*axis*/ 1, AddFunctor(), &out); } #else PADDLE_ENFORCE_EQ(mean->numel(), left, platform::errors::InvalidArgument( "mean's length (%d) is not equal with expected (%d).", mean->numel(), left)); PADDLE_ENFORCE_EQ(var->numel(), left, platform::errors::InvalidArgument( "var's length (%d) is not equal with expected (%d).", var->numel(), left)); if (scale) { PADDLE_ENFORCE_EQ( scale->numel(), right, platform::errors::InvalidArgument( "scale's length (%d) is not equal with expected (%d).", scale->numel(), right)); } if (bias) { PADDLE_ENFORCE_EQ( bias->numel(), right, platform::errors::InvalidArgument( "bias's length (%d) is not equal with expected (%d).", bias->numel(), right)); } auto ker = jit::KernelFuncs, platform::CPUPlace>::Cache() .At(right); ker(x.data(), out.data(), mean->data(), var->data(), scale ? scale->data() : nullptr, bias ? bias->data() : nullptr, static_cast(left), static_cast(epsilon), right); #endif } }; template class LayerNormGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { const float epsilon = ctx.Attr("epsilon"); auto x = *ctx.Input("X"); auto* mean = ctx.Input("Mean"); auto* var = ctx.Input("Variance"); auto* scale = ctx.Input("Scale"); auto d_y = *ctx.Input(framework::GradVarName("Y")); const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); // init output auto* d_x = ctx.Output(framework::GradVarName("X")); auto* d_scale = ctx.Output(framework::GradVarName("Scale")); auto* d_bias = ctx.Output(framework::GradVarName("Bias")); const auto& x_dims = x.dims(); auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); int left = static_cast(matrix_dim[0]); int right = static_cast(matrix_dim[1]); framework::DDim matrix_shape({left, right}); d_y.Resize(matrix_shape); auto& dev_ctx = ctx.template device_context(); ColwiseSum2D colwise_sum(left, right, ctx.device_context()); Tensor temp; Tensor temp_norm; if (d_scale || d_x) { x.Resize(matrix_shape); temp.mutable_data(matrix_shape, ctx.GetPlace()); temp_norm.mutable_data(matrix_shape, ctx.GetPlace()); // get x_norm ElementwiseComputeEx, DeviceContext, T>( ctx, &x, mean, /*axis*/ 0, SubFunctor(), &temp_norm); ElementwiseComputeEx, DeviceContext, T>( ctx, &temp_norm, var, /*axis*/ 0, DivAndSqrtFunctor(static_cast(epsilon)), &temp_norm); } if (d_bias) { d_bias->mutable_data(ctx.GetPlace()); colwise_sum(dev_ctx, d_y, d_bias); } if (d_scale) { d_scale->mutable_data(ctx.GetPlace()); ElementwiseComputeEx, DeviceContext, T>( ctx, &temp_norm, &d_y, /*axis*/ 0, MulFunctor(), &temp); colwise_sum(dev_ctx, temp, d_scale); } if (d_x) { framework::DDim vec_shape({left}); d_x->mutable_data(ctx.GetPlace()); auto dx_dim = d_x->dims(); Tensor temp_vec; temp_vec.mutable_data(vec_shape, ctx.GetPlace()); RowwiseMean2D row_mean(left, right, ctx.device_context()); if (d_scale) { // dy_dx ElementwiseComputeEx, DeviceContext, T>( ctx, &d_y, scale, /*axis*/ 1, MulFunctor(), &temp); framework::TensorCopy(temp, ctx.GetPlace(), ctx.device_context(), d_x); // dy_dmean_dx row_mean(dev_ctx, temp, &temp_vec); ElementwiseComputeEx, DeviceContext, T>( ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor(), d_x); // dy_var_dx ElementwiseComputeEx, DeviceContext, T>( ctx, &temp, &temp_norm, /*axis*/ 0, MulFunctor(), &temp); } else { // dy_dx framework::TensorCopy(d_y, ctx.GetPlace(), ctx.device_context(), d_x); // dy_dmean_dx row_mean(dev_ctx, d_y, &temp_vec); ElementwiseComputeEx, DeviceContext, T>( ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor(), d_x); // dy_var_dx ElementwiseComputeEx, DeviceContext, T>( ctx, &d_y, &temp_norm, /*axis*/ 0, MulFunctor(), &temp); } // dy_var_dx row_mean(dev_ctx, temp, &temp_vec); ElementwiseComputeEx, DeviceContext, T>( ctx, &temp_norm, &temp_vec, /*axis*/ 0, MulFunctor(), &temp); ElementwiseComputeEx, DeviceContext, T>( ctx, d_x, &temp, /*axis*/ 0, SubFunctor(), d_x); ElementwiseComputeEx, DeviceContext, T>( ctx, d_x, var, /*axis*/ 0, DivAndSqrtFunctor(static_cast(epsilon)), d_x); d_x->Resize(dx_dim); } } }; } // namespace operators } // namespace paddle