diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index 01c5b79fff11569f7b43b7024ca95f8da1739b84..0042ed6bac7b4082066c4b3ecc070eda19ddf5b4 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -114,9 +114,11 @@ template struct Qk_vec_ {}; template <> struct Qk_vec_ { using Type = float; }; template <> struct Qk_vec_ { using Type = float2; }; template <> struct Qk_vec_ { using Type = float4; }; +template <> struct Qk_vec_ { using Type = float4; }; template <> struct Qk_vec_ { using Type = uint32_t; }; template <> struct Qk_vec_ { using Type = uint32_t; }; template <> struct Qk_vec_ { using Type = uint2; }; +template <> struct Qk_vec_ { using Type = uint4; }; template struct K_vec_ {}; template <> struct K_vec_ { using Type = float; }; @@ -532,11 +534,11 @@ inline __device__ void zero(T &dst) { // NOLINT template __global__ void masked_multihead_attention_kernel( - Masked_multihead_attention_params params, int pad_active_groups) { + Masked_multihead_attention_params params) { #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) - static_assert(Dh % THREADS_PER_KEY == 0, ""); - static_assert(Dh % THREADS_PER_VALUE == 0, ""); + static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); + static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); constexpr int WARP_SIZE = 32; constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; @@ -552,7 +554,8 @@ __global__ void masked_multihead_attention_kernel( T *out_smem = reinterpret_cast(smem_); __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - __shared__ T q_smem[Dh]; + using Qk_vec = typename Qk_vec_::Type; + __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; const int bi = blockIdx.y; const int hi = blockIdx.x; @@ -565,10 +568,11 @@ __global__ void masked_multihead_attention_kernel( // qkv [B, S=1, 3, num_head, head_dim] int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; - using Qk_vec = typename Qk_vec_::Type; constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); - static_assert(Dh % QK_VEC_SIZE == 0 && Dh / QK_VEC_SIZE <= WARP_SIZE, ""); - constexpr int QK_VECS_PER_WARP = Dh / QK_VEC_SIZE; + static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); + // Use block reduction if needed + // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); + constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; // cache_k, [B, num_head, head_dim / x, max_seq_len, x] // x == 4/8 for FP32/FP16, 128bit, 16Byte @@ -584,13 +588,29 @@ __global__ void masked_multihead_attention_kernel( int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE; int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE; - Qk_vec q = *reinterpret_cast(&q_base[qk_offset]); - Qk_vec k = *reinterpret_cast(&k_base[qk_offset]); - - Qk_vec q_bias = - *reinterpret_cast(&q_bias_base[qk_bias_offset]); - Qk_vec k_bias = - *reinterpret_cast(&k_bias_base[qk_bias_offset]); + Qk_vec q; + zero(q); + q = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&q_base[qk_offset]) + : q; + Qk_vec k; + zero(k); + k = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&k_base[qk_offset]) + : k; + + Qk_vec q_bias; + zero(q_bias); + q_bias = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&q_bias_base[qk_bias_offset]) + : q_bias; + Qk_vec k_bias; + zero(k_bias); + k_bias = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&k_bias_base[qk_bias_offset]) + : k_bias; q = add(q, q_bias); // TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510 @@ -604,24 +624,33 @@ __global__ void masked_multihead_attention_kernel( int offset = bhi * params.max_seq_length * Dh + co * params.max_seq_length * QK_ELTS_IN_16B + params.timestep * QK_ELTS_IN_16B + ci; - *reinterpret_cast(¶ms.cache_kv[offset]) = k; + if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { + *reinterpret_cast(¶ms.cache_kv[offset]) = k; + } qk = dot(q, k); - } - if (tid < WARP_SIZE) { - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); - } - if (tid == 0) { - // NOTE(wangxi): mask must be 0.0 - // T mask = params.attn_mask[ - // bi * (params.timestep + 1) + params.timestep]; - // qk += static_cast(mask); - qk *= params.inv_sqrt_dh; - qk_max = qk; - qk_smem[params.timestep] = qk; + + if (QK_VECS_PER_WARP <= WARP_SIZE) { +#pragma unroll + for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + } } } + if (QK_VECS_PER_WARP > WARP_SIZE) { + constexpr int WARPS_PER_RED = + (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + qk = block_sum(&red_smem[WARPS_PER_RED], qk); + } + if (tid == 0) { + // NOTE(wangxi): mask must be 0.0 + // T mask = params.attn_mask[ + // bi * (params.timestep + 1) + params.timestep]; + // qk += static_cast(mask); + qk *= params.inv_sqrt_dh; + qk_max = qk; + qk_smem[params.timestep] = qk; + } __syncthreads(); #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER @@ -635,13 +664,15 @@ __global__ void masked_multihead_attention_kernel( using K_vec = typename K_vec_::Type; constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); - static_assert(Dh % K_VEC_SIZE == 0, ""); - constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY; + static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); + constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; int ko = tid / THREADS_PER_KEY; int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE; + static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD, ""); + K_vec q[K_VECS_PER_THREAD]; #pragma unroll for (int i = 0; i < K_VECS_PER_THREAD; ++i) { @@ -657,11 +688,17 @@ __global__ void masked_multihead_attention_kernel( for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { K_vec k[K_VECS_PER_THREAD]; + K_vec k_vec_zero; + zero(k_vec_zero); #pragma unroll for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { int jj = ii * params.max_seq_length + ti; if (ti < params.timestep) { - k[ii] = *reinterpret_cast(&k_cache[jj * QK_ELTS_IN_16B]); + k[ii] = + (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) + ? *reinterpret_cast( + &k_cache[jj * QK_ELTS_IN_16B]) + : k_vec_zero; } } @@ -727,7 +764,7 @@ __global__ void masked_multihead_attention_kernel( } __syncthreads(); - constexpr int V_VEC_SIZE = Dh / THREADS_PER_VALUE; + constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; using V_vec = typename V_vec_::Type; int vo = tid / THREADS_PER_VALUE; @@ -747,7 +784,7 @@ __global__ void masked_multihead_attention_kernel( zero(out); constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - if (vo < V_PER_ITER) { + if (Dh == Dh_MAX || vi < Dh) { for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) { V_vec v = *reinterpret_cast(&v_cache[ti * Dh]); #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) @@ -770,10 +807,12 @@ __global__ void masked_multihead_attention_kernel( __syncthreads(); #endif - if (vo == (params.timestep % V_PER_ITER)) { + V_vec v_bias; + zero(v_bias); + if (vo == (params.timestep % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { V_vec v = *reinterpret_cast( ¶ms.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); - V_vec v_bias = *reinterpret_cast( + v_bias = *reinterpret_cast( ¶ms.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); v = add(v, v_bias); *reinterpret_cast(&v_cache[params.timestep * Dh]) = v; @@ -787,31 +826,31 @@ __global__ void masked_multihead_attention_kernel( __syncthreads(); - if (vo < pad_active_groups / 2) { - zero(*reinterpret_cast(&out_smem[vo * Dh + vi])); - } + if (Dh == Dh_MAX || vi < Dh) { #pragma unroll - for (int active_groups = pad_active_groups; active_groups >= 2; - active_groups /= 2) { - int midpoint = active_groups / 2; + for (int active_groups = V_PER_ITER; active_groups >= 2; + active_groups /= 2) { + int midpoint = active_groups / 2; - if (vo >= midpoint && vo < active_groups) { + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { #ifdef MMHA_USE_FP32_ACUM_FOR_OUT - convert_from_float( - *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), - out); + convert_from_float( + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), + out); #else - *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; #endif + } + __syncthreads(); + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = + add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); + } + __syncthreads(); } - __syncthreads(); - if (vo < midpoint) { - out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); - } - __syncthreads(); } - if (vo == 0) { + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { #ifdef MMHA_USE_FP32_ACUM_FOR_OUT convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); @@ -837,7 +876,7 @@ __global__ void masked_multihead_attention_kernel( template inline size_t smem_size_in_bytes( const Masked_multihead_attention_params ¶ms, int dim_head, - int threads_per_value, int threads_per_block, int pad_active_groups) { + int threads_per_value, int threads_per_block) { size_t qk_sz = div_up(params.timestep + 1, 4) * 16; size_t logits_sz = 0; @@ -848,27 +887,25 @@ inline size_t smem_size_in_bytes( #endif size_t softmax_sz = qk_sz + logits_sz; - int rows_per_red = pad_active_groups; + int rows_per_red = threads_per_block / threads_per_value; size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2; return max(softmax_sz, red_sz); } -#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \ - THDS_PER_BLOCK, stream) \ - int pad_active_groups = \ - 1 << static_cast(ceil(std::log2(THDS_PER_BLOCK / THDS_PER_VALUE))); \ - size_t smem_sz = smem_size_in_bytes(params, Dh, THDS_PER_VALUE, \ - THDS_PER_BLOCK, pad_active_groups); \ - dim3 grid(params.num_head, params.batch_size); \ - masked_multihead_attention_kernel \ - <<>>(params, pad_active_groups) +#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \ + THDS_PER_BLOCK, stream) \ + size_t smem_sz = \ + smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_head, params.batch_size); \ + masked_multihead_attention_kernel \ + <<>>(params) template void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, const cudaStream_t &stream) { - constexpr int THREADS_PER_VALUE = Dh * sizeof(T) / 16; + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; if (params.timestep < 32) { MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream); } else if (params.timestep < 2048) { @@ -898,6 +935,12 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor, params.inv_sqrt_dh = inv_sqrt_dh; switch (dim_head) { + case 10: + fmha_launch_kernel(params, dev_ctx.stream()); + break; + case 26: + fmha_launch_kernel(params, dev_ctx.stream()); + break; case 32: fmha_launch_kernel(params, dev_ctx.stream()); break; @@ -910,11 +953,12 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor, case 128: fmha_launch_kernel(params, dev_ctx.stream()); break; + case 192: + fmha_launch_kernel(params, dev_ctx.stream()); + break; default: PADDLE_THROW(platform::errors::Unimplemented( - "dim_head = %d is unsupport, only support " - "dim_head = 32, 64, 96 or 128 for now.", - dim_head)); + "Dim_head = %d is unsupport!", dim_head)); } }