layer_norm_op.cu 38.3 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16 17 18 19 20 21
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
P
Pei Yang 已提交
22 23
#include <memory>
#include <vector>
F
furnace 已提交
24

P
Pei Yang 已提交
25
#include "paddle/fluid/framework/ddim.h"
Y
Yi Wang 已提交
26
#include "paddle/fluid/operators/layer_norm_op.h"
F
furnace 已提交
27
#include "paddle/fluid/platform/float16.h"
28 29 30 31 32 33
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
C
chengduoZH 已提交
34

S
sneaxiy 已提交
35 36 37
namespace paddle {
namespace operators {

F
furnace 已提交
38 39 40 41 42 43 44
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;

S
sneaxiy 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
inline static int GetDesiredBlockDim(int block_dim) {
  const int kMaxBlockDim = 512;
  return block_dim >= kMaxBlockDim
             ? kMaxBlockDim
             : (1 << (static_cast<int>(std::log2f(block_dim))));
}

#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...)  \
  case (1 << (log2_block_dim)): {                       \
    constexpr auto kBlockDim = (1 << (log2_block_dim)); \
    __VA_ARGS__;                                        \
  } break

#define FIXED_BLOCK_DIM_CASE(...)              \
  FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(2, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(1, ##__VA_ARGS__)

69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                             \
    log2_block_dim, feature_size, kMaxBlockNum, ...)                           \
  case (1 << (log2_block_dim)): {                                              \
    for (int i = 0; i < std::ceil(feature_size / (1.0 * kMaxBlockNum)); i++) { \
      int col_offset = i * kMaxBlockNum;                                       \
      int block_num = std::min(feature_size - col_offset, kMaxBlockNum);       \
      constexpr auto kBlockDim = (1 << (log2_block_dim));                      \
      __VA_ARGS__;                                                             \
    }                                                                          \
  } break

#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(feature_size, kMaxBlockNum, ...) \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(9, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(8, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(7, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(6, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(5, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(4, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(3, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(2, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(1, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__)

100 101 102
static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); }
static __device__ __forceinline__ double real_sqrt(double x) { return sqrt(x); }

103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
template <typename T>
struct PairForLayerNorm {
  __device__ __forceinline__ PairForLayerNorm() {}
  __device__ __forceinline__ PairForLayerNorm(const T &first, const T &second)
      : first_(first), second_(second) {}

  T first_;
  T second_;
};

template <typename T>
struct PairForLayerNormAddFunctor {
  __device__ __forceinline__ PairForLayerNorm<T> operator()(
      const PairForLayerNorm<T> &p1, const PairForLayerNorm<T> &p2) {
    return PairForLayerNorm<T>(p1.first_ + p2.first_, p1.second_ + p2.second_);
  }
};

L
Leo Chen 已提交
121
template <typename T>
122
__inline__ __device__ T rsqrt_(const T val) {
123
  return static_cast<T>(1) / sqrt(val);
L
Leo Chen 已提交
124 125 126
}

template <>
127
__inline__ __device__ float rsqrt_(const float val) {
L
Leo Chen 已提交
128 129 130
  return rsqrtf(val);
}

131
template <>
132
__inline__ __device__ double rsqrt_(const double val) {
133 134 135 136
  return rsqrt(val);
}

#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
L
Leo Chen 已提交
137
template <>
138
__inline__ __device__ half rsqrt_(const half val) {
L
Leo Chen 已提交
139 140
  return hrsqrt(val);
}
141
#endif
L
Leo Chen 已提交
142

F
furnace 已提交
143 144 145
template <typename T, typename U, int BlockDim>
__global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
                                 T *y, U *mean, U *var, float epsilon,
S
sneaxiy 已提交
146
                                 int feature_size) {
L
Leo Chen 已提交
147
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
S
sneaxiy 已提交
148
  __shared__ typename BlockReduce::TempStorage temp_storage;
L
Leo Chen 已提交
149 150
  __shared__ U mean_share;
  __shared__ U var_share;
S
sneaxiy 已提交
151 152 153 154

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

155
  // Step 1: Reduce to calculate mean and var
L
Leo Chen 已提交
156 157
  U mean_val = 0;
  U var_val = 0;
S
sneaxiy 已提交
158
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
159
    U tmp = static_cast<U>(x[i]);
160
    mean_val += tmp;
S
sneaxiy 已提交
161 162
    var_val += (tmp * tmp);
  }
163
  auto pair = BlockReduce(temp_storage)
L
Leo Chen 已提交
164 165
                  .Reduce(PairForLayerNorm<U>(mean_val, var_val),
                          PairForLayerNormAddFunctor<U>());
166 167
  if (threadIdx.x == 0) {
    auto tmp = pair.first_ / feature_size;
L
Leo Chen 已提交
168 169 170
    mean[blockIdx.x] = mean_share = static_cast<U>(tmp);
    var[blockIdx.x] = var_share =
        static_cast<U>(pair.second_ / feature_size - tmp * tmp);
171
  }
S
sneaxiy 已提交
172
  __syncthreads();
L
Leo Chen 已提交
173 174

  mean_val = mean_share;
175
  U invvar = rsqrt_<U>(var_share + static_cast<U>(epsilon));
S
sneaxiy 已提交
176

177
  // Step 2: Calculate y
S
sneaxiy 已提交
178 179 180 181
  if (scale != nullptr) {
    if (bias != nullptr) {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
F
furnace 已提交
182
        y[i] = static_cast<T>(
L
Leo Chen 已提交
183
            scale[j] * (static_cast<U>(x[i]) - mean_val) * invvar + bias[j]);
S
sneaxiy 已提交
184 185 186 187
      }
    } else {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
L
Leo Chen 已提交
188 189
        y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) *
                              invvar);
S
sneaxiy 已提交
190 191 192 193 194 195
      }
    }
  } else {  // scale == nullptr
    if (bias != nullptr) {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
L
Leo Chen 已提交
196
        y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
F
furnace 已提交
197
                              bias[j]);
S
sneaxiy 已提交
198 199 200 201
      }
    } else {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
L
Leo Chen 已提交
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
        y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar);
      }
    }
  }
}

template <typename T, typename U, int VPT>
__inline__ __device__ void cuLoadAddStridedInputs(
    const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
    const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
    const T *input, const T *dout, const int i1_end, const int n2,
    const U *__restrict__ mean, const U *__restrict__ var,
    const float epsilon) {
  const int i1 = i1_block + thr_load_row_off;
  if (i1 >= i1_end) return;
  U curr_mean = mean[i1];
218
  U curr_invvar = rsqrt_<U>(var[i1] + epsilon);
L
Leo Chen 已提交
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
  for (int k = 0; k < VPT; ++k) {
    const int i2 = i2_off + k;
    const int load_idx = i1 * n2 + i2;
    const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
    if (i2 < n2) {
      U curr_input = static_cast<U>(input[load_idx]);
      U curr_dout = static_cast<U>(dout[load_idx]);
      warp_buf1[write_idx] += curr_dout;
      warp_buf2[write_idx] +=
          curr_dout * (curr_input - curr_mean) * curr_invvar;
    }
  }
}

template <typename T, typename U, int BDIMX, int BDIMY, int VPTX>
__global__ void LayerNormBackwardPartGradGammaBeta(
    const T *__restrict__ dout, const T *__restrict__ input, const int n1,
    const int n2, const U *__restrict__ mean, const U *__restrict__ var,
    float epsilon, U *part_grad_gamma, U *part_grad_beta) {
  // VPTX -> value per thread.x, BDIMX -> blockDim.x, BDIMY -> blockDim.y, BDIMX
  // -> blockDim.x
  // template for compile time optimizations

  constexpr int row_stride = BDIMX + 1;
  const int thr_load_col_off = (threadIdx.x * VPTX) & (BDIMX - 1);
  const int thr_load_row_off =
      (threadIdx.x * VPTX) / BDIMX + threadIdx.y * BDIMY;
  const int i2_off = blockIdx.x * BDIMX + thr_load_col_off;

  constexpr int shared_cap = (BDIMX * BDIMY > 2 * VPTX * BDIMY * row_stride)
                                 ? BDIMX * BDIMY
                                 : 2 * VPTX * BDIMY * row_stride;
  __shared__ U buf[shared_cap];

  U *warp_buf1 = reinterpret_cast<U *>(buf);
  U *warp_buf2 = warp_buf1 + VPTX * BDIMY * row_stride;

  for (int idx = threadIdx.y * blockDim.x + threadIdx.x;
       idx < 2 * VPTX * BDIMY * row_stride; idx += BDIMX * BDIMY) {
    buf[idx] = U(0);
  }
  __syncthreads();

  for (int i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1;
       i1_block += VPTX * BDIMY * gridDim.y) {
    cuLoadAddStridedInputs<T, U, VPTX>(
        i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride,
        warp_buf1, warp_buf2, input, dout, n1, n2, mean, var, epsilon);
  }
  __syncthreads();

  // inter-warp reductions
  // sum within each warp
  U acc1 = U(0);
  U acc2 = U(0);
  for (int k = 0; k < VPTX; ++k) {
    int row1 = threadIdx.y + k * VPTX;
    int idx1 = row1 * row_stride + threadIdx.x;
    acc1 += warp_buf1[idx1];
    acc2 += warp_buf2[idx1];
  }
  warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
  warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
  __syncthreads();
  // sum all warps
  for (int offset = VPTX >> 1; offset > 1; offset >>= 1) {
    if (threadIdx.y < offset) {
      int row1 = threadIdx.y;
      int row2 = threadIdx.y + offset;
      int idx1 = row1 * row_stride + threadIdx.x;
      int idx2 = row2 * row_stride + threadIdx.x;
      warp_buf1[idx1] += warp_buf1[idx2];
      warp_buf2[idx1] += warp_buf2[idx2];
    }
    __syncthreads();
  }
  int i2 = blockIdx.x * blockDim.x + threadIdx.x;
  if (threadIdx.y == 0 && i2 < n2) {
    int row1 = threadIdx.y;
    int row2 = threadIdx.y + 1;
    int idx1 = row1 * row_stride + threadIdx.x;
    int idx2 = row2 * row_stride + threadIdx.x;
    part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
    part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
  }
}

template <typename T, typename U, int BDIMX, int BDIMY>
__global__ void LayerNormBackwardSumGradGammaBeta(
    const U *part_grad_gamma, const U *part_grad_beta, const int part_size,
    // const int n1, const int n2, T* grad_gamma, T* grad_beta) {
    const int n1, const int n2, U *grad_gamma, U *grad_beta) {
  // sum partial gradients for gamma and beta
  __shared__ U buf[BDIMX * BDIMY];
  int i2 = blockIdx.x * BDIMX + threadIdx.x;
  if (i2 < n2) {
    // each warp does sequential reductions until reduced part_size is num_warps
    int num_warp_reductions = part_size / BDIMY;
    U sum_gamma = U(0);
    U sum_beta = U(0);
    const U *part_grad_gamma_ptr =
        part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
    const U *part_grad_beta_ptr =
        part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
    for (int warp_offset = 0; warp_offset < num_warp_reductions;
         ++warp_offset) {
      sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
      sum_beta += part_grad_beta_ptr[warp_offset * n2];
    }
    // inter-warp reductions
    constexpr int nbsize3 = BDIMX * BDIMY / 2;
    for (int offset = BDIMY / 2; offset >= 1; offset /= 2) {
      // top half write to shared memory
      if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
        const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
        buf[write_idx] = sum_gamma;
        buf[write_idx + nbsize3] = sum_beta;
      }
      __syncthreads();
      // bottom half sums
      if (threadIdx.y < offset) {
        const int read_idx = threadIdx.y * BDIMX + threadIdx.x;
        sum_gamma += buf[read_idx];
        sum_beta += buf[read_idx + nbsize3];
      }
      __syncthreads();
    }
    // write out fully summed gradients
    if (threadIdx.y == 0) {
      grad_gamma[i2] = sum_gamma;
      grad_beta[i2] = sum_beta;
    }
  }
}

template <typename T, typename U, int BDIMX, int BDIMY>
__global__ void LayerNormBackwardComputeGradInput(
    const T *__restrict__ dout, const T *__restrict__ input, const int n1,
    const int n2,
    // const U* __restrict__ mean, const U* __restrict__ var, const float
    // epsilon, const T* gamma,
    const U *__restrict__ mean, const U *__restrict__ var, const float epsilon,
    const U *gamma, T *grad_input) {
362 363 364
#ifdef __HIPCC__
  for (auto i1 = hipBlockIdx_y; i1 < n1; i1 += hipGridDim_y) {
#else
L
Leo Chen 已提交
365
  for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
366
#endif
L
Leo Chen 已提交
367 368 369
    U sum_loss1 = U(0);
    U sum_loss2 = U(0);
    const U c_mean = mean[i1];
370
    const U c_invvar = rsqrt_<U>(var[i1] + epsilon);
L
Leo Chen 已提交
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409
    const T *k_input = input + i1 * n2;
    const T *k_dout = dout + i1 * n2;
    constexpr int numx = BDIMX * BDIMY;
    const int thrx = threadIdx.x + threadIdx.y * BDIMX;
    if (gamma != NULL) {
      int l = 4 * thrx;
      for (; l + 3 < n2; l += 4 * numx) {
        for (int k = 0; k < 4; ++k) {
          const U c_h = static_cast<U>(k_input[l + k]);
          const U c_loss = static_cast<U>(k_dout[l + k]);
          sum_loss1 += c_loss * gamma[l + k];
          sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
        }
      }
      for (; l < n2; ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        sum_loss1 += c_loss * gamma[l];
        sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
      }
    } else {
      int l = 4 * thrx;
      for (; l + 3 < n2; l += 4 * numx) {
        for (int k = 0; k < 4; ++k) {
          const U c_h = static_cast<U>(k_input[l + k]);
          const U c_loss = static_cast<U>(k_dout[l + k]);
          sum_loss1 += c_loss;
          sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
        }
      }
      for (; l < n2; ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        sum_loss1 += c_loss;
        sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
      }
    }
    // intra-warp reductions
    for (int mask = BDIMX / 2; mask > 0; mask /= 2) {
410 411 412 413 414 415
#ifdef PADDLE_WITH_HIP
      sum_loss1 += __shfl_xor(sum_loss1, mask,
                              warpSize);  // WARP_SHFL_XOR(sum_loss1, mask);
      sum_loss2 += __shfl_xor(sum_loss2, mask,
                              warpSize);  // WARP_SHFL_XOR(sum_loss2, mask);
#else
L
Leo Chen 已提交
416 417 418 419 420 421
      sum_loss1 +=
          __shfl_xor_sync(0xffffffff, sum_loss1, mask,
                          warpSize);  // WARP_SHFL_XOR(sum_loss1, mask);
      sum_loss2 +=
          __shfl_xor_sync(0xffffffff, sum_loss2, mask,
                          warpSize);  // WARP_SHFL_XOR(sum_loss2, mask);
422
#endif
L
Leo Chen 已提交
423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475
    }
    // inter-warp reductions
    if (BDIMY > 1) {
      __shared__ U buf[BDIMX * BDIMY];
      for (int offset = BDIMY / 2; offset > 0; offset /= 2) {
        // upper half of warps write to shared
        if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
          const int wrt_i = (threadIdx.y - offset) * BDIMX + threadIdx.x;
          buf[2 * wrt_i] = sum_loss1;
          buf[2 * wrt_i + 1] = sum_loss2;
        }
        __syncthreads();
        // lower half merges
        if (threadIdx.y < offset) {
          const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
          sum_loss1 += buf[2 * read_i];
          sum_loss2 += buf[2 * read_i + 1];
        }
        __syncthreads();
      }
      if (threadIdx.y == 0) {
        buf[2 * threadIdx.x] = sum_loss1;
        buf[2 * threadIdx.x + 1] = sum_loss2;
      }
      __syncthreads();
      if (threadIdx.y != 0) {
        sum_loss1 = buf[2 * threadIdx.x];
        sum_loss2 = buf[2 * threadIdx.x + 1];
      }
    }
    // all threads now have the two sums over l
    U fH = (U)n2;
    U term1 = (U(1) / fH) * c_invvar;
    T *k_grad_input = grad_input + i1 * n2;
    if (gamma != NULL) {
      for (int l = thrx; l < n2; l += numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        U f_grad_input = fH * c_loss * gamma[l];
        f_grad_input -= sum_loss1;
        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
      }
    } else {
      for (int l = thrx; l < n2; l += numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        U f_grad_input = fH * c_loss;
        f_grad_input -= sum_loss1;
        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
S
sneaxiy 已提交
476 477 478 479 480 481 482
      }
    }
  }
}

// Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr
F
furnace 已提交
483
template <typename T, typename U, int BlockDim, bool HasDx>
S
sneaxiy 已提交
484
__global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
F
furnace 已提交
485 486 487
                                             U *d_scale, U *d_bias, T *d_x,
                                             const U *mean, const U *var,
                                             const U *scale, float epsilon,
488 489
                                             int batch_size, int feature_size,
                                             int col_offset) {
F
furnace 已提交
490
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
S
sneaxiy 已提交
491 492
  __shared__ typename BlockReduce::TempStorage temp_storage;

493 494
  int beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset);
  int end_idx = batch_size * feature_size + (blockIdx.x + col_offset);
S
sneaxiy 已提交
495
  int stride = BlockDim * feature_size;
496

F
furnace 已提交
497
  U d_scale_partial = static_cast<U>(0), d_bias_partial = static_cast<U>(0);
S
sneaxiy 已提交
498 499 500

  for (int i = beg_idx; i < end_idx; i += stride) {
    int row_idx = i / feature_size;
F
furnace 已提交
501 502 503 504
    auto var_val = real_sqrt(static_cast<U>(var[row_idx]) + epsilon);
    d_scale_partial += static_cast<U>(d_y[i]) *
                       (static_cast<U>(x[i]) - mean[row_idx]) / var_val;
    d_bias_partial += static_cast<U>(d_y[i]);
505
    if (HasDx) {
F
furnace 已提交
506 507
      d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
                              scale[blockIdx.x + col_offset] / var_val);
508
    }
S
sneaxiy 已提交
509 510
  }

511
  auto pair = BlockReduce(temp_storage)
F
furnace 已提交
512 513
                  .Reduce(PairForLayerNorm<U>(d_scale_partial, d_bias_partial),
                          PairForLayerNormAddFunctor<U>());
S
sneaxiy 已提交
514 515

  if (threadIdx.x == 0) {
516 517
    d_scale[blockIdx.x + col_offset] = pair.first_;
    d_bias[blockIdx.x + col_offset] = pair.second_;
S
sneaxiy 已提交
518 519 520 521 522 523
  }
}

// Make sure that there is only one true expression: d_scale != nullptr
// or d_bias != nullptr
// Notice: scale may be nullptr
F
furnace 已提交
524
template <typename T, typename U, int BlockDim, bool HasDx, bool HasDScale>
S
sneaxiy 已提交
525
__global__ void LayerNormBackwardGradientScaleOrBias(
F
furnace 已提交
526 527
    const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean,
    const U *var, const U *scale, float epsilon, int batch_size,
528
    int feature_size, int col_offset) {
F
furnace 已提交
529
  using BlockReduce = cub::BlockReduce<U, BlockDim>;
S
sneaxiy 已提交
530
  __shared__ typename BlockReduce::TempStorage temp_storage;
531 532
  int beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset;
  int end_idx = batch_size * feature_size + blockIdx.x + col_offset;
S
sneaxiy 已提交
533
  int stride = BlockDim * feature_size;
F
furnace 已提交
534
  U d_scale_or_d_bias_partial = static_cast<U>(0);
S
sneaxiy 已提交
535 536 537

  for (int i = beg_idx; i < end_idx; i += stride) {
    int row_idx = i / feature_size;
F
furnace 已提交
538 539
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(var[row_idx]) + epsilon));
S
sneaxiy 已提交
540
    if (HasDScale) {
F
furnace 已提交
541 542 543
      d_scale_or_d_bias_partial += static_cast<U>(d_y[i]) *
                                   (static_cast<U>(x[i]) - mean[row_idx]) /
                                   var_val;
S
sneaxiy 已提交
544
    } else {  // d_bias != nullptr
F
furnace 已提交
545
      d_scale_or_d_bias_partial += static_cast<U>(d_y[i]);
S
sneaxiy 已提交
546 547 548
    }

    if (HasDx) {
549
      if (scale != nullptr) {
F
furnace 已提交
550 551
        d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
                                scale[blockIdx.x + col_offset] / var_val);
552
      } else {
F
furnace 已提交
553
        d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
554
      }
S
sneaxiy 已提交
555 556 557 558 559 560 561 562
    }
  }

  d_scale_or_d_bias_partial =
      BlockReduce(temp_storage).Reduce(d_scale_or_d_bias_partial, cub::Sum());

  if (threadIdx.x == 0) {
    if (HasDScale) {
563
      d_scale[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
S
sneaxiy 已提交
564
    } else {
565
      d_bias[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
S
sneaxiy 已提交
566 567 568 569
    }
  }
}

F
furnace 已提交
570
template <typename T, typename U, int BlockDim>
571
__global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x,
F
furnace 已提交
572 573
                                                          const U *mean,
                                                          const U *var,
574 575
                                                          float epsilon,
                                                          int feature_size) {
F
furnace 已提交
576
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
577
  __shared__ typename BlockReduce::TempStorage temp_storage;
F
furnace 已提交
578
  __shared__ U d_x_reduce_tmp[2];
579 580 581 582

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

F
furnace 已提交
583 584 585
  U block_mean = mean[blockIdx.x];
  U block_var = var[blockIdx.x];
  U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
586
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
587 588 589
    d_x_mean_partial += static_cast<U>(d_x[i]);
    d_x_var_partial +=
        static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
590 591 592 593
  }

  auto pair =
      BlockReduce(temp_storage)
F
furnace 已提交
594 595
          .Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<U>());
596 597

  if (threadIdx.x == 0) {
F
furnace 已提交
598 599 600 601
    d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
    d_x_reduce_tmp[1] =
        static_cast<float>(pair.second_) /
        (feature_size * (static_cast<float>(block_var) + epsilon));
602 603 604 605 606 607
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
608 609 610
    d_x[i] -= static_cast<T>(d_x_mean_partial);
    d_x[i] -=
        static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
611 612 613
  }
}

S
sneaxiy 已提交
614
// Here, we only calculate d_x
F
furnace 已提交
615
template <typename T, typename U, int BlockDim>
616
__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
F
furnace 已提交
617 618
                                                T *d_x, const U *mean,
                                                const U *var, const U *scale,
619 620
                                                float epsilon,
                                                int feature_size) {
F
furnace 已提交
621
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
622
  __shared__ typename BlockReduce::TempStorage temp_storage;
F
furnace 已提交
623
  __shared__ U d_x_reduce_tmp[2];
624 625 626 627

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

F
furnace 已提交
628 629
  U block_mean = mean[blockIdx.x], block_var = var[blockIdx.x];
  U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
630
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
631 632
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon));
S
sneaxiy 已提交
633
    if (scale != nullptr) {
634
      int col_idx = i % feature_size;
F
furnace 已提交
635 636
      d_x[i] =
          static_cast<T>(static_cast<U>(d_y[i]) * scale[col_idx] / var_val);
S
sneaxiy 已提交
637
    } else {
F
furnace 已提交
638
      d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
S
sneaxiy 已提交
639
    }
F
furnace 已提交
640 641 642
    d_x_mean_partial += static_cast<U>(d_x[i]);
    d_x_var_partial +=
        static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
643 644 645 646
  }

  auto pair =
      BlockReduce(temp_storage)
F
furnace 已提交
647 648
          .Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<U>());
649 650

  if (threadIdx.x == 0) {
F
furnace 已提交
651 652 653 654
    d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
    d_x_reduce_tmp[1] =
        static_cast<float>(pair.second_) /
        (feature_size * (static_cast<float>(block_var) + epsilon));
655 656 657 658 659 660
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
661 662 663
    d_x[i] -= static_cast<T>(d_x_mean_partial);
    d_x[i] -=
        static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
S
sneaxiy 已提交
664 665 666
  }
}

F
furnace 已提交
667
template <typename T, typename U>
S
sneaxiy 已提交
668
__global__ void LayerNormBackwardWhenBatchSizeIsOne(
F
furnace 已提交
669 670
    const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean,
    const U *var, const U *scale, float epsilon, int feature_size) {
S
sneaxiy 已提交
671 672
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < feature_size) {
F
furnace 已提交
673 674
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(var[idx]) + epsilon));
S
sneaxiy 已提交
675
    if (d_x != nullptr) {
676
      if (d_scale == nullptr) {
F
furnace 已提交
677
        d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val);
678
      } else {
F
furnace 已提交
679 680
        d_x[idx] =
            static_cast<T>(static_cast<U>(d_y[idx]) * scale[idx] / var_val);
681
      }
S
sneaxiy 已提交
682
    }
683 684

    if (d_scale != nullptr) {
F
furnace 已提交
685 686
      d_scale[idx] = static_cast<U>(d_y[idx]) *
                     (static_cast<U>(x[idx]) - mean[idx]) / var_val;
687 688
    }

F
furnace 已提交
689
    if (d_bias != nullptr) d_bias[idx] = static_cast<U>(d_y[idx]);
S
sneaxiy 已提交
690 691 692
  }
}

F
furnace 已提交
693 694 695 696
template <typename T, typename U>
static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
                              const U *mean, const U *var, T *d_x, U *d_scale,
                              U *d_bias, float epsilon, int batch_size,
L
Leo Chen 已提交
697 698 699 700 701
                              int feature_size,
                              const framework::ExecutionContext &ctx) {
  auto &dev_ctx = ctx.cuda_device_context();
  auto stream = dev_ctx.stream();

S
sneaxiy 已提交
702
  const int kMaxBlockDim = 512;
703
  const int kMaxBlockNum = 128;
704 705 706
  int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) |
                      ((d_scale != nullptr ? 1 : 0) << 1) |
                      ((d_bias != nullptr ? 1 : 0));
S
sneaxiy 已提交
707 708 709 710
  if (gradient_flag == 0) return;

  if (batch_size == 1) {
    LayerNormBackwardWhenBatchSizeIsOne<
F
furnace 已提交
711 712 713
        T, U><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim,
                0, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale,
                             epsilon, feature_size);
714 715 716 717

    if (d_x != nullptr) {
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
718
                             T, U, kBlockDim><<<1, kBlockDim, 0, stream>>>(
719 720 721
            x, d_x, mean, var, epsilon, feature_size));
      }
    }
S
sneaxiy 已提交
722 723 724 725 726 727 728
    return;
  }

  auto block_dim = GetDesiredBlockDim(batch_size);
  switch (gradient_flag) {
    case 1:  // d_x == nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
729 730 731
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
732
                T, U, kBlockDim, false,
733 734 735
                false><<<block_num, kBlockDim, 0, stream>>>(
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
736 737 738 739
      }
      break;
    case 2:  // d_x == nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
740 741 742
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
743 744
                T, U, kBlockDim, false,
                true><<<block_num, kBlockDim, 0, stream>>>(
745 746
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
747 748 749 750
      }
      break;
    case 3:  // d_x == nullptr, d_scale != nulptr, d_bias != nullptr
      switch (block_dim) {
751 752
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
S
sneaxiy 已提交
753
            LayerNormBackwardGradientAll<
F
furnace 已提交
754
                T, U, kBlockDim, false><<<block_num, kBlockDim, 0, stream>>>(
S
sneaxiy 已提交
755
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
756
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
757 758 759
      }
      break;
    case 4:  // d_x != nullptr, d_scale == nullptr, d_bias == nullptr
760 761 762
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardGradientOnlyDX<
F
furnace 已提交
763
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
764 765
                x, d_y, d_x, mean, var, scale, epsilon, feature_size));
      }
S
sneaxiy 已提交
766 767 768
      break;
    case 5:  // d_x != nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
769 770 771
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
772 773
                T, U, kBlockDim, true,
                false><<<block_num, kBlockDim, 0, stream>>>(
774 775
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
776
      }
777 778 779
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
780
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
781 782
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
783 784 785
      break;
    case 6:  // d_x != nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
786 787 788
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
789 790
                T, U, kBlockDim, true,
                true><<<block_num, kBlockDim, 0, stream>>>(
791 792
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
793
      }
794 795 796
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
797
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
798 799
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
800 801
      break;
    case 7:  // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
L
Leo Chen 已提交
802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837
    {
      constexpr int VPT = 4;
      constexpr int BDIMX2 = 32;
      constexpr int BDIMY2 = 4;
      dim3 threads2(BDIMX2, BDIMY2, 1);
      constexpr int part_size = BDIMY2 * VPT;
      const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1);

      auto part_grad_gamma_ptr =
          memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U));
      auto part_grad_beta_ptr =
          memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U));
      U *part_grad_gamma = reinterpret_cast<U *>(part_grad_gamma_ptr->ptr());
      U *part_grad_beta = reinterpret_cast<U *>(part_grad_beta_ptr->ptr());

      LayerNormBackwardPartGradGammaBeta<T, U, BDIMX2, BDIMY2,
                                         VPT><<<blocks2, threads2, 0, stream>>>(
          d_y, x, batch_size, feature_size, mean, var, epsilon, part_grad_gamma,
          part_grad_beta);  // compute part_grad_gamma, beta

      constexpr int BDIMX3 = 32;
      constexpr int BDIMY3 = 8;
      dim3 threads3(BDIMX3, BDIMY3, 1);
      const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1);
      LayerNormBackwardSumGradGammaBeta<
          T, U, BDIMX3, BDIMY3><<<blocks3, threads3, 0, stream>>>(
          part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size,
          d_scale, d_bias);

      constexpr int BDIMX1 = 32;
      constexpr int BDIMY1 = 4;
      dim3 threads1(BDIMX1, BDIMY1, 1);
      const dim3 blocks1(1, batch_size, 1);
      LayerNormBackwardComputeGradInput<
          T, U, BDIMX1, BDIMY1><<<blocks1, threads1, 0, stream>>>(
          d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
S
sneaxiy 已提交
838
      break;
L
Leo Chen 已提交
839
    }
S
sneaxiy 已提交
840 841 842 843 844
    default:
      break;
  }
}

P
Pei Yang 已提交
845
template <typename T>
846
void LayerNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
P
Pei Yang 已提交
847 848 849 850 851 852 853 854 855 856 857
                                               const T *input,
                                               std::vector<int> input_shape,
                                               const T *bias, const T *scale,
                                               T *output, T *mean, T *variance,
                                               int begin_norm_axis, float eps) {
  const auto x_dims = framework::make_ddim(input_shape);
  auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
  int batch_size = static_cast<int>(matrix_dim[0]);
  int feature_size = static_cast<int>(matrix_dim[1]);
  switch (GetDesiredBlockDim(feature_size)) {
    FIXED_BLOCK_DIM_CASE(
F
furnace 已提交
858
        LayerNormForward<T, T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
P
Pei Yang 已提交
859 860 861 862 863 864 865 866 867
            input, scale, bias, output, mean, variance, eps, feature_size));
    default:
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Product from begin_norm_axis to end in layer_norm must be larger "
          "than 1"));
      break;
  }
}

S
sneaxiy 已提交
868 869 870 871 872
template <typename T>
class LayerNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
873
    using U = LayerNormParamType<T>;
S
sneaxiy 已提交
874 875 876 877 878 879 880 881 882 883 884 885 886
    const float epsilon = ctx.Attr<float>("epsilon");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *bias = ctx.Input<Tensor>("Bias");
    auto *x = ctx.Input<Tensor>("X");

    auto *y = ctx.Output<Tensor>("Y");
    auto *mean = ctx.Output<Tensor>("Mean");
    auto *var = ctx.Output<Tensor>("Variance");
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");

    const auto x_dims = x->dims();
    auto *x_data = x->data<T>();
    auto *y_data = y->mutable_data<T>(ctx.GetPlace());
887 888 889 890
    auto *mean_data = mean->mutable_data<U>(ctx.GetPlace());
    auto *var_data = var->mutable_data<U>(ctx.GetPlace());
    auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
    auto *bias_data = (bias == nullptr ? nullptr : bias->data<U>());
S
sneaxiy 已提交
891 892 893 894 895 896 897 898 899

    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
    int batch_size = static_cast<int>(matrix_dim[0]);
    int feature_size = static_cast<int>(matrix_dim[1]);

    auto stream = ctx.cuda_device_context().stream();

    switch (GetDesiredBlockDim(feature_size)) {
      FIXED_BLOCK_DIM_CASE(
900
          LayerNormForward<T, U,
F
furnace 已提交
901
                           kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
S
sneaxiy 已提交
902 903 904
              x_data, scale_data, bias_data, y_data, mean_data, var_data,
              epsilon, feature_size));
      default:
905 906
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Product from begin_norm_axis to end must be larger than 1"));
S
sneaxiy 已提交
907 908 909 910 911 912 913 914 915 916
        break;
    }
  }
};

template <typename T>
class LayerNormGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
F
furnace 已提交
917
    using U = LayerNormParamType<T>;
S
sneaxiy 已提交
918 919 920 921 922 923 924 925 926 927 928 929 930 931
    const float epsilon = ctx.Attr<float>("epsilon");
    // d_x, d_scale, d_bias may be nullptr
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    auto *x = ctx.Input<Tensor>("X");
    auto *mean = ctx.Input<Tensor>("Mean");
    auto *var = ctx.Input<Tensor>("Variance");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));

    auto *x_data = x->data<T>();
    auto *d_y_data = d_y->data<T>();
F
furnace 已提交
932 933 934 935
    auto *mean_data = mean->data<U>();
    auto *var_data = var->data<U>();

    auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
S
sneaxiy 已提交
936 937
    auto *d_scale_data =
        (d_scale == nullptr ? nullptr
F
furnace 已提交
938
                            : d_scale->mutable_data<U>(ctx.GetPlace()));
S
sneaxiy 已提交
939
    auto *d_bias_data =
F
furnace 已提交
940
        (d_bias == nullptr ? nullptr : d_bias->mutable_data<U>(ctx.GetPlace()));
S
sneaxiy 已提交
941 942 943 944 945 946 947 948 949
    auto *d_x_data =
        (d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));

    const auto &x_dims = x->dims();
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
    int batch_size = static_cast<int>(matrix_dim[0]);
    int feature_size = static_cast<int>(matrix_dim[1]);

F
furnace 已提交
950 951
    LayerNormBackward<T, U>(x_data, d_y_data, scale_data, mean_data, var_data,
                            d_x_data, d_scale_data, d_bias_data, epsilon,
L
Leo Chen 已提交
952
                            batch_size, feature_size, ctx);
S
sneaxiy 已提交
953 954
  }
};
F
furnace 已提交
955

P
Pei Yang 已提交
956
template class LayerNormDirectCUDAFunctor<float>;
F
furnace 已提交
957

958 959
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
S
sneaxiy 已提交
960 961 962 963 964
#undef FIXED_BLOCK_DIM_CASE_BASE
#undef FIXED_BLOCK_DIM_CASE
}  // namespace operators
}  // namespace paddle

C
chengduoZH 已提交
965
namespace ops = paddle::operators;
F
furnace 已提交
966
namespace plat = paddle::platform;
967 968 969 970 971 972 973 974 975 976 977 978
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(
    layer_norm,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
    layer_norm_grad,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
                             plat::float16>);
#else
C
chengduoZH 已提交
979 980
REGISTER_OP_CUDA_KERNEL(
    layer_norm,
C
chengduoZH 已提交
981
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
F
furnace 已提交
982 983
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
C
chengduoZH 已提交
984 985
REGISTER_OP_CUDA_KERNEL(
    layer_norm_grad,
C
chengduoZH 已提交
986
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
F
furnace 已提交
987 988 989
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
                             plat::float16>);
990
#endif