test_fused_attention_op.py 13.4 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

import paddle
import paddle.nn.functional as F
19
import paddle.incubate.nn.functional as incubate_f
L
Li Min 已提交
20 21 22 23 24 25 26
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.common import Linear, Dropout
from paddle.nn.layer.transformer import _convert_attention_mask
from paddle import tensor
from paddle.fluid import layers
import unittest
from op_test import OpTest
27
from paddle.fluid.framework import default_main_program
28 29

default_main_program().random_seed = 42
L
Li Min 已提交
30 31 32 33 34 35


class TestFusedAttentionOp(OpTest):
    def setUp(self):
        self.config()
        self.generate_input_data()
L
Li Min 已提交
36 37 38 39 40 41 42 43 44 45 46 47

        self.rtol = 1e-5
        # FIXME(limin29): Because there is a problem with the test precision
        #  on A100, atol is temporarily set to 1e-2, and it will be
        #  changed back after the precision problem is solved.
        self.atol = 1e-2
        # make sure local development precision
        if "V100" in paddle.device.cuda.get_device_name():
            self.atol = 1e-4
        if self.x_type is np.float16:
            self.atol = 1e-1

L
Li Min 已提交
48 49
        paddle.set_default_dtype(self.x_type)
        self.__class__.op_type = "fused_attention"
50 51
        # use autograd to check grad in this unittest.
        self.__class__.no_need_check_grad = True
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
        self.q_proj = Linear(
            self.embed_dim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
        self.k_proj = Linear(
            self.kdim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
        self.v_proj = Linear(
            self.vdim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
        self.out_proj = Linear(
            self.embed_dim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
L
Li Min 已提交
76 77 78 79 80 81 82 83 84
        paddle.set_default_dtype(np.float32)
        self.norm1 = LayerNorm(self.embed_dim)
        self.norm2 = LayerNorm(self.embed_dim)
        paddle.set_default_dtype(self.x_type)
        self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train")

    def config(self):
        self.x_type = np.float32
        self.attn_mask_type = np.float64
85
        self.pre_layer_norm = False
86
        self.has_attn_mask = True
87
        self.has_cache_kv = False
L
Li Min 已提交
88 89 90 91
        self.training = True

        self.batch_size = 8
        self.query_length = 128
92
        self.cache_length = 128
L
Li Min 已提交
93 94 95 96 97 98 99 100 101
        self.head_dim = 64
        self.num_heads = 16
        self.embed_dim = self.head_dim * self.num_heads

        self.dropout_prob = 0.0
        self.attn_dropout_prob = 0.0
        self.weight_attr = None
        self.bias_attr = None
        self.kdim, self.vdim = self.embed_dim, self.embed_dim
102 103 104 105
        self.key_length, self.value_length = (
            self.query_length,
            self.query_length,
        )
L
Li Min 已提交
106 107

    def generate_input_data(self):
108 109 110
        self.query = np.random.rand(
            self.batch_size, self.query_length, self.embed_dim
        ).astype(self.x_type)
111 112 113
        out_seq_len = self.key_length
        if self.has_cache_kv:
            assert self.training is False, ValueError(
114 115 116 117 118 119 120 121 122
                'cache_kv can only used in inference'
            )
            self.cache_kv = np.random.rand(
                2,
                self.batch_size,
                self.num_heads,
                self.cache_length,
                self.head_dim,
            ).astype(self.x_type)
123 124 125 126
            out_seq_len += self.cache_length
        else:
            self.cache_kv = None

127
        if self.has_attn_mask:
128
            # [B, n_head, seq_len, out_seq_len]
129 130 131 132 133 134 135 136 137
            self.attn_mask = np.ones(
                (
                    self.batch_size,
                    self.num_heads,
                    self.query_length,
                    out_seq_len,
                ),
                dtype=self.attn_mask_type,
            )
138 139 140 141 142 143
            if self.attn_mask_type == np.int64:
                self.attn_mask = np.tril(self.attn_mask)
            elif self.attn_mask_type == np.float64:
                self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
            else:
                raise ValueError(
144 145
                    "'attn_mask_type' should be 'int64' or 'float64'."
                )
L
Li Min 已提交
146
        else:
147
            self.attn_mask = None
L
Li Min 已提交
148 149
        self.key, self.value = self.query, self.query

150 151 152
        self.dout = np.random.random(
            (self.batch_size, self.query_length, self.embed_dim)
        ).astype(self.x_type)
L
Li Min 已提交
153 154 155 156

    def GetBaselineOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
        tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
157 158 159 160 161

        cache_kv = None
        if self.has_cache_kv:
            cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)

162 163 164 165
        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None
L
Li Min 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
        residual = tensor_query

        ln1_out = tensor_query
        if self.pre_layer_norm:
            ln1_out = self.norm1(tensor_query)

        q = self.q_proj(ln1_out)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3])
        k = self.k_proj(ln1_out)
        v = self.v_proj(ln1_out)
        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])

182 183 184 185 186 187 188 189 190 191 192 193
        if self.has_cache_kv:
            # [1, B, n_head, cache_seq_len, head_dim]
            cache_k, cache_v = paddle.split(cache_kv, 2)
            cache_k = paddle.squeeze(cache_k, axis=0)
            cache_v = paddle.squeeze(cache_v, axis=0)
            # [B, n_head, cache_seq_len + seq_len, head_dim]
            # out_seq_len = cache_seq_len + seq_len
            k_out = paddle.concat([cache_k, k_out], axis=-2)
            v_out = paddle.concat([cache_v, v_out], axis=-2)

        # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
        # --> [B, n_head, seq_len, out_seq_len]
194 195 196
        qk_out = layers.matmul(
            x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5
        )
L
Li Min 已提交
197 198 199 200 201 202 203 204 205

        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype)
            attn_mask_out = qk_out + attn_mask
            softmax_out = F.softmax(attn_mask_out)
        else:
            softmax_out = F.softmax(qk_out)

        if self.dropout_prob:
206 207 208 209 210 211
            dropout_out = F.dropout(
                softmax_out,
                self.dropout_prob,
                training=self.training,
                mode="upscale_in_train",
            )
212 213
            # [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
            # --> [B, n_head, seq_len, head_dim]
L
Li Min 已提交
214 215 216 217 218 219
            qktv_out = tensor.matmul(dropout_out, v_out)
        else:
            qktv_out = tensor.matmul(softmax_out, v_out)

        fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3])
        out_linear_in = tensor.reshape(
220 221
            x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]
        )
L
Li Min 已提交
222 223 224 225 226
        out = self.out_proj(out_linear_in)

        residual_out = residual + self.dropout(out)
        if not self.pre_layer_norm:
            final_out = self.norm1(residual_out)
L
Li Min 已提交
227 228
        else:
            final_out = residual_out
229 230 231 232

        if self.has_cache_kv:
            return final_out

233 234 235
        paddle.autograd.backward(
            [final_out], [paddle.to_tensor(self.dout)], retain_graph=True
        )
236
        return final_out, tensor_query.grad
L
Li Min 已提交
237 238 239

    def GetFusedAttentionOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
240 241 242 243 244 245 246 247 248 249 250 251
        q_proj_weight = paddle.to_tensor(
            self.q_proj.weight, stop_gradient=False
        )
        k_proj_weight = paddle.to_tensor(
            self.k_proj.weight, stop_gradient=False
        )
        v_proj_weight = paddle.to_tensor(
            self.v_proj.weight, stop_gradient=False
        )
        out_linear_weight = paddle.to_tensor(
            self.out_proj.weight, stop_gradient=False
        )
252 253 254 255 256

        if self.bias_attr is False:
            qkv_bias_tensor = None
            out_linear_bias = None
        else:
257 258 259 260 261 262 263 264 265
            q_proj_bias = paddle.to_tensor(
                self.q_proj.bias, stop_gradient=False
            )
            k_proj_bias = paddle.to_tensor(
                self.k_proj.bias, stop_gradient=False
            )
            v_proj_bias = paddle.to_tensor(
                self.v_proj.bias, stop_gradient=False
            )
266
            qkv_bias = np.concatenate(
267 268
                (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())
            )
269 270
            qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
            qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
271 272 273
            out_linear_bias = paddle.to_tensor(
                self.out_proj.bias, stop_gradient=False
            )
L
Li Min 已提交
274 275 276 277 278 279 280 281 282 283

        ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False)
        ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False)
        ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False)
        ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False)

        q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
        k_proj_weight = k_proj_weight.numpy().transpose((1, 0))
        v_proj_weight = v_proj_weight.numpy().transpose((1, 0))
        qkv_weight = np.concatenate(
284 285
            (q_proj_weight, k_proj_weight, v_proj_weight)
        )
L
Li Min 已提交
286
        qkv_weight = qkv_weight.reshape(
287 288
            (3, self.num_heads, self.head_dim, self.embed_dim)
        )
L
Li Min 已提交
289 290

        x = paddle.to_tensor(self.query, stop_gradient=False)
291 292 293
        cache_kv = None
        if self.has_cache_kv:
            cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)
294 295 296 297
        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None
L
Li Min 已提交
298 299 300 301 302 303
        qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
        epsilon = 1e-05
        ln2_epsilon = 1e-05

        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, x.dtype)
304
        final_out = incubate_f.fused_multi_head_attention(
305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
            x,
            qkv_weight_tensor,
            out_linear_weight,
            self.pre_layer_norm,
            ln1_scale,
            ln1_bias,
            ln2_scale,
            ln2_bias,
            epsilon,
            qkv_bias_tensor,
            out_linear_bias,
            cache_kv,
            attn_mask,
            self.dropout_prob,
            self.attn_dropout_prob,
            ln2_epsilon,
        )
322 323 324 325

        if self.has_cache_kv:
            return final_out[0], final_out[1]

326 327 328
        paddle.autograd.backward(
            [final_out], [paddle.to_tensor(self.dout)], retain_graph=True
        )
329
        return final_out, x.grad
L
Li Min 已提交
330 331

    def test_fused_attention_op(self):
332 333
        final_out_ref, x_grad_ref = self.GetBaselineOut()
        final_out, x_grad = self.GetFusedAttentionOut()
334 335 336 337 338 339
        np.testing.assert_allclose(
            final_out_ref, final_out.numpy(), rtol=self.rtol, atol=self.atol
        )
        np.testing.assert_allclose(
            x_grad_ref, x_grad.numpy(), rtol=self.rtol, atol=self.atol
        )
L
Li Min 已提交
340 341


342 343
class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp):
    def config(self):
344
        super().config()
345 346 347
        self.bias_attr = False


348 349
class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
    def config(self):
350
        super().config()
351
        self.pre_layer_norm = True
352 353 354 355


class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
    def config(self):
356
        super().config()
357 358
        self.pre_layer_norm = True
        self.has_attn_mask = False
359 360


L
Li Min 已提交
361 362
class TestFusedAttentionOpFp16(TestFusedAttentionOp):
    def config(self):
363
        super().config()
L
Li Min 已提交
364 365 366
        self.x_type = np.float16

    def test_fused_attention_op(self):
367 368
        final_out_ref, x_grad_ref = self.GetBaselineOut()
        final_out, x_grad = self.GetFusedAttentionOut()
369 370 371 372 373 374
        np.testing.assert_allclose(
            final_out_ref, final_out.numpy(), rtol=self.rtol, atol=self.atol
        )
        np.testing.assert_allclose(
            x_grad_ref, x_grad.numpy(), rtol=self.rtol, atol=self.atol
        )
L
Li Min 已提交
375 376


377 378 379 380 381 382 383 384 385 386 387 388
class TestFusedAttentionOpCacheKV(TestFusedAttentionOp):
    def config(self):
        super().config()
        self.has_cache_kv = True
        self.training = False
        self.query_length = 1
        self.key_length, self.value_length = 1, 1

    def test_fused_attention_op(self):
        with paddle.no_grad():
            final_out_ref = self.GetBaselineOut()
            final_out, cache_kv_out = self.GetFusedAttentionOut()
389 390 391
            np.testing.assert_allclose(
                final_out_ref, final_out.numpy(), rtol=self.rtol, atol=self.atol
            )
392 393


L
Li Min 已提交
394 395
if __name__ == "__main__":
    unittest.main()