From ca0177190f75a4f39482b8fe1c8e929ab8e1a381 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 22 Jan 2018 15:18:47 +0800 Subject: [PATCH] add layer_norm --- paddle/operators/layer_norm_op.cc | 283 ++++++++++++++++++ paddle/operators/layer_norm_op.h | 35 +++ .../v2/fluid/tests/test_layer_norm_op.py | 81 +++++ 3 files changed, 399 insertions(+) create mode 100644 paddle/operators/layer_norm_op.cc create mode 100644 paddle/operators/layer_norm_op.h create mode 100644 python/paddle/v2/fluid/tests/test_layer_norm_op.py diff --git a/paddle/operators/layer_norm_op.cc b/paddle/operators/layer_norm_op.cc new file mode 100644 index 0000000000..f1ddcd8210 --- /dev/null +++ b/paddle/operators/layer_norm_op.cc @@ -0,0 +1,283 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/layer_norm_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using DataLayout = framework::DataLayout; + +template +using EigenMatrixMapRowMajor = Eigen::Map< + Eigen::Matrix>; +template +using ConstEigenMatrixMapRowMajor = Eigen::Map< + const Eigen::Matrix>; + +class LayerNormOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), ""); + PADDLE_ENFORCE(ctx->HasInput("Scale"), ""); + PADDLE_ENFORCE(ctx->HasInput("Bias"), ""); + PADDLE_ENFORCE(ctx->HasOutput("Y"), ""); + + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], 1); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], 1); + + ctx->SetOutputDim("Y", ctx->GetInputDim("X")); + ctx->SetOutputDim("Mean", {ctx->GetInputDim("X")[0]}); + ctx->SetOutputDim("Variance", {ctx->GetInputDim("X")[0]}); + + ctx->ShareLoD("X", "Y"); + } +}; + +class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LayerNormOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input tensor"); + AddInput("Scale", + "Scale is a 1-dimensional tensor of size 1 " + "that is applied to the output"); + AddInput("Bias", + "Bias is a 1-dimensional tensor of size 1 " + "that is applied to the output"); + AddOutput("Y", "result after normalization"); + AddOutput("Mean", "Mean of the current mini batch."); + AddOutput("Variance", "Variance of the current mini batch."); + + AddAttr("epsilon", "") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, + "'epsilon' should be between 0.0 and 0.001."); + }); + AddAttr>("axis", + "(vector default:{1, 1, 1}), the " + "axis to normalize.") + .SetDefault({1, 2, 3}); // todo(zcd) : who to set axis + + AddComment(R"DOC( +Layer Normalization. + +Layer Norm has been implemented as discussed in the paper: +https://arxiv.org/abs/1607.06450 +... +)DOC"); + } +}; + +template +class LayerNormKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const float epsilon = ctx.Attr("epsilon"); + const auto *scale = ctx.Input("Scale"); + const auto *bias = ctx.Input("Bias"); + const auto *x = ctx.Input("X"); + const auto &x_dims = x->dims(); + + const int N = x_dims[0]; + const int sample_size = x->numel() / N; + + auto scale_data = scale->data()[0]; + auto bias_data = bias->data()[0]; + + auto *output = ctx.Output("Y"); + auto *mean = ctx.Output("Mean"); + auto *var = ctx.Output("Variance"); + output->mutable_data(ctx.GetPlace()); + mean->mutable_data(ctx.GetPlace()); + var->mutable_data(ctx.GetPlace()); + + int left = N, right = sample_size; + auto input_map = ConstEigenMatrixMapRowMajor(x->data(), left, right); + auto mean_map = EigenMatrixMapRowMajor(mean->data(), left, 1); + auto var_map = EigenMatrixMapRowMajor(var->data(), left, 1); + auto output_map = EigenMatrixMapRowMajor(output->data(), left, right); + + auto squre = [](T ele) { return ele * ele; }; + auto add_epslion = [epsilon](T ele) { return ele + epsilon; }; + + mean_map = input_map.rowwise().mean(); + var_map = (input_map - mean_map.replicate(1, right)) + .unaryExpr(squre) + .rowwise() + .mean() + .unaryExpr(add_epslion); + + auto scale_inv_std = [scale_data](T ele) { + return std::sqrt(1 / ele) * scale_data; + }; + auto sub_bias = [bias_data](T ele) { return bias_data - ele; }; + + output_map = (var_map.unaryExpr(scale_inv_std).replicate(1, right)) + .cwiseProduct(input_map) + + var_map.unaryExpr(scale_inv_std) + .cwiseProduct(mean_map) + .unaryExpr(sub_bias) + .replicate(1, right); + } +}; + +class LayerNormGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + // check input + PADDLE_ENFORCE(ctx->HasInput("X")); + PADDLE_ENFORCE(ctx->HasInput("Scale"), ""); + PADDLE_ENFORCE(ctx->HasInput("Mean"), ""); + PADDLE_ENFORCE(ctx->HasInput("Variance"), ""); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), ""); + + const auto x_dims = ctx->GetInputDim("X"); + + // check output + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } + if (ctx->HasOutput(framework::GradVarName("Scale"))) { + ctx->SetOutputDim(framework::GradVarName("Scale"), {1}); + } + if (ctx->HasOutput(framework::GradVarName("Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Bias"), {1}); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + const auto *var = ctx.InputVar(framework::GradVarName("Y")); + if (var == nullptr) { + PADDLE_THROW("can't find Y@GRAD"); + } + const Tensor *t = nullptr; + if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &var->Get(); + } + if (t == nullptr) { + PADDLE_THROW("can't find Y@GRAD"); + } + return framework::OpKernelType(framework::ToDataType(t->type()), + ctx.GetPlace()); + } +}; + +template +class LayerNormGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *x = ctx.Input("X"); + const auto *mean = ctx.Input("Mean"); + const auto *var = ctx.Input("Variance"); + const auto *scale = ctx.Input("Scale"); + const auto *d_y = ctx.Input(framework::GradVarName("Y")); + + const auto &x_dims = x->dims(); + const int N = x_dims[0]; + const int sample_size = x->numel() / N; + int left = N, right = sample_size; + + auto scale_data = scale->data()[0]; + + // init output + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_scale = ctx.Output(framework::GradVarName("Scale")); + auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + + auto x_map = ConstEigenMatrixMapRowMajor(x->data(), left, right); + auto d_y_map = ConstEigenMatrixMapRowMajor(d_y->data(), left, right); + auto mean_map = ConstEigenMatrixMapRowMajor(mean->data(), left, 1); + auto var_map = ConstEigenMatrixMapRowMajor(var->data(), left, 1); + + if (d_bias) { + d_bias->mutable_data(ctx.GetPlace()); + d_bias->data()[0] = d_y_map.sum(); + } + if (d_scale) { + d_scale->mutable_data(ctx.GetPlace()); + auto inv_std = [](T ele) { return std::sqrt(1 / ele); }; + d_scale->data()[0] = + ((x_map - mean_map.replicate(1, right)) + .cwiseProduct(var_map.unaryExpr(inv_std).replicate(1, right)) + .cwiseProduct(d_y_map)) + .sum(); // also can use `y` to get d_scale_map + } + + if (d_x) { + d_x->mutable_data(ctx.GetPlace()); + auto d_x_map = EigenMatrixMapRowMajor(d_x->data(), left, right); + auto triple_product = [](T ele) { return ele * ele * ele; }; + auto neg_inv_std = [](T ele) { return T(-1.0) * std::sqrt(1 / ele); }; + auto inv_std_scale_func = [scale_data](T ele) { + return std::sqrt(1 / ele) * scale_data; + }; + auto neg_inv_std_scale_func = [scale_data](T ele) { + return T(-1.0) * std::sqrt(1 / ele) * scale_data; + }; + // dy_dx + auto dx_end = var_map.unaryExpr(inv_std_scale_func) + .replicate(1, right) + .cwiseProduct(d_y_map); + // dy_dmean_dx + auto dmean_end = var_map.unaryExpr(neg_inv_std_scale_func) + .replicate(1, right) + .cwiseProduct(d_y_map) + .rowwise() + .sum(); + auto dx_mean = (T(1.0) / right) * dmean_end.replicate(1, right); + // dy_var_dx + auto dvar_end_0 = (x_map - mean_map.replicate(1, right)) + .cwiseProduct(d_y_map) + .rowwise() + .sum(); + auto dvar_end = var_map.unaryExpr(neg_inv_std) + .unaryExpr(triple_product) + .cwiseProduct(dvar_end_0); + auto dx_var = (1.0f / right) * + (x_map - mean_map.replicate(1, right)) + .cwiseProduct(dvar_end.replicate(1, right)); + + d_x_map = dx_end + dx_mean + dx_var; + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker, + layer_norm_grad, ops::LayerNormGradOp); +REGISTER_OP_CPU_KERNEL( + layer_norm, + ops::LayerNormKernel); +REGISTER_OP_CPU_KERNEL( + layer_norm_grad, + ops::LayerNormGradKernel); diff --git a/paddle/operators/layer_norm_op.h b/paddle/operators/layer_norm_op.h new file mode 100644 index 0000000000..bca35b91e6 --- /dev/null +++ b/paddle/operators/layer_norm_op.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class LayerNormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override; +}; + +template +class LayerNormGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override; +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_layer_norm_op.py b/python/paddle/v2/fluid/tests/test_layer_norm_op.py new file mode 100644 index 0000000000..73450c599d --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_layer_norm_op.py @@ -0,0 +1,81 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +from op_test import OpTest + + +def layer_norm_naive(x, scale, beta, epsilon): + n, c, h, w = x.shape + mean = np.mean(x, axis=(1, 2, 3)) + var = np.var(x, axis=(1, 2, 3)) + epsilon + output = scale * np.divide((x - mean.reshape([n, 1, 1, 1])), + (np.sqrt(var)).reshape([n, 1, 1, 1])) + beta + return output, mean, var + + +class TestLayerNormdOp(OpTest): + def setUp(self): + self.init_test_case() + + input = np.random.random(self.input_size).astype("float32") + self.inputs = { + 'X': input, + 'Scale': np.array([self.scale]).astype("float32"), + 'Bias': np.array([self.bias]).astype("float32") + } + output, mean, var = layer_norm_naive(input, self.scale, self.bias, + self.epsilon) + self.outputs = {'Y': output, 'Mean': mean, 'Variance': var} + + def test_check_output(self): + self.check_output() + + # def test_check_grad(self): + # self.check_grad( + # ['Scale', 'Bias', 'X'], ['Y', 'Mean', 'Variance'], + # max_relative_error=0.02) + + def test_check_grad_no_x(self): + self.check_grad( + ['Scale', 'Bias'], ['Y', 'Mean', 'Variance'], + max_relative_error=0.02, + no_grad_set=set(['X'])) + + # def test_check_grad_no_scale(self): + # self.check_grad( + # ['Bias','X'], + # 'Y', + # max_relative_error=0.02, + # no_grad_set=set(['Scale'])) + # + # def test_check_grad_no_bias(self): + # self.check_grad( + # ['Scale','X'], + # 'Y', + # max_relative_error=0.02, + # no_grad_set=set(['Bias'])) + + def init_test_case(self): + self.op_type = "layer_norm" + self.input_size = [2, 3, 4, 5] + self.scale = 0.21 + self.bias = 0.1 + self.epsilon = 0.00001 + + +if __name__ == '__main__': + unittest.main() -- GitLab