layer_norm_op.cu 26.1 KB
Newer Older
S
sneaxiy 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

S
sneaxiy 已提交
15
#include <cub/cub.cuh>
P
Pei Yang 已提交
16 17
#include <memory>
#include <vector>
F
furnace 已提交
18

P
Pei Yang 已提交
19
#include "paddle/fluid/framework/ddim.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/operators/layer_norm_op.h"
F
furnace 已提交
21 22
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
C
chengduoZH 已提交
23

S
sneaxiy 已提交
24 25 26
namespace paddle {
namespace operators {

F
furnace 已提交
27 28 29 30 31 32 33
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;

S
sneaxiy 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
inline static int GetDesiredBlockDim(int block_dim) {
  const int kMaxBlockDim = 512;
  return block_dim >= kMaxBlockDim
             ? kMaxBlockDim
             : (1 << (static_cast<int>(std::log2f(block_dim))));
}

#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...)  \
  case (1 << (log2_block_dim)): {                       \
    constexpr auto kBlockDim = (1 << (log2_block_dim)); \
    __VA_ARGS__;                                        \
  } break

#define FIXED_BLOCK_DIM_CASE(...)              \
  FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(2, ##__VA_ARGS__); \
  FIXED_BLOCK_DIM_CASE_BASE(1, ##__VA_ARGS__)

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(                             \
    log2_block_dim, feature_size, kMaxBlockNum, ...)                           \
  case (1 << (log2_block_dim)): {                                              \
    for (int i = 0; i < std::ceil(feature_size / (1.0 * kMaxBlockNum)); i++) { \
      int col_offset = i * kMaxBlockNum;                                       \
      int block_num = std::min(feature_size - col_offset, kMaxBlockNum);       \
      constexpr auto kBlockDim = (1 << (log2_block_dim));                      \
      __VA_ARGS__;                                                             \
    }                                                                          \
  } break

#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(feature_size, kMaxBlockNum, ...) \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(9, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(8, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(7, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(6, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(5, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(4, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(3, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(2, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__);                   \
  FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(1, feature_size, kMaxBlockNum,    \
                                            ##__VA_ARGS__)

89 90 91
static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); }
static __device__ __forceinline__ double real_sqrt(double x) { return sqrt(x); }

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
template <typename T>
struct PairForLayerNorm {
  __device__ __forceinline__ PairForLayerNorm() {}
  __device__ __forceinline__ PairForLayerNorm(const T &first, const T &second)
      : first_(first), second_(second) {}

  T first_;
  T second_;
};

template <typename T>
struct PairForLayerNormAddFunctor {
  __device__ __forceinline__ PairForLayerNorm<T> operator()(
      const PairForLayerNorm<T> &p1, const PairForLayerNorm<T> &p2) {
    return PairForLayerNorm<T>(p1.first_ + p2.first_, p1.second_ + p2.second_);
  }
};

F
furnace 已提交
110 111 112
template <typename T, typename U, int BlockDim>
__global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
                                 T *y, U *mean, U *var, float epsilon,
S
sneaxiy 已提交
113
                                 int feature_size) {
Y
Yu Yang 已提交
114
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<double>, BlockDim>;
S
sneaxiy 已提交
115 116 117 118 119
  __shared__ typename BlockReduce::TempStorage temp_storage;

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

120
  // Step 1: Reduce to calculate mean and var
Y
Yu Yang 已提交
121 122
  double mean_val = 0;
  double var_val = 0;
S
sneaxiy 已提交
123
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
124
    U tmp = static_cast<U>(x[i]);
125
    mean_val += tmp;
S
sneaxiy 已提交
126 127
    var_val += (tmp * tmp);
  }
128
  auto pair = BlockReduce(temp_storage)
Y
Yu Yang 已提交
129 130
                  .Reduce(PairForLayerNorm<double>(mean_val, var_val),
                          PairForLayerNormAddFunctor<double>());
131 132
  if (threadIdx.x == 0) {
    auto tmp = pair.first_ / feature_size;
F
furnace 已提交
133 134
    mean[blockIdx.x] = static_cast<U>(tmp);
    var[blockIdx.x] = static_cast<U>(pair.second_ / feature_size - tmp * tmp);
135
  }
S
sneaxiy 已提交
136
  __syncthreads();
137
  mean_val = mean[blockIdx.x];
L
Leo Chen 已提交
138
  var_val = static_cast<U>(real_sqrt(var[blockIdx.x] + epsilon));
S
sneaxiy 已提交
139

140
  // Step 2: Calculate y
S
sneaxiy 已提交
141 142 143 144
  if (scale != nullptr) {
    if (bias != nullptr) {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
F
furnace 已提交
145 146
        y[i] = static_cast<T>(
            scale[j] * (static_cast<U>(x[i]) - mean_val) / var_val + bias[j]);
S
sneaxiy 已提交
147 148 149 150
      }
    } else {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
F
furnace 已提交
151 152
        y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) /
                              var_val);
S
sneaxiy 已提交
153 154 155 156 157 158
      }
    }
  } else {  // scale == nullptr
    if (bias != nullptr) {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
F
furnace 已提交
159 160
        y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) / var_val +
                              bias[j]);
S
sneaxiy 已提交
161 162 163 164
      }
    } else {
      for (int i = beg_idx, j = threadIdx.x; i < end_idx;
           i += BlockDim, j += BlockDim) {
F
furnace 已提交
165
        y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) / var_val);
S
sneaxiy 已提交
166 167 168 169 170 171 172
      }
    }
  }
}

// Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr
F
furnace 已提交
173
template <typename T, typename U, int BlockDim, bool HasDx>
S
sneaxiy 已提交
174
__global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
F
furnace 已提交
175 176 177
                                             U *d_scale, U *d_bias, T *d_x,
                                             const U *mean, const U *var,
                                             const U *scale, float epsilon,
178 179
                                             int batch_size, int feature_size,
                                             int col_offset) {
F
furnace 已提交
180
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
S
sneaxiy 已提交
181 182
  __shared__ typename BlockReduce::TempStorage temp_storage;

183 184
  int beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset);
  int end_idx = batch_size * feature_size + (blockIdx.x + col_offset);
S
sneaxiy 已提交
185
  int stride = BlockDim * feature_size;
186

F
furnace 已提交
187
  U d_scale_partial = static_cast<U>(0), d_bias_partial = static_cast<U>(0);
S
sneaxiy 已提交
188 189 190

  for (int i = beg_idx; i < end_idx; i += stride) {
    int row_idx = i / feature_size;
F
furnace 已提交
191 192 193 194
    auto var_val = real_sqrt(static_cast<U>(var[row_idx]) + epsilon);
    d_scale_partial += static_cast<U>(d_y[i]) *
                       (static_cast<U>(x[i]) - mean[row_idx]) / var_val;
    d_bias_partial += static_cast<U>(d_y[i]);
195
    if (HasDx) {
F
furnace 已提交
196 197
      d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
                              scale[blockIdx.x + col_offset] / var_val);
198
    }
S
sneaxiy 已提交
199 200
  }

201
  auto pair = BlockReduce(temp_storage)
F
furnace 已提交
202 203
                  .Reduce(PairForLayerNorm<U>(d_scale_partial, d_bias_partial),
                          PairForLayerNormAddFunctor<U>());
S
sneaxiy 已提交
204 205

  if (threadIdx.x == 0) {
206 207
    d_scale[blockIdx.x + col_offset] = pair.first_;
    d_bias[blockIdx.x + col_offset] = pair.second_;
S
sneaxiy 已提交
208 209 210 211 212 213
  }
}

// Make sure that there is only one true expression: d_scale != nullptr
// or d_bias != nullptr
// Notice: scale may be nullptr
F
furnace 已提交
214
template <typename T, typename U, int BlockDim, bool HasDx, bool HasDScale>
S
sneaxiy 已提交
215
__global__ void LayerNormBackwardGradientScaleOrBias(
F
furnace 已提交
216 217
    const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean,
    const U *var, const U *scale, float epsilon, int batch_size,
218
    int feature_size, int col_offset) {
F
furnace 已提交
219
  using BlockReduce = cub::BlockReduce<U, BlockDim>;
S
sneaxiy 已提交
220
  __shared__ typename BlockReduce::TempStorage temp_storage;
221 222
  int beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset;
  int end_idx = batch_size * feature_size + blockIdx.x + col_offset;
S
sneaxiy 已提交
223
  int stride = BlockDim * feature_size;
F
furnace 已提交
224
  U d_scale_or_d_bias_partial = static_cast<U>(0);
S
sneaxiy 已提交
225 226 227

  for (int i = beg_idx; i < end_idx; i += stride) {
    int row_idx = i / feature_size;
F
furnace 已提交
228 229
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(var[row_idx]) + epsilon));
S
sneaxiy 已提交
230
    if (HasDScale) {
F
furnace 已提交
231 232 233
      d_scale_or_d_bias_partial += static_cast<U>(d_y[i]) *
                                   (static_cast<U>(x[i]) - mean[row_idx]) /
                                   var_val;
S
sneaxiy 已提交
234
    } else {  // d_bias != nullptr
F
furnace 已提交
235
      d_scale_or_d_bias_partial += static_cast<U>(d_y[i]);
S
sneaxiy 已提交
236 237 238
    }

    if (HasDx) {
239
      if (scale != nullptr) {
F
furnace 已提交
240 241
        d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
                                scale[blockIdx.x + col_offset] / var_val);
242
      } else {
F
furnace 已提交
243
        d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
244
      }
S
sneaxiy 已提交
245 246 247 248 249 250 251 252
    }
  }

  d_scale_or_d_bias_partial =
      BlockReduce(temp_storage).Reduce(d_scale_or_d_bias_partial, cub::Sum());

  if (threadIdx.x == 0) {
    if (HasDScale) {
253
      d_scale[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
S
sneaxiy 已提交
254
    } else {
255
      d_bias[blockIdx.x + col_offset] = d_scale_or_d_bias_partial;
S
sneaxiy 已提交
256 257 258 259
    }
  }
}

F
furnace 已提交
260
template <typename T, typename U, int BlockDim>
261
__global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x,
F
furnace 已提交
262 263
                                                          const U *mean,
                                                          const U *var,
264 265
                                                          float epsilon,
                                                          int feature_size) {
F
furnace 已提交
266
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
267
  __shared__ typename BlockReduce::TempStorage temp_storage;
F
furnace 已提交
268
  __shared__ U d_x_reduce_tmp[2];
269 270 271 272

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

F
furnace 已提交
273 274 275
  U block_mean = mean[blockIdx.x];
  U block_var = var[blockIdx.x];
  U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
276
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
277 278 279
    d_x_mean_partial += static_cast<U>(d_x[i]);
    d_x_var_partial +=
        static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
280 281 282 283
  }

  auto pair =
      BlockReduce(temp_storage)
F
furnace 已提交
284 285
          .Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<U>());
286 287

  if (threadIdx.x == 0) {
F
furnace 已提交
288 289 290 291
    d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
    d_x_reduce_tmp[1] =
        static_cast<float>(pair.second_) /
        (feature_size * (static_cast<float>(block_var) + epsilon));
292 293 294 295 296 297
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
298 299 300
    d_x[i] -= static_cast<T>(d_x_mean_partial);
    d_x[i] -=
        static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
301 302 303
  }
}

S
sneaxiy 已提交
304
// Here, we only calculate d_x
F
furnace 已提交
305
template <typename T, typename U, int BlockDim>
306
__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
F
furnace 已提交
307 308
                                                T *d_x, const U *mean,
                                                const U *var, const U *scale,
309 310
                                                float epsilon,
                                                int feature_size) {
F
furnace 已提交
311
  using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
312
  __shared__ typename BlockReduce::TempStorage temp_storage;
F
furnace 已提交
313
  __shared__ U d_x_reduce_tmp[2];
314 315 316 317

  int beg_idx = blockIdx.x * feature_size + threadIdx.x;
  int end_idx = (blockIdx.x + 1) * feature_size;

F
furnace 已提交
318 319
  U block_mean = mean[blockIdx.x], block_var = var[blockIdx.x];
  U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
320
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
321 322
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon));
S
sneaxiy 已提交
323
    if (scale != nullptr) {
324
      int col_idx = i % feature_size;
F
furnace 已提交
325 326
      d_x[i] =
          static_cast<T>(static_cast<U>(d_y[i]) * scale[col_idx] / var_val);
S
sneaxiy 已提交
327
    } else {
F
furnace 已提交
328
      d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
S
sneaxiy 已提交
329
    }
F
furnace 已提交
330 331 332
    d_x_mean_partial += static_cast<U>(d_x[i]);
    d_x_var_partial +=
        static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
333 334 335 336
  }

  auto pair =
      BlockReduce(temp_storage)
F
furnace 已提交
337 338
          .Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
                  PairForLayerNormAddFunctor<U>());
339 340

  if (threadIdx.x == 0) {
F
furnace 已提交
341 342 343 344
    d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
    d_x_reduce_tmp[1] =
        static_cast<float>(pair.second_) /
        (feature_size * (static_cast<float>(block_var) + epsilon));
345 346 347 348 349 350
  }
  __syncthreads();

  d_x_mean_partial = d_x_reduce_tmp[0];
  d_x_var_partial = d_x_reduce_tmp[1];
  for (int i = beg_idx; i < end_idx; i += BlockDim) {
F
furnace 已提交
351 352 353
    d_x[i] -= static_cast<T>(d_x_mean_partial);
    d_x[i] -=
        static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
S
sneaxiy 已提交
354 355 356
  }
}

F
furnace 已提交
357
template <typename T, typename U>
S
sneaxiy 已提交
358
__global__ void LayerNormBackwardWhenBatchSizeIsOne(
F
furnace 已提交
359 360
    const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean,
    const U *var, const U *scale, float epsilon, int feature_size) {
S
sneaxiy 已提交
361 362
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < feature_size) {
F
furnace 已提交
363 364
    auto var_val =
        static_cast<U>(real_sqrt(static_cast<float>(var[idx]) + epsilon));
S
sneaxiy 已提交
365
    if (d_x != nullptr) {
366
      if (d_scale == nullptr) {
F
furnace 已提交
367
        d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val);
368
      } else {
F
furnace 已提交
369 370
        d_x[idx] =
            static_cast<T>(static_cast<U>(d_y[idx]) * scale[idx] / var_val);
371
      }
S
sneaxiy 已提交
372
    }
373 374

    if (d_scale != nullptr) {
F
furnace 已提交
375 376
      d_scale[idx] = static_cast<U>(d_y[idx]) *
                     (static_cast<U>(x[idx]) - mean[idx]) / var_val;
377 378
    }

F
furnace 已提交
379
    if (d_bias != nullptr) d_bias[idx] = static_cast<U>(d_y[idx]);
S
sneaxiy 已提交
380 381 382
  }
}

F
furnace 已提交
383 384 385 386
template <typename T, typename U>
static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
                              const U *mean, const U *var, T *d_x, U *d_scale,
                              U *d_bias, float epsilon, int batch_size,
S
sneaxiy 已提交
387 388
                              int feature_size, cudaStream_t stream) {
  const int kMaxBlockDim = 512;
389
  const int kMaxBlockNum = 128;
390 391 392
  int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) |
                      ((d_scale != nullptr ? 1 : 0) << 1) |
                      ((d_bias != nullptr ? 1 : 0));
S
sneaxiy 已提交
393 394 395 396
  if (gradient_flag == 0) return;

  if (batch_size == 1) {
    LayerNormBackwardWhenBatchSizeIsOne<
F
furnace 已提交
397 398 399
        T, U><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim,
                0, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale,
                             epsilon, feature_size);
400 401 402 403

    if (d_x != nullptr) {
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
404
                             T, U, kBlockDim><<<1, kBlockDim, 0, stream>>>(
405 406 407
            x, d_x, mean, var, epsilon, feature_size));
      }
    }
S
sneaxiy 已提交
408 409 410 411 412 413 414
    return;
  }

  auto block_dim = GetDesiredBlockDim(batch_size);
  switch (gradient_flag) {
    case 1:  // d_x == nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
415 416 417
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
418
                T, U, kBlockDim, false,
419 420 421
                false><<<block_num, kBlockDim, 0, stream>>>(
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
422 423 424 425
      }
      break;
    case 2:  // d_x == nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
426 427 428
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
429 430
                T, U, kBlockDim, false,
                true><<<block_num, kBlockDim, 0, stream>>>(
431 432
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
433 434 435 436
      }
      break;
    case 3:  // d_x == nullptr, d_scale != nulptr, d_bias != nullptr
      switch (block_dim) {
437 438
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
S
sneaxiy 已提交
439
            LayerNormBackwardGradientAll<
F
furnace 已提交
440
                T, U, kBlockDim, false><<<block_num, kBlockDim, 0, stream>>>(
S
sneaxiy 已提交
441
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
442
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
443 444 445
      }
      break;
    case 4:  // d_x != nullptr, d_scale == nullptr, d_bias == nullptr
446 447 448
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardGradientOnlyDX<
F
furnace 已提交
449
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
450 451
                x, d_y, d_x, mean, var, scale, epsilon, feature_size));
      }
S
sneaxiy 已提交
452 453 454
      break;
    case 5:  // d_x != nulptr, d_scale == nullptr, d_bias != nullptr
      switch (block_dim) {
455 456 457
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
458 459
                T, U, kBlockDim, true,
                false><<<block_num, kBlockDim, 0, stream>>>(
460 461
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
462
      }
463 464 465
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
466
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
467 468
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
469 470 471
      break;
    case 6:  // d_x != nullptr, d_scale != nullptr, d_bias == nullptr
      switch (block_dim) {
472 473 474
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
            LayerNormBackwardGradientScaleOrBias<
F
furnace 已提交
475 476
                T, U, kBlockDim, true,
                true><<<block_num, kBlockDim, 0, stream>>>(
477 478
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
479
      }
480 481 482
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
483
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
484 485
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
486 487 488
      break;
    case 7:  // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
      switch (block_dim) {
489 490
        FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
            feature_size, kMaxBlockNum,
S
sneaxiy 已提交
491
            LayerNormBackwardGradientAll<
F
furnace 已提交
492
                T, U, kBlockDim, true><<<block_num, kBlockDim, 0, stream>>>(
S
sneaxiy 已提交
493
                x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
494
                batch_size, feature_size, col_offset));
S
sneaxiy 已提交
495
      }
496 497 498
      switch (GetDesiredBlockDim(feature_size)) {
        FIXED_BLOCK_DIM_CASE(
            LayerNormBackwardPostProcessToCalculateDX<
F
furnace 已提交
499
                T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
500 501
                x, d_x, mean, var, epsilon, feature_size));
      }
S
sneaxiy 已提交
502 503 504 505 506 507
      break;
    default:
      break;
  }
}

P
Pei Yang 已提交
508 509 510 511 512 513 514 515 516 517 518 519 520
template <typename T>
void LayerNormDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
                                               const T *input,
                                               std::vector<int> input_shape,
                                               const T *bias, const T *scale,
                                               T *output, T *mean, T *variance,
                                               int begin_norm_axis, float eps) {
  const auto x_dims = framework::make_ddim(input_shape);
  auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
  int batch_size = static_cast<int>(matrix_dim[0]);
  int feature_size = static_cast<int>(matrix_dim[1]);
  switch (GetDesiredBlockDim(feature_size)) {
    FIXED_BLOCK_DIM_CASE(
F
furnace 已提交
521
        LayerNormForward<T, T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
P
Pei Yang 已提交
522 523 524 525 526 527 528 529 530
            input, scale, bias, output, mean, variance, eps, feature_size));
    default:
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Product from begin_norm_axis to end in layer_norm must be larger "
          "than 1"));
      break;
  }
}

S
sneaxiy 已提交
531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548
template <typename T>
class LayerNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const float epsilon = ctx.Attr<float>("epsilon");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *bias = ctx.Input<Tensor>("Bias");
    auto *x = ctx.Input<Tensor>("X");

    auto *y = ctx.Output<Tensor>("Y");
    auto *mean = ctx.Output<Tensor>("Mean");
    auto *var = ctx.Output<Tensor>("Variance");
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");

    const auto x_dims = x->dims();
    auto *x_data = x->data<T>();
    auto *y_data = y->mutable_data<T>(ctx.GetPlace());
F
furnace 已提交
549 550 551 552 553 554
    auto *mean_data = mean->mutable_data<LayerNormParamType<T>>(ctx.GetPlace());
    auto *var_data = var->mutable_data<LayerNormParamType<T>>(ctx.GetPlace());
    auto *scale_data =
        (scale == nullptr ? nullptr : scale->data<LayerNormParamType<T>>());
    auto *bias_data =
        (bias == nullptr ? nullptr : bias->data<LayerNormParamType<T>>());
S
sneaxiy 已提交
555 556 557 558 559 560 561 562 563

    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
    int batch_size = static_cast<int>(matrix_dim[0]);
    int feature_size = static_cast<int>(matrix_dim[1]);

    auto stream = ctx.cuda_device_context().stream();

    switch (GetDesiredBlockDim(feature_size)) {
      FIXED_BLOCK_DIM_CASE(
F
furnace 已提交
564 565
          LayerNormForward<T, LayerNormParamType<T>,
                           kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
S
sneaxiy 已提交
566 567 568
              x_data, scale_data, bias_data, y_data, mean_data, var_data,
              epsilon, feature_size));
      default:
569 570
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Product from begin_norm_axis to end must be larger than 1"));
S
sneaxiy 已提交
571 572 573 574 575 576 577 578 579 580
        break;
    }
  }
};

template <typename T>
class LayerNormGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
F
furnace 已提交
581
    using U = LayerNormParamType<T>;
S
sneaxiy 已提交
582 583 584 585 586 587 588 589 590 591 592 593 594 595
    const float epsilon = ctx.Attr<float>("epsilon");
    // d_x, d_scale, d_bias may be nullptr
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

    auto *x = ctx.Input<Tensor>("X");
    auto *mean = ctx.Input<Tensor>("Mean");
    auto *var = ctx.Input<Tensor>("Variance");
    auto *scale = ctx.Input<Tensor>("Scale");
    auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));

    auto *x_data = x->data<T>();
    auto *d_y_data = d_y->data<T>();
F
furnace 已提交
596 597 598 599
    auto *mean_data = mean->data<U>();
    auto *var_data = var->data<U>();

    auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
S
sneaxiy 已提交
600 601
    auto *d_scale_data =
        (d_scale == nullptr ? nullptr
F
furnace 已提交
602
                            : d_scale->mutable_data<U>(ctx.GetPlace()));
S
sneaxiy 已提交
603
    auto *d_bias_data =
F
furnace 已提交
604
        (d_bias == nullptr ? nullptr : d_bias->mutable_data<U>(ctx.GetPlace()));
S
sneaxiy 已提交
605 606 607 608 609 610 611 612 613 614 615
    auto *d_x_data =
        (d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));

    const auto &x_dims = x->dims();
    const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
    auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
    int batch_size = static_cast<int>(matrix_dim[0]);
    int feature_size = static_cast<int>(matrix_dim[1]);

    auto stream = ctx.cuda_device_context().stream();

F
furnace 已提交
616 617 618
    LayerNormBackward<T, U>(x_data, d_y_data, scale_data, mean_data, var_data,
                            d_x_data, d_scale_data, d_bias_data, epsilon,
                            batch_size, feature_size, stream);
S
sneaxiy 已提交
619 620
  }
};
F
furnace 已提交
621

P
Pei Yang 已提交
622
template class LayerNormDirectCUDAFunctor<float>;
F
furnace 已提交
623

624 625
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
S
sneaxiy 已提交
626 627 628 629 630
#undef FIXED_BLOCK_DIM_CASE_BASE
#undef FIXED_BLOCK_DIM_CASE
}  // namespace operators
}  // namespace paddle

C
chengduoZH 已提交
631
namespace ops = paddle::operators;
F
furnace 已提交
632
namespace plat = paddle::platform;
C
chengduoZH 已提交
633 634
REGISTER_OP_CUDA_KERNEL(
    layer_norm,
C
chengduoZH 已提交
635
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
F
furnace 已提交
636 637
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
C
chengduoZH 已提交
638 639
REGISTER_OP_CUDA_KERNEL(
    layer_norm_grad,
C
chengduoZH 已提交
640
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
F
furnace 已提交
641 642 643
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
                             plat::float16>);