sync_batch_norm_op.cu.h 20.0 KB
Newer Older
K
Kaipeng Deng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <algorithm>
#include <cfloat>
#include <cmath>
#include <string>
#include <vector>
22
#ifdef __NVCC__
K
Kaipeng Deng 已提交
23
#include "cub/cub.cuh"
24 25 26 27 28 29 30
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#include "paddle/fluid/platform/miopen_helper.h"
#endif
K
Kaipeng Deng 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/norm_utils.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/nccl_helper.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;

template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void KeLocalStats(const T *x, int N, int M, int C,
                             BatchNormParamType<T> *mean_var) {
  typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage temp_storage;
  for (int k = blockIdx.x; k < C; k += gridDim.x) {
    BatchNormParamType<T> x_sum = 0.;
    BatchNormParamType<T> x2_sum = 0.;
    for (int i = threadIdx.x; i < N * M; i += BlockDim) {
      int id = layout == framework::DataLayout::kNCHW
                   ? (i / M) * C * M + k * M + i % M
                   : i * C + k;
      auto x_in = static_cast<BatchNormParamType<T>>(x[id]);
      x_sum += x_in;
      x2_sum += x_in * x_in;
    }
    __syncthreads();
    auto out = BlockReduce(temp_storage).Reduce(x_sum, cub::Sum());
    __syncthreads();
    if (threadIdx.x == 0) {
      mean_var[k] = out / (N * M);
    }
    out = BlockReduce(temp_storage).Reduce(x2_sum, cub::Sum());
    __syncthreads();
    if (threadIdx.x == 0) {
      mean_var[k + C] = out / (N * M);
    }
  }
  if (blockIdx.x == 0 && threadIdx.x == 0) {
    mean_var[2 * C] = static_cast<BatchNormParamType<T>>(1.0);
  }
}

template <typename T>
__global__ void KeSyncAndMovingStats(
    BatchNormParamType<T> *means, BatchNormParamType<T> *variances,
    BatchNormParamType<T> *num_dev, const int C,
    const BatchNormParamType<T> momentum, const double epsilon,
    BatchNormParamType<T> *sv_mean_data, BatchNormParamType<T> *sv_inv_var_data,
    BatchNormParamType<T> *moving_means,
    BatchNormParamType<T> *moving_variances) {
  // sync stats across multi-devices
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  for (int i = gid; i < C; i += stride) {
    auto mean = means[i] / (*num_dev);
    auto var = variances[i] / (*num_dev);
    var = var - mean * mean;

    // sync stats
    sv_mean_data[i] = mean;
    sv_inv_var_data[i] = 1.0 / sqrt(var + epsilon);
    variances[i] = var;

    // moving stats
    moving_means[i] = moving_means[i] * momentum + mean * (1. - momentum);
    moving_variances[i] =
        moving_variances[i] * momentum + var * (1. - momentum);
  }
}

template <typename T, framework::DataLayout layout>
static __global__ void KeNormAffine(const T *x,
                                    const BatchNormParamType<T> *scale,
                                    const BatchNormParamType<T> *bias,
                                    const BatchNormParamType<T> *mean,
                                    const BatchNormParamType<T> *variance,
                                    const double epsilon, const int C,
                                    const int M, const int num, T *y) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  for (int i = gid; i < num; i += stride) {
    const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C;
    auto x_i = static_cast<BatchNormParamType<T>>(x[i]);
    auto y_i =
        (x_i - mean[c]) / sqrt(variance[c] + epsilon) * scale[c] + bias[c];
    y[i] = static_cast<T>(y_i);
  }
}

template <typename DeviceContext, typename T>
void SyncBatchNormFunctor(const framework::ExecutionContext &ctx,
                          const DataLayout layout, const framework::Tensor *x,
                          framework::Tensor *y, const framework::Tensor *mean,
                          const framework::Tensor *variance,
                          framework::Tensor *mean_out,
                          framework::Tensor *variance_out,
                          framework::Tensor *saved_mean,
                          framework::Tensor *saved_variance, double epsilon,
                          const float momentum, const bool is_test,
                          const bool use_global_stats

                          ) {
  const auto &x_dims = x->dims();
  PADDLE_ENFORCE_GE(x_dims.size(), 2,
                    platform::errors::InvalidArgument(
                        "The Input dim size should be larger than 1."));
  PADDLE_ENFORCE_LE(x_dims.size(), 5,
                    platform::errors::InvalidArgument(
                        "The Input dim size should be less than 6."));
  int N, C, H, W, D;
  ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D);
  int x_numel = x->numel();

  const T *x_d = x->data<T>();
  const auto *s_d = ctx.Input<Tensor>("Scale")->data<BatchNormParamType<T>>();
  const auto *b_d = ctx.Input<Tensor>("Bias")->data<BatchNormParamType<T>>();

  T *y_d = y->mutable_data<T>(ctx.GetPlace());

  const BatchNormParamType<T> *mean_data = nullptr;
  const BatchNormParamType<T> *var_data = nullptr;

  auto &dev_ctx = ctx.cuda_device_context();
  auto stream = dev_ctx.stream();
  const int block = 512;
  int max_threads = dev_ctx.GetMaxPhysicalThreadCount();

  paddle::memory::AllocationPtr alloc_ptr{nullptr};

  if (is_test) {
    mean_data = mean->data<BatchNormParamType<T>>();
    var_data = variance->data<BatchNormParamType<T>>();
  } else {
    // x, x^2, 1, here 1 is used to calc device num
    // device num also can be got from platform::DeviceContextPool
    const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType<T>);
    alloc_ptr = memory::Alloc(dev_ctx, bytes);

    auto *stats = reinterpret_cast<BatchNormParamType<T> *>(alloc_ptr->ptr());
    const int threads = 256;
    int grid = std::min(C, (max_threads + threads - 1) / threads);
    if (layout == framework::DataLayout::kNCHW) {
      KeLocalStats<T, threads,
                   framework::DataLayout::kNCHW><<<grid, threads, 0, stream>>>(
          x_d, N, H * W * D, C, stats);
    } else {
      KeLocalStats<T, threads,
                   framework::DataLayout::kNHWC><<<grid, threads, 0, stream>>>(
          x_d, N, H * W * D, C, stats);
    }

190
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
K
Kaipeng Deng 已提交
191 192 193 194
    auto *comm = dev_ctx.nccl_comm();
    if (comm) {
      int dtype = platform::ToNCCLDataType(mean_out->type());
      // In-place operation
195 196 197
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
          stats, stats, 2 * C + 1, static_cast<ncclDataType_t>(dtype), ncclSum,
          comm, stream));
K
Kaipeng Deng 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
    }
#endif

    auto *est_mean_data =
        mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
    auto *est_var_data =
        variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());

    auto *sv_mean_data =
        saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
    auto *sv_inv_var_data =
        saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());

    // Note, Input('Mean')/Input('Variance') share variable with
    // Output('MeanOut')/Output('VarianceOut')
    KeSyncAndMovingStats<T><<<(C + block - 1) / block, block, 0, stream>>>(
        stats, stats + C, stats + 2 * C, C, momentum, epsilon, sv_mean_data,
        sv_inv_var_data, est_mean_data, est_var_data);

    mean_data = sv_mean_data;
    var_data = stats + C;
  }

  int grid2 = (std::min(x_numel, max_threads) + block - 1) / block;
  if (layout == framework::DataLayout::kNCHW) {
    KeNormAffine<T, framework::DataLayout::kNCHW><<<grid2, block, 0, stream>>>(
        x_d, s_d, b_d, mean_data, var_data, epsilon, C, H * W * D, x_numel,
        y_d);
  } else {
    KeNormAffine<T, framework::DataLayout::kNHWC><<<grid2, block, 0, stream>>>(
        x_d, s_d, b_d, mean_data, var_data, epsilon, C, H * W * D, x_numel,
        y_d);
  }
}

template <typename T, const int BlockDim, framework::DataLayout layout>
__global__ void KeBackwardLocalStats(const T *dy, const T *x,
                                     const BatchNormParamType<T> *means, int N,
                                     int M, int C,
                                     BatchNormParamType<T> *sum_dy_prod) {
  typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage temp_storage;
  for (int k = blockIdx.x; k < C; k += gridDim.x) {
    BatchNormParamType<T> sum1 = 0.;
    BatchNormParamType<T> sum2 = 0.;
    auto mean = means[k];
    for (int i = threadIdx.x; i < N * M; i += blockDim.x) {
      int id = layout == framework::DataLayout::kNCHW
                   ? (i / M) * C * M + k * M + i % M
                   : i * C + k;
      auto g = static_cast<BatchNormParamType<T>>(dy[id]);
      sum1 += g;
      auto x_i = static_cast<BatchNormParamType<T>>(x[id]);
      sum2 += g * (x_i - mean);
    }

    __syncthreads();
    auto out = BlockReduce(temp_storage).Reduce(sum1, cub::Sum());
    __syncthreads();
    if (threadIdx.x == 0) {
      sum_dy_prod[k] = out;
    }
    out = BlockReduce(temp_storage).Reduce(sum2, cub::Sum());
    __syncthreads();
    if (threadIdx.x == 0) {
      sum_dy_prod[k + C] = out;
    }
  }
  if (blockIdx.x == 0 && threadIdx.x == 0) {
    sum_dy_prod[2 * C] = 1.0;
  }
}

template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void KeBNBackwardScaleBias(
    const T *dy, const T *x, const BatchNormParamType<T> *mean,
    const BatchNormParamType<T> *inv_variance, const double epsilon,
    const int N, const int C, const int HxW, BatchNormParamType<T> *dscale,
    BatchNormParamType<T> *dbias) {
  const int outer_size = C;
  const int inner_size = N * HxW;
  typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage temp_storage;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    BatchNormParamType<T> ds_sum = 0.;
    BatchNormParamType<T> db_sum = 0.;

    auto inv_var_i = inv_variance[i];
    auto mean_i = mean[i];
    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int id = layout == framework::DataLayout::kNCHW
                         ? ((j / HxW) * C + i) * HxW + (j % HxW)
                         : j * outer_size + i;
      auto x_i = static_cast<BatchNormParamType<T>>(x[id]);
      auto dy_i = static_cast<BatchNormParamType<T>>(dy[id]);
      ds_sum += dy_i * (x_i - mean_i);
      db_sum += dy_i;
    }
    __syncthreads();
    auto os = BlockReduce(temp_storage).Reduce(ds_sum, cub::Sum());
    __syncthreads();
    auto ob = BlockReduce(temp_storage).Reduce(db_sum, cub::Sum());
    __syncthreads();
    if (threadIdx.x == 0) {
      dscale[i] = os * inv_var_i;
      dbias[i] = ob;
    }
    __syncthreads();
  }
}

template <typename T, framework::DataLayout layout>
static __global__ void KeBNRestoreData(T *x, const BatchNormParamType<T> *scale,
                                       const BatchNormParamType<T> *bias,
                                       const BatchNormParamType<T> *mean,
                                       const BatchNormParamType<T> *sv_inv,
                                       const double epsilon, int C, int M,
                                       int num, const T *y) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  for (int i = gid; i < num; i += stride) {
    const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C;
    auto y_i = static_cast<BatchNormParamType<T>>(y[i]);
    auto x_i = (y_i - bias[c]) / scale[c] / sv_inv[c] + mean[c];
    x[i] = static_cast<T>(x_i);
  }
}

template <typename T, framework::DataLayout layout>
static __global__ void KeBNBackwardData(
    const T *dy, const T *x, const BatchNormParamType<T> *gamma,
    const BatchNormParamType<T> *mean,
    const BatchNormParamType<T> *inv_variance,
    const BatchNormParamType<T> *g_sum_dy,
    const BatchNormParamType<T> *g_sum_dy_prod,
    const BatchNormParamType<T> *num_dev, const double epsilon, const int C,
    const int HxW, const int num, T *dx) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  auto scale = static_cast<BatchNormParamType<T>>(C) / num;
  auto dev_num = num_dev[0];
  for (int i = gid; i < num; i += stride) {
    const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
    auto inv_var = inv_variance[c];
    auto s_d = gamma[c];
    auto gvar =
        -(g_sum_dy_prod[c] / dev_num) * s_d * inv_var * (inv_var * inv_var);
    auto gmean = -(g_sum_dy[c] / dev_num) * s_d * inv_var;

    auto x_i = static_cast<BatchNormParamType<T>>(x[i]);
    auto dy_i = static_cast<BatchNormParamType<T>>(dy[i]);
    auto dx_i =
        dy_i * s_d * inv_var + gmean * scale + gvar * scale * (x_i - mean[c]);
    dx[i] = static_cast<T>(dx_i);
  }
}

template <typename DeviceContext, typename T>
void SyncBatchNormGradFunctor(
    const framework::ExecutionContext &ctx, const DataLayout layout,
    const framework::Tensor *scale, const framework::Tensor *bias,
    framework::Tensor *d_x, const framework::Tensor *d_y,
    framework::Tensor *d_scale, framework::Tensor *d_bias,
    const framework::Tensor *mean, const framework::Tensor *variance,
    const double epsilon) {
  // sync_batch_norm with inplace as false will take X as grad input, which
  // is same as cuDNN batch_norm backward calculation, batch_norm
  // with inplace as true only take Y as input and X should be calculate
  // by inverse operation of batch_norm on Y
  const Tensor *x;
  bool is_inplace;
  if (ctx.HasInput("Y")) {
    x = ctx.Input<Tensor>("Y");
    is_inplace = true;
  } else {
    x = ctx.Input<Tensor>("X");
    is_inplace = false;
  }

  const auto &x_dims = x->dims();

  PADDLE_ENFORCE_GE(x_dims.size(), 2,
                    platform::errors::InvalidArgument(
                        "The Input X dim size should be larger than 1."));
  PADDLE_ENFORCE_LE(x_dims.size(), 5,
                    platform::errors::InvalidArgument(
                        "The Input X dim size should be less than 6."));

  int N, C, H, W, D;
  ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D);
  PADDLE_ENFORCE_EQ(scale->dims()[0], C,
                    platform::errors::InvalidArgument(
                        "Expected first dim for input parameter(scale) of "
                        "OP(sync_batch_norm) be (%d), but given (%d).",
                        C, scale->dims()[0]));

  d_x->mutable_data<T>(ctx.GetPlace());
  if (d_scale && d_bias) {
    d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
    d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
  }
  PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL,
                    platform::errors::InvalidArgument(
                        "Expected rank for input parameter(scale) of "
                        "OP(sync_batch_norm) be (1), but given (%d).",
                        scale->dims().size()));

  std::vector<int> dims;
  std::vector<int> strides;
  if (layout == DataLayout::kNCHW) {
    dims = {N, C, H, W, D};
    strides = {C * H * W * D, H * W * D, W * D, D, 1};
  } else {
    dims = {N, C, H, W, D};
    strides = {H * W * C * D, 1, W * D * C, D * C, C};
  }
  const T *x_d = x->data<T>();
  auto px = *x;
  const T *dy_d = d_y->data<T>();

  auto &dev_ctx = ctx.cuda_device_context();
  auto stream = dev_ctx.stream();

  const auto *saved_mean = mean->data<BatchNormParamType<T>>();
  const auto *saved_inv_var = variance->data<BatchNormParamType<T>>();
  const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType<T>);
  auto alloc_ptr = memory::Alloc(dev_ctx, bytes);
  auto *stats = reinterpret_cast<BatchNormParamType<T> *>(alloc_ptr->ptr());

  const int block = 512;
  const int threads = 256;
  int x_numel = x->numel();
  int fsize = H * W * D;
  int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
  int grid = std::min(C, (max_threads + threads - 1) / threads);
  int grid2 = (std::min(x_numel, max_threads) + block - 1) / block;

  if (is_inplace) {
    if (layout == framework::DataLayout::kNCHW) {
      KeBNRestoreData<
          T, framework::DataLayout::kNCHW><<<grid2, block, 0, stream>>>(
          px.mutable_data<T>(ctx.GetPlace()),
          scale->data<BatchNormParamType<T>>(),
          bias->data<BatchNormParamType<T>>(), saved_mean, saved_inv_var,
          epsilon, C, H * W * D, x_numel, x->data<T>());
    } else {
      KeBNRestoreData<
          T, framework::DataLayout::kNHWC><<<grid2, block, 0, stream>>>(
          px.mutable_data<T>(ctx.GetPlace()),
          scale->data<BatchNormParamType<T>>(),
          bias->data<BatchNormParamType<T>>(), saved_mean, saved_inv_var,
          epsilon, C, H * W * D, x_numel, x->data<T>());
    }
  }

  if (layout == framework::DataLayout::kNCHW) {
    KeBackwardLocalStats<
        T, threads, framework::DataLayout::kNCHW><<<grid, threads, 0, stream>>>(
        dy_d, x_d, saved_mean, N, fsize, C, stats);
  } else {
    KeBackwardLocalStats<
        T, threads, framework::DataLayout::kNHWC><<<grid, threads, 0, stream>>>(
        dy_d, x_d, saved_mean, N, fsize, C, stats);
  }

464
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
K
Kaipeng Deng 已提交
465 466 467 468
  auto *comm = dev_ctx.nccl_comm();
  if (comm) {
    int dtype = platform::ToNCCLDataType(scale->type());
    // In-place operation
469 470 471
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
        stats, stats, 2 * C + 1, static_cast<ncclDataType_t>(dtype), ncclSum,
        comm, stream));
K
Kaipeng Deng 已提交
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525
  }
#endif

  if (layout == framework::DataLayout::kNCHW) {
    if (d_scale && d_bias) {
      KeBNBackwardScaleBias<
          T, threads,
          framework::DataLayout::kNCHW><<<grid, threads, 0, stream>>>(
          dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize,
          d_scale->data<BatchNormParamType<T>>(),
          d_bias->data<BatchNormParamType<T>>());
    }
    if (d_x) {
      KeBNBackwardData<
          T, framework::DataLayout::kNCHW><<<grid2, block, 0, stream>>>(
          dy_d, x_d, scale->data<BatchNormParamType<T>>(), saved_mean,
          saved_inv_var, stats, stats + C, stats + 2 * C, epsilon, C, fsize,
          x->numel(), d_x->data<T>());
    }
  } else {
    if (d_scale && d_bias) {
      KeBNBackwardScaleBias<
          T, threads,
          framework::DataLayout::kNHWC><<<grid, threads, 0, stream>>>(
          dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize,
          d_scale->data<BatchNormParamType<T>>(),
          d_bias->data<BatchNormParamType<T>>());
    }
    if (d_x) {
      KeBNBackwardData<
          T, framework::DataLayout::kNHWC><<<grid2, block, 0, stream>>>(
          dy_d, x_d, scale->data<BatchNormParamType<T>>(), saved_mean,
          saved_inv_var, stats, stats + C, stats + 2 * C, epsilon, C, fsize,
          x->numel(), d_x->data<T>());
    }
  }
}

template <typename DeviceContext, typename T>
class SyncBatchNormKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override;
};

// Deriving the Gradient for the Backward Pass of Batch Normalization
// https://kevinzakka.github.io/2016/09/14/batch_normalization/
template <typename DeviceContext, typename T>
class SyncBatchNormGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override;
};

}  // namespace operators
}  // namespace paddle