diff --git a/paddle/fluid/operators/gelu_op.cu b/paddle/fluid/operators/gelu_op.cu index 8151d21fa676d2f4715b5f446fcd7b6152dbb29d..2f72fbff2668b7b9f8a78c086d88270164d04b79 100644 --- a/paddle/fluid/operators/gelu_op.cu +++ b/paddle/fluid/operators/gelu_op.cu @@ -16,9 +16,156 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/gelu_op.h" +DECLARE_bool(use_fast_math); + namespace paddle { namespace operators { +#ifdef __NVCC__ +template +static __device__ __forceinline__ float FP32FastTanh(float x) { +#if __CUDA_ARCH__ >= 750 && !defined(_WIN32) + if (FastMode) { + float y; + asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(y) : "f"(x)); + return y; + } +#endif + return tanhf(x); +} + +template +static __device__ __forceinline__ float FP32GeluFwd(float x) { + auto tanh_out = + FP32FastTanh(0.79788456f * x * (1.0f + 0.044715f * x * x)); + return x * 0.5f * (1.0f + tanh_out); +} + +template +static __device__ __forceinline__ float FP32GeluBwd(float x, float y_g) { + auto tanh_out = + FP32FastTanh(0.79788456f * x * (1.0f + 0.044715f * x * x)); + auto tmp = 0.5f * x * ((1.0f - tanh_out * tanh_out) * + (0.79788456f + 0.1070322243f * x * x)) + + 0.5f * (1.0f + tanh_out); + return tmp * y_g; +} + +template +static __global__ void FP16FastGeluFwdCUDAKernel(const __half* x, __half* y, + size_t n) { + size_t offset = + static_cast(threadIdx.x + blockIdx.x * blockDim.x) * VecSize; + size_t stride = static_cast(blockDim.x * gridDim.x) * VecSize; + for (; offset < n; offset += stride) { + using ArrT = platform::AlignedVector<__half, VecSize>; + ArrT in_arr = *reinterpret_cast(x + offset); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + float tmp = __half2float(in_arr[i]); + in_arr[i] = __float2half(FP32GeluFwd(tmp)); + } + *reinterpret_cast(y + offset) = in_arr; + } +} + +template +static __global__ void FP16FastGeluBwdCUDAKernel(const __half* x, + const __half* y_g, __half* x_g, + size_t n) { + size_t offset = + static_cast(threadIdx.x + blockIdx.x * blockDim.x) * VecSize; + size_t stride = static_cast(blockDim.x * gridDim.x) * VecSize; + for (; offset < n; offset += stride) { + using ArrT = platform::AlignedVector<__half, VecSize>; + ArrT x_in_arr = *reinterpret_cast(x + offset); + ArrT y_g_in_arr = *reinterpret_cast(y_g + offset); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + __half2 tmp_fp16_2; + tmp_fp16_2.x = x_in_arr[i]; + tmp_fp16_2.y = y_g_in_arr[i]; + float2 tmp_fp32_2 = __half22float2(tmp_fp16_2); + x_in_arr[i] = + __float2half(FP32GeluBwd(tmp_fp32_2.x, tmp_fp32_2.y)); + } + *reinterpret_cast(x_g + offset) = x_in_arr; + } +} + +static bool TryLaunchFP16FastGeluFwdVectorizeCUDAKernel( + const platform::CUDADeviceContext& dev_ctx, const __half* x, __half* y, + size_t n) { + auto is_aligned = [](const void* p, size_t alignment) { + return reinterpret_cast(p) % alignment == 0; + }; + +#define PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(__vec_size, __use_fast_math) \ + do { \ + constexpr auto kAlignment = \ + alignof(platform::AlignedVector<__half, __vec_size>); \ + if (n % __vec_size == 0 && is_aligned(x, kAlignment) && \ + is_aligned(y, kAlignment)) { \ + size_t thread = std::min(512, dev_ctx.GetMaxThreadsPerBlock()); \ + size_t block = (n / __vec_size + thread - 1) / thread; \ + block = std::min(block, dev_ctx.GetCUDAMaxGridDimSize().x); \ + VLOG(10) << "Use FP16 fast gelu fwd kernel, block = " << block \ + << " , thread = " << thread; \ + FP16FastGeluFwdCUDAKernel< \ + __vec_size, \ + __use_fast_math><<>>(x, y, n); \ + return true; \ + } \ + } while (0) + + if (FLAGS_use_fast_math) { + PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(8, true); + } else { + PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(8, false); + } + +#undef PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL + return false; +} + +static bool TryLaunchFP16FastGeluBwdVectorizeCUDAKernel( + const platform::CUDADeviceContext& dev_ctx, const __half* x, + const __half* y_g, __half* x_g, size_t n) { + auto is_aligned = [](const void* p, size_t alignment) { + return reinterpret_cast(p) % alignment == 0; + }; + +#define PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(__vec_size, __use_fast_math) \ + do { \ + constexpr auto kAlignment = \ + alignof(platform::AlignedVector<__half, __vec_size>); \ + if (n % __vec_size == 0 && is_aligned(x, kAlignment) && \ + is_aligned(x, kAlignment) && is_aligned(y_g, kAlignment) && \ + is_aligned(x_g, kAlignment)) { \ + size_t thread = std::min(512, dev_ctx.GetMaxThreadsPerBlock()); \ + size_t block = (n / __vec_size + thread - 1) / thread; \ + block = std::min(block, dev_ctx.GetCUDAMaxGridDimSize().x); \ + VLOG(10) << "Use FP16 fast gelu bwd kernel, block = " << block \ + << " , thread = " << thread; \ + FP16FastGeluBwdCUDAKernel< \ + __vec_size, \ + __use_fast_math><<>>(x, y_g, \ + x_g, n); \ + return true; \ + } \ + } while (0) + + if (FLAGS_use_fast_math) { + PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(8, true); + } else { + PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(8, false); + } + +#undef PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL + return false; +} +#endif + template struct GeluWithApproximateFunctor { using MPType = typename details::MPTypeTrait::Type; @@ -59,7 +206,19 @@ class GeluKernel std::vector outs = {out}; const auto& dev_ctx = context.template device_context(); + if (approximate) { +#ifdef __NVCC__ + if (std::is_same::value) { + size_t n = in->numel(); + const auto* in_ptr = reinterpret_cast(in->data()); + auto* out_ptr = reinterpret_cast<__half*>(out->data()); + if (TryLaunchFP16FastGeluFwdVectorizeCUDAKernel(dev_ctx, in_ptr, + out_ptr, n)) { + return; + } + } +#endif LaunchElementwiseCudaKernel( dev_ctx, ins, &outs, 0, GeluWithApproximateFunctor()); } else { @@ -120,6 +279,18 @@ class GeluGradKernel const auto& dev_ctx = context.template device_context(); if (approximate) { +#ifdef __NVCC__ + if (std::is_same::value) { + size_t n = x->numel(); + const auto* x_ptr = reinterpret_cast(x->data()); + const auto* y_g_ptr = reinterpret_cast(dout->data()); + auto* x_g_ptr = reinterpret_cast<__half*>(dx->data()); + if (TryLaunchFP16FastGeluBwdVectorizeCUDAKernel(dev_ctx, x_ptr, y_g_ptr, + x_g_ptr, n)) { + return; + } + } +#endif LaunchElementwiseCudaKernel( dev_ctx, ins, &outs, 0, GeluWithApproximateGradFunctor()); } else { diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 8b117a5a8292d598d6e2ac7504277e13799f0c6c..44bd4eaa29b807214912f45401b962d089d02348 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -652,6 +652,9 @@ PADDLE_DEFINE_EXPORTED_bool( #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PADDLE_DEFINE_EXPORTED_bool(conv2d_disable_cudnn, false, "Disable cudnn in conv2d"); + +PADDLE_DEFINE_EXPORTED_bool(use_fast_math, false, + "Whether to use fast math GPU functions."); #endif /** diff --git a/python/paddle/fluid/tests/unittests/test_gelu_op.py b/python/paddle/fluid/tests/unittests/test_gelu_op.py index 13174edad8f17412ac19befbc984e7c699ef6f8d..de34b63c9398e70ea97adf61fbee0183d8dc6468 100644 --- a/python/paddle/fluid/tests/unittests/test_gelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_gelu_op.py @@ -19,6 +19,8 @@ import numpy as np from scipy.special import erf import paddle.fluid as fluid import paddle.fluid.dygraph as dg +import paddle +import paddle.nn.functional as F def gelu(x, approximate): @@ -59,6 +61,36 @@ class TestGeluOp(unittest.TestCase): if fluid.is_compiled_with_cuda(): self._test_case1_gpu(approximate) + def test_fast_math(self): + if not paddle.is_compiled_with_cuda(): + return + + def use_fast_math(enabled): + paddle.set_flags({'FLAGS_use_fast_math': enabled}) + + shape = [11, 17, 8] + x_np = np.random.uniform(-1, 1, size=shape).astype(np.float16) + y_g_np = np.random.uniform(-1, 1, size=shape).astype(np.float16) + + def run_gelu_op(approximate): + with dg.guard(): + x = paddle.to_tensor(x_np) + x.stop_gradient = False + y = F.gelu(x, approximate=approximate) + x_grad = paddle.grad([y], [x], [paddle.to_tensor(y_g_np)])[0] + return y.numpy(), x_grad.numpy() + + use_fast_math(True) + y_fast_math, x_g_fast_math = run_gelu_op(True) + use_fast_math(False) + + y_ref, x_g_ref = run_gelu_op(True) + self.assertTrue(np.allclose(y_ref, y_fast_math, rtol=1e-5, atol=5e-4)) + + self.assertTrue( + np.allclose( + x_g_ref, x_g_fast_math, rtol=1e-5, atol=5e-4)) + if __name__ == '__main__': unittest.main()