test_fused_attention_op.py 13.9 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

import paddle
import paddle.nn.functional as F
19
import paddle.incubate.nn.functional as incubate_f
L
Li Min 已提交
20 21 22 23 24 25 26
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.common import Linear, Dropout
from paddle.nn.layer.transformer import _convert_attention_mask
from paddle import tensor
from paddle.fluid import layers
import unittest
from op_test import OpTest
27
from paddle.fluid.framework import default_main_program
28 29

default_main_program().random_seed = 42
L
Li Min 已提交
30 31 32


class TestFusedAttentionOp(OpTest):
33

L
Li Min 已提交
34 35 36
    def setUp(self):
        self.config()
        self.generate_input_data()
L
Li Min 已提交
37 38 39 40 41 42 43 44 45 46 47 48

        self.rtol = 1e-5
        # FIXME(limin29): Because there is a problem with the test precision
        #  on A100, atol is temporarily set to 1e-2, and it will be
        #  changed back after the precision problem is solved.
        self.atol = 1e-2
        # make sure local development precision
        if "V100" in paddle.device.cuda.get_device_name():
            self.atol = 1e-4
        if self.x_type is np.float16:
            self.atol = 1e-1

L
Li Min 已提交
49 50
        paddle.set_default_dtype(self.x_type)
        self.__class__.op_type = "fused_attention"
51 52
        # use autograd to check grad in this unittest.
        self.__class__.no_need_check_grad = True
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
        self.q_proj = Linear(self.embed_dim,
                             self.embed_dim,
                             self.weight_attr,
                             bias_attr=self.bias_attr)
        self.k_proj = Linear(self.kdim,
                             self.embed_dim,
                             self.weight_attr,
                             bias_attr=self.bias_attr)
        self.v_proj = Linear(self.vdim,
                             self.embed_dim,
                             self.weight_attr,
                             bias_attr=self.bias_attr)
        self.out_proj = Linear(self.embed_dim,
                               self.embed_dim,
                               self.weight_attr,
                               bias_attr=self.bias_attr)
L
Li Min 已提交
69 70 71 72 73 74 75 76 77
        paddle.set_default_dtype(np.float32)
        self.norm1 = LayerNorm(self.embed_dim)
        self.norm2 = LayerNorm(self.embed_dim)
        paddle.set_default_dtype(self.x_type)
        self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train")

    def config(self):
        self.x_type = np.float32
        self.attn_mask_type = np.float64
78
        self.pre_layer_norm = False
79
        self.has_attn_mask = True
80
        self.has_cache_kv = False
L
Li Min 已提交
81 82 83 84
        self.training = True

        self.batch_size = 8
        self.query_length = 128
85
        self.cache_length = 128
L
Li Min 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99
        self.head_dim = 64
        self.num_heads = 16
        self.embed_dim = self.head_dim * self.num_heads

        self.dropout_prob = 0.0
        self.attn_dropout_prob = 0.0
        self.weight_attr = None
        self.bias_attr = None
        self.kdim, self.vdim = self.embed_dim, self.embed_dim
        self.key_length, self.value_length = self.query_length, self.query_length

    def generate_input_data(self):
        self.query = np.random.rand(self.batch_size, self.query_length,
                                    self.embed_dim).astype(self.x_type)
100 101 102 103 104 105 106 107 108 109 110
        out_seq_len = self.key_length
        if self.has_cache_kv:
            assert self.training is False, ValueError(
                'cache_kv can only used in inference')
            self.cache_kv = np.random.rand(2, self.batch_size, self.num_heads,
                                           self.cache_length,
                                           self.head_dim).astype(self.x_type)
            out_seq_len += self.cache_length
        else:
            self.cache_kv = None

111
        if self.has_attn_mask:
112
            # [B, n_head, seq_len, out_seq_len]
113 114 115
            self.attn_mask = np.ones((self.batch_size, self.num_heads,
                                      self.query_length, out_seq_len),
                                     dtype=self.attn_mask_type)
116 117 118 119 120 121 122
            if self.attn_mask_type == np.int64:
                self.attn_mask = np.tril(self.attn_mask)
            elif self.attn_mask_type == np.float64:
                self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
            else:
                raise ValueError(
                    "'attn_mask_type' should be 'int64' or 'float64'.")
L
Li Min 已提交
123
        else:
124
            self.attn_mask = None
L
Li Min 已提交
125 126 127 128 129 130 131 132
        self.key, self.value = self.query, self.query

        self.dout = np.random.random((self.batch_size, self.query_length,
                                      self.embed_dim)).astype(self.x_type)

    def GetBaselineOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
        tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
133 134 135 136 137

        cache_kv = None
        if self.has_cache_kv:
            cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)

138 139 140 141
        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None
L
Li Min 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
        residual = tensor_query

        ln1_out = tensor_query
        if self.pre_layer_norm:
            ln1_out = self.norm1(tensor_query)

        q = self.q_proj(ln1_out)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3])
        k = self.k_proj(ln1_out)
        v = self.v_proj(ln1_out)
        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])

158 159 160 161 162 163 164 165 166 167 168 169
        if self.has_cache_kv:
            # [1, B, n_head, cache_seq_len, head_dim]
            cache_k, cache_v = paddle.split(cache_kv, 2)
            cache_k = paddle.squeeze(cache_k, axis=0)
            cache_v = paddle.squeeze(cache_v, axis=0)
            # [B, n_head, cache_seq_len + seq_len, head_dim]
            # out_seq_len = cache_seq_len + seq_len
            k_out = paddle.concat([cache_k, k_out], axis=-2)
            v_out = paddle.concat([cache_v, v_out], axis=-2)

        # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
        # --> [B, n_head, seq_len, out_seq_len]
170 171 172 173
        qk_out = layers.matmul(x=q_out,
                               y=k_out,
                               transpose_y=True,
                               alpha=self.head_dim**-0.5)
L
Li Min 已提交
174 175 176 177 178 179 180 181 182

        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype)
            attn_mask_out = qk_out + attn_mask
            softmax_out = F.softmax(attn_mask_out)
        else:
            softmax_out = F.softmax(qk_out)

        if self.dropout_prob:
183 184 185 186
            dropout_out = F.dropout(softmax_out,
                                    self.dropout_prob,
                                    training=self.training,
                                    mode="upscale_in_train")
187 188
            # [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
            # --> [B, n_head, seq_len, head_dim]
L
Li Min 已提交
189 190 191 192 193 194 195 196 197 198 199 200
            qktv_out = tensor.matmul(dropout_out, v_out)
        else:
            qktv_out = tensor.matmul(softmax_out, v_out)

        fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3])
        out_linear_in = tensor.reshape(
            x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]])
        out = self.out_proj(out_linear_in)

        residual_out = residual + self.dropout(out)
        if not self.pre_layer_norm:
            final_out = self.norm1(residual_out)
L
Li Min 已提交
201 202
        else:
            final_out = residual_out
203 204 205 206

        if self.has_cache_kv:
            return final_out

207 208
        paddle.autograd.backward([final_out], [paddle.to_tensor(self.dout)],
                                 retain_graph=True)
209
        return final_out, tensor_query.grad
L
Li Min 已提交
210 211 212

    def GetFusedAttentionOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
213 214 215 216 217 218 219 220
        q_proj_weight = paddle.to_tensor(self.q_proj.weight,
                                         stop_gradient=False)
        k_proj_weight = paddle.to_tensor(self.k_proj.weight,
                                         stop_gradient=False)
        v_proj_weight = paddle.to_tensor(self.v_proj.weight,
                                         stop_gradient=False)
        out_linear_weight = paddle.to_tensor(self.out_proj.weight,
                                             stop_gradient=False)
221 222 223 224 225

        if self.bias_attr is False:
            qkv_bias_tensor = None
            out_linear_bias = None
        else:
226 227 228 229 230 231
            q_proj_bias = paddle.to_tensor(self.q_proj.bias,
                                           stop_gradient=False)
            k_proj_bias = paddle.to_tensor(self.k_proj.bias,
                                           stop_gradient=False)
            v_proj_bias = paddle.to_tensor(self.v_proj.bias,
                                           stop_gradient=False)
232 233 234 235
            qkv_bias = np.concatenate(
                (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()))
            qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
            qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
236 237
            out_linear_bias = paddle.to_tensor(self.out_proj.bias,
                                               stop_gradient=False)
L
Li Min 已提交
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252

        ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False)
        ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False)
        ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False)
        ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False)

        q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
        k_proj_weight = k_proj_weight.numpy().transpose((1, 0))
        v_proj_weight = v_proj_weight.numpy().transpose((1, 0))
        qkv_weight = np.concatenate(
            (q_proj_weight, k_proj_weight, v_proj_weight))
        qkv_weight = qkv_weight.reshape(
            (3, self.num_heads, self.head_dim, self.embed_dim))

        x = paddle.to_tensor(self.query, stop_gradient=False)
253 254 255
        cache_kv = None
        if self.has_cache_kv:
            cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)
256 257 258 259
        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None
L
Li Min 已提交
260 261 262 263 264 265
        qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
        epsilon = 1e-05
        ln2_epsilon = 1e-05

        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, x.dtype)
266
        final_out = incubate_f.fused_multi_head_attention(
L
Li Min 已提交
267 268
            x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm,
            ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor,
269
            out_linear_bias, cache_kv, attn_mask, self.dropout_prob,
L
Li Min 已提交
270
            self.attn_dropout_prob, ln2_epsilon)
271 272 273 274

        if self.has_cache_kv:
            return final_out[0], final_out[1]

275 276
        paddle.autograd.backward([final_out], [paddle.to_tensor(self.dout)],
                                 retain_graph=True)
277
        return final_out, x.grad
L
Li Min 已提交
278 279

    def test_fused_attention_op(self):
280 281
        final_out_ref, x_grad_ref = self.GetBaselineOut()
        final_out, x_grad = self.GetFusedAttentionOut()
282 283 284 285 286 287 288 289
        np.testing.assert_allclose(final_out_ref,
                                   final_out.numpy(),
                                   rtol=self.rtol,
                                   atol=self.atol)
        np.testing.assert_allclose(x_grad_ref,
                                   x_grad.numpy(),
                                   rtol=self.rtol,
                                   atol=self.atol)
L
Li Min 已提交
290 291


292
class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp):
293

294
    def config(self):
295
        super().config()
296 297 298
        self.bias_attr = False


299
class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
300

301
    def config(self):
302
        super().config()
303
        self.pre_layer_norm = True
304 305 306


class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
307

308
    def config(self):
309
        super().config()
310 311
        self.pre_layer_norm = True
        self.has_attn_mask = False
312 313


L
Li Min 已提交
314
class TestFusedAttentionOpFp16(TestFusedAttentionOp):
315

L
Li Min 已提交
316
    def config(self):
317
        super().config()
L
Li Min 已提交
318 319 320
        self.x_type = np.float16

    def test_fused_attention_op(self):
321 322
        final_out_ref, x_grad_ref = self.GetBaselineOut()
        final_out, x_grad = self.GetFusedAttentionOut()
323 324 325 326 327 328 329 330
        np.testing.assert_allclose(final_out_ref,
                                   final_out.numpy(),
                                   rtol=self.rtol,
                                   atol=self.atol)
        np.testing.assert_allclose(x_grad_ref,
                                   x_grad.numpy(),
                                   rtol=self.rtol,
                                   atol=self.atol)
L
Li Min 已提交
331 332


333
class TestFusedAttentionOpCacheKV(TestFusedAttentionOp):
334

335 336 337 338 339 340 341 342 343 344 345
    def config(self):
        super().config()
        self.has_cache_kv = True
        self.training = False
        self.query_length = 1
        self.key_length, self.value_length = 1, 1

    def test_fused_attention_op(self):
        with paddle.no_grad():
            final_out_ref = self.GetBaselineOut()
            final_out, cache_kv_out = self.GetFusedAttentionOut()
346 347 348 349
            np.testing.assert_allclose(final_out_ref,
                                       final_out.numpy(),
                                       rtol=self.rtol,
                                       atol=self.atol)
350 351


L
Li Min 已提交
352 353
if __name__ == "__main__":
    unittest.main()