未验证 提交 b0e28540 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Optimize FusedBiasAddGelu Kernel (#47679)

* Add quick gelu and fused bias add kernel

* fix annotation

* remove useless code

* add fast gelu option and set it in multi transformer op

* add flag to restrict if use fast gelu approximate

* fix flags conflict

* fix use tanh function instead

* add cudart version limit

* use phi fast tanh func

* fix comment
上级 27ee6e71
......@@ -18,13 +18,11 @@ limitations under the License. */
#endif
#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h"
#include "paddle/phi/kernels/gpu/gelu_funcs.h"
namespace paddle {
namespace operators {
/**
*@brief the gelu functor
*/
template <typename T>
struct GeluFunctor {
inline __host__ __device__ T operator()(const T x) const {
......@@ -36,6 +34,13 @@ struct GeluFunctor {
}
};
template <typename T>
struct FastGeluFunctor {
inline __device__ T operator()(const T x) const {
return phi::GeluFwd<T, true>(x);
}
};
/**
*@brief the gelu grad functor
*/
......@@ -131,6 +136,49 @@ __global__ void FusedDropoutActBias(
}
}
template <typename T,
int VecSize,
typename Functor,
typename InType = T,
typename OutType = T>
__global__ void FusedActBias(Functor act,
const uint64_t elem_cnt,
const uint64_t cols,
const InType *__restrict__ src,
const T *__restrict__ bias,
OutType *dst) {
const int32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = phi::AlignedVector<T, VecSize>;
using LoadInType = phi::AlignedVector<InType, VecSize>;
using LoadFloat = phi::AlignedVector<float, VecSize>;
using StoreOutType = phi::AlignedVector<OutType, VecSize>;
LoadInType src_vec;
LoadT bias_vec;
StoreOutType out_vec;
for (int32_t idx = global_thread_idx * VecSize,
step = blockDim.x * gridDim.x * VecSize;
idx < elem_cnt;
idx += step) {
const int32_t col_idx = idx % cols;
phi::Load<InType, VecSize>(&src[idx], &src_vec);
if (bias) {
phi::Load<T, VecSize>(&bias[col_idx], &bias_vec);
}
#pragma unroll
for (int32_t unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) {
if (bias) {
out_vec[unroll_idx] = static_cast<OutType>(
act(static_cast<T>(src_vec[unroll_idx]) + bias_vec[unroll_idx]));
} else {
out_vec[unroll_idx] =
static_cast<OutType>(act(static_cast<T>(src_vec[unroll_idx])));
}
}
phi::Store<OutType, VecSize>(out_vec, &dst[idx]);
}
}
/**
* @brief dst = dropout(activation(src + bias));
*/
......@@ -170,6 +218,18 @@ void LaunchDropoutActBias(Functor act_functor,
const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
if (cols % VecSize == 0) {
if (is_test && (dequant_out_scale_data == nullptr)) {
const int32_t elem_cnt = rows * cols;
const int32_t pack_num = elem_cnt / VecSize;
const int32_t tmp_cols = cols / VecSize;
int block_size =
std::max(static_cast<int32_t>(32), std::min(tmp_cols, 128));
const int grid_size = std::max(static_cast<int32_t>(1),
(pack_num + block_size - 1) / block_size);
FusedActBias<T, VecSize, Functor, InType, OutType>
<<<grid_size, block_size, 0, ctx.stream()>>>(
act_functor, elem_cnt, cols, src, bias, dst);
} else {
FusedDropoutActBias<T, MaskType, VecSize, Functor, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor,
......@@ -188,6 +248,7 @@ void LaunchDropoutActBias(Functor act_functor,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale);
}
} else {
FusedDropoutActBias<T, MaskType, 1, Functor, InType, OutType>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
......
......@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h"
#include "paddle/phi/kernels/funcs/functors.h"
DECLARE_bool(use_fast_math);
namespace paddle {
namespace operators {
......@@ -216,6 +218,30 @@ class FusedDropoutHelper {
const float quant_min_bound = -127.0) {
auto increment = GetIncrement(ctx);
if (act_method == "gelu") {
if (FLAGS_use_fast_math) {
FastGeluFunctor<T> fast_gelu;
LaunchDropoutActBias<T, MaskType, FastGeluFunctor<T>, InType, OutType>(
fast_gelu,
dropout_param_.seed,
rows_,
cols_,
dropout_param_.increment,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
dropout_param_.is_test,
src,
bias,
out,
mask,
ctx,
quant_last_in_scale,
dequant_out_scale_data,
quant_out_scale_offset,
quant_next_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
GeluFunctor<T> gelu;
LaunchDropoutActBias<T, MaskType, GeluFunctor<T>, InType, OutType>(
gelu,
......@@ -238,6 +264,7 @@ class FusedDropoutHelper {
quant_round_type,
quant_max_bound,
quant_min_bound);
}
} else if (act_method == "relu") {
phi::funcs::ReluFunctor<T> relu;
LaunchDropoutActBias<T,
......
......@@ -37,11 +37,12 @@ static __device__ __forceinline__ float FP32FastTanh(float x) {
return tanhf(x);
}
template <bool FastMode>
static __device__ __forceinline__ float FP32GeluFwd(float x) {
auto tanh_out =
FP32FastTanh<FastMode>(0.79788456f * x * (1.0f + 0.044715f * x * x));
return x * 0.5f * (1.0f + tanh_out);
template <typename T, bool FastMode>
static __device__ __forceinline__ T GeluFwd(T x) {
const float cast_x = static_cast<float>(x);
auto tanh_out = FP32FastTanh<FastMode>(0.79788456f * cast_x *
(1.0f + 0.044715f * cast_x * cast_x));
return static_cast<T>(cast_x * 0.5f * (1.0f + tanh_out));
}
template <bool FastMode>
......@@ -67,8 +68,7 @@ static __global__ void FP16FastGeluFwdCUDAKernel(const __half* x,
ArrT in_arr = *reinterpret_cast<const ArrT*>(x + offset);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
float tmp = __half2float(in_arr[i]);
in_arr[i] = __float2half(FP32GeluFwd<FastMode>(tmp));
in_arr[i] = GeluFwd<half, FastMode>(in_arr[i]);
}
*reinterpret_cast<ArrT*>(y + offset) = in_arr;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册