未验证 提交 644dfc60 编写于 作者: L lzy 提交者: GitHub

make FusedMultiTransformer supports RoPE (#48842)

上级 1bbc9b64
...@@ -137,7 +137,7 @@ __global__ void apply_scale(T *data, T scale, int n) { ...@@ -137,7 +137,7 @@ __global__ void apply_scale(T *data, T scale, int n) {
template <typename T> template <typename T>
__global__ void RotrayKernel(const T *inputact, __global__ void RotrayKernel(const T *inputact,
const T *input1, const T *input1,
const T *intput2, const T *input2,
T *output, T *output,
const int nElement, const int nElement,
const int lastdim) { const int lastdim) {
...@@ -147,7 +147,7 @@ __global__ void RotrayKernel(const T *inputact, ...@@ -147,7 +147,7 @@ __global__ void RotrayKernel(const T *inputact,
int col = index % lastdim; int col = index % lastdim;
int half_lastdim = lastdim / 2; int half_lastdim = lastdim / 2;
const int right_index = index - col + (col + half_lastdim) % lastdim; const int right_index = index - col + (col + half_lastdim) % lastdim;
output[index] = left_elemul_out + intput2[index] * inputact[right_index]; output[index] = left_elemul_out + input2[index] * inputact[right_index];
} }
inline int round_up(int seq_len, int multiple = 32) { inline int round_up(int seq_len, int multiple = 32) {
......
...@@ -256,8 +256,7 @@ class FMHARef { ...@@ -256,8 +256,7 @@ class FMHARef {
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
} }
void ComputeForwardWithoutTranspose(const phi::DenseTensor& qkv_input_tensor, void ComputeForwardWithoutTranspose(const phi::DenseTensor* cache_kv_tensor,
const phi::DenseTensor* cache_kv_tensor,
const phi::DenseTensor* src_mask_tensor, const phi::DenseTensor* src_mask_tensor,
phi::DenseTensor* q_transpose_out_tensor, phi::DenseTensor* q_transpose_out_tensor,
phi::DenseTensor* kv_transpose_out_tensor, phi::DenseTensor* kv_transpose_out_tensor,
......
...@@ -174,6 +174,9 @@ class FusedMultiTransformerOpOpMaker ...@@ -174,6 +174,9 @@ class FusedMultiTransformerOpOpMaker
"(optional) The prefix caches for generation inference.") "(optional) The prefix caches for generation inference.")
.AsDispensable() .AsDispensable()
.AsDuplicable(); .AsDuplicable();
AddInput("RotaryPosEmb",
"(optional) The RoPE embeddings for generation inference.")
.AsDispensable();
AddInput("TimeStep", AddInput("TimeStep",
"(optional, int) The time step for generation inference.") "(optional, int) The time step for generation inference.")
.AsDispensable(); .AsDispensable();
...@@ -209,6 +212,18 @@ class FusedMultiTransformerOpOpMaker ...@@ -209,6 +212,18 @@ class FusedMultiTransformerOpOpMaker
"else, uses post_layer_norm architecuture. " "else, uses post_layer_norm architecuture. "
"[default true].") "[default true].")
.SetDefault(true); .SetDefault(true);
AddAttr<int>("rotary_emb_dims",
"the Attr(dims) for RotaryPosEmb's Computation [default 0].")
.SetDefault(0)
.AddCustomChecker([](const int &rotary_emb_dims) {
PADDLE_ENFORCE_EQ(
rotary_emb_dims >= 0 && rotary_emb_dims <= 2,
true,
platform::errors::InvalidArgument(
"'rotary_emb_dims' in Op(Rotray) should be between"
"0 and 2, But received [%s].",
rotary_emb_dims));
});
AddAttr<float>("epsilon", AddAttr<float>("epsilon",
"Constant for numerical stability [default 1e-5].") "Constant for numerical stability [default 1e-5].")
.SetDefault(1e-5) .SetDefault(1e-5)
......
...@@ -77,6 +77,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -77,6 +77,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto *qkv_out_data = auto *qkv_out_data =
dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T)); dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T));
// 2.1 rotary
auto *rotary_tensor = ctx.Input<phi::DenseTensor>("RotaryPosEmb");
const int rotary_emb_dims = ctx.Attr<int>("rotary_emb_dims");
// 3. fmha // 3. fmha
AttnDropoutParam attn_param( AttnDropoutParam attn_param(
true, "upscale_in_train", 0.0, true, true, 0, nullptr); true, "upscale_in_train", 0.0, true, true, 0, nullptr);
...@@ -297,6 +301,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -297,6 +301,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
qkv_out, qkv_out,
*qkv_bias, *qkv_bias,
*src_mask, *src_mask,
rotary_tensor,
cache_kv_out, cache_kv_out,
&fmha_out, &fmha_out,
bsz, bsz,
...@@ -304,6 +309,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -304,6 +309,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
num_head, num_head,
dim_head, dim_head,
time_step->data<int>()[0], time_step->data<int>()[0],
rotary_emb_dims,
1. / sqrt(dim_head)); 1. / sqrt(dim_head));
} else if (cache_kv_out) { // generation context stage } else if (cache_kv_out) { // generation context stage
const phi::DenseTensor *pre_cache_kv_tensor = const phi::DenseTensor *pre_cache_kv_tensor =
...@@ -322,8 +328,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -322,8 +328,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
seq_len, seq_len,
dim_head, dim_head,
compute_bias); compute_bias);
fmha_compute.ComputeForwardWithoutTranspose(qkv_out,
pre_cache_kv_tensor, // q_transpose_out_data [bs, head_num, seq_len, dim_head]
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>();
rotary_qk(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
q_transpose_out_data,
kv_transpose_out_data,
rotary_emb_data,
rotary_emb_dims,
bsz,
num_head,
seq_len,
dim_head);
}
fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor,
src_mask, src_mask,
&q_transpose_out, &q_transpose_out,
&kv_transpose_out, &kv_transpose_out,
...@@ -383,8 +406,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -383,8 +406,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
seq_len, seq_len,
dim_head, dim_head,
compute_bias); compute_bias);
fmha_compute.ComputeForwardWithoutTranspose(qkv_out,
cache_kv, // q_transpose_out_data [bs, head_num, seq_len, dim_head]
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>();
rotary_qk(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
q_transpose_out_data,
kv_transpose_out_data,
rotary_emb_data,
rotary_emb_dims,
bsz,
num_head,
seq_len,
dim_head);
}
fmha_compute.ComputeForwardWithoutTranspose(cache_kv,
src_mask, src_mask,
&q_transpose_out, &q_transpose_out,
&kv_transpose_out, &kv_transpose_out,
...@@ -594,6 +634,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -594,6 +634,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto *qkv_out_data = auto *qkv_out_data =
dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T)); dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T));
// 2.1 rotary
auto *rotary_tensor = ctx.Input<phi::DenseTensor>("RotaryPosEmb");
const int rotary_emb_dims = ctx.Attr<int>("rotary_emb_dims");
// 3. fmha // 3. fmha
AttnDropoutParam attn_param( AttnDropoutParam attn_param(
true, "upscale_in_train", 0.0, true, true, 0, nullptr); true, "upscale_in_train", 0.0, true, true, 0, nullptr);
...@@ -821,6 +865,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -821,6 +865,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
qkv_out, qkv_out,
*qkv_bias, *qkv_bias,
*src_mask, *src_mask,
rotary_tensor,
cache_kv_out, cache_kv_out,
&fmha_out, &fmha_out,
bsz, bsz,
...@@ -828,6 +873,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -828,6 +873,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
num_head, num_head,
dim_head, dim_head,
time_step->data<int>()[0], time_step->data<int>()[0],
rotary_emb_dims,
1. / sqrt(dim_head)); 1. / sqrt(dim_head));
} else if (cache_kv_out) { // generation context stage } else if (cache_kv_out) { // generation context stage
const phi::DenseTensor *pre_cache_kv_tensor = const phi::DenseTensor *pre_cache_kv_tensor =
...@@ -846,8 +892,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -846,8 +892,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
seq_len, seq_len,
dim_head, dim_head,
compute_bias); compute_bias);
fmha_compute.ComputeForwardWithoutTranspose(qkv_out,
pre_cache_kv_tensor, // q_transpose_out_data [bs, head_num, seq_len, dim_head]
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>();
rotary_qk(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
q_transpose_out_data,
kv_transpose_out_data,
rotary_emb_data,
rotary_emb_dims,
bsz,
num_head,
seq_len,
dim_head);
}
fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor,
src_mask, src_mask,
&q_transpose_out, &q_transpose_out,
&kv_transpose_out, &kv_transpose_out,
...@@ -907,8 +970,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -907,8 +970,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
seq_len, seq_len,
dim_head, dim_head,
compute_bias); compute_bias);
fmha_compute.ComputeForwardWithoutTranspose(qkv_out,
cache_kv, // q_transpose_out_data [bs, head_num, seq_len, dim_head]
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>();
rotary_qk(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
q_transpose_out_data,
kv_transpose_out_data,
rotary_emb_data,
rotary_emb_dims,
bsz,
num_head,
seq_len,
dim_head);
}
fmha_compute.ComputeForwardWithoutTranspose(cache_kv,
src_mask, src_mask,
&q_transpose_out, &q_transpose_out,
&kv_transpose_out, &kv_transpose_out,
......
...@@ -127,6 +127,11 @@ struct Masked_multihead_attention_params { ...@@ -127,6 +127,11 @@ struct Masked_multihead_attention_params {
// v [B, num_head, max_seq_len, dim_head] // v [B, num_head, max_seq_len, dim_head]
T *cache_kv; T *cache_kv;
// The RoPE embedding, [B, 1, 1, dim_head]
// rotary_emb_dims = 1 if pos_ids_extra is null else 2
const T *rotary_emb;
int rotary_emb_dims;
int batch_size; int batch_size;
int num_head; int num_head;
int timestep; // cache_seq_length int timestep; // cache_seq_length
...@@ -404,6 +409,18 @@ inline __device__ float4 mul(float4 a, float b) { ...@@ -404,6 +409,18 @@ inline __device__ float4 mul(float4 a, float b) {
return res; return res;
} }
template <typename Qk_vec>
inline __device__ Qk_vec apply_rotary_emb(Qk_vec input_left,
Qk_vec input_right,
Qk_vec cos_emb,
Qk_vec sin_emb,
float alpha) {
Qk_vec res1 = mul<Qk_vec, Qk_vec, Qk_vec>(input_left, cos_emb);
Qk_vec res2 = mul<Qk_vec, Qk_vec, Qk_vec>(input_right, sin_emb);
res2 = mul<Qk_vec, Qk_vec, float>(res2, alpha);
return add(res1, res2);
}
inline __device__ float sum(float v) { return v; } inline __device__ float sum(float v) { return v; }
inline __device__ float sum(float2 v) { return v.x + v.y; } inline __device__ float sum(float2 v) { return v.x + v.y; }
inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; }
...@@ -804,6 +821,67 @@ __global__ void masked_multihead_attention_kernel( ...@@ -804,6 +821,67 @@ __global__ void masked_multihead_attention_kernel(
// we may not require k_bias. // we may not require k_bias.
k = add(k, k_bias); k = add(k, k_bias);
// rotary pos emb
if (params.rotary_emb_dims != 0) {
int last_dim = Dh / params.rotary_emb_dims;
int half_lastdim = last_dim / 2;
int rotary_offset = bi * Dh + tid * QK_VEC_SIZE;
const T *cos_base = params.rotary_emb;
const T *sin_base = params.rotary_emb + params.batch_size * Dh;
int stride = half_lastdim / QK_VEC_SIZE;
int stride_all_lastdim = 2 * stride;
int right_id = tid / stride_all_lastdim * stride_all_lastdim +
(tid + stride) % (stride_all_lastdim);
int qk_right_offset = qkv_base_offset + right_id * QK_VEC_SIZE;
int qk_right_bias_offset = hi * Dh + right_id * QK_VEC_SIZE;
Qk_vec q_right;
zero(q_right);
q_right =
(Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&q_base[qk_right_offset])
: q_right;
Qk_vec k_right;
zero(k_right);
k_right =
(Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&k_base[qk_right_offset])
: k_right;
Qk_vec q_right_bias;
zero(q_right_bias);
q_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(
&q_bias_base[qk_right_bias_offset])
: q_right_bias;
Qk_vec k_right_bias;
zero(k_right_bias);
k_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(
&k_bias_base[qk_right_bias_offset])
: k_right_bias;
q_right = add(q_right, q_right_bias);
k_right = add(k_right, k_right_bias);
Qk_vec cos_emb;
zero(cos_emb);
cos_emb =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&cos_base[rotary_offset])
: cos_emb;
Qk_vec sin_emb;
zero(sin_emb);
sin_emb =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&sin_base[rotary_offset])
: sin_emb;
float alpha = (tid % stride_all_lastdim) < stride ? static_cast<float>(-1)
: static_cast<float>(1);
q = apply_rotary_emb(q, q_right, cos_emb, sin_emb, alpha);
k = apply_rotary_emb(k, k_right, cos_emb, sin_emb, alpha);
}
*reinterpret_cast<Qk_vec *>(&q_smem[tid * QK_VEC_SIZE]) = q; *reinterpret_cast<Qk_vec *>(&q_smem[tid * QK_VEC_SIZE]) = q;
int co = tid / QK_VECS_IN_16B; int co = tid / QK_VECS_IN_16B;
...@@ -1120,6 +1198,7 @@ void fmha(const phi::GPUContext &dev_ctx, ...@@ -1120,6 +1198,7 @@ void fmha(const phi::GPUContext &dev_ctx,
const phi::DenseTensor &qkv_tensor, const phi::DenseTensor &qkv_tensor,
const phi::DenseTensor &qkv_bias_tensor, const phi::DenseTensor &qkv_bias_tensor,
const phi::DenseTensor &src_mask_tensor, const phi::DenseTensor &src_mask_tensor,
const phi::DenseTensor *rotary_tensor,
phi::DenseTensor *cache_kv_tensor, phi::DenseTensor *cache_kv_tensor,
phi::DenseTensor *out_tensor, phi::DenseTensor *out_tensor,
int batch_size, int batch_size,
...@@ -1127,6 +1206,7 @@ void fmha(const phi::GPUContext &dev_ctx, ...@@ -1127,6 +1206,7 @@ void fmha(const phi::GPUContext &dev_ctx,
int num_head, int num_head,
int dim_head, int dim_head,
int timestep, int timestep,
int rotary_emb_dims,
float inv_sqrt_dh) { float inv_sqrt_dh) {
Masked_multihead_attention_params<T> params; Masked_multihead_attention_params<T> params;
params.out = out_tensor->data<T>(); params.out = out_tensor->data<T>();
...@@ -1134,12 +1214,18 @@ void fmha(const phi::GPUContext &dev_ctx, ...@@ -1134,12 +1214,18 @@ void fmha(const phi::GPUContext &dev_ctx,
params.qkv_bias = qkv_bias_tensor.data<T>(); params.qkv_bias = qkv_bias_tensor.data<T>();
params.attn_mask = src_mask_tensor.data<T>(); params.attn_mask = src_mask_tensor.data<T>();
params.cache_kv = cache_kv_tensor->data<T>(); params.cache_kv = cache_kv_tensor->data<T>();
if (rotary_emb_dims > 0) {
params.rotary_emb = rotary_tensor->data<T>();
} else {
params.rotary_emb = nullptr;
}
params.batch_size = batch_size; params.batch_size = batch_size;
params.num_head = num_head; params.num_head = num_head;
params.timestep = timestep; params.timestep = timestep;
params.max_seq_length = max_seq_length; params.max_seq_length = max_seq_length;
params.inv_sqrt_dh = inv_sqrt_dh; params.inv_sqrt_dh = inv_sqrt_dh;
params.rotary_emb_dims = rotary_emb_dims;
switch (dim_head) { switch (dim_head) {
case 10: case 10:
...@@ -1169,6 +1255,35 @@ void fmha(const phi::GPUContext &dev_ctx, ...@@ -1169,6 +1255,35 @@ void fmha(const phi::GPUContext &dev_ctx,
} }
} }
template <typename T>
void fmha(const phi::GPUContext &dev_ctx,
const phi::DenseTensor &qkv_tensor,
const phi::DenseTensor &qkv_bias_tensor,
const phi::DenseTensor &src_mask_tensor,
phi::DenseTensor *cache_kv_tensor,
phi::DenseTensor *out_tensor,
int batch_size,
int max_seq_length,
int num_head,
int dim_head,
int timestep,
float inv_sqrt_dh) {
fmha<T>(dev_ctx,
qkv_tensor,
qkv_bias_tensor,
src_mask_tensor,
nullptr,
cache_kv_tensor,
out_tensor,
batch_size,
max_seq_length,
num_head,
dim_head,
timestep,
0,
inv_sqrt_dh);
}
// NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8 // NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8
constexpr int VEC_16B = 16; constexpr int VEC_16B = 16;
...@@ -1405,6 +1520,94 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, ...@@ -1405,6 +1520,94 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
} }
} }
template <typename T>
__global__ void RotrayKernel(const T *input,
const T *cos_emb,
const T *sin_emb,
T *output,
const int batch_size,
const int head_num,
const int seq_len,
const int last_dim) {
int bi = blockIdx.x;
int hi = blockIdx.y;
int si = blockIdx.z;
int half_lastdim = last_dim / 2;
// Note(ZhenyuLi): Calculate the relevant data at one time, so that no
// additional space is required.
for (int ti = threadIdx.x; ti < half_lastdim; ti += blockDim.x) {
int base_idx = bi * head_num * seq_len * last_dim +
hi * seq_len * last_dim + si * last_dim;
int left_idx = base_idx + ti;
const int right_idx = base_idx + ti + half_lastdim;
int emb_idx = bi * seq_len * last_dim + si * last_dim + ti;
T input_left = input[left_idx];
T input_right = input[right_idx];
T cos_tmp = cos_emb[emb_idx];
T sin_tmp = sin_emb[emb_idx];
T res1 = input_left * cos_tmp - input_right * sin_tmp;
T res2 = input_right * cos_tmp + input_left * sin_tmp;
output[left_idx] = res1;
output[right_idx] = res2;
}
}
template <typename T>
void rotary_qk(const phi::GPUContext &dev_ctx,
T *q,
T *k, // kv
const T *q_input, // q
const T *k_input, // kv
const T *rotary_emb,
const int rotary_emb_dims,
const int batch_size,
const int head_num,
const int seq_len,
const int dim_head) {
// q_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num,
// seq_len * rotary_emb_dims, dim_head / rotary_emb_dims]
// kv_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num,
// seq_len * rotary_emb_dims, dim_head / rotary_emb_dims] rotary_emb [2, bs,
// 1, seq_len, dim_head] -> [2, bs, 1, seq_len * rotary_emb_dims, dim_head /
// rotary_emb_dims]
dim3 grid(batch_size, head_num, seq_len * rotary_emb_dims);
const int last_dim = dim_head / rotary_emb_dims;
auto getBlockSize = [](int dim) {
if (dim > 256) {
return 512;
} else if (dim > 128) {
return 256;
} else if (dim > 64) {
return 128;
} else if (dim > 32) {
return 64;
} else {
return 32;
}
};
int BlockSize = getBlockSize(last_dim / 2);
const T *cos_emb = rotary_emb;
const T *sin_emb = rotary_emb + batch_size * seq_len * dim_head;
RotrayKernel<<<grid, BlockSize, 0, dev_ctx.stream()>>>(
q_input,
cos_emb,
sin_emb,
q,
batch_size,
head_num,
seq_len * rotary_emb_dims,
last_dim);
RotrayKernel<<<grid, BlockSize, 0, dev_ctx.stream()>>>(
k_input,
cos_emb,
sin_emb,
k,
batch_size,
head_num,
seq_len * rotary_emb_dims,
last_dim);
}
#if CUDA_VERSION >= 11060 #if CUDA_VERSION >= 11060
// Only Used in Inference // Only Used in Inference
template <typename T> template <typename T>
......
...@@ -62,6 +62,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -62,6 +62,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"QKVBias", "QKVBias",
"CacheKV", "CacheKV",
"PreCaches", "PreCaches",
"RotaryPosEmb",
"TimeStep", "TimeStep",
"SrcMask", "SrcMask",
"OutLinearW", "OutLinearW",
......
...@@ -120,6 +120,8 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -120,6 +120,8 @@ class TestFusedMultiTransformerOp(OpTest):
self.has_cache_kv = False self.has_cache_kv = False
self.gen_cache_kv = False self.gen_cache_kv = False
self.has_pre_cache = False self.has_pre_cache = False
self.rotary_embs = None
self.rotary_emb_dims = 0
self.training = False self.training = False
...@@ -213,12 +215,53 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -213,12 +215,53 @@ class TestFusedMultiTransformerOp(OpTest):
) )
else: else:
self.attn_mask = None self.attn_mask = None
if self.rotary_emb_dims > 0:
self.rotary_emb = np.random.uniform(
-1,
1,
(
2,
self.batch_size,
1,
self.query_length,
self.head_dim // 2 // self.rotary_emb_dims,
),
).astype(self.x_type)
concat_nums = 2 * self.rotary_emb_dims
rotary_embs = []
for _ in range(concat_nums):
rotary_embs.append(self.rotary_emb)
self.rotary_embs = np.concatenate(rotary_embs, -1)
self.key, self.value = self.query, self.query self.key, self.value = self.query, self.query
self.dout = np.random.uniform( self.dout = np.random.uniform(
-1, 1, (self.batch_size, self.query_length, self.embed_dim) -1, 1, (self.batch_size, self.query_length, self.embed_dim)
).astype(self.x_type) ).astype(self.x_type)
def rotate_half(self, x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return paddle.concat((-x2, x1), axis=-1)
def apply_rotary_emb(self, x, cos_emb, sin_emb, rotary_emb_dims):
# x shape [bsz, num_heads, seq_len, head_dim]
# cos_emb, sin_emb shape [bsz, 1, seq_len, head_dim]
x_dims = paddle.split(x, num_or_sections=rotary_emb_dims, axis=-1)
cos_dims = paddle.split(
cos_emb, num_or_sections=rotary_emb_dims, axis=-1
)
sin_dims = paddle.split(
sin_emb, num_or_sections=rotary_emb_dims, axis=-1
)
rotary_dims = []
for x_dim, cos_dim, sin_dim in zip(x_dims, cos_dims, sin_dims):
rotary_dims.append(
x_dim * cos_dim + self.rotate_half(x_dim) * sin_dim
)
return paddle.concat(rotary_dims, axis=-1)
def GetBaselineOut(self): def GetBaselineOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0)) paddle.disable_static(place=paddle.CUDAPlace(0))
tensor_query = paddle.to_tensor(self.query, stop_gradient=False) tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
...@@ -238,6 +281,11 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -238,6 +281,11 @@ class TestFusedMultiTransformerOp(OpTest):
else: else:
attn_mask = None attn_mask = None
if self.rotary_emb_dims > 0:
rotary_embs = paddle.to_tensor(
self.rotary_embs, stop_gradient=False
)
for i in range(self.layers): for i in range(self.layers):
residual = tensor_query residual = tensor_query
ln1_out = tensor_query ln1_out = tensor_query
...@@ -254,6 +302,16 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -254,6 +302,16 @@ class TestFusedMultiTransformerOp(OpTest):
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])
if self.rotary_emb_dims > 0:
cos_emb = rotary_embs[0]
sin_emb = rotary_embs[1]
q_out = self.apply_rotary_emb(
q_out, cos_emb, sin_emb, self.rotary_emb_dims
)
k_out = self.apply_rotary_emb(
k_out, cos_emb, sin_emb, self.rotary_emb_dims
)
if self.has_cache_kv: if self.has_cache_kv:
# [1, B, n_head, cache_seq_len, head_dim] # [1, B, n_head, cache_seq_len, head_dim]
cache_k, cache_v = paddle.split(cache_kv, 2) cache_k, cache_v = paddle.split(cache_kv, 2)
...@@ -414,6 +472,13 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -414,6 +472,13 @@ class TestFusedMultiTransformerOp(OpTest):
(3, self.num_heads, self.head_dim, self.embed_dim) (3, self.num_heads, self.head_dim, self.embed_dim)
) )
if self.rotary_emb_dims > 0:
rotary_embs = paddle.to_tensor(
self.rotary_embs, stop_gradient=False
)
else:
rotary_embs = None
x = paddle.to_tensor(self.query, stop_gradient=False) x = paddle.to_tensor(self.query, stop_gradient=False)
cache_kvs, cache_kv = None, None cache_kvs, cache_kv = None, None
time_step = None time_step = None
...@@ -550,6 +615,8 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -550,6 +615,8 @@ class TestFusedMultiTransformerOp(OpTest):
pre_layer_norm=self.pre_layer_norm, pre_layer_norm=self.pre_layer_norm,
epsilon=epsilon, epsilon=epsilon,
cache_kvs=cache_kvs, cache_kvs=cache_kvs,
rotary_embs=rotary_embs,
rotary_emb_dims=self.rotary_emb_dims,
pre_caches=pre_caches, pre_caches=pre_caches,
time_step=time_step, time_step=time_step,
attn_mask=attn_mask, attn_mask=attn_mask,
...@@ -573,6 +640,11 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -573,6 +640,11 @@ class TestFusedMultiTransformerOp(OpTest):
time_step = None time_step = None
time_step_feed = None time_step_feed = None
pre_caches, pre_cache = None, None pre_caches, pre_cache = None, None
rotary_embs = None
if self.rotary_emb_dims > 0:
rotary_embs = paddle.to_tensor(self.rotary_embs)
if self.has_cache_kv: if self.has_cache_kv:
cache_kvs = [] cache_kvs = []
...@@ -727,6 +799,8 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -727,6 +799,8 @@ class TestFusedMultiTransformerOp(OpTest):
attn_mask=attn_mask, attn_mask=attn_mask,
caches=cache_kvs, caches=cache_kvs,
pre_caches=pre_caches, pre_caches=pre_caches,
rotary_embs=rotary_embs,
rotary_emb_dims=self.rotary_emb_dims,
time_step=time_step, time_step=time_step,
)[0] )[0]
exe = paddle.static.Executor(place=paddle.CUDAPlace(0)) exe = paddle.static.Executor(place=paddle.CUDAPlace(0))
...@@ -735,7 +809,9 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -735,7 +809,9 @@ class TestFusedMultiTransformerOp(OpTest):
'x': self.query, 'x': self.query,
'cache_kvs': cache_kvs_feed, 'cache_kvs': cache_kvs_feed,
'pre_caches': pre_caches_feed, 'pre_caches': pre_caches_feed,
'rotary_embs': rotary_embs,
'time_step': time_step_feed, 'time_step': time_step_feed,
'rotary_emb_dims': self.rotary_emb_dims,
'attn_mask': attn_mask, 'attn_mask': attn_mask,
} }
out = exe.run( out = exe.run(
...@@ -802,6 +878,38 @@ class TestFusedMultiTransformerOp(OpTest): ...@@ -802,6 +878,38 @@ class TestFusedMultiTransformerOp(OpTest):
) )
class TestFusedMultiTransformerOpRotaryFP16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.x_type = np.float16
self.rotary_emb_dims = 1
class TestFusedMultiTransformerOpGenRotaryFP16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.x_type = np.float16
self.has_cache_kv = True
self.gen_cache_kv = False
self.query_length = 1
self.key_length, self.value_length = (
self.query_length,
self.query_length,
)
self.rotary_emb_dims = 2
class TestFusedMultiTransformerOpGenCacheRotaryFP16(
TestFusedMultiTransformerOp
):
def config(self):
super().config()
self.x_type = np.float16
self.has_cache_kv = True
self.gen_cache_kv = True
self.rotary_emb_dims = 1
class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp): class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp):
def config(self): def config(self):
super().config() super().config()
...@@ -932,12 +1040,15 @@ class TestFusedMultiTransformerOpPreCacheStatic(TestFusedMultiTransformerOp): ...@@ -932,12 +1040,15 @@ class TestFusedMultiTransformerOpPreCacheStatic(TestFusedMultiTransformerOp):
) )
def test_fused_multi_transformer_op(self): def test_fused_multi_transformer_op(self):
final_out_ref = self.GetBaselineOut() for i in range(3):
final_out = self.GetFusedMultiTransformerOutStatic() self.rotary_emb_dims = i
self.generate_input_data()
np.testing.assert_allclose( final_out_ref = self.GetBaselineOut()
final_out_ref, final_out, rtol=self.rtol, atol=self.atol final_out = self.GetFusedMultiTransformerOutStatic()
)
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -845,9 +845,11 @@ def fused_multi_transformer( ...@@ -845,9 +845,11 @@ def fused_multi_transformer(
epsilon=1e-05, epsilon=1e-05,
cache_kvs=None, cache_kvs=None,
pre_caches=None, pre_caches=None,
rotary_embs=None,
time_step=None, time_step=None,
attn_mask=None, attn_mask=None,
dropout_rate=0.0, dropout_rate=0.0,
rotary_emb_dims=0,
activation="gelu", activation="gelu",
training=False, training=False,
mode='upscale_in_train', mode='upscale_in_train',
...@@ -912,11 +914,14 @@ def fused_multi_transformer( ...@@ -912,11 +914,14 @@ def fused_multi_transformer(
epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5. epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5.
cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None. cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None.
pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None.
rotary_embs (Tensor optional): The RoPE embs for rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None. time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
with shape `[batch_size, 1, sequence_length, sequence_length]`. Default None. with shape `[batch_size, 1, sequence_length, sequence_length]`. Default None.
dropout_rate (float, optional): The dropout probability of setting units to zero. Default 0.0. dropout_rate (float, optional): The dropout probability of setting units to zero. Default 0.0.
rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None,
1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0.
activation (str, optional): The activation. Default "gelu". activation (str, optional): The activation. Default "gelu".
training (bool, optional): A flag indicating whether it is in train phrase or not. Default False. training (bool, optional): A flag indicating whether it is in train phrase or not. Default False.
mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer'] mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
...@@ -1006,6 +1011,7 @@ def fused_multi_transformer( ...@@ -1006,6 +1011,7 @@ def fused_multi_transformer(
qkv_biases, qkv_biases,
cache_kvs, cache_kvs,
pre_caches, pre_caches,
rotary_embs,
time_step, time_step,
attn_mask, attn_mask,
linear_weights, linear_weights,
...@@ -1023,6 +1029,8 @@ def fused_multi_transformer( ...@@ -1023,6 +1029,8 @@ def fused_multi_transformer(
epsilon, epsilon,
'dropout_rate', 'dropout_rate',
dropout_rate, dropout_rate,
'rotary_emb_dims',
rotary_emb_dims,
'is_test', 'is_test',
not training, not training,
'dropout_implementation', 'dropout_implementation',
...@@ -1063,6 +1071,8 @@ def fused_multi_transformer( ...@@ -1063,6 +1071,8 @@ def fused_multi_transformer(
inputs['TimeStep'] = time_step inputs['TimeStep'] = time_step
if pre_caches is not None: if pre_caches is not None:
inputs['PreCaches'] = pre_caches inputs['PreCaches'] = pre_caches
if rotary_emb_dims > 0:
inputs['RotaryPosEmb'] = rotary_embs
inputs['SrcMask'] = attn_mask inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = linear_weights inputs['OutLinearW'] = linear_weights
if linear_biases is not None: if linear_biases is not None:
...@@ -1082,6 +1092,7 @@ def fused_multi_transformer( ...@@ -1082,6 +1092,7 @@ def fused_multi_transformer(
'pre_layer_norm': pre_layer_norm, 'pre_layer_norm': pre_layer_norm,
'epsilon': epsilon, 'epsilon': epsilon,
'dropout_rate': dropout_rate, 'dropout_rate': dropout_rate,
'rotary_emb_dims': rotary_emb_dims,
'is_test': not training, 'is_test': not training,
'dropout_implementation': mode, 'dropout_implementation': mode,
'act_method': activation, 'act_method': activation,
......
...@@ -1357,7 +1357,14 @@ class FusedMultiTransformer(Layer): ...@@ -1357,7 +1357,14 @@ class FusedMultiTransformer(Layer):
self.name = name self.name = name
def forward( def forward(
self, src, attn_mask=None, caches=None, pre_caches=None, time_step=None self,
src,
attn_mask=None,
caches=None,
pre_caches=None,
rotary_embs=None,
rotary_emb_dims=0,
time_step=None,
): ):
r""" r"""
Applies multi transformer layers on the input. Applies multi transformer layers on the input.
...@@ -1378,6 +1385,9 @@ class FusedMultiTransformer(Layer): ...@@ -1378,6 +1385,9 @@ class FusedMultiTransformer(Layer):
`[2, batch_size, num_head, max_seq_len, head_dim]`. Default None. `[2, batch_size, num_head, max_seq_len, head_dim]`. Default None.
pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches
for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None.
rotary_embs (Tensor optional): The RoPE embs for the rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None.
rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None,
1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0.
time_step (Tensor, optional): The time step tensor for the generation time_step (Tensor, optional): The time step tensor for the generation
model. Which used in decode stage, to represent the time step, model. Which used in decode stage, to represent the time step,
that is, the real seq_len of CacheKV. The shape is `[1]`, must be that is, the real seq_len of CacheKV. The shape is `[1]`, must be
...@@ -1411,9 +1421,11 @@ class FusedMultiTransformer(Layer): ...@@ -1411,9 +1421,11 @@ class FusedMultiTransformer(Layer):
epsilon=self._epsilon, epsilon=self._epsilon,
cache_kvs=caches, cache_kvs=caches,
pre_caches=pre_caches, pre_caches=pre_caches,
rotary_embs=rotary_embs,
time_step=time_step, time_step=time_step,
attn_mask=attn_mask, attn_mask=attn_mask,
dropout_rate=self.dropout_rate, dropout_rate=self.dropout_rate,
rotary_emb_dims=rotary_emb_dims,
activation=self.activation, activation=self.activation,
training=self.training, training=self.training,
mode='upscale_in_train', mode='upscale_in_train',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册