未验证 提交 53df50c7 编写于 作者: L lzy 提交者: GitHub

make FusedMultiTransformer supports variable-lengths. (#49560)

* make FusedMultiTransformer supports variable-lengths.

* modify ffn2 when cuda_version >= 11.6 because of #49392.

* code style

* delete remove_padding
上级 a3989b5e
...@@ -62,6 +62,78 @@ class AttnDropoutParam { ...@@ -62,6 +62,78 @@ class AttnDropoutParam {
const phi::DenseTensor* seed_; const phi::DenseTensor* seed_;
}; };
template <typename T, int VecSize>
__global__ void TransposeRemovingPadding(const T* input_data,
T* output_data,
const int batch_size,
const int num_head,
const int seq_len,
const int head_dim,
const int token_num,
const int elem_cnt,
const int* padding_offset) {
// transpose and remove padding
// [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head,
// head_dim]
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
const int dim_embed = num_head * head_dim;
using LoadT = phi::AlignedVector<T, VecSize>;
LoadT src_vec;
for (int32_t linear_index = idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / dim_embed;
const int ori_token_idx =
token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
const int ori_batch_id = ori_token_idx / seq_len;
const int ori_seq_id = ori_token_idx % seq_len;
const int ori_head_id = (linear_index % dim_embed) / head_dim;
const int ori_head_lane = (linear_index % dim_embed) % head_dim;
const int ori_idx = ori_batch_id * num_head * seq_len * head_dim +
ori_head_id * seq_len * head_dim +
ori_seq_id * head_dim + ori_head_lane;
phi::Load<T, VecSize>(&input_data[ori_idx], &src_vec);
phi::Store<T, VecSize>(src_vec, &output_data[linear_index]);
}
}
template <typename T>
void InvokeTransposeRemovePadding(const phi::GPUContext& dev_ctx,
const T* input_data,
T* output_data,
const int batch_size,
const int num_head,
const int seq_len,
const int head_dim,
const int token_num,
const int* padding_offset) {
// [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head,
// head_dim]
constexpr int VEC_16B = 16;
const int elem_cnt = token_num * num_head * head_dim;
constexpr int PackSize = VEC_16B / sizeof(T);
PADDLE_ENFORCE_EQ(
head_dim % PackSize,
0,
platform::errors::PreconditionNotMet(
"dim_head=%d must be divisible by vec_size=%d", head_dim, PackSize));
const int32_t pack_num = elem_cnt / PackSize;
const int32_t block_size = 128;
int32_t grid_size = (pack_num + block_size - 1) / block_size;
TransposeRemovingPadding<T, PackSize>
<<<grid_size, block_size, 0, dev_ctx.stream()>>>(input_data,
output_data,
batch_size,
num_head,
seq_len,
head_dim,
token_num,
elem_cnt,
padding_offset);
}
template <typename T> template <typename T>
class FMHARef { class FMHARef {
public: public:
...@@ -256,18 +328,21 @@ class FMHARef { ...@@ -256,18 +328,21 @@ class FMHARef {
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
} }
void ComputeForwardWithoutTranspose(const phi::DenseTensor* cache_kv_tensor, void ComputeForwardWithoutTranspose(
const phi::DenseTensor* src_mask_tensor, const phi::DenseTensor* cache_kv_tensor,
phi::DenseTensor* q_transpose_out_tensor, const phi::DenseTensor* src_mask_tensor,
phi::DenseTensor* kv_transpose_out_tensor, const phi::DenseTensor* padding_offset_tensor,
phi::DenseTensor* cache_kv_out_tensor, phi::DenseTensor* q_transpose_out_tensor,
phi::DenseTensor* qk_out_tensor, phi::DenseTensor* kv_transpose_out_tensor,
phi::DenseTensor* src_mask_out_tensor, phi::DenseTensor* cache_kv_out_tensor,
phi::DenseTensor* softmax_out_tensor, phi::DenseTensor* qk_out_tensor,
phi::DenseTensor* dropout_mask_out_tensor, phi::DenseTensor* src_mask_out_tensor,
phi::DenseTensor* dropout_out_tensor, phi::DenseTensor* softmax_out_tensor,
phi::DenseTensor* qktv_out_tensor, phi::DenseTensor* dropout_mask_out_tensor,
phi::DenseTensor* fmha_out_tensor) { phi::DenseTensor* dropout_out_tensor,
phi::DenseTensor* qktv_out_tensor,
phi::DenseTensor* fmha_out_tensor,
const int token_num) {
// input shape: [bs, seq_len, 3, num_head, head_dim] // input shape: [bs, seq_len, 3, num_head, head_dim]
// transpose with perm [2, 0, 3, 1, 4], // transpose with perm [2, 0, 3, 1, 4],
// output_shape: [3, bs, num_head, seq_len, head_dim] // output_shape: [3, bs, num_head, seq_len, head_dim]
...@@ -424,9 +499,21 @@ class FMHARef { ...@@ -424,9 +499,21 @@ class FMHARef {
} }
// transpose: [0, 2, 1, 3] // transpose: [0, 2, 1, 3]
// output shape: [batch_size, seq_len, num_heads, head_dim] // output shape: [batch_size, seq_len, num_heads, head_dim]
std::vector<int> perm_3 = {0, 2, 1, 3}; if (!padding_offset_tensor) {
phi::funcs::TransposeGPUKernelDriver<T>( std::vector<int> perm_3 = {0, 2, 1, 3};
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); phi::funcs::TransposeGPUKernelDriver<T>(
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
} else {
InvokeTransposeRemovePadding<T>(dev_ctx_,
qktv_out_data,
fmha_out_data,
batch_size_,
num_head_,
seq_len_,
head_dim_,
token_num,
padding_offset_tensor->data<int>());
}
} }
void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor, void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor,
......
...@@ -182,6 +182,8 @@ class FusedMultiTransformerOpOpMaker ...@@ -182,6 +182,8 @@ class FusedMultiTransformerOpOpMaker
AddInput("TimeStep", AddInput("TimeStep",
"(optional, int) The time step for generation inference.") "(optional, int) The time step for generation inference.")
.AsDispensable(); .AsDispensable();
AddInput("SeqLengths", "(optional) The sequence length tensor of inputs.")
.AsDispensable();
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
.AsDispensable(); .AsDispensable();
AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable(); AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable();
......
...@@ -32,6 +32,48 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -32,6 +32,48 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int dim_embed = input_x_dims[2]; int dim_embed = input_x_dims[2];
int bsz_seq = bsz * seq_len; int bsz_seq = bsz * seq_len;
const std::string act_method = ctx.Attr<std::string>("act_method"); const std::string act_method = ctx.Attr<std::string>("act_method");
bool remove_padding = false;
auto *sequence_lengths = ctx.Input<phi::DenseTensor>("SeqLengths");
if (sequence_lengths) {
remove_padding = true;
}
phi::DenseTensor d_token_tensor;
phi::DenseTensor padding_offset_tensor;
phi::DenseTensor x_remove_padding;
bool encoder_remove_padding = (remove_padding && !time_step);
int token_num = 0;
// remove padding in encoder
if (encoder_remove_padding) {
// just for encoder
d_token_tensor.Resize({{1}});
auto *d_token_num = dev_ctx.Alloc<int>(
&d_token_tensor, d_token_tensor.numel() * sizeof(int));
// alloc the max size of padding_offset_tensor
padding_offset_tensor.Resize({{bsz_seq}});
dev_ctx.Alloc<int>(&padding_offset_tensor,
padding_offset_tensor.numel() * sizeof(int));
InvokeGetPaddingOffset(dev_ctx,
&token_num,
d_token_num,
padding_offset_tensor.data<int>(),
sequence_lengths->data<int>(),
bsz,
seq_len);
padding_offset_tensor.Resize({{token_num}});
x_remove_padding.Resize({{token_num, dim_embed}});
dev_ctx.Alloc<T>(&x_remove_padding, x_remove_padding.numel() * sizeof(T));
InvokeRemovePadding(dev_ctx,
x_remove_padding.data<T>(),
input_x->data<T>(),
padding_offset_tensor.data<int>(),
token_num,
dim_embed);
} else {
token_num = bsz_seq;
}
auto *padding_offset_data =
encoder_remove_padding ? padding_offset_tensor.data<int>() : nullptr;
// 1. layer norm // 1. layer norm
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm"); const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
...@@ -39,12 +81,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -39,12 +81,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto ln_scales = ctx.MultiInput<phi::DenseTensor>("LnScale"); auto ln_scales = ctx.MultiInput<phi::DenseTensor>("LnScale");
auto ln_biases = ctx.MultiInput<phi::DenseTensor>("LnBias"); auto ln_biases = ctx.MultiInput<phi::DenseTensor>("LnBias");
auto ln_compute = AttnLayerNorm<T>(dev_ctx, epsilon, bsz_seq, dim_embed); auto ln_compute = AttnLayerNorm<T>(dev_ctx, epsilon, token_num, dim_embed);
phi::DenseTensor ln_mean, ln_var; phi::DenseTensor ln_mean, ln_var;
ln_mean.Resize({{bsz_seq}}); ln_mean.Resize({{token_num}});
auto *ln_mean_data = auto *ln_mean_data =
dev_ctx.Alloc<U>(&ln_mean, ln_mean.numel() * sizeof(U)); dev_ctx.Alloc<U>(&ln_mean, ln_mean.numel() * sizeof(U));
ln_var.Resize({{bsz_seq}}); ln_var.Resize({{token_num}});
auto *ln_var_data = dev_ctx.Alloc<U>(&ln_var, ln_var.numel() * sizeof(U)); auto *ln_var_data = dev_ctx.Alloc<U>(&ln_var, ln_var.numel() * sizeof(U));
// 2. qkv // 2. qkv
...@@ -67,13 +109,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -67,13 +109,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto qkv_compute = AttnMatMul<T>(dev_ctx, auto qkv_compute = AttnMatMul<T>(dev_ctx,
false, false,
trans_qkvw, trans_qkvw,
bsz_seq, token_num,
output_size, output_size,
input_size, input_size,
/*compute_bias=*/false); /*compute_bias=*/false);
phi::DenseTensor qkv_out; phi::DenseTensor qkv_out;
qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); qkv_out.Resize({{token_num, 3, num_head, dim_head}});
auto *qkv_out_data = auto *qkv_out_data =
dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T)); dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T));
...@@ -175,23 +217,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -175,23 +217,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
// (transA, transB, compute_bias) = (false, false, false) // (transA, transB, compute_bias) = (false, false, false)
auto out_linear_compute = AttnMatMul<T>( auto out_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_embed, hidden_size, false); dev_ctx, false, false, token_num, dim_embed, hidden_size, false);
// 5. ln(residual + bias) // 5. ln(residual + bias)
DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper( FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); dev_ctx, token_num, dim_embed, dropout_param2, epsilon);
auto ffn_ln_scales = ctx.MultiInput<phi::DenseTensor>("FFNLnScale"); auto ffn_ln_scales = ctx.MultiInput<phi::DenseTensor>("FFNLnScale");
auto ffn_ln_biases = ctx.MultiInput<phi::DenseTensor>("FFNLnBias"); auto ffn_ln_biases = ctx.MultiInput<phi::DenseTensor>("FFNLnBias");
phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; phi::DenseTensor bias_dropout_residual_out, dropout_mask_out;
T *bias_dropout_residual_out_data = nullptr; T *bias_dropout_residual_out_data = nullptr;
if (pre_layer_norm) { if (pre_layer_norm) {
bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}}); bias_dropout_residual_out.Resize({{token_num, dim_embed}});
bias_dropout_residual_out_data = bias_dropout_residual_out_data =
dev_ctx.Alloc<T>(&bias_dropout_residual_out, dev_ctx.Alloc<T>(&bias_dropout_residual_out,
bias_dropout_residual_out.numel() * sizeof(T)); bias_dropout_residual_out.numel() * sizeof(T));
} }
dropout_mask_out.Resize({{bsz, seq_len, dim_embed}}); dropout_mask_out.Resize({{token_num, dim_embed}});
auto *dropout_mask_out_data = dev_ctx.Alloc<uint8_t>( auto *dropout_mask_out_data = dev_ctx.Alloc<uint8_t>(
&dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t));
...@@ -203,11 +245,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -203,11 +245,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int dim_ffn = ffn1_weight_dim[1]; int dim_ffn = ffn1_weight_dim[1];
auto ffn1_cublas_linear = CublasFusedMLP<T>(dev_ctx); auto ffn1_cublas_linear = CublasFusedMLP<T>(dev_ctx);
const phi::DDim ffn1_input_shape({bsz_seq, dim_embed}); const phi::DDim ffn1_input_shape({token_num, dim_embed});
ffn1_cublas_linear.Setup(ffn1_input_shape, ffn1_weight_dim, false, false); ffn1_cublas_linear.Setup(ffn1_input_shape, ffn1_weight_dim, false, false);
phi::DenseTensor ffn1_out; phi::DenseTensor ffn1_out;
ffn1_out.Resize({{bsz_seq, dim_ffn}}); ffn1_out.Resize({{token_num, dim_ffn}});
auto *ffn1_out_data = auto *ffn1_out_data =
dev_ctx.Alloc<T>(&ffn1_out, ffn1_out.numel() * sizeof(T)); dev_ctx.Alloc<T>(&ffn1_out, ffn1_out.numel() * sizeof(T));
...@@ -216,23 +258,33 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -216,23 +258,33 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto ffn2_biases = ctx.MultiInput<phi::DenseTensor>("FFN2Bias"); auto ffn2_biases = ctx.MultiInput<phi::DenseTensor>("FFN2Bias");
auto ffn2_linear_compute = AttnMatMul<T>( auto ffn2_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_embed, dim_ffn, false); dev_ctx, false, false, token_num, dim_embed, dim_ffn, false);
// 8. ffn2 Layernorm residual bias // 8. ffn2 Layernorm residual bias
DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> ffn2_fused_dropout_helper( FusedDropoutLayerNormHelper<T, uint8_t> ffn2_fused_dropout_helper(
dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon);
// calc // calc
auto *out = ctx.Output<phi::DenseTensor>("Out"); auto *out = ctx.Output<phi::DenseTensor>("Out");
auto *from_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T)); auto *from_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
phi::DenseTensor *from_tensor = out; phi::DenseTensor *from_tensor = out;
phi::DenseTensor tmp_out; phi::DenseTensor tmp_out, tmp_out_rm_padding;
tmp_out.Resize({{bsz, seq_len, dim_embed}}); tmp_out.Resize({{token_num, dim_embed}});
if (encoder_remove_padding) {
tmp_out_rm_padding.Resize({{token_num, dim_embed}});
auto *tmp_out_rm_padding_data = dev_ctx.Alloc<T>(
&tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T));
}
auto *tmp_out_data = auto *tmp_out_data =
dev_ctx.Alloc<T>(&tmp_out, tmp_out.numel() * sizeof(T)); dev_ctx.Alloc<T>(&tmp_out, tmp_out.numel() * sizeof(T));
auto *x_data = input_x->data<T>(); const T *x_data;
if (encoder_remove_padding) {
x_data = x_remove_padding.data<T>();
} else {
x_data = input_x->data<T>();
}
phi::DenseTensor *buf0 = nullptr; phi::DenseTensor *buf0 = nullptr;
phi::DenseTensor *buf1 = nullptr; phi::DenseTensor *buf1 = nullptr;
...@@ -240,19 +292,27 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -240,19 +292,27 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// step1: buf1 --> buf0 // step1: buf1 --> buf0
// step2: buf0 --> buf1 // step2: buf0 --> buf1
int layers = qkv_weights.size(); int layers = qkv_weights.size();
if (pre_layer_norm) { if (encoder_remove_padding) {
if (layers & 1) { // In the case of variable lengths, the padding needs to be rebuilt
// odd, set buf1 as out // eventually. So buf0 and buf1 do not need to be changed according to the
// pre_layer_norm and the number of layers.
buf0 = &tmp_out;
buf1 = &tmp_out_rm_padding;
} else {
if (pre_layer_norm) {
if (layers & 1) {
// odd, set buf1 as out
buf0 = &tmp_out;
buf1 = out;
} else {
// even, set buf0 as out
buf0 = out;
buf1 = &tmp_out;
}
} else {
buf0 = &tmp_out; buf0 = &tmp_out;
buf1 = out; buf1 = out;
} else {
// even, set buf0 as out
buf0 = out;
buf1 = &tmp_out;
} }
} else {
buf0 = &tmp_out;
buf1 = out;
} }
for (int i = 0; i < layers; ++i) { for (int i = 0; i < layers; ++i) {
...@@ -278,8 +338,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -278,8 +338,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// NOTE: in decoder stage, bias is fused in fmha // NOTE: in decoder stage, bias is fused in fmha
const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias; const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias;
if (!pre_layer_norm && i == 0) { if (!pre_layer_norm && i == 0) {
const phi::DenseTensor *tmp_input_x =
(encoder_remove_padding) ? &x_remove_padding : input_x;
qkv_compute.ComputeForward( qkv_compute.ComputeForward(
qkv_weights[i], input_x, bias, &qkv_out, &qkv_out); qkv_weights[i], tmp_input_x, bias, &qkv_out, &qkv_out);
} else { } else {
qkv_compute.ComputeForward( qkv_compute.ComputeForward(
qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); qkv_weights[i], buf1, bias, &qkv_out, &qkv_out);
...@@ -300,6 +362,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -300,6 +362,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
qkv_out, qkv_out,
*qkv_bias, *qkv_bias,
*src_mask, *src_mask,
sequence_lengths,
rotary_tensor, rotary_tensor,
cache_kv_out, cache_kv_out,
&fmha_out, &fmha_out,
...@@ -322,22 +385,26 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -322,22 +385,26 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
kv_transpose_out_data, kv_transpose_out_data,
qkv_out_data, qkv_out_data,
qkv_bias->data<T>(), qkv_bias->data<T>(),
padding_offset_data,
token_num,
bsz, bsz,
num_head, num_head,
seq_len, seq_len,
dim_head, dim_head,
compute_bias); compute_bias);
// q_transpose_out_data [bs, head_num, seq_len, dim_head] // q_transpose_out_data [bs, head_num, seq_len, dim_head]
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) { if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>(); auto *rotary_emb_data = rotary_tensor->data<T>();
const int *sequence_lengths_data =
encoder_remove_padding ? sequence_lengths->data<int>() : nullptr;
rotary_qk(dev_ctx, rotary_qk(dev_ctx,
q_transpose_out_data, q_transpose_out_data,
kv_transpose_out_data, kv_transpose_out_data,
q_transpose_out_data, q_transpose_out_data,
kv_transpose_out_data, kv_transpose_out_data,
rotary_emb_data, rotary_emb_data,
sequence_lengths_data,
rotary_emb_dims, rotary_emb_dims,
bsz, bsz,
num_head, num_head,
...@@ -345,8 +412,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -345,8 +412,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
dim_head); dim_head);
} }
phi::DenseTensor *tmp_padding_offset_tensor =
encoder_remove_padding ? &padding_offset_tensor : nullptr;
fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor,
src_mask, src_mask,
tmp_padding_offset_tensor,
&q_transpose_out, &q_transpose_out,
&kv_transpose_out, &kv_transpose_out,
pre_cache_kv_out_tmp, pre_cache_kv_out_tmp,
...@@ -356,7 +426,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -356,7 +426,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
&attn_dropout_mask_out, &attn_dropout_mask_out,
&attn_dropout_out, &attn_dropout_out,
&qktv_out, &qktv_out,
&fmha_out); &fmha_out,
token_num);
const T *k_ptr = nullptr; const T *k_ptr = nullptr;
const T *v_ptr = nullptr; const T *v_ptr = nullptr;
...@@ -400,6 +471,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -400,6 +471,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
kv_transpose_out_data, kv_transpose_out_data,
qkv_out_data, qkv_out_data,
qkv_bias->data<T>(), qkv_bias->data<T>(),
padding_offset_data,
token_num,
bsz, bsz,
num_head, num_head,
seq_len, seq_len,
...@@ -410,12 +483,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -410,12 +483,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) { if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>(); auto *rotary_emb_data = rotary_tensor->data<T>();
const int *sequence_lengths_data =
encoder_remove_padding ? sequence_lengths->data<int>() : nullptr;
rotary_qk(dev_ctx, rotary_qk(dev_ctx,
q_transpose_out_data, q_transpose_out_data,
kv_transpose_out_data, kv_transpose_out_data,
q_transpose_out_data, q_transpose_out_data,
kv_transpose_out_data, kv_transpose_out_data,
rotary_emb_data, rotary_emb_data,
sequence_lengths_data,
rotary_emb_dims, rotary_emb_dims,
bsz, bsz,
num_head, num_head,
...@@ -423,8 +499,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -423,8 +499,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
dim_head); dim_head);
} }
phi::DenseTensor *tmp_padding_offset_tensor =
encoder_remove_padding ? &padding_offset_tensor : nullptr;
fmha_compute.ComputeForwardWithoutTranspose(cache_kv, fmha_compute.ComputeForwardWithoutTranspose(cache_kv,
src_mask, src_mask,
tmp_padding_offset_tensor,
&q_transpose_out, &q_transpose_out,
&kv_transpose_out, &kv_transpose_out,
cache_kv_out, cache_kv_out,
...@@ -434,7 +513,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -434,7 +513,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
&attn_dropout_mask_out, &attn_dropout_mask_out,
&attn_dropout_out, &attn_dropout_out,
&qktv_out, &qktv_out,
&fmha_out); &fmha_out,
token_num);
} }
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step3"; VLOG(0) << "step3";
...@@ -580,6 +660,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -580,6 +660,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
std::swap(buf0, buf1); std::swap(buf0, buf1);
} }
} }
if (encoder_remove_padding) {
if (pre_layer_norm) {
InvokeRebuildPadding(dev_ctx,
from_data,
buf0->data<T>(),
padding_offset_data,
token_num,
dim_embed);
} else {
InvokeRebuildPadding(dev_ctx,
from_data,
buf1->data<T>(),
padding_offset_data,
token_num,
dim_embed);
}
}
} }
}; };
...@@ -601,6 +698,48 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -601,6 +698,48 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int dim_embed = input_x_dims[2]; int dim_embed = input_x_dims[2];
int bsz_seq = bsz * seq_len; int bsz_seq = bsz * seq_len;
const std::string act_method = ctx.Attr<std::string>("act_method"); const std::string act_method = ctx.Attr<std::string>("act_method");
bool remove_padding = false;
auto *sequence_lengths = ctx.Input<phi::DenseTensor>("SeqLengths");
if (sequence_lengths) {
remove_padding = true;
}
phi::DenseTensor d_token_tensor;
phi::DenseTensor padding_offset_tensor;
phi::DenseTensor x_remove_padding;
bool encoder_remove_padding = (remove_padding && !time_step);
int token_num = 0;
// remove padding in encoder
if (encoder_remove_padding) {
// just for encoder
d_token_tensor.Resize({{1}});
auto *d_token_num = dev_ctx.Alloc<int>(
&d_token_tensor, d_token_tensor.numel() * sizeof(int));
// alloc the max size of padding_offset_tensor
padding_offset_tensor.Resize({{bsz_seq}});
dev_ctx.Alloc<int>(&padding_offset_tensor,
padding_offset_tensor.numel() * sizeof(int));
InvokeGetPaddingOffset(dev_ctx,
&token_num,
d_token_num,
padding_offset_tensor.data<int>(),
sequence_lengths->data<int>(),
bsz,
seq_len);
padding_offset_tensor.Resize({{token_num}});
x_remove_padding.Resize({{token_num, dim_embed}});
dev_ctx.Alloc<T>(&x_remove_padding, x_remove_padding.numel() * sizeof(T));
InvokeRemovePadding(dev_ctx,
x_remove_padding.data<T>(),
input_x->data<T>(),
padding_offset_tensor.data<int>(),
token_num,
dim_embed);
} else {
token_num = bsz_seq;
}
auto *padding_offset_data =
encoder_remove_padding ? padding_offset_tensor.data<int>() : nullptr;
// 1. layer norm // 1. layer norm
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm"); const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
...@@ -608,12 +747,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -608,12 +747,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto ln_scales = ctx.MultiInput<phi::DenseTensor>("LnScale"); auto ln_scales = ctx.MultiInput<phi::DenseTensor>("LnScale");
auto ln_biases = ctx.MultiInput<phi::DenseTensor>("LnBias"); auto ln_biases = ctx.MultiInput<phi::DenseTensor>("LnBias");
auto ln_compute = AttnLayerNorm<T>(dev_ctx, epsilon, bsz_seq, dim_embed); auto ln_compute = AttnLayerNorm<T>(dev_ctx, epsilon, token_num, dim_embed);
phi::DenseTensor ln_mean, ln_var; phi::DenseTensor ln_mean, ln_var;
ln_mean.Resize({{bsz_seq}}); ln_mean.Resize({{token_num}});
auto *ln_mean_data = auto *ln_mean_data =
dev_ctx.Alloc<U>(&ln_mean, ln_mean.numel() * sizeof(U)); dev_ctx.Alloc<U>(&ln_mean, ln_mean.numel() * sizeof(U));
ln_var.Resize({{bsz_seq}}); ln_var.Resize({{token_num}});
auto *ln_var_data = dev_ctx.Alloc<U>(&ln_var, ln_var.numel() * sizeof(U)); auto *ln_var_data = dev_ctx.Alloc<U>(&ln_var, ln_var.numel() * sizeof(U));
// 2. qkv // 2. qkv
...@@ -636,13 +775,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -636,13 +775,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto qkv_compute = AttnMatMul<T>(dev_ctx, auto qkv_compute = AttnMatMul<T>(dev_ctx,
false, false,
trans_qkvw, trans_qkvw,
bsz_seq, token_num,
output_size, output_size,
input_size, input_size,
/*compute_bias=*/false); /*compute_bias=*/false);
phi::DenseTensor qkv_out; phi::DenseTensor qkv_out;
qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); qkv_out.Resize({{token_num, 3, num_head, dim_head}});
auto *qkv_out_data = auto *qkv_out_data =
dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T)); dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T));
...@@ -744,23 +883,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -744,23 +883,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
// (transA, transB, compute_bias) = (false, false, false) // (transA, transB, compute_bias) = (false, false, false)
auto out_linear_compute = AttnMatMul<T>( auto out_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_embed, hidden_size, false); dev_ctx, false, false, token_num, dim_embed, hidden_size, false);
// 5. ln(residual + bias) // 5. ln(residual + bias)
DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper( FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); dev_ctx, token_num, dim_embed, dropout_param2, epsilon);
auto ffn_ln_scales = ctx.MultiInput<phi::DenseTensor>("FFNLnScale"); auto ffn_ln_scales = ctx.MultiInput<phi::DenseTensor>("FFNLnScale");
auto ffn_ln_biases = ctx.MultiInput<phi::DenseTensor>("FFNLnBias"); auto ffn_ln_biases = ctx.MultiInput<phi::DenseTensor>("FFNLnBias");
phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; phi::DenseTensor bias_dropout_residual_out, dropout_mask_out;
T *bias_dropout_residual_out_data = nullptr; T *bias_dropout_residual_out_data = nullptr;
if (pre_layer_norm) { if (pre_layer_norm) {
bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}}); bias_dropout_residual_out.Resize({{token_num, dim_embed}});
bias_dropout_residual_out_data = bias_dropout_residual_out_data =
dev_ctx.Alloc<T>(&bias_dropout_residual_out, dev_ctx.Alloc<T>(&bias_dropout_residual_out,
bias_dropout_residual_out.numel() * sizeof(T)); bias_dropout_residual_out.numel() * sizeof(T));
} }
dropout_mask_out.Resize({{bsz, seq_len, dim_embed}}); dropout_mask_out.Resize({{token_num, dim_embed}});
auto *dropout_mask_out_data = dev_ctx.Alloc<uint8_t>( auto *dropout_mask_out_data = dev_ctx.Alloc<uint8_t>(
&dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t));
...@@ -771,21 +910,21 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -771,21 +910,21 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int dim_ffn = ffn1_weight_dim[1]; int dim_ffn = ffn1_weight_dim[1];
auto ffn1_linear_compute = AttnMatMul<T>( auto ffn1_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_ffn, dim_embed, false); dev_ctx, false, false, token_num, dim_ffn, dim_embed, false);
phi::DenseTensor ffn1_out; phi::DenseTensor ffn1_out;
ffn1_out.Resize({{bsz_seq, dim_ffn}}); ffn1_out.Resize({{token_num, dim_ffn}});
auto *ffn1_out_data = auto *ffn1_out_data =
dev_ctx.Alloc<T>(&ffn1_out, ffn1_out.numel() * sizeof(T)); dev_ctx.Alloc<T>(&ffn1_out, ffn1_out.numel() * sizeof(T));
// 7. ffn act + bias // 7. ffn act + bias
DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0); DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper( FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param); dev_ctx, token_num, dim_ffn, ffn1_dropout_param);
phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask; phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask;
ffn1_dropout_out.Resize({{bsz_seq, dim_ffn}}); ffn1_dropout_out.Resize({{token_num, dim_ffn}});
auto *ffn1_dropout_out_data = dev_ctx.Alloc<T>( auto *ffn1_dropout_out_data = dev_ctx.Alloc<T>(
&ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T)); &ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T));
ffn1_dropout_mask.Resize({{bsz_seq, dim_ffn}}); ffn1_dropout_mask.Resize({{token_num, dim_ffn}});
auto *ffn1_dropout_mask_data = dev_ctx.Alloc<uint8_t>( auto *ffn1_dropout_mask_data = dev_ctx.Alloc<uint8_t>(
&ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t)); &ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t));
...@@ -793,23 +932,33 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -793,23 +932,33 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto ffn2_weights = ctx.MultiInput<phi::DenseTensor>("FFN2Weight"); auto ffn2_weights = ctx.MultiInput<phi::DenseTensor>("FFN2Weight");
auto ffn2_biases = ctx.MultiInput<phi::DenseTensor>("FFN2Bias"); auto ffn2_biases = ctx.MultiInput<phi::DenseTensor>("FFN2Bias");
auto ffn2_linear_compute = AttnMatMul<T>( auto ffn2_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_embed, dim_ffn, false); dev_ctx, false, false, token_num, dim_embed, dim_ffn, false);
// 9. ffn2 residual bias // 9. ffn2 residual bias
DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> ffn2_fused_dropout_helper( FusedDropoutLayerNormHelper<T, uint8_t> ffn2_fused_dropout_helper(
dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon);
// calc // calc
auto *out = ctx.Output<phi::DenseTensor>("Out"); auto *out = ctx.Output<phi::DenseTensor>("Out");
auto *from_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T)); auto *from_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
phi::DenseTensor *from_tensor = out; phi::DenseTensor *from_tensor = out;
phi::DenseTensor tmp_out; phi::DenseTensor tmp_out, tmp_out_rm_padding;
tmp_out.Resize({{bsz, seq_len, dim_embed}}); tmp_out.Resize({{token_num, dim_embed}});
if (encoder_remove_padding) {
tmp_out_rm_padding.Resize({{token_num, dim_embed}});
auto *tmp_out_rm_padding_data = dev_ctx.Alloc<T>(
&tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T));
}
auto *tmp_out_data = auto *tmp_out_data =
dev_ctx.Alloc<T>(&tmp_out, tmp_out.numel() * sizeof(T)); dev_ctx.Alloc<T>(&tmp_out, tmp_out.numel() * sizeof(T));
auto *x_data = input_x->data<T>(); const T *x_data;
if (encoder_remove_padding) {
x_data = x_remove_padding.data<T>();
} else {
x_data = input_x->data<T>();
}
phi::DenseTensor *buf0 = nullptr; phi::DenseTensor *buf0 = nullptr;
phi::DenseTensor *buf1 = nullptr; phi::DenseTensor *buf1 = nullptr;
...@@ -817,19 +966,27 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -817,19 +966,27 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// step1: buf1 --> buf0 // step1: buf1 --> buf0
// step2: buf0 --> buf1 // step2: buf0 --> buf1
int layers = qkv_weights.size(); int layers = qkv_weights.size();
if (pre_layer_norm) { if (encoder_remove_padding) {
if (layers & 1) { // In the case of variable lengths, the padding needs to be rebuilt
// odd, set buf1 as out // eventually. So buf0 and buf1 do not need to be changed according to the
// pre_layer_norm and the number of layers.
buf0 = &tmp_out;
buf1 = &tmp_out_rm_padding;
} else {
if (pre_layer_norm) {
if (layers & 1) {
// odd, set buf1 as out
buf0 = &tmp_out;
buf1 = out;
} else {
// even, set buf0 as out
buf0 = out;
buf1 = &tmp_out;
}
} else {
buf0 = &tmp_out; buf0 = &tmp_out;
buf1 = out; buf1 = out;
} else {
// even, set buf0 as out
buf0 = out;
buf1 = &tmp_out;
} }
} else {
buf0 = &tmp_out;
buf1 = out;
} }
for (int i = 0; i < layers; ++i) { for (int i = 0; i < layers; ++i) {
...@@ -855,8 +1012,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -855,8 +1012,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// NOTE: in decoder stage, bias is fused in fmha // NOTE: in decoder stage, bias is fused in fmha
const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias; const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias;
if (!pre_layer_norm && i == 0) { if (!pre_layer_norm && i == 0) {
const phi::DenseTensor *tmp_input_x =
(encoder_remove_padding) ? &x_remove_padding : input_x;
qkv_compute.ComputeForward( qkv_compute.ComputeForward(
qkv_weights[i], input_x, bias, &qkv_out, &qkv_out); qkv_weights[i], tmp_input_x, bias, &qkv_out, &qkv_out);
} else { } else {
qkv_compute.ComputeForward( qkv_compute.ComputeForward(
qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); qkv_weights[i], buf1, bias, &qkv_out, &qkv_out);
...@@ -877,6 +1036,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -877,6 +1036,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
qkv_out, qkv_out,
*qkv_bias, *qkv_bias,
*src_mask, *src_mask,
sequence_lengths,
rotary_tensor, rotary_tensor,
cache_kv_out, cache_kv_out,
&fmha_out, &fmha_out,
...@@ -899,6 +1059,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -899,6 +1059,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
kv_transpose_out_data, kv_transpose_out_data,
qkv_out_data, qkv_out_data,
qkv_bias->data<T>(), qkv_bias->data<T>(),
padding_offset_data,
token_num,
bsz, bsz,
num_head, num_head,
seq_len, seq_len,
...@@ -909,12 +1071,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -909,12 +1071,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) { if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>(); auto *rotary_emb_data = rotary_tensor->data<T>();
const int *sequence_lengths_data =
encoder_remove_padding ? sequence_lengths->data<int>() : nullptr;
rotary_qk(dev_ctx, rotary_qk(dev_ctx,
q_transpose_out_data, q_transpose_out_data,
kv_transpose_out_data, kv_transpose_out_data,
q_transpose_out_data, q_transpose_out_data,
kv_transpose_out_data, kv_transpose_out_data,
rotary_emb_data, rotary_emb_data,
sequence_lengths_data,
rotary_emb_dims, rotary_emb_dims,
bsz, bsz,
num_head, num_head,
...@@ -922,8 +1087,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -922,8 +1087,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
dim_head); dim_head);
} }
phi::DenseTensor *tmp_padding_offset_tensor =
encoder_remove_padding ? &padding_offset_tensor : nullptr;
fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor,
src_mask, src_mask,
tmp_padding_offset_tensor,
&q_transpose_out, &q_transpose_out,
&kv_transpose_out, &kv_transpose_out,
pre_cache_kv_out_tmp, pre_cache_kv_out_tmp,
...@@ -933,7 +1101,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -933,7 +1101,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
&attn_dropout_mask_out, &attn_dropout_mask_out,
&attn_dropout_out, &attn_dropout_out,
&qktv_out, &qktv_out,
&fmha_out); &fmha_out,
token_num);
const T *k_ptr = nullptr; const T *k_ptr = nullptr;
const T *v_ptr = nullptr; const T *v_ptr = nullptr;
...@@ -977,6 +1146,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -977,6 +1146,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
kv_transpose_out_data, kv_transpose_out_data,
qkv_out_data, qkv_out_data,
qkv_bias->data<T>(), qkv_bias->data<T>(),
padding_offset_data,
token_num,
bsz, bsz,
num_head, num_head,
seq_len, seq_len,
...@@ -987,12 +1158,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -987,12 +1158,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) { if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>(); auto *rotary_emb_data = rotary_tensor->data<T>();
const int *sequence_lengths_data =
encoder_remove_padding ? sequence_lengths->data<int>() : nullptr;
rotary_qk(dev_ctx, rotary_qk(dev_ctx,
q_transpose_out_data, q_transpose_out_data,
kv_transpose_out_data, kv_transpose_out_data,
q_transpose_out_data, q_transpose_out_data,
kv_transpose_out_data, kv_transpose_out_data,
rotary_emb_data, rotary_emb_data,
sequence_lengths_data,
rotary_emb_dims, rotary_emb_dims,
bsz, bsz,
num_head, num_head,
...@@ -1000,8 +1174,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -1000,8 +1174,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
dim_head); dim_head);
} }
phi::DenseTensor *tmp_padding_offset_tensor =
encoder_remove_padding ? &padding_offset_tensor : nullptr;
fmha_compute.ComputeForwardWithoutTranspose(cache_kv, fmha_compute.ComputeForwardWithoutTranspose(cache_kv,
src_mask, src_mask,
tmp_padding_offset_tensor,
&q_transpose_out, &q_transpose_out,
&kv_transpose_out, &kv_transpose_out,
cache_kv_out, cache_kv_out,
...@@ -1011,7 +1188,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -1011,7 +1188,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
&attn_dropout_mask_out, &attn_dropout_mask_out,
&attn_dropout_out, &attn_dropout_out,
&qktv_out, &qktv_out,
&fmha_out); &fmha_out,
token_num);
} }
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step3"; VLOG(0) << "step3";
...@@ -1162,6 +1340,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -1162,6 +1340,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
std::swap(buf0, buf1); std::swap(buf0, buf1);
} }
} }
if (encoder_remove_padding) {
if (pre_layer_norm) {
InvokeRebuildPadding(dev_ctx,
from_data,
buf0->data<T>(),
padding_offset_data,
token_num,
dim_embed);
} else {
InvokeRebuildPadding(dev_ctx,
from_data,
buf1->data<T>(),
padding_offset_data,
token_num,
dim_embed);
}
}
} }
}; };
......
...@@ -127,6 +127,8 @@ struct Masked_multihead_attention_params { ...@@ -127,6 +127,8 @@ struct Masked_multihead_attention_params {
// v [B, num_head, max_seq_len, dim_head] // v [B, num_head, max_seq_len, dim_head]
T *cache_kv; T *cache_kv;
const int *sequence_lengths{nullptr};
// The RoPE embedding, [B, 1, 1, dim_head] // The RoPE embedding, [B, 1, 1, dim_head]
// rotary_emb_dims = 1 if pos_ids_extra is null else 2 // rotary_emb_dims = 1 if pos_ids_extra is null else 2
const T *rotary_emb; const T *rotary_emb;
...@@ -769,6 +771,10 @@ __global__ void masked_multihead_attention_kernel( ...@@ -769,6 +771,10 @@ __global__ void masked_multihead_attention_kernel(
float qk_max = -FLT_MAX; float qk_max = -FLT_MAX;
float qk = 0; float qk = 0;
int act_time_step = params.sequence_lengths == nullptr
? params.timestep
: params.sequence_lengths[bi];
// qkv [B, S=1, 3, num_head, head_dim] // qkv [B, S=1, 3, num_head, head_dim]
int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh;
...@@ -888,7 +894,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -888,7 +894,7 @@ __global__ void masked_multihead_attention_kernel(
int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE; int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE;
int offset = bhi * params.max_seq_length * Dh + int offset = bhi * params.max_seq_length * Dh +
co * params.max_seq_length * QK_ELTS_IN_16B + co * params.max_seq_length * QK_ELTS_IN_16B +
params.timestep * QK_ELTS_IN_16B + ci; act_time_step * QK_ELTS_IN_16B + ci;
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k; *reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k;
} }
...@@ -914,7 +920,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -914,7 +920,7 @@ __global__ void masked_multihead_attention_kernel(
// qk += static_cast<float>(mask); // qk += static_cast<float>(mask);
qk *= params.inv_sqrt_dh; qk *= params.inv_sqrt_dh;
qk_max = qk; qk_max = qk;
qk_smem[params.timestep] = qk; qk_smem[act_time_step] = qk;
} }
__syncthreads(); __syncthreads();
...@@ -949,7 +955,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -949,7 +955,7 @@ __global__ void masked_multihead_attention_kernel(
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
T *k_cache = &params.cache_kv[bhi * params.max_seq_length * Dh + ki]; T *k_cache = &params.cache_kv[bhi * params.max_seq_length * Dh + ki];
int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP;
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
K_vec k[K_VECS_PER_THREAD]; K_vec k[K_VECS_PER_THREAD];
...@@ -958,7 +964,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -958,7 +964,7 @@ __global__ void masked_multihead_attention_kernel(
#pragma unroll #pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_seq_length + ti; int jj = ii * params.max_seq_length + ti;
if (ti < params.timestep) { if (ti < act_time_step) {
k[ii] = k[ii] =
(Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length)
? *reinterpret_cast<const K_vec *>( ? *reinterpret_cast<const K_vec *>(
...@@ -972,7 +978,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -972,7 +978,7 @@ __global__ void masked_multihead_attention_kernel(
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k, params.inv_sqrt_dh); float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k, params.inv_sqrt_dh);
// bool is_mask = false; // bool is_mask = false;
if (ti < params.timestep && tid % THREADS_PER_KEY == 0) { if (ti < act_time_step && tid % THREADS_PER_KEY == 0) {
// qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); // qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
T mask = params.attn_mask[bi * (params.timestep + 1) + ti]; T mask = params.attn_mask[bi * (params.timestep + 1) + ti];
qk += static_cast<float>(mask); qk += static_cast<float>(mask);
...@@ -1014,7 +1020,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -1014,7 +1020,7 @@ __global__ void masked_multihead_attention_kernel(
#endif #endif
float sum = 0.f; float sum = 0.f;
for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) { for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) {
// bool is_mask = false; // bool is_mask = false;
// float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max); // float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max);
float logit = __expf(qk_smem[ti] - qk_max); float logit = __expf(qk_smem[ti] - qk_max);
...@@ -1026,7 +1032,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -1026,7 +1032,7 @@ __global__ void masked_multihead_attention_kernel(
// FIXME(wangxi): need add 1.e-6f? // FIXME(wangxi): need add 1.e-6f?
float inv_sum = __fdividef(1.f, sum + 1.e-6f); float inv_sum = __fdividef(1.f, sum + 1.e-6f);
for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) { for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) {
convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum); convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum);
} }
__syncthreads(); __syncthreads();
...@@ -1052,7 +1058,7 @@ __global__ void masked_multihead_attention_kernel( ...@@ -1052,7 +1058,7 @@ __global__ void masked_multihead_attention_kernel(
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
if (Dh == Dh_MAX || vi < Dh) { if (Dh == Dh_MAX || vi < Dh) {
for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) { for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) {
V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]); V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti]; float logit = logits_smem[ti];
...@@ -1076,18 +1082,18 @@ __global__ void masked_multihead_attention_kernel( ...@@ -1076,18 +1082,18 @@ __global__ void masked_multihead_attention_kernel(
V_vec v_bias; V_vec v_bias;
zero(v_bias); zero(v_bias);
if (vo == (params.timestep % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) {
V_vec v = *reinterpret_cast<const V_vec *>( V_vec v = *reinterpret_cast<const V_vec *>(
&params.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); &params.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]);
v_bias = *reinterpret_cast<const V_vec *>( v_bias = *reinterpret_cast<const V_vec *>(
&params.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); &params.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]);
v = add(v, v_bias); v = add(v, v_bias);
*reinterpret_cast<V_vec *>(&v_cache[params.timestep * Dh]) = v; *reinterpret_cast<V_vec *>(&v_cache[act_time_step * Dh]) = v;
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
out = fma(logits_smem[params.timestep], cast_to_float(v), out); out = fma(logits_smem[act_time_step], cast_to_float(v), out);
#else #else
out = fma(logits_smem[params.timestep], v, out); out = fma(logits_smem[act_time_step], v, out);
#endif #endif
} }
...@@ -1198,6 +1204,7 @@ void fmha(const phi::GPUContext &dev_ctx, ...@@ -1198,6 +1204,7 @@ void fmha(const phi::GPUContext &dev_ctx,
const phi::DenseTensor &qkv_tensor, const phi::DenseTensor &qkv_tensor,
const phi::DenseTensor &qkv_bias_tensor, const phi::DenseTensor &qkv_bias_tensor,
const phi::DenseTensor &src_mask_tensor, const phi::DenseTensor &src_mask_tensor,
const phi::DenseTensor *sequence_lengths_tensor,
const phi::DenseTensor *rotary_tensor, const phi::DenseTensor *rotary_tensor,
phi::DenseTensor *cache_kv_tensor, phi::DenseTensor *cache_kv_tensor,
phi::DenseTensor *out_tensor, phi::DenseTensor *out_tensor,
...@@ -1214,6 +1221,11 @@ void fmha(const phi::GPUContext &dev_ctx, ...@@ -1214,6 +1221,11 @@ void fmha(const phi::GPUContext &dev_ctx,
params.qkv_bias = qkv_bias_tensor.data<T>(); params.qkv_bias = qkv_bias_tensor.data<T>();
params.attn_mask = src_mask_tensor.data<T>(); params.attn_mask = src_mask_tensor.data<T>();
params.cache_kv = cache_kv_tensor->data<T>(); params.cache_kv = cache_kv_tensor->data<T>();
if (sequence_lengths_tensor) {
params.sequence_lengths = sequence_lengths_tensor->data<int>();
}
if (rotary_emb_dims > 0) { if (rotary_emb_dims > 0) {
params.rotary_emb = rotary_tensor->data<T>(); params.rotary_emb = rotary_tensor->data<T>();
} else { } else {
...@@ -1273,6 +1285,7 @@ void fmha(const phi::GPUContext &dev_ctx, ...@@ -1273,6 +1285,7 @@ void fmha(const phi::GPUContext &dev_ctx,
qkv_bias_tensor, qkv_bias_tensor,
src_mask_tensor, src_mask_tensor,
nullptr, nullptr,
nullptr,
cache_kv_tensor, cache_kv_tensor,
out_tensor, out_tensor,
batch_size, batch_size,
...@@ -1395,6 +1408,7 @@ __global__ void add_fusedQKV_bias_transpose_split_kernel( ...@@ -1395,6 +1408,7 @@ __global__ void add_fusedQKV_bias_transpose_split_kernel(
T *kv_buf, T *kv_buf,
const T *qkv, const T *qkv,
const T *qkv_bias, const T *qkv_bias,
const int *padding_offset,
const int32_t elem_cnt, const int32_t elem_cnt,
const int batch_size, const int batch_size,
const int seq_len, const int seq_len,
...@@ -1423,10 +1437,10 @@ __global__ void add_fusedQKV_bias_transpose_split_kernel( ...@@ -1423,10 +1437,10 @@ __global__ void add_fusedQKV_bias_transpose_split_kernel(
} }
} }
const int32_t token_idx = linear_index / fused_hidden_size; const int32_t token_idx = linear_index / fused_hidden_size;
// const int32_t token_padded_idx = token_idx + (padding_offset == nullptr ? const int32_t ori_token_idx =
// 0 : padding_offset[token_idx]); token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
const int32_t target_batch_id = token_idx / seq_len; const int32_t target_batch_id = ori_token_idx / seq_len;
const int32_t seq_id = token_idx % seq_len; const int32_t seq_id = ori_token_idx % seq_len;
// equal to: // equal to:
// const int qkv_id = (linear_index % fused_hidden_size) / hidden_size; // const int qkv_id = (linear_index % fused_hidden_size) / hidden_size;
...@@ -1475,12 +1489,13 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, ...@@ -1475,12 +1489,13 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
T *kv_buf, T *kv_buf,
const T *qkv, const T *qkv,
const T *qkv_bias, const T *qkv_bias,
const int *padding_offset,
const int token_num,
const int batch_size, const int batch_size,
const int head_num, const int head_num,
const int seq_len, const int seq_len,
const int size_per_head, const int size_per_head,
bool compute_bias) { bool compute_bias) {
const int32_t token_num = batch_size * seq_len;
const int32_t elem_cnt = token_num * head_num * size_per_head * 3; const int32_t elem_cnt = token_num * head_num * size_per_head * 3;
constexpr int PackSize = VEC_16B / sizeof(T); constexpr int PackSize = VEC_16B / sizeof(T);
PADDLE_ENFORCE_EQ(size_per_head % PackSize, PADDLE_ENFORCE_EQ(size_per_head % PackSize,
...@@ -1499,6 +1514,7 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, ...@@ -1499,6 +1514,7 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
kv_buf, kv_buf,
qkv, qkv,
qkv_bias, qkv_bias,
padding_offset,
elem_cnt, elem_cnt,
batch_size, batch_size,
seq_len, seq_len,
...@@ -1511,6 +1527,7 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, ...@@ -1511,6 +1527,7 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
kv_buf, kv_buf,
qkv, qkv,
qkv_bias, qkv_bias,
padding_offset,
elem_cnt, elem_cnt,
batch_size, batch_size,
seq_len, seq_len,
...@@ -1524,7 +1541,9 @@ template <typename T> ...@@ -1524,7 +1541,9 @@ template <typename T>
__global__ void RotrayKernel(const T *input, __global__ void RotrayKernel(const T *input,
const T *cos_emb, const T *cos_emb,
const T *sin_emb, const T *sin_emb,
const int *sequence_lengths,
T *output, T *output,
const int rotary_emb_dims,
const int batch_size, const int batch_size,
const int head_num, const int head_num,
const int seq_len, const int seq_len,
...@@ -1532,6 +1551,7 @@ __global__ void RotrayKernel(const T *input, ...@@ -1532,6 +1551,7 @@ __global__ void RotrayKernel(const T *input,
int bi = blockIdx.x; int bi = blockIdx.x;
int hi = blockIdx.y; int hi = blockIdx.y;
int si = blockIdx.z; int si = blockIdx.z;
if (sequence_lengths && si >= sequence_lengths[bi] * rotary_emb_dims) return;
int half_lastdim = last_dim / 2; int half_lastdim = last_dim / 2;
// Note(ZhenyuLi): Calculate the relevant data at one time, so that no // Note(ZhenyuLi): Calculate the relevant data at one time, so that no
// additional space is required. // additional space is required.
...@@ -1559,6 +1579,7 @@ void rotary_qk(const phi::GPUContext &dev_ctx, ...@@ -1559,6 +1579,7 @@ void rotary_qk(const phi::GPUContext &dev_ctx,
const T *q_input, // q const T *q_input, // q
const T *k_input, // kv const T *k_input, // kv
const T *rotary_emb, const T *rotary_emb,
const int *sequence_lengths,
const int rotary_emb_dims, const int rotary_emb_dims,
const int batch_size, const int batch_size,
const int head_num, const int head_num,
...@@ -1592,7 +1613,9 @@ void rotary_qk(const phi::GPUContext &dev_ctx, ...@@ -1592,7 +1613,9 @@ void rotary_qk(const phi::GPUContext &dev_ctx,
q_input, q_input,
cos_emb, cos_emb,
sin_emb, sin_emb,
sequence_lengths,
q, q,
rotary_emb_dims,
batch_size, batch_size,
head_num, head_num,
seq_len * rotary_emb_dims, seq_len * rotary_emb_dims,
...@@ -1601,13 +1624,109 @@ void rotary_qk(const phi::GPUContext &dev_ctx, ...@@ -1601,13 +1624,109 @@ void rotary_qk(const phi::GPUContext &dev_ctx,
k_input, k_input,
cos_emb, cos_emb,
sin_emb, sin_emb,
sequence_lengths,
k, k,
rotary_emb_dims,
batch_size, batch_size,
head_num, head_num,
seq_len * rotary_emb_dims, seq_len * rotary_emb_dims,
last_dim); last_dim);
} }
__global__ void GetPaddingOffset(int *d_token_num,
int *padding_offset,
const int *sequence_lengths,
const int batch_size,
const int max_seq_len) {
// get padding offset of each batch
int total_seq_len = 0;
int cum_offset = 0;
int index = 0;
for (int i = 0; i < batch_size; i++) {
const int seq_len = sequence_lengths[i];
for (int j = 0; j < seq_len; j++) {
padding_offset[index] = cum_offset;
index++;
}
cum_offset += max_seq_len - seq_len;
total_seq_len += seq_len;
}
d_token_num[0] = total_seq_len;
}
void InvokeGetPaddingOffset(const phi::GPUContext &dev_ctx,
int *h_token_num,
int *d_token_num,
int *padding_offset,
const int *sequence_lengths,
const int batch_size,
const int max_seq_len) {
GetPaddingOffset<<<1, 1, 0, dev_ctx.stream()>>>(
d_token_num, padding_offset, sequence_lengths, batch_size, max_seq_len);
memory::Copy(platform::CPUPlace(),
h_token_num,
dev_ctx.GetPlace(),
d_token_num,
sizeof(int),
dev_ctx.stream());
}
template <typename T>
__global__ void RemovePadding(T *output_data,
const T *input_data,
const int *padding_offset,
const int dim_embed) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int src_seq_id = bid + padding_offset[bid];
const int tgt_seq_id = bid;
for (int i = tid; i < dim_embed; i += blockDim.x) {
output_data[tgt_seq_id * dim_embed + i] =
input_data[src_seq_id * dim_embed + i];
}
}
template <typename T>
void InvokeRemovePadding(const phi::GPUContext &dev_ctx,
T *output_data,
const T *input_data,
const int *padding_offset,
const int token_num,
const int dim_embed) {
RemovePadding<<<token_num, 256, 0, dev_ctx.stream()>>>(
output_data, input_data, padding_offset, dim_embed);
}
template <typename T>
__global__ void RebuildPadding(T *output_data,
const T *input_data,
const int *padding_offset,
const int dim_embed) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int dst_seq_id = bid + padding_offset[bid];
const int src_seq_id = bid;
for (int i = tid; i < dim_embed; i += blockDim.x) {
output_data[dst_seq_id * dim_embed + i] =
input_data[src_seq_id * dim_embed + i];
}
}
template <typename T>
void InvokeRebuildPadding(const phi::GPUContext &dev_ctx,
T *output_data,
const T *input_data,
const int *padding_offset,
const int token_num,
const int dim_embed) {
// src: [token_num, dim_embed]
// dst: [batch_size * max_seq_len, dim_embed]
RebuildPadding<<<token_num, 256, 0, dev_ctx.stream()>>>(
output_data, input_data, padding_offset, dim_embed);
}
#if CUDA_VERSION >= 11060 #if CUDA_VERSION >= 11060
// Only Used in Inference // Only Used in Inference
template <typename T> template <typename T>
......
...@@ -64,6 +64,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -64,6 +64,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"PreCaches", "PreCaches",
"RotaryPosEmb", "RotaryPosEmb",
"TimeStep", "TimeStep",
"SeqLengths",
"SrcMask", "SrcMask",
"OutLinearW", "OutLinearW",
"OutLinearBias", "OutLinearBias",
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
import unittest import unittest
import numpy as np import numpy as np
...@@ -28,6 +29,7 @@ from paddle.nn.layer.common import Dropout, Linear ...@@ -28,6 +29,7 @@ from paddle.nn.layer.common import Dropout, Linear
from paddle.nn.layer.norm import LayerNorm from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.transformer import _convert_attention_mask from paddle.nn.layer.transformer import _convert_attention_mask
random.seed(42)
default_main_program().random_seed = 42 default_main_program().random_seed = 42
...@@ -123,6 +125,8 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -123,6 +125,8 @@ class TestFusedMultiTransformerOp(OpTest):
self.rotary_embs = None self.rotary_embs = None
self.rotary_emb_dims = 0 self.rotary_emb_dims = 0
self.remove_padding = False
self.training = False self.training = False
self.layers = 4 self.layers = 4
...@@ -175,6 +179,27 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -175,6 +179,27 @@ class TestFusedMultiTransformerOp(OpTest):
else: else:
self.cache_kv = None self.cache_kv = None
if self.remove_padding:
if self.has_cache_kv and not self.gen_cache_kv:
# decoder
self.seq_lens = [
random.randint(1, self.cache_length)
for _ in range(self.batch_size)
]
self.seq_lens[
random.randint(0, self.batch_size)
] = self.cache_length
self.seq_lens = np.array(self.seq_lens).astype(np.int32)
else:
self.seq_lens = [
random.randint(1, self.query_length)
for _ in range(self.batch_size)
]
self.seq_lens[
random.randint(0, self.batch_size)
] = self.query_length
self.seq_lens = np.array(self.seq_lens).astype(np.int32)
if self.has_pre_cache: if self.has_pre_cache:
out_seq_len += self.pre_cache_num out_seq_len += self.pre_cache_num
self.pre_cache_kv = np.random.uniform( self.pre_cache_kv = np.random.uniform(
...@@ -406,6 +431,138 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -406,6 +431,138 @@ class TestFusedMultiTransformerOp(OpTest):
return final_out, cache_kvs return final_out, cache_kvs
return final_out return final_out
def GetVariableDecoderBaselineOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
final_outs = []
cache_outs = []
if self.rotary_emb_dims > 0:
rotary_embs = paddle.to_tensor(
self.rotary_embs, stop_gradient=False
)
for i in range(self.batch_size):
tensor_query = paddle.to_tensor(
self.query[i : i + 1], stop_gradient=False
)
cache_kvs = []
cache_kv = None
if self.has_cache_kv:
cache_kv = paddle.to_tensor(
self.cache_kv[:, i : i + 1, :, : self.seq_lens[i], :],
stop_gradient=False,
)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(
self.attn_mask[i : i + 1, :, :, : self.seq_lens[i] + 1],
stop_gradient=False,
)
else:
attn_mask = None
for j in range(self.layers):
residual = tensor_query
ln1_out = tensor_query
if self.pre_layer_norm:
ln1_out = self.norm(tensor_query)
q = self.q_proj(ln1_out)
q = tensor.reshape(
x=q, shape=[0, 0, self.num_heads, self.head_dim]
)
q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3])
k = self.k_proj(ln1_out)
v = self.v_proj(ln1_out)
k = tensor.reshape(
x=k, shape=[0, 0, self.num_heads, self.head_dim]
)
k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(
x=v, shape=[0, 0, self.num_heads, self.head_dim]
)
v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])
if self.rotary_emb_dims > 0:
cos_emb = rotary_embs[0][i : i + 1]
sin_emb = rotary_embs[1][i : i + 1]
q_out = self.apply_rotary_emb(
q_out, cos_emb, sin_emb, self.rotary_emb_dims
)
k_out = self.apply_rotary_emb(
k_out, cos_emb, sin_emb, self.rotary_emb_dims
)
if self.has_cache_kv:
# [1, B, n_head, cache_seq_len, head_dim]
cache_k, cache_v = paddle.split(cache_kv, 2)
cache_k = paddle.squeeze(cache_k, axis=0)
cache_v = paddle.squeeze(cache_v, axis=0)
# [B, n_head, cache_seq_len + seq_len, head_dim]
# out_seq_len = cache_seq_len + seq_len
if self.gen_cache_kv:
cache_kvs.append((k_out, v_out))
else:
cache_outs.append([k_out, v_out])
k_out = paddle.concat([cache_k, k_out], axis=-2)
v_out = paddle.concat([cache_v, v_out], axis=-2)
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len]
qk_out = paddle.matmul(x=q_out, y=k_out, transpose_y=True)
qk_out = paddle.scale(qk_out, scale=self.head_dim**-0.5)
if attn_mask is not None:
attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype)
attn_mask_out = qk_out + attn_mask
softmax_out = F.softmax(attn_mask_out)
else:
softmax_out = F.softmax(qk_out)
if self.dropout_prob:
dropout_out = F.dropout(
softmax_out,
self.dropout_prob,
training=self.training,
mode="upscale_in_train",
)
# [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, head_dim]
qktv_out = tensor.matmul(dropout_out, v_out)
else:
qktv_out = tensor.matmul(softmax_out, v_out)
fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3])
out_linear_in = tensor.reshape(
x=fmha_out,
shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]],
)
out = self.out_proj(out_linear_in)
residual_out = residual + self.dropout(out)
if not self.pre_layer_norm:
attn_out = self.norm(residual_out)
else:
attn_out = residual_out
ffn_ln_out = attn_out
if self.pre_layer_norm:
ffn_ln_out = self.ffn_norm(attn_out)
ffn1_out = self.ffn1_proj(ffn_ln_out)
ffn1_out = self.dropout(self.activation(ffn1_out))
ffn2_out = self.ffn2_proj(ffn1_out)
residual_out = attn_out + self.dropout(ffn2_out)
final_out = residual_out
if not self.pre_layer_norm:
final_out = self.ffn_norm(residual_out)
tensor_query = final_out
final_outs.append(final_out)
final_out = paddle.concat(final_outs, axis=0)
return final_out, cache_outs
def GetFusedMultiTransformerOut(self): def GetFusedMultiTransformerOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0)) paddle.disable_static(place=paddle.CUDAPlace(0))
q_proj_weight = paddle.to_tensor( q_proj_weight = paddle.to_tensor(
...@@ -482,7 +639,7 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -482,7 +639,7 @@ class TestFusedMultiTransformerOp(OpTest):
x = paddle.to_tensor(self.query, stop_gradient=False) x = paddle.to_tensor(self.query, stop_gradient=False)
cache_kvs, cache_kv = None, None cache_kvs, cache_kv = None, None
time_step = None time_step = None
pre_caches, pre_cache = None, None pre_caches = None
if self.has_cache_kv: if self.has_cache_kv:
cache_kvs = [] cache_kvs = []
...@@ -539,6 +696,11 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -539,6 +696,11 @@ class TestFusedMultiTransformerOp(OpTest):
[self.cache_length], dtype='int32', place=paddle.CPUPlace() [self.cache_length], dtype='int32', place=paddle.CPUPlace()
) )
if self.remove_padding:
seq_lens = paddle.to_tensor(self.seq_lens, dtype='int32')
else:
seq_lens = None
if self.has_pre_cache: if self.has_pre_cache:
cache_kvs = [] cache_kvs = []
max_seq_length = ( max_seq_length = (
...@@ -619,6 +781,7 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -619,6 +781,7 @@ class TestFusedMultiTransformerOp(OpTest):
rotary_emb_dims=self.rotary_emb_dims, rotary_emb_dims=self.rotary_emb_dims,
pre_caches=pre_caches, pre_caches=pre_caches,
time_step=time_step, time_step=time_step,
seq_lens=seq_lens,
attn_mask=attn_mask, attn_mask=attn_mask,
dropout_rate=self.dropout_prob, dropout_rate=self.dropout_prob,
activation=self.act_method, activation=self.act_method,
...@@ -637,13 +800,19 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -637,13 +800,19 @@ class TestFusedMultiTransformerOp(OpTest):
paddle.enable_static() paddle.enable_static()
x = paddle.fluid.data('x', self.query.shape, self.query.dtype) x = paddle.fluid.data('x', self.query.shape, self.query.dtype)
cache_kvs, cache_kv = None, None cache_kvs, cache_kv = None, None
cache_kvs_feed = None
time_step = None time_step = None
time_step_feed = None time_step_feed = None
pre_caches, pre_cache = None, None seq_lens = None
seq_lens_feed = None
pre_caches = None
pre_caches_feed = None
rotary_embs = None rotary_embs = None
if self.rotary_emb_dims > 0: if self.rotary_emb_dims > 0:
rotary_embs = paddle.to_tensor(self.rotary_embs) rotary_embs = paddle.fluid.data(
'rotary_embs', self.rotary_embs.shape, self.rotary_embs.dtype
)
if self.has_cache_kv: if self.has_cache_kv:
cache_kvs = [] cache_kvs = []
...@@ -698,6 +867,12 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -698,6 +867,12 @@ class TestFusedMultiTransformerOp(OpTest):
) )
time_step_feed = self.cache_length time_step_feed = self.cache_length
if self.remove_padding:
seq_lens = paddle.fluid.data(
'seq_lens', self.seq_lens.shape, self.seq_lens.dtype
)
seq_lens_feed = self.seq_lens
if self.has_pre_cache: if self.has_pre_cache:
cache_kvs = [] cache_kvs = []
max_seq_length = ( max_seq_length = (
...@@ -799,31 +974,43 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -799,31 +974,43 @@ class TestFusedMultiTransformerOp(OpTest):
attn_mask=attn_mask, attn_mask=attn_mask,
caches=cache_kvs, caches=cache_kvs,
pre_caches=pre_caches, pre_caches=pre_caches,
seq_lens=seq_lens,
rotary_embs=rotary_embs, rotary_embs=rotary_embs,
rotary_emb_dims=self.rotary_emb_dims, rotary_emb_dims=self.rotary_emb_dims,
time_step=time_step, time_step=time_step,
)[0] )
exe = paddle.static.Executor(place=paddle.CUDAPlace(0)) exe = paddle.static.Executor(place=paddle.CUDAPlace(0))
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
feed_data = { feed_data = {
'x': self.query, 'x': self.query,
'cache_kvs': cache_kvs_feed, 'cache_kvs': cache_kvs_feed,
'pre_caches': pre_caches_feed, 'pre_caches': pre_caches_feed,
'rotary_embs': rotary_embs, 'rotary_embs': self.rotary_embs,
'time_step': time_step_feed, 'time_step': time_step_feed,
'rotary_emb_dims': self.rotary_emb_dims, 'rotary_emb_dims': self.rotary_emb_dims,
'attn_mask': attn_mask, 'attn_mask': attn_mask,
'seq_lens': seq_lens_feed,
} }
out = exe.run( if self.has_pre_cache:
paddle.fluid.default_main_program(), out = exe.run(
feed=feed_data, paddle.fluid.default_main_program(),
fetch_list=[final_out], feed=feed_data,
) fetch_list=[final_out[0].name],
)
else:
out = exe.run(
paddle.fluid.default_main_program(),
feed=feed_data,
fetch_list=[final_out.name],
)
paddle.disable_static() paddle.disable_static()
return out[0] return out
def test_fused_multi_transformer_op(self): def test_fused_multi_transformer_op(self):
final_out_ref = self.GetBaselineOut() if self.has_cache_kv and not self.gen_cache_kv and self.remove_padding:
final_out_ref = self.GetVariableDecoderBaselineOut()
else:
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedMultiTransformerOut() final_out = self.GetFusedMultiTransformerOut()
if self.has_cache_kv: if self.has_cache_kv:
final_out, cache_kv_out = final_out final_out, cache_kv_out = final_out
...@@ -846,6 +1033,39 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -846,6 +1033,39 @@ class TestFusedMultiTransformerOp(OpTest):
print("cache_v out timestep=128") print("cache_v out timestep=128")
print(cache_kv_out[0][1, 0, 0, self.cache_length, :]) print(cache_kv_out[0][1, 0, 0, self.cache_length, :])
if self.remove_padding and not self.gen_cache_kv:
# test decoder
final_out_ref, cache_kvs = final_out_ref
for i in range(self.batch_size):
for j in range(self.layers):
cache_k = cache_kv_out[j][0, :]
cache_k = cache_k.reshape(
[bsz, num_head, v_elems, max_seq_len, elems]
)
cache_k = cache_k[:, :, :, : self.seq_lens[i] + 1, :]
cache_k = cache_k.transpose([0, 1, 3, 2, 4])
cache_k = cache_k.reshape(
[bsz, num_head, self.seq_lens[i] + 1, head_dim]
)
cache_v = cache_kv_out[j][
1, :, :, : self.seq_lens[i] + 1, :
]
cache_k_ref = cache_kvs[i * self.layers + j][0]
cache_v_ref = cache_kvs[i * self.layers + j][1]
np.testing.assert_allclose(
cache_k_ref,
cache_k[i : i + 1, :, -1:, :],
rtol=self.rtol,
atol=self.atol,
)
np.testing.assert_allclose(
cache_v_ref,
cache_v[i : i + 1, :, -1:, :],
rtol=self.rtol,
atol=self.atol,
)
if self.gen_cache_kv: if self.gen_cache_kv:
final_out_ref, cache_kvs = final_out_ref final_out_ref, cache_kvs = final_out_ref
for i in range(self.layers): for i in range(self.layers):
...@@ -864,18 +1084,42 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -864,18 +1084,42 @@ class TestFusedMultiTransformerOp(OpTest):
cache_v = cache_kv_out[i][1, :, :, : self.cache_length, :] cache_v = cache_kv_out[i][1, :, :, : self.cache_length, :]
np.testing.assert_allclose( if self.remove_padding:
cache_k_ref, cache_k, rtol=self.rtol, atol=self.atol for i in range(self.batch_size):
) np.testing.assert_allclose(
np.testing.assert_allclose( cache_k_ref[i, :, : self.seq_lens[i], :],
cache_v_ref, cache_v, rtol=self.rtol, atol=self.atol cache_k[i, :, : self.seq_lens[i], :],
) rtol=self.rtol,
atol=self.atol,
)
np.testing.assert_allclose(
cache_v_ref[i, :, : self.seq_lens[i], :],
cache_v[i, :, : self.seq_lens[i], :],
rtol=self.rtol,
atol=self.atol,
)
else:
np.testing.assert_allclose(
cache_k_ref, cache_k, rtol=self.rtol, atol=self.atol
)
np.testing.assert_allclose(
cache_v_ref, cache_v, rtol=self.rtol, atol=self.atol
)
if i == 0: if i == 0:
break break
np.testing.assert_allclose( if self.remove_padding:
final_out_ref, final_out, rtol=self.rtol, atol=self.atol for i in range(self.batch_size):
) np.testing.assert_allclose(
final_out_ref[i, : self.seq_lens[i]],
final_out[i, : self.seq_lens[i]],
rtol=self.rtol,
atol=self.atol,
)
else:
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol
)
class TestFusedMultiTransformerOpRotaryFP16(TestFusedMultiTransformerOp): class TestFusedMultiTransformerOpRotaryFP16(TestFusedMultiTransformerOp):
...@@ -1020,10 +1264,109 @@ class TestFusedMultiTransformerOpPreCache(TestFusedMultiTransformerOp): ...@@ -1020,10 +1264,109 @@ class TestFusedMultiTransformerOpPreCache(TestFusedMultiTransformerOp):
self.x_type = np.float16 self.x_type = np.float16
class TestFusedMultiTransformerOpPreCacheStatic(TestFusedMultiTransformerOp): class TestFusedMultiTransformerOpVariableGenCache1(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.remove_padding = True
self.x_type = np.float16
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerOpVariableGenCache2(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.remove_padding = True
self.layers = 4 # even layers
class TestFusedMultiTransformerOpVariableGenCache3(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.remove_padding = True
self.layers = 4 # even layers
self.rotary_emb_dims = 2
class TestFusedMultiTransformerOpVariableGenCache4(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.remove_padding = True
self.layers = 3 # odd layers
self.rotary_emb_dims = 2
class TestFusedMultiTransformerOpVariableNormTransformer1(
TestFusedMultiTransformerOp
):
def config(self):
super().config()
self.has_cache_kv = False
self.gen_cache_kv = False
self.remove_padding = True
self.x_type = np.float16
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerOpVariableNormTransformer2(
TestFusedMultiTransformerOp
):
def config(self):
super().config()
self.has_cache_kv = False
self.gen_cache_kv = False
self.remove_padding = True
self.layers = 4 # even layers
class TestFusedMultiTransformerOpVariableDecoder1(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = False
self.remove_padding = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.x_type = np.float16
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerOpVariableDecoder2(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = False
self.remove_padding = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.layers = 4 # even layers
class TestFusedMultiTransformerOpVariableDecoder3(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = False
self.remove_padding = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.layers = 4 # even layers
self.rotary_emb_dims = 2
class TestFusedMultiTransformerOpPreCacheStatic1(TestFusedMultiTransformerOp):
def config(self): def config(self):
super().config() super().config()
self.has_pre_cache = True
self.has_attn_mask = False self.has_attn_mask = False
self.x_type = np.float32 self.x_type = np.float32
self.weight_attr = paddle.ParamAttr( self.weight_attr = paddle.ParamAttr(
...@@ -1040,14 +1383,29 @@ class TestFusedMultiTransformerOpPreCacheStatic(TestFusedMultiTransformerOp): ...@@ -1040,14 +1383,29 @@ class TestFusedMultiTransformerOpPreCacheStatic(TestFusedMultiTransformerOp):
) )
def test_fused_multi_transformer_op(self): def test_fused_multi_transformer_op(self):
for i in range(3): self.has_pre_cache = True
self.rotary_emb_dims = i self.remove_padding = False
self.generate_input_data() self.rotary_emb_dims = 2
final_out_ref = self.GetBaselineOut() self.generate_input_data()
final_out = self.GetFusedMultiTransformerOutStatic() final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedMultiTransformerOutStatic()[0]
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol
)
self.has_pre_cache = False
self.remove_padding = True
self.generate_input_data()
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedMultiTransformerOutStatic()[0]
for i in range(self.batch_size):
np.testing.assert_allclose( np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol final_out_ref[i, : self.seq_lens[i]],
final_out[i, : self.seq_lens[i]],
rtol=self.rtol,
atol=self.atol,
) )
......
...@@ -887,6 +887,7 @@ def fused_multi_transformer( ...@@ -887,6 +887,7 @@ def fused_multi_transformer(
epsilon=1e-05, epsilon=1e-05,
cache_kvs=None, cache_kvs=None,
pre_caches=None, pre_caches=None,
seq_lens=None,
rotary_embs=None, rotary_embs=None,
time_step=None, time_step=None,
attn_mask=None, attn_mask=None,
...@@ -956,6 +957,7 @@ def fused_multi_transformer( ...@@ -956,6 +957,7 @@ def fused_multi_transformer(
epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5. epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5.
cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None. cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None.
pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None.
seq_lens (Tensor optional): The sequence lengths of this batch. The shape is `[bsz]`. Default None.
rotary_embs (Tensor optional): The RoPE embs for rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. rotary_embs (Tensor optional): The RoPE embs for rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None. time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
...@@ -1055,6 +1057,7 @@ def fused_multi_transformer( ...@@ -1055,6 +1057,7 @@ def fused_multi_transformer(
pre_caches, pre_caches,
rotary_embs, rotary_embs,
time_step, time_step,
seq_lens,
attn_mask, attn_mask,
linear_weights, linear_weights,
linear_biases, linear_biases,
...@@ -1115,6 +1118,7 @@ def fused_multi_transformer( ...@@ -1115,6 +1118,7 @@ def fused_multi_transformer(
inputs['PreCaches'] = pre_caches inputs['PreCaches'] = pre_caches
if rotary_emb_dims > 0: if rotary_emb_dims > 0:
inputs['RotaryPosEmb'] = rotary_embs inputs['RotaryPosEmb'] = rotary_embs
inputs['SeqLengths'] = seq_lens
inputs['SrcMask'] = attn_mask inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = linear_weights inputs['OutLinearW'] = linear_weights
if linear_biases is not None: if linear_biases is not None:
......
...@@ -1382,6 +1382,7 @@ class FusedMultiTransformer(Layer): ...@@ -1382,6 +1382,7 @@ class FusedMultiTransformer(Layer):
pre_caches=None, pre_caches=None,
rotary_embs=None, rotary_embs=None,
rotary_emb_dims=0, rotary_emb_dims=0,
seq_lens=None,
time_step=None, time_step=None,
): ):
r""" r"""
...@@ -1406,6 +1407,7 @@ class FusedMultiTransformer(Layer): ...@@ -1406,6 +1407,7 @@ class FusedMultiTransformer(Layer):
rotary_embs (Tensor optional): The RoPE embs for the rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. rotary_embs (Tensor optional): The RoPE embs for the rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None.
rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None, rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None,
1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0. 1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0.
seq_lens (Tensor optional): The sequence lengths of this batch. The shape is `[bsz]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation time_step (Tensor, optional): The time step tensor for the generation
model. Which used in decode stage, to represent the time step, model. Which used in decode stage, to represent the time step,
that is, the real seq_len of CacheKV. The shape is `[1]`, must be that is, the real seq_len of CacheKV. The shape is `[1]`, must be
...@@ -1441,6 +1443,7 @@ class FusedMultiTransformer(Layer): ...@@ -1441,6 +1443,7 @@ class FusedMultiTransformer(Layer):
pre_caches=pre_caches, pre_caches=pre_caches,
rotary_embs=rotary_embs, rotary_embs=rotary_embs,
time_step=time_step, time_step=time_step,
seq_lens=seq_lens,
attn_mask=attn_mask, attn_mask=attn_mask,
dropout_rate=self.dropout_rate, dropout_rate=self.dropout_rate,
rotary_emb_dims=rotary_emb_dims, rotary_emb_dims=rotary_emb_dims,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册