/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/layer_norm_op.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; using DataLayout = framework::DataLayout; template using EigenMatrixMapRowMajor = Eigen::Map< Eigen::Matrix>; template using ConstEigenMatrixMapRowMajor = Eigen::Map< const Eigen::Matrix>; class LayerNormOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), ""); PADDLE_ENFORCE(ctx->HasInput("Scale"), ""); PADDLE_ENFORCE(ctx->HasInput("Bias"), ""); PADDLE_ENFORCE(ctx->HasOutput("Y"), ""); auto x_dim = ctx->GetInputDim("X"); auto begin_norm_axis = ctx->Attrs().Get("begin_norm_axis"); PADDLE_ENFORCE_LT(begin_norm_axis, x_dim.size(), "'begin_norm_axis' must be less than the rank of X"); auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis); int left = static_cast(matrix_dim[0]); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], left); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], left); ctx->SetOutputDim("Y", ctx->GetInputDim("X")); ctx->SetOutputDim("Mean", {left}); ctx->SetOutputDim("Variance", {left}); ctx->ShareLoD("X", "Y"); } }; class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { public: LayerNormOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input tensor"); AddInput("Scale", "Scale is a 1-dimensional tensor of size H " "that is applied to the output"); AddInput("Bias", "Bias is a 1-dimensional tensor of size H " "that is applied to the output"); AddOutput("Y", "result after normalization"); AddOutput("Mean", "Mean of the current mini batch."); AddOutput("Variance", "Variance of the current mini batch."); AddAttr("epsilon", "") .SetDefault(1e-5) .AddCustomChecker([](const float &epsilon) { PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, "'epsilon' should be between 0.0 and 0.001."); }); AddAttr("begin_norm_axis", "(int default:1), the " "axis of `begin_norm_axis ... Rank(X) - 1` will be normalized") .SetDefault(1) .AddCustomChecker([](const int &begin_norm_axis) { PADDLE_ENFORCE_GT(begin_norm_axis, 0, "'begin_norm_axis' should be greater than zero."); }); AddComment(R"DOC( Layer Normalization. Layer Norm has been implemented as discussed in the paper: https://arxiv.org/abs/1607.06450 ... )DOC"); } }; template class LayerNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { const float epsilon = ctx.Attr("epsilon"); const auto *scale = ctx.Input("Scale"); const auto *bias = ctx.Input("Bias"); const auto *x = ctx.Input("X"); const auto &x_dims = x->dims(); const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); auto *output = ctx.Output("Y"); auto *mean = ctx.Output("Mean"); auto *var = ctx.Output("Variance"); output->mutable_data(ctx.GetPlace()); mean->mutable_data(ctx.GetPlace()); var->mutable_data(ctx.GetPlace()); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int left = static_cast(matrix_dim[0]); int right = static_cast(matrix_dim[1]); auto input_map = ConstEigenMatrixMapRowMajor(x->data(), left, right); auto scale_map = ConstEigenMatrixMapRowMajor(scale->data(), 1, right); auto bias_map = ConstEigenMatrixMapRowMajor(bias->data(), 1, right); auto mean_map = EigenMatrixMapRowMajor(mean->data(), left, 1); auto var_map = EigenMatrixMapRowMajor(var->data(), left, 1); auto output_map = EigenMatrixMapRowMajor(output->data(), left, right); auto squre = [](T ele) { return ele * ele; }; auto add_epslion = [epsilon](T ele) { return ele + epsilon; }; mean_map = input_map.rowwise().mean(); var_map = (input_map - mean_map.replicate(1, right)) .unaryExpr(squre) .rowwise() .mean() .unaryExpr(add_epslion); auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; // TODO(zcd): Some thinking about output_map, is it appropriate that // `output_map` and `input_map` point to the same memory. auto inv_std_scale = var_map.unaryExpr(inv_std_func); output_map = (input_map - mean_map.replicate(1, right)) .cwiseProduct(inv_std_scale.replicate(1, right)) .cwiseProduct(scale_map.replicate(left, 1)) - bias_map.replicate(left, 1); } }; class LayerNormGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { // check input PADDLE_ENFORCE(ctx->HasInput("X")); PADDLE_ENFORCE(ctx->HasInput("Scale"), ""); PADDLE_ENFORCE(ctx->HasInput("Mean"), ""); PADDLE_ENFORCE(ctx->HasInput("Variance"), ""); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), ""); // check output if (ctx->HasOutput(framework::GradVarName("X"))) { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } if (ctx->HasOutput(framework::GradVarName("Scale"))) { ctx->SetOutputDim(framework::GradVarName("Scale"), ctx->GetInputDim("Scale")); } if (ctx->HasOutput(framework::GradVarName("Bias"))) { ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias")); } } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { const auto *var = ctx.InputVar(framework::GradVarName("Y")); if (var == nullptr) { PADDLE_THROW("can't find Y@GRAD"); } const Tensor *t = nullptr; if (var->IsType()) { t = &var->Get(); } else if (var->IsType()) { t = &var->Get(); } if (t == nullptr) { PADDLE_THROW("can't find Y@GRAD"); } return framework::OpKernelType(framework::ToDataType(t->type()), ctx.GetPlace()); } }; template class LayerNormGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { const auto *x = ctx.Input("X"); const auto *mean = ctx.Input("Mean"); const auto *var = ctx.Input("Variance"); const auto *scale = ctx.Input("Scale"); const auto *d_y = ctx.Input(framework::GradVarName("Y")); const auto &x_dims = x->dims(); const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); int left = static_cast(matrix_dim[0]); int right = static_cast(matrix_dim[1]); // init output auto *d_x = ctx.Output(framework::GradVarName("X")); auto *d_scale = ctx.Output(framework::GradVarName("Scale")); auto *d_bias = ctx.Output(framework::GradVarName("Bias")); auto scale_map = ConstEigenMatrixMapRowMajor(scale->data(), 1, right); auto x_map = ConstEigenMatrixMapRowMajor(x->data(), left, right); auto d_y_map = ConstEigenMatrixMapRowMajor(d_y->data(), left, right); auto mean_map = ConstEigenMatrixMapRowMajor(mean->data(), left, 1); auto var_map = ConstEigenMatrixMapRowMajor(var->data(), left, 1); if (d_bias) { d_bias->mutable_data(ctx.GetPlace()); auto d_bias_map = EigenMatrixMapRowMajor(d_bias->data(), 1, right); d_bias_map = d_y_map.colwise().mean(); } if (d_scale) { d_scale->mutable_data(ctx.GetPlace()); auto d_scale_map = EigenMatrixMapRowMajor(d_scale->data(), 1, right); auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; // There are two equation to compute d_scale. One uses "Y" and the other // does not use "Y" d_scale_map = ((x_map - mean_map.replicate(1, right)) .cwiseProduct( var_map.unaryExpr(inv_std_func).replicate(1, right)) .cwiseProduct(d_y_map)) .colwise() .mean(); } if (d_x) { d_x->mutable_data(ctx.GetPlace()); auto d_x_map = EigenMatrixMapRowMajor(d_x->data(), left, right); auto triple_product_func = [](T ele) { return ele * ele * ele; }; auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; // dy_dx auto dx_end = var_map.unaryExpr(inv_std_func) .replicate(1, right) .cwiseProduct(d_y_map) .cwiseProduct(scale_map.replicate(left, 1)); // dy_dmean_dx auto dx_mean = (T(-1.0) / right) * var_map.unaryExpr(inv_std_func) .replicate(1, right) .cwiseProduct(d_y_map) .cwiseProduct(scale_map.replicate(left, 1)) .rowwise() .sum() .replicate(1, right); // dy_var_dx auto dvar_end_part = (x_map - mean_map.replicate(1, right)) .cwiseProduct(d_y_map) .rowwise() .sum(); auto dvar_end = var_map.unaryExpr(inv_std_func) .unaryExpr(triple_product_func) .cwiseProduct(dvar_end_part) .replicate(1, right) .cwiseProduct(scale_map.replicate(left, 1)); auto dx_var = (T(-1.0) / right) * (x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end); d_x_map = dx_end + dx_mean + dx_var; } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker, layer_norm_grad, ops::LayerNormGradOp); REGISTER_OP_CPU_KERNEL( layer_norm, ops::LayerNormKernel); REGISTER_OP_CPU_KERNEL( layer_norm_grad, ops::LayerNormGradKernel);