batch_norm_op.cu 58.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
Q
Qiao Longfei 已提交
16
#include <cfloat>
17 18
#include <string>
#include <vector>
19
#ifdef __NVCC__
20
#include "cub/cub.cuh"
21 22 23 24 25
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
S
Siddharth Goyal 已提交
26
#include "paddle/fluid/framework/data_layout.h"
27
#include "paddle/fluid/operators/batch_norm_op.h"
Y
Yi Wang 已提交
28
#include "paddle/fluid/operators/math/math_function.h"
29
#include "paddle/fluid/operators/norm_utils.cu.h"
K
Kexin Zhao 已提交
30
#include "paddle/fluid/platform/float16.h"
Q
Qiao Longfei 已提交
31

32
DECLARE_bool(cudnn_batchnorm_spatial_persistent);
W
Wu Yi 已提交
33

Q
Qiao Longfei 已提交
34 35 36 37
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
Q
QI JUN 已提交
38
using DataLayout = framework::DataLayout;
Q
Qiao Longfei 已提交
39 40
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
K
Kexin Zhao 已提交
41
template <typename T>
K
update  
Kexin Zhao 已提交
42
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
Q
Qiao Longfei 已提交
43

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
template <typename T, framework::DataLayout layout>
static __global__ void BNForwardInference(
    const T *x, const BatchNormParamType<T> *mean,
    const BatchNormParamType<T> *variance, const BatchNormParamType<T> *scale,
    const BatchNormParamType<T> *bias, const int C, const int N, const int HxW,
    const double epsilon, T *y) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  int num = N * C * HxW;
  for (int i = gid; i < num; i += stride) {
    const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
    BatchNormParamType<T> x_sub_mean =
        static_cast<BatchNormParamType<T>>(x[i]) - mean[c];
    BatchNormParamType<T> inv_var = 1 / sqrt(variance[c] + epsilon);
    y[i] = static_cast<T>(scale[c] * x_sub_mean * inv_var + bias[c]);
  }
}

template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
    const T *x, const BatchNormParamType<T> *scale,
    const BatchNormParamType<T> *bias, const int C, const int N, const int HxW,
    const double epsilon, double exponentialAverageFactor, T *y,
    BatchNormParamType<T> *mean, BatchNormParamType<T> *variance,
    BatchNormParamType<T> *save_mean,
    BatchNormParamType<T> *save_inv_variance) {
  int outer_size = C;
  int inner_size = N * HxW;
  typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage mean_storage;
  __shared__ typename BlockReduce::TempStorage variance_storeage;
  __shared__ BatchNormParamType<T> mean_val;
  __shared__ BatchNormParamType<T> variance_val;
  __shared__ BatchNormParamType<T> inv_var_val;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    BatchNormParamType<T> x_sum = static_cast<BatchNormParamType<T>>(0);
    BatchNormParamType<T> x_square_sum = static_cast<BatchNormParamType<T>>(0);

    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index = layout == framework::DataLayout::kNCHW
                            ? (j / HxW * C + i) * HxW + j % HxW
                            : j * outer_size + i;
      BatchNormParamType<T> x_i = static_cast<BatchNormParamType<T>>(x[index]);
      x_sum += x_i;
      x_square_sum += x_i * x_i;
    }
    x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum());
    x_square_sum =
        BlockReduce(variance_storeage).Reduce(x_square_sum, cub::Sum());
    if (threadIdx.x == 0) {
      mean_val = x_sum / inner_size;
      variance_val = x_square_sum / inner_size - mean_val * mean_val;
      inv_var_val = 1 / sqrt(variance_val + epsilon);

      if (save_mean && save_inv_variance) {
        save_mean[i] = mean_val;
        save_inv_variance[i] = inv_var_val;
      }
      mean[i] = (1 - exponentialAverageFactor) * mean_val +
                exponentialAverageFactor * mean[i];
      variance[i] = (1 - exponentialAverageFactor) * variance_val +
                    exponentialAverageFactor * variance[i];
    }
    __syncthreads();

    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index = layout == framework::DataLayout::kNCHW
                            ? (j / HxW * C + i) * HxW + j % HxW
                            : j * outer_size + i;
      BatchNormParamType<T> x_sub_mean =
          static_cast<BatchNormParamType<T>>(x[index]) - mean_val;
      y[index] = scale[i] * x_sub_mean * inv_var_val + bias[i];
    }
  }
}

Q
Qiao Longfei 已提交
121
template <typename T>
Q
QI JUN 已提交
122 123
class BatchNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
Q
Qiao Longfei 已提交
124 125
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
K
Kaipeng Deng 已提交
126 127 128
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(ctx.GetPlace()), true,
        platform::errors::InvalidArgument("It must use CUDAPlace."));
Q
Qiao Longfei 已提交
129
    double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
130
    float momentum = ctx.Attr<float>("momentum");
Q
Qiao Longfei 已提交
131
    const bool is_test = ctx.Attr<bool>("is_test");
132
    const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
133
    const bool trainable_stats = ctx.Attr<bool>("trainable_statistics");
Q
QI JUN 已提交
134 135 136
    const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
    const DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
Q
Qiao Longfei 已提交
137

138 139
    bool test_mode = is_test && (!trainable_stats);

Q
Qiao Longfei 已提交
140 141 142 143
    // Get the size for each dimension.
    // NCHW [batch_size, in_channels, in_height, in_width]
    const auto *x = ctx.Input<Tensor>("X");
    const auto &x_dims = x->dims();
C
ceci3 已提交
144 145 146 147 148 149
    PADDLE_ENFORCE_EQ(
        x_dims.size() >= 2 && x_dims.size() <= 5, true,
        platform::errors::InvalidArgument(
            "The size of input's dimensions should be between 2 and 5"
            "But received: the size of input's dimensions is [%d]",
            x_dims.size()));
Q
Qiao Longfei 已提交
150

151 152 153
    auto *y = ctx.Output<Tensor>("Y");
    y->mutable_data<T>(ctx.GetPlace());

154 155 156 157
    int N, C, H, W, D;
    ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);

    auto dtype = platform::CudnnDataType<T>::type;
158 159

#ifdef PADDLE_WITH_HIP
160 161 162 163 164 165
    auto compute_format = data_layout == DataLayout::kNHWC ? DataLayout::kNHWC
                                                           : DataLayout::kNCHW;

// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
// auto compute_format = DataLayout::kNCHW;
166
#else
167
    const bool fast_nhwc_batch_norm =
168
        test_mode ||
169 170 171 172 173 174
        (dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent);

    auto compute_format =
        fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC
            ? DataLayout::kNHWC
            : DataLayout::kNCHW;
175
#endif
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192

    Tensor transformed_x(x->type());
    Tensor transformed_y(y->type());
    if (data_layout == DataLayout::kNHWC &&
        compute_format == DataLayout::kNCHW && x_dims.size() > 2) {
      VLOG(3) << "Transform input tensor from NHWC to NCHW.";
      ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
                                                           &transformed_x);
      TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
                                                          &transformed_x);
      ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, y,
                                                           &transformed_y);
    } else {
      transformed_x.ShareDataWith(*x);
      transformed_y.ShareDataWith(*y);
    }

193 194
// ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
195 196 197 198 199 200 201 202 203
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// miopenTensorDescriptor_t data_desc_;
// miopenTensorDescriptor_t bn_param_desc_;
// miopenBatchNormMode_t mode_;

// PADDLE_ENFORCE_CUDA_SUCCESS(
//     platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS(
//     platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
204
#else
Q
Qiao Longfei 已提交
205 206 207 208
    cudnnTensorDescriptor_t data_desc_;
    cudnnTensorDescriptor_t bn_param_desc_;
    cudnnBatchNormMode_t mode_;

209 210 211
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
    PADDLE_ENFORCE_CUDA_SUCCESS(
Q
Qiao Longfei 已提交
212
        platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
213
#endif
Q
Qiao Longfei 已提交
214 215 216 217 218 219 220

    if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
      LOG(ERROR) << "Provided epsilon is smaller than "
                 << "CUDNN_BN_MIN_EPSILON. Setting it to "
                 << "CUDNN_BN_MIN_EPSILON instead.";
    }
    epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
221 222

#ifdef PADDLE_WITH_HIP
223 224
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// mode_ = miopenBNSpatial;
225
#elif CUDNN_VERSION_MIN(7, 0, 1)
W
Wu Yi 已提交
226 227
    if (FLAGS_cudnn_batchnorm_spatial_persistent) {
      mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
228 229
    } else if (H == 1 && W == 1) {
      mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
W
Wu Yi 已提交
230 231 232
    } else {
      mode_ = CUDNN_BATCHNORM_SPATIAL;
    }
233
#else
234 235 236 237 238
    if (H == 1 && W == 1) {
      mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
    } else {
      mode_ = CUDNN_BATCHNORM_SPATIAL;
    }
239
#endif  // CUDNN_VERSION_MIN(7, 0, 1)
Q
Qiao Longfei 已提交
240

M
minqiyang 已提交
241
    VLOG(3) << "Setting descriptors.";
Q
Qiao Longfei 已提交
242 243
    std::vector<int> dims;
    std::vector<int> strides;
244
    if (compute_format == DataLayout::kNCHW) {
Q
Qiao Longfei 已提交
245 246 247 248 249 250
      dims = {N, C, H, W, D};
      strides = {C * H * W * D, H * W * D, W * D, D, 1};
    } else {
      dims = {N, C, H, W, D};
      strides = {H * W * D * C, 1, W * D * C, D * C, C};
    }
251 252

#ifdef PADDLE_WITH_HIP
253 254 255 256 257 258 259 260 261
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
//     data_desc_, CudnnDataType<T>::type,
//     x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
//     const_cast<int *>(strides.data())));
// Note: PERSISTENT not implemented for inference
// PADDLE_ENFORCE_CUDA_SUCCESS(
//     platform::dynload::miopenDeriveBNTensorDescriptor(
//         bn_param_desc_, data_desc_, test_mode ? miopenBNSpatial : mode_));
262
#else
263
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
Q
Qiao Longfei 已提交
264 265
        data_desc_, CudnnDataType<T>::type,
        x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
K
Kexin Zhao 已提交
266
    // Note: PERSISTENT not implemented for inference
267 268 269
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnDeriveBNTensorDescriptor(
            bn_param_desc_, data_desc_,
270
            test_mode ? CUDNN_BATCHNORM_SPATIAL : mode_));
271
#endif
Q
Qiao Longfei 已提交
272 273 274 275

    const auto *scale = ctx.Input<Tensor>("Scale");
    const auto *bias = ctx.Input<Tensor>("Bias");

Q
QI JUN 已提交
276
    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
Q
Qiao Longfei 已提交
277

Q
QI JUN 已提交
278
    auto handle = dev_ctx.cudnn_handle();
Q
Qiao Longfei 已提交
279 280

    // Now, depending on whether we are running test or not, we have two paths.
281 282 283 284
    // It is training mode when it's not reference AND not using pre-trained
    // model.
    bool training = !test_mode && !use_global_stats;
    if (!training) {
Q
Qiao Longfei 已提交
285 286 287 288
      // only when test we use input to do computation.
      const auto *est_mean = ctx.Input<Tensor>("Mean");
      const auto *est_var = ctx.Input<Tensor>("Variance");
      // Run inference mode.
C
ceci3 已提交
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
      PADDLE_ENFORCE_EQ(
          est_mean->dims().size(), 1UL,
          platform::errors::InvalidArgument(
              "The size of mean's dimensions must equal to 1."
              "But received: the size of mean's dimensions mean is [%d],"
              "the dimensions of mean is [%s].",
              est_mean->dims().size(), est_mean->dims()));
      PADDLE_ENFORCE_EQ(
          est_var->dims().size(), 1UL,
          platform::errors::InvalidArgument(
              "The size of variance's dimensions must equal to 1."
              "But received: the size of variance's dimensions is [%d],"
              "the dimensions of variance is [%s].",
              est_var->dims().size(), est_var->dims()));
      PADDLE_ENFORCE_EQ(
          est_mean->dims()[0], C,
          platform::errors::InvalidArgument(
              "The first dimension of mean must equal to the number of "
              "Channels, which is [%d]. But received: the first dimension"
              "of mean is [%d], the dimensions of mean is [%s].",
              C, est_mean->dims()[0], est_mean->dims()));
      PADDLE_ENFORCE_EQ(
          est_var->dims()[0], C,
          platform::errors::InvalidArgument(
              "The first dimension of variance must equal to the number"
              "of Channels, which is [%d]. But received: the first dimension of"
              "variance is [%d], the dimensions of variance is [%s].",
              C, est_var->dims()[0], est_var->dims()));
Q
Qiao Longfei 已提交
317

318
#ifdef PADDLE_WITH_HIP
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
      const int block_size = 256;
      const int grid_size = (N * C * H * W * D + block_size - 1) / block_size;
      if (compute_format == DataLayout::kNCHW) {
        BNForwardInference<
            T,
            DataLayout::kNCHW><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
            transformed_x.template data<T>(),
            est_mean->template data<BatchNormParamType<T>>(),
            est_var->template data<BatchNormParamType<T>>(),
            scale->template data<BatchNormParamType<T>>(),
            bias->template data<BatchNormParamType<T>>(), C, N, H * W * D,
            epsilon, transformed_y.template data<T>());
      } else {
        BNForwardInference<
            T,
            DataLayout::kNHWC><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
            transformed_x.template data<T>(),
            est_mean->template data<BatchNormParamType<T>>(),
            est_var->template data<BatchNormParamType<T>>(),
            scale->template data<BatchNormParamType<T>>(),
            bias->template data<BatchNormParamType<T>>(), C, N, H * W * D,
            epsilon, transformed_y.template data<T>());
      }

// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(
//     platform::dynload::miopenBatchNormalizationForwardInference(
//         handle, miopenBNSpatial,
//         const_cast<void *>(
//             static_cast<const void *>(CudnnDataType<T>::kOne())),
//         const_cast<void *>(
//             static_cast<const void *>(CudnnDataType<T>::kZero())),
//         data_desc_,
//         static_cast<const void *>(transformed_x.template data<T>()),
//         data_desc_,
//         static_cast<void *>(
//             transformed_y.template mutable_data<T>(ctx.GetPlace())),
//         bn_param_desc_,
//         const_cast<void *>(static_cast<const void *>(
//             scale->template data<BatchNormParamType<T>>())),
//         const_cast<void *>(static_cast<const void *>(
//             bias->template data<BatchNormParamType<T>>())),
//         const_cast<void *>(static_cast<const void *>(
//             est_mean->template data<BatchNormParamType<T>>())),
//         const_cast<void *>(static_cast<const void *>(
//             est_var->template data<BatchNormParamType<T>>())),
//         epsilon));
366
#else
367 368 369 370 371 372 373 374 375 376 377 378
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnBatchNormalizationForwardInference(
              handle,
              // Note: PERSISTENT not implemented for inference
              CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
              CudnnDataType<T>::kZero(), data_desc_,
              transformed_x.template data<T>(), data_desc_,
              transformed_y.template mutable_data<T>(ctx.GetPlace()),
              bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
              bias->template data<BatchNormParamType<T>>(),
              est_mean->template data<BatchNormParamType<T>>(),
              est_var->template data<BatchNormParamType<T>>(), epsilon));
379
#endif
Q
Qiao Longfei 已提交
380
    } else {
381 382 383 384 385 386 387 388 389
      // if MomentumTensor is set, use MomentumTensor value, momentum
      // is only used in this training branch
      if (ctx.HasInput("MomentumTensor")) {
        const auto *mom_tensor = ctx.Input<Tensor>("MomentumTensor");
        Tensor mom_cpu;
        TensorCopySync(*mom_tensor, platform::CPUPlace(), &mom_cpu);
        momentum = mom_cpu.data<float>()[0];
      }

Q
Qiao Longfei 已提交
390
      // Run training mode.
391 392
      // obtain running mean and running inv var, and there is no need
      // to initialize them.
D
Dang Qingqing 已提交
393 394 395 396 397 398 399 400 401 402 403

      auto *mean_out = ctx.Output<Tensor>("MeanOut");
      auto *variance_out = ctx.Output<Tensor>("VarianceOut");
      mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
      variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());

      auto *saved_mean = ctx.Output<Tensor>("SavedMean");
      auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
      saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
      saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());

404
      if ((N * H * W * D) == 1) {
405 406
        // Only 1 element in normalization dimension,
        // skip the batch norm calculation, let y = x.
407
        framework::TensorCopy(*x, ctx.GetPlace(), y);
408 409 410
      } else {
        double this_factor = 1. - momentum;

411 412
        bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1)
413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476
        called = true;
        size_t workspace_size = 0;
        size_t reserve_space_size = 0;
        void *reserve_space_ptr = nullptr;
        void *workspace_ptr = nullptr;
        Tensor workspace_tensor;
        // Create reserve space and workspace for batch norm.
        // Create tensor for each batchnorm op, it will be used in the
        // backward. Thus this tensor shouldn't be temp.
        auto *reserve_space = ctx.Output<Tensor>("ReserveSpace");
        PADDLE_ENFORCE_NOT_NULL(
            reserve_space,
            platform::errors::NotFound(
                "The argument ReserveSpace of batch_norm op is not found."));

        // --------------- cudnn batchnorm workspace ---------------
        PADDLE_ENFORCE_CUDA_SUCCESS(
            platform::dynload::
                cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
                    /*handle=*/handle,
                    /*mode=*/mode_,
                    /*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
                    /*xDesc=*/data_desc_,
                    /*zDesc=*/nullptr,
                    /*yDesc=*/data_desc_,
                    /*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
                    /*activationDesc=*/nullptr,
                    /*sizeInBytes=*/&workspace_size));

        // -------------- cudnn batchnorm reserve space --------------
        PADDLE_ENFORCE_CUDA_SUCCESS(
            platform::dynload::
                cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
                    /*handle=*/handle,
                    /*mode=*/mode_,
                    /*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
                    /*activationDesc=*/nullptr,
                    /*xDesc=*/data_desc_,
                    /*sizeInBytes=*/&reserve_space_size));

        reserve_space_ptr = reserve_space->mutable_data(
            ctx.GetPlace(), transformed_x.type(), reserve_space_size);
        workspace_ptr = workspace_tensor.mutable_data(
            ctx.GetPlace(), transformed_x.type(), workspace_size);
        PADDLE_ENFORCE_CUDA_SUCCESS(
            platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
                handle, mode_, CUDNN_BATCHNORM_OPS_BN, CudnnDataType<T>::kOne(),
                CudnnDataType<T>::kZero(), data_desc_,
                transformed_x.template data<T>(), nullptr, nullptr, data_desc_,
                transformed_y.template data<T>(), bn_param_desc_,
                scale->template data<BatchNormParamType<T>>(),
                bias->template data<BatchNormParamType<T>>(), this_factor,
                mean_out->template mutable_data<BatchNormParamType<T>>(
                    ctx.GetPlace()),
                variance_out->template mutable_data<BatchNormParamType<T>>(
                    ctx.GetPlace()),
                epsilon,
                saved_mean->template mutable_data<BatchNormParamType<T>>(
                    ctx.GetPlace()),
                saved_variance->template mutable_data<BatchNormParamType<T>>(
                    ctx.GetPlace()),
                nullptr, workspace_ptr, workspace_size, reserve_space_ptr,
                reserve_space_size));
#endif  // CUDNN_VERSION_MIN(7, 4, 1)
477
        if (!called) {
478
#ifdef PADDLE_WITH_HIP
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538
          const int num = transformed_x.numel();
          const int block = 256;
          const int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
          const int max_blocks = std::max(max_threads / block, 1);
          const int grid = std::min(C, max_blocks);
          if (compute_format == DataLayout::kNCHW) {
            BNForwardTraining<
                T, block,
                DataLayout::kNCHW><<<grid, block, 0, dev_ctx.stream()>>>(
                transformed_x.template data<T>(),
                scale->template data<BatchNormParamType<T>>(),
                bias->template data<BatchNormParamType<T>>(), C, N, H * W * D,
                epsilon, this_factor, transformed_y.template data<T>(),
                mean_out->template data<BatchNormParamType<T>>(),
                variance_out->template data<BatchNormParamType<T>>(),
                saved_mean->template data<BatchNormParamType<T>>(),
                saved_variance->template data<BatchNormParamType<T>>());
          } else {
            BNForwardTraining<
                T, block,
                DataLayout::kNHWC><<<grid, block, 0, dev_ctx.stream()>>>(
                transformed_x.template data<T>(),
                scale->template data<BatchNormParamType<T>>(),
                bias->template data<BatchNormParamType<T>>(), C, N, H * W * D,
                epsilon, this_factor, transformed_y.template data<T>(),
                mean_out->template data<BatchNormParamType<T>>(),
                variance_out->template data<BatchNormParamType<T>>(),
                saved_mean->template data<BatchNormParamType<T>>(),
                saved_variance->template data<BatchNormParamType<T>>());
          }

// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(
//     platform::dynload::miopenBatchNormalizationForwardTraining(
//         handle, mode_, const_cast<void *>(static_cast<const void *>(
//                            CudnnDataType<T>::kOne())),
//         const_cast<void *>(
//             static_cast<const void *>(CudnnDataType<T>::kZero())),
//         data_desc_,
//         static_cast<const void *>(transformed_x.template data<T>()),
//         data_desc_,
//         static_cast<void *>(
//             transformed_y.template mutable_data<T>(ctx.GetPlace())),
//         bn_param_desc_,
//         const_cast<void *>(static_cast<const void *>(
//             scale->template data<BatchNormParamType<T>>())),
//         const_cast<void *>(static_cast<const void *>(
//             bias->template data<BatchNormParamType<T>>())),
//         this_factor,
//         static_cast<void *>(
//             mean_out->template mutable_data<BatchNormParamType<T>>(
//                 ctx.GetPlace())),
//         static_cast<void *>(variance_out->template mutable_data<
//                             BatchNormParamType<T>>(ctx.GetPlace())),
//         epsilon,
//         static_cast<void *>(
//             saved_mean->template mutable_data<BatchNormParamType<T>>(
//                 ctx.GetPlace())),
//         static_cast<void *>(saved_variance->template mutable_data<
//                             BatchNormParamType<T>>(ctx.GetPlace()))));
539
#else
540
          PADDLE_ENFORCE_CUDA_SUCCESS(
541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556
              platform::dynload::cudnnBatchNormalizationForwardTraining(
                  handle, mode_, CudnnDataType<T>::kOne(),
                  CudnnDataType<T>::kZero(), data_desc_,
                  transformed_x.template data<T>(), data_desc_,
                  transformed_y.template mutable_data<T>(ctx.GetPlace()),
                  bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
                  bias->template data<BatchNormParamType<T>>(), this_factor,
                  mean_out->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  variance_out->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  epsilon,
                  saved_mean->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  saved_variance->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace())));
557
#endif
558
        }
559
      }
Q
Qiao Longfei 已提交
560 561
    }

562 563 564 565 566 567
    if (data_layout == DataLayout::kNHWC &&
        compute_format == DataLayout::kNCHW && x_dims.size() > 2) {
      VLOG(3) << "Transform batchnorm output from NCHW to NHWC";
      TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
          ctx, &transformed_y, y);
    }
568
#ifdef PADDLE_WITH_HIP
569 570 571 572 573 574
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// clean when exit.
// PADDLE_ENFORCE_CUDA_SUCCESS(
//     platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS(
//     platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
575
#else
Q
Qiao Longfei 已提交
576
    // clean when exit.
577 578 579
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
    PADDLE_ENFORCE_CUDA_SUCCESS(
Q
Qiao Longfei 已提交
580
        platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
581
#endif
Q
Qiao Longfei 已提交
582 583 584
  }
};

585
template <typename T, int BlockDim, framework::DataLayout layout>
586
static __global__ LAUNCH_BOUNDS(BlockDim) void KeBNBackwardScaleBias(
587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620
    const T *dy, const T *x, const BatchNormParamType<T> *mean,
    const BatchNormParamType<T> *variance, const double epsilon, const int N,
    const int C, const int HxW, BatchNormParamType<T> *dscale,
    BatchNormParamType<T> *dbias) {
  const int outer_size = C;
  const int inner_size = N * HxW;
  typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage ds_storage;
  __shared__ typename BlockReduce::TempStorage db_storage;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    BatchNormParamType<T> ds_sum = static_cast<BatchNormParamType<T>>(0);
    BatchNormParamType<T> db_sum = static_cast<BatchNormParamType<T>>(0);

    BatchNormParamType<T> inv_var_i = 1.0 / sqrt(variance[i] + epsilon);
    BatchNormParamType<T> mean_i = mean[i];
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index = layout == framework::DataLayout::kNCHW
                            ? (j / HxW * C + i) * HxW + j % HxW
                            : j * outer_size + i;
      ds_sum += static_cast<BatchNormParamType<T>>(dy[index]) *
                (static_cast<BatchNormParamType<T>>(x[index]) - mean_i);
      db_sum += static_cast<BatchNormParamType<T>>(dy[index]);
    }
    ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
    db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
    if (threadIdx.x == 0) {
      dscale[i] = ds_sum * inv_var_i;
      dbias[i] = db_sum;
    }
    __syncthreads();
  }
}

Q
qingqing01 已提交
621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636
template <typename T, framework::DataLayout layout>
static __global__ void KeBNBackwardData(const T *dy,
                                        const BatchNormParamType<T> *scale,
                                        const BatchNormParamType<T> *variance,
                                        const double epsilon, const int C,
                                        const int HxW, const int num, T *dx) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  for (int i = gid; i < num; i += stride) {
    const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
    BatchNormParamType<T> inv_var = 1.0 / sqrt(variance[c] + epsilon);
    dx[i] = static_cast<T>(static_cast<BatchNormParamType<T>>(dy[i]) *
                           scale[c] * inv_var);
  }
}

K
Kaipeng Deng 已提交
637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663
template <typename T>
static __global__ void KeBNRestoreData(const framework::DataLayout layout, T *x,
                                       const BatchNormParamType<T> *scale,
                                       const BatchNormParamType<T> *bias,
                                       const BatchNormParamType<T> *mean,
                                       const BatchNormParamType<T> *variance,
                                       double epsilon, int C, int M,
                                       const int num, const T *y) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  for (int i = gid; i < num; i += stride) {
    const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C;
    auto y_i = static_cast<BatchNormParamType<T>>(y[i]);
    auto x_i = (y_i - bias[c]) / scale[c] / variance[c] + mean[c];
    x[i] = static_cast<T>(x_i);
  }
}

template <typename T>
class InplaceHelper {
 public:
  void operator()(const framework::DataLayout layout, T *x,
                  const BatchNormParamType<T> *scale,
                  const BatchNormParamType<T> *bias,
                  const BatchNormParamType<T> *mean,
                  const BatchNormParamType<T> *variance, double epsilon, int C,
                  int M, const int num, const T *y, int grid2, const int block,
664
                  const gpuStream_t &stream) {
K
Kaipeng Deng 已提交
665 666 667 668 669 670 671
    PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument(
                                "X and Y should be inplaced in inplace mode"));
    KeBNRestoreData<<<grid2, block, 0, stream>>>(
        layout, x, scale, bias, mean, variance, epsilon, C, M, num, y);
  }
};

L
lvmengsi 已提交
672
template <typename T, int BlockDim, framework::DataLayout layout>
673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763
static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward(
    const T *dy, const T *x, const BatchNormParamType<T> *scale,
    const BatchNormParamType<T> *saved_mean,
    const BatchNormParamType<T> *saved_inv_variance, const int C, const int N,
    const int HxW, const double epsilon, T *dx, BatchNormParamType<T> *dscale,
    BatchNormParamType<T> *dbias) {
  const int outer_size = C;
  const int inner_size = N * HxW;
  typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage ds_storage;
  __shared__ typename BlockReduce::TempStorage db_storage;
  __shared__ typename BlockReduce::TempStorage mean_storage;
  __shared__ typename BlockReduce::TempStorage variance_storeage;
  __shared__ BatchNormParamType<T> inv_var_val;
  __shared__ BatchNormParamType<T> mean_val;
  __shared__ BatchNormParamType<T> dscale_val;
  __shared__ BatchNormParamType<T> dbias_val;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    BatchNormParamType<T> ds_sum = static_cast<BatchNormParamType<T>>(0);
    BatchNormParamType<T> db_sum = static_cast<BatchNormParamType<T>>(0);

    if (saved_mean && saved_inv_variance) {
      if (threadIdx.x == 0) {
        inv_var_val = saved_inv_variance[i];
        mean_val = saved_mean[i];
      }
    } else {
      BatchNormParamType<T> x_sum = static_cast<BatchNormParamType<T>>(0);
      BatchNormParamType<T> x_square_sum =
          static_cast<BatchNormParamType<T>>(0);

      for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
        const int index = layout == framework::DataLayout::kNCHW
                              ? (j / HxW * C + i) * HxW + j % HxW
                              : j * outer_size + i;
        BatchNormParamType<T> x_i =
            static_cast<BatchNormParamType<T>>(x[index]);
        x_sum += x_i;
        x_square_sum += x_i * x_i;
      }
      x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum());
      x_square_sum =
          BlockReduce(variance_storeage).Reduce(x_square_sum, cub::Sum());
      if (threadIdx.x == 0) {
        mean_val = x_sum / inner_size;
        inv_var_val =
            1 / sqrt(x_square_sum / inner_size - mean_val * mean_val + epsilon);
      }
    }
    __syncthreads();

    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index = layout == framework::DataLayout::kNCHW
                            ? (j / HxW * C + i) * HxW + j % HxW
                            : j * outer_size + i;
      BatchNormParamType<T> dy_i =
          static_cast<BatchNormParamType<T>>(dy[index]);
      ds_sum +=
          dy_i * (static_cast<BatchNormParamType<T>>(x[index]) - mean_val);
      db_sum += dy_i;
    }
    ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
    db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
    if (threadIdx.x == 0) {
      dscale_val = ds_sum * inv_var_val;
      dbias_val = db_sum;
      dscale[i] = dscale_val;
      dbias[i] = dbias_val;
    }
    __syncthreads();

    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index = layout == framework::DataLayout::kNCHW
                            ? (j / HxW * C + i) * HxW + j % HxW
                            : j * outer_size + i;
      dx[index] = scale[i] * inv_var_val *
                  (static_cast<BatchNormParamType<T>>(dy[index]) -
                   dbias_val / static_cast<BatchNormParamType<T>>(inner_size) -
                   (static_cast<BatchNormParamType<T>>(x[index]) - mean_val) *
                       inv_var_val * dscale_val / inner_size);
    }
  }
}

template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
    const T *dy, const BatchNormParamType<T> *scale,
    const BatchNormParamType<T> *mean, const T *x,
    const BatchNormParamType<T> *variance, const int C, const int N,
    const int HxW, T *dx) {
L
lvmengsi 已提交
764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811
  const int outer_size = C;
  const int inner_size = N * HxW;
  typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage dy_storage;
  __shared__ typename BlockReduce::TempStorage dy_x_sub_mean_storage;
  __shared__ BatchNormParamType<T> dy_sum_val;
  __shared__ BatchNormParamType<T> dy_x_sub_mean_sum_val;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    BatchNormParamType<T> inv_var_i = variance[i];
    BatchNormParamType<T> mean_i = mean[i];
    BatchNormParamType<T> dy_sum = static_cast<BatchNormParamType<T>>(0);
    BatchNormParamType<T> dy_x_sub_mean_sum =
        static_cast<BatchNormParamType<T>>(0);
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index = layout == framework::DataLayout::kNCHW
                            ? (j / HxW * C + i) * HxW + j % HxW
                            : j * outer_size + i;
      BatchNormParamType<T> dy_i =
          static_cast<BatchNormParamType<T>>(dy[index]);
      dy_sum += dy_i;
      dy_x_sub_mean_sum +=
          dy_i * (static_cast<BatchNormParamType<T>>(x[index]) - mean_i);
    }

    dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
    dy_x_sub_mean_sum = BlockReduce(dy_x_sub_mean_storage)
                            .Reduce(dy_x_sub_mean_sum, cub::Sum());

    if (threadIdx.x == 0) {
      dy_sum_val = dy_sum;
      dy_x_sub_mean_sum_val = dy_x_sub_mean_sum;
    }
    __syncthreads();
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index = layout == framework::DataLayout::kNCHW
                            ? (j / HxW * C + i) * HxW + j % HxW
                            : j * outer_size + i;
      dx[index] =
          (static_cast<BatchNormParamType<T>>(dy[index]) -
           dy_sum_val / static_cast<BatchNormParamType<T>>(inner_size) -
           (static_cast<BatchNormParamType<T>>(x[index]) - mean_i) *
               dy_x_sub_mean_sum_val * inv_var_i * inv_var_i / inner_size) *
          scale[i] * inv_var_i;
    }
  }
}

Q
Qiao Longfei 已提交
812
template <typename T>
Q
QI JUN 已提交
813
class BatchNormGradKernel<platform::CUDADeviceContext, T>
Q
Qiao Longfei 已提交
814 815 816
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
K
Kaipeng Deng 已提交
817 818 819
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(ctx.GetPlace()), true,
        platform::errors::InvalidArgument("It must use CUDAPlace."));
Q
Qiao Longfei 已提交
820
    double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
Q
QI JUN 已提交
821
    const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
C
ceci3 已提交
822
    bool use_global_stats = ctx.Attr<bool>("use_global_stats");
823

Q
QI JUN 已提交
824 825
    const DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
Q
Qiao Longfei 已提交
826 827
    const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
    const auto *scale = ctx.Input<Tensor>("Scale");
K
Kaipeng Deng 已提交
828 829 830 831 832 833 834 835 836 837 838 839 840 841 842
    const auto *bias = ctx.Input<Tensor>("Bias");

    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    // batch_norm with inplace as false will take X as grad input, which
    // is same as cuDNN batch_norm backward calculation, batch_norm
    // with inplace as true only take Y as input and X should be calculate
    // by inverse operation of batch_norm on Y
    const Tensor *x;
    bool is_inplace;
    if (ctx.HasInput("Y")) {
      x = ctx.Input<Tensor>("Y");
      is_inplace = true;
843 844 845 846 847
      if (d_x) {
        PADDLE_ENFORCE_EQ(d_x, d_y,
                          platform::errors::InvalidArgument(
                              "X@GRAD and Y@GRAD not inplace in inplace mode"));
      }
K
Kaipeng Deng 已提交
848 849 850
    } else {
      x = ctx.Input<Tensor>("X");
      is_inplace = false;
851 852 853 854 855
      if (d_x) {
        PADDLE_ENFORCE_NE(
            d_x, d_y, platform::errors::InvalidArgument(
                          "X@GRAD and Y@GRAD inplaced in non-inplace mode"));
      }
K
Kaipeng Deng 已提交
856 857
    }

858
    const bool is_test = ctx.Attr<bool>("is_test");
C
ceci3 已提交
859
    use_global_stats = is_test || use_global_stats;
Q
Qiao Longfei 已提交
860 861 862

    const auto &x_dims = x->dims();

C
ceci3 已提交
863 864 865 866 867 868 869
    PADDLE_ENFORCE_EQ(
        x_dims.size() >= 2 && x_dims.size() <= 5, true,
        platform::errors::InvalidArgument(
            "The size of input's dimensions should be between 2 and 5."
            "But received: the size of input's dimensions is [%d],"
            "the dimensions of input is [%s]",
            x_dims.size(), x_dims));
Q
Qiao Longfei 已提交
870
    int N, C, H, W, D;
Q
QI JUN 已提交
871
    ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
Q
Qiao Longfei 已提交
872

873
    // init output
874 875 876
    if (d_x) {
      d_x->mutable_data<T>(ctx.GetPlace());
    }
K
Kaipeng Deng 已提交
877

878 879 880
    if (d_scale && d_bias) {
      d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
      d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
881
    }
C
ceci3 已提交
882 883 884 885 886 887 888 889 890 891 892 893 894
    PADDLE_ENFORCE_EQ(
        scale->dims().size(), 1UL,
        platform::errors::InvalidArgument(
            "The size of scale's dimensions must equal to 1. But received: "
            "the size of scale's dimensions is [%d], the dimensions of scale "
            "is [%s].",
            scale->dims().size(), scale->dims()));
    PADDLE_ENFORCE_EQ(
        scale->dims()[0], C,
        platform::errors::InvalidArgument(
            "The first dimension of scale must equal to Channels[%d]. But "
            "received: the first dimension of scale is [%d]",
            C, scale->dims()[0]));
Q
Qiao Longfei 已提交
895

896 897
    auto dtype = platform::CudnnDataType<T>::type;
    const auto *reserve_space = ctx.Input<Tensor>("ReserveSpace");
898
#ifdef PADDLE_WITH_HIP
899 900 901 902 903 904
    auto compute_format = data_layout == DataLayout::kNHWC ? DataLayout::kNHWC
                                                           : DataLayout::kNCHW;

// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
// auto compute_format = DataLayout::kNCHW;
905
#else
906 907 908 909 910 911 912
    const bool fast_nhwc_batch_norm =
        dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent &&
        reserve_space != nullptr;
    auto compute_format =
        fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC
            ? DataLayout::kNHWC
            : DataLayout::kNCHW;
913
#endif
914 915 916

    Tensor transformed_x(x->type());
    Tensor transformed_d_y(d_y->type());
917
    Tensor transformed_d_x;
918 919 920 921 922 923 924 925 926 927 928
    if (data_layout == DataLayout::kNHWC &&
        compute_format == DataLayout::kNCHW) {
      VLOG(3) << "Transform input tensor from NHWC to NCHW.";
      ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
                                                           &transformed_x);
      TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
                                                          &transformed_x);
      ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_y,
                                                           &transformed_d_y);
      TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_y,
                                                          &transformed_d_y);
929 930 931 932
      if (d_x) {
        ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_x,
                                                             &transformed_d_x);
      }
933 934 935
    } else {
      transformed_x.ShareDataWith(*x);
      transformed_d_y.ShareDataWith(*d_y);
936 937 938
      if (d_x) {
        transformed_d_x.ShareDataWith(*d_x);
      }
939 940
    }

Z
zchen0211 已提交
941 942
    std::vector<int> dims;
    std::vector<int> strides;
943
    if (compute_format == DataLayout::kNCHW) {
Z
zchen0211 已提交
944 945 946 947 948 949
      dims = {N, C, H, W, D};
      strides = {C * H * W * D, H * W * D, W * D, D, 1};
    } else {
      dims = {N, C, H, W, D};
      strides = {H * W * C * D, 1, W * D * C, D * C, C};
    }
Q
Qiao Longfei 已提交
950

951
    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
952
    const int num = transformed_x.numel();
953 954 955
#ifdef HIPCC
    const int block = 256;
#else
L
lvmengsi 已提交
956
    const int block = 512;
957
#endif
L
lvmengsi 已提交
958 959 960 961
    int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
    const int max_blocks = std::max(max_threads / block, 1);
    int grid1 = (num + block - 1) / block;
    int grid2 = std::min(C, max_blocks);
K
Kaipeng Deng 已提交
962 963
    auto stream = dev_ctx.stream();
    InplaceHelper<T> inplace_functor;
L
lvmengsi 已提交
964

965 966
    if (!use_global_stats) {
      if ((N * H * W * D) == 1) {
967 968 969
        if (d_x) {
          framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
        }
970 971 972 973 974 975 976
        math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
            functor;
        functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
        functor(dev_ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
        return;
      }

977 978
// ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
979 980 981 982 983 984 985 986 987
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// miopenTensorDescriptor_t data_desc_;
// miopenTensorDescriptor_t bn_param_desc_;
// miopenBatchNormMode_t mode_;

// PADDLE_ENFORCE_CUDA_SUCCESS(
//     platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS(
//     platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
988
#else
989 990 991 992
      cudnnTensorDescriptor_t data_desc_;
      cudnnTensorDescriptor_t bn_param_desc_;
      cudnnBatchNormMode_t mode_;

993
      PADDLE_ENFORCE_CUDA_SUCCESS(
994
          platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
995
      PADDLE_ENFORCE_CUDA_SUCCESS(
996
          platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
997
#endif
998 999 1000 1001 1002 1003
      if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
        LOG(ERROR) << "Provided epsilon is smaller than "
                   << "CUDNN_BN_MIN_EPSILON. Setting it to "
                   << "CUDNN_BN_MIN_EPSILON instead.";
      }
      epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
1004
#ifdef PADDLE_WITH_HIP
1005 1006
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// mode_ = miopenBNSpatial;
1007
#elif CUDNN_VERSION_MIN(7, 0, 1)
W
Wu Yi 已提交
1008 1009
      if (FLAGS_cudnn_batchnorm_spatial_persistent) {
        mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
1010 1011
      } else if (H == 1 && W == 1) {
        mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
W
Wu Yi 已提交
1012 1013 1014
      } else {
        mode_ = CUDNN_BATCHNORM_SPATIAL;
      }
1015
#else
1016 1017 1018 1019 1020
      if (H == 1 && W == 1) {
        mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
      } else {
        mode_ = CUDNN_BATCHNORM_SPATIAL;
      }
1021
#endif  // CUDNN_VERSION_MIN(7, 0, 1)
1022

1023
#ifdef PADDLE_WITH_HIP
1024 1025 1026 1027 1028 1029 1030 1031
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
//     data_desc_, CudnnDataType<T>::type,
//     x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
//     const_cast<int *>(strides.data())));
// PADDLE_ENFORCE_CUDA_SUCCESS(
//     platform::dynload::miopenDeriveBNTensorDescriptor(bn_param_desc_,
//                                                       data_desc_, mode_));
1032
#else
1033
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
1034 1035
          data_desc_, CudnnDataType<T>::type,
          x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
1036 1037 1038
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_,
                                                           data_desc_, mode_));
1039
#endif
1040 1041 1042

      const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
      const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
L
lvmengsi 已提交
1043
      const auto *saved_mean_data =
1044
          saved_mean->template data<BatchNormParamType<T>>();
L
lvmengsi 已提交
1045
      const auto *saved_var_data =
1046 1047
          saved_var->template data<BatchNormParamType<T>>();

K
Kaipeng Deng 已提交
1048 1049 1050 1051 1052 1053 1054 1055
      if (is_inplace) {
        inplace_functor(compute_format, transformed_x.data<T>(),
                        scale->template data<BatchNormParamType<T>>(),
                        bias->template data<BatchNormParamType<T>>(),
                        saved_mean_data, saved_var_data, epsilon, C, H * W * D,
                        num, transformed_x.data<T>(), grid2, block, stream);
      }

1056
      // This branch calls CUDNN APIs
1057
      if (d_x && d_scale && d_bias) {
1058 1059
        bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1)
1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122
        called = true;
        size_t workspace_size = 0;
        void *workspace_ptr = nullptr;
        Tensor workspace_tensor;
        auto reserve_space_size = reserve_space->memory_size();
        // --------------- cudnn batchnorm workspace ---------------
        PADDLE_ENFORCE_CUDA_SUCCESS(
            platform::dynload::
                cudnnGetBatchNormalizationBackwardExWorkspaceSize(
                    /*handle=*/dev_ctx.cudnn_handle(),
                    /*mode=*/mode_,
                    /*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
                    /*xDesc=*/data_desc_,
                    /*yDesc=*/data_desc_,
                    /*dyDesc=*/data_desc_,
                    /*dzDesc=*/nullptr,
                    /*dxDesc=*/data_desc_,
                    /*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
                    /*activationDesc=*/nullptr,
                    /*sizeInBytes=*/&workspace_size));

        workspace_ptr = workspace_tensor.mutable_data(
            ctx.GetPlace(), transformed_x.type(), workspace_size);

        PADDLE_ENFORCE_CUDA_SUCCESS(
            platform::dynload::cudnnBatchNormalizationBackwardEx(
                /*handle=*/dev_ctx.cudnn_handle(),
                /*mode=*/mode_,
                /*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
                /*alphaDataDiff=*/CudnnDataType<T>::kOne(),
                /*betaDataDiff=*/CudnnDataType<T>::kZero(),
                /*alphaParamDiff=*/CudnnDataType<T>::kOne(),
                /*betaParamDiff=*/CudnnDataType<T>::kZero(),
                /*xDesc=*/data_desc_,
                /*xData=*/transformed_x.template data<T>(),
                /*yDesc=*/nullptr,
                /*yData=*/nullptr,
                /*dyDesc=*/data_desc_,
                /*dyData=*/transformed_d_y.template data<T>(),
                /*dzDesc=*/nullptr,
                /*dzData=*/nullptr,
                /*dxDesc=*/data_desc_,
                /*dxData=*/transformed_d_x.template mutable_data<T>(
                    ctx.GetPlace()),
                /*dBnScaleBiasDesc=*/bn_param_desc_,
                /*bnScaleData=*/scale->template data<BatchNormParamType<T>>(),
                /*bnBiasData=*/nullptr,
                /*dBnScaleData=*/d_scale
                    ->template mutable_data<BatchNormParamType<T>>(
                        ctx.GetPlace()),
                /*dBnBiasData=*/d_bias
                    ->template mutable_data<BatchNormParamType<T>>(
                        ctx.GetPlace()),
                /*epsilon=*/epsilon,
                /*savedMean=*/saved_mean_data,
                /*savedInvVariance=*/saved_var_data,
                /*activationDesc=*/nullptr,
                /*workspace=*/workspace_ptr,
                /*workSpaceSizeInBytes=*/workspace_size,
                /*reserveSpace=*/const_cast<T *>(
                    reserve_space->template data<T>()),
                /*reserveSpaceSizeInBytes=*/reserve_space_size));
#endif  // CUDNN_VERSION_MIN(7, 4, 1)
1123
        if (!called) {
1124
#ifdef PADDLE_WITH_HIP
1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167
          if (compute_format == DataLayout::kNCHW) {
            BNBackward<
                T, block,
                DataLayout::kNCHW><<<grid2, block, 0, dev_ctx.stream()>>>(
                transformed_d_y.template data<T>(),
                transformed_x.template data<T>(),
                scale->template data<BatchNormParamType<T>>(), saved_mean_data,
                saved_var_data, C, N, H * W * D, epsilon,
                transformed_d_x.template data<T>(),
                d_scale->template mutable_data<BatchNormParamType<T>>(
                    ctx.GetPlace()),
                d_bias->template mutable_data<BatchNormParamType<T>>(
                    ctx.GetPlace()));
          } else {
            BNBackward<
                T, block,
                DataLayout::kNHWC><<<grid2, block, 0, dev_ctx.stream()>>>(
                transformed_d_y.template data<T>(),
                transformed_x.template data<T>(),
                scale->template data<BatchNormParamType<T>>(), saved_mean_data,
                saved_var_data, C, N, H * W * D, epsilon,
                transformed_d_x.template data<T>(),
                d_scale->template mutable_data<BatchNormParamType<T>>(
                    ctx.GetPlace()),
                d_bias->template mutable_data<BatchNormParamType<T>>(
                    ctx.GetPlace()));
          }

// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(
//     platform::dynload::miopenBatchNormalizationBackward(
//         dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
//         CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
//         CudnnDataType<T>::kZero(), data_desc_,
//         transformed_x.template data<T>(), data_desc_,
//         transformed_d_y.template data<T>(), data_desc_,
//         transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
//         bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
//         d_scale->template mutable_data<BatchNormParamType<T>>(
//             ctx.GetPlace()),
//         d_bias->template mutable_data<BatchNormParamType<T>>(
//             ctx.GetPlace()),
//         epsilon, saved_mean_data, saved_var_data));
1168
#else
1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182
          PADDLE_ENFORCE_CUDA_SUCCESS(
              platform::dynload::cudnnBatchNormalizationBackward(
                  dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
                  CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
                  CudnnDataType<T>::kZero(), data_desc_,
                  transformed_x.template data<T>(), data_desc_,
                  transformed_d_y.template data<T>(), data_desc_,
                  transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
                  bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
                  d_scale->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  d_bias->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  epsilon, saved_mean_data, saved_var_data));
1183
#endif
1184 1185 1186 1187 1188 1189 1190 1191
        }

        if (data_layout == DataLayout::kNHWC &&
            compute_format == DataLayout::kNCHW) {
          VLOG(3) << "Transform batchnorm output from NCHW to NHWC";
          TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
              ctx, &transformed_d_x, d_x);
        }
L
lvmengsi 已提交
1192
      } else {
1193
        // This branch call CUDA kernels
1194
        if (compute_format == DataLayout::kNCHW) {
L
lvmengsi 已提交
1195 1196 1197 1198 1199 1200 1201
          if (d_x) {
            BNBackwardData<T, block, framework::DataLayout::kNCHW><<<
                grid2, block, 0, dev_ctx.stream()>>>(
                d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
                saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
                d_x->data<T>());
          }
1202 1203 1204 1205 1206 1207 1208 1209 1210
          if (d_scale && d_bias) {
            KeBNBackwardScaleBias<
                T, block,
                framework::DataLayout::kNCHW><<<grid2, block, 0, stream>>>(
                d_y->data<T>(), x->data<T>(), saved_mean_data, saved_var_data,
                epsilon, N, C, H * W * D,
                d_scale->data<BatchNormParamType<T>>(),
                d_bias->data<BatchNormParamType<T>>());
          }
L
lvmengsi 已提交
1211 1212
        } else {
          if (d_x) {
L
Lv Mengsi 已提交
1213
            BNBackwardData<T, block, framework::DataLayout::kNHWC><<<
L
lvmengsi 已提交
1214 1215 1216 1217 1218
                grid2, block, 0, dev_ctx.stream()>>>(
                d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
                saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
                d_x->data<T>());
          }
1219 1220 1221 1222 1223 1224 1225 1226 1227
          if (d_scale && d_bias) {
            KeBNBackwardScaleBias<
                T, block,
                framework::DataLayout::kNHWC><<<grid2, block, 0, stream>>>(
                d_y->data<T>(), x->data<T>(), saved_mean_data, saved_var_data,
                epsilon, N, C, H * W * D,
                d_scale->data<BatchNormParamType<T>>(),
                d_bias->data<BatchNormParamType<T>>());
          }
L
lvmengsi 已提交
1228 1229
        }
      }
1230

1231
#ifdef PADDLE_WITH_HIP
1232 1233 1234 1235 1236 1237
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// clean when exit.
// PADDLE_ENFORCE_CUDA_SUCCESS(
//     platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS(
//     platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
1238
#else
1239
      // clean when exit.
1240
      PADDLE_ENFORCE_CUDA_SUCCESS(
1241
          platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
1242
      PADDLE_ENFORCE_CUDA_SUCCESS(
1243
          platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
1244
#endif
1245 1246 1247 1248 1249 1250 1251 1252 1253
    } else {
      const auto *running_mean = ctx.Input<Tensor>("Mean");
      const auto *running_var = ctx.Input<Tensor>("Variance");

      const auto *running_mean_data =
          running_mean->template data<BatchNormParamType<T>>();
      const auto *running_var_data =
          running_var->template data<BatchNormParamType<T>>();

K
Kaipeng Deng 已提交
1254 1255 1256 1257 1258 1259 1260 1261 1262
      if (is_inplace) {
        auto px = *x;
        inplace_functor(data_layout, px.mutable_data<T>(ctx.GetPlace()),
                        scale->template data<BatchNormParamType<T>>(),
                        bias->template data<BatchNormParamType<T>>(),
                        running_mean_data, running_var_data, epsilon, C,
                        H * W * D, num, x->data<T>(), grid2, block, stream);
      }

1263
      if (compute_format == DataLayout::kNCHW) {
1264
        if (d_x) {
K
Kaipeng Deng 已提交
1265 1266
          KeBNBackwardData<
              T, framework::DataLayout::kNCHW><<<grid1, block, 0, stream>>>(
1267 1268 1269 1270
              d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
              running_var_data, epsilon, C, H * W, num, d_x->data<T>());
        }
        if (d_scale && d_bias) {
K
Kaipeng Deng 已提交
1271 1272 1273
          KeBNBackwardScaleBias<
              T, block,
              framework::DataLayout::kNCHW><<<grid2, block, 0, stream>>>(
1274
              d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data,
Q
qingqing01 已提交
1275
              epsilon, N, C, H * W * D, d_scale->data<BatchNormParamType<T>>(),
1276 1277 1278 1279
              d_bias->data<BatchNormParamType<T>>());
        }
      } else {
        if (d_x) {
K
Kaipeng Deng 已提交
1280 1281
          KeBNBackwardData<
              T, framework::DataLayout::kNHWC><<<grid1, block, 0, stream>>>(
1282 1283 1284 1285
              d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
              running_var_data, epsilon, C, H * W, num, d_x->data<T>());
        }
        if (d_scale && d_bias) {
K
Kaipeng Deng 已提交
1286 1287 1288
          KeBNBackwardScaleBias<
              T, block,
              framework::DataLayout::kNHWC><<<grid2, block, 0, stream>>>(
1289
              d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data,
Q
qingqing01 已提交
1290
              epsilon, N, C, H * W * D, d_scale->data<BatchNormParamType<T>>(),
1291 1292 1293 1294
              d_bias->data<BatchNormParamType<T>>());
        }
      }
    }
Q
Qiao Longfei 已提交
1295 1296 1297
  }
};

1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336
template <typename T>
class BatchNormDoubleGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *X = ctx.Input<Tensor>("X");
    const auto *Scale = ctx.Input<Tensor>("Scale");
    const auto *dY = ctx.Input<Tensor>("DY");
    const auto *Saved_mean = ctx.Input<Tensor>("SavedMean");
    const auto *Saved_variance = ctx.Input<Tensor>("SavedVariance");
    const double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
    const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
    const bool is_test = ctx.Attr<bool>("is_test");

    PADDLE_ENFORCE_EQ(
        is_test, false,
        platform::errors::InvalidArgument(
            "`is_test = True` CANNOT be used in train program. If "
            "you want to use global status in pre_train model, "
            "please set `use_global_stats = True`"));

    const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
    const DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);

    const auto *ddX = ctx.Input<Tensor>("DDX");
    const auto *ddScale = ctx.Input<Tensor>("DDScale");
    const auto *ddBias = ctx.Input<Tensor>("DDBias");

    auto *dX = ctx.Output<Tensor>("DX");
    auto *dScale = ctx.Output<Tensor>("DScale");
    auto *ddY = ctx.Output<Tensor>("DDY");

    NormDoubleGradFunctor<platform::CUDADeviceContext, T>(
        ctx, data_layout, X, Scale, dY, Saved_mean, Saved_variance, epsilon,
        use_global_stats, ddX, ddScale, ddBias, dX, dScale, ddY);
  }
};

Q
Qiao Longfei 已提交
1337 1338 1339 1340
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
K
Kexin Zhao 已提交
1341
namespace plat = paddle::platform;
1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(
    batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
    ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
    batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>,
    ops::BatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
    batch_norm_grad_grad,
    ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, float>);
#else
Q
QI JUN 已提交
1354
REGISTER_OP_CUDA_KERNEL(
K
Kexin Zhao 已提交
1355
    batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
D
dzhwinter 已提交
1356
    ops::BatchNormKernel<plat::CUDADeviceContext, double>,
K
Kexin Zhao 已提交
1357
    ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
Q
QI JUN 已提交
1358
REGISTER_OP_CUDA_KERNEL(
D
dzhwinter 已提交
1359
    batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>,
C
chengduo 已提交
1360 1361
    ops::BatchNormGradKernel<plat::CUDADeviceContext, double>,
    ops::BatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);
1362 1363 1364 1365
REGISTER_OP_CUDA_KERNEL(
    batch_norm_grad_grad,
    ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, float>,
    ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, double>);
1366
#endif