未验证 提交 85baa3c0 编写于 作者: L Li Min 提交者: GitHub

Extend forward fast layer_norm kernel to support more dimensions. (#43118)

* extend forward fast_ln_kernel to support more column values.
上级 8c7cb3d6
......@@ -478,11 +478,15 @@ void LaunchLayernormResidualDropoutBias(
#define LAUNCH_FUSED_FAST_LN_KERNEL \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(768); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1024); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1280); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1536); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1792); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(2048); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(4096)
bool can_call_fast_ln_kernel = false;
if ((cols == 768 || cols == 1024 || cols == 4096) && scale != nullptr &&
layernorm_bias != nullptr) {
if (((cols >= 768 && cols <= 2048 && cols % 256 == 0) || cols == 4096) &&
scale != nullptr && layernorm_bias != nullptr) {
can_call_fast_ln_kernel = true;
}
VLOG(6) << "can_call_fast_ln_kernel = " << can_call_fast_ln_kernel;
......
......@@ -36,8 +36,6 @@ using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;
#define LN_NUM_COLS 1024
inline static int GetDesiredBlockDim(int64_t block_dim) {
#ifdef __HIPCC__
const int kMaxBlockDim = 256;
......@@ -183,11 +181,12 @@ template <typename T, typename U, typename ScaleT = U, int VecSize = 8,
int ROWS_PER_CTA = WARPS_M,
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA>
__global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
__global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
int rows, int cols, const float epsilon, const T *__restrict__ x_ptr,
const ScaleT *__restrict__ gamma_ptr, const ScaleT *__restrict__ beta_ptr,
U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr,
T *__restrict__ y_ptr) {
__shared__ U smem[WARPS_M * WARPS_N];
using Vec = phi::AlignedVector<T, VecSize>;
using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;
......@@ -210,12 +209,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
col += THREADS_PER_ROW;
}
constexpr U rn = 1.f / U(LN_NUM_COLS);
constexpr U rn = 1.f / U(ELTS_PER_ROW);
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
Vec x[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]);
phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]);
col += THREADS_PER_ROW;
}
U xf[LDGS * VecSize];
......@@ -235,6 +234,23 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
}
if (WARPS_N > 1) {
if (lane == 0) {
smem[warp_m * WARPS_N + warp_n] = mu_local;
}
__syncthreads();
if (tidx == 0) {
mu_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
mu_local += smem[warp_m * WARPS_N + it];
}
smem[warp_m] = mu_local;
}
__syncthreads();
mu_local = smem[warp_m];
}
mu_local *= rn;
if (lane == 0) {
mean_out_ptr[row] = mu_local;
......@@ -254,6 +270,24 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
}
if (WARPS_N > 1) {
if (lane == 0) {
smem[warp_m * WARPS_N + warp_n] = var_local;
}
__syncthreads();
if (tidx == 0) {
var_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
var_local += smem[warp_m * WARPS_N + it];
}
smem[warp_m] = var_local;
}
__syncthreads();
var_local = smem[warp_m];
}
// Note: to assure if it is right for double
U rsigma = rsqrtf(var_local * rn + epsilon);
if (lane == 0) {
......@@ -277,7 +311,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Store<T, VecSize>(x[it], y_ptr + row * LN_NUM_COLS + col * VecSize);
phi::Store<T, VecSize>(x[it], y_ptr + row * ELTS_PER_ROW + col * VecSize);
col += THREADS_PER_ROW;
}
}
......@@ -416,10 +450,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
const int r = bidx * ROWS_PER_CTA + warp_m;
const int c = warp_n * THREADS_PER_WARP + lane;
static_assert(LN_NUM_COLS == THREADS_PER_ROW * LDGS * VecSize, "");
static_assert(ELTS_PER_ROW == THREADS_PER_ROW * LDGS * VecSize, "");
// smem for column reduction
__shared__ U smem_[ROWS_PER_CTA * LN_NUM_COLS];
__shared__ U smem_[ROWS_PER_CTA * ELTS_PER_ROW];
U dgamma_sum[LDGS * VecSize];
U dbeta_sum[LDGS * VecSize];
......@@ -434,7 +468,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
U *sum_loss2_shared = &smem_sum_loss2[warp_m * WARPS_N];
// step-1: compute dx and local results of dscale and dbias
constexpr float rn = 1.f / static_cast<float>(LN_NUM_COLS);
constexpr float rn = 1.f / static_cast<float>(ELTS_PER_ROW);
Vec_scale gamma[LDGS];
int col = c;
#pragma unroll
......@@ -452,12 +486,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
int col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
phi::Load<T, VecSize>(dout_ptr + row * LN_NUM_COLS + col * VecSize,
phi::Load<T, VecSize>(dout_ptr + row * ELTS_PER_ROW + col * VecSize,
&dout[it]);
phi::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]);
phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]);
if (isFusedDropoutResidualLn) {
phi::Load<MaskType, VecSize>(
mask_ptr + row * LN_NUM_COLS + col * VecSize, &mask_vec[it]);
mask_ptr + row * ELTS_PER_ROW + col * VecSize, &mask_vec[it]);
}
col += THREADS_PER_ROW;
......@@ -551,10 +585,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
col = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
phi::Store<T, VecSize>(x[it], dx_ptr + row * LN_NUM_COLS + col * VecSize);
phi::Store<T, VecSize>(x[it],
dx_ptr + row * ELTS_PER_ROW + col * VecSize);
if (isFusedDropoutResidualLn) {
phi::Store<T, VecSize>(
dout[it], d_dropout_src_ptr + row * LN_NUM_COLS + col * VecSize);
dout[it], d_dropout_src_ptr + row * ELTS_PER_ROW + col * VecSize);
}
col += THREADS_PER_ROW;
}
......@@ -562,12 +597,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
// step-2: column reduction of dscale and dbias for each thread block.
// each block's sum: [4 * 1024] -> [1 * 1024]
enum { NUM_RES = LN_NUM_COLS / THREADS_PER_CTA }; // 1024/128 = 8
static_assert(NUM_RES * THREADS_PER_CTA == LN_NUM_COLS, "");
enum { NUM_RES = ELTS_PER_ROW / THREADS_PER_CTA }; // 1024/128 = 8
static_assert(NUM_RES * THREADS_PER_CTA == ELTS_PER_ROW, "");
U *smem_write;
smem_write = &smem_[warp_m * LN_NUM_COLS + tid_r * VecSize]; // [4 * 1024]
smem_write = &smem_[warp_m * ELTS_PER_ROW + tid_r * VecSize]; // [4 * 1024]
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
......@@ -583,12 +618,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dbeta_sum[jt] +=
smem_[it * LN_NUM_COLS + tidx + jt * THREADS_PER_CTA];
smem_[it * ELTS_PER_ROW + tidx + jt * THREADS_PER_CTA];
}
}
__syncthreads();
smem_write = &smem_[warp_m * LN_NUM_COLS + tid_r * VecSize];
smem_write = &smem_[warp_m * ELTS_PER_ROW + tid_r * VecSize];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
......@@ -603,19 +638,19 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) {
cta_dgamma_sum[jt] +=
smem_[it * LN_NUM_COLS + tidx + jt * THREADS_PER_CTA];
smem_[it * ELTS_PER_ROW + tidx + jt * THREADS_PER_CTA];
}
}
// the shape of results:(#blocks, 1024)
U *dgamma_part =
static_cast<U *>(dgamma_temp_ptr) + bidx * LN_NUM_COLS + tidx;
static_cast<U *>(dgamma_temp_ptr) + bidx * ELTS_PER_ROW + tidx;
for (int jt = 0; jt < NUM_RES; jt++) {
*dgamma_part = cta_dgamma_sum[jt];
dgamma_part += THREADS_PER_CTA;
}
U *dbeta_part = static_cast<U *>(dbeta_temp_ptr) + bidx * LN_NUM_COLS + tidx;
U *dbeta_part = static_cast<U *>(dbeta_temp_ptr) + bidx * ELTS_PER_ROW + tidx;
for (int jt = 0; jt < NUM_RES; jt++) {
*dbeta_part = cta_dbeta_sum[jt];
dbeta_part += THREADS_PER_CTA;
......@@ -640,7 +675,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
const int rows, U *__restrict__ dg_part_, U *__restrict__ db_part_,
ScaleT *__restrict__ dg_, ScaleT *__restrict__ db_) {
using Vec = phi::AlignedVector<U, VecSize>;
static_assert(VEC_COLS == LN_NUM_COLS / VecSize, "");
static_assert(VEC_COLS == ELTS_PER_ROW / VecSize, "");
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
......@@ -656,8 +691,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
__shared__ U smem_space[(WARPS_M - 1) * THREADS_PER_ROW * VecSize];
for (int col = c; col < VEC_COLS; col += gridDim.x * THREADS_PER_ROW) {
const U *dg_part_ptr = (dg_part_) + r * LN_NUM_COLS + col * VecSize;
const U *db_part_ptr = (db_part_) + r * LN_NUM_COLS + col * VecSize;
const U *dg_part_ptr = (dg_part_) + r * ELTS_PER_ROW + col * VecSize;
const U *db_part_ptr = (db_part_) + r * ELTS_PER_ROW + col * VecSize;
U dg_sum[VecSize];
U db_sum[VecSize];
......@@ -669,8 +704,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
Vec db;
phi::Load<U, VecSize>(dg_part_ptr, &dg);
phi::Load<U, VecSize>(db_part_ptr, &db);
dg_part_ptr += ROWS_PER_CTA * LN_NUM_COLS;
db_part_ptr += ROWS_PER_CTA * LN_NUM_COLS;
dg_part_ptr += ROWS_PER_CTA * ELTS_PER_ROW;
db_part_ptr += ROWS_PER_CTA * ELTS_PER_ROW;
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
......
......@@ -84,7 +84,7 @@ void LayerNormKernel(const Context &dev_ctx,
PADDLE_ENFORCE_EQ(
scale->dtype(),
bias->dtype(),
phi::errors::InvalidArgument("Thie Scale and Bias of layer_norm op "
phi::errors::InvalidArgument("This Scale and Bias of layer_norm op "
"should have the same data type."));
}
} else {
......@@ -131,59 +131,75 @@ void LayerNormKernel(const Context &dev_ctx,
} \
} while (0)
#define PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, feature_size) \
case (feature_size): { \
constexpr int WARPS_N = feature_size < 1024 ? 1 : (feature_size / 1024); \
constexpr int WARPS_M = 4 / WARPS_N; \
const int THREADS_PER_WARP = 32; \
const int BYTES_PER_LDG = 16; \
const int VecSize = BYTES_PER_LDG / sizeof(T); \
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \
const int ROWS_PER_CTA = WARPS_M; \
const int grid = static_cast<int>( \
std::ceil(batch_size / static_cast<float>(ROWS_PER_CTA))); \
paddle::operators::fast_ln_fwd_kernel< \
T, \
U, \
ScaleT, \
VecSize, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>( \
batch_size, \
feature_size, \
epsilon, \
x_data, \
static_cast<const ScaleT *>(void_scale_data), \
static_cast<const ScaleT *>(void_bias_data), \
mean_data, \
var_data, \
y_data); \
} break
#define PADDLE_LAUNCH_FAST_LAYERNORM_FWD(ScaleT) \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 768); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1024); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1280); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1536); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1792); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 2048); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 4096)
#ifdef PADDLE_WITH_CUDA
bool can_call_1024_kernel = false;
if (feature_size == 1024 && scale != nullptr && bias != nullptr) {
can_call_1024_kernel = true;
bool can_call_fast_kernel = false;
if ((feature_size >= 768 && feature_size <= 2048 && feature_size % 256 == 0 ||
feature_size == 4096) &&
scale != nullptr && bias != nullptr) {
// can_call_fast_kernel = true;
can_call_fast_kernel = false;
}
if (can_call_1024_kernel) {
const int WARPS_M = 4;
const int WARPS_N = 1;
const int THREADS_PER_WARP = 32;
const int BYTES_PER_LDG = 16;
const int VecSize = BYTES_PER_LDG / sizeof(T);
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M;
const int ROWS_PER_CTA = WARPS_M;
const int grid = static_cast<int>(
std::ceil(batch_size / static_cast<float>(ROWS_PER_CTA)));
if (can_call_fast_kernel) {
if (is_scale_bias_same_dtype_with_x) {
paddle::operators::ln_fwd_1024_kernel<
T,
U,
T,
VecSize,
WARPS_M,
WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>(
batch_size,
feature_size,
epsilon,
x_data,
static_cast<const T *>(void_scale_data),
static_cast<const T *>(void_bias_data),
mean_data,
var_data,
y_data);
switch (feature_size) {
PADDLE_LAUNCH_FAST_LAYERNORM_FWD(T);
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"Only when feature_size is from 256 to 4096 and is diviaible by "
"256 is supported "
"now"));
break;
}
} else {
paddle::operators::ln_fwd_1024_kernel<
T,
U,
U,
VecSize,
WARPS_M,
WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>(
batch_size,
feature_size,
epsilon,
x_data,
static_cast<const U *>(void_scale_data),
static_cast<const U *>(void_bias_data),
mean_data,
var_data,
y_data);
switch (feature_size) {
PADDLE_LAUNCH_FAST_LAYERNORM_FWD(U);
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"Only when feature_size is from 256 to 4096 and is diviaible by "
"is supported "
"now"));
break;
}
}
} else {
#endif
......@@ -197,6 +213,7 @@ void LayerNormKernel(const Context &dev_ctx,
#endif
#undef PADDLE_LAUNCH_LAYERNORM_FWD
#undef PADDLE_LAUNCH_FAST_LAYERNORM_FWD
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册