lrn_op.cc 14.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
G
gongweibao 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/lrn_op.h"
16
#include <memory>
17
#include <string>
18
#include <vector>
19
#include "paddle/fluid/operators/math/blas.h"
20
#include "paddle/fluid/operators/math/math_function.h"
T
Tomasz Patejko 已提交
21 22 23
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
G
gongweibao 已提交
24 25 26 27 28

namespace paddle {
namespace operators {

using framework::Tensor;
29
using DataLayout = framework::DataLayout;
G
gongweibao 已提交
30

31
template <typename T>
Q
QI JUN 已提交
32
struct LRNFunctor<platform::CPUDeviceContext, T> {
33 34 35
  void operator()(const framework::ExecutionContext& ctx,
                  const framework::Tensor& input, framework::Tensor* out,
                  framework::Tensor* mid, int N, int C, int H, int W, int n,
36
                  T k, T alpha, T beta, const DataLayout data_layout) {
37 38
    auto place = ctx.GetPlace();
    auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
    math::Transpose<platform::CPUDeviceContext, T, 4> transpose;
    auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
    Tensor in_transpose, mid_transpose, out_transpose;
    // if channel_last, transpose to channel_first
    if (data_layout == DataLayout::kNHWC) {
      auto in_dims = input.dims();
      std::vector<int64_t> shape(
          {in_dims[0], in_dims[3], in_dims[1], in_dims[2]});
      in_transpose.mutable_data<T>(framework::make_ddim(shape), place);
      mid_transpose.mutable_data<T>(framework::make_ddim(shape), place);
      out_transpose.mutable_data<T>(framework::make_ddim(shape), place);
      std::vector<int> axis = {0, 3, 1, 2};
      transpose(dev_ctx, input, &in_transpose, axis);
    } else {
      in_transpose = input;
      mid_transpose = *mid;
      out_transpose = *out;
      mid_transpose.mutable_data<T>(mid->dims(), place);
      out_transpose.mutable_data<T>(out->dims(), place);
    }

    const T* idata = in_transpose.data<T>();
    T* odata = out_transpose.data<T>();
    T* mdata = mid_transpose.data<T>();

64 65 66 67 68 69 70 71 72 73 74
    Tensor squared;
    T* sdata = squared.mutable_data<T>({1, C + n - 1, H, W}, place);
    std::memset(sdata, 0, sizeof(T) * squared.numel());
    for (int i = 0; i < mid->numel(); ++i) {
      mdata[i] = k;
    }
    int img_size = H * W;
    int fea_size = C * img_size;
    int pre_pad = (n - 1) / 2;
    // compute batches one by one
    for (int i = 0; i < N; ++i) {
T
tensor-tang 已提交
75
      blas.VSQUARE(fea_size, idata + i * fea_size, sdata + pre_pad * img_size);
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
      // init the first channel of mid
      for (int c = 0; c < n; ++c) {
        blas.AXPY(img_size, alpha, sdata + c * img_size, mdata + i * fea_size);
      }
      for (int c = 1; c < C; ++c) {
        // copy previous scale
        int mid_offset = i * fea_size + c * img_size;
        std::memcpy(mdata + mid_offset, mdata + mid_offset - img_size,
                    img_size * sizeof(T));
        // add last
        blas.AXPY(img_size, alpha, sdata + (c + n - 1) * img_size,
                  mdata + mid_offset);
        // sub rest
        blas.AXPY(img_size, -alpha, sdata + (c - 1) * img_size,
                  mdata + mid_offset);
91 92
      }
    }
93 94 95
    // compute the final output
    blas.VPOW(mid->numel(), mdata, -beta, odata);
    blas.VMUL(mid->numel(), odata, idata, odata);
96 97 98 99 100 101 102

    // if channel_last, transpose the output(NCHW) to channel_last
    if (data_layout == DataLayout::kNHWC) {
      std::vector<int> axis = {0, 2, 3, 1};
      transpose(dev_ctx, mid_transpose, mid, axis);
      transpose(dev_ctx, out_transpose, out, axis);
    }
103 104
  }
};
Q
QI JUN 已提交
105 106
template struct LRNFunctor<platform::CPUDeviceContext, float>;
template struct LRNFunctor<platform::CPUDeviceContext, double>;
107 108

template <typename T>
Q
QI JUN 已提交
109
struct LRNGradFunctor<platform::CPUDeviceContext, T> {
110 111 112 113
  void operator()(const framework::ExecutionContext& ctx,
                  const framework::Tensor& x, const framework::Tensor& out,
                  const framework::Tensor& mid, framework::Tensor* x_g,
                  const framework::Tensor& out_g, int N, int C, int H, int W,
114
                  int n, T alpha, T beta, const DataLayout data_layout) {
115 116 117 118 119 120 121 122 123 124 125 126 127 128
    T ratio = -2 * alpha * beta;
    auto x_g_e = framework::EigenVector<T>::Flatten(*x_g);
    x_g_e = x_g_e.constant(0.0);

    auto e_x = framework::EigenTensor<T, 4>::From(x);
    auto e_x_g = framework::EigenTensor<T, 4>::From(*x_g);
    auto e_out = framework::EigenTensor<T, 4>::From(out);
    auto e_out_g = framework::EigenTensor<T, 4>::From(out_g);
    auto e_mid = framework::EigenTensor<T, 4>::From(mid);

    const int start = -(n - 1) / 2;
    const int end = start + n;
    for (int m = 0; m < N; m++) {
      for (int i = 0; i < C; i++) {
129 130 131 132 133 134
        auto offsets = Eigen::array<int, 4>({{m, i, 0, 0}});
        auto extents = Eigen::array<int, 4>({{1, 1, H, W}});
        if (data_layout == DataLayout::kNHWC) {
          offsets = Eigen::array<int, 4>({{m, 0, 0, i}});
          extents = Eigen::array<int, 4>({{1, H, W, 1}});
        }
135

136 137 138 139
        auto i_x = e_x.slice(offsets, extents);
        auto i_x_g = e_x_g.slice(offsets, extents);
        auto i_out_g = e_out_g.slice(offsets, extents);
        auto i_mid = e_mid.slice(offsets, extents);
140 141

        i_x_g = i_mid.pow(-beta) * i_out_g;
Q
qingqing01 已提交
142
        for (int c = start; c < end; c++) {
143 144 145 146 147
          int ch = i + c;
          if (ch < 0 || ch >= C) {
            continue;
          }

148 149 150 151 152 153 154 155
          if (data_layout != DataLayout::kNHWC) {
            offsets = Eigen::array<int, 4>({{m, ch, 0, 0}});
          } else {
            offsets = Eigen::array<int, 4>({{m, 0, 0, ch}});
          }
          auto c_out = e_out.slice(offsets, extents);
          auto c_mid = e_mid.slice(offsets, extents);
          auto c_out_g = e_out_g.slice(offsets, extents);
156 157 158 159 160 161 162

          i_x_g += ratio * c_out_g * c_out * i_x / c_mid;
        }
      }
    }
  }
};
Q
QI JUN 已提交
163 164
template struct LRNGradFunctor<platform::CPUDeviceContext, float>;
template struct LRNGradFunctor<platform::CPUDeviceContext, double>;
165

G
gongweibao 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
class LRNOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LRNOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
                   "Output(Out) of LRNOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("MidOut"),
                   "MidOut(Out) of LRNOp should not be null.");

    auto x_dim = ctx->GetInputDim("X");
    PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4.");

181 182 183
    int n = ctx->Attrs().Get<int>("n");
    PADDLE_ENFORCE(n > 0 && n % 2 == 1, "n should be positive odd value");

G
gongweibao 已提交
184 185
    ctx->SetOutputDim("Out", x_dim);
    ctx->ShareLoD("X", /*->*/ "Out");
186
    ctx->SetOutputDim("MidOut", x_dim);
G
gongweibao 已提交
187
  }
T
Tomasz Patejko 已提交
188 189

  framework::OpKernelType GetExpectedKernelType(
190
      const framework::ExecutionContext& ctx) const override {
191 192
    framework::LibraryType library_{framework::LibraryType::kPlain};
    // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
193
    framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
194 195 196 197 198 199 200 201 202 203
#ifdef PADDLE_WITH_MKLDNN
    if (library_ == framework::LibraryType::kPlain &&
        platform::CanMKLDNNBeUsed(ctx)) {
      library_ = framework::LibraryType::kMKLDNN;
      layout_ = framework::DataLayout::kMKLDNN;
    }
#endif
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
        layout_, library_);
T
Tomasz Patejko 已提交
204
  }
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
#ifdef PADDLE_WITH_MKLDNN
    if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
        (tensor.layout() != framework::DataLayout::kMKLDNN)) {
      auto attrs = Attrs();
      auto ar = paddle::framework::AttrReader(attrs);
      const std::string data_format = ar.Get<std::string>("data_format");
      auto dl = framework::StringToDataLayout(data_format);
      // Some models may have intentionally set "AnyLayout" for pool
      // op. Treat this as NCHW (default data_format value)
      if (dl != framework::DataLayout::kAnyLayout) {
        return framework::OpKernelType(expected_kernel_type.data_type_,
                                       tensor.place(), dl);
      }
    }
#endif
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
G
gongweibao 已提交
227 228 229 230 231
};

template <typename T>
class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
232
  void Make() override {
K
kexinzhao 已提交
233 234 235
    AddInput("X",
             "(Tensor) The input of LRN operator. "
             "It must be a 4D tenor with NCHW format.");
G
gongweibao 已提交
236 237 238
    AddOutput("Out",
              "(Tensor) The output of LRN operator, which is also the 4D "
              "tensor with NCHW format.");
K
kexinzhao 已提交
239 240 241 242 243 244 245 246
    AddOutput("MidOut",
              "(Tensor) Middle result of LRN operator. It's computed in "
              "forward process and also used in backward process.");

    AddAttr<int>("n",
                 "(int default 5) "
                 "n is the \"adjacent\" kernel that maps "
                 "at the same spatial position.")
G
gongweibao 已提交
247 248 249
        .SetDefault(5)
        .GreaterThan(0);

K
kexinzhao 已提交
250 251 252
    AddAttr<T>("k",
               "(float, default 2.0) "
               "k is the bias.")
G
gongweibao 已提交
253 254 255
        .SetDefault(2.0)
        .GreaterThan(0.0);

K
kexinzhao 已提交
256 257 258
    AddAttr<T>("alpha",
               "(float, default 0.0001) "
               "alpha is the scale number.")
G
gongweibao 已提交
259 260 261
        .SetDefault(0.0001)
        .GreaterThan(0.0);

K
kexinzhao 已提交
262 263 264
    AddAttr<T>("beta",
               "(float, default 0.75) "
               "beta is the power number.")
G
gongweibao 已提交
265 266
        .SetDefault(0.75)
        .GreaterThan(0.0);
T
Tomasz Patejko 已提交
267 268 269 270 271 272 273 274 275 276
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
    AddAttr<std::string>(
        "data_format",
        "(string, default NCHW) Only used in "
        "An optional string from: \"NHWC\", \"NCHW\". "
        "Defaults to \"NHWC\". Specify the data format of the output data, "
        "the input will be transformed automatically. ")
        .SetDefault("AnyLayout");
277
    AddAttr<bool>("is_test",
278 279
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
280
        .SetDefault(false);
G
gongweibao 已提交
281 282

    AddComment(R"DOC(
K
kexinzhao 已提交
283
Local Response Normalization Operator.
G
gongweibao 已提交
284

285 286
This operator comes from the paper:
<<ImageNet Classification with Deep Convolutional Neural Networks>>.
G
gongweibao 已提交
287

K
kexinzhao 已提交
288
The original formula is:
G
gongweibao 已提交
289

K
kexinzhao 已提交
290 291
$$
Output(i, x, y) = Input(i, x, y) / \left(
X
xiaoting 已提交
292
k + \alpha \sum\limits^{\min(C-1, i + n/2)}_{j = \max(0, i - n/2)}
K
kexinzhao 已提交
293 294 295
(Input(j, x, y))^2
\right)^{\beta}
$$
G
gongweibao 已提交
296

K
kexinzhao 已提交
297
Function implementation:
G
gongweibao 已提交
298

T
tianshuo78520a 已提交
299
Inputs and outputs are in NCHW or NHWC format, while input.shape.ndims() equals 4.
300
If NCHW, the dimensions 0 ~ 3 represent batch size, feature maps, rows,
K
kexinzhao 已提交
301
and columns, respectively.
G
gongweibao 已提交
302

K
kexinzhao 已提交
303 304
Input and Output in the formula above is for each map(i) of one image, and
Input(i, x, y), Output(i, x, y) represents an element in an image.
G
gongweibao 已提交
305

K
kexinzhao 已提交
306 307 308
C is the number of feature maps of one image. n is a hyper-parameter
configured when operator is initialized. The sum in the denominator
is the sum of the same positions in the neighboring maps.
Q
QI JUN 已提交
309

K
kexinzhao 已提交
310
)DOC");
G
gongweibao 已提交
311 312 313 314 315 316 317 318 319 320
  }
};

class LRNOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
321
    PADDLE_ENFORCE(ctx->HasInput("MidOut"), "Input(MidOut) should not be null");
G
gongweibao 已提交
322 323 324 325 326 327 328
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
                   "Input(Out@GRAD) should not be null");

    auto x_dims = ctx->GetInputDim("X");
    ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
  }

T
Tomasz Patejko 已提交
329
  framework::OpKernelType GetExpectedKernelType(
330
      const framework::ExecutionContext& ctx) const override {
331 332
    framework::LibraryType library_{framework::LibraryType::kPlain};
    // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
333
    framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
334 335 336 337 338 339 340 341 342 343
#ifdef PADDLE_WITH_MKLDNN
    if (library_ == framework::LibraryType::kPlain &&
        platform::CanMKLDNNBeUsed(ctx)) {
      library_ = framework::LibraryType::kMKLDNN;
      layout_ = framework::DataLayout::kMKLDNN;
    }
#endif
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
        layout_, library_);
T
Tomasz Patejko 已提交
344
  }
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
#ifdef PADDLE_WITH_MKLDNN
    if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
        (tensor.layout() != framework::DataLayout::kMKLDNN)) {
      auto attrs = Attrs();
      auto ar = paddle::framework::AttrReader(attrs);
      const std::string data_format = ar.Get<std::string>("data_format");
      auto dl = framework::StringToDataLayout(data_format);
      // Some models may have intentionally set "AnyLayout" for lrn
      // op. Treat this as NCHW (default data_format value)
      if (dl != framework::DataLayout::kAnyLayout) {
        return framework::OpKernelType(expected_kernel_type.data_type_,
                                       tensor.place(), dl);
      }
    }
#endif
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
T
Tomasz Patejko 已提交
367
};
368 369 370 371 372

template <typename T>
class LRNGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
373
  void Apply(GradOpPtr<T> op) const override {
374 375 376 377 378 379 380 381 382 383
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput("X", this->Input("X"));
    op->SetInput("Out", this->Output("Out"));
    op->SetInput("MidOut", this->Output("MidOut"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
  }
};

G
gongweibao 已提交
384 385 386 387
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
388 389 390
REGISTER_OPERATOR(lrn, ops::LRNOp, ops::LRNOpMaker<float>,
                  ops::LRNGradOpMaker<paddle::framework::OpDesc>,
                  ops::LRNGradOpMaker<paddle::imperative::OpBase>);
H
hong 已提交
391

392
REGISTER_OPERATOR(lrn_grad, ops::LRNOpGrad);
Q
QI JUN 已提交
393 394 395 396
REGISTER_OP_CPU_KERNEL(
    lrn, ops::LRNKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
    lrn_grad, ops::LRNGradKernel<paddle::platform::CPUDeviceContext, float>);