fused_rotary_position_embedding.py 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from paddle import _C_ops
from paddle.framework import in_dynamic_mode


20 21 22 23 24 25 26 27 28
def fused_rotary_position_embedding(
    q,
    k=None,
    v=None,
    sin=None,
    cos=None,
    position_ids=None,
    use_neox_rotary_style=True,
):
29 30 31 32
    r"""
    Fused rotary position embedding.

    Args:
33 34 35 36 37 38 39
        q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of q must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
        k (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of k must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
        v (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
        sin (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of sin must be [seq_len, head_dim] or [1, seq_len, 1, head_dim] and head_dim must be a multiple of 2.
        cos (Tensor, optional): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape of cos must be [seq_len, head_dim] or [1, seq_len, 1, head_dim] and head_dim must be a multiple of 2.
        position_ids (Tensor, optional): The input tensor. The data type is int64. The shape of position_ids must be [batch_size, seq_len].
        use_neox_rotary_style(optional|bool): When the use_neox_rotary_style is True, every two adjacent numbers are calculated. When the use_neox_rotary_style is False, the numbers corresponding to the positions of the front half and back half segments are calculated. Default True.
40 41 42 43 44 45 46 47 48

    Returns:
        out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` .


    Examples:

        ..  code-block:: python

49 50 51 52
            >>> # required: gpu
            >>> # doctest: +REQUIRES(env:GPU)
            >>> import paddle
            >>> from paddle.incubate.nn.functional import fused_rotary_position_embedding
53

54
            >>> paddle.device.set_device('gpu')
55

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
            >>> # batch_size = 2
            >>> # seq_len = 2
            >>> # num_heads = 2
            >>> # head_dim = 2

            >>> paddle.seed(1204)

            >>> # q, k, v: [batch_size, seq_len, num_heads, head_dim]
            >>> q = paddle.randn([2, 2, 2, 2], dtype='float16')
            >>> k = paddle.randn([2, 2, 2, 2], dtype='float16')
            >>> v = paddle.randn([2, 2, 2, 2], dtype='float16')

            >>> # sin, cos: [1, seq_len, 1, head_dim]
            >>> x = paddle.randn([1, 2, 1, 2], dtype='float16')
            >>> y = paddle.randn([1, 2, 1, 2], dtype='float16')
            >>> sin = paddle.sin(x)
            >>> cos = paddle.cos(y)

            >>> # position_ids: [batch_size, seq_len]
            >>> position_ids = paddle.randint(high=2, shape=[2, 2], dtype='int64')

            >>> # out_q, out_k, out_v: [batch_size, seq_len, num_heads, head_dim]
            >>> out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos, position_ids=position_ids, use_neox_rotary_style=False)
            >>> print(out_q)
            Tensor(shape=[2, 2, 2, 2], dtype=float16, place=Place(gpu:0), stop_gradient=True,
            [[[[-0.54931641,  0.64990234],
               [-1.08691406,  1.18261719]],
              [[ 0.57812500,  0.11749268],
               [-0.63281250,  0.15551758]]],
             [[[-0.77050781,  0.07733154],
               [-0.73730469, -0.16735840]],
              [[ 0.07116699, -0.90966797],
               [-0.03628540, -0.20202637]]]])
89 90
    """
    if in_dynamic_mode():
91 92 93
        return _C_ops.fused_rotary_position_embedding(
            q, k, v, sin, cos, position_ids, use_neox_rotary_style
        )
94 95 96 97

    raise RuntimeError(
        "This feature is currently supported only in dynamic mode and with CUDAPlace."
    )