未验证 提交 03f9e598 编写于 作者: Z Zhang Zheng 提交者: GitHub

Support more dimensions in MMHA (#43612)

* support more dimensions

* fix
上级 2ddbc647
......@@ -114,9 +114,11 @@ template <typename T, int Dh> struct Qk_vec_ {};
template <> struct Qk_vec_<float, 32> { using Type = float; };
template <> struct Qk_vec_<float, 64> { using Type = float2; };
template <> struct Qk_vec_<float, 128> { using Type = float4; };
template <> struct Qk_vec_<float, 256> { using Type = float4; };
template <> struct Qk_vec_<float16, 32> { using Type = uint32_t; };
template <> struct Qk_vec_<float16, 64> { using Type = uint32_t; };
template <> struct Qk_vec_<float16, 128> { using Type = uint2; };
template <> struct Qk_vec_<float16, 256> { using Type = uint4; };
template <typename T, int THREADS_PER_KEY> struct K_vec_ {};
template <> struct K_vec_<float, 4> { using Type = float; };
......@@ -532,11 +534,11 @@ inline __device__ void zero(T &dst) { // NOLINT
template <typename T, int Dh, int Dh_MAX, int THREADS_PER_KEY,
int THREADS_PER_VALUE, int THREADS_PER_BLOCK>
__global__ void masked_multihead_attention_kernel(
Masked_multihead_attention_params<T> params, int pad_active_groups) {
Masked_multihead_attention_params<T> params) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
static_assert(Dh % THREADS_PER_KEY == 0, "");
static_assert(Dh % THREADS_PER_VALUE == 0, "");
static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
constexpr int WARP_SIZE = 32;
constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
......@@ -552,7 +554,8 @@ __global__ void masked_multihead_attention_kernel(
T *out_smem = reinterpret_cast<T *>(smem_);
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
__shared__ T q_smem[Dh];
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
__shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX];
const int bi = blockIdx.y;
const int hi = blockIdx.x;
......@@ -565,10 +568,11 @@ __global__ void masked_multihead_attention_kernel(
// qkv [B, S=1, 3, num_head, head_dim]
int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh;
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
static_assert(Dh % QK_VEC_SIZE == 0 && Dh / QK_VEC_SIZE <= WARP_SIZE, "");
constexpr int QK_VECS_PER_WARP = Dh / QK_VEC_SIZE;
static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
// Use block reduction if needed
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE;
// cache_k, [B, num_head, head_dim / x, max_seq_len, x]
// x == 4/8 for FP32/FP16, 128bit, 16Byte
......@@ -584,13 +588,29 @@ __global__ void masked_multihead_attention_kernel(
int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE;
int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE;
Qk_vec q = *reinterpret_cast<const Qk_vec *>(&q_base[qk_offset]);
Qk_vec k = *reinterpret_cast<const Qk_vec *>(&k_base[qk_offset]);
Qk_vec q_bias =
*reinterpret_cast<const Qk_vec *>(&q_bias_base[qk_bias_offset]);
Qk_vec k_bias =
*reinterpret_cast<const Qk_vec *>(&k_bias_base[qk_bias_offset]);
Qk_vec q;
zero(q);
q = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&q_base[qk_offset])
: q;
Qk_vec k;
zero(k);
k = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&k_base[qk_offset])
: k;
Qk_vec q_bias;
zero(q_bias);
q_bias =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&q_bias_base[qk_bias_offset])
: q_bias;
Qk_vec k_bias;
zero(k_bias);
k_bias =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&k_bias_base[qk_bias_offset])
: k_bias;
q = add(q, q_bias);
// TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510
......@@ -604,24 +624,33 @@ __global__ void masked_multihead_attention_kernel(
int offset = bhi * params.max_seq_length * Dh +
co * params.max_seq_length * QK_ELTS_IN_16B +
params.timestep * QK_ELTS_IN_16B + ci;
*reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k;
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k;
}
qk = dot<Qk_vec, Qk_vec>(q, k);
}
if (tid < WARP_SIZE) {
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
if (tid == 0) {
// NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk *= params.inv_sqrt_dh;
qk_max = qk;
qk_smem[params.timestep] = qk;
if (QK_VECS_PER_WARP <= WARP_SIZE) {
#pragma unroll
for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
}
}
}
if (QK_VECS_PER_WARP > WARP_SIZE) {
constexpr int WARPS_PER_RED =
(QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
qk = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
}
if (tid == 0) {
// NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk *= params.inv_sqrt_dh;
qk_max = qk;
qk_smem[params.timestep] = qk;
}
__syncthreads();
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
......@@ -635,13 +664,15 @@ __global__ void masked_multihead_attention_kernel(
using K_vec = typename K_vec_<T, THREADS_PER_KEY>::Type;
constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T);
static_assert(Dh % K_VEC_SIZE == 0, "");
constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY;
static_assert(Dh_MAX % K_VEC_SIZE == 0, "");
constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY;
constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
int ko = tid / THREADS_PER_KEY;
int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE;
static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD, "");
K_vec q[K_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < K_VECS_PER_THREAD; ++i) {
......@@ -657,11 +688,17 @@ __global__ void masked_multihead_attention_kernel(
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
K_vec k[K_VECS_PER_THREAD];
K_vec k_vec_zero;
zero(k_vec_zero);
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_seq_length + ti;
if (ti < params.timestep) {
k[ii] = *reinterpret_cast<const K_vec *>(&k_cache[jj * QK_ELTS_IN_16B]);
k[ii] =
(Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length)
? *reinterpret_cast<const K_vec *>(
&k_cache[jj * QK_ELTS_IN_16B])
: k_vec_zero;
}
}
......@@ -727,7 +764,7 @@ __global__ void masked_multihead_attention_kernel(
}
__syncthreads();
constexpr int V_VEC_SIZE = Dh / THREADS_PER_VALUE;
constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type;
int vo = tid / THREADS_PER_VALUE;
......@@ -747,7 +784,7 @@ __global__ void masked_multihead_attention_kernel(
zero(out);
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
if (vo < V_PER_ITER) {
if (Dh == Dh_MAX || vi < Dh) {
for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) {
V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
......@@ -770,10 +807,12 @@ __global__ void masked_multihead_attention_kernel(
__syncthreads();
#endif
if (vo == (params.timestep % V_PER_ITER)) {
V_vec v_bias;
zero(v_bias);
if (vo == (params.timestep % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) {
V_vec v = *reinterpret_cast<const V_vec *>(
&params.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]);
V_vec v_bias = *reinterpret_cast<const V_vec *>(
v_bias = *reinterpret_cast<const V_vec *>(
&params.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]);
v = add(v, v_bias);
*reinterpret_cast<V_vec *>(&v_cache[params.timestep * Dh]) = v;
......@@ -787,31 +826,31 @@ __global__ void masked_multihead_attention_kernel(
__syncthreads();
if (vo < pad_active_groups / 2) {
zero(*reinterpret_cast<V_vec *>(&out_smem[vo * Dh + vi]));
}
if (Dh == Dh_MAX || vi < Dh) {
#pragma unroll
for (int active_groups = pad_active_groups; active_groups >= 2;
active_groups /= 2) {
int midpoint = active_groups / 2;
for (int active_groups = V_PER_ITER; active_groups >= 2;
active_groups /= 2) {
int midpoint = active_groups / 2;
if (vo >= midpoint && vo < active_groups) {
if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float(
*reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]),
out);
convert_from_float(
*reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]),
out);
#else
*reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
*reinterpret_cast<V_vec *>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
#endif
}
__syncthreads();
if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
out =
add(*reinterpret_cast<const V_vec *>(&out_smem[vo * Dh + vi]), out);
}
__syncthreads();
}
__syncthreads();
if (vo < midpoint) {
out = add(*reinterpret_cast<const V_vec *>(&out_smem[vo * Dh + vi]), out);
}
__syncthreads();
}
if (vo == 0) {
if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float(*reinterpret_cast<V_vec *>(&params.out[bhi * Dh + vi]),
out);
......@@ -837,7 +876,7 @@ __global__ void masked_multihead_attention_kernel(
template <typename T>
inline size_t smem_size_in_bytes(
const Masked_multihead_attention_params<T> &params, int dim_head,
int threads_per_value, int threads_per_block, int pad_active_groups) {
int threads_per_value, int threads_per_block) {
size_t qk_sz = div_up(params.timestep + 1, 4) * 16;
size_t logits_sz = 0;
......@@ -848,27 +887,25 @@ inline size_t smem_size_in_bytes(
#endif
size_t softmax_sz = qk_sz + logits_sz;
int rows_per_red = pad_active_groups;
int rows_per_red = threads_per_block / threads_per_value;
size_t red_sz = rows_per_red * dim_head * sizeof(T) / 2;
return max(softmax_sz, red_sz);
}
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, stream) \
int pad_active_groups = \
1 << static_cast<int>(ceil(std::log2(THDS_PER_BLOCK / THDS_PER_VALUE))); \
size_t smem_sz = smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, \
THDS_PER_BLOCK, pad_active_groups); \
dim3 grid(params.num_head, params.batch_size); \
masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, \
THDS_PER_VALUE, THDS_PER_BLOCK> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params, pad_active_groups)
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, stream) \
size_t smem_sz = \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_head, params.batch_size); \
masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, \
THDS_PER_VALUE, THDS_PER_BLOCK> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
template <typename T, int Dh, int Dh_MAX>
void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params,
const cudaStream_t &stream) {
constexpr int THREADS_PER_VALUE = Dh * sizeof(T) / 16;
constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
if (params.timestep < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream);
} else if (params.timestep < 2048) {
......@@ -898,6 +935,12 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
params.inv_sqrt_dh = inv_sqrt_dh;
switch (dim_head) {
case 10:
fmha_launch_kernel<T, 10, 32>(params, dev_ctx.stream());
break;
case 26:
fmha_launch_kernel<T, 26, 32>(params, dev_ctx.stream());
break;
case 32:
fmha_launch_kernel<T, 32, 32>(params, dev_ctx.stream());
break;
......@@ -910,11 +953,12 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
case 128:
fmha_launch_kernel<T, 128, 128>(params, dev_ctx.stream());
break;
case 192:
fmha_launch_kernel<T, 192, 256>(params, dev_ctx.stream());
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"dim_head = %d is unsupport, only support "
"dim_head = 32, 64, 96 or 128 for now.",
dim_head));
"Dim_head = %d is unsupport!", dim_head));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册