layer_norm_op.cu 38.9 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16 17 18 19 20 21
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
P
Pei Yang 已提交
22 23
#include <memory>
#include <vector>
F
furnace 已提交
24

P
Pei Yang 已提交
25
#include "paddle/fluid/framework/ddim.h"
Y
Yi Wang 已提交
26
#include "paddle/fluid/operators/layer_norm_op.h"
F
furnace 已提交
27
#include "paddle/fluid/platform/float16.h"
28 29 30 31 32 33
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
C
chengduoZH 已提交
34

S
sneaxiy 已提交
35 36 37
namespace paddle {
namespace operators {

F
furnace 已提交
38 39 40 41 42 43 44
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;

Z
zhiboniu 已提交
45
inline static int GetDesiredBlockDim(int64_t block_dim) {
46 47
#ifdef __HIPCC__
  const int kMaxBlockDim = 256;
Z
zhiboniu 已提交
48
  const int lwarpSize = 64;
49
#else
S
sneaxiy 已提交
50
  const int kMaxBlockDim = 512;
Z
zhiboniu 已提交
51
  const int lwarpSize = 32;
52
#endif
Z
zhiboniu 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
  return block_dim >= kMaxBlockDim ? kMaxBlockDim : lwarpSize;
}

template <typename U>
static __forceinline__ __device__ U WarpReduceSum(U val) {
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int offset = warpSize / 2; offset > 0; offset /= 2) {
    val += paddle::platform::CudaShuffleDownSync(mask, val, offset);
  }
  return val;
}

template <typename U>
__forceinline__ __device__ U BlockReduceSum(U val) {
  static __shared__ U shared[32];
  int lane = threadIdx.x % warpSize;
  int wid = threadIdx.x / warpSize;

  val = WarpReduceSum(val);  // Each warp performs partial reduction

  if (lane == 0) shared[wid] = val;  // Write reduced value to shared memory

  __syncthreads();  // Wait for all partial reductions

  // read from shared memory only if that warp existed
  val =
      (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : static_cast<U>(0);

  if (wid == 0) val = WarpReduceSum(val);  // Final reduce within first warp

  return val;
S
sneaxiy 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
}

#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...)  \
  case (1 << (log2_block_dim)): {                       \
    constexpr auto kBlockDim = (1 << (log2_block_dim)); \
    __VA_ARGS__;                                        \
  } break

#define FIXED_BLOCK_DIM_CASE(...)              \
  FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(2, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(1, ##__VA_ARGS__)

Z
zhiboniu 已提交
104 105 106 107 108 109 110 111 112 113 114
#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                          \
    log2_block_dim, feature_size, kMaxBlockNum, ...)                        \
  case (1 << (log2_block_dim)): {                                           \
    for (int64_t i = 0; i < std::ceil(feature_size / (1.0 * kMaxBlockNum)); \
         i++) {                                                             \
      int64_t col_offset = i * static_cast<int64_t>(kMaxBlockNum);          \
      int block_num = static_cast<int>(std::min(                            \
          feature_size - col_offset, static_cast<int64_t>(kMaxBlockNum)));  \
      constexpr auto kBlockDim = (1 << (log2_block_dim));                   \
      __VA_ARGS__;                                                          \
    }                                                                       \
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
  } break

#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(feature_size, kMaxBlockNum, ...) \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(9, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(8, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(7, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(6, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(5, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(4, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(3, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(2, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(1, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__)

137 138 139
static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); }
static __device__ __forceinline__ double real_sqrt(double x) { return sqrt(x); }

140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
template <typename T>
struct PairForLayerNorm {
  __device__ __forceinline__ PairForLayerNorm() {}
  __device__ __forceinline__ PairForLayerNorm(const T &first, const T &second)
      : first_(first), second_(second) {}

  T first_;
  T second_;
};

template <typename T>
struct PairForLayerNormAddFunctor {
  __device__ __forceinline__ PairForLayerNorm<T> operator()(
      const PairForLayerNorm<T> &p1, const PairForLayerNorm<T> &p2) {
    return PairForLayerNorm<T>(p1.first_ + p2.first_, p1.second_ + p2.second_);
  }
};

L
Leo Chen 已提交
158
template <typename T>
159
__inline__ __device__ T rsqrt_(const T val) {
160
  return static_cast<T>(1) / sqrt(val);
L
Leo Chen 已提交
161 162 163
}

template <>
164
__inline__ __device__ float rsqrt_(const float val) {
L
Leo Chen 已提交
165 166 167
  return rsqrtf(val);
}

168
template <>
169
__inline__ __device__ double rsqrt_(const double val) {
170 171 172 173
  return rsqrt(val);
}

#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
L
Leo Chen 已提交
174
template <>
175
__inline__ __device__ half rsqrt_(const half val) {
L
Leo Chen 已提交
176 177
  return hrsqrt(val);
}
178
#endif
L
Leo Chen 已提交
179

F
furnace 已提交
180 181 182
template <typename T, typename U, int BlockDim>
__global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
                                 T *y, U *mean, U *var, float epsilon,
Z
zhiboniu 已提交
183
                                 int64_t feature_size) {
L
Leo Chen 已提交
184 185
  __shared__ U mean_share;
  __shared__ U var_share;
S
sneaxiy 已提交
186

Z
zhiboniu 已提交
187 188
  int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int64_t end_idx = (blockIdx.x + 1) * feature_size;
S
sneaxiy 已提交
189

190
  // Step 1: Reduce to calculate mean and var
L
Leo Chen 已提交
191 192
  U mean_val = 0;
  U var_val = 0;
Z
zhiboniu 已提交
193
  for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
194
    U tmp = static_cast<U>(x[i]);
195
    mean_val += tmp;
S
sneaxiy 已提交
196 197
    var_val += (tmp * tmp);
  }
Z
zhiboniu 已提交
198 199 200 201

  mean_val = BlockReduceSum<U>(mean_val);
  var_val = BlockReduceSum<U>(var_val);

202
  if (threadIdx.x == 0) {
Z
zhiboniu 已提交
203 204
    auto scale = static_cast<float>(1.) / static_cast<float>(feature_size);
    auto tmp = mean_val * scale;
L
Leo Chen 已提交
205
    mean[blockIdx.x] = mean_share = static_cast<U>(tmp);
Z
zhiboniu 已提交
206 207 208
    var_share = static_cast<U>(var_val * scale - mean_share * mean_share);
    var_share = var_share > U(0) ? var_share : U(0);
    var[blockIdx.x] = var_share;
209
  }
S
sneaxiy 已提交
210
  __syncthreads();
L
Leo Chen 已提交
211 212

  mean_val = mean_share;
213
  U invvar = rsqrt_<U>(var_share + static_cast<U>(epsilon));
S
sneaxiy 已提交
214

215
  // Step 2: Calculate y
S
sneaxiy 已提交
216 217
  if (scale != nullptr) {
    if (bias != nullptr) {
Z
zhiboniu 已提交
218
      for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
S
sneaxiy 已提交
219
           i += BlockDim, j += BlockDim) {
F
furnace 已提交
220
        y[i] = static_cast<T>(
L
Leo Chen 已提交
221
            scale[j] * (static_cast<U>(x[i]) - mean_val) * invvar + bias[j]);
S
sneaxiy 已提交
222 223
      }
    } else {
Z
zhiboniu 已提交
224
      for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
S
sneaxiy 已提交
225
           i += BlockDim, j += BlockDim) {
L
Leo Chen 已提交
226 227
        y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) *
                              invvar);
S
sneaxiy 已提交
228 229 230 231
      }
    }
  } else {  // scale == nullptr
    if (bias != nullptr) {
Z
zhiboniu 已提交
232
      for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
S
sneaxiy 已提交
233
           i += BlockDim, j += BlockDim) {
L
Leo Chen 已提交
234
        y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
F
furnace 已提交
235
                              bias[j]);
S
sneaxiy 已提交
236 237
      }
    } else {
Z
zhiboniu 已提交
238
      for (int64_t i = beg_idx, j = threadIdx.x; i < end_idx;
S
sneaxiy 已提交
239
           i += BlockDim, j += BlockDim) {
L
Leo Chen 已提交
240 241 242 243 244 245 246 247
        y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar);
      }
    }
  }
}

template <typename T, typename U, int VPT>
__inline__ __device__ void cuLoadAddStridedInputs(
Z
zhiboniu 已提交
248 249 250 251 252 253
    const int64_t i1_block, const int thr_load_row_off,
    const int thr_load_col_off, const int i2_off, const int row_stride,
    U *warp_buf1, U *warp_buf2, const T *input, const T *dout,
    const int64_t i1_end, const int64_t n2, const U *__restrict__ mean,
    const U *__restrict__ var, const float epsilon) {
  const int64_t i1 = i1_block + thr_load_row_off;
L
Leo Chen 已提交
254 255
  if (i1 >= i1_end) return;
  U curr_mean = mean[i1];
256
  U curr_invvar = rsqrt_<U>(var[i1] + epsilon);
L
Leo Chen 已提交
257 258
  for (int k = 0; k < VPT; ++k) {
    const int i2 = i2_off + k;
Z
zhiboniu 已提交
259
    const int64_t load_idx = i1 * n2 + i2;
L
Leo Chen 已提交
260 261 262 263 264 265 266 267 268 269 270 271 272
    const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
    if (i2 < n2) {
      U curr_input = static_cast<U>(input[load_idx]);
      U curr_dout = static_cast<U>(dout[load_idx]);
      warp_buf1[write_idx] += curr_dout;
      warp_buf2[write_idx] +=
          curr_dout * (curr_input - curr_mean) * curr_invvar;
    }
  }
}

template <typename T, typename U, int BDIMX, int BDIMY, int VPTX>
__global__ void LayerNormBackwardPartGradGammaBeta(
Z
zhiboniu 已提交
273 274
    const T *__restrict__ dout, const T *__restrict__ input, const int64_t n1,
    const int64_t n2, const U *__restrict__ mean, const U *__restrict__ var,
L
Leo Chen 已提交
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
    float epsilon, U *part_grad_gamma, U *part_grad_beta) {
  // VPTX -> value per thread.x, BDIMX -> blockDim.x, BDIMY -> blockDim.y, BDIMX
  // -> blockDim.x
  // template for compile time optimizations

  constexpr int row_stride = BDIMX + 1;
  const int thr_load_col_off = (threadIdx.x * VPTX) & (BDIMX - 1);
  const int thr_load_row_off =
      (threadIdx.x * VPTX) / BDIMX + threadIdx.y * BDIMY;
  const int i2_off = blockIdx.x * BDIMX + thr_load_col_off;

  constexpr int shared_cap = (BDIMX * BDIMY > 2 * VPTX * BDIMY * row_stride)
                                 ? BDIMX * BDIMY
                                 : 2 * VPTX * BDIMY * row_stride;
  __shared__ U buf[shared_cap];

  U *warp_buf1 = reinterpret_cast<U *>(buf);
  U *warp_buf2 = warp_buf1 + VPTX * BDIMY * row_stride;

  for (int idx = threadIdx.y * blockDim.x + threadIdx.x;
       idx < 2 * VPTX * BDIMY * row_stride; idx += BDIMX * BDIMY) {
    buf[idx] = U(0);
  }
  __syncthreads();

Z
zhiboniu 已提交
300
  for (int64_t i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1;
L
Leo Chen 已提交
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
       i1_block += VPTX * BDIMY * gridDim.y) {
    cuLoadAddStridedInputs<T, U, VPTX>(
        i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride,
        warp_buf1, warp_buf2, input, dout, n1, n2, mean, var, epsilon);
  }
  __syncthreads();

  // inter-warp reductions
  // sum within each warp
  U acc1 = U(0);
  U acc2 = U(0);
  for (int k = 0; k < VPTX; ++k) {
    int row1 = threadIdx.y + k * VPTX;
    int idx1 = row1 * row_stride + threadIdx.x;
    acc1 += warp_buf1[idx1];
    acc2 += warp_buf2[idx1];
  }
  warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
  warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
  __syncthreads();
  // sum all warps
  for (int offset = VPTX >> 1; offset > 1; offset >>= 1) {
    if (threadIdx.y < offset) {
      int row1 = threadIdx.y;
      int row2 = threadIdx.y + offset;
      int idx1 = row1 * row_stride + threadIdx.x;
      int idx2 = row2 * row_stride + threadIdx.x;
      warp_buf1[idx1] += warp_buf1[idx2];
      warp_buf2[idx1] += warp_buf2[idx2];
    }
    __syncthreads();
  }
Z
zhiboniu 已提交
333
  int64_t i2 = blockIdx.x * blockDim.x + threadIdx.x;
L
Leo Chen 已提交
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
  if (threadIdx.y == 0 && i2 < n2) {
    int row1 = threadIdx.y;
    int row2 = threadIdx.y + 1;
    int idx1 = row1 * row_stride + threadIdx.x;
    int idx2 = row2 * row_stride + threadIdx.x;
    part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
    part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
  }
}

template <typename T, typename U, int BDIMX, int BDIMY>
__global__ void LayerNormBackwardSumGradGammaBeta(
    const U *part_grad_gamma, const U *part_grad_beta, const int part_size,
    // const int n1, const int n2, T* grad_gamma, T* grad_beta) {
    const int n1, const int n2, U *grad_gamma, U *grad_beta) {
  // sum partial gradients for gamma and beta
  __shared__ U buf[BDIMX * BDIMY];
Z
zhiboniu 已提交
351
  int64_t i2 = blockIdx.x * BDIMX + threadIdx.x;
L
Leo Chen 已提交
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
  if (i2 < n2) {
    // each warp does sequential reductions until reduced part_size is num_warps
    int num_warp_reductions = part_size / BDIMY;
    U sum_gamma = U(0);
    U sum_beta = U(0);
    const U *part_grad_gamma_ptr =
        part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
    const U *part_grad_beta_ptr =
        part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
    for (int warp_offset = 0; warp_offset < num_warp_reductions;
         ++warp_offset) {
      sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
      sum_beta += part_grad_beta_ptr[warp_offset * n2];
    }
    // inter-warp reductions
    constexpr int nbsize3 = BDIMX * BDIMY / 2;
    for (int offset = BDIMY / 2; offset >= 1; offset /= 2) {
      // top half write to shared memory
      if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
        const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
        buf[write_idx] = sum_gamma;
        buf[write_idx + nbsize3] = sum_beta;
      }
      __syncthreads();
      // bottom half sums
      if (threadIdx.y < offset) {
        const int read_idx = threadIdx.y * BDIMX + threadIdx.x;
        sum_gamma += buf[read_idx];
        sum_beta += buf[read_idx + nbsize3];
      }
      __syncthreads();
    }
    // write out fully summed gradients
    if (threadIdx.y == 0) {
      grad_gamma[i2] = sum_gamma;
      grad_beta[i2] = sum_beta;
    }
  }
}

template <typename T, typename U, int BDIMX, int BDIMY>
__global__ void LayerNormBackwardComputeGradInput(
    const T *__restrict__ dout, const T *__restrict__ input, const int n1,
    const int n2,
    // const U* __restrict__ mean, const U* __restrict__ var, const float
    // epsilon, const T* gamma,
    const U *__restrict__ mean, const U *__restrict__ var, const float epsilon,
    const U *gamma, T *grad_input) {
400 401 402
#ifdef __HIPCC__
  for (auto i1 = hipBlockIdx_y; i1 < n1; i1 += hipGridDim_y) {
#else
L
Leo Chen 已提交
403
  for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
404
#endif
L
Leo Chen 已提交
405 406 407
    U sum_loss1 = U(0);
    U sum_loss2 = U(0);
    const U c_mean = mean[i1];
408
    const U c_invvar = rsqrt_<U>(var[i1] + epsilon);
L
Leo Chen 已提交
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
    const T *k_input = input + i1 * n2;
    const T *k_dout = dout + i1 * n2;
    constexpr int numx = BDIMX * BDIMY;
    const int thrx = threadIdx.x + threadIdx.y * BDIMX;
    if (gamma != NULL) {
      int l = 4 * thrx;
      for (; l + 3 < n2; l += 4 * numx) {
        for (int k = 0; k < 4; ++k) {
          const U c_h = static_cast<U>(k_input[l + k]);
          const U c_loss = static_cast<U>(k_dout[l + k]);
          sum_loss1 += c_loss * gamma[l + k];
          sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
        }
      }
      for (; l < n2; ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        sum_loss1 += c_loss * gamma[l];
        sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
      }
    } else {
      int l = 4 * thrx;
      for (; l + 3 < n2; l += 4 * numx) {
        for (int k = 0; k < 4; ++k) {
          const U c_h = static_cast<U>(k_input[l + k]);
          const U c_loss = static_cast<U>(k_dout[l + k]);
          sum_loss1 += c_loss;
          sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
        }
      }
      for (; l < n2; ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        sum_loss1 += c_loss;
        sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
      }
    }
    // intra-warp reductions
    for (int mask = BDIMX / 2; mask > 0; mask /= 2) {
448 449 450 451 452 453
#ifdef PADDLE_WITH_HIP
      sum_loss1 += __shfl_xor(sum_loss1, mask,
                              warpSize);  // WARP_SHFL_XOR(sum_loss1, mask);
      sum_loss2 += __shfl_xor(sum_loss2, mask,
                              warpSize);  // WARP_SHFL_XOR(sum_loss2, mask);
#else
L
Leo Chen 已提交
454 455 456 457 458 459
      sum_loss1 +=
          __shfl_xor_sync(0xffffffff, sum_loss1, mask,
                          warpSize);  // WARP_SHFL_XOR(sum_loss1, mask);
      sum_loss2 +=
          __shfl_xor_sync(0xffffffff, sum_loss2, mask,
                          warpSize);  // WARP_SHFL_XOR(sum_loss2, mask);
460
#endif
L
Leo Chen 已提交
461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
    }
    // inter-warp reductions
    if (BDIMY > 1) {
      __shared__ U buf[BDIMX * BDIMY];
      for (int offset = BDIMY / 2; offset > 0; offset /= 2) {
        // upper half of warps write to shared
        if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
          const int wrt_i = (threadIdx.y - offset) * BDIMX + threadIdx.x;
          buf[2 * wrt_i] = sum_loss1;
          buf[2 * wrt_i + 1] = sum_loss2;
        }
        __syncthreads();
        // lower half merges
        if (threadIdx.y < offset) {
          const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
          sum_loss1 += buf[2 * read_i];
          sum_loss2 += buf[2 * read_i + 1];
        }
        __syncthreads();
      }
      if (threadIdx.y == 0) {
        buf[2 * threadIdx.x] = sum_loss1;
        buf[2 * threadIdx.x + 1] = sum_loss2;
      }
      __syncthreads();
      if (threadIdx.y != 0) {
        sum_loss1 = buf[2 * threadIdx.x];
        sum_loss2 = buf[2 * threadIdx.x + 1];
      }
    }
    // all threads now have the two sums over l
    U fH = (U)n2;
    U term1 = (U(1) / fH) * c_invvar;
    T *k_grad_input = grad_input + i1 * n2;
    if (gamma != NULL) {
      for (int l = thrx; l < n2; l += numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        U f_grad_input = fH * c_loss * gamma[l];
        f_grad_input -= sum_loss1;
        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
      }
    } else {
      for (int l = thrx; l < n2; l += numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        U f_grad_input = fH * c_loss;
        f_grad_input -= sum_loss1;
        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
S
sneaxiy 已提交
514 515 516 517 518 519 520
      }
    }
  }
}

// Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr
F
furnace 已提交
521
template <typename T, typename U, int BlockDim, bool HasDx>
Z
zhiboniu 已提交
522 523 524 525 526 527 528
__global__ void LayerNormBackwardGradientAll(
    const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean,
    const U *var, const U *scale, float epsilon, int64_t batch_size,
    int64_t feature_size, int64_t col_offset) {
  int64_t beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset);
  int64_t end_idx = batch_size * feature_size + (blockIdx.x + col_offset);
  int64_t stride = BlockDim * feature_size;
529

F
furnace 已提交
530
  U d_scale_partial = static_cast<U>(0), d_bias_partial = static_cast<U>(0);
S
sneaxiy 已提交
531

Z
zhiboniu 已提交
532
  for (int64_t i = beg_idx; i < end_idx; i += stride) {
S
sneaxiy 已提交
533
    int row_idx = i / feature_size;
F
furnace 已提交
534 535 536 537
    auto var_val = real_sqrt(static_cast<U>(var[row_idx]) + epsilon);
    d_scale_partial += static_cast<U>(d_y[i]) *
                       (static_cast<U>(x[i]) - mean[row_idx]) / var_val;
    d_bias_partial += static_cast<U>(d_y[i]);
538
    if (HasDx) {
F
furnace 已提交
539 540
      d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
                              scale[blockIdx.x + col_offset] / var_val);
541
    }
S
sneaxiy 已提交
542 543
  }

Z
zhiboniu 已提交
544 545
  d_scale_partial = BlockReduceSum<U>(d_scale_partial);
  d_bias_partial = BlockReduceSum<U>(d_bias_partial);
S
sneaxiy 已提交
546 547

  if (threadIdx.x == 0) {
Z
zhiboniu 已提交
548 549
    d_scale[blockIdx.x + col_offset] = d_scale_partial;
    d_bias[blockIdx.x + col_offset] = d_bias_partial;
S
sneaxiy 已提交
550 551 552 553 554 555
  }
}

// Make sure that there is only one true expression: d_scale != nullptr
// or d_bias != nullptr
// Notice: scale may be nullptr
F
furnace 已提交
556
template <typename T, typename U, int BlockDim, bool HasDx, bool HasDScale>
S
sneaxiy 已提交
557
__global__ void LayerNormBackwardGradientScaleOrBias(
F
furnace 已提交
558
    const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean,
Z
zhiboniu 已提交
559 560
    const U *var, const U *scale, float epsilon, int64_t batch_size,
    int64_t feature_size, int col_offset) {
F
furnace 已提交
561
  using BlockReduce = cub::BlockReduce<U, BlockDim>;
S
sneaxiy 已提交
562
  __shared__ typename BlockReduce::TempStorage temp_storage;
Z
zhiboniu 已提交
563 564
  int64_t beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset;
  int64_t end_idx = batch_size * feature_size + blockIdx.x + col_offset;
S
sneaxiy 已提交
565
  int stride = BlockDim * feature_size;
F
furnace 已提交
566
  U d_scale_or_d_bias_partial = static_cast<U>(0);
S
sneaxiy 已提交
567

Z
zhiboniu 已提交
568
  for (int64_t i = beg_idx; i < end_idx; i += stride) {
S
sneaxiy 已提交
569
    int row_idx = i / feature_size;
F
furnace 已提交
570 571
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(var[row_idx]) + epsilon));
S
sneaxiy 已提交
572
    if (HasDScale) {
F
furnace 已提交
573 574 575
      d_scale_or_d_bias_partial += static_cast<U>(d_y[i]) *
                                   (static_cast<U>(x[i]) - mean[row_idx]) /
                                   var_val;
S
sneaxiy 已提交
576
    } else {  // d_bias != nullptr
F
furnace 已提交
577
      d_scale_or_d_bias_partial += static_cast<U>(d_y[i]);
S
sneaxiy 已提交
578 579 580
    }

    if (HasDx) {
581
      if (scale != nullptr) {
F
furnace 已提交
582 583
        d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
                                scale[blockIdx.x + col_offset] / var_val);
584
      } else {
F
furnace 已提交
585
        d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
586
      }
S
sneaxiy 已提交
587 588 589 590 591 592 593 594
    }
  }

  d_scale_or_d_bias_partial =
      BlockReduce(temp_storage).Reduce(d_scale_or_d_bias_partial, cub::Sum());

  if (threadIdx.x == 0) {
    if (HasDScale) {
595
      d_scale[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
S
sneaxiy 已提交
596
    } else {
597
      d_bias[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
S
sneaxiy 已提交
598 599 600 601
    }
  }
}

F
furnace 已提交
602
template <typename T, typename U, int BlockDim>
Z
zhiboniu 已提交
603 604 605
__global__ void LayerNormBackwardPostProcessToCalculateDX(
    const T *x, T *d_x, const U *mean, const U *var, float epsilon,
    int64_t feature_size) {
F
furnace 已提交
606
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
607
  __shared__ typename BlockReduce::TempStorage temp_storage;
F
furnace 已提交
608
  __shared__ U d_x_reduce_tmp[2];
609

Z
zhiboniu 已提交
610 611
  int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int64_t end_idx = (blockIdx.x + 1) * feature_size;
612

F
furnace 已提交
613 614 615
  U block_mean = mean[blockIdx.x];
  U block_var = var[blockIdx.x];
  U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
Z
zhiboniu 已提交
616
  for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
617 618 619
    d_x_mean_partial += static_cast<U>(d_x[i]);
    d_x_var_partial +=
        static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
620 621 622 623
  }

  auto pair =
      BlockReduce(temp_storage)
F
furnace 已提交
624 625
          .Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<U>());
626 627

  if (threadIdx.x == 0) {
F
furnace 已提交
628 629 630 631
    d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
    d_x_reduce_tmp[1] =
        static_cast<float>(pair.second_) /
        (feature_size * (static_cast<float>(block_var) + epsilon));
632 633 634 635 636
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
Z
zhiboniu 已提交
637
  for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
638 639 640
    d_x[i] -= static_cast<T>(d_x_mean_partial);
    d_x[i] -=
        static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
641 642 643
  }
}

S
sneaxiy 已提交
644
// Here, we only calculate d_x
F
furnace 已提交
645
template <typename T, typename U, int BlockDim>
646
__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
F
furnace 已提交
647 648
                                                T *d_x, const U *mean,
                                                const U *var, const U *scale,
649
                                                float epsilon,
Z
zhiboniu 已提交
650
                                                int64_t feature_size) {
F
furnace 已提交
651
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
652
  __shared__ typename BlockReduce::TempStorage temp_storage;
F
furnace 已提交
653
  __shared__ U d_x_reduce_tmp[2];
654

Z
zhiboniu 已提交
655 656
  int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int64_t end_idx = (blockIdx.x + 1) * feature_size;
657

F
furnace 已提交
658 659
  U block_mean = mean[blockIdx.x], block_var = var[blockIdx.x];
  U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
Z
zhiboniu 已提交
660
  for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
661 662
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon));
S
sneaxiy 已提交
663
    if (scale != nullptr) {
664
      int col_idx = i % feature_size;
F
furnace 已提交
665 666
      d_x[i] =
          static_cast<T>(static_cast<U>(d_y[i]) * scale[col_idx] / var_val);
S
sneaxiy 已提交
667
    } else {
F
furnace 已提交
668
      d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
S
sneaxiy 已提交
669
    }
F
furnace 已提交
670 671 672
    d_x_mean_partial += static_cast<U>(d_x[i]);
    d_x_var_partial +=
        static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
673 674 675 676
  }

  auto pair =
      BlockReduce(temp_storage)
F
furnace 已提交
677 678
          .Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<U>());
679 680

  if (threadIdx.x == 0) {
F
furnace 已提交
681 682 683 684
    d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
    d_x_reduce_tmp[1] =
        static_cast<float>(pair.second_) /
        (feature_size * (static_cast<float>(block_var) + epsilon));
685 686 687 688 689
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
Z
zhiboniu 已提交
690
  for (int64_t i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
691 692 693
    d_x[i] -= static_cast<T>(d_x_mean_partial);
    d_x[i] -=
        static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
S
sneaxiy 已提交
694 695 696
  }
}

F
furnace 已提交
697
template <typename T, typename U>
S
sneaxiy 已提交
698
__global__ void LayerNormBackwardWhenBatchSizeIsOne(
F
furnace 已提交
699
    const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean,
Z
zhiboniu 已提交
700 701
    const U *var, const U *scale, float epsilon, int64_t feature_size) {
  int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
S
sneaxiy 已提交
702
  if (idx < feature_size) {
F
furnace 已提交
703 704
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(var[idx]) + epsilon));
S
sneaxiy 已提交
705
    if (d_x != nullptr) {
706
      if (d_scale == nullptr) {
F
furnace 已提交
707
        d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val);
708
      } else {
F
furnace 已提交
709 710
        d_x[idx] =
            static_cast<T>(static_cast<U>(d_y[idx]) * scale[idx] / var_val);
711
      }
S
sneaxiy 已提交
712
    }
713 714

    if (d_scale != nullptr) {
F
furnace 已提交
715 716
      d_scale[idx] = static_cast<U>(d_y[idx]) *
                     (static_cast<U>(x[idx]) - mean[idx]) / var_val;
717 718
    }

F
furnace 已提交
719
    if (d_bias != nullptr) d_bias[idx] = static_cast<U>(d_y[idx]);
S
sneaxiy 已提交
720 721 722
  }
}

F
furnace 已提交
723 724 725
template <typename T, typename U>
static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
                              const U *mean, const U *var, T *d_x, U *d_scale,
Z
zhiboniu 已提交
726 727
                              U *d_bias, float epsilon, int64_t batch_size,
                              int64_t feature_size,
L
Leo Chen 已提交
728 729 730
                              const framework::ExecutionContext &ctx) {
  auto &dev_ctx = ctx.cuda_device_context();
  auto stream = dev_ctx.stream();
731 732 733
#ifdef __HIPCC__
  const int kMaxBlockDim = 256;
#else
S
sneaxiy 已提交
734
  const int kMaxBlockDim = 512;
735
#endif
736
  const int kMaxBlockNum = 128;
737 738 739
  int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) |
                      ((d_scale != nullptr ? 1 : 0) << 1) |
                      ((d_bias != nullptr ? 1 : 0));
S
sneaxiy 已提交
740 741 742 743
  if (gradient_flag == 0) return;

  if (batch_size == 1) {
    LayerNormBackwardWhenBatchSizeIsOne<
F
furnace 已提交
744 745 746
        T, U><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim,
                0, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale,
                             epsilon, feature_size);
747 748 749 750

    if (d_x != nullptr) {
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
751
                             T, U, kBlockDim><<<1, kBlockDim, 0, stream>>>(
752 753 754
            x, d_x, mean, var, epsilon, feature_size));
      }
    }
S
sneaxiy 已提交
755 756 757 758 759 760 761
    return;
  }

  auto block_dim = GetDesiredBlockDim(batch_size);
  switch (gradient_flag) {
    case 1:  // d_x == nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
762 763 764
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
765
                T, U, kBlockDim, false,
766 767 768
                false><<<block_num, kBlockDim, 0, stream>>>(
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
769 770 771 772
      }
      break;
    case 2:  // d_x == nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
773 774 775
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
776 777
                T, U, kBlockDim, false,
                true><<<block_num, kBlockDim, 0, stream>>>(
778 779
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
780 781 782 783
      }
      break;
    case 3:  // d_x == nullptr, d_scale != nulptr, d_bias != nullptr
      switch (block_dim) {
784 785
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
S
sneaxiy 已提交
786
            LayerNormBackwardGradientAll<
F
furnace 已提交
787
                T, U, kBlockDim, false><<<block_num, kBlockDim, 0, stream>>>(
S
sneaxiy 已提交
788
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
789
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
790 791 792
      }
      break;
    case 4:  // d_x != nullptr, d_scale == nullptr, d_bias == nullptr
793 794 795
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardGradientOnlyDX<
F
furnace 已提交
796
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
797 798
                x, d_y, d_x, mean, var, scale, epsilon, feature_size));
      }
S
sneaxiy 已提交
799 800 801
      break;
    case 5:  // d_x != nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
802 803 804
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
805 806
                T, U, kBlockDim, true,
                false><<<block_num, kBlockDim, 0, stream>>>(
807 808
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
809
      }
810 811 812
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
813
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
814 815
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
816 817 818
      break;
    case 6:  // d_x != nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
819 820 821
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
822 823
                T, U, kBlockDim, true,
                true><<<block_num, kBlockDim, 0, stream>>>(
824 825
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
826
      }
827 828 829
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
830
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
831 832
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
833 834
      break;
    case 7:  // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
L
Leo Chen 已提交
835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870
    {
      constexpr int VPT = 4;
      constexpr int BDIMX2 = 32;
      constexpr int BDIMY2 = 4;
      dim3 threads2(BDIMX2, BDIMY2, 1);
      constexpr int part_size = BDIMY2 * VPT;
      const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1);

      auto part_grad_gamma_ptr =
          memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U));
      auto part_grad_beta_ptr =
          memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U));
      U *part_grad_gamma = reinterpret_cast<U *>(part_grad_gamma_ptr->ptr());
      U *part_grad_beta = reinterpret_cast<U *>(part_grad_beta_ptr->ptr());

      LayerNormBackwardPartGradGammaBeta<T, U, BDIMX2, BDIMY2,
                                         VPT><<<blocks2, threads2, 0, stream>>>(
          d_y, x, batch_size, feature_size, mean, var, epsilon, part_grad_gamma,
          part_grad_beta);  // compute part_grad_gamma, beta

      constexpr int BDIMX3 = 32;
      constexpr int BDIMY3 = 8;
      dim3 threads3(BDIMX3, BDIMY3, 1);
      const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1);
      LayerNormBackwardSumGradGammaBeta<
          T, U, BDIMX3, BDIMY3><<<blocks3, threads3, 0, stream>>>(
          part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size,
          d_scale, d_bias);

      constexpr int BDIMX1 = 32;
      constexpr int BDIMY1 = 4;
      dim3 threads1(BDIMX1, BDIMY1, 1);
      const dim3 blocks1(1, batch_size, 1);
      LayerNormBackwardComputeGradInput<
          T, U, BDIMX1, BDIMY1><<<blocks1, threads1, 0, stream>>>(
          d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
S
sneaxiy 已提交
871
      break;
L
Leo Chen 已提交
872
    }
S
sneaxiy 已提交
873 874 875 876 877
    default:
      break;
  }
}

P
Pei Yang 已提交
878
template <typename T>
879
void LayerNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
P
Pei Yang 已提交
880 881 882 883 884 885 886
                                               const T *input,
                                               std::vector<int> input_shape,
                                               const T *bias, const T *scale,
                                               T *output, T *mean, T *variance,
                                               int begin_norm_axis, float eps) {
  const auto x_dims = framework::make_ddim(input_shape);
  auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
Z
zhiboniu 已提交
887 888
  int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
  int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
P
Pei Yang 已提交
889 890
  switch (GetDesiredBlockDim(feature_size)) {
    FIXED_BLOCK_DIM_CASE(
F
furnace 已提交
891
        LayerNormForward<T, T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
P
Pei Yang 已提交
892 893 894 895 896 897 898 899 900
            input, scale, bias, output, mean, variance, eps, feature_size));
    default:
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Product from begin_norm_axis to end in layer_norm must be larger "
          "than 1"));
      break;
  }
}

S
sneaxiy 已提交
901 902 903 904 905
template <typename T>
class LayerNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
906
    using U = LayerNormParamType<T>;
S
sneaxiy 已提交
907 908 909 910 911 912 913 914 915 916 917 918 919
    const float epsilon = ctx.Attr<float>("epsilon");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *bias = ctx.Input<Tensor>("Bias");
    auto *x = ctx.Input<Tensor>("X");

    auto *y = ctx.Output<Tensor>("Y");
    auto *mean = ctx.Output<Tensor>("Mean");
    auto *var = ctx.Output<Tensor>("Variance");
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");

    const auto x_dims = x->dims();
    auto *x_data = x->data<T>();
    auto *y_data = y->mutable_data<T>(ctx.GetPlace());
920 921 922 923
    auto *mean_data = mean->mutable_data<U>(ctx.GetPlace());
    auto *var_data = var->mutable_data<U>(ctx.GetPlace());
    auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
    auto *bias_data = (bias == nullptr ? nullptr : bias->data<U>());
S
sneaxiy 已提交
924 925

    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
Z
zhiboniu 已提交
926 927
    int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
    int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
S
sneaxiy 已提交
928 929 930 931 932

    auto stream = ctx.cuda_device_context().stream();

    switch (GetDesiredBlockDim(feature_size)) {
      FIXED_BLOCK_DIM_CASE(
933
          LayerNormForward<T, U,
F
furnace 已提交
934
                           kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
S
sneaxiy 已提交
935 936 937
              x_data, scale_data, bias_data, y_data, mean_data, var_data,
              epsilon, feature_size));
      default:
938 939
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Product from begin_norm_axis to end must be larger than 1"));
S
sneaxiy 已提交
940 941 942 943 944 945 946 947 948 949
        break;
    }
  }
};

template <typename T>
class LayerNormGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
F
furnace 已提交
950
    using U = LayerNormParamType<T>;
S
sneaxiy 已提交
951 952 953 954 955 956 957 958 959 960 961 962 963 964
    const float epsilon = ctx.Attr<float>("epsilon");
    // d_x, d_scale, d_bias may be nullptr
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    auto *x = ctx.Input<Tensor>("X");
    auto *mean = ctx.Input<Tensor>("Mean");
    auto *var = ctx.Input<Tensor>("Variance");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));

    auto *x_data = x->data<T>();
    auto *d_y_data = d_y->data<T>();
F
furnace 已提交
965 966 967 968
    auto *mean_data = mean->data<U>();
    auto *var_data = var->data<U>();

    auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
S
sneaxiy 已提交
969 970
    auto *d_scale_data =
        (d_scale == nullptr ? nullptr
F
furnace 已提交
971
                            : d_scale->mutable_data<U>(ctx.GetPlace()));
S
sneaxiy 已提交
972
    auto *d_bias_data =
F
furnace 已提交
973
        (d_bias == nullptr ? nullptr : d_bias->mutable_data<U>(ctx.GetPlace()));
S
sneaxiy 已提交
974 975 976 977 978 979
    auto *d_x_data =
        (d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));

    const auto &x_dims = x->dims();
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
Z
zhiboniu 已提交
980 981
    int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
    int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
S
sneaxiy 已提交
982

F
furnace 已提交
983 984
    LayerNormBackward<T, U>(x_data, d_y_data, scale_data, mean_data, var_data,
                            d_x_data, d_scale_data, d_bias_data, epsilon,
L
Leo Chen 已提交
985
                            batch_size, feature_size, ctx);
S
sneaxiy 已提交
986 987
  }
};
F
furnace 已提交
988

P
Pei Yang 已提交
989
template class LayerNormDirectCUDAFunctor<float>;
F
furnace 已提交
990

991 992
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
S
sneaxiy 已提交
993 994 995 996 997
#undef FIXED_BLOCK_DIM_CASE_BASE
#undef FIXED_BLOCK_DIM_CASE
}  // namespace operators
}  // namespace paddle

C
chengduoZH 已提交
998
namespace ops = paddle::operators;
F
furnace 已提交
999
namespace plat = paddle::platform;
1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(
    layer_norm,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
    layer_norm_grad,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
                             plat::float16>);
#else
C
chengduoZH 已提交
1012 1013
REGISTER_OP_CUDA_KERNEL(
    layer_norm,
C
chengduoZH 已提交
1014
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
F
furnace 已提交
1015 1016
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
C
chengduoZH 已提交
1017 1018
REGISTER_OP_CUDA_KERNEL(
    layer_norm_grad,
C
chengduoZH 已提交
1019
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
F
furnace 已提交
1020 1021 1022
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
                             plat::float16>);
1023
#endif