layer_norm_op.cu 37.2 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

S
sneaxiy 已提交
15
#include <cub/cub.cuh>
P
Pei Yang 已提交
16 17
#include <memory>
#include <vector>
F
furnace 已提交
18

P
Pei Yang 已提交
19
#include "paddle/fluid/framework/ddim.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/operators/layer_norm_op.h"
F
furnace 已提交
21 22
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
C
chengduoZH 已提交
23

S
sneaxiy 已提交
24 25 26
namespace paddle {
namespace operators {

F
furnace 已提交
27 28 29 30 31 32 33
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;

S
sneaxiy 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
inline static int GetDesiredBlockDim(int block_dim) {
  const int kMaxBlockDim = 512;
  return block_dim >= kMaxBlockDim
             ? kMaxBlockDim
             : (1 << (static_cast<int>(std::log2f(block_dim))));
}

#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...)  \
  case (1 << (log2_block_dim)): {                       \
    constexpr auto kBlockDim = (1 << (log2_block_dim)); \
    __VA_ARGS__;                                        \
  } break

#define FIXED_BLOCK_DIM_CASE(...)              \
  FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(2, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(1, ##__VA_ARGS__)

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                             \
    log2_block_dim, feature_size, kMaxBlockNum, ...)                           \
  case (1 << (log2_block_dim)): {                                              \
    for (int i = 0; i < std::ceil(feature_size / (1.0 * kMaxBlockNum)); i++) { \
      int col_offset = i * kMaxBlockNum;                                       \
      int block_num = std::min(feature_size - col_offset, kMaxBlockNum);       \
      constexpr auto kBlockDim = (1 << (log2_block_dim));                      \
      __VA_ARGS__;                                                             \
    }                                                                          \
  } break

#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(feature_size, kMaxBlockNum, ...) \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(9, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(8, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(7, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(6, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(5, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(4, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(3, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(2, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(1, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__)

89 90 91
static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); }
static __device__ __forceinline__ double real_sqrt(double x) { return sqrt(x); }

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
template <typename T>
struct PairForLayerNorm {
  __device__ __forceinline__ PairForLayerNorm() {}
  __device__ __forceinline__ PairForLayerNorm(const T &first, const T &second)
      : first_(first), second_(second) {}

  T first_;
  T second_;
};

template <typename T>
struct PairForLayerNormAddFunctor {
  __device__ __forceinline__ PairForLayerNorm<T> operator()(
      const PairForLayerNorm<T> &p1, const PairForLayerNorm<T> &p2) {
    return PairForLayerNorm<T>(p1.first_ + p2.first_, p1.second_ + p2.second_);
  }
};

L
Leo Chen 已提交
110
template <typename T>
111
__inline__ __device__ T rsqrt_(const T val) {
112
  return static_cast<T>(1) / sqrt(val);
L
Leo Chen 已提交
113 114 115
}

template <>
116
__inline__ __device__ float rsqrt_(const float val) {
L
Leo Chen 已提交
117 118 119
  return rsqrtf(val);
}

120
template <>
121
__inline__ __device__ double rsqrt_(const double val) {
122 123 124 125
  return rsqrt(val);
}

#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
L
Leo Chen 已提交
126
template <>
127
__inline__ __device__ half rsqrt_(const half val) {
L
Leo Chen 已提交
128 129
  return hrsqrt(val);
}
130
#endif
L
Leo Chen 已提交
131

F
furnace 已提交
132 133 134
template <typename T, typename U, int BlockDim>
__global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
                                 T *y, U *mean, U *var, float epsilon,
S
sneaxiy 已提交
135
                                 int feature_size) {
L
Leo Chen 已提交
136
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
S
sneaxiy 已提交
137
  __shared__ typename BlockReduce::TempStorage temp_storage;
L
Leo Chen 已提交
138 139
  __shared__ U mean_share;
  __shared__ U var_share;
S
sneaxiy 已提交
140 141 142 143

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

144
  // Step 1: Reduce to calculate mean and var
L
Leo Chen 已提交
145 146
  U mean_val = 0;
  U var_val = 0;
S
sneaxiy 已提交
147
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
148
    U tmp = static_cast<U>(x[i]);
149
    mean_val += tmp;
S
sneaxiy 已提交
150 151
    var_val += (tmp * tmp);
  }
152
  auto pair = BlockReduce(temp_storage)
L
Leo Chen 已提交
153 154
                  .Reduce(PairForLayerNorm<U>(mean_val, var_val),
                          PairForLayerNormAddFunctor<U>());
155 156
  if (threadIdx.x == 0) {
    auto tmp = pair.first_ / feature_size;
L
Leo Chen 已提交
157 158 159
    mean[blockIdx.x] = mean_share = static_cast<U>(tmp);
    var[blockIdx.x] = var_share =
        static_cast<U>(pair.second_ / feature_size - tmp * tmp);
160
  }
S
sneaxiy 已提交
161
  __syncthreads();
L
Leo Chen 已提交
162 163

  mean_val = mean_share;
164
  U invvar = rsqrt_<U>(var_share + static_cast<U>(epsilon));
S
sneaxiy 已提交
165

166
  // Step 2: Calculate y
S
sneaxiy 已提交
167 168 169 170
  if (scale != nullptr) {
    if (bias != nullptr) {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
F
furnace 已提交
171
        y[i] = static_cast<T>(
L
Leo Chen 已提交
172
            scale[j] * (static_cast<U>(x[i]) - mean_val) * invvar + bias[j]);
S
sneaxiy 已提交
173 174 175 176
      }
    } else {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
L
Leo Chen 已提交
177 178
        y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) *
                              invvar);
S
sneaxiy 已提交
179 180 181 182 183 184
      }
    }
  } else {  // scale == nullptr
    if (bias != nullptr) {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
L
Leo Chen 已提交
185
        y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
F
furnace 已提交
186
                              bias[j]);
S
sneaxiy 已提交
187 188 189 190
      }
    } else {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
L
Leo Chen 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
        y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar);
      }
    }
  }
}

template <typename T, typename U, int VPT>
__inline__ __device__ void cuLoadAddStridedInputs(
    const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
    const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
    const T *input, const T *dout, const int i1_end, const int n2,
    const U *__restrict__ mean, const U *__restrict__ var,
    const float epsilon) {
  const int i1 = i1_block + thr_load_row_off;
  if (i1 >= i1_end) return;
  U curr_mean = mean[i1];
207
  U curr_invvar = rsqrt_<U>(var[i1] + epsilon);
L
Leo Chen 已提交
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
  for (int k = 0; k < VPT; ++k) {
    const int i2 = i2_off + k;
    const int load_idx = i1 * n2 + i2;
    const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
    if (i2 < n2) {
      U curr_input = static_cast<U>(input[load_idx]);
      U curr_dout = static_cast<U>(dout[load_idx]);
      warp_buf1[write_idx] += curr_dout;
      warp_buf2[write_idx] +=
          curr_dout * (curr_input - curr_mean) * curr_invvar;
    }
  }
}

template <typename T, typename U, int BDIMX, int BDIMY, int VPTX>
__global__ void LayerNormBackwardPartGradGammaBeta(
    const T *__restrict__ dout, const T *__restrict__ input, const int n1,
    const int n2, const U *__restrict__ mean, const U *__restrict__ var,
    float epsilon, U *part_grad_gamma, U *part_grad_beta) {
  // VPTX -> value per thread.x, BDIMX -> blockDim.x, BDIMY -> blockDim.y, BDIMX
  // -> blockDim.x
  // template for compile time optimizations

  constexpr int row_stride = BDIMX + 1;
  const int thr_load_col_off = (threadIdx.x * VPTX) & (BDIMX - 1);
  const int thr_load_row_off =
      (threadIdx.x * VPTX) / BDIMX + threadIdx.y * BDIMY;
  const int i2_off = blockIdx.x * BDIMX + thr_load_col_off;

  constexpr int shared_cap = (BDIMX * BDIMY > 2 * VPTX * BDIMY * row_stride)
                                 ? BDIMX * BDIMY
                                 : 2 * VPTX * BDIMY * row_stride;
  __shared__ U buf[shared_cap];

  U *warp_buf1 = reinterpret_cast<U *>(buf);
  U *warp_buf2 = warp_buf1 + VPTX * BDIMY * row_stride;

  for (int idx = threadIdx.y * blockDim.x + threadIdx.x;
       idx < 2 * VPTX * BDIMY * row_stride; idx += BDIMX * BDIMY) {
    buf[idx] = U(0);
  }
  __syncthreads();

  for (int i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1;
       i1_block += VPTX * BDIMY * gridDim.y) {
    cuLoadAddStridedInputs<T, U, VPTX>(
        i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride,
        warp_buf1, warp_buf2, input, dout, n1, n2, mean, var, epsilon);
  }
  __syncthreads();

  // inter-warp reductions
  // sum within each warp
  U acc1 = U(0);
  U acc2 = U(0);
  for (int k = 0; k < VPTX; ++k) {
    int row1 = threadIdx.y + k * VPTX;
    int idx1 = row1 * row_stride + threadIdx.x;
    acc1 += warp_buf1[idx1];
    acc2 += warp_buf2[idx1];
  }
  warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
  warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
  __syncthreads();
  // sum all warps
  for (int offset = VPTX >> 1; offset > 1; offset >>= 1) {
    if (threadIdx.y < offset) {
      int row1 = threadIdx.y;
      int row2 = threadIdx.y + offset;
      int idx1 = row1 * row_stride + threadIdx.x;
      int idx2 = row2 * row_stride + threadIdx.x;
      warp_buf1[idx1] += warp_buf1[idx2];
      warp_buf2[idx1] += warp_buf2[idx2];
    }
    __syncthreads();
  }
  int i2 = blockIdx.x * blockDim.x + threadIdx.x;
  if (threadIdx.y == 0 && i2 < n2) {
    int row1 = threadIdx.y;
    int row2 = threadIdx.y + 1;
    int idx1 = row1 * row_stride + threadIdx.x;
    int idx2 = row2 * row_stride + threadIdx.x;
    part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
    part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
  }
}

template <typename T, typename U, int BDIMX, int BDIMY>
__global__ void LayerNormBackwardSumGradGammaBeta(
    const U *part_grad_gamma, const U *part_grad_beta, const int part_size,
    // const int n1, const int n2, T* grad_gamma, T* grad_beta) {
    const int n1, const int n2, U *grad_gamma, U *grad_beta) {
  // sum partial gradients for gamma and beta
  __shared__ U buf[BDIMX * BDIMY];
  int i2 = blockIdx.x * BDIMX + threadIdx.x;
  if (i2 < n2) {
    // each warp does sequential reductions until reduced part_size is num_warps
    int num_warp_reductions = part_size / BDIMY;
    U sum_gamma = U(0);
    U sum_beta = U(0);
    const U *part_grad_gamma_ptr =
        part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
    const U *part_grad_beta_ptr =
        part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
    for (int warp_offset = 0; warp_offset < num_warp_reductions;
         ++warp_offset) {
      sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
      sum_beta += part_grad_beta_ptr[warp_offset * n2];
    }
    // inter-warp reductions
    constexpr int nbsize3 = BDIMX * BDIMY / 2;
    for (int offset = BDIMY / 2; offset >= 1; offset /= 2) {
      // top half write to shared memory
      if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
        const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
        buf[write_idx] = sum_gamma;
        buf[write_idx + nbsize3] = sum_beta;
      }
      __syncthreads();
      // bottom half sums
      if (threadIdx.y < offset) {
        const int read_idx = threadIdx.y * BDIMX + threadIdx.x;
        sum_gamma += buf[read_idx];
        sum_beta += buf[read_idx + nbsize3];
      }
      __syncthreads();
    }
    // write out fully summed gradients
    if (threadIdx.y == 0) {
      grad_gamma[i2] = sum_gamma;
      grad_beta[i2] = sum_beta;
    }
  }
}

template <typename T, typename U, int BDIMX, int BDIMY>
__global__ void LayerNormBackwardComputeGradInput(
    const T *__restrict__ dout, const T *__restrict__ input, const int n1,
    const int n2,
    // const U* __restrict__ mean, const U* __restrict__ var, const float
    // epsilon, const T* gamma,
    const U *__restrict__ mean, const U *__restrict__ var, const float epsilon,
    const U *gamma, T *grad_input) {
  for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
    U sum_loss1 = U(0);
    U sum_loss2 = U(0);
    const U c_mean = mean[i1];
355
    const U c_invvar = rsqrt_<U>(var[i1] + epsilon);
L
Leo Chen 已提交
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453
    const T *k_input = input + i1 * n2;
    const T *k_dout = dout + i1 * n2;
    constexpr int numx = BDIMX * BDIMY;
    const int thrx = threadIdx.x + threadIdx.y * BDIMX;
    if (gamma != NULL) {
      int l = 4 * thrx;
      for (; l + 3 < n2; l += 4 * numx) {
        for (int k = 0; k < 4; ++k) {
          const U c_h = static_cast<U>(k_input[l + k]);
          const U c_loss = static_cast<U>(k_dout[l + k]);
          sum_loss1 += c_loss * gamma[l + k];
          sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
        }
      }
      for (; l < n2; ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        sum_loss1 += c_loss * gamma[l];
        sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
      }
    } else {
      int l = 4 * thrx;
      for (; l + 3 < n2; l += 4 * numx) {
        for (int k = 0; k < 4; ++k) {
          const U c_h = static_cast<U>(k_input[l + k]);
          const U c_loss = static_cast<U>(k_dout[l + k]);
          sum_loss1 += c_loss;
          sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
        }
      }
      for (; l < n2; ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        sum_loss1 += c_loss;
        sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
      }
    }
    // intra-warp reductions
    for (int mask = BDIMX / 2; mask > 0; mask /= 2) {
      sum_loss1 +=
          __shfl_xor_sync(0xffffffff, sum_loss1, mask,
                          warpSize);  // WARP_SHFL_XOR(sum_loss1, mask);
      sum_loss2 +=
          __shfl_xor_sync(0xffffffff, sum_loss2, mask,
                          warpSize);  // WARP_SHFL_XOR(sum_loss2, mask);
    }
    // inter-warp reductions
    if (BDIMY > 1) {
      __shared__ U buf[BDIMX * BDIMY];
      for (int offset = BDIMY / 2; offset > 0; offset /= 2) {
        // upper half of warps write to shared
        if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
          const int wrt_i = (threadIdx.y - offset) * BDIMX + threadIdx.x;
          buf[2 * wrt_i] = sum_loss1;
          buf[2 * wrt_i + 1] = sum_loss2;
        }
        __syncthreads();
        // lower half merges
        if (threadIdx.y < offset) {
          const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
          sum_loss1 += buf[2 * read_i];
          sum_loss2 += buf[2 * read_i + 1];
        }
        __syncthreads();
      }
      if (threadIdx.y == 0) {
        buf[2 * threadIdx.x] = sum_loss1;
        buf[2 * threadIdx.x + 1] = sum_loss2;
      }
      __syncthreads();
      if (threadIdx.y != 0) {
        sum_loss1 = buf[2 * threadIdx.x];
        sum_loss2 = buf[2 * threadIdx.x + 1];
      }
    }
    // all threads now have the two sums over l
    U fH = (U)n2;
    U term1 = (U(1) / fH) * c_invvar;
    T *k_grad_input = grad_input + i1 * n2;
    if (gamma != NULL) {
      for (int l = thrx; l < n2; l += numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        U f_grad_input = fH * c_loss * gamma[l];
        f_grad_input -= sum_loss1;
        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
      }
    } else {
      for (int l = thrx; l < n2; l += numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        U f_grad_input = fH * c_loss;
        f_grad_input -= sum_loss1;
        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
S
sneaxiy 已提交
454 455 456 457 458 459 460
      }
    }
  }
}

// Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr
F
furnace 已提交
461
template <typename T, typename U, int BlockDim, bool HasDx>
S
sneaxiy 已提交
462
__global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
F
furnace 已提交
463 464 465
                                             U *d_scale, U *d_bias, T *d_x,
                                             const U *mean, const U *var,
                                             const U *scale, float epsilon,
466 467
                                             int batch_size, int feature_size,
                                             int col_offset) {
F
furnace 已提交
468
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
S
sneaxiy 已提交
469 470
  __shared__ typename BlockReduce::TempStorage temp_storage;

471 472
  int beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset);
  int end_idx = batch_size * feature_size + (blockIdx.x + col_offset);
S
sneaxiy 已提交
473
  int stride = BlockDim * feature_size;
474

F
furnace 已提交
475
  U d_scale_partial = static_cast<U>(0), d_bias_partial = static_cast<U>(0);
S
sneaxiy 已提交
476 477 478

  for (int i = beg_idx; i < end_idx; i += stride) {
    int row_idx = i / feature_size;
F
furnace 已提交
479 480 481 482
    auto var_val = real_sqrt(static_cast<U>(var[row_idx]) + epsilon);
    d_scale_partial += static_cast<U>(d_y[i]) *
                       (static_cast<U>(x[i]) - mean[row_idx]) / var_val;
    d_bias_partial += static_cast<U>(d_y[i]);
483
    if (HasDx) {
F
furnace 已提交
484 485
      d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
                              scale[blockIdx.x + col_offset] / var_val);
486
    }
S
sneaxiy 已提交
487 488
  }

489
  auto pair = BlockReduce(temp_storage)
F
furnace 已提交
490 491
                  .Reduce(PairForLayerNorm<U>(d_scale_partial, d_bias_partial),
                          PairForLayerNormAddFunctor<U>());
S
sneaxiy 已提交
492 493

  if (threadIdx.x == 0) {
494 495
    d_scale[blockIdx.x + col_offset] = pair.first_;
    d_bias[blockIdx.x + col_offset] = pair.second_;
S
sneaxiy 已提交
496 497 498 499 500 501
  }
}

// Make sure that there is only one true expression: d_scale != nullptr
// or d_bias != nullptr
// Notice: scale may be nullptr
F
furnace 已提交
502
template <typename T, typename U, int BlockDim, bool HasDx, bool HasDScale>
S
sneaxiy 已提交
503
__global__ void LayerNormBackwardGradientScaleOrBias(
F
furnace 已提交
504 505
    const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean,
    const U *var, const U *scale, float epsilon, int batch_size,
506
    int feature_size, int col_offset) {
F
furnace 已提交
507
  using BlockReduce = cub::BlockReduce<U, BlockDim>;
S
sneaxiy 已提交
508
  __shared__ typename BlockReduce::TempStorage temp_storage;
509 510
  int beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset;
  int end_idx = batch_size * feature_size + blockIdx.x + col_offset;
S
sneaxiy 已提交
511
  int stride = BlockDim * feature_size;
F
furnace 已提交
512
  U d_scale_or_d_bias_partial = static_cast<U>(0);
S
sneaxiy 已提交
513 514 515

  for (int i = beg_idx; i < end_idx; i += stride) {
    int row_idx = i / feature_size;
F
furnace 已提交
516 517
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(var[row_idx]) + epsilon));
S
sneaxiy 已提交
518
    if (HasDScale) {
F
furnace 已提交
519 520 521
      d_scale_or_d_bias_partial += static_cast<U>(d_y[i]) *
                                   (static_cast<U>(x[i]) - mean[row_idx]) /
                                   var_val;
S
sneaxiy 已提交
522
    } else {  // d_bias != nullptr
F
furnace 已提交
523
      d_scale_or_d_bias_partial += static_cast<U>(d_y[i]);
S
sneaxiy 已提交
524 525 526
    }

    if (HasDx) {
527
      if (scale != nullptr) {
F
furnace 已提交
528 529
        d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
                                scale[blockIdx.x + col_offset] / var_val);
530
      } else {
F
furnace 已提交
531
        d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
532
      }
S
sneaxiy 已提交
533 534 535 536 537 538 539 540
    }
  }

  d_scale_or_d_bias_partial =
      BlockReduce(temp_storage).Reduce(d_scale_or_d_bias_partial, cub::Sum());

  if (threadIdx.x == 0) {
    if (HasDScale) {
541
      d_scale[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
S
sneaxiy 已提交
542
    } else {
543
      d_bias[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
S
sneaxiy 已提交
544 545 546 547
    }
  }
}

F
furnace 已提交
548
template <typename T, typename U, int BlockDim>
549
__global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x,
F
furnace 已提交
550 551
                                                          const U *mean,
                                                          const U *var,
552 553
                                                          float epsilon,
                                                          int feature_size) {
F
furnace 已提交
554
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
555
  __shared__ typename BlockReduce::TempStorage temp_storage;
F
furnace 已提交
556
  __shared__ U d_x_reduce_tmp[2];
557 558 559 560

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

F
furnace 已提交
561 562 563
  U block_mean = mean[blockIdx.x];
  U block_var = var[blockIdx.x];
  U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
564
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
565 566 567
    d_x_mean_partial += static_cast<U>(d_x[i]);
    d_x_var_partial +=
        static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
568 569 570 571
  }

  auto pair =
      BlockReduce(temp_storage)
F
furnace 已提交
572 573
          .Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<U>());
574 575

  if (threadIdx.x == 0) {
F
furnace 已提交
576 577 578 579
    d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
    d_x_reduce_tmp[1] =
        static_cast<float>(pair.second_) /
        (feature_size * (static_cast<float>(block_var) + epsilon));
580 581 582 583 584 585
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
586 587 588
    d_x[i] -= static_cast<T>(d_x_mean_partial);
    d_x[i] -=
        static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
589 590 591
  }
}

S
sneaxiy 已提交
592
// Here, we only calculate d_x
F
furnace 已提交
593
template <typename T, typename U, int BlockDim>
594
__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
F
furnace 已提交
595 596
                                                T *d_x, const U *mean,
                                                const U *var, const U *scale,
597 598
                                                float epsilon,
                                                int feature_size) {
F
furnace 已提交
599
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
600
  __shared__ typename BlockReduce::TempStorage temp_storage;
F
furnace 已提交
601
  __shared__ U d_x_reduce_tmp[2];
602 603 604 605

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

F
furnace 已提交
606 607
  U block_mean = mean[blockIdx.x], block_var = var[blockIdx.x];
  U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
608
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
609 610
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon));
S
sneaxiy 已提交
611
    if (scale != nullptr) {
612
      int col_idx = i % feature_size;
F
furnace 已提交
613 614
      d_x[i] =
          static_cast<T>(static_cast<U>(d_y[i]) * scale[col_idx] / var_val);
S
sneaxiy 已提交
615
    } else {
F
furnace 已提交
616
      d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
S
sneaxiy 已提交
617
    }
F
furnace 已提交
618 619 620
    d_x_mean_partial += static_cast<U>(d_x[i]);
    d_x_var_partial +=
        static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
621 622 623 624
  }

  auto pair =
      BlockReduce(temp_storage)
F
furnace 已提交
625 626
          .Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<U>());
627 628

  if (threadIdx.x == 0) {
F
furnace 已提交
629 630 631 632
    d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
    d_x_reduce_tmp[1] =
        static_cast<float>(pair.second_) /
        (feature_size * (static_cast<float>(block_var) + epsilon));
633 634 635 636 637 638
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
639 640 641
    d_x[i] -= static_cast<T>(d_x_mean_partial);
    d_x[i] -=
        static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
S
sneaxiy 已提交
642 643 644
  }
}

F
furnace 已提交
645
template <typename T, typename U>
S
sneaxiy 已提交
646
__global__ void LayerNormBackwardWhenBatchSizeIsOne(
F
furnace 已提交
647 648
    const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean,
    const U *var, const U *scale, float epsilon, int feature_size) {
S
sneaxiy 已提交
649 650
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < feature_size) {
F
furnace 已提交
651 652
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(var[idx]) + epsilon));
S
sneaxiy 已提交
653
    if (d_x != nullptr) {
654
      if (d_scale == nullptr) {
F
furnace 已提交
655
        d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val);
656
      } else {
F
furnace 已提交
657 658
        d_x[idx] =
            static_cast<T>(static_cast<U>(d_y[idx]) * scale[idx] / var_val);
659
      }
S
sneaxiy 已提交
660
    }
661 662

    if (d_scale != nullptr) {
F
furnace 已提交
663 664
      d_scale[idx] = static_cast<U>(d_y[idx]) *
                     (static_cast<U>(x[idx]) - mean[idx]) / var_val;
665 666
    }

F
furnace 已提交
667
    if (d_bias != nullptr) d_bias[idx] = static_cast<U>(d_y[idx]);
S
sneaxiy 已提交
668 669 670
  }
}

F
furnace 已提交
671 672 673 674
template <typename T, typename U>
static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
                              const U *mean, const U *var, T *d_x, U *d_scale,
                              U *d_bias, float epsilon, int batch_size,
L
Leo Chen 已提交
675 676 677 678 679
                              int feature_size,
                              const framework::ExecutionContext &ctx) {
  auto &dev_ctx = ctx.cuda_device_context();
  auto stream = dev_ctx.stream();

S
sneaxiy 已提交
680
  const int kMaxBlockDim = 512;
681
  const int kMaxBlockNum = 128;
682 683 684
  int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) |
                      ((d_scale != nullptr ? 1 : 0) << 1) |
                      ((d_bias != nullptr ? 1 : 0));
S
sneaxiy 已提交
685 686 687 688
  if (gradient_flag == 0) return;

  if (batch_size == 1) {
    LayerNormBackwardWhenBatchSizeIsOne<
F
furnace 已提交
689 690 691
        T, U><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim,
                0, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale,
                             epsilon, feature_size);
692 693 694 695

    if (d_x != nullptr) {
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
696
                             T, U, kBlockDim><<<1, kBlockDim, 0, stream>>>(
697 698 699
            x, d_x, mean, var, epsilon, feature_size));
      }
    }
S
sneaxiy 已提交
700 701 702 703 704 705 706
    return;
  }

  auto block_dim = GetDesiredBlockDim(batch_size);
  switch (gradient_flag) {
    case 1:  // d_x == nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
707 708 709
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
710
                T, U, kBlockDim, false,
711 712 713
                false><<<block_num, kBlockDim, 0, stream>>>(
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
714 715 716 717
      }
      break;
    case 2:  // d_x == nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
718 719 720
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
721 722
                T, U, kBlockDim, false,
                true><<<block_num, kBlockDim, 0, stream>>>(
723 724
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
725 726 727 728
      }
      break;
    case 3:  // d_x == nullptr, d_scale != nulptr, d_bias != nullptr
      switch (block_dim) {
729 730
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
S
sneaxiy 已提交
731
            LayerNormBackwardGradientAll<
F
furnace 已提交
732
                T, U, kBlockDim, false><<<block_num, kBlockDim, 0, stream>>>(
S
sneaxiy 已提交
733
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
734
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
735 736 737
      }
      break;
    case 4:  // d_x != nullptr, d_scale == nullptr, d_bias == nullptr
738 739 740
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardGradientOnlyDX<
F
furnace 已提交
741
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
742 743
                x, d_y, d_x, mean, var, scale, epsilon, feature_size));
      }
S
sneaxiy 已提交
744 745 746
      break;
    case 5:  // d_x != nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
747 748 749
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
750 751
                T, U, kBlockDim, true,
                false><<<block_num, kBlockDim, 0, stream>>>(
752 753
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
754
      }
755 756 757
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
758
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
759 760
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
761 762 763
      break;
    case 6:  // d_x != nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
764 765 766
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
767 768
                T, U, kBlockDim, true,
                true><<<block_num, kBlockDim, 0, stream>>>(
769 770
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
771
      }
772 773 774
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
775
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
776 777
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
778 779
      break;
    case 7:  // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
L
Leo Chen 已提交
780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815
    {
      constexpr int VPT = 4;
      constexpr int BDIMX2 = 32;
      constexpr int BDIMY2 = 4;
      dim3 threads2(BDIMX2, BDIMY2, 1);
      constexpr int part_size = BDIMY2 * VPT;
      const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1);

      auto part_grad_gamma_ptr =
          memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U));
      auto part_grad_beta_ptr =
          memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U));
      U *part_grad_gamma = reinterpret_cast<U *>(part_grad_gamma_ptr->ptr());
      U *part_grad_beta = reinterpret_cast<U *>(part_grad_beta_ptr->ptr());

      LayerNormBackwardPartGradGammaBeta<T, U, BDIMX2, BDIMY2,
                                         VPT><<<blocks2, threads2, 0, stream>>>(
          d_y, x, batch_size, feature_size, mean, var, epsilon, part_grad_gamma,
          part_grad_beta);  // compute part_grad_gamma, beta

      constexpr int BDIMX3 = 32;
      constexpr int BDIMY3 = 8;
      dim3 threads3(BDIMX3, BDIMY3, 1);
      const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1);
      LayerNormBackwardSumGradGammaBeta<
          T, U, BDIMX3, BDIMY3><<<blocks3, threads3, 0, stream>>>(
          part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size,
          d_scale, d_bias);

      constexpr int BDIMX1 = 32;
      constexpr int BDIMY1 = 4;
      dim3 threads1(BDIMX1, BDIMY1, 1);
      const dim3 blocks1(1, batch_size, 1);
      LayerNormBackwardComputeGradInput<
          T, U, BDIMX1, BDIMY1><<<blocks1, threads1, 0, stream>>>(
          d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
S
sneaxiy 已提交
816
      break;
L
Leo Chen 已提交
817
    }
S
sneaxiy 已提交
818 819 820 821 822
    default:
      break;
  }
}

P
Pei Yang 已提交
823 824 825 826 827 828 829 830 831 832 833 834 835
template <typename T>
void LayerNormDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
                                               const T *input,
                                               std::vector<int> input_shape,
                                               const T *bias, const T *scale,
                                               T *output, T *mean, T *variance,
                                               int begin_norm_axis, float eps) {
  const auto x_dims = framework::make_ddim(input_shape);
  auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
  int batch_size = static_cast<int>(matrix_dim[0]);
  int feature_size = static_cast<int>(matrix_dim[1]);
  switch (GetDesiredBlockDim(feature_size)) {
    FIXED_BLOCK_DIM_CASE(
F
furnace 已提交
836
        LayerNormForward<T, T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
P
Pei Yang 已提交
837 838 839 840 841 842 843 844 845
            input, scale, bias, output, mean, variance, eps, feature_size));
    default:
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Product from begin_norm_axis to end in layer_norm must be larger "
          "than 1"));
      break;
  }
}

S
sneaxiy 已提交
846 847 848 849 850
template <typename T>
class LayerNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
851
    using U = LayerNormParamType<T>;
S
sneaxiy 已提交
852 853 854 855 856 857 858 859 860 861 862 863 864
    const float epsilon = ctx.Attr<float>("epsilon");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *bias = ctx.Input<Tensor>("Bias");
    auto *x = ctx.Input<Tensor>("X");

    auto *y = ctx.Output<Tensor>("Y");
    auto *mean = ctx.Output<Tensor>("Mean");
    auto *var = ctx.Output<Tensor>("Variance");
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");

    const auto x_dims = x->dims();
    auto *x_data = x->data<T>();
    auto *y_data = y->mutable_data<T>(ctx.GetPlace());
865 866 867 868
    auto *mean_data = mean->mutable_data<U>(ctx.GetPlace());
    auto *var_data = var->mutable_data<U>(ctx.GetPlace());
    auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
    auto *bias_data = (bias == nullptr ? nullptr : bias->data<U>());
S
sneaxiy 已提交
869 870 871 872 873 874 875 876 877

    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
    int batch_size = static_cast<int>(matrix_dim[0]);
    int feature_size = static_cast<int>(matrix_dim[1]);

    auto stream = ctx.cuda_device_context().stream();

    switch (GetDesiredBlockDim(feature_size)) {
      FIXED_BLOCK_DIM_CASE(
878
          LayerNormForward<T, U,
F
furnace 已提交
879
                           kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
S
sneaxiy 已提交
880 881 882
              x_data, scale_data, bias_data, y_data, mean_data, var_data,
              epsilon, feature_size));
      default:
883 884
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Product from begin_norm_axis to end must be larger than 1"));
S
sneaxiy 已提交
885 886 887 888 889 890 891 892 893 894
        break;
    }
  }
};

template <typename T>
class LayerNormGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
F
furnace 已提交
895
    using U = LayerNormParamType<T>;
S
sneaxiy 已提交
896 897 898 899 900 901 902 903 904 905 906 907 908 909
    const float epsilon = ctx.Attr<float>("epsilon");
    // d_x, d_scale, d_bias may be nullptr
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    auto *x = ctx.Input<Tensor>("X");
    auto *mean = ctx.Input<Tensor>("Mean");
    auto *var = ctx.Input<Tensor>("Variance");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));

    auto *x_data = x->data<T>();
    auto *d_y_data = d_y->data<T>();
F
furnace 已提交
910 911 912 913
    auto *mean_data = mean->data<U>();
    auto *var_data = var->data<U>();

    auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
S
sneaxiy 已提交
914 915
    auto *d_scale_data =
        (d_scale == nullptr ? nullptr
F
furnace 已提交
916
                            : d_scale->mutable_data<U>(ctx.GetPlace()));
S
sneaxiy 已提交
917
    auto *d_bias_data =
F
furnace 已提交
918
        (d_bias == nullptr ? nullptr : d_bias->mutable_data<U>(ctx.GetPlace()));
S
sneaxiy 已提交
919 920 921 922 923 924 925 926 927
    auto *d_x_data =
        (d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));

    const auto &x_dims = x->dims();
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
    int batch_size = static_cast<int>(matrix_dim[0]);
    int feature_size = static_cast<int>(matrix_dim[1]);

F
furnace 已提交
928 929
    LayerNormBackward<T, U>(x_data, d_y_data, scale_data, mean_data, var_data,
                            d_x_data, d_scale_data, d_bias_data, epsilon,
L
Leo Chen 已提交
930
                            batch_size, feature_size, ctx);
S
sneaxiy 已提交
931 932
  }
};
F
furnace 已提交
933

P
Pei Yang 已提交
934
template class LayerNormDirectCUDAFunctor<float>;
F
furnace 已提交
935

936 937
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
S
sneaxiy 已提交
938 939 940 941 942
#undef FIXED_BLOCK_DIM_CASE_BASE
#undef FIXED_BLOCK_DIM_CASE
}  // namespace operators
}  // namespace paddle

C
chengduoZH 已提交
943
namespace ops = paddle::operators;
F
furnace 已提交
944
namespace plat = paddle::platform;
C
chengduoZH 已提交
945 946
REGISTER_OP_CUDA_KERNEL(
    layer_norm,
C
chengduoZH 已提交
947
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
F
furnace 已提交
948 949
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
C
chengduoZH 已提交
950 951
REGISTER_OP_CUDA_KERNEL(
    layer_norm_grad,
C
chengduoZH 已提交
952
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
F
furnace 已提交
953 954 955
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
                             plat::float16>);