batch_norm_op.cc 26.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/batch_norm_op.h"
Q
qingqing01 已提交
16
#include <memory>
S
Siddharth Goyal 已提交
17
#include <string>
Q
qingqing01 已提交
18
#include <unordered_map>
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/data_layout.h"
20 21 22
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
Q
Qiao Longfei 已提交
23 24 25 26

namespace paddle {
namespace operators {

Q
qingqing01 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
  PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ConvOp should not be null.");
  PADDLE_ENFORCE(ctx->HasInput("Scale"),
                 "Input(Scale) of ConvOp should not be null.");
  PADDLE_ENFORCE(ctx->HasInput("Bias"),
                 "Input(Bias) of ConvOp should not be null.");
  PADDLE_ENFORCE(ctx->HasInput("Mean"),
                 "Input(Mean) of ConvOp should not be null.");
  PADDLE_ENFORCE(ctx->HasInput("Variance"),
                 "Input(Variance) of ConvOp should not be null.");
  PADDLE_ENFORCE(ctx->HasOutput("Y"),
                 "Output(Y) of ConvOp should not be null.");
  bool is_test = ctx->Attrs().Get<bool>("is_test");
  if (!is_test) {
    PADDLE_ENFORCE(ctx->HasOutput("MeanOut"),
                   "Output(MeanOut) of ConvOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("VarianceOut"),
                   "Output(VarianceOut) of ConvOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("SavedMean"),
                   "Output(SavedMean) of ConvOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("SavedVariance"),
                   "Output(SavedVariance) of ConvOp should not be null.");
Q
Qiao Longfei 已提交
49
  }
K
Kexin Zhao 已提交
50

Q
qingqing01 已提交
51 52 53 54 55 56 57 58 59 60
  // make sure Mean/MeanOut and Variance/VarianceOut share memory in Python
  PADDLE_ENFORCE_EQ(ctx->Inputs("Mean")[0], ctx->Outputs("MeanOut")[0],
                    "Mean and MeanOut should share the same memory");
  PADDLE_ENFORCE_EQ(ctx->Inputs("Variance")[0], ctx->Outputs("VarianceOut")[0],
                    "Variance and VarianceOut should share the same memory");

  const auto x_dims = ctx->GetInputDim("X");
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));

61 62 63 64 65 66 67
  if (ctx->IsRuntime() && ctx->HasInput("MomentumTensor")) {
    auto mom = ctx->Inputs("MomentumTensor");
    PADDLE_ENFORCE_EQ(mom.size(), 1,
                      platform::errors::InvalidArgument(
                          "Input(MomentumTensor) size must be 1"));
  }

68 69 70 71 72 73 74 75 76 77 78 79
  PADDLE_ENFORCE_GE(
      x_dims.size(), 2,
      "ShapeError: the dimension of input X must greater than or equal to 2."
      "But received: the shape of input X = [%s], the dimension of input X ="
      "[%d]",
      x_dims, x_dims.size());
  PADDLE_ENFORCE_LE(
      x_dims.size(), 5,
      "ShapeError: the dimension of input X must smaller than or equal to 5."
      "But received: the shape of input X = [%s], the dimension of input X ="
      "[%d]",
      x_dims, x_dims.size());
Q
qingqing01 已提交
80 81 82 83 84

  const int64_t C =
      (data_layout == DataLayout::kNCHW ? x_dims[1]
                                        : x_dims[x_dims.size() - 1]);

85 86
  auto scale_dim = ctx->GetInputDim("Scale");
  auto bias_dim = ctx->GetInputDim("Bias");
Q
qingqing01 已提交
87

88 89 90 91 92 93 94 95 96 97
  PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL,
                    "ShapeError: the dimension of scale must equal to 1."
                    "But received: the shape of scale is [%s], the dimension "
                    "of scale is [%d]",
                    scale_dim, scale_dim.size());
  PADDLE_ENFORCE_EQ(
      bias_dim.size(), 1UL,
      "ShapeError: the dimension of bias must equal to 1."
      "But received: the shape of bias is [%s],the dimension of bias is [%d]",
      bias_dim, bias_dim.size());
C
ceci3 已提交
98

99 100 101 102 103 104 105
  bool check = true;
  if ((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 ||
                              framework::product(bias_dim) <= 0)) {
    check = false;
  }

  if (check) {
106 107 108 109 110 111 112 113
    PADDLE_ENFORCE_EQ(scale_dim[0], C,
                      "ShapeError: the shape of scale must equal to [%d]"
                      "But received: the shape of scale is [%d]",
                      C, scale_dim[0]);
    PADDLE_ENFORCE_EQ(bias_dim[0], C,
                      "ShapeError: the shape of bias must equal to [%d]"
                      "But received: the shape of bias is [%d]",
                      C, bias_dim[0]);
114
  }
Q
qingqing01 已提交
115 116 117 118 119 120 121 122 123 124
  ctx->SetOutputDim("Y", x_dims);
  ctx->SetOutputDim("MeanOut", {C});
  ctx->SetOutputDim("VarianceOut", {C});
  ctx->SetOutputDim("SavedMean", {C});
  ctx->SetOutputDim("SavedVariance", {C});
  ctx->ShareLoD("X", "Y");
}

framework::OpKernelType BatchNormOp::GetExpectedKernelType(
    const framework::ExecutionContext &ctx) const {
125
  auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
Q
qingqing01 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
  // By default, the type of the scale, bias, mean,
  // and var tensors should both be float. (For float or float16 input tensor)
  // or double (For double input tensor).
  auto bn_param_type = framework::proto::VarType::FP32;
  if (input_data_type == framework::proto::VarType::FP64) {
    bn_param_type = framework::proto::VarType::FP64;
  }
  PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Scale")->type(),
                    "Scale input should be of float type");
  PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Bias")->type(),
                    "Bias input should be of float type");
  PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Mean")->type(),
                    "Mean input should be of float type");
  PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Variance")->type(),
                    "Variance input should be of float type");

  // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
  framework::LibraryType library = framework::LibraryType::kPlain;
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
145
#ifdef PADDLE_WITH_MKLDNN
Q
qingqing01 已提交
146 147 148 149
  if (library == framework::LibraryType::kPlain &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
    layout = framework::DataLayout::kMKLDNN;
K
Kexin Zhao 已提交
150
  }
Q
qingqing01 已提交
151
#endif
Q
Qiao Longfei 已提交
152

Q
qingqing01 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
  return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
                                 library);
}

void BatchNormOpMaker::Make() {
  AddAttr<bool>("is_test",
                "(bool, default false) Set to true for inference only, false "
                "for training. Some layers may run faster when this is true.")
      .SetDefault(false);
  AddAttr<float>("momentum", "").SetDefault(0.9);
  AddAttr<float>("epsilon", "")
      .SetDefault(1e-5)
      .AddCustomChecker([](const float &epsilon) {
        PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
                       "'epsilon' should be between 0.0 and 0.001.");
      });
  AddAttr<std::string>("data_layout", "").SetDefault("NCHW");
  AddInput("X", "The input tensor");
  AddInput("Scale",
           "Scale is a 1-dimensional tensor of size C "
           "that is applied to the output");
  AddInput("Bias",
           "Bias is a 1-dimensional tensor of size C "
           "that is applied to the output");
  AddInput("Mean",
           "The global mean (for training) or "
           "estimated mean (for testing)");
  AddInput("Variance",
           "The global variance (for training) "
           "or estimated Variance (for testing)");
183 184 185 186 187
  AddInput("MomentumTensor",
           "(Tensor<float32>, optional) If provided, batch_norm will "
           "use this as momentum, this has a higher priority than "
           "attr(momentum), the shape of this tensor MUST BE [1].")
      .AsDispensable();
Q
qingqing01 已提交
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
  AddOutput("Y", "result after normalization");
  AddOutput("MeanOut",
            "Share memory with Mean. "
            "Store the global mean when training");
  AddOutput("VarianceOut",
            "Share memory with Variance. "
            "Store the global Variance when training");
  AddOutput("SavedMean",
            "Mean of the current mini batch, "
            "will apply to output when training")
      .AsIntermediate();
  AddOutput("SavedVariance",
            "Variance of the current mini batch, "
            "will apply to output when training")
      .AsIntermediate();
  AddAttr<bool>("use_mkldnn",
                "(bool, default false) Only used in mkldnn kernel")
      .SetDefault(false);
  AddAttr<bool>("fuse_with_relu",
                "(bool, default false) Only used in mkldnn kernel")
      .SetDefault(false);
  AddAttr<bool>("use_global_stats",
                "(bool, default false) Whether to use global mean and "
                "variance. In inference or test mode, set use_global_stats "
                "to true or is_test true. the behavior is equivalent. "
                "In train mode, when setting use_global_stats True, the "
                "global mean and variance are also used during train time, "
                "the BN acts as scaling and shiffting.")
      .SetDefault(false);
  AddComment(R"DOC(
218
Batch Normalization.
Q
Qiao Longfei 已提交
219

220 221 222 223 224 225
Batch Norm has been implemented as discussed in the paper:
https://arxiv.org/pdf/1502.03167.pdf
Can be used as a normalizer function for conv2d and fully_connected operations.
The required data format for this layer is one of the following:
1. NHWC `[batch, in_height, in_width, in_channels]`
2. NCHW `[batch, in_channels, in_height, in_width]`
Q
Qiao Longfei 已提交
226 227

)DOC");
Q
qingqing01 已提交
228
}
C
chengduo 已提交
229

Q
Qiao Longfei 已提交
230
template <typename T>
Q
QI JUN 已提交
231 232
class BatchNormKernel<platform::CPUDeviceContext, T>
    : public framework::OpKernel<T> {
Q
Qiao Longfei 已提交
233 234 235
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const float epsilon = ctx.Attr<float>("epsilon");
236
    float momentum = ctx.Attr<float>("momentum");
Q
Qiao Longfei 已提交
237
    const bool is_test = ctx.Attr<bool>("is_test");
238 239 240 241
    const bool use_global_stats = ctx.Attr<bool>("use_global_stats");

    bool global_stats = is_test || use_global_stats;

Q
QI JUN 已提交
242 243 244
    const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
    const DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
Q
Qiao Longfei 已提交
245 246 247

    const auto *x = ctx.Input<Tensor>("X");
    const auto &x_dims = x->dims();
248 249
    PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
                   "The Input dim size should be between 2 and 5");
Q
Qiao Longfei 已提交
250 251
    const int N = x_dims[0];
    const int C =
Q
QI JUN 已提交
252 253
        (data_layout == DataLayout::kNCHW ? x_dims[1]
                                          : x_dims[x_dims.size() - 1]);
Q
Qiao Longfei 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
    const int sample_size = x->numel() / N / C;

    auto *y = ctx.Output<Tensor>("Y");
    auto *mean_out = ctx.Output<Tensor>("MeanOut");
    auto *variance_out = ctx.Output<Tensor>("VarianceOut");
    auto *saved_mean = ctx.Output<Tensor>("SavedMean");
    auto *saved_variance = ctx.Output<Tensor>("SavedVariance");

    // alloc memory
    y->mutable_data<T>(ctx.GetPlace());
    mean_out->mutable_data<T>(ctx.GetPlace());
    variance_out->mutable_data<T>(ctx.GetPlace());
    saved_mean->mutable_data<T>(ctx.GetPlace());
    saved_variance->mutable_data<T>(ctx.GetPlace());

269
    if (!global_stats) {
Q
Qiao Longfei 已提交
270 271 272 273 274 275 276 277
      // saved_xx is use just in this batch of data
      EigenVectorArrayMap<T> saved_mean_e(
          saved_mean->mutable_data<T>(ctx.GetPlace()), C);
      EigenVectorArrayMap<T> saved_variance_e(
          saved_variance->mutable_data<T>(ctx.GetPlace()), C);
      saved_mean_e.setZero();
      saved_variance_e.setZero();

278 279 280 281 282 283
      EigenVectorArrayMap<T> running_mean_arr(
          mean_out->mutable_data<T>(ctx.GetPlace()), C);
      EigenVectorArrayMap<T> running_var_arr(
          variance_out->mutable_data<T>(ctx.GetPlace()), C);

      if ((N * sample_size) == 1) {
284 285
        // Only 1 element in normalization dimension,
        // we skip the batch norm calculation, let y = x.
286
        framework::TensorCopy(*x, ctx.GetPlace(), y);
287 288 289
        return;
      }

Q
QI JUN 已提交
290 291
      switch (data_layout) {
        case DataLayout::kNCHW: {
Q
Qiao Longfei 已提交
292 293 294 295 296 297 298 299 300 301 302 303
          ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
          for (int nc = 0; nc < N * C; ++nc) {
            saved_mean_e(nc % C) += x_arr.col(nc).sum();
          }
          saved_mean_e /= N * sample_size;
          for (int nc = 0; nc < N * C; ++nc) {
            saved_variance_e(nc % C) +=
                (x_arr.col(nc) - saved_mean_e(nc % C)).matrix().squaredNorm();
          }
          saved_variance_e /= N * sample_size;
          break;
        }
Q
QI JUN 已提交
304
        case DataLayout::kNHWC: {
Q
Qiao Longfei 已提交
305 306 307 308 309 310 311 312 313 314 315 316 317
          ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N * sample_size);
          for (int i = 0; i < N * sample_size; ++i) {
            saved_mean_e += x_arr.col(i);
          }
          saved_mean_e /= N * sample_size;
          for (int i = 0; i < N * sample_size; ++i) {
            saved_variance_e +=
                (x_arr.col(i) - saved_mean_e) * (x_arr.col(i) - saved_mean_e);
          }
          saved_variance_e /= N * sample_size;
          break;
        }
        default:
Q
QI JUN 已提交
318
          PADDLE_THROW("Unknown storage order: %s", data_layout_str);
Q
Qiao Longfei 已提交
319 320
      }

321 322 323 324 325 326 327
      // if MomentumTensor is set, use MomentumTensor value, momentum
      // is only used in this training branch
      if (ctx.HasInput("MomentumTensor")) {
        const auto *mom_tensor = ctx.Input<Tensor>("MomentumTensor");
        momentum = mom_tensor->data<float>()[0];
      }

Q
Qiao Longfei 已提交
328 329 330 331 332 333 334 335
      running_mean_arr =
          running_mean_arr * momentum + saved_mean_e * (1. - momentum);
      running_var_arr =
          running_var_arr * momentum + saved_variance_e * (1. - momentum);
    }

    // use SavedMean and SavedVariance to do normalize
    Eigen::Array<T, Eigen::Dynamic, 1> inv_std(C);
336
    if (global_stats) {
Q
Qiao Longfei 已提交
337 338 339 340 341 342 343 344 345 346 347
      ConstEigenVectorArrayMap<T> var_arr(
          ctx.Input<Tensor>("Variance")->data<T>(), C);
      inv_std = (var_arr + epsilon).sqrt().inverse();
    } else {
      EigenVectorArrayMap<T> saved_inv_std(
          ctx.Output<Tensor>("SavedVariance")->data<T>(), C);
      // inverse SavedVariance first, gradient will use it too.
      saved_inv_std = (saved_inv_std + epsilon).inverse().sqrt();
      inv_std = saved_inv_std;
    }
    ConstEigenVectorArrayMap<T> mean_arr(
348 349
        global_stats ? ctx.Input<Tensor>("Mean")->data<T>()
                     : ctx.Output<Tensor>("SavedMean")->data<T>(),
Q
Qiao Longfei 已提交
350 351 352 353 354 355 356 357 358 359 360 361 362
        C);

    //   ((x - est_mean) * (inv_var) * scale + bias
    //   formula transform ====>
    //   (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
    const auto *scale = ctx.Input<Tensor>("Scale");
    const auto *bias = ctx.Input<Tensor>("Bias");
    ConstEigenVectorArrayMap<T> scale_arr(scale->data<T>(), C);
    ConstEigenVectorArrayMap<T> bias_arr(bias->data<T>(), C);
    Eigen::Array<T, Eigen::Dynamic, 1> new_scale = inv_std * scale_arr;
    Eigen::Array<T, Eigen::Dynamic, 1> new_bias =
        bias_arr - mean_arr * inv_std * scale_arr;

Q
QI JUN 已提交
363 364
    switch (data_layout) {
      case DataLayout::kNCHW: {
Q
Qiao Longfei 已提交
365 366 367 368 369 370 371 372
        EigenArrayMap<T> y_arr(y->mutable_data<T>(ctx.GetPlace()), sample_size,
                               N * C);
        ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
        for (int nc = 0; nc < N * C; ++nc) {
          y_arr.col(nc) = x_arr.col(nc) * new_scale(nc % C) + new_bias(nc % C);
        }
        break;
      }
Q
QI JUN 已提交
373
      case DataLayout::kNHWC: {
Q
Qiao Longfei 已提交
374 375 376 377 378 379 380 381 382
        EigenArrayMap<T>(y->mutable_data<T>(ctx.GetPlace()), C,
                         N * sample_size) =
            (ConstEigenArrayMap<T>(x->data<T>(), C, N * sample_size).colwise() *
             new_scale)
                .colwise() +
            new_bias;
        break;
      }
      default:
Q
QI JUN 已提交
383
        PADDLE_THROW("Unknown storage order: %d", data_layout);
Q
Qiao Longfei 已提交
384 385 386 387
    }
  }
};

Q
qingqing01 已提交
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
  // check input
  PADDLE_ENFORCE(ctx->HasInput("X"));
  PADDLE_ENFORCE(ctx->HasInput("Scale"), "Input(scale) should not be null.");
  PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
                 "Input(Y@GRAD) should not be null.");
  PADDLE_ENFORCE(ctx->HasInput("SavedMean"),
                 "Input(SavedMean) should not be null.");
  PADDLE_ENFORCE(ctx->HasInput("SavedVariance"),
                 "Input(SavedVariance) should not be null");

  // check output
  PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "");
  if (ctx->HasOutput(framework::GradVarName("Scale"))) {
    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
                   "Output(Scale@GRAD) and Output(Bias@GRAD) should not be "
                   "null at same time");
  }
  const bool use_global_stats = ctx->Attrs().Get<bool>("use_global_stats");
  if (use_global_stats) {
    PADDLE_ENFORCE(!ctx->Attrs().Get<bool>("use_mkldnn"),
                   "Using global stats during training is not supported "
                   "in gradient op kernel of batch_norm_mkldnn_op now.");
  }
Q
Qiao Longfei 已提交
412

Q
qingqing01 已提交
413 414 415 416 417
  const auto x_dims = ctx->GetInputDim("X");
  const DataLayout data_layout = framework::StringToDataLayout(
      ctx->Attrs().Get<std::string>("data_layout"));
  const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
                                                  : x_dims[x_dims.size() - 1]);
Q
Qiao Longfei 已提交
418

Q
qingqing01 已提交
419 420 421 422
  ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
  if (ctx->HasOutput(framework::GradVarName("Scale"))) {
    ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
    ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
Q
Qiao Longfei 已提交
423
  }
Q
qingqing01 已提交
424
}
Q
Qiao Longfei 已提交
425

Q
qingqing01 已提交
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
    const framework::ExecutionContext &ctx) const {
  const auto *var = ctx.InputVar(framework::GradVarName("Y"));
  if (var == nullptr) {
    PADDLE_THROW("can't find Y@GRAD");
  }
  const Tensor *t = nullptr;
  if (var->IsType<Tensor>()) {
    t = &var->Get<Tensor>();
  } else if (var->IsType<LoDTensor>()) {
    t = &var->Get<LoDTensor>();
  }
  if (t == nullptr) {
    PADDLE_THROW("can't find Y@GRAD");
  }
441

Q
qingqing01 已提交
442 443 444
  // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
  framework::LibraryType library = framework::LibraryType::kPlain;
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
445

446
#ifdef PADDLE_WITH_MKLDNN
Q
qingqing01 已提交
447 448 449 450 451
  if (library == framework::LibraryType::kPlain &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
    layout = framework::DataLayout::kMKLDNN;
  }
452
#endif
453

454 455 456
  return framework::OpKernelType(
      OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
      library);
Q
qingqing01 已提交
457
}
Q
Qiao Longfei 已提交
458 459

template <typename T>
Q
QI JUN 已提交
460
class BatchNormGradKernel<platform::CPUDeviceContext, T>
Q
Qiao Longfei 已提交
461 462 463 464 465 466 467 468 469
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *x = ctx.Input<Tensor>("X");
    const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
    const auto *scale = ctx.Input<Tensor>("Scale");
    const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
    // SavedVariance have been reverted in forward operator
    const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
Q
QI JUN 已提交
470
    const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
471 472
    const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
    const float epsilon = ctx.Attr<float>("epsilon");
Q
QI JUN 已提交
473 474
    const DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
Q
Qiao Longfei 已提交
475 476 477 478

    // Get the size for each dimension.
    // NCHW [batch_size, in_channels, in_height, in_width]
    const auto &x_dims = x->dims();
479 480
    PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
                   "The Input dim size should be between 2 and 5");
Q
Qiao Longfei 已提交
481 482
    const int N = x_dims[0];
    const int C =
Q
QI JUN 已提交
483 484
        (data_layout == DataLayout::kNCHW ? x_dims[1]
                                          : x_dims[x_dims.size() - 1]);
Q
Qiao Longfei 已提交
485 486 487 488 489 490 491 492
    const int sample_size = x->numel() / N / C;

    // init output
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    d_x->mutable_data<T>(ctx.GetPlace());
493 494 495 496 497 498 499 500

    const T *mean_data = saved_mean->data<T>();
    const T *inv_var_data = saved_inv_variance->data<T>();
    Tensor inv_var_tensor;
    if (use_global_stats) {
      const auto *running_mean = ctx.Input<Tensor>("Mean");
      const auto *running_variance = ctx.Input<Tensor>("Variance");
      mean_data = running_mean->data<T>();
Z
Zeng Jinle 已提交
501
      inv_var_tensor.Resize({C});
502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521
      T *running_inv_var_data = inv_var_tensor.mutable_data<T>(ctx.GetPlace());
      EigenVectorArrayMap<T> inv_var_tmp(running_inv_var_data, C);
      ConstEigenVectorArrayMap<T> var_arr(running_variance->data<T>(), C);

      inv_var_tmp = (var_arr + epsilon).sqrt().inverse().eval();
      inv_var_data = running_inv_var_data;
    }

    ConstEigenVectorArrayMap<T> scale_arr(scale->data<T>(), C);
    ConstEigenVectorArrayMap<T> mean_arr(mean_data, C);
    ConstEigenVectorArrayMap<T> inv_var_arr(inv_var_data, C);

    T *d_bias_data = nullptr;
    T *d_scale_data = nullptr;
    if (d_scale && d_bias) {
      d_scale->mutable_data<T>(ctx.GetPlace());
      d_bias->mutable_data<T>(ctx.GetPlace());
      d_bias_data = d_bias->mutable_data<T>(ctx.GetPlace());
      d_scale_data = d_scale->mutable_data<T>(ctx.GetPlace());
    }
Q
Qiao Longfei 已提交
522 523 524 525 526

    // d_bias = np.sum(d_y, axis=0)
    // d_scale = np.sum((X - mean) / inv_std * dy, axis=0)
    // d_x = (1. / N) * scale * inv_var * (N * d_y - np.sum(d_y, axis=0)
    //   - (X - mean) * inv_var * inv_var * np.sum(d_y * (X - mean), axis=0))
527 528
    EigenVectorArrayMap<T> d_bias_arr(d_bias_data, C);
    EigenVectorArrayMap<T> d_scale_arr(d_scale_data, C);
Q
Qiao Longfei 已提交
529

530 531 532 533
    if (d_scale && d_bias) {
      d_bias_arr.setZero();
      d_scale_arr.setZero();
    }
Q
Qiao Longfei 已提交
534

535 536
    if ((N * sample_size) == 1 && !use_global_stats) {
      framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
537 538 539
      return;
    }

540 541
    int scale_coefff = use_global_stats ? 1 : N * sample_size;
    const auto scale_inv_var_nhw = scale_arr * inv_var_arr / scale_coefff;
Q
Qiao Longfei 已提交
542

L
lvmengsi 已提交
543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
    Tensor dy_sum;
    dy_sum.Resize({C});
    dy_sum.mutable_data<T>(ctx.GetPlace());
    EigenVectorArrayMap<T> dy_sum_arr(dy_sum.mutable_data<T>(ctx.GetPlace()),
                                      C);

    Tensor dy_mul_x_sub_mean_mul_invstd_sum;
    dy_mul_x_sub_mean_mul_invstd_sum.Resize({C});
    dy_mul_x_sub_mean_mul_invstd_sum.mutable_data<T>(ctx.GetPlace());
    EigenVectorArrayMap<T> dy_mul_x_sub_mean_mul_invstd_sum_arr(
        dy_mul_x_sub_mean_mul_invstd_sum.mutable_data<T>(ctx.GetPlace()), C);

    dy_sum_arr.setZero();
    dy_mul_x_sub_mean_mul_invstd_sum_arr.setZero();

Q
QI JUN 已提交
558 559
    switch (data_layout) {
      case DataLayout::kNCHW: {
Q
Qiao Longfei 已提交
560 561 562 563 564 565
        ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
        ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), sample_size, N * C);
        EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()),
                                 sample_size, N * C);
        d_x_arr.setZero();

L
lvmengsi 已提交
566 567 568 569 570 571 572 573
        for (int nc = 0; nc < N * C; ++nc) {
          int c = nc % C;
          dy_sum_arr(c) += d_y_arr.col(nc).sum();
          dy_mul_x_sub_mean_mul_invstd_sum_arr(c) +=
              ((x_arr.col(nc) - mean_arr(c)) * inv_var_arr(c) * d_y_arr.col(nc))
                  .sum();
        }

574
        if (d_scale && d_bias) {
L
lvmengsi 已提交
575 576
          d_bias_arr = dy_sum_arr;
          d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr;
Q
Qiao Longfei 已提交
577
        }
L
lvmengsi 已提交
578

579 580 581 582 583
        if (!use_global_stats) {
          for (int nc = 0; nc < N * C; ++nc) {
            int c = nc % C;
            d_x_arr.col(nc) +=
                scale_inv_var_nhw(c) *
L
lvmengsi 已提交
584 585 586
                (d_y_arr.col(nc) * N * sample_size - dy_sum_arr(c) -
                 (x_arr.col(nc) - mean_arr[c]) *
                     dy_mul_x_sub_mean_mul_invstd_sum_arr(c) * inv_var_arr(c));
587 588 589 590 591 592
          }
        } else {
          for (int nc = 0; nc < N * C; ++nc) {
            int c = nc % C;
            d_x_arr.col(nc) += scale_inv_var_nhw(c) * d_y_arr.col(nc);
          }
Q
Qiao Longfei 已提交
593 594 595
        }
        break;
      }
Q
QI JUN 已提交
596
      case DataLayout::kNHWC: {
Q
Qiao Longfei 已提交
597 598 599 600 601 602
        ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N * sample_size);
        ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), C, N * sample_size);
        EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), C,
                                 N * sample_size);
        d_x_arr.setZero();

L
lvmengsi 已提交
603 604 605 606 607
        for (int nhw = 0; nhw < N * sample_size; ++nhw) {
          dy_sum_arr += d_y_arr.col(nhw);
          dy_mul_x_sub_mean_mul_invstd_sum_arr +=
              (x_arr.col(nhw) - mean_arr) * inv_var_arr * d_y_arr.col(nhw);
        }
608 609

        if (d_scale && d_bias) {
L
lvmengsi 已提交
610 611
          d_bias_arr = dy_sum_arr;
          d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr;
612 613 614 615 616 617
        }

        if (!use_global_stats) {
          for (int nhw = 0; nhw < N * sample_size; ++nhw) {
            d_x_arr.col(nhw) +=
                scale_inv_var_nhw *
L
lvmengsi 已提交
618 619 620
                (d_y_arr.col(nhw) * N * sample_size - dy_sum_arr -
                 (x_arr.col(nhw) - mean_arr) *
                     dy_mul_x_sub_mean_mul_invstd_sum_arr * inv_var_arr);
621 622 623 624 625
          }
        } else {
          for (int nhw = 0; nhw < N * sample_size; ++nhw) {
            d_x_arr.col(nhw) += scale_inv_var_nhw * d_y_arr.col(nhw);
          }
Q
Qiao Longfei 已提交
626 627 628 629
        }
        break;
      }
      default:
Q
QI JUN 已提交
630
        PADDLE_THROW("Unknown storage order: %s", data_layout_str);
Q
Qiao Longfei 已提交
631 632 633 634
    }
  }
};

H
hong 已提交
635
template <typename T>
636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651
std::unique_ptr<T> BatchNormGradMaker<T>::Apply() const {
  auto *op = new T();
  op->SetType(this->ForwardOpType() + "_grad");
  op->SetInput("X", this->Input("X"));
  op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));

  op->SetInput("Scale", this->Input("Scale"));
  op->SetInput("Bias", this->Input("Bias"));
  op->SetInput("SavedMean", this->Output("SavedMean"));
  op->SetInput("SavedVariance", this->Output("SavedVariance"));

  // used when setting use_global_stats True during training
  if (boost::get<bool>(this->GetAttr("use_global_stats"))) {
    op->SetInput("Mean", this->Output("MeanOut"));
    op->SetInput("Variance", this->Output("VarianceOut"));
  }
652

653
  op->SetAttrMap(this->Attrs());
Y
Yu Yang 已提交
654

655 656 657
  op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
  op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
  op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
Y
Yu Yang 已提交
658

659 660
  return std::unique_ptr<T>(op);
}
Y
Yu Yang 已提交
661

Q
Qiao Longfei 已提交
662 663 664 665
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yu Yang 已提交
666
REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
H
hong 已提交
667 668 669
                  ops::BatchNormOpInferVarType,
                  ops::BatchNormGradMaker<paddle::framework::OpDesc>,
                  ops::BatchNormGradMaker<paddle::imperative::OpBase>);
670
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
Y
Yu Yang 已提交
671

Q
QI JUN 已提交
672
REGISTER_OP_CPU_KERNEL(
D
dzhwinter 已提交
673 674
    batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>,
    ops::BatchNormKernel<paddle::platform::CPUDeviceContext, double>);
Q
Qiao Longfei 已提交
675 676
REGISTER_OP_CPU_KERNEL(
    batch_norm_grad,
D
dzhwinter 已提交
677 678
    ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>,
    ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, double>);