gelu_funcs.h 7.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/platform/flags.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"

DECLARE_bool(use_fast_math);

namespace phi {

#ifdef __NVCC__
template <bool FastMode>
static __device__ __forceinline__ float FP32FastTanh(float x) {
#if __CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000
  if (FastMode) {
    float y;
    asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(y) : "f"(x));
    return y;
  }
#endif
  return tanhf(x);
}

template <bool FastMode>
static __device__ __forceinline__ float FP32GeluFwd(float x) {
  auto tanh_out =
      FP32FastTanh<FastMode>(0.79788456f * x * (1.0f + 0.044715f * x * x));
  return x * 0.5f * (1.0f + tanh_out);
}

template <bool FastMode>
static __device__ __forceinline__ float FP32GeluBwd(float x, float y_g) {
  auto tanh_out =
      FP32FastTanh<FastMode>(0.79788456f * x * (1.0f + 0.044715f * x * x));
  auto tmp = 0.5f * x * ((1.0f - tanh_out * tanh_out) *
                         (0.79788456f + 0.1070322243f * x * x)) +
             0.5f * (1.0f + tanh_out);
  return tmp * y_g;
}

template <int VecSize, bool FastMode>
static __global__ void FP16FastGeluFwdCUDAKernel(const __half* x,
                                                 __half* y,
                                                 size_t n) {
  size_t offset =
      static_cast<size_t>(threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
  size_t stride = static_cast<size_t>(blockDim.x * gridDim.x) * VecSize;
  for (; offset < n; offset += stride) {
    using ArrT = phi::AlignedVector<__half, VecSize>;
    ArrT in_arr = *reinterpret_cast<const ArrT*>(x + offset);
#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
      float tmp = __half2float(in_arr[i]);
      in_arr[i] = __float2half(FP32GeluFwd<FastMode>(tmp));
    }
    *reinterpret_cast<ArrT*>(y + offset) = in_arr;
  }
}

template <int VecSize, bool FastMode>
static __global__ void FP16FastGeluBwdCUDAKernel(const __half* x,
                                                 const __half* y_g,
                                                 __half* x_g,
                                                 size_t n) {
  size_t offset =
      static_cast<size_t>(threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
  size_t stride = static_cast<size_t>(blockDim.x * gridDim.x) * VecSize;
  for (; offset < n; offset += stride) {
    using ArrT = phi::AlignedVector<__half, VecSize>;
    ArrT x_in_arr = *reinterpret_cast<const ArrT*>(x + offset);
    ArrT y_g_in_arr = *reinterpret_cast<const ArrT*>(y_g + offset);
#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
      __half2 tmp_fp16_2;
      tmp_fp16_2.x = x_in_arr[i];
      tmp_fp16_2.y = y_g_in_arr[i];
      float2 tmp_fp32_2 = __half22float2(tmp_fp16_2);
      x_in_arr[i] =
          __float2half(FP32GeluBwd<FastMode>(tmp_fp32_2.x, tmp_fp32_2.y));
    }
    *reinterpret_cast<ArrT*>(x_g + offset) = x_in_arr;
  }
}

static bool TryLaunchFP16FastGeluFwdVectorizeCUDAKernel(
    const GPUContext& dev_ctx, const __half* x, __half* y, size_t n) {
  auto is_aligned = [](const void* p, size_t alignment) {
    return reinterpret_cast<uintptr_t>(p) % alignment == 0;
  };

#define PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(__vec_size, __use_fast_math)      \
  do {                                                                        \
    constexpr auto kAlignment =                                               \
        alignof(phi::AlignedVector<__half, __vec_size>);                      \
    if (n % __vec_size == 0 && is_aligned(x, kAlignment) &&                   \
        is_aligned(y, kAlignment)) {                                          \
      size_t thread = std::min<size_t>(512, dev_ctx.GetMaxThreadsPerBlock()); \
      size_t block = (n / __vec_size + thread - 1) / thread;                  \
      block = std::min<size_t>(block, dev_ctx.GetCUDAMaxGridDimSize()[0]);    \
      VLOG(10) << "Use FP16 fast gelu fwd kernel, block = " << block          \
               << " , thread = " << thread;                                   \
      FP16FastGeluFwdCUDAKernel<                                              \
          __vec_size,                                                         \
          __use_fast_math><<<block, thread, 0, dev_ctx.stream()>>>(x, y, n);  \
      return true;                                                            \
    }                                                                         \
  } while (0)

  if (FLAGS_use_fast_math) {
    PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(8, true);
  } else {
    PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(8, false);
  }

#undef PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL
  return false;
}

static bool TryLaunchFP16FastGeluBwdVectorizeCUDAKernel(
    const GPUContext& dev_ctx,
    const __half* x,
    const __half* y_g,
    __half* x_g,
    size_t n) {
  auto is_aligned = [](const void* p, size_t alignment) {
    return reinterpret_cast<uintptr_t>(p) % alignment == 0;
  };

#define PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(__vec_size, __use_fast_math)      \
  do {                                                                        \
    constexpr auto kAlignment =                                               \
        alignof(phi::AlignedVector<__half, __vec_size>);                      \
    if (n % __vec_size == 0 && is_aligned(x, kAlignment) &&                   \
        is_aligned(x, kAlignment) && is_aligned(y_g, kAlignment) &&           \
        is_aligned(x_g, kAlignment)) {                                        \
      size_t thread = std::min<size_t>(512, dev_ctx.GetMaxThreadsPerBlock()); \
      size_t block = (n / __vec_size + thread - 1) / thread;                  \
      block = std::min<size_t>(block, dev_ctx.GetCUDAMaxGridDimSize()[0]);    \
      VLOG(10) << "Use FP16 fast gelu bwd kernel, block = " << block          \
               << " , thread = " << thread;                                   \
      FP16FastGeluBwdCUDAKernel<                                              \
          __vec_size,                                                         \
          __use_fast_math><<<block, thread, 0, dev_ctx.stream()>>>(           \
          x, y_g, x_g, n);                                                    \
      return true;                                                            \
    }                                                                         \
  } while (0)

  if (FLAGS_use_fast_math) {
    PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(8, true);
  } else {
    PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(8, false);
  }

#undef PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL
  return false;
}
#endif

}  // namespace phi