未验证 提交 85baa3c0 编写于 作者: L Li Min 提交者: GitHub

Extend forward fast layer_norm kernel to support more dimensions. (#43118)

* extend forward fast_ln_kernel to support more column values.
上级 8c7cb3d6
...@@ -478,11 +478,15 @@ void LaunchLayernormResidualDropoutBias( ...@@ -478,11 +478,15 @@ void LaunchLayernormResidualDropoutBias(
#define LAUNCH_FUSED_FAST_LN_KERNEL \ #define LAUNCH_FUSED_FAST_LN_KERNEL \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(768); \ LAUNCH_FUSED_FAST_LN_KERNEL_BASE(768); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1024); \ LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1024); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1280); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1536); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1792); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(2048); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(4096) LAUNCH_FUSED_FAST_LN_KERNEL_BASE(4096)
bool can_call_fast_ln_kernel = false; bool can_call_fast_ln_kernel = false;
if ((cols == 768 || cols == 1024 || cols == 4096) && scale != nullptr && if (((cols >= 768 && cols <= 2048 && cols % 256 == 0) || cols == 4096) &&
layernorm_bias != nullptr) { scale != nullptr && layernorm_bias != nullptr) {
can_call_fast_ln_kernel = true; can_call_fast_ln_kernel = true;
} }
VLOG(6) << "can_call_fast_ln_kernel = " << can_call_fast_ln_kernel; VLOG(6) << "can_call_fast_ln_kernel = " << can_call_fast_ln_kernel;
......
...@@ -36,8 +36,6 @@ using CudnnDataType = platform::CudnnDataType<T>; ...@@ -36,8 +36,6 @@ using CudnnDataType = platform::CudnnDataType<T>;
template <typename T> template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType; using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;
#define LN_NUM_COLS 1024
inline static int GetDesiredBlockDim(int64_t block_dim) { inline static int GetDesiredBlockDim(int64_t block_dim) {
#ifdef __HIPCC__ #ifdef __HIPCC__
const int kMaxBlockDim = 256; const int kMaxBlockDim = 256;
...@@ -183,11 +181,12 @@ template <typename T, typename U, typename ScaleT = U, int VecSize = 8, ...@@ -183,11 +181,12 @@ template <typename T, typename U, typename ScaleT = U, int VecSize = 8,
int ROWS_PER_CTA = WARPS_M, int ROWS_PER_CTA = WARPS_M,
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize, int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA> int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA>
__global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel( __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
int rows, int cols, const float epsilon, const T *__restrict__ x_ptr, int rows, int cols, const float epsilon, const T *__restrict__ x_ptr,
const ScaleT *__restrict__ gamma_ptr, const ScaleT *__restrict__ beta_ptr, const ScaleT *__restrict__ gamma_ptr, const ScaleT *__restrict__ beta_ptr,
U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr, U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr,
T *__restrict__ y_ptr) { T *__restrict__ y_ptr) {
__shared__ U smem[WARPS_M * WARPS_N];
using Vec = phi::AlignedVector<T, VecSize>; using Vec = phi::AlignedVector<T, VecSize>;
using Vec_scale = phi::AlignedVector<ScaleT, VecSize>; using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;
...@@ -210,12 +209,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel( ...@@ -210,12 +209,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
col += THREADS_PER_ROW; col += THREADS_PER_ROW;
} }
constexpr U rn = 1.f / U(LN_NUM_COLS); constexpr U rn = 1.f / U(ELTS_PER_ROW);
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) { for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
Vec x[LDGS]; Vec x[LDGS];
#pragma unroll #pragma unroll
for (int it = 0, col = c; it < LDGS; it++) { for (int it = 0, col = c; it < LDGS; it++) {
phi::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]); phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]);
col += THREADS_PER_ROW; col += THREADS_PER_ROW;
} }
U xf[LDGS * VecSize]; U xf[LDGS * VecSize];
...@@ -235,6 +234,23 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel( ...@@ -235,6 +234,23 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
for (int it = 1; it < THREADS_PER_WARP; it *= 2) { for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it); mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
} }
if (WARPS_N > 1) {
if (lane == 0) {
smem[warp_m * WARPS_N + warp_n] = mu_local;
}
__syncthreads();
if (tidx == 0) {
mu_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
mu_local += smem[warp_m * WARPS_N + it];
}
smem[warp_m] = mu_local;
}
__syncthreads();
mu_local = smem[warp_m];
}
mu_local *= rn; mu_local *= rn;
if (lane == 0) { if (lane == 0) {
mean_out_ptr[row] = mu_local; mean_out_ptr[row] = mu_local;
...@@ -254,6 +270,24 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel( ...@@ -254,6 +270,24 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
for (int it = 1; it < THREADS_PER_WARP; it *= 2) { for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it); var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
} }
if (WARPS_N > 1) {
if (lane == 0) {
smem[warp_m * WARPS_N + warp_n] = var_local;
}
__syncthreads();
if (tidx == 0) {
var_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
var_local += smem[warp_m * WARPS_N + it];
}
smem[warp_m] = var_local;
}
__syncthreads();
var_local = smem[warp_m];
}
// Note: to assure if it is right for double // Note: to assure if it is right for double
U rsigma = rsqrtf(var_local * rn + epsilon); U rsigma = rsqrtf(var_local * rn + epsilon);
if (lane == 0) { if (lane == 0) {
...@@ -277,7 +311,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel( ...@@ -277,7 +311,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
#pragma unroll #pragma unroll
for (int it = 0, col = c; it < LDGS; it++) { for (int it = 0, col = c; it < LDGS; it++) {
phi::Store<T, VecSize>(x[it], y_ptr + row * LN_NUM_COLS + col * VecSize); phi::Store<T, VecSize>(x[it], y_ptr + row * ELTS_PER_ROW + col * VecSize);
col += THREADS_PER_ROW; col += THREADS_PER_ROW;
} }
} }
...@@ -416,10 +450,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( ...@@ -416,10 +450,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
const int r = bidx * ROWS_PER_CTA + warp_m; const int r = bidx * ROWS_PER_CTA + warp_m;
const int c = warp_n * THREADS_PER_WARP + lane; const int c = warp_n * THREADS_PER_WARP + lane;
static_assert(LN_NUM_COLS == THREADS_PER_ROW * LDGS * VecSize, ""); static_assert(ELTS_PER_ROW == THREADS_PER_ROW * LDGS * VecSize, "");
// smem for column reduction // smem for column reduction
__shared__ U smem_[ROWS_PER_CTA * LN_NUM_COLS]; __shared__ U smem_[ROWS_PER_CTA * ELTS_PER_ROW];
U dgamma_sum[LDGS * VecSize]; U dgamma_sum[LDGS * VecSize];
U dbeta_sum[LDGS * VecSize]; U dbeta_sum[LDGS * VecSize];
...@@ -434,7 +468,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( ...@@ -434,7 +468,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
U *sum_loss2_shared = &smem_sum_loss2[warp_m * WARPS_N]; U *sum_loss2_shared = &smem_sum_loss2[warp_m * WARPS_N];
// step-1: compute dx and local results of dscale and dbias // step-1: compute dx and local results of dscale and dbias
constexpr float rn = 1.f / static_cast<float>(LN_NUM_COLS); constexpr float rn = 1.f / static_cast<float>(ELTS_PER_ROW);
Vec_scale gamma[LDGS]; Vec_scale gamma[LDGS];
int col = c; int col = c;
#pragma unroll #pragma unroll
...@@ -452,12 +486,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( ...@@ -452,12 +486,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
int col = c; int col = c;
#pragma unroll #pragma unroll
for (int it = 0; it < LDGS; it++) { for (int it = 0; it < LDGS; it++) {
phi::Load<T, VecSize>(dout_ptr + row * LN_NUM_COLS + col * VecSize, phi::Load<T, VecSize>(dout_ptr + row * ELTS_PER_ROW + col * VecSize,
&dout[it]); &dout[it]);
phi::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]); phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]);
if (isFusedDropoutResidualLn) { if (isFusedDropoutResidualLn) {
phi::Load<MaskType, VecSize>( phi::Load<MaskType, VecSize>(
mask_ptr + row * LN_NUM_COLS + col * VecSize, &mask_vec[it]); mask_ptr + row * ELTS_PER_ROW + col * VecSize, &mask_vec[it]);
} }
col += THREADS_PER_ROW; col += THREADS_PER_ROW;
...@@ -551,10 +585,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( ...@@ -551,10 +585,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
col = c; col = c;
#pragma unroll #pragma unroll
for (int it = 0; it < LDGS; it++) { for (int it = 0; it < LDGS; it++) {
phi::Store<T, VecSize>(x[it], dx_ptr + row * LN_NUM_COLS + col * VecSize); phi::Store<T, VecSize>(x[it],
dx_ptr + row * ELTS_PER_ROW + col * VecSize);
if (isFusedDropoutResidualLn) { if (isFusedDropoutResidualLn) {
phi::Store<T, VecSize>( phi::Store<T, VecSize>(
dout[it], d_dropout_src_ptr + row * LN_NUM_COLS + col * VecSize); dout[it], d_dropout_src_ptr + row * ELTS_PER_ROW + col * VecSize);
} }
col += THREADS_PER_ROW; col += THREADS_PER_ROW;
} }
...@@ -562,12 +597,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( ...@@ -562,12 +597,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
// step-2: column reduction of dscale and dbias for each thread block. // step-2: column reduction of dscale and dbias for each thread block.
// each block's sum: [4 * 1024] -> [1 * 1024] // each block's sum: [4 * 1024] -> [1 * 1024]
enum { NUM_RES = LN_NUM_COLS / THREADS_PER_CTA }; // 1024/128 = 8 enum { NUM_RES = ELTS_PER_ROW / THREADS_PER_CTA }; // 1024/128 = 8
static_assert(NUM_RES * THREADS_PER_CTA == LN_NUM_COLS, ""); static_assert(NUM_RES * THREADS_PER_CTA == ELTS_PER_ROW, "");
U *smem_write; U *smem_write;
smem_write = &smem_[warp_m * LN_NUM_COLS + tid_r * VecSize]; // [4 * 1024] smem_write = &smem_[warp_m * ELTS_PER_ROW + tid_r * VecSize]; // [4 * 1024]
#pragma unroll #pragma unroll
for (int it = 0; it < LDGS; it++) { for (int it = 0; it < LDGS; it++) {
#pragma unroll #pragma unroll
...@@ -583,12 +618,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( ...@@ -583,12 +618,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
for (int it = 0; it < ROWS_PER_CTA; it++) { for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) { for (int jt = 0; jt < NUM_RES; jt++) {
cta_dbeta_sum[jt] += cta_dbeta_sum[jt] +=
smem_[it * LN_NUM_COLS + tidx + jt * THREADS_PER_CTA]; smem_[it * ELTS_PER_ROW + tidx + jt * THREADS_PER_CTA];
} }
} }
__syncthreads(); __syncthreads();
smem_write = &smem_[warp_m * LN_NUM_COLS + tid_r * VecSize]; smem_write = &smem_[warp_m * ELTS_PER_ROW + tid_r * VecSize];
#pragma unroll #pragma unroll
for (int it = 0; it < LDGS; it++) { for (int it = 0; it < LDGS; it++) {
#pragma unroll #pragma unroll
...@@ -603,19 +638,19 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( ...@@ -603,19 +638,19 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
for (int it = 0; it < ROWS_PER_CTA; it++) { for (int it = 0; it < ROWS_PER_CTA; it++) {
for (int jt = 0; jt < NUM_RES; jt++) { for (int jt = 0; jt < NUM_RES; jt++) {
cta_dgamma_sum[jt] += cta_dgamma_sum[jt] +=
smem_[it * LN_NUM_COLS + tidx + jt * THREADS_PER_CTA]; smem_[it * ELTS_PER_ROW + tidx + jt * THREADS_PER_CTA];
} }
} }
// the shape of results:(#blocks, 1024) // the shape of results:(#blocks, 1024)
U *dgamma_part = U *dgamma_part =
static_cast<U *>(dgamma_temp_ptr) + bidx * LN_NUM_COLS + tidx; static_cast<U *>(dgamma_temp_ptr) + bidx * ELTS_PER_ROW + tidx;
for (int jt = 0; jt < NUM_RES; jt++) { for (int jt = 0; jt < NUM_RES; jt++) {
*dgamma_part = cta_dgamma_sum[jt]; *dgamma_part = cta_dgamma_sum[jt];
dgamma_part += THREADS_PER_CTA; dgamma_part += THREADS_PER_CTA;
} }
U *dbeta_part = static_cast<U *>(dbeta_temp_ptr) + bidx * LN_NUM_COLS + tidx; U *dbeta_part = static_cast<U *>(dbeta_temp_ptr) + bidx * ELTS_PER_ROW + tidx;
for (int jt = 0; jt < NUM_RES; jt++) { for (int jt = 0; jt < NUM_RES; jt++) {
*dbeta_part = cta_dbeta_sum[jt]; *dbeta_part = cta_dbeta_sum[jt];
dbeta_part += THREADS_PER_CTA; dbeta_part += THREADS_PER_CTA;
...@@ -640,7 +675,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel( ...@@ -640,7 +675,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
const int rows, U *__restrict__ dg_part_, U *__restrict__ db_part_, const int rows, U *__restrict__ dg_part_, U *__restrict__ db_part_,
ScaleT *__restrict__ dg_, ScaleT *__restrict__ db_) { ScaleT *__restrict__ dg_, ScaleT *__restrict__ db_) {
using Vec = phi::AlignedVector<U, VecSize>; using Vec = phi::AlignedVector<U, VecSize>;
static_assert(VEC_COLS == LN_NUM_COLS / VecSize, ""); static_assert(VEC_COLS == ELTS_PER_ROW / VecSize, "");
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const int bidx = blockIdx.x; const int bidx = blockIdx.x;
...@@ -656,8 +691,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel( ...@@ -656,8 +691,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
__shared__ U smem_space[(WARPS_M - 1) * THREADS_PER_ROW * VecSize]; __shared__ U smem_space[(WARPS_M - 1) * THREADS_PER_ROW * VecSize];
for (int col = c; col < VEC_COLS; col += gridDim.x * THREADS_PER_ROW) { for (int col = c; col < VEC_COLS; col += gridDim.x * THREADS_PER_ROW) {
const U *dg_part_ptr = (dg_part_) + r * LN_NUM_COLS + col * VecSize; const U *dg_part_ptr = (dg_part_) + r * ELTS_PER_ROW + col * VecSize;
const U *db_part_ptr = (db_part_) + r * LN_NUM_COLS + col * VecSize; const U *db_part_ptr = (db_part_) + r * ELTS_PER_ROW + col * VecSize;
U dg_sum[VecSize]; U dg_sum[VecSize];
U db_sum[VecSize]; U db_sum[VecSize];
...@@ -669,8 +704,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel( ...@@ -669,8 +704,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
Vec db; Vec db;
phi::Load<U, VecSize>(dg_part_ptr, &dg); phi::Load<U, VecSize>(dg_part_ptr, &dg);
phi::Load<U, VecSize>(db_part_ptr, &db); phi::Load<U, VecSize>(db_part_ptr, &db);
dg_part_ptr += ROWS_PER_CTA * LN_NUM_COLS; dg_part_ptr += ROWS_PER_CTA * ELTS_PER_ROW;
db_part_ptr += ROWS_PER_CTA * LN_NUM_COLS; db_part_ptr += ROWS_PER_CTA * ELTS_PER_ROW;
#pragma unroll #pragma unroll
for (int jt = 0; jt < VecSize; jt++) { for (int jt = 0; jt < VecSize; jt++) {
......
...@@ -84,7 +84,7 @@ void LayerNormKernel(const Context &dev_ctx, ...@@ -84,7 +84,7 @@ void LayerNormKernel(const Context &dev_ctx,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale->dtype(), scale->dtype(),
bias->dtype(), bias->dtype(),
phi::errors::InvalidArgument("Thie Scale and Bias of layer_norm op " phi::errors::InvalidArgument("This Scale and Bias of layer_norm op "
"should have the same data type.")); "should have the same data type."));
} }
} else { } else {
...@@ -131,59 +131,75 @@ void LayerNormKernel(const Context &dev_ctx, ...@@ -131,59 +131,75 @@ void LayerNormKernel(const Context &dev_ctx,
} \ } \
} while (0) } while (0)
#define PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, feature_size) \
case (feature_size): { \
constexpr int WARPS_N = feature_size < 1024 ? 1 : (feature_size / 1024); \
constexpr int WARPS_M = 4 / WARPS_N; \
const int THREADS_PER_WARP = 32; \
const int BYTES_PER_LDG = 16; \
const int VecSize = BYTES_PER_LDG / sizeof(T); \
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \
const int ROWS_PER_CTA = WARPS_M; \
const int grid = static_cast<int>( \
std::ceil(batch_size / static_cast<float>(ROWS_PER_CTA))); \
paddle::operators::fast_ln_fwd_kernel< \
T, \
U, \
ScaleT, \
VecSize, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>( \
batch_size, \
feature_size, \
epsilon, \
x_data, \
static_cast<const ScaleT *>(void_scale_data), \
static_cast<const ScaleT *>(void_bias_data), \
mean_data, \
var_data, \
y_data); \
} break
#define PADDLE_LAUNCH_FAST_LAYERNORM_FWD(ScaleT) \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 768); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1024); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1280); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1536); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1792); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 2048); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 4096)
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
bool can_call_1024_kernel = false; bool can_call_fast_kernel = false;
if (feature_size == 1024 && scale != nullptr && bias != nullptr) { if ((feature_size >= 768 && feature_size <= 2048 && feature_size % 256 == 0 ||
can_call_1024_kernel = true; feature_size == 4096) &&
scale != nullptr && bias != nullptr) {
// can_call_fast_kernel = true;
can_call_fast_kernel = false;
} }
if (can_call_1024_kernel) {
const int WARPS_M = 4; if (can_call_fast_kernel) {
const int WARPS_N = 1;
const int THREADS_PER_WARP = 32;
const int BYTES_PER_LDG = 16;
const int VecSize = BYTES_PER_LDG / sizeof(T);
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M;
const int ROWS_PER_CTA = WARPS_M;
const int grid = static_cast<int>(
std::ceil(batch_size / static_cast<float>(ROWS_PER_CTA)));
if (is_scale_bias_same_dtype_with_x) { if (is_scale_bias_same_dtype_with_x) {
paddle::operators::ln_fwd_1024_kernel< switch (feature_size) {
T, PADDLE_LAUNCH_FAST_LAYERNORM_FWD(T);
U, default:
T, PADDLE_THROW(phi::errors::InvalidArgument(
VecSize, "Only when feature_size is from 256 to 4096 and is diviaible by "
WARPS_M, "256 is supported "
WARPS_N, "now"));
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>( break;
batch_size, }
feature_size,
epsilon,
x_data,
static_cast<const T *>(void_scale_data),
static_cast<const T *>(void_bias_data),
mean_data,
var_data,
y_data);
} else { } else {
paddle::operators::ln_fwd_1024_kernel< switch (feature_size) {
T, PADDLE_LAUNCH_FAST_LAYERNORM_FWD(U);
U, default:
U, PADDLE_THROW(phi::errors::InvalidArgument(
VecSize, "Only when feature_size is from 256 to 4096 and is diviaible by "
WARPS_M, "is supported "
WARPS_N, "now"));
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>( break;
batch_size, }
feature_size,
epsilon,
x_data,
static_cast<const U *>(void_scale_data),
static_cast<const U *>(void_bias_data),
mean_data,
var_data,
y_data);
} }
} else { } else {
#endif #endif
...@@ -197,6 +213,7 @@ void LayerNormKernel(const Context &dev_ctx, ...@@ -197,6 +213,7 @@ void LayerNormKernel(const Context &dev_ctx,
#endif #endif
#undef PADDLE_LAUNCH_LAYERNORM_FWD #undef PADDLE_LAUNCH_LAYERNORM_FWD
#undef PADDLE_LAUNCH_FAST_LAYERNORM_FWD
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册