未验证 提交 03f9e598 编写于 作者: Z Zhang Zheng 提交者: GitHub

Support more dimensions in MMHA (#43612)

* support more dimensions

* fix
上级 2ddbc647
...@@ -114,9 +114,11 @@ template <typename T, int Dh> struct Qk_vec_ {}; ...@@ -114,9 +114,11 @@ template <typename T, int Dh> struct Qk_vec_ {};
template <> struct Qk_vec_<float, 32> { using Type = float; }; template <> struct Qk_vec_<float, 32> { using Type = float; };
template <> struct Qk_vec_<float, 64> { using Type = float2; }; template <> struct Qk_vec_<float, 64> { using Type = float2; };
template <> struct Qk_vec_<float, 128> { using Type = float4; }; template <> struct Qk_vec_<float, 128> { using Type = float4; };
template <> struct Qk_vec_<float, 256> { using Type = float4; };
template <> struct Qk_vec_<float16, 32> { using Type = uint32_t; }; template <> struct Qk_vec_<float16, 32> { using Type = uint32_t; };
template <> struct Qk_vec_<float16, 64> { using Type = uint32_t; }; template <> struct Qk_vec_<float16, 64> { using Type = uint32_t; };
template <> struct Qk_vec_<float16, 128> { using Type = uint2; }; template <> struct Qk_vec_<float16, 128> { using Type = uint2; };
template <> struct Qk_vec_<float16, 256> { using Type = uint4; };
template <typename T, int THREADS_PER_KEY> struct K_vec_ {}; template <typename T, int THREADS_PER_KEY> struct K_vec_ {};
template <> struct K_vec_<float, 4> { using Type = float; }; template <> struct K_vec_<float, 4> { using Type = float; };
...@@ -532,11 +534,11 @@ inline __device__ void zero(T &dst) { // NOLINT ...@@ -532,11 +534,11 @@ inline __device__ void zero(T &dst) { // NOLINT
template <typename T, int Dh, int Dh_MAX, int THREADS_PER_KEY, template <typename T, int Dh, int Dh_MAX, int THREADS_PER_KEY,
int THREADS_PER_VALUE, int THREADS_PER_BLOCK> int THREADS_PER_VALUE, int THREADS_PER_BLOCK>
__global__ void masked_multihead_attention_kernel( __global__ void masked_multihead_attention_kernel(
Masked_multihead_attention_params<T> params, int pad_active_groups) { Masked_multihead_attention_params<T> params) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
static_assert(Dh % THREADS_PER_KEY == 0, ""); static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
static_assert(Dh % THREADS_PER_VALUE == 0, ""); static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
constexpr int WARP_SIZE = 32; constexpr int WARP_SIZE = 32;
constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
...@@ -552,7 +554,8 @@ __global__ void masked_multihead_attention_kernel( ...@@ -552,7 +554,8 @@ __global__ void masked_multihead_attention_kernel(
T *out_smem = reinterpret_cast<T *>(smem_); T *out_smem = reinterpret_cast<T *>(smem_);
__shared__ float red_smem[WARPS_PER_BLOCK * 2]; __shared__ float red_smem[WARPS_PER_BLOCK * 2];
__shared__ T q_smem[Dh]; using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
__shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX];
const int bi = blockIdx.y; const int bi = blockIdx.y;
const int hi = blockIdx.x; const int hi = blockIdx.x;
...@@ -565,10 +568,11 @@ __global__ void masked_multihead_attention_kernel( ...@@ -565,10 +568,11 @@ __global__ void masked_multihead_attention_kernel(
// qkv [B, S=1, 3, num_head, head_dim] // qkv [B, S=1, 3, num_head, head_dim]
int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh;
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
static_assert(Dh % QK_VEC_SIZE == 0 && Dh / QK_VEC_SIZE <= WARP_SIZE, ""); static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
constexpr int QK_VECS_PER_WARP = Dh / QK_VEC_SIZE; // Use block reduction if needed
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE;
// cache_k, [B, num_head, head_dim / x, max_seq_len, x] // cache_k, [B, num_head, head_dim / x, max_seq_len, x]
// x == 4/8 for FP32/FP16, 128bit, 16Byte // x == 4/8 for FP32/FP16, 128bit, 16Byte
...@@ -584,13 +588,29 @@ __global__ void masked_multihead_attention_kernel( ...@@ -584,13 +588,29 @@ __global__ void masked_multihead_attention_kernel(
int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE; int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE;
int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE; int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE;
Qk_vec q = *reinterpret_cast<const Qk_vec *>(&q_base[qk_offset]); Qk_vec q;
Qk_vec k = *reinterpret_cast<const Qk_vec *>(&k_base[qk_offset]); zero(q);
q = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
Qk_vec q_bias = ? *reinterpret_cast<const Qk_vec *>(&q_base[qk_offset])
*reinterpret_cast<const Qk_vec *>(&q_bias_base[qk_bias_offset]); : q;
Qk_vec k_bias = Qk_vec k;
*reinterpret_cast<const Qk_vec *>(&k_bias_base[qk_bias_offset]); zero(k);
k = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&k_base[qk_offset])
: k;
Qk_vec q_bias;
zero(q_bias);
q_bias =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&q_bias_base[qk_bias_offset])
: q_bias;
Qk_vec k_bias;
zero(k_bias);
k_bias =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&k_bias_base[qk_bias_offset])
: k_bias;
q = add(q, q_bias); q = add(q, q_bias);
// TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510 // TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510
...@@ -604,24 +624,33 @@ __global__ void masked_multihead_attention_kernel( ...@@ -604,24 +624,33 @@ __global__ void masked_multihead_attention_kernel(
int offset = bhi * params.max_seq_length * Dh + int offset = bhi * params.max_seq_length * Dh +
co * params.max_seq_length * QK_ELTS_IN_16B + co * params.max_seq_length * QK_ELTS_IN_16B +
params.timestep * QK_ELTS_IN_16B + ci; params.timestep * QK_ELTS_IN_16B + ci;
*reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k; if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k;
}
qk = dot<Qk_vec, Qk_vec>(q, k); qk = dot<Qk_vec, Qk_vec>(q, k);
}
if (tid < WARP_SIZE) { if (QK_VECS_PER_WARP <= WARP_SIZE) {
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { #pragma unroll
qk += __shfl_xor_sync(uint32_t(-1), qk, mask); for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
} qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
if (tid == 0) { }
// NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk *= params.inv_sqrt_dh;
qk_max = qk;
qk_smem[params.timestep] = qk;
} }
} }
if (QK_VECS_PER_WARP > WARP_SIZE) {
constexpr int WARPS_PER_RED =
(QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
qk = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
}
if (tid == 0) {
// NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk *= params.inv_sqrt_dh;
qk_max = qk;
qk_smem[params.timestep] = qk;
}
__syncthreads(); __syncthreads();
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
...@@ -635,13 +664,15 @@ __global__ void masked_multihead_attention_kernel( ...@@ -635,13 +664,15 @@ __global__ void masked_multihead_attention_kernel(
using K_vec = typename K_vec_<T, THREADS_PER_KEY>::Type; using K_vec = typename K_vec_<T, THREADS_PER_KEY>::Type;
constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T);
static_assert(Dh % K_VEC_SIZE == 0, ""); static_assert(Dh_MAX % K_VEC_SIZE == 0, "");
constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY; constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY;
constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
int ko = tid / THREADS_PER_KEY; int ko = tid / THREADS_PER_KEY;
int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE; int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE;
static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD, "");
K_vec q[K_VECS_PER_THREAD]; K_vec q[K_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = 0; i < K_VECS_PER_THREAD; ++i) { for (int i = 0; i < K_VECS_PER_THREAD; ++i) {
...@@ -657,11 +688,17 @@ __global__ void masked_multihead_attention_kernel( ...@@ -657,11 +688,17 @@ __global__ void masked_multihead_attention_kernel(
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
K_vec k[K_VECS_PER_THREAD]; K_vec k[K_VECS_PER_THREAD];
K_vec k_vec_zero;
zero(k_vec_zero);
#pragma unroll #pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_seq_length + ti; int jj = ii * params.max_seq_length + ti;
if (ti < params.timestep) { if (ti < params.timestep) {
k[ii] = *reinterpret_cast<const K_vec *>(&k_cache[jj * QK_ELTS_IN_16B]); k[ii] =
(Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length)
? *reinterpret_cast<const K_vec *>(
&k_cache[jj * QK_ELTS_IN_16B])
: k_vec_zero;
} }
} }
...@@ -727,7 +764,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -727,7 +764,7 @@ __global__ void masked_multihead_attention_kernel(
} }
__syncthreads(); __syncthreads();
constexpr int V_VEC_SIZE = Dh / THREADS_PER_VALUE; constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type; using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type;
int vo = tid / THREADS_PER_VALUE; int vo = tid / THREADS_PER_VALUE;
...@@ -747,7 +784,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -747,7 +784,7 @@ __global__ void masked_multihead_attention_kernel(
zero(out); zero(out);
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
if (vo < V_PER_ITER) { if (Dh == Dh_MAX || vi < Dh) {
for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) { for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) {
V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]); V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
...@@ -770,10 +807,12 @@ __global__ void masked_multihead_attention_kernel( ...@@ -770,10 +807,12 @@ __global__ void masked_multihead_attention_kernel(
__syncthreads(); __syncthreads();
#endif #endif
if (vo == (params.timestep % V_PER_ITER)) { V_vec v_bias;
zero(v_bias);
if (vo == (params.timestep % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) {
V_vec v = *reinterpret_cast<const V_vec *>( V_vec v = *reinterpret_cast<const V_vec *>(
&params.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); &params.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]);
V_vec v_bias = *reinterpret_cast<const V_vec *>( v_bias = *reinterpret_cast<const V_vec *>(
&params.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); &params.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]);
v = add(v, v_bias); v = add(v, v_bias);
*reinterpret_cast<V_vec *>(&v_cache[params.timestep * Dh]) = v; *reinterpret_cast<V_vec *>(&v_cache[params.timestep * Dh]) = v;
...@@ -787,31 +826,31 @@ __global__ void masked_multihead_attention_kernel( ...@@ -787,31 +826,31 @@ __global__ void masked_multihead_attention_kernel(
__syncthreads(); __syncthreads();
if (vo < pad_active_groups / 2) { if (Dh == Dh_MAX || vi < Dh) {
zero(*reinterpret_cast<V_vec *>(&out_smem[vo * Dh + vi]));
}
#pragma unroll #pragma unroll
for (int active_groups = pad_active_groups; active_groups >= 2; for (int active_groups = V_PER_ITER; active_groups >= 2;
active_groups /= 2) { active_groups /= 2) {
int midpoint = active_groups / 2; int midpoint = active_groups / 2;
if (vo >= midpoint && vo < active_groups) { if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float( convert_from_float(
*reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]), *reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]),
out); out);
#else #else
*reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]) = out; *reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
#endif #endif
}
__syncthreads();
if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
out =
add(*reinterpret_cast<const V_vec *>(&out_smem[vo * Dh + vi]), out);
}
__syncthreads();
} }
__syncthreads();
if (vo < midpoint) {
out = add(*reinterpret_cast<const V_vec *>(&out_smem[vo * Dh + vi]), out);
}
__syncthreads();
} }
if (vo == 0) { if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float(*reinterpret_cast<V_vec *>(&params.out[bhi * Dh + vi]), convert_from_float(*reinterpret_cast<V_vec *>(&params.out[bhi * Dh + vi]),
out); out);
...@@ -837,7 +876,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -837,7 +876,7 @@ __global__ void masked_multihead_attention_kernel(
template <typename T> template <typename T>
inline size_t smem_size_in_bytes( inline size_t smem_size_in_bytes(
const Masked_multihead_attention_params<T> &params, int dim_head, const Masked_multihead_attention_params<T> &params, int dim_head,
int threads_per_value, int threads_per_block, int pad_active_groups) { int threads_per_value, int threads_per_block) {
size_t qk_sz = div_up(params.timestep + 1, 4) * 16; size_t qk_sz = div_up(params.timestep + 1, 4) * 16;
size_t logits_sz = 0; size_t logits_sz = 0;
...@@ -848,27 +887,25 @@ inline size_t smem_size_in_bytes( ...@@ -848,27 +887,25 @@ inline size_t smem_size_in_bytes(
#endif #endif
size_t softmax_sz = qk_sz + logits_sz; size_t softmax_sz = qk_sz + logits_sz;
int rows_per_red = pad_active_groups; int rows_per_red = threads_per_block / threads_per_value;
size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2; size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2;
return max(softmax_sz, red_sz); return max(softmax_sz, red_sz);
} }
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \ #define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, stream) \ THDS_PER_BLOCK, stream) \
int pad_active_groups = \ size_t smem_sz = \
1 << static_cast<int>(ceil(std::log2(THDS_PER_BLOCK / THDS_PER_VALUE))); \ smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
size_t smem_sz = smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, \ dim3 grid(params.num_head, params.batch_size); \
THDS_PER_BLOCK, pad_active_groups); \ masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, \
dim3 grid(params.num_head, params.batch_size); \ THDS_PER_VALUE, THDS_PER_BLOCK> \
masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, \ <<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
THDS_PER_VALUE, THDS_PER_BLOCK> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params, pad_active_groups)
template <typename T, int Dh, int Dh_MAX> template <typename T, int Dh, int Dh_MAX>
void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params, void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params,
const cudaStream_t &stream) { const cudaStream_t &stream) {
constexpr int THREADS_PER_VALUE = Dh * sizeof(T) / 16; constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
if (params.timestep < 32) { if (params.timestep < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream); MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream);
} else if (params.timestep < 2048) { } else if (params.timestep < 2048) {
...@@ -898,6 +935,12 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor, ...@@ -898,6 +935,12 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
params.inv_sqrt_dh = inv_sqrt_dh; params.inv_sqrt_dh = inv_sqrt_dh;
switch (dim_head) { switch (dim_head) {
case 10:
fmha_launch_kernel<T, 10, 32>(params, dev_ctx.stream());
break;
case 26:
fmha_launch_kernel<T, 26, 32>(params, dev_ctx.stream());
break;
case 32: case 32:
fmha_launch_kernel<T, 32, 32>(params, dev_ctx.stream()); fmha_launch_kernel<T, 32, 32>(params, dev_ctx.stream());
break; break;
...@@ -910,11 +953,12 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor, ...@@ -910,11 +953,12 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
case 128: case 128:
fmha_launch_kernel<T, 128, 128>(params, dev_ctx.stream()); fmha_launch_kernel<T, 128, 128>(params, dev_ctx.stream());
break; break;
case 192:
fmha_launch_kernel<T, 192, 256>(params, dev_ctx.stream());
break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"dim_head = %d is unsupport, only support " "Dim_head = %d is unsupport!", dim_head));
"dim_head = 32, 64, 96 or 128 for now.",
dim_head));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册