未验证 提交 989c5e87 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Add masked multihead attention kernel and export API. (#55344)

* support_mmha
* add_python_api
* add_api_doc
* fix_doc_error
* fix_infermeta
* add_infermeta
* add_bf16_cuda_check
* add_bf16_check
* fix_ci_windows
* fix_ci_windows_kernel_register
* fix_test_mmha
* add_cumoffsets
* remove_bias
* delete_mmha_reshape_input_output
* rename_delete_hfile
* remove_fluid

---------
Co-authored-by: Nyangjianfengo1 <yangjianfeng01@baidu.com>
上级 2a5d1d54
...@@ -931,9 +931,8 @@ __global__ void masked_multihead_attention_kernel( ...@@ -931,9 +931,8 @@ __global__ void masked_multihead_attention_kernel(
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
if (bi == 0 && hi == 0 && tid == 0) { if (bi == 0 && hi == 0 && tid == 0) {
printf("=======q_out=======\n"); VLOG(0) << "=======q_out=======\n";
for (int i = 0; i < Dh; ++i) printf("%f ", static_cast<float>(q_smem[i])); for (int i = 0; i < Dh; ++i) VLOG(0) << static_cast<float>(q_smem[i]);
printf("\n");
} }
__syncthreads(); __syncthreads();
#endif #endif
......
...@@ -1616,6 +1616,17 @@ ...@@ -1616,6 +1616,17 @@
data_type : logits data_type : logits
backward : margin_cross_entropy_grad backward : margin_cross_entropy_grad
- op : masked_multihead_attention_
args : (Tensor x, Tensor cache_kv, Tensor src_mask, Tensor cum_offsets, Tensor sequence_lengths, Tensor rotary_tensor, Tensor beam_cache_offset, Tensor qkv_out_scale, Tensor out_shift, Tensor out_smooth, int seq_len, int rotary_emb_dims, bool use_neox_rotary_style=false, float out_scale=-1, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0)
output : Tensor(out), Tensor(cache_kv_out), Tensor(beam_cache_offset_out)
infer_meta :
func : MaskedMultiheadAttentionInferMeta
kernel :
func : masked_multihead_attention
data_type : cache_kv
optional : src_mask, cum_offsets, sequence_lengths, rotary_tensor, beam_cache_offset, qkv_out_scale, out_shift, out_smooth
inplace : (cache_kv -> cache_kv_out), (beam_cache_offset -> beam_cache_offset_out)
- op : masked_select - op : masked_select
args : (Tensor x, Tensor mask) args : (Tensor x, Tensor mask)
output : Tensor (out) output : Tensor (out)
......
...@@ -3983,5 +3983,69 @@ void WeightOnlyMatmulInferMeta(const MetaTensor& x, ...@@ -3983,5 +3983,69 @@ void WeightOnlyMatmulInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv,
const MetaTensor& src_mask,
const MetaTensor& cum_offsets,
const MetaTensor& sequence_lengths,
const MetaTensor& rotary_tensor,
const MetaTensor& beam_cache_offset,
const MetaTensor& qkv_out_scale,
const MetaTensor& out_shift,
const MetaTensor& out_smooth,
int seq_len,
int rotary_emb_dims,
const bool use_neox_rotary_style,
const float out_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* cache_kv_out,
MetaTensor* beam_cache_offset_out) {
int bsz = x.dims()[0];
auto x_dtype = x.dtype();
auto cache_kv_dims = cache_kv.dims();
int num_head = cache_kv.dims()[2];
int dim_head = cache_kv.dims()[4];
PADDLE_ENFORCE_EQ(
cache_kv_dims.size(),
5,
errors::InvalidArgument("The cache_kv must be 5 dims, but got %d",
cache_kv_dims.size()));
PADDLE_ENFORCE_EQ(
cache_kv_dims[0],
2,
errors::InvalidArgument("The first dim of cache_kv must be 2, but got %d",
cache_kv_dims[0]));
if (rotary_tensor) {
PADDLE_ENFORCE_EQ(
rotary_tensor.dtype(),
DataType::FLOAT32,
errors::InvalidArgument(
"The dtype of rotary_tensor must be float32, but got %d",
rotary_tensor.dtype()));
}
out->set_dims({bsz, num_head * dim_head});
if (out_scale > 0) {
out->set_dtype(DataType::INT8);
} else {
out->set_dtype(x_dtype);
}
cache_kv_out->set_dims(cache_kv_dims);
cache_kv_out->set_dtype(cache_kv.dtype());
if (beam_cache_offset) {
beam_cache_offset_out->set_dims(beam_cache_offset.dims());
beam_cache_offset_out->set_dtype(beam_cache_offset.dtype());
}
}
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
...@@ -773,4 +773,25 @@ void FusedRopeInferMeta(const MetaTensor& q, ...@@ -773,4 +773,25 @@ void FusedRopeInferMeta(const MetaTensor& q,
MetaTensor* out_k, MetaTensor* out_k,
MetaTensor* out_v); MetaTensor* out_v);
void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
const MetaTensor& cache_kv,
const MetaTensor& src_mask,
const MetaTensor& cum_offsets,
const MetaTensor& sequence_lengths,
const MetaTensor& rotary_tensor,
const MetaTensor& beam_cache_offset,
const MetaTensor& qkv_out_scale,
const MetaTensor& out_shift,
const MetaTensor& out_smooth,
int seq_len,
int rotary_emb_dims,
const bool use_neox_rotary_style,
const float out_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
MetaTensor* out,
MetaTensor* cache_kv_out,
MetaTensor* beam_cache_offset_out);
} // namespace phi } // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor used by mmha kernel.
*/
#ifndef PADDLE_WITH_HIP
#pragma once
#if defined(__CUDACC__) && CUDA_VERSION >= 11000
#define ENABLE_BF16
#include <cuda_bf16.h>
#endif
#include <cuda_fp16.h>
#include <float.h>
#include <cub/cub.cuh>
#include "paddle/phi/common/datatype_traits.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
namespace fusion {
namespace { // NOLINT
struct Float8_ {
float2 x;
float2 y;
float2 z;
float2 w;
};
struct Float4_ {
float2 x;
float2 y;
};
#ifdef ENABLE_BF16
struct bf16_4_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
};
struct bf16_8_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
__nv_bfloat162 z;
__nv_bfloat162 w;
};
#endif
template <typename T, int Dh>
struct Qk_vec_ {};
template <>
struct Qk_vec_<float, 32> {
using Type = float;
};
template <>
struct Qk_vec_<float, 64> {
using Type = float2;
};
template <>
struct Qk_vec_<float, 128> {
using Type = float4;
};
template <>
struct Qk_vec_<float, 256> {
using Type = float4;
};
template <>
struct Qk_vec_<float16, 32> {
using Type = uint32_t;
};
template <>
struct Qk_vec_<float16, 64> {
using Type = uint32_t;
};
template <>
struct Qk_vec_<float16, 128> {
using Type = uint2;
};
template <>
struct Qk_vec_<float16, 256> {
using Type = uint4;
};
#ifdef ENABLE_BF16
template <>
struct Qk_vec_<bfloat16, 32> {
using Type = __nv_bfloat162;
};
template <>
struct Qk_vec_<bfloat16, 64> {
using Type = __nv_bfloat162;
};
template <>
struct Qk_vec_<bfloat16, 128> {
using Type = bf16_4_t;
};
template <>
struct Qk_vec_<bfloat16, 256> {
using Type = bf16_8_t;
};
#endif // ENABLE_BF16
// RoPE Type
template <typename T1, typename T2, int Dh>
struct Qk_vec_RoPE_ {};
template <>
struct Qk_vec_RoPE_<float16, float, 32> {
using Type = float2;
};
template <>
struct Qk_vec_RoPE_<float16, float, 64> {
using Type = float2;
};
template <>
struct Qk_vec_RoPE_<float16, float, 128> {
using Type = float4;
};
template <>
struct Qk_vec_RoPE_<float16, float, 256> {
using Type = Float8_;
};
template <>
struct Qk_vec_RoPE_<float, float, 32> {
using Type = float;
};
template <>
struct Qk_vec_RoPE_<float, float, 64> {
using Type = float2;
};
template <>
struct Qk_vec_RoPE_<float, float, 128> {
using Type = float4;
};
template <>
struct Qk_vec_RoPE_<float, float, 256> {
using Type = float4;
};
#ifdef ENABLE_BF16
template <>
struct Qk_vec_RoPE_<bfloat16, float, 32> {
using Type = float2;
};
template <>
struct Qk_vec_RoPE_<bfloat16, float, 64> {
using Type = float2;
};
template <>
struct Qk_vec_RoPE_<bfloat16, float, 128> {
using Type = float4;
};
template <>
struct Qk_vec_RoPE_<bfloat16, float, 256> {
using Type = Float8_;
};
#endif
//------------------------------------
template <typename T, int THREADS_PER_KEY>
struct K_vec_ {};
template <>
struct K_vec_<float, 4> {
using Type = float;
};
template <>
struct K_vec_<float, 2> {
using Type = float2;
};
template <>
struct K_vec_<float, 1> {
using Type = float4;
};
template <>
struct K_vec_<float16, 4> {
using Type = uint32_t;
};
template <>
struct K_vec_<float16, 2> {
using Type = uint2;
};
template <>
struct K_vec_<float16, 1> {
using Type = uint4;
};
#ifdef ENABLE_BF16
template <>
struct K_vec_<bfloat16, 4> {
using Type = __nv_bfloat162;
};
template <>
struct K_vec_<bfloat16, 2> {
using Type = bf16_4_t;
};
template <>
struct K_vec_<bfloat16, 1> {
using Type = bf16_8_t;
};
#endif // ENABLE_BF16
template <typename T, int V_VEC_SIZE>
struct V_vec_ {};
template <>
struct V_vec_<float, 1> {
using Type = float;
};
template <>
struct V_vec_<float, 2> {
using Type = float2;
};
template <>
struct V_vec_<float, 4> {
using Type = float4;
};
template <>
struct V_vec_<float16, 2> {
using Type = uint32_t;
};
template <>
struct V_vec_<float16, 4> {
using Type = uint2;
};
template <>
struct V_vec_<float16, 8> {
using Type = uint4;
};
#ifdef ENABLE_BF16
template <>
struct V_vec_<bfloat16, 2> {
using Type = __nv_bfloat162;
};
template <>
struct V_vec_<bfloat16, 4> {
using Type = bf16_4_t;
};
template <>
struct V_vec_<bfloat16, 8> {
using Type = bf16_8_t;
};
#endif // ENABLE_BF16
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x,
const __nv_bfloat162 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
#else
return __hmul2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x,
const __nv_bfloat16 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y));
#else
return __hmul(x, y);
#endif
}
#endif // ENABLE_BF16
inline __device__ float half_to_float(uint16_t h) {
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
}
inline __device__ float2 half2_to_float2(uint32_t v) {
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
return make_float2(half_to_float(lo), half_to_float(hi));
}
inline __device__ uint32_t float2_to_half2(float2 f) {
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
: "=r"(tmp.u32)
: "f"(f.y), "f"(f.x));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif
return tmp.u32;
}
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x,
const __nv_bfloat162 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
#else
return __hadd2(x, y);
#endif
}
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = __low2float(val);
f_val.y = __high2float(val);
return f_val;
#else
return __bfloat1622float2(val);
#endif
}
inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
__nv_bfloat162 val_;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
val_ = __float22bfloat162_rn(val);
#else
val_.x = __float2bfloat16_rn(val.x);
val_.y = __float2bfloat16_rn(val.y);
#endif
return val_;
}
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__nv_bfloat162 val2;
val2.x = val;
val2.y = val;
return val2;
#else
return __bfloat162bfloat162(val);
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x,
const __nv_bfloat162 y,
const __nv_bfloat162 z) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh, fzl, fzh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
fzl = __low2float(z);
fzh = __high2float(z);
return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
#else
return __hfma2(x, y, z);
#endif
}
#endif // ENABLE_BF16
inline __device__ float add(float a, float b) { return a + b; }
inline __device__ float2 add(float2 a, float2 b) {
float2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ float4 add(float4 a, float4 b) {
float4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
uint16_t c;
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
inline __device__ uint2 add(uint2 a, uint2 b) {
uint2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ uint4 add(uint4 a, uint4 b) {
uint4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ float2 add(uint32_t a, float2 fb) {
float2 fa = half2_to_float2(a);
return add(fa, fb);
}
inline __device__ Float8_ add(uint4 a, Float8_ fb) {
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
return a + b;
}
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
return bf16hadd2(a, b);
}
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
bf16_4_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
bf16_8_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ float add(float a, __nv_bfloat16 b) {
return a + __bfloat162float(b);
}
inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
float2 fa = bf1622float2(a);
return add(fa, fb);
}
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
Float4_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
return fc;
}
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
#endif // ENABLE_BF16
template <typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b);
template <>
inline __device__ float mul<float, float>(float a, float b) {
return a * b;
}
template <>
inline __device__ float2 mul(float2 a, float2 b) {
float2 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
return c;
}
template <>
inline __device__ float4 mul(float4 a, float4 b) {
float4 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
c.z = a.z * b.z;
c.w = a.w * b.w;
return c;
}
template <>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c;
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
template <>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
template <>
inline __device__ uint2 mul(uint2 a, uint2 b) {
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
return c;
}
template <>
inline __device__ uint4 mul(uint4 a, uint4 b) {
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
return c;
}
template <>
inline __device__ uint32_t mul(uint32_t a, float b) {
float2 tmp = half2_to_float2(a);
float2 tmp_res;
tmp_res.x = tmp.x * b;
tmp_res.y = tmp.y * b;
uint32_t res = float2_to_half2(tmp_res);
return res;
}
template <>
inline __device__ uint32_t mul(uint32_t a, float2 b) {
float2 tmp = half2_to_float2(a);
float2 tmp_res;
tmp_res.x = tmp.x * b.x;
tmp_res.y = tmp.y * b.y;
uint32_t res = float2_to_half2(tmp_res);
return res;
}
template <>
inline __device__ float2 mul(uint32_t a, float b) {
float2 tmp = half2_to_float2(a);
float2 res;
res.x = tmp.x * b;
res.y = tmp.y * b;
return res;
}
template <>
inline __device__ uint2 mul(uint2 a, float b) {
uint2 res;
res.x = mul<uint32_t, uint32_t, float>(a.x, b);
res.y = mul<uint32_t, uint32_t, float>(a.y, b);
return res;
}
template <>
inline __device__ uint2 mul(uint2 a, float4 b) {
Float4_& b_ = *reinterpret_cast<Float4_*>(&b);
uint2 res;
res.x = mul<uint32_t, uint32_t, float2>(a.x, b_.x);
res.y = mul<uint32_t, uint32_t, float2>(a.y, b_.y);
return res;
}
template <>
inline __device__ uint4 mul(uint4 a, float b) {
uint4 res;
res.x = mul<uint32_t, uint32_t, float>(a.x, b);
res.y = mul<uint32_t, uint32_t, float>(a.y, b);
res.z = mul<uint32_t, uint32_t, float>(a.z, b);
res.w = mul<uint32_t, uint32_t, float>(a.w, b);
return res;
}
template <>
inline __device__ uint4 mul(uint4 a, Float8_ b) {
uint4 res;
res.x = mul<uint32_t, uint32_t, float2>(a.x, b.x);
res.y = mul<uint32_t, uint32_t, float2>(a.y, b.y);
res.z = mul<uint32_t, uint32_t, float2>(a.z, b.z);
res.w = mul<uint32_t, uint32_t, float2>(a.w, b.w);
return res;
}
template <>
inline __device__ float2 mul(float2 a, float b) {
float2 res;
res.x = a.x * b;
res.y = a.y * b;
return res;
}
template <>
inline __device__ float2 mul(float2 a, uint32_t b) {
float2 tmp_b = half2_to_float2(b);
float2 res;
res.x = a.x * tmp_b.x;
res.y = a.y * tmp_b.y;
return res;
}
template <>
inline __device__ float4 mul(float4 a, float b) {
float4 res;
res.x = a.x * b;
res.y = a.y * b;
res.z = a.z * b;
res.w = a.w * b;
return res;
}
#ifdef ENABLE_BF16
template <>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hmul(a, b);
#else
return bf16hmul(a, b);
#endif
}
template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
return bf16hmul2(a, b);
}
template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
template <>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return c;
}
template <>
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return c;
}
template <>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return c;
}
template <>
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return c;
}
template <>
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
float fa = static_cast<float>(a);
float fb = static_cast<float>(b);
return fa * fb;
}
template <>
inline __device__ float mul(__nv_bfloat16 a, float b) {
return __bfloat162float(a) * b;
}
template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, float b) {
__nv_bfloat162 res;
__nv_bfloat162 _bf16 = __float2bfloat162_rn(b);
res = bf16hmul2(a, _bf16);
return res;
}
template <>
inline __device__ __nv_bfloat162 mul(float2 a, float2 b) {
float2 res = mul<float2, float2, float2>(a, b);
__nv_bfloat162 bf16_res = float22bf162(res);
return bf16_res;
}
template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, float2 b) {
float2 a_ = bf1622float2(a);
float2 res = mul<float2, float2, float2>(a_, b);
__nv_bfloat162 bf16_res = float22bf162(res);
return bf16_res;
}
template <>
inline __device__ bf16_4_t mul(bf16_4_t a, float b) {
__nv_bfloat162 s = __float2bfloat162_rn(b);
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, s);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, s);
return c;
}
template <>
inline __device__ bf16_4_t mul(bf16_4_t a, float4 b) {
Float4_& b_ = *reinterpret_cast<Float4_*>(&b);
float2 a1 = bf1622float2(a.x);
float2 a2 = bf1622float2(a.y);
bf16_4_t c;
c.x = mul<__nv_bfloat162, float2, float2>(a1, b_.x);
c.y = mul<__nv_bfloat162, float2, float2>(a2, b_.y);
return c;
}
template <>
inline __device__ bf16_8_t mul(bf16_8_t a, float b) {
__nv_bfloat162 s = __float2bfloat162_rn(b);
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, s);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, s);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, s);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, s);
return c;
}
template <>
inline __device__ bf16_8_t mul(bf16_8_t a, Float8_ b) {
float2 a1 = bf1622float2(a.x);
float2 a2 = bf1622float2(a.y);
float2 a3 = bf1622float2(a.z);
float2 a4 = bf1622float2(a.w);
bf16_8_t c;
c.x = mul<__nv_bfloat162, float2, float2>(a1, b.x);
c.y = mul<__nv_bfloat162, float2, float2>(a2, b.y);
c.z = mul<__nv_bfloat162, float2, float2>(a3, b.z);
c.w = mul<__nv_bfloat162, float2, float2>(a4, b.w);
return c;
}
template <>
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb);
}
template <>
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
template <>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return fc;
}
template <>
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a);
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return fc;
}
template <>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return fc;
}
template <>
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a);
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return fc;
}
#endif // ENABLE_BF16
template <typename Qk_vec, typename Qk_vec_RoPE>
inline __device__ Qk_vec apply_rotary_emb(Qk_vec input_left,
Qk_vec input_right,
Qk_vec_RoPE cos_emb,
Qk_vec_RoPE sin_emb,
float alpha) {
Qk_vec res1 = mul<Qk_vec, Qk_vec, Qk_vec_RoPE>(input_left, cos_emb);
Qk_vec res2 = mul<Qk_vec, Qk_vec, Qk_vec_RoPE>(input_right, sin_emb);
res2 = mul<Qk_vec, Qk_vec, float>(res2, alpha);
return add(res1, res2);
}
inline __device__ float sum(float v) { return v; }
inline __device__ float sum(float2 v) { return v.x + v.y; }
inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; }
inline __device__ float sum(uint16_t v) { return half_to_float(v); }
inline __device__ float sum(uint32_t v) {
float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y;
}
inline __device__ float sum(uint2 v) {
uint32_t c = add(v.x, v.y);
return sum(c);
}
inline __device__ float sum(uint4 v) {
uint32_t c = add(v.x, v.y);
c = add(c, v.z);
c = add(c, v.w);
return sum(c);
}
#ifdef ENABLE_BF16
inline __device__ float sum(__nv_bfloat162 v) {
float2 vf = bf1622float2(v);
return vf.x + vf.y;
}
inline __device__ float sum(bf16_4_t v) { return sum(v.x) + sum(v.y); }
inline __device__ float sum(bf16_8_t v) {
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
}
#endif // ENABLE_BF16
template <typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b));
}
template <typename A, typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b));
}
inline __device__ constexpr uint32_t shfl_mask(int threads) {
return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u;
}
template <typename T>
inline __device__ __host__ T div_up(T m, T n) {
return (m + n - 1) / n;
}
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ float2 fma(float2 a, uint32_t b, float2 c) {
float2 tmp_b = half2_to_float2(b);
float2 d = fma(a, tmp_b, c);
return d;
}
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
float4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
return d;
}
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
uint2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
uint4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ float2 fma(float a, float2 b, float2 c) {
float2 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
inline __device__ float4 fma(float a, float4 b, float4 c) {
float4 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
Float8_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
Float4_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat162 fma(float a, float2 b, __nv_bfloat162 c) {
return bf16hfma2(__float2bfloat162_rn(a), float22bf162(b), c);
}
inline __device__ bf16_4_t fma(float a, Float4_ b, bf16_4_t c) {
bf16_4_t d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
#endif // ENABLE_BF16
inline __device__ uint32_t h0_h0(uint16_t a) {
uint32_t b;
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
return b;
}
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
return fma(h0_h0(a), b, c);
}
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
uint32_t s = h0_h0(a);
uint2 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
uint32_t s = h0_h0(a);
uint4 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a,
__nv_bfloat162 b,
__nv_bfloat162 c) {
return bf16hfma2(a, b, c);
}
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a,
__nv_bfloat162 b,
__nv_bfloat162 c) {
return bf16hfma2(bf162bf162(a), b, c);
}
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
bf16_4_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
bf16_8_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
return __bfloat162float(a) * __bfloat162float(b) + fc;
}
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return fma(fa, fb, fc);
}
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
return fma(bf162bf162(a), b, fc);
}
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
__nv_bfloat162 s = bf162bf162(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
__nv_bfloat162 s = bf162bf162(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}
#endif // ENABLE_BF16
inline __device__ float cast_to_float(float u) { return u; }
inline __device__ float2 cast_to_float(float2 u) { return u; }
inline __device__ float4 cast_to_float(float4 u) { return u; }
inline __device__ float2 cast_to_float(uint32_t u) {
return half2_to_float2(u);
}
inline __device__ Float4_ cast_to_float(uint2 u) {
Float4_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
return tmp;
}
inline __device__ Float8_ cast_to_float(uint4 u) {
Float8_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
tmp.z = half2_to_float2(u.z);
tmp.w = half2_to_float2(u.w);
return tmp;
}
inline __device__ Float4_ cast_to_float(Float4_ u) { return u; }
inline __device__ Float8_ cast_to_float(Float8_ u) { return u; }
#ifdef ENABLE_BF16
inline __device__ float cast_to_float(__nv_bfloat16 u) {
return __bfloat162float(u);
}
inline __device__ float2 cast_to_float(__nv_bfloat162 u) {
return bf1622float2(u);
}
inline __device__ Float4_ cast_to_float(bf16_4_t u) {
Float4_ tmp;
tmp.x = bf1622float2(u.x);
tmp.y = bf1622float2(u.y);
return tmp;
}
inline __device__ Float8_ cast_to_float(bf16_8_t u) {
Float8_ tmp;
tmp.x = bf1622float2(u.x);
tmp.y = bf1622float2(u.y);
tmp.z = bf1622float2(u.z);
tmp.w = bf1622float2(u.w);
return tmp;
}
#endif // ENABLE_BF16
inline __device__ float2 rotary_embedding_coefficient(const int zid,
const int rot_embed_dim,
const float t_step) {
const float inv_freq =
t_step / pow(10000.0f, zid / static_cast<float>(rot_embed_dim));
return {cos(inv_freq), sin(inv_freq)};
}
inline __device__ float2 rotary_embedding_transform(const float2 v,
const float2 coef) {
float2 rot_v;
rot_v.x = coef.x * v.x - coef.y * v.y;
rot_v.y = coef.x * v.y + coef.y * v.x;
return rot_v;
}
inline __device__ float2 rotary_embedding_transform(const float2 v,
const float2 cos,
const float2 sin) {
float2 rot_v;
rot_v.x = v.x * cos.x - v.y * sin.x;
rot_v.y = v.y * cos.y + v.x * sin.y;
return rot_v;
}
inline __device__ uint32_t rotary_embedding_transform(const uint32_t v,
const float2 coef) {
float2 fv = half2_to_float2(v);
float2 rot_fv = rotary_embedding_transform(fv, coef);
return float2_to_half2(rot_fv);
}
inline __device__ uint32_t rotary_embedding_transform(const uint32_t v,
const uint32_t cos,
const uint32_t sin) {
float2 fv = half2_to_float2(v);
float2 fcos = half2_to_float2(cos);
float2 fsin = half2_to_float2(sin);
float2 rot_fv = rotary_embedding_transform(fv, fcos, fsin);
return float2_to_half2(rot_fv);
}
inline __device__ uint32_t rotary_embedding_transform(const uint32_t v,
const float2 cos,
const float2 sin) {
float2 fv = half2_to_float2(v);
float2 rot_fv = rotary_embedding_transform(fv, cos, sin);
return float2_to_half2(rot_fv);
}
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat162
rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef) {
float2 fv = bf1622float2(v);
float2 rot_fv = rotary_embedding_transform(fv, coef);
return __floats2bfloat162_rn(rot_fv.x, rot_fv.y);
}
inline __device__ __nv_bfloat162
rotary_embedding_transform(const __nv_bfloat162 v,
const __nv_bfloat162 cos,
const __nv_bfloat162 sin) {
float2 fv = bf1622float2(v);
float2 fcos = bf1622float2(cos);
float2 fsin = bf1622float2(sin);
float2 rot_fv = rotary_embedding_transform(fv, fcos, fsin);
return __floats2bfloat162_rn(rot_fv.x, rot_fv.y);
}
inline __device__ __nv_bfloat162 rotary_embedding_transform(
const __nv_bfloat162 v, const float2 cos, const float2 sin) {
float2 fv = bf1622float2(v);
float2 rot_fv = rotary_embedding_transform(fv, cos, sin);
return __floats2bfloat162_rn(rot_fv.x, rot_fv.y);
}
#endif
inline __device__ void apply_rotary_embedding(float& q, // NOLINT
float& k, // NOLINT
float& cos, // NOLINT
float& sin) { // NOLINT
return;
}
inline __device__ void apply_rotary_embedding(float2& q, // NOLINT
float2& k, // NOLINT
float2& cos, // NOLINT
float2& sin) { // NOLINT
q = rotary_embedding_transform(q, cos, sin);
k = rotary_embedding_transform(k, cos, sin);
}
inline __device__ void apply_rotary_embedding(float4& q, // NOLINT
float4& k, // NOLINT
float4& cos, // NOLINT
float4& sin) { // NOLINT
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
Float4_& cos_ = *reinterpret_cast<Float4_*>(&cos);
Float4_& sin_ = *reinterpret_cast<Float4_*>(&sin);
q_.x = rotary_embedding_transform(q_.x, cos_.x, sin_.x);
k_.x = rotary_embedding_transform(k_.x, cos_.x, sin_.x);
q_.y = rotary_embedding_transform(q_.y, cos_.y, sin_.y);
k_.y = rotary_embedding_transform(k_.y, cos_.y, sin_.y);
}
inline __device__ void apply_rotary_embedding(uint32_t& q, // NOLINT
uint32_t& k, // NOLINT
uint32_t& cos, // NOLINT
uint32_t& sin) { // NOLINT
q = rotary_embedding_transform(q, cos, sin);
k = rotary_embedding_transform(k, cos, sin);
}
inline __device__ void apply_rotary_embedding(uint32_t& q, // NOLINT
uint32_t& k, // NOLINT
float2& cos, // NOLINT
float2& sin) { // NOLINT
q = rotary_embedding_transform(q, cos, sin);
k = rotary_embedding_transform(k, cos, sin);
}
inline __device__ void apply_rotary_embedding(uint2& q, // NOLINT
uint2& k, // NOLINT
uint2& cos, // NOLINT
uint2& sin) { // NOLINT
q.x = rotary_embedding_transform(q.x, cos.x, sin.x);
k.x = rotary_embedding_transform(k.x, cos.x, sin.x);
q.y = rotary_embedding_transform(q.y, cos.y, sin.y);
k.y = rotary_embedding_transform(k.y, cos.y, sin.x);
}
inline __device__ void apply_rotary_embedding(uint2& q, // NOLINT
uint2& k, // NOLINT
float4& cos, // NOLINT
float4& sin) { // NOLINT
Float4_& cos_ = *reinterpret_cast<Float4_*>(&cos);
Float4_& sin_ = *reinterpret_cast<Float4_*>(&sin);
q.x = rotary_embedding_transform(q.x, cos_.x, sin_.x);
k.x = rotary_embedding_transform(k.x, cos_.x, sin_.x);
q.y = rotary_embedding_transform(q.y, cos_.y, sin_.y);
k.y = rotary_embedding_transform(k.y, cos_.y, sin_.x);
}
inline __device__ void apply_rotary_embedding(uint4& q, // NOLINT
uint4& k, // NOLINT
uint4& cos, // NOLINT
uint4& sin) { // NOLINT
q.x = rotary_embedding_transform(q.x, cos.x, sin.x);
k.x = rotary_embedding_transform(k.x, cos.x, sin.x);
q.y = rotary_embedding_transform(q.y, cos.y, sin.y);
k.y = rotary_embedding_transform(k.y, cos.y, sin.y);
q.z = rotary_embedding_transform(q.z, cos.z, sin.z);
k.z = rotary_embedding_transform(k.z, cos.z, sin.z);
q.w = rotary_embedding_transform(q.w, cos.w, sin.w);
k.w = rotary_embedding_transform(k.w, cos.w, sin.w);
}
inline __device__ void apply_rotary_embedding(uint4& q, // NOLINT
uint4& k, // NOLINT
Float8_& cos, // NOLINT
Float8_& sin) { // NOLINT
q.x = rotary_embedding_transform(q.x, cos.x, sin.x);
k.x = rotary_embedding_transform(k.x, cos.x, sin.x);
q.y = rotary_embedding_transform(q.y, cos.y, sin.y);
k.y = rotary_embedding_transform(k.y, cos.y, sin.y);
q.z = rotary_embedding_transform(q.z, cos.z, sin.z);
k.z = rotary_embedding_transform(k.z, cos.z, sin.z);
q.w = rotary_embedding_transform(q.w, cos.w, sin.w);
k.w = rotary_embedding_transform(k.w, cos.w, sin.w);
}
inline __device__ void apply_rotary_embedding(float& q, // NOLINT
int zid,
int rot_embed_dim,
int t_step) {
return;
}
inline __device__ void apply_rotary_embedding(
float& q, float& k, int zid, int rot_embed_dim, int t_step) { // NOLINT
return;
}
inline __device__ void apply_rotary_embedding(float2& q, // NOLINT
int tid,
int rot_embed_dim,
int t_step) {
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef =
rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void apply_rotary_embedding(
float2& q, float2& k, int tid, int rot_embed_dim, int t_step) { // NOLINT
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef =
rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(float4& q, // NOLINT
int tid,
int rot_embed_dim,
int t_step) {
if (4 * tid >= rot_embed_dim) {
return;
}
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
const auto coef0 =
rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
q_.x = rotary_embedding_transform(q_.x, coef0);
const auto coef1 =
rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
q_.y = rotary_embedding_transform(q_.y, coef1);
}
inline __device__ void apply_rotary_embedding(
float4& q, float4& k, int tid, int rot_embed_dim, int t_step) { // NOLINT
if (4 * tid >= rot_embed_dim) {
return;
}
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
const auto coef0 =
rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
q_.x = rotary_embedding_transform(q_.x, coef0);
k_.x = rotary_embedding_transform(k_.x, coef0);
const auto coef1 =
rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
q_.y = rotary_embedding_transform(q_.y, coef1);
k_.y = rotary_embedding_transform(k_.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint32_t& q, // NOLINT
int tid,
int rot_embed_dim,
int t_step) {
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef =
rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void apply_rotary_embedding(uint32_t& q, // NOLINT
uint32_t& k, // NOLINT
int tid,
int rot_embed_dim,
int t_step) { // NOLINT
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef =
rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(uint2& q, // NOLINT
int tid,
int rot_embed_dim,
int t_step) {
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 =
rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 =
rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
}
inline __device__ void apply_rotary_embedding(
uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step) { // NOLINT
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 =
rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 =
rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint4& q, // NOLINT
int tid,
int rot_embed_dim,
int t_step) {
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 =
rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 =
rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 =
rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 =
rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
q.w = rotary_embedding_transform(q.w, coef3);
}
inline __device__ void apply_rotary_embedding(
uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step) { // NOLINT
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 =
rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 =
rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 =
rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 =
rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}
#ifdef ENABLE_BF16
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, // NOLINT
__nv_bfloat162& k, // NOLINT
__nv_bfloat162& cos, // NOLINT
__nv_bfloat162& sin) { // NOLINT
q = rotary_embedding_transform(q, cos, sin);
k = rotary_embedding_transform(k, cos, sin);
}
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, // NOLINT
__nv_bfloat162& k, // NOLINT
float2& cos, // NOLINT
float2& sin) { // NOLINT
q = rotary_embedding_transform(q, cos, sin);
k = rotary_embedding_transform(k, cos, sin);
}
inline __device__ void apply_rotary_embedding(bf16_4_t& q, // NOLINT
bf16_4_t& k, // NOLINT
bf16_4_t& cos, // NOLINT
bf16_4_t& sin) { // NOLINT
q.x = rotary_embedding_transform(q.x, cos.x, sin.x);
k.x = rotary_embedding_transform(k.x, cos.x, sin.x);
q.y = rotary_embedding_transform(q.y, cos.y, sin.y);
k.y = rotary_embedding_transform(k.y, cos.y, sin.y);
}
inline __device__ void apply_rotary_embedding(bf16_4_t& q, // NOLINT
bf16_4_t& k, // NOLINT
float4& cos, // NOLINT
float4& sin) { // NOLINT
Float4_& cos_ = *reinterpret_cast<Float4_*>(&cos);
Float4_& sin_ = *reinterpret_cast<Float4_*>(&sin);
q.x = rotary_embedding_transform(q.x, cos_.x, sin_.x);
k.x = rotary_embedding_transform(k.x, cos_.x, sin_.x);
q.y = rotary_embedding_transform(q.y, cos_.y, sin_.y);
k.y = rotary_embedding_transform(k.y, cos_.y, sin_.y);
}
inline __device__ void apply_rotary_embedding(bf16_8_t& q, // NOLINT
bf16_8_t& k, // NOLINT
bf16_8_t& cos, // NOLINT
bf16_8_t& sin) { // NOLINT
q.x = rotary_embedding_transform(q.x, cos.x, sin.x);
k.x = rotary_embedding_transform(k.x, cos.x, sin.x);
q.y = rotary_embedding_transform(q.y, cos.y, sin.y);
k.y = rotary_embedding_transform(k.y, cos.y, sin.y);
q.z = rotary_embedding_transform(q.z, cos.z, sin.z);
k.z = rotary_embedding_transform(k.z, cos.z, sin.z);
q.w = rotary_embedding_transform(q.w, cos.w, sin.w);
k.w = rotary_embedding_transform(k.w, cos.w, sin.w);
}
inline __device__ void apply_rotary_embedding(bf16_8_t& q, // NOLINT
bf16_8_t& k, // NOLINT
Float8_& cos, // NOLINT
Float8_& sin) { // NOLINT
q.x = rotary_embedding_transform(q.x, cos.x, sin.x);
k.x = rotary_embedding_transform(k.x, cos.x, sin.x);
q.y = rotary_embedding_transform(q.y, cos.y, sin.y);
k.y = rotary_embedding_transform(k.y, cos.y, sin.y);
q.z = rotary_embedding_transform(q.z, cos.z, sin.z);
k.z = rotary_embedding_transform(k.z, cos.z, sin.z);
q.w = rotary_embedding_transform(q.w, cos.w, sin.w);
k.w = rotary_embedding_transform(k.w, cos.w, sin.w);
}
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, // NOLINT
int tid,
int rot_embed_dim,
int t_step) {
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef =
rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, // NOLINT
__nv_bfloat162& k, // NOLINT
int tid,
int rot_embed_dim,
int t_step) {
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef =
rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(bf16_4_t& q, // NOLINT
int tid,
int rot_embed_dim,
int t_step) {
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 =
rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 =
rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
}
inline __device__ void apply_rotary_embedding(bf16_4_t& q, // NOLINT
bf16_4_t& k, // NOLINT
int tid,
int rot_embed_dim,
int t_step) { // NOLINT
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 =
rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 =
rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}
inline __device__ void apply_rotary_embedding(bf16_8_t& q, // NOLINT
int tid,
int rot_embed_dim,
int t_step) {
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 =
rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 =
rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 =
rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 =
rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
q.w = rotary_embedding_transform(q.w, coef3);
}
inline __device__ void apply_rotary_embedding(bf16_8_t& q, // NOLINT
bf16_8_t& k, // NOLINT
int tid,
int rot_embed_dim,
int t_step) { // NOLINT
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 =
rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 =
rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 =
rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 =
rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}
#endif // ENABLE_BF16
} // namespace
} // namespace fusion
} // namespace phi
#endif // PADDLE_WITH_HIP
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/gpu/masked_multihead_attention.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void MMHAKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& cache_kv,
const paddle::optional<DenseTensor>& src_mask,
const paddle::optional<DenseTensor>& cum_offsets,
const paddle::optional<DenseTensor>& sequence_lengths,
const paddle::optional<DenseTensor>& rotary_tensor,
const paddle::optional<DenseTensor>& beam_cache_offset,
const paddle::optional<DenseTensor>& qkv_out_scale,
const paddle::optional<DenseTensor>& out_shift,
const paddle::optional<DenseTensor>& out_smooth,
int seq_len,
int rotary_emb_dims,
const bool use_neox_rotary_style,
const float out_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
DenseTensor* out,
DenseTensor* cache_kv_out,
DenseTensor* beam_cache_offset_out) {
#ifndef PADDLE_WITH_HIP
const auto& x_dims = x.dims();
int bsz = x_dims[0];
int cache_bsz = cache_kv.dims()[1];
int num_head = cache_kv.dims()[2];
int max_seq_len = cache_kv.dims()[3];
int dim_head = cache_kv.dims()[4];
int timestep = max_seq_len;
float inv_sqrt_dh = 1. / sqrt(dim_head);
Masked_multihead_attention_params<T> params;
bool mask_broadcast_num_heads = true;
if (src_mask) {
if (src_mask->dims()[1] == 1) {
mask_broadcast_num_heads = true;
} else if (src_mask->dims()[1] == num_head) {
mask_broadcast_num_heads = false;
} else {
PADDLE_THROW(errors::InvalidArgument(
"Unknow dimension for attn_mask, the num_head(2nd) "
"dimension is invalid, it should be 1 or num_head(%d), "
"but got %d",
num_head,
src_mask->dims()[1]));
}
params.attn_mask = src_mask->data<T>();
params.mask_length = src_mask->dims()[3];
timestep = src_mask->dims()[3] - 1;
}
if (out_scale > 0) {
dev_ctx.template Alloc<int8_t>(out);
} else {
dev_ctx.template Alloc<T>(out);
}
if (sequence_lengths) {
params.sequence_lengths = sequence_lengths->data<int>();
}
if (cum_offsets) {
params.cum_offsets = cum_offsets->data<int>();
} else {
params.cum_offsets = nullptr;
}
if (rotary_emb_dims > 0) {
params.rotary_emb = rotary_tensor->data<float>();
} else {
params.rotary_emb = nullptr;
}
if (beam_cache_offset) {
params.beam_cache_offset = beam_cache_offset->data<int>();
params.beam_width = beam_cache_offset->dims()[1];
}
params.mask_broadcast_num_heads = mask_broadcast_num_heads;
params.cache_kv = const_cast<T*>(cache_kv_out->data<T>());
params.neox_rotary_style = use_neox_rotary_style;
params.add_qkv_bias = false;
params.batch_size = bsz;
params.cache_batch_size = cache_bsz;
params.num_head = num_head;
params.timestep = timestep;
params.seq_len = seq_len;
params.max_seq_length = max_seq_len;
params.inv_sqrt_dh = inv_sqrt_dh;
params.rotary_emb_dims = rotary_emb_dims;
if (out_shift) {
DispatchFMHA<T>(dev_ctx,
x,
*(out_shift.get_ptr()),
*(out_smooth.get_ptr()),
params,
num_head,
dim_head,
out,
qkv_out_scale.get_ptr(),
out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
} else {
DispatchFMHA<T>(dev_ctx,
x,
params,
num_head,
dim_head,
out,
qkv_out_scale.get_ptr(),
out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound);
}
#endif // PADDLE_WITH_HIP
}
} // namespace fusion
} // namespace phi
#if CUDA_VERSION >= 11000
PD_REGISTER_KERNEL(masked_multihead_attention,
GPU,
ALL_LAYOUT,
phi::fusion::MMHAKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(masked_multihead_attention,
GPU,
ALL_LAYOUT,
phi::fusion::MMHAKernel,
float,
phi::dtype::float16) {}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/fusion/gpu/masked_multihead_attention_utils.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void MMHAKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& cache_kv,
const paddle::optional<DenseTensor>& src_mask,
const paddle::optional<DenseTensor>& cum_offsets,
const paddle::optional<DenseTensor>& sequence_lengths,
const paddle::optional<DenseTensor>& rotary_tensor,
const paddle::optional<DenseTensor>& beam_cache_offset,
const paddle::optional<DenseTensor>& qkv_out_scale,
const paddle::optional<DenseTensor>& out_shift,
const paddle::optional<DenseTensor>& out_smooth,
int seq_len,
int rotary_emb_dims,
const bool use_neox_rotary_style,
const float out_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
DenseTensor* out,
DenseTensor* cache_kv_out,
DenseTensor* beam_cache_offset_out);
} // namespace fusion
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for mmha kernel.
*/
#ifndef PADDLE_WITH_HIP
#pragma once
#include "glog/logging.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/mmha_util.cu.h"
namespace phi {
namespace fusion {
#define MMHA_USE_FP32_ACUM_FOR_LOGITS
#define MMHA_USE_FP32_ACUM_FOR_OUT
#define MMHA_USE_FP32_ACUM_FOR_FMA
template <typename T>
__device__ __inline__ T ClipFunc(const T v, const T min, const T max) {
if (v > max) return max;
if (v < min) return min;
return v;
}
template <typename InType, typename OutType>
__forceinline__ __device__ OutType QuantHelperFunc(const InType input,
const float scale,
const int round_type,
const float max_bound,
const float min_bound) {
float quant_value = max_bound * scale * input;
if (round_type == 0) {
quant_value = static_cast<float>(rint(quant_value));
} else {
quant_value = static_cast<float>(round(quant_value));
}
return static_cast<OutType>(
ClipFunc<float>(quant_value, min_bound, max_bound));
}
template <typename T>
struct Masked_multihead_attention_params {
// output buffer, [B, 1(seq_len), num_head * dim_head]
T *out;
// qkv_out, [B, 1(seq_len), 3, num_head * dim_head]
const T *qkv;
// bias, [3, num_head, dim_head]
T *qkv_bias;
// [bsz, seq_len]
const int *cum_offsets;
// TODO(wangxi): optimize with input_lengths and max_input_len?
// [bsz, 1, 1, time_step(cache_seq_length)+1]
const T *attn_mask;
int mask_length;
// whether to broadcast num_heads(2nd) dimension for attn_mask
// in MMHA, if false, attn_mask shape should be
// [bsz, num_heads, 1, time_step(cache_seq_length)+1]
bool mask_broadcast_num_heads;
// [2, B, num_head, max_seq_len(valid cache_seq_len), dim_head]
// k [B, num_head, dim_head/x, max_seq_len, x], that is `seq_len` first
// v [B, num_head, max_seq_len, dim_head]
T *cache_kv;
// [B, max_seq_len]
const int *beam_cache_offset = nullptr;
const int *sequence_lengths{nullptr};
// The RoPE embedding, [2, B, rotary_seq_len, 1, dim_head]
// rotary_emb_dims = 1 if pos_ids_extra is null else 2
const float *rotary_emb;
int rotary_emb_dims;
int rotary_seq_len = 1;
int batch_size; // batch * beam
int beam_width;
int cache_batch_size;
int num_head;
int timestep; // cache_seq_length
int seq_len;
int max_seq_length;
// 1.f / sqrt(Dh)
float inv_sqrt_dh;
bool add_qkv_bias;
bool neox_rotary_style;
};
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
template <typename T>
struct K_vec_acum_fp32_ {};
template <>
struct K_vec_acum_fp32_<uint32_t> {
using Type = float2;
};
#endif
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template <typename T>
struct V_vec_acum_fp32_ {};
// template <> struct V_vec_acum_fp32_<float> { using Type = float; };
// template <> struct V_vec_acum_fp32_<float2> { using Type = float2; };
template <>
struct V_vec_acum_fp32_<float4> {
using Type = float4;
};
// template <> struct V_vec_acum_fp32_<uint32_t> { using Type = float2; };
// template <> struct V_vec_acum_fp32_<uint2 > { using Type = Float4_; };
template <>
struct V_vec_acum_fp32_<uint4> {
using Type = Float8_;
};
#ifdef ENABLE_BF16
template <>
struct V_vec_acum_fp32_<__nv_bfloat162> {
using Type = float2;
};
template <>
struct V_vec_acum_fp32_<bf16_4_t> {
using Type = Float4_;
};
template <>
struct V_vec_acum_fp32_<bf16_8_t> {
using Type = Float8_;
};
#endif // ENABLE_BF16
#endif
// clang-format on
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int THREADS_PER_KEY, typename K_vec, int N>
inline __device__ float qk_dot_(const K_vec (&q)[N],
const K_vec (&k)[N],
float inv_sqrt_dh) {
K_vec inv_q = mul<K_vec, K_vec, float>(q[0], inv_sqrt_dh);
K_vec qk_vec = mul<K_vec, K_vec, K_vec>(inv_q, k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
inv_q = mul<K_vec, K_vec, float>(q[ii], inv_sqrt_dh);
qk_vec = fma(inv_q, k[ii], qk_vec);
}
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
return qk;
}
inline __device__ float4 hmma_fp32_tensorcore(const uint2 &a, uint32_t b) {
float4 c;
float zero = 0.f;
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5}, \n"
" {%6}, \n"
" {%7, %7, %7, %7}; \n"
: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
return c;
}
template <int N>
inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N],
const uint32_t (&k)[N],
float inv_sqrt_dh) {
#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \
__CUDA_ARCH__ >= 750
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type;
#else
using K_vec_acum = uint32_t;
#endif
K_vec_acum inv_q = mul<K_vec_acum, uint32_t, float>(q[0], inv_sqrt_dh);
K_vec_acum qk_vec = mul<K_vec_acum, K_vec_acum, uint32_t>(inv_q, k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
inv_q = mul<K_vec_acum, uint32_t, float>(q[ii], inv_sqrt_dh);
qk_vec = fma(inv_q, k[ii], qk_vec);
}
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
uint32_t qk_vec_ = float2_to_half2(qk_vec);
return hmma_fp32_tensorcore(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
#else
return hmma_fp32_tensorcore(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
#endif
#else
return 0.f;
#endif
}
template <typename T, int THREADS_PER_KEY>
struct Qk_dot {
template <typename K_vec, int N>
static inline __device__ float dot(const K_vec (&q)[N],
const K_vec (&k)[N],
float inv_sqrt_dh) {
return qk_dot_<THREADS_PER_KEY>(q, k, inv_sqrt_dh);
}
};
template <>
struct Qk_dot<float16, 4> {
template <int N>
static inline __device__ float dot(const uint32_t (&q)[N],
const uint32_t (&k)[N],
float inv_sqrt_dh) {
#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \
__CUDA_ARCH__ >= 750
return qk_hmma_dot_(q, k, inv_sqrt_dh);
#else
return qk_dot_<4>(q, k, inv_sqrt_dh);
#endif
}
};
template <int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float *red_smem, float sum) {
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
if (lane == 0) {
red_smem[warp] = sum;
}
__syncthreads();
if (lane < WARPS_PER_BLOCK) {
sum = red_smem[lane];
}
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
return __shfl_sync(uint32_t(-1), sum, 0);
}
inline __device__ void convert_from_float(float &dst, float src) { // NOLINT
dst = src;
}
inline __device__ void convert_from_float(float4 &dst, float4 src) { // NOLINT
dst = src;
}
inline __device__ void convert_from_float(phi::float16 &dst, // NOLINT
float src) {
dst = static_cast<phi::float16>(src);
}
inline __device__ void convert_from_float(uint4 &dst, Float8_ src) { // NOLINT
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
dst.z = float2_to_half2(src.z);
dst.w = float2_to_half2(src.w);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ void convert_from_float(__nv_bfloat16 &dst, // NOLINT
float src) { // NOLINT
dst = __float2bfloat16(src);
}
inline __device__ void convert_from_float(__nv_bfloat162 &dst, // NOLINT
float2 src) { // NOLINT
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst = __float22bfloat162_rn(src);
#else
dst = __floats2bfloat162_rn(src.x, src.y);
#endif
}
inline __device__ void convert_from_float(bf16_4_t &dst, // NOLINT
Float4_ src) { // NOLINT
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
#else
dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(bf16_4_t &dst, // NOLINT
float4 src) { // NOLINT
convert_from_float(
dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
}
inline __device__ void convert_from_float(bf16_8_t &dst, // NOLINT
Float8_ src) { // NOLINT
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
dst.z = __float22bfloat162_rn(src.z);
dst.w = __float22bfloat162_rn(src.w);
#else
dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
dst.z = __floats2bfloat162_rn(src.z.x, src.z.y);
dst.w = __floats2bfloat162_rn(src.w.x, src.w.y);
#endif
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void zero(uint16_t &dst) { dst = uint16_t(0); } // NOLINT
template <typename T>
inline __device__ void zero(T &dst) { // NOLINT
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;
#pragma unroll
for (int ii = 0; ii < WORDS; ++ii) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}
template <typename T,
int Dh,
int Dh_MAX,
int THREADS_PER_KEY,
int THREADS_PER_VALUE,
int THREADS_PER_BLOCK,
typename LoadFunc,
typename StoreFunc>
__global__ void masked_multihead_attention_kernel(
Masked_multihead_attention_params<T> params,
LoadFunc load_func,
StoreFunc store_func) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const int bi = blockIdx.y;
if (params.sequence_lengths && params.sequence_lengths[bi] == 0) {
return;
}
typedef PDDataTypeTraits<T> traits_;
typedef typename traits_::DataType DataType_;
static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
constexpr int WARP_SIZE = 32;
constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
extern __shared__ char smem_[];
float *qk_smem = reinterpret_cast<float *>(smem_);
char *logits_smem_ = smem_;
// fp32 accum for logits
float *logits_smem = reinterpret_cast<float *>(logits_smem_);
T *out_smem = reinterpret_cast<T *>(smem_);
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
using Qk_vec_RoPE = typename Qk_vec_RoPE_<T, float, Dh_MAX>::Type;
__shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX];
// beam id
const int beami = bi % params.beam_width;
// real batch id
const int bbi = bi / params.beam_width;
const int hi = blockIdx.x;
const int bhi = bi * params.num_head + hi;
const int bbhi = bbi * params.beam_width * params.num_head + hi;
const int ti =
params.cum_offsets ? bi * params.seq_len - params.cum_offsets[bi] : -1;
const int thi = params.cum_offsets ? ti * params.num_head + hi : -1;
const int tid = threadIdx.x;
const int bi_seq_len_offset = bi * params.max_seq_length;
float qk_max = -FLT_MAX;
float qk = 0;
int act_time_step = params.sequence_lengths == nullptr
? params.timestep
: params.sequence_lengths[bi];
// qkv [B, S=1, 3, num_head, head_dim]
int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh;
constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
// Use block reduction if needed
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE;
// cache_k, [B, num_head, head_dim / x, max_seq_len, x]
// x == 4/8 for FP32/FP16, 128bit, 16Byte
constexpr int QK_ELTS_IN_16B = 16 / sizeof(T);
constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec);
// const T *q_base = params.qkv;
// const T *k_base = params.qkv + params.num_head * Dh;
T *q_bias_base = nullptr;
T *k_bias_base = nullptr;
if (params.add_qkv_bias) {
q_bias_base = params.qkv_bias;
k_bias_base = params.qkv_bias + params.num_head * Dh;
}
if (tid < QK_VECS_PER_WARP) {
int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE;
int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE;
Qk_vec q;
zero(q);
// q = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
// ? *reinterpret_cast<const Qk_vec *>(&q_base[qk_offset])
// : q;
if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) {
load_func.template load<Qk_vec>(q, qk_offset);
}
Qk_vec k;
zero(k);
// k = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
// ? *reinterpret_cast<const Qk_vec *>(&k_base[qk_offset])
// : k;
if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) {
load_func.template load<Qk_vec>(k, params.num_head * Dh + qk_offset);
}
if (params.add_qkv_bias) {
Qk_vec q_bias;
zero(q_bias);
Qk_vec k_bias;
zero(k_bias);
q_bias =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&q_bias_base[qk_bias_offset])
: q_bias;
k_bias =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&k_bias_base[qk_bias_offset])
: k_bias;
q = add(q, q_bias);
// TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510
// we may not require k_bias.
k = add(k, k_bias);
}
if (!params.neox_rotary_style) {
if (params.rotary_emb_dims != 0) {
int rotary_offset = bi * Dh + tid * QK_VEC_SIZE;
const float *cos_base = params.rotary_emb;
const float *sin_base = params.rotary_emb + params.batch_size * Dh;
Qk_vec_RoPE cos_emb, sin_emb;
zero(cos_emb);
zero(sin_emb);
cos_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec_RoPE *>(
&cos_base[rotary_offset])
: cos_emb;
sin_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec_RoPE *>(
&sin_base[rotary_offset])
: sin_emb;
apply_rotary_embedding(q, k, cos_emb, sin_emb);
}
} else {
/* old rotary pos emb */
if (params.rotary_emb_dims != 0) {
int last_dim = Dh / params.rotary_emb_dims;
int half_lastdim = last_dim / 2;
int rotary_offset = bi * Dh + tid * QK_VEC_SIZE;
const float *cos_base = params.rotary_emb;
const float *sin_base = params.rotary_emb + params.batch_size * Dh;
int stride = half_lastdim / QK_VEC_SIZE;
int stride_all_lastdim = 2 * stride;
int right_id = tid / stride_all_lastdim * stride_all_lastdim +
(tid + stride) % (stride_all_lastdim);
int qk_right_offset = qkv_base_offset + right_id * QK_VEC_SIZE;
int qk_right_bias_offset = hi * Dh + right_id * QK_VEC_SIZE;
Qk_vec q_right;
zero(q_right);
// q_right =
// (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
// ? *reinterpret_cast<const Qk_vec *>(&q_base[qk_right_offset])
// : q_right;
if (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) {
load_func.template load<Qk_vec>(q_right, qk_right_offset);
}
Qk_vec k_right;
zero(k_right);
// k_right =
// (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
// ? *reinterpret_cast<const Qk_vec *>(&k_base[qk_right_offset])
// : k_right;
if (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) {
load_func.template load<Qk_vec>(
k_right, params.num_head * Dh + qk_right_offset);
}
if (params.add_qkv_bias) {
Qk_vec q_right_bias;
zero(q_right_bias);
q_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(
&q_bias_base[qk_right_bias_offset])
: q_right_bias;
Qk_vec k_right_bias;
zero(k_right_bias);
k_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(
&k_bias_base[qk_right_bias_offset])
: k_right_bias;
q_right = add(q_right, q_right_bias);
k_right = add(k_right, k_right_bias);
}
Qk_vec_RoPE cos_emb;
zero(cos_emb);
cos_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec_RoPE *>(
&cos_base[rotary_offset])
: cos_emb;
Qk_vec_RoPE sin_emb;
zero(sin_emb);
sin_emb = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec_RoPE *>(
&sin_base[rotary_offset])
: sin_emb;
float alpha = (tid % stride_all_lastdim) < stride
? static_cast<float>(-1)
: static_cast<float>(1);
q = apply_rotary_emb<Qk_vec, Qk_vec_RoPE>(
q, q_right, cos_emb, sin_emb, alpha);
k = apply_rotary_emb<Qk_vec, Qk_vec_RoPE>(
k, k_right, cos_emb, sin_emb, alpha);
}
}
*reinterpret_cast<Qk_vec *>(&q_smem[tid * QK_VEC_SIZE]) = q;
int co = tid / QK_VECS_IN_16B;
int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE;
int offset = bhi * params.max_seq_length * Dh +
co * params.max_seq_length * QK_ELTS_IN_16B +
act_time_step * QK_ELTS_IN_16B + ci;
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k;
}
qk = dot<Qk_vec, Qk_vec>(q, k);
if (QK_VECS_PER_WARP <= WARP_SIZE) {
#pragma unroll
for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
}
}
}
if (QK_VECS_PER_WARP > WARP_SIZE) {
constexpr int WARPS_PER_RED =
(QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
qk = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
}
if (tid == 0) {
// NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk *= params.inv_sqrt_dh;
qk_max = qk;
qk_smem[act_time_step] = qk;
}
__syncthreads();
using K_vec = typename K_vec_<T, THREADS_PER_KEY>::Type;
constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T);
static_assert(Dh_MAX % K_VEC_SIZE == 0, "");
constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY;
constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
int ko = tid / THREADS_PER_KEY;
int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE;
static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD, "");
K_vec q[K_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < K_VECS_PER_THREAD; ++i) {
q[i] = *reinterpret_cast<const K_vec *>(
&q_smem[ki + i * THREADS_PER_KEY * K_VEC_SIZE]);
}
constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY;
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
T *k_cache = &params.cache_kv[bhi * params.max_seq_length * Dh + ki];
T *k_cache_batch = &params.cache_kv[bbhi * params.max_seq_length * Dh + ki];
int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP;
const int *beam_offsets = params.beam_cache_offset
? &params.beam_cache_offset[bi_seq_len_offset]
: nullptr;
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
const int beam_offset = beam_offsets ? beam_offsets[ti] * params.num_head *
params.max_seq_length * Dh
: 0;
K_vec k[K_VECS_PER_THREAD];
K_vec k_vec_zero;
zero(k_vec_zero);
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_seq_length + ti;
if (ti < act_time_step) {
if (beam_offset) {
k[ii] =
(Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length)
? *reinterpret_cast<const K_vec *>(
&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])
: k_vec_zero;
} else {
k[ii] =
(Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length)
? *reinterpret_cast<const K_vec *>(
&k_cache[jj * QK_ELTS_IN_16B])
: k_vec_zero;
}
}
}
// NOTE(liyurui): We should multiple q with inv_sqrt_dh first, for dot(q, k)
// may overflow with FP16 in large model.
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k, params.inv_sqrt_dh);
// bool is_mask = false;
if (ti < act_time_step && tid % THREADS_PER_KEY == 0) {
// qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
auto mask_bhi = params.mask_broadcast_num_heads ? bi : bhi;
// T mask = params.attn_mask[mask_bhi * (params.timestep + 1) + ti];
if (params.attn_mask) {
T mask = params.attn_mask[mask_bhi * params.mask_length + ti];
qk += static_cast<float>(mask);
}
qk_max = fmaxf(qk_max, qk);
qk_smem[ti] = qk;
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
const int warp = tid / WARP_SIZE;
const int lane = tid % WARP_SIZE;
if (lane == 0) {
red_smem[warp] = qk_max;
}
__syncthreads();
qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
float sum = 0.f;
for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) {
// bool is_mask = false;
// float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max);
float logit = __expf(qk_smem[ti] - qk_max);
sum += logit;
qk_smem[ti] = logit;
}
sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
// FIXME(wangxi): need add 1.e-6f?
float inv_sum = __fdividef(1.f, sum + 1.e-6f);
for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) {
convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum);
}
__syncthreads();
constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type;
int vo = tid / THREADS_PER_VALUE;
int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE;
T *v_cache = &params.cache_kv[params.cache_batch_size * params.num_head *
params.max_seq_length * Dh +
bhi * params.max_seq_length * Dh + vi];
T *v_cache_batch = &params.cache_kv[params.batch_size * params.num_head *
params.max_seq_length * Dh +
bbhi * params.max_seq_length * Dh + vi];
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
using V_vec_acum = typename V_vec_acum_fp32_<V_vec>::Type;
#else
using V_vec_acum = V_vec;
#endif
V_vec_acum out;
zero(out);
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
if (Dh == Dh_MAX || vi < Dh) {
for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) {
const int beam_offset =
beam_offsets
? beam_offsets[ti] * params.num_head * params.max_seq_length * Dh
: 0;
V_vec v;
if (beam_offset) {
v = *reinterpret_cast<const V_vec *>(
&v_cache_batch[beam_offset + ti * Dh]);
} else {
v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]);
}
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti];
out = fma(logit, cast_to_float(v), out);
#else
DataType_ logit = static_cast<DataType_>(logits_smem[ti]);
// Update the partial sums.
out = fma(logit, v, out);
#endif
}
}
V_vec v_bias;
zero(v_bias);
if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) {
// V_vec v = *reinterpret_cast<const V_vec *>(
// &params.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]);
V_vec v;
load_func.template load<V_vec>(
v, 2 * params.num_head * Dh + qkv_base_offset + vi);
if (params.add_qkv_bias) {
v_bias = *reinterpret_cast<const V_vec *>(
&params.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]);
v = add(v, v_bias);
}
*reinterpret_cast<V_vec *>(&v_cache[act_time_step * Dh]) = v;
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
out = fma(logits_smem[act_time_step], cast_to_float(v), out);
#else
out = fma(logits_smem[act_time_step], v, out);
#endif
}
__syncthreads();
if (Dh == Dh_MAX || vi < Dh) {
#pragma unroll
for (int active_groups = V_PER_ITER; active_groups >= 2;
active_groups /= 2) {
int midpoint = active_groups / 2;
if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float(
*reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]),
out);
#else
*reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
#endif
}
__syncthreads();
if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
out =
add(*reinterpret_cast<const V_vec *>(&out_smem[vo * Dh + vi]), out);
}
__syncthreads();
}
}
if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
// convert_from_float(*reinterpret_cast<V_vec *>(&params.out[bhi * Dh +
// vi]),
// out);
V_vec tmp_out;
convert_from_float(tmp_out, out);
store_func.template store<V_vec>(tmp_out,
thi != -1 ? thi * Dh + vi : bhi * Dh + vi);
#else
// *reinterpret_cast<V_vec *>(&params.out[bhi * Dh + vi]) = out;
store_func.template store<V_vec>(out,
thi != -1 ? thi * Dh + vi : bhi * Dh + vi);
#endif
}
#else
assert(false);
#endif
}
template <typename T>
inline size_t smem_size_in_bytes(
const Masked_multihead_attention_params<T> &params,
int dim_head,
int threads_per_value,
int threads_per_block) {
size_t qk_sz = div_up(params.timestep + 1, 4) * 16;
size_t logits_sz = 0;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS // NOLINT
if (sizeof(T) != 4) {
logits_sz = div_up(params.max_seq_length, 4) * 4 * sizeof(T);
}
#endif // NOLINT
size_t softmax_sz = qk_sz + logits_sz;
int rows_per_red = threads_per_block / threads_per_value;
size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2;
return max(softmax_sz, red_sz);
}
#define MMHA_LAUNCH_KERNEL(T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK, \
stream, \
load_func, \
store_func) \
size_t smem_sz = \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
constexpr auto kernel_fn = \
masked_multihead_attention_kernel<T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK, \
decltype(load_func), \
decltype(store_func)>; \
if (smem_sz > 0xc000) { \
cudaFuncSetAttribute( \
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
} \
dim3 grid(params.num_head, params.batch_size); \
kernel_fn<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>( \
params, load_func, store_func)
template <typename T, int Dh, int Dh_MAX, typename LoadFunc, typename StoreFunc>
void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params,
const cudaStream_t &stream,
LoadFunc load_func,
StoreFunc store_func) {
constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
if (params.timestep < 32) {
MMHA_LAUNCH_KERNEL(
T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream, load_func, store_func);
} else if (params.timestep < 2048) {
#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \
__CUDA_ARCH__ >= 750
MMHA_LAUNCH_KERNEL(T,
Dh,
Dh_MAX,
4,
THREADS_PER_VALUE,
256,
stream,
load_func,
store_func);
#else
MMHA_LAUNCH_KERNEL(T,
Dh,
Dh_MAX,
2,
THREADS_PER_VALUE,
128,
stream,
load_func,
store_func);
#endif
} else {
MMHA_LAUNCH_KERNEL(T,
Dh,
Dh_MAX,
1,
THREADS_PER_VALUE,
256,
stream,
load_func,
store_func);
}
}
template <typename T, typename LoadFunc, typename StoreFunc>
void fmha_impl(const phi::GPUContext &dev_ctx,
const Masked_multihead_attention_params<T> &params,
int dim_head,
LoadFunc load_func,
StoreFunc store_func) {
switch (dim_head) {
case 10:
fmha_launch_kernel<T, 10, 32>(
params, dev_ctx.stream(), load_func, store_func);
break;
case 26:
fmha_launch_kernel<T, 26, 32>(
params, dev_ctx.stream(), load_func, store_func);
break;
case 32:
fmha_launch_kernel<T, 32, 32>(
params, dev_ctx.stream(), load_func, store_func);
break;
case 64:
fmha_launch_kernel<T, 64, 64>(
params, dev_ctx.stream(), load_func, store_func);
break;
case 96:
fmha_launch_kernel<T, 96, 128>(
params, dev_ctx.stream(), load_func, store_func);
break;
case 128:
fmha_launch_kernel<T, 128, 128>(
params, dev_ctx.stream(), load_func, store_func);
break;
case 192:
fmha_launch_kernel<T, 192, 256>(
params, dev_ctx.stream(), load_func, store_func);
break;
default:
PADDLE_THROW(
phi::errors::Unimplemented("Dim_head = %d is unsupport!", dim_head));
}
}
template <typename T, typename LoadT = T>
struct MMHALoad {
explicit MMHALoad(const LoadT *src) : src_(src) {}
template <typename Vec>
__device__ void load(Vec &dst, int idx) {
dst = *reinterpret_cast<const Vec *>(src_ + idx);
}
const LoadT *src_;
};
template <typename T, typename StoreT = T, bool Smooth = false>
struct MMHAStore {
explicit MMHAStore(StoreT *dst) : dst_(dst) {}
template <typename Vec>
__device__ void store(Vec &src, int idx) {
*reinterpret_cast<Vec *>(dst_ + idx) = src;
}
StoreT *dst_;
};
template <typename T>
struct MMHAStore<T, T, true> {
MMHAStore(T *dst, const T *shift, const T *smooth, const int cols)
: dst_(dst), shift_(shift), smooth_(smooth), cols_(cols) {}
template <typename Vec>
__device__ void store(Vec &src, int idx) {
constexpr int VecSize = sizeof(Vec) / sizeof(T);
using TVec = phi::AlignedVector<T, VecSize>;
TVec src_vec;
TVec shift_vec;
TVec smooth_vec;
*reinterpret_cast<Vec *>(&src_vec) = src;
phi::Load<T, VecSize>(shift_ + idx % cols_, &shift_vec);
phi::Load<T, VecSize>(smooth_ + idx % cols_, &smooth_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
src_vec[i] = (src_vec[i] + shift_vec[i]) * smooth_vec[i];
}
phi::Store<T, VecSize>(src_vec, dst_ + idx);
}
T *dst_;
const T *shift_;
const T *smooth_;
const int cols_;
};
template <typename T>
struct MMHALoad<T, int32_t> {
MMHALoad(const int32_t *src, const float *dequant_scales, const int cols)
: src_(src), dequant_scales_(dequant_scales), cols_(cols) {}
template <typename Vec>
__device__ void load(Vec &dst, int idx) {
constexpr int VecSize = sizeof(Vec) / sizeof(T);
using SrcVec = phi::AlignedVector<int32_t, VecSize>;
using DstVec = phi::AlignedVector<T, VecSize>;
using ScaleVec = phi::AlignedVector<float, VecSize>;
SrcVec src_vec;
DstVec dst_vec;
ScaleVec scale_vec;
phi::Load<int32_t, VecSize>(src_ + idx, &src_vec);
phi::Load<float, VecSize>(dequant_scales_ + idx % cols_, &scale_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
dst_vec[i] =
static_cast<T>(static_cast<float>(src_vec[i]) * scale_vec[i]);
}
dst = *reinterpret_cast<Vec *>(&dst_vec);
}
const int32_t *src_;
const float *dequant_scales_;
const int cols_;
};
template <typename T>
struct MMHAStore<T, int8_t> {
MMHAStore(int8_t *dst,
const int quant_round_type,
const float quant_scale,
const float quant_max_bound,
const float quant_min_bound)
: dst_(dst),
quant_round_type_(quant_round_type),
quant_scale_(quant_scale),
quant_max_bound_(quant_max_bound),
quant_min_bound_(quant_min_bound) {}
template <typename Vec>
__device__ void store(Vec &src, int idx) { // NOLINT
constexpr int VecSize = sizeof(Vec) / sizeof(T);
using SrcVec = phi::AlignedVector<T, VecSize>;
using DstVec = phi::AlignedVector<int8_t, VecSize>;
SrcVec src_vec;
*reinterpret_cast<Vec *>(&src_vec) = src;
DstVec dst_vec;
#pragma unroll
for (int i = 0; i < VecSize; i++) {
dst_vec[i] =
QuantHelperFunc<float, int8_t>(static_cast<float>(src_vec[i]),
quant_scale_,
quant_round_type_,
quant_max_bound_,
quant_min_bound_);
}
phi::Store<int8_t, VecSize>(dst_vec, dst_ + idx);
}
int8_t *dst_;
const int quant_round_type_;
const float quant_scale_;
const float quant_max_bound_;
const float quant_min_bound_;
};
template <typename T>
struct MMHAStore<T, int8_t, true> {
MMHAStore(int8_t *dst,
const T *shift,
const T *smooth,
const int cols,
const int quant_round_type,
const float quant_scale,
const float quant_max_bound,
const float quant_min_bound)
: dst_(dst),
quant_round_type_(quant_round_type),
quant_scale_(quant_scale),
quant_max_bound_(quant_max_bound),
quant_min_bound_(quant_min_bound),
shift_(shift),
smooth_(smooth),
cols_(cols) {}
template <typename Vec>
__device__ void store(Vec &src, int idx) { // NOLINT
constexpr int VecSize = sizeof(Vec) / sizeof(T);
using SrcVec = phi::AlignedVector<T, VecSize>;
using DstVec = phi::AlignedVector<int8_t, VecSize>;
SrcVec src_vec;
DstVec dst_vec;
SrcVec shift_vec;
SrcVec smooth_vec;
*reinterpret_cast<Vec *>(&src_vec) = src;
phi::Load<T, VecSize>(shift_ + idx % cols_, &shift_vec);
phi::Load<T, VecSize>(smooth_ + idx % cols_, &smooth_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
src_vec[i] = (src_vec[i] + shift_vec[i]) * smooth_vec[i];
dst_vec[i] =
QuantHelperFunc<float, int8_t>(static_cast<float>(src_vec[i]),
quant_scale_,
quant_round_type_,
quant_max_bound_,
quant_min_bound_);
}
phi::Store<int8_t, VecSize>(dst_vec, dst_ + idx);
}
int8_t *dst_;
const T *shift_;
const T *smooth_;
const int cols_;
const int quant_round_type_;
const float quant_scale_;
const float quant_max_bound_;
const float quant_min_bound_;
};
template <typename T>
void DispatchFMHA(const phi::GPUContext &dev_ctx,
const phi::DenseTensor &qkv_tensor,
const Masked_multihead_attention_params<T> &params,
int num_head,
int dim_head,
phi::DenseTensor *out_tensor,
const phi::DenseTensor *dequant_qkv_scales = nullptr,
const float quant_fmha_out_scale = -1,
const int quant_round_type = 1,
const float quant_max_bound = 127.0f,
const float quant_min_bound = -127.0f) {
if (dequant_qkv_scales != nullptr && quant_fmha_out_scale > 0) {
MMHALoad<T, int32_t> load_func(qkv_tensor.data<int32_t>(),
dequant_qkv_scales->data<float>(),
3 * num_head * dim_head);
MMHAStore<T, int8_t> store_func(out_tensor->data<int8_t>(),
quant_round_type,
quant_fmha_out_scale,
quant_max_bound,
quant_min_bound);
fmha_impl(dev_ctx, params, dim_head, load_func, store_func);
} else if (dequant_qkv_scales == nullptr && quant_fmha_out_scale > 0) {
MMHALoad<T> load_func(qkv_tensor.data<T>());
MMHAStore<T, int8_t> store_func(out_tensor->data<int8_t>(),
quant_round_type,
quant_fmha_out_scale,
quant_max_bound,
quant_min_bound);
fmha_impl(dev_ctx, params, dim_head, load_func, store_func);
} else if (dequant_qkv_scales != nullptr && quant_fmha_out_scale <= 0) {
MMHALoad<T, int32_t> load_func(qkv_tensor.data<int32_t>(),
dequant_qkv_scales->data<float>(),
3 * num_head * dim_head);
MMHAStore<T> store_func(out_tensor->data<T>());
fmha_impl(dev_ctx, params, dim_head, load_func, store_func);
} else {
MMHALoad<T> load_func(qkv_tensor.data<T>());
MMHAStore<T> store_func(out_tensor->data<T>());
fmha_impl(dev_ctx, params, dim_head, load_func, store_func);
}
}
template <typename T>
void DispatchFMHA(const phi::GPUContext &dev_ctx,
const phi::DenseTensor &qkv_tensor,
const phi::DenseTensor &shift,
const phi::DenseTensor &smooth,
const Masked_multihead_attention_params<T> &params,
int num_head,
int dim_head,
phi::DenseTensor *out_tensor,
const phi::DenseTensor *dequant_qkv_scales = nullptr,
const float quant_fmha_out_scale = -1,
const int quant_round_type = 1,
const float quant_max_bound = 127.0f,
const float quant_min_bound = -127.0f) {
if (dequant_qkv_scales != nullptr && quant_fmha_out_scale > 0) {
MMHALoad<T, int32_t> load_func(qkv_tensor.data<int32_t>(),
dequant_qkv_scales->data<float>(),
3 * num_head * dim_head);
MMHAStore<T, int8_t, true> store_func(out_tensor->data<int8_t>(),
shift.data<T>(),
smooth.data<T>(),
num_head * dim_head,
quant_round_type,
quant_fmha_out_scale,
quant_max_bound,
quant_min_bound);
fmha_impl(dev_ctx, params, dim_head, load_func, store_func);
} else if (dequant_qkv_scales == nullptr && quant_fmha_out_scale > 0) {
MMHALoad<T> load_func(qkv_tensor.data<T>());
MMHAStore<T, int8_t, true> store_func(out_tensor->data<int8_t>(),
shift.data<T>(),
smooth.data<T>(),
num_head * dim_head,
quant_round_type,
quant_fmha_out_scale,
quant_max_bound,
quant_min_bound);
fmha_impl(dev_ctx, params, dim_head, load_func, store_func);
} else if (dequant_qkv_scales != nullptr && quant_fmha_out_scale <= 0) {
MMHALoad<T, int32_t> load_func(qkv_tensor.data<int32_t>(),
dequant_qkv_scales->data<float>(),
3 * num_head * dim_head);
MMHAStore<T, T, true> store_func(out_tensor->data<T>(),
shift.data<T>(),
smooth.data<T>(),
num_head * dim_head);
fmha_impl(dev_ctx, params, dim_head, load_func, store_func);
} else {
MMHALoad<T> load_func(qkv_tensor.data<T>());
MMHAStore<T, T, true> store_func(out_tensor->data<T>(),
shift.data<T>(),
smooth.data<T>(),
num_head * dim_head);
fmha_impl(dev_ctx, params, dim_head, load_func, store_func);
}
}
} // namespace fusion
} // namespace phi
#endif // PADDLE_WITH_HIP
...@@ -30,6 +30,7 @@ from .variable_length_memory_efficient_attention import ( ...@@ -30,6 +30,7 @@ from .variable_length_memory_efficient_attention import (
) )
from .fused_rms_norm import fused_rms_norm from .fused_rms_norm import fused_rms_norm
from .fused_layer_norm import fused_layer_norm from .fused_layer_norm import fused_layer_norm
from .masked_multihead_attention import masked_multihead_attention
__all__ = [ __all__ = [
'fused_multi_head_attention', 'fused_multi_head_attention',
...@@ -45,4 +46,5 @@ __all__ = [ ...@@ -45,4 +46,5 @@ __all__ = [
'variable_length_memory_efficient_attention', 'variable_length_memory_efficient_attention',
"fused_rms_norm", "fused_rms_norm",
"fused_layer_norm", "fused_layer_norm",
"masked_multihead_attention",
] ]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _C_ops
from paddle.framework import LayerHelper, in_dynamic_mode
def masked_multihead_attention(
x,
cache_kv=None,
src_mask=None,
cum_offsets=None,
sequence_lengths=None,
rotary_tensor=None,
beam_cache_offset=None,
qkv_out_scale=None,
out_shift=None,
out_smooth=None,
seq_len=1,
rotary_emb_dims=0,
use_neox_rotary_style=False,
out_scale=-1,
quant_round_type=1,
quant_max_bound=127.0,
quant_min_bound=-127.0,
):
r"""
Masked Multi-head attention for text summarization.
This is a fusion operator to compute masked multihead attention in transformer model architecture.
This operator only supports running on GPU.
Args:
x (Tensor): The input tensor could be 2-D tensor. Its shape is [batch_size, 3 * num_head * head_dim].
cache_kvs (list(Tensor)|tuple(Tensor)): The cache structure tensors for the generation model. Its shape is [2, batch_size, num_head, max_seq_len, head_dim].
src_mask (Tensor, optional): The src_mask tensor. Its shape is [batch_size, 1, 1, sequence_length].
sequence_lengths (Tensor, optional): The sequence_lengths tensor, used to index input. Its shape is [batch_size, 1].
rotary_tensor (Tensor, optional): The rotary_tensor tensor. The dtype must be float. Its shape is [batch_size, 1, 1, sequence_length, head_dim].
beam_cache_offset (Tensor, optional): The beam_cache_offset tensor. Its shape is [batch_size, beam_size, max_seq_len + max_dec_len].
qkv_out_scale (Tensor, optional): The qkv_out_scale tensor, used in quant. Its shape is [3, num_head, head_dim].
out_shift (Tensor, optional): The out_shift tensor, used in quant.
out_smooth (Tensor, optional): The out_smooth tensor, used in quant.
seq_len (int, optional): The seq_len, used to get input length. Default 1.
rotary_emb_dims (int, optional): The rotary_emb_dims. Default 1.
use_neox_rotary_style (bool, optional): A flag indicating whether neox_rotary_style is needed or not. Default False.
out_scale (float, optional): The out_scale, used in quant.
quant_round_type (int, optional): The quant_round_type, used in quant. Default 1.
quant_max_bound (float, optional): The quant_max_bound, used in quant. Default 127.0.
quant_min_bound (float, optional): The quant_min_bound, used in quant. Default -127.0.
Returns:
Tensor|tuple: If "beam_cache_offset_out" is not none, return the
tuple (output, cache_kvs_out, beam_cache_offset_out), which output is the output of
masked_multihead_attention layers, cache_kvs_out is inplace with input `cache_kvs`.
If "beam_cache_offset_out" is none, return the tuple (output, cache_kvs_out).
Examples:
.. code-block:: python
# required: gpu
import paddle
import paddle.incubate.nn.functional as F
# input: [batch_size, 3 * num_head * dim_head]
x = paddle.rand(shape=(2, 3 * 32 * 128), dtype="float32")
# src_mask: [batch_size, 1, 1, sequence_length]
src_mask = paddle.rand(shape=(2, 1, 1, 10), dtype="float32")
# cache_kv: [2, batch_size, num_head, max_seq_len, dim_head]
cache_kv = paddle.rand(shape=(2, 2, 32, 64, 128), dtype="float32")
output = F.masked_multihead_attention(
x, src_mask=src_mask, cache_kv=cache_kv)
"""
if in_dynamic_mode():
return _C_ops.masked_multihead_attention_(
x,
cache_kv,
src_mask,
cum_offsets,
sequence_lengths,
rotary_tensor,
beam_cache_offset,
qkv_out_scale,
out_shift,
out_smooth,
seq_len,
rotary_emb_dims,
use_neox_rotary_style,
out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
)
helper = LayerHelper('masked_multihead_attention', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
inputs = {}
inputs['x'] = x
inputs['cache_kv'] = cache_kv
if src_mask is not None:
inputs['src_mask'] = src_mask
if cum_offsets is not None:
inputs['cum_offsets'] = cum_offsets
if sequence_lengths is not None:
inputs['sequence_lengths'] = sequence_lengths
if rotary_tensor is not None:
inputs['rotary_tensor'] = rotary_tensor
beam_cache_offset_flag = False
if beam_cache_offset is not None:
inputs['beam_cache_offset'] = beam_cache_offset
beam_cache_offset_flag = True
else:
beam_cache_offset = helper.create_variable_for_type_inference(
dtype="int"
)
if qkv_out_scale is not None:
inputs['qkv_out_scale'] = qkv_out_scale
if out_shift is not None:
inputs['out_shift'] = out_shift
if out_smooth is not None:
inputs['out_smooth'] = out_smooth
outputs = {
'out': out,
'cache_kv_out': cache_kv,
'beam_cache_offset_out': beam_cache_offset,
}
helper.append_op(
type='masked_multihead_attention',
inputs=inputs,
outputs=outputs,
attrs={
'seq_len': seq_len,
'rotary_emb_dims': rotary_emb_dims,
'use_neox_rotary_style': use_neox_rotary_style,
'out_scale': out_scale,
'quant_round_type': quant_round_type,
'quant_max_bound': quant_max_bound,
'quant_min_bound': quant_min_bound,
},
)
return (
(out, cache_kv, beam_cache_offset)
if beam_cache_offset_flag is not None
else (out, cache_kv)
)
...@@ -155,6 +155,7 @@ if(WIN32) ...@@ -155,6 +155,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_ops_nms) list(REMOVE_ITEM TEST_OPS test_ops_nms)
list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias) list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias)
list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op) list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op)
list(REMOVE_ITEM TEST_OPS test_masked_multihead_attention_op)
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
list(REMOVE_ITEM TEST_OPS test_rms_norm_op) list(REMOVE_ITEM TEST_OPS test_rms_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_layernorm_op) list(REMOVE_ITEM TEST_OPS test_fused_layernorm_op)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.framework import core
from paddle.incubate.nn.functional import masked_multihead_attention
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMMHAOp(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.bsz = 2
self.cache_bsz = 2
self.num_head = 32
self.dim_head = 128
self.beam_size = 1
self.max_seq_len = 33
self.sequence_length = 32
self.x = np.random.uniform(
-0.05, 0.05, [self.bsz, 3, self.num_head, self.dim_head]
)
self.x_int = np.random.randint(
2, 10, size=(self.bsz, 3, self.num_head, self.dim_head)
).astype("int")
self.src_mask = np.zeros([self.bsz, 1, 1, self.sequence_length + 1])
self.cum_offsets = None
self.sequence_lengths = None
self.rotary_tensor = None
self.beam_cache_offset = None
self.cache_kv_out = np.random.uniform(
-0.05,
0.05,
[
2,
self.cache_bsz,
self.num_head,
self.sequence_length,
self.dim_head,
],
)
numpy_ones = np.zeros(
[2, self.cache_bsz, self.num_head, 1, self.dim_head]
)
self.cache_kv_mmha_out = np.concatenate(
(self.cache_kv_out, numpy_ones), axis=3
)
self.qkv_out_scale = np.random.uniform(
-0.05, 0.05, [3, self.num_head, self.dim_head]
)
self.out_shift = None
self.out_smooth = None
self.seq_len = 1
self.rotary_emb_dims = 0
self.use_neox_rotary_style = False
self.out_scale = 10
self.quant_round_type = 1
self.quant_max_bound = 126
self.quant_min_bound = -126
def quant_helper(
self, x, quant_scale, quant_round_type, quant_max_bound, quant_min_bound
):
quant_value = quant_max_bound * quant_scale * x
if quant_round_type == 0:
quant_value = paddle.to_tensor(np.rint(quant_value.numpy()))
else:
quant_value = paddle.round(quant_value)
return paddle.cast(
paddle.clip(quant_value, quant_min_bound, quant_max_bound),
paddle.int8,
)
def mmha_naive(
self,
x,
cache_kv_out,
src_mask,
qkv_out_scale,
seq_len,
out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
bsz,
):
if qkv_out_scale is not None:
x = x.cast(cache_kv_out.dtype) * qkv_out_scale
else:
x = x
x = paddle.transpose(
x, [0, 2, 1, 3]
) # [bz, seqlen, nhead, head_dim] --> [bz, nhead, seqlen, head_dim]
q, k, v = paddle.split(x, 3, axis=2)
cache_k, cache_v = paddle.split(cache_kv_out, 2, axis=0)
k = paddle.concat([cache_k.squeeze(0), k], axis=2)
v = paddle.concat([cache_v.squeeze(0), v], axis=2)
product = paddle.matmul(
x=q * (x.shape[3] ** -0.5), y=k, transpose_y=True
)
product = product + src_mask
product = paddle.nn.functional.softmax(product)
out = (
paddle.matmul(product, v).transpose([0, 2, 1, 3]).reshape([bsz, -1])
)
normalized_out = self.quant_helper(
out,
out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
).reshape([bsz, -1])
return out, normalized_out
def check_main(
self,
x,
cache_kv_out,
cache_kv_mmha_out,
src_mask,
qkv_out_scale,
out_scale,
dtype,
):
paddle.disable_static()
if qkv_out_scale is not None:
x = paddle.to_tensor(x).cast("int32")
qkv_out_scale = paddle.to_tensor(qkv_out_scale).cast("float32")
else:
x = paddle.to_tensor(x).cast(dtype)
src_mask = paddle.to_tensor(src_mask).cast(dtype)
cache_kv_out = paddle.to_tensor(cache_kv_out).cast(dtype)
cache_kv_mmha_out = paddle.to_tensor(cache_kv_mmha_out).cast(dtype)
paddle_naive_mmha_out = 0
paddle_naive_mmha_out = self.mmha_naive(
x,
cache_kv_out,
src_mask,
qkv_out_scale,
self.seq_len,
out_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
self.bsz,
)
x = x.reshape([self.bsz, -1])
paddle_mmha_out = masked_multihead_attention(
x,
cache_kv_mmha_out,
src_mask,
None,
None,
None,
None,
qkv_out_scale,
None,
None,
self.seq_len,
self.rotary_emb_dims,
self.use_neox_rotary_style,
out_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
)
paddle.enable_static()
return paddle_naive_mmha_out, paddle_mmha_out
def test_mmha_fp16(self):
if not paddle.is_compiled_with_cuda():
return
paddle_naive_mmha, paddle_mmha_out = self.check_main(
self.x,
self.cache_kv_out,
self.cache_kv_mmha_out,
self.src_mask,
None,
-1,
'float16',
)
np.testing.assert_allclose(
paddle_mmha_out[0].numpy(),
paddle_naive_mmha[0].numpy(),
rtol=1e-3,
atol=1e-3,
)
def test_mmha_qkv_out_scale(self):
if not paddle.is_compiled_with_cuda():
return
paddle_naive_mmha, paddle_mmha_out = self.check_main(
self.x_int,
self.cache_kv_out,
self.cache_kv_mmha_out,
self.src_mask,
self.qkv_out_scale,
-1,
'float16',
)
np.testing.assert_allclose(
paddle_mmha_out[0].numpy(),
paddle_naive_mmha[0].numpy(),
rtol=1e-3,
atol=1e-3,
)
def test_mmha_outlinear_in_scale(self):
if not paddle.is_compiled_with_cuda():
return
paddle_naive_mmha, paddle_mmha_out = self.check_main(
self.x,
self.cache_kv_out,
self.cache_kv_mmha_out,
self.src_mask,
None,
self.out_scale,
'float16',
)
np.testing.assert_allclose(
paddle_mmha_out[0].numpy(),
paddle_naive_mmha[1].numpy(),
rtol=1,
atol=1,
)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestLayerNormStaticInt8Op(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.bsz = 2
self.cache_bsz = 2
self.num_head = 32
self.dim_head = 128
self.beam_size = 1
self.max_seq_len = 33
self.sequence_length = 32
self.x = np.random.uniform(
-0.05, 0.05, [self.bsz, 3, self.num_head, self.dim_head]
)
self.src_mask = np.zeros([self.bsz, 1, 1, self.sequence_length + 1])
self.cum_offsets = None
self.sequence_lengths = None
self.rotary_tensor = None
self.beam_cache_offset = None
self.cache_kv_out = np.random.uniform(
-0.05,
0.05,
[
2,
self.cache_bsz,
self.num_head,
self.sequence_length,
self.dim_head,
],
)
numpy_ones = np.zeros(
[2, self.cache_bsz, self.num_head, 1, self.dim_head]
)
self.cache_kv_mmha_out = np.concatenate(
(self.cache_kv_out, numpy_ones), axis=3
)
self.qkv_out_scale = None
self.out_shift = None
self.out_smooth = None
self.seq_len = 1
self.rotary_emb_dims = 0
self.use_neox_rotary_style = False
self.out_scale = -1
self.quant_round_type = 1
self.quant_max_bound = 127
self.quant_min_bound = -127
self.place = paddle.CUDAPlace(0)
def mmha_naive(
self,
x,
cache_kv_out,
src_mask,
qkv_out_scale,
seq_len,
out_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
bsz,
):
if qkv_out_scale is not None:
x = x.cast(cache_kv_out.dtype) * qkv_out_scale
x = paddle.transpose(
x, [0, 2, 1, 3]
) # [bz, seqlen, nhead, head_dim] --> [bz, nhead, seqlen, head_dim]
q, k, v = paddle.split(x, 3, axis=2)
cache_k, cache_v = paddle.split(cache_kv_out, 2, axis=0)
k = paddle.concat([cache_k.squeeze(0), k], axis=2)
v = paddle.concat([cache_v.squeeze(0), v], axis=2)
product = paddle.matmul(
x=q * (x.shape[3] ** -0.5), y=k, transpose_y=True
)
product = product + src_mask
product = paddle.nn.functional.softmax(product)
out = (
paddle.matmul(product, v).transpose([0, 2, 1, 3]).reshape([bsz, -1])
)
return out
def check_main(
self,
x,
src_mask,
cache_kv_out,
cache_kv_mmha_out,
qkv_out_scale,
out_scale,
dtype,
):
paddle.disable_static()
x_tensor = paddle.to_tensor(x).cast(dtype)
src_mask_tensor = paddle.to_tensor(src_mask).cast(dtype)
cache_kv_out = paddle.to_tensor(cache_kv_out).cast(dtype)
paddle_naive_mmha_out = self.mmha_naive(
x_tensor,
cache_kv_out,
src_mask_tensor,
None,
self.seq_len,
out_scale,
self.quant_round_type,
self.quant_max_bound,
self.quant_min_bound,
self.bsz,
)
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x_static = paddle.static.data(
name="x_static",
shape=[self.bsz, 3 * self.num_head * self.dim_head],
dtype=dtype,
)
src_mask_static = paddle.static.data(
name="src_mask_static",
shape=[self.bsz, 1, 1, self.sequence_length + 1],
dtype=dtype,
)
cache_kv_mmha_out_static = paddle.static.data(
name="cache_kv_mmha_out_static",
shape=[
2,
self.cache_bsz,
self.num_head,
self.sequence_length + 1,
self.dim_head,
],
dtype=dtype,
)
outs = masked_multihead_attention(
x_static,
cache_kv_mmha_out_static,
src_mask_static,
None,
None,
None,
None,
None,
None,
None,
32,
0,
False,
-1,
1,
127.0,
-127.0,
)
exe = paddle.static.Executor(self.place)
out_s = exe.run(
feed={
"x_static": x.reshape(self.bsz, -1).astype(dtype),
"cache_kv_mmha_out_static": cache_kv_mmha_out.astype(dtype),
"src_mask_static": src_mask.astype(dtype),
},
fetch_list=[outs],
)
return paddle_naive_mmha_out, out_s
def test_mmha_fp16(self):
if not paddle.is_compiled_with_cuda():
return
paddle_naive_mmha_out, paddle_mmha_out = self.check_main(
self.x,
self.src_mask,
self.cache_kv_out,
self.cache_kv_mmha_out,
self.qkv_out_scale,
self.out_scale,
'float16',
)
np.testing.assert_allclose(
paddle_mmha_out[0],
paddle_naive_mmha_out.numpy(),
rtol=1e-3,
atol=1e-3,
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册