batch_norm_op.cu 31.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
Q
Qiao Longfei 已提交
16
#include <cfloat>
17 18 19
#include <string>
#include <vector>
#include "cub/cub.cuh"
S
Siddharth Goyal 已提交
20
#include "paddle/fluid/framework/data_layout.h"
21
#include "paddle/fluid/operators/batch_norm_op.h"
Y
Yi Wang 已提交
22 23
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
K
Kexin Zhao 已提交
24
#include "paddle/fluid/platform/float16.h"
Q
Qiao Longfei 已提交
25

26
DECLARE_bool(cudnn_batchnorm_spatial_persistent);
W
Wu Yi 已提交
27

Q
Qiao Longfei 已提交
28 29 30 31
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
Q
QI JUN 已提交
32
using DataLayout = framework::DataLayout;
Q
Qiao Longfei 已提交
33 34
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
K
Kexin Zhao 已提交
35
template <typename T>
K
update  
Kexin Zhao 已提交
36
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
Q
Qiao Longfei 已提交
37 38

template <typename T>
Q
QI JUN 已提交
39 40
class BatchNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
Q
Qiao Longfei 已提交
41 42 43
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
D
dzhwinter 已提交
44
                   "It must use CUDAPlace.");
Q
Qiao Longfei 已提交
45
    double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
46
    float momentum = ctx.Attr<float>("momentum");
Q
Qiao Longfei 已提交
47
    const bool is_test = ctx.Attr<bool>("is_test");
48
    const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
Q
QI JUN 已提交
49 50 51
    const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
    const DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
Q
Qiao Longfei 已提交
52 53 54 55 56

    // Get the size for each dimension.
    // NCHW [batch_size, in_channels, in_height, in_width]
    const auto *x = ctx.Input<Tensor>("X");
    const auto &x_dims = x->dims();
57 58
    PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
                   "The Input dim size should be between 2 and 5");
Q
Qiao Longfei 已提交
59

60 61 62
    auto *y = ctx.Output<Tensor>("Y");
    y->mutable_data<T>(ctx.GetPlace());

63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    int N, C, H, W, D;
    ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);

    auto dtype = platform::CudnnDataType<T>::type;
    const bool fast_nhwc_batch_norm =
        is_test ||
        (dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent);

    auto compute_format =
        fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC
            ? DataLayout::kNHWC
            : DataLayout::kNCHW;

    Tensor transformed_x(x->type());
    Tensor transformed_y(y->type());
    if (data_layout == DataLayout::kNHWC &&
        compute_format == DataLayout::kNCHW && x_dims.size() > 2) {
      VLOG(3) << "Transform input tensor from NHWC to NCHW.";
      ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
                                                           &transformed_x);
      TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
                                                          &transformed_x);
      ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, y,
                                                           &transformed_y);
    } else {
      transformed_x.ShareDataWith(*x);
      transformed_y.ShareDataWith(*y);
    }

Q
Qiao Longfei 已提交
92 93 94 95 96
    // ------------------- cudnn descriptors ---------------------
    cudnnTensorDescriptor_t data_desc_;
    cudnnTensorDescriptor_t bn_param_desc_;
    cudnnBatchNormMode_t mode_;

97 98 99
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
    PADDLE_ENFORCE_CUDA_SUCCESS(
Q
Qiao Longfei 已提交
100 101 102 103 104 105 106 107
        platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));

    if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
      LOG(ERROR) << "Provided epsilon is smaller than "
                 << "CUDNN_BN_MIN_EPSILON. Setting it to "
                 << "CUDNN_BN_MIN_EPSILON instead.";
    }
    epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
108
#if CUDNN_VERSION_MIN(7, 0, 0)
W
Wu Yi 已提交
109 110 111 112 113
    if (FLAGS_cudnn_batchnorm_spatial_persistent) {
      mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
    } else {
      mode_ = CUDNN_BATCHNORM_SPATIAL;
    }
114
#else
Q
Qiao Longfei 已提交
115
    mode_ = CUDNN_BATCHNORM_SPATIAL;
116
#endif
Q
Qiao Longfei 已提交
117

M
minqiyang 已提交
118
    VLOG(3) << "Setting descriptors.";
Q
Qiao Longfei 已提交
119 120
    std::vector<int> dims;
    std::vector<int> strides;
121
    if (compute_format == DataLayout::kNCHW) {
Q
Qiao Longfei 已提交
122 123 124 125 126 127
      dims = {N, C, H, W, D};
      strides = {C * H * W * D, H * W * D, W * D, D, 1};
    } else {
      dims = {N, C, H, W, D};
      strides = {H * W * D * C, 1, W * D * C, D * C, C};
    }
128
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
Q
Qiao Longfei 已提交
129 130
        data_desc_, CudnnDataType<T>::type,
        x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
K
Kexin Zhao 已提交
131
    // Note: PERSISTENT not implemented for inference
132 133 134 135
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnDeriveBNTensorDescriptor(
            bn_param_desc_, data_desc_,
            is_test ? CUDNN_BATCHNORM_SPATIAL : mode_));
Q
Qiao Longfei 已提交
136 137 138 139

    const auto *scale = ctx.Input<Tensor>("Scale");
    const auto *bias = ctx.Input<Tensor>("Bias");

Q
QI JUN 已提交
140
    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
Q
Qiao Longfei 已提交
141

Q
QI JUN 已提交
142
    auto handle = dev_ctx.cudnn_handle();
Q
Qiao Longfei 已提交
143 144

    // Now, depending on whether we are running test or not, we have two paths.
145
    if (is_test || use_global_stats) {
Q
Qiao Longfei 已提交
146 147 148 149 150 151 152 153 154
      // only when test we use input to do computation.
      const auto *est_mean = ctx.Input<Tensor>("Mean");
      const auto *est_var = ctx.Input<Tensor>("Variance");
      // Run inference mode.
      PADDLE_ENFORCE_EQ(est_mean->dims().size(), 1UL);
      PADDLE_ENFORCE_EQ(est_var->dims().size(), 1UL);
      PADDLE_ENFORCE_EQ(est_mean->dims()[0], C);
      PADDLE_ENFORCE_EQ(est_var->dims()[0], C);

155 156 157 158 159 160 161 162 163 164 165 166
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnBatchNormalizationForwardInference(
              handle,
              // Note: PERSISTENT not implemented for inference
              CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
              CudnnDataType<T>::kZero(), data_desc_,
              transformed_x.template data<T>(), data_desc_,
              transformed_y.template mutable_data<T>(ctx.GetPlace()),
              bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
              bias->template data<BatchNormParamType<T>>(),
              est_mean->template data<BatchNormParamType<T>>(),
              est_var->template data<BatchNormParamType<T>>(), epsilon));
Q
Qiao Longfei 已提交
167
    } else {
168 169 170 171 172 173 174 175 176
      // if MomentumTensor is set, use MomentumTensor value, momentum
      // is only used in this training branch
      if (ctx.HasInput("MomentumTensor")) {
        const auto *mom_tensor = ctx.Input<Tensor>("MomentumTensor");
        Tensor mom_cpu;
        TensorCopySync(*mom_tensor, platform::CPUPlace(), &mom_cpu);
        momentum = mom_cpu.data<float>()[0];
      }

Q
Qiao Longfei 已提交
177 178 179
      // Run training mode.
      // obtain running mean and running inv var, and see if we need to
      // initialize them.
D
Dang Qingqing 已提交
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194

      auto *mean_out = ctx.Output<Tensor>("MeanOut");
      auto *variance_out = ctx.Output<Tensor>("VarianceOut");
      mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
      variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());

      auto *saved_mean = ctx.Output<Tensor>("SavedMean");
      auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
      saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
      saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
      math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
          functor;
      functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
      functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));

195
      if ((N * H * W * D) == 1) {
196 197
        // Only 1 element in normalization dimension,
        // skip the batch norm calculation, let y = x.
198
        framework::TensorCopy(*x, ctx.GetPlace(), y);
199 200 201
      } else {
        double this_factor = 1. - momentum;

202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
        bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1)
        if (compute_format == DataLayout::kNHWC) {
          called = true;
          size_t workspace_size = 0;
          size_t reserve_space_size = 0;
          void *reserve_space_ptr = nullptr;
          void *workspace_ptr = nullptr;
          Tensor workspace_tensor;
          // Create reserve space and workspace for batch norm.
          // Create tensor for each batchnorm op, it will be used in the
          // backward. Thus this tensor shouldn't be temp.
          auto *reserve_space = ctx.Output<Tensor>("ReserveSpace");
          PADDLE_ENFORCE_NOT_NULL(
              reserve_space,
              platform::errors::NotFound(
                  "The argument ReserveSpace of batch_norm op is not found."));

          // --------------- cudnn batchnorm workspace ---------------
221
          PADDLE_ENFORCE_CUDA_SUCCESS(
222 223 224 225 226 227 228 229 230 231 232 233 234
              platform::dynload::
                  cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
                      /*handle=*/handle,
                      /*mode=*/mode_,
                      /*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
                      /*xDesc=*/data_desc_,
                      /*zDesc=*/nullptr,
                      /*yDesc=*/data_desc_,
                      /*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
                      /*activationDesc=*/nullptr,
                      /*sizeInBytes=*/&workspace_size));

          // -------------- cudnn batchnorm reserve space --------------
235
          PADDLE_ENFORCE_CUDA_SUCCESS(
236 237 238 239 240 241 242 243 244 245 246 247 248
              platform::dynload::
                  cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
                      /*handle=*/handle,
                      /*mode=*/mode_,
                      /*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
                      /*activationDesc=*/nullptr,
                      /*xDesc=*/data_desc_,
                      /*sizeInBytes=*/&reserve_space_size));

          reserve_space_ptr = reserve_space->mutable_data(
              ctx.GetPlace(), transformed_x.type(), reserve_space_size);
          workspace_ptr = workspace_tensor.mutable_data(
              ctx.GetPlace(), transformed_x.type(), workspace_size);
249
          PADDLE_ENFORCE_CUDA_SUCCESS(
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
              platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
                  handle, mode_, CUDNN_BATCHNORM_OPS_BN,
                  CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
                  data_desc_, transformed_x.template data<T>(), nullptr,
                  nullptr, data_desc_, transformed_y.template data<T>(),
                  bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
                  bias->template data<BatchNormParamType<T>>(), this_factor,
                  mean_out->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  variance_out->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  epsilon,
                  saved_mean->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  saved_variance->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  nullptr, workspace_ptr, workspace_size, reserve_space_ptr,
                  reserve_space_size));
        }
#endif
        if (!called) {
271
          PADDLE_ENFORCE_CUDA_SUCCESS(
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
              platform::dynload::cudnnBatchNormalizationForwardTraining(
                  handle, mode_, CudnnDataType<T>::kOne(),
                  CudnnDataType<T>::kZero(), data_desc_,
                  transformed_x.template data<T>(), data_desc_,
                  transformed_y.template mutable_data<T>(ctx.GetPlace()),
                  bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
                  bias->template data<BatchNormParamType<T>>(), this_factor,
                  mean_out->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  variance_out->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  epsilon,
                  saved_mean->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  saved_variance->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace())));
        }
289
      }
Q
Qiao Longfei 已提交
290 291
    }

292 293 294 295 296 297
    if (data_layout == DataLayout::kNHWC &&
        compute_format == DataLayout::kNCHW && x_dims.size() > 2) {
      VLOG(3) << "Transform batchnorm output from NCHW to NHWC";
      TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
          ctx, &transformed_y, y);
    }
Q
Qiao Longfei 已提交
298
    // clean when exit.
299 300 301
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
    PADDLE_ENFORCE_CUDA_SUCCESS(
Q
Qiao Longfei 已提交
302 303 304 305
        platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
  }
};

306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void KeBNBackwardScaleBias(
    const T *dy, const T *x, const BatchNormParamType<T> *mean,
    const BatchNormParamType<T> *variance, const double epsilon, const int N,
    const int C, const int HxW, BatchNormParamType<T> *dscale,
    BatchNormParamType<T> *dbias) {
  const int outer_size = C;
  const int inner_size = N * HxW;
  typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage ds_storage;
  __shared__ typename BlockReduce::TempStorage db_storage;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    BatchNormParamType<T> ds_sum = static_cast<BatchNormParamType<T>>(0);
    BatchNormParamType<T> db_sum = static_cast<BatchNormParamType<T>>(0);

    BatchNormParamType<T> inv_var_i = 1.0 / sqrt(variance[i] + epsilon);
    BatchNormParamType<T> mean_i = mean[i];
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index = layout == framework::DataLayout::kNCHW
                            ? (j / HxW * C + i) * HxW + j % HxW
                            : j * outer_size + i;
      ds_sum += static_cast<BatchNormParamType<T>>(dy[index]) *
                (static_cast<BatchNormParamType<T>>(x[index]) - mean_i);
      db_sum += static_cast<BatchNormParamType<T>>(dy[index]);
    }
    ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
    db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
    if (threadIdx.x == 0) {
      dscale[i] = ds_sum * inv_var_i;
      dbias[i] = db_sum;
    }
    __syncthreads();
  }
}

Q
qingqing01 已提交
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
template <typename T, framework::DataLayout layout>
static __global__ void KeBNBackwardData(const T *dy,
                                        const BatchNormParamType<T> *scale,
                                        const BatchNormParamType<T> *variance,
                                        const double epsilon, const int C,
                                        const int HxW, const int num, T *dx) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  for (int i = gid; i < num; i += stride) {
    const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
    BatchNormParamType<T> inv_var = 1.0 / sqrt(variance[c] + epsilon);
    dx[i] = static_cast<T>(static_cast<BatchNormParamType<T>>(dy[i]) *
                           scale[c] * inv_var);
  }
}

L
lvmengsi 已提交
358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void BNBackwardData(const T *dy,
                                      const BatchNormParamType<T> *scale,
                                      const BatchNormParamType<T> *mean,
                                      const T *x,
                                      const BatchNormParamType<T> *variance,
                                      const int C, const int N, const int HxW,
                                      T *dx) {
  const int outer_size = C;
  const int inner_size = N * HxW;
  typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage dy_storage;
  __shared__ typename BlockReduce::TempStorage dy_x_sub_mean_storage;
  __shared__ BatchNormParamType<T> dy_sum_val;
  __shared__ BatchNormParamType<T> dy_x_sub_mean_sum_val;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    BatchNormParamType<T> inv_var_i = variance[i];
    BatchNormParamType<T> mean_i = mean[i];
    BatchNormParamType<T> dy_sum = static_cast<BatchNormParamType<T>>(0);
    BatchNormParamType<T> dy_x_sub_mean_sum =
        static_cast<BatchNormParamType<T>>(0);
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index = layout == framework::DataLayout::kNCHW
                            ? (j / HxW * C + i) * HxW + j % HxW
                            : j * outer_size + i;
      BatchNormParamType<T> dy_i =
          static_cast<BatchNormParamType<T>>(dy[index]);
      dy_sum += dy_i;
      dy_x_sub_mean_sum +=
          dy_i * (static_cast<BatchNormParamType<T>>(x[index]) - mean_i);
    }

    dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
    dy_x_sub_mean_sum = BlockReduce(dy_x_sub_mean_storage)
                            .Reduce(dy_x_sub_mean_sum, cub::Sum());

    if (threadIdx.x == 0) {
      dy_sum_val = dy_sum;
      dy_x_sub_mean_sum_val = dy_x_sub_mean_sum;
    }
    __syncthreads();

    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index = layout == framework::DataLayout::kNCHW
                            ? (j / HxW * C + i) * HxW + j % HxW
                            : j * outer_size + i;
      dx[index] =
          (static_cast<BatchNormParamType<T>>(dy[index]) -
           dy_sum_val / static_cast<BatchNormParamType<T>>(inner_size) -
           (static_cast<BatchNormParamType<T>>(x[index]) - mean_i) *
               dy_x_sub_mean_sum_val * inv_var_i * inv_var_i / inner_size) *
          scale[i] * inv_var_i;
    }
  }
}

Q
Qiao Longfei 已提交
415
template <typename T>
Q
QI JUN 已提交
416
class BatchNormGradKernel<platform::CUDADeviceContext, T>
Q
Qiao Longfei 已提交
417 418 419 420
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
D
dzhwinter 已提交
421
                   "It must use CUDAPlace.");
Q
Qiao Longfei 已提交
422
    double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
Q
QI JUN 已提交
423
    const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
424 425
    const bool use_global_stats = ctx.Attr<bool>("use_global_stats");

Q
QI JUN 已提交
426 427
    const DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
Q
Qiao Longfei 已提交
428 429 430
    const auto *x = ctx.Input<Tensor>("X");
    const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
    const auto *scale = ctx.Input<Tensor>("Scale");
431 432 433 434 435 436 437
    const bool is_test = ctx.Attr<bool>("is_test");
    PADDLE_ENFORCE_EQ(
        is_test, false,
        platform::errors::InvalidArgument(
            "`is_test = True` CANNOT be used in train program. If "
            "you want to use global status in pre_train model, "
            "please set `use_global_stats = True`"));
Q
Qiao Longfei 已提交
438 439 440

    const auto &x_dims = x->dims();

441 442
    PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
                   "The Input dim size should be between 2 and 5");
Q
Qiao Longfei 已提交
443
    int N, C, H, W, D;
Q
QI JUN 已提交
444
    ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
Q
Qiao Longfei 已提交
445

446 447 448 449 450 451
    // init output
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    d_x->mutable_data<T>(ctx.GetPlace());
452 453 454
    if (d_scale && d_bias) {
      d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
      d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
455
    }
Q
Qiao Longfei 已提交
456 457 458
    PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
    PADDLE_ENFORCE_EQ(scale->dims()[0], C);

459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490
    auto dtype = platform::CudnnDataType<T>::type;
    const auto *reserve_space = ctx.Input<Tensor>("ReserveSpace");
    const bool fast_nhwc_batch_norm =
        dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent &&
        reserve_space != nullptr;
    auto compute_format =
        fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC
            ? DataLayout::kNHWC
            : DataLayout::kNCHW;

    Tensor transformed_x(x->type());
    Tensor transformed_d_y(d_y->type());
    Tensor transformed_d_x(d_x->type());
    if (data_layout == DataLayout::kNHWC &&
        compute_format == DataLayout::kNCHW) {
      VLOG(3) << "Transform input tensor from NHWC to NCHW.";
      ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
                                                           &transformed_x);
      TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
                                                          &transformed_x);
      ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_y,
                                                           &transformed_d_y);
      TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_y,
                                                          &transformed_d_y);
      ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_x,
                                                           &transformed_d_x);
    } else {
      transformed_x.ShareDataWith(*x);
      transformed_d_y.ShareDataWith(*d_y);
      transformed_d_x.ShareDataWith(*d_x);
    }

Z
zchen0211 已提交
491 492
    std::vector<int> dims;
    std::vector<int> strides;
493
    if (compute_format == DataLayout::kNCHW) {
Z
zchen0211 已提交
494 495 496 497 498 499
      dims = {N, C, H, W, D};
      strides = {C * H * W * D, H * W * D, W * D, D, 1};
    } else {
      dims = {N, C, H, W, D};
      strides = {H * W * C * D, 1, W * D * C, D * C, C};
    }
Q
Qiao Longfei 已提交
500

501
    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
502
    const int num = transformed_x.numel();
L
lvmengsi 已提交
503 504 505 506 507 508
    const int block = 512;
    int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
    const int max_blocks = std::max(max_threads / block, 1);
    int grid1 = (num + block - 1) / block;
    int grid2 = std::min(C, max_blocks);

509 510 511 512 513 514 515 516 517 518 519 520 521 522 523
    if (!use_global_stats) {
      if ((N * H * W * D) == 1) {
        framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
        math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
            functor;
        functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
        functor(dev_ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
        return;
      }

      // ------------------- cudnn descriptors ---------------------
      cudnnTensorDescriptor_t data_desc_;
      cudnnTensorDescriptor_t bn_param_desc_;
      cudnnBatchNormMode_t mode_;

524
      PADDLE_ENFORCE_CUDA_SUCCESS(
525
          platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
526
      PADDLE_ENFORCE_CUDA_SUCCESS(
527 528 529 530 531 532 533
          platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
      if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
        LOG(ERROR) << "Provided epsilon is smaller than "
                   << "CUDNN_BN_MIN_EPSILON. Setting it to "
                   << "CUDNN_BN_MIN_EPSILON instead.";
      }
      epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
534
#if CUDNN_VERSION_MIN(7, 0, 0)
W
Wu Yi 已提交
535 536 537 538 539
      if (FLAGS_cudnn_batchnorm_spatial_persistent) {
        mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
      } else {
        mode_ = CUDNN_BATCHNORM_SPATIAL;
      }
540
#else
541
      mode_ = CUDNN_BATCHNORM_SPATIAL;
542
#endif
543

544
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
545 546
          data_desc_, CudnnDataType<T>::type,
          x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
547 548 549
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_,
                                                           data_desc_, mode_));
550 551 552

      const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
      const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
L
lvmengsi 已提交
553
      const auto *saved_mean_data =
554
          saved_mean->template data<BatchNormParamType<T>>();
L
lvmengsi 已提交
555
      const auto *saved_var_data =
556 557
          saved_var->template data<BatchNormParamType<T>>();

L
lvmengsi 已提交
558
      if (d_scale && d_bias) {
559 560 561 562 563 564 565 566 567
        bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1)
        if (compute_format == DataLayout::kNHWC) {
          called = true;
          size_t workspace_size = 0;
          void *workspace_ptr = nullptr;
          Tensor workspace_tensor;
          auto reserve_space_size = reserve_space->memory_size();
          // --------------- cudnn batchnorm workspace ---------------
568 569 570 571 572 573 574 575 576 577 578 579 580 581
          PADDLE_ENFORCE_CUDA_SUCCESS(
              platform::dynload::
                  cudnnGetBatchNormalizationBackwardExWorkspaceSize(
                      /*handle=*/dev_ctx.cudnn_handle(),
                      /*mode=*/mode_,
                      /*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
                      /*xDesc=*/data_desc_,
                      /*yDesc=*/data_desc_,
                      /*dyDesc=*/data_desc_,
                      /*dzDesc=*/nullptr,
                      /*dxDesc=*/data_desc_,
                      /*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
                      /*activationDesc=*/nullptr,
                      /*sizeInBytes=*/&workspace_size));
582 583 584 585

          workspace_ptr = workspace_tensor.mutable_data(
              ctx.GetPlace(), transformed_x.type(), workspace_size);

586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
          PADDLE_ENFORCE_CUDA_SUCCESS(
              platform::dynload::cudnnBatchNormalizationBackwardEx(
                  /*handle=*/dev_ctx.cudnn_handle(),
                  /*mode=*/mode_,
                  /*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
                  /*alphaDataDiff=*/CudnnDataType<T>::kOne(),
                  /*betaDataDiff=*/CudnnDataType<T>::kZero(),
                  /*alphaParamDiff=*/CudnnDataType<T>::kOne(),
                  /*betaParamDiff=*/CudnnDataType<T>::kZero(),
                  /*xDesc=*/data_desc_,
                  /*xData=*/transformed_x.template data<T>(),
                  /*yDesc=*/nullptr,
                  /*yData=*/nullptr,
                  /*dyDesc=*/data_desc_,
                  /*dyData=*/transformed_d_y.template data<T>(),
                  /*dzDesc=*/nullptr,
                  /*dzData=*/nullptr,
                  /*dxDesc=*/data_desc_,
                  /*dxData=*/transformed_d_x.template mutable_data<T>(
605
                      ctx.GetPlace()),
606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623
                  /*dBnScaleBiasDesc=*/bn_param_desc_,
                  /*bnScaleData=*/scale->template data<BatchNormParamType<T>>(),
                  /*bnBiasData=*/nullptr,
                  /*dBnScaleData=*/d_scale
                      ->template mutable_data<BatchNormParamType<T>>(
                          ctx.GetPlace()),
                  /*dBnBiasData=*/d_bias
                      ->template mutable_data<BatchNormParamType<T>>(
                          ctx.GetPlace()),
                  /*epsilon=*/epsilon,
                  /*savedMean=*/saved_mean_data,
                  /*savedInvVariance=*/saved_var_data,
                  /*activationDesc=*/nullptr,
                  /*workspace=*/workspace_ptr,
                  /*workSpaceSizeInBytes=*/workspace_size,
                  /*reserveSpace=*/const_cast<T *>(
                      reserve_space->template data<T>()),
                  /*reserveSpaceSizeInBytes=*/reserve_space_size));
624 625 626
        }
#endif
        if (!called) {
627 628 629 630 631 632 633 634 635 636 637 638 639 640
          PADDLE_ENFORCE_CUDA_SUCCESS(
              platform::dynload::cudnnBatchNormalizationBackward(
                  dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
                  CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
                  CudnnDataType<T>::kZero(), data_desc_,
                  transformed_x.template data<T>(), data_desc_,
                  transformed_d_y.template data<T>(), data_desc_,
                  transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
                  bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
                  d_scale->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  d_bias->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  epsilon, saved_mean_data, saved_var_data));
641 642 643 644 645 646 647 648
        }

        if (data_layout == DataLayout::kNHWC &&
            compute_format == DataLayout::kNCHW) {
          VLOG(3) << "Transform batchnorm output from NCHW to NHWC";
          TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
              ctx, &transformed_d_x, d_x);
        }
L
lvmengsi 已提交
649
      } else {
650
        if (compute_format == DataLayout::kNCHW) {
L
lvmengsi 已提交
651 652 653 654 655 656 657 658 659
          if (d_x) {
            BNBackwardData<T, block, framework::DataLayout::kNCHW><<<
                grid2, block, 0, dev_ctx.stream()>>>(
                d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
                saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
                d_x->data<T>());
          }
        } else {
          if (d_x) {
L
Lv Mengsi 已提交
660
            BNBackwardData<T, block, framework::DataLayout::kNHWC><<<
L
lvmengsi 已提交
661 662 663 664 665 666 667
                grid2, block, 0, dev_ctx.stream()>>>(
                d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
                saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
                d_x->data<T>());
          }
        }
      }
668 669

      // clean when exit.
670
      PADDLE_ENFORCE_CUDA_SUCCESS(
671
          platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
672
      PADDLE_ENFORCE_CUDA_SUCCESS(
673 674 675 676 677 678 679 680 681 682
          platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
    } else {
      const auto *running_mean = ctx.Input<Tensor>("Mean");
      const auto *running_var = ctx.Input<Tensor>("Variance");

      const auto *running_mean_data =
          running_mean->template data<BatchNormParamType<T>>();
      const auto *running_var_data =
          running_var->template data<BatchNormParamType<T>>();

683
      if (compute_format == DataLayout::kNCHW) {
684 685 686 687 688 689 690 691 692 693
        if (d_x) {
          KeBNBackwardData<T, framework::DataLayout::kNCHW><<<
              grid1, block, 0, dev_ctx.stream()>>>(
              d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
              running_var_data, epsilon, C, H * W, num, d_x->data<T>());
        }
        if (d_scale && d_bias) {
          KeBNBackwardScaleBias<T, block, framework::DataLayout::kNCHW><<<
              grid2, block, 0, dev_ctx.stream()>>>(
              d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data,
Q
qingqing01 已提交
694
              epsilon, N, C, H * W * D, d_scale->data<BatchNormParamType<T>>(),
695 696 697 698 699 700 701 702 703 704
              d_bias->data<BatchNormParamType<T>>());
        }
      } else {
        if (d_x) {
          KeBNBackwardData<T, framework::DataLayout::kNHWC><<<
              grid1, block, 0, dev_ctx.stream()>>>(
              d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
              running_var_data, epsilon, C, H * W, num, d_x->data<T>());
        }
        if (d_scale && d_bias) {
Q
qingqing01 已提交
705
          KeBNBackwardScaleBias<T, block, framework::DataLayout::kNHWC><<<
706 707
              grid2, block, 0, dev_ctx.stream()>>>(
              d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data,
Q
qingqing01 已提交
708
              epsilon, N, C, H * W * D, d_scale->data<BatchNormParamType<T>>(),
709 710 711 712
              d_bias->data<BatchNormParamType<T>>());
        }
      }
    }
Q
Qiao Longfei 已提交
713 714 715 716 717 718 719
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
K
Kexin Zhao 已提交
720
namespace plat = paddle::platform;
Q
QI JUN 已提交
721
REGISTER_OP_CUDA_KERNEL(
K
Kexin Zhao 已提交
722
    batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
D
dzhwinter 已提交
723
    ops::BatchNormKernel<plat::CUDADeviceContext, double>,
K
Kexin Zhao 已提交
724
    ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
Q
QI JUN 已提交
725
REGISTER_OP_CUDA_KERNEL(
D
dzhwinter 已提交
726
    batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>,
C
chengduo 已提交
727 728
    ops::BatchNormGradKernel<plat::CUDADeviceContext, double>,
    ops::BatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);