未验证 提交 990c5e7f 编写于 作者: Z Zhang Zheng 提交者: GitHub

Support head_dim = 96 in fused_multi_transformer for PLATO-XL (#43120)

* Support head_dim = 96 in fused_multi_transformer in PLATO-XL

* add notes
上级 041000c2
...@@ -529,10 +529,10 @@ inline __device__ void zero(T &dst) { // NOLINT ...@@ -529,10 +529,10 @@ inline __device__ void zero(T &dst) { // NOLINT
dst = tmp.raw; dst = tmp.raw;
} }
template <typename T, int Dh, int THREADS_PER_KEY, int THREADS_PER_VALUE, template <typename T, int Dh, int Dh_MAX, int THREADS_PER_KEY,
int THREADS_PER_BLOCK> int THREADS_PER_VALUE, int THREADS_PER_BLOCK>
__global__ void masked_multihead_attention_kernel( __global__ void masked_multihead_attention_kernel(
Masked_multihead_attention_params<T> params) { Masked_multihead_attention_params<T> params, int pad_active_groups) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
static_assert(Dh % THREADS_PER_KEY == 0, ""); static_assert(Dh % THREADS_PER_KEY == 0, "");
...@@ -560,11 +560,12 @@ __global__ void masked_multihead_attention_kernel( ...@@ -560,11 +560,12 @@ __global__ void masked_multihead_attention_kernel(
const int tid = threadIdx.x; const int tid = threadIdx.x;
float qk_max = -FLT_MAX; float qk_max = -FLT_MAX;
float qk = 0;
// qkv [B, S=1, 3, num_head, head_dim] // qkv [B, S=1, 3, num_head, head_dim]
int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh;
using Qk_vec = typename Qk_vec_<T, Dh>::Type; using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
static_assert(Dh % QK_VEC_SIZE == 0 && Dh / QK_VEC_SIZE <= WARP_SIZE, ""); static_assert(Dh % QK_VEC_SIZE == 0 && Dh / QK_VEC_SIZE <= WARP_SIZE, "");
constexpr int QK_VECS_PER_WARP = Dh / QK_VEC_SIZE; constexpr int QK_VECS_PER_WARP = Dh / QK_VEC_SIZE;
...@@ -605,18 +606,18 @@ __global__ void masked_multihead_attention_kernel( ...@@ -605,18 +606,18 @@ __global__ void masked_multihead_attention_kernel(
params.timestep * QK_ELTS_IN_16B + ci; params.timestep * QK_ELTS_IN_16B + ci;
*reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k; *reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k;
float qk = dot<Qk_vec, Qk_vec>(q, k); qk = dot<Qk_vec, Qk_vec>(q, k);
#pragma unroll }
for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { if (tid < WARP_SIZE) {
qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
} }
qk *= params.inv_sqrt_dh;
if (tid == 0) { if (tid == 0) {
// NOTE(wangxi): mask must be 0.0 // NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[ // T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep]; // bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask); // qk += static_cast<float>(mask);
qk *= params.inv_sqrt_dh;
qk_max = qk; qk_max = qk;
qk_smem[params.timestep] = qk; qk_smem[params.timestep] = qk;
} }
...@@ -746,16 +747,18 @@ __global__ void masked_multihead_attention_kernel( ...@@ -746,16 +747,18 @@ __global__ void masked_multihead_attention_kernel(
zero(out); zero(out);
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) { if (vo < V_PER_ITER) {
V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]); for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) {
V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti]; float logit = logits_smem[ti];
out = fma(logit, cast_to_float(v), out); out = fma(logit, cast_to_float(v), out);
#else #else
T logit = logits_smem[ti]; T logit = logits_smem[ti];
// Update the partial sums. // Update the partial sums.
out = fma(logit, v, out); out = fma(logit, v, out);
#endif #endif
}
} }
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
...@@ -784,8 +787,12 @@ __global__ void masked_multihead_attention_kernel( ...@@ -784,8 +787,12 @@ __global__ void masked_multihead_attention_kernel(
__syncthreads(); __syncthreads();
if (vo < pad_active_groups / 2) {
zero(*reinterpret_cast<V_vec *>(&out_smem[vo * Dh + vi]));
}
#pragma unroll #pragma unroll
for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { for (int active_groups = pad_active_groups; active_groups >= 2;
active_groups /= 2) {
int midpoint = active_groups / 2; int midpoint = active_groups / 2;
if (vo >= midpoint && vo < active_groups) { if (vo >= midpoint && vo < active_groups) {
...@@ -830,7 +837,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -830,7 +837,7 @@ __global__ void masked_multihead_attention_kernel(
template <typename T> template <typename T>
inline size_t smem_size_in_bytes( inline size_t smem_size_in_bytes(
const Masked_multihead_attention_params<T> &params, int dim_head, const Masked_multihead_attention_params<T> &params, int dim_head,
int threads_per_value, int threads_per_block) { int threads_per_value, int threads_per_block, int pad_active_groups) {
size_t qk_sz = div_up(params.timestep + 1, 4) * 16; size_t qk_sz = div_up(params.timestep + 1, 4) * 16;
size_t logits_sz = 0; size_t logits_sz = 0;
...@@ -841,31 +848,34 @@ inline size_t smem_size_in_bytes( ...@@ -841,31 +848,34 @@ inline size_t smem_size_in_bytes(
#endif #endif
size_t softmax_sz = qk_sz + logits_sz; size_t softmax_sz = qk_sz + logits_sz;
int rows_per_red = threads_per_block / threads_per_value; int rows_per_red = pad_active_groups;
size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2; size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2;
return max(softmax_sz, red_sz); return max(softmax_sz, red_sz);
} }
#define MMHA_LAUNCH_KERNEL(T, Dh, THDS_PER_KEY, THDS_PER_VALUE, \ #define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, stream) \ THDS_PER_BLOCK, stream) \
size_t smem_sz = \ int pad_active_groups = \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ 1 << static_cast<int>(ceil(std::log2(THDS_PER_BLOCK / THDS_PER_VALUE))); \
dim3 grid(params.num_head, params.batch_size); \ size_t smem_sz = smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, \
masked_multihead_attention_kernel< \ THDS_PER_BLOCK, pad_active_groups); \
T, Dh, THDS_PER_KEY, THDS_PER_VALUE, \ dim3 grid(params.num_head, params.batch_size); \
THDS_PER_BLOCK><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params) masked_multihead_attention_kernel< \
T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
template <typename T, int Dh> THDS_PER_BLOCK><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>( \
params, pad_active_groups)
template <typename T, int Dh, int Dh_MAX>
void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params, void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params,
const cudaStream_t &stream) { const cudaStream_t &stream) {
constexpr int THREADS_PER_VALUE = Dh * sizeof(T) / 16; constexpr int THREADS_PER_VALUE = Dh * sizeof(T) / 16;
if (params.timestep < 32) { if (params.timestep < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, 4, THREADS_PER_VALUE, 64, stream); MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream);
} else if (params.timestep < 2048) { } else if (params.timestep < 2048) {
MMHA_LAUNCH_KERNEL(T, Dh, 2, THREADS_PER_VALUE, 128, stream); MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream);
} else { } else {
MMHA_LAUNCH_KERNEL(T, Dh, 1, THREADS_PER_VALUE, 256, stream); MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream);
} }
} }
...@@ -890,18 +900,21 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor, ...@@ -890,18 +900,21 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
switch (dim_head) { switch (dim_head) {
case 32: case 32:
fmha_launch_kernel<T, 32>(params, dev_ctx.stream()); fmha_launch_kernel<T, 32, 32>(params, dev_ctx.stream());
break; break;
case 64: case 64:
fmha_launch_kernel<T, 64>(params, dev_ctx.stream()); fmha_launch_kernel<T, 64, 64>(params, dev_ctx.stream());
break;
case 96:
fmha_launch_kernel<T, 96, 128>(params, dev_ctx.stream());
break; break;
case 128: case 128:
fmha_launch_kernel<T, 128>(params, dev_ctx.stream()); fmha_launch_kernel<T, 128, 128>(params, dev_ctx.stream());
break; break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"dim_head = %d is unsupport, only support " "dim_head = %d is unsupport, only support "
"dim_head = 32, 64 or 128 for now.", "dim_head = 32, 64, 96 or 128 for now.",
dim_head)); dim_head));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册