未验证 提交 53df50c7 编写于 作者: L lzy 提交者: GitHub

make FusedMultiTransformer supports variable-lengths. (#49560)

* make FusedMultiTransformer supports variable-lengths.

* modify ffn2 when cuda_version >= 11.6 because of #49392.

* code style

* delete remove_padding
上级 a3989b5e
......@@ -62,6 +62,78 @@ class AttnDropoutParam {
const phi::DenseTensor* seed_;
};
template <typename T, int VecSize>
__global__ void TransposeRemovingPadding(const T* input_data,
T* output_data,
const int batch_size,
const int num_head,
const int seq_len,
const int head_dim,
const int token_num,
const int elem_cnt,
const int* padding_offset) {
// transpose and remove padding
// [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head,
// head_dim]
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
const int dim_embed = num_head * head_dim;
using LoadT = phi::AlignedVector<T, VecSize>;
LoadT src_vec;
for (int32_t linear_index = idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / dim_embed;
const int ori_token_idx =
token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
const int ori_batch_id = ori_token_idx / seq_len;
const int ori_seq_id = ori_token_idx % seq_len;
const int ori_head_id = (linear_index % dim_embed) / head_dim;
const int ori_head_lane = (linear_index % dim_embed) % head_dim;
const int ori_idx = ori_batch_id * num_head * seq_len * head_dim +
ori_head_id * seq_len * head_dim +
ori_seq_id * head_dim + ori_head_lane;
phi::Load<T, VecSize>(&input_data[ori_idx], &src_vec);
phi::Store<T, VecSize>(src_vec, &output_data[linear_index]);
}
}
template <typename T>
void InvokeTransposeRemovePadding(const phi::GPUContext& dev_ctx,
const T* input_data,
T* output_data,
const int batch_size,
const int num_head,
const int seq_len,
const int head_dim,
const int token_num,
const int* padding_offset) {
// [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head,
// head_dim]
constexpr int VEC_16B = 16;
const int elem_cnt = token_num * num_head * head_dim;
constexpr int PackSize = VEC_16B / sizeof(T);
PADDLE_ENFORCE_EQ(
head_dim % PackSize,
0,
platform::errors::PreconditionNotMet(
"dim_head=%d must be divisible by vec_size=%d", head_dim, PackSize));
const int32_t pack_num = elem_cnt / PackSize;
const int32_t block_size = 128;
int32_t grid_size = (pack_num + block_size - 1) / block_size;
TransposeRemovingPadding<T, PackSize>
<<<grid_size, block_size, 0, dev_ctx.stream()>>>(input_data,
output_data,
batch_size,
num_head,
seq_len,
head_dim,
token_num,
elem_cnt,
padding_offset);
}
template <typename T>
class FMHARef {
public:
......@@ -256,18 +328,21 @@ class FMHARef {
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
}
void ComputeForwardWithoutTranspose(const phi::DenseTensor* cache_kv_tensor,
const phi::DenseTensor* src_mask_tensor,
phi::DenseTensor* q_transpose_out_tensor,
phi::DenseTensor* kv_transpose_out_tensor,
phi::DenseTensor* cache_kv_out_tensor,
phi::DenseTensor* qk_out_tensor,
phi::DenseTensor* src_mask_out_tensor,
phi::DenseTensor* softmax_out_tensor,
phi::DenseTensor* dropout_mask_out_tensor,
phi::DenseTensor* dropout_out_tensor,
phi::DenseTensor* qktv_out_tensor,
phi::DenseTensor* fmha_out_tensor) {
void ComputeForwardWithoutTranspose(
const phi::DenseTensor* cache_kv_tensor,
const phi::DenseTensor* src_mask_tensor,
const phi::DenseTensor* padding_offset_tensor,
phi::DenseTensor* q_transpose_out_tensor,
phi::DenseTensor* kv_transpose_out_tensor,
phi::DenseTensor* cache_kv_out_tensor,
phi::DenseTensor* qk_out_tensor,
phi::DenseTensor* src_mask_out_tensor,
phi::DenseTensor* softmax_out_tensor,
phi::DenseTensor* dropout_mask_out_tensor,
phi::DenseTensor* dropout_out_tensor,
phi::DenseTensor* qktv_out_tensor,
phi::DenseTensor* fmha_out_tensor,
const int token_num) {
// input shape: [bs, seq_len, 3, num_head, head_dim]
// transpose with perm [2, 0, 3, 1, 4],
// output_shape: [3, bs, num_head, seq_len, head_dim]
......@@ -424,9 +499,21 @@ class FMHARef {
}
// transpose: [0, 2, 1, 3]
// output shape: [batch_size, seq_len, num_heads, head_dim]
std::vector<int> perm_3 = {0, 2, 1, 3};
phi::funcs::TransposeGPUKernelDriver<T>(
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
if (!padding_offset_tensor) {
std::vector<int> perm_3 = {0, 2, 1, 3};
phi::funcs::TransposeGPUKernelDriver<T>(
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
} else {
InvokeTransposeRemovePadding<T>(dev_ctx_,
qktv_out_data,
fmha_out_data,
batch_size_,
num_head_,
seq_len_,
head_dim_,
token_num,
padding_offset_tensor->data<int>());
}
}
void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor,
......
......@@ -182,6 +182,8 @@ class FusedMultiTransformerOpOpMaker
AddInput("TimeStep",
"(optional, int) The time step for generation inference.")
.AsDispensable();
AddInput("SeqLengths", "(optional) The sequence length tensor of inputs.")
.AsDispensable();
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
.AsDispensable();
AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable();
......
......@@ -127,6 +127,8 @@ struct Masked_multihead_attention_params {
// v [B, num_head, max_seq_len, dim_head]
T *cache_kv;
const int *sequence_lengths{nullptr};
// The RoPE embedding, [B, 1, 1, dim_head]
// rotary_emb_dims = 1 if pos_ids_extra is null else 2
const T *rotary_emb;
......@@ -769,6 +771,10 @@ __global__ void masked_multihead_attention_kernel(
float qk_max = -FLT_MAX;
float qk = 0;
int act_time_step = params.sequence_lengths == nullptr
? params.timestep
: params.sequence_lengths[bi];
// qkv [B, S=1, 3, num_head, head_dim]
int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh;
......@@ -888,7 +894,7 @@ __global__ void masked_multihead_attention_kernel(
int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE;
int offset = bhi * params.max_seq_length * Dh +
co * params.max_seq_length * QK_ELTS_IN_16B +
params.timestep * QK_ELTS_IN_16B + ci;
act_time_step * QK_ELTS_IN_16B + ci;
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k;
}
......@@ -914,7 +920,7 @@ __global__ void masked_multihead_attention_kernel(
// qk += static_cast<float>(mask);
qk *= params.inv_sqrt_dh;
qk_max = qk;
qk_smem[params.timestep] = qk;
qk_smem[act_time_step] = qk;
}
__syncthreads();
......@@ -949,7 +955,7 @@ __global__ void masked_multihead_attention_kernel(
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
T *k_cache = &params.cache_kv[bhi * params.max_seq_length * Dh + ki];
int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP;
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
K_vec k[K_VECS_PER_THREAD];
......@@ -958,7 +964,7 @@ __global__ void masked_multihead_attention_kernel(
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_seq_length + ti;
if (ti < params.timestep) {
if (ti < act_time_step) {
k[ii] =
(Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length)
? *reinterpret_cast<const K_vec *>(
......@@ -972,7 +978,7 @@ __global__ void masked_multihead_attention_kernel(
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k, params.inv_sqrt_dh);
// bool is_mask = false;
if (ti < params.timestep && tid % THREADS_PER_KEY == 0) {
if (ti < act_time_step && tid % THREADS_PER_KEY == 0) {
// qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
T mask = params.attn_mask[bi * (params.timestep + 1) + ti];
qk += static_cast<float>(mask);
......@@ -1014,7 +1020,7 @@ __global__ void masked_multihead_attention_kernel(
#endif
float sum = 0.f;
for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) {
for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) {
// bool is_mask = false;
// float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max);
float logit = __expf(qk_smem[ti] - qk_max);
......@@ -1026,7 +1032,7 @@ __global__ void masked_multihead_attention_kernel(
// FIXME(wangxi): need add 1.e-6f?
float inv_sum = __fdividef(1.f, sum + 1.e-6f);
for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) {
for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) {
convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum);
}
__syncthreads();
......@@ -1052,7 +1058,7 @@ __global__ void masked_multihead_attention_kernel(
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
if (Dh == Dh_MAX || vi < Dh) {
for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) {
for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) {
V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti];
......@@ -1076,18 +1082,18 @@ __global__ void masked_multihead_attention_kernel(
V_vec v_bias;
zero(v_bias);
if (vo == (params.timestep % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) {
if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) {
V_vec v = *reinterpret_cast<const V_vec *>(
&params.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]);
v_bias = *reinterpret_cast<const V_vec *>(
&params.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]);
v = add(v, v_bias);
*reinterpret_cast<V_vec *>(&v_cache[params.timestep * Dh]) = v;
*reinterpret_cast<V_vec *>(&v_cache[act_time_step * Dh]) = v;
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
out = fma(logits_smem[params.timestep], cast_to_float(v), out);
out = fma(logits_smem[act_time_step], cast_to_float(v), out);
#else
out = fma(logits_smem[params.timestep], v, out);
out = fma(logits_smem[act_time_step], v, out);
#endif
}
......@@ -1198,6 +1204,7 @@ void fmha(const phi::GPUContext &dev_ctx,
const phi::DenseTensor &qkv_tensor,
const phi::DenseTensor &qkv_bias_tensor,
const phi::DenseTensor &src_mask_tensor,
const phi::DenseTensor *sequence_lengths_tensor,
const phi::DenseTensor *rotary_tensor,
phi::DenseTensor *cache_kv_tensor,
phi::DenseTensor *out_tensor,
......@@ -1214,6 +1221,11 @@ void fmha(const phi::GPUContext &dev_ctx,
params.qkv_bias = qkv_bias_tensor.data<T>();
params.attn_mask = src_mask_tensor.data<T>();
params.cache_kv = cache_kv_tensor->data<T>();
if (sequence_lengths_tensor) {
params.sequence_lengths = sequence_lengths_tensor->data<int>();
}
if (rotary_emb_dims > 0) {
params.rotary_emb = rotary_tensor->data<T>();
} else {
......@@ -1273,6 +1285,7 @@ void fmha(const phi::GPUContext &dev_ctx,
qkv_bias_tensor,
src_mask_tensor,
nullptr,
nullptr,
cache_kv_tensor,
out_tensor,
batch_size,
......@@ -1395,6 +1408,7 @@ __global__ void add_fusedQKV_bias_transpose_split_kernel(
T *kv_buf,
const T *qkv,
const T *qkv_bias,
const int *padding_offset,
const int32_t elem_cnt,
const int batch_size,
const int seq_len,
......@@ -1423,10 +1437,10 @@ __global__ void add_fusedQKV_bias_transpose_split_kernel(
}
}
const int32_t token_idx = linear_index / fused_hidden_size;
// const int32_t token_padded_idx = token_idx + (padding_offset == nullptr ?
// 0 : padding_offset[token_idx]);
const int32_t target_batch_id = token_idx / seq_len;
const int32_t seq_id = token_idx % seq_len;
const int32_t ori_token_idx =
token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
const int32_t target_batch_id = ori_token_idx / seq_len;
const int32_t seq_id = ori_token_idx % seq_len;
// equal to:
// const int qkv_id = (linear_index % fused_hidden_size) / hidden_size;
......@@ -1475,12 +1489,13 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
T *kv_buf,
const T *qkv,
const T *qkv_bias,
const int *padding_offset,
const int token_num,
const int batch_size,
const int head_num,
const int seq_len,
const int size_per_head,
bool compute_bias) {
const int32_t token_num = batch_size * seq_len;
const int32_t elem_cnt = token_num * head_num * size_per_head * 3;
constexpr int PackSize = VEC_16B / sizeof(T);
PADDLE_ENFORCE_EQ(size_per_head % PackSize,
......@@ -1499,6 +1514,7 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
kv_buf,
qkv,
qkv_bias,
padding_offset,
elem_cnt,
batch_size,
seq_len,
......@@ -1511,6 +1527,7 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
kv_buf,
qkv,
qkv_bias,
padding_offset,
elem_cnt,
batch_size,
seq_len,
......@@ -1524,7 +1541,9 @@ template <typename T>
__global__ void RotrayKernel(const T *input,
const T *cos_emb,
const T *sin_emb,
const int *sequence_lengths,
T *output,
const int rotary_emb_dims,
const int batch_size,
const int head_num,
const int seq_len,
......@@ -1532,6 +1551,7 @@ __global__ void RotrayKernel(const T *input,
int bi = blockIdx.x;
int hi = blockIdx.y;
int si = blockIdx.z;
if (sequence_lengths && si >= sequence_lengths[bi] * rotary_emb_dims) return;
int half_lastdim = last_dim / 2;
// Note(ZhenyuLi): Calculate the relevant data at one time, so that no
// additional space is required.
......@@ -1559,6 +1579,7 @@ void rotary_qk(const phi::GPUContext &dev_ctx,
const T *q_input, // q
const T *k_input, // kv
const T *rotary_emb,
const int *sequence_lengths,
const int rotary_emb_dims,
const int batch_size,
const int head_num,
......@@ -1592,7 +1613,9 @@ void rotary_qk(const phi::GPUContext &dev_ctx,
q_input,
cos_emb,
sin_emb,
sequence_lengths,
q,
rotary_emb_dims,
batch_size,
head_num,
seq_len * rotary_emb_dims,
......@@ -1601,13 +1624,109 @@ void rotary_qk(const phi::GPUContext &dev_ctx,
k_input,
cos_emb,
sin_emb,
sequence_lengths,
k,
rotary_emb_dims,
batch_size,
head_num,
seq_len * rotary_emb_dims,
last_dim);
}
__global__ void GetPaddingOffset(int *d_token_num,
int *padding_offset,
const int *sequence_lengths,
const int batch_size,
const int max_seq_len) {
// get padding offset of each batch
int total_seq_len = 0;
int cum_offset = 0;
int index = 0;
for (int i = 0; i < batch_size; i++) {
const int seq_len = sequence_lengths[i];
for (int j = 0; j < seq_len; j++) {
padding_offset[index] = cum_offset;
index++;
}
cum_offset += max_seq_len - seq_len;
total_seq_len += seq_len;
}
d_token_num[0] = total_seq_len;
}
void InvokeGetPaddingOffset(const phi::GPUContext &dev_ctx,
int *h_token_num,
int *d_token_num,
int *padding_offset,
const int *sequence_lengths,
const int batch_size,
const int max_seq_len) {
GetPaddingOffset<<<1, 1, 0, dev_ctx.stream()>>>(
d_token_num, padding_offset, sequence_lengths, batch_size, max_seq_len);
memory::Copy(platform::CPUPlace(),
h_token_num,
dev_ctx.GetPlace(),
d_token_num,
sizeof(int),
dev_ctx.stream());
}
template <typename T>
__global__ void RemovePadding(T *output_data,
const T *input_data,
const int *padding_offset,
const int dim_embed) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int src_seq_id = bid + padding_offset[bid];
const int tgt_seq_id = bid;
for (int i = tid; i < dim_embed; i += blockDim.x) {
output_data[tgt_seq_id * dim_embed + i] =
input_data[src_seq_id * dim_embed + i];
}
}
template <typename T>
void InvokeRemovePadding(const phi::GPUContext &dev_ctx,
T *output_data,
const T *input_data,
const int *padding_offset,
const int token_num,
const int dim_embed) {
RemovePadding<<<token_num, 256, 0, dev_ctx.stream()>>>(
output_data, input_data, padding_offset, dim_embed);
}
template <typename T>
__global__ void RebuildPadding(T *output_data,
const T *input_data,
const int *padding_offset,
const int dim_embed) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int dst_seq_id = bid + padding_offset[bid];
const int src_seq_id = bid;
for (int i = tid; i < dim_embed; i += blockDim.x) {
output_data[dst_seq_id * dim_embed + i] =
input_data[src_seq_id * dim_embed + i];
}
}
template <typename T>
void InvokeRebuildPadding(const phi::GPUContext &dev_ctx,
T *output_data,
const T *input_data,
const int *padding_offset,
const int token_num,
const int dim_embed) {
// src: [token_num, dim_embed]
// dst: [batch_size * max_seq_len, dim_embed]
RebuildPadding<<<token_num, 256, 0, dev_ctx.stream()>>>(
output_data, input_data, padding_offset, dim_embed);
}
#if CUDA_VERSION >= 11060
// Only Used in Inference
template <typename T>
......
......@@ -64,6 +64,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"PreCaches",
"RotaryPosEmb",
"TimeStep",
"SeqLengths",
"SrcMask",
"OutLinearW",
"OutLinearBias",
......
......@@ -887,6 +887,7 @@ def fused_multi_transformer(
epsilon=1e-05,
cache_kvs=None,
pre_caches=None,
seq_lens=None,
rotary_embs=None,
time_step=None,
attn_mask=None,
......@@ -956,6 +957,7 @@ def fused_multi_transformer(
epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5.
cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None.
pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None.
seq_lens (Tensor optional): The sequence lengths of this batch. The shape is `[bsz]`. Default None.
rotary_embs (Tensor optional): The RoPE embs for rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
......@@ -1055,6 +1057,7 @@ def fused_multi_transformer(
pre_caches,
rotary_embs,
time_step,
seq_lens,
attn_mask,
linear_weights,
linear_biases,
......@@ -1115,6 +1118,7 @@ def fused_multi_transformer(
inputs['PreCaches'] = pre_caches
if rotary_emb_dims > 0:
inputs['RotaryPosEmb'] = rotary_embs
inputs['SeqLengths'] = seq_lens
inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = linear_weights
if linear_biases is not None:
......
......@@ -1382,6 +1382,7 @@ class FusedMultiTransformer(Layer):
pre_caches=None,
rotary_embs=None,
rotary_emb_dims=0,
seq_lens=None,
time_step=None,
):
r"""
......@@ -1406,6 +1407,7 @@ class FusedMultiTransformer(Layer):
rotary_embs (Tensor optional): The RoPE embs for the rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None.
rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None,
1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0.
seq_lens (Tensor optional): The sequence lengths of this batch. The shape is `[bsz]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation
model. Which used in decode stage, to represent the time step,
that is, the real seq_len of CacheKV. The shape is `[1]`, must be
......@@ -1441,6 +1443,7 @@ class FusedMultiTransformer(Layer):
pre_caches=pre_caches,
rotary_embs=rotary_embs,
time_step=time_step,
seq_lens=seq_lens,
attn_mask=attn_mask,
dropout_rate=self.dropout_rate,
rotary_emb_dims=rotary_emb_dims,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册