未验证 提交 990c5e7f 编写于 作者: Z Zhang Zheng 提交者: GitHub

Support head_dim = 96 in fused_multi_transformer for PLATO-XL (#43120)

* Support head_dim = 96 in fused_multi_transformer in PLATO-XL

* add notes
上级 041000c2
......@@ -529,10 +529,10 @@ inline __device__ void zero(T &dst) { // NOLINT
dst = tmp.raw;
}
template <typename T, int Dh, int THREADS_PER_KEY, int THREADS_PER_VALUE,
int THREADS_PER_BLOCK>
template <typename T, int Dh, int Dh_MAX, int THREADS_PER_KEY,
int THREADS_PER_VALUE, int THREADS_PER_BLOCK>
__global__ void masked_multihead_attention_kernel(
Masked_multihead_attention_params<T> params) {
Masked_multihead_attention_params<T> params, int pad_active_groups) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
static_assert(Dh % THREADS_PER_KEY == 0, "");
......@@ -560,11 +560,12 @@ __global__ void masked_multihead_attention_kernel(
const int tid = threadIdx.x;
float qk_max = -FLT_MAX;
float qk = 0;
// qkv [B, S=1, 3, num_head, head_dim]
int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh;
using Qk_vec = typename Qk_vec_<T, Dh>::Type;
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
static_assert(Dh % QK_VEC_SIZE == 0 && Dh / QK_VEC_SIZE <= WARP_SIZE, "");
constexpr int QK_VECS_PER_WARP = Dh / QK_VEC_SIZE;
......@@ -605,18 +606,18 @@ __global__ void masked_multihead_attention_kernel(
params.timestep * QK_ELTS_IN_16B + ci;
*reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k;
float qk = dot<Qk_vec, Qk_vec>(q, k);
#pragma unroll
for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
qk = dot<Qk_vec, Qk_vec>(q, k);
}
if (tid < WARP_SIZE) {
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
qk *= params.inv_sqrt_dh;
if (tid == 0) {
// NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk *= params.inv_sqrt_dh;
qk_max = qk;
qk_smem[params.timestep] = qk;
}
......@@ -746,16 +747,18 @@ __global__ void masked_multihead_attention_kernel(
zero(out);
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) {
V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]);
if (vo < V_PER_ITER) {
for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) {
V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti];
out = fma(logit, cast_to_float(v), out);
float logit = logits_smem[ti];
out = fma(logit, cast_to_float(v), out);
#else
T logit = logits_smem[ti];
// Update the partial sums.
out = fma(logit, v, out);
T logit = logits_smem[ti];
// Update the partial sums.
out = fma(logit, v, out);
#endif
}
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
......@@ -784,8 +787,12 @@ __global__ void masked_multihead_attention_kernel(
__syncthreads();
if (vo < pad_active_groups / 2) {
zero(*reinterpret_cast<V_vec *>(&out_smem[vo * Dh + vi]));
}
#pragma unroll
for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) {
for (int active_groups = pad_active_groups; active_groups >= 2;
active_groups /= 2) {
int midpoint = active_groups / 2;
if (vo >= midpoint && vo < active_groups) {
......@@ -830,7 +837,7 @@ __global__ void masked_multihead_attention_kernel(
template <typename T>
inline size_t smem_size_in_bytes(
const Masked_multihead_attention_params<T> &params, int dim_head,
int threads_per_value, int threads_per_block) {
int threads_per_value, int threads_per_block, int pad_active_groups) {
size_t qk_sz = div_up(params.timestep + 1, 4) * 16;
size_t logits_sz = 0;
......@@ -841,31 +848,34 @@ inline size_t smem_size_in_bytes(
#endif
size_t softmax_sz = qk_sz + logits_sz;
int rows_per_red = threads_per_block / threads_per_value;
int rows_per_red = pad_active_groups;
size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2;
return max(softmax_sz, red_sz);
}
#define MMHA_LAUNCH_KERNEL(T, Dh, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, stream) \
size_t smem_sz = \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_head, params.batch_size); \
masked_multihead_attention_kernel< \
T, Dh, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
template <typename T, int Dh>
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, stream) \
int pad_active_groups = \
1 << static_cast<int>(ceil(std::log2(THDS_PER_BLOCK / THDS_PER_VALUE))); \
size_t smem_sz = smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, \
THDS_PER_BLOCK, pad_active_groups); \
dim3 grid(params.num_head, params.batch_size); \
masked_multihead_attention_kernel< \
T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>( \
params, pad_active_groups)
template <typename T, int Dh, int Dh_MAX>
void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params,
const cudaStream_t &stream) {
constexpr int THREADS_PER_VALUE = Dh * sizeof(T) / 16;
if (params.timestep < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, 4, THREADS_PER_VALUE, 64, stream);
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream);
} else if (params.timestep < 2048) {
MMHA_LAUNCH_KERNEL(T, Dh, 2, THREADS_PER_VALUE, 128, stream);
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream);
} else {
MMHA_LAUNCH_KERNEL(T, Dh, 1, THREADS_PER_VALUE, 256, stream);
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream);
}
}
......@@ -890,18 +900,21 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
switch (dim_head) {
case 32:
fmha_launch_kernel<T, 32>(params, dev_ctx.stream());
fmha_launch_kernel<T, 32, 32>(params, dev_ctx.stream());
break;
case 64:
fmha_launch_kernel<T, 64>(params, dev_ctx.stream());
fmha_launch_kernel<T, 64, 64>(params, dev_ctx.stream());
break;
case 96:
fmha_launch_kernel<T, 96, 128>(params, dev_ctx.stream());
break;
case 128:
fmha_launch_kernel<T, 128>(params, dev_ctx.stream());
fmha_launch_kernel<T, 128, 128>(params, dev_ctx.stream());
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"dim_head = %d is unsupport, only support "
"dim_head = 32, 64 or 128 for now.",
"dim_head = 32, 64, 96 or 128 for now.",
dim_head));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册