layer_norm_op.cu 37.2 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

S
sneaxiy 已提交
15
#include <cub/cub.cuh>
P
Pei Yang 已提交
16 17
#include <memory>
#include <vector>
F
furnace 已提交
18

P
Pei Yang 已提交
19
#include "paddle/fluid/framework/ddim.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/operators/layer_norm_op.h"
F
furnace 已提交
21 22
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
C
chengduoZH 已提交
23

S
sneaxiy 已提交
24 25 26
namespace paddle {
namespace operators {

F
furnace 已提交
27 28 29 30 31 32 33
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;

S
sneaxiy 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
inline static int GetDesiredBlockDim(int block_dim) {
  const int kMaxBlockDim = 512;
  return block_dim >= kMaxBlockDim
             ? kMaxBlockDim
             : (1 << (static_cast<int>(std::log2f(block_dim))));
}

#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...)  \
  case (1 << (log2_block_dim)): {                       \
    constexpr auto kBlockDim = (1 << (log2_block_dim)); \
    __VA_ARGS__;                                        \
  } break

#define FIXED_BLOCK_DIM_CASE(...)              \
  FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(2, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(1, ##__VA_ARGS__)

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                             \
    log2_block_dim, feature_size, kMaxBlockNum, ...)                           \
  case (1 << (log2_block_dim)): {                                              \
    for (int i = 0; i < std::ceil(feature_size / (1.0 * kMaxBlockNum)); i++) { \
      int col_offset = i * kMaxBlockNum;                                       \
      int block_num = std::min(feature_size - col_offset, kMaxBlockNum);       \
      constexpr auto kBlockDim = (1 << (log2_block_dim));                      \
      __VA_ARGS__;                                                             \
    }                                                                          \
  } break

#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(feature_size, kMaxBlockNum, ...) \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(9, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(8, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(7, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(6, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(5, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(4, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(3, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(2, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(1, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__)

89 90 91
static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); }
static __device__ __forceinline__ double real_sqrt(double x) { return sqrt(x); }

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
template <typename T>
struct PairForLayerNorm {
  __device__ __forceinline__ PairForLayerNorm() {}
  __device__ __forceinline__ PairForLayerNorm(const T &first, const T &second)
      : first_(first), second_(second) {}

  T first_;
  T second_;
};

template <typename T>
struct PairForLayerNormAddFunctor {
  __device__ __forceinline__ PairForLayerNorm<T> operator()(
      const PairForLayerNorm<T> &p1, const PairForLayerNorm<T> &p2) {
    return PairForLayerNorm<T>(p1.first_ + p2.first_, p1.second_ + p2.second_);
  }
};

L
Leo Chen 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
template <typename T>
__inline__ __device__ T rsqrt(const T val) {
  return ::rsqrt(val);
}

template <>
__inline__ __device__ float rsqrt(const float val) {
  return rsqrtf(val);
}

template <>
__inline__ __device__ half rsqrt(const half val) {
  return hrsqrt(val);
}

F
furnace 已提交
125 126 127
template <typename T, typename U, int BlockDim>
__global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
                                 T *y, U *mean, U *var, float epsilon,
S
sneaxiy 已提交
128
                                 int feature_size) {
L
Leo Chen 已提交
129
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
S
sneaxiy 已提交
130
  __shared__ typename BlockReduce::TempStorage temp_storage;
L
Leo Chen 已提交
131 132
  __shared__ U mean_share;
  __shared__ U var_share;
S
sneaxiy 已提交
133 134 135 136

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

137
  // Step 1: Reduce to calculate mean and var
L
Leo Chen 已提交
138 139
  U mean_val = 0;
  U var_val = 0;
S
sneaxiy 已提交
140
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
141
    U tmp = static_cast<U>(x[i]);
142
    mean_val += tmp;
S
sneaxiy 已提交
143 144
    var_val += (tmp * tmp);
  }
145
  auto pair = BlockReduce(temp_storage)
L
Leo Chen 已提交
146 147
                  .Reduce(PairForLayerNorm<U>(mean_val, var_val),
                          PairForLayerNormAddFunctor<U>());
148 149
  if (threadIdx.x == 0) {
    auto tmp = pair.first_ / feature_size;
L
Leo Chen 已提交
150 151 152
    mean[blockIdx.x] = mean_share = static_cast<U>(tmp);
    var[blockIdx.x] = var_share =
        static_cast<U>(pair.second_ / feature_size - tmp * tmp);
153
  }
S
sneaxiy 已提交
154
  __syncthreads();
L
Leo Chen 已提交
155 156 157

  mean_val = mean_share;
  U invvar = rsqrt<U>(var_share + static_cast<U>(epsilon));
S
sneaxiy 已提交
158

159
  // Step 2: Calculate y
S
sneaxiy 已提交
160 161 162 163
  if (scale != nullptr) {
    if (bias != nullptr) {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
F
furnace 已提交
164
        y[i] = static_cast<T>(
L
Leo Chen 已提交
165
            scale[j] * (static_cast<U>(x[i]) - mean_val) * invvar + bias[j]);
S
sneaxiy 已提交
166 167 168 169
      }
    } else {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
L
Leo Chen 已提交
170 171
        y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) *
                              invvar);
S
sneaxiy 已提交
172 173 174 175 176 177
      }
    }
  } else {  // scale == nullptr
    if (bias != nullptr) {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
L
Leo Chen 已提交
178
        y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
F
furnace 已提交
179
                              bias[j]);
S
sneaxiy 已提交
180 181 182 183
      }
    } else {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
L
Leo Chen 已提交
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
        y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar);
      }
    }
  }
}

template <typename T, typename U, int VPT>
__inline__ __device__ void cuLoadAddStridedInputs(
    const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
    const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
    const T *input, const T *dout, const int i1_end, const int n2,
    const U *__restrict__ mean, const U *__restrict__ var,
    const float epsilon) {
  const int i1 = i1_block + thr_load_row_off;
  if (i1 >= i1_end) return;
  U curr_mean = mean[i1];
  U curr_invvar = rsqrt<U>(var[i1] + epsilon);
  for (int k = 0; k < VPT; ++k) {
    const int i2 = i2_off + k;
    const int load_idx = i1 * n2 + i2;
    const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
    if (i2 < n2) {
      U curr_input = static_cast<U>(input[load_idx]);
      U curr_dout = static_cast<U>(dout[load_idx]);
      warp_buf1[write_idx] += curr_dout;
      warp_buf2[write_idx] +=
          curr_dout * (curr_input - curr_mean) * curr_invvar;
    }
  }
}

template <typename T, typename U, int BDIMX, int BDIMY, int VPTX>
__global__ void LayerNormBackwardPartGradGammaBeta(
    const T *__restrict__ dout, const T *__restrict__ input, const int n1,
    const int n2, const U *__restrict__ mean, const U *__restrict__ var,
    float epsilon, U *part_grad_gamma, U *part_grad_beta) {
  // VPTX -> value per thread.x, BDIMX -> blockDim.x, BDIMY -> blockDim.y, BDIMX
  // -> blockDim.x
  // template for compile time optimizations

  constexpr int row_stride = BDIMX + 1;
  const int thr_load_col_off = (threadIdx.x * VPTX) & (BDIMX - 1);
  const int thr_load_row_off =
      (threadIdx.x * VPTX) / BDIMX + threadIdx.y * BDIMY;
  const int i2_off = blockIdx.x * BDIMX + thr_load_col_off;

  constexpr int shared_cap = (BDIMX * BDIMY > 2 * VPTX * BDIMY * row_stride)
                                 ? BDIMX * BDIMY
                                 : 2 * VPTX * BDIMY * row_stride;
  __shared__ U buf[shared_cap];

  U *warp_buf1 = reinterpret_cast<U *>(buf);
  U *warp_buf2 = warp_buf1 + VPTX * BDIMY * row_stride;

  for (int idx = threadIdx.y * blockDim.x + threadIdx.x;
       idx < 2 * VPTX * BDIMY * row_stride; idx += BDIMX * BDIMY) {
    buf[idx] = U(0);
  }
  __syncthreads();

  for (int i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1;
       i1_block += VPTX * BDIMY * gridDim.y) {
    cuLoadAddStridedInputs<T, U, VPTX>(
        i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride,
        warp_buf1, warp_buf2, input, dout, n1, n2, mean, var, epsilon);
  }
  __syncthreads();

  // inter-warp reductions
  // sum within each warp
  U acc1 = U(0);
  U acc2 = U(0);
  for (int k = 0; k < VPTX; ++k) {
    int row1 = threadIdx.y + k * VPTX;
    int idx1 = row1 * row_stride + threadIdx.x;
    acc1 += warp_buf1[idx1];
    acc2 += warp_buf2[idx1];
  }
  warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
  warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
  __syncthreads();
  // sum all warps
  for (int offset = VPTX >> 1; offset > 1; offset >>= 1) {
    if (threadIdx.y < offset) {
      int row1 = threadIdx.y;
      int row2 = threadIdx.y + offset;
      int idx1 = row1 * row_stride + threadIdx.x;
      int idx2 = row2 * row_stride + threadIdx.x;
      warp_buf1[idx1] += warp_buf1[idx2];
      warp_buf2[idx1] += warp_buf2[idx2];
    }
    __syncthreads();
  }
  int i2 = blockIdx.x * blockDim.x + threadIdx.x;
  if (threadIdx.y == 0 && i2 < n2) {
    int row1 = threadIdx.y;
    int row2 = threadIdx.y + 1;
    int idx1 = row1 * row_stride + threadIdx.x;
    int idx2 = row2 * row_stride + threadIdx.x;
    part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
    part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
  }
}

template <typename T, typename U, int BDIMX, int BDIMY>
__global__ void LayerNormBackwardSumGradGammaBeta(
    const U *part_grad_gamma, const U *part_grad_beta, const int part_size,
    // const int n1, const int n2, T* grad_gamma, T* grad_beta) {
    const int n1, const int n2, U *grad_gamma, U *grad_beta) {
  // sum partial gradients for gamma and beta
  __shared__ U buf[BDIMX * BDIMY];
  int i2 = blockIdx.x * BDIMX + threadIdx.x;
  if (i2 < n2) {
    // each warp does sequential reductions until reduced part_size is num_warps
    int num_warp_reductions = part_size / BDIMY;
    U sum_gamma = U(0);
    U sum_beta = U(0);
    const U *part_grad_gamma_ptr =
        part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
    const U *part_grad_beta_ptr =
        part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
    for (int warp_offset = 0; warp_offset < num_warp_reductions;
         ++warp_offset) {
      sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
      sum_beta += part_grad_beta_ptr[warp_offset * n2];
    }
    // inter-warp reductions
    constexpr int nbsize3 = BDIMX * BDIMY / 2;
    for (int offset = BDIMY / 2; offset >= 1; offset /= 2) {
      // top half write to shared memory
      if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
        const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
        buf[write_idx] = sum_gamma;
        buf[write_idx + nbsize3] = sum_beta;
      }
      __syncthreads();
      // bottom half sums
      if (threadIdx.y < offset) {
        const int read_idx = threadIdx.y * BDIMX + threadIdx.x;
        sum_gamma += buf[read_idx];
        sum_beta += buf[read_idx + nbsize3];
      }
      __syncthreads();
    }
    // write out fully summed gradients
    if (threadIdx.y == 0) {
      grad_gamma[i2] = sum_gamma;
      grad_beta[i2] = sum_beta;
    }
  }
}

template <typename T, typename U, int BDIMX, int BDIMY>
__global__ void LayerNormBackwardComputeGradInput(
    const T *__restrict__ dout, const T *__restrict__ input, const int n1,
    const int n2,
    // const U* __restrict__ mean, const U* __restrict__ var, const float
    // epsilon, const T* gamma,
    const U *__restrict__ mean, const U *__restrict__ var, const float epsilon,
    const U *gamma, T *grad_input) {
  for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
    U sum_loss1 = U(0);
    U sum_loss2 = U(0);
    const U c_mean = mean[i1];
    const U c_invvar = rsqrt<U>(var[i1] + epsilon);
    const T *k_input = input + i1 * n2;
    const T *k_dout = dout + i1 * n2;
    constexpr int numx = BDIMX * BDIMY;
    const int thrx = threadIdx.x + threadIdx.y * BDIMX;
    if (gamma != NULL) {
      int l = 4 * thrx;
      for (; l + 3 < n2; l += 4 * numx) {
        for (int k = 0; k < 4; ++k) {
          const U c_h = static_cast<U>(k_input[l + k]);
          const U c_loss = static_cast<U>(k_dout[l + k]);
          sum_loss1 += c_loss * gamma[l + k];
          sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
        }
      }
      for (; l < n2; ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        sum_loss1 += c_loss * gamma[l];
        sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
      }
    } else {
      int l = 4 * thrx;
      for (; l + 3 < n2; l += 4 * numx) {
        for (int k = 0; k < 4; ++k) {
          const U c_h = static_cast<U>(k_input[l + k]);
          const U c_loss = static_cast<U>(k_dout[l + k]);
          sum_loss1 += c_loss;
          sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
        }
      }
      for (; l < n2; ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        sum_loss1 += c_loss;
        sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
      }
    }
    // intra-warp reductions
    for (int mask = BDIMX / 2; mask > 0; mask /= 2) {
      sum_loss1 +=
          __shfl_xor_sync(0xffffffff, sum_loss1, mask,
                          warpSize);  // WARP_SHFL_XOR(sum_loss1, mask);
      sum_loss2 +=
          __shfl_xor_sync(0xffffffff, sum_loss2, mask,
                          warpSize);  // WARP_SHFL_XOR(sum_loss2, mask);
    }
    // inter-warp reductions
    if (BDIMY > 1) {
      __shared__ U buf[BDIMX * BDIMY];
      for (int offset = BDIMY / 2; offset > 0; offset /= 2) {
        // upper half of warps write to shared
        if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
          const int wrt_i = (threadIdx.y - offset) * BDIMX + threadIdx.x;
          buf[2 * wrt_i] = sum_loss1;
          buf[2 * wrt_i + 1] = sum_loss2;
        }
        __syncthreads();
        // lower half merges
        if (threadIdx.y < offset) {
          const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
          sum_loss1 += buf[2 * read_i];
          sum_loss2 += buf[2 * read_i + 1];
        }
        __syncthreads();
      }
      if (threadIdx.y == 0) {
        buf[2 * threadIdx.x] = sum_loss1;
        buf[2 * threadIdx.x + 1] = sum_loss2;
      }
      __syncthreads();
      if (threadIdx.y != 0) {
        sum_loss1 = buf[2 * threadIdx.x];
        sum_loss2 = buf[2 * threadIdx.x + 1];
      }
    }
    // all threads now have the two sums over l
    U fH = (U)n2;
    U term1 = (U(1) / fH) * c_invvar;
    T *k_grad_input = grad_input + i1 * n2;
    if (gamma != NULL) {
      for (int l = thrx; l < n2; l += numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        U f_grad_input = fH * c_loss * gamma[l];
        f_grad_input -= sum_loss1;
        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
      }
    } else {
      for (int l = thrx; l < n2; l += numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        U f_grad_input = fH * c_loss;
        f_grad_input -= sum_loss1;
        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
S
sneaxiy 已提交
447 448 449 450 451 452 453
      }
    }
  }
}

// Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr
F
furnace 已提交
454
template <typename T, typename U, int BlockDim, bool HasDx>
S
sneaxiy 已提交
455
__global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
F
furnace 已提交
456 457 458
                                             U *d_scale, U *d_bias, T *d_x,
                                             const U *mean, const U *var,
                                             const U *scale, float epsilon,
459 460
                                             int batch_size, int feature_size,
                                             int col_offset) {
F
furnace 已提交
461
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
S
sneaxiy 已提交
462 463
  __shared__ typename BlockReduce::TempStorage temp_storage;

464 465
  int beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset);
  int end_idx = batch_size * feature_size + (blockIdx.x + col_offset);
S
sneaxiy 已提交
466
  int stride = BlockDim * feature_size;
467

F
furnace 已提交
468
  U d_scale_partial = static_cast<U>(0), d_bias_partial = static_cast<U>(0);
S
sneaxiy 已提交
469 470 471

  for (int i = beg_idx; i < end_idx; i += stride) {
    int row_idx = i / feature_size;
F
furnace 已提交
472 473 474 475
    auto var_val = real_sqrt(static_cast<U>(var[row_idx]) + epsilon);
    d_scale_partial += static_cast<U>(d_y[i]) *
                       (static_cast<U>(x[i]) - mean[row_idx]) / var_val;
    d_bias_partial += static_cast<U>(d_y[i]);
476
    if (HasDx) {
F
furnace 已提交
477 478
      d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
                              scale[blockIdx.x + col_offset] / var_val);
479
    }
S
sneaxiy 已提交
480 481
  }

482
  auto pair = BlockReduce(temp_storage)
F
furnace 已提交
483 484
                  .Reduce(PairForLayerNorm<U>(d_scale_partial, d_bias_partial),
                          PairForLayerNormAddFunctor<U>());
S
sneaxiy 已提交
485 486

  if (threadIdx.x == 0) {
487 488
    d_scale[blockIdx.x + col_offset] = pair.first_;
    d_bias[blockIdx.x + col_offset] = pair.second_;
S
sneaxiy 已提交
489 490 491 492 493 494
  }
}

// Make sure that there is only one true expression: d_scale != nullptr
// or d_bias != nullptr
// Notice: scale may be nullptr
F
furnace 已提交
495
template <typename T, typename U, int BlockDim, bool HasDx, bool HasDScale>
S
sneaxiy 已提交
496
__global__ void LayerNormBackwardGradientScaleOrBias(
F
furnace 已提交
497 498
    const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean,
    const U *var, const U *scale, float epsilon, int batch_size,
499
    int feature_size, int col_offset) {
F
furnace 已提交
500
  using BlockReduce = cub::BlockReduce<U, BlockDim>;
S
sneaxiy 已提交
501
  __shared__ typename BlockReduce::TempStorage temp_storage;
502 503
  int beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset;
  int end_idx = batch_size * feature_size + blockIdx.x + col_offset;
S
sneaxiy 已提交
504
  int stride = BlockDim * feature_size;
F
furnace 已提交
505
  U d_scale_or_d_bias_partial = static_cast<U>(0);
S
sneaxiy 已提交
506 507 508

  for (int i = beg_idx; i < end_idx; i += stride) {
    int row_idx = i / feature_size;
F
furnace 已提交
509 510
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(var[row_idx]) + epsilon));
S
sneaxiy 已提交
511
    if (HasDScale) {
F
furnace 已提交
512 513 514
      d_scale_or_d_bias_partial += static_cast<U>(d_y[i]) *
                                   (static_cast<U>(x[i]) - mean[row_idx]) /
                                   var_val;
S
sneaxiy 已提交
515
    } else {  // d_bias != nullptr
F
furnace 已提交
516
      d_scale_or_d_bias_partial += static_cast<U>(d_y[i]);
S
sneaxiy 已提交
517 518 519
    }

    if (HasDx) {
520
      if (scale != nullptr) {
F
furnace 已提交
521 522
        d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
                                scale[blockIdx.x + col_offset] / var_val);
523
      } else {
F
furnace 已提交
524
        d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
525
      }
S
sneaxiy 已提交
526 527 528 529 530 531 532 533
    }
  }

  d_scale_or_d_bias_partial =
      BlockReduce(temp_storage).Reduce(d_scale_or_d_bias_partial, cub::Sum());

  if (threadIdx.x == 0) {
    if (HasDScale) {
534
      d_scale[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
S
sneaxiy 已提交
535
    } else {
536
      d_bias[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
S
sneaxiy 已提交
537 538 539 540
    }
  }
}

F
furnace 已提交
541
template <typename T, typename U, int BlockDim>
542
__global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x,
F
furnace 已提交
543 544
                                                          const U *mean,
                                                          const U *var,
545 546
                                                          float epsilon,
                                                          int feature_size) {
F
furnace 已提交
547
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
548
  __shared__ typename BlockReduce::TempStorage temp_storage;
F
furnace 已提交
549
  __shared__ U d_x_reduce_tmp[2];
550 551 552 553

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

F
furnace 已提交
554 555 556
  U block_mean = mean[blockIdx.x];
  U block_var = var[blockIdx.x];
  U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
557
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
558 559 560
    d_x_mean_partial += static_cast<U>(d_x[i]);
    d_x_var_partial +=
        static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
561 562 563 564
  }

  auto pair =
      BlockReduce(temp_storage)
F
furnace 已提交
565 566
          .Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<U>());
567 568

  if (threadIdx.x == 0) {
F
furnace 已提交
569 570 571 572
    d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
    d_x_reduce_tmp[1] =
        static_cast<float>(pair.second_) /
        (feature_size * (static_cast<float>(block_var) + epsilon));
573 574 575 576 577 578
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
579 580 581
    d_x[i] -= static_cast<T>(d_x_mean_partial);
    d_x[i] -=
        static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
582 583 584
  }
}

S
sneaxiy 已提交
585
// Here, we only calculate d_x
F
furnace 已提交
586
template <typename T, typename U, int BlockDim>
587
__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
F
furnace 已提交
588 589
                                                T *d_x, const U *mean,
                                                const U *var, const U *scale,
590 591
                                                float epsilon,
                                                int feature_size) {
F
furnace 已提交
592
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
593
  __shared__ typename BlockReduce::TempStorage temp_storage;
F
furnace 已提交
594
  __shared__ U d_x_reduce_tmp[2];
595 596 597 598

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

F
furnace 已提交
599 600
  U block_mean = mean[blockIdx.x], block_var = var[blockIdx.x];
  U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
601
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
602 603
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon));
S
sneaxiy 已提交
604
    if (scale != nullptr) {
605
      int col_idx = i % feature_size;
F
furnace 已提交
606 607
      d_x[i] =
          static_cast<T>(static_cast<U>(d_y[i]) * scale[col_idx] / var_val);
S
sneaxiy 已提交
608
    } else {
F
furnace 已提交
609
      d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
S
sneaxiy 已提交
610
    }
F
furnace 已提交
611 612 613
    d_x_mean_partial += static_cast<U>(d_x[i]);
    d_x_var_partial +=
        static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
614 615 616 617
  }

  auto pair =
      BlockReduce(temp_storage)
F
furnace 已提交
618 619
          .Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<U>());
620 621

  if (threadIdx.x == 0) {
F
furnace 已提交
622 623 624 625
    d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
    d_x_reduce_tmp[1] =
        static_cast<float>(pair.second_) /
        (feature_size * (static_cast<float>(block_var) + epsilon));
626 627 628 629 630 631
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
632 633 634
    d_x[i] -= static_cast<T>(d_x_mean_partial);
    d_x[i] -=
        static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
S
sneaxiy 已提交
635 636 637
  }
}

F
furnace 已提交
638
template <typename T, typename U>
S
sneaxiy 已提交
639
__global__ void LayerNormBackwardWhenBatchSizeIsOne(
F
furnace 已提交
640 641
    const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean,
    const U *var, const U *scale, float epsilon, int feature_size) {
S
sneaxiy 已提交
642 643
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < feature_size) {
F
furnace 已提交
644 645
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(var[idx]) + epsilon));
S
sneaxiy 已提交
646
    if (d_x != nullptr) {
647
      if (d_scale == nullptr) {
F
furnace 已提交
648
        d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val);
649
      } else {
F
furnace 已提交
650 651
        d_x[idx] =
            static_cast<T>(static_cast<U>(d_y[idx]) * scale[idx] / var_val);
652
      }
S
sneaxiy 已提交
653
    }
654 655

    if (d_scale != nullptr) {
F
furnace 已提交
656 657
      d_scale[idx] = static_cast<U>(d_y[idx]) *
                     (static_cast<U>(x[idx]) - mean[idx]) / var_val;
658 659
    }

F
furnace 已提交
660
    if (d_bias != nullptr) d_bias[idx] = static_cast<U>(d_y[idx]);
S
sneaxiy 已提交
661 662 663
  }
}

F
furnace 已提交
664 665 666 667
template <typename T, typename U>
static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
                              const U *mean, const U *var, T *d_x, U *d_scale,
                              U *d_bias, float epsilon, int batch_size,
L
Leo Chen 已提交
668 669 670 671 672
                              int feature_size,
                              const framework::ExecutionContext &ctx) {
  auto &dev_ctx = ctx.cuda_device_context();
  auto stream = dev_ctx.stream();

S
sneaxiy 已提交
673
  const int kMaxBlockDim = 512;
674
  const int kMaxBlockNum = 128;
675 676 677
  int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) |
                      ((d_scale != nullptr ? 1 : 0) << 1) |
                      ((d_bias != nullptr ? 1 : 0));
S
sneaxiy 已提交
678 679 680 681
  if (gradient_flag == 0) return;

  if (batch_size == 1) {
    LayerNormBackwardWhenBatchSizeIsOne<
F
furnace 已提交
682 683 684
        T, U><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim,
                0, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale,
                             epsilon, feature_size);
685 686 687 688

    if (d_x != nullptr) {
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
689
                             T, U, kBlockDim><<<1, kBlockDim, 0, stream>>>(
690 691 692
            x, d_x, mean, var, epsilon, feature_size));
      }
    }
S
sneaxiy 已提交
693 694 695 696 697 698 699
    return;
  }

  auto block_dim = GetDesiredBlockDim(batch_size);
  switch (gradient_flag) {
    case 1:  // d_x == nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
700 701 702
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
703
                T, U, kBlockDim, false,
704 705 706
                false><<<block_num, kBlockDim, 0, stream>>>(
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
707 708 709 710
      }
      break;
    case 2:  // d_x == nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
711 712 713
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
714 715
                T, U, kBlockDim, false,
                true><<<block_num, kBlockDim, 0, stream>>>(
716 717
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
718 719 720 721
      }
      break;
    case 3:  // d_x == nullptr, d_scale != nulptr, d_bias != nullptr
      switch (block_dim) {
722 723
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
S
sneaxiy 已提交
724
            LayerNormBackwardGradientAll<
F
furnace 已提交
725
                T, U, kBlockDim, false><<<block_num, kBlockDim, 0, stream>>>(
S
sneaxiy 已提交
726
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
727
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
728 729 730
      }
      break;
    case 4:  // d_x != nullptr, d_scale == nullptr, d_bias == nullptr
731 732 733
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardGradientOnlyDX<
F
furnace 已提交
734
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
735 736
                x, d_y, d_x, mean, var, scale, epsilon, feature_size));
      }
S
sneaxiy 已提交
737 738 739
      break;
    case 5:  // d_x != nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
740 741 742
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
743 744
                T, U, kBlockDim, true,
                false><<<block_num, kBlockDim, 0, stream>>>(
745 746
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
747
      }
748 749 750
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
751
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
752 753
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
754 755 756
      break;
    case 6:  // d_x != nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
757 758 759
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
760 761
                T, U, kBlockDim, true,
                true><<<block_num, kBlockDim, 0, stream>>>(
762 763
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
764
      }
765 766 767
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
768
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
769 770
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
771 772
      break;
    case 7:  // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
L
Leo Chen 已提交
773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808
    {
      constexpr int VPT = 4;
      constexpr int BDIMX2 = 32;
      constexpr int BDIMY2 = 4;
      dim3 threads2(BDIMX2, BDIMY2, 1);
      constexpr int part_size = BDIMY2 * VPT;
      const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1);

      auto part_grad_gamma_ptr =
          memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U));
      auto part_grad_beta_ptr =
          memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U));
      U *part_grad_gamma = reinterpret_cast<U *>(part_grad_gamma_ptr->ptr());
      U *part_grad_beta = reinterpret_cast<U *>(part_grad_beta_ptr->ptr());

      LayerNormBackwardPartGradGammaBeta<T, U, BDIMX2, BDIMY2,
                                         VPT><<<blocks2, threads2, 0, stream>>>(
          d_y, x, batch_size, feature_size, mean, var, epsilon, part_grad_gamma,
          part_grad_beta);  // compute part_grad_gamma, beta

      constexpr int BDIMX3 = 32;
      constexpr int BDIMY3 = 8;
      dim3 threads3(BDIMX3, BDIMY3, 1);
      const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1);
      LayerNormBackwardSumGradGammaBeta<
          T, U, BDIMX3, BDIMY3><<<blocks3, threads3, 0, stream>>>(
          part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size,
          d_scale, d_bias);

      constexpr int BDIMX1 = 32;
      constexpr int BDIMY1 = 4;
      dim3 threads1(BDIMX1, BDIMY1, 1);
      const dim3 blocks1(1, batch_size, 1);
      LayerNormBackwardComputeGradInput<
          T, U, BDIMX1, BDIMY1><<<blocks1, threads1, 0, stream>>>(
          d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
S
sneaxiy 已提交
809
      break;
L
Leo Chen 已提交
810
    }
S
sneaxiy 已提交
811 812 813 814 815
    default:
      break;
  }
}

P
Pei Yang 已提交
816 817 818 819 820 821 822 823 824 825 826 827 828
template <typename T>
void LayerNormDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
                                               const T *input,
                                               std::vector<int> input_shape,
                                               const T *bias, const T *scale,
                                               T *output, T *mean, T *variance,
                                               int begin_norm_axis, float eps) {
  const auto x_dims = framework::make_ddim(input_shape);
  auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
  int batch_size = static_cast<int>(matrix_dim[0]);
  int feature_size = static_cast<int>(matrix_dim[1]);
  switch (GetDesiredBlockDim(feature_size)) {
    FIXED_BLOCK_DIM_CASE(
F
furnace 已提交
829
        LayerNormForward<T, T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
P
Pei Yang 已提交
830 831 832 833 834 835 836 837 838
            input, scale, bias, output, mean, variance, eps, feature_size));
    default:
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Product from begin_norm_axis to end in layer_norm must be larger "
          "than 1"));
      break;
  }
}

S
sneaxiy 已提交
839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856
template <typename T>
class LayerNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const float epsilon = ctx.Attr<float>("epsilon");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *bias = ctx.Input<Tensor>("Bias");
    auto *x = ctx.Input<Tensor>("X");

    auto *y = ctx.Output<Tensor>("Y");
    auto *mean = ctx.Output<Tensor>("Mean");
    auto *var = ctx.Output<Tensor>("Variance");
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");

    const auto x_dims = x->dims();
    auto *x_data = x->data<T>();
    auto *y_data = y->mutable_data<T>(ctx.GetPlace());
F
furnace 已提交
857 858 859 860 861 862
    auto *mean_data = mean->mutable_data<LayerNormParamType<T>>(ctx.GetPlace());
    auto *var_data = var->mutable_data<LayerNormParamType<T>>(ctx.GetPlace());
    auto *scale_data =
        (scale == nullptr ? nullptr : scale->data<LayerNormParamType<T>>());
    auto *bias_data =
        (bias == nullptr ? nullptr : bias->data<LayerNormParamType<T>>());
S
sneaxiy 已提交
863 864 865 866 867 868 869 870 871

    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
    int batch_size = static_cast<int>(matrix_dim[0]);
    int feature_size = static_cast<int>(matrix_dim[1]);

    auto stream = ctx.cuda_device_context().stream();

    switch (GetDesiredBlockDim(feature_size)) {
      FIXED_BLOCK_DIM_CASE(
F
furnace 已提交
872 873
          LayerNormForward<T, LayerNormParamType<T>,
                           kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
S
sneaxiy 已提交
874 875 876
              x_data, scale_data, bias_data, y_data, mean_data, var_data,
              epsilon, feature_size));
      default:
877 878
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Product from begin_norm_axis to end must be larger than 1"));
S
sneaxiy 已提交
879 880 881 882 883 884 885 886 887 888
        break;
    }
  }
};

template <typename T>
class LayerNormGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
F
furnace 已提交
889
    using U = LayerNormParamType<T>;
S
sneaxiy 已提交
890 891 892 893 894 895 896 897 898 899 900 901 902 903
    const float epsilon = ctx.Attr<float>("epsilon");
    // d_x, d_scale, d_bias may be nullptr
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    auto *x = ctx.Input<Tensor>("X");
    auto *mean = ctx.Input<Tensor>("Mean");
    auto *var = ctx.Input<Tensor>("Variance");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));

    auto *x_data = x->data<T>();
    auto *d_y_data = d_y->data<T>();
F
furnace 已提交
904 905 906 907
    auto *mean_data = mean->data<U>();
    auto *var_data = var->data<U>();

    auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
S
sneaxiy 已提交
908 909
    auto *d_scale_data =
        (d_scale == nullptr ? nullptr
F
furnace 已提交
910
                            : d_scale->mutable_data<U>(ctx.GetPlace()));
S
sneaxiy 已提交
911
    auto *d_bias_data =
F
furnace 已提交
912
        (d_bias == nullptr ? nullptr : d_bias->mutable_data<U>(ctx.GetPlace()));
S
sneaxiy 已提交
913 914 915 916 917 918 919 920 921
    auto *d_x_data =
        (d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));

    const auto &x_dims = x->dims();
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
    int batch_size = static_cast<int>(matrix_dim[0]);
    int feature_size = static_cast<int>(matrix_dim[1]);

F
furnace 已提交
922 923
    LayerNormBackward<T, U>(x_data, d_y_data, scale_data, mean_data, var_data,
                            d_x_data, d_scale_data, d_bias_data, epsilon,
L
Leo Chen 已提交
924
                            batch_size, feature_size, ctx);
S
sneaxiy 已提交
925 926
  }
};
F
furnace 已提交
927

P
Pei Yang 已提交
928
template class LayerNormDirectCUDAFunctor<float>;
F
furnace 已提交
929

930 931
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
S
sneaxiy 已提交
932 933 934 935 936
#undef FIXED_BLOCK_DIM_CASE_BASE
#undef FIXED_BLOCK_DIM_CASE
}  // namespace operators
}  // namespace paddle

C
chengduoZH 已提交
937
namespace ops = paddle::operators;
F
furnace 已提交
938
namespace plat = paddle::platform;
C
chengduoZH 已提交
939 940
REGISTER_OP_CUDA_KERNEL(
    layer_norm,
C
chengduoZH 已提交
941
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
F
furnace 已提交
942 943
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
C
chengduoZH 已提交
944 945
REGISTER_OP_CUDA_KERNEL(
    layer_norm_grad,
C
chengduoZH 已提交
946
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
F
furnace 已提交
947 948 949
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
                             plat::float16>);