未验证 提交 53df50c7 编写于 作者: L lzy 提交者: GitHub

make FusedMultiTransformer supports variable-lengths. (#49560)

* make FusedMultiTransformer supports variable-lengths.

* modify ffn2 when cuda_version >= 11.6 because of #49392.

* code style

* delete remove_padding
上级 a3989b5e
...@@ -62,6 +62,78 @@ class AttnDropoutParam { ...@@ -62,6 +62,78 @@ class AttnDropoutParam {
const phi::DenseTensor* seed_; const phi::DenseTensor* seed_;
}; };
template <typename T, int VecSize>
__global__ void TransposeRemovingPadding(const T* input_data,
T* output_data,
const int batch_size,
const int num_head,
const int seq_len,
const int head_dim,
const int token_num,
const int elem_cnt,
const int* padding_offset) {
// transpose and remove padding
// [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head,
// head_dim]
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
const int dim_embed = num_head * head_dim;
using LoadT = phi::AlignedVector<T, VecSize>;
LoadT src_vec;
for (int32_t linear_index = idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / dim_embed;
const int ori_token_idx =
token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
const int ori_batch_id = ori_token_idx / seq_len;
const int ori_seq_id = ori_token_idx % seq_len;
const int ori_head_id = (linear_index % dim_embed) / head_dim;
const int ori_head_lane = (linear_index % dim_embed) % head_dim;
const int ori_idx = ori_batch_id * num_head * seq_len * head_dim +
ori_head_id * seq_len * head_dim +
ori_seq_id * head_dim + ori_head_lane;
phi::Load<T, VecSize>(&input_data[ori_idx], &src_vec);
phi::Store<T, VecSize>(src_vec, &output_data[linear_index]);
}
}
template <typename T>
void InvokeTransposeRemovePadding(const phi::GPUContext& dev_ctx,
const T* input_data,
T* output_data,
const int batch_size,
const int num_head,
const int seq_len,
const int head_dim,
const int token_num,
const int* padding_offset) {
// [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head,
// head_dim]
constexpr int VEC_16B = 16;
const int elem_cnt = token_num * num_head * head_dim;
constexpr int PackSize = VEC_16B / sizeof(T);
PADDLE_ENFORCE_EQ(
head_dim % PackSize,
0,
platform::errors::PreconditionNotMet(
"dim_head=%d must be divisible by vec_size=%d", head_dim, PackSize));
const int32_t pack_num = elem_cnt / PackSize;
const int32_t block_size = 128;
int32_t grid_size = (pack_num + block_size - 1) / block_size;
TransposeRemovingPadding<T, PackSize>
<<<grid_size, block_size, 0, dev_ctx.stream()>>>(input_data,
output_data,
batch_size,
num_head,
seq_len,
head_dim,
token_num,
elem_cnt,
padding_offset);
}
template <typename T> template <typename T>
class FMHARef { class FMHARef {
public: public:
...@@ -256,18 +328,21 @@ class FMHARef { ...@@ -256,18 +328,21 @@ class FMHARef {
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
} }
void ComputeForwardWithoutTranspose(const phi::DenseTensor* cache_kv_tensor, void ComputeForwardWithoutTranspose(
const phi::DenseTensor* src_mask_tensor, const phi::DenseTensor* cache_kv_tensor,
phi::DenseTensor* q_transpose_out_tensor, const phi::DenseTensor* src_mask_tensor,
phi::DenseTensor* kv_transpose_out_tensor, const phi::DenseTensor* padding_offset_tensor,
phi::DenseTensor* cache_kv_out_tensor, phi::DenseTensor* q_transpose_out_tensor,
phi::DenseTensor* qk_out_tensor, phi::DenseTensor* kv_transpose_out_tensor,
phi::DenseTensor* src_mask_out_tensor, phi::DenseTensor* cache_kv_out_tensor,
phi::DenseTensor* softmax_out_tensor, phi::DenseTensor* qk_out_tensor,
phi::DenseTensor* dropout_mask_out_tensor, phi::DenseTensor* src_mask_out_tensor,
phi::DenseTensor* dropout_out_tensor, phi::DenseTensor* softmax_out_tensor,
phi::DenseTensor* qktv_out_tensor, phi::DenseTensor* dropout_mask_out_tensor,
phi::DenseTensor* fmha_out_tensor) { phi::DenseTensor* dropout_out_tensor,
phi::DenseTensor* qktv_out_tensor,
phi::DenseTensor* fmha_out_tensor,
const int token_num) {
// input shape: [bs, seq_len, 3, num_head, head_dim] // input shape: [bs, seq_len, 3, num_head, head_dim]
// transpose with perm [2, 0, 3, 1, 4], // transpose with perm [2, 0, 3, 1, 4],
// output_shape: [3, bs, num_head, seq_len, head_dim] // output_shape: [3, bs, num_head, seq_len, head_dim]
...@@ -424,9 +499,21 @@ class FMHARef { ...@@ -424,9 +499,21 @@ class FMHARef {
} }
// transpose: [0, 2, 1, 3] // transpose: [0, 2, 1, 3]
// output shape: [batch_size, seq_len, num_heads, head_dim] // output shape: [batch_size, seq_len, num_heads, head_dim]
std::vector<int> perm_3 = {0, 2, 1, 3}; if (!padding_offset_tensor) {
phi::funcs::TransposeGPUKernelDriver<T>( std::vector<int> perm_3 = {0, 2, 1, 3};
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); phi::funcs::TransposeGPUKernelDriver<T>(
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
} else {
InvokeTransposeRemovePadding<T>(dev_ctx_,
qktv_out_data,
fmha_out_data,
batch_size_,
num_head_,
seq_len_,
head_dim_,
token_num,
padding_offset_tensor->data<int>());
}
} }
void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor, void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor,
......
...@@ -182,6 +182,8 @@ class FusedMultiTransformerOpOpMaker ...@@ -182,6 +182,8 @@ class FusedMultiTransformerOpOpMaker
AddInput("TimeStep", AddInput("TimeStep",
"(optional, int) The time step for generation inference.") "(optional, int) The time step for generation inference.")
.AsDispensable(); .AsDispensable();
AddInput("SeqLengths", "(optional) The sequence length tensor of inputs.")
.AsDispensable();
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
.AsDispensable(); .AsDispensable();
AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable(); AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable();
......
...@@ -127,6 +127,8 @@ struct Masked_multihead_attention_params { ...@@ -127,6 +127,8 @@ struct Masked_multihead_attention_params {
// v [B, num_head, max_seq_len, dim_head] // v [B, num_head, max_seq_len, dim_head]
T *cache_kv; T *cache_kv;
const int *sequence_lengths{nullptr};
// The RoPE embedding, [B, 1, 1, dim_head] // The RoPE embedding, [B, 1, 1, dim_head]
// rotary_emb_dims = 1 if pos_ids_extra is null else 2 // rotary_emb_dims = 1 if pos_ids_extra is null else 2
const T *rotary_emb; const T *rotary_emb;
...@@ -769,6 +771,10 @@ __global__ void masked_multihead_attention_kernel( ...@@ -769,6 +771,10 @@ __global__ void masked_multihead_attention_kernel(
float qk_max = -FLT_MAX; float qk_max = -FLT_MAX;
float qk = 0; float qk = 0;
int act_time_step = params.sequence_lengths == nullptr
? params.timestep
: params.sequence_lengths[bi];
// qkv [B, S=1, 3, num_head, head_dim] // qkv [B, S=1, 3, num_head, head_dim]
int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh;
...@@ -888,7 +894,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -888,7 +894,7 @@ __global__ void masked_multihead_attention_kernel(
int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE; int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE;
int offset = bhi * params.max_seq_length * Dh + int offset = bhi * params.max_seq_length * Dh +
co * params.max_seq_length * QK_ELTS_IN_16B + co * params.max_seq_length * QK_ELTS_IN_16B +
params.timestep * QK_ELTS_IN_16B + ci; act_time_step * QK_ELTS_IN_16B + ci;
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k; *reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k;
} }
...@@ -914,7 +920,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -914,7 +920,7 @@ __global__ void masked_multihead_attention_kernel(
// qk += static_cast<float>(mask); // qk += static_cast<float>(mask);
qk *= params.inv_sqrt_dh; qk *= params.inv_sqrt_dh;
qk_max = qk; qk_max = qk;
qk_smem[params.timestep] = qk; qk_smem[act_time_step] = qk;
} }
__syncthreads(); __syncthreads();
...@@ -949,7 +955,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -949,7 +955,7 @@ __global__ void masked_multihead_attention_kernel(
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
T *k_cache = &params.cache_kv[bhi * params.max_seq_length * Dh + ki]; T *k_cache = &params.cache_kv[bhi * params.max_seq_length * Dh + ki];
int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP;
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
K_vec k[K_VECS_PER_THREAD]; K_vec k[K_VECS_PER_THREAD];
...@@ -958,7 +964,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -958,7 +964,7 @@ __global__ void masked_multihead_attention_kernel(
#pragma unroll #pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_seq_length + ti; int jj = ii * params.max_seq_length + ti;
if (ti < params.timestep) { if (ti < act_time_step) {
k[ii] = k[ii] =
(Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length)
? *reinterpret_cast<const K_vec *>( ? *reinterpret_cast<const K_vec *>(
...@@ -972,7 +978,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -972,7 +978,7 @@ __global__ void masked_multihead_attention_kernel(
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k, params.inv_sqrt_dh); float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k, params.inv_sqrt_dh);
// bool is_mask = false; // bool is_mask = false;
if (ti < params.timestep && tid % THREADS_PER_KEY == 0) { if (ti < act_time_step && tid % THREADS_PER_KEY == 0) {
// qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); // qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
T mask = params.attn_mask[bi * (params.timestep + 1) + ti]; T mask = params.attn_mask[bi * (params.timestep + 1) + ti];
qk += static_cast<float>(mask); qk += static_cast<float>(mask);
...@@ -1014,7 +1020,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -1014,7 +1020,7 @@ __global__ void masked_multihead_attention_kernel(
#endif #endif
float sum = 0.f; float sum = 0.f;
for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) { for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) {
// bool is_mask = false; // bool is_mask = false;
// float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max); // float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max);
float logit = __expf(qk_smem[ti] - qk_max); float logit = __expf(qk_smem[ti] - qk_max);
...@@ -1026,7 +1032,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -1026,7 +1032,7 @@ __global__ void masked_multihead_attention_kernel(
// FIXME(wangxi): need add 1.e-6f? // FIXME(wangxi): need add 1.e-6f?
float inv_sum = __fdividef(1.f, sum + 1.e-6f); float inv_sum = __fdividef(1.f, sum + 1.e-6f);
for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) { for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) {
convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum); convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum);
} }
__syncthreads(); __syncthreads();
...@@ -1052,7 +1058,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -1052,7 +1058,7 @@ __global__ void masked_multihead_attention_kernel(
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
if (Dh == Dh_MAX || vi < Dh) { if (Dh == Dh_MAX || vi < Dh) {
for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) { for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) {
V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]); V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti]; float logit = logits_smem[ti];
...@@ -1076,18 +1082,18 @@ __global__ void masked_multihead_attention_kernel( ...@@ -1076,18 +1082,18 @@ __global__ void masked_multihead_attention_kernel(
V_vec v_bias; V_vec v_bias;
zero(v_bias); zero(v_bias);
if (vo == (params.timestep % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) {
V_vec v = *reinterpret_cast<const V_vec *>( V_vec v = *reinterpret_cast<const V_vec *>(
&params.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); &params.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]);
v_bias = *reinterpret_cast<const V_vec *>( v_bias = *reinterpret_cast<const V_vec *>(
&params.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); &params.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]);
v = add(v, v_bias); v = add(v, v_bias);
*reinterpret_cast<V_vec *>(&v_cache[params.timestep * Dh]) = v; *reinterpret_cast<V_vec *>(&v_cache[act_time_step * Dh]) = v;
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
out = fma(logits_smem[params.timestep], cast_to_float(v), out); out = fma(logits_smem[act_time_step], cast_to_float(v), out);
#else #else
out = fma(logits_smem[params.timestep], v, out); out = fma(logits_smem[act_time_step], v, out);
#endif #endif
} }
...@@ -1198,6 +1204,7 @@ void fmha(const phi::GPUContext &dev_ctx, ...@@ -1198,6 +1204,7 @@ void fmha(const phi::GPUContext &dev_ctx,
const phi::DenseTensor &qkv_tensor, const phi::DenseTensor &qkv_tensor,
const phi::DenseTensor &qkv_bias_tensor, const phi::DenseTensor &qkv_bias_tensor,
const phi::DenseTensor &src_mask_tensor, const phi::DenseTensor &src_mask_tensor,
const phi::DenseTensor *sequence_lengths_tensor,
const phi::DenseTensor *rotary_tensor, const phi::DenseTensor *rotary_tensor,
phi::DenseTensor *cache_kv_tensor, phi::DenseTensor *cache_kv_tensor,
phi::DenseTensor *out_tensor, phi::DenseTensor *out_tensor,
...@@ -1214,6 +1221,11 @@ void fmha(const phi::GPUContext &dev_ctx, ...@@ -1214,6 +1221,11 @@ void fmha(const phi::GPUContext &dev_ctx,
params.qkv_bias = qkv_bias_tensor.data<T>(); params.qkv_bias = qkv_bias_tensor.data<T>();
params.attn_mask = src_mask_tensor.data<T>(); params.attn_mask = src_mask_tensor.data<T>();
params.cache_kv = cache_kv_tensor->data<T>(); params.cache_kv = cache_kv_tensor->data<T>();
if (sequence_lengths_tensor) {
params.sequence_lengths = sequence_lengths_tensor->data<int>();
}
if (rotary_emb_dims > 0) { if (rotary_emb_dims > 0) {
params.rotary_emb = rotary_tensor->data<T>(); params.rotary_emb = rotary_tensor->data<T>();
} else { } else {
...@@ -1273,6 +1285,7 @@ void fmha(const phi::GPUContext &dev_ctx, ...@@ -1273,6 +1285,7 @@ void fmha(const phi::GPUContext &dev_ctx,
qkv_bias_tensor, qkv_bias_tensor,
src_mask_tensor, src_mask_tensor,
nullptr, nullptr,
nullptr,
cache_kv_tensor, cache_kv_tensor,
out_tensor, out_tensor,
batch_size, batch_size,
...@@ -1395,6 +1408,7 @@ __global__ void add_fusedQKV_bias_transpose_split_kernel( ...@@ -1395,6 +1408,7 @@ __global__ void add_fusedQKV_bias_transpose_split_kernel(
T *kv_buf, T *kv_buf,
const T *qkv, const T *qkv,
const T *qkv_bias, const T *qkv_bias,
const int *padding_offset,
const int32_t elem_cnt, const int32_t elem_cnt,
const int batch_size, const int batch_size,
const int seq_len, const int seq_len,
...@@ -1423,10 +1437,10 @@ __global__ void add_fusedQKV_bias_transpose_split_kernel( ...@@ -1423,10 +1437,10 @@ __global__ void add_fusedQKV_bias_transpose_split_kernel(
} }
} }
const int32_t token_idx = linear_index / fused_hidden_size; const int32_t token_idx = linear_index / fused_hidden_size;
// const int32_t token_padded_idx = token_idx + (padding_offset == nullptr ? const int32_t ori_token_idx =
// 0 : padding_offset[token_idx]); token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
const int32_t target_batch_id = token_idx / seq_len; const int32_t target_batch_id = ori_token_idx / seq_len;
const int32_t seq_id = token_idx % seq_len; const int32_t seq_id = ori_token_idx % seq_len;
// equal to: // equal to:
// const int qkv_id = (linear_index % fused_hidden_size) / hidden_size; // const int qkv_id = (linear_index % fused_hidden_size) / hidden_size;
...@@ -1475,12 +1489,13 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, ...@@ -1475,12 +1489,13 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
T *kv_buf, T *kv_buf,
const T *qkv, const T *qkv,
const T *qkv_bias, const T *qkv_bias,
const int *padding_offset,
const int token_num,
const int batch_size, const int batch_size,
const int head_num, const int head_num,
const int seq_len, const int seq_len,
const int size_per_head, const int size_per_head,
bool compute_bias) { bool compute_bias) {
const int32_t token_num = batch_size * seq_len;
const int32_t elem_cnt = token_num * head_num * size_per_head * 3; const int32_t elem_cnt = token_num * head_num * size_per_head * 3;
constexpr int PackSize = VEC_16B / sizeof(T); constexpr int PackSize = VEC_16B / sizeof(T);
PADDLE_ENFORCE_EQ(size_per_head % PackSize, PADDLE_ENFORCE_EQ(size_per_head % PackSize,
...@@ -1499,6 +1514,7 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, ...@@ -1499,6 +1514,7 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
kv_buf, kv_buf,
qkv, qkv,
qkv_bias, qkv_bias,
padding_offset,
elem_cnt, elem_cnt,
batch_size, batch_size,
seq_len, seq_len,
...@@ -1511,6 +1527,7 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, ...@@ -1511,6 +1527,7 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
kv_buf, kv_buf,
qkv, qkv,
qkv_bias, qkv_bias,
padding_offset,
elem_cnt, elem_cnt,
batch_size, batch_size,
seq_len, seq_len,
...@@ -1524,7 +1541,9 @@ template <typename T> ...@@ -1524,7 +1541,9 @@ template <typename T>
__global__ void RotrayKernel(const T *input, __global__ void RotrayKernel(const T *input,
const T *cos_emb, const T *cos_emb,
const T *sin_emb, const T *sin_emb,
const int *sequence_lengths,
T *output, T *output,
const int rotary_emb_dims,
const int batch_size, const int batch_size,
const int head_num, const int head_num,
const int seq_len, const int seq_len,
...@@ -1532,6 +1551,7 @@ __global__ void RotrayKernel(const T *input, ...@@ -1532,6 +1551,7 @@ __global__ void RotrayKernel(const T *input,
int bi = blockIdx.x; int bi = blockIdx.x;
int hi = blockIdx.y; int hi = blockIdx.y;
int si = blockIdx.z; int si = blockIdx.z;
if (sequence_lengths && si >= sequence_lengths[bi] * rotary_emb_dims) return;
int half_lastdim = last_dim / 2; int half_lastdim = last_dim / 2;
// Note(ZhenyuLi): Calculate the relevant data at one time, so that no // Note(ZhenyuLi): Calculate the relevant data at one time, so that no
// additional space is required. // additional space is required.
...@@ -1559,6 +1579,7 @@ void rotary_qk(const phi::GPUContext &dev_ctx, ...@@ -1559,6 +1579,7 @@ void rotary_qk(const phi::GPUContext &dev_ctx,
const T *q_input, // q const T *q_input, // q
const T *k_input, // kv const T *k_input, // kv
const T *rotary_emb, const T *rotary_emb,
const int *sequence_lengths,
const int rotary_emb_dims, const int rotary_emb_dims,
const int batch_size, const int batch_size,
const int head_num, const int head_num,
...@@ -1592,7 +1613,9 @@ void rotary_qk(const phi::GPUContext &dev_ctx, ...@@ -1592,7 +1613,9 @@ void rotary_qk(const phi::GPUContext &dev_ctx,
q_input, q_input,
cos_emb, cos_emb,
sin_emb, sin_emb,
sequence_lengths,
q, q,
rotary_emb_dims,
batch_size, batch_size,
head_num, head_num,
seq_len * rotary_emb_dims, seq_len * rotary_emb_dims,
...@@ -1601,13 +1624,109 @@ void rotary_qk(const phi::GPUContext &dev_ctx, ...@@ -1601,13 +1624,109 @@ void rotary_qk(const phi::GPUContext &dev_ctx,
k_input, k_input,
cos_emb, cos_emb,
sin_emb, sin_emb,
sequence_lengths,
k, k,
rotary_emb_dims,
batch_size, batch_size,
head_num, head_num,
seq_len * rotary_emb_dims, seq_len * rotary_emb_dims,
last_dim); last_dim);
} }
__global__ void GetPaddingOffset(int *d_token_num,
int *padding_offset,
const int *sequence_lengths,
const int batch_size,
const int max_seq_len) {
// get padding offset of each batch
int total_seq_len = 0;
int cum_offset = 0;
int index = 0;
for (int i = 0; i < batch_size; i++) {
const int seq_len = sequence_lengths[i];
for (int j = 0; j < seq_len; j++) {
padding_offset[index] = cum_offset;
index++;
}
cum_offset += max_seq_len - seq_len;
total_seq_len += seq_len;
}
d_token_num[0] = total_seq_len;
}
void InvokeGetPaddingOffset(const phi::GPUContext &dev_ctx,
int *h_token_num,
int *d_token_num,
int *padding_offset,
const int *sequence_lengths,
const int batch_size,
const int max_seq_len) {
GetPaddingOffset<<<1, 1, 0, dev_ctx.stream()>>>(
d_token_num, padding_offset, sequence_lengths, batch_size, max_seq_len);
memory::Copy(platform::CPUPlace(),
h_token_num,
dev_ctx.GetPlace(),
d_token_num,
sizeof(int),
dev_ctx.stream());
}
template <typename T>
__global__ void RemovePadding(T *output_data,
const T *input_data,
const int *padding_offset,
const int dim_embed) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int src_seq_id = bid + padding_offset[bid];
const int tgt_seq_id = bid;
for (int i = tid; i < dim_embed; i += blockDim.x) {
output_data[tgt_seq_id * dim_embed + i] =
input_data[src_seq_id * dim_embed + i];
}
}
template <typename T>
void InvokeRemovePadding(const phi::GPUContext &dev_ctx,
T *output_data,
const T *input_data,
const int *padding_offset,
const int token_num,
const int dim_embed) {
RemovePadding<<<token_num, 256, 0, dev_ctx.stream()>>>(
output_data, input_data, padding_offset, dim_embed);
}
template <typename T>
__global__ void RebuildPadding(T *output_data,
const T *input_data,
const int *padding_offset,
const int dim_embed) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int dst_seq_id = bid + padding_offset[bid];
const int src_seq_id = bid;
for (int i = tid; i < dim_embed; i += blockDim.x) {
output_data[dst_seq_id * dim_embed + i] =
input_data[src_seq_id * dim_embed + i];
}
}
template <typename T>
void InvokeRebuildPadding(const phi::GPUContext &dev_ctx,
T *output_data,
const T *input_data,
const int *padding_offset,
const int token_num,
const int dim_embed) {
// src: [token_num, dim_embed]
// dst: [batch_size * max_seq_len, dim_embed]
RebuildPadding<<<token_num, 256, 0, dev_ctx.stream()>>>(
output_data, input_data, padding_offset, dim_embed);
}
#if CUDA_VERSION >= 11060 #if CUDA_VERSION >= 11060
// Only Used in Inference // Only Used in Inference
template <typename T> template <typename T>
......
...@@ -64,6 +64,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -64,6 +64,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"PreCaches", "PreCaches",
"RotaryPosEmb", "RotaryPosEmb",
"TimeStep", "TimeStep",
"SeqLengths",
"SrcMask", "SrcMask",
"OutLinearW", "OutLinearW",
"OutLinearBias", "OutLinearBias",
......
...@@ -887,6 +887,7 @@ def fused_multi_transformer( ...@@ -887,6 +887,7 @@ def fused_multi_transformer(
epsilon=1e-05, epsilon=1e-05,
cache_kvs=None, cache_kvs=None,
pre_caches=None, pre_caches=None,
seq_lens=None,
rotary_embs=None, rotary_embs=None,
time_step=None, time_step=None,
attn_mask=None, attn_mask=None,
...@@ -956,6 +957,7 @@ def fused_multi_transformer( ...@@ -956,6 +957,7 @@ def fused_multi_transformer(
epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5. epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5.
cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None. cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None.
pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None.
seq_lens (Tensor optional): The sequence lengths of this batch. The shape is `[bsz]`. Default None.
rotary_embs (Tensor optional): The RoPE embs for rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. rotary_embs (Tensor optional): The RoPE embs for rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None. time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
...@@ -1055,6 +1057,7 @@ def fused_multi_transformer( ...@@ -1055,6 +1057,7 @@ def fused_multi_transformer(
pre_caches, pre_caches,
rotary_embs, rotary_embs,
time_step, time_step,
seq_lens,
attn_mask, attn_mask,
linear_weights, linear_weights,
linear_biases, linear_biases,
...@@ -1115,6 +1118,7 @@ def fused_multi_transformer( ...@@ -1115,6 +1118,7 @@ def fused_multi_transformer(
inputs['PreCaches'] = pre_caches inputs['PreCaches'] = pre_caches
if rotary_emb_dims > 0: if rotary_emb_dims > 0:
inputs['RotaryPosEmb'] = rotary_embs inputs['RotaryPosEmb'] = rotary_embs
inputs['SeqLengths'] = seq_lens
inputs['SrcMask'] = attn_mask inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = linear_weights inputs['OutLinearW'] = linear_weights
if linear_biases is not None: if linear_biases is not None:
......
...@@ -1382,6 +1382,7 @@ class FusedMultiTransformer(Layer): ...@@ -1382,6 +1382,7 @@ class FusedMultiTransformer(Layer):
pre_caches=None, pre_caches=None,
rotary_embs=None, rotary_embs=None,
rotary_emb_dims=0, rotary_emb_dims=0,
seq_lens=None,
time_step=None, time_step=None,
): ):
r""" r"""
...@@ -1406,6 +1407,7 @@ class FusedMultiTransformer(Layer): ...@@ -1406,6 +1407,7 @@ class FusedMultiTransformer(Layer):
rotary_embs (Tensor optional): The RoPE embs for the rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. rotary_embs (Tensor optional): The RoPE embs for the rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None.
rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None, rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None,
1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0. 1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0.
seq_lens (Tensor optional): The sequence lengths of this batch. The shape is `[bsz]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation time_step (Tensor, optional): The time step tensor for the generation
model. Which used in decode stage, to represent the time step, model. Which used in decode stage, to represent the time step,
that is, the real seq_len of CacheKV. The shape is `[1]`, must be that is, the real seq_len of CacheKV. The shape is `[1]`, must be
...@@ -1441,6 +1443,7 @@ class FusedMultiTransformer(Layer): ...@@ -1441,6 +1443,7 @@ class FusedMultiTransformer(Layer):
pre_caches=pre_caches, pre_caches=pre_caches,
rotary_embs=rotary_embs, rotary_embs=rotary_embs,
time_step=time_step, time_step=time_step,
seq_lens=seq_lens,
attn_mask=attn_mask, attn_mask=attn_mask,
dropout_rate=self.dropout_rate, dropout_rate=self.dropout_rate,
rotary_emb_dims=rotary_emb_dims, rotary_emb_dims=rotary_emb_dims,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册