diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index 970b2d82e2b15c7c2856db4abb5a8f448f5b0ed5..3972c60e8347b327d7a4587de022534d7b92fa7b 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -250,11 +250,14 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper { } // out = layernorm(residual + dropout(src + bias)) - void LayernormResidualDropoutBias( - const platform::CUDADeviceContext& ctx, const T* src, const T* residual, - const T* bias, const LayerNormParamType* gamma, - const LayerNormParamType* beta, T* dropout_out, MaskType* mask, T* out, - LayerNormParamType* mean, LayerNormParamType* variance) { + template , bool is_same_type = false> + void LayernormResidualDropoutBias(const platform::CUDADeviceContext& ctx, + const T* src, const T* residual, + const T* bias, const P* gamma, + const P* beta, T* dropout_out, + MaskType* mask, T* out, + LayerNormParamType* mean, + LayerNormParamType* variance) { using U = LayerNormParamType; int vec_size = MAX_CACHE_BYTES / sizeof(T); if (this->cols_ % vec_size != 0) { @@ -263,7 +266,7 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper { int threads = GetDesiredBlockDim(this->cols_ / vec_size); int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size; increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment); - LaunchLayernormResidualDropoutBias( + LaunchLayernormResidualDropoutBias( this->rows_, this->cols_, increment, this->dropout_param_.seed, this->dropout_param_.dropout_prob, epsilon_, this->dropout_param_.is_upscale_in_train, this->dropout_param_.is_test, @@ -271,17 +274,19 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper { variance, ctx); } - void LayernormResidualDropoutBiasGrad( - const platform::CUDADeviceContext& ctx, const T* d_out, - const T* layernorm_src, const MaskType* mask, - const LayerNormParamType* gamma, const LayerNormParamType* mean, - const LayerNormParamType* variance, T* d_layernorm_src, - LayerNormParamType* d_scale, LayerNormParamType* d_layernorm_bias, - T* d_dropout_src, T* d_bias, T* d_residual) { + template , bool is_same_type = false> + void LayernormResidualDropoutBiasGrad(const platform::CUDADeviceContext& ctx, + const T* d_out, const T* layernorm_src, + const MaskType* mask, const P* gamma, + const LayerNormParamType* mean, + const LayerNormParamType* variance, + T* d_layernorm_src, P* d_scale, + P* d_layernorm_bias, T* d_dropout_src, + T* d_bias, T* d_residual) { using U = LayerNormParamType; - LayerNormBackward(layernorm_src, d_out, gamma, mean, variance, - d_layernorm_src, d_scale, d_layernorm_bias, - epsilon_, this->rows_, this->cols_, ctx); + LayerNormBackward( + layernorm_src, d_out, gamma, mean, variance, d_layernorm_src, d_scale, + d_layernorm_bias, epsilon_, this->rows_, this->cols_, ctx); this->ResidualDropoutBiasGrad(ctx, d_layernorm_src, mask, d_dropout_src, d_residual, d_bias); } diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index 1827e137c15f183ae97552c795df85556ef0cd1a..b27b70dc9dc0c7c08a6b04889be8eb0bc6b29d24 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -24,46 +24,57 @@ using CudnnDataType = platform::CudnnDataType; template using LayerNormParamType = typename CudnnDataType::BatchNormParamType; +template +using LayerNormScaleBiasT = + typename std::conditional::type; + /** * @brief fused add_bias, dropout, add residual and leyer_norm into one * operators. Currently only support forward */ -template -__device__ void CalcLayernormY(const LayerNormParamType *scale, - const LayerNormParamType *bias, const T *x, - T *y, const int row_id, const int col_id, - const int cols, - const LayerNormParamType mean_val, - const LayerNormParamType invvar) { - using U = LayerNormParamType; +template +__device__ void CalcLayernormY( + const LayerNormScaleBiasT *scale, + const LayerNormScaleBiasT *bias, const T *x, + T *y, const int row_id, const int col_id, const int cols, + const LayerNormParamType mean_val, const LayerNormParamType invvar) { using LoadT = platform::AlignedVector; using StoreT = platform::AlignedVector; using LoadU = platform::AlignedVector; + using LoadScaleOrBias = + platform::AlignedVector, + VecSize>; for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) { - LoadU scale_vec; - LoadU bias_vec; + LoadScaleOrBias scale_vec; + LoadScaleOrBias bias_vec; LoadT x_vec; #pragma unroll for (int ii = 0; ii < VecSize; ii++) { - scale_vec[ii] = static_cast(1); - bias_vec[ii] = static_cast(0); + scale_vec[ii] = + static_cast>(1); + bias_vec[ii] = + static_cast>(0); } // vectorize load data from global platform::Load(&x[row_id * cols + i], &x_vec); if (scale != nullptr) { - platform::Load(&scale[i], &scale_vec); + platform::Load, + VecSize>(&scale[i], &scale_vec); } if (bias != nullptr) { - platform::Load(&bias[i], &bias_vec); + platform::Load, + VecSize>(&bias[i], &bias_vec); } StoreT y_vec; for (int ii = 0; ii < VecSize; ii++) { - y_vec[ii] = static_cast( - scale_vec[ii] * (static_cast(x_vec[ii]) - mean_val) * invvar + - bias_vec[ii]); + y_vec[ii] = + static_cast(static_cast(scale_vec[ii]) * + (static_cast(x_vec[ii]) - mean_val) * invvar + + static_cast(bias_vec[ii])); } platform::Store(y_vec, &y[row_id * cols + i]); } @@ -85,15 +96,17 @@ __device__ void CalcLayernormY(const LayerNormParamType *scale, * means: [rows]: layernorm means * vars: [rows]: layernorm vars */ -template +template __global__ void FusedLayernormResidualDropoutBias( const size_t rows, const size_t cols, uint64_t seed, const float dropout_prob, const bool is_upscale_in_train, const bool is_test, const uint64_t increment, const float epsilon, const T *src, const T *residual, const T *bias, - const LayerNormParamType *scale, - const LayerNormParamType *layernorm_bias, MaskType *mask, T *dst, - T *layernorm_dst, LayerNormParamType *mean, LayerNormParamType *var) { + const LayerNormScaleBiasT *scale, + const LayerNormScaleBiasT *layernorm_bias, + MaskType *mask, T *dst, T *layernorm_dst, LayerNormParamType *mean, + LayerNormParamType *var) { int col_id = threadIdx.x; int row_id = blockIdx.x; int idx = row_id * cols + col_id; @@ -101,7 +114,6 @@ __global__ void FusedLayernormResidualDropoutBias( curand_init(seed, idx, increment, &state); T factor = GetFactor(dropout_prob, is_upscale_in_train, is_test); - using U = LayerNormParamType; __shared__ U mean_share; __shared__ U var_share; @@ -121,10 +133,12 @@ __global__ void FusedLayernormResidualDropoutBias( mean_val = BlockReduceSum(mean_val, shared_mean); var_val = BlockReduceSum(var_val, shared_var); if (threadIdx.x == 0) { - auto scale = static_cast(1.) / static_cast(cols); - auto tmp = mean_val * scale; + auto scale = static_cast>( + static_cast(1.) / static_cast(cols)); + auto tmp = mean_val * static_cast(scale); mean[row_id] = mean_share = static_cast(tmp); - var_share = static_cast(var_val * scale - mean_share * mean_share); + var_share = static_cast(var_val * static_cast(scale) - + mean_share * mean_share); var_share = var_share > U(0) ? var_share : U(0); var[row_id] = var_share; } @@ -134,8 +148,9 @@ __global__ void FusedLayernormResidualDropoutBias( U invvar = rsqrt_(var_share + static_cast(epsilon)); // calculate layernorm_dst - CalcLayernormY(scale, layernorm_bias, dst, layernorm_dst, row_id, - col_id, cols, mean_val, invvar); + CalcLayernormY( + scale, layernorm_bias, dst, layernorm_dst, row_id, col_id, cols, mean_val, + invvar); } /** @@ -154,16 +169,17 @@ __global__ void FusedLayernormResidualDropoutBias( * means: [rows]: layernorm means * vars: [rows]: layernorm vars */ -template +template void LaunchLayernormResidualDropoutBias( const uint32_t rows, const uint32_t cols, const int increment, uint64_t seed, const float dropout_prob, const float epsilon, const bool is_upscale_in_train, const bool is_test, const T *src, - const T *residual, const T *bias, const LayerNormParamType *scale, - const LayerNormParamType *layernorm_bias, MaskType *mask_data, T *dst, - T *layernorm_dst, LayerNormParamType *mean, LayerNormParamType *var, - const platform::CUDADeviceContext &ctx) { - using U = LayerNormParamType; + const T *residual, const T *bias, + const LayerNormScaleBiasT *scale, + const LayerNormScaleBiasT *layernorm_bias, + MaskType *mask_data, T *dst, T *layernorm_dst, LayerNormParamType *mean, + LayerNormParamType *var, const platform::CUDADeviceContext &ctx) { // dropout_prob == 1.0f if (std::abs(dropout_prob - 1.0f) < 1e-5) { auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); @@ -175,8 +191,9 @@ void LaunchLayernormResidualDropoutBias( // call layernorm forward switch (GetDesiredBlockDim(cols)) { FIXED_BLOCK_DIM_CASE( - LayerNormForward<<>>( + LayerNormForward< + T, U, kBlockDim, + ScaleBiasWithSameTypeX><<>>( dst, scale, layernorm_bias, layernorm_dst, mean, var, epsilon, cols)); default: @@ -184,21 +201,24 @@ void LaunchLayernormResidualDropoutBias( "Product from begin_norm_axis to end must be larger than 1")); break; } + return; } const int VecSize = MAX_CACHE_BYTES / sizeof(T); if (cols % VecSize != 0) { int blockDim = GetDesiredBlockDim(cols); - FusedLayernormResidualDropoutBias<<>>( + FusedLayernormResidualDropoutBias< + T, uint8_t, 1, U, + ScaleBiasWithSameTypeX><<>>( rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, increment, epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst, layernorm_dst, mean, var); } else { int blockDim = GetDesiredBlockDim(cols / VecSize); FusedLayernormResidualDropoutBias< - T, uint8_t, VecSize><<>>( + T, uint8_t, VecSize, U, + ScaleBiasWithSameTypeX><<>>( rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, increment, epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst, layernorm_dst, mean, var); diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu index 50e3555b4bcd629c85aca27fd3f2999b6694ecb2..57d3fc94dc88a0699b103c081642757798719332 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu @@ -223,7 +223,7 @@ struct TestFusedLayernormResidualDropoutBias { layernorm_bias_ptr = layernorm_bias.data(); } - paddle::operators::LaunchLayernormResidualDropoutBias( + paddle::operators::LaunchLayernormResidualDropoutBias( rows, cols, increment, seed, dropout_prob, epsilon, is_upscale_in_train, is_test, src.data(), residual.data(), bias_ptr, scale_ptr, layernorm_bias_ptr, mask.data(), out.data(),