lrn_op.cc 8.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
G
gongweibao 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/lrn_op.h"
G
gongweibao 已提交
16 17 18 19 20 21

namespace paddle {
namespace operators {

using framework::Tensor;

22
template <typename T>
Q
QI JUN 已提交
23
struct LRNFunctor<platform::CPUDeviceContext, T> {
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
  void operator()(const framework::ExecutionContext& ctx,
                  const framework::Tensor& input, framework::Tensor* out,
                  framework::Tensor* mid, int N, int C, int H, int W, int n,
                  T k, T alpha, T beta) {
    auto x_v = framework::EigenVector<T>::Flatten(input);

    const int start = -(n - 1) / 2;
    const int end = start + n;

    auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
    e_mid = e_mid.constant(k);

    auto e_x = framework::EigenTensor<T, 4>::From(input);
    for (int m = 0; m < N; m++) {
      for (int i = 0; i < C; i++) {
        for (int c = start; c <= end; c++) {
          int ch = i + c;
          if (ch >= 0 && ch < C) {
            auto s = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
                                 Eigen::array<int, 4>({{1, 1, H, W}}));

            auto r = e_x.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
                               Eigen::array<int, 4>({{1, 1, H, W}}));

            s += alpha * r.square();
          }
        }
      }
    }

    auto out_e = framework::EigenVector<T>::Flatten(*out);
    out_e = x_v * e_mid.reshape(Eigen::DSizes<int, 1>(e_mid.size())).pow(-beta);
  }
};
Q
QI JUN 已提交
58 59
template struct LRNFunctor<platform::CPUDeviceContext, float>;
template struct LRNFunctor<platform::CPUDeviceContext, double>;
60 61

template <typename T>
Q
QI JUN 已提交
62
struct LRNGradFunctor<platform::CPUDeviceContext, T> {
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
  void operator()(const framework::ExecutionContext& ctx,
                  const framework::Tensor& x, const framework::Tensor& out,
                  const framework::Tensor& mid, framework::Tensor* x_g,
                  const framework::Tensor& out_g, int N, int C, int H, int W,
                  int n, T alpha, T beta) {
    T ratio = -2 * alpha * beta;
    auto x_g_e = framework::EigenVector<T>::Flatten(*x_g);
    x_g_e = x_g_e.constant(0.0);

    auto e_x = framework::EigenTensor<T, 4>::From(x);
    auto e_x_g = framework::EigenTensor<T, 4>::From(*x_g);
    auto e_out = framework::EigenTensor<T, 4>::From(out);
    auto e_out_g = framework::EigenTensor<T, 4>::From(out_g);
    auto e_mid = framework::EigenTensor<T, 4>::From(mid);

    const int start = -(n - 1) / 2;
    const int end = start + n;
    for (int m = 0; m < N; m++) {
      for (int i = 0; i < C; i++) {
        auto i_x = e_x.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
                             Eigen::array<int, 4>({{1, 1, H, W}}));

        auto i_x_g = e_x_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
                                 Eigen::array<int, 4>({{1, 1, H, W}}));

        auto i_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
                                     Eigen::array<int, 4>({{1, 1, H, W}}));

        auto i_mid = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
                                 Eigen::array<int, 4>({{1, 1, H, W}}));

        i_x_g = i_mid.pow(-beta) * i_out_g;
        for (int c = start; c <= end; c++) {
          int ch = i + c;
          if (ch < 0 || ch >= C) {
            continue;
          }

          auto c_out = e_out.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
                                   Eigen::array<int, 4>({{1, 1, H, W}}));

          auto c_mid = e_mid.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
                                   Eigen::array<int, 4>({{1, 1, H, W}}));

          auto c_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
                                       Eigen::array<int, 4>({{1, 1, H, W}}));

          i_x_g += ratio * c_out_g * c_out * i_x / c_mid;
        }
      }
    }
  }
};
Q
QI JUN 已提交
116 117
template struct LRNGradFunctor<platform::CPUDeviceContext, float>;
template struct LRNGradFunctor<platform::CPUDeviceContext, double>;
118

G
gongweibao 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
class LRNOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LRNOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
                   "Output(Out) of LRNOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("MidOut"),
                   "MidOut(Out) of LRNOp should not be null.");

    auto x_dim = ctx->GetInputDim("X");
    PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4.");

    ctx->SetOutputDim("Out", x_dim);
    ctx->SetOutputDim("MidOut", x_dim);
    ctx->ShareLoD("X", /*->*/ "Out");
  }
};

template <typename T>
class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
143
  LRNOpMaker(OpProto* proto, OpAttrChecker* op_checker)
G
gongweibao 已提交
144
      : OpProtoAndCheckerMaker(proto, op_checker) {
K
kexinzhao 已提交
145 146 147
    AddInput("X",
             "(Tensor) The input of LRN operator. "
             "It must be a 4D tenor with NCHW format.");
G
gongweibao 已提交
148 149 150
    AddOutput("Out",
              "(Tensor) The output of LRN operator, which is also the 4D "
              "tensor with NCHW format.");
K
kexinzhao 已提交
151 152 153 154 155 156 157 158
    AddOutput("MidOut",
              "(Tensor) Middle result of LRN operator. It's computed in "
              "forward process and also used in backward process.");

    AddAttr<int>("n",
                 "(int default 5) "
                 "n is the \"adjacent\" kernel that maps "
                 "at the same spatial position.")
G
gongweibao 已提交
159 160 161
        .SetDefault(5)
        .GreaterThan(0);

K
kexinzhao 已提交
162 163 164
    AddAttr<T>("k",
               "(float, default 2.0) "
               "k is the bias.")
G
gongweibao 已提交
165 166 167
        .SetDefault(2.0)
        .GreaterThan(0.0);

K
kexinzhao 已提交
168 169 170
    AddAttr<T>("alpha",
               "(float, default 0.0001) "
               "alpha is the scale number.")
G
gongweibao 已提交
171 172 173
        .SetDefault(0.0001)
        .GreaterThan(0.0);

K
kexinzhao 已提交
174 175 176
    AddAttr<T>("beta",
               "(float, default 0.75) "
               "beta is the power number.")
G
gongweibao 已提交
177 178 179 180
        .SetDefault(0.75)
        .GreaterThan(0.0);

    AddComment(R"DOC(
K
kexinzhao 已提交
181
Local Response Normalization Operator.
G
gongweibao 已提交
182

183 184
This operator comes from the paper:
<<ImageNet Classification with Deep Convolutional Neural Networks>>.
G
gongweibao 已提交
185

K
kexinzhao 已提交
186
The original formula is:
G
gongweibao 已提交
187

K
kexinzhao 已提交
188 189 190 191 192 193
$$
Output(i, x, y) = Input(i, x, y) / \left(
k + \alpha \sum\limits^{\min(C, c + n/2)}_{j = \max(0, c - n/2)}
(Input(j, x, y))^2
\right)^{\beta}
$$
G
gongweibao 已提交
194

K
kexinzhao 已提交
195
Function implementation:
G
gongweibao 已提交
196

K
kexinzhao 已提交
197 198 199
Inputs and outpus are in NCHW format, while input.shape.ndims() equals 4.
And dimensions 0 ~ 3 represent batch size, feature maps, rows,
and columns, respectively.
G
gongweibao 已提交
200

K
kexinzhao 已提交
201 202
Input and Output in the formula above is for each map(i) of one image, and
Input(i, x, y), Output(i, x, y) represents an element in an image.
G
gongweibao 已提交
203

K
kexinzhao 已提交
204 205 206
C is the number of feature maps of one image. n is a hyper-parameter
configured when operator is initialized. The sum in the denominator
is the sum of the same positions in the neighboring maps.
Q
QI JUN 已提交
207

K
kexinzhao 已提交
208
)DOC");
G
gongweibao 已提交
209 210 211 212 213 214 215 216 217 218
  }
};

class LRNOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
219
    PADDLE_ENFORCE(ctx->HasInput("MidOut"), "Input(MidOut) should not be null");
G
gongweibao 已提交
220 221 222 223 224 225 226 227 228 229 230 231 232
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
                   "Input(Out@GRAD) should not be null");

    auto x_dims = ctx->GetInputDim("X");
    ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(lrn, ops::LRNOp, ops::LRNOpMaker<float>, lrn_grad, ops::LRNOpGrad);
Q
QI JUN 已提交
233 234 235 236
REGISTER_OP_CPU_KERNEL(
    lrn, ops::LRNKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
    lrn_grad, ops::LRNGradKernel<paddle::platform::CPUDeviceContext, float>);