gelu_op.cu 13.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
17 18
#include "paddle/fluid/operators/gelu_op.h"

19 20
DECLARE_bool(use_fast_math);

21 22 23
namespace paddle {
namespace operators {

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
#ifdef __NVCC__
template <bool FastMode>
static __device__ __forceinline__ float FP32FastTanh(float x) {
#if __CUDA_ARCH__ >= 750 && !defined(_WIN32)
  if (FastMode) {
    float y;
    asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(y) : "f"(x));
    return y;
  }
#endif
  return tanhf(x);
}

template <bool FastMode>
static __device__ __forceinline__ float FP32GeluFwd(float x) {
  auto tanh_out =
      FP32FastTanh<FastMode>(0.79788456f * x * (1.0f + 0.044715f * x * x));
  return x * 0.5f * (1.0f + tanh_out);
}

template <bool FastMode>
static __device__ __forceinline__ float FP32GeluBwd(float x, float y_g) {
  auto tanh_out =
      FP32FastTanh<FastMode>(0.79788456f * x * (1.0f + 0.044715f * x * x));
  auto tmp = 0.5f * x * ((1.0f - tanh_out * tanh_out) *
                         (0.79788456f + 0.1070322243f * x * x)) +
             0.5f * (1.0f + tanh_out);
  return tmp * y_g;
}

template <int VecSize, bool FastMode>
static __global__ void FP16FastGeluFwdCUDAKernel(const __half* x, __half* y,
                                                 size_t n) {
  size_t offset =
      static_cast<size_t>(threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
  size_t stride = static_cast<size_t>(blockDim.x * gridDim.x) * VecSize;
  for (; offset < n; offset += stride) {
    using ArrT = platform::AlignedVector<__half, VecSize>;
    ArrT in_arr = *reinterpret_cast<const ArrT*>(x + offset);
#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
      float tmp = __half2float(in_arr[i]);
      in_arr[i] = __float2half(FP32GeluFwd<FastMode>(tmp));
    }
    *reinterpret_cast<ArrT*>(y + offset) = in_arr;
  }
}

template <int VecSize, bool FastMode>
static __global__ void FP16FastGeluBwdCUDAKernel(const __half* x,
                                                 const __half* y_g, __half* x_g,
                                                 size_t n) {
  size_t offset =
      static_cast<size_t>(threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
  size_t stride = static_cast<size_t>(blockDim.x * gridDim.x) * VecSize;
  for (; offset < n; offset += stride) {
    using ArrT = platform::AlignedVector<__half, VecSize>;
    ArrT x_in_arr = *reinterpret_cast<const ArrT*>(x + offset);
    ArrT y_g_in_arr = *reinterpret_cast<const ArrT*>(y_g + offset);
#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
      __half2 tmp_fp16_2;
      tmp_fp16_2.x = x_in_arr[i];
      tmp_fp16_2.y = y_g_in_arr[i];
      float2 tmp_fp32_2 = __half22float2(tmp_fp16_2);
      x_in_arr[i] =
          __float2half(FP32GeluBwd<FastMode>(tmp_fp32_2.x, tmp_fp32_2.y));
    }
    *reinterpret_cast<ArrT*>(x_g + offset) = x_in_arr;
  }
}

static bool TryLaunchFP16FastGeluFwdVectorizeCUDAKernel(
    const platform::CUDADeviceContext& dev_ctx, const __half* x, __half* y,
    size_t n) {
  auto is_aligned = [](const void* p, size_t alignment) {
    return reinterpret_cast<uintptr_t>(p) % alignment == 0;
  };

#define PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(__vec_size, __use_fast_math)      \
  do {                                                                        \
    constexpr auto kAlignment =                                               \
        alignof(platform::AlignedVector<__half, __vec_size>);                 \
    if (n % __vec_size == 0 && is_aligned(x, kAlignment) &&                   \
        is_aligned(y, kAlignment)) {                                          \
      size_t thread = std::min<size_t>(512, dev_ctx.GetMaxThreadsPerBlock()); \
      size_t block = (n / __vec_size + thread - 1) / thread;                  \
      block = std::min<size_t>(block, dev_ctx.GetCUDAMaxGridDimSize().x);     \
      VLOG(10) << "Use FP16 fast gelu fwd kernel, block = " << block          \
               << " , thread = " << thread;                                   \
      FP16FastGeluFwdCUDAKernel<                                              \
          __vec_size,                                                         \
          __use_fast_math><<<block, thread, 0, dev_ctx.stream()>>>(x, y, n);  \
      return true;                                                            \
    }                                                                         \
  } while (0)

  if (FLAGS_use_fast_math) {
    PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(8, true);
  } else {
    PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(8, false);
  }

#undef PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL
  return false;
}

static bool TryLaunchFP16FastGeluBwdVectorizeCUDAKernel(
    const platform::CUDADeviceContext& dev_ctx, const __half* x,
    const __half* y_g, __half* x_g, size_t n) {
  auto is_aligned = [](const void* p, size_t alignment) {
    return reinterpret_cast<uintptr_t>(p) % alignment == 0;
  };

#define PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(__vec_size, __use_fast_math)      \
  do {                                                                        \
    constexpr auto kAlignment =                                               \
        alignof(platform::AlignedVector<__half, __vec_size>);                 \
    if (n % __vec_size == 0 && is_aligned(x, kAlignment) &&                   \
        is_aligned(x, kAlignment) && is_aligned(y_g, kAlignment) &&           \
        is_aligned(x_g, kAlignment)) {                                        \
      size_t thread = std::min<size_t>(512, dev_ctx.GetMaxThreadsPerBlock()); \
      size_t block = (n / __vec_size + thread - 1) / thread;                  \
      block = std::min<size_t>(block, dev_ctx.GetCUDAMaxGridDimSize().x);     \
      VLOG(10) << "Use FP16 fast gelu bwd kernel, block = " << block          \
               << " , thread = " << thread;                                   \
      FP16FastGeluBwdCUDAKernel<                                              \
          __vec_size,                                                         \
          __use_fast_math><<<block, thread, 0, dev_ctx.stream()>>>(x, y_g,    \
                                                                   x_g, n);   \
      return true;                                                            \
    }                                                                         \
  } while (0)

  if (FLAGS_use_fast_math) {
    PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(8, true);
  } else {
    PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(8, false);
  }

#undef PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL
  return false;
}
#endif

169 170 171 172 173 174 175
template <typename T>
struct GeluWithApproximateFunctor {
  using MPType = typename details::MPTypeTrait<T>::Type;
  inline HOSTDEVICE T operator()(T arg_x) {
    // this function is tanh approximation of gelu
    MPType x = static_cast<MPType>(arg_x);
    MPType one = static_cast<MPType>(1);
176 177 178 179 180
    MPType half = static_cast<MPType>(0.5);
    MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
    auto tanh_out =
        tanh(kAlpha * x * (one + static_cast<MPType>(GELU_CONSTANT) * x * x));
    MPType out = x * half * (one + tanh_out);
181 182 183 184 185 186 187 188 189 190
    return static_cast<T>(out);
  }
};

template <typename T>
struct GeluWithoutApproximateFunctor {
  using MPType = typename details::MPTypeTrait<T>::Type;
  inline HOSTDEVICE T operator()(T arg_x) {
    // actual gelu with approximation = false
    MPType x = static_cast<MPType>(arg_x);
191
    return static_cast<T>(x * normcdf(x));
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
  }
};

template <typename T>
class GeluKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* out = context.Output<framework::Tensor>("Out");
    auto* in = context.Input<framework::Tensor>("X");
    auto approximate = context.Attr<bool>("approximate");
    out->mutable_data<T>(in->place());

    std::vector<const framework::Tensor*> ins = {in};
    std::vector<framework::Tensor*> outs = {out};
    const auto& dev_ctx =
        context.template device_context<platform::CUDADeviceContext>();
209

210
    if (approximate) {
211 212 213 214 215 216 217 218 219 220 221
#ifdef __NVCC__
      if (std::is_same<T, platform::float16>::value) {
        size_t n = in->numel();
        const auto* in_ptr = reinterpret_cast<const __half*>(in->data<T>());
        auto* out_ptr = reinterpret_cast<__half*>(out->data<T>());
        if (TryLaunchFP16FastGeluFwdVectorizeCUDAKernel(dev_ctx, in_ptr,
                                                        out_ptr, n)) {
          return;
        }
      }
#endif
222 223
      paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary,
                                                     T, T>(
224 225
          dev_ctx, ins, &outs, 0, GeluWithApproximateFunctor<T>());
    } else {
226 227
      paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary,
                                                     T, T>(
228 229 230 231 232
          dev_ctx, ins, &outs, 0, GeluWithoutApproximateFunctor<T>());
    }
  }
};

233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
template <typename T>
struct GeluWithApproximateGradFunctor {
  using MPType = typename details::MPTypeTrait<T>::Type;
  inline HOSTDEVICE T operator()(T arg_x, T arg_dout) {
    MPType x = static_cast<MPType>(arg_x);
    MPType dout = static_cast<MPType>(arg_dout);
    MPType one = static_cast<MPType>(1);
    MPType half = static_cast<MPType>(0.5);
    MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
    MPType kBeta =
        kAlpha * static_cast<MPType>(GELU_CONSTANT) * static_cast<MPType>(3);
    auto cube_x = x * x * x;
    auto tanh_out =
        tanh(kAlpha * ((static_cast<MPType>(GELU_CONSTANT) * cube_x) + x));
    auto ans =
        half * (one + tanh_out +
                (one - tanh_out * tanh_out) * (x * kAlpha + kBeta * cube_x));
    return static_cast<T>(ans * dout);
  }
};

template <typename T>
struct GeluWithoutApproximateGradFunctor {
  using MPType = typename details::MPTypeTrait<T>::Type;
  inline HOSTDEVICE T operator()(T arg_x, T arg_dout) {
    MPType x = static_cast<MPType>(arg_x);
    MPType dout = static_cast<MPType>(arg_dout);
260 261 262 263
    constexpr MPType kBeta = M_2_SQRTPI * M_SQRT1_2 * static_cast<MPType>(0.5);
    const MPType cdf = normcdf(x);
    const MPType pdf = exp(static_cast<MPType>(-0.5) * x * x) * kBeta;
    return static_cast<T>(dout * (cdf + x * pdf));
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
  }
};

template <typename T>
class GeluGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<framework::Tensor>("X");
    auto* dout =
        context.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
    auto approximate = context.Attr<bool>("approximate");
    dx->mutable_data<T>(dout->place());

    std::vector<const framework::Tensor*> ins = {x, dout};
    std::vector<framework::Tensor*> outs = {dx};
    const auto& dev_ctx =
        context.template device_context<platform::CUDADeviceContext>();
    if (approximate) {
284 285 286 287 288 289 290 291 292 293 294 295
#ifdef __NVCC__
      if (std::is_same<T, platform::float16>::value) {
        size_t n = x->numel();
        const auto* x_ptr = reinterpret_cast<const __half*>(x->data<T>());
        const auto* y_g_ptr = reinterpret_cast<const __half*>(dout->data<T>());
        auto* x_g_ptr = reinterpret_cast<__half*>(dx->data<T>());
        if (TryLaunchFP16FastGeluBwdVectorizeCUDAKernel(dev_ctx, x_ptr, y_g_ptr,
                                                        x_g_ptr, n)) {
          return;
        }
      }
#endif
296 297
      paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary,
                                                     T, T>(
298 299
          dev_ctx, ins, &outs, 0, GeluWithApproximateGradFunctor<T>());
    } else {
300 301
      paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary,
                                                     T, T>(
302 303 304 305 306
          dev_ctx, ins, &outs, 0, GeluWithoutApproximateGradFunctor<T>());
    }
  }
};

307 308 309
}  // namespace operators
}  // namespace paddle

310 311 312 313 314 315 316 317 318 319 320
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    gelu, ops::GeluKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GeluKernel<paddle::platform::CUDADeviceContext, double>,
    ops::GeluKernel<paddle::platform::CUDADeviceContext,
                    paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
    gelu_grad, ops::GeluGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::GeluGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::GeluGradKernel<paddle::platform::CUDADeviceContext,
                        paddle::platform::float16>);