diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index fc044e0bafa310b8004803acc7cbed0bdafd9ae6..8c551db1f8bca0a968ee148a434e94cab986f47e 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -541,7 +541,7 @@ void LaunchLayernormResidualDropoutGrad( if (!is_upscale_in_train) { factor = static_cast(1.0f); } - ln_bwd_1024_kernel_driver< + ln_bwd_fast_kernel_driver< T, U, LayerNormScaleBiasT, MaskType>( dev_ctx, rows, cols, epsilon, layernorm_src, scale, mean, var, d_out, d_residual, d_scale, d_layernorm_bias, mask_data, factor, d_dropout_src); diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index ac20a5962f394e7f016c2b9db190a660c1ee430f..3519a07539182f8eb7542fde6085e01f261f9358 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -22,6 +22,8 @@ limitations under the License. */ namespace cub = hipcub; #endif +#include + #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/phi/core/ddim.h" @@ -428,7 +430,7 @@ template < int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW, int ROWS_PER_CTA = WARPS_M, int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize, int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA> -__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( +__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel( const int rows, float epsilon, const T *__restrict__ x_ptr, const ScaleT *__restrict__ gamma_ptr, const U *__restrict__ mean_ptr, const U *__restrict__ var_ptr, const T *__restrict__ dout_ptr, @@ -671,7 +673,7 @@ template < int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize, int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA, int VEC_COLS = ELTS_PER_ROW / VecSize> -__global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel( +__global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_fast_final_kernel( const int rows, U *__restrict__ dg_part_, U *__restrict__ db_part_, ScaleT *__restrict__ dg_, ScaleT *__restrict__ db_) { using Vec = phi::AlignedVector; @@ -795,7 +797,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel( */ template -void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows, +void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx, const int rows, const int cols, float epsilon, const T *x_ptr, const ScaleT *scale_ptr, const U *mean_ptr, const U *var_ptr, const T *dout_ptr, T *dx_ptr, @@ -804,10 +806,10 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows, T factor = static_cast(0), T *d_dropout_src_ptr = nullptr) { auto stream = dev_ctx.stream(); - if (cols == 1024) { + if (cols == 1024 || cols == 384 || cols == 256) { // step-1: compute dx and reduced part results of dscale and dbias. - const int WARPS_M = 4; - const int WARPS_N = 1; + const int WARPS_M = 4; // how many rows delt in a cta. + const int WARPS_N = 1; // how many warps to deal with a row. const int BYTES_PER_LDG = 16; const int VecSize = BYTES_PER_LDG / sizeof(T); @@ -839,20 +841,52 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows, "To compute fused_dropout_residual_ln grad, d_dropout_src_ptr " "can't be null")); } - fused_ln_bwd_1024_kernel - <<>>( - rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr, - dscale_temp_ptr, dbias_temp_ptr, dx_ptr, mask_ptr, factor, - d_dropout_src_ptr); +#define LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \ + fused_ln_bwd_fast_kernel \ + <<>>( \ + rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr, \ + dscale_temp_ptr, dbias_temp_ptr, dx_ptr, mask_ptr, factor, \ + d_dropout_src_ptr); + + if (cols == 1024) { + LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(VecSize, 1024); + } else { + switch (cols) { + case 384: + LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(1, 384); + break; + case 256: + LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(VecSize, 256); + break; + } + } +#undef LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL } else { - fused_ln_bwd_1024_kernel - <<>>( - rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr, - dscale_temp_ptr, dbias_temp_ptr, dx_ptr); +#define LAUNCH_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \ + fused_ln_bwd_fast_kernel \ + <<>>( \ + rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr, \ + dscale_temp_ptr, dbias_temp_ptr, dx_ptr); + + if (cols == 1024) { + LAUNCH_FUSED_LN_BWD_FAST_KERNEL(VecSize, 1024); + } else { + switch (cols) { + case 384: + LAUNCH_FUSED_LN_BWD_FAST_KERNEL(1, 384); + break; + case 256: + LAUNCH_FUSED_LN_BWD_FAST_KERNEL(VecSize, 256); + break; + } + } + +#undef LAUNCH_FUSED_LN_BWD_FAST_KERNEL } + const int WARPS_M_2 = 16; const int WARPS_N_2 = 1; const int BYTES_PER_LDG_2 = 4; @@ -865,18 +899,36 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows, WARPS_M_2 * THREADS_PER_ROW_2; // 16 * 32 = 512 const int ROWS_PER_CTA_2 = WARPS_M_2; // 16 - const int gridx_2 = static_cast( - std::ceil(1024 / static_cast(THREADS_PER_ROW_2 * VecSize_2))); // #blocks: 32,#threads_per_block: 512 // Note: it is not supported for double type. if (sizeof(U) > 4) { PADDLE_THROW(platform::errors::InvalidArgument( "Only support float and fp16 type")); } else { - ln_bwd_1024_final_kernel - <<>>( - gridx, dscale_temp_ptr, dbias_temp_ptr, dscale_ptr, dbias_ptr); + int gridx_2 = 0; + +#define LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(vec_size, ele_per_row) \ + gridx_2 = static_cast(std::ceil( \ + ele_per_row / static_cast(THREADS_PER_ROW_2 * vec_size))); \ + ln_bwd_fast_final_kernel \ + <<>>( \ + gridx, dscale_temp_ptr, dbias_temp_ptr, dscale_ptr, dbias_ptr); + + if (cols == 1024) { + LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(VecSize_2, 1024); + } else { + switch (cols) { + case 384: + LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(1, 384); + break; + case 256: + LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(VecSize_2, 256); + break; + } + } + +#undef LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL } } else { PADDLE_THROW(platform::errors::InvalidArgument( @@ -1484,15 +1536,17 @@ static void LayerNormBackward( case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr { #ifdef PADDLE_WITH_CUDA - bool can_call_1024_kernel = false; + bool can_call_fast_kernel = false; // todo: rule out double type. - if (feature_size == 1024 && sizeof(T) <= 4) { - can_call_1024_kernel = true; + if ((feature_size == 1024 || feature_size == 384 || + feature_size == 256) && + sizeof(T) <= 4) { + can_call_fast_kernel = true; } - VLOG(6) << "can_call_1024_kernel = " << can_call_1024_kernel; - if (can_call_1024_kernel) { - ln_bwd_1024_kernel_driver< + VLOG(6) << "can_call_fast_kernel = " << can_call_fast_kernel; + if (can_call_fast_kernel) { + ln_bwd_fast_kernel_driver< T, U, LayerNormScaleBiasT>( dev_ctx, batch_size, feature_size, epsilon, x, scale, mean, var, d_y, d_x, d_scale, d_bias); diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index 1cc2906731bd81edfff32a3cdf32cfe30329ac63..2ee1a1ba76f7b00c825e1d32cc2522c22932d2fe 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -247,7 +247,6 @@ class TestLayerNormOp(unittest.TestCase): def test_check_forward_backward_with_scale_and_bias(self): self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1) - self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1) self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1, @@ -288,6 +287,14 @@ class TestLayerNormOp(unittest.TestCase): begin_norm_axis=1, has_scale=True, has_bias=True) + self.check_forward_backward(shape=[1, 128, 256, 256], + begin_norm_axis=3, + has_scale=True, + has_bias=True) + self.check_forward_backward(shape=[1, 256, 384], + begin_norm_axis=2, + has_scale=True, + has_bias=True) class TestLayerNormAPI(unittest.TestCase):