From 85baa3c0272d6b56a5e8fb0d59d4ed4222f4abe2 Mon Sep 17 00:00:00 2001 From: Li Min <11663212+limin2021@users.noreply.github.com> Date: Thu, 2 Jun 2022 10:11:58 +0800 Subject: [PATCH] Extend forward fast layer_norm kernel to support more dimensions. (#43118) * extend forward fast_ln_kernel to support more column values. --- .../fused_layernorm_residual_dropout_bias.h | 8 +- paddle/fluid/operators/layer_norm_kernel.cu.h | 89 +++++++++---- paddle/phi/kernels/gpu/layer_norm_kernel.cu | 117 ++++++++++-------- 3 files changed, 135 insertions(+), 79 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index 866de8e04a9..9d7d34ebdc9 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -478,11 +478,15 @@ void LaunchLayernormResidualDropoutBias( #define LAUNCH_FUSED_FAST_LN_KERNEL \ LAUNCH_FUSED_FAST_LN_KERNEL_BASE(768); \ LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1024); \ + LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1280); \ + LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1536); \ + LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1792); \ + LAUNCH_FUSED_FAST_LN_KERNEL_BASE(2048); \ LAUNCH_FUSED_FAST_LN_KERNEL_BASE(4096) bool can_call_fast_ln_kernel = false; - if ((cols == 768 || cols == 1024 || cols == 4096) && scale != nullptr && - layernorm_bias != nullptr) { + if (((cols >= 768 && cols <= 2048 && cols % 256 == 0) || cols == 4096) && + scale != nullptr && layernorm_bias != nullptr) { can_call_fast_ln_kernel = true; } VLOG(6) << "can_call_fast_ln_kernel = " << can_call_fast_ln_kernel; diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index 5b5ddddaafb..0c5946b4ae4 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -36,8 +36,6 @@ using CudnnDataType = platform::CudnnDataType; template using LayerNormParamType = typename CudnnDataType::BatchNormParamType; -#define LN_NUM_COLS 1024 - inline static int GetDesiredBlockDim(int64_t block_dim) { #ifdef __HIPCC__ const int kMaxBlockDim = 256; @@ -183,11 +181,12 @@ template -__global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel( +__global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( int rows, int cols, const float epsilon, const T *__restrict__ x_ptr, const ScaleT *__restrict__ gamma_ptr, const ScaleT *__restrict__ beta_ptr, U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr, T *__restrict__ y_ptr) { + __shared__ U smem[WARPS_M * WARPS_N]; using Vec = phi::AlignedVector; using Vec_scale = phi::AlignedVector; @@ -210,12 +209,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel( col += THREADS_PER_ROW; } - constexpr U rn = 1.f / U(LN_NUM_COLS); + constexpr U rn = 1.f / U(ELTS_PER_ROW); for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) { Vec x[LDGS]; #pragma unroll for (int it = 0, col = c; it < LDGS; it++) { - phi::Load(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]); + phi::Load(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]); col += THREADS_PER_ROW; } U xf[LDGS * VecSize]; @@ -235,6 +234,23 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel( for (int it = 1; it < THREADS_PER_WARP; it *= 2) { mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it); } + if (WARPS_N > 1) { + if (lane == 0) { + smem[warp_m * WARPS_N + warp_n] = mu_local; + } + __syncthreads(); + if (tidx == 0) { + mu_local = 0.f; +#pragma unroll + for (int it = 0; it < WARPS_N; ++it) { + mu_local += smem[warp_m * WARPS_N + it]; + } + smem[warp_m] = mu_local; + } + __syncthreads(); + mu_local = smem[warp_m]; + } + mu_local *= rn; if (lane == 0) { mean_out_ptr[row] = mu_local; @@ -254,6 +270,24 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel( for (int it = 1; it < THREADS_PER_WARP; it *= 2) { var_local += __shfl_xor_sync(uint32_t(-1), var_local, it); } + + if (WARPS_N > 1) { + if (lane == 0) { + smem[warp_m * WARPS_N + warp_n] = var_local; + } + __syncthreads(); + if (tidx == 0) { + var_local = 0.f; +#pragma unroll + for (int it = 0; it < WARPS_N; ++it) { + var_local += smem[warp_m * WARPS_N + it]; + } + smem[warp_m] = var_local; + } + __syncthreads(); + var_local = smem[warp_m]; + } + // Note: to assure if it is right for double U rsigma = rsqrtf(var_local * rn + epsilon); if (lane == 0) { @@ -277,7 +311,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel( #pragma unroll for (int it = 0, col = c; it < LDGS; it++) { - phi::Store(x[it], y_ptr + row * LN_NUM_COLS + col * VecSize); + phi::Store(x[it], y_ptr + row * ELTS_PER_ROW + col * VecSize); col += THREADS_PER_ROW; } } @@ -416,10 +450,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( const int r = bidx * ROWS_PER_CTA + warp_m; const int c = warp_n * THREADS_PER_WARP + lane; - static_assert(LN_NUM_COLS == THREADS_PER_ROW * LDGS * VecSize, ""); + static_assert(ELTS_PER_ROW == THREADS_PER_ROW * LDGS * VecSize, ""); // smem for column reduction - __shared__ U smem_[ROWS_PER_CTA * LN_NUM_COLS]; + __shared__ U smem_[ROWS_PER_CTA * ELTS_PER_ROW]; U dgamma_sum[LDGS * VecSize]; U dbeta_sum[LDGS * VecSize]; @@ -434,7 +468,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( U *sum_loss2_shared = &smem_sum_loss2[warp_m * WARPS_N]; // step-1: compute dx and local results of dscale and dbias - constexpr float rn = 1.f / static_cast(LN_NUM_COLS); + constexpr float rn = 1.f / static_cast(ELTS_PER_ROW); Vec_scale gamma[LDGS]; int col = c; #pragma unroll @@ -452,12 +486,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( int col = c; #pragma unroll for (int it = 0; it < LDGS; it++) { - phi::Load(dout_ptr + row * LN_NUM_COLS + col * VecSize, + phi::Load(dout_ptr + row * ELTS_PER_ROW + col * VecSize, &dout[it]); - phi::Load(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]); + phi::Load(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]); if (isFusedDropoutResidualLn) { phi::Load( - mask_ptr + row * LN_NUM_COLS + col * VecSize, &mask_vec[it]); + mask_ptr + row * ELTS_PER_ROW + col * VecSize, &mask_vec[it]); } col += THREADS_PER_ROW; @@ -551,10 +585,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( col = c; #pragma unroll for (int it = 0; it < LDGS; it++) { - phi::Store(x[it], dx_ptr + row * LN_NUM_COLS + col * VecSize); + phi::Store(x[it], + dx_ptr + row * ELTS_PER_ROW + col * VecSize); if (isFusedDropoutResidualLn) { phi::Store( - dout[it], d_dropout_src_ptr + row * LN_NUM_COLS + col * VecSize); + dout[it], d_dropout_src_ptr + row * ELTS_PER_ROW + col * VecSize); } col += THREADS_PER_ROW; } @@ -562,12 +597,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( // step-2: column reduction of dscale and dbias for each thread block. // each block's sum: [4 * 1024] -> [1 * 1024] - enum { NUM_RES = LN_NUM_COLS / THREADS_PER_CTA }; // 1024/128 = 8 - static_assert(NUM_RES * THREADS_PER_CTA == LN_NUM_COLS, ""); + enum { NUM_RES = ELTS_PER_ROW / THREADS_PER_CTA }; // 1024/128 = 8 + static_assert(NUM_RES * THREADS_PER_CTA == ELTS_PER_ROW, ""); U *smem_write; - smem_write = &smem_[warp_m * LN_NUM_COLS + tid_r * VecSize]; // [4 * 1024] + smem_write = &smem_[warp_m * ELTS_PER_ROW + tid_r * VecSize]; // [4 * 1024] #pragma unroll for (int it = 0; it < LDGS; it++) { #pragma unroll @@ -583,12 +618,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( for (int it = 0; it < ROWS_PER_CTA; it++) { for (int jt = 0; jt < NUM_RES; jt++) { cta_dbeta_sum[jt] += - smem_[it * LN_NUM_COLS + tidx + jt * THREADS_PER_CTA]; + smem_[it * ELTS_PER_ROW + tidx + jt * THREADS_PER_CTA]; } } __syncthreads(); - smem_write = &smem_[warp_m * LN_NUM_COLS + tid_r * VecSize]; + smem_write = &smem_[warp_m * ELTS_PER_ROW + tid_r * VecSize]; #pragma unroll for (int it = 0; it < LDGS; it++) { #pragma unroll @@ -603,19 +638,19 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( for (int it = 0; it < ROWS_PER_CTA; it++) { for (int jt = 0; jt < NUM_RES; jt++) { cta_dgamma_sum[jt] += - smem_[it * LN_NUM_COLS + tidx + jt * THREADS_PER_CTA]; + smem_[it * ELTS_PER_ROW + tidx + jt * THREADS_PER_CTA]; } } // the shape of results:(#blocks, 1024) U *dgamma_part = - static_cast(dgamma_temp_ptr) + bidx * LN_NUM_COLS + tidx; + static_cast(dgamma_temp_ptr) + bidx * ELTS_PER_ROW + tidx; for (int jt = 0; jt < NUM_RES; jt++) { *dgamma_part = cta_dgamma_sum[jt]; dgamma_part += THREADS_PER_CTA; } - U *dbeta_part = static_cast(dbeta_temp_ptr) + bidx * LN_NUM_COLS + tidx; + U *dbeta_part = static_cast(dbeta_temp_ptr) + bidx * ELTS_PER_ROW + tidx; for (int jt = 0; jt < NUM_RES; jt++) { *dbeta_part = cta_dbeta_sum[jt]; dbeta_part += THREADS_PER_CTA; @@ -640,7 +675,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel( const int rows, U *__restrict__ dg_part_, U *__restrict__ db_part_, ScaleT *__restrict__ dg_, ScaleT *__restrict__ db_) { using Vec = phi::AlignedVector; - static_assert(VEC_COLS == LN_NUM_COLS / VecSize, ""); + static_assert(VEC_COLS == ELTS_PER_ROW / VecSize, ""); const int tidx = threadIdx.x; const int bidx = blockIdx.x; @@ -656,8 +691,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel( __shared__ U smem_space[(WARPS_M - 1) * THREADS_PER_ROW * VecSize]; for (int col = c; col < VEC_COLS; col += gridDim.x * THREADS_PER_ROW) { - const U *dg_part_ptr = (dg_part_) + r * LN_NUM_COLS + col * VecSize; - const U *db_part_ptr = (db_part_) + r * LN_NUM_COLS + col * VecSize; + const U *dg_part_ptr = (dg_part_) + r * ELTS_PER_ROW + col * VecSize; + const U *db_part_ptr = (db_part_) + r * ELTS_PER_ROW + col * VecSize; U dg_sum[VecSize]; U db_sum[VecSize]; @@ -669,8 +704,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel( Vec db; phi::Load(dg_part_ptr, &dg); phi::Load(db_part_ptr, &db); - dg_part_ptr += ROWS_PER_CTA * LN_NUM_COLS; - db_part_ptr += ROWS_PER_CTA * LN_NUM_COLS; + dg_part_ptr += ROWS_PER_CTA * ELTS_PER_ROW; + db_part_ptr += ROWS_PER_CTA * ELTS_PER_ROW; #pragma unroll for (int jt = 0; jt < VecSize; jt++) { diff --git a/paddle/phi/kernels/gpu/layer_norm_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_kernel.cu index 72127042c16..665913893e0 100644 --- a/paddle/phi/kernels/gpu/layer_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/layer_norm_kernel.cu @@ -84,7 +84,7 @@ void LayerNormKernel(const Context &dev_ctx, PADDLE_ENFORCE_EQ( scale->dtype(), bias->dtype(), - phi::errors::InvalidArgument("Thie Scale and Bias of layer_norm op " + phi::errors::InvalidArgument("This Scale and Bias of layer_norm op " "should have the same data type.")); } } else { @@ -131,59 +131,75 @@ void LayerNormKernel(const Context &dev_ctx, } \ } while (0) +#define PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, feature_size) \ + case (feature_size): { \ + constexpr int WARPS_N = feature_size < 1024 ? 1 : (feature_size / 1024); \ + constexpr int WARPS_M = 4 / WARPS_N; \ + const int THREADS_PER_WARP = 32; \ + const int BYTES_PER_LDG = 16; \ + const int VecSize = BYTES_PER_LDG / sizeof(T); \ + const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \ + const int ROWS_PER_CTA = WARPS_M; \ + const int grid = static_cast( \ + std::ceil(batch_size / static_cast(ROWS_PER_CTA))); \ + paddle::operators::fast_ln_fwd_kernel< \ + T, \ + U, \ + ScaleT, \ + VecSize, \ + WARPS_M, \ + WARPS_N, \ + BYTES_PER_LDG><<>>( \ + batch_size, \ + feature_size, \ + epsilon, \ + x_data, \ + static_cast(void_scale_data), \ + static_cast(void_bias_data), \ + mean_data, \ + var_data, \ + y_data); \ + } break + +#define PADDLE_LAUNCH_FAST_LAYERNORM_FWD(ScaleT) \ + PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 768); \ + PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1024); \ + PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1280); \ + PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1536); \ + PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1792); \ + PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 2048); \ + PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 4096) + #ifdef PADDLE_WITH_CUDA - bool can_call_1024_kernel = false; - if (feature_size == 1024 && scale != nullptr && bias != nullptr) { - can_call_1024_kernel = true; + bool can_call_fast_kernel = false; + if ((feature_size >= 768 && feature_size <= 2048 && feature_size % 256 == 0 || + feature_size == 4096) && + scale != nullptr && bias != nullptr) { + // can_call_fast_kernel = true; + can_call_fast_kernel = false; } - if (can_call_1024_kernel) { - const int WARPS_M = 4; - const int WARPS_N = 1; - const int THREADS_PER_WARP = 32; - const int BYTES_PER_LDG = 16; - const int VecSize = BYTES_PER_LDG / sizeof(T); - - const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; - const int ROWS_PER_CTA = WARPS_M; - - const int grid = static_cast( - std::ceil(batch_size / static_cast(ROWS_PER_CTA))); + + if (can_call_fast_kernel) { if (is_scale_bias_same_dtype_with_x) { - paddle::operators::ln_fwd_1024_kernel< - T, - U, - T, - VecSize, - WARPS_M, - WARPS_N, - BYTES_PER_LDG><<>>( - batch_size, - feature_size, - epsilon, - x_data, - static_cast(void_scale_data), - static_cast(void_bias_data), - mean_data, - var_data, - y_data); + switch (feature_size) { + PADDLE_LAUNCH_FAST_LAYERNORM_FWD(T); + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "Only when feature_size is from 256 to 4096 and is diviaible by " + "256 is supported " + "now")); + break; + } } else { - paddle::operators::ln_fwd_1024_kernel< - T, - U, - U, - VecSize, - WARPS_M, - WARPS_N, - BYTES_PER_LDG><<>>( - batch_size, - feature_size, - epsilon, - x_data, - static_cast(void_scale_data), - static_cast(void_bias_data), - mean_data, - var_data, - y_data); + switch (feature_size) { + PADDLE_LAUNCH_FAST_LAYERNORM_FWD(U); + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "Only when feature_size is from 256 to 4096 and is diviaible by " + "is supported " + "now")); + break; + } } } else { #endif @@ -197,6 +213,7 @@ void LayerNormKernel(const Context &dev_ctx, #endif #undef PADDLE_LAUNCH_LAYERNORM_FWD +#undef PADDLE_LAUNCH_FAST_LAYERNORM_FWD } } // namespace phi -- GitLab