未验证 提交 7cb49539 编写于 作者: Z Zhang Zheng 提交者: GitHub

Suppport more scenes for fused_fast_ln (#42282)

* Suppport more scenes for fused_fast_ln

* fix
上级 687219fe
...@@ -156,9 +156,9 @@ __global__ void FusedLayernormResidualDropoutBias( ...@@ -156,9 +156,9 @@ __global__ void FusedLayernormResidualDropoutBias(
} }
/* /*
* @brief layernorm(residual + dropout(x)); * @brief layernorm(residual + dropout(x));
* Conditions: * Conditions:
* (1) The number of cols is 1024; * (1) The number of cols is 768/1024/4096;
* (2) layer_norm scale and bias is not null; * (2) layer_norm scale and bias is not null;
* (3) linear bias is null; * (3) linear bias is null;
* @param * @param
...@@ -166,6 +166,7 @@ __global__ void FusedLayernormResidualDropoutBias( ...@@ -166,6 +166,7 @@ __global__ void FusedLayernormResidualDropoutBias(
* cols: 1024 * cols: 1024
* x_: [rows, cols], inputs * x_: [rows, cols], inputs
* residual_:[rows, cols] * residual_:[rows, cols]
* bias_: [cols], linear bias, can be null
* gamma_: [cols]: layernorm scale, not null * gamma_: [cols]: layernorm scale, not null
* beta_: [cols], layernorm bias, not null * beta_: [cols], layernorm bias, not null
* mask_out_: [rows, cols], dropout result * mask_out_: [rows, cols], dropout result
...@@ -173,7 +174,7 @@ __global__ void FusedLayernormResidualDropoutBias( ...@@ -173,7 +174,7 @@ __global__ void FusedLayernormResidualDropoutBias(
* y_: [rows, cols], layernorm result * y_: [rows, cols], layernorm result
* mean_out_: [rows]: layernorm means * mean_out_: [rows]: layernorm means
* var_out_: [rows]: layernorm vars * var_out_: [rows]: layernorm vars
*/ */
template < template <
typename T, typename U, typename ScaleT = U, typename MaskType = uint8_t, typename T, typename U, typename ScaleT = U, typename MaskType = uint8_t,
int VecSize = 8, int WARPS_M = 4, int WARPS_N = 1, int BYTES_PER_LDG = 16, int VecSize = 8, int WARPS_M = 4, int WARPS_N = 1, int BYTES_PER_LDG = 16,
...@@ -182,14 +183,16 @@ template < ...@@ -182,14 +183,16 @@ template <
int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW, int ROWS_PER_CTA = WARPS_M, int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW, int ROWS_PER_CTA = WARPS_M,
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize, int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA> int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA>
__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
int rows, int cols, uint64_t seed, const float dropout_prob, int rows, int cols, uint64_t seed, const float dropout_prob,
const bool is_upscale_in_train, const bool is_test, const bool is_upscale_in_train, const bool is_test,
const uint64_t increment, const float epsilon, const T *__restrict__ x_ptr, const uint64_t increment, const float epsilon, const T *__restrict__ x_ptr,
const T *__restrict__ residual_ptr, const ScaleT *__restrict__ gamma_ptr, const T *__restrict__ residual_ptr, const T *__restrict__ bias_ptr,
const ScaleT *__restrict__ beta_ptr, MaskType *__restrict__ mask_out_ptr, const ScaleT *__restrict__ gamma_ptr, const ScaleT *__restrict__ beta_ptr,
U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr, MaskType *__restrict__ mask_out_ptr, U *__restrict__ mean_out_ptr,
T *__restrict__ residual_out_ptr, T *__restrict__ y_ptr) { U *__restrict__ var_out_ptr, T *__restrict__ residual_out_ptr,
T *__restrict__ y_ptr) {
__shared__ U smem[WARPS_M * WARPS_N];
using Vec = phi::AlignedVector<T, VecSize>; using Vec = phi::AlignedVector<T, VecSize>;
using Vec_scale = phi::AlignedVector<ScaleT, VecSize>; using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;
using MaskStoreT = phi::AlignedVector<MaskType, VecSize>; using MaskStoreT = phi::AlignedVector<MaskType, VecSize>;
...@@ -204,12 +207,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( ...@@ -204,12 +207,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
const int c = warp_n * THREADS_PER_WARP + lane; // lane const int c = warp_n * THREADS_PER_WARP + lane; // lane
const int r = bidx * ROWS_PER_CTA + warp_m; // row id const int r = bidx * ROWS_PER_CTA + warp_m; // row id
int idx = r * LN_NUM_COLS + c; int idx = r * ELTS_PER_ROW + c;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state); curand_init(seed, idx, increment, &state);
T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test); T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
// bias
Vec bias[LDGS];
if (bias_ptr != nullptr) {
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Load<T, VecSize>(bias_ptr + col * VecSize, &bias[it]);
col += THREADS_PER_ROW;
}
}
Vec_scale gamma[LDGS]; Vec_scale gamma[LDGS];
Vec_scale beta[LDGS]; Vec_scale beta[LDGS];
#pragma unroll #pragma unroll
...@@ -219,14 +232,14 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( ...@@ -219,14 +232,14 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
col += THREADS_PER_ROW; col += THREADS_PER_ROW;
} }
constexpr U rn = 1.f / U(LN_NUM_COLS); constexpr U rn = 1.f / U(ELTS_PER_ROW);
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) { for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
Vec x[LDGS]; Vec x[LDGS];
Vec residual[LDGS]; Vec residual[LDGS];
#pragma unroll #pragma unroll
for (int it = 0, col = c; it < LDGS; it++) { for (int it = 0, col = c; it < LDGS; it++) {
phi::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]); phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]);
phi::Load<T, VecSize>(residual_ptr + row * LN_NUM_COLS + col * VecSize, phi::Load<T, VecSize>(residual_ptr + row * ELTS_PER_ROW + col * VecSize,
&residual[it]); &residual[it]);
col += THREADS_PER_ROW; col += THREADS_PER_ROW;
} }
...@@ -255,14 +268,28 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( ...@@ -255,14 +268,28 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
// 4 * 8 // 4 * 8
U xf[LDGS * VecSize]; U xf[LDGS * VecSize];
if (bias_ptr != nullptr) {
#pragma unroll #pragma unroll
for (int it = 0; it < LDGS; it++) { for (int it = 0; it < LDGS; it++) {
#pragma unroll #pragma unroll
for (int jt = 0; jt < VecSize; jt++) { for (int jt = 0; jt < VecSize; jt++) {
// dropout(x) + residual // dropout(x) + residual
x[it][jt] = x[it][jt] * static_cast<T>(mask_vec[it][jt]) * factor + x[it][jt] = (x[it][jt] + bias[it][jt]) *
residual[it][jt]; static_cast<T>(mask_vec[it][jt]) * factor +
xf[it * VecSize + jt] = U(x[it][jt]); residual[it][jt];
xf[it * VecSize + jt] = U(x[it][jt]);
}
}
} else {
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
// dropout(x) + residual
x[it][jt] = x[it][jt] * static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
xf[it * VecSize + jt] = U(x[it][jt]);
}
} }
} }
...@@ -270,9 +297,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( ...@@ -270,9 +297,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
#pragma unroll #pragma unroll
for (int it = 0, col = c; it < LDGS; it++) { for (int it = 0, col = c; it < LDGS; it++) {
phi::Store<T, VecSize>( phi::Store<T, VecSize>(
x[it], residual_out_ptr + row * LN_NUM_COLS + col * VecSize); x[it], residual_out_ptr + row * ELTS_PER_ROW + col * VecSize);
phi::Store<MaskType, VecSize>( phi::Store<MaskType, VecSize>(
mask_vec[it], mask_out_ptr + row * LN_NUM_COLS + col * VecSize); mask_vec[it], mask_out_ptr + row * ELTS_PER_ROW + col * VecSize);
col += THREADS_PER_ROW; col += THREADS_PER_ROW;
} }
...@@ -289,6 +316,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( ...@@ -289,6 +316,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
for (int it = 1; it < THREADS_PER_WARP; it *= 2) { for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it); mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
} }
if (WARPS_N > 1) {
if (lane == 0) {
smem[warp_m * WARPS_N + warp_n] = mu_local;
}
__syncthreads();
if (tidx == 0) {
mu_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
mu_local += smem[warp_m * WARPS_N + it];
}
smem[warp_m] = mu_local;
}
__syncthreads();
mu_local = smem[warp_m];
}
mu_local *= rn; mu_local *= rn;
if (lane == 0) { if (lane == 0) {
mean_out_ptr[row] = mu_local; mean_out_ptr[row] = mu_local;
...@@ -308,6 +351,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( ...@@ -308,6 +351,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
for (int it = 1; it < THREADS_PER_WARP; it *= 2) { for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it); var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
} }
if (WARPS_N > 1) {
if (lane == 0) {
smem[warp_m * WARPS_N + warp_n] = var_local;
}
__syncthreads();
if (tidx == 0) {
var_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
var_local += smem[warp_m * WARPS_N + it];
}
smem[warp_m] = var_local;
}
__syncthreads();
var_local = smem[warp_m];
}
U rsigma = rsqrtf(var_local * rn + epsilon); U rsigma = rsqrtf(var_local * rn + epsilon);
if (lane == 0) { if (lane == 0) {
// Note: the stored var is different for paddle(ln) and apex (fast ln). // Note: the stored var is different for paddle(ln) and apex (fast ln).
...@@ -332,7 +391,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel( ...@@ -332,7 +391,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
#pragma unroll #pragma unroll
for (int it = 0, col = c; it < LDGS; it++) { for (int it = 0, col = c; it < LDGS; it++) {
phi::Store<T, VecSize>(x[it], y_ptr + row * LN_NUM_COLS + col * VecSize); phi::Store<T, VecSize>(x[it], y_ptr + row * ELTS_PER_ROW + col * VecSize);
col += THREADS_PER_ROW; col += THREADS_PER_ROW;
} }
} }
...@@ -390,12 +449,37 @@ void LaunchLayernormResidualDropoutBias( ...@@ -390,12 +449,37 @@ void LaunchLayernormResidualDropoutBias(
return; return;
} }
bool can_call_1024_kernel = false; #define LAUNCH_FUSED_FAST_LN_KERNEL_BASE(cols) \
if (cols == 1024 && scale != nullptr && layernorm_bias != nullptr && case (cols): { \
bias == nullptr) { constexpr int WARPS_N = cols < 1024 ? 1 : (cols / 1024); \
can_call_1024_kernel = true; constexpr int WARPS_M = 4 / WARPS_N; \
const int THREADS_PER_WARP = 32; \
const int BYTES_PER_LDG = 16; \
const int VecSize = BYTES_PER_LDG / sizeof(T); \
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \
const int ROWS_PER_CTA = WARPS_M; \
const int grid = \
static_cast<int>(std::ceil(rows / static_cast<float>(ROWS_PER_CTA))); \
fused_fast_ln_fwd_kernel< \
T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, uint8_t, \
VecSize, WARPS_M, WARPS_N, BYTES_PER_LDG, \
cols><<<grid, THREADS_PER_CTA, 0, ctx.stream()>>>( \
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, \
increment, epsilon, src, residual, bias, scale, layernorm_bias, \
mask_data, mean, var, dst, layernorm_dst); \
} break
#define LAUNCH_FUSED_FAST_LN_KERNEL \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(768); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1024); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(4096)
bool can_call_fast_ln_kernel = false;
if ((cols == 768 || cols == 1024 || cols == 4096) && scale != nullptr &&
layernorm_bias != nullptr) {
can_call_fast_ln_kernel = true;
} }
VLOG(6) << "can_call_1024_kernel = " << can_call_1024_kernel; VLOG(6) << "can_call_fast_ln_kernel = " << can_call_fast_ln_kernel;
const int VecSize = MAX_CACHE_BYTES / sizeof(T); const int VecSize = MAX_CACHE_BYTES / sizeof(T);
if (cols % VecSize != 0) { if (cols % VecSize != 0) {
...@@ -407,26 +491,15 @@ void LaunchLayernormResidualDropoutBias( ...@@ -407,26 +491,15 @@ void LaunchLayernormResidualDropoutBias(
epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst, epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst,
layernorm_dst, mean, var); layernorm_dst, mean, var);
} else { } else {
if (can_call_1024_kernel) { if (can_call_fast_ln_kernel) {
const int WARPS_M = 4; switch (cols) {
const int WARPS_N = 1; LAUNCH_FUSED_FAST_LN_KERNEL;
const int THREADS_PER_WARP = 32; default:
const int BYTES_PER_LDG = 16; PADDLE_THROW(platform::errors::InvalidArgument(
const int VecSize = BYTES_PER_LDG / sizeof(T); "Only when column is equal to 768/1024/4096 is supported for "
"now"));
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; break;
const int ROWS_PER_CTA = WARPS_M; }
// Note: the grid can not exceed max_grid of the gpu.
const int grid =
static_cast<int>(std::ceil(rows / static_cast<float>(ROWS_PER_CTA)));
fused_ln_fwd_1024_kernel<
T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, uint8_t,
VecSize, WARPS_M, WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, ctx.stream()>>>(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test,
increment, epsilon, src, residual, scale, layernorm_bias, mask_data,
mean, var, dst, layernorm_dst);
} else { } else {
int blockDim = GetDesiredBlockDim(cols / VecSize); int blockDim = GetDesiredBlockDim(cols / VecSize);
FusedLayernormResidualDropoutBias< FusedLayernormResidualDropoutBias<
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册