layer_norm_kernel.cu.h 70.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

25 26
#include <iostream>

27
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
28 29
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
30
#include "paddle/phi/core/ddim.h"
31
#include "paddle/phi/kernels/funcs/aligned_vector.h"
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;

inline static int GetDesiredBlockDim(int64_t block_dim) {
#ifdef __HIPCC__
  const int kMaxBlockDim = 256;
  const int lwarpSize = 64;
#else
  const int kMaxBlockDim = 512;
  const int lwarpSize = 32;
#endif
  return block_dim >= kMaxBlockDim ? kMaxBlockDim : lwarpSize;
}

template <typename U>
static __forceinline__ __device__ U WarpReduceSum(U val) {
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int offset = warpSize / 2; offset > 0; offset /= 2) {
    val += paddle::platform::CudaShuffleDownSync(mask, val, offset);
  }
  return val;
}

template <typename U>
__forceinline__ __device__ U BlockReduceSum(U val, U *shared) {
  int lane = threadIdx.x % warpSize;
  int wid = threadIdx.x / warpSize;

  val = WarpReduceSum(val);  // Each warp performs partial reduction

  __syncthreads();
  if (lane == 0) shared[wid] = val;  // Write reduced value to shared memory

  __syncthreads();  // Wait for all partial reductions
  // read from shared memory only if that warp existed
  val =
      (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : static_cast<U>(0);

  if (wid == 0) val = WarpReduceSum(val);  // Final reduce within first warp

  return val;
}

#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...)  \
  case (1 << (log2_block_dim)): {                       \
    constexpr auto kBlockDim = (1 << (log2_block_dim)); \
    __VA_ARGS__;                                        \
  } break

#define FIXED_BLOCK_DIM_CASE(...)              \
  FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(2, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(1, ##__VA_ARGS__)

#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                          \
    log2_block_dim, feature_size, kMaxBlockNum, ...)                        \
  case (1 << (log2_block_dim)): {                                           \
    for (int64_t i = 0; i < std::ceil(feature_size / (1.0 * kMaxBlockNum)); \
         i++) {                                                             \
      int64_t col_offset = i * static_cast<int64_t>(kMaxBlockNum);          \
      int block_num = static_cast<int>(std::min(                            \
          feature_size - col_offset, static_cast<int64_t>(kMaxBlockNum)));  \
      constexpr auto kBlockDim = (1 << (log2_block_dim));                   \
      __VA_ARGS__;                                                          \
    }                                                                       \
  } break

#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(feature_size, kMaxBlockNum, ...) \
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                                  \
      9, feature_size, kMaxBlockNum, ##__VA_ARGS__);                          \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                                  \
      8, feature_size, kMaxBlockNum, ##__VA_ARGS__);                          \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                                  \
      7, feature_size, kMaxBlockNum, ##__VA_ARGS__);                          \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                                  \
      6, feature_size, kMaxBlockNum, ##__VA_ARGS__);                          \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                                  \
      5, feature_size, kMaxBlockNum, ##__VA_ARGS__);                          \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                                  \
      4, feature_size, kMaxBlockNum, ##__VA_ARGS__);                          \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                                  \
      3, feature_size, kMaxBlockNum, ##__VA_ARGS__);                          \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                                  \
      2, feature_size, kMaxBlockNum, ##__VA_ARGS__);                          \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                                  \
      1, feature_size, kMaxBlockNum, ##__VA_ARGS__)
132 133

static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); }
S
sneaxiy 已提交
134 135 136
static __device__ __forceinline__ double real_sqrt(double x) {
  return ::sqrt(x);
}
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167

template <typename T>
struct PairForLayerNorm {
  __device__ __forceinline__ PairForLayerNorm() {}
  __device__ __forceinline__ PairForLayerNorm(const T &first, const T &second)
      : first_(first), second_(second) {}

  T first_;
  T second_;
};

template <typename T>
struct PairForLayerNormAddFunctor {
  __device__ __forceinline__ PairForLayerNorm<T> operator()(
      const PairForLayerNorm<T> &p1, const PairForLayerNorm<T> &p2) {
    return PairForLayerNorm<T>(p1.first_ + p2.first_, p1.second_ + p2.second_);
  }
};

template <typename T>
__inline__ __device__ T rsqrt_(const T val) {
  return static_cast<T>(1) / sqrt(val);
}

template <>
__inline__ __device__ float rsqrt_(const float val) {
  return rsqrtf(val);
}

template <>
__inline__ __device__ double rsqrt_(const double val) {
S
sneaxiy 已提交
168
  return ::rsqrt(val);
169 170 171 172 173 174 175 176 177
}

#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
template <>
__inline__ __device__ half rsqrt_(const half val) {
  return hrsqrt(val);
}
#endif

178
#ifdef PADDLE_WITH_CUDA
179 180 181 182 183 184 185 186 187
template <typename T,
          typename U,
          typename ScaleT = U,
          int VecSize = 8,
          int WARPS_M = 4,
          int WARPS_N = 1,
          int BYTES_PER_LDG = 16,
          int ELTS_PER_ROW = 1024,
          int THREADS_PER_WARP = 32,
188 189 190 191 192
          int THREADS_PER_ROW = WARPS_N *THREADS_PER_WARP,
          int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW,
          int ROWS_PER_CTA = WARPS_M,
          int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
          int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA>
193
__global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
194 195 196 197 198 199 200 201
    int rows,
    int cols,
    const float epsilon,
    const T *__restrict__ x_ptr,
    const ScaleT *__restrict__ gamma_ptr,
    const ScaleT *__restrict__ beta_ptr,
    U *__restrict__ mean_out_ptr,
    U *__restrict__ var_out_ptr,
202
    T *__restrict__ y_ptr) {
203
  __shared__ U smem[WARPS_M * WARPS_N];
204 205
  using Vec = phi::AlignedVector<T, VecSize>;
  using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220

  const int tidx = threadIdx.x;
  const int bidx = blockIdx.x;
  const int lane = tidx % THREADS_PER_WARP;  // 0, 1, ..., 31
  const int warp = tidx / THREADS_PER_WARP;  // 0, 1, 2, 3
  const int warp_n = warp % WARPS_N;         // 0
  const int warp_m = warp / WARPS_N;         // 0, 1, 2, 3

  const int c = warp_n * THREADS_PER_WARP + lane;  // lane
  const int r = bidx * ROWS_PER_CTA + warp_m;      // row id

  Vec_scale gamma[LDGS];
  Vec_scale beta[LDGS];
#pragma unroll
  for (int it = 0, col = c; it < LDGS; it++) {
221 222
    phi::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
    phi::Load<ScaleT, VecSize>(beta_ptr + col * VecSize, &beta[it]);
223 224 225
    col += THREADS_PER_ROW;
  }

226
  constexpr U rn = 1.f / U(ELTS_PER_ROW);
227 228 229 230
  for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
    Vec x[LDGS];
#pragma unroll
    for (int it = 0, col = c; it < LDGS; it++) {
231
      phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]);
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
      col += THREADS_PER_ROW;
    }
    U xf[LDGS * VecSize];

    U mu_local = 0.f;

#pragma unroll
    for (int it = 0; it < LDGS; it++) {
#pragma unroll
      for (int jt = 0; jt < VecSize; jt++) {
        xf[it * VecSize + jt] = U(x[it][jt]);
        mu_local += xf[it * VecSize + jt];
      }
    }

#pragma unroll
    for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
      mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
    }
251 252 253 254 255
    if (WARPS_N > 1) {
      if (lane == 0) {
        smem[warp_m * WARPS_N + warp_n] = mu_local;
      }
      __syncthreads();
256
      if (tidx % THREADS_PER_ROW == 0) {
257 258 259 260 261 262 263 264 265 266 267
        mu_local = 0.f;
#pragma unroll
        for (int it = 0; it < WARPS_N; ++it) {
          mu_local += smem[warp_m * WARPS_N + it];
        }
        smem[warp_m] = mu_local;
      }
      __syncthreads();
      mu_local = smem[warp_m];
    }

268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
    mu_local *= rn;
    if (lane == 0) {
      mean_out_ptr[row] = mu_local;
    }
    U var_local = 0.f;

#pragma unroll
    for (int it = 0; it < LDGS; it++) {
#pragma unroll
      for (int jt = 0; jt < VecSize; jt++) {
        U diff = xf[it * VecSize + jt] - mu_local;
        var_local += diff * diff;
      }
    }

#pragma unroll
    for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
      var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
    }
287 288 289 290 291 292

    if (WARPS_N > 1) {
      if (lane == 0) {
        smem[warp_m * WARPS_N + warp_n] = var_local;
      }
      __syncthreads();
293
      if (tidx % THREADS_PER_ROW == 0) {
294 295 296 297 298 299 300 301 302 303 304
        var_local = 0.f;
#pragma unroll
        for (int it = 0; it < WARPS_N; ++it) {
          var_local += smem[warp_m * WARPS_N + it];
        }
        smem[warp_m] = var_local;
      }
      __syncthreads();
      var_local = smem[warp_m];
    }

305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
    // Note: to assure if it is right for double
    U rsigma = rsqrtf(var_local * rn + epsilon);
    if (lane == 0) {
      var_out_ptr[row] = var_local * rn;
    }

#pragma unroll
    for (int it = 0; it < LDGS; it++) {
#pragma unroll
      for (int jt = 0; jt < VecSize; jt++) {
        // use fp16 to compute
        // ScaleT tmp = static_cast<ScaleT>(rsigma * (xf[it * VecSize + jt] -
        // mu_local));
        // x[it][jt] = gamma[it][jt] *  tmp + beta[it][jt];
        // cast to fp32 to compute
        U tmp = (rsigma * (static_cast<U>(xf[it * VecSize + jt]) - mu_local));
        x[it][jt] = static_cast<T>(static_cast<U>(gamma[it][jt]) * tmp +
                                   static_cast<U>(beta[it][jt]));
      }
    }

#pragma unroll
    for (int it = 0, col = c; it < LDGS; it++) {
328
      phi::Store<T, VecSize>(x[it], y_ptr + row * ELTS_PER_ROW + col * VecSize);
329 330 331 332 333 334
      col += THREADS_PER_ROW;
    }
  }
}
#endif

335 336 337 338
template <typename T, typename U, bool ScaleBiasWithSameTypeX>
using LayerNormScaleBiasT =
    typename std::conditional<ScaleBiasWithSameTypeX, T, U>::type;

339 340 341
template <typename T,
          typename U,
          int BlockDim,
342 343 344
          bool ScaleBiasWithSameTypeX = false,
          typename InType = T,
          typename OutType = T>
345
__global__ void LayerNormForward(
346
    const InType *x,
347 348
    const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
    const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *bias,
349
    OutType *y,
350 351 352
    U *mean,
    U *var,
    float epsilon,
353 354 355 356 357 358 359
    int64_t feature_size,
    const float *dequant_out_scale_data = nullptr,
    const int quant_out_scale_offset = 0,
    const float quant_in_scale = 1.0,
    const int quant_round_type = 1,
    const float quant_max_bound = 127.0,
    const float quant_min_bound = -127.0) {
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
  __shared__ U mean_share;
  __shared__ U var_share;
  __shared__ U shared_mean[32];  // threadIdx.x / warpSize <= kMaxBlockDim /
                                 // warpSize <= 1024/32 = 32;
  __shared__ U shared_var[32];

  int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int64_t end_idx = (blockIdx.x + 1) * feature_size;

  // Step 1: Reduce to calculate mean and var
  U mean_val = 0;
  U var_val = 0;
  for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
    U tmp = static_cast<U>(x[i]);
    mean_val += tmp;
    var_val += (tmp * tmp);
  }

  mean_val = BlockReduceSum<U>(mean_val, shared_mean);
  var_val = BlockReduceSum<U>(var_val, shared_var);

  if (threadIdx.x == 0) {
382 383
    auto scale = static_cast<U>(static_cast<float>(1.) /
                                static_cast<float>(feature_size));
384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
    auto tmp = mean_val * scale;
    mean[blockIdx.x] = mean_share = static_cast<U>(tmp);
    var_share = static_cast<U>(var_val * scale - mean_share * mean_share);
    var_share = var_share > U(0) ? var_share : U(0);
    var[blockIdx.x] = var_share;
  }
  __syncthreads();

  mean_val = mean_share;
  U invvar = rsqrt_<U>(var_share + static_cast<U>(epsilon));

  // Step 2: Calculate y
  if (scale != nullptr) {
    if (bias != nullptr) {
      for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
        if (std::is_same<OutType, int8_t>::value) {
          y[i] = quant_helper(
              static_cast<T>(static_cast<U>(scale[j]) *
                                 (static_cast<U>(x[i]) - mean_val) * invvar +
                             static_cast<U>(bias[j])),
              quant_in_scale,
              quant_round_type,
              quant_max_bound,
              quant_min_bound);
        } else {
          y[i] = static_cast<OutType>(static_cast<U>(scale[j]) *
                                          (static_cast<U>(x[i]) - mean_val) *
                                          invvar +
                                      static_cast<U>(bias[j]));
        }
415 416 417 418
      }
    } else {
      for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
419 420 421 422 423 424 425 426 427 428 429 430 431
        if (std::is_same<OutType, int8_t>::value) {
          y[i] = quant_helper(
              static_cast<T>(static_cast<U>(scale[j]) *
                             (static_cast<U>(x[i]) - mean_val) * invvar),
              quant_in_scale,
              quant_round_type,
              quant_max_bound,
              quant_min_bound);
        } else {
          y[i] =
              static_cast<OutType>(static_cast<U>(scale[j]) *
                                   (static_cast<U>(x[i]) - mean_val) * invvar);
        }
432 433 434 435 436 437
      }
    }
  } else {  // scale == nullptr
    if (bias != nullptr) {
      for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
438 439 440 441 442 443 444 445 446 447 448 449 450
        if (std::is_same<OutType, int8_t>::value) {
          y[i] = quant_helper(
              static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
                             static_cast<U>(bias[j])),
              quant_in_scale,
              quant_round_type,
              quant_max_bound,
              quant_min_bound);
        } else {
          y[i] =
              static_cast<OutType>((static_cast<U>(x[i]) - mean_val) * invvar +
                                   static_cast<U>(bias[j]));
        }
451 452 453 454
      }
    } else {
      for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
455 456 457 458 459 460 461 462 463 464 465
        if (std::is_same<OutType, int8_t>::value) {
          y[i] = quant_helper(
              static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar),
              quant_in_scale,
              quant_round_type,
              quant_max_bound,
              quant_min_bound);
        } else {
          y[i] =
              static_cast<OutType>((static_cast<U>(x[i]) - mean_val) * invvar);
        }
466 467 468 469 470 471
      }
    }
  }
}

template <typename T, typename U, int VPT>
472 473 474 475 476 477 478 479 480 481 482 483 484 485
__inline__ __device__ void cuLoadAddStridedInputs(const int64_t i1_block,
                                                  const int thr_load_row_off,
                                                  const int thr_load_col_off,
                                                  const int i2_off,
                                                  const int row_stride,
                                                  U *warp_buf1,
                                                  U *warp_buf2,
                                                  const T *input,
                                                  const T *dout,
                                                  const int64_t i1_end,
                                                  const int64_t n2,
                                                  const U *__restrict__ mean,
                                                  const U *__restrict__ var,
                                                  const float epsilon) {
486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503
  const int64_t i1 = i1_block + thr_load_row_off;
  if (i1 >= i1_end) return;
  U curr_mean = mean[i1];
  U curr_invvar = rsqrt_<U>(var[i1] + epsilon);
  for (int k = 0; k < VPT; ++k) {
    const int i2 = i2_off + k;
    const int64_t load_idx = i1 * n2 + i2;
    const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
    if (i2 < n2) {
      U curr_input = static_cast<U>(input[load_idx]);
      U curr_dout = static_cast<U>(dout[load_idx]);
      warp_buf1[write_idx] += curr_dout;
      warp_buf2[write_idx] +=
          curr_dout * (curr_input - curr_mean) * curr_invvar;
    }
  }
}

504
#ifdef PADDLE_WITH_CUDA
505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520
template <bool isFusedDropoutResidualLn,
          typename T,
          typename U,
          typename ScaleT = U,
          typename MaskType = uint8_t,
          int VecSize = 8,
          int WARPS_M = 4,
          int WARPS_N = 1,
          int BYTES_PER_LDG = 16,
          int ELTS_PER_ROW = 1024,
          int THREADS_PER_WARP = 32,
          int THREADS_PER_ROW = WARPS_N *THREADS_PER_WARP,
          int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW,
          int ROWS_PER_CTA = WARPS_M,
          int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
          int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA>
521
__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel(
522 523 524 525 526 527 528 529 530 531 532 533 534
    const int rows,
    float epsilon,
    const T *__restrict__ x_ptr,
    const ScaleT *__restrict__ gamma_ptr,
    const U *__restrict__ mean_ptr,
    const U *__restrict__ var_ptr,
    const T *__restrict__ dout_ptr,
    U *__restrict__ dgamma_temp_ptr,
    U *__restrict__ dbeta_temp_ptr,
    T *__restrict__ dx_ptr,
    const MaskType *mask_ptr = nullptr,
    T factor = static_cast<T>(0),
    T *d_dropout_src_ptr = nullptr) {
535 536 537
  using Vec = phi::AlignedVector<T, VecSize>;
  using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;
  using MaskLoadT = phi::AlignedVector<MaskType, VecSize>;
538 539 540 541 542 543 544 545 546 547 548 549

  const int tidx = threadIdx.x;
  const int bidx = blockIdx.x;
  const int lane = tidx % THREADS_PER_WARP;            // 0, 1, ..., 31
  const int warp = tidx / THREADS_PER_WARP;            // 0, 1, 2, 3
  const int warp_m = warp / WARPS_N;                   // 0, 1, 2, 3
  const int warp_n = warp % WARPS_N;                   // 0
  const int tid_r = warp_n * THREADS_PER_WARP + lane;  // 0, 1, ..., 31

  const int r = bidx * ROWS_PER_CTA + warp_m;
  const int c = warp_n * THREADS_PER_WARP + lane;

550
  static_assert(ELTS_PER_ROW == THREADS_PER_ROW * LDGS * VecSize, "");
551 552

  // smem for column reduction
553
  __shared__ U smem_[ROWS_PER_CTA * ELTS_PER_ROW];
554 555 556 557 558 559 560 561 562 563 564 565 566 567

  U dgamma_sum[LDGS * VecSize];
  U dbeta_sum[LDGS * VecSize];

  memset(dgamma_sum, 0, sizeof(U) * LDGS * VecSize);
  memset(dbeta_sum, 0, sizeof(U) * LDGS * VecSize);

  // Note: it is no use for WARP_N = 1
  __shared__ U smem_sum_loss1[ROWS_PER_CTA * WARPS_N];  // 4
  __shared__ U smem_sum_loss2[ROWS_PER_CTA * WARPS_N];  // 4
  U *sum_loss1_shared = &smem_sum_loss1[warp_m * WARPS_N];
  U *sum_loss2_shared = &smem_sum_loss2[warp_m * WARPS_N];

  // step-1: compute dx and local results of dscale and dbias
568
  constexpr float rn = 1.f / static_cast<float>(ELTS_PER_ROW);
569 570 571 572
  Vec_scale gamma[LDGS];
  int col = c;
#pragma unroll
  for (int it = 0; it < LDGS; it++) {
573
    phi::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
574 575 576 577 578 579 580 581 582 583 584 585
    col += THREADS_PER_ROW;
  }

#pragma unroll 1
  for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
    const U mean_cur_row = mean_ptr[row];
    const U var_cur_row = rsqrt_<U>(var_ptr[row] + epsilon);
    Vec dout[LDGS], x[LDGS];
    MaskLoadT mask_vec[LDGS];
    int col = c;
#pragma unroll
    for (int it = 0; it < LDGS; it++) {
586
      phi::Load<T, VecSize>(dout_ptr + row * ELTS_PER_ROW + col * VecSize,
587
                            &dout[it]);
588
      phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]);
589
      if (isFusedDropoutResidualLn) {
590
        phi::Load<MaskType, VecSize>(
591
            mask_ptr + row * ELTS_PER_ROW + col * VecSize, &mask_vec[it]);
592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
      }

      col += THREADS_PER_ROW;
    }

    // local reductions
    U dy[LDGS * VecSize];
    U y[LDGS * VecSize];

    U sum_loss1 = 0.f;
    U sum_loss2 = 0.f;
#pragma unroll
    for (int it = 0; it < LDGS; it++) {
#pragma unroll
      for (int jt = 0; jt < VecSize; jt++) {
607
        U x_tmp = static_cast<U>(x[it][jt]);
608 609
        U y_tmp = var_cur_row * (x_tmp - mean_cur_row);
        U dy_tmp = static_cast<U>(gamma[it][jt]) *
610 611
                   static_cast<U>(dout[it][jt]);    // scale * dy
        U dout_tmp = static_cast<U>(dout[it][jt]);  // dy
612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684

        // used for get dx (row reduction)
        sum_loss1 += dy_tmp;          // scale * dy, sum_1
        sum_loss2 += dy_tmp * y_tmp;  // scale * dy * y, sum_2

        dy[it * VecSize + jt] = dy_tmp;  // scale * dy
        y[it * VecSize + jt] = y_tmp;    // y

        // used for get dscale and dbias (column reduction)
        dgamma_sum[it * VecSize + jt] += dout_tmp * y_tmp;  // dy * y
        dbeta_sum[it * VecSize + jt] += dout_tmp;           // dy
      }
    }

    // reduction across row for sum_loss1, sum_loss2
    if (WARPS_N == 1) {
#pragma unroll
      // row reduction among 32 threads.
      for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
        sum_loss1 += __shfl_xor_sync(uint32_t(-1), sum_loss1, it);
        sum_loss2 += __shfl_xor_sync(uint32_t(-1), sum_loss2, it);
      }
      sum_loss1 *= rn;
      sum_loss2 *= rn;
    } else {
#pragma unroll
      for (int it = 16; it > 0; it /= 2) {
        sum_loss1 += __shfl_down_sync(uint32_t(-1), sum_loss1, it);
        sum_loss2 += __shfl_down_sync(uint32_t(-1), sum_loss2, it);
      }

      if (lane == 0) {
        sum_loss1_shared[warp_n] = sum_loss1;
        sum_loss2_shared[warp_n] = sum_loss2;
      }

      __syncthreads();
      if (warp_n == 0 && lane == 0) {
        sum_loss1 = 0.f;
        sum_loss2 = 0.f;
        for (int it = 0; it < WARPS_N; it++) {
          sum_loss1 += sum_loss1_shared[it];
          sum_loss2 += sum_loss2_shared[it];
        }
        sum_loss1_shared[0] = sum_loss1;
        sum_loss2_shared[0] = sum_loss2;
      }
      __syncthreads();

      sum_loss1 = sum_loss1_shared[0] * rn;
      sum_loss2 = sum_loss2_shared[0] * rn;
    }

#pragma unroll
    for (int it = 0; it < LDGS; it++) {
#pragma unroll
      for (int jt = 0; jt < VecSize; jt++) {
        U dy_tmp = dy[it * VecSize + jt];  // scale * dy
        U y_tmp = y[it * VecSize + jt];    // y
        // dx = var * (scale * dy - sum_loss2 * y - sum_loss1)
        U dx_tmp = var_cur_row * (dy_tmp - sum_loss2 * y_tmp - sum_loss1);
        // Note: reuse x and dout vec register to store dx and d_dropout_src.
        x[it][jt] = static_cast<T>(dx_tmp);
        if (isFusedDropoutResidualLn) {
          dout[it][jt] = x[it][jt] * static_cast<T>(mask_vec[it][jt]) * factor;
        }
      }
    }

    // store dx to global memory
    col = c;
#pragma unroll
    for (int it = 0; it < LDGS; it++) {
685 686
      phi::Store<T, VecSize>(x[it],
                             dx_ptr + row * ELTS_PER_ROW + col * VecSize);
687
      if (isFusedDropoutResidualLn) {
688
        phi::Store<T, VecSize>(
689
            dout[it], d_dropout_src_ptr + row * ELTS_PER_ROW + col * VecSize);
690 691 692 693 694 695 696
      }
      col += THREADS_PER_ROW;
    }
  }

  // step-2: column reduction of dscale and dbias for each thread block.
  // each block's sum: [4 * 1024] -> [1 * 1024]
697 698
  enum { NUM_RES = ELTS_PER_ROW / THREADS_PER_CTA };  // 1024/128 = 8
  static_assert(NUM_RES * THREADS_PER_CTA == ELTS_PER_ROW, "");
699 700 701

  U *smem_write;

702
  smem_write = &smem_[warp_m * ELTS_PER_ROW + tid_r * VecSize];  // [4 * 1024]
703 704 705 706 707 708 709 710 711 712 713 714 715 716 717
#pragma unroll
  for (int it = 0; it < LDGS; it++) {
#pragma unroll
    for (int jt = 0; jt < VecSize; jt++) {
      smem_write[jt] = dbeta_sum[it * VecSize + jt];
    }
    smem_write += THREADS_PER_ROW * VecSize;  // 32*8
  }
  __syncthreads();
  U cta_dbeta_sum[NUM_RES];
  memset(cta_dbeta_sum, 0, sizeof(U) * NUM_RES);
  // column reduction for elems in smem: 4*1024 -> 1*1024.
  for (int it = 0; it < ROWS_PER_CTA; it++) {
    for (int jt = 0; jt < NUM_RES; jt++) {
      cta_dbeta_sum[jt] +=
718
          smem_[it * ELTS_PER_ROW + tidx + jt * THREADS_PER_CTA];
719 720 721 722
    }
  }
  __syncthreads();

723
  smem_write = &smem_[warp_m * ELTS_PER_ROW + tid_r * VecSize];
724 725 726 727 728 729 730 731 732 733 734 735 736 737
#pragma unroll
  for (int it = 0; it < LDGS; it++) {
#pragma unroll
    for (int jt = 0; jt < VecSize; jt++) {
      smem_write[jt] = dgamma_sum[it * VecSize + jt];
    }
    smem_write += THREADS_PER_ROW * VecSize;
  }
  __syncthreads();
  U cta_dgamma_sum[NUM_RES];
  memset(cta_dgamma_sum, 0, sizeof(U) * NUM_RES);
  for (int it = 0; it < ROWS_PER_CTA; it++) {
    for (int jt = 0; jt < NUM_RES; jt++) {
      cta_dgamma_sum[jt] +=
738
          smem_[it * ELTS_PER_ROW + tidx + jt * THREADS_PER_CTA];
739 740 741 742 743
    }
  }

  // the shape of results:(#blocks, 1024)
  U *dgamma_part =
744
      static_cast<U *>(dgamma_temp_ptr) + bidx * ELTS_PER_ROW + tidx;
745 746 747 748 749
  for (int jt = 0; jt < NUM_RES; jt++) {
    *dgamma_part = cta_dgamma_sum[jt];
    dgamma_part += THREADS_PER_CTA;
  }

750
  U *dbeta_part = static_cast<U *>(dbeta_temp_ptr) + bidx * ELTS_PER_ROW + tidx;
751 752 753 754 755 756 757 758 759 760
  for (int jt = 0; jt < NUM_RES; jt++) {
    *dbeta_part = cta_dbeta_sum[jt];
    dbeta_part += THREADS_PER_CTA;
  }
}

/* This function carry out column reduction whose input is [rows, 1024] and
 * output is [1, 1024].
 * #blocks: 32
 * #threads: 512
761
 */
762
// todo(@limin29): to think if there are better impl strategies
763 764 765 766 767 768 769 770 771 772 773 774 775 776
template <typename U,
          typename ScaleT = U,
          int VecSize = 1,
          int WARPS_M = 16,
          int WARPS_N = 1,
          int BYTES_PER_LDG = 4,
          int ELTS_PER_ROW = 1024,
          int THREADS_PER_WARP = 32,
          int THREADS_PER_ROW = WARPS_N *THREADS_PER_WARP,
          int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW,
          int ROWS_PER_CTA = WARPS_M,
          int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
          int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA,
          int VEC_COLS = ELTS_PER_ROW / VecSize>
777
__global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_fast_final_kernel(
778 779 780 781 782
    const int rows,
    U *__restrict__ dg_part_,
    U *__restrict__ db_part_,
    ScaleT *__restrict__ dg_,
    ScaleT *__restrict__ db_) {
783
  using Vec = phi::AlignedVector<U, VecSize>;
784
  static_assert(VEC_COLS == ELTS_PER_ROW / VecSize, "");
785 786 787 788 789 790 791 792 793 794 795 796 797 798 799

  const int tidx = threadIdx.x;
  const int bidx = blockIdx.x;
  const int lane = tidx % THREADS_PER_WARP;
  const int warp = tidx / THREADS_PER_WARP;
  const int warp_m = warp / WARPS_N;
  const int warp_n = warp % WARPS_N;
  const int tid_c = warp_n * THREADS_PER_WARP + lane;

  const int c = bidx * THREADS_PER_ROW + tid_c;
  const int r = warp_m;

  __shared__ U smem_space[(WARPS_M - 1) * THREADS_PER_ROW * VecSize];

  for (int col = c; col < VEC_COLS; col += gridDim.x * THREADS_PER_ROW) {
800 801
    const U *dg_part_ptr = (dg_part_) + r * ELTS_PER_ROW + col * VecSize;
    const U *db_part_ptr = (db_part_) + r * ELTS_PER_ROW + col * VecSize;
802 803 804 805 806 807 808 809 810

    U dg_sum[VecSize];
    U db_sum[VecSize];
    memset(dg_sum, 0, sizeof(U) * VecSize);
    memset(db_sum, 0, sizeof(U) * VecSize);
#pragma unroll
    for (int row = r; row < rows; row += ROWS_PER_CTA) {
      Vec dg;
      Vec db;
811 812
      phi::Load<U, VecSize>(dg_part_ptr, &dg);
      phi::Load<U, VecSize>(db_part_ptr, &db);
813 814
      dg_part_ptr += ROWS_PER_CTA * ELTS_PER_ROW;
      db_part_ptr += ROWS_PER_CTA * ELTS_PER_ROW;
815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891

#pragma unroll
      for (int jt = 0; jt < VecSize; jt++) {
        dg_sum[jt] += dg[jt];
        db_sum[jt] += db[jt];
      }
    }

    // reduction across rows of the thread block
    U *smem_write;
    smem_write = smem_space + (warp_m - 1) * THREADS_PER_ROW * VecSize + tid_c;

    if (warp_m > 0) {
#pragma unroll
      for (int jt = 0; jt < VecSize; jt++) {
        *smem_write = dg_sum[jt];
        smem_write += THREADS_PER_ROW;
      }
    }
    __syncthreads();

    U *smem_read;
    smem_read = smem_space + tid_c;
    if (warp_m == 0) {
#pragma unroll
      for (int it = 0; it < WARPS_M - 1; it++) {
#pragma unroll
        for (int jt = 0; jt < VecSize; jt++) {
          dg_sum[jt] += *smem_read;
          smem_read += THREADS_PER_ROW;
        }
      }
    }

    __syncthreads();

    smem_write = smem_space + (warp_m - 1) * THREADS_PER_ROW * VecSize + tid_c;

    if (warp_m > 0) {
#pragma unroll
      for (int jt = 0; jt < VecSize; jt++) {
        *smem_write = db_sum[jt];
        smem_write += THREADS_PER_ROW;
      }
    }
    __syncthreads();

    smem_read = smem_space + tid_c;
    if (warp_m == 0) {
#pragma unroll
      for (int it = 0; it < WARPS_M - 1; it++) {
#pragma unroll
        for (int jt = 0; jt < VecSize; jt++) {
          db_sum[jt] += *smem_read;
          smem_read += THREADS_PER_ROW;
        }
      }

      union {
        ScaleT raw;
        ScaleT elt[VecSize];
      } dg_out, db_out;

#pragma unroll
      for (int jt = 0; jt < VecSize; jt++) {
        dg_out.elt[jt] = dg_sum[jt];
        db_out.elt[jt] = db_sum[jt];
      }
      ScaleT *dg_ptr = reinterpret_cast<ScaleT *>(dg_) + col;
      ScaleT *db_ptr = reinterpret_cast<ScaleT *>(db_) + col;
      *dg_ptr = dg_out.raw;
      *db_ptr = db_out.raw;
    }
  }
}

/* This function support two kinds of computations (only for float and fp16
892 893 894 895 896 897 898 899 900 901
 * type):
 *
 * Case-1: compute layer_norm_grad for layernorm op by setting mask_ptr and
 * d_dropout_src_ptr to nullptr. Here, d_x_ptr returns the grad of layernorm
 * input.
 *
 * Case-2: compute layer_norm_grad + residual_grad + dropout_grad for
 * fused_dropout_residual_layernorm op. Here, dx_ptr returns residual_grad.
 *
 */
902 903 904
template <typename T,
          typename U,
          typename ScaleT = U,
905
          typename MaskType = uint8_t>
906 907 908 909 910 911 912 913 914 915 916 917
void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx,
                               const int rows,
                               const int cols,
                               float epsilon,
                               const T *x_ptr,
                               const ScaleT *scale_ptr,
                               const U *mean_ptr,
                               const U *var_ptr,
                               const T *dout_ptr,
                               T *dx_ptr,
                               ScaleT *dscale_ptr,
                               ScaleT *dbias_ptr,
H
hong 已提交
918 919 920
                               const MaskType *mask_ptr = nullptr,
                               T factor = static_cast<T>(0),
                               T *d_dropout_src_ptr = nullptr) {
921
  auto stream = dev_ctx.stream();
922
  if (cols == 1024 || cols == 384 || cols == 256) {
923
    // step-1: compute dx and reduced part results of dscale and dbias.
924 925
    const int WARPS_M = 4;  // how many rows delt in a cta.
    const int WARPS_N = 1;  // how many warps to deal with a row.
926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956
    const int BYTES_PER_LDG = 16;
    const int VecSize = BYTES_PER_LDG / sizeof(T);

    const int THREADS_PER_WARP = 32;
    const int THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP;
    const int THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW;
    const int ROWS_PER_CTA = WARPS_M;

    // 4 * 1024 * 4
    const int SMEM_BYTES = ROWS_PER_CTA * cols * sizeof(U);

    // #blocks = 2 * #SM
    const int gridx = 2 * dev_ctx.GetSMCount();

    // get temp space for dscale and dbias.
    framework::Tensor dscale_temp;
    dscale_temp.Resize({gridx, cols});
    dscale_temp.mutable_data<U>(dev_ctx.GetPlace());
    U *dscale_temp_ptr = dscale_temp.data<U>();

    framework::Tensor dbias_temp;
    dbias_temp.Resize({gridx, cols});
    dbias_temp.mutable_data<U>(dev_ctx.GetPlace());
    U *dbias_temp_ptr = dbias_temp.data<U>();

    if (mask_ptr != nullptr) {
      if (d_dropout_src_ptr == nullptr) {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "To compute fused_dropout_residual_ln grad, d_dropout_src_ptr "
            "can't be null"));
      }
957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980
#define LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \
  fused_ln_bwd_fast_kernel<true,                                    \
                           T,                                       \
                           U,                                       \
                           ScaleT,                                  \
                           MaskType,                                \
                           vec_size,                                \
                           WARPS_M,                                 \
                           WARPS_N,                                 \
                           BYTES_PER_LDG,                           \
                           ele_per_row>                             \
      <<<gridx, THREADS_PER_CTA, 0, stream>>>(rows,                 \
                                              epsilon,              \
                                              x_ptr,                \
                                              scale_ptr,            \
                                              mean_ptr,             \
                                              var_ptr,              \
                                              dout_ptr,             \
                                              dscale_temp_ptr,      \
                                              dbias_temp_ptr,       \
                                              dx_ptr,               \
                                              mask_ptr,             \
                                              factor,               \
                                              d_dropout_src_ptr);
981 982 983 984 985 986 987 988 989 990 991 992 993 994

      if (cols == 1024) {
        LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(VecSize, 1024);
      } else {
        switch (cols) {
          case 384:
            LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(1, 384);
            break;
          case 256:
            LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(VecSize, 256);
            break;
        }
      }
#undef LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL
995 996

    } else {
997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017
#define LAUNCH_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \
  fused_ln_bwd_fast_kernel<false,                              \
                           T,                                  \
                           U,                                  \
                           ScaleT,                             \
                           MaskType,                           \
                           vec_size,                           \
                           WARPS_M,                            \
                           WARPS_N,                            \
                           BYTES_PER_LDG,                      \
                           ele_per_row>                        \
      <<<gridx, THREADS_PER_CTA, 0, stream>>>(rows,            \
                                              epsilon,         \
                                              x_ptr,           \
                                              scale_ptr,       \
                                              mean_ptr,        \
                                              var_ptr,         \
                                              dout_ptr,        \
                                              dscale_temp_ptr, \
                                              dbias_temp_ptr,  \
                                              dx_ptr);
1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032

      if (cols == 1024) {
        LAUNCH_FUSED_LN_BWD_FAST_KERNEL(VecSize, 1024);
      } else {
        switch (cols) {
          case 384:
            LAUNCH_FUSED_LN_BWD_FAST_KERNEL(1, 384);
            break;
          case 256:
            LAUNCH_FUSED_LN_BWD_FAST_KERNEL(VecSize, 256);
            break;
        }
      }

#undef LAUNCH_FUSED_LN_BWD_FAST_KERNEL
1033
    }
1034

1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052
    const int WARPS_M_2 = 16;
    const int WARPS_N_2 = 1;
    const int BYTES_PER_LDG_2 = 4;
    const int VecSize_2 =
        std::max(1, static_cast<int>(BYTES_PER_LDG_2 / sizeof(U)));  // 1

    const int THREADS_PER_WARP_2 = 32;
    const int THREADS_PER_ROW_2 = WARPS_N_2 * THREADS_PER_WARP_2;  // 32
    const int THREADS_PER_CTA_2 =
        WARPS_M_2 * THREADS_PER_ROW_2;     // 16 * 32 = 512
    const int ROWS_PER_CTA_2 = WARPS_M_2;  // 16

    // #blocks: 32,#threads_per_block: 512
    // Note: it is not supported for double type.
    if (sizeof(U) > 4) {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Only support float and fp16 type"));
    } else {
1053 1054 1055 1056 1057
      int gridx_2 = 0;

#define LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(vec_size, ele_per_row)         \
  gridx_2 = static_cast<int>(std::ceil(                                 \
      ele_per_row / static_cast<float>(THREADS_PER_ROW_2 * vec_size))); \
1058 1059 1060 1061 1062 1063 1064
  ln_bwd_fast_final_kernel<U,                                           \
                           ScaleT,                                      \
                           vec_size,                                    \
                           WARPS_M_2,                                   \
                           WARPS_N_2,                                   \
                           BYTES_PER_LDG_2,                             \
                           ele_per_row>                                 \
1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081
      <<<gridx_2, THREADS_PER_CTA_2, 0, stream>>>(                      \
          gridx, dscale_temp_ptr, dbias_temp_ptr, dscale_ptr, dbias_ptr);

      if (cols == 1024) {
        LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(VecSize_2, 1024);
      } else {
        switch (cols) {
          case 384:
            LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(1, 384);
            break;
          case 256:
            LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(VecSize_2, 256);
            break;
        }
      }

#undef LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL
1082 1083 1084 1085 1086 1087 1088 1089
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Fast layer_norm kernel is only used when feature_size is 1024"));
  }
}
#endif

1090
template <typename T, typename U, int BDIMX, int BDIMY, int VPTX>
1091 1092 1093 1094 1095 1096 1097 1098 1099
__global__ void LayerNormBackwardPartGradGammaBeta(const T *__restrict__ dout,
                                                   const T *__restrict__ input,
                                                   const int64_t n1,
                                                   const int64_t n2,
                                                   const U *__restrict__ mean,
                                                   const U *__restrict__ var,
                                                   float epsilon,
                                                   U *part_grad_gamma,
                                                   U *part_grad_beta) {
1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118
  // VPTX -> value per thread.x, BDIMX -> blockDim.x, BDIMY -> blockDim.y, BDIMX
  // -> blockDim.x
  // template for compile time optimizations

  constexpr int row_stride = BDIMX + 1;
  const int thr_load_col_off = (threadIdx.x * VPTX) & (BDIMX - 1);
  const int thr_load_row_off =
      (threadIdx.x * VPTX) / BDIMX + threadIdx.y * BDIMY;
  const int i2_off = blockIdx.x * BDIMX + thr_load_col_off;

  constexpr int shared_cap = (BDIMX * BDIMY > 2 * VPTX * BDIMY * row_stride)
                                 ? BDIMX * BDIMY
                                 : 2 * VPTX * BDIMY * row_stride;
  __shared__ U buf[shared_cap];

  U *warp_buf1 = reinterpret_cast<U *>(buf);
  U *warp_buf2 = warp_buf1 + VPTX * BDIMY * row_stride;

  for (int idx = threadIdx.y * blockDim.x + threadIdx.x;
1119 1120
       idx < 2 * VPTX * BDIMY * row_stride;
       idx += BDIMX * BDIMY) {
1121 1122 1123 1124 1125 1126
    buf[idx] = U(0);
  }
  __syncthreads();

  for (int64_t i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1;
       i1_block += VPTX * BDIMY * gridDim.y) {
1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140
    cuLoadAddStridedInputs<T, U, VPTX>(i1_block,
                                       thr_load_row_off,
                                       thr_load_col_off,
                                       i2_off,
                                       row_stride,
                                       warp_buf1,
                                       warp_buf2,
                                       input,
                                       dout,
                                       n1,
                                       n2,
                                       mean,
                                       var,
                                       epsilon);
1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179
  }
  __syncthreads();

  // inter-warp reductions
  // sum within each warp
  U acc1 = U(0);
  U acc2 = U(0);
  for (int k = 0; k < VPTX; ++k) {
    int row1 = threadIdx.y + k * VPTX;
    int idx1 = row1 * row_stride + threadIdx.x;
    acc1 += warp_buf1[idx1];
    acc2 += warp_buf2[idx1];
  }
  warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
  warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
  __syncthreads();
  // sum all warps
  for (int offset = VPTX >> 1; offset > 1; offset >>= 1) {
    if (threadIdx.y < offset) {
      int row1 = threadIdx.y;
      int row2 = threadIdx.y + offset;
      int idx1 = row1 * row_stride + threadIdx.x;
      int idx2 = row2 * row_stride + threadIdx.x;
      warp_buf1[idx1] += warp_buf1[idx2];
      warp_buf2[idx1] += warp_buf2[idx2];
    }
    __syncthreads();
  }
  int64_t i2 = blockIdx.x * blockDim.x + threadIdx.x;
  if (threadIdx.y == 0 && i2 < n2) {
    int row1 = threadIdx.y;
    int row2 = threadIdx.y + 1;
    int idx1 = row1 * row_stride + threadIdx.x;
    int idx2 = row2 * row_stride + threadIdx.x;
    part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
    part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
  }
}

1180
template <typename T, typename U, int BDIMX, int BDIMY, bool ScaleBiasSameTypeX>
1181
__global__ void LayerNormBackwardSumGradGammaBeta(
1182 1183 1184
    const U *part_grad_gamma,
    const U *part_grad_beta,
    const int part_size,
1185
    // const int n1, const int n2, T* grad_gamma, T* grad_beta) {
1186 1187
    const int n1,
    const int n2,
1188 1189
    LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX> *grad_gamma,
    LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX> *grad_beta) {
1190
  // sum partial gradients for gamma and beta
1191
  using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX>;
1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227
  __shared__ U buf[BDIMX * BDIMY];
  int64_t i2 = blockIdx.x * BDIMX + threadIdx.x;
  if (i2 < n2) {
    // each warp does sequential reductions until reduced part_size is num_warps
    int num_warp_reductions = part_size / BDIMY;
    U sum_gamma = U(0);
    U sum_beta = U(0);
    const U *part_grad_gamma_ptr =
        part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
    const U *part_grad_beta_ptr =
        part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
    for (int warp_offset = 0; warp_offset < num_warp_reductions;
         ++warp_offset) {
      sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
      sum_beta += part_grad_beta_ptr[warp_offset * n2];
    }
    // inter-warp reductions
    constexpr int nbsize3 = BDIMX * BDIMY / 2;
    for (int offset = BDIMY / 2; offset >= 1; offset /= 2) {
      // top half write to shared memory
      if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
        const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
        buf[write_idx] = sum_gamma;
        buf[write_idx + nbsize3] = sum_beta;
      }
      __syncthreads();
      // bottom half sums
      if (threadIdx.y < offset) {
        const int read_idx = threadIdx.y * BDIMX + threadIdx.x;
        sum_gamma += buf[read_idx];
        sum_beta += buf[read_idx + nbsize3];
      }
      __syncthreads();
    }
    // write out fully summed gradients
    if (threadIdx.y == 0) {
1228 1229
      grad_gamma[i2] = static_cast<ScaleBiasT>(sum_gamma);
      grad_beta[i2] = static_cast<ScaleBiasT>(sum_beta);
1230 1231 1232 1233
    }
  }
}

1234
template <typename T, typename U, int BDIMX, int BDIMY, bool ScaleBiasSameTypeX>
1235
__global__ void LayerNormBackwardComputeGradInput(
1236 1237 1238 1239 1240 1241
    const T *__restrict__ dout,
    const T *__restrict__ input,
    const int n1,
    const int n2,
    const U *__restrict__ mean,
    const U *__restrict__ var,
1242
    const float epsilon,
1243 1244
    const LayerNormScaleBiasT<T, U, ScaleBiasSameTypeX> *gamma,
    T *grad_input) {
1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263
#ifdef __HIPCC__
  for (auto i1 = hipBlockIdx_x; i1 < n1; i1 += hipGridDim_x) {
#else
  for (auto i1 = blockIdx.x; i1 < n1; i1 += gridDim.x) {
#endif
    U sum_loss1 = U(0);
    U sum_loss2 = U(0);
    const U c_mean = mean[i1];
    const U c_invvar = rsqrt_<U>(var[i1] + epsilon);
    const T *k_input = input + i1 * n2;
    const T *k_dout = dout + i1 * n2;
    constexpr int numx = BDIMX * BDIMY;
    const int thrx = threadIdx.x + threadIdx.y * BDIMX;
    if (gamma != NULL) {
      int l = 4 * thrx;
      for (; l + 3 < n2; l += 4 * numx) {
        for (int k = 0; k < 4; ++k) {
          const U c_h = static_cast<U>(k_input[l + k]);
          const U c_loss = static_cast<U>(k_dout[l + k]);
1264 1265 1266
          sum_loss1 += c_loss * static_cast<U>(gamma[l + k]);
          sum_loss2 +=
              c_loss * static_cast<U>(gamma[l + k]) * (c_h - c_mean) * c_invvar;
1267 1268 1269 1270 1271
        }
      }
      for (; l < n2; ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
1272 1273 1274
        sum_loss1 += c_loss * static_cast<U>(gamma[l]);
        sum_loss2 +=
            c_loss * static_cast<U>(gamma[l]) * (c_h - c_mean) * c_invvar;
1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295
      }
    } else {
      int l = 4 * thrx;
      for (; l + 3 < n2; l += 4 * numx) {
        for (int k = 0; k < 4; ++k) {
          const U c_h = static_cast<U>(k_input[l + k]);
          const U c_loss = static_cast<U>(k_dout[l + k]);
          sum_loss1 += c_loss;
          sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
        }
      }
      for (; l < n2; ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        sum_loss1 += c_loss;
        sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
      }
    }
    // intra-warp reductions
    for (int mask = BDIMX / 2; mask > 0; mask /= 2) {
#ifdef PADDLE_WITH_HIP
1296 1297
      sum_loss1 += __shfl_xor(sum_loss1,
                              mask,
1298
                              warpSize);  // WARP_SHFL_XOR(sum_loss1, mask);
1299 1300
      sum_loss2 += __shfl_xor(sum_loss2,
                              mask,
1301 1302 1303
                              warpSize);  // WARP_SHFL_XOR(sum_loss2, mask);
#else
      sum_loss1 +=
1304 1305 1306
          __shfl_xor_sync(0xffffffff,
                          sum_loss1,
                          mask,
1307 1308
                          warpSize);  // WARP_SHFL_XOR(sum_loss1, mask);
      sum_loss2 +=
1309 1310 1311
          __shfl_xor_sync(0xffffffff,
                          sum_loss2,
                          mask,
1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351
                          warpSize);  // WARP_SHFL_XOR(sum_loss2, mask);
#endif
    }
    // inter-warp reductions
    if (BDIMY > 1) {
      __shared__ U buf[BDIMX * BDIMY];
      for (int offset = BDIMY / 2; offset > 0; offset /= 2) {
        // upper half of warps write to shared
        if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
          const int wrt_i = (threadIdx.y - offset) * BDIMX + threadIdx.x;
          buf[2 * wrt_i] = sum_loss1;
          buf[2 * wrt_i + 1] = sum_loss2;
        }
        __syncthreads();
        // lower half merges
        if (threadIdx.y < offset) {
          const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
          sum_loss1 += buf[2 * read_i];
          sum_loss2 += buf[2 * read_i + 1];
        }
        __syncthreads();
      }
      if (threadIdx.y == 0) {
        buf[2 * threadIdx.x] = sum_loss1;
        buf[2 * threadIdx.x + 1] = sum_loss2;
      }
      __syncthreads();
      if (threadIdx.y != 0) {
        sum_loss1 = buf[2 * threadIdx.x];
        sum_loss2 = buf[2 * threadIdx.x + 1];
      }
    }
    // all threads now have the two sums over l
    U fH = (U)n2;
    U term1 = (U(1) / fH) * c_invvar;
    T *k_grad_input = grad_input + i1 * n2;
    if (gamma != NULL) {
      for (int l = thrx; l < n2; l += numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
1352
        U f_grad_input = fH * c_loss * static_cast<U>(gamma[l]);
1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373
        f_grad_input -= sum_loss1;
        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
      }
    } else {
      for (int l = thrx; l < n2; l += numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        U f_grad_input = fH * c_loss;
        f_grad_input -= sum_loss1;
        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
      }
    }
  }
}

// Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr
1374 1375 1376 1377
template <typename T,
          typename U,
          int BlockDim,
          bool HasDx,
1378
          bool ScaleBiasWithSameTypeX>
1379
__global__ void LayerNormBackwardGradientAll(
1380 1381
    const T *x,
    const T *d_y,
1382
    LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
1383 1384 1385 1386
    LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias,
    T *d_x,
    const U *mean,
    const U *var,
1387
    const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
1388 1389 1390
    float epsilon,
    int64_t batch_size,
    int64_t feature_size,
1391 1392
    int64_t col_offset) {
  using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406
  int64_t beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset);
  int64_t end_idx = batch_size * feature_size + (blockIdx.x + col_offset);
  int64_t stride = BlockDim * feature_size;

  U d_scale_partial = static_cast<U>(0), d_bias_partial = static_cast<U>(0);

  for (int64_t i = beg_idx; i < end_idx; i += stride) {
    int row_idx = i / feature_size;
    auto var_val = real_sqrt(static_cast<U>(var[row_idx]) + epsilon);
    d_scale_partial += static_cast<U>(d_y[i]) *
                       (static_cast<U>(x[i]) - mean[row_idx]) / var_val;
    d_bias_partial += static_cast<U>(d_y[i]);
    if (HasDx) {
      d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
1407 1408
                              static_cast<U>(scale[blockIdx.x + col_offset]) /
                              var_val);
1409 1410 1411 1412 1413 1414 1415 1416 1417 1418
    }
  }

  __shared__ U shared_scale[32];  // threadIdx.x / warpSize <= kMaxBlockDim /
                                  // warpSize <= 1024/32 = 32;
  __shared__ U shared_bias[32];
  d_scale_partial = BlockReduceSum<U>(d_scale_partial, shared_scale);
  d_bias_partial = BlockReduceSum<U>(d_bias_partial, shared_bias);

  if (threadIdx.x == 0) {
1419 1420
    d_scale[blockIdx.x + col_offset] = static_cast<ScaleBiasT>(d_scale_partial);
    d_bias[blockIdx.x + col_offset] = static_cast<ScaleBiasT>(d_bias_partial);
1421 1422 1423 1424 1425 1426
  }
}

// Make sure that there is only one true expression: d_scale != nullptr
// or d_bias != nullptr
// Notice: scale may be nullptr
1427 1428 1429 1430 1431
template <typename T,
          typename U,
          int BlockDim,
          bool HasDx,
          bool HasDScale,
1432
          bool ScaleBiasWithSameTypeX>
1433
__global__ void LayerNormBackwardGradientScaleOrBias(
1434 1435
    const T *x,
    const T *d_y,
1436
    LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
1437 1438 1439 1440
    LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias,
    T *d_x,
    const U *mean,
    const U *var,
1441
    const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
1442 1443 1444 1445
    float epsilon,
    int64_t batch_size,
    int64_t feature_size,
    int col_offset) {
1446
  using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468
  using BlockReduce = cub::BlockReduce<U, BlockDim>;
  __shared__ typename BlockReduce::TempStorage temp_storage;
  int64_t beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset;
  int64_t end_idx = batch_size * feature_size + blockIdx.x + col_offset;
  int stride = BlockDim * feature_size;
  U d_scale_or_d_bias_partial = static_cast<U>(0);

  for (int64_t i = beg_idx; i < end_idx; i += stride) {
    int row_idx = i / feature_size;
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(var[row_idx]) + epsilon));
    if (HasDScale) {
      d_scale_or_d_bias_partial += static_cast<U>(d_y[i]) *
                                   (static_cast<U>(x[i]) - mean[row_idx]) /
                                   var_val;
    } else {  // d_bias != nullptr
      d_scale_or_d_bias_partial += static_cast<U>(d_y[i]);
    }

    if (HasDx) {
      if (scale != nullptr) {
        d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
1469 1470
                                static_cast<U>(scale[blockIdx.x + col_offset]) /
                                var_val);
1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481
      } else {
        d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
      }
    }
  }

  d_scale_or_d_bias_partial =
      BlockReduce(temp_storage).Reduce(d_scale_or_d_bias_partial, cub::Sum());

  if (threadIdx.x == 0) {
    if (HasDScale) {
1482 1483
      d_scale[blockIdx.x + col_offset] =
          static_cast<ScaleBiasT>(d_scale_or_d_bias_partial);
1484
    } else {
1485 1486
      d_bias[blockIdx.x + col_offset] =
          static_cast<ScaleBiasT>(d_scale_or_d_bias_partial);
1487 1488 1489 1490 1491 1492
    }
  }
}

template <typename T, typename U, int BlockDim>
__global__ void LayerNormBackwardPostProcessToCalculateDX(
1493 1494 1495 1496 1497
    const T *x,
    T *d_x,
    const U *mean,
    const U *var,
    float epsilon,
1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537
    int64_t feature_size) {
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
  __shared__ typename BlockReduce::TempStorage temp_storage;
  __shared__ U d_x_reduce_tmp[2];

  int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int64_t end_idx = (blockIdx.x + 1) * feature_size;

  U block_mean = mean[blockIdx.x];
  U block_var = var[blockIdx.x];
  U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
  for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
    d_x_mean_partial += static_cast<U>(d_x[i]);
    d_x_var_partial +=
        static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
  }

  auto pair =
      BlockReduce(temp_storage)
          .Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<U>());

  if (threadIdx.x == 0) {
    d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
    d_x_reduce_tmp[1] =
        static_cast<float>(pair.second_) /
        (feature_size * (static_cast<float>(block_var) + epsilon));
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
  for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
    d_x[i] -= static_cast<T>(d_x_mean_partial);
    d_x[i] -=
        static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
  }
}

// Here, we only calculate d_x
1538 1539
template <typename T, typename U, int BlockDim, bool ScaleBiasWithSameTypeX>
__global__ void LayerNormBackwardGradientOnlyDX(
1540 1541 1542 1543 1544
    const T *x,
    const T *d_y,
    T *d_x,
    const U *mean,
    const U *var,
1545
    const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
1546 1547
    float epsilon,
    int64_t feature_size) {
1548
  using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
  __shared__ typename BlockReduce::TempStorage temp_storage;
  __shared__ U d_x_reduce_tmp[2];

  int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int64_t end_idx = (blockIdx.x + 1) * feature_size;

  U block_mean = mean[blockIdx.x], block_var = var[blockIdx.x];
  U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
  for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon));
    if (scale != nullptr) {
      int col_idx = i % feature_size;
1563 1564
      d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
                              static_cast<U>(scale[col_idx]) / var_val);
1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594
    } else {
      d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
    }
    d_x_mean_partial += static_cast<U>(d_x[i]);
    d_x_var_partial +=
        static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
  }

  auto pair =
      BlockReduce(temp_storage)
          .Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<U>());

  if (threadIdx.x == 0) {
    d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
    d_x_reduce_tmp[1] =
        static_cast<float>(pair.second_) /
        (feature_size * (static_cast<float>(block_var) + epsilon));
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
  for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
    d_x[i] -= static_cast<T>(d_x_mean_partial);
    d_x[i] -=
        static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
  }
}

1595
template <typename T, typename U, bool ScaleBiasWithSameTypeX>
1596
__global__ void LayerNormBackwardWhenBatchSizeIsOne(
1597 1598 1599
    const T *x,
    const T *d_y,
    T *d_x,
1600
    LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
1601 1602
    LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias,
    const U *mean,
1603 1604
    const U *var,
    const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
1605 1606
    float epsilon,
    int64_t feature_size) {
1607
  int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
1608
  using ScaleBiasT = LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>;
1609 1610
  if (idx < feature_size) {
    auto var_val =
1611
        static_cast<U>(real_sqrt(static_cast<float>(var[0]) + epsilon));
1612 1613 1614 1615
    if (d_x != nullptr) {
      if (d_scale == nullptr) {
        d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val);
      } else {
1616 1617
        d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) *
                                  static_cast<U>(scale[idx]) / var_val);
1618 1619 1620 1621
      }
    }

    if (d_scale != nullptr) {
1622 1623 1624
      d_scale[idx] =
          static_cast<ScaleBiasT>(static_cast<U>(d_y[idx]) *
                                  (static_cast<U>(x[idx]) - mean[0]) / var_val);
1625 1626
    }

1627 1628 1629
    if (d_bias != nullptr) {
      d_bias[idx] = static_cast<ScaleBiasT>(d_y[idx]);
    }
1630 1631 1632
  }
}

1633 1634
template <typename T, typename U, bool ScaleBiasWithSameTypeX = false>
static void LayerNormBackward(
1635 1636
    const T *x,
    const T *d_y,
1637
    const LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *scale,
1638 1639 1640
    const U *mean,
    const U *var,
    T *d_x,
1641
    LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
1642 1643 1644 1645 1646
    LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias,
    float epsilon,
    int64_t batch_size,
    int64_t feature_size,
    const phi::GPUContext &dev_ctx) {
1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659
  auto stream = dev_ctx.stream();
#ifdef __HIPCC__
  const int kMaxBlockDim = 256;
#else
  const int kMaxBlockDim = 512;
#endif
  const int kMaxBlockNum = 128;
  int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) |
                      ((d_scale != nullptr ? 1 : 0) << 1) |
                      ((d_bias != nullptr ? 1 : 0));
  if (gradient_flag == 0) return;

  if (batch_size == 1) {
1660
    LayerNormBackwardWhenBatchSizeIsOne<T, U, ScaleBiasWithSameTypeX>
1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672
        <<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim,
           kMaxBlockDim,
           0,
           stream>>>(x,
                     d_y,
                     d_x,
                     d_scale,
                     d_bias,
                     mean,
                     var,
                     scale,
                     epsilon,
1673
                     feature_size);
1674 1675 1676

    if (d_x != nullptr) {
      switch (GetDesiredBlockDim(feature_size)) {
1677 1678
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<T, U, kBlockDim>
1679 1680
            <<<1, kBlockDim, 0, stream>>>(
                x, d_x, mean, var, epsilon, feature_size));
1681 1682 1683 1684 1685 1686 1687 1688 1689 1690
      }
    }
    return;
  }

  auto block_dim = GetDesiredBlockDim(batch_size);
  switch (gradient_flag) {
    case 1:  // d_x == nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
1691 1692 1693 1694 1695 1696 1697
            feature_size,
            kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<T,
                                                 U,
                                                 kBlockDim,
                                                 false,
                                                 false,
1698
                                                 ScaleBiasWithSameTypeX>
1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710
            <<<block_num, kBlockDim, 0, stream>>>(x,
                                                  d_y,
                                                  d_scale,
                                                  d_bias,
                                                  d_x,
                                                  mean,
                                                  var,
                                                  scale,
                                                  epsilon,
                                                  batch_size,
                                                  feature_size,
                                                  col_offset));
1711 1712 1713 1714 1715
      }
      break;
    case 2:  // d_x == nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
1716 1717 1718 1719 1720 1721 1722
            feature_size,
            kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<T,
                                                 U,
                                                 kBlockDim,
                                                 false,
                                                 true,
1723
                                                 ScaleBiasWithSameTypeX>
1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735
            <<<block_num, kBlockDim, 0, stream>>>(x,
                                                  d_y,
                                                  d_scale,
                                                  d_bias,
                                                  d_x,
                                                  mean,
                                                  var,
                                                  scale,
                                                  epsilon,
                                                  batch_size,
                                                  feature_size,
                                                  col_offset));
1736 1737 1738 1739 1740
      }
      break;
    case 3:  // d_x == nullptr, d_scale != nulptr, d_bias != nullptr
      switch (block_dim) {
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
1741 1742 1743 1744 1745 1746
            feature_size,
            kMaxBlockNum,
            LayerNormBackwardGradientAll<T,
                                         U,
                                         kBlockDim,
                                         false,
1747
                                         ScaleBiasWithSameTypeX>
1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759
            <<<block_num, kBlockDim, 0, stream>>>(x,
                                                  d_y,
                                                  d_scale,
                                                  d_bias,
                                                  d_x,
                                                  mean,
                                                  var,
                                                  scale,
                                                  epsilon,
                                                  batch_size,
                                                  feature_size,
                                                  col_offset));
1760 1761 1762 1763 1764
      }
      break;
    case 4:  // d_x != nullptr, d_scale == nullptr, d_bias == nullptr
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
1765 1766 1767
            LayerNormBackwardGradientOnlyDX<T,
                                            U,
                                            kBlockDim,
1768 1769
                                            ScaleBiasWithSameTypeX>
            <<<batch_size, kBlockDim, 0, stream>>>(
1770 1771 1772 1773 1774 1775
                x, d_y, d_x, mean, var, scale, epsilon, feature_size));
      }
      break;
    case 5:  // d_x != nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
1776 1777 1778 1779 1780 1781 1782
            feature_size,
            kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<T,
                                                 U,
                                                 kBlockDim,
                                                 true,
                                                 false,
1783
                                                 ScaleBiasWithSameTypeX>
1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795
            <<<block_num, kBlockDim, 0, stream>>>(x,
                                                  d_y,
                                                  d_scale,
                                                  d_bias,
                                                  d_x,
                                                  mean,
                                                  var,
                                                  scale,
                                                  epsilon,
                                                  batch_size,
                                                  feature_size,
                                                  col_offset));
1796 1797 1798
      }
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
1799
            LayerNormBackwardPostProcessToCalculateDX<T, U, kBlockDim>
1800 1801
            <<<batch_size, kBlockDim, 0, stream>>>(
                x, d_x, mean, var, epsilon, feature_size));
1802 1803 1804 1805 1806
      }
      break;
    case 6:  // d_x != nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
1807 1808 1809 1810 1811 1812 1813
            feature_size,
            kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<T,
                                                 U,
                                                 kBlockDim,
                                                 true,
                                                 true,
1814
                                                 ScaleBiasWithSameTypeX>
1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826
            <<<block_num, kBlockDim, 0, stream>>>(x,
                                                  d_y,
                                                  d_scale,
                                                  d_bias,
                                                  d_x,
                                                  mean,
                                                  var,
                                                  scale,
                                                  epsilon,
                                                  batch_size,
                                                  feature_size,
                                                  col_offset));
1827 1828 1829
      }
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
1830
            LayerNormBackwardPostProcessToCalculateDX<T, U, kBlockDim>
1831 1832
            <<<batch_size, kBlockDim, 0, stream>>>(
                x, d_x, mean, var, epsilon, feature_size));
1833 1834 1835 1836
      }
      break;
    case 7:  // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
    {
1837
#ifdef PADDLE_WITH_CUDA
1838
      bool can_call_fast_kernel = false;
1839
      // todo: rule out double type.
1840 1841 1842 1843
      if ((feature_size == 1024 || feature_size == 384 ||
           feature_size == 256) &&
          sizeof(T) <= 4) {
        can_call_fast_kernel = true;
1844 1845
      }

1846 1847 1848
      VLOG(6) << "can_call_fast_kernel = " << can_call_fast_kernel;
      if (can_call_fast_kernel) {
        ln_bwd_fast_kernel_driver<
1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862
            T,
            U,
            LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>>(dev_ctx,
                                                               batch_size,
                                                               feature_size,
                                                               epsilon,
                                                               x,
                                                               scale,
                                                               mean,
                                                               var,
                                                               d_y,
                                                               d_x,
                                                               d_scale,
                                                               d_bias);
1863 1864 1865 1866 1867 1868 1869 1870 1871
      } else {
#endif
        constexpr int VPT = 4;
        constexpr int BDIMX2 = 32;
        constexpr int BDIMY2 = 4;
        dim3 threads2(BDIMX2, BDIMY2, 1);
        constexpr int part_size = BDIMY2 * VPT;
        const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1);

1872 1873 1874 1875 1876 1877 1878 1879
        auto part_grad_gamma_ptr = memory::Alloc(
            dev_ctx.GetPlace(),
            part_size * feature_size * sizeof(U),
            phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
        auto part_grad_beta_ptr = memory::Alloc(
            dev_ctx.GetPlace(),
            part_size * feature_size * sizeof(U),
            phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
1880 1881 1882
        U *part_grad_gamma = reinterpret_cast<U *>(part_grad_gamma_ptr->ptr());
        U *part_grad_beta = reinterpret_cast<U *>(part_grad_beta_ptr->ptr());

1883 1884
        LayerNormBackwardPartGradGammaBeta<T, U, BDIMX2, BDIMY2, VPT>
            <<<blocks2, threads2, 0, stream>>>(
1885 1886 1887 1888 1889 1890 1891
                d_y,
                x,
                batch_size,
                feature_size,
                mean,
                var,
                epsilon,
1892 1893
                part_grad_gamma,
                part_grad_beta);  // compute part_grad_gamma, beta
1894 1895 1896 1897 1898

        constexpr int BDIMX3 = 32;
        constexpr int BDIMY3 = 8;
        dim3 threads3(BDIMX3, BDIMY3, 1);
        const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1);
1899 1900 1901 1902
        LayerNormBackwardSumGradGammaBeta<T,
                                          U,
                                          BDIMX3,
                                          BDIMY3,
1903
                                          ScaleBiasWithSameTypeX>
1904 1905 1906 1907 1908 1909 1910
            <<<blocks3, threads3, 0, stream>>>(part_grad_gamma,
                                               part_grad_beta,
                                               part_size,
                                               batch_size,
                                               feature_size,
                                               d_scale,
                                               d_bias);
1911 1912 1913 1914

        constexpr int BDIMX1 = 32;
        constexpr int BDIMY1 = 4;
        dim3 threads1(BDIMX1, BDIMY1, 1);
1915 1916 1917 1918
        LayerNormBackwardComputeGradInput<T,
                                          U,
                                          BDIMX1,
                                          BDIMY1,
1919
                                          ScaleBiasWithSameTypeX>
1920 1921 1922 1923 1924 1925 1926 1927 1928
            <<<batch_size, threads1, 0, stream>>>(d_y,
                                                  x,
                                                  batch_size,
                                                  feature_size,
                                                  mean,
                                                  var,
                                                  epsilon,
                                                  scale,
                                                  d_x);
1929 1930 1931 1932
#ifdef PADDLE_WITH_CUDA
      }
#endif

1933 1934 1935 1936 1937 1938 1939 1940 1941
      break;
    }
    default:
      break;
  }
}

}  // namespace operators
}  // namespace paddle