diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 7512cf6133fb8c9d3b9af523da067d3c0c9d4244..ce4ed90aef0ee6074bf3fc89c422a16955e4a172 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1607,14 +1607,14 @@ backward : margin_cross_entropy_grad - op : masked_multihead_attention_ - args : (Tensor x, Tensor cache_kv, Tensor src_mask, Tensor cum_offsets, Tensor sequence_lengths, Tensor rotary_tensor, Tensor beam_cache_offset, Tensor qkv_out_scale, Tensor out_shift, Tensor out_smooth, int seq_len, int rotary_emb_dims, bool use_neox_rotary_style=false, float out_scale=-1, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0) + args : (Tensor x, Tensor cache_kv, Tensor bias, Tensor src_mask, Tensor cum_offsets, Tensor sequence_lengths, Tensor rotary_tensor, Tensor beam_cache_offset, Tensor qkv_out_scale, Tensor out_shift, Tensor out_smooth, int seq_len, int rotary_emb_dims, bool use_neox_rotary_style=false, str compute_dtype = "default", float out_scale=-1, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0) output : Tensor(out), Tensor(cache_kv_out), Tensor(beam_cache_offset_out) infer_meta : func : MaskedMultiheadAttentionInferMeta kernel : func : masked_multihead_attention - data_type : cache_kv - optional : src_mask, cum_offsets, sequence_lengths, rotary_tensor, beam_cache_offset, qkv_out_scale, out_shift, out_smooth + data_type : x + optional : bias, src_mask, cum_offsets, sequence_lengths, rotary_tensor, beam_cache_offset, qkv_out_scale, out_shift, out_smooth inplace : (cache_kv -> cache_kv_out), (beam_cache_offset -> beam_cache_offset_out) - op : masked_select diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 86a70ee2bbdc8f89296d5cb2b70a098246d983e9..7f4411f2c807ce67cba4d4748d937270533e4e41 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -4094,6 +4094,7 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row, void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, const MetaTensor& cache_kv, + const MetaTensor& bias, const MetaTensor& src_mask, const MetaTensor& cum_offsets, const MetaTensor& sequence_lengths, @@ -4105,6 +4106,7 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, int seq_len, int rotary_emb_dims, const bool use_neox_rotary_style, + const std::string& compute_dtype, const float out_scale, const int quant_round_type, const float quant_max_bound, @@ -4113,7 +4115,6 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, MetaTensor* cache_kv_out, MetaTensor* beam_cache_offset_out) { int bsz = x.dims()[0]; - auto x_dtype = x.dtype(); auto cache_kv_dims = cache_kv.dims(); int num_head = cache_kv.dims()[2]; int dim_head = cache_kv.dims()[4]; @@ -4141,10 +4142,86 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, out->set_dims({bsz, num_head * dim_head}); - if (out_scale > 0) { - out->set_dtype(DataType::INT8); + auto FBADtypeCheck = [](const MetaTensor& check_tensor, + const std::string& tensor_name, + const std::string& compute_dtype) { + if (compute_dtype == "bf16") { + PADDLE_ENFORCE_EQ( + check_tensor.dtype(), + phi::DataType::BFLOAT16, + phi::errors::InvalidArgument( + "Input(%s) dtype must be the same with Attr(compute_dtype)", + tensor_name)); + } else if (compute_dtype == "fp16") { + PADDLE_ENFORCE_EQ( + check_tensor.dtype(), + phi::DataType::FLOAT16, + phi::errors::InvalidArgument( + "Input(%s) dtype must be the same with Attr(compute_dtype)", + tensor_name)); + } else if (compute_dtype == "fp32") { + PADDLE_ENFORCE_EQ( + check_tensor.dtype(), + phi::DataType::FLOAT32, + phi::errors::InvalidArgument( + "Input(%s) dtype must be the same with Attr(compute_dtype)", + tensor_name)); + } + }; + + // In the case of quantization enabled, the dtype for computation is + // determined based on compute_dtype. + if (x.dtype() == phi::DataType::INT32) { + PADDLE_ENFORCE_NE( + compute_dtype, + "default", + phi::errors::InvalidArgument( + "If Input(x) dtype is INT32, Attr(compute_dtype) must be set.")); + + if (bias) { + FBADtypeCheck(bias, "bias", compute_dtype); + } + + if (out_scale > 0) { + out->set_dtype(phi::DataType::INT8); + } else { + if (compute_dtype == "bf16") { + out->set_dtype(phi::DataType::BFLOAT16); + } else if (compute_dtype == "fp16") { + out->set_dtype(phi::DataType::FLOAT16); + } else if (compute_dtype == "fp32") { + out->set_dtype(phi::DataType::FLOAT32); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "In the case of quantization enabled with Input(x) INT32, " + "Attr(compute_dtype) must be set in (bf16, fp16, fp32), " + "but get compute_dtype (%s)", + compute_dtype)); + } + } } else { - out->set_dtype(x_dtype); + if (bias) { + if (compute_dtype != "default") { + FBADtypeCheck(bias, "bias", compute_dtype); + FBADtypeCheck(x, "x", compute_dtype); + } else { + PADDLE_ENFORCE_EQ( + x.dtype(), + bias.dtype(), + phi::errors::InvalidArgument("Input(x) and Input(bias) must be the " + "same dtype in this situation")); + } + } else { + // bias not exist + if (compute_dtype != "default") { + FBADtypeCheck(x, "x", compute_dtype); + } + } + if (out_scale > 0) { + out->set_dtype(phi::DataType::INT8); + } else { + out->set_dtype(x.dtype()); + } } cache_kv_out->set_dims(cache_kv_dims); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index c427f7e8fcc2985c40caf6eb94add4ef5c36320c..2eb5e2551646fae251695daf432832bd900974f4 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -801,6 +801,7 @@ void FusedRopeInferMeta(const MetaTensor& q, void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, const MetaTensor& cache_kv, + const MetaTensor& bias, const MetaTensor& src_mask, const MetaTensor& cum_offsets, const MetaTensor& sequence_lengths, @@ -812,6 +813,7 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, int seq_len, int rotary_emb_dims, const bool use_neox_rotary_style, + const std::string& compute_dtype, const float out_scale, const int quant_round_type, const float quant_max_bound, diff --git a/paddle/phi/kernels/fusion/gpu/masked_multihead_attention.cu b/paddle/phi/kernels/fusion/gpu/masked_multihead_attention.cu index d95336528c2243742613a43a6dcf3c558acce0d0..312f81ae31a6188b8a408e67a4a598f8d4e6f83d 100644 --- a/paddle/phi/kernels/fusion/gpu/masked_multihead_attention.cu +++ b/paddle/phi/kernels/fusion/gpu/masked_multihead_attention.cu @@ -12,36 +12,1301 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/fusion/gpu/masked_multihead_attention.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/fusion/gpu/mmha_util.cu.h" namespace phi { namespace fusion { -template -void MMHAKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& cache_kv, - const paddle::optional& src_mask, - const paddle::optional& cum_offsets, - const paddle::optional& sequence_lengths, - const paddle::optional& rotary_tensor, - const paddle::optional& beam_cache_offset, - const paddle::optional& qkv_out_scale, - const paddle::optional& out_shift, - const paddle::optional& out_smooth, - int seq_len, - int rotary_emb_dims, - const bool use_neox_rotary_style, - const float out_scale, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - DenseTensor* out, - DenseTensor* cache_kv_out, - DenseTensor* beam_cache_offset_out) { #ifndef PADDLE_WITH_HIP - const auto& x_dims = x.dims(); + +#define MMHA_USE_FP32_ACUM_FOR_LOGITS +#define MMHA_USE_FP32_ACUM_FOR_OUT +#define MMHA_USE_FP32_ACUM_FOR_FMA + +template +__device__ __inline__ T ClipFunc(const T v, const T min, const T max) { + if (v > max) return max; + if (v < min) return min; + return v; +} + +constexpr unsigned int str2int(const char *str, int h = 0) { + return !str[h] ? 5381 : (str2int(str, h + 1) * 33) ^ str[h]; +} + +template +__forceinline__ __device__ OutType QuantHelperFunc(const InType input, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + float quant_value = max_bound * scale * input; + + if (round_type == 0) { + quant_value = static_cast(rint(quant_value)); + } else { + quant_value = static_cast(round(quant_value)); + } + return static_cast( + ClipFunc(quant_value, min_bound, max_bound)); +} + +template +struct Masked_multihead_attention_params { + // output buffer, [B, 1(seq_len), num_head * dim_head] + T *out; + // qkv_out, [B, 1(seq_len), 3, num_head * dim_head] + const T *qkv; + // bias, [3, num_head, dim_head] + T *qkv_bias; + // [bsz, seq_len] + const int *cum_offsets; + // TODO(wangxi): optimize with input_lengths and max_input_len? + // [bsz, 1, 1, time_step(cache_seq_length)+1] + const T *attn_mask; + int mask_length; + // whether to broadcast num_heads(2nd) dimension for attn_mask + // in MMHA, if false, attn_mask shape should be + // [bsz, num_heads, 1, time_step(cache_seq_length)+1] + bool mask_broadcast_num_heads; + + // [2, B, num_head, max_seq_len(valid cache_seq_len), dim_head] + // k [B, num_head, dim_head/x, max_seq_len, x], that is `seq_len` first + // v [B, num_head, max_seq_len, dim_head] + T *cache_kv; + // [B, max_seq_len] + const int *beam_cache_offset = nullptr; + + const int *sequence_lengths{nullptr}; + + // The RoPE embedding, [2, B, rotary_seq_len, 1, dim_head] + // rotary_emb_dims = 1 if pos_ids_extra is null else 2 + const float *rotary_emb; + int rotary_emb_dims; + int rotary_seq_len = 1; + + int batch_size; // batch * beam + int beam_width; + int cache_batch_size; + int num_head; + int timestep; // cache_seq_length + int seq_len; + int max_seq_length; + + // 1.f / sqrt(Dh) + float inv_sqrt_dh; + + bool add_qkv_bias; + bool neox_rotary_style; +}; + +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA +template +struct K_vec_acum_fp32_ {}; + +template <> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +#endif + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT +template +struct V_vec_acum_fp32_ {}; +// template <> struct V_vec_acum_fp32_ { using Type = float; }; +// template <> struct V_vec_acum_fp32_ { using Type = float2; }; +template <> +struct V_vec_acum_fp32_ { + using Type = float4; +}; +// template <> struct V_vec_acum_fp32_ { using Type = float2; }; +// template <> struct V_vec_acum_fp32_ { using Type = Float4_; }; +template <> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; + +#ifdef ENABLE_BF16 +template <> +struct V_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template <> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template <> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#endif // ENABLE_BF16 + +#endif + +// clang-format on + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_dot_(const K_vec (&q)[N], + const K_vec (&k)[N], + float inv_sqrt_dh) { + K_vec inv_q = mul(q[0], inv_sqrt_dh); + K_vec qk_vec = mul(inv_q, k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + inv_q = mul(q[ii], inv_sqrt_dh); + qk_vec = fma(inv_q, k[ii], qk_vec); + } + + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +inline __device__ float4 hmma_fp32_tensorcore(const uint2 &a, uint32_t b) { + float4 c; + float zero = 0.f; + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%7, %7, %7, %7}; \n" + + : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) + : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); + return c; +} + +template +inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], + const uint32_t (&k)[N], + float inv_sqrt_dh) { +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 750 +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = uint32_t; +#endif + K_vec_acum inv_q = mul(q[0], inv_sqrt_dh); + K_vec_acum qk_vec = mul(inv_q, k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + inv_q = mul(q[ii], inv_sqrt_dh); + qk_vec = fma(inv_q, k[ii], qk_vec); + } +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + uint32_t qk_vec_ = float2_to_half2(qk_vec); + return hmma_fp32_tensorcore(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; +#else + return hmma_fp32_tensorcore(make_uint2(qk_vec, 0u), 0x3c003c00u).x; +#endif +#else + return 0.f; +#endif +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const K_vec (&q)[N], + const K_vec (&k)[N], + float inv_sqrt_dh) { + return qk_dot_(q, k, inv_sqrt_dh); + } +}; + +template <> +struct Qk_dot { + template + static inline __device__ float dot(const uint32_t (&q)[N], + const uint32_t (&k)[N], + float inv_sqrt_dh) { +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 750 + return qk_hmma_dot_(q, k, inv_sqrt_dh); +#else + return qk_dot_<4>(q, k, inv_sqrt_dh); +#endif + } +}; + +template +inline __device__ float block_sum(float *red_smem, float sum) { + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + if (lane == 0) { + red_smem[warp] = sum; + } + __syncthreads(); + + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } + +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + return __shfl_sync(uint32_t(-1), sum, 0); +} + +inline __device__ void convert_from_float(float &dst, float src) { // NOLINT + dst = src; +} + +inline __device__ void convert_from_float(float4 &dst, float4 src) { // NOLINT + dst = src; +} + +inline __device__ void convert_from_float(phi::float16 &dst, // NOLINT + float src) { + dst = static_cast(src); +} + +inline __device__ void convert_from_float(uint4 &dst, Float8_ src) { // NOLINT + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(__nv_bfloat16 &dst, // NOLINT + float src) { // NOLINT + dst = __float2bfloat16(src); +} + +inline __device__ void convert_from_float(__nv_bfloat162 &dst, // NOLINT + float2 src) { // NOLINT +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst = __float22bfloat162_rn(src); +#else + dst = __floats2bfloat162_rn(src.x, src.y); +#endif +} + +inline __device__ void convert_from_float(bf16_4_t &dst, // NOLINT + Float4_ src) { // NOLINT +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_4_t &dst, // NOLINT + float4 src) { // NOLINT + convert_from_float( + dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); +} + +inline __device__ void convert_from_float(bf16_8_t &dst, // NOLINT + Float8_ src) { // NOLINT +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); + dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); + dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); +#endif +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void zero(uint16_t &dst) { dst = uint16_t(0); } // NOLINT + +template +inline __device__ void zero(T &dst) { // NOLINT + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +template +__global__ void masked_multihead_attention_kernel( + Masked_multihead_attention_params params, + LoadFunc load_func, + StoreFunc store_func) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + const int bi = blockIdx.y; + if (params.sequence_lengths && params.sequence_lengths[bi] == 0) { + return; + } + + typedef PDDataTypeTraits traits_; + typedef typename traits_::DataType DataType_; + + static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); + static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); + + constexpr int WARP_SIZE = 32; + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + extern __shared__ char smem_[]; + + float *qk_smem = reinterpret_cast(smem_); + + char *logits_smem_ = smem_; + // fp32 accum for logits + float *logits_smem = reinterpret_cast(logits_smem_); + + T *out_smem = reinterpret_cast(smem_); + + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + using Qk_vec = typename Qk_vec_::Type; + using Qk_vec_RoPE = typename Qk_vec_RoPE_::Type; + __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; + + // beam id + const int beami = bi % params.beam_width; + // real batch id + const int bbi = bi / params.beam_width; + const int hi = blockIdx.x; + const int bhi = bi * params.num_head + hi; + const int bbhi = bbi * params.beam_width * params.num_head + hi; + const int ti = + params.cum_offsets ? bi * params.seq_len - params.cum_offsets[bi] : -1; + const int thi = params.cum_offsets ? ti * params.num_head + hi : -1; + const int tid = threadIdx.x; + + const int bi_seq_len_offset = bi * params.max_seq_length; + + float qk_max = -FLT_MAX; + float qk = 0; + + int act_time_step = params.sequence_lengths == nullptr + ? params.timestep + : params.sequence_lengths[bi]; + + // qkv [B, S=1, 3, num_head, head_dim] + int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; + + constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); + static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); + // Use block reduction if needed + // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); + constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; + + // cache_k, [B, num_head, head_dim / x, max_seq_len, x] + // x == 4/8 for FP32/FP16, 128bit, 16Byte + constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); + constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); + + // const T *q_base = params.qkv; + // const T *k_base = params.qkv + params.num_head * Dh; + T *q_bias_base = nullptr; + T *k_bias_base = nullptr; + + if (params.add_qkv_bias) { + q_bias_base = params.qkv_bias; + k_bias_base = params.qkv_bias + params.num_head * Dh; + } + + if (tid < QK_VECS_PER_WARP) { + int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE; + int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE; + + Qk_vec q; + zero(q); + // q = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + // ? *reinterpret_cast(&q_base[qk_offset]) + // : q; + if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) { + load_func.template load(q, qk_offset); + } + + Qk_vec k; + zero(k); + // k = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + // ? *reinterpret_cast(&k_base[qk_offset]) + // : k; + if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) { + load_func.template load(k, params.num_head * Dh + qk_offset); + } + + if (params.add_qkv_bias) { + Qk_vec q_bias; + zero(q_bias); + Qk_vec k_bias; + zero(k_bias); + + q_bias = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&q_bias_base[qk_bias_offset]) + : q_bias; + k_bias = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&k_bias_base[qk_bias_offset]) + : k_bias; + + q = add(q, q_bias); + // TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510 + // we may not require k_bias. + k = add(k, k_bias); + } + + if (!params.neox_rotary_style) { + if (params.rotary_emb_dims != 0) { + int rotary_offset = bi * Dh + tid * QK_VEC_SIZE; + const float *cos_base = params.rotary_emb; + const float *sin_base = params.rotary_emb + params.batch_size * Dh; + Qk_vec_RoPE cos_emb, sin_emb; + zero(cos_emb); + zero(sin_emb); + cos_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &cos_base[rotary_offset]) + : cos_emb; + sin_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &sin_base[rotary_offset]) + : sin_emb; + apply_rotary_embedding(q, k, cos_emb, sin_emb); + } + } else { + /* old rotary pos emb */ + if (params.rotary_emb_dims != 0) { + int last_dim = Dh / params.rotary_emb_dims; + int half_lastdim = last_dim / 2; + int rotary_offset = bi * Dh + tid * QK_VEC_SIZE; + const float *cos_base = params.rotary_emb; + const float *sin_base = params.rotary_emb + params.batch_size * Dh; + int stride = half_lastdim / QK_VEC_SIZE; + int stride_all_lastdim = 2 * stride; + int right_id = tid / stride_all_lastdim * stride_all_lastdim + + (tid + stride) % (stride_all_lastdim); + int qk_right_offset = qkv_base_offset + right_id * QK_VEC_SIZE; + int qk_right_bias_offset = hi * Dh + right_id * QK_VEC_SIZE; + Qk_vec q_right; + zero(q_right); + // q_right = + // (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + // ? *reinterpret_cast(&q_base[qk_right_offset]) + // : q_right; + if (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) { + load_func.template load(q_right, qk_right_offset); + } + Qk_vec k_right; + zero(k_right); + // k_right = + // (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + // ? *reinterpret_cast(&k_base[qk_right_offset]) + // : k_right; + if (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) { + load_func.template load( + k_right, params.num_head * Dh + qk_right_offset); + } + + if (params.add_qkv_bias) { + Qk_vec q_right_bias; + zero(q_right_bias); + q_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &q_bias_base[qk_right_bias_offset]) + : q_right_bias; + Qk_vec k_right_bias; + zero(k_right_bias); + k_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &k_bias_base[qk_right_bias_offset]) + : k_right_bias; + + q_right = add(q_right, q_right_bias); + k_right = add(k_right, k_right_bias); + } + + Qk_vec_RoPE cos_emb; + zero(cos_emb); + cos_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &cos_base[rotary_offset]) + : cos_emb; + + Qk_vec_RoPE sin_emb; + zero(sin_emb); + sin_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &sin_base[rotary_offset]) + : sin_emb; + float alpha = (tid % stride_all_lastdim) < stride + ? static_cast(-1) + : static_cast(1); + q = apply_rotary_emb( + q, q_right, cos_emb, sin_emb, alpha); + k = apply_rotary_emb( + k, k_right, cos_emb, sin_emb, alpha); + } + } + + *reinterpret_cast(&q_smem[tid * QK_VEC_SIZE]) = q; + + int co = tid / QK_VECS_IN_16B; + int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE; + int offset = bhi * params.max_seq_length * Dh + + co * params.max_seq_length * QK_ELTS_IN_16B + + act_time_step * QK_ELTS_IN_16B + ci; + if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { + *reinterpret_cast(¶ms.cache_kv[offset]) = k; + } + + qk = dot(q, k); + + if (QK_VECS_PER_WARP <= WARP_SIZE) { +#pragma unroll + for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + } + } + } + if (QK_VECS_PER_WARP > WARP_SIZE) { + constexpr int WARPS_PER_RED = + (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + qk = block_sum(&red_smem[WARPS_PER_RED], qk); + } + if (tid == 0) { + // NOTE(wangxi): mask must be 0.0 + // T mask = params.attn_mask[ + // bi * (params.timestep + 1) + params.timestep]; + // qk += static_cast(mask); + qk *= params.inv_sqrt_dh; + qk_max = qk; + qk_smem[act_time_step] = qk; + } + __syncthreads(); + + using K_vec = typename K_vec_::Type; + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); + static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); + constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + + int ko = tid / THREADS_PER_KEY; + int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE; + + static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD, ""); + + K_vec q[K_VECS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < K_VECS_PER_THREAD; ++i) { + q[i] = *reinterpret_cast( + &q_smem[ki + i * THREADS_PER_KEY * K_VEC_SIZE]); + } + + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + T *k_cache = ¶ms.cache_kv[bhi * params.max_seq_length * Dh + ki]; + T *k_cache_batch = ¶ms.cache_kv[bbhi * params.max_seq_length * Dh + ki]; + int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP; + + const int *beam_offsets = params.beam_cache_offset + ? ¶ms.beam_cache_offset[bi_seq_len_offset] + : nullptr; + for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { + const int beam_offset = beam_offsets ? beam_offsets[ti] * params.num_head * + params.max_seq_length * Dh + : 0; + K_vec k[K_VECS_PER_THREAD]; + K_vec k_vec_zero; + zero(k_vec_zero); +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_seq_length + ti; + if (ti < act_time_step) { + if (beam_offset) { + k[ii] = + (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) + ? *reinterpret_cast( + &k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]) + : k_vec_zero; + } else { + k[ii] = + (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) + ? *reinterpret_cast( + &k_cache[jj * QK_ELTS_IN_16B]) + : k_vec_zero; + } + } + } + + // NOTE(liyurui): We should multiple q with inv_sqrt_dh first, for dot(q, k) + // may overflow with FP16 in large model. + float qk = Qk_dot::dot(q, k, params.inv_sqrt_dh); + + // bool is_mask = false; + if (ti < act_time_step && tid % THREADS_PER_KEY == 0) { + // qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); + auto mask_bhi = params.mask_broadcast_num_heads ? bi : bhi; + // T mask = params.attn_mask[mask_bhi * (params.timestep + 1) + ti]; + if (params.attn_mask) { + T mask = params.attn_mask[mask_bhi * params.mask_length + ti]; + qk += static_cast(mask); + } + qk_max = fmaxf(qk_max, qk); + + qk_smem[ti] = qk; + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + const int warp = tid / WARP_SIZE; + const int lane = tid % WARP_SIZE; + + if (lane == 0) { + red_smem[warp] = qk_max; + } + + __syncthreads(); + + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + float sum = 0.f; + for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) { + // bool is_mask = false; + // float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max); + float logit = __expf(qk_smem[ti] - qk_max); + sum += logit; + qk_smem[ti] = logit; + } + + sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); + + // FIXME(wangxi): need add 1.e-6f? + float inv_sum = __fdividef(1.f, sum + 1.e-6f); + + for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) { + convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum); + } + __syncthreads(); + + constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; + using V_vec = typename V_vec_::Type; + + int vo = tid / THREADS_PER_VALUE; + int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE; + + T *v_cache = ¶ms.cache_kv[params.cache_batch_size * params.num_head * + params.max_seq_length * Dh + + bhi * params.max_seq_length * Dh + vi]; + T *v_cache_batch = ¶ms.cache_kv[params.batch_size * params.num_head * + params.max_seq_length * Dh + + bbhi * params.max_seq_length * Dh + vi]; + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + using V_vec_acum = typename V_vec_acum_fp32_::Type; +#else + using V_vec_acum = V_vec; +#endif + + V_vec_acum out; + zero(out); + + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + if (Dh == Dh_MAX || vi < Dh) { + for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) { + const int beam_offset = + beam_offsets + ? beam_offsets[ti] * params.num_head * params.max_seq_length * Dh + : 0; + V_vec v; + if (beam_offset) { + v = *reinterpret_cast( + &v_cache_batch[beam_offset + ti * Dh]); + } else { + v = *reinterpret_cast(&v_cache[ti * Dh]); + } +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti]; + out = fma(logit, cast_to_float(v), out); +#else + DataType_ logit = static_cast(logits_smem[ti]); + // Update the partial sums. + out = fma(logit, v, out); +#endif + } + } + + V_vec v_bias; + zero(v_bias); + if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { + // V_vec v = *reinterpret_cast( + // ¶ms.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); + V_vec v; + load_func.template load( + v, 2 * params.num_head * Dh + qkv_base_offset + vi); + if (params.add_qkv_bias) { + v_bias = *reinterpret_cast( + ¶ms.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); + v = add(v, v_bias); + } + + *reinterpret_cast(&v_cache[act_time_step * Dh]) = v; + +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + out = fma(logits_smem[act_time_step], cast_to_float(v), out); +#else + out = fma(logits_smem[act_time_step], v, out); +#endif + } + + __syncthreads(); + + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; + active_groups /= 2) { + int midpoint = active_groups / 2; + + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float( + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), + out); +#else + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; +#endif + } + __syncthreads(); + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = + add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); + } + __syncthreads(); + } + } + + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + // convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + + // vi]), + // out); + V_vec tmp_out; + convert_from_float(tmp_out, out); + store_func.template store(tmp_out, + thi != -1 ? thi * Dh + vi : bhi * Dh + vi); +#else + // *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; + store_func.template store(out, + thi != -1 ? thi * Dh + vi : bhi * Dh + vi); +#endif + } + +#else + assert(false); +#endif +} + +template +inline size_t smem_size_in_bytes( + const Masked_multihead_attention_params ¶ms, + int dim_head, + int threads_per_value, + int threads_per_block) { + size_t qk_sz = div_up(params.timestep + 1, 4) * 16; + size_t logits_sz = 0; + +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS // NOLINT + if (sizeof(T) != 4) { + logits_sz = div_up(params.max_seq_length, 4) * 4 * sizeof(T); + } +#endif // NOLINT + size_t softmax_sz = qk_sz + logits_sz; + + int rows_per_red = threads_per_block / threads_per_value; + size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2; + + return max(softmax_sz, red_sz); +} + +#define MMHA_LAUNCH_KERNEL(T, \ + Dh, \ + Dh_MAX, \ + THDS_PER_KEY, \ + THDS_PER_VALUE, \ + THDS_PER_BLOCK, \ + stream, \ + load_func, \ + store_func) \ + size_t smem_sz = \ + smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ + constexpr auto kernel_fn = \ + masked_multihead_attention_kernel; \ + if (smem_sz > 0xc000) { \ + cudaFuncSetAttribute( \ + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + } \ + dim3 grid(params.num_head, params.batch_size); \ + kernel_fn<<>>( \ + params, load_func, store_func) + +template +void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, + const cudaStream_t &stream, + LoadFunc load_func, + StoreFunc store_func) { + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + if (params.timestep < 32) { + MMHA_LAUNCH_KERNEL( + T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream, load_func, store_func); + } else if (params.timestep < 2048) { +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 750 + MMHA_LAUNCH_KERNEL(T, + Dh, + Dh_MAX, + 4, + THREADS_PER_VALUE, + 256, + stream, + load_func, + store_func); +#else + MMHA_LAUNCH_KERNEL(T, + Dh, + Dh_MAX, + 2, + THREADS_PER_VALUE, + 128, + stream, + load_func, + store_func); +#endif + } else { + MMHA_LAUNCH_KERNEL(T, + Dh, + Dh_MAX, + 1, + THREADS_PER_VALUE, + 256, + stream, + load_func, + store_func); + } +} + +template +void fmha_impl(const phi::GPUContext &dev_ctx, + const Masked_multihead_attention_params ¶ms, + int dim_head, + LoadFunc load_func, + StoreFunc store_func) { + switch (dim_head) { + case 10: + fmha_launch_kernel( + params, dev_ctx.stream(), load_func, store_func); + break; + case 26: + fmha_launch_kernel( + params, dev_ctx.stream(), load_func, store_func); + break; + case 32: + fmha_launch_kernel( + params, dev_ctx.stream(), load_func, store_func); + break; + case 64: + fmha_launch_kernel( + params, dev_ctx.stream(), load_func, store_func); + break; + case 96: + fmha_launch_kernel( + params, dev_ctx.stream(), load_func, store_func); + break; + case 128: + fmha_launch_kernel( + params, dev_ctx.stream(), load_func, store_func); + break; + case 192: + fmha_launch_kernel( + params, dev_ctx.stream(), load_func, store_func); + break; + default: + PADDLE_THROW( + phi::errors::Unimplemented("Dim_head = %d is unsupport!", dim_head)); + } +} + +template +struct MMHALoad { + explicit MMHALoad(const LoadT *src) : src_(src) {} + + template + __device__ void load(Vec &dst, int idx) { + dst = *reinterpret_cast(src_ + idx); + } + + const LoadT *src_; +}; + +template +struct MMHAStore { + explicit MMHAStore(StoreT *dst) : dst_(dst) {} + + template + __device__ void store(Vec &src, int idx) { + *reinterpret_cast(dst_ + idx) = src; + } + + StoreT *dst_; +}; + +template +struct MMHAStore { + MMHAStore(T *dst, const T *shift, const T *smooth, const int cols) + : dst_(dst), shift_(shift), smooth_(smooth), cols_(cols) {} + + template + __device__ void store(Vec &src, int idx) { + constexpr int VecSize = sizeof(Vec) / sizeof(T); + using TVec = phi::AlignedVector; + TVec src_vec; + TVec shift_vec; + TVec smooth_vec; + + *reinterpret_cast(&src_vec) = src; + phi::Load(shift_ + idx % cols_, &shift_vec); + phi::Load(smooth_ + idx % cols_, &smooth_vec); + +#pragma unroll + for (int i = 0; i < VecSize; i++) { + src_vec[i] = (src_vec[i] + shift_vec[i]) * smooth_vec[i]; + } + + phi::Store(src_vec, dst_ + idx); + } + + T *dst_; + const T *shift_; + const T *smooth_; + const int cols_; +}; + +template +struct MMHALoad { + MMHALoad(const int32_t *src, const float *dequant_scales, const int cols) + : src_(src), dequant_scales_(dequant_scales), cols_(cols) {} + + template + __device__ void load(Vec &dst, int idx) { + constexpr int VecSize = sizeof(Vec) / sizeof(T); + using SrcVec = phi::AlignedVector; + using DstVec = phi::AlignedVector; + using ScaleVec = phi::AlignedVector; + + SrcVec src_vec; + DstVec dst_vec; + ScaleVec scale_vec; + + phi::Load(src_ + idx, &src_vec); + phi::Load(dequant_scales_ + idx % cols_, &scale_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + dst_vec[i] = + static_cast(static_cast(src_vec[i]) * scale_vec[i]); + } + dst = *reinterpret_cast(&dst_vec); + } + + const int32_t *src_; + const float *dequant_scales_; + const int cols_; +}; + +template +struct MMHAStore { + MMHAStore(int8_t *dst, + const int quant_round_type, + const float quant_scale, + const float quant_max_bound, + const float quant_min_bound) + : dst_(dst), + quant_round_type_(quant_round_type), + quant_scale_(quant_scale), + quant_max_bound_(quant_max_bound), + quant_min_bound_(quant_min_bound) {} + + template + __device__ void store(Vec &src, int idx) { // NOLINT + constexpr int VecSize = sizeof(Vec) / sizeof(T); + using SrcVec = phi::AlignedVector; + using DstVec = phi::AlignedVector; + + SrcVec src_vec; + *reinterpret_cast(&src_vec) = src; + + DstVec dst_vec; +#pragma unroll + for (int i = 0; i < VecSize; i++) { + dst_vec[i] = + QuantHelperFunc(static_cast(src_vec[i]), + quant_scale_, + quant_round_type_, + quant_max_bound_, + quant_min_bound_); + } + + phi::Store(dst_vec, dst_ + idx); + } + + int8_t *dst_; + const int quant_round_type_; + const float quant_scale_; + const float quant_max_bound_; + const float quant_min_bound_; +}; + +template +struct MMHAStore { + MMHAStore(int8_t *dst, + const T *shift, + const T *smooth, + const int cols, + const int quant_round_type, + const float quant_scale, + const float quant_max_bound, + const float quant_min_bound) + : dst_(dst), + quant_round_type_(quant_round_type), + quant_scale_(quant_scale), + quant_max_bound_(quant_max_bound), + quant_min_bound_(quant_min_bound), + shift_(shift), + smooth_(smooth), + cols_(cols) {} + + template + __device__ void store(Vec &src, int idx) { // NOLINT + constexpr int VecSize = sizeof(Vec) / sizeof(T); + using SrcVec = phi::AlignedVector; + using DstVec = phi::AlignedVector; + + SrcVec src_vec; + DstVec dst_vec; + SrcVec shift_vec; + SrcVec smooth_vec; + + *reinterpret_cast(&src_vec) = src; + phi::Load(shift_ + idx % cols_, &shift_vec); + phi::Load(smooth_ + idx % cols_, &smooth_vec); + +#pragma unroll + for (int i = 0; i < VecSize; i++) { + src_vec[i] = (src_vec[i] + shift_vec[i]) * smooth_vec[i]; + dst_vec[i] = + QuantHelperFunc(static_cast(src_vec[i]), + quant_scale_, + quant_round_type_, + quant_max_bound_, + quant_min_bound_); + } + + phi::Store(dst_vec, dst_ + idx); + } + + int8_t *dst_; + const T *shift_; + const T *smooth_; + const int cols_; + const int quant_round_type_; + const float quant_scale_; + const float quant_max_bound_; + const float quant_min_bound_; +}; + +template +void DispatchFMHA(const phi::GPUContext &dev_ctx, + const phi::DenseTensor &qkv_tensor, + const Masked_multihead_attention_params ¶ms, + int num_head, + int dim_head, + phi::DenseTensor *out_tensor, + const phi::DenseTensor *dequant_qkv_scales = nullptr, + const float quant_fmha_out_scale = -1, + const int quant_round_type = 1, + const float quant_max_bound = 127.0f, + const float quant_min_bound = -127.0f) { + if (dequant_qkv_scales != nullptr && quant_fmha_out_scale > 0) { + MMHALoad load_func(qkv_tensor.data(), + dequant_qkv_scales->data(), + 3 * num_head * dim_head); + MMHAStore store_func(out_tensor->data(), + quant_round_type, + quant_fmha_out_scale, + quant_max_bound, + quant_min_bound); + fmha_impl(dev_ctx, params, dim_head, load_func, store_func); + } else if (dequant_qkv_scales == nullptr && quant_fmha_out_scale > 0) { + MMHALoad load_func(qkv_tensor.data()); + MMHAStore store_func(out_tensor->data(), + quant_round_type, + quant_fmha_out_scale, + quant_max_bound, + quant_min_bound); + fmha_impl(dev_ctx, params, dim_head, load_func, store_func); + } else if (dequant_qkv_scales != nullptr && quant_fmha_out_scale <= 0) { + MMHALoad load_func(qkv_tensor.data(), + dequant_qkv_scales->data(), + 3 * num_head * dim_head); + MMHAStore store_func(out_tensor->data()); + fmha_impl(dev_ctx, params, dim_head, load_func, store_func); + } else { + MMHALoad load_func(qkv_tensor.data()); + MMHAStore store_func(out_tensor->data()); + fmha_impl(dev_ctx, params, dim_head, load_func, store_func); + } +} + +template +void DispatchFMHA(const phi::GPUContext &dev_ctx, + const phi::DenseTensor &qkv_tensor, + const phi::DenseTensor &shift, + const phi::DenseTensor &smooth, + const Masked_multihead_attention_params ¶ms, + int num_head, + int dim_head, + phi::DenseTensor *out_tensor, + const phi::DenseTensor *dequant_qkv_scales = nullptr, + const float quant_fmha_out_scale = -1, + const int quant_round_type = 1, + const float quant_max_bound = 127.0f, + const float quant_min_bound = -127.0f) { + if (dequant_qkv_scales != nullptr && quant_fmha_out_scale > 0) { + MMHALoad load_func(qkv_tensor.data(), + dequant_qkv_scales->data(), + 3 * num_head * dim_head); + MMHAStore store_func(out_tensor->data(), + shift.data(), + smooth.data(), + num_head * dim_head, + quant_round_type, + quant_fmha_out_scale, + quant_max_bound, + quant_min_bound); + fmha_impl(dev_ctx, params, dim_head, load_func, store_func); + } else if (dequant_qkv_scales == nullptr && quant_fmha_out_scale > 0) { + MMHALoad load_func(qkv_tensor.data()); + MMHAStore store_func(out_tensor->data(), + shift.data(), + smooth.data(), + num_head * dim_head, + quant_round_type, + quant_fmha_out_scale, + quant_max_bound, + quant_min_bound); + fmha_impl(dev_ctx, params, dim_head, load_func, store_func); + } else if (dequant_qkv_scales != nullptr && quant_fmha_out_scale <= 0) { + MMHALoad load_func(qkv_tensor.data(), + dequant_qkv_scales->data(), + 3 * num_head * dim_head); + MMHAStore store_func(out_tensor->data(), + shift.data(), + smooth.data(), + num_head * dim_head); + fmha_impl(dev_ctx, params, dim_head, load_func, store_func); + } else { + MMHALoad load_func(qkv_tensor.data()); + MMHAStore store_func(out_tensor->data(), + shift.data(), + smooth.data(), + num_head * dim_head); + fmha_impl(dev_ctx, params, dim_head, load_func, store_func); + } +} + +struct NormalVersion {}; +struct UnusedVersion {}; + +template +struct DispatchDtypeTrait { + using FuncVersion = NormalVersion; +}; + +template <> +struct DispatchDtypeTrait { + using FuncVersion = UnusedVersion; +}; + +template +void DispatchWithDtype(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &cache_kv, + const paddle::optional &bias, + const paddle::optional &src_mask, + const paddle::optional &cum_offsets, + const paddle::optional &sequence_lengths, + const paddle::optional &rotary_tensor, + const paddle::optional &beam_cache_offset, + const paddle::optional &qkv_out_scale, + const paddle::optional &out_shift, + const paddle::optional &out_smooth, + int seq_len, + int rotary_emb_dims, + const bool use_neox_rotary_style, + const float out_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + DenseTensor *out, + DenseTensor *cache_kv_out, + DenseTensor *beam_cache_offset_out, + NormalVersion) { + const auto &x_dims = x.dims(); int bsz = x_dims[0]; int cache_bsz = cache_kv.dims()[1]; int num_head = cache_kv.dims()[2]; @@ -53,6 +1318,12 @@ void MMHAKernel(const Context& dev_ctx, Masked_multihead_attention_params params; bool mask_broadcast_num_heads = true; + params.add_qkv_bias = false; + if (bias) { + params.add_qkv_bias = true; + params.qkv_bias = const_cast(bias->data()); + } + if (src_mask) { if (src_mask->dims()[1] == 1) { mask_broadcast_num_heads = true; @@ -99,9 +1370,8 @@ void MMHAKernel(const Context& dev_ctx, } params.mask_broadcast_num_heads = mask_broadcast_num_heads; - params.cache_kv = const_cast(cache_kv_out->data()); + params.cache_kv = const_cast(cache_kv_out->data()); params.neox_rotary_style = use_neox_rotary_style; - params.add_qkv_bias = false; params.batch_size = bsz; params.cache_batch_size = cache_bsz; params.num_head = num_head; @@ -138,6 +1408,175 @@ void MMHAKernel(const Context& dev_ctx, quant_max_bound, quant_min_bound); } +} + +template +void DispatchWithDtype(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &cache_kv, + const paddle::optional &bias, + const paddle::optional &src_mask, + const paddle::optional &cum_offsets, + const paddle::optional &sequence_lengths, + const paddle::optional &rotary_tensor, + const paddle::optional &beam_cache_offset, + const paddle::optional &qkv_out_scale, + const paddle::optional &out_shift, + const paddle::optional &out_smooth, + int seq_len, + int rotary_emb_dims, + const bool use_neox_rotary_style, + const float out_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + DenseTensor *out, + DenseTensor *cache_kv_out, + DenseTensor *beam_cache_offset_out, + UnusedVersion) {} + +#endif // PADDLE_WITH_HIP + +template +void MMHAKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &cache_kv, + const paddle::optional &bias, + const paddle::optional &src_mask, + const paddle::optional &cum_offsets, + const paddle::optional &sequence_lengths, + const paddle::optional &rotary_tensor, + const paddle::optional &beam_cache_offset, + const paddle::optional &qkv_out_scale, + const paddle::optional &out_shift, + const paddle::optional &out_smooth, + int seq_len, + int rotary_emb_dims, + const bool use_neox_rotary_style, + const std::string &compute_dtype, + const float out_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + DenseTensor *out, + DenseTensor *cache_kv_out, + DenseTensor *beam_cache_offset_out) { +#ifndef PADDLE_WITH_HIP + if (x.dtype() == phi::DataType::INT32) { + switch (str2int(compute_dtype.c_str())) { + case str2int("fp16"): + DispatchWithDtype( + dev_ctx, + x, + cache_kv, + bias, + src_mask, + cum_offsets, + sequence_lengths, + rotary_tensor, + beam_cache_offset, + qkv_out_scale, + out_shift, + out_smooth, + seq_len, + rotary_emb_dims, + use_neox_rotary_style, + out_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + out, + cache_kv_out, + beam_cache_offset_out, + typename DispatchDtypeTrait::FuncVersion{}); + break; +#if CUDA_VERSION >= 11000 + case str2int("bf16"): + DispatchWithDtype( + dev_ctx, + x, + cache_kv, + bias, + src_mask, + cum_offsets, + sequence_lengths, + rotary_tensor, + beam_cache_offset, + qkv_out_scale, + out_shift, + out_smooth, + seq_len, + rotary_emb_dims, + use_neox_rotary_style, + out_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + out, + cache_kv_out, + beam_cache_offset_out, + typename DispatchDtypeTrait::FuncVersion{}); + break; +#endif + case str2int("fp32"): + DispatchWithDtype( + dev_ctx, + x, + cache_kv, + bias, + src_mask, + cum_offsets, + sequence_lengths, + rotary_tensor, + beam_cache_offset, + qkv_out_scale, + out_shift, + out_smooth, + seq_len, + rotary_emb_dims, + use_neox_rotary_style, + out_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + out, + cache_kv_out, + beam_cache_offset_out, + typename DispatchDtypeTrait::FuncVersion{}); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "In the case of quantization enabled with Input(x) INT32, " + "Attr(compute_dtype) must be set in (bf16, fp16, fp32), " + "but get compute_dtype (%s)", + compute_dtype)); + } + } else { + DispatchWithDtype( + dev_ctx, + x, + cache_kv, + bias, + src_mask, + cum_offsets, + sequence_lengths, + rotary_tensor, + beam_cache_offset, + qkv_out_scale, + out_shift, + out_smooth, + seq_len, + rotary_emb_dims, + use_neox_rotary_style, + out_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + out, + cache_kv_out, + beam_cache_offset_out, + typename DispatchDtypeTrait::FuncVersion{}); + } #endif // PADDLE_WITH_HIP } @@ -151,12 +1590,14 @@ PD_REGISTER_KERNEL(masked_multihead_attention, phi::fusion::MMHAKernel, float, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + int32_t) {} #else PD_REGISTER_KERNEL(masked_multihead_attention, GPU, ALL_LAYOUT, phi::fusion::MMHAKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + int32_t) {} #endif diff --git a/paddle/phi/kernels/fusion/gpu/masked_multihead_attention.h b/paddle/phi/kernels/fusion/gpu/masked_multihead_attention.h deleted file mode 100644 index 861c1f6e5ac1e3edf41d04c63d617e7b0e45a9d4..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/fusion/gpu/masked_multihead_attention.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/fusion/gpu/masked_multihead_attention_utils.h" - -namespace phi { -namespace fusion { - -template -void MMHAKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& cache_kv, - const paddle::optional& src_mask, - const paddle::optional& cum_offsets, - const paddle::optional& sequence_lengths, - const paddle::optional& rotary_tensor, - const paddle::optional& beam_cache_offset, - const paddle::optional& qkv_out_scale, - const paddle::optional& out_shift, - const paddle::optional& out_smooth, - int seq_len, - int rotary_emb_dims, - const bool use_neox_rotary_style, - const float out_scale, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - DenseTensor* out, - DenseTensor* cache_kv_out, - DenseTensor* beam_cache_offset_out); - -} // namespace fusion -} // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/masked_multihead_attention_utils.h b/paddle/phi/kernels/fusion/gpu/masked_multihead_attention_utils.h deleted file mode 100644 index 4e715032a26b1d17366b6d654377ddbea7310d3e..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/fusion/gpu/masked_multihead_attention_utils.h +++ /dev/null @@ -1,1306 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for mmha kernel. -*/ - -#ifndef PADDLE_WITH_HIP -#pragma once - -#include "glog/logging.h" -#include "paddle/phi/kernels/funcs/aligned_vector.h" -#include "paddle/phi/kernels/funcs/mmha_util.cu.h" - -namespace phi { -namespace fusion { - -#define MMHA_USE_FP32_ACUM_FOR_LOGITS -#define MMHA_USE_FP32_ACUM_FOR_OUT -#define MMHA_USE_FP32_ACUM_FOR_FMA - -template -__device__ __inline__ T ClipFunc(const T v, const T min, const T max) { - if (v > max) return max; - if (v < min) return min; - return v; -} - -template -__forceinline__ __device__ OutType QuantHelperFunc(const InType input, - const float scale, - const int round_type, - const float max_bound, - const float min_bound) { - float quant_value = max_bound * scale * input; - - if (round_type == 0) { - quant_value = static_cast(rint(quant_value)); - } else { - quant_value = static_cast(round(quant_value)); - } - return static_cast( - ClipFunc(quant_value, min_bound, max_bound)); -} - -template -struct Masked_multihead_attention_params { - // output buffer, [B, 1(seq_len), num_head * dim_head] - T *out; - // qkv_out, [B, 1(seq_len), 3, num_head * dim_head] - const T *qkv; - // bias, [3, num_head, dim_head] - T *qkv_bias; - // [bsz, seq_len] - const int *cum_offsets; - // TODO(wangxi): optimize with input_lengths and max_input_len? - // [bsz, 1, 1, time_step(cache_seq_length)+1] - const T *attn_mask; - int mask_length; - // whether to broadcast num_heads(2nd) dimension for attn_mask - // in MMHA, if false, attn_mask shape should be - // [bsz, num_heads, 1, time_step(cache_seq_length)+1] - bool mask_broadcast_num_heads; - - // [2, B, num_head, max_seq_len(valid cache_seq_len), dim_head] - // k [B, num_head, dim_head/x, max_seq_len, x], that is `seq_len` first - // v [B, num_head, max_seq_len, dim_head] - T *cache_kv; - // [B, max_seq_len] - const int *beam_cache_offset = nullptr; - - const int *sequence_lengths{nullptr}; - - // The RoPE embedding, [2, B, rotary_seq_len, 1, dim_head] - // rotary_emb_dims = 1 if pos_ids_extra is null else 2 - const float *rotary_emb; - int rotary_emb_dims; - int rotary_seq_len = 1; - - int batch_size; // batch * beam - int beam_width; - int cache_batch_size; - int num_head; - int timestep; // cache_seq_length - int seq_len; - int max_seq_length; - - // 1.f / sqrt(Dh) - float inv_sqrt_dh; - - bool add_qkv_bias; - bool neox_rotary_style; -}; - -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA -template -struct K_vec_acum_fp32_ {}; - -template <> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -#endif - -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT -template -struct V_vec_acum_fp32_ {}; -// template <> struct V_vec_acum_fp32_ { using Type = float; }; -// template <> struct V_vec_acum_fp32_ { using Type = float2; }; -template <> -struct V_vec_acum_fp32_ { - using Type = float4; -}; -// template <> struct V_vec_acum_fp32_ { using Type = float2; }; -// template <> struct V_vec_acum_fp32_ { using Type = Float4_; }; -template <> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; - -#ifdef ENABLE_BF16 -template <> -struct V_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template <> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -template <> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; -#endif // ENABLE_BF16 - -#endif - -// clang-format on - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_dot_(const K_vec (&q)[N], - const K_vec (&k)[N], - float inv_sqrt_dh) { - K_vec inv_q = mul(q[0], inv_sqrt_dh); - K_vec qk_vec = mul(inv_q, k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - inv_q = mul(q[ii], inv_sqrt_dh); - qk_vec = fma(inv_q, k[ii], qk_vec); - } - - float qk = sum(qk_vec); -#pragma unroll - for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); - } - return qk; -} - -inline __device__ float4 hmma_fp32_tensorcore(const uint2 &a, uint32_t b) { - float4 c; - float zero = 0.f; - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" - " {%0, %1, %2, %3}, \n" - " {%4, %5}, \n" - " {%6}, \n" - " {%7, %7, %7, %7}; \n" - - : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) - : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); - return c; -} - -template -inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], - const uint32_t (&k)[N], - float inv_sqrt_dh) { -#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ - __CUDA_ARCH__ >= 750 -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = uint32_t; -#endif - K_vec_acum inv_q = mul(q[0], inv_sqrt_dh); - K_vec_acum qk_vec = mul(inv_q, k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - inv_q = mul(q[ii], inv_sqrt_dh); - qk_vec = fma(inv_q, k[ii], qk_vec); - } -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - uint32_t qk_vec_ = float2_to_half2(qk_vec); - return hmma_fp32_tensorcore(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; -#else - return hmma_fp32_tensorcore(make_uint2(qk_vec, 0u), 0x3c003c00u).x; -#endif -#else - return 0.f; -#endif -} - -template -struct Qk_dot { - template - static inline __device__ float dot(const K_vec (&q)[N], - const K_vec (&k)[N], - float inv_sqrt_dh) { - return qk_dot_(q, k, inv_sqrt_dh); - } -}; - -template <> -struct Qk_dot { - template - static inline __device__ float dot(const uint32_t (&q)[N], - const uint32_t (&k)[N], - float inv_sqrt_dh) { -#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ - __CUDA_ARCH__ >= 750 - return qk_hmma_dot_(q, k, inv_sqrt_dh); -#else - return qk_dot_<4>(q, k, inv_sqrt_dh); -#endif - } -}; - -template -inline __device__ float block_sum(float *red_smem, float sum) { - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - if (lane == 0) { - red_smem[warp] = sum; - } - __syncthreads(); - - if (lane < WARPS_PER_BLOCK) { - sum = red_smem[lane]; - } - -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - return __shfl_sync(uint32_t(-1), sum, 0); -} - -inline __device__ void convert_from_float(float &dst, float src) { // NOLINT - dst = src; -} - -inline __device__ void convert_from_float(float4 &dst, float4 src) { // NOLINT - dst = src; -} - -inline __device__ void convert_from_float(phi::float16 &dst, // NOLINT - float src) { - dst = static_cast(src); -} - -inline __device__ void convert_from_float(uint4 &dst, Float8_ src) { // NOLINT - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); - dst.z = float2_to_half2(src.z); - dst.w = float2_to_half2(src.w); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(__nv_bfloat16 &dst, // NOLINT - float src) { // NOLINT - dst = __float2bfloat16(src); -} - -inline __device__ void convert_from_float(__nv_bfloat162 &dst, // NOLINT - float2 src) { // NOLINT -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst = __float22bfloat162_rn(src); -#else - dst = __floats2bfloat162_rn(src.x, src.y); -#endif -} - -inline __device__ void convert_from_float(bf16_4_t &dst, // NOLINT - Float4_ src) { // NOLINT -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_4_t &dst, // NOLINT - float4 src) { // NOLINT - convert_from_float( - dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); -} - -inline __device__ void convert_from_float(bf16_8_t &dst, // NOLINT - Float8_ src) { // NOLINT -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); - dst.z = __float22bfloat162_rn(src.z); - dst.w = __float22bfloat162_rn(src.w); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); - dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); - dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); -#endif -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void zero(uint16_t &dst) { dst = uint16_t(0); } // NOLINT - -template -inline __device__ void zero(T &dst) { // NOLINT - constexpr int WORDS = sizeof(T) / 4; - union { - T raw; - uint32_t words[WORDS]; - } tmp; -#pragma unroll - for (int ii = 0; ii < WORDS; ++ii) { - tmp.words[ii] = 0u; - } - dst = tmp.raw; -} - -template -__global__ void masked_multihead_attention_kernel( - Masked_multihead_attention_params params, - LoadFunc load_func, - StoreFunc store_func) { -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) - const int bi = blockIdx.y; - if (params.sequence_lengths && params.sequence_lengths[bi] == 0) { - return; - } - - typedef PDDataTypeTraits traits_; - typedef typename traits_::DataType DataType_; - - static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); - static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); - - constexpr int WARP_SIZE = 32; - constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; - - extern __shared__ char smem_[]; - - float *qk_smem = reinterpret_cast(smem_); - - char *logits_smem_ = smem_; - // fp32 accum for logits - float *logits_smem = reinterpret_cast(logits_smem_); - - T *out_smem = reinterpret_cast(smem_); - - __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - using Qk_vec = typename Qk_vec_::Type; - using Qk_vec_RoPE = typename Qk_vec_RoPE_::Type; - __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; - - // beam id - const int beami = bi % params.beam_width; - // real batch id - const int bbi = bi / params.beam_width; - const int hi = blockIdx.x; - const int bhi = bi * params.num_head + hi; - const int bbhi = bbi * params.beam_width * params.num_head + hi; - const int ti = - params.cum_offsets ? bi * params.seq_len - params.cum_offsets[bi] : -1; - const int thi = params.cum_offsets ? ti * params.num_head + hi : -1; - const int tid = threadIdx.x; - - const int bi_seq_len_offset = bi * params.max_seq_length; - - float qk_max = -FLT_MAX; - float qk = 0; - - int act_time_step = params.sequence_lengths == nullptr - ? params.timestep - : params.sequence_lengths[bi]; - - // qkv [B, S=1, 3, num_head, head_dim] - int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; - - constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); - static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); - // Use block reduction if needed - // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); - constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; - - // cache_k, [B, num_head, head_dim / x, max_seq_len, x] - // x == 4/8 for FP32/FP16, 128bit, 16Byte - constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); - constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); - - // const T *q_base = params.qkv; - // const T *k_base = params.qkv + params.num_head * Dh; - T *q_bias_base = nullptr; - T *k_bias_base = nullptr; - - if (params.add_qkv_bias) { - q_bias_base = params.qkv_bias; - k_bias_base = params.qkv_bias + params.num_head * Dh; - } - - if (tid < QK_VECS_PER_WARP) { - int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE; - int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE; - - Qk_vec q; - zero(q); - // q = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - // ? *reinterpret_cast(&q_base[qk_offset]) - // : q; - if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) { - load_func.template load(q, qk_offset); - } - - Qk_vec k; - zero(k); - // k = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - // ? *reinterpret_cast(&k_base[qk_offset]) - // : k; - if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) { - load_func.template load(k, params.num_head * Dh + qk_offset); - } - - if (params.add_qkv_bias) { - Qk_vec q_bias; - zero(q_bias); - Qk_vec k_bias; - zero(k_bias); - - q_bias = - (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&q_bias_base[qk_bias_offset]) - : q_bias; - k_bias = - (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&k_bias_base[qk_bias_offset]) - : k_bias; - - q = add(q, q_bias); - // TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510 - // we may not require k_bias. - k = add(k, k_bias); - } - - if (!params.neox_rotary_style) { - if (params.rotary_emb_dims != 0) { - int rotary_offset = bi * Dh + tid * QK_VEC_SIZE; - const float *cos_base = params.rotary_emb; - const float *sin_base = params.rotary_emb + params.batch_size * Dh; - Qk_vec_RoPE cos_emb, sin_emb; - zero(cos_emb); - zero(sin_emb); - cos_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast( - &cos_base[rotary_offset]) - : cos_emb; - sin_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast( - &sin_base[rotary_offset]) - : sin_emb; - apply_rotary_embedding(q, k, cos_emb, sin_emb); - } - } else { - /* old rotary pos emb */ - if (params.rotary_emb_dims != 0) { - int last_dim = Dh / params.rotary_emb_dims; - int half_lastdim = last_dim / 2; - int rotary_offset = bi * Dh + tid * QK_VEC_SIZE; - const float *cos_base = params.rotary_emb; - const float *sin_base = params.rotary_emb + params.batch_size * Dh; - int stride = half_lastdim / QK_VEC_SIZE; - int stride_all_lastdim = 2 * stride; - int right_id = tid / stride_all_lastdim * stride_all_lastdim + - (tid + stride) % (stride_all_lastdim); - int qk_right_offset = qkv_base_offset + right_id * QK_VEC_SIZE; - int qk_right_bias_offset = hi * Dh + right_id * QK_VEC_SIZE; - Qk_vec q_right; - zero(q_right); - // q_right = - // (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) - // ? *reinterpret_cast(&q_base[qk_right_offset]) - // : q_right; - if (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) { - load_func.template load(q_right, qk_right_offset); - } - Qk_vec k_right; - zero(k_right); - // k_right = - // (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) - // ? *reinterpret_cast(&k_base[qk_right_offset]) - // : k_right; - if (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) { - load_func.template load( - k_right, params.num_head * Dh + qk_right_offset); - } - - if (params.add_qkv_bias) { - Qk_vec q_right_bias; - zero(q_right_bias); - q_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) - ? *reinterpret_cast( - &q_bias_base[qk_right_bias_offset]) - : q_right_bias; - Qk_vec k_right_bias; - zero(k_right_bias); - k_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) - ? *reinterpret_cast( - &k_bias_base[qk_right_bias_offset]) - : k_right_bias; - - q_right = add(q_right, q_right_bias); - k_right = add(k_right, k_right_bias); - } - - Qk_vec_RoPE cos_emb; - zero(cos_emb); - cos_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast( - &cos_base[rotary_offset]) - : cos_emb; - - Qk_vec_RoPE sin_emb; - zero(sin_emb); - sin_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast( - &sin_base[rotary_offset]) - : sin_emb; - float alpha = (tid % stride_all_lastdim) < stride - ? static_cast(-1) - : static_cast(1); - q = apply_rotary_emb( - q, q_right, cos_emb, sin_emb, alpha); - k = apply_rotary_emb( - k, k_right, cos_emb, sin_emb, alpha); - } - } - - *reinterpret_cast(&q_smem[tid * QK_VEC_SIZE]) = q; - - int co = tid / QK_VECS_IN_16B; - int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE; - int offset = bhi * params.max_seq_length * Dh + - co * params.max_seq_length * QK_ELTS_IN_16B + - act_time_step * QK_ELTS_IN_16B + ci; - if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - *reinterpret_cast(¶ms.cache_kv[offset]) = k; - } - - qk = dot(q, k); - - if (QK_VECS_PER_WARP <= WARP_SIZE) { -#pragma unroll - for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); - } - } - } - if (QK_VECS_PER_WARP > WARP_SIZE) { - constexpr int WARPS_PER_RED = - (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; - qk = block_sum(&red_smem[WARPS_PER_RED], qk); - } - if (tid == 0) { - // NOTE(wangxi): mask must be 0.0 - // T mask = params.attn_mask[ - // bi * (params.timestep + 1) + params.timestep]; - // qk += static_cast(mask); - qk *= params.inv_sqrt_dh; - qk_max = qk; - qk_smem[act_time_step] = qk; - } - __syncthreads(); - - using K_vec = typename K_vec_::Type; - constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); - static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); - constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; - constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; - - int ko = tid / THREADS_PER_KEY; - int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE; - - static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD, ""); - - K_vec q[K_VECS_PER_THREAD]; -#pragma unroll - for (int i = 0; i < K_VECS_PER_THREAD; ++i) { - q[i] = *reinterpret_cast( - &q_smem[ki + i * THREADS_PER_KEY * K_VEC_SIZE]); - } - - constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; - - T *k_cache = ¶ms.cache_kv[bhi * params.max_seq_length * Dh + ki]; - T *k_cache_batch = ¶ms.cache_kv[bbhi * params.max_seq_length * Dh + ki]; - int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP; - - const int *beam_offsets = params.beam_cache_offset - ? ¶ms.beam_cache_offset[bi_seq_len_offset] - : nullptr; - for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { - const int beam_offset = beam_offsets ? beam_offsets[ti] * params.num_head * - params.max_seq_length * Dh - : 0; - K_vec k[K_VECS_PER_THREAD]; - K_vec k_vec_zero; - zero(k_vec_zero); -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.max_seq_length + ti; - if (ti < act_time_step) { - if (beam_offset) { - k[ii] = - (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) - ? *reinterpret_cast( - &k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]) - : k_vec_zero; - } else { - k[ii] = - (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) - ? *reinterpret_cast( - &k_cache[jj * QK_ELTS_IN_16B]) - : k_vec_zero; - } - } - } - - // NOTE(liyurui): We should multiple q with inv_sqrt_dh first, for dot(q, k) - // may overflow with FP16 in large model. - float qk = Qk_dot::dot(q, k, params.inv_sqrt_dh); - - // bool is_mask = false; - if (ti < act_time_step && tid % THREADS_PER_KEY == 0) { - // qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); - auto mask_bhi = params.mask_broadcast_num_heads ? bi : bhi; - // T mask = params.attn_mask[mask_bhi * (params.timestep + 1) + ti]; - if (params.attn_mask) { - T mask = params.attn_mask[mask_bhi * params.mask_length + ti]; - qk += static_cast(mask); - } - qk_max = fmaxf(qk_max, qk); - - qk_smem[ti] = qk; - } - } - -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - const int warp = tid / WARP_SIZE; - const int lane = tid % WARP_SIZE; - - if (lane == 0) { - red_smem[warp] = qk_max; - } - - __syncthreads(); - - qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - - float sum = 0.f; - for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) { - // bool is_mask = false; - // float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max); - float logit = __expf(qk_smem[ti] - qk_max); - sum += logit; - qk_smem[ti] = logit; - } - - sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); - - // FIXME(wangxi): need add 1.e-6f? - float inv_sum = __fdividef(1.f, sum + 1.e-6f); - - for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) { - convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum); - } - __syncthreads(); - - constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; - using V_vec = typename V_vec_::Type; - - int vo = tid / THREADS_PER_VALUE; - int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE; - - T *v_cache = ¶ms.cache_kv[params.cache_batch_size * params.num_head * - params.max_seq_length * Dh + - bhi * params.max_seq_length * Dh + vi]; - T *v_cache_batch = ¶ms.cache_kv[params.batch_size * params.num_head * - params.max_seq_length * Dh + - bbhi * params.max_seq_length * Dh + vi]; - -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - using V_vec_acum = typename V_vec_acum_fp32_::Type; -#else - using V_vec_acum = V_vec; -#endif - - V_vec_acum out; - zero(out); - - constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - if (Dh == Dh_MAX || vi < Dh) { - for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) { - const int beam_offset = - beam_offsets - ? beam_offsets[ti] * params.num_head * params.max_seq_length * Dh - : 0; - V_vec v; - if (beam_offset) { - v = *reinterpret_cast( - &v_cache_batch[beam_offset + ti * Dh]); - } else { - v = *reinterpret_cast(&v_cache[ti * Dh]); - } -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - float logit = logits_smem[ti]; - out = fma(logit, cast_to_float(v), out); -#else - DataType_ logit = static_cast(logits_smem[ti]); - // Update the partial sums. - out = fma(logit, v, out); -#endif - } - } - - V_vec v_bias; - zero(v_bias); - if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { - // V_vec v = *reinterpret_cast( - // ¶ms.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); - V_vec v; - load_func.template load( - v, 2 * params.num_head * Dh + qkv_base_offset + vi); - if (params.add_qkv_bias) { - v_bias = *reinterpret_cast( - ¶ms.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); - v = add(v, v_bias); - } - - *reinterpret_cast(&v_cache[act_time_step * Dh]) = v; - -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - out = fma(logits_smem[act_time_step], cast_to_float(v), out); -#else - out = fma(logits_smem[act_time_step], v, out); -#endif - } - - __syncthreads(); - - if (Dh == Dh_MAX || vi < Dh) { -#pragma unroll - for (int active_groups = V_PER_ITER; active_groups >= 2; - active_groups /= 2) { - int midpoint = active_groups / 2; - - if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - convert_from_float( - *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), - out); -#else - *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; -#endif - } - __syncthreads(); - if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { - out = - add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); - } - __syncthreads(); - } - } - - if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - // convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + - // vi]), - // out); - V_vec tmp_out; - convert_from_float(tmp_out, out); - store_func.template store(tmp_out, - thi != -1 ? thi * Dh + vi : bhi * Dh + vi); -#else - // *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; - store_func.template store(out, - thi != -1 ? thi * Dh + vi : bhi * Dh + vi); -#endif - } - -#else - assert(false); -#endif -} - -template -inline size_t smem_size_in_bytes( - const Masked_multihead_attention_params ¶ms, - int dim_head, - int threads_per_value, - int threads_per_block) { - size_t qk_sz = div_up(params.timestep + 1, 4) * 16; - size_t logits_sz = 0; - -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS // NOLINT - if (sizeof(T) != 4) { - logits_sz = div_up(params.max_seq_length, 4) * 4 * sizeof(T); - } -#endif // NOLINT - size_t softmax_sz = qk_sz + logits_sz; - - int rows_per_red = threads_per_block / threads_per_value; - size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2; - - return max(softmax_sz, red_sz); -} - -#define MMHA_LAUNCH_KERNEL(T, \ - Dh, \ - Dh_MAX, \ - THDS_PER_KEY, \ - THDS_PER_VALUE, \ - THDS_PER_BLOCK, \ - stream, \ - load_func, \ - store_func) \ - size_t smem_sz = \ - smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ - constexpr auto kernel_fn = \ - masked_multihead_attention_kernel; \ - if (smem_sz > 0xc000) { \ - cudaFuncSetAttribute( \ - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ - } \ - dim3 grid(params.num_head, params.batch_size); \ - kernel_fn<<>>( \ - params, load_func, store_func) - -template -void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, - const cudaStream_t &stream, - LoadFunc load_func, - StoreFunc store_func) { - constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; - if (params.timestep < 32) { - MMHA_LAUNCH_KERNEL( - T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream, load_func, store_func); - } else if (params.timestep < 2048) { -#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ - __CUDA_ARCH__ >= 750 - MMHA_LAUNCH_KERNEL(T, - Dh, - Dh_MAX, - 4, - THREADS_PER_VALUE, - 256, - stream, - load_func, - store_func); -#else - MMHA_LAUNCH_KERNEL(T, - Dh, - Dh_MAX, - 2, - THREADS_PER_VALUE, - 128, - stream, - load_func, - store_func); -#endif - } else { - MMHA_LAUNCH_KERNEL(T, - Dh, - Dh_MAX, - 1, - THREADS_PER_VALUE, - 256, - stream, - load_func, - store_func); - } -} - -template -void fmha_impl(const phi::GPUContext &dev_ctx, - const Masked_multihead_attention_params ¶ms, - int dim_head, - LoadFunc load_func, - StoreFunc store_func) { - switch (dim_head) { - case 10: - fmha_launch_kernel( - params, dev_ctx.stream(), load_func, store_func); - break; - case 26: - fmha_launch_kernel( - params, dev_ctx.stream(), load_func, store_func); - break; - case 32: - fmha_launch_kernel( - params, dev_ctx.stream(), load_func, store_func); - break; - case 64: - fmha_launch_kernel( - params, dev_ctx.stream(), load_func, store_func); - break; - case 96: - fmha_launch_kernel( - params, dev_ctx.stream(), load_func, store_func); - break; - case 128: - fmha_launch_kernel( - params, dev_ctx.stream(), load_func, store_func); - break; - case 192: - fmha_launch_kernel( - params, dev_ctx.stream(), load_func, store_func); - break; - default: - PADDLE_THROW( - phi::errors::Unimplemented("Dim_head = %d is unsupport!", dim_head)); - } -} - -template -struct MMHALoad { - explicit MMHALoad(const LoadT *src) : src_(src) {} - - template - __device__ void load(Vec &dst, int idx) { - dst = *reinterpret_cast(src_ + idx); - } - - const LoadT *src_; -}; - -template -struct MMHAStore { - explicit MMHAStore(StoreT *dst) : dst_(dst) {} - - template - __device__ void store(Vec &src, int idx) { - *reinterpret_cast(dst_ + idx) = src; - } - - StoreT *dst_; -}; - -template -struct MMHAStore { - MMHAStore(T *dst, const T *shift, const T *smooth, const int cols) - : dst_(dst), shift_(shift), smooth_(smooth), cols_(cols) {} - - template - __device__ void store(Vec &src, int idx) { - constexpr int VecSize = sizeof(Vec) / sizeof(T); - using TVec = phi::AlignedVector; - TVec src_vec; - TVec shift_vec; - TVec smooth_vec; - - *reinterpret_cast(&src_vec) = src; - phi::Load(shift_ + idx % cols_, &shift_vec); - phi::Load(smooth_ + idx % cols_, &smooth_vec); - -#pragma unroll - for (int i = 0; i < VecSize; i++) { - src_vec[i] = (src_vec[i] + shift_vec[i]) * smooth_vec[i]; - } - - phi::Store(src_vec, dst_ + idx); - } - - T *dst_; - const T *shift_; - const T *smooth_; - const int cols_; -}; - -template -struct MMHALoad { - MMHALoad(const int32_t *src, const float *dequant_scales, const int cols) - : src_(src), dequant_scales_(dequant_scales), cols_(cols) {} - - template - __device__ void load(Vec &dst, int idx) { - constexpr int VecSize = sizeof(Vec) / sizeof(T); - using SrcVec = phi::AlignedVector; - using DstVec = phi::AlignedVector; - using ScaleVec = phi::AlignedVector; - - SrcVec src_vec; - DstVec dst_vec; - ScaleVec scale_vec; - - phi::Load(src_ + idx, &src_vec); - phi::Load(dequant_scales_ + idx % cols_, &scale_vec); -#pragma unroll - for (int i = 0; i < VecSize; i++) { - dst_vec[i] = - static_cast(static_cast(src_vec[i]) * scale_vec[i]); - } - dst = *reinterpret_cast(&dst_vec); - } - - const int32_t *src_; - const float *dequant_scales_; - const int cols_; -}; - -template -struct MMHAStore { - MMHAStore(int8_t *dst, - const int quant_round_type, - const float quant_scale, - const float quant_max_bound, - const float quant_min_bound) - : dst_(dst), - quant_round_type_(quant_round_type), - quant_scale_(quant_scale), - quant_max_bound_(quant_max_bound), - quant_min_bound_(quant_min_bound) {} - - template - __device__ void store(Vec &src, int idx) { // NOLINT - constexpr int VecSize = sizeof(Vec) / sizeof(T); - using SrcVec = phi::AlignedVector; - using DstVec = phi::AlignedVector; - - SrcVec src_vec; - *reinterpret_cast(&src_vec) = src; - - DstVec dst_vec; -#pragma unroll - for (int i = 0; i < VecSize; i++) { - dst_vec[i] = - QuantHelperFunc(static_cast(src_vec[i]), - quant_scale_, - quant_round_type_, - quant_max_bound_, - quant_min_bound_); - } - - phi::Store(dst_vec, dst_ + idx); - } - - int8_t *dst_; - const int quant_round_type_; - const float quant_scale_; - const float quant_max_bound_; - const float quant_min_bound_; -}; - -template -struct MMHAStore { - MMHAStore(int8_t *dst, - const T *shift, - const T *smooth, - const int cols, - const int quant_round_type, - const float quant_scale, - const float quant_max_bound, - const float quant_min_bound) - : dst_(dst), - quant_round_type_(quant_round_type), - quant_scale_(quant_scale), - quant_max_bound_(quant_max_bound), - quant_min_bound_(quant_min_bound), - shift_(shift), - smooth_(smooth), - cols_(cols) {} - - template - __device__ void store(Vec &src, int idx) { // NOLINT - constexpr int VecSize = sizeof(Vec) / sizeof(T); - using SrcVec = phi::AlignedVector; - using DstVec = phi::AlignedVector; - - SrcVec src_vec; - DstVec dst_vec; - SrcVec shift_vec; - SrcVec smooth_vec; - - *reinterpret_cast(&src_vec) = src; - phi::Load(shift_ + idx % cols_, &shift_vec); - phi::Load(smooth_ + idx % cols_, &smooth_vec); - -#pragma unroll - for (int i = 0; i < VecSize; i++) { - src_vec[i] = (src_vec[i] + shift_vec[i]) * smooth_vec[i]; - dst_vec[i] = - QuantHelperFunc(static_cast(src_vec[i]), - quant_scale_, - quant_round_type_, - quant_max_bound_, - quant_min_bound_); - } - - phi::Store(dst_vec, dst_ + idx); - } - - int8_t *dst_; - const T *shift_; - const T *smooth_; - const int cols_; - const int quant_round_type_; - const float quant_scale_; - const float quant_max_bound_; - const float quant_min_bound_; -}; - -template -void DispatchFMHA(const phi::GPUContext &dev_ctx, - const phi::DenseTensor &qkv_tensor, - const Masked_multihead_attention_params ¶ms, - int num_head, - int dim_head, - phi::DenseTensor *out_tensor, - const phi::DenseTensor *dequant_qkv_scales = nullptr, - const float quant_fmha_out_scale = -1, - const int quant_round_type = 1, - const float quant_max_bound = 127.0f, - const float quant_min_bound = -127.0f) { - if (dequant_qkv_scales != nullptr && quant_fmha_out_scale > 0) { - MMHALoad load_func(qkv_tensor.data(), - dequant_qkv_scales->data(), - 3 * num_head * dim_head); - MMHAStore store_func(out_tensor->data(), - quant_round_type, - quant_fmha_out_scale, - quant_max_bound, - quant_min_bound); - fmha_impl(dev_ctx, params, dim_head, load_func, store_func); - } else if (dequant_qkv_scales == nullptr && quant_fmha_out_scale > 0) { - MMHALoad load_func(qkv_tensor.data()); - MMHAStore store_func(out_tensor->data(), - quant_round_type, - quant_fmha_out_scale, - quant_max_bound, - quant_min_bound); - fmha_impl(dev_ctx, params, dim_head, load_func, store_func); - } else if (dequant_qkv_scales != nullptr && quant_fmha_out_scale <= 0) { - MMHALoad load_func(qkv_tensor.data(), - dequant_qkv_scales->data(), - 3 * num_head * dim_head); - MMHAStore store_func(out_tensor->data()); - fmha_impl(dev_ctx, params, dim_head, load_func, store_func); - } else { - MMHALoad load_func(qkv_tensor.data()); - MMHAStore store_func(out_tensor->data()); - fmha_impl(dev_ctx, params, dim_head, load_func, store_func); - } -} - -template -void DispatchFMHA(const phi::GPUContext &dev_ctx, - const phi::DenseTensor &qkv_tensor, - const phi::DenseTensor &shift, - const phi::DenseTensor &smooth, - const Masked_multihead_attention_params ¶ms, - int num_head, - int dim_head, - phi::DenseTensor *out_tensor, - const phi::DenseTensor *dequant_qkv_scales = nullptr, - const float quant_fmha_out_scale = -1, - const int quant_round_type = 1, - const float quant_max_bound = 127.0f, - const float quant_min_bound = -127.0f) { - if (dequant_qkv_scales != nullptr && quant_fmha_out_scale > 0) { - MMHALoad load_func(qkv_tensor.data(), - dequant_qkv_scales->data(), - 3 * num_head * dim_head); - MMHAStore store_func(out_tensor->data(), - shift.data(), - smooth.data(), - num_head * dim_head, - quant_round_type, - quant_fmha_out_scale, - quant_max_bound, - quant_min_bound); - fmha_impl(dev_ctx, params, dim_head, load_func, store_func); - } else if (dequant_qkv_scales == nullptr && quant_fmha_out_scale > 0) { - MMHALoad load_func(qkv_tensor.data()); - MMHAStore store_func(out_tensor->data(), - shift.data(), - smooth.data(), - num_head * dim_head, - quant_round_type, - quant_fmha_out_scale, - quant_max_bound, - quant_min_bound); - fmha_impl(dev_ctx, params, dim_head, load_func, store_func); - } else if (dequant_qkv_scales != nullptr && quant_fmha_out_scale <= 0) { - MMHALoad load_func(qkv_tensor.data(), - dequant_qkv_scales->data(), - 3 * num_head * dim_head); - MMHAStore store_func(out_tensor->data(), - shift.data(), - smooth.data(), - num_head * dim_head); - fmha_impl(dev_ctx, params, dim_head, load_func, store_func); - } else { - MMHALoad load_func(qkv_tensor.data()); - MMHAStore store_func(out_tensor->data(), - shift.data(), - smooth.data(), - num_head * dim_head); - fmha_impl(dev_ctx, params, dim_head, load_func, store_func); - } -} - -} // namespace fusion -} // namespace phi - -#endif // PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/funcs/mmha_util.cu.h b/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h similarity index 99% rename from paddle/phi/kernels/funcs/mmha_util.cu.h rename to paddle/phi/kernels/fusion/gpu/mmha_util.cu.h index 94657b96b63e0e86c202e95bffa9b81300368397..ed311e520681f05c5d32c7eba34e7fa75d945b44 100644 --- a/paddle/phi/kernels/funcs/mmha_util.cu.h +++ b/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h @@ -48,7 +48,6 @@ */ #ifndef PADDLE_WITH_HIP - #pragma once #if defined(__CUDACC__) && CUDA_VERSION >= 11000 @@ -66,8 +65,6 @@ namespace phi { namespace fusion { -namespace { // NOLINT - struct Float8_ { float2 x; float2 y; @@ -1712,8 +1709,6 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, // NOLINT } #endif // ENABLE_BF16 -} // namespace - } // namespace fusion } // namespace phi diff --git a/python/paddle/incubate/nn/functional/masked_multihead_attention.py b/python/paddle/incubate/nn/functional/masked_multihead_attention.py index 55dd20c8089115124b4548c9d4a9d8f883fa088a..93b9b1419855b72f3493916fe6bf38ce445482e3 100644 --- a/python/paddle/incubate/nn/functional/masked_multihead_attention.py +++ b/python/paddle/incubate/nn/functional/masked_multihead_attention.py @@ -19,6 +19,7 @@ from paddle.framework import LayerHelper, in_dynamic_mode def masked_multihead_attention( x, cache_kv=None, + bias=None, src_mask=None, cum_offsets=None, sequence_lengths=None, @@ -30,6 +31,7 @@ def masked_multihead_attention( seq_len=1, rotary_emb_dims=0, use_neox_rotary_style=False, + compute_dtype='default', out_scale=-1, quant_round_type=1, quant_max_bound=127.0, @@ -43,6 +45,7 @@ def masked_multihead_attention( Args: x (Tensor): The input tensor could be 2-D tensor. Its shape is [batch_size, 3 * num_head * head_dim]. cache_kvs (list(Tensor)|tuple(Tensor)): The cache structure tensors for the generation model. Its shape is [2, batch_size, num_head, max_seq_len, head_dim]. + bias (Tensor, optional): The bias tensor. Its shape is [3, num_head, head_dim]. src_mask (Tensor, optional): The src_mask tensor. Its shape is [batch_size, 1, 1, sequence_length]. sequence_lengths (Tensor, optional): The sequence_lengths tensor, used to index input. Its shape is [batch_size, 1]. rotary_tensor (Tensor, optional): The rotary_tensor tensor. The dtype must be float. Its shape is [batch_size, 1, 1, sequence_length, head_dim]. @@ -53,6 +56,7 @@ def masked_multihead_attention( seq_len (int, optional): The seq_len, used to get input length. Default 1. rotary_emb_dims (int, optional): The rotary_emb_dims. Default 1. use_neox_rotary_style (bool, optional): A flag indicating whether neox_rotary_style is needed or not. Default False. + compute_dtype (string): A compute dtype, used to represent the input data type. out_scale (float, optional): The out_scale, used in quant. quant_round_type (int, optional): The quant_round_type, used in quant. Default 1. quant_max_bound (float, optional): The quant_max_bound, used in quant. Default 127.0. @@ -89,6 +93,7 @@ def masked_multihead_attention( return _C_ops.masked_multihead_attention_( x, cache_kv, + bias, src_mask, cum_offsets, sequence_lengths, @@ -100,6 +105,7 @@ def masked_multihead_attention( seq_len, rotary_emb_dims, use_neox_rotary_style, + compute_dtype, out_scale, quant_round_type, quant_max_bound, @@ -107,11 +113,22 @@ def masked_multihead_attention( ) helper = LayerHelper('masked_multihead_attention', **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) + if x.dtype == "int32": + if compute_dtype == "bf16": + dtype = "uint16" + elif compute_dtype == "fp16": + dtype = "float16" + elif compute_dtype == "fp32": + dtype = "float32" + out = helper.create_variable_for_type_inference(dtype=dtype) + else: + out = helper.create_variable_for_type_inference(dtype=x.dtype) inputs = {} inputs['x'] = x inputs['cache_kv'] = cache_kv + if bias is not None: + inputs['bias'] = bias if src_mask is not None: inputs['src_mask'] = src_mask if cum_offsets is not None: @@ -148,6 +165,7 @@ def masked_multihead_attention( 'seq_len': seq_len, 'rotary_emb_dims': rotary_emb_dims, 'use_neox_rotary_style': use_neox_rotary_style, + 'compute_dtype': compute_dtype, 'out_scale': out_scale, 'quant_round_type': quant_round_type, 'quant_max_bound': quant_max_bound, diff --git a/test/legacy_test/test_masked_multihead_attention_op.py b/test/legacy_test/test_masked_multihead_attention_op.py index 546b4e8353349be08c8428afacdfea8ceabe69cb..6a7709cc69cad79426881f444e8e99536ed7c7f1 100644 --- a/test/legacy_test/test_masked_multihead_attention_op.py +++ b/test/legacy_test/test_masked_multihead_attention_op.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import unittest import numpy as np @@ -43,6 +42,10 @@ class TestMMHAOp(unittest.TestCase): 2, 10, size=(self.bsz, 3, self.num_head, self.dim_head) ).astype("int") + self.bias = np.random.uniform( + -0.05, 0.05, [3, self.num_head, self.dim_head] + ) + self.src_mask = np.zeros([self.bsz, 1, 1, self.sequence_length + 1]) self.cum_offsets = None @@ -77,7 +80,7 @@ class TestMMHAOp(unittest.TestCase): self.seq_len = 1 self.rotary_emb_dims = 0 self.use_neox_rotary_style = False - + self.compute_dtype = "default" self.out_scale = 10 self.quant_round_type = 1 self.quant_max_bound = 126 @@ -100,6 +103,7 @@ class TestMMHAOp(unittest.TestCase): self, x, cache_kv_out, + bias, src_mask, qkv_out_scale, seq_len, @@ -110,9 +114,9 @@ class TestMMHAOp(unittest.TestCase): bsz, ): if qkv_out_scale is not None: - x = x.cast(cache_kv_out.dtype) * qkv_out_scale + x = x.cast(cache_kv_out.dtype) * qkv_out_scale + bias else: - x = x + x = x + bias x = paddle.transpose( x, [0, 2, 1, 3] @@ -145,6 +149,7 @@ class TestMMHAOp(unittest.TestCase): x, cache_kv_out, cache_kv_mmha_out, + bias, src_mask, qkv_out_scale, out_scale, @@ -157,12 +162,14 @@ class TestMMHAOp(unittest.TestCase): else: x = paddle.to_tensor(x).cast(dtype) src_mask = paddle.to_tensor(src_mask).cast(dtype) + bias = paddle.to_tensor(bias).cast(dtype) cache_kv_out = paddle.to_tensor(cache_kv_out).cast(dtype) cache_kv_mmha_out = paddle.to_tensor(cache_kv_mmha_out).cast(dtype) paddle_naive_mmha_out = 0 paddle_naive_mmha_out = self.mmha_naive( x, cache_kv_out, + bias, src_mask, qkv_out_scale, self.seq_len, @@ -174,9 +181,14 @@ class TestMMHAOp(unittest.TestCase): ) x = x.reshape([self.bsz, -1]) + if x.dtype == paddle.float16: + dtype = self.compute_dtype + else: + dtype = "fp16" paddle_mmha_out = masked_multihead_attention( x, cache_kv_mmha_out, + bias, src_mask, None, None, @@ -188,6 +200,7 @@ class TestMMHAOp(unittest.TestCase): self.seq_len, self.rotary_emb_dims, self.use_neox_rotary_style, + dtype, out_scale, self.quant_round_type, self.quant_max_bound, @@ -204,6 +217,7 @@ class TestMMHAOp(unittest.TestCase): self.x, self.cache_kv_out, self.cache_kv_mmha_out, + self.bias, self.src_mask, None, -1, @@ -224,6 +238,7 @@ class TestMMHAOp(unittest.TestCase): self.x_int, self.cache_kv_out, self.cache_kv_mmha_out, + self.bias, self.src_mask, self.qkv_out_scale, -1, @@ -244,6 +259,7 @@ class TestMMHAOp(unittest.TestCase): self.x, self.cache_kv_out, self.cache_kv_mmha_out, + self.bias, self.src_mask, None, self.out_scale, @@ -274,6 +290,9 @@ class TestLayerNormStaticInt8Op(unittest.TestCase): self.x = np.random.uniform( -0.05, 0.05, [self.bsz, 3, self.num_head, self.dim_head] ) + self.bias = np.random.uniform( + -0.05, 0.05, [3, self.num_head, self.dim_head] + ) self.src_mask = np.zeros([self.bsz, 1, 1, self.sequence_length + 1]) self.cum_offsets = None @@ -317,6 +336,7 @@ class TestLayerNormStaticInt8Op(unittest.TestCase): self, x, cache_kv_out, + bias, src_mask, qkv_out_scale, seq_len, @@ -327,7 +347,9 @@ class TestLayerNormStaticInt8Op(unittest.TestCase): bsz, ): if qkv_out_scale is not None: - x = x.cast(cache_kv_out.dtype) * qkv_out_scale + x = x.cast(cache_kv_out.dtype) * qkv_out_scale + bias + else: + x = x + bias x = paddle.transpose( x, [0, 2, 1, 3] @@ -351,6 +373,7 @@ class TestLayerNormStaticInt8Op(unittest.TestCase): def check_main( self, x, + bias, src_mask, cache_kv_out, cache_kv_mmha_out, @@ -361,11 +384,13 @@ class TestLayerNormStaticInt8Op(unittest.TestCase): paddle.disable_static() x_tensor = paddle.to_tensor(x).cast(dtype) src_mask_tensor = paddle.to_tensor(src_mask).cast(dtype) + bias_tensor = paddle.to_tensor(bias).cast(dtype) cache_kv_out = paddle.to_tensor(cache_kv_out).cast(dtype) paddle_naive_mmha_out = self.mmha_naive( x_tensor, cache_kv_out, + bias_tensor, src_mask_tensor, None, self.seq_len, @@ -383,6 +408,11 @@ class TestLayerNormStaticInt8Op(unittest.TestCase): shape=[self.bsz, 3 * self.num_head * self.dim_head], dtype=dtype, ) + bias_static = paddle.static.data( + name="bias_static", + shape=[3, self.num_head, self.dim_head], + dtype=dtype, + ) src_mask_static = paddle.static.data( name="src_mask_static", shape=[self.bsz, 1, 1, self.sequence_length + 1], @@ -403,6 +433,7 @@ class TestLayerNormStaticInt8Op(unittest.TestCase): outs = masked_multihead_attention( x_static, cache_kv_mmha_out_static, + bias_static, src_mask_static, None, None, @@ -414,6 +445,7 @@ class TestLayerNormStaticInt8Op(unittest.TestCase): 32, 0, False, + "fp16", -1, 1, 127.0, @@ -424,6 +456,7 @@ class TestLayerNormStaticInt8Op(unittest.TestCase): feed={ "x_static": x.reshape(self.bsz, -1).astype(dtype), "cache_kv_mmha_out_static": cache_kv_mmha_out.astype(dtype), + "bias_static": bias.astype(dtype), "src_mask_static": src_mask.astype(dtype), }, fetch_list=[outs], @@ -437,6 +470,7 @@ class TestLayerNormStaticInt8Op(unittest.TestCase): paddle_naive_mmha_out, paddle_mmha_out = self.check_main( self.x, + self.bias, self.src_mask, self.cache_kv_out, self.cache_kv_mmha_out,