test_fused_rotary_position_embedding.py 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
#   Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle.fluid import core
from paddle.incubate.nn.functional import fused_rotary_position_embedding


def deal_qkv(init_q, init_k, init_v):
    perm = [0, 2, 1, 3]
    q = paddle.transpose(x=init_q, perm=perm)
    k = paddle.transpose(x=init_k, perm=perm)
    v = paddle.transpose(x=init_v, perm=perm)
    return q, k, v


def mult_qkv(value, cos_tensor, sin_tensor):
    rotate_half_q = paddle.reshape(
        paddle.stack([value[:, :, :, 1::2], value[:, :, :, 0::2]], axis=-1),
        paddle.shape(value),
    )
    query = paddle.add(
        paddle.multiply(value, cos_tensor),
        paddle.multiply(rotate_half_q, sin_tensor),
    )
    return query


44 45 46
def get_sin_cos_tensor(seq_len, head_dim, sign):
    pos_seq = paddle.arange(0, seq_len, 1, dtype="float32")
    indices = paddle.arange(0, head_dim, 2, dtype="float32")
47

48
    indices = 1 / 10000 ** (indices / head_dim)
49 50
    sinusoid_inp = pos_seq.unsqueeze(1) * indices.unsqueeze(0)

51 52
    sin_sin = np.empty((seq_len * head_dim), dtype=np.float32)
    cos_cos = np.empty((seq_len * head_dim), dtype=np.float32)
53 54 55 56 57 58
    numpy_array = sinusoid_inp.numpy()
    iter_array = np.nditer(numpy_array)

    i = 0

    for value in iter_array:
59
        sin_sin[i * 2] = sign * np.sin(value)
60 61 62 63 64
        cos_cos[i * 2 + 0] = np.cos(value)
        sin_sin[i * 2 + 1] = np.sin(value)
        cos_cos[i * 2 + 1] = np.cos(value)
        i += 1

65 66
    tensor_sin = paddle.reshape(
        paddle.to_tensor(sin_sin),
67
        [1, seq_len, 1, head_dim],
68
    )
69 70
    tensor_cos = paddle.reshape(
        paddle.to_tensor(cos_cos),
71
        [1, seq_len, 1, head_dim],
72 73
    )

74 75 76 77
    return tensor_sin, tensor_cos


def paddle_fused_rotary_position_embedding(init_q, init_k, init_v):
78 79
    # permute q, k, v from [batch_size, seq_len, num_heads, head_dim]
    # to [batch_size, num_heads, seq_len, head_dim]
80 81 82 83
    q, k, v = deal_qkv(init_q, init_k, init_v)

    sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], -1)

84 85 86 87 88 89
    # permute sin, cos from [1, seq_len, 1, head_dim]
    # to [1, 1, seq_len, head_dim]
    perm = [0, 2, 1, 3]
    sin_tensor = paddle.transpose(x=sin_tensor, perm=perm)
    cos_tensor = paddle.transpose(x=cos_tensor, perm=perm)

90 91 92 93
    query = mult_qkv(q, cos_tensor, sin_tensor)
    value = mult_qkv(v, cos_tensor, sin_tensor)
    key = mult_qkv(k, cos_tensor, sin_tensor)

94
    # permute the result back to [batch_size, seq_len, num_heads, head_dim]
95 96 97 98 99 100 101 102 103 104
    r_query, r_key, r_value = deal_qkv(query, key, value)
    return r_query, r_key, r_value


@unittest.skipIf(
    not core.is_compiled_with_cuda(),
    "core is not compiled with CUDA ",
)
class TestFusedRotaryPositionEmbedding(unittest.TestCase):
    def setUp(self):
105
        self.shape = [1, 8, 2, 16]
106 107 108 109 110 111 112 113 114
        self.dtype = 'float32'
        self.training = True
        self.seed = 1203

    def get_paddle_tensor(self):
        tmp = paddle.randn(self.shape, self.dtype)
        tmp.stop_gradient = False
        return tmp

115
    def get_forward_backward(self, rope_function, seed, flag=0):
116 117 118 119 120 121 122
        paddle.disable_static()
        paddle.seed(seed)
        fw = []
        bw = []
        tensor_q = self.get_paddle_tensor()
        tensor_k = self.get_paddle_tensor()
        tensor_v = self.get_paddle_tensor()
123 124 125 126 127 128 129 130 131
        if flag:
            tensor_sin, tensor_cos = get_sin_cos_tensor(
                tensor_q.shape[1], tensor_q.shape[3], 1
            )
            out_q, out_k, out_v = rope_function(
                tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos
            )
        else:
            out_q, out_k, out_v = rope_function(tensor_q, tensor_k, tensor_v)
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148

        fw.append(out_q)
        fw.append(out_k)
        fw.append(out_v)

        out_gq = paddle.randn(out_q.shape, self.dtype)
        out_gk = paddle.randn(out_q.shape, self.dtype)
        out_gv = paddle.randn(out_q.shape, self.dtype)
        paddle.autograd.backward(
            [out_q, out_k, out_v], [out_gq, out_gk, out_gv], True
        )
        bw.append(tensor_q)
        bw.append(tensor_k)
        bw.append(tensor_v)

        return fw, bw

149
    def test_fused_rope(self):
150 151 152 153 154 155 156 157 158 159 160 161 162 163
        p_fw, p_bw = self.get_forward_backward(
            paddle_fused_rotary_position_embedding, seed=self.seed
        )
        f_fw, f_bw = self.get_forward_backward(
            fused_rotary_position_embedding, seed=self.seed
        )
        for i in range(len(p_fw)):
            np.testing.assert_allclose(
                p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05
            )
            np.testing.assert_allclose(
                p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
            )

164
    def test_fused_rope_with_sin_cos(self):
165 166 167 168 169 170 171 172 173 174 175 176 177 178
        p_fw, p_bw = self.get_forward_backward(
            paddle_fused_rotary_position_embedding, seed=self.seed
        )
        f_fw, f_bw = self.get_forward_backward(
            fused_rotary_position_embedding, seed=self.seed, flag=1
        )
        for i in range(len(p_fw)):
            np.testing.assert_allclose(
                p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05
            )
            np.testing.assert_allclose(
                p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
            )

179 180 181 182 183 184 185 186 187
    def test_error(self):
        paddle.enable_static()
        with self.assertRaises(RuntimeError):
            static_q = paddle.static.data(
                name="q", shape=self.shape, dtype=self.dtype
            )
            fused_rotary_position_embedding(static_q, static_q, static_q)
        paddle.disable_static()

188 189 190

if __name__ == '__main__':
    unittest.main()