test_fused_attention_op.py 13.4 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import unittest

L
Li Min 已提交
17
import numpy as np
18
from op_test import OpTest
L
Li Min 已提交
19 20

import paddle
21
import paddle.incubate.nn.functional as incubate_f
22
import paddle.nn.functional as F
L
Li Min 已提交
23 24
from paddle import tensor
from paddle.fluid import layers
25
from paddle.fluid.framework import default_main_program
26 27 28
from paddle.nn.layer.common import Dropout, Linear
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.transformer import _convert_attention_mask
29 30

default_main_program().random_seed = 42
L
Li Min 已提交
31 32 33 34 35 36


class TestFusedAttentionOp(OpTest):
    def setUp(self):
        self.config()
        self.generate_input_data()
L
Li Min 已提交
37 38 39 40 41 42 43 44 45 46 47 48

        self.rtol = 1e-5
        # FIXME(limin29): Because there is a problem with the test precision
        #  on A100, atol is temporarily set to 1e-2, and it will be
        #  changed back after the precision problem is solved.
        self.atol = 1e-2
        # make sure local development precision
        if "V100" in paddle.device.cuda.get_device_name():
            self.atol = 1e-4
        if self.x_type is np.float16:
            self.atol = 1e-1

L
Li Min 已提交
49 50
        paddle.set_default_dtype(self.x_type)
        self.__class__.op_type = "fused_attention"
51 52
        # use autograd to check grad in this unittest.
        self.__class__.no_need_check_grad = True
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
        self.q_proj = Linear(
            self.embed_dim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
        self.k_proj = Linear(
            self.kdim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
        self.v_proj = Linear(
            self.vdim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
        self.out_proj = Linear(
            self.embed_dim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
L
Li Min 已提交
77 78 79 80 81 82 83 84 85
        paddle.set_default_dtype(np.float32)
        self.norm1 = LayerNorm(self.embed_dim)
        self.norm2 = LayerNorm(self.embed_dim)
        paddle.set_default_dtype(self.x_type)
        self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train")

    def config(self):
        self.x_type = np.float32
        self.attn_mask_type = np.float64
86
        self.pre_layer_norm = False
87
        self.has_attn_mask = True
88
        self.has_cache_kv = False
L
Li Min 已提交
89 90 91 92
        self.training = True

        self.batch_size = 8
        self.query_length = 128
93
        self.cache_length = 128
L
Li Min 已提交
94 95 96 97 98 99 100 101 102
        self.head_dim = 64
        self.num_heads = 16
        self.embed_dim = self.head_dim * self.num_heads

        self.dropout_prob = 0.0
        self.attn_dropout_prob = 0.0
        self.weight_attr = None
        self.bias_attr = None
        self.kdim, self.vdim = self.embed_dim, self.embed_dim
103 104 105 106
        self.key_length, self.value_length = (
            self.query_length,
            self.query_length,
        )
L
Li Min 已提交
107 108

    def generate_input_data(self):
109 110 111
        self.query = np.random.rand(
            self.batch_size, self.query_length, self.embed_dim
        ).astype(self.x_type)
112 113 114
        out_seq_len = self.key_length
        if self.has_cache_kv:
            assert self.training is False, ValueError(
115 116 117 118 119 120 121 122 123
                'cache_kv can only used in inference'
            )
            self.cache_kv = np.random.rand(
                2,
                self.batch_size,
                self.num_heads,
                self.cache_length,
                self.head_dim,
            ).astype(self.x_type)
124 125 126 127
            out_seq_len += self.cache_length
        else:
            self.cache_kv = None

128
        if self.has_attn_mask:
129
            # [B, n_head, seq_len, out_seq_len]
130 131 132 133 134 135 136 137 138
            self.attn_mask = np.ones(
                (
                    self.batch_size,
                    self.num_heads,
                    self.query_length,
                    out_seq_len,
                ),
                dtype=self.attn_mask_type,
            )
139 140 141 142 143 144
            if self.attn_mask_type == np.int64:
                self.attn_mask = np.tril(self.attn_mask)
            elif self.attn_mask_type == np.float64:
                self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
            else:
                raise ValueError(
145 146
                    "'attn_mask_type' should be 'int64' or 'float64'."
                )
L
Li Min 已提交
147
        else:
148
            self.attn_mask = None
L
Li Min 已提交
149 150
        self.key, self.value = self.query, self.query

151 152 153
        self.dout = np.random.random(
            (self.batch_size, self.query_length, self.embed_dim)
        ).astype(self.x_type)
L
Li Min 已提交
154 155 156 157

    def GetBaselineOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
        tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
158 159 160 161 162

        cache_kv = None
        if self.has_cache_kv:
            cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)

163 164 165 166
        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None
L
Li Min 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
        residual = tensor_query

        ln1_out = tensor_query
        if self.pre_layer_norm:
            ln1_out = self.norm1(tensor_query)

        q = self.q_proj(ln1_out)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3])
        k = self.k_proj(ln1_out)
        v = self.v_proj(ln1_out)
        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])

183 184 185 186 187 188 189 190 191 192 193 194
        if self.has_cache_kv:
            # [1, B, n_head, cache_seq_len, head_dim]
            cache_k, cache_v = paddle.split(cache_kv, 2)
            cache_k = paddle.squeeze(cache_k, axis=0)
            cache_v = paddle.squeeze(cache_v, axis=0)
            # [B, n_head, cache_seq_len + seq_len, head_dim]
            # out_seq_len = cache_seq_len + seq_len
            k_out = paddle.concat([cache_k, k_out], axis=-2)
            v_out = paddle.concat([cache_v, v_out], axis=-2)

        # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
        # --> [B, n_head, seq_len, out_seq_len]
195 196 197
        qk_out = layers.matmul(
            x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5
        )
L
Li Min 已提交
198 199 200 201 202 203 204 205 206

        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype)
            attn_mask_out = qk_out + attn_mask
            softmax_out = F.softmax(attn_mask_out)
        else:
            softmax_out = F.softmax(qk_out)

        if self.dropout_prob:
207 208 209 210 211 212
            dropout_out = F.dropout(
                softmax_out,
                self.dropout_prob,
                training=self.training,
                mode="upscale_in_train",
            )
213 214
            # [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
            # --> [B, n_head, seq_len, head_dim]
L
Li Min 已提交
215 216 217 218 219 220
            qktv_out = tensor.matmul(dropout_out, v_out)
        else:
            qktv_out = tensor.matmul(softmax_out, v_out)

        fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3])
        out_linear_in = tensor.reshape(
221 222
            x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]
        )
L
Li Min 已提交
223 224 225 226 227
        out = self.out_proj(out_linear_in)

        residual_out = residual + self.dropout(out)
        if not self.pre_layer_norm:
            final_out = self.norm1(residual_out)
L
Li Min 已提交
228 229
        else:
            final_out = residual_out
230 231 232 233

        if self.has_cache_kv:
            return final_out

234 235 236
        paddle.autograd.backward(
            [final_out], [paddle.to_tensor(self.dout)], retain_graph=True
        )
237
        return final_out, tensor_query.grad
L
Li Min 已提交
238 239 240

    def GetFusedAttentionOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
241 242 243 244 245 246 247 248 249 250 251 252
        q_proj_weight = paddle.to_tensor(
            self.q_proj.weight, stop_gradient=False
        )
        k_proj_weight = paddle.to_tensor(
            self.k_proj.weight, stop_gradient=False
        )
        v_proj_weight = paddle.to_tensor(
            self.v_proj.weight, stop_gradient=False
        )
        out_linear_weight = paddle.to_tensor(
            self.out_proj.weight, stop_gradient=False
        )
253 254 255 256 257

        if self.bias_attr is False:
            qkv_bias_tensor = None
            out_linear_bias = None
        else:
258 259 260 261 262 263 264 265 266
            q_proj_bias = paddle.to_tensor(
                self.q_proj.bias, stop_gradient=False
            )
            k_proj_bias = paddle.to_tensor(
                self.k_proj.bias, stop_gradient=False
            )
            v_proj_bias = paddle.to_tensor(
                self.v_proj.bias, stop_gradient=False
            )
267
            qkv_bias = np.concatenate(
268 269
                (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())
            )
270 271
            qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
            qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
272 273 274
            out_linear_bias = paddle.to_tensor(
                self.out_proj.bias, stop_gradient=False
            )
L
Li Min 已提交
275 276 277 278 279 280 281 282 283 284

        ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False)
        ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False)
        ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False)
        ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False)

        q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
        k_proj_weight = k_proj_weight.numpy().transpose((1, 0))
        v_proj_weight = v_proj_weight.numpy().transpose((1, 0))
        qkv_weight = np.concatenate(
285 286
            (q_proj_weight, k_proj_weight, v_proj_weight)
        )
L
Li Min 已提交
287
        qkv_weight = qkv_weight.reshape(
288 289
            (3, self.num_heads, self.head_dim, self.embed_dim)
        )
L
Li Min 已提交
290 291

        x = paddle.to_tensor(self.query, stop_gradient=False)
292 293 294
        cache_kv = None
        if self.has_cache_kv:
            cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)
295 296 297 298
        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None
L
Li Min 已提交
299 300 301 302 303 304
        qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
        epsilon = 1e-05
        ln2_epsilon = 1e-05

        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, x.dtype)
305
        final_out = incubate_f.fused_multi_head_attention(
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
            x,
            qkv_weight_tensor,
            out_linear_weight,
            self.pre_layer_norm,
            ln1_scale,
            ln1_bias,
            ln2_scale,
            ln2_bias,
            epsilon,
            qkv_bias_tensor,
            out_linear_bias,
            cache_kv,
            attn_mask,
            self.dropout_prob,
            self.attn_dropout_prob,
            ln2_epsilon,
        )
323 324 325 326

        if self.has_cache_kv:
            return final_out[0], final_out[1]

327 328 329
        paddle.autograd.backward(
            [final_out], [paddle.to_tensor(self.dout)], retain_graph=True
        )
330
        return final_out, x.grad
L
Li Min 已提交
331 332

    def test_fused_attention_op(self):
333 334
        final_out_ref, x_grad_ref = self.GetBaselineOut()
        final_out, x_grad = self.GetFusedAttentionOut()
335 336 337 338 339 340
        np.testing.assert_allclose(
            final_out_ref, final_out.numpy(), rtol=self.rtol, atol=self.atol
        )
        np.testing.assert_allclose(
            x_grad_ref, x_grad.numpy(), rtol=self.rtol, atol=self.atol
        )
L
Li Min 已提交
341 342


343 344
class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp):
    def config(self):
345
        super().config()
346 347 348
        self.bias_attr = False


349 350
class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
    def config(self):
351
        super().config()
352
        self.pre_layer_norm = True
353 354 355 356


class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
    def config(self):
357
        super().config()
358 359
        self.pre_layer_norm = True
        self.has_attn_mask = False
360 361


L
Li Min 已提交
362 363
class TestFusedAttentionOpFp16(TestFusedAttentionOp):
    def config(self):
364
        super().config()
L
Li Min 已提交
365 366 367
        self.x_type = np.float16

    def test_fused_attention_op(self):
368 369
        final_out_ref, x_grad_ref = self.GetBaselineOut()
        final_out, x_grad = self.GetFusedAttentionOut()
370 371 372 373 374 375
        np.testing.assert_allclose(
            final_out_ref, final_out.numpy(), rtol=self.rtol, atol=self.atol
        )
        np.testing.assert_allclose(
            x_grad_ref, x_grad.numpy(), rtol=self.rtol, atol=self.atol
        )
L
Li Min 已提交
376 377


378 379 380 381 382 383 384 385 386 387 388 389
class TestFusedAttentionOpCacheKV(TestFusedAttentionOp):
    def config(self):
        super().config()
        self.has_cache_kv = True
        self.training = False
        self.query_length = 1
        self.key_length, self.value_length = 1, 1

    def test_fused_attention_op(self):
        with paddle.no_grad():
            final_out_ref = self.GetBaselineOut()
            final_out, cache_kv_out = self.GetFusedAttentionOut()
390 391 392
            np.testing.assert_allclose(
                final_out_ref, final_out.numpy(), rtol=self.rtol, atol=self.atol
            )
393 394


L
Li Min 已提交
395 396
if __name__ == "__main__":
    unittest.main()