batch_norm_kernel.cu 49.1 KB
Newer Older
H
hong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

23
#include "paddle/fluid/operators/layout_utils.h"
H
hong 已提交
24
#include "paddle/phi/backends/gpu/gpu_context.h"
25
#include "paddle/phi/backends/gpu/gpu_dnn.h"
26
#include "paddle/phi/common/layout.h"
27
#include "paddle/phi/core/enforce.h"
28
#include "paddle/phi/core/flags.h"
H
hong 已提交
29 30
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/batch_norm_kernel.h"
31
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
H
hong 已提交
32
#include "paddle/phi/kernels/funcs/eigen/common.h"
33
#include "paddle/phi/kernels/funcs/norm_utils.cu.h"
34
#include "paddle/phi/kernels/funcs/norm_utils.h"
35
#include "paddle/phi/kernels/funcs/reduce_function.h"
H
hong 已提交
36 37 38 39 40 41 42 43 44 45 46 47

#ifdef __HIPCC__
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
#else
#define LAUNCH_BOUNDS(BlockDim)
#endif

DECLARE_bool(cudnn_batchnorm_spatial_persistent);

namespace phi {

template <typename T>
48
using CudnnDataType = phi::backends::gpu::CudnnDataType<T>;
H
hong 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
template <typename T>
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;

template <typename T, phi::DataLayout layout>
static __global__ void BNForwardInference(const T *x,
                                          const BatchNormParamType<T> *mean,
                                          const BatchNormParamType<T> *variance,
                                          const BatchNormParamType<T> *scale,
                                          const BatchNormParamType<T> *bias,
                                          const int C,
                                          const int N,
                                          const int HxW,
                                          const double epsilon,
                                          T *y) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  int num = N * C * HxW;
  for (int i = gid; i < num; i += stride) {
    const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C;
    BatchNormParamType<T> x_sub_mean =
        static_cast<BatchNormParamType<T>>(x[i]) - mean[c];
    BatchNormParamType<T> inv_var = 1 / sqrt(variance[c] + epsilon);
    y[i] = static_cast<T>(scale[c] * x_sub_mean * inv_var + bias[c]);
  }
}

75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
template <typename T>
static __global__ void InverseVariance(const BatchNormParamType<T> *variance,
                                       const double epsilon,
                                       const int C,
                                       BatchNormParamType<T> *inv_variance) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  if (tid < C) {
    inv_variance[tid] = 1 / sqrt(variance[tid] + epsilon);
  }
}

template <typename T, phi::DataLayout layout>
static __global__ void BN1DForwardInference(
    const T *x,
    const BatchNormParamType<T> *mean,
    const BatchNormParamType<T> *inv_variance,
    const BatchNormParamType<T> *scale,
    const BatchNormParamType<T> *bias,
    const int C,
    const int N,
    const int HxW,
    const double epsilon,
    T *y) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  int num = N * C * HxW;
  for (int i = gid; i < num; i += stride) {
    const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C;
    BatchNormParamType<T> x_sub_mean =
        static_cast<BatchNormParamType<T>>(x[i]) - mean[c];
    y[i] = static_cast<T>(scale[c] * x_sub_mean * inv_variance[c] + bias[c]);
  }
}

H
hong 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
template <typename T, int BlockDim, phi::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
    const T *x,
    const BatchNormParamType<T> *scale,
    const BatchNormParamType<T> *bias,
    const int C,
    const int N,
    const int HxW,
    const double epsilon,
    double exponentialAverageFactor,
    T *y,
    BatchNormParamType<T> *mean,
    BatchNormParamType<T> *variance,
    BatchNormParamType<T> *save_mean,
    BatchNormParamType<T> *save_inv_variance) {
  int outer_size = C;
  int inner_size = N * HxW;
  typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage mean_storage;
  __shared__ typename BlockReduce::TempStorage variance_storeage;
  __shared__ BatchNormParamType<T> mean_val;
  __shared__ BatchNormParamType<T> variance_val;
  __shared__ BatchNormParamType<T> inv_var_val;

  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
    BatchNormParamType<T> x_sum = static_cast<BatchNormParamType<T>>(0);
    BatchNormParamType<T> x_square_sum = static_cast<BatchNormParamType<T>>(0);

    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index = layout == phi::DataLayout::kNCHW
                            ? (j / HxW * C + i) * HxW + j % HxW
                            : j * outer_size + i;
      BatchNormParamType<T> x_i = static_cast<BatchNormParamType<T>>(x[index]);
      x_sum += x_i;
      x_square_sum += x_i * x_i;
    }
    x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum());
    x_square_sum =
        BlockReduce(variance_storeage).Reduce(x_square_sum, cub::Sum());
    if (threadIdx.x == 0) {
      mean_val = x_sum / inner_size;
      variance_val = x_square_sum / inner_size - mean_val * mean_val;
      inv_var_val = 1 / sqrt(variance_val + epsilon);

      if (save_mean && save_inv_variance) {
        save_mean[i] = mean_val;
        save_inv_variance[i] = inv_var_val;
      }
      mean[i] = (1 - exponentialAverageFactor) * mean_val +
                exponentialAverageFactor * mean[i];
      variance[i] = (1 - exponentialAverageFactor) * variance_val +
                    exponentialAverageFactor * variance[i];
    }
    __syncthreads();

    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
      const int index = layout == phi::DataLayout::kNCHW
                            ? (j / HxW * C + i) * HxW + j % HxW
                            : j * outer_size + i;
      BatchNormParamType<T> x_sub_mean =
          static_cast<BatchNormParamType<T>>(x[index]) - mean_val;
      y[index] = scale[i] * x_sub_mean * inv_var_val + bias[i];
    }
  }
}

175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
template <typename T>
__device__ __forceinline__ void merge_block_horizonal(
    BatchNormParamType<T> x_sum,
    BatchNormParamType<T> x_square_sum,
    BatchNormParamType<T> *smem_sum,
    BatchNormParamType<T> *smem_square_sum,
    BatchNormParamType<T> *x_sum_out,
    BatchNormParamType<T> *x_square_sum_out) {
  int tid = threadIdx.x + threadIdx.y * blockDim.x;
#pragma unroll
  for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) {
    if (threadIdx.x < offset * 2) {
      smem_sum[tid] = x_sum;
      smem_square_sum[tid] = x_square_sum;
    }
    __syncthreads();
    if (threadIdx.x < offset) {
      int pair_tid = tid + offset;
      x_sum += smem_sum[pair_tid];
      x_square_sum += smem_square_sum[pair_tid];
    }
  }
  if (threadIdx.x == 0) {
    *x_sum_out = x_sum;
    *x_square_sum_out = x_square_sum;
  }
}

template <typename T, int BlockDim>
static __global__ void BNForwardTraining2DChannelLastCompStat(
    const T *x,
    const BatchNormParamType<T> *scale,
    const BatchNormParamType<T> *bias,
    const int C,
    const int N,
    const int HxW,
    const double epsilon,
    double exponentialAverageFactor,
    T *y,
    BatchNormParamType<T> *global_mean,
    BatchNormParamType<T> *global_variance,
    BatchNormParamType<T> *save_mean,
    BatchNormParamType<T> *save_inv_variance,
    BatchNormParamType<T> *compute_mean,
    BatchNormParamType<T> *compute_inv_var,
    BatchNormParamType<T> *block_data_ptr,
    int *flag_ptr) {
  int outer_size = C;
  int inner_size = N * HxW;

  __shared__ BatchNormParamType<T> smem_sum[BlockDim];
  __shared__ BatchNormParamType<T> smem_square_sum[BlockDim];

  int outer_loop_stride = gridDim.x * blockDim.x;
  int inner_loop_stride = gridDim.y * blockDim.y;

  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size;
       i += outer_loop_stride) {
    BatchNormParamType<T> x_sum = static_cast<BatchNormParamType<T>>(0);
    BatchNormParamType<T> x_square_sum = static_cast<BatchNormParamType<T>>(0);

    for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size;
         j += inner_loop_stride) {
      const int index = j * outer_size + i;
      BatchNormParamType<T> x_i = static_cast<BatchNormParamType<T>>(x[index]);
      x_sum += x_i;
      x_square_sum += x_i * x_i;
    }

    // vertical block sum
245 246 247 248 249 250
    funcs::BlockReduceByVetical<T, BatchNormParamType<T>>(x_sum,
                                                          x_square_sum,
                                                          &smem_sum[0],
                                                          &smem_square_sum[0],
                                                          &x_sum,
                                                          &x_square_sum);
251 252 253

    if (gridDim.y > 1) {
      __shared__ bool is_last_block_done;
254 255 256 257 258 259 260 261 262
      funcs::ReduceSumPost<T, BatchNormParamType<T>>(C,
                                                     i,
                                                     &x_sum,
                                                     &x_square_sum,
                                                     &is_last_block_done,
                                                     smem_sum,
                                                     smem_square_sum,
                                                     block_data_ptr,
                                                     flag_ptr);
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511

      if (is_last_block_done) {
        // final compute
        if (threadIdx.y == 0) {
          BatchNormParamType<T> compute_mean_val = x_sum / inner_size;
          BatchNormParamType<T> variance_val =
              x_square_sum / inner_size - compute_mean_val * compute_mean_val;
          BatchNormParamType<T> compute_inv_var_val =
              1 / sqrt(variance_val + epsilon);

          if (save_mean && save_inv_variance) {
            save_mean[i] = compute_mean_val;
            save_inv_variance[i] = compute_inv_var_val;
          }
          global_mean[i] = (1 - exponentialAverageFactor) * compute_mean_val +
                           exponentialAverageFactor * global_mean[i];
          global_variance[i] = (1 - exponentialAverageFactor) * variance_val +
                               exponentialAverageFactor * global_variance[i];

          compute_mean[i] = compute_mean_val;
          compute_inv_var[i] = compute_inv_var_val;
        }
      }
    } else {
      if (blockIdx.y == 0 && threadIdx.y == 0) {
        BatchNormParamType<T> compute_mean_val = x_sum / inner_size;
        BatchNormParamType<T> variance_val =
            x_square_sum / inner_size - compute_mean_val * compute_mean_val;
        BatchNormParamType<T> compute_inv_var_val =
            1 / sqrt(variance_val + epsilon);

        if (save_mean && save_inv_variance) {
          save_mean[i] = compute_mean_val;
          save_inv_variance[i] = compute_inv_var_val;
        }
        global_mean[i] = (1 - exponentialAverageFactor) * compute_mean_val +
                         exponentialAverageFactor * global_mean[i];
        global_variance[i] = (1 - exponentialAverageFactor) * variance_val +
                             exponentialAverageFactor * global_variance[i];

        compute_mean[i] = compute_mean_val;
        compute_inv_var[i] = compute_inv_var_val;
      }
    }
  }
}

template <typename T>
static __global__ void BNForwardTraining2DChannelLastWriteRes(
    const T *x,
    const BatchNormParamType<T> *scale,
    const BatchNormParamType<T> *bias,
    const int C,
    const int N,
    const int HxW,
    T *y,
    BatchNormParamType<T> *compute_mean,
    BatchNormParamType<T> *compute_inv_var) {
  int outer_size = C;
  int inner_size = N * HxW;

  int outer_loop_stride = gridDim.x * blockDim.x;
  int inner_loop_stride = gridDim.y * blockDim.y;

  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size;
       i += outer_loop_stride) {
    BatchNormParamType<T> mean_val = compute_mean[i];
    BatchNormParamType<T> inv_var_val = compute_inv_var[i];
    BatchNormParamType<T> scale_val = scale[i];
    BatchNormParamType<T> bias_val = bias[i];

    for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size;
         j += inner_loop_stride) {
      const int index = j * outer_size + i;
      BatchNormParamType<T> x_sub_mean =
          static_cast<BatchNormParamType<T>>(x[index]) - mean_val;
      y[index] = scale_val * x_sub_mean * inv_var_val + bias_val;
    }
  }
}

template <typename T, int BlockDim>
static __global__ void BNForwardTraining2DCompStat(
    const T *x,
    const BatchNormParamType<T> *scale,
    const BatchNormParamType<T> *bias,
    const int C,
    const int N,
    const int HxW,
    const double epsilon,
    double exponentialAverageFactor,
    T *y,
    BatchNormParamType<T> *global_mean,
    BatchNormParamType<T> *global_variance,
    BatchNormParamType<T> *save_mean,
    BatchNormParamType<T> *save_inv_variance,
    BatchNormParamType<T> *compute_mean,
    BatchNormParamType<T> *compute_inv_var,
    BatchNormParamType<T> *block_data_ptr,
    int *flag_ptr) {
  int outer_size = C;
  int inner_size = N * HxW;

  __shared__ BatchNormParamType<T> smem_sum[BlockDim];
  __shared__ BatchNormParamType<T> smem_square_sum[BlockDim];

  int outer_loop_stride = gridDim.y * blockDim.y;
  int inner_loop_stride = gridDim.x * blockDim.x;

  for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < outer_size;
       i += outer_loop_stride) {
    BatchNormParamType<T> x_sum = static_cast<BatchNormParamType<T>>(0);
    BatchNormParamType<T> x_square_sum = static_cast<BatchNormParamType<T>>(0);

    for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < inner_size;
         j += inner_loop_stride) {
      const int index = (j / HxW * C + i) * HxW + j % HxW;
      BatchNormParamType<T> x_i = static_cast<BatchNormParamType<T>>(x[index]);
      x_sum += x_i;
      x_square_sum += x_i * x_i;
    }

    // horizonal block sum
    merge_block_horizonal<T>(x_sum,
                             x_square_sum,
                             &smem_sum[0],
                             &smem_square_sum[0],
                             &x_sum,
                             &x_square_sum);

    if (gridDim.x > 1) {
      volatile BatchNormParamType<T> *staging_sum = block_data_ptr;
      volatile BatchNormParamType<T> *staging_square_sum =
          &block_data_ptr[C * gridDim.x];
      // write block data to global memory
      if (threadIdx.x == 0) {
        staging_sum[i + blockIdx.x * C] = x_sum;
        staging_square_sum[i + blockIdx.x * C] = x_square_sum;
      }

      // make sure write is visible to all blocks
      __threadfence();
      __syncthreads();

      __shared__ bool is_last_block_done;
      // mark block done
      if (threadIdx.x == 0 && threadIdx.y == 0) {
        int old = atomicAdd(&flag_ptr[blockIdx.y], 1);
        is_last_block_done = (old == (gridDim.x - 1));
      }

      __syncthreads();

      if (is_last_block_done) {
        x_sum = static_cast<BatchNormParamType<T>>(0);
        x_square_sum = static_cast<BatchNormParamType<T>>(0);
        // thread sum
        for (int x = threadIdx.x; x < gridDim.x; x += blockDim.x) {
          x_sum += staging_sum[i + x * C];
          x_square_sum += staging_square_sum[i + x * C];
        }

        // horizonal block sum
        merge_block_horizonal<T>(x_sum,
                                 x_square_sum,
                                 &smem_sum[0],
                                 &smem_square_sum[0],
                                 &x_sum,
                                 &x_square_sum);

        // final compute
        if (threadIdx.x == 0) {
          BatchNormParamType<T> compute_mean_val = x_sum / inner_size;
          BatchNormParamType<T> variance_val =
              x_square_sum / inner_size - compute_mean_val * compute_mean_val;
          BatchNormParamType<T> compute_inv_var_val =
              1 / sqrt(variance_val + epsilon);

          if (save_mean && save_inv_variance) {
            save_mean[i] = compute_mean_val;
            save_inv_variance[i] = compute_inv_var_val;
          }
          global_mean[i] = (1 - exponentialAverageFactor) * compute_mean_val +
                           exponentialAverageFactor * global_mean[i];
          global_variance[i] = (1 - exponentialAverageFactor) * variance_val +
                               exponentialAverageFactor * global_variance[i];

          compute_mean[i] = compute_mean_val;
          compute_inv_var[i] = compute_inv_var_val;
        }
      }
    } else {
      if (blockIdx.x == 0 && threadIdx.x == 0) {
        BatchNormParamType<T> compute_mean_val = x_sum / inner_size;
        BatchNormParamType<T> variance_val =
            x_square_sum / inner_size - compute_mean_val * compute_mean_val;
        BatchNormParamType<T> compute_inv_var_val =
            1 / sqrt(variance_val + epsilon);

        if (save_mean && save_inv_variance) {
          save_mean[i] = compute_mean_val;
          save_inv_variance[i] = compute_inv_var_val;
        }
        global_mean[i] = (1 - exponentialAverageFactor) * compute_mean_val +
                         exponentialAverageFactor * global_mean[i];
        global_variance[i] = (1 - exponentialAverageFactor) * variance_val +
                             exponentialAverageFactor * global_variance[i];

        compute_mean[i] = compute_mean_val;
        compute_inv_var[i] = compute_inv_var_val;
      }
    }
  }
}

template <typename T>
static __global__ void BNForwardTraining2DWriteRes(
    const T *x,
    const BatchNormParamType<T> *scale,
    const BatchNormParamType<T> *bias,
    const int C,
    const int N,
    const int HxW,
    T *y,
    BatchNormParamType<T> *compute_mean,
    BatchNormParamType<T> *compute_inv_var) {
  int outer_size = C;
  int inner_size = N * HxW;

  int outer_loop_stride = gridDim.y * blockDim.y;
  int inner_loop_stride = gridDim.x * blockDim.x;

  for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < outer_size;
       i += outer_loop_stride) {
    BatchNormParamType<T> mean_val = compute_mean[i];
    BatchNormParamType<T> inv_var_val = compute_inv_var[i];
    BatchNormParamType<T> scale_val = scale[i];
    BatchNormParamType<T> bias_val = bias[i];

    for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < inner_size;
         j += inner_loop_stride) {
      const int index = (j / HxW * C + i) * HxW + j % HxW;
      BatchNormParamType<T> x_sub_mean =
          static_cast<BatchNormParamType<T>>(x[index]) - mean_val;
      y[index] = scale_val * x_sub_mean * inv_var_val + bias_val;
    }
  }
}

H
hong 已提交
512 513 514 515 516
template <typename T, typename Context>
void BatchNormKernel(const Context &ctx,
                     const DenseTensor &x,
                     const DenseTensor &mean,
                     const DenseTensor &variance,
517 518 519
                     const DenseTensor &scale,
                     const DenseTensor &bias,
                     bool is_test,
H
hong 已提交
520 521 522 523 524 525 526 527 528 529 530 531 532
                     float momentum,
                     float epsilon_f,
                     const std::string &data_layout_str,
                     bool use_global_stats,
                     bool trainable_statistics,
                     DenseTensor *y,
                     DenseTensor *mean_out,
                     DenseTensor *variance_out,
                     DenseTensor *saved_mean,
                     DenseTensor *saved_variance,
                     DenseTensor *reserve_space) {
  double epsilon = epsilon_f;
  const bool trainable_stats = trainable_statistics;
533
  const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
H
hong 已提交
534 535 536 537 538 539 540 541 542 543 544 545 546 547 548
  bool test_mode = is_test && (!trainable_stats);

  // Get the size for each dimension.
  // NCHW [batch_size, in_channels, in_height, in_width]
  const auto &x_dims = x.dims();
  PADDLE_ENFORCE_EQ(
      x_dims.size() >= 2 && x_dims.size() <= 5,
      true,
      phi::errors::InvalidArgument(
          "The size of input's dimensions should be between 2 and 5"
          "But received: the size of input's dimensions is [%d]",
          x_dims.size()));

  ctx.template Alloc<T>(y);
  int N, C, H, W, D;
549
  phi::funcs::ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
H
hong 已提交
550

551
  auto dtype = phi::backends::gpu::CudnnDataType<T>::type;
H
hong 已提交
552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722

#ifdef PADDLE_WITH_HIP
  auto compute_format =
      data_layout == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW;

// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
// auto compute_format = DataLayout::kNCHW;
#else
  const bool fast_nhwc_batch_norm =
      test_mode ||
      (dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent);

  auto compute_format = fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC
                            ? DataLayout::kNHWC
                            : DataLayout::kNCHW;
#endif

  DenseTensor transformed_x(x.type());
  DenseTensor transformed_y(y->type());

  if (data_layout == DataLayout::kNHWC && compute_format == DataLayout::kNCHW &&
      x_dims.size() > 2) {
    VLOG(3) << "Transform input tensor from NHWC to NCHW.";
    ResizeToChannelFirst<Context, T>(ctx, &x, &transformed_x);
    TransToChannelFirst<Context, T>(ctx, &x, &transformed_x);
    ResizeToChannelFirst<Context, T>(ctx, y, &transformed_y);
  } else {
    transformed_x.ShareDataWith(x);
    transformed_y.ShareDataWith(*y);
  }

// ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// miopenTensorDescriptor_t data_desc_;
// miopenTensorDescriptor_t bn_param_desc_;
// miopenBatchNormMode_t mode_;

// PADDLE_ENFORCE_GPU_SUCCESS(
//     platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
// PADDLE_ENFORCE_GPU_SUCCESS(
//     platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
#else
  cudnnTensorDescriptor_t data_desc_;
  cudnnTensorDescriptor_t bn_param_desc_;
  cudnnBatchNormMode_t mode_;

  PADDLE_ENFORCE_GPU_SUCCESS(
      paddle::platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
  PADDLE_ENFORCE_GPU_SUCCESS(
      paddle::platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
#endif

  if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
    LOG(ERROR) << "Provided epsilon is smaller than "
               << "CUDNN_BN_MIN_EPSILON. Setting it to "
               << "CUDNN_BN_MIN_EPSILON instead.";
  }
  epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);

#ifdef PADDLE_WITH_HIP
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// mode_ = miopenBNSpatial;
#elif CUDNN_VERSION_MIN(7, 0, 1)
  if (FLAGS_cudnn_batchnorm_spatial_persistent) {
    mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
  } else if (H == 1 && W == 1) {
    mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
  } else {
    mode_ = CUDNN_BATCHNORM_SPATIAL;
  }
#else
  if (H == 1 && W == 1) {
    mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
  } else {
    mode_ = CUDNN_BATCHNORM_SPATIAL;
  }
#endif  // CUDNN_VERSION_MIN(7, 0, 1)

  VLOG(3) << "Setting descriptors.";
  std::vector<int> dims;
  std::vector<int> strides;
  if (compute_format == DataLayout::kNCHW) {
    dims = {N, C, H, W, D};
    strides = {C * H * W * D, H * W * D, W * D, D, 1};
  } else {
    dims = {N, C, H, W, D};
    strides = {H * W * D * C, 1, W * D * C, D * C, C};
  }

#ifdef PADDLE_WITH_HIP
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
//     data_desc_, CudnnDataType<T>::type,
//     x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
//     const_cast<int *>(strides.data())));
// Note: PERSISTENT not implemented for inference
// PADDLE_ENFORCE_GPU_SUCCESS(
//     platform::dynload::miopenDeriveBNTensorDescriptor(
//         bn_param_desc_, data_desc_, test_mode ? miopenBNSpatial : mode_));
#else
  PADDLE_ENFORCE_GPU_SUCCESS(
      paddle::platform::dynload::cudnnSetTensorNdDescriptor(
          data_desc_,
          CudnnDataType<T>::type,
          x_dims.size() > 3 ? x_dims.size() : 4,
          dims.data(),
          strides.data()));
  // Note: PERSISTENT not implemented for inference
  PADDLE_ENFORCE_GPU_SUCCESS(
      paddle::platform::dynload::cudnnDeriveBNTensorDescriptor(
          bn_param_desc_,
          data_desc_,
          test_mode ? CUDNN_BATCHNORM_SPATIAL : mode_));
#endif

  auto handle = ctx.cudnn_handle();

  // Now, depending on whether we are running test or not, we have two paths.
  // It is training mode when it's not reference AND not using pre-trained
  // model.
  bool training = !test_mode && !use_global_stats;
  if (!training) {
    // only when test we use input to do computation.
    const auto *est_mean = &mean;
    const auto *est_var = &variance;
    // Run inference mode.
    PADDLE_ENFORCE_EQ(
        est_mean->dims().size(),
        1UL,
        phi::errors::InvalidArgument(
            "The size of mean's dimensions must equal to 1."
            "But received: the size of mean's dimensions mean is [%d],"
            "the dimensions of mean is [%s].",
            est_mean->dims().size(),
            est_mean->dims()));
    PADDLE_ENFORCE_EQ(
        est_var->dims().size(),
        1UL,
        phi::errors::InvalidArgument(
            "The size of variance's dimensions must equal to 1."
            "But received: the size of variance's dimensions is [%d],"
            "the dimensions of variance is [%s].",
            est_var->dims().size(),
            est_var->dims()));
    PADDLE_ENFORCE_EQ(
        est_mean->dims()[0],
        C,
        phi::errors::InvalidArgument(
            "The first dimension of mean must equal to the number of "
            "Channels, which is [%d]. But received: the first dimension"
            "of mean is [%d], the dimensions of mean is [%s].",
            C,
            est_mean->dims()[0],
            est_mean->dims()));
    PADDLE_ENFORCE_EQ(
        est_var->dims()[0],
        C,
        phi::errors::InvalidArgument(
            "The first dimension of variance must equal to the number"
            "of Channels, which is [%d]. But received: the first dimension of"
            "variance is [%d], the dimensions of variance is [%s].",
            C,
            est_var->dims()[0],
            est_var->dims()));

#ifdef PADDLE_WITH_HIP
    const int block_size = 256;
    const int grid_size = (N * C * H * W * D + block_size - 1) / block_size;
    if (compute_format == DataLayout::kNCHW) {
723 724 725 726 727 728 729 730 731 732 733 734
      BNForwardInference<T, DataLayout::kNCHW>
          <<<grid_size, block_size, 0, ctx.stream()>>>(
              transformed_x.template data<T>(),
              est_mean->template data<BatchNormParamType<T>>(),
              est_var->template data<BatchNormParamType<T>>(),
              scale.template data<BatchNormParamType<T>>(),
              bias.template data<BatchNormParamType<T>>(),
              C,
              N,
              H * W * D,
              epsilon,
              transformed_y.template data<T>());
H
hong 已提交
735
    } else {
736 737 738 739 740 741 742 743 744 745 746 747
      BNForwardInference<T, DataLayout::kNHWC>
          <<<grid_size, block_size, 0, ctx.stream()>>>(
              transformed_x.template data<T>(),
              est_mean->template data<BatchNormParamType<T>>(),
              est_var->template data<BatchNormParamType<T>>(),
              scale.template data<BatchNormParamType<T>>(),
              bias.template data<BatchNormParamType<T>>(),
              C,
              N,
              H * W * D,
              epsilon,
              transformed_y.template data<T>());
H
hong 已提交
748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772
    }
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_GPU_SUCCESS(
//     platform::dynload::miopenBatchNormalizationForwardInference(
//         handle, miopenBNSpatial,
//         const_cast<void *>(
//             static_cast<const void *>(CudnnDataType<T>::kOne())),
//         const_cast<void *>(
//             static_cast<const void *>(CudnnDataType<T>::kZero())),
//         data_desc_,
//         static_cast<const void *>(transformed_x.template data<T>()),
//         data_desc_,
//         static_cast<void *>(
//             transformed_y.template mutable_data<T>(ctx.GetPlace())),
//         bn_param_desc_,
//         const_cast<void *>(static_cast<const void *>(
//             scale->template data<BatchNormParamType<T>>())),
//         const_cast<void *>(static_cast<const void *>(
//             bias->template data<BatchNormParamType<T>>())),
//         const_cast<void *>(static_cast<const void *>(
//             est_mean->template data<BatchNormParamType<T>>())),
//         const_cast<void *>(static_cast<const void *>(
//             est_var->template data<BatchNormParamType<T>>())),
//         epsilon));
#else
773
    const bool use_native_kernel =
774
        (x_dims.size() == 2 ||
Z
zhangkaihuo 已提交
775
         (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_EVAL));
776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792
    if (use_native_kernel) {
      const int block_size = 256;
      const int grid_size = (N * C * H * W * D + block_size - 1) / block_size;
      if (compute_format == DataLayout::kNCHW) {
        BNForwardInference<T, DataLayout::kNCHW>
            <<<grid_size, block_size, 0, ctx.stream()>>>(
                transformed_x.template data<T>(),
                est_mean->template data<BatchNormParamType<T>>(),
                est_var->template data<BatchNormParamType<T>>(),
                scale.template data<BatchNormParamType<T>>(),
                bias.template data<BatchNormParamType<T>>(),
                C,
                N,
                H * W * D,
                epsilon,
                transformed_y.template data<T>());
      } else {
793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829
        if (x_dims.size() == 2) {
          DenseTensor inv_var = phi::Empty<BatchNormParamType<T>>(ctx, {C});
          auto *inv_var_ptr = inv_var.data<BatchNormParamType<T>>();
          const int threads = 512 > C ? C : 512;
          const int blocks = (C + 511) / 512;
          InverseVariance<T><<<blocks, threads>>>(
              est_var->template data<BatchNormParamType<T>>(),
              epsilon,
              C,
              inv_var_ptr);
          BN1DForwardInference<T, DataLayout::kNHWC>
              <<<grid_size, block_size, 0, ctx.stream()>>>(
                  transformed_x.template data<T>(),
                  est_mean->template data<BatchNormParamType<T>>(),
                  // est_var->template data<BatchNormParamType<T>>(),
                  inv_var_ptr,
                  scale.template data<BatchNormParamType<T>>(),
                  bias.template data<BatchNormParamType<T>>(),
                  C,
                  N,
                  H * W * D,
                  epsilon,
                  transformed_y.template data<T>());
        } else {
          BNForwardInference<T, DataLayout::kNHWC>
              <<<grid_size, block_size, 0, ctx.stream()>>>(
                  transformed_x.template data<T>(),
                  est_mean->template data<BatchNormParamType<T>>(),
                  est_var->template data<BatchNormParamType<T>>(),
                  scale.template data<BatchNormParamType<T>>(),
                  bias.template data<BatchNormParamType<T>>(),
                  C,
                  N,
                  H * W * D,
                  epsilon,
                  transformed_y.template data<T>());
        }
830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849
      }
    } else {
      PADDLE_ENFORCE_GPU_SUCCESS(
          paddle::platform::dynload::cudnnBatchNormalizationForwardInference(
              handle,
              // Note: PERSISTENT not implemented for inference
              CUDNN_BATCHNORM_SPATIAL,
              CudnnDataType<T>::kOne(),
              CudnnDataType<T>::kZero(),
              data_desc_,
              transformed_x.template data<T>(),
              data_desc_,
              ctx.template Alloc<T>(&transformed_y),
              bn_param_desc_,
              scale.template data<BatchNormParamType<T>>(),
              bias.template data<BatchNormParamType<T>>(),
              est_mean->template data<BatchNormParamType<T>>(),
              est_var->template data<BatchNormParamType<T>>(),
              epsilon));
    }
H
hong 已提交
850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866
#endif
  } else {
    // if MomentumTensor is set, use MomentumTensor value, momentum
    // is only used in this training branch

    // need to solve here
    // if (ctx.HasInput("MomentumTensor")) {
    //   const auto *mom_tensor = MomentumTensor;
    //   DenseTensor mom_cpu;
    //   paddle::framework::TensorCopySync(*mom_tensor, platform::CPUPlace(),
    //                                     &mom_cpu);
    //   momentum = mom_cpu.data<float>()[0];
    // }

    // Run training mode.
    // obtain running mean and running inv var, and there is no need
    // to initialize them.
H
hong 已提交
867 868
    ctx.template Alloc<BatchNormParamType<T>>(mean_out);
    ctx.template Alloc<BatchNormParamType<T>>(variance_out);
H
hong 已提交
869

H
hong 已提交
870 871
    ctx.template Alloc<BatchNormParamType<T>>(saved_mean);
    ctx.template Alloc<BatchNormParamType<T>>(saved_variance);
H
hong 已提交
872 873 874 875 876 877 878 879

    if ((N * H * W * D) == 1) {
      // Only 1 element in normalization dimension,
      // skip the batch norm calculation, let y = x.
      paddle::framework::TensorCopy(x, ctx.GetPlace(), y);
    } else {
      double this_factor = 1. - momentum;
#ifdef PADDLE_WITH_HIP
880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947
      const int num = transformed_x.numel();
      const int block = 256;
      const int max_threads = ctx.GetMaxPhysicalThreadCount();
      const int max_blocks = std::max(max_threads / block, 1);
      const int grid = std::min(C, max_blocks);
      if (compute_format == DataLayout::kNCHW) {
        BNForwardTraining<T, block, DataLayout::kNCHW>
            <<<grid, block, 0, ctx.stream()>>>(
                transformed_x.template data<T>(),
                scale.template data<BatchNormParamType<T>>(),
                bias.template data<BatchNormParamType<T>>(),
                C,
                N,
                H * W * D,
                epsilon,
                this_factor,
                transformed_y.template data<T>(),
                mean_out->template data<BatchNormParamType<T>>(),
                variance_out->template data<BatchNormParamType<T>>(),
                saved_mean->template data<BatchNormParamType<T>>(),
                saved_variance->template data<BatchNormParamType<T>>());
      } else {
        BNForwardTraining<T, block, DataLayout::kNHWC>
            <<<grid, block, 0, ctx.stream()>>>(
                transformed_x.template data<T>(),
                scale.template data<BatchNormParamType<T>>(),
                bias.template data<BatchNormParamType<T>>(),
                C,
                N,
                H * W * D,
                epsilon,
                this_factor,
                transformed_y.template data<T>(),
                mean_out->template data<BatchNormParamType<T>>(),
                variance_out->template data<BatchNormParamType<T>>(),
                saved_mean->template data<BatchNormParamType<T>>(),
                saved_variance->template data<BatchNormParamType<T>>());
      }
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_GPU_SUCCESS(
//     platform::dynload::miopenBatchNormalizationForwardTraining(
//         handle, mode_, const_cast<void *>(static_cast<const void *>(
//                            CudnnDataType<T>::kOne())),
//         const_cast<void *>(
//             static_cast<const void *>(CudnnDataType<T>::kZero())),
//         data_desc_,
//         static_cast<const void *>(transformed_x.template data<T>()),
//         data_desc_,
//         static_cast<void *>(
//             transformed_y.template mutable_data<T>(ctx.GetPlace())),
//         bn_param_desc_,
//         const_cast<void *>(static_cast<const void *>(
//             scale->template data<BatchNormParamType<T>>())),
//         const_cast<void *>(static_cast<const void *>(
//             bias->template data<BatchNormParamType<T>>())),
//         this_factor,
//         static_cast<void *>(
//             mean_out->template mutable_data<BatchNormParamType<T>>(
//                 ctx.GetPlace())),
//         static_cast<void *>(variance_out->template mutable_data<
//                             BatchNormParamType<T>>(ctx.GetPlace())),
//         epsilon,
//         static_cast<void *>(
//             saved_mean->template mutable_data<BatchNormParamType<T>>(
//                 ctx.GetPlace())),
//         static_cast<void *>(saved_variance->template mutable_data<
//                             BatchNormParamType<T>>(ctx.GetPlace()))));
#else
Z
zhangkaihuo 已提交
948
      // const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
949
      const bool use_native_kernel =
950
          ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
Z
zhangkaihuo 已提交
951
           (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
952
      if (use_native_kernel) {
953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003
        dim3 block;
        dim3 grid;
        const int block_size = 512;
        const int MAX_GRID_SIZE = 128;
        const int WARP_SIZE = 32;

        // init intermediate storage
        DenseTensor block_data_tensor;
        DenseTensor flag_tensor;
        DenseTensor compute_mean_tensor =
            phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
        DenseTensor compute_inv_var_tensor =
            phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});

        BatchNormParamType<T> *block_data_ptr = nullptr;
        int *flag_ptr = nullptr;

        if (x_dims.size() != 2 && compute_format == DataLayout::kNCHW) {
          // init block&grid config
          int block_x =
              std::min(phi::funcs::details::GetLastPow2(H * W * D), block_size);
          int block_y = std::min(phi::funcs::details::GetLastPow2(C),
                                 block_size / block_x);

          if (block_x * block_y != block_size) {
            block_x =
                std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16),
                         block_size / block_y);
          }

          int grid_x =
              std::min((N * H * W * D + block_x * 16 - 1) / (block_x * 16),
                       MAX_GRID_SIZE);
          int grid_y = (C + block_y - 1) / block_y;

          block.x = block_x;
          block.y = block_y;
          grid.x = grid_x;
          grid.y = grid_y;

          if (grid.x > 1) {
            block_data_tensor = phi::Empty<BatchNormParamType<T>, Context>(
                ctx, {2 * C * grid.x});
            flag_tensor = phi::Empty<int, Context>(ctx, {grid.y});

            block_data_ptr = block_data_tensor.data<BatchNormParamType<T>>();
            flag_ptr = flag_tensor.data<int>();
            funcs::SetConstant<Context, int> set_zero;
            set_zero(ctx, &flag_tensor, static_cast<int>(0));
          }
          BNForwardTraining2DCompStat<T, block_size>
1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016
              <<<grid, block, 0, ctx.stream()>>>(
                  transformed_x.template data<T>(),
                  scale.template data<BatchNormParamType<T>>(),
                  bias.template data<BatchNormParamType<T>>(),
                  C,
                  N,
                  H * W * D,
                  epsilon,
                  this_factor,
                  transformed_y.template data<T>(),
                  mean_out->template data<BatchNormParamType<T>>(),
                  variance_out->template data<BatchNormParamType<T>>(),
                  saved_mean->template data<BatchNormParamType<T>>(),
1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032
                  saved_variance->template data<BatchNormParamType<T>>(),
                  compute_mean_tensor.data<BatchNormParamType<T>>(),
                  compute_inv_var_tensor.data<BatchNormParamType<T>>(),
                  block_data_ptr,
                  flag_ptr);

          BNForwardTraining2DWriteRes<T><<<grid, block, 0, ctx.stream()>>>(
              transformed_x.template data<T>(),
              scale.template data<BatchNormParamType<T>>(),
              bias.template data<BatchNormParamType<T>>(),
              C,
              N,
              H * W * D,
              transformed_y.template data<T>(),
              compute_mean_tensor.data<BatchNormParamType<T>>(),
              compute_inv_var_tensor.data<BatchNormParamType<T>>());
H
hong 已提交
1033
        } else {
1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064
          // init block&grid config
          int block_x =
              std::min(phi::funcs::details::GetLastPow2(C), WARP_SIZE);
          int block_y =
              std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16),
                       block_size / block_x);
          if (block_x * block_y != block_size) {
            block_x = std::min(phi::funcs::details::GetLastPow2(C),
                               block_size / block_y);
          }
          int grid_x = (C + block_x - 1) / block_x;
          int grid_y =
              std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16),
                       MAX_GRID_SIZE);

          block.x = block_x;
          block.y = block_y;
          grid.x = grid_x;
          grid.y = grid_y;

          if (grid.y > 1) {
            block_data_tensor = phi::Empty<BatchNormParamType<T>, Context>(
                ctx, {2 * C * grid.y});
            flag_tensor = phi::Empty<int, Context>(ctx, {grid.x});

            block_data_ptr = block_data_tensor.data<BatchNormParamType<T>>();
            flag_ptr = flag_tensor.data<int>();
            funcs::SetConstant<Context, int> set_zero;
            set_zero(ctx, &flag_tensor, static_cast<int>(0));
          }
          BNForwardTraining2DChannelLastCompStat<T, block_size>
1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077
              <<<grid, block, 0, ctx.stream()>>>(
                  transformed_x.template data<T>(),
                  scale.template data<BatchNormParamType<T>>(),
                  bias.template data<BatchNormParamType<T>>(),
                  C,
                  N,
                  H * W * D,
                  epsilon,
                  this_factor,
                  transformed_y.template data<T>(),
                  mean_out->template data<BatchNormParamType<T>>(),
                  variance_out->template data<BatchNormParamType<T>>(),
                  saved_mean->template data<BatchNormParamType<T>>(),
1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094
                  saved_variance->template data<BatchNormParamType<T>>(),
                  compute_mean_tensor.data<BatchNormParamType<T>>(),
                  compute_inv_var_tensor.data<BatchNormParamType<T>>(),
                  block_data_ptr,
                  flag_ptr);

          BNForwardTraining2DChannelLastWriteRes<T>
              <<<grid, block, 0, ctx.stream()>>>(
                  transformed_x.template data<T>(),
                  scale.template data<BatchNormParamType<T>>(),
                  bias.template data<BatchNormParamType<T>>(),
                  C,
                  N,
                  H * W * D,
                  transformed_y.template data<T>(),
                  compute_mean_tensor.data<BatchNormParamType<T>>(),
                  compute_inv_var_tensor.data<BatchNormParamType<T>>());
H
hong 已提交
1095
        }
1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106
      } else {
#if CUDNN_VERSION_MIN(7, 4, 1)
        size_t workspace_size = 0;
        size_t reserve_space_size = 0;
        void *reserve_space_ptr = nullptr;
        void *workspace_ptr = nullptr;
        DenseTensor workspace_tensor;
        DenseTensor reserve_space_tensor;
        // Create reserve space and workspace for batch norm.
        // Create tensor for each batchnorm op, it will be used in the
        // backward. Thus this tensor shouldn't be temp.
1107
        // auto *reserve_space = ctx.Output<phi::DenseTensor>("ReserveSpace");
1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172
        if (reserve_space == nullptr) {
          reserve_space = &reserve_space_tensor;
        }
        PADDLE_ENFORCE_NOT_NULL(
            reserve_space,
            phi::errors::NotFound(
                "The argument ReserveSpace of batch_norm op is not found."));
        // --------------- cudnn batchnorm workspace ---------------
        PADDLE_ENFORCE_GPU_SUCCESS(
            paddle::platform::dynload::
                cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
                    /*handle=*/handle,
                    /*mode=*/mode_,
                    /*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
                    /*xDesc=*/data_desc_,
                    /*zDesc=*/nullptr,
                    /*yDesc=*/data_desc_,
                    /*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
                    /*activationDesc=*/nullptr,
                    /*sizeInBytes=*/&workspace_size));

        // -------------- cudnn batchnorm reserve space --------------
        PADDLE_ENFORCE_GPU_SUCCESS(
            paddle::platform::dynload::
                cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
                    /*handle=*/handle,
                    /*mode=*/mode_,
                    /*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
                    /*activationDesc=*/nullptr,
                    /*xDesc=*/data_desc_,
                    /*sizeInBytes=*/&reserve_space_size));

        reserve_space->Resize({static_cast<int64_t>(reserve_space_size)});
        reserve_space_ptr =
            static_cast<void *>(ctx.template Alloc<uint8_t>(reserve_space));
        workspace_tensor.Resize({static_cast<int64_t>(workspace_size)});
        workspace_ptr =
            static_cast<void *>(ctx.template Alloc<uint8_t>(&workspace_tensor));
        PADDLE_ENFORCE_GPU_SUCCESS(
            paddle::platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
                handle,
                mode_,
                CUDNN_BATCHNORM_OPS_BN,
                CudnnDataType<T>::kOne(),
                CudnnDataType<T>::kZero(),
                data_desc_,
                transformed_x.template data<T>(),
                nullptr,
                nullptr,
                data_desc_,
                transformed_y.template data<T>(),
                bn_param_desc_,
                scale.template data<BatchNormParamType<T>>(),
                bias.template data<BatchNormParamType<T>>(),
                this_factor,
                ctx.template Alloc<BatchNormParamType<T>>(mean_out),
                ctx.template Alloc<BatchNormParamType<T>>(variance_out),
                epsilon,
                ctx.template Alloc<BatchNormParamType<T>>(saved_mean),
                ctx.template Alloc<BatchNormParamType<T>>(saved_variance),
                nullptr,
                workspace_ptr,
                workspace_size,
                reserve_space_ptr,
                reserve_space_size));
H
hong 已提交
1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187
#else
        PADDLE_ENFORCE_GPU_SUCCESS(
            paddle::platform::dynload::cudnnBatchNormalizationForwardTraining(
                handle,
                mode_,
                CudnnDataType<T>::kOne(),
                CudnnDataType<T>::kZero(),
                data_desc_,
                transformed_x.template data<T>(),
                data_desc_,
                ctx.template Alloc<T>(&transformed_y),
                bn_param_desc_,
                scale.template data<BatchNormParamType<T>>(),
                bias.template data<BatchNormParamType<T>>(),
                this_factor,
H
hong 已提交
1188 1189
                ctx.template Alloc<BatchNormParamType<T>>(mean_out),
                ctx.template Alloc<BatchNormParamType<T>>(variance_out),
H
hong 已提交
1190
                epsilon,
H
hong 已提交
1191 1192
                ctx.template Alloc<BatchNormParamType<T>>(saved_mean),
                ctx.template Alloc<BatchNormParamType<T>>(saved_variance)));
1193
#endif  // CUDNN_VERSION_MIN(7, 4, 1)
H
hong 已提交
1194
      }
1195
#endif
H
hong 已提交
1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227
    }
  }

  if (data_layout == DataLayout::kNHWC && compute_format == DataLayout::kNCHW &&
      x_dims.size() > 2) {
    VLOG(3) << "Transform batchnorm output from NCHW to NHWC";
    TransToChannelLast<Context, T>(ctx, &transformed_y, y);
  }
#ifdef PADDLE_WITH_HIP
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// clean when exit.
// PADDLE_ENFORCE_GPU_SUCCESS(
//     platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
// PADDLE_ENFORCE_GPU_SUCCESS(
//     platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
#else
  // clean when exit.
  PADDLE_ENFORCE_GPU_SUCCESS(
      paddle::platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
  PADDLE_ENFORCE_GPU_SUCCESS(
      paddle::platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
#endif
}

}  // namespace phi

#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(batch_norm,
                   GPU,
                   ALL_LAYOUT,
                   phi::BatchNormKernel,
                   float,
1228 1229 1230 1231 1232 1233 1234 1235 1236 1237
                   phi::dtype::float16) {
  kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
  kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
  kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
  kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
  kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
  kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
  kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
  kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
}
H
hong 已提交
1238 1239 1240 1241 1242 1243 1244 1245 1246
#else
PD_REGISTER_KERNEL(batch_norm,
                   GPU,
                   ALL_LAYOUT,
                   phi::BatchNormKernel,
                   float,
                   double,
                   phi::dtype::float16) {
  if (kernel_key.dtype() == phi::DataType::FLOAT16) {
1247 1248 1249 1250
    kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
    kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
    kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
    kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
H
hong 已提交
1251 1252 1253 1254 1255 1256 1257 1258
    kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
    kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
    kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
    kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
  }
}

#endif