From 990c5e7f15da02d0484359e811305f6e8e0dd682 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Thu, 2 Jun 2022 14:36:00 +0800 Subject: [PATCH] Support head_dim = 96 in fused_multi_transformer for PLATO-XL (#43120) * Support head_dim = 96 in fused_multi_transformer in PLATO-XL * add notes --- .../fused/fused_multi_transformer_op.cu | 87 +++++++++++-------- 1 file changed, 50 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index fe93d323c5..c13c287f4a 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -529,10 +529,10 @@ inline __device__ void zero(T &dst) { // NOLINT dst = tmp.raw; } -template +template __global__ void masked_multihead_attention_kernel( - Masked_multihead_attention_params params) { + Masked_multihead_attention_params params, int pad_active_groups) { #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) static_assert(Dh % THREADS_PER_KEY == 0, ""); @@ -560,11 +560,12 @@ __global__ void masked_multihead_attention_kernel( const int tid = threadIdx.x; float qk_max = -FLT_MAX; + float qk = 0; // qkv [B, S=1, 3, num_head, head_dim] int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; - using Qk_vec = typename Qk_vec_::Type; + using Qk_vec = typename Qk_vec_::Type; constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); static_assert(Dh % QK_VEC_SIZE == 0 && Dh / QK_VEC_SIZE <= WARP_SIZE, ""); constexpr int QK_VECS_PER_WARP = Dh / QK_VEC_SIZE; @@ -605,18 +606,18 @@ __global__ void masked_multihead_attention_kernel( params.timestep * QK_ELTS_IN_16B + ci; *reinterpret_cast(¶ms.cache_kv[offset]) = k; - float qk = dot(q, k); -#pragma unroll - for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + qk = dot(q, k); + } + if (tid < WARP_SIZE) { + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); } - - qk *= params.inv_sqrt_dh; if (tid == 0) { // NOTE(wangxi): mask must be 0.0 // T mask = params.attn_mask[ // bi * (params.timestep + 1) + params.timestep]; // qk += static_cast(mask); + qk *= params.inv_sqrt_dh; qk_max = qk; qk_smem[params.timestep] = qk; } @@ -746,16 +747,18 @@ __global__ void masked_multihead_attention_kernel( zero(out); constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) { - V_vec v = *reinterpret_cast(&v_cache[ti * Dh]); + if (vo < V_PER_ITER) { + for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) { + V_vec v = *reinterpret_cast(&v_cache[ti * Dh]); #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - float logit = logits_smem[ti]; - out = fma(logit, cast_to_float(v), out); + float logit = logits_smem[ti]; + out = fma(logit, cast_to_float(v), out); #else - T logit = logits_smem[ti]; - // Update the partial sums. - out = fma(logit, v, out); + T logit = logits_smem[ti]; + // Update the partial sums. + out = fma(logit, v, out); #endif + } } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER @@ -784,8 +787,12 @@ __global__ void masked_multihead_attention_kernel( __syncthreads(); + if (vo < pad_active_groups / 2) { + zero(*reinterpret_cast(&out_smem[vo * Dh + vi])); + } #pragma unroll - for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { + for (int active_groups = pad_active_groups; active_groups >= 2; + active_groups /= 2) { int midpoint = active_groups / 2; if (vo >= midpoint && vo < active_groups) { @@ -830,7 +837,7 @@ __global__ void masked_multihead_attention_kernel( template inline size_t smem_size_in_bytes( const Masked_multihead_attention_params ¶ms, int dim_head, - int threads_per_value, int threads_per_block) { + int threads_per_value, int threads_per_block, int pad_active_groups) { size_t qk_sz = div_up(params.timestep + 1, 4) * 16; size_t logits_sz = 0; @@ -841,31 +848,34 @@ inline size_t smem_size_in_bytes( #endif size_t softmax_sz = qk_sz + logits_sz; - int rows_per_red = threads_per_block / threads_per_value; + int rows_per_red = pad_active_groups; size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2; return max(softmax_sz, red_sz); } -#define MMHA_LAUNCH_KERNEL(T, Dh, THDS_PER_KEY, THDS_PER_VALUE, \ - THDS_PER_BLOCK, stream) \ - size_t smem_sz = \ - smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_head, params.batch_size); \ - masked_multihead_attention_kernel< \ - T, Dh, THDS_PER_KEY, THDS_PER_VALUE, \ - THDS_PER_BLOCK><<>>(params) - -template +#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \ + THDS_PER_BLOCK, stream) \ + int pad_active_groups = \ + 1 << static_cast(ceil(std::log2(THDS_PER_BLOCK / THDS_PER_VALUE))); \ + size_t smem_sz = smem_size_in_bytes(params, Dh, THDS_PER_VALUE, \ + THDS_PER_BLOCK, pad_active_groups); \ + dim3 grid(params.num_head, params.batch_size); \ + masked_multihead_attention_kernel< \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \ + THDS_PER_BLOCK><<>>( \ + params, pad_active_groups) + +template void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, const cudaStream_t &stream) { constexpr int THREADS_PER_VALUE = Dh * sizeof(T) / 16; if (params.timestep < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, 4, THREADS_PER_VALUE, 64, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream); } else if (params.timestep < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, 2, THREADS_PER_VALUE, 128, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, 1, THREADS_PER_VALUE, 256, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream); } } @@ -890,18 +900,21 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor, switch (dim_head) { case 32: - fmha_launch_kernel(params, dev_ctx.stream()); + fmha_launch_kernel(params, dev_ctx.stream()); break; case 64: - fmha_launch_kernel(params, dev_ctx.stream()); + fmha_launch_kernel(params, dev_ctx.stream()); + break; + case 96: + fmha_launch_kernel(params, dev_ctx.stream()); break; case 128: - fmha_launch_kernel(params, dev_ctx.stream()); + fmha_launch_kernel(params, dev_ctx.stream()); break; default: PADDLE_THROW(platform::errors::Unimplemented( "dim_head = %d is unsupport, only support " - "dim_head = 32, 64 or 128 for now.", + "dim_head = 32, 64, 96 or 128 for now.", dim_head)); } } -- GitLab