未验证 提交 8c20d668 编写于 作者: S sneaxiy 提交者: GitHub

Speedup FP16 Gelu op using fast math and vectorized 8 kernel (#38980)

* speedup gelu using fast math

* add bwd part
上级 55e9087f
......@@ -16,9 +16,156 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/gelu_op.h"
DECLARE_bool(use_fast_math);
namespace paddle {
namespace operators {
#ifdef __NVCC__
template <bool FastMode>
static __device__ __forceinline__ float FP32FastTanh(float x) {
#if __CUDA_ARCH__ >= 750 && !defined(_WIN32)
if (FastMode) {
float y;
asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(y) : "f"(x));
return y;
}
#endif
return tanhf(x);
}
template <bool FastMode>
static __device__ __forceinline__ float FP32GeluFwd(float x) {
auto tanh_out =
FP32FastTanh<FastMode>(0.79788456f * x * (1.0f + 0.044715f * x * x));
return x * 0.5f * (1.0f + tanh_out);
}
template <bool FastMode>
static __device__ __forceinline__ float FP32GeluBwd(float x, float y_g) {
auto tanh_out =
FP32FastTanh<FastMode>(0.79788456f * x * (1.0f + 0.044715f * x * x));
auto tmp = 0.5f * x * ((1.0f - tanh_out * tanh_out) *
(0.79788456f + 0.1070322243f * x * x)) +
0.5f * (1.0f + tanh_out);
return tmp * y_g;
}
template <int VecSize, bool FastMode>
static __global__ void FP16FastGeluFwdCUDAKernel(const __half* x, __half* y,
size_t n) {
size_t offset =
static_cast<size_t>(threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
size_t stride = static_cast<size_t>(blockDim.x * gridDim.x) * VecSize;
for (; offset < n; offset += stride) {
using ArrT = platform::AlignedVector<__half, VecSize>;
ArrT in_arr = *reinterpret_cast<const ArrT*>(x + offset);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
float tmp = __half2float(in_arr[i]);
in_arr[i] = __float2half(FP32GeluFwd<FastMode>(tmp));
}
*reinterpret_cast<ArrT*>(y + offset) = in_arr;
}
}
template <int VecSize, bool FastMode>
static __global__ void FP16FastGeluBwdCUDAKernel(const __half* x,
const __half* y_g, __half* x_g,
size_t n) {
size_t offset =
static_cast<size_t>(threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
size_t stride = static_cast<size_t>(blockDim.x * gridDim.x) * VecSize;
for (; offset < n; offset += stride) {
using ArrT = platform::AlignedVector<__half, VecSize>;
ArrT x_in_arr = *reinterpret_cast<const ArrT*>(x + offset);
ArrT y_g_in_arr = *reinterpret_cast<const ArrT*>(y_g + offset);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
__half2 tmp_fp16_2;
tmp_fp16_2.x = x_in_arr[i];
tmp_fp16_2.y = y_g_in_arr[i];
float2 tmp_fp32_2 = __half22float2(tmp_fp16_2);
x_in_arr[i] =
__float2half(FP32GeluBwd<FastMode>(tmp_fp32_2.x, tmp_fp32_2.y));
}
*reinterpret_cast<ArrT*>(x_g + offset) = x_in_arr;
}
}
static bool TryLaunchFP16FastGeluFwdVectorizeCUDAKernel(
const platform::CUDADeviceContext& dev_ctx, const __half* x, __half* y,
size_t n) {
auto is_aligned = [](const void* p, size_t alignment) {
return reinterpret_cast<uintptr_t>(p) % alignment == 0;
};
#define PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(__vec_size, __use_fast_math) \
do { \
constexpr auto kAlignment = \
alignof(platform::AlignedVector<__half, __vec_size>); \
if (n % __vec_size == 0 && is_aligned(x, kAlignment) && \
is_aligned(y, kAlignment)) { \
size_t thread = std::min<size_t>(512, dev_ctx.GetMaxThreadsPerBlock()); \
size_t block = (n / __vec_size + thread - 1) / thread; \
block = std::min<size_t>(block, dev_ctx.GetCUDAMaxGridDimSize().x); \
VLOG(10) << "Use FP16 fast gelu fwd kernel, block = " << block \
<< " , thread = " << thread; \
FP16FastGeluFwdCUDAKernel< \
__vec_size, \
__use_fast_math><<<block, thread, 0, dev_ctx.stream()>>>(x, y, n); \
return true; \
} \
} while (0)
if (FLAGS_use_fast_math) {
PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(8, true);
} else {
PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(8, false);
}
#undef PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL
return false;
}
static bool TryLaunchFP16FastGeluBwdVectorizeCUDAKernel(
const platform::CUDADeviceContext& dev_ctx, const __half* x,
const __half* y_g, __half* x_g, size_t n) {
auto is_aligned = [](const void* p, size_t alignment) {
return reinterpret_cast<uintptr_t>(p) % alignment == 0;
};
#define PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(__vec_size, __use_fast_math) \
do { \
constexpr auto kAlignment = \
alignof(platform::AlignedVector<__half, __vec_size>); \
if (n % __vec_size == 0 && is_aligned(x, kAlignment) && \
is_aligned(x, kAlignment) && is_aligned(y_g, kAlignment) && \
is_aligned(x_g, kAlignment)) { \
size_t thread = std::min<size_t>(512, dev_ctx.GetMaxThreadsPerBlock()); \
size_t block = (n / __vec_size + thread - 1) / thread; \
block = std::min<size_t>(block, dev_ctx.GetCUDAMaxGridDimSize().x); \
VLOG(10) << "Use FP16 fast gelu bwd kernel, block = " << block \
<< " , thread = " << thread; \
FP16FastGeluBwdCUDAKernel< \
__vec_size, \
__use_fast_math><<<block, thread, 0, dev_ctx.stream()>>>(x, y_g, \
x_g, n); \
return true; \
} \
} while (0)
if (FLAGS_use_fast_math) {
PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(8, true);
} else {
PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(8, false);
}
#undef PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL
return false;
}
#endif
template <typename T>
struct GeluWithApproximateFunctor {
using MPType = typename details::MPTypeTrait<T>::Type;
......@@ -59,7 +206,19 @@ class GeluKernel<platform::CUDADeviceContext, T>
std::vector<framework::Tensor*> outs = {out};
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
if (approximate) {
#ifdef __NVCC__
if (std::is_same<T, platform::float16>::value) {
size_t n = in->numel();
const auto* in_ptr = reinterpret_cast<const __half*>(in->data<T>());
auto* out_ptr = reinterpret_cast<__half*>(out->data<T>());
if (TryLaunchFP16FastGeluFwdVectorizeCUDAKernel(dev_ctx, in_ptr,
out_ptr, n)) {
return;
}
}
#endif
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithApproximateFunctor<T>());
} else {
......@@ -120,6 +279,18 @@ class GeluGradKernel<platform::CUDADeviceContext, T>
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
if (approximate) {
#ifdef __NVCC__
if (std::is_same<T, platform::float16>::value) {
size_t n = x->numel();
const auto* x_ptr = reinterpret_cast<const __half*>(x->data<T>());
const auto* y_g_ptr = reinterpret_cast<const __half*>(dout->data<T>());
auto* x_g_ptr = reinterpret_cast<__half*>(dx->data<T>());
if (TryLaunchFP16FastGeluBwdVectorizeCUDAKernel(dev_ctx, x_ptr, y_g_ptr,
x_g_ptr, n)) {
return;
}
}
#endif
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithApproximateGradFunctor<T>());
} else {
......
......@@ -652,6 +652,9 @@ PADDLE_DEFINE_EXPORTED_bool(
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PADDLE_DEFINE_EXPORTED_bool(conv2d_disable_cudnn, false,
"Disable cudnn in conv2d");
PADDLE_DEFINE_EXPORTED_bool(use_fast_math, false,
"Whether to use fast math GPU functions.");
#endif
/**
......
......@@ -19,6 +19,8 @@ import numpy as np
from scipy.special import erf
import paddle.fluid as fluid
import paddle.fluid.dygraph as dg
import paddle
import paddle.nn.functional as F
def gelu(x, approximate):
......@@ -59,6 +61,36 @@ class TestGeluOp(unittest.TestCase):
if fluid.is_compiled_with_cuda():
self._test_case1_gpu(approximate)
def test_fast_math(self):
if not paddle.is_compiled_with_cuda():
return
def use_fast_math(enabled):
paddle.set_flags({'FLAGS_use_fast_math': enabled})
shape = [11, 17, 8]
x_np = np.random.uniform(-1, 1, size=shape).astype(np.float16)
y_g_np = np.random.uniform(-1, 1, size=shape).astype(np.float16)
def run_gelu_op(approximate):
with dg.guard():
x = paddle.to_tensor(x_np)
x.stop_gradient = False
y = F.gelu(x, approximate=approximate)
x_grad = paddle.grad([y], [x], [paddle.to_tensor(y_g_np)])[0]
return y.numpy(), x_grad.numpy()
use_fast_math(True)
y_fast_math, x_g_fast_math = run_gelu_op(True)
use_fast_math(False)
y_ref, x_g_ref = run_gelu_op(True)
self.assertTrue(np.allclose(y_ref, y_fast_math, rtol=1e-5, atol=5e-4))
self.assertTrue(
np.allclose(
x_g_ref, x_g_fast_math, rtol=1e-5, atol=5e-4))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册