layer_norm_op.cu 23.8 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

S
sneaxiy 已提交
15
#include <cub/cub.cuh>
P
Pei Yang 已提交
16 17 18
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
Y
Yi Wang 已提交
19
#include "paddle/fluid/operators/layer_norm_op.h"
C
chengduoZH 已提交
20

S
sneaxiy 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
namespace paddle {
namespace operators {

inline static int GetDesiredBlockDim(int block_dim) {
  const int kMaxBlockDim = 512;
  return block_dim >= kMaxBlockDim
             ? kMaxBlockDim
             : (1 << (static_cast<int>(std::log2f(block_dim))));
}

#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...)  \
  case (1 << (log2_block_dim)): {                       \
    constexpr auto kBlockDim = (1 << (log2_block_dim)); \
    __VA_ARGS__;                                        \
  } break

#define FIXED_BLOCK_DIM_CASE(...)              \
  FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(2, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(1, ##__VA_ARGS__)

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                             \
    log2_block_dim, feature_size, kMaxBlockNum, ...)                           \
  case (1 << (log2_block_dim)): {                                              \
    for (int i = 0; i < std::ceil(feature_size / (1.0 * kMaxBlockNum)); i++) { \
      int col_offset = i * kMaxBlockNum;                                       \
      int block_num = std::min(feature_size - col_offset, kMaxBlockNum);       \
      constexpr auto kBlockDim = (1 << (log2_block_dim));                      \
      __VA_ARGS__;                                                             \
    }                                                                          \
  } break

#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(feature_size, kMaxBlockNum, ...) \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(9, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(8, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(7, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(6, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(5, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(4, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(3, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(2, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(1, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__)

79 80 81
static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); }
static __device__ __forceinline__ double real_sqrt(double x) { return sqrt(x); }

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
template <typename T>
struct PairForLayerNorm {
  __device__ __forceinline__ PairForLayerNorm() {}
  __device__ __forceinline__ PairForLayerNorm(const T &first, const T &second)
      : first_(first), second_(second) {}

  T first_;
  T second_;
};

template <typename T>
struct PairForLayerNormAddFunctor {
  __device__ __forceinline__ PairForLayerNorm<T> operator()(
      const PairForLayerNorm<T> &p1, const PairForLayerNorm<T> &p2) {
    return PairForLayerNorm<T>(p1.first_ + p2.first_, p1.second_ + p2.second_);
  }
};

S
sneaxiy 已提交
100 101 102 103
template <typename T, int BlockDim>
__global__ void LayerNormForward(const T *x, const T *scale, const T *bias,
                                 T *y, T *mean, T *var, float epsilon,
                                 int feature_size) {
Y
Yu Yang 已提交
104
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<double>, BlockDim>;
S
sneaxiy 已提交
105 106 107 108 109
  __shared__ typename BlockReduce::TempStorage temp_storage;

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

110
  // Step 1: Reduce to calculate mean and var
Y
Yu Yang 已提交
111 112
  double mean_val = 0;
  double var_val = 0;
S
sneaxiy 已提交
113
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
114 115
    T tmp = x[i];
    mean_val += tmp;
S
sneaxiy 已提交
116 117
    var_val += (tmp * tmp);
  }
118
  auto pair = BlockReduce(temp_storage)
Y
Yu Yang 已提交
119 120
                  .Reduce(PairForLayerNorm<double>(mean_val, var_val),
                          PairForLayerNormAddFunctor<double>());
121 122
  if (threadIdx.x == 0) {
    auto tmp = pair.first_ / feature_size;
Y
Yu Yang 已提交
123 124
    mean[blockIdx.x] = static_cast<T>(tmp);
    var[blockIdx.x] = static_cast<T>(pair.second_ / feature_size - tmp * tmp);
125
  }
S
sneaxiy 已提交
126
  __syncthreads();
127
  mean_val = mean[blockIdx.x];
128
  var_val = static_cast<T>(real_sqrt(var[blockIdx.x] + epsilon));
S
sneaxiy 已提交
129

130
  // Step 2: Calculate y
S
sneaxiy 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
  if (scale != nullptr) {
    if (bias != nullptr) {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
        y[i] = scale[j] * (x[i] - mean_val) / var_val + bias[j];
      }
    } else {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
        y[i] = scale[j] * (x[i] - mean_val) / var_val;
      }
    }
  } else {  // scale == nullptr
    if (bias != nullptr) {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
        y[i] = (x[i] - mean_val) / var_val + bias[j];
      }
    } else {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
        y[i] = (x[i] - mean_val) / var_val;
      }
    }
  }
}

// Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr
template <typename T, int BlockDim, bool HasDx>
__global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
                                             T *d_scale, T *d_bias, T *d_x,
                                             const T *mean, const T *var,
                                             const T *scale, float epsilon,
165 166
                                             int batch_size, int feature_size,
                                             int col_offset) {
167
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
S
sneaxiy 已提交
168 169
  __shared__ typename BlockReduce::TempStorage temp_storage;

170 171
  int beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset);
  int end_idx = batch_size * feature_size + (blockIdx.x + col_offset);
S
sneaxiy 已提交
172
  int stride = BlockDim * feature_size;
173

S
sneaxiy 已提交
174 175 176 177
  T d_scale_partial = 0, d_bias_partial = 0;

  for (int i = beg_idx; i < end_idx; i += stride) {
    int row_idx = i / feature_size;
178
    auto var_val = static_cast<T>(real_sqrt(var[row_idx] + epsilon));
S
sneaxiy 已提交
179 180
    d_scale_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val;
    d_bias_partial += d_y[i];
181
    if (HasDx) {
182
      d_x[i] = d_y[i] * scale[blockIdx.x + col_offset] / var_val;
183
    }
S
sneaxiy 已提交
184 185
  }

186 187 188
  auto pair = BlockReduce(temp_storage)
                  .Reduce(PairForLayerNorm<T>(d_scale_partial, d_bias_partial),
                          PairForLayerNormAddFunctor<T>());
S
sneaxiy 已提交
189 190

  if (threadIdx.x == 0) {
191 192
    d_scale[blockIdx.x + col_offset] = pair.first_;
    d_bias[blockIdx.x + col_offset] = pair.second_;
S
sneaxiy 已提交
193 194 195 196 197 198 199 200 201 202
  }
}

// Make sure that there is only one true expression: d_scale != nullptr
// or d_bias != nullptr
// Notice: scale may be nullptr
template <typename T, int BlockDim, bool HasDx, bool HasDScale>
__global__ void LayerNormBackwardGradientScaleOrBias(
    const T *x, const T *d_y, T *d_scale, T *d_bias, T *d_x, const T *mean,
    const T *var, const T *scale, float epsilon, int batch_size,
203
    int feature_size, int col_offset) {
S
sneaxiy 已提交
204 205
  using BlockReduce = cub::BlockReduce<T, BlockDim>;
  __shared__ typename BlockReduce::TempStorage temp_storage;
206 207
  int beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset;
  int end_idx = batch_size * feature_size + blockIdx.x + col_offset;
S
sneaxiy 已提交
208 209 210 211 212
  int stride = BlockDim * feature_size;
  T d_scale_or_d_bias_partial = 0;

  for (int i = beg_idx; i < end_idx; i += stride) {
    int row_idx = i / feature_size;
213
    auto var_val = static_cast<T>(real_sqrt(var[row_idx] + epsilon));
S
sneaxiy 已提交
214 215 216 217 218 219 220
    if (HasDScale) {
      d_scale_or_d_bias_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val;
    } else {  // d_bias != nullptr
      d_scale_or_d_bias_partial += d_y[i];
    }

    if (HasDx) {
221
      if (scale != nullptr) {
222
        d_x[i] = d_y[i] * scale[blockIdx.x + col_offset] / var_val;
223
      } else {
S
sneaxiy 已提交
224
        d_x[i] = d_y[i] / var_val;
225
      }
S
sneaxiy 已提交
226 227 228 229 230 231 232 233
    }
  }

  d_scale_or_d_bias_partial =
      BlockReduce(temp_storage).Reduce(d_scale_or_d_bias_partial, cub::Sum());

  if (threadIdx.x == 0) {
    if (HasDScale) {
234
      d_scale[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
S
sneaxiy 已提交
235
    } else {
236
      d_bias[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
S
sneaxiy 已提交
237 238 239 240
    }
  }
}

241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
template <typename T, int BlockDim>
__global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x,
                                                          const T *mean,
                                                          const T *var,
                                                          float epsilon,
                                                          int feature_size) {
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
  __shared__ typename BlockReduce::TempStorage temp_storage;
  __shared__ T d_x_reduce_tmp[2];

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

  T block_mean = mean[blockIdx.x];
  T block_var = var[blockIdx.x];
  T d_x_mean_partial = 0, d_x_var_partial = 0;
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
    d_x_mean_partial += d_x[i];
    d_x_var_partial += d_x[i] * (x[i] - block_mean);
  }

  auto pair =
      BlockReduce(temp_storage)
          .Reduce(PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<T>());

  if (threadIdx.x == 0) {
    d_x_reduce_tmp[0] = pair.first_ / feature_size;
    d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon));
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
    d_x[i] -= d_x_mean_partial;
    d_x[i] -= (x[i] - block_mean) * d_x_var_partial;
  }
}

S
sneaxiy 已提交
281
// Here, we only calculate d_x
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
template <typename T, int BlockDim>
__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
                                                T *d_x, const T *mean,
                                                const T *var, const T *scale,
                                                float epsilon,
                                                int feature_size) {
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
  __shared__ typename BlockReduce::TempStorage temp_storage;
  __shared__ T d_x_reduce_tmp[2];

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

  T block_mean = mean[blockIdx.x], block_var = var[blockIdx.x];
  T d_x_mean_partial = 0, d_x_var_partial = 0;
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
    auto var_val = static_cast<T>(real_sqrt(block_var + epsilon));
S
sneaxiy 已提交
299
    if (scale != nullptr) {
300 301
      int col_idx = i % feature_size;
      d_x[i] = d_y[i] * scale[col_idx] / var_val;
S
sneaxiy 已提交
302
    } else {
303
      d_x[i] = d_y[i] / var_val;
S
sneaxiy 已提交
304
    }
305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
    d_x_mean_partial += d_x[i];
    d_x_var_partial += d_x[i] * (x[i] - block_mean);
  }

  auto pair =
      BlockReduce(temp_storage)
          .Reduce(PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<T>());

  if (threadIdx.x == 0) {
    d_x_reduce_tmp[0] = pair.first_ / feature_size;
    d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon));
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
    d_x[i] -= d_x_mean_partial;
    d_x[i] -= (x[i] - block_mean) * d_x_var_partial;
S
sneaxiy 已提交
325 326 327 328 329 330 331 332 333
  }
}

template <typename T>
__global__ void LayerNormBackwardWhenBatchSizeIsOne(
    const T *x, const T *d_y, T *d_x, T *d_scale, T *d_bias, const T *mean,
    const T *var, const T *scale, float epsilon, int feature_size) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < feature_size) {
334
    auto var_val = static_cast<T>(real_sqrt(var[idx] + epsilon));
S
sneaxiy 已提交
335
    if (d_x != nullptr) {
336
      if (d_scale == nullptr) {
S
sneaxiy 已提交
337
        d_x[idx] = d_y[idx] / var_val;
338
      } else {
S
sneaxiy 已提交
339
        d_x[idx] = d_y[idx] * scale[idx] / var_val;
340
      }
S
sneaxiy 已提交
341
    }
342 343

    if (d_scale != nullptr) {
S
sneaxiy 已提交
344
      d_scale[idx] = d_y[idx] * (x[idx] - mean[idx]) / var_val;
345 346
    }

S
sneaxiy 已提交
347 348 349 350 351 352 353 354 355 356
    if (d_bias != nullptr) d_bias[idx] = d_y[idx];
  }
}

template <typename T>
static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
                              const T *mean, const T *var, T *d_x, T *d_scale,
                              T *d_bias, float epsilon, int batch_size,
                              int feature_size, cudaStream_t stream) {
  const int kMaxBlockDim = 512;
357
  const int kMaxBlockNum = 128;
358 359 360
  int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) |
                      ((d_scale != nullptr ? 1 : 0) << 1) |
                      ((d_bias != nullptr ? 1 : 0));
S
sneaxiy 已提交
361 362 363 364 365 366 367
  if (gradient_flag == 0) return;

  if (batch_size == 1) {
    LayerNormBackwardWhenBatchSizeIsOne<
        T><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0,
             stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon,
                       feature_size);
368 369 370 371 372 373 374 375

    if (d_x != nullptr) {
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX<
                             T, kBlockDim><<<1, kBlockDim, 0, stream>>>(
            x, d_x, mean, var, epsilon, feature_size));
      }
    }
S
sneaxiy 已提交
376 377 378 379 380 381 382
    return;
  }

  auto block_dim = GetDesiredBlockDim(batch_size);
  switch (gradient_flag) {
    case 1:  // d_x == nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
383 384 385 386 387 388 389
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
                T, kBlockDim, false,
                false><<<block_num, kBlockDim, 0, stream>>>(
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
390 391 392 393
      }
      break;
    case 2:  // d_x == nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
394 395 396 397 398 399
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
                T, kBlockDim, false, true><<<block_num, kBlockDim, 0, stream>>>(
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
400 401 402 403
      }
      break;
    case 3:  // d_x == nullptr, d_scale != nulptr, d_bias != nullptr
      switch (block_dim) {
404 405
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
S
sneaxiy 已提交
406
            LayerNormBackwardGradientAll<
407
                T, kBlockDim, false><<<block_num, kBlockDim, 0, stream>>>(
S
sneaxiy 已提交
408
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
409
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
410 411 412
      }
      break;
    case 4:  // d_x != nullptr, d_scale == nullptr, d_bias == nullptr
413 414 415 416 417 418
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardGradientOnlyDX<
                T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
                x, d_y, d_x, mean, var, scale, epsilon, feature_size));
      }
S
sneaxiy 已提交
419 420 421
      break;
    case 5:  // d_x != nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
422 423 424 425 426 427
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
                T, kBlockDim, true, false><<<block_num, kBlockDim, 0, stream>>>(
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
428
      }
429 430 431 432 433 434
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
                T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
435 436 437
      break;
    case 6:  // d_x != nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
438 439 440 441 442 443
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
                T, kBlockDim, true, true><<<block_num, kBlockDim, 0, stream>>>(
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
444
      }
445 446 447 448 449 450
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
                T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
451 452 453
      break;
    case 7:  // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
      switch (block_dim) {
454 455
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
S
sneaxiy 已提交
456
            LayerNormBackwardGradientAll<
457
                T, kBlockDim, true><<<block_num, kBlockDim, 0, stream>>>(
S
sneaxiy 已提交
458
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
459
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
460
      }
461 462 463 464 465 466
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
                T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
467 468 469 470 471 472
      break;
    default:
      break;
  }
}

P
Pei Yang 已提交
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495
template <typename T>
void LayerNormDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
                                               const T *input,
                                               std::vector<int> input_shape,
                                               const T *bias, const T *scale,
                                               T *output, T *mean, T *variance,
                                               int begin_norm_axis, float eps) {
  const auto x_dims = framework::make_ddim(input_shape);
  auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
  int batch_size = static_cast<int>(matrix_dim[0]);
  int feature_size = static_cast<int>(matrix_dim[1]);
  switch (GetDesiredBlockDim(feature_size)) {
    FIXED_BLOCK_DIM_CASE(
        LayerNormForward<T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
            input, scale, bias, output, mean, variance, eps, feature_size));
    default:
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Product from begin_norm_axis to end in layer_norm must be larger "
          "than 1"));
      break;
  }
}

S
sneaxiy 已提交
496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
template <typename T>
class LayerNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const float epsilon = ctx.Attr<float>("epsilon");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *bias = ctx.Input<Tensor>("Bias");
    auto *x = ctx.Input<Tensor>("X");

    auto *y = ctx.Output<Tensor>("Y");
    auto *mean = ctx.Output<Tensor>("Mean");
    auto *var = ctx.Output<Tensor>("Variance");
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");

    const auto x_dims = x->dims();
    auto *x_data = x->data<T>();
    auto *y_data = y->mutable_data<T>(ctx.GetPlace());
    auto *mean_data = mean->mutable_data<T>(ctx.GetPlace());
    auto *var_data = var->mutable_data<T>(ctx.GetPlace());
    auto *scale_data = (scale == nullptr ? nullptr : scale->data<T>());
    auto *bias_data = (bias == nullptr ? nullptr : bias->data<T>());

    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
    int batch_size = static_cast<int>(matrix_dim[0]);
    int feature_size = static_cast<int>(matrix_dim[1]);

    auto stream = ctx.cuda_device_context().stream();

    switch (GetDesiredBlockDim(feature_size)) {
      FIXED_BLOCK_DIM_CASE(
          LayerNormForward<T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
              x_data, scale_data, bias_data, y_data, mean_data, var_data,
              epsilon, feature_size));
      default:
531 532
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Product from begin_norm_axis to end must be larger than 1"));
S
sneaxiy 已提交
533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580
        break;
    }
  }
};

template <typename T>
class LayerNormGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const float epsilon = ctx.Attr<float>("epsilon");
    // d_x, d_scale, d_bias may be nullptr
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    auto *x = ctx.Input<Tensor>("X");
    auto *mean = ctx.Input<Tensor>("Mean");
    auto *var = ctx.Input<Tensor>("Variance");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));

    auto *x_data = x->data<T>();
    auto *d_y_data = d_y->data<T>();
    auto *mean_data = mean->data<T>();
    auto *var_data = var->data<T>();
    auto *scale_data = (scale == nullptr ? nullptr : scale->data<T>());
    auto *d_scale_data =
        (d_scale == nullptr ? nullptr
                            : d_scale->mutable_data<T>(ctx.GetPlace()));
    auto *d_bias_data =
        (d_bias == nullptr ? nullptr : d_bias->mutable_data<T>(ctx.GetPlace()));
    auto *d_x_data =
        (d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));

    const auto &x_dims = x->dims();
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
    int batch_size = static_cast<int>(matrix_dim[0]);
    int feature_size = static_cast<int>(matrix_dim[1]);

    auto stream = ctx.cuda_device_context().stream();

    LayerNormBackward<T>(x_data, d_y_data, scale_data, mean_data, var_data,
                         d_x_data, d_scale_data, d_bias_data, epsilon,
                         batch_size, feature_size, stream);
  }
};
P
Pei Yang 已提交
581
template class LayerNormDirectCUDAFunctor<float>;
582 583
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
S
sneaxiy 已提交
584 585 586 587 588
#undef FIXED_BLOCK_DIM_CASE_BASE
#undef FIXED_BLOCK_DIM_CASE
}  // namespace operators
}  // namespace paddle

C
chengduoZH 已提交
589 590 591
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    layer_norm,
C
chengduoZH 已提交
592 593
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>);
C
chengduoZH 已提交
594 595
REGISTER_OP_CUDA_KERNEL(
    layer_norm_grad,
C
chengduoZH 已提交
596 597
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>);