layer_norm_kernel.cu 27.8 KB
Newer Older
H
hong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/layer_norm_kernel.h"
16
#include "gflags/gflags.h"
H
hong 已提交
17 18
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
19
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
H
hong 已提交
20 21
#include "paddle/phi/kernels/funcs/layer_norm_util.h"

22 23
DECLARE_bool(use_fast_math);

H
hong 已提交
24 25
namespace phi {

26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454
#ifdef PADDLE_WITH_CUDA
template <typename U>
__device__ inline void WelfordOnline(U val, U *mean, U *square, U *count) {
  *count += 1;
  U delta1 = val - *mean;
  *mean += delta1 / (*count);
  U delta2 = val - *mean;
  *square += delta1 * delta2;
}

template <typename U>
__device__ inline void WelfordOnline(
    U b_mean, U b_square, U b_cnt, U *mean, U *square, U *count) {
  if (b_cnt == 0) {
    return;
  }

  U new_cnt = *count + b_cnt;
  U nb_n = b_cnt / new_cnt;
  U delta = b_mean - *mean;
  *mean += delta * nb_n;
  *square += b_square + delta * delta * (*count) * nb_n;
  *count = new_cnt;
}

template <typename U>
__device__ inline void WelfordWarpAllReduce(U *mean, U *square, U *count) {
  constexpr int kWarpSize = 32;
#pragma unroll
  for (int mask = 1; mask < kWarpSize; mask *= 2) {
    U b_mean = __shfl_down_sync(0xffffffff, *mean, mask);
    U b_square = __shfl_down_sync(0xffffffff, *square, mask);
    U b_cnt = __shfl_down_sync(0xffffffff, *count, mask);
    WelfordOnline<U>(b_mean, b_square, b_cnt, mean, square, count);
  }

  *mean = __shfl_sync(0xffffffff, *mean, 0, kWarpSize);
  *square = __shfl_sync(0xffffffff, *square, 0, kWarpSize);
  *count = __shfl_sync(0xffffffff, *count, 0, kWarpSize);
}

template <int VecSize>
struct ThreadAssigner {
  __device__ __forceinline__ int operator()(const int cols,
                                            const int cols_per_thread,
                                            int32_t *last_tid_idx) {
    return cols_per_thread;
  }
};

template <>
struct ThreadAssigner<1> {
  __device__ inline int operator()(const int cols,
                                   const int cols_per_thread,
                                   int *last_tid_idx) {
    int cols_this_thread = cols_per_thread;
    int last_tid = (cols / cols_per_thread);
    *last_tid_idx = last_tid;
    if (threadIdx.x == last_tid) {
      cols_this_thread = cols - cols_per_thread * last_tid;
    } else if (threadIdx.x > last_tid) {
      cols_this_thread = 0;
    }
    return cols_this_thread;
  }
};

template <typename T, typename U, int VecSize>
struct LayerNormDataReader {
  __device__ inline void operator()(const T *__restrict__ row_src,
                                    U *buffer,
                                    const int last_tid_idx,
                                    const int read_times,
                                    const int cols_this_thread) {
    using VecT = phi::AlignedVector<T, VecSize>;
    const VecT *__restrict__ v_src =
        reinterpret_cast<const VecT *__restrict__>(row_src);

    for (int i = 0; i < read_times; ++i) {
      VecT temp_src = v_src[threadIdx.x + i * blockDim.x];
#pragma unroll
      for (int j = 0; j < VecSize; ++j) {
        buffer[i * VecSize + j] = static_cast<U>(temp_src[j]);
      }
    }
  }
};

template <typename T, typename U>
struct LayerNormDataReader<T, U, 1> {
  __device__ inline void operator()(const T *__restrict__ row_src,
                                    U *buffer,
                                    const int last_tid_idx,
                                    const int read_times,
                                    const int cols_this_thread) {
    // read_time is just cols_per_thread while VecSize is 1.
    if (threadIdx.x < last_tid_idx) {
      for (int i = 0; i < cols_this_thread; ++i) {
        buffer[i] = static_cast<U>(row_src[threadIdx.x + last_tid_idx * i]);
      }
    } else {
      for (int i = 0; i < cols_this_thread; ++i) {
        buffer[i] = static_cast<U>(row_src[i + read_times * last_tid_idx]);
      }
    }
  }
};

template <typename T, typename U, bool IsSameType, int VecSize>
struct LayerNormDataWritter {
  __device__ inline void operator()(
      T *__restrict__ row_dst,
      const U *__restrict__ buffer,
      const funcs::LayerNormScaleBiasT<T, U, IsSameType> *__restrict__ scale,
      const funcs::LayerNormScaleBiasT<T, U, IsSameType> *__restrict__ bias,
      const U row_mean,
      const U row_inv_var,
      const int write_times,
      const int cols_this_thread,
      const int last_tid_idx,
      const bool valid_scale,
      const bool valid_bias) {
    using VecT = phi::AlignedVector<T, VecSize>;
    using ScaleT = funcs::LayerNormScaleBiasT<T, U, IsSameType>;
    using VecScaleT = phi::AlignedVector<ScaleT, VecSize>;
    VecT *v_dst = reinterpret_cast<VecT *>(row_dst);

    // cols_this_thread is just cols_per_thread
    if ((!valid_scale) && (!valid_bias)) {
      for (int i = 0; i < write_times; ++i) {
        VecT temp_dst;
#pragma unroll
        for (int j = 0; j < VecSize; ++j) {
          temp_dst[j] = static_cast<T>((buffer[i * VecSize + j] - row_mean) *
                                       row_inv_var);
        }
        v_dst[threadIdx.x + blockDim.x * i] = temp_dst;
      }
    } else {
      const VecScaleT *__restrict__ v_scale =
          reinterpret_cast<const VecScaleT *__restrict__>(scale);
      const VecScaleT *__restrict__ v_bias =
          reinterpret_cast<const VecScaleT *__restrict__>(bias);
      if (valid_scale && valid_bias) {
        for (int i = 0; i < write_times; ++i) {
          int idx = threadIdx.x + blockDim.x * i;
          VecT temp_dst;
          VecScaleT temp_v_scale = v_scale[idx];
          VecScaleT temp_v_bias = v_bias[idx];
#pragma unroll
          for (int j = 0; j < VecSize; ++j) {
            temp_dst[j] = static_cast<T>(
                static_cast<U>(temp_v_scale[j]) *
                    (buffer[i * VecSize + j] - row_mean) * row_inv_var +
                static_cast<U>(temp_v_bias[j]));
          }
          v_dst[idx] = temp_dst;
        }
      } else {
        if (valid_scale) {
          for (int i = 0; i < write_times; ++i) {
            int idx = threadIdx.x + blockDim.x * i;
            VecT temp_dst;
            VecScaleT temp_v_scale = v_scale[idx];
#pragma unroll
            for (int j = 0; j < VecSize; ++j) {
              temp_dst[j] = static_cast<T>(
                  static_cast<U>(temp_v_scale[j]) *
                  (buffer[i * VecSize + j] - row_mean) * row_inv_var);
            }
            v_dst[idx] = temp_dst;
          }
        } else {
          for (int i = 0; i < write_times; ++i) {
            int idx = threadIdx.x + blockDim.x * i;
            VecT temp_dst;
            VecScaleT temp_v_bias = v_bias[idx];
#pragma unroll
            for (int j = 0; j < VecSize; ++j) {
              temp_dst[j] = static_cast<T>(
                  (buffer[i * VecSize + j] - row_mean) * row_inv_var +
                  static_cast<U>(temp_v_bias[j]));
            }
            v_dst[idx] = temp_dst;
          }
        }
      }
    }
  }
};

template <typename T, typename U, bool IsSameType>
struct LayerNormDataWritter<T, U, IsSameType, 1> {
  __device__ __forceinline__ void operator()(
      T *__restrict__ row_dst,
      U *__restrict__ buffer,
      const funcs::LayerNormScaleBiasT<T, U, IsSameType> *__restrict__ scale,
      const funcs::LayerNormScaleBiasT<T, U, IsSameType> *__restrict__ bias,
      const U row_mean,
      const U row_inv_var,
      const int write_times,
      const int cols_this_thread,
      const int last_tid_idx,
      const bool valid_scale,
      const bool valid_bias) {
    // write_times is just col_per_thread.
    if ((!valid_scale) && (!valid_bias)) {
      if (threadIdx.x < last_tid_idx) {
        for (int i = 0; i < cols_this_thread; ++i) {
          row_dst[threadIdx.x + last_tid_idx * i] =
              (buffer[i] - row_mean) * row_inv_var;
        }
      } else {
        for (int i = 0; i < cols_this_thread; ++i) {
          row_dst[last_tid_idx * write_times + i] =
              (buffer[i] - row_mean) * row_inv_var;
        }
      }
    } else if (valid_scale && valid_bias) {
      if (threadIdx.x < last_tid_idx) {
        for (int i = 0; i < cols_this_thread; ++i) {
          int idx = threadIdx.x + last_tid_idx * i;
          row_dst[idx] =
              static_cast<T>(static_cast<U>(scale[idx]) *
                                 (buffer[i] - row_mean) * row_inv_var +
                             static_cast<U>(bias[idx]));
        }
      } else {
        for (int i = 0; i < cols_this_thread; ++i) {
          int idx = last_tid_idx * write_times + i;
          row_dst[idx] =
              static_cast<T>(static_cast<U>(scale[idx]) *
                                 (buffer[i] - row_mean) * row_inv_var +
                             static_cast<U>(bias[idx]));
        }
      }
    } else {
      if (valid_scale) {
        if (threadIdx.x < last_tid_idx) {
          for (int i = 0; i < cols_this_thread; ++i) {
            int idx = threadIdx.x + last_tid_idx * i;
            row_dst[idx] = static_cast<T>(static_cast<U>(scale[idx]) *
                                          (buffer[i] - row_mean) * row_inv_var);
          }
        } else {
          for (int i = 0; i < cols_this_thread; ++i) {
            int idx = last_tid_idx * write_times + i;
            row_dst[idx] = static_cast<T>(static_cast<U>(scale[idx]) *
                                          (buffer[i] - row_mean) * row_inv_var);
          }
        }
      } else {
        if (threadIdx.x < last_tid_idx) {
          for (int i = 0; i < cols_this_thread; ++i) {
            int idx = threadIdx.x + last_tid_idx * i;
            row_dst[idx] = static_cast<T>((buffer[i] - row_mean) * row_inv_var +
                                          static_cast<U>(bias[idx]));
          }
        } else {
          for (int i = 0; i < cols_this_thread; ++i) {
            int idx = last_tid_idx * write_times + i;
            row_dst[idx] = static_cast<T>((buffer[i] - row_mean) * row_inv_var +
                                          static_cast<U>(bias[idx]));
          }
        }
      }
    }
  }
};

template <typename IndexT, typename T, typename U, bool IsSameType, int VecSize>
__global__ void LayerNormFwdWithWelford(
    const T *__restrict__ src_data,
    T *dst_data,
    const funcs::LayerNormScaleBiasT<T, U, IsSameType> *__restrict__ scale,
    const funcs::LayerNormScaleBiasT<T, U, IsSameType> *__restrict__ bias,
    U *mean,
    U *var,
    const U epsilon,
    const IndexT rows,
    const int32_t cols,
    const int32_t cols_per_thread,
    const bool valid_scale,
    const bool valid_bias) {
  constexpr int kWarpSize = 32;
  int last_tid_idx = 0;  // For condition once vecSize is 1.
  IndexT row_offset = blockIdx.x * blockDim.y + threadIdx.y;
  int cols_this_thread =
      ThreadAssigner<VecSize>()(cols, cols_per_thread, &last_tid_idx);
  int read_times = cols_per_thread / VecSize;

  if (row_offset < rows) {
    U buffer[kWarpSize];
    U tid_cnt = static_cast<U>(0);
    U tid_mean = static_cast<U>(0);
    U tid_square = static_cast<U>(0);

    const T *__restrict__ row_src = src_data + row_offset * cols;
    T *row_dst = dst_data + row_offset * cols;
    LayerNormDataReader<T, U, VecSize>()(
        row_src, buffer, last_tid_idx, read_times, cols_this_thread);

    for (int i = 0; i < cols_this_thread; i++) {
      WelfordOnline<U>(buffer[i], &tid_mean, &tid_square, &tid_cnt);
    }

    U warp_cnt = tid_cnt;
    U warp_mean = tid_mean;
    U warp_square = tid_square;
    WelfordWarpAllReduce<U>(&warp_mean, &warp_square, &warp_cnt);

    U row_variance = max(warp_square / warp_cnt, 0.f);
    U row_inv_var = funcs::rsqrt_(row_variance + epsilon);

    // TODO(limingshu): make code below vectorization.
    if (threadIdx.x == 0) {
      // warp_mean is just row_mean here.
      mean[row_offset] = warp_mean;
      var[row_offset] = row_variance;
    }
    LayerNormDataWritter<T, U, IsSameType, VecSize>()(row_dst,
                                                      buffer,
                                                      scale,
                                                      bias,
                                                      warp_mean,
                                                      row_inv_var,
                                                      read_times,
                                                      cols_this_thread,
                                                      last_tid_idx,
                                                      valid_scale,
                                                      valid_bias);
  }
}

template <typename Context, typename T, typename U>
void LaunchLayerNormKernel(const Context &dev_ctx,
                           const T *x_data,
                           T *y_data,
                           const void *void_scale_data,
                           const void *void_bias_data,
                           U *mean_data,
                           U *var_data,
                           float epsilon,
                           const int64_t rows,
                           const int cols,
                           const bool valid_scale,
                           const bool valid_bias,
                           const bool is_same_type) {
  constexpr int WarpSize = 32;
  constexpr int RowPerBlock = 4;
  int64_t block_size = (rows + (RowPerBlock - 1)) / RowPerBlock;
  dim3 threads(WarpSize, RowPerBlock, 1);

  int vec_size = 1;
  int cols_per_thread = (cols + (WarpSize - 1)) / WarpSize;
  if (cols_per_thread > 1 && (cols % WarpSize == 0)) {
    int data_vec_size = 0;
    uint64_t addr = (reinterpret_cast<uint64_t>(x_data) |
                     reinterpret_cast<uint64_t>(y_data));
    if (valid_bias || valid_scale) {
      if (is_same_type) {
        addr = valid_scale
                   ? (addr | reinterpret_cast<uint64_t>(void_scale_data))
                   : addr;
        addr = valid_bias ? (addr | reinterpret_cast<uint64_t>(void_bias_data))
                          : addr;
        data_vec_size = phi::GetVectorizedSize<T>(reinterpret_cast<T *>(addr));
      } else {
        uint64_t bias_addr = reinterpret_cast<uint64_t>(void_bias_data);
        uint64_t attr_addr = valid_scale
                                 ? reinterpret_cast<uint64_t>(void_scale_data)
                                 : bias_addr;
        attr_addr = valid_bias
                        ? (valid_scale ? (attr_addr | bias_addr) : attr_addr)
                        : attr_addr;
        data_vec_size = std::min(
            phi::GetVectorizedSize<T>(reinterpret_cast<T *>(addr)),
            phi::GetVectorizedSize<U>(reinterpret_cast<U *>(attr_addr)));
      }
    }
    for (int size = data_vec_size; size > 0; size /= 2) {
      if (cols_per_thread % size == 0) {
        vec_size = size;
        break;
      }
    }
  }

#define IMPL_LAYER_NORM_WELFORD_CASE(index_t, scale_t, is_same_, vec_size_) \
  case (vec_size_): {                                                       \
    LayerNormFwdWithWelford<index_t, T, U, is_same_, vec_size_>             \
        <<<block_size, threads, 0, dev_ctx.stream()>>>(                     \
            x_data,                                                         \
            y_data,                                                         \
            static_cast<const scale_t *>(void_scale_data),                  \
            static_cast<const scale_t *>(void_bias_data),                   \
            mean_data,                                                      \
            var_data,                                                       \
            static_cast<const U>(epsilon),                                  \
            rows,                                                           \
            cols,                                                           \
            cols_per_thread,                                                \
            valid_scale,                                                    \
            valid_bias);                                                    \
  } break

#define IMPL_LAYER_NORM_WELFORD(index_t, scale_t, is_same_)    \
  IMPL_LAYER_NORM_WELFORD_CASE(index_t, scale_t, is_same_, 4); \
  IMPL_LAYER_NORM_WELFORD_CASE(index_t, scale_t, is_same_, 2); \
  IMPL_LAYER_NORM_WELFORD_CASE(index_t, scale_t, is_same_, 1);

  if (rows < std::numeric_limits<int32_t>::max()) {
    if (is_same_type) {
      switch (vec_size) { IMPL_LAYER_NORM_WELFORD(int32_t, T, true); }
    } else {
      switch (vec_size) { IMPL_LAYER_NORM_WELFORD(int32_t, U, false); }
    }
  } else {
    if (is_same_type) {
      switch (vec_size) { IMPL_LAYER_NORM_WELFORD(int64_t, T, true); }
    } else {
      switch (vec_size) { IMPL_LAYER_NORM_WELFORD(int64_t, U, false); }
    }
  }
#undef IMPL_LAYER_NORM_WELFORD_CASE
#undef IMPL_LAYER_NORM_WELFORD
}
#endif  // PADDLE_WITH_CUDA

455 456 457 458 459 460 461 462 463 464 465
template <typename T, typename U>
void LayerNormDirectCUDAFunctor<T, U>::operator()(gpuStream_t stream,
                                                  const T *input,
                                                  std::vector<int> input_shape,
                                                  const U *bias,
                                                  const U *scale,
                                                  T *output,
                                                  U *mean,
                                                  U *variance,
                                                  int begin_norm_axis,
                                                  float eps) {
H
hong 已提交
466 467 468 469
  const auto x_dims = phi::make_ddim(input_shape);
  auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
  int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
  int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
470
  switch (phi::funcs::GetDesiredBlockDim(feature_size)) {
471
    FIXED_BLOCK_DIM_CASE(
472
        phi::funcs::LayerNormForward<T, U, kBlockDim>
473 474
        <<<batch_size, kBlockDim, 0, stream>>>(
            input, scale, bias, output, mean, variance, eps, feature_size));
H
hong 已提交
475 476 477 478 479 480 481 482
    default:
      PADDLE_THROW(phi::errors::InvalidArgument(
          "Product from begin_norm_axis to end in layer_norm must be larger "
          "than 1"));
      break;
  }
}

483
template class LayerNormDirectCUDAFunctor<float, float>;
MarDino's avatar
MarDino 已提交
484
template class LayerNormDirectCUDAFunctor<double, double>;
485 486 487
#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
template class LayerNormDirectCUDAFunctor<half, float>;
#endif
H
hong 已提交
488 489 490 491

template <typename T, typename Context>
void LayerNormKernel(const Context &dev_ctx,
                     const DenseTensor &x,
492 493
                     const paddle::optional<DenseTensor> &scale_opt,
                     const paddle::optional<DenseTensor> &bias_opt,
H
hong 已提交
494 495 496 497 498
                     float epsilon,
                     int begin_norm_axis,
                     DenseTensor *y,
                     DenseTensor *mean,
                     DenseTensor *var) {
499
  using U = phi::funcs::LayerNormParamType<T>;
H
hong 已提交
500 501 502 503 504 505 506 507 508
  auto *scale = scale_opt.get_ptr();
  auto *bias = bias_opt.get_ptr();

  const auto x_dims = x.dims();
  auto *x_data = x.data<T>();
  auto *y_data = dev_ctx.template Alloc<T>(y);
  auto *mean_data = dev_ctx.template Alloc<U>(mean);
  auto *var_data = dev_ctx.template Alloc<U>(var);

509 510 511 512
  bool valid_scale = (scale != nullptr);
  bool valid_bias = (bias != nullptr);
  auto *void_scale_data = valid_scale ? scale->data() : nullptr;
  auto *void_bias_data = valid_bias ? bias->data() : nullptr;
H
hong 已提交
513 514 515

  auto x_dtype = x.dtype();
  phi::DataType scale_bias_dtype;
516
  if (valid_scale) {
H
hong 已提交
517
    scale_bias_dtype = scale->dtype();
518
    if (valid_bias) {
H
hong 已提交
519 520 521
      PADDLE_ENFORCE_EQ(
          scale->dtype(),
          bias->dtype(),
522
          phi::errors::InvalidArgument("This Scale and Bias of layer_norm op "
H
hong 已提交
523 524 525
                                       "should have the same data type."));
    }
  } else {
526
    scale_bias_dtype = valid_bias ? bias->dtype() : x_dtype;
H
hong 已提交
527 528 529 530 531
  }

  bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype;
  if (!is_scale_bias_same_dtype_with_x) {
    PADDLE_ENFORCE_EQ(scale_bias_dtype,
532
                      phi::CppTypeToDataType<U>::Type(),
H
hong 已提交
533 534 535 536 537 538 539 540 541 542 543
                      phi::errors::InvalidArgument(
                          "Unsupported data type of Scale and Bias"));
  }

  auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
  int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
  int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
  auto stream = dev_ctx.stream();

#define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
  do {                                                                     \
544
    switch (phi::funcs::GetDesiredBlockDim(feature_size)) {                \
545
      FIXED_BLOCK_DIM_CASE(                                                \
546
          phi::funcs::                                                     \
547 548 549 550 551 552 553 554 555 556
              LayerNormForward<T, U, kBlockDim, IsScaleBiasSameDTypeWithX> \
          <<<batch_size, kBlockDim, 0, stream>>>(                          \
              x_data,                                                      \
              static_cast<const ScaleBiasT *>(void_scale_data),            \
              static_cast<const ScaleBiasT *>(void_bias_data),             \
              y_data,                                                      \
              mean_data,                                                   \
              var_data,                                                    \
              epsilon,                                                     \
              feature_size));                                              \
H
hong 已提交
557 558 559 560 561 562 563
      default:                                                             \
        PADDLE_THROW(phi::errors::InvalidArgument(                         \
            "Product from begin_norm_axis to end must be larger than 1")); \
        break;                                                             \
    }                                                                      \
  } while (0)

564 565 566 567 568 569 570 571 572 573 574
#define PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, feature_size)          \
  case (feature_size): {                                                     \
    constexpr int WARPS_N = feature_size < 1024 ? 1 : (feature_size / 1024); \
    constexpr int WARPS_M = 4 / WARPS_N;                                     \
    const int THREADS_PER_WARP = 32;                                         \
    const int BYTES_PER_LDG = 16;                                            \
    const int VecSize = BYTES_PER_LDG / sizeof(T);                           \
    const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M;        \
    const int ROWS_PER_CTA = WARPS_M;                                        \
    const int grid = static_cast<int>(                                       \
        std::ceil(batch_size / static_cast<float>(ROWS_PER_CTA)));           \
575 576 577 578 579 580 581
    phi::funcs::fast_ln_fwd_kernel<T,                                        \
                                   U,                                        \
                                   ScaleT,                                   \
                                   VecSize,                                  \
                                   WARPS_M,                                  \
                                   WARPS_N,                                  \
                                   BYTES_PER_LDG>                            \
582 583 584 585 586 587 588 589 590 591
        <<<grid, THREADS_PER_CTA, 0, stream>>>(                              \
            batch_size,                                                      \
            feature_size,                                                    \
            epsilon,                                                         \
            x_data,                                                          \
            static_cast<const ScaleT *>(void_scale_data),                    \
            static_cast<const ScaleT *>(void_bias_data),                     \
            mean_data,                                                       \
            var_data,                                                        \
            y_data);                                                         \
592 593 594 595 596 597 598 599 600 601 602
  } break

#define PADDLE_LAUNCH_FAST_LAYERNORM_FWD(ScaleT)       \
  PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 768);  \
  PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1024); \
  PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1280); \
  PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1536); \
  PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1792); \
  PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 2048); \
  PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 4096)

H
hong 已提交
603
#ifdef PADDLE_WITH_CUDA
604 605 606 607 608 609
  bool can_call_fast_kernel = false;
  if ((feature_size >= 768 && feature_size <= 2048 && feature_size % 256 == 0 ||
       feature_size == 4096) &&
      scale != nullptr && bias != nullptr) {
    // can_call_fast_kernel = true;
    can_call_fast_kernel = false;
H
hong 已提交
610
  }
611 612

  if (can_call_fast_kernel) {
H
hong 已提交
613
    if (is_scale_bias_same_dtype_with_x) {
614 615 616 617 618 619 620 621 622
      switch (feature_size) {
        PADDLE_LAUNCH_FAST_LAYERNORM_FWD(T);
        default:
          PADDLE_THROW(phi::errors::InvalidArgument(
              "Only when feature_size is from 256 to 4096 and is diviaible by "
              "256 is supported "
              "now"));
          break;
      }
H
hong 已提交
623
    } else {
624 625 626 627 628 629 630 631 632
      switch (feature_size) {
        PADDLE_LAUNCH_FAST_LAYERNORM_FWD(U);
        default:
          PADDLE_THROW(phi::errors::InvalidArgument(
              "Only when feature_size is from 256 to 4096 and is diviaible by "
              "is supported "
              "now"));
          break;
      }
H
hong 已提交
633 634
    }
  } else {
635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650
    // WarpShuffle intrinsics is involved in LaunchLayerNormKernel.
    if (FLAGS_use_fast_math && feature_size <= 1024 &&
        (!std::is_same<T, int8_t>::value)) {
      LaunchLayerNormKernel<Context, T, U>(dev_ctx,
                                           x_data,
                                           y_data,
                                           void_scale_data,
                                           void_bias_data,
                                           mean_data,
                                           var_data,
                                           epsilon,
                                           batch_size,
                                           feature_size,
                                           valid_scale,
                                           valid_bias,
                                           is_scale_bias_same_dtype_with_x);
H
hong 已提交
651
    } else {
652 653 654 655 656 657
#endif
      if (is_scale_bias_same_dtype_with_x) {
        PADDLE_LAUNCH_LAYERNORM_FWD(T, true);
      } else {
        PADDLE_LAUNCH_LAYERNORM_FWD(U, false);
      }
H
hong 已提交
658
#ifdef PADDLE_WITH_CUDA
659
    }
H
hong 已提交
660 661 662 663
  }
#endif

#undef PADDLE_LAUNCH_LAYERNORM_FWD
664
#undef PADDLE_LAUNCH_FAST_LAYERNORM_FWD
H
hong 已提交
665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694
}

}  // namespace phi

#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
PD_REGISTER_KERNEL(layer_norm,
                   GPU,
                   ALL_LAYOUT,
                   phi::LayerNormKernel,
                   float,
                   phi::dtype::float16) {}
#elif CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(layer_norm,
                   GPU,
                   ALL_LAYOUT,
                   phi::LayerNormKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(layer_norm,
                   GPU,
                   ALL_LAYOUT,
                   phi::LayerNormKernel,
                   float,
                   double,
                   phi::dtype::float16) {}
#endif