From 644dfc6014fdd8275d59f4ef6660af15332ddb30 Mon Sep 17 00:00:00 2001 From: lzy <569782149@qq.com> Date: Fri, 23 Dec 2022 20:37:17 +0800 Subject: [PATCH] make FusedMultiTransformer supports RoPE (#48842) --- .../multihead_matmul_roformer_plugin.cu | 4 +- paddle/fluid/operators/fused/fmha_ref.h | 3 +- .../fused/fused_multi_transformer_op.cc | 15 ++ .../fused/fused_multi_transformer_op.cu | 96 ++++++++- .../fused/fused_multi_transformer_op.cu.h | 203 ++++++++++++++++++ paddle/fluid/pybind/op_function_generator.h | 1 + .../test_fused_multi_transformer_op.py | 123 ++++++++++- .../nn/functional/fused_transformer.py | 11 + .../incubate/nn/layer/fused_transformer.py | 14 +- 9 files changed, 451 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.cu index 50f263ad61..43bc124bb9 100644 --- a/paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.cu @@ -137,7 +137,7 @@ __global__ void apply_scale(T *data, T scale, int n) { template __global__ void RotrayKernel(const T *inputact, const T *input1, - const T *intput2, + const T *input2, T *output, const int nElement, const int lastdim) { @@ -147,7 +147,7 @@ __global__ void RotrayKernel(const T *inputact, int col = index % lastdim; int half_lastdim = lastdim / 2; const int right_index = index - col + (col + half_lastdim) % lastdim; - output[index] = left_elemul_out + intput2[index] * inputact[right_index]; + output[index] = left_elemul_out + input2[index] * inputact[right_index]; } inline int round_up(int seq_len, int multiple = 32) { diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 80f8ccde26..483ebdfa05 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -256,8 +256,7 @@ class FMHARef { dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); } - void ComputeForwardWithoutTranspose(const phi::DenseTensor& qkv_input_tensor, - const phi::DenseTensor* cache_kv_tensor, + void ComputeForwardWithoutTranspose(const phi::DenseTensor* cache_kv_tensor, const phi::DenseTensor* src_mask_tensor, phi::DenseTensor* q_transpose_out_tensor, phi::DenseTensor* kv_transpose_out_tensor, diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc index 92b782c44c..89d2275e06 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc @@ -174,6 +174,9 @@ class FusedMultiTransformerOpOpMaker "(optional) The prefix caches for generation inference.") .AsDispensable() .AsDuplicable(); + AddInput("RotaryPosEmb", + "(optional) The RoPE embeddings for generation inference.") + .AsDispensable(); AddInput("TimeStep", "(optional, int) The time step for generation inference.") .AsDispensable(); @@ -209,6 +212,18 @@ class FusedMultiTransformerOpOpMaker "else, uses post_layer_norm architecuture. " "[default true].") .SetDefault(true); + AddAttr("rotary_emb_dims", + "the Attr(dims) for RotaryPosEmb's Computation [default 0].") + .SetDefault(0) + .AddCustomChecker([](const int &rotary_emb_dims) { + PADDLE_ENFORCE_EQ( + rotary_emb_dims >= 0 && rotary_emb_dims <= 2, + true, + platform::errors::InvalidArgument( + "'rotary_emb_dims' in Op(Rotray) should be between" + "0 and 2, But received [%s].", + rotary_emb_dims)); + }); AddAttr("epsilon", "Constant for numerical stability [default 1e-5].") .SetDefault(1e-5) diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index 5ca66cb132..9eab7c6bcb 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -77,6 +77,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto *qkv_out_data = dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + // 2.1 rotary + auto *rotary_tensor = ctx.Input("RotaryPosEmb"); + const int rotary_emb_dims = ctx.Attr("rotary_emb_dims"); + // 3. fmha AttnDropoutParam attn_param( true, "upscale_in_train", 0.0, true, true, 0, nullptr); @@ -297,6 +301,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { qkv_out, *qkv_bias, *src_mask, + rotary_tensor, cache_kv_out, &fmha_out, bsz, @@ -304,6 +309,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { num_head, dim_head, time_step->data()[0], + rotary_emb_dims, 1. / sqrt(dim_head)); } else if (cache_kv_out) { // generation context stage const phi::DenseTensor *pre_cache_kv_tensor = @@ -322,8 +328,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { seq_len, dim_head, compute_bias); - fmha_compute.ComputeForwardWithoutTranspose(qkv_out, - pre_cache_kv_tensor, + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor->data(); + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); + } + + fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, src_mask, &q_transpose_out, &kv_transpose_out, @@ -383,8 +406,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { seq_len, dim_head, compute_bias); - fmha_compute.ComputeForwardWithoutTranspose(qkv_out, - cache_kv, + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor->data(); + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); + } + + fmha_compute.ComputeForwardWithoutTranspose(cache_kv, src_mask, &q_transpose_out, &kv_transpose_out, @@ -594,6 +634,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto *qkv_out_data = dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + // 2.1 rotary + auto *rotary_tensor = ctx.Input("RotaryPosEmb"); + const int rotary_emb_dims = ctx.Attr("rotary_emb_dims"); + // 3. fmha AttnDropoutParam attn_param( true, "upscale_in_train", 0.0, true, true, 0, nullptr); @@ -821,6 +865,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { qkv_out, *qkv_bias, *src_mask, + rotary_tensor, cache_kv_out, &fmha_out, bsz, @@ -828,6 +873,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { num_head, dim_head, time_step->data()[0], + rotary_emb_dims, 1. / sqrt(dim_head)); } else if (cache_kv_out) { // generation context stage const phi::DenseTensor *pre_cache_kv_tensor = @@ -846,8 +892,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { seq_len, dim_head, compute_bias); - fmha_compute.ComputeForwardWithoutTranspose(qkv_out, - pre_cache_kv_tensor, + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor->data(); + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); + } + + fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, src_mask, &q_transpose_out, &kv_transpose_out, @@ -907,8 +970,25 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { seq_len, dim_head, compute_bias); - fmha_compute.ComputeForwardWithoutTranspose(qkv_out, - cache_kv, + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor->data(); + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); + } + + fmha_compute.ComputeForwardWithoutTranspose(cache_kv, src_mask, &q_transpose_out, &kv_transpose_out, diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index f9454cb45c..6a276de9e6 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -127,6 +127,11 @@ struct Masked_multihead_attention_params { // v [B, num_head, max_seq_len, dim_head] T *cache_kv; + // The RoPE embedding, [B, 1, 1, dim_head] + // rotary_emb_dims = 1 if pos_ids_extra is null else 2 + const T *rotary_emb; + int rotary_emb_dims; + int batch_size; int num_head; int timestep; // cache_seq_length @@ -404,6 +409,18 @@ inline __device__ float4 mul(float4 a, float b) { return res; } +template +inline __device__ Qk_vec apply_rotary_emb(Qk_vec input_left, + Qk_vec input_right, + Qk_vec cos_emb, + Qk_vec sin_emb, + float alpha) { + Qk_vec res1 = mul(input_left, cos_emb); + Qk_vec res2 = mul(input_right, sin_emb); + res2 = mul(res2, alpha); + return add(res1, res2); +} + inline __device__ float sum(float v) { return v; } inline __device__ float sum(float2 v) { return v.x + v.y; } inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } @@ -804,6 +821,67 @@ __global__ void masked_multihead_attention_kernel( // we may not require k_bias. k = add(k, k_bias); + // rotary pos emb + if (params.rotary_emb_dims != 0) { + int last_dim = Dh / params.rotary_emb_dims; + int half_lastdim = last_dim / 2; + int rotary_offset = bi * Dh + tid * QK_VEC_SIZE; + const T *cos_base = params.rotary_emb; + const T *sin_base = params.rotary_emb + params.batch_size * Dh; + int stride = half_lastdim / QK_VEC_SIZE; + int stride_all_lastdim = 2 * stride; + int right_id = tid / stride_all_lastdim * stride_all_lastdim + + (tid + stride) % (stride_all_lastdim); + int qk_right_offset = qkv_base_offset + right_id * QK_VEC_SIZE; + int qk_right_bias_offset = hi * Dh + right_id * QK_VEC_SIZE; + Qk_vec q_right; + zero(q_right); + q_right = + (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&q_base[qk_right_offset]) + : q_right; + Qk_vec k_right; + zero(k_right); + k_right = + (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&k_base[qk_right_offset]) + : k_right; + + Qk_vec q_right_bias; + zero(q_right_bias); + q_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &q_bias_base[qk_right_bias_offset]) + : q_right_bias; + Qk_vec k_right_bias; + zero(k_right_bias); + k_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &k_bias_base[qk_right_bias_offset]) + : k_right_bias; + + q_right = add(q_right, q_right_bias); + k_right = add(k_right, k_right_bias); + + Qk_vec cos_emb; + zero(cos_emb); + cos_emb = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&cos_base[rotary_offset]) + : cos_emb; + + Qk_vec sin_emb; + zero(sin_emb); + sin_emb = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&sin_base[rotary_offset]) + : sin_emb; + float alpha = (tid % stride_all_lastdim) < stride ? static_cast(-1) + : static_cast(1); + q = apply_rotary_emb(q, q_right, cos_emb, sin_emb, alpha); + k = apply_rotary_emb(k, k_right, cos_emb, sin_emb, alpha); + } + *reinterpret_cast(&q_smem[tid * QK_VEC_SIZE]) = q; int co = tid / QK_VECS_IN_16B; @@ -1120,6 +1198,7 @@ void fmha(const phi::GPUContext &dev_ctx, const phi::DenseTensor &qkv_tensor, const phi::DenseTensor &qkv_bias_tensor, const phi::DenseTensor &src_mask_tensor, + const phi::DenseTensor *rotary_tensor, phi::DenseTensor *cache_kv_tensor, phi::DenseTensor *out_tensor, int batch_size, @@ -1127,6 +1206,7 @@ void fmha(const phi::GPUContext &dev_ctx, int num_head, int dim_head, int timestep, + int rotary_emb_dims, float inv_sqrt_dh) { Masked_multihead_attention_params params; params.out = out_tensor->data(); @@ -1134,12 +1214,18 @@ void fmha(const phi::GPUContext &dev_ctx, params.qkv_bias = qkv_bias_tensor.data(); params.attn_mask = src_mask_tensor.data(); params.cache_kv = cache_kv_tensor->data(); + if (rotary_emb_dims > 0) { + params.rotary_emb = rotary_tensor->data(); + } else { + params.rotary_emb = nullptr; + } params.batch_size = batch_size; params.num_head = num_head; params.timestep = timestep; params.max_seq_length = max_seq_length; params.inv_sqrt_dh = inv_sqrt_dh; + params.rotary_emb_dims = rotary_emb_dims; switch (dim_head) { case 10: @@ -1169,6 +1255,35 @@ void fmha(const phi::GPUContext &dev_ctx, } } +template +void fmha(const phi::GPUContext &dev_ctx, + const phi::DenseTensor &qkv_tensor, + const phi::DenseTensor &qkv_bias_tensor, + const phi::DenseTensor &src_mask_tensor, + phi::DenseTensor *cache_kv_tensor, + phi::DenseTensor *out_tensor, + int batch_size, + int max_seq_length, + int num_head, + int dim_head, + int timestep, + float inv_sqrt_dh) { + fmha(dev_ctx, + qkv_tensor, + qkv_bias_tensor, + src_mask_tensor, + nullptr, + cache_kv_tensor, + out_tensor, + batch_size, + max_seq_length, + num_head, + dim_head, + timestep, + 0, + inv_sqrt_dh); +} + // NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8 constexpr int VEC_16B = 16; @@ -1405,6 +1520,94 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, } } +template +__global__ void RotrayKernel(const T *input, + const T *cos_emb, + const T *sin_emb, + T *output, + const int batch_size, + const int head_num, + const int seq_len, + const int last_dim) { + int bi = blockIdx.x; + int hi = blockIdx.y; + int si = blockIdx.z; + int half_lastdim = last_dim / 2; + // Note(ZhenyuLi): Calculate the relevant data at one time, so that no + // additional space is required. + for (int ti = threadIdx.x; ti < half_lastdim; ti += blockDim.x) { + int base_idx = bi * head_num * seq_len * last_dim + + hi * seq_len * last_dim + si * last_dim; + int left_idx = base_idx + ti; + const int right_idx = base_idx + ti + half_lastdim; + int emb_idx = bi * seq_len * last_dim + si * last_dim + ti; + T input_left = input[left_idx]; + T input_right = input[right_idx]; + T cos_tmp = cos_emb[emb_idx]; + T sin_tmp = sin_emb[emb_idx]; + T res1 = input_left * cos_tmp - input_right * sin_tmp; + T res2 = input_right * cos_tmp + input_left * sin_tmp; + output[left_idx] = res1; + output[right_idx] = res2; + } +} + +template +void rotary_qk(const phi::GPUContext &dev_ctx, + T *q, + T *k, // kv + const T *q_input, // q + const T *k_input, // kv + const T *rotary_emb, + const int rotary_emb_dims, + const int batch_size, + const int head_num, + const int seq_len, + const int dim_head) { + // q_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num, + // seq_len * rotary_emb_dims, dim_head / rotary_emb_dims] + // kv_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num, + // seq_len * rotary_emb_dims, dim_head / rotary_emb_dims] rotary_emb [2, bs, + // 1, seq_len, dim_head] -> [2, bs, 1, seq_len * rotary_emb_dims, dim_head / + // rotary_emb_dims] + dim3 grid(batch_size, head_num, seq_len * rotary_emb_dims); + const int last_dim = dim_head / rotary_emb_dims; + auto getBlockSize = [](int dim) { + if (dim > 256) { + return 512; + } else if (dim > 128) { + return 256; + } else if (dim > 64) { + return 128; + } else if (dim > 32) { + return 64; + } else { + return 32; + } + }; + int BlockSize = getBlockSize(last_dim / 2); + const T *cos_emb = rotary_emb; + const T *sin_emb = rotary_emb + batch_size * seq_len * dim_head; + RotrayKernel<<>>( + q_input, + cos_emb, + sin_emb, + q, + batch_size, + head_num, + seq_len * rotary_emb_dims, + last_dim); + RotrayKernel<<>>( + k_input, + cos_emb, + sin_emb, + k, + batch_size, + head_num, + seq_len * rotary_emb_dims, + last_dim); +} + #if CUDA_VERSION >= 11060 // Only Used in Inference template diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 21f97ecb48..68ed995403 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -62,6 +62,7 @@ std::map> op_ins_map = { "QKVBias", "CacheKV", "PreCaches", + "RotaryPosEmb", "TimeStep", "SrcMask", "OutLinearW", diff --git a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py index abbfb3f08b..f1b048102e 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py @@ -120,6 +120,8 @@ class TestFusedMultiTransformerOp(OpTest): self.has_cache_kv = False self.gen_cache_kv = False self.has_pre_cache = False + self.rotary_embs = None + self.rotary_emb_dims = 0 self.training = False @@ -213,12 +215,53 @@ class TestFusedMultiTransformerOp(OpTest): ) else: self.attn_mask = None + + if self.rotary_emb_dims > 0: + self.rotary_emb = np.random.uniform( + -1, + 1, + ( + 2, + self.batch_size, + 1, + self.query_length, + self.head_dim // 2 // self.rotary_emb_dims, + ), + ).astype(self.x_type) + concat_nums = 2 * self.rotary_emb_dims + rotary_embs = [] + for _ in range(concat_nums): + rotary_embs.append(self.rotary_emb) + self.rotary_embs = np.concatenate(rotary_embs, -1) + self.key, self.value = self.query, self.query self.dout = np.random.uniform( -1, 1, (self.batch_size, self.query_length, self.embed_dim) ).astype(self.x_type) + def rotate_half(self, x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return paddle.concat((-x2, x1), axis=-1) + + def apply_rotary_emb(self, x, cos_emb, sin_emb, rotary_emb_dims): + # x shape [bsz, num_heads, seq_len, head_dim] + # cos_emb, sin_emb shape [bsz, 1, seq_len, head_dim] + x_dims = paddle.split(x, num_or_sections=rotary_emb_dims, axis=-1) + cos_dims = paddle.split( + cos_emb, num_or_sections=rotary_emb_dims, axis=-1 + ) + sin_dims = paddle.split( + sin_emb, num_or_sections=rotary_emb_dims, axis=-1 + ) + + rotary_dims = [] + for x_dim, cos_dim, sin_dim in zip(x_dims, cos_dims, sin_dims): + rotary_dims.append( + x_dim * cos_dim + self.rotate_half(x_dim) * sin_dim + ) + return paddle.concat(rotary_dims, axis=-1) + def GetBaselineOut(self): paddle.disable_static(place=paddle.CUDAPlace(0)) tensor_query = paddle.to_tensor(self.query, stop_gradient=False) @@ -238,6 +281,11 @@ class TestFusedMultiTransformerOp(OpTest): else: attn_mask = None + if self.rotary_emb_dims > 0: + rotary_embs = paddle.to_tensor( + self.rotary_embs, stop_gradient=False + ) + for i in range(self.layers): residual = tensor_query ln1_out = tensor_query @@ -254,6 +302,16 @@ class TestFusedMultiTransformerOp(OpTest): v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + if self.rotary_emb_dims > 0: + cos_emb = rotary_embs[0] + sin_emb = rotary_embs[1] + q_out = self.apply_rotary_emb( + q_out, cos_emb, sin_emb, self.rotary_emb_dims + ) + k_out = self.apply_rotary_emb( + k_out, cos_emb, sin_emb, self.rotary_emb_dims + ) + if self.has_cache_kv: # [1, B, n_head, cache_seq_len, head_dim] cache_k, cache_v = paddle.split(cache_kv, 2) @@ -414,6 +472,13 @@ class TestFusedMultiTransformerOp(OpTest): (3, self.num_heads, self.head_dim, self.embed_dim) ) + if self.rotary_emb_dims > 0: + rotary_embs = paddle.to_tensor( + self.rotary_embs, stop_gradient=False + ) + else: + rotary_embs = None + x = paddle.to_tensor(self.query, stop_gradient=False) cache_kvs, cache_kv = None, None time_step = None @@ -550,6 +615,8 @@ class TestFusedMultiTransformerOp(OpTest): pre_layer_norm=self.pre_layer_norm, epsilon=epsilon, cache_kvs=cache_kvs, + rotary_embs=rotary_embs, + rotary_emb_dims=self.rotary_emb_dims, pre_caches=pre_caches, time_step=time_step, attn_mask=attn_mask, @@ -573,6 +640,11 @@ class TestFusedMultiTransformerOp(OpTest): time_step = None time_step_feed = None pre_caches, pre_cache = None, None + rotary_embs = None + + if self.rotary_emb_dims > 0: + rotary_embs = paddle.to_tensor(self.rotary_embs) + if self.has_cache_kv: cache_kvs = [] @@ -727,6 +799,8 @@ class TestFusedMultiTransformerOp(OpTest): attn_mask=attn_mask, caches=cache_kvs, pre_caches=pre_caches, + rotary_embs=rotary_embs, + rotary_emb_dims=self.rotary_emb_dims, time_step=time_step, )[0] exe = paddle.static.Executor(place=paddle.CUDAPlace(0)) @@ -735,7 +809,9 @@ class TestFusedMultiTransformerOp(OpTest): 'x': self.query, 'cache_kvs': cache_kvs_feed, 'pre_caches': pre_caches_feed, + 'rotary_embs': rotary_embs, 'time_step': time_step_feed, + 'rotary_emb_dims': self.rotary_emb_dims, 'attn_mask': attn_mask, } out = exe.run( @@ -802,6 +878,38 @@ class TestFusedMultiTransformerOp(OpTest): ) +class TestFusedMultiTransformerOpRotaryFP16(TestFusedMultiTransformerOp): + def config(self): + super().config() + self.x_type = np.float16 + self.rotary_emb_dims = 1 + + +class TestFusedMultiTransformerOpGenRotaryFP16(TestFusedMultiTransformerOp): + def config(self): + super().config() + self.x_type = np.float16 + self.has_cache_kv = True + self.gen_cache_kv = False + self.query_length = 1 + self.key_length, self.value_length = ( + self.query_length, + self.query_length, + ) + self.rotary_emb_dims = 2 + + +class TestFusedMultiTransformerOpGenCacheRotaryFP16( + TestFusedMultiTransformerOp +): + def config(self): + super().config() + self.x_type = np.float16 + self.has_cache_kv = True + self.gen_cache_kv = True + self.rotary_emb_dims = 1 + + class TestFusedMultiTransformerOpFp16(TestFusedMultiTransformerOp): def config(self): super().config() @@ -932,12 +1040,15 @@ class TestFusedMultiTransformerOpPreCacheStatic(TestFusedMultiTransformerOp): ) def test_fused_multi_transformer_op(self): - final_out_ref = self.GetBaselineOut() - final_out = self.GetFusedMultiTransformerOutStatic() - - np.testing.assert_allclose( - final_out_ref, final_out, rtol=self.rtol, atol=self.atol - ) + for i in range(3): + self.rotary_emb_dims = i + self.generate_input_data() + final_out_ref = self.GetBaselineOut() + final_out = self.GetFusedMultiTransformerOutStatic() + + np.testing.assert_allclose( + final_out_ref, final_out, rtol=self.rtol, atol=self.atol + ) if __name__ == "__main__": diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index a7e342a8ac..61270f86d3 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -845,9 +845,11 @@ def fused_multi_transformer( epsilon=1e-05, cache_kvs=None, pre_caches=None, + rotary_embs=None, time_step=None, attn_mask=None, dropout_rate=0.0, + rotary_emb_dims=0, activation="gelu", training=False, mode='upscale_in_train', @@ -912,11 +914,14 @@ def fused_multi_transformer( epsilon (float, optional): Small float value added to denominator of the layer_norm to avoid dividing by zero. Default is 1e-5. cache_kvs (list(Tensor)|tuple(Tensor), optional): The cache structure tensors for the generation model. The shape is `[2, bsz, num\_head, max\_seq\_len, head\_dim]`. Default None. pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. + rotary_embs (Tensor optional): The RoPE embs for rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be in CPUPlace. Default None. attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape `[batch_size, 1, sequence_length, sequence_length]`. Default None. dropout_rate (float, optional): The dropout probability of setting units to zero. Default 0.0. + rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None, + 1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0. activation (str, optional): The activation. Default "gelu". training (bool, optional): A flag indicating whether it is in train phrase or not. Default False. mode (str, optional): ['upscale_in_train'(default) | 'downscale_in_infer'] @@ -1006,6 +1011,7 @@ def fused_multi_transformer( qkv_biases, cache_kvs, pre_caches, + rotary_embs, time_step, attn_mask, linear_weights, @@ -1023,6 +1029,8 @@ def fused_multi_transformer( epsilon, 'dropout_rate', dropout_rate, + 'rotary_emb_dims', + rotary_emb_dims, 'is_test', not training, 'dropout_implementation', @@ -1063,6 +1071,8 @@ def fused_multi_transformer( inputs['TimeStep'] = time_step if pre_caches is not None: inputs['PreCaches'] = pre_caches + if rotary_emb_dims > 0: + inputs['RotaryPosEmb'] = rotary_embs inputs['SrcMask'] = attn_mask inputs['OutLinearW'] = linear_weights if linear_biases is not None: @@ -1082,6 +1092,7 @@ def fused_multi_transformer( 'pre_layer_norm': pre_layer_norm, 'epsilon': epsilon, 'dropout_rate': dropout_rate, + 'rotary_emb_dims': rotary_emb_dims, 'is_test': not training, 'dropout_implementation': mode, 'act_method': activation, diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index 2f745a3feb..78fc72794e 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -1357,7 +1357,14 @@ class FusedMultiTransformer(Layer): self.name = name def forward( - self, src, attn_mask=None, caches=None, pre_caches=None, time_step=None + self, + src, + attn_mask=None, + caches=None, + pre_caches=None, + rotary_embs=None, + rotary_emb_dims=0, + time_step=None, ): r""" Applies multi transformer layers on the input. @@ -1378,6 +1385,9 @@ class FusedMultiTransformer(Layer): `[2, batch_size, num_head, max_seq_len, head_dim]`. Default None. pre_caches (list(Tensor)|tuple(Tensor), optional): The prefix caches for the generation model. The shape is `[2, bsz, num\_head, cache\_len, head\_dim]`. Default None. + rotary_embs (Tensor optional): The RoPE embs for the rotary computation. The shape is `[2, bsz, 1, seq\_len, head\_dim]`. Default None. + rotary_emb_dims (int, optional): The rotary_emb_dims of rotary computation, and it is 0 when rotary_embs is None, + 1 when rotary_embs is not None and pos_extra_ids is None, 2 when rotary_embs and pos_extra_ids are both not None. Default 0. time_step (Tensor, optional): The time step tensor for the generation model. Which used in decode stage, to represent the time step, that is, the real seq_len of CacheKV. The shape is `[1]`, must be @@ -1411,9 +1421,11 @@ class FusedMultiTransformer(Layer): epsilon=self._epsilon, cache_kvs=caches, pre_caches=pre_caches, + rotary_embs=rotary_embs, time_step=time_step, attn_mask=attn_mask, dropout_rate=self.dropout_rate, + rotary_emb_dims=rotary_emb_dims, activation=self.activation, training=self.training, mode='upscale_in_train', -- GitLab