未验证 提交 644dfc60 编写于 作者: L lzy 提交者: GitHub

make FusedMultiTransformer supports RoPE (#48842)

上级 1bbc9b64
......@@ -137,7 +137,7 @@ __global__ void apply_scale(T *data, T scale, int n) {
template <typename T>
__global__ void RotrayKernel(const T *inputact,
const T *input1,
const T *intput2,
const T *input2,
T *output,
const int nElement,
const int lastdim) {
......@@ -147,7 +147,7 @@ __global__ void RotrayKernel(const T *inputact,
int col = index % lastdim;
int half_lastdim = lastdim / 2;
const int right_index = index - col + (col + half_lastdim) % lastdim;
output[index] = left_elemul_out + intput2[index] * inputact[right_index];
output[index] = left_elemul_out + input2[index] * inputact[right_index];
}
inline int round_up(int seq_len, int multiple = 32) {
......
......@@ -256,8 +256,7 @@ class FMHARef {
dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
}
void ComputeForwardWithoutTranspose(const phi::DenseTensor& qkv_input_tensor,
const phi::DenseTensor* cache_kv_tensor,
void ComputeForwardWithoutTranspose(const phi::DenseTensor* cache_kv_tensor,
const phi::DenseTensor* src_mask_tensor,
phi::DenseTensor* q_transpose_out_tensor,
phi::DenseTensor* kv_transpose_out_tensor,
......
......@@ -174,6 +174,9 @@ class FusedMultiTransformerOpOpMaker
"(optional) The prefix caches for generation inference.")
.AsDispensable()
.AsDuplicable();
AddInput("RotaryPosEmb",
"(optional) The RoPE embeddings for generation inference.")
.AsDispensable();
AddInput("TimeStep",
"(optional, int) The time step for generation inference.")
.AsDispensable();
......@@ -209,6 +212,18 @@ class FusedMultiTransformerOpOpMaker
"else, uses post_layer_norm architecuture. "
"[default true].")
.SetDefault(true);
AddAttr<int>("rotary_emb_dims",
"the Attr(dims) for RotaryPosEmb's Computation [default 0].")
.SetDefault(0)
.AddCustomChecker([](const int &rotary_emb_dims) {
PADDLE_ENFORCE_EQ(
rotary_emb_dims >= 0 && rotary_emb_dims <= 2,
true,
platform::errors::InvalidArgument(
"'rotary_emb_dims' in Op(Rotray) should be between"
"0 and 2, But received [%s].",
rotary_emb_dims));
});
AddAttr<float>("epsilon",
"Constant for numerical stability [default 1e-5].")
.SetDefault(1e-5)
......
......@@ -77,6 +77,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto *qkv_out_data =
dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T));
// 2.1 rotary
auto *rotary_tensor = ctx.Input<phi::DenseTensor>("RotaryPosEmb");
const int rotary_emb_dims = ctx.Attr<int>("rotary_emb_dims");
// 3. fmha
AttnDropoutParam attn_param(
true, "upscale_in_train", 0.0, true, true, 0, nullptr);
......@@ -297,6 +301,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
qkv_out,
*qkv_bias,
*src_mask,
rotary_tensor,
cache_kv_out,
&fmha_out,
bsz,
......@@ -304,6 +309,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
num_head,
dim_head,
time_step->data<int>()[0],
rotary_emb_dims,
1. / sqrt(dim_head));
} else if (cache_kv_out) { // generation context stage
const phi::DenseTensor *pre_cache_kv_tensor =
......@@ -322,8 +328,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
seq_len,
dim_head,
compute_bias);
fmha_compute.ComputeForwardWithoutTranspose(qkv_out,
pre_cache_kv_tensor,
// q_transpose_out_data [bs, head_num, seq_len, dim_head]
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>();
rotary_qk(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
q_transpose_out_data,
kv_transpose_out_data,
rotary_emb_data,
rotary_emb_dims,
bsz,
num_head,
seq_len,
dim_head);
}
fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor,
src_mask,
&q_transpose_out,
&kv_transpose_out,
......@@ -383,8 +406,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
seq_len,
dim_head,
compute_bias);
fmha_compute.ComputeForwardWithoutTranspose(qkv_out,
cache_kv,
// q_transpose_out_data [bs, head_num, seq_len, dim_head]
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>();
rotary_qk(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
q_transpose_out_data,
kv_transpose_out_data,
rotary_emb_data,
rotary_emb_dims,
bsz,
num_head,
seq_len,
dim_head);
}
fmha_compute.ComputeForwardWithoutTranspose(cache_kv,
src_mask,
&q_transpose_out,
&kv_transpose_out,
......@@ -594,6 +634,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto *qkv_out_data =
dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T));
// 2.1 rotary
auto *rotary_tensor = ctx.Input<phi::DenseTensor>("RotaryPosEmb");
const int rotary_emb_dims = ctx.Attr<int>("rotary_emb_dims");
// 3. fmha
AttnDropoutParam attn_param(
true, "upscale_in_train", 0.0, true, true, 0, nullptr);
......@@ -821,6 +865,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
qkv_out,
*qkv_bias,
*src_mask,
rotary_tensor,
cache_kv_out,
&fmha_out,
bsz,
......@@ -828,6 +873,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
num_head,
dim_head,
time_step->data<int>()[0],
rotary_emb_dims,
1. / sqrt(dim_head));
} else if (cache_kv_out) { // generation context stage
const phi::DenseTensor *pre_cache_kv_tensor =
......@@ -846,8 +892,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
seq_len,
dim_head,
compute_bias);
fmha_compute.ComputeForwardWithoutTranspose(qkv_out,
pre_cache_kv_tensor,
// q_transpose_out_data [bs, head_num, seq_len, dim_head]
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>();
rotary_qk(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
q_transpose_out_data,
kv_transpose_out_data,
rotary_emb_data,
rotary_emb_dims,
bsz,
num_head,
seq_len,
dim_head);
}
fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor,
src_mask,
&q_transpose_out,
&kv_transpose_out,
......@@ -907,8 +970,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
seq_len,
dim_head,
compute_bias);
fmha_compute.ComputeForwardWithoutTranspose(qkv_out,
cache_kv,
// q_transpose_out_data [bs, head_num, seq_len, dim_head]
// kv_transpose_out_data [2, bs, head_num, seq_len, dim_head]
if (rotary_emb_dims != 0) {
auto *rotary_emb_data = rotary_tensor->data<T>();
rotary_qk(dev_ctx,
q_transpose_out_data,
kv_transpose_out_data,
q_transpose_out_data,
kv_transpose_out_data,
rotary_emb_data,
rotary_emb_dims,
bsz,
num_head,
seq_len,
dim_head);
}
fmha_compute.ComputeForwardWithoutTranspose(cache_kv,
src_mask,
&q_transpose_out,
&kv_transpose_out,
......
......@@ -127,6 +127,11 @@ struct Masked_multihead_attention_params {
// v [B, num_head, max_seq_len, dim_head]
T *cache_kv;
// The RoPE embedding, [B, 1, 1, dim_head]
// rotary_emb_dims = 1 if pos_ids_extra is null else 2
const T *rotary_emb;
int rotary_emb_dims;
int batch_size;
int num_head;
int timestep; // cache_seq_length
......@@ -404,6 +409,18 @@ inline __device__ float4 mul(float4 a, float b) {
return res;
}
template <typename Qk_vec>
inline __device__ Qk_vec apply_rotary_emb(Qk_vec input_left,
Qk_vec input_right,
Qk_vec cos_emb,
Qk_vec sin_emb,
float alpha) {
Qk_vec res1 = mul<Qk_vec, Qk_vec, Qk_vec>(input_left, cos_emb);
Qk_vec res2 = mul<Qk_vec, Qk_vec, Qk_vec>(input_right, sin_emb);
res2 = mul<Qk_vec, Qk_vec, float>(res2, alpha);
return add(res1, res2);
}
inline __device__ float sum(float v) { return v; }
inline __device__ float sum(float2 v) { return v.x + v.y; }
inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; }
......@@ -804,6 +821,67 @@ __global__ void masked_multihead_attention_kernel(
// we may not require k_bias.
k = add(k, k_bias);
// rotary pos emb
if (params.rotary_emb_dims != 0) {
int last_dim = Dh / params.rotary_emb_dims;
int half_lastdim = last_dim / 2;
int rotary_offset = bi * Dh + tid * QK_VEC_SIZE;
const T *cos_base = params.rotary_emb;
const T *sin_base = params.rotary_emb + params.batch_size * Dh;
int stride = half_lastdim / QK_VEC_SIZE;
int stride_all_lastdim = 2 * stride;
int right_id = tid / stride_all_lastdim * stride_all_lastdim +
(tid + stride) % (stride_all_lastdim);
int qk_right_offset = qkv_base_offset + right_id * QK_VEC_SIZE;
int qk_right_bias_offset = hi * Dh + right_id * QK_VEC_SIZE;
Qk_vec q_right;
zero(q_right);
q_right =
(Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&q_base[qk_right_offset])
: q_right;
Qk_vec k_right;
zero(k_right);
k_right =
(Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&k_base[qk_right_offset])
: k_right;
Qk_vec q_right_bias;
zero(q_right_bias);
q_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(
&q_bias_base[qk_right_bias_offset])
: q_right_bias;
Qk_vec k_right_bias;
zero(k_right_bias);
k_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(
&k_bias_base[qk_right_bias_offset])
: k_right_bias;
q_right = add(q_right, q_right_bias);
k_right = add(k_right, k_right_bias);
Qk_vec cos_emb;
zero(cos_emb);
cos_emb =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&cos_base[rotary_offset])
: cos_emb;
Qk_vec sin_emb;
zero(sin_emb);
sin_emb =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&sin_base[rotary_offset])
: sin_emb;
float alpha = (tid % stride_all_lastdim) < stride ? static_cast<float>(-1)
: static_cast<float>(1);
q = apply_rotary_emb(q, q_right, cos_emb, sin_emb, alpha);
k = apply_rotary_emb(k, k_right, cos_emb, sin_emb, alpha);
}
*reinterpret_cast<Qk_vec *>(&q_smem[tid * QK_VEC_SIZE]) = q;
int co = tid / QK_VECS_IN_16B;
......@@ -1120,6 +1198,7 @@ void fmha(const phi::GPUContext &dev_ctx,
const phi::DenseTensor &qkv_tensor,
const phi::DenseTensor &qkv_bias_tensor,
const phi::DenseTensor &src_mask_tensor,
const phi::DenseTensor *rotary_tensor,
phi::DenseTensor *cache_kv_tensor,
phi::DenseTensor *out_tensor,
int batch_size,
......@@ -1127,6 +1206,7 @@ void fmha(const phi::GPUContext &dev_ctx,
int num_head,
int dim_head,
int timestep,
int rotary_emb_dims,
float inv_sqrt_dh) {
Masked_multihead_attention_params<T> params;
params.out = out_tensor->data<T>();
......@@ -1134,12 +1214,18 @@ void fmha(const phi::GPUContext &dev_ctx,
params.qkv_bias = qkv_bias_tensor.data<T>();
params.attn_mask = src_mask_tensor.data<T>();
params.cache_kv = cache_kv_tensor->data<T>();
if (rotary_emb_dims > 0) {
params.rotary_emb = rotary_tensor->data<T>();
} else {
params.rotary_emb = nullptr;
}
params.batch_size = batch_size;
params.num_head = num_head;
params.timestep = timestep;
params.max_seq_length = max_seq_length;
params.inv_sqrt_dh = inv_sqrt_dh;
params.rotary_emb_dims = rotary_emb_dims;
switch (dim_head) {
case 10:
......@@ -1169,6 +1255,35 @@ void fmha(const phi::GPUContext &dev_ctx,
}
}
template <typename T>
void fmha(const phi::GPUContext &dev_ctx,
const phi::DenseTensor &qkv_tensor,
const phi::DenseTensor &qkv_bias_tensor,
const phi::DenseTensor &src_mask_tensor,
phi::DenseTensor *cache_kv_tensor,
phi::DenseTensor *out_tensor,
int batch_size,
int max_seq_length,
int num_head,
int dim_head,
int timestep,
float inv_sqrt_dh) {
fmha<T>(dev_ctx,
qkv_tensor,
qkv_bias_tensor,
src_mask_tensor,
nullptr,
cache_kv_tensor,
out_tensor,
batch_size,
max_seq_length,
num_head,
dim_head,
timestep,
0,
inv_sqrt_dh);
}
// NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8
constexpr int VEC_16B = 16;
......@@ -1405,6 +1520,94 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx,
}
}
template <typename T>
__global__ void RotrayKernel(const T *input,
const T *cos_emb,
const T *sin_emb,
T *output,
const int batch_size,
const int head_num,
const int seq_len,
const int last_dim) {
int bi = blockIdx.x;
int hi = blockIdx.y;
int si = blockIdx.z;
int half_lastdim = last_dim / 2;
// Note(ZhenyuLi): Calculate the relevant data at one time, so that no
// additional space is required.
for (int ti = threadIdx.x; ti < half_lastdim; ti += blockDim.x) {
int base_idx = bi * head_num * seq_len * last_dim +
hi * seq_len * last_dim + si * last_dim;
int left_idx = base_idx + ti;
const int right_idx = base_idx + ti + half_lastdim;
int emb_idx = bi * seq_len * last_dim + si * last_dim + ti;
T input_left = input[left_idx];
T input_right = input[right_idx];
T cos_tmp = cos_emb[emb_idx];
T sin_tmp = sin_emb[emb_idx];
T res1 = input_left * cos_tmp - input_right * sin_tmp;
T res2 = input_right * cos_tmp + input_left * sin_tmp;
output[left_idx] = res1;
output[right_idx] = res2;
}
}
template <typename T>
void rotary_qk(const phi::GPUContext &dev_ctx,
T *q,
T *k, // kv
const T *q_input, // q
const T *k_input, // kv
const T *rotary_emb,
const int rotary_emb_dims,
const int batch_size,
const int head_num,
const int seq_len,
const int dim_head) {
// q_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num,
// seq_len * rotary_emb_dims, dim_head / rotary_emb_dims]
// kv_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num,
// seq_len * rotary_emb_dims, dim_head / rotary_emb_dims] rotary_emb [2, bs,
// 1, seq_len, dim_head] -> [2, bs, 1, seq_len * rotary_emb_dims, dim_head /
// rotary_emb_dims]
dim3 grid(batch_size, head_num, seq_len * rotary_emb_dims);
const int last_dim = dim_head / rotary_emb_dims;
auto getBlockSize = [](int dim) {
if (dim > 256) {
return 512;
} else if (dim > 128) {
return 256;
} else if (dim > 64) {
return 128;
} else if (dim > 32) {
return 64;
} else {
return 32;
}
};
int BlockSize = getBlockSize(last_dim / 2);
const T *cos_emb = rotary_emb;
const T *sin_emb = rotary_emb + batch_size * seq_len * dim_head;
RotrayKernel<<<grid, BlockSize, 0, dev_ctx.stream()>>>(
q_input,
cos_emb,
sin_emb,
q,
batch_size,
head_num,
seq_len * rotary_emb_dims,
last_dim);
RotrayKernel<<<grid, BlockSize, 0, dev_ctx.stream()>>>(
k_input,
cos_emb,
sin_emb,
k,
batch_size,
head_num,
seq_len * rotary_emb_dims,
last_dim);
}
#if CUDA_VERSION >= 11060
// Only Used in Inference
template <typename T>
......
......@@ -62,6 +62,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"QKVBias",
"CacheKV",
"PreCaches",
"RotaryPosEmb",
"TimeStep",
"SrcMask",
"OutLinearW",
......
......@@ -120,6 +120,8 @@ class TestFusedMultiTransformerOp(OpTest):
self.has_cache_kv = False
self.gen_cache_kv = False
self.has_pre_cache = False
self.rotary_embs = None
self.rotary_emb_dims = 0
self.training = False
......@@ -213,12 +215,53 @@ class TestFusedMultiTransformerOp(OpTest):
)
else:
self.attn_mask = None
if self.rotary_emb_dims > 0:
self.rotary_emb = np.random.uniform(
-1,
1,
(
2,
self.batch_size,
1,
self.query_length,
self.head_dim // 2 // self.rotary_emb_dims,
),
).astype(self.x_type)
concat_nums = 2 * self.rotary_emb_dims
rotary_embs = []
for _ in range(concat_nums):
rotary_embs.append(self.rotary_emb)
self.rotary_embs = np.concatenate(rotary_embs, -1)
self.key, self.value = self.query, self.query
self.dout = np.random.uniform(
-1, 1, (self.batch_size, self.query_length, self.embed_dim)
).astype(self.x_type)
def rotate_half(self, x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return paddle.concat((-x2, x1), axis=-1)
def apply_rotary_emb(self, x, cos_emb, sin_emb, rotary_emb_dims):
# x shape [bsz, num_heads, seq_len, head_dim]
# cos_emb, sin_emb shape [bsz, 1, seq_len, head_dim]
x_dims = paddle.split(x, num_or_sections=rotary_emb_dims, axis=-1)
cos_dims = paddle.split(
cos_emb, num_or_sections=rotary_emb_dims, axis=-1
)
sin_dims = paddle.split(
sin_emb, num_or_sections=rotary_emb_dims, axis=-1
)
rotary_dims = []
for x_dim, cos_dim, sin_dim in zip(x_dims, cos_dims, sin_dims):
rotary_dims.append(
x_dim * cos_dim + self.rotate_half(x_dim) * sin_dim
)
return paddle.concat(rotary_dims, axis=-1)
def GetBaselineOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
......@@ -238,6 +281,11 @@ class TestFusedMultiTransformerOp(OpTest):
else:
attn_mask = None
if self.rotary_emb_dims > 0:
rotary_embs = paddle.to_tensor(
self.rotary_embs, stop_gradient=False
)
for i in range(self.layers):
residual = tensor_query
ln1_out = tensor_query
......@@ -254,6 +302,16 @@ class TestFusedMultiTransformerOp(OpTest):
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])
if self.rotary_emb_dims > 0:
cos_emb = rotary_embs[0]
sin_emb = rotary_embs[1]
q_out = self.apply_rotary_emb(
q_out, cos_emb, sin_emb, self.rotary_emb_dims
)
k_out = self.apply_rotary_emb(
k_out, cos_emb, sin_emb, self.rotary_emb_dims
)
if self.has_cache_kv:
# [1, B, n_head, cache_seq_len, head_dim]
cache_k, cache_v = paddle.split(cache_kv, 2)
......@@ -414,6 +472,13 @@ class TestFusedMultiTransformerOp(OpTest):
(3, self.num_heads, self.head_dim, self.embed_dim)
)
if self.rotary_emb_dims > 0:
rotary_embs = paddle.to_tensor(
self.rotary_embs, stop_gradient=False
)
else:
rotary_embs = None
x = paddle.to_tensor(self.query, stop_gradient=False)
cache_kvs, cache_kv = None, None
time_step = None
......@@ -550,6 +615,8 @@ class TestFusedMultiTransformerOp(OpTest):
pre_layer_norm=self.pre_layer_norm,
epsilon=epsilon,
cache_kvs=cache_kvs,
rotary_embs=rotary_embs,
rotary_emb_dims=self.rotary_emb_dims,
pre_caches=pre_caches,
time_step=time_step,
attn_mask=attn_mask,
......@@ -573,6 +640,11 @@ class TestFusedMultiTransformerOp(OpTest):
time_step = None
time_step_feed = None
pre_caches, pre_cache = None, None
rotary_embs = None
if self.rotary_emb_dims > 0:
rotary_embs = paddle.to_tensor(self.rotary_embs)
if self.has_cache_kv:
cache_kvs = []
......@@ -727,6 +799,8 @@ class TestFusedMultiTransformerOp(OpTest):
attn_mask=attn_mask,
caches=cache_kvs,
pre_caches=pre_caches,
rotary_embs=rotary_embs,
rotary_emb_dims=self.rotary_emb_dims,
time_step=time_step,
)[0]
exe = paddle.static.Executor(place=paddle.CUDAPlace(0))
......@@ -735,7 +809,9 @@ class TestFusedMultiTransformerOp(OpTest):
'x': self.query,
'cache_kvs': cache_kvs_feed,
'pre_caches': pre_caches_feed,
'rotary_embs': rotary_embs,
'time_step': time_step_feed,
'rotary_emb_dims': self.rotary_emb_dims,
'attn_mask': attn_mask,
}
out = exe.run(
......@@ -802,6 +878,38 @@ class TestFusedMultiTransformerOp(OpTest):
)
class TestFusedMultiTransformerOpRotaryFP16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.x_type = np.float16
self.rotary_emb_dims = 1
class TestFusedMultiTransformerOpGenRotaryFP16(TestFusedMultiTransformerOp):
def config(self):
super().config()
self.x_type = np.float16
self.has_cache_kv = True
self.gen_cache_kv = False
self.query_length = 1
self.key_length, self.value_length = (
self.query_length,
self.query_length,
)
self.rotary_emb_dims = 2
class TestFusedMultiTransformerOpGenCacheRotaryFP16(
TestFusedMultiTransformerOp
):
def config(self):
super().config()
self.x_type = np.float16
self.has_cache_kv = True
self.gen_cache_kv = True
self.rotary_emb_dims = 1
class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp):
def config(self):
super().config()
......@@ -932,12 +1040,15 @@ class TestFusedMultiTransformerOpPreCacheStatic(TestFusedMultiTransformerOp):
)
def test_fused_multi_transformer_op(self):
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedMultiTransformerOutStatic()
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol
)
for i in range(3):
self.rotary_emb_dims = i
self.generate_input_data()
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedMultiTransformerOutStatic()
np.testing.assert_allclose(
final_out_ref, final_out, rtol=self.rtol, atol=self.atol
)
if __name__ == "__main__":
......
......@@ -845,9 +845,11 @@ def fused_multi_transformer(
epsilon=1e-05,
cache_kvs=None,
pre_caches=None,
rotary_embs=None,
time_step=None,
attn_mask=None,
dropout_rate=0.0,
rotary_emb_dims=0,
activation="gelu",
training=False,
mode='upscale_in_train',
......@@ -912,11 +914,14 @@ def fused_multi_transformer(
epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5.
cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None.
pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None.
rotary_embs (Tensor optional): The RoPE embs for rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None.
time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
with shape `[batch_size, 1, sequence_length, sequence_length]`. Default None.
dropout_rate (float, optional): The dropout probability of setting units to zero. Default 0.0.
rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None,
1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0.
activation (str, optional): The activation. Default "gelu".
training (bool, optional): A flag indicating whether it is in train phrase or not. Default False.
mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
......@@ -1006,6 +1011,7 @@ def fused_multi_transformer(
qkv_biases,
cache_kvs,
pre_caches,
rotary_embs,
time_step,
attn_mask,
linear_weights,
......@@ -1023,6 +1029,8 @@ def fused_multi_transformer(
epsilon,
'dropout_rate',
dropout_rate,
'rotary_emb_dims',
rotary_emb_dims,
'is_test',
not training,
'dropout_implementation',
......@@ -1063,6 +1071,8 @@ def fused_multi_transformer(
inputs['TimeStep'] = time_step
if pre_caches is not None:
inputs['PreCaches'] = pre_caches
if rotary_emb_dims > 0:
inputs['RotaryPosEmb'] = rotary_embs
inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = linear_weights
if linear_biases is not None:
......@@ -1082,6 +1092,7 @@ def fused_multi_transformer(
'pre_layer_norm': pre_layer_norm,
'epsilon': epsilon,
'dropout_rate': dropout_rate,
'rotary_emb_dims': rotary_emb_dims,
'is_test': not training,
'dropout_implementation': mode,
'act_method': activation,
......
......@@ -1357,7 +1357,14 @@ class FusedMultiTransformer(Layer):
self.name = name
def forward(
self, src, attn_mask=None, caches=None, pre_caches=None, time_step=None
self,
src,
attn_mask=None,
caches=None,
pre_caches=None,
rotary_embs=None,
rotary_emb_dims=0,
time_step=None,
):
r"""
Applies multi transformer layers on the input.
......@@ -1378,6 +1385,9 @@ class FusedMultiTransformer(Layer):
`[2, batch_size, num_head, max_seq_len, head_dim]`. Default None.
pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches
for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None.
rotary_embs (Tensor optional): The RoPE embs for the rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None.
rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None,
1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0.
time_step (Tensor, optional): The time step tensor for the generation
model. Which used in decode stage, to represent the time step,
that is, the real seq_len of CacheKV. The shape is `[1]`, must be
......@@ -1411,9 +1421,11 @@ class FusedMultiTransformer(Layer):
epsilon=self._epsilon,
cache_kvs=caches,
pre_caches=pre_caches,
rotary_embs=rotary_embs,
time_step=time_step,
attn_mask=attn_mask,
dropout_rate=self.dropout_rate,
rotary_emb_dims=rotary_emb_dims,
activation=self.activation,
training=self.training,
mode='upscale_in_train',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册