batch_norm_op.cu 39.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
Q
Qiao Longfei 已提交
16
#include <cfloat>
17 18 19
#include <string>
#include <vector>
#include "cub/cub.cuh"
S
Siddharth Goyal 已提交
20
#include "paddle/fluid/framework/data_layout.h"
21
#include "paddle/fluid/operators/batch_norm_op.h"
Y
Yi Wang 已提交
22
#include "paddle/fluid/operators/math/math_function.h"
23
#include "paddle/fluid/operators/norm_utils.cu.h"
Y
Yi Wang 已提交
24
#include "paddle/fluid/platform/cudnn_helper.h"
K
Kexin Zhao 已提交
25
#include "paddle/fluid/platform/float16.h"
Q
Qiao Longfei 已提交
26

27
DECLARE_bool(cudnn_batchnorm_spatial_persistent);
W
Wu Yi 已提交
28

Q
Qiao Longfei 已提交
29 30 31 32
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
Q
QI JUN 已提交
33
using DataLayout = framework::DataLayout;
Q
Qiao Longfei 已提交
34 35
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
K
Kexin Zhao 已提交
36
template <typename T>
K
update  
Kexin Zhao 已提交
37
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
Q
Qiao Longfei 已提交
38 39

template <typename T>
Q
QI JUN 已提交
40 41
class BatchNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
Q
Qiao Longfei 已提交
42 43
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
K
Kaipeng Deng 已提交
44 45 46
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(ctx.GetPlace()), true,
        platform::errors::InvalidArgument("It must use CUDAPlace."));
Q
Qiao Longfei 已提交
47
    double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
48
    float momentum = ctx.Attr<float>("momentum");
Q
Qiao Longfei 已提交
49
    const bool is_test = ctx.Attr<bool>("is_test");
50
    const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
51
    const bool trainable_stats = ctx.Attr<bool>("trainable_statistics");
Q
QI JUN 已提交
52 53 54
    const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
    const DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
Q
Qiao Longfei 已提交
55

56 57
    bool test_mode = is_test && (!trainable_stats);

Q
Qiao Longfei 已提交
58 59 60 61
    // Get the size for each dimension.
    // NCHW [batch_size, in_channels, in_height, in_width]
    const auto *x = ctx.Input<Tensor>("X");
    const auto &x_dims = x->dims();
C
ceci3 已提交
62 63 64 65 66 67
    PADDLE_ENFORCE_EQ(
        x_dims.size() >= 2 && x_dims.size() <= 5, true,
        platform::errors::InvalidArgument(
            "The size of input's dimensions should be between 2 and 5"
            "But received: the size of input's dimensions is [%d]",
            x_dims.size()));
Q
Qiao Longfei 已提交
68

69 70 71
    auto *y = ctx.Output<Tensor>("Y");
    y->mutable_data<T>(ctx.GetPlace());

72 73 74 75 76
    int N, C, H, W, D;
    ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);

    auto dtype = platform::CudnnDataType<T>::type;
    const bool fast_nhwc_batch_norm =
77
        test_mode ||
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
        (dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent);

    auto compute_format =
        fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC
            ? DataLayout::kNHWC
            : DataLayout::kNCHW;

    Tensor transformed_x(x->type());
    Tensor transformed_y(y->type());
    if (data_layout == DataLayout::kNHWC &&
        compute_format == DataLayout::kNCHW && x_dims.size() > 2) {
      VLOG(3) << "Transform input tensor from NHWC to NCHW.";
      ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
                                                           &transformed_x);
      TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
                                                          &transformed_x);
      ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, y,
                                                           &transformed_y);
    } else {
      transformed_x.ShareDataWith(*x);
      transformed_y.ShareDataWith(*y);
    }

Q
Qiao Longfei 已提交
101 102 103 104 105
    // ------------------- cudnn descriptors ---------------------
    cudnnTensorDescriptor_t data_desc_;
    cudnnTensorDescriptor_t bn_param_desc_;
    cudnnBatchNormMode_t mode_;

106 107 108
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
    PADDLE_ENFORCE_CUDA_SUCCESS(
Q
Qiao Longfei 已提交
109 110 111 112 113 114 115 116
        platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));

    if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
      LOG(ERROR) << "Provided epsilon is smaller than "
                 << "CUDNN_BN_MIN_EPSILON. Setting it to "
                 << "CUDNN_BN_MIN_EPSILON instead.";
    }
    epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
117
#if CUDNN_VERSION_MIN(7, 0, 0)
W
Wu Yi 已提交
118 119 120 121 122
    if (FLAGS_cudnn_batchnorm_spatial_persistent) {
      mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
    } else {
      mode_ = CUDNN_BATCHNORM_SPATIAL;
    }
123
#else
Q
Qiao Longfei 已提交
124
    mode_ = CUDNN_BATCHNORM_SPATIAL;
125
#endif
Q
Qiao Longfei 已提交
126

M
minqiyang 已提交
127
    VLOG(3) << "Setting descriptors.";
Q
Qiao Longfei 已提交
128 129
    std::vector<int> dims;
    std::vector<int> strides;
130
    if (compute_format == DataLayout::kNCHW) {
Q
Qiao Longfei 已提交
131 132 133 134 135 136
      dims = {N, C, H, W, D};
      strides = {C * H * W * D, H * W * D, W * D, D, 1};
    } else {
      dims = {N, C, H, W, D};
      strides = {H * W * D * C, 1, W * D * C, D * C, C};
    }
137
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
Q
Qiao Longfei 已提交
138 139
        data_desc_, CudnnDataType<T>::type,
        x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
K
Kexin Zhao 已提交
140
    // Note: PERSISTENT not implemented for inference
141 142 143
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnDeriveBNTensorDescriptor(
            bn_param_desc_, data_desc_,
144
            test_mode ? CUDNN_BATCHNORM_SPATIAL : mode_));
Q
Qiao Longfei 已提交
145 146 147 148

    const auto *scale = ctx.Input<Tensor>("Scale");
    const auto *bias = ctx.Input<Tensor>("Bias");

Q
QI JUN 已提交
149
    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
Q
Qiao Longfei 已提交
150

Q
QI JUN 已提交
151
    auto handle = dev_ctx.cudnn_handle();
Q
Qiao Longfei 已提交
152 153

    // Now, depending on whether we are running test or not, we have two paths.
154
    if (test_mode || use_global_stats) {
Q
Qiao Longfei 已提交
155 156 157 158
      // only when test we use input to do computation.
      const auto *est_mean = ctx.Input<Tensor>("Mean");
      const auto *est_var = ctx.Input<Tensor>("Variance");
      // Run inference mode.
C
ceci3 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
      PADDLE_ENFORCE_EQ(
          est_mean->dims().size(), 1UL,
          platform::errors::InvalidArgument(
              "The size of mean's dimensions must equal to 1."
              "But received: the size of mean's dimensions mean is [%d],"
              "the dimensions of mean is [%s].",
              est_mean->dims().size(), est_mean->dims()));
      PADDLE_ENFORCE_EQ(
          est_var->dims().size(), 1UL,
          platform::errors::InvalidArgument(
              "The size of variance's dimensions must equal to 1."
              "But received: the size of variance's dimensions is [%d],"
              "the dimensions of variance is [%s].",
              est_var->dims().size(), est_var->dims()));
      PADDLE_ENFORCE_EQ(
          est_mean->dims()[0], C,
          platform::errors::InvalidArgument(
              "The first dimension of mean must equal to the number of "
              "Channels, which is [%d]. But received: the first dimension"
              "of mean is [%d], the dimensions of mean is [%s].",
              C, est_mean->dims()[0], est_mean->dims()));
      PADDLE_ENFORCE_EQ(
          est_var->dims()[0], C,
          platform::errors::InvalidArgument(
              "The first dimension of variance must equal to the number"
              "of Channels, which is [%d]. But received: the first dimension of"
              "variance is [%d], the dimensions of variance is [%s].",
              C, est_var->dims()[0], est_var->dims()));
Q
Qiao Longfei 已提交
187

188 189 190 191 192 193 194 195 196 197 198 199
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnBatchNormalizationForwardInference(
              handle,
              // Note: PERSISTENT not implemented for inference
              CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
              CudnnDataType<T>::kZero(), data_desc_,
              transformed_x.template data<T>(), data_desc_,
              transformed_y.template mutable_data<T>(ctx.GetPlace()),
              bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
              bias->template data<BatchNormParamType<T>>(),
              est_mean->template data<BatchNormParamType<T>>(),
              est_var->template data<BatchNormParamType<T>>(), epsilon));
Q
Qiao Longfei 已提交
200
    } else {
201 202 203 204 205 206 207 208 209
      // if MomentumTensor is set, use MomentumTensor value, momentum
      // is only used in this training branch
      if (ctx.HasInput("MomentumTensor")) {
        const auto *mom_tensor = ctx.Input<Tensor>("MomentumTensor");
        Tensor mom_cpu;
        TensorCopySync(*mom_tensor, platform::CPUPlace(), &mom_cpu);
        momentum = mom_cpu.data<float>()[0];
      }

Q
Qiao Longfei 已提交
210 211 212
      // Run training mode.
      // obtain running mean and running inv var, and see if we need to
      // initialize them.
D
Dang Qingqing 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227

      auto *mean_out = ctx.Output<Tensor>("MeanOut");
      auto *variance_out = ctx.Output<Tensor>("VarianceOut");
      mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
      variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());

      auto *saved_mean = ctx.Output<Tensor>("SavedMean");
      auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
      saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
      saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
      math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
          functor;
      functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
      functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));

228
      if ((N * H * W * D) == 1) {
229 230
        // Only 1 element in normalization dimension,
        // skip the batch norm calculation, let y = x.
231
        framework::TensorCopy(*x, ctx.GetPlace(), y);
232 233 234
      } else {
        double this_factor = 1. - momentum;

235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
        bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1)
        if (compute_format == DataLayout::kNHWC) {
          called = true;
          size_t workspace_size = 0;
          size_t reserve_space_size = 0;
          void *reserve_space_ptr = nullptr;
          void *workspace_ptr = nullptr;
          Tensor workspace_tensor;
          // Create reserve space and workspace for batch norm.
          // Create tensor for each batchnorm op, it will be used in the
          // backward. Thus this tensor shouldn't be temp.
          auto *reserve_space = ctx.Output<Tensor>("ReserveSpace");
          PADDLE_ENFORCE_NOT_NULL(
              reserve_space,
              platform::errors::NotFound(
                  "The argument ReserveSpace of batch_norm op is not found."));

          // --------------- cudnn batchnorm workspace ---------------
254
          PADDLE_ENFORCE_CUDA_SUCCESS(
255 256 257 258 259 260 261 262 263 264 265 266 267
              platform::dynload::
                  cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
                      /*handle=*/handle,
                      /*mode=*/mode_,
                      /*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
                      /*xDesc=*/data_desc_,
                      /*zDesc=*/nullptr,
                      /*yDesc=*/data_desc_,
                      /*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
                      /*activationDesc=*/nullptr,
                      /*sizeInBytes=*/&workspace_size));

          // -------------- cudnn batchnorm reserve space --------------
268
          PADDLE_ENFORCE_CUDA_SUCCESS(
269 270 271 272 273 274 275 276 277 278 279 280 281
              platform::dynload::
                  cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
                      /*handle=*/handle,
                      /*mode=*/mode_,
                      /*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
                      /*activationDesc=*/nullptr,
                      /*xDesc=*/data_desc_,
                      /*sizeInBytes=*/&reserve_space_size));

          reserve_space_ptr = reserve_space->mutable_data(
              ctx.GetPlace(), transformed_x.type(), reserve_space_size);
          workspace_ptr = workspace_tensor.mutable_data(
              ctx.GetPlace(), transformed_x.type(), workspace_size);
282
          PADDLE_ENFORCE_CUDA_SUCCESS(
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
              platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
                  handle, mode_, CUDNN_BATCHNORM_OPS_BN,
                  CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
                  data_desc_, transformed_x.template data<T>(), nullptr,
                  nullptr, data_desc_, transformed_y.template data<T>(),
                  bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
                  bias->template data<BatchNormParamType<T>>(), this_factor,
                  mean_out->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  variance_out->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  epsilon,
                  saved_mean->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  saved_variance->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  nullptr, workspace_ptr, workspace_size, reserve_space_ptr,
                  reserve_space_size));
        }
#endif
        if (!called) {
304
          PADDLE_ENFORCE_CUDA_SUCCESS(
305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
              platform::dynload::cudnnBatchNormalizationForwardTraining(
                  handle, mode_, CudnnDataType<T>::kOne(),
                  CudnnDataType<T>::kZero(), data_desc_,
                  transformed_x.template data<T>(), data_desc_,
                  transformed_y.template mutable_data<T>(ctx.GetPlace()),
                  bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
                  bias->template data<BatchNormParamType<T>>(), this_factor,
                  mean_out->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  variance_out->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  epsilon,
                  saved_mean->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  saved_variance->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace())));
        }
322
      }
Q
Qiao Longfei 已提交
323 324
    }

325 326 327 328 329 330
    if (data_layout == DataLayout::kNHWC &&
        compute_format == DataLayout::kNCHW && x_dims.size() > 2) {
      VLOG(3) << "Transform batchnorm output from NCHW to NHWC";
      TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
          ctx, &transformed_y, y);
    }
Q
Qiao Longfei 已提交
331
    // clean when exit.
332 333 334
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
    PADDLE_ENFORCE_CUDA_SUCCESS(
Q
Qiao Longfei 已提交
335 336 337 338
        platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
  }
};

339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void KeBNBackwardScaleBias(
    const T *dy, const T *x, const BatchNormParamType<T> *mean,
    const BatchNormParamType<T> *variance, const double epsilon, const int N,
    const int C, const int HxW, BatchNormParamType<T> *dscale,
    BatchNormParamType<T> *dbias) {
  const int outer_size = C;
  const int inner_size = N * HxW;
  typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage ds_storage;
  __shared__ typename BlockReduce::TempStorage db_storage;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    BatchNormParamType<T> ds_sum = static_cast<BatchNormParamType<T>>(0);
    BatchNormParamType<T> db_sum = static_cast<BatchNormParamType<T>>(0);

    BatchNormParamType<T> inv_var_i = 1.0 / sqrt(variance[i] + epsilon);
    BatchNormParamType<T> mean_i = mean[i];
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index = layout == framework::DataLayout::kNCHW
                            ? (j / HxW * C + i) * HxW + j % HxW
                            : j * outer_size + i;
      ds_sum += static_cast<BatchNormParamType<T>>(dy[index]) *
                (static_cast<BatchNormParamType<T>>(x[index]) - mean_i);
      db_sum += static_cast<BatchNormParamType<T>>(dy[index]);
    }
    ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
    db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
    if (threadIdx.x == 0) {
      dscale[i] = ds_sum * inv_var_i;
      dbias[i] = db_sum;
    }
    __syncthreads();
  }
}

Q
qingqing01 已提交
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
template <typename T, framework::DataLayout layout>
static __global__ void KeBNBackwardData(const T *dy,
                                        const BatchNormParamType<T> *scale,
                                        const BatchNormParamType<T> *variance,
                                        const double epsilon, const int C,
                                        const int HxW, const int num, T *dx) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  for (int i = gid; i < num; i += stride) {
    const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
    BatchNormParamType<T> inv_var = 1.0 / sqrt(variance[c] + epsilon);
    dx[i] = static_cast<T>(static_cast<BatchNormParamType<T>>(dy[i]) *
                           scale[c] * inv_var);
  }
}

K
Kaipeng Deng 已提交
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
template <typename T>
static __global__ void KeBNRestoreData(const framework::DataLayout layout, T *x,
                                       const BatchNormParamType<T> *scale,
                                       const BatchNormParamType<T> *bias,
                                       const BatchNormParamType<T> *mean,
                                       const BatchNormParamType<T> *variance,
                                       double epsilon, int C, int M,
                                       const int num, const T *y) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  for (int i = gid; i < num; i += stride) {
    const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C;
    auto y_i = static_cast<BatchNormParamType<T>>(y[i]);
    auto x_i = (y_i - bias[c]) / scale[c] / variance[c] + mean[c];
    x[i] = static_cast<T>(x_i);
  }
}

template <typename T>
class InplaceHelper {
 public:
  void operator()(const framework::DataLayout layout, T *x,
                  const BatchNormParamType<T> *scale,
                  const BatchNormParamType<T> *bias,
                  const BatchNormParamType<T> *mean,
                  const BatchNormParamType<T> *variance, double epsilon, int C,
                  int M, const int num, const T *y, int grid2, const int block,
                  const cudaStream_t &stream) {
    PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument(
                                "X and Y should be inplaced in inplace mode"));
    KeBNRestoreData<<<grid2, block, 0, stream>>>(
        layout, x, scale, bias, mean, variance, epsilon, C, M, num, y);
  }
};

L
lvmengsi 已提交
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void BNBackwardData(const T *dy,
                                      const BatchNormParamType<T> *scale,
                                      const BatchNormParamType<T> *mean,
                                      const T *x,
                                      const BatchNormParamType<T> *variance,
                                      const int C, const int N, const int HxW,
                                      T *dx) {
  const int outer_size = C;
  const int inner_size = N * HxW;
  typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage dy_storage;
  __shared__ typename BlockReduce::TempStorage dy_x_sub_mean_storage;
  __shared__ BatchNormParamType<T> dy_sum_val;
  __shared__ BatchNormParamType<T> dy_x_sub_mean_sum_val;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    BatchNormParamType<T> inv_var_i = variance[i];
    BatchNormParamType<T> mean_i = mean[i];
    BatchNormParamType<T> dy_sum = static_cast<BatchNormParamType<T>>(0);
    BatchNormParamType<T> dy_x_sub_mean_sum =
        static_cast<BatchNormParamType<T>>(0);
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index = layout == framework::DataLayout::kNCHW
                            ? (j / HxW * C + i) * HxW + j % HxW
                            : j * outer_size + i;
      BatchNormParamType<T> dy_i =
          static_cast<BatchNormParamType<T>>(dy[index]);
      dy_sum += dy_i;
      dy_x_sub_mean_sum +=
          dy_i * (static_cast<BatchNormParamType<T>>(x[index]) - mean_i);
    }

    dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
    dy_x_sub_mean_sum = BlockReduce(dy_x_sub_mean_storage)
                            .Reduce(dy_x_sub_mean_sum, cub::Sum());

    if (threadIdx.x == 0) {
      dy_sum_val = dy_sum;
      dy_x_sub_mean_sum_val = dy_x_sub_mean_sum;
    }
    __syncthreads();

    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index = layout == framework::DataLayout::kNCHW
                            ? (j / HxW * C + i) * HxW + j % HxW
                            : j * outer_size + i;
      dx[index] =
          (static_cast<BatchNormParamType<T>>(dy[index]) -
           dy_sum_val / static_cast<BatchNormParamType<T>>(inner_size) -
           (static_cast<BatchNormParamType<T>>(x[index]) - mean_i) *
               dy_x_sub_mean_sum_val * inv_var_i * inv_var_i / inner_size) *
          scale[i] * inv_var_i;
    }
  }
}

Q
Qiao Longfei 已提交
483
template <typename T>
Q
QI JUN 已提交
484
class BatchNormGradKernel<platform::CUDADeviceContext, T>
Q
Qiao Longfei 已提交
485 486 487
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
K
Kaipeng Deng 已提交
488 489 490
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(ctx.GetPlace()), true,
        platform::errors::InvalidArgument("It must use CUDAPlace."));
Q
Qiao Longfei 已提交
491
    double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
Q
QI JUN 已提交
492
    const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
493 494
    const bool use_global_stats = ctx.Attr<bool>("use_global_stats");

Q
QI JUN 已提交
495 496
    const DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
Q
Qiao Longfei 已提交
497 498
    const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
    const auto *scale = ctx.Input<Tensor>("Scale");
K
Kaipeng Deng 已提交
499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524
    const auto *bias = ctx.Input<Tensor>("Bias");

    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    // batch_norm with inplace as false will take X as grad input, which
    // is same as cuDNN batch_norm backward calculation, batch_norm
    // with inplace as true only take Y as input and X should be calculate
    // by inverse operation of batch_norm on Y
    const Tensor *x;
    bool is_inplace;
    if (ctx.HasInput("Y")) {
      x = ctx.Input<Tensor>("Y");
      is_inplace = true;
      PADDLE_ENFORCE_EQ(d_x, d_y,
                        platform::errors::InvalidArgument(
                            "X@GRAD and Y@GRAD not inplace in inplace mode"));
    } else {
      x = ctx.Input<Tensor>("X");
      is_inplace = false;
      PADDLE_ENFORCE_NE(d_x, d_y,
                        platform::errors::InvalidArgument(
                            "X@GRAD and Y@GRAD inplaced in non-inplace mode"));
    }

525 526 527 528 529 530 531
    const bool is_test = ctx.Attr<bool>("is_test");
    PADDLE_ENFORCE_EQ(
        is_test, false,
        platform::errors::InvalidArgument(
            "`is_test = True` CANNOT be used in train program. If "
            "you want to use global status in pre_train model, "
            "please set `use_global_stats = True`"));
Q
Qiao Longfei 已提交
532 533 534

    const auto &x_dims = x->dims();

C
ceci3 已提交
535 536 537 538 539 540 541
    PADDLE_ENFORCE_EQ(
        x_dims.size() >= 2 && x_dims.size() <= 5, true,
        platform::errors::InvalidArgument(
            "The size of input's dimensions should be between 2 and 5."
            "But received: the size of input's dimensions is [%d],"
            "the dimensions of input is [%s]",
            x_dims.size(), x_dims));
Q
Qiao Longfei 已提交
542
    int N, C, H, W, D;
Q
QI JUN 已提交
543
    ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
Q
Qiao Longfei 已提交
544

545 546
    // init output
    d_x->mutable_data<T>(ctx.GetPlace());
K
Kaipeng Deng 已提交
547

548 549 550
    if (d_scale && d_bias) {
      d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
      d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
551
    }
C
ceci3 已提交
552 553 554 555 556 557 558 559 560 561 562 563 564
    PADDLE_ENFORCE_EQ(
        scale->dims().size(), 1UL,
        platform::errors::InvalidArgument(
            "The size of scale's dimensions must equal to 1. But received: "
            "the size of scale's dimensions is [%d], the dimensions of scale "
            "is [%s].",
            scale->dims().size(), scale->dims()));
    PADDLE_ENFORCE_EQ(
        scale->dims()[0], C,
        platform::errors::InvalidArgument(
            "The first dimension of scale must equal to Channels[%d]. But "
            "received: the first dimension of scale is [%d]",
            C, scale->dims()[0]));
Q
Qiao Longfei 已提交
565

566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597
    auto dtype = platform::CudnnDataType<T>::type;
    const auto *reserve_space = ctx.Input<Tensor>("ReserveSpace");
    const bool fast_nhwc_batch_norm =
        dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent &&
        reserve_space != nullptr;
    auto compute_format =
        fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC
            ? DataLayout::kNHWC
            : DataLayout::kNCHW;

    Tensor transformed_x(x->type());
    Tensor transformed_d_y(d_y->type());
    Tensor transformed_d_x(d_x->type());
    if (data_layout == DataLayout::kNHWC &&
        compute_format == DataLayout::kNCHW) {
      VLOG(3) << "Transform input tensor from NHWC to NCHW.";
      ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
                                                           &transformed_x);
      TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
                                                          &transformed_x);
      ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_y,
                                                           &transformed_d_y);
      TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_y,
                                                          &transformed_d_y);
      ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_x,
                                                           &transformed_d_x);
    } else {
      transformed_x.ShareDataWith(*x);
      transformed_d_y.ShareDataWith(*d_y);
      transformed_d_x.ShareDataWith(*d_x);
    }

Z
zchen0211 已提交
598 599
    std::vector<int> dims;
    std::vector<int> strides;
600
    if (compute_format == DataLayout::kNCHW) {
Z
zchen0211 已提交
601 602 603 604 605 606
      dims = {N, C, H, W, D};
      strides = {C * H * W * D, H * W * D, W * D, D, 1};
    } else {
      dims = {N, C, H, W, D};
      strides = {H * W * C * D, 1, W * D * C, D * C, C};
    }
Q
Qiao Longfei 已提交
607

608
    auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
609
    const int num = transformed_x.numel();
L
lvmengsi 已提交
610 611 612 613 614
    const int block = 512;
    int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
    const int max_blocks = std::max(max_threads / block, 1);
    int grid1 = (num + block - 1) / block;
    int grid2 = std::min(C, max_blocks);
K
Kaipeng Deng 已提交
615 616
    auto stream = dev_ctx.stream();
    InplaceHelper<T> inplace_functor;
L
lvmengsi 已提交
617

618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
    if (!use_global_stats) {
      if ((N * H * W * D) == 1) {
        framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
        math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
            functor;
        functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
        functor(dev_ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
        return;
      }

      // ------------------- cudnn descriptors ---------------------
      cudnnTensorDescriptor_t data_desc_;
      cudnnTensorDescriptor_t bn_param_desc_;
      cudnnBatchNormMode_t mode_;

633
      PADDLE_ENFORCE_CUDA_SUCCESS(
634
          platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
635
      PADDLE_ENFORCE_CUDA_SUCCESS(
636 637 638 639 640 641 642
          platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
      if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
        LOG(ERROR) << "Provided epsilon is smaller than "
                   << "CUDNN_BN_MIN_EPSILON. Setting it to "
                   << "CUDNN_BN_MIN_EPSILON instead.";
      }
      epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
643
#if CUDNN_VERSION_MIN(7, 0, 0)
W
Wu Yi 已提交
644 645 646 647 648
      if (FLAGS_cudnn_batchnorm_spatial_persistent) {
        mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
      } else {
        mode_ = CUDNN_BATCHNORM_SPATIAL;
      }
649
#else
650
      mode_ = CUDNN_BATCHNORM_SPATIAL;
651
#endif
652

653
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
654 655
          data_desc_, CudnnDataType<T>::type,
          x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
656 657 658
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_,
                                                           data_desc_, mode_));
659 660 661

      const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
      const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
L
lvmengsi 已提交
662
      const auto *saved_mean_data =
663
          saved_mean->template data<BatchNormParamType<T>>();
L
lvmengsi 已提交
664
      const auto *saved_var_data =
665 666
          saved_var->template data<BatchNormParamType<T>>();

K
Kaipeng Deng 已提交
667 668 669 670 671 672 673 674
      if (is_inplace) {
        inplace_functor(compute_format, transformed_x.data<T>(),
                        scale->template data<BatchNormParamType<T>>(),
                        bias->template data<BatchNormParamType<T>>(),
                        saved_mean_data, saved_var_data, epsilon, C, H * W * D,
                        num, transformed_x.data<T>(), grid2, block, stream);
      }

L
lvmengsi 已提交
675
      if (d_scale && d_bias) {
676 677 678 679 680 681 682 683 684
        bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1)
        if (compute_format == DataLayout::kNHWC) {
          called = true;
          size_t workspace_size = 0;
          void *workspace_ptr = nullptr;
          Tensor workspace_tensor;
          auto reserve_space_size = reserve_space->memory_size();
          // --------------- cudnn batchnorm workspace ---------------
685 686 687 688 689 690 691 692 693 694 695 696 697 698
          PADDLE_ENFORCE_CUDA_SUCCESS(
              platform::dynload::
                  cudnnGetBatchNormalizationBackwardExWorkspaceSize(
                      /*handle=*/dev_ctx.cudnn_handle(),
                      /*mode=*/mode_,
                      /*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
                      /*xDesc=*/data_desc_,
                      /*yDesc=*/data_desc_,
                      /*dyDesc=*/data_desc_,
                      /*dzDesc=*/nullptr,
                      /*dxDesc=*/data_desc_,
                      /*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
                      /*activationDesc=*/nullptr,
                      /*sizeInBytes=*/&workspace_size));
699 700 701 702

          workspace_ptr = workspace_tensor.mutable_data(
              ctx.GetPlace(), transformed_x.type(), workspace_size);

703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721
          PADDLE_ENFORCE_CUDA_SUCCESS(
              platform::dynload::cudnnBatchNormalizationBackwardEx(
                  /*handle=*/dev_ctx.cudnn_handle(),
                  /*mode=*/mode_,
                  /*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
                  /*alphaDataDiff=*/CudnnDataType<T>::kOne(),
                  /*betaDataDiff=*/CudnnDataType<T>::kZero(),
                  /*alphaParamDiff=*/CudnnDataType<T>::kOne(),
                  /*betaParamDiff=*/CudnnDataType<T>::kZero(),
                  /*xDesc=*/data_desc_,
                  /*xData=*/transformed_x.template data<T>(),
                  /*yDesc=*/nullptr,
                  /*yData=*/nullptr,
                  /*dyDesc=*/data_desc_,
                  /*dyData=*/transformed_d_y.template data<T>(),
                  /*dzDesc=*/nullptr,
                  /*dzData=*/nullptr,
                  /*dxDesc=*/data_desc_,
                  /*dxData=*/transformed_d_x.template mutable_data<T>(
722
                      ctx.GetPlace()),
723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740
                  /*dBnScaleBiasDesc=*/bn_param_desc_,
                  /*bnScaleData=*/scale->template data<BatchNormParamType<T>>(),
                  /*bnBiasData=*/nullptr,
                  /*dBnScaleData=*/d_scale
                      ->template mutable_data<BatchNormParamType<T>>(
                          ctx.GetPlace()),
                  /*dBnBiasData=*/d_bias
                      ->template mutable_data<BatchNormParamType<T>>(
                          ctx.GetPlace()),
                  /*epsilon=*/epsilon,
                  /*savedMean=*/saved_mean_data,
                  /*savedInvVariance=*/saved_var_data,
                  /*activationDesc=*/nullptr,
                  /*workspace=*/workspace_ptr,
                  /*workSpaceSizeInBytes=*/workspace_size,
                  /*reserveSpace=*/const_cast<T *>(
                      reserve_space->template data<T>()),
                  /*reserveSpaceSizeInBytes=*/reserve_space_size));
741 742 743
        }
#endif
        if (!called) {
744 745 746 747 748 749 750 751 752 753 754 755 756 757
          PADDLE_ENFORCE_CUDA_SUCCESS(
              platform::dynload::cudnnBatchNormalizationBackward(
                  dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
                  CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
                  CudnnDataType<T>::kZero(), data_desc_,
                  transformed_x.template data<T>(), data_desc_,
                  transformed_d_y.template data<T>(), data_desc_,
                  transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
                  bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
                  d_scale->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  d_bias->template mutable_data<BatchNormParamType<T>>(
                      ctx.GetPlace()),
                  epsilon, saved_mean_data, saved_var_data));
758 759 760 761 762 763 764 765
        }

        if (data_layout == DataLayout::kNHWC &&
            compute_format == DataLayout::kNCHW) {
          VLOG(3) << "Transform batchnorm output from NCHW to NHWC";
          TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
              ctx, &transformed_d_x, d_x);
        }
L
lvmengsi 已提交
766
      } else {
767
        if (compute_format == DataLayout::kNCHW) {
L
lvmengsi 已提交
768 769 770 771 772 773 774 775 776
          if (d_x) {
            BNBackwardData<T, block, framework::DataLayout::kNCHW><<<
                grid2, block, 0, dev_ctx.stream()>>>(
                d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
                saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
                d_x->data<T>());
          }
        } else {
          if (d_x) {
L
Lv Mengsi 已提交
777
            BNBackwardData<T, block, framework::DataLayout::kNHWC><<<
L
lvmengsi 已提交
778 779 780 781 782 783 784
                grid2, block, 0, dev_ctx.stream()>>>(
                d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
                saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
                d_x->data<T>());
          }
        }
      }
785 786

      // clean when exit.
787
      PADDLE_ENFORCE_CUDA_SUCCESS(
788
          platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
789
      PADDLE_ENFORCE_CUDA_SUCCESS(
790 791 792 793 794 795 796 797 798 799
          platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
    } else {
      const auto *running_mean = ctx.Input<Tensor>("Mean");
      const auto *running_var = ctx.Input<Tensor>("Variance");

      const auto *running_mean_data =
          running_mean->template data<BatchNormParamType<T>>();
      const auto *running_var_data =
          running_var->template data<BatchNormParamType<T>>();

K
Kaipeng Deng 已提交
800 801 802 803 804 805 806 807 808
      if (is_inplace) {
        auto px = *x;
        inplace_functor(data_layout, px.mutable_data<T>(ctx.GetPlace()),
                        scale->template data<BatchNormParamType<T>>(),
                        bias->template data<BatchNormParamType<T>>(),
                        running_mean_data, running_var_data, epsilon, C,
                        H * W * D, num, x->data<T>(), grid2, block, stream);
      }

809
      if (compute_format == DataLayout::kNCHW) {
810
        if (d_x) {
K
Kaipeng Deng 已提交
811 812
          KeBNBackwardData<
              T, framework::DataLayout::kNCHW><<<grid1, block, 0, stream>>>(
813 814 815 816
              d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
              running_var_data, epsilon, C, H * W, num, d_x->data<T>());
        }
        if (d_scale && d_bias) {
K
Kaipeng Deng 已提交
817 818 819
          KeBNBackwardScaleBias<
              T, block,
              framework::DataLayout::kNCHW><<<grid2, block, 0, stream>>>(
820
              d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data,
Q
qingqing01 已提交
821
              epsilon, N, C, H * W * D, d_scale->data<BatchNormParamType<T>>(),
822 823 824 825
              d_bias->data<BatchNormParamType<T>>());
        }
      } else {
        if (d_x) {
K
Kaipeng Deng 已提交
826 827
          KeBNBackwardData<
              T, framework::DataLayout::kNHWC><<<grid1, block, 0, stream>>>(
828 829 830 831
              d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
              running_var_data, epsilon, C, H * W, num, d_x->data<T>());
        }
        if (d_scale && d_bias) {
K
Kaipeng Deng 已提交
832 833 834
          KeBNBackwardScaleBias<
              T, block,
              framework::DataLayout::kNHWC><<<grid2, block, 0, stream>>>(
835
              d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data,
Q
qingqing01 已提交
836
              epsilon, N, C, H * W * D, d_scale->data<BatchNormParamType<T>>(),
837 838 839 840
              d_bias->data<BatchNormParamType<T>>());
        }
      }
    }
Q
Qiao Longfei 已提交
841 842 843
  }
};

844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882
template <typename T>
class BatchNormDoubleGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *X = ctx.Input<Tensor>("X");
    const auto *Scale = ctx.Input<Tensor>("Scale");
    const auto *dY = ctx.Input<Tensor>("DY");
    const auto *Saved_mean = ctx.Input<Tensor>("SavedMean");
    const auto *Saved_variance = ctx.Input<Tensor>("SavedVariance");
    const double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
    const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
    const bool is_test = ctx.Attr<bool>("is_test");

    PADDLE_ENFORCE_EQ(
        is_test, false,
        platform::errors::InvalidArgument(
            "`is_test = True` CANNOT be used in train program. If "
            "you want to use global status in pre_train model, "
            "please set `use_global_stats = True`"));

    const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
    const DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);

    const auto *ddX = ctx.Input<Tensor>("DDX");
    const auto *ddScale = ctx.Input<Tensor>("DDScale");
    const auto *ddBias = ctx.Input<Tensor>("DDBias");

    auto *dX = ctx.Output<Tensor>("DX");
    auto *dScale = ctx.Output<Tensor>("DScale");
    auto *ddY = ctx.Output<Tensor>("DDY");

    NormDoubleGradFunctor<platform::CUDADeviceContext, T>(
        ctx, data_layout, X, Scale, dY, Saved_mean, Saved_variance, epsilon,
        use_global_stats, ddX, ddScale, ddBias, dX, dScale, ddY);
  }
};

Q
Qiao Longfei 已提交
883 884 885 886
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
K
Kexin Zhao 已提交
887
namespace plat = paddle::platform;
Q
QI JUN 已提交
888
REGISTER_OP_CUDA_KERNEL(
K
Kexin Zhao 已提交
889
    batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
D
dzhwinter 已提交
890
    ops::BatchNormKernel<plat::CUDADeviceContext, double>,
K
Kexin Zhao 已提交
891
    ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
Q
QI JUN 已提交
892
REGISTER_OP_CUDA_KERNEL(
D
dzhwinter 已提交
893
    batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>,
C
chengduo 已提交
894 895
    ops::BatchNormGradKernel<plat::CUDADeviceContext, double>,
    ops::BatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);
896 897 898 899
REGISTER_OP_CUDA_KERNEL(
    batch_norm_grad_grad,
    ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, float>,
    ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, double>);