diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias.h b/paddle/fluid/operators/fused/fused_dropout_act_bias.h index e3e19d9ea6ebcbea48b83a54b0edb817cbec4f8c..553fb8d7be604289a15cad14528512146150a7c8 100644 --- a/paddle/fluid/operators/fused/fused_dropout_act_bias.h +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias.h @@ -18,13 +18,11 @@ limitations under the License. */ #endif #include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h" +#include "paddle/phi/kernels/gpu/gelu_funcs.h" namespace paddle { namespace operators { -/** - *@brief the gelu functor - */ template struct GeluFunctor { inline __host__ __device__ T operator()(const T x) const { @@ -36,6 +34,13 @@ struct GeluFunctor { } }; +template +struct FastGeluFunctor { + inline __device__ T operator()(const T x) const { + return phi::GeluFwd(x); + } +}; + /** *@brief the gelu grad functor */ @@ -131,6 +136,49 @@ __global__ void FusedDropoutActBias( } } +template +__global__ void FusedActBias(Functor act, + const uint64_t elem_cnt, + const uint64_t cols, + const InType *__restrict__ src, + const T *__restrict__ bias, + OutType *dst) { + const int32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + using LoadT = phi::AlignedVector; + using LoadInType = phi::AlignedVector; + using LoadFloat = phi::AlignedVector; + using StoreOutType = phi::AlignedVector; + + LoadInType src_vec; + LoadT bias_vec; + StoreOutType out_vec; + for (int32_t idx = global_thread_idx * VecSize, + step = blockDim.x * gridDim.x * VecSize; + idx < elem_cnt; + idx += step) { + const int32_t col_idx = idx % cols; + phi::Load(&src[idx], &src_vec); + if (bias) { + phi::Load(&bias[col_idx], &bias_vec); + } +#pragma unroll + for (int32_t unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) { + if (bias) { + out_vec[unroll_idx] = static_cast( + act(static_cast(src_vec[unroll_idx]) + bias_vec[unroll_idx])); + } else { + out_vec[unroll_idx] = + static_cast(act(static_cast(src_vec[unroll_idx]))); + } + } + phi::Store(out_vec, &dst[idx]); + } +} + /** * @brief dst = dropout(activation(src + bias)); */ @@ -170,24 +218,37 @@ void LaunchDropoutActBias(Functor act_functor, const int real_vec_size = cols % VecSize == 0 ? VecSize : 1; const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size); if (cols % VecSize == 0) { - FusedDropoutActBias - <<>>( - act_functor, - seed, - rows, - cols, - increment, - dropout_prob, - is_upscale_in_train, - is_test, - src, - bias, - dst, - mask_data, - quant_last_in_scale, - dequant_out_scale_data, - quant_out_scale_offset, - quant_next_in_scale); + if (is_test && (dequant_out_scale_data == nullptr)) { + const int32_t elem_cnt = rows * cols; + const int32_t pack_num = elem_cnt / VecSize; + const int32_t tmp_cols = cols / VecSize; + int block_size = + std::max(static_cast(32), std::min(tmp_cols, 128)); + const int grid_size = std::max(static_cast(1), + (pack_num + block_size - 1) / block_size); + FusedActBias + <<>>( + act_functor, elem_cnt, cols, src, bias, dst); + } else { + FusedDropoutActBias + <<>>( + act_functor, + seed, + rows, + cols, + increment, + dropout_prob, + is_upscale_in_train, + is_test, + src, + bias, + dst, + mask_data, + quant_last_in_scale, + dequant_out_scale_data, + quant_out_scale_offset, + quant_next_in_scale); + } } else { FusedDropoutActBias <<>>( diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index 3230854284062dc528664bf8752acc358e2c1f3c..46c5f7c0e5f94138cfd3a8cdf04758786d2ce517 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h" #include "paddle/phi/kernels/funcs/functors.h" +DECLARE_bool(use_fast_math); + namespace paddle { namespace operators { @@ -216,28 +218,53 @@ class FusedDropoutHelper { const float quant_min_bound = -127.0) { auto increment = GetIncrement(ctx); if (act_method == "gelu") { - GeluFunctor gelu; - LaunchDropoutActBias, InType, OutType>( - gelu, - dropout_param_.seed, - rows_, - cols_, - dropout_param_.increment, - dropout_param_.dropout_prob, - dropout_param_.is_upscale_in_train, - dropout_param_.is_test, - src, - bias, - out, - mask, - ctx, - quant_last_in_scale, - dequant_out_scale_data, - quant_out_scale_offset, - quant_next_in_scale, - quant_round_type, - quant_max_bound, - quant_min_bound); + if (FLAGS_use_fast_math) { + FastGeluFunctor fast_gelu; + LaunchDropoutActBias, InType, OutType>( + fast_gelu, + dropout_param_.seed, + rows_, + cols_, + dropout_param_.increment, + dropout_param_.dropout_prob, + dropout_param_.is_upscale_in_train, + dropout_param_.is_test, + src, + bias, + out, + mask, + ctx, + quant_last_in_scale, + dequant_out_scale_data, + quant_out_scale_offset, + quant_next_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); + } else { + GeluFunctor gelu; + LaunchDropoutActBias, InType, OutType>( + gelu, + dropout_param_.seed, + rows_, + cols_, + dropout_param_.increment, + dropout_param_.dropout_prob, + dropout_param_.is_upscale_in_train, + dropout_param_.is_test, + src, + bias, + out, + mask, + ctx, + quant_last_in_scale, + dequant_out_scale_data, + quant_out_scale_offset, + quant_next_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); + } } else if (act_method == "relu") { phi::funcs::ReluFunctor relu; LaunchDropoutActBias -static __device__ __forceinline__ float FP32GeluFwd(float x) { - auto tanh_out = - FP32FastTanh(0.79788456f * x * (1.0f + 0.044715f * x * x)); - return x * 0.5f * (1.0f + tanh_out); +template +static __device__ __forceinline__ T GeluFwd(T x) { + const float cast_x = static_cast(x); + auto tanh_out = FP32FastTanh(0.79788456f * cast_x * + (1.0f + 0.044715f * cast_x * cast_x)); + return static_cast(cast_x * 0.5f * (1.0f + tanh_out)); } template @@ -67,8 +68,7 @@ static __global__ void FP16FastGeluFwdCUDAKernel(const __half* x, ArrT in_arr = *reinterpret_cast(x + offset); #pragma unroll for (int i = 0; i < VecSize; ++i) { - float tmp = __half2float(in_arr[i]); - in_arr[i] = __float2half(FP32GeluFwd(tmp)); + in_arr[i] = GeluFwd(in_arr[i]); } *reinterpret_cast(y + offset) = in_arr; }