未验证 提交 53df50c7 编写于 作者: L lzy 提交者: GitHub

make FusedMultiTransformer supports variable-lengths. (#49560)

* make FusedMultiTransformer supports variable-lengths.

* modify ffn2 when cuda_version >= 11.6 because of #49392.

* code style

* delete remove_padding
上级 a3989b5e
......@@ -62,6 +62,78 @@ class AttnDropoutParam {
const phi::DenseTensor* seed_;
};
template <typename T, int VecSize>
__global__ void TransposeRemovingPadding(const T* input_data,
T* output_data,
const int batch_size,
const int num_head,
const int seq_len,
const int head_dim,
const int token_num,
const int elem_cnt,
const int* padding_offset) {
// transpose and remove padding
// [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head,
// head_dim]
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
const int dim_embed = num_head * head_dim;
using LoadT = phi::AlignedVector<T, VecSize>;
LoadT src_vec;
for (int32_t linear_index = idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / dim_embed;
const int ori_token_idx =
token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
const int ori_batch_id = ori_token_idx / seq_len;
const int ori_seq_id = ori_token_idx % seq_len;
const int ori_head_id = (linear_index % dim_embed) / head_dim;
const int ori_head_lane = (linear_index % dim_embed) % head_dim;
const int ori_idx = ori_batch_id * num_head * seq_len * head_dim +
ori_head_id * seq_len * head_dim +
ori_seq_id * head_dim + ori_head_lane;
phi::Load<T, VecSize>(&input_data[ori_idx], &src_vec);
phi::Store<T, VecSize>(src_vec, &output_data[linear_index]);
}
}
template <typename T>
void InvokeTransposeRemovePadding(const phi::GPUContext& dev_ctx,
const T* input_data,
T* output_data,
const int batch_size,
const int num_head,
const int seq_len,
const int head_dim,
const int token_num,
const int* padding_offset) {
// [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head,
// head_dim]
constexpr int VEC_16B = 16;
const int elem_cnt = token_num * num_head * head_dim;
constexpr int PackSize = VEC_16B / sizeof(T);
PADDLE_ENFORCE_EQ(
head_dim % PackSize,
0,
platform::errors::PreconditionNotMet(
"dim_head=%d must be divisible by vec_size=%d", head_dim, PackSize));
const int32_t pack_num = elem_cnt / PackSize;
const int32_t block_size = 128;
int32_t grid_size = (pack_num + block_size - 1) / block_size;
TransposeRemovingPadding<T, PackSize>
<<<grid_size, block_size, 0, dev_ctx.stream()>>>(input_data,
output_data,
batch_size,
num_head,
seq_len,
head_dim,
token_num,
elem_cnt,
padding_offset);
}
template <typename T>
class FMHARef {
public:
......@@ -256,18 +328,21 @@ class FMHARef {
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
}
void ComputeForwardWithoutTranspose(const phi::DenseTensor* cache_kv_tensor,
const phi::DenseTensor* src_mask_tensor,
phi::DenseTensor* q_transpose_out_tensor,
phi::DenseTensor* kv_transpose_out_tensor,
phi::DenseTensor* cache_kv_out_tensor,
phi::DenseTensor* qk_out_tensor,
phi::DenseTensor* src_mask_out_tensor,
phi::DenseTensor* softmax_out_tensor,
phi::DenseTensor* dropout_mask_out_tensor,
phi::DenseTensor* dropout_out_tensor,
phi::DenseTensor* qktv_out_tensor,
phi::DenseTensor* fmha_out_tensor) {
void ComputeForwardWithoutTranspose(
const phi::DenseTensor* cache_kv_tensor,
const phi::DenseTensor* src_mask_tensor,
const phi::DenseTensor* padding_offset_tensor,
phi::DenseTensor* q_transpose_out_tensor,
phi::DenseTensor* kv_transpose_out_tensor,
phi::DenseTensor* cache_kv_out_tensor,
phi::DenseTensor* qk_out_tensor,
phi::DenseTensor* src_mask_out_tensor,
phi::DenseTensor* softmax_out_tensor,
phi::DenseTensor* dropout_mask_out_tensor,
phi::DenseTensor* dropout_out_tensor,
phi::DenseTensor* qktv_out_tensor,
phi::DenseTensor* fmha_out_tensor,
const int token_num) {
// input shape: [bs, seq_len, 3, num_head, head_dim]
// transpose with perm [2, 0, 3, 1, 4],
// output_shape: [3, bs, num_head, seq_len, head_dim]
......@@ -424,9 +499,21 @@ class FMHARef {
}
// transpose: [0, 2, 1, 3]
// output shape: [batch_size, seq_len, num_heads, head_dim]
std::vector<int> perm_3 = {0, 2, 1, 3};
phi::funcs::TransposeGPUKernelDriver<T>(
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
if (!padding_offset_tensor) {
std::vector<int> perm_3 = {0, 2, 1, 3};
phi::funcs::TransposeGPUKernelDriver<T>(
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
} else {
InvokeTransposeRemovePadding<T>(dev_ctx_,
qktv_out_data,
fmha_out_data,
batch_size_,
num_head_,
seq_len_,
head_dim_,
token_num,
padding_offset_tensor->data<int>());
}
}
void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor,
......
......@@ -182,6 +182,8 @@ class FusedMultiTransformerOpOpMaker
AddInput("TimeStep",
"(optional, int) The time step for generation inference.")
.AsDispensable();
AddInput("SeqLengths", "(optional) The sequence length tensor of inputs.")
.AsDispensable();
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
.AsDispensable();
AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable();
......
......@@ -32,6 +32,48 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int dim_embed = input_x_dims[2];
int bsz_seq = bsz * seq_len;
const std::string act_method = ctx.Attr<std::string>("act_method");
bool remove_padding = false;
auto *sequence_lengths = ctx.Input<phi::DenseTensor>("SeqLengths");
if (sequence_lengths) {
remove_padding = true;
}
phi::DenseTensor d_token_tensor;
phi::DenseTensor padding_offset_tensor;
phi::DenseTensor x_remove_padding;
bool encoder_remove_padding = (remove_padding && !time_step);
int token_num = 0;
// remove padding in encoder
if (encoder_remove_padding) {
// just for encoder
d_token_tensor.Resize({{1}});
auto *d_token_num = dev_ctx.Alloc<int>(
&d_token_tensor, d_token_tensor.numel() * sizeof(int));
// alloc the max size of padding_offset_tensor
padding_offset_tensor.Resize({{bsz_seq}});
dev_ctx.Alloc<int>(&padding_offset_tensor,
padding_offset_tensor.numel() * sizeof(int));
InvokeGetPaddingOffset(dev_ctx,
&token_num,
d_token_num,
padding_offset_tensor.data<int>(),
sequence_lengths->data<int>(),
bsz,
seq_len);
padding_offset_tensor.Resize({{token_num}});
x_remove_padding.Resize({{token_num, dim_embed}});
dev_ctx.Alloc<T>(&x_remove_padding, x_remove_padding.numel() * sizeof(T));
InvokeRemovePadding(dev_ctx,
x_remove_padding.data<T>(),
input_x->data<T>(),
padding_offset_tensor.data<int>(),
token_num,
dim_embed);
} else {
token_num = bsz_seq;
}
auto *padding_offset_data =
encoder_remove_padding ? padding_offset_tensor.data<int>() : nullptr;
// 1. layer norm
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
......@@ -39,12 +81,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto ln_scales = ctx.MultiInput<phi::DenseTensor>("LnScale");
auto ln_biases = ctx.MultiInput<phi::DenseTensor>("LnBias");
auto ln_compute = AttnLayerNorm<T>(dev_ctx, epsilon, bsz_seq, dim_embed);
auto ln_compute = AttnLayerNorm<T>(dev_ctx, epsilon, token_num, dim_embed);
phi::DenseTensor ln_mean, ln_var;
ln_mean.Resize({{bsz_seq}});
ln_mean.Resize({{token_num}});
auto *ln_mean_data =
dev_ctx.Alloc<U>(&ln_mean, ln_mean.numel() * sizeof(U));
ln_var.Resize({{bsz_seq}});
ln_var.Resize({{token_num}});
auto *ln_var_data = dev_ctx.Alloc<U>(&ln_var, ln_var.numel() * sizeof(U));
// 2. qkv
......@@ -67,13 +109,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto qkv_compute = AttnMatMul<T>(dev_ctx,
false,
trans_qkvw,
bsz_seq,
token_num,
output_size,
input_size,
/*compute_bias=*/false);
phi::DenseTensor qkv_out;
qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}});
qkv_out.Resize({{token_num, 3, num_head, dim_head}});
auto *qkv_out_data =
dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T));
......@@ -175,23 +217,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int ring_id = ctx.Attr<int>("ring_id");
// (transA, transB, compute_bias) = (false, false, false)
auto out_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_embed, hidden_size, false);
dev_ctx, false, false, token_num, dim_embed, hidden_size, false);
// 5. ln(residual + bias)
DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon);
dev_ctx, token_num, dim_embed, dropout_param2, epsilon);
auto ffn_ln_scales = ctx.MultiInput<phi::DenseTensor>("FFNLnScale");
auto ffn_ln_biases = ctx.MultiInput<phi::DenseTensor>("FFNLnBias");
phi::DenseTensor bias_dropout_residual_out, dropout_mask_out;
T *bias_dropout_residual_out_data = nullptr;
if (pre_layer_norm) {
bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}});
bias_dropout_residual_out.Resize({{token_num, dim_embed}});
bias_dropout_residual_out_data =
dev_ctx.Alloc<T>(&bias_dropout_residual_out,
bias_dropout_residual_out.numel() * sizeof(T));
}
dropout_mask_out.Resize({{bsz, seq_len, dim_embed}});
dropout_mask_out.Resize({{token_num, dim_embed}});
auto *dropout_mask_out_data = dev_ctx.Alloc<uint8_t>(
&dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t));
......@@ -203,11 +245,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int dim_ffn = ffn1_weight_dim[1];
auto ffn1_cublas_linear = CublasFusedMLP<T>(dev_ctx);
const phi::DDim ffn1_input_shape({bsz_seq, dim_embed});
const phi::DDim ffn1_input_shape({token_num, dim_embed});
ffn1_cublas_linear.Setup(ffn1_input_shape, ffn1_weight_dim, false, false);
phi::DenseTensor ffn1_out;
ffn1_out.Resize({{bsz_seq, dim_ffn}});
ffn1_out.Resize({{token_num, dim_ffn}});
auto *ffn1_out_data =
dev_ctx.Alloc<T>(&ffn1_out, ffn1_out.numel() * sizeof(T));
......@@ -216,23 +258,33 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto ffn2_biases = ctx.MultiInput<phi::DenseTensor>("FFN2Bias");
auto ffn2_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_embed, dim_ffn, false);
dev_ctx, false, false, token_num, dim_embed, dim_ffn, false);
// 8. ffn2 Layernorm residual bias
DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> ffn2_fused_dropout_helper(
dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon);
dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon);
// calc
auto *out = ctx.Output<phi::DenseTensor>("Out");
auto *from_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
phi::DenseTensor *from_tensor = out;
phi::DenseTensor tmp_out;
tmp_out.Resize({{bsz, seq_len, dim_embed}});
phi::DenseTensor tmp_out, tmp_out_rm_padding;
tmp_out.Resize({{token_num, dim_embed}});
if (encoder_remove_padding) {
tmp_out_rm_padding.Resize({{token_num, dim_embed}});
auto *tmp_out_rm_padding_data = dev_ctx.Alloc<T>(
&tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T));
}
auto *tmp_out_data =
dev_ctx.Alloc<T>(&tmp_out, tmp_out.numel() * sizeof(T));
auto *x_data = input_x->data<T>();
const T *x_data;
if (encoder_remove_padding) {
x_data = x_remove_padding.data<T>();
} else {
x_data = input_x->data<T>();
}
phi::DenseTensor *buf0 = nullptr;
phi::DenseTensor *buf1 = nullptr;
......@@ -240,19 +292,27 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// step1: buf1 --> buf0
// step2: buf0 --> buf1
int layers = qkv_weights.size();
if (pre_layer_norm) {
if (layers & 1) {
// odd, set buf1 as out
if (encoder_remove_padding) {
// In the case of variable lengths, the padding needs to be rebuilt
// eventually. So buf0 and buf1 do not need to be changed according to the
// pre_layer_norm and the number of layers.
buf0 = &tmp_out;
buf1 = &tmp_out_rm_padding;
} else {
if (pre_layer_norm) {
if (layers & 1) {
// odd, set buf1 as out
buf0 = &tmp_out;
buf1 = out;
} else {
// even, set buf0 as out
buf0 = out;
buf1 = &tmp_out;
}
} else {
buf0 = &tmp_out;
buf1 = out;
} else {
// even, set buf0 as out
buf0 = out;
buf1 = &tmp_out;
}
} else {
buf0 = &tmp_out;
buf1 = out;
}
for (int i = 0; i < layers; ++i) {
......@@ -278,8 +338,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// NOTE: in decoder stage, bias is fused in fmha
const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias;
if (!pre_layer_norm && i == 0) {
const phi::DenseTensor *tmp_input_x =
(encoder_remove_padding) ? &x_remove_padding : input_x;
qkv_compute.ComputeForward(
qkv_weights[i], input_x, bias, &qkv_out, &qkv_out);
qkv_weights[i], tmp_input_x, bias, &qkv_out, &qkv_out);
} else {
qkv_compute.ComputeForward(
qkv_weights[i], buf1, bias, &qkv_out, &qkv_out);
......@@ -300,6 +362,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
qkv_out,
*qkv_bias,
*src_mask,
sequence_lengths,
rotary_tensor,
cache_kv_out,
&fmha_out,
......@@ -322,22 +385,26 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
kv_transpose_out_data,
qkv_out_data,
qkv_bias->data<T>(),
padding_offset_data,
token_num,
bsz,
num_head,
seq_len,
dim_head,
compute_bias);
// q_transpose_out_data [bs, head_num, seq_len, dim_head]
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>();
const int *sequence_lengths_data =
encoder_remove_padding ? sequence_lengths->data<int>() : nullptr;
rotary_qk(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
q_transpose_out_data,
kv_transpose_out_data,
rotary_emb_data,
sequence_lengths_data,
rotary_emb_dims,
bsz,
num_head,
......@@ -345,8 +412,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
dim_head);
}
phi::DenseTensor *tmp_padding_offset_tensor =
encoder_remove_padding ? &padding_offset_tensor : nullptr;
fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor,
src_mask,
tmp_padding_offset_tensor,
&q_transpose_out,
&kv_transpose_out,
pre_cache_kv_out_tmp,
......@@ -356,7 +426,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
&attn_dropout_mask_out,
&attn_dropout_out,
&qktv_out,
&fmha_out);
&fmha_out,
token_num);
const T *k_ptr = nullptr;
const T *v_ptr = nullptr;
......@@ -400,6 +471,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
kv_transpose_out_data,
qkv_out_data,
qkv_bias->data<T>(),
padding_offset_data,
token_num,
bsz,
num_head,
seq_len,
......@@ -410,12 +483,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>();
const int *sequence_lengths_data =
encoder_remove_padding ? sequence_lengths->data<int>() : nullptr;
rotary_qk(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
q_transpose_out_data,
kv_transpose_out_data,
rotary_emb_data,
sequence_lengths_data,
rotary_emb_dims,
bsz,
num_head,
......@@ -423,8 +499,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
dim_head);
}
phi::DenseTensor *tmp_padding_offset_tensor =
encoder_remove_padding ? &padding_offset_tensor : nullptr;
fmha_compute.ComputeForwardWithoutTranspose(cache_kv,
src_mask,
tmp_padding_offset_tensor,
&q_transpose_out,
&kv_transpose_out,
cache_kv_out,
......@@ -434,7 +513,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
&attn_dropout_mask_out,
&attn_dropout_out,
&qktv_out,
&fmha_out);
&fmha_out,
token_num);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step3";
......@@ -580,6 +660,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
std::swap(buf0, buf1);
}
}
if (encoder_remove_padding) {
if (pre_layer_norm) {
InvokeRebuildPadding(dev_ctx,
from_data,
buf0->data<T>(),
padding_offset_data,
token_num,
dim_embed);
} else {
InvokeRebuildPadding(dev_ctx,
from_data,
buf1->data<T>(),
padding_offset_data,
token_num,
dim_embed);
}
}
}
};
......@@ -601,6 +698,48 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int dim_embed = input_x_dims[2];
int bsz_seq = bsz * seq_len;
const std::string act_method = ctx.Attr<std::string>("act_method");
bool remove_padding = false;
auto *sequence_lengths = ctx.Input<phi::DenseTensor>("SeqLengths");
if (sequence_lengths) {
remove_padding = true;
}
phi::DenseTensor d_token_tensor;
phi::DenseTensor padding_offset_tensor;
phi::DenseTensor x_remove_padding;
bool encoder_remove_padding = (remove_padding && !time_step);
int token_num = 0;
// remove padding in encoder
if (encoder_remove_padding) {
// just for encoder
d_token_tensor.Resize({{1}});
auto *d_token_num = dev_ctx.Alloc<int>(
&d_token_tensor, d_token_tensor.numel() * sizeof(int));
// alloc the max size of padding_offset_tensor
padding_offset_tensor.Resize({{bsz_seq}});
dev_ctx.Alloc<int>(&padding_offset_tensor,
padding_offset_tensor.numel() * sizeof(int));
InvokeGetPaddingOffset(dev_ctx,
&token_num,
d_token_num,
padding_offset_tensor.data<int>(),
sequence_lengths->data<int>(),
bsz,
seq_len);
padding_offset_tensor.Resize({{token_num}});
x_remove_padding.Resize({{token_num, dim_embed}});
dev_ctx.Alloc<T>(&x_remove_padding, x_remove_padding.numel() * sizeof(T));
InvokeRemovePadding(dev_ctx,
x_remove_padding.data<T>(),
input_x->data<T>(),
padding_offset_tensor.data<int>(),
token_num,
dim_embed);
} else {
token_num = bsz_seq;
}
auto *padding_offset_data =
encoder_remove_padding ? padding_offset_tensor.data<int>() : nullptr;
// 1. layer norm
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
......@@ -608,12 +747,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto ln_scales = ctx.MultiInput<phi::DenseTensor>("LnScale");
auto ln_biases = ctx.MultiInput<phi::DenseTensor>("LnBias");
auto ln_compute = AttnLayerNorm<T>(dev_ctx, epsilon, bsz_seq, dim_embed);
auto ln_compute = AttnLayerNorm<T>(dev_ctx, epsilon, token_num, dim_embed);
phi::DenseTensor ln_mean, ln_var;
ln_mean.Resize({{bsz_seq}});
ln_mean.Resize({{token_num}});
auto *ln_mean_data =
dev_ctx.Alloc<U>(&ln_mean, ln_mean.numel() * sizeof(U));
ln_var.Resize({{bsz_seq}});
ln_var.Resize({{token_num}});
auto *ln_var_data = dev_ctx.Alloc<U>(&ln_var, ln_var.numel() * sizeof(U));
// 2. qkv
......@@ -636,13 +775,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto qkv_compute = AttnMatMul<T>(dev_ctx,
false,
trans_qkvw,
bsz_seq,
token_num,
output_size,
input_size,
/*compute_bias=*/false);
phi::DenseTensor qkv_out;
qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}});
qkv_out.Resize({{token_num, 3, num_head, dim_head}});
auto *qkv_out_data =
dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T));
......@@ -744,23 +883,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int ring_id = ctx.Attr<int>("ring_id");
// (transA, transB, compute_bias) = (false, false, false)
auto out_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_embed, hidden_size, false);
dev_ctx, false, false, token_num, dim_embed, hidden_size, false);
// 5. ln(residual + bias)
DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon);
dev_ctx, token_num, dim_embed, dropout_param2, epsilon);
auto ffn_ln_scales = ctx.MultiInput<phi::DenseTensor>("FFNLnScale");
auto ffn_ln_biases = ctx.MultiInput<phi::DenseTensor>("FFNLnBias");
phi::DenseTensor bias_dropout_residual_out, dropout_mask_out;
T *bias_dropout_residual_out_data = nullptr;
if (pre_layer_norm) {
bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}});
bias_dropout_residual_out.Resize({{token_num, dim_embed}});
bias_dropout_residual_out_data =
dev_ctx.Alloc<T>(&bias_dropout_residual_out,
bias_dropout_residual_out.numel() * sizeof(T));
}
dropout_mask_out.Resize({{bsz, seq_len, dim_embed}});
dropout_mask_out.Resize({{token_num, dim_embed}});
auto *dropout_mask_out_data = dev_ctx.Alloc<uint8_t>(
&dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t));
......@@ -771,21 +910,21 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
int dim_ffn = ffn1_weight_dim[1];
auto ffn1_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_ffn, dim_embed, false);
dev_ctx, false, false, token_num, dim_ffn, dim_embed, false);
phi::DenseTensor ffn1_out;
ffn1_out.Resize({{bsz_seq, dim_ffn}});
ffn1_out.Resize({{token_num, dim_ffn}});
auto *ffn1_out_data =
dev_ctx.Alloc<T>(&ffn1_out, ffn1_out.numel() * sizeof(T));
// 7. ffn act + bias
DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param);
dev_ctx, token_num, dim_ffn, ffn1_dropout_param);
phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask;
ffn1_dropout_out.Resize({{bsz_seq, dim_ffn}});
ffn1_dropout_out.Resize({{token_num, dim_ffn}});
auto *ffn1_dropout_out_data = dev_ctx.Alloc<T>(
&ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T));
ffn1_dropout_mask.Resize({{bsz_seq, dim_ffn}});
ffn1_dropout_mask.Resize({{token_num, dim_ffn}});
auto *ffn1_dropout_mask_data = dev_ctx.Alloc<uint8_t>(
&ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t));
......@@ -793,23 +932,33 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto ffn2_weights = ctx.MultiInput<phi::DenseTensor>("FFN2Weight");
auto ffn2_biases = ctx.MultiInput<phi::DenseTensor>("FFN2Bias");
auto ffn2_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_embed, dim_ffn, false);
dev_ctx, false, false, token_num, dim_embed, dim_ffn, false);
// 9. ffn2 residual bias
DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> ffn2_fused_dropout_helper(
dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon);
dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon);
// calc
auto *out = ctx.Output<phi::DenseTensor>("Out");
auto *from_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
phi::DenseTensor *from_tensor = out;
phi::DenseTensor tmp_out;
tmp_out.Resize({{bsz, seq_len, dim_embed}});
phi::DenseTensor tmp_out, tmp_out_rm_padding;
tmp_out.Resize({{token_num, dim_embed}});
if (encoder_remove_padding) {
tmp_out_rm_padding.Resize({{token_num, dim_embed}});
auto *tmp_out_rm_padding_data = dev_ctx.Alloc<T>(
&tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T));
}
auto *tmp_out_data =
dev_ctx.Alloc<T>(&tmp_out, tmp_out.numel() * sizeof(T));
auto *x_data = input_x->data<T>();
const T *x_data;
if (encoder_remove_padding) {
x_data = x_remove_padding.data<T>();
} else {
x_data = input_x->data<T>();
}
phi::DenseTensor *buf0 = nullptr;
phi::DenseTensor *buf1 = nullptr;
......@@ -817,19 +966,27 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// step1: buf1 --> buf0
// step2: buf0 --> buf1
int layers = qkv_weights.size();
if (pre_layer_norm) {
if (layers & 1) {
// odd, set buf1 as out
if (encoder_remove_padding) {
// In the case of variable lengths, the padding needs to be rebuilt
// eventually. So buf0 and buf1 do not need to be changed according to the
// pre_layer_norm and the number of layers.
buf0 = &tmp_out;
buf1 = &tmp_out_rm_padding;
} else {
if (pre_layer_norm) {
if (layers & 1) {
// odd, set buf1 as out
buf0 = &tmp_out;
buf1 = out;
} else {
// even, set buf0 as out
buf0 = out;
buf1 = &tmp_out;
}
} else {
buf0 = &tmp_out;
buf1 = out;
} else {
// even, set buf0 as out
buf0 = out;
buf1 = &tmp_out;
}
} else {
buf0 = &tmp_out;
buf1 = out;
}
for (int i = 0; i < layers; ++i) {
......@@ -855,8 +1012,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// NOTE: in decoder stage, bias is fused in fmha
const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias;
if (!pre_layer_norm && i == 0) {
const phi::DenseTensor *tmp_input_x =
(encoder_remove_padding) ? &x_remove_padding : input_x;
qkv_compute.ComputeForward(
qkv_weights[i], input_x, bias, &qkv_out, &qkv_out);
qkv_weights[i], tmp_input_x, bias, &qkv_out, &qkv_out);
} else {
qkv_compute.ComputeForward(
qkv_weights[i], buf1, bias, &qkv_out, &qkv_out);
......@@ -877,6 +1036,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
qkv_out,
*qkv_bias,
*src_mask,
sequence_lengths,
rotary_tensor,
cache_kv_out,
&fmha_out,
......@@ -899,6 +1059,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
kv_transpose_out_data,
qkv_out_data,
qkv_bias->data<T>(),
padding_offset_data,
token_num,
bsz,
num_head,
seq_len,
......@@ -909,12 +1071,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>();
const int *sequence_lengths_data =
encoder_remove_padding ? sequence_lengths->data<int>() : nullptr;
rotary_qk(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
q_transpose_out_data,
kv_transpose_out_data,
rotary_emb_data,
sequence_lengths_data,
rotary_emb_dims,
bsz,
num_head,
......@@ -922,8 +1087,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
dim_head);
}
phi::DenseTensor *tmp_padding_offset_tensor =
encoder_remove_padding ? &padding_offset_tensor : nullptr;
fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor,
src_mask,
tmp_padding_offset_tensor,
&q_transpose_out,
&kv_transpose_out,
pre_cache_kv_out_tmp,
......@@ -933,7 +1101,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
&attn_dropout_mask_out,
&attn_dropout_out,
&qktv_out,
&fmha_out);
&fmha_out,
token_num);
const T *k_ptr = nullptr;
const T *v_ptr = nullptr;
......@@ -977,6 +1146,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
kv_transpose_out_data,
qkv_out_data,
qkv_bias->data<T>(),
padding_offset_data,
token_num,
bsz,
num_head,
seq_len,
......@@ -987,12 +1158,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>();
const int *sequence_lengths_data =
encoder_remove_padding ? sequence_lengths->data<int>() : nullptr;
rotary_qk(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
q_transpose_out_data,
kv_transpose_out_data,
rotary_emb_data,
sequence_lengths_data,
rotary_emb_dims,
bsz,
num_head,
......@@ -1000,8 +1174,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
dim_head);
}
phi::DenseTensor *tmp_padding_offset_tensor =
encoder_remove_padding ? &padding_offset_tensor : nullptr;
fmha_compute.ComputeForwardWithoutTranspose(cache_kv,
src_mask,
tmp_padding_offset_tensor,
&q_transpose_out,
&kv_transpose_out,
cache_kv_out,
......@@ -1011,7 +1188,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
&attn_dropout_mask_out,
&attn_dropout_out,
&qktv_out,
&fmha_out);
&fmha_out,
token_num);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step3";
......@@ -1162,6 +1340,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
std::swap(buf0, buf1);
}
}
if (encoder_remove_padding) {
if (pre_layer_norm) {
InvokeRebuildPadding(dev_ctx,
from_data,
buf0->data<T>(),
padding_offset_data,
token_num,
dim_embed);
} else {
InvokeRebuildPadding(dev_ctx,
from_data,
buf1->data<T>(),
padding_offset_data,
token_num,
dim_embed);
}
}
}
};
......
......@@ -127,6 +127,8 @@ struct Masked_multihead_attention_params {
// v [B, num_head, max_seq_len, dim_head]
T *cache_kv;
const int *sequence_lengths{nullptr};
// The RoPE embedding, [B, 1, 1, dim_head]
// rotary_emb_dims = 1 if pos_ids_extra is null else 2
const T *rotary_emb;
......@@ -769,6 +771,10 @@ __global__ void masked_multihead_attention_kernel(
float qk_max = -FLT_MAX;
float qk = 0;
int act_time_step = params.sequence_lengths == nullptr
? params.timestep
: params.sequence_lengths[bi];
// qkv [B, S=1, 3, num_head, head_dim]
int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh;
......@@ -888,7 +894,7 @@ __global__ void masked_multihead_attention_kernel(
int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE;
int offset = bhi * params.max_seq_length * Dh +
co * params.max_seq_length * QK_ELTS_IN_16B +
params.timestep * QK_ELTS_IN_16B + ci;
act_time_step * QK_ELTS_IN_16B + ci;
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec *>(&params.cache_kv[offset]) = k;
}
......@@ -914,7 +920,7 @@ __global__ void masked_multihead_attention_kernel(
// qk += static_cast<float>(mask);
qk *= params.inv_sqrt_dh;
qk_max = qk;
qk_smem[params.timestep] = qk;
qk_smem[act_time_step] = qk;
}
__syncthreads();
......@@ -949,7 +955,7 @@ __global__ void masked_multihead_attention_kernel(
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
T *k_cache = &params.cache_kv[bhi * params.max_seq_length * Dh + ki];
int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP;
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
K_vec k[K_VECS_PER_THREAD];
......@@ -958,7 +964,7 @@ __global__ void masked_multihead_attention_kernel(
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_seq_length + ti;
if (ti < params.timestep) {
if (ti < act_time_step) {
k[ii] =
(Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length)
? *reinterpret_cast<const K_vec *>(
......@@ -972,7 +978,7 @@ __global__ void masked_multihead_attention_kernel(
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k, params.inv_sqrt_dh);
// bool is_mask = false;
if (ti < params.timestep && tid % THREADS_PER_KEY == 0) {
if (ti < act_time_step && tid % THREADS_PER_KEY == 0) {
// qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
T mask = params.attn_mask[bi * (params.timestep + 1) + ti];
qk += static_cast<float>(mask);
......@@ -1014,7 +1020,7 @@ __global__ void masked_multihead_attention_kernel(
#endif
float sum = 0.f;
for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) {
for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) {
// bool is_mask = false;
// float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max);
float logit = __expf(qk_smem[ti] - qk_max);
......@@ -1026,7 +1032,7 @@ __global__ void masked_multihead_attention_kernel(
// FIXME(wangxi): need add 1.e-6f?
float inv_sum = __fdividef(1.f, sum + 1.e-6f);
for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) {
for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) {
convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum);
}
__syncthreads();
......@@ -1052,7 +1058,7 @@ __global__ void masked_multihead_attention_kernel(
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
if (Dh == Dh_MAX || vi < Dh) {
for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) {
for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) {
V_vec v = *reinterpret_cast<const V_vec *>(&v_cache[ti * Dh]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti];
......@@ -1076,18 +1082,18 @@ __global__ void masked_multihead_attention_kernel(
V_vec v_bias;
zero(v_bias);
if (vo == (params.timestep % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) {
if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) {
V_vec v = *reinterpret_cast<const V_vec *>(
&params.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]);
v_bias = *reinterpret_cast<const V_vec *>(
&params.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]);
v = add(v, v_bias);
*reinterpret_cast<V_vec *>(&v_cache[params.timestep * Dh]) = v;
*reinterpret_cast<V_vec *>(&v_cache[act_time_step * Dh]) = v;
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
out = fma(logits_smem[params.timestep], cast_to_float(v), out);
out = fma(logits_smem[act_time_step], cast_to_float(v), out);
#else
out = fma(logits_smem[params.timestep], v, out);
out = fma(logits_smem[act_time_step], v, out);
#endif
}
......@@ -1198,6 +1204,7 @@ void fmha(const phi::GPUContext &dev_ctx,
const phi::DenseTensor &qkv_tensor,
const phi::DenseTensor &qkv_bias_tensor,
const phi::DenseTensor &src_mask_tensor,
const phi::DenseTensor *sequence_lengths_tensor,
const phi::DenseTensor *rotary_tensor,
phi::DenseTensor *cache_kv_tensor,
phi::DenseTensor *out_tensor,
......@@ -1214,6 +1221,11 @@ void fmha(const phi::GPUContext &dev_ctx,
params.qkv_bias = qkv_bias_tensor.data<T>();
params.attn_mask = src_mask_tensor.data<T>();
params.cache_kv = cache_kv_tensor->data<T>();
if (sequence_lengths_tensor) {
params.sequence_lengths = sequence_lengths_tensor->data<int>();
}
if (rotary_emb_dims > 0) {
params.rotary_emb = rotary_tensor->data<T>();
} else {
......@@ -1273,6 +1285,7 @@ void fmha(const phi::GPUContext &dev_ctx,
qkv_bias_tensor,
src_mask_tensor,
nullptr,
nullptr,
cache_kv_tensor,
out_tensor,
batch_size,
......@@ -1395,6 +1408,7 @@ __global__ void add_fusedQKV_bias_transpose_split_kernel(
T *kv_buf,
const T *qkv,
const T *qkv_bias,
const int *padding_offset,
const int32_t elem_cnt,
const int batch_size,
const int seq_len,
......@@ -1423,10 +1437,10 @@ __global__ void add_fusedQKV_bias_transpose_split_kernel(
}
}
const int32_t token_idx = linear_index / fused_hidden_size;
// const int32_t token_padded_idx = token_idx + (padding_offset == nullptr ?
// 0 : padding_offset[token_idx]);
const int32_t target_batch_id = token_idx / seq_len;
const int32_t seq_id = token_idx % seq_len;
const int32_t ori_token_idx =
token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
const int32_t target_batch_id = ori_token_idx / seq_len;
const int32_t seq_id = ori_token_idx % seq_len;
// equal to:
// const int qkv_id = (linear_index % fused_hidden_size) / hidden_size;
......@@ -1475,12 +1489,13 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
T *kv_buf,
const T *qkv,
const T *qkv_bias,
const int *padding_offset,
const int token_num,
const int batch_size,
const int head_num,
const int seq_len,
const int size_per_head,
bool compute_bias) {
const int32_t token_num = batch_size * seq_len;
const int32_t elem_cnt = token_num * head_num * size_per_head * 3;
constexpr int PackSize = VEC_16B / sizeof(T);
PADDLE_ENFORCE_EQ(size_per_head % PackSize,
......@@ -1499,6 +1514,7 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
kv_buf,
qkv,
qkv_bias,
padding_offset,
elem_cnt,
batch_size,
seq_len,
......@@ -1511,6 +1527,7 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
kv_buf,
qkv,
qkv_bias,
padding_offset,
elem_cnt,
batch_size,
seq_len,
......@@ -1524,7 +1541,9 @@ template <typename T>
__global__ void RotrayKernel(const T *input,
const T *cos_emb,
const T *sin_emb,
const int *sequence_lengths,
T *output,
const int rotary_emb_dims,
const int batch_size,
const int head_num,
const int seq_len,
......@@ -1532,6 +1551,7 @@ __global__ void RotrayKernel(const T *input,
int bi = blockIdx.x;
int hi = blockIdx.y;
int si = blockIdx.z;
if (sequence_lengths && si >= sequence_lengths[bi] * rotary_emb_dims) return;
int half_lastdim = last_dim / 2;
// Note(ZhenyuLi): Calculate the relevant data at one time, so that no
// additional space is required.
......@@ -1559,6 +1579,7 @@ void rotary_qk(const phi::GPUContext &dev_ctx,
const T *q_input, // q
const T *k_input, // kv
const T *rotary_emb,
const int *sequence_lengths,
const int rotary_emb_dims,
const int batch_size,
const int head_num,
......@@ -1592,7 +1613,9 @@ void rotary_qk(const phi::GPUContext &dev_ctx,
q_input,
cos_emb,
sin_emb,
sequence_lengths,
q,
rotary_emb_dims,
batch_size,
head_num,
seq_len * rotary_emb_dims,
......@@ -1601,13 +1624,109 @@ void rotary_qk(const phi::GPUContext &dev_ctx,
k_input,
cos_emb,
sin_emb,
sequence_lengths,
k,
rotary_emb_dims,
batch_size,
head_num,
seq_len * rotary_emb_dims,
last_dim);
}
__global__ void GetPaddingOffset(int *d_token_num,
int *padding_offset,
const int *sequence_lengths,
const int batch_size,
const int max_seq_len) {
// get padding offset of each batch
int total_seq_len = 0;
int cum_offset = 0;
int index = 0;
for (int i = 0; i < batch_size; i++) {
const int seq_len = sequence_lengths[i];
for (int j = 0; j < seq_len; j++) {
padding_offset[index] = cum_offset;
index++;
}
cum_offset += max_seq_len - seq_len;
total_seq_len += seq_len;
}
d_token_num[0] = total_seq_len;
}
void InvokeGetPaddingOffset(const phi::GPUContext &dev_ctx,
int *h_token_num,
int *d_token_num,
int *padding_offset,
const int *sequence_lengths,
const int batch_size,
const int max_seq_len) {
GetPaddingOffset<<<1, 1, 0, dev_ctx.stream()>>>(
d_token_num, padding_offset, sequence_lengths, batch_size, max_seq_len);
memory::Copy(platform::CPUPlace(),
h_token_num,
dev_ctx.GetPlace(),
d_token_num,
sizeof(int),
dev_ctx.stream());
}
template <typename T>
__global__ void RemovePadding(T *output_data,
const T *input_data,
const int *padding_offset,
const int dim_embed) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int src_seq_id = bid + padding_offset[bid];
const int tgt_seq_id = bid;
for (int i = tid; i < dim_embed; i += blockDim.x) {
output_data[tgt_seq_id * dim_embed + i] =
input_data[src_seq_id * dim_embed + i];
}
}
template <typename T>
void InvokeRemovePadding(const phi::GPUContext &dev_ctx,
T *output_data,
const T *input_data,
const int *padding_offset,
const int token_num,
const int dim_embed) {
RemovePadding<<<token_num, 256, 0, dev_ctx.stream()>>>(
output_data, input_data, padding_offset, dim_embed);
}
template <typename T>
__global__ void RebuildPadding(T *output_data,
const T *input_data,
const int *padding_offset,
const int dim_embed) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int dst_seq_id = bid + padding_offset[bid];
const int src_seq_id = bid;
for (int i = tid; i < dim_embed; i += blockDim.x) {
output_data[dst_seq_id * dim_embed + i] =
input_data[src_seq_id * dim_embed + i];
}
}
template <typename T>
void InvokeRebuildPadding(const phi::GPUContext &dev_ctx,
T *output_data,
const T *input_data,
const int *padding_offset,
const int token_num,
const int dim_embed) {
// src: [token_num, dim_embed]
// dst: [batch_size * max_seq_len, dim_embed]
RebuildPadding<<<token_num, 256, 0, dev_ctx.stream()>>>(
output_data, input_data, padding_offset, dim_embed);
}
#if CUDA_VERSION >= 11060
// Only Used in Inference
template <typename T>
......
......@@ -64,6 +64,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"PreCaches",
"RotaryPosEmb",
"TimeStep",
"SeqLengths",
"SrcMask",
"OutLinearW",
"OutLinearBias",
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
......@@ -28,6 +29,7 @@ from paddle.nn.layer.common import Dropout, Linear
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.transformer import _convert_attention_mask
random.seed(42)
default_main_program().random_seed = 42
......@@ -123,6 +125,8 @@ class TestFusedMultiTransformerOp(OpTest):
self.rotary_embs = None
self.rotary_emb_dims = 0
self.remove_padding = False
self.training = False
self.layers = 4
......@@ -175,6 +179,27 @@ class TestFusedMultiTransformerOp(OpTest):
else:
self.cache_kv = None
if self.remove_padding:
if self.has_cache_kv and not self.gen_cache_kv:
# decoder
self.seq_lens = [
random.randint(1, self.cache_length)
for _ in range(self.batch_size)
]
self.seq_lens[
random.randint(0, self.batch_size)
] = self.cache_length
self.seq_lens = np.array(self.seq_lens).astype(np.int32)
else:
self.seq_lens = [
random.randint(1, self.query_length)
for _ in range(self.batch_size)
]
self.seq_lens[
random.randint(0, self.batch_size)
] = self.query_length
self.seq_lens = np.array(self.seq_lens).astype(np.int32)
if self.has_pre_cache:
out_seq_len += self.pre_cache_num
self.pre_cache_kv = np.random.uniform(
......@@ -406,6 +431,138 @@ class TestFusedMultiTransformerOp(OpTest):
return final_out, cache_kvs
return final_out
def GetVariableDecoderBaselineOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
final_outs = []
cache_outs = []
if self.rotary_emb_dims > 0:
rotary_embs = paddle.to_tensor(
self.rotary_embs, stop_gradient=False
)
for i in range(self.batch_size):
tensor_query = paddle.to_tensor(
self.query[i : i + 1], stop_gradient=False
)
cache_kvs = []
cache_kv = None
if self.has_cache_kv:
cache_kv = paddle.to_tensor(
self.cache_kv[:, i : i + 1, :, : self.seq_lens[i], :],
stop_gradient=False,
)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(
self.attn_mask[i : i + 1, :, :, : self.seq_lens[i] + 1],
stop_gradient=False,
)
else:
attn_mask = None
for j in range(self.layers):
residual = tensor_query
ln1_out = tensor_query
if self.pre_layer_norm:
ln1_out = self.norm(tensor_query)
q = self.q_proj(ln1_out)
q = tensor.reshape(
x=q, shape=[0, 0, self.num_heads, self.head_dim]
)
q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3])
k = self.k_proj(ln1_out)
v = self.v_proj(ln1_out)
k = tensor.reshape(
x=k, shape=[0, 0, self.num_heads, self.head_dim]
)
k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(
x=v, shape=[0, 0, self.num_heads, self.head_dim]
)
v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])
if self.rotary_emb_dims > 0:
cos_emb = rotary_embs[0][i : i + 1]
sin_emb = rotary_embs[1][i : i + 1]
q_out = self.apply_rotary_emb(
q_out, cos_emb, sin_emb, self.rotary_emb_dims
)
k_out = self.apply_rotary_emb(
k_out, cos_emb, sin_emb, self.rotary_emb_dims
)
if self.has_cache_kv:
# [1, B, n_head, cache_seq_len, head_dim]
cache_k, cache_v = paddle.split(cache_kv, 2)
cache_k = paddle.squeeze(cache_k, axis=0)
cache_v = paddle.squeeze(cache_v, axis=0)
# [B, n_head, cache_seq_len + seq_len, head_dim]
# out_seq_len = cache_seq_len + seq_len
if self.gen_cache_kv:
cache_kvs.append((k_out, v_out))
else:
cache_outs.append([k_out, v_out])
k_out = paddle.concat([cache_k, k_out], axis=-2)
v_out = paddle.concat([cache_v, v_out], axis=-2)
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len]
qk_out = paddle.matmul(x=q_out, y=k_out, transpose_y=True)
qk_out = paddle.scale(qk_out, scale=self.head_dim**-0.5)
if attn_mask is not None:
attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype)
attn_mask_out = qk_out + attn_mask
softmax_out = F.softmax(attn_mask_out)
else:
softmax_out = F.softmax(qk_out)
if self.dropout_prob:
dropout_out = F.dropout(
softmax_out,
self.dropout_prob,
training=self.training,
mode="upscale_in_train",
)
# [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, head_dim]
qktv_out = tensor.matmul(dropout_out, v_out)
else:
qktv_out = tensor.matmul(softmax_out, v_out)
fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3])
out_linear_in = tensor.reshape(
x=fmha_out,
shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]],
)
out = self.out_proj(out_linear_in)
residual_out = residual + self.dropout(out)
if not self.pre_layer_norm:
attn_out = self.norm(residual_out)
else:
attn_out = residual_out
ffn_ln_out = attn_out
if self.pre_layer_norm:
ffn_ln_out = self.ffn_norm(attn_out)
ffn1_out = self.ffn1_proj(ffn_ln_out)
ffn1_out = self.dropout(self.activation(ffn1_out))
ffn2_out = self.ffn2_proj(ffn1_out)
residual_out = attn_out + self.dropout(ffn2_out)
final_out = residual_out
if not self.pre_layer_norm:
final_out = self.ffn_norm(residual_out)
tensor_query = final_out
final_outs.append(final_out)
final_out = paddle.concat(final_outs, axis=0)
return final_out, cache_outs
def GetFusedMultiTransformerOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
q_proj_weight = paddle.to_tensor(
......@@ -482,7 +639,7 @@ class TestFusedMultiTransformerOp(OpTest):
x = paddle.to_tensor(self.query, stop_gradient=False)
cache_kvs, cache_kv = None, None
time_step = None
pre_caches, pre_cache = None, None
pre_caches = None
if self.has_cache_kv:
cache_kvs = []
......@@ -539,6 +696,11 @@ class TestFusedMultiTransformerOp(OpTest):
[self.cache_length], dtype='int32', place=paddle.CPUPlace()
)
if self.remove_padding:
seq_lens = paddle.to_tensor(self.seq_lens, dtype='int32')
else:
seq_lens = None
if self.has_pre_cache:
cache_kvs = []
max_seq_length = (
......@@ -619,6 +781,7 @@ class TestFusedMultiTransformerOp(OpTest):
rotary_emb_dims=self.rotary_emb_dims,
pre_caches=pre_caches,
time_step=time_step,
seq_lens=seq_lens,
attn_mask=attn_mask,
dropout_rate=self.dropout_prob,
activation=self.act_method,
......@@ -637,13 +800,19 @@ class TestFusedMultiTransformerOp(OpTest):
paddle.enable_static()
x = paddle.fluid.data('x', self.query.shape, self.query.dtype)
cache_kvs, cache_kv = None, None
cache_kvs_feed = None
time_step = None
time_step_feed = None
pre_caches, pre_cache = None, None
seq_lens = None
seq_lens_feed = None
pre_caches = None
pre_caches_feed = None
rotary_embs = None
if self.rotary_emb_dims > 0:
rotary_embs = paddle.to_tensor(self.rotary_embs)
rotary_embs = paddle.fluid.data(
'rotary_embs', self.rotary_embs.shape, self.rotary_embs.dtype
)
if self.has_cache_kv:
cache_kvs = []
......@@ -698,6 +867,12 @@ class TestFusedMultiTransformerOp(OpTest):
)
time_step_feed = self.cache_length
if self.remove_padding:
seq_lens = paddle.fluid.data(
'seq_lens', self.seq_lens.shape, self.seq_lens.dtype
)
seq_lens_feed = self.seq_lens
if self.has_pre_cache:
cache_kvs = []
max_seq_length = (
......@@ -799,31 +974,43 @@ class TestFusedMultiTransformerOp(OpTest):
attn_mask=attn_mask,
caches=cache_kvs,
pre_caches=pre_caches,
seq_lens=seq_lens,
rotary_embs=rotary_embs,
rotary_emb_dims=self.rotary_emb_dims,
time_step=time_step,
)[0]
)
exe = paddle.static.Executor(place=paddle.CUDAPlace(0))
exe.run(paddle.static.default_startup_program())
feed_data = {
'x': self.query,
'cache_kvs': cache_kvs_feed,
'pre_caches': pre_caches_feed,
'rotary_embs': rotary_embs,
'rotary_embs': self.rotary_embs,
'time_step': time_step_feed,
'rotary_emb_dims': self.rotary_emb_dims,
'attn_mask': attn_mask,
'seq_lens': seq_lens_feed,
}
out = exe.run(
paddle.fluid.default_main_program(),
feed=feed_data,
fetch_list=[final_out],
)
if self.has_pre_cache:
out = exe.run(
paddle.fluid.default_main_program(),
feed=feed_data,
fetch_list=[final_out[0].name],
)
else:
out = exe.run(
paddle.fluid.default_main_program(),
feed=feed_data,
fetch_list=[final_out.name],
)
paddle.disable_static()
return out[0]
return out
def test_fused_multi_transformer_op(self):
final_out_ref = self.GetBaselineOut()
if self.has_cache_kv and not self.gen_cache_kv and self.remove_padding:
final_out_ref = self.GetVariableDecoderBaselineOut()
else:
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedMultiTransformerOut()
if self.has_cache_kv:
final_out, cache_kv_out = final_out
......@@ -846,6 +1033,39 @@ class TestFusedMultiTransformerOp(OpTest):
print("cache_v out timestep=128")
print(cache_kv_out[0][1, 0, 0, self.cache_length, :])
if self.remove_padding and not self.gen_cache_kv:
# test decoder
final_out_ref, cache_kvs = final_out_ref
for i in range(self.batch_size):
for j in range(self.layers):
cache_k = cache_kv_out[j][0, :]
cache_k = cache_k.reshape(
[bsz, num_head, v_elems, max_seq_len, elems]
)
cache_k = cache_k[:, :, :, : self.seq_lens[i] + 1, :]
cache_k = cache_k.transpose([0, 1, 3, 2, 4])
cache_k = cache_k.reshape(
[bsz, num_head, self.seq_lens[i] + 1, head_dim]
)
cache_v = cache_kv_out[j][
1, :, :, : self.seq_lens[i] + 1, :
]
cache_k_ref = cache_kvs[i * self.layers + j][0]
cache_v_ref = cache_kvs[i * self.layers + j][1]
np.testing.assert_allclose(
cache_k_ref,
cache_k[i : i + 1, :, -1:, :],
rtol=self.rtol,
atol=self.atol,
)
np.testing.assert_allclose(
cache_v_ref,
cache_v[i : i + 1, :, -1:, :],
rtol=self.rtol,
atol=self.atol,
)
if self.gen_cache_kv:
final_out_ref, cache_kvs = final_out_ref
for i in range(self.layers):
......@@ -864,18 +1084,42 @@ class TestFusedMultiTransformerOp(OpTest):
cache_v = cache_kv_out[i][1, :, :, : self.cache_length, :]
np.testing.assert_allclose(
cache_k_ref, cache_k, rtol=self.rtol, atol=self.atol
)
np.testing.assert_allclose(
cache_v_ref, cache_v, rtol=self.rtol, atol=self.atol
)
if self.remove_padding:
for i in range(self.batch_size):
np.testing.assert_allclose(
cache_k_ref[i, :, : self.seq_lens[i], :],
cache_k[i, :, : self.seq_lens[i], :],
rtol=self.rtol,
atol=self.atol,
)
np.testing.assert_allclose(
cache_v_ref[i, :, : self.seq_lens[i], :],
cache_v[i, :, : self.seq_lens[i], :],
rtol=self.rtol,
atol=self.atol,
)
else:
np.testing.assert_allclose(
cache_k_ref, cache_k, rtol=self.rtol, atol=self.atol
)
np.testing.assert_allclose(
cache_v_ref, cache_v, rtol=self.rtol, atol=self.atol
)
if i == 0:
break
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol
)
if self.remove_padding:
for i in range(self.batch_size):
np.testing.assert_allclose(
final_out_ref[i, : self.seq_lens[i]],
final_out[i, : self.seq_lens[i]],
rtol=self.rtol,
atol=self.atol,
)
else:
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol
)
class TestFusedMultiTransformerOpRotaryFP16(TestFusedMultiTransformerOp):
......@@ -1020,10 +1264,109 @@ class TestFusedMultiTransformerOpPreCache(TestFusedMultiTransformerOp):
self.x_type = np.float16
class TestFusedMultiTransformerOpPreCacheStatic(TestFusedMultiTransformerOp):
class TestFusedMultiTransformerOpVariableGenCache1(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.remove_padding = True
self.x_type = np.float16
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerOpVariableGenCache2(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.remove_padding = True
self.layers = 4 # even layers
class TestFusedMultiTransformerOpVariableGenCache3(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.remove_padding = True
self.layers = 4 # even layers
self.rotary_emb_dims = 2
class TestFusedMultiTransformerOpVariableGenCache4(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = True
self.remove_padding = True
self.layers = 3 # odd layers
self.rotary_emb_dims = 2
class TestFusedMultiTransformerOpVariableNormTransformer1(
TestFusedMultiTransformerOp
):
def config(self):
super().config()
self.has_cache_kv = False
self.gen_cache_kv = False
self.remove_padding = True
self.x_type = np.float16
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerOpVariableNormTransformer2(
TestFusedMultiTransformerOp
):
def config(self):
super().config()
self.has_cache_kv = False
self.gen_cache_kv = False
self.remove_padding = True
self.layers = 4 # even layers
class TestFusedMultiTransformerOpVariableDecoder1(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = False
self.remove_padding = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.x_type = np.float16
self.layers = 3 # odd layers
self.pre_layer_norm = False
class TestFusedMultiTransformerOpVariableDecoder2(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = False
self.remove_padding = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.layers = 4 # even layers
class TestFusedMultiTransformerOpVariableDecoder3(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_cache_kv = True
self.gen_cache_kv = False
self.remove_padding = True
self.query_length = 1
self.key_length, self.value_length = 1, 1
self.layers = 4 # even layers
self.rotary_emb_dims = 2
class TestFusedMultiTransformerOpPreCacheStatic1(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.has_pre_cache = True
self.has_attn_mask = False
self.x_type = np.float32
self.weight_attr = paddle.ParamAttr(
......@@ -1040,14 +1383,29 @@ class TestFusedMultiTransformerOpPreCacheStatic(TestFusedMultiTransformerOp):
)
def test_fused_multi_transformer_op(self):
for i in range(3):
self.rotary_emb_dims = i
self.generate_input_data()
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedMultiTransformerOutStatic()
self.has_pre_cache = True
self.remove_padding = False
self.rotary_emb_dims = 2
self.generate_input_data()
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedMultiTransformerOutStatic()[0]
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol
)
self.has_pre_cache = False
self.remove_padding = True
self.generate_input_data()
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedMultiTransformerOutStatic()[0]
for i in range(self.batch_size):
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol
final_out_ref[i, : self.seq_lens[i]],
final_out[i, : self.seq_lens[i]],
rtol=self.rtol,
atol=self.atol,
)
......
......@@ -887,6 +887,7 @@ def fused_multi_transformer(
epsilon=1e-05,
cache_kvs=None,
pre_caches=None,
seq_lens=None,
rotary_embs=None,
time_step=None,
attn_mask=None,
......@@ -956,6 +957,7 @@ def fused_multi_transformer(
epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5.
cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None.
pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None.
seq_lens (Tensor optional): The sequence lengths of this batch. The shape is `[bsz]`. Default None.
rotary_embs (Tensor optional): The RoPE embs for rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
......@@ -1055,6 +1057,7 @@ def fused_multi_transformer(
pre_caches,
rotary_embs,
time_step,
seq_lens,
attn_mask,
linear_weights,
linear_biases,
......@@ -1115,6 +1118,7 @@ def fused_multi_transformer(
inputs['PreCaches'] = pre_caches
if rotary_emb_dims > 0:
inputs['RotaryPosEmb'] = rotary_embs
inputs['SeqLengths'] = seq_lens
inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = linear_weights
if linear_biases is not None:
......
......@@ -1382,6 +1382,7 @@ class FusedMultiTransformer(Layer):
pre_caches=None,
rotary_embs=None,
rotary_emb_dims=0,
seq_lens=None,
time_step=None,
):
r"""
......@@ -1406,6 +1407,7 @@ class FusedMultiTransformer(Layer):
rotary_embs (Tensor optional): The RoPE embs for the rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None.
rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None,
1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0.
seq_lens (Tensor optional): The sequence lengths of this batch. The shape is `[bsz]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation
model. Which used in decode stage, to represent the time step,
that is, the real seq_len of CacheKV. The shape is `[1]`, must be
......@@ -1441,6 +1443,7 @@ class FusedMultiTransformer(Layer):
pre_caches=pre_caches,
rotary_embs=rotary_embs,
time_step=time_step,
seq_lens=seq_lens,
attn_mask=attn_mask,
dropout_rate=self.dropout_rate,
rotary_emb_dims=rotary_emb_dims,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册