diff --git a/paddle/phi/api/yaml/fused_backward.yaml b/paddle/phi/api/yaml/fused_backward.yaml index 5f49e790e2550bff155aec0a003cfea2df2d0e0f..8dfa117d44c976fc9519340aa1f077bf26bb21a4 100644 --- a/paddle/phi/api/yaml/fused_backward.yaml +++ b/paddle/phi/api/yaml/fused_backward.yaml @@ -17,10 +17,10 @@ support_dygraph_mode : true - backward_op : fused_rotary_position_embedding_grad - forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos) -> Tensor(out_q), Tensor(out_k), Tensor(out_v) - args : (Tensor sin, Tensor cos, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad) + forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style) -> Tensor(out_q), Tensor(out_k), Tensor(out_v) + args : (Tensor sin, Tensor cos, Tensor position_ids, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad, bool use_neox_rotary_style) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) - optional : sin, cos, out_k_grad, out_v_grad, k_grad, v_grad + optional : sin, cos, position_ids, out_k_grad, out_v_grad, k_grad, v_grad infer_meta : func : FusedRopeGradInferMeta kernel : diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 90593ed4d43eab58d083c9393f4f302eb534e57c..9820fd614055720d899d7d32ad41a3ca37a91fff 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -149,11 +149,11 @@ optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index - op : fused_rotary_position_embedding - args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos) + args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style = true) output : Tensor(out_q), Tensor(out_k), Tensor(out_v) infer_meta : func : FusedRopeInferMeta - optional : k,v,sin,cos, out_k, out_v + optional : k, v, sin, cos, position_ids, out_k, out_v kernel : func : fused_rotary_position_embedding data_type : q diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 31da8f079bb4322fc964a0ab801eaffca410e181..5dedf81e937ac841847a0563915ddca3f1529491 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1219,9 +1219,11 @@ void IndexPutGradInferMeta(const MetaTensor& x, void FusedRopeGradInferMeta(const MetaTensor& sin, const MetaTensor& cos, + const MetaTensor& position_ids, const MetaTensor& dout_q, const MetaTensor& dout_k, const MetaTensor& dout_v, + bool use_neox_rotary_style, MetaTensor* dq, MetaTensor* dk, MetaTensor* dv) { diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index b8a3aee03bde451d2a1d950231e615d418898384..a00bc2cde450fceec74bb140bb539d6b980c9c21 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -186,9 +186,11 @@ void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset, void FusedRopeGradInferMeta(const MetaTensor& sin, const MetaTensor& cos, + const MetaTensor& position_ids, const MetaTensor& dout_q, const MetaTensor& dout_k, const MetaTensor& dout_v, + bool use_neox_rotary_style, MetaTensor* dq, MetaTensor* dk, MetaTensor* dv); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index e54c6ab3ba7dc4332a64012084dd3fec51034daa..63c421dcdcdb989bd7c039e65b43f59a3a2b4e4a 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -4041,6 +4041,8 @@ void FusedRopeInferMeta(const MetaTensor& q, const MetaTensor& v, const MetaTensor& sin, const MetaTensor& cos, + const MetaTensor& position_ids, + bool use_neox_rotary_style, MetaTensor* out_q, MetaTensor* out_k, MetaTensor* out_v) { diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 8a782676fe9175c0c8e1e077b5396c36465249dc..9beb9d213899d954597e2def53e98245e322772e 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -807,6 +807,8 @@ void FusedRopeInferMeta(const MetaTensor& q, const MetaTensor& v, const MetaTensor& sin, const MetaTensor& cos, + const MetaTensor& position_ids, + bool use_neox_rotary_style, MetaTensor* out_q, MetaTensor* out_k, MetaTensor* out_v); diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu index 442317eb53d98062d0cbc7fc8d0b9b743be940b2..70ea70912f6397c562b526b4659637b866452dda 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu @@ -27,9 +27,11 @@ template void FusedRopeGradKernel(const Context& dev_ctx, const paddle::optional& sin, const paddle::optional& cos, + const paddle::optional& position_ids, const DenseTensor& dout_q, const paddle::optional& dout_k, const paddle::optional& dout_v, + bool use_neox_rotary_style, DenseTensor* dq, DenseTensor* dk, DenseTensor* dv) { @@ -58,6 +60,7 @@ void FusedRopeGradKernel(const Context& dev_ctx, phi::Array outs_data; phi::Array ins_data; phi::Array sin_cos_data; + const int64_t* position_ids_data = NULL; ins_data[0] = dout_q.data(); outs_data[0] = dq->data(); @@ -86,21 +89,42 @@ void FusedRopeGradKernel(const Context& dev_ctx, sin_cos_data[1] = cos->data(); flag_sin_cos = true; + + if (position_ids.get_ptr()) { + position_ids_data = position_ids->data(); + } } int sign = -1; - VectorizedFusedRopeKernel - <<>>(ins_data, - sin_cos_data, - flag_sin_cos, - sign, - batch_size, - seq_len, - num_heads, - head_dim, - outs_data, - num_inputs, - div_c); + if (use_neox_rotary_style) { + VectorizedFusedRopeWithRotateEveryTwoKernel + <<>>(ins_data, + sin_cos_data, + position_ids_data, + flag_sin_cos, + sign, + batch_size, + seq_len, + num_heads, + head_dim, + outs_data, + num_inputs, + div_c); + } else { + VectorizedFusedRopeWithRotateHalfKernel + <<>>(ins_data, + sin_cos_data, + position_ids_data, + flag_sin_cos, + sign, + batch_size, + seq_len, + num_heads, + head_dim, + outs_data, + num_inputs, + div_c); + } } } // namespace fusion diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu index f6dcbc2a9038f06c11e4509edac047efeb867103..6e032211cc6a09f8b56596dc5942dff8bdecfe1e 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu @@ -30,6 +30,8 @@ void FusedRopeKernel(const Context& dev_ctx, const paddle::optional& v, const paddle::optional& sin, const paddle::optional& cos, + const paddle::optional& position_ids, + bool use_neox_rotary_style, DenseTensor* out_q, DenseTensor* out_k, DenseTensor* out_v) { @@ -59,6 +61,7 @@ void FusedRopeKernel(const Context& dev_ctx, phi::Array outs_data; phi::Array ins_data; phi::Array sin_cos_data; + const int64_t* position_ids_data = NULL; ins_data[0] = q.data(); outs_data[0] = out_q->data(); @@ -109,15 +112,52 @@ void FusedRopeKernel(const Context& dev_ctx, "The batch_size and num_heads of sin and cos must be 1.")); } int sin_seq_len_dim = (dims_size) == 4 ? 1 : 0; - PADDLE_ENFORCE_EQ((sin_dims[dims_size - 1] == head_dim && - sin_dims[sin_seq_len_dim] == seq_len), - true, - phi::errors::InvalidArgument( - "The seq_len and head_dim of sin and cos " - "must be the same as those of q. But recieved sin's " - "shape is {%s}, q's shape is {%s}.", - sin_dims, - q.dims())); + + if (position_ids.get_ptr()) { + PADDLE_ENFORCE_EQ( + (sin_dims[dims_size - 1] == head_dim && + sin_dims[sin_seq_len_dim] >= seq_len), + true, + phi::errors::InvalidArgument( + "The seq_len of sin and cos must be greater than or equal to " + "this of q. The head_dim of sin and cos must be the same as this " + "of q. But recieved sin's " + "shape is {%s}, q's shape is {%s}.", + sin_dims, + q.dims())); + + auto position_ids_dims = position_ids.get_ptr()->dims(); + PADDLE_ENFORCE_EQ(position_ids_dims.size(), + 2, + phi::errors::InvalidArgument( + "The dims of position_ids is expected to " + "be 2, but recieved %d.", + position_ids_dims.size())); + + PADDLE_ENFORCE_EQ( + (position_ids_dims[0] == batch_size && + position_ids_dims[1] == seq_len), + true, + phi::errors::InvalidArgument( + "The batch_size and seq_len of position_ids must be the same as " + "those of q. But recieved position_ids's " + "shape is {%s}, q's shape is {%s}.", + position_ids_dims, + q.dims())); + + position_ids_data = position_ids->data(); + } else { + PADDLE_ENFORCE_EQ( + (sin_dims[dims_size - 1] == head_dim && + sin_dims[sin_seq_len_dim] == seq_len), + true, + phi::errors::InvalidArgument( + "The seq_len and head_dim of sin and cos " + "must be the same as those of q. But recieved sin's " + "shape is {%s}, q's shape is {%s}.", + sin_dims, + q.dims())); + } sin_cos_data[0] = sin->data(); sin_cos_data[1] = cos->data(); @@ -126,18 +166,35 @@ void FusedRopeKernel(const Context& dev_ctx, } int sign = 1; - VectorizedFusedRopeKernel - <<>>(ins_data, - sin_cos_data, - flag_sin_cos, - sign, - batch_size, - seq_len, - num_heads, - head_dim, - outs_data, - num_inputs, - div_c); + if (use_neox_rotary_style) { + VectorizedFusedRopeWithRotateEveryTwoKernel + <<>>(ins_data, + sin_cos_data, + position_ids_data, + flag_sin_cos, + sign, + batch_size, + seq_len, + num_heads, + head_dim, + outs_data, + num_inputs, + div_c); + } else { + VectorizedFusedRopeWithRotateHalfKernel + <<>>(ins_data, + sin_cos_data, + position_ids_data, + flag_sin_cos, + sign, + batch_size, + seq_len, + num_heads, + head_dim, + outs_data, + num_inputs, + div_c); + } } } // namespace fusion } // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h index 54ffba19e60c0e04761187737f13039c005e1134..972f5ee633bbb02ef2ff263de6518466f8d2dca2 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h +++ b/paddle/phi/kernels/fusion/gpu/fused_rope_utils.h @@ -20,17 +20,71 @@ namespace phi { namespace fusion { template -__global__ void VectorizedFusedRopeKernel(phi::Array ins_data, - phi::Array sin_cos_data, - bool flag_sin_cos, - int sign, - int64_t batch_size, - int64_t seq_len, - int64_t num_heads, - int64_t head_dim, - phi::Array outs_data, - int num_inputs, - MPType div_c) { +__device__ void VectorizedGetSinCos(phi::Array sin_cos_data, + const int64_t* position_ids_data, + bool flag_sin_cos, + int64_t index, + int64_t seq_len, + int64_t num_heads, + int64_t head_dim, + MPType* out_sin, + MPType* out_cos, + MPType div_c) { + MPType* sin_value = out_sin; + MPType* cos_value = out_cos; + + if (flag_sin_cos) { +#pragma unroll + for (int64_t nx = 0; nx < VecSize; ++nx) { + int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); + int64_t pos_seq_ori = index_wc / (num_heads * head_dim); + int64_t pos_seq; + if (position_ids_data) { + int64_t pos_bs = (index + nx) / (seq_len * num_heads * head_dim); + int64_t index_ids = pos_bs * seq_len + pos_seq_ori; + pos_seq = position_ids_data[index_ids]; + } else { + pos_seq = pos_seq_ori; + } + int64_t pos_head = index_wc % head_dim; + int64_t index_sc = pos_seq * head_dim + pos_head; + const T* sin_input = sin_cos_data[0] + index_sc; + const T* cos_input = sin_cos_data[1] + index_sc; + + sin_value[nx] = static_cast(sin_input[0]); + cos_value[nx] = static_cast(cos_input[0]); + } + } else { +#pragma unroll + for (int nx = 0; nx < VecSize; ++nx) { + // get sin_index and cos_index + int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); + int64_t pos_seq = index_wc / (num_heads * head_dim); + MPType idx = static_cast((index_wc % head_dim) / 2 * 2.0); + MPType indicses = + static_cast(1) / + pow(static_cast(10000), idx * static_cast(div_c)); + MPType value = pos_seq * indicses; + sin_value[nx] = sin(value); + cos_value[nx] = cos(value); + } + } +} + +template +__global__ void VectorizedFusedRopeWithRotateEveryTwoKernel( + phi::Array ins_data, + phi::Array sin_cos_data, + const int64_t* position_ids_data, + bool flag_sin_cos, + int sign, + int64_t batch_size, + int64_t seq_len, + int64_t num_heads, + int64_t head_dim, + phi::Array outs_data, + int num_inputs, + MPType div_c) { int64_t index = (static_cast(blockIdx.x) * static_cast(blockDim.x) + threadIdx.x) * @@ -46,34 +100,16 @@ __global__ void VectorizedFusedRopeKernel(phi::Array ins_data, constexpr int kVectorsPerThread = VecSize / 2; for (; index < size; index += stride) { - if (flag_sin_cos) { -#pragma unroll - for (int64_t nx = 0; nx < VecSize; ++nx) { - int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); - int64_t pos_seq = index_wc / (num_heads * head_dim); - int64_t pos_head = index_wc % head_dim; - int64_t index_sc = pos_seq * head_dim + pos_head; - const T* sin_input = sin_cos_data[0] + index_sc; - const T* cos_input = sin_cos_data[1] + index_sc; - - sin_value[nx] = static_cast(sin_input[0]); - cos_value[nx] = static_cast(cos_input[0]); - } - } else { -#pragma unroll - for (int nx = 0; nx < VecSize; ++nx) { - // get sin_index and cos_index - int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim); - int64_t pos_seq = index_wc / (num_heads * head_dim); - MPType idx = static_cast((index_wc % head_dim) / 2 * 2.0); - MPType indicses = - static_cast(1) / - pow(static_cast(10000), idx * static_cast(div_c)); - MPType value = pos_seq * indicses; - sin_value[nx] = sin(value); - cos_value[nx] = cos(value); - } - } + VectorizedGetSinCos(sin_cos_data, + position_ids_data, + flag_sin_cos, + index, + seq_len, + num_heads, + head_dim, + sin_value, + cos_value, + div_c); #pragma unroll for (int iter = 0; iter < 3; iter++) { @@ -102,5 +138,74 @@ __global__ void VectorizedFusedRopeKernel(phi::Array ins_data, } } +template +__global__ void VectorizedFusedRopeWithRotateHalfKernel( + phi::Array ins_data, + phi::Array sin_cos_data, + const int64_t* position_ids_data, + bool flag_sin_cos, + int sign, + int64_t batch_size, + int64_t seq_len, + int64_t num_heads, + int64_t head_dim, + phi::Array outs_data, + int num_inputs, + MPType div_c) { + int64_t index = + (static_cast(blockIdx.x) * static_cast(blockDim.x) + + threadIdx.x) * + VecSize; + int64_t stride = static_cast(gridDim.x) * + static_cast(blockDim.x) * VecSize; + int64_t size = batch_size * seq_len * num_heads * head_dim; + MPType sin_value[VecSize]; + MPType cos_value[VecSize]; + MPType result[VecSize]; + T store[VecSize]; + using VecType = phi::AlignedVector; + constexpr int kVectorsPerThread = VecSize / 2; + + for (; index < size; index += stride) { + VectorizedGetSinCos(sin_cos_data, + position_ids_data, + flag_sin_cos, + index, + seq_len, + num_heads, + head_dim, + sin_value, + cos_value, + div_c); + + // use rotate_half mode + int stride_r = head_dim / 2; +#pragma unroll + for (int iter = 0; iter < 3; iter++) { + if (iter > num_inputs) break; + // get value_index and rotate_half_index + int index_v = index; + int index_r = (index % head_dim) < stride_r ? (index + stride_r) + : (index - stride_r); + MPType sign_r = (index % head_dim) < stride_r ? static_cast(-1) + : static_cast(1); + const T* input_v = ins_data[iter] + index_v; + const T* input_r = ins_data[iter] + index_r; + VecType* out = reinterpret_cast(outs_data[iter] + index); + +#pragma unroll + for (int nx = 0; nx < VecSize; ++nx) { + MPType p0 = static_cast(input_v[nx]); + MPType p1 = static_cast(input_r[nx]); + + result[nx] = cos_value[nx] * p0 + sign * sign_r * sin_value[nx] * p1; + + store[nx] = static_cast(result[nx]); + } + out[0] = *(reinterpret_cast(store)); + } + } +} + } // namespace fusion } // namespace phi diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index e05ae63f07807e548b58c38e8419cbda3b12102d..f68dfb1dcd53f9b14b10b14faef0edcc8e5e1fdc 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -17,16 +17,26 @@ from paddle import _C_ops from paddle.framework import in_dynamic_mode -def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None): +def fused_rotary_position_embedding( + q, + k=None, + v=None, + sin=None, + cos=None, + position_ids=None, + use_neox_rotary_style=True, +): r""" Fused rotary position embedding. Args: - q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if q must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. - k (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if k must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. - v (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. - sin (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if sin must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. - cos (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if cos must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2. + q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of q must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. + k (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of k must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. + v (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2. + sin (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of sin must be [seq_len, head_dim] or [1, seq_len, 1, head_dim] and head_dim must be a multiple of 2. + cos (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of cos must be [seq_len, head_dim] or [1, seq_len, 1, head_dim] and head_dim must be a multiple of 2. + position_ids (Tensor, optional): The input tensor. The data type is int64. The shape of position_ids must be [batch_size, seq_len]. + use_neox_rotary_style(optional|bool): When the use_neox_rotary_style is True, every two adjacent numbers are calculated. When the use_neox_rotary_style is False, the numbers corresponding to the positions of the front half and back half segments are calculated. Default True. Returns: out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` . @@ -36,23 +46,51 @@ def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None): .. code-block:: python - # required: gpu - import paddle - from paddle.incubate.nn.functional import fused_rotary_position_embedding + >>> # required: gpu + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> from paddle.incubate.nn.functional import fused_rotary_position_embedding - q = paddle.randn([1, 1, 4, 10], dtype='float16') - k = paddle.randn([1, 1, 4, 10], dtype='float16') - v = paddle.randn([1, 1, 4, 10], dtype='float16') - out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v) + >>> paddle.device.set_device('gpu') - x = paddle.randn([1, 1, 1, 10], dtype='float16') - y = paddle.randn([1, 1, 1, 10], dtype='float16') - sin = paddle.sin(x) - cos = paddle.cos(y) - out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos) + >>> # batch_size = 2 + >>> # seq_len = 2 + >>> # num_heads = 2 + >>> # head_dim = 2 + + >>> paddle.seed(1204) + + >>> # q, k, v: [batch_size, seq_len, num_heads, head_dim] + >>> q = paddle.randn([2, 2, 2, 2], dtype='float16') + >>> k = paddle.randn([2, 2, 2, 2], dtype='float16') + >>> v = paddle.randn([2, 2, 2, 2], dtype='float16') + + >>> # sin, cos: [1, seq_len, 1, head_dim] + >>> x = paddle.randn([1, 2, 1, 2], dtype='float16') + >>> y = paddle.randn([1, 2, 1, 2], dtype='float16') + >>> sin = paddle.sin(x) + >>> cos = paddle.cos(y) + + >>> # position_ids: [batch_size, seq_len] + >>> position_ids = paddle.randint(high=2, shape=[2, 2], dtype='int64') + + >>> # out_q, out_k, out_v: [batch_size, seq_len, num_heads, head_dim] + >>> out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos, position_ids=position_ids, use_neox_rotary_style=False) + >>> print(out_q) + Tensor(shape=[2, 2, 2, 2], dtype=float16, place=Place(gpu:0), stop_gradient=True, + [[[[-0.54931641, 0.64990234], + [-1.08691406, 1.18261719]], + [[ 0.57812500, 0.11749268], + [-0.63281250, 0.15551758]]], + [[[-0.77050781, 0.07733154], + [-0.73730469, -0.16735840]], + [[ 0.07116699, -0.90966797], + [-0.03628540, -0.20202637]]]]) """ if in_dynamic_mode(): - return _C_ops.fused_rotary_position_embedding(q, k, v, sin, cos) + return _C_ops.fused_rotary_position_embedding( + q, k, v, sin, cos, position_ids, use_neox_rotary_style + ) raise RuntimeError( "This feature is currently supported only in dynamic mode and with CUDAPlace." diff --git a/test/legacy_test/test_fused_rotary_position_embedding.py b/test/legacy_test/test_fused_rotary_position_embedding.py index 9842fbf1f4ee8c4b0f1140f8928a0da4fb609f1f..de6355d56a5ee6b0b2598b94c38ff2696e1da96a 100644 --- a/test/legacy_test/test_fused_rotary_position_embedding.py +++ b/test/legacy_test/test_fused_rotary_position_embedding.py @@ -41,6 +41,24 @@ def mult_qkv(value, cos_tensor, sin_tensor): return query +def mult_qkv_rotate_half(value, cos_tensor, sin_tensor): + rotate_half_q = paddle.reshape( + paddle.concat( + [ + -value[..., value.shape[-1] // 2 :], + value[..., : value.shape[-1] // 2], + ], + axis=-1, + ), + paddle.shape(value), + ) + query = paddle.add( + paddle.multiply(value, cos_tensor), + paddle.multiply(rotate_half_q, sin_tensor), + ) + return query + + def get_sin_cos_tensor(seq_len, head_dim, sign): pos_seq = paddle.arange(0, seq_len, 1, dtype="float32") indices = paddle.arange(0, head_dim, 2, dtype="float32") @@ -74,22 +92,38 @@ def get_sin_cos_tensor(seq_len, head_dim, sign): return tensor_sin, tensor_cos -def paddle_fused_rotary_position_embedding(init_q, init_k, init_v): +def paddle_fused_rotary_position_embedding( + init_q, init_k, init_v, position_ids=None, use_neox_rotary_style=True +): # permute q, k, v from [batch_size, seq_len, num_heads, head_dim] # to [batch_size, num_heads, seq_len, head_dim] q, k, v = deal_qkv(init_q, init_k, init_v) - sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], -1) + sign = -1 if use_neox_rotary_style else 1 + sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], sign) + + if position_ids is not None: + sin_tensor = sin_tensor.squeeze(axis=[0, 2]) # [seq_len, dim] + cos_tensor = cos_tensor.squeeze(axis=[0, 2]) # [seq_len, dim] + sin_tensor = sin_tensor[position_ids].unsqueeze( + 2 + ) # [bs, seq_len, 1, dim] + cos_tensor = cos_tensor[position_ids].unsqueeze( + 2 + ) # [bs, seq_len, 1, dim] - # permute sin, cos from [1, seq_len, 1, head_dim] - # to [1, 1, seq_len, head_dim] perm = [0, 2, 1, 3] sin_tensor = paddle.transpose(x=sin_tensor, perm=perm) cos_tensor = paddle.transpose(x=cos_tensor, perm=perm) - query = mult_qkv(q, cos_tensor, sin_tensor) - value = mult_qkv(v, cos_tensor, sin_tensor) - key = mult_qkv(k, cos_tensor, sin_tensor) + if use_neox_rotary_style: + query = mult_qkv(q, cos_tensor, sin_tensor) + value = mult_qkv(v, cos_tensor, sin_tensor) + key = mult_qkv(k, cos_tensor, sin_tensor) + else: + query = mult_qkv_rotate_half(q, cos_tensor, sin_tensor) + value = mult_qkv_rotate_half(v, cos_tensor, sin_tensor) + key = mult_qkv_rotate_half(k, cos_tensor, sin_tensor) # permute the result back to [batch_size, seq_len, num_heads, head_dim] r_query, r_key, r_value = deal_qkv(query, key, value) @@ -102,7 +136,7 @@ def paddle_fused_rotary_position_embedding(init_q, init_k, init_v): ) class TestFusedRotaryPositionEmbedding(unittest.TestCase): def setUp(self): - self.shape = [1, 8, 2, 16] + self.shape = [2, 8, 2, 16] self.dtype = 'float32' self.training = True self.seed = 1203 @@ -112,7 +146,14 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase): tmp.stop_gradient = False return tmp - def get_forward_backward(self, rope_function, seed, flag=0): + def get_forward_backward( + self, + rope_function, + seed, + flag=False, + use_neox_rotary_style=True, + position_ids=None, + ): paddle.disable_static() paddle.seed(seed) fw = [] @@ -120,15 +161,45 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase): tensor_q = self.get_paddle_tensor() tensor_k = self.get_paddle_tensor() tensor_v = self.get_paddle_tensor() - if flag: - tensor_sin, tensor_cos = get_sin_cos_tensor( - tensor_q.shape[1], tensor_q.shape[3], 1 - ) - out_q, out_k, out_v = rope_function( - tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos - ) + if use_neox_rotary_style: + if flag: + tensor_sin, tensor_cos = get_sin_cos_tensor( + tensor_q.shape[1], tensor_q.shape[3], 1 + ) + out_q, out_k, out_v = rope_function( + tensor_q, + tensor_k, + tensor_v, + tensor_sin, + tensor_cos, + position_ids=position_ids, + ) + else: + out_q, out_k, out_v = rope_function( + tensor_q, tensor_k, tensor_v, position_ids=position_ids + ) else: - out_q, out_k, out_v = rope_function(tensor_q, tensor_k, tensor_v) + if flag: + tensor_sin, tensor_cos = get_sin_cos_tensor( + tensor_q.shape[1], tensor_q.shape[3], 1 + ) + out_q, out_k, out_v = rope_function( + tensor_q, + tensor_k, + tensor_v, + tensor_sin, + tensor_cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + out_q, out_k, out_v = rope_function( + tensor_q, + tensor_k, + tensor_v, + position_ids=position_ids, + use_neox_rotary_style=False, + ) fw.append(out_q) fw.append(out_k) @@ -166,7 +237,49 @@ class TestFusedRotaryPositionEmbedding(unittest.TestCase): paddle_fused_rotary_position_embedding, seed=self.seed ) f_fw, f_bw = self.get_forward_backward( - fused_rotary_position_embedding, seed=self.seed, flag=1 + fused_rotary_position_embedding, seed=self.seed, flag=True + ) + for i in range(len(p_fw)): + np.testing.assert_allclose( + p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05 + ) + np.testing.assert_allclose( + p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05 + ) + + def test_fused_rope_rotate_half(self): + p_fw, p_bw = self.get_forward_backward( + paddle_fused_rotary_position_embedding, + seed=self.seed, + use_neox_rotary_style=False, + ) + f_fw, f_bw = self.get_forward_backward( + fused_rotary_position_embedding, + seed=self.seed, + use_neox_rotary_style=False, + ) + for i in range(len(p_fw)): + np.testing.assert_allclose( + p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05 + ) + np.testing.assert_allclose( + p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05 + ) + + def test_fused_rope_position_ids(self): + position_ids = paddle.to_tensor( + [[7, 5, 4, 6, 3, 1, 2, 0], [3, 1, 4, 0, 7, 6, 5, 2]] + ) + p_fw, p_bw = self.get_forward_backward( + paddle_fused_rotary_position_embedding, + seed=self.seed, + position_ids=position_ids, + ) + f_fw, f_bw = self.get_forward_backward( + fused_rotary_position_embedding, + seed=self.seed, + flag=True, + position_ids=position_ids, ) for i in range(len(p_fw)): np.testing.assert_allclose(