layer_norm_op.cc 7.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/layer_norm_op.h"
S
sneaxiy 已提交
16
#include <memory>
C
chengduoZH 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;

class LayerNormOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
C
chengduoZH 已提交
30 31 32 33 34 35 36 37
    PADDLE_ENFORCE(ctx->HasInput("X"),
                   "Input(X) of LayerNormOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Y"),
                   "Output(Y) of LayerNormOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Mean"),
                   "Output(Mean) of LayerNormOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Variance"),
                   "Output(Variance) of LayerNormOp should not be null.");
C
chengduoZH 已提交
38

C
chengduoZH 已提交
39 40 41
    auto x_dim = ctx->GetInputDim("X");
    auto begin_norm_axis = ctx->Attrs().Get<int>("begin_norm_axis");
    PADDLE_ENFORCE_LT(begin_norm_axis, x_dim.size(),
C
chengduoZH 已提交
42
                      "'begin_norm_axis' must be less than the rank of X.");
C
chengduoZH 已提交
43 44 45

    auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis);
    int left = static_cast<int>(matrix_dim[0]);
C
chengduoZH 已提交
46
    int right = static_cast<int>(matrix_dim[1]);
C
chengduoZH 已提交
47
    if (ctx->HasInput("Scale")) {
T
tensor-tang 已提交
48
      PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1);
C
chengduoZH 已提交
49 50 51
      PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right);
    }
    if (ctx->HasInput("Bias")) {
T
tensor-tang 已提交
52
      PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1);
C
chengduoZH 已提交
53 54
      PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right);
    }
C
chengduoZH 已提交
55

C
chengduoZH 已提交
56
    ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
C
chengduoZH 已提交
57 58
    ctx->SetOutputDim("Mean", {left});
    ctx->SetOutputDim("Variance", {left});
C
chengduoZH 已提交
59 60 61 62 63 64
    ctx->ShareLoD("X", "Y");
  }
};

class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
65
  void Make() override {
Y
yuyang18 已提交
66
    AddInput("X", "The input tensor.");
C
chengduoZH 已提交
67
    AddInput("Scale",
Y
yuyang18 已提交
68
             "(optional) Scale is a 1-dimensional tensor of size "
C
chengduoZH 已提交
69 70 71
             "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])."
             "It is applied to the output.")
        .AsDispensable();
C
chengduoZH 已提交
72
    AddInput("Bias",
Y
yuyang18 已提交
73
             "(optional) Bias is a 1-dimensional tensor of size "
C
chengduoZH 已提交
74 75 76
             "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])."
             "It is applied to the output.")
        .AsDispensable();
Y
yuyang18 已提交
77 78 79
    AddOutput("Y", "Result after normalization.");
    AddOutput("Mean", "Mean of the current mini batch.").AsIntermediate();
    AddOutput("Variance", "Variance of the current mini batch.")
C
chengduoZH 已提交
80 81 82
        .AsIntermediate();

    AddAttr<float>("epsilon",
Y
yuyang18 已提交
83
                   "Constant for numerical stability [default 1e-5].")
C
chengduoZH 已提交
84 85 86 87 88
        .SetDefault(1e-5)
        .AddCustomChecker([](const float &epsilon) {
          PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
                         "'epsilon' should be between 0.0 and 0.001.");
        });
C
chengduoZH 已提交
89
    AddAttr<int>("begin_norm_axis",
Y
yuyang18 已提交
90
                 "the axis of `begin_norm_axis ... Rank(X) - 1` will be "
C
chengduoZH 已提交
91
                 "normalized. `begin_norm_axis` splits the tensor(`X`) to a "
Y
yuyang18 已提交
92
                 "matrix [N,H]. [default 1].")
C
chengduoZH 已提交
93 94 95 96 97
        .SetDefault(1)
        .AddCustomChecker([](const int &begin_norm_axis) {
          PADDLE_ENFORCE_GT(begin_norm_axis, 0,
                            "'begin_norm_axis' should be greater than zero.");
        });
C
chengduoZH 已提交
98 99

    AddComment(R"DOC(
Y
yuyang18 已提交
100 101 102 103 104 105 106 107
Assume feature vectors exist on dimensions
:attr:`begin_norm_axis ... rank(input)` and calculate the moment statistics
along these dimensions for each feature vector :math:`a` with size
:math:`H`, then normalize each feature vector using the corresponding
statistics. After that, apply learnable gain and bias on the normalized
tensor to scale and shift if :attr:`scale` and :attr:`shift` are set.

Refer to `Layer Normalization <https://arxiv.org/pdf/1607.06450v1.pdf>`_
C
chengduoZH 已提交
108 109 110 111 112 113 114 115 116 117
)DOC");
  }
};

class LayerNormGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    // check input
C
chengduoZH 已提交
118 119 120 121 122 123 124 125
    PADDLE_ENFORCE(ctx->HasInput("X"),
                   "Input(X) of LayerNormOp should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("Mean"),
                   "Input(Mean) of LayerNormOp should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("Variance"),
                   "Input(Variance) of LayerNormOp should not be null.");
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
                   "Input(Y@GRAD) of LayerNormOp should not be null.");
C
chengduoZH 已提交
126 127 128

    // check output
    if (ctx->HasOutput(framework::GradVarName("X"))) {
C
chengduoZH 已提交
129
      ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
C
chengduoZH 已提交
130 131
    }
    if (ctx->HasOutput(framework::GradVarName("Scale"))) {
C
chengduoZH 已提交
132 133
      ctx->SetOutputDim(framework::GradVarName("Scale"),
                        ctx->GetInputDim("Scale"));
C
chengduoZH 已提交
134 135
    }
    if (ctx->HasOutput(framework::GradVarName("Bias"))) {
C
chengduoZH 已提交
136
      ctx->SetOutputDim(framework::GradVarName("Bias"),
S
sneaxiy 已提交
137
                        ctx->GetInputDim("Scale"));
C
chengduoZH 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    const auto *var = ctx.InputVar(framework::GradVarName("Y"));
    if (var == nullptr) {
      PADDLE_THROW("can't find Y@GRAD");
    }
    const Tensor *t = nullptr;
    if (var->IsType<Tensor>()) {
      t = &var->Get<Tensor>();
    } else if (var->IsType<LoDTensor>()) {
      t = &var->Get<LoDTensor>();
    }
    if (t == nullptr) {
      PADDLE_THROW("can't find Y@GRAD");
    }
Y
Yu Yang 已提交
157
    return framework::OpKernelType(t->type(), ctx.GetPlace());
C
chengduoZH 已提交
158 159 160
  }
};

S
sneaxiy 已提交
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
class LayerNormGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType("layer_norm_grad");
    op->SetInput("X", Input("X"));
    op->SetInput("Mean", Output("Mean"));
    op->SetInput("Variance", Output("Variance"));
    if (ForwardOp().Inputs().count("Scale") > 0) {
      op->SetInput("Scale", Input("Scale"));
      op->SetOutput(framework::GradVarName("Scale"), InputGrad("Scale"));
    }

    if (ForwardOp().Inputs().count("Bias") > 0) {
      op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
    }

    op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());
    return op;
  }
};

C
chengduoZH 已提交
188 189 190 191
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
192
REGISTER_OPERATOR(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker,
S
sneaxiy 已提交
193
                  ops::LayerNormGradOpDescMaker);
194
REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp);
C
chengduoZH 已提交
195
REGISTER_OP_CPU_KERNEL(
C
chengduoZH 已提交
196 197
    layer_norm, ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>,
    ops::LayerNormKernel<paddle::platform::CPUDeviceContext, double>);
C
chengduoZH 已提交
198 199
REGISTER_OP_CPU_KERNEL(
    layer_norm_grad,
C
chengduoZH 已提交
200 201
    ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, float>,
    ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, double>);