test_fused_rotary_position_embedding.py 5.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
#   Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle.fluid import core
from paddle.incubate.nn.functional import fused_rotary_position_embedding


def deal_qkv(init_q, init_k, init_v):
    perm = [0, 2, 1, 3]
    q = paddle.transpose(x=init_q, perm=perm)
    k = paddle.transpose(x=init_k, perm=perm)
    v = paddle.transpose(x=init_v, perm=perm)
    return q, k, v


def mult_qkv(value, cos_tensor, sin_tensor):
    rotate_half_q = paddle.reshape(
        paddle.stack([value[:, :, :, 1::2], value[:, :, :, 0::2]], axis=-1),
        paddle.shape(value),
    )
    query = paddle.add(
        paddle.multiply(value, cos_tensor),
        paddle.multiply(rotate_half_q, sin_tensor),
    )
    return query


44 45 46
def get_sin_cos_tensor(seq_len, head_dim, sign):
    pos_seq = paddle.arange(0, seq_len, 1, dtype="float32")
    indices = paddle.arange(0, head_dim, 2, dtype="float32")
47

48
    indices = 1 / 10000 ** (indices / head_dim)
49 50
    sinusoid_inp = pos_seq.unsqueeze(1) * indices.unsqueeze(0)

51 52
    sin_sin = np.empty((seq_len * head_dim), dtype=np.float32)
    cos_cos = np.empty((seq_len * head_dim), dtype=np.float32)
53 54 55 56 57 58
    numpy_array = sinusoid_inp.numpy()
    iter_array = np.nditer(numpy_array)

    i = 0

    for value in iter_array:
59
        sin_sin[i * 2] = sign * np.sin(value)
60 61 62 63 64
        cos_cos[i * 2 + 0] = np.cos(value)
        sin_sin[i * 2 + 1] = np.sin(value)
        cos_cos[i * 2 + 1] = np.cos(value)
        i += 1

65 66 67
    tensor_sin = paddle.reshape(
        paddle.to_tensor(sin_sin),
        [1, 1, seq_len, head_dim],
68
    )
69 70 71
    tensor_cos = paddle.reshape(
        paddle.to_tensor(cos_cos),
        [1, 1, seq_len, head_dim],
72 73
    )

74 75 76 77 78 79 80 81
    return tensor_sin, tensor_cos


def paddle_fused_rotary_position_embedding(init_q, init_k, init_v):
    q, k, v = deal_qkv(init_q, init_k, init_v)

    sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], -1)

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    query = mult_qkv(q, cos_tensor, sin_tensor)
    value = mult_qkv(v, cos_tensor, sin_tensor)
    key = mult_qkv(k, cos_tensor, sin_tensor)

    r_query, r_key, r_value = deal_qkv(query, key, value)

    return r_query, r_key, r_value


@unittest.skipIf(
    not core.is_compiled_with_cuda(),
    "core is not compiled with CUDA ",
)
class TestFusedRotaryPositionEmbedding(unittest.TestCase):
    def setUp(self):
        self.shape = [1, 16, 1, 16]
        self.dtype = 'float32'
        self.training = True
        self.seed = 1203

    def get_paddle_tensor(self):
        tmp = paddle.randn(self.shape, self.dtype)
        tmp.stop_gradient = False
        return tmp

107
    def get_forward_backward(self, rope_function, seed, flag=0):
108 109 110 111 112 113 114
        paddle.disable_static()
        paddle.seed(seed)
        fw = []
        bw = []
        tensor_q = self.get_paddle_tensor()
        tensor_k = self.get_paddle_tensor()
        tensor_v = self.get_paddle_tensor()
115 116 117 118 119 120 121 122 123
        if flag:
            tensor_sin, tensor_cos = get_sin_cos_tensor(
                tensor_q.shape[1], tensor_q.shape[3], 1
            )
            out_q, out_k, out_v = rope_function(
                tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos
            )
        else:
            out_q, out_k, out_v = rope_function(tensor_q, tensor_k, tensor_v)
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155

        fw.append(out_q)
        fw.append(out_k)
        fw.append(out_v)

        out_gq = paddle.randn(out_q.shape, self.dtype)
        out_gk = paddle.randn(out_q.shape, self.dtype)
        out_gv = paddle.randn(out_q.shape, self.dtype)
        paddle.autograd.backward(
            [out_q, out_k, out_v], [out_gq, out_gk, out_gv], True
        )
        bw.append(tensor_q)
        bw.append(tensor_k)
        bw.append(tensor_v)

        return fw, bw

    def test_fused_dropout_add(self):
        p_fw, p_bw = self.get_forward_backward(
            paddle_fused_rotary_position_embedding, seed=self.seed
        )
        f_fw, f_bw = self.get_forward_backward(
            fused_rotary_position_embedding, seed=self.seed
        )
        for i in range(len(p_fw)):
            np.testing.assert_allclose(
                p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05
            )
            np.testing.assert_allclose(
                p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
            )

156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
    def test_fused_dropout_add_sin_cos(self):
        p_fw, p_bw = self.get_forward_backward(
            paddle_fused_rotary_position_embedding, seed=self.seed
        )
        f_fw, f_bw = self.get_forward_backward(
            fused_rotary_position_embedding, seed=self.seed, flag=1
        )
        for i in range(len(p_fw)):
            np.testing.assert_allclose(
                p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05
            )
            np.testing.assert_allclose(
                p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
            )

171 172 173 174 175 176 177 178 179
    def test_error(self):
        paddle.enable_static()
        with self.assertRaises(RuntimeError):
            static_q = paddle.static.data(
                name="q", shape=self.shape, dtype=self.dtype
            )
            fused_rotary_position_embedding(static_q, static_q, static_q)
        paddle.disable_static()

180 181 182

if __name__ == '__main__':
    unittest.main()