test_fused_attention_op.py 24.6 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import unittest

L
Li Min 已提交
17
import numpy as np
18
from op_test import OpTest
L
Li Min 已提交
19 20

import paddle
21
import paddle.incubate.nn.functional as incubate_f
22
import paddle.nn.functional as F
L
Li Min 已提交
23
from paddle import tensor
24
from paddle.fluid.framework import default_main_program
25 26 27
from paddle.nn.layer.common import Dropout, Linear
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.transformer import _convert_attention_mask
28 29

default_main_program().random_seed = 42
L
Li Min 已提交
30 31 32 33 34 35


class TestFusedAttentionOp(OpTest):
    def setUp(self):
        self.config()
        self.generate_input_data()
L
Li Min 已提交
36 37 38 39 40 41 42 43 44 45 46 47

        self.rtol = 1e-5
        # FIXME(limin29): Because there is a problem with the test precision
        #  on A100, atol is temporarily set to 1e-2, and it will be
        #  changed back after the precision problem is solved.
        self.atol = 1e-2
        # make sure local development precision
        if "V100" in paddle.device.cuda.get_device_name():
            self.atol = 1e-4
        if self.x_type is np.float16:
            self.atol = 1e-1

L
Li Min 已提交
48 49
        paddle.set_default_dtype(self.x_type)
        self.__class__.op_type = "fused_attention"
50 51
        # use autograd to check grad in this unittest.
        self.__class__.no_need_check_grad = True
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
        self.q_proj = Linear(
            self.embed_dim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
        self.k_proj = Linear(
            self.kdim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
        self.v_proj = Linear(
            self.vdim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
        self.out_proj = Linear(
            self.embed_dim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
L
Li Min 已提交
76 77 78 79 80 81 82 83 84
        paddle.set_default_dtype(np.float32)
        self.norm1 = LayerNorm(self.embed_dim)
        self.norm2 = LayerNorm(self.embed_dim)
        paddle.set_default_dtype(self.x_type)
        self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train")

    def config(self):
        self.x_type = np.float32
        self.attn_mask_type = np.float64
85
        self.pre_layer_norm = False
86
        self.has_attn_mask = True
87
        self.has_cache_kv = False
L
Li Min 已提交
88 89 90 91
        self.training = True

        self.batch_size = 8
        self.query_length = 128
92
        self.cache_length = 128
L
Li Min 已提交
93 94 95 96 97 98 99 100 101
        self.head_dim = 64
        self.num_heads = 16
        self.embed_dim = self.head_dim * self.num_heads

        self.dropout_prob = 0.0
        self.attn_dropout_prob = 0.0
        self.weight_attr = None
        self.bias_attr = None
        self.kdim, self.vdim = self.embed_dim, self.embed_dim
102 103 104 105
        self.key_length, self.value_length = (
            self.query_length,
            self.query_length,
        )
L
Li Min 已提交
106 107

    def generate_input_data(self):
108 109 110
        self.query = np.random.rand(
            self.batch_size, self.query_length, self.embed_dim
        ).astype(self.x_type)
111 112 113
        out_seq_len = self.key_length
        if self.has_cache_kv:
            assert self.training is False, ValueError(
114 115 116 117 118 119 120 121 122
                'cache_kv can only used in inference'
            )
            self.cache_kv = np.random.rand(
                2,
                self.batch_size,
                self.num_heads,
                self.cache_length,
                self.head_dim,
            ).astype(self.x_type)
123 124 125 126
            out_seq_len += self.cache_length
        else:
            self.cache_kv = None

127
        if self.has_attn_mask:
128
            # [B, n_head, seq_len, out_seq_len]
129 130 131 132 133 134 135 136 137
            self.attn_mask = np.ones(
                (
                    self.batch_size,
                    self.num_heads,
                    self.query_length,
                    out_seq_len,
                ),
                dtype=self.attn_mask_type,
            )
138 139 140 141 142 143
            if self.attn_mask_type == np.int64:
                self.attn_mask = np.tril(self.attn_mask)
            elif self.attn_mask_type == np.float64:
                self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
            else:
                raise ValueError(
144 145
                    "'attn_mask_type' should be 'int64' or 'float64'."
                )
L
Li Min 已提交
146
        else:
147
            self.attn_mask = None
L
Li Min 已提交
148 149
        self.key, self.value = self.query, self.query

150 151 152
        self.dout = np.random.random(
            (self.batch_size, self.query_length, self.embed_dim)
        ).astype(self.x_type)
L
Li Min 已提交
153 154 155 156

    def GetBaselineOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
        tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
157 158 159 160 161

        cache_kv = None
        if self.has_cache_kv:
            cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)

162 163 164 165
        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None
L
Li Min 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
        residual = tensor_query

        ln1_out = tensor_query
        if self.pre_layer_norm:
            ln1_out = self.norm1(tensor_query)

        q = self.q_proj(ln1_out)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3])
        k = self.k_proj(ln1_out)
        v = self.v_proj(ln1_out)
        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])

182 183 184 185 186 187 188 189 190 191 192 193
        if self.has_cache_kv:
            # [1, B, n_head, cache_seq_len, head_dim]
            cache_k, cache_v = paddle.split(cache_kv, 2)
            cache_k = paddle.squeeze(cache_k, axis=0)
            cache_v = paddle.squeeze(cache_v, axis=0)
            # [B, n_head, cache_seq_len + seq_len, head_dim]
            # out_seq_len = cache_seq_len + seq_len
            k_out = paddle.concat([cache_k, k_out], axis=-2)
            v_out = paddle.concat([cache_v, v_out], axis=-2)

        # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
        # --> [B, n_head, seq_len, out_seq_len]
K
kangguangli 已提交
194 195
        qk_out = paddle.matmul(x=q_out, y=k_out, transpose_y=True)
        qk_out = paddle.scale(qk_out, scale=self.head_dim**-0.5)
L
Li Min 已提交
196 197 198 199 200 201 202 203 204

        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype)
            attn_mask_out = qk_out + attn_mask
            softmax_out = F.softmax(attn_mask_out)
        else:
            softmax_out = F.softmax(qk_out)

        if self.dropout_prob:
205 206 207 208 209 210
            dropout_out = F.dropout(
                softmax_out,
                self.dropout_prob,
                training=self.training,
                mode="upscale_in_train",
            )
211 212
            # [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
            # --> [B, n_head, seq_len, head_dim]
L
Li Min 已提交
213 214 215 216 217 218
            qktv_out = tensor.matmul(dropout_out, v_out)
        else:
            qktv_out = tensor.matmul(softmax_out, v_out)

        fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3])
        out_linear_in = tensor.reshape(
219 220
            x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]
        )
L
Li Min 已提交
221 222 223 224 225
        out = self.out_proj(out_linear_in)

        residual_out = residual + self.dropout(out)
        if not self.pre_layer_norm:
            final_out = self.norm1(residual_out)
L
Li Min 已提交
226 227
        else:
            final_out = residual_out
228 229 230 231

        if self.has_cache_kv:
            return final_out

232 233 234
        paddle.autograd.backward(
            [final_out], [paddle.to_tensor(self.dout)], retain_graph=True
        )
235
        return final_out, tensor_query.grad
L
Li Min 已提交
236 237 238

    def GetFusedAttentionOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
239 240 241 242 243 244 245 246 247 248 249 250
        q_proj_weight = paddle.to_tensor(
            self.q_proj.weight, stop_gradient=False
        )
        k_proj_weight = paddle.to_tensor(
            self.k_proj.weight, stop_gradient=False
        )
        v_proj_weight = paddle.to_tensor(
            self.v_proj.weight, stop_gradient=False
        )
        out_linear_weight = paddle.to_tensor(
            self.out_proj.weight, stop_gradient=False
        )
251 252 253 254 255

        if self.bias_attr is False:
            qkv_bias_tensor = None
            out_linear_bias = None
        else:
256 257 258 259 260 261 262 263 264
            q_proj_bias = paddle.to_tensor(
                self.q_proj.bias, stop_gradient=False
            )
            k_proj_bias = paddle.to_tensor(
                self.k_proj.bias, stop_gradient=False
            )
            v_proj_bias = paddle.to_tensor(
                self.v_proj.bias, stop_gradient=False
            )
265
            qkv_bias = np.concatenate(
266 267
                (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())
            )
268 269
            qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
            qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
270 271 272
            out_linear_bias = paddle.to_tensor(
                self.out_proj.bias, stop_gradient=False
            )
L
Li Min 已提交
273 274 275 276 277 278 279 280 281 282

        ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False)
        ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False)
        ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False)
        ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False)

        q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
        k_proj_weight = k_proj_weight.numpy().transpose((1, 0))
        v_proj_weight = v_proj_weight.numpy().transpose((1, 0))
        qkv_weight = np.concatenate(
283 284
            (q_proj_weight, k_proj_weight, v_proj_weight)
        )
L
Li Min 已提交
285
        qkv_weight = qkv_weight.reshape(
286 287
            (3, self.num_heads, self.head_dim, self.embed_dim)
        )
L
Li Min 已提交
288 289

        x = paddle.to_tensor(self.query, stop_gradient=False)
290 291 292
        cache_kv = None
        if self.has_cache_kv:
            cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)
293 294 295 296
        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None
L
Li Min 已提交
297 298 299 300 301 302
        qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
        epsilon = 1e-05
        ln2_epsilon = 1e-05

        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, x.dtype)
303
        final_out = incubate_f.fused_multi_head_attention(
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
            x,
            qkv_weight_tensor,
            out_linear_weight,
            self.pre_layer_norm,
            ln1_scale,
            ln1_bias,
            ln2_scale,
            ln2_bias,
            epsilon,
            qkv_bias_tensor,
            out_linear_bias,
            cache_kv,
            attn_mask,
            self.dropout_prob,
            self.attn_dropout_prob,
            ln2_epsilon,
        )
321 322 323 324

        if self.has_cache_kv:
            return final_out[0], final_out[1]

325 326 327
        paddle.autograd.backward(
            [final_out], [paddle.to_tensor(self.dout)], retain_graph=True
        )
328
        return final_out, x.grad
L
Li Min 已提交
329 330

    def test_fused_attention_op(self):
331 332
        final_out_ref, x_grad_ref = self.GetBaselineOut()
        final_out, x_grad = self.GetFusedAttentionOut()
333 334 335 336 337 338
        np.testing.assert_allclose(
            final_out_ref, final_out.numpy(), rtol=self.rtol, atol=self.atol
        )
        np.testing.assert_allclose(
            x_grad_ref, x_grad.numpy(), rtol=self.rtol, atol=self.atol
        )
L
Li Min 已提交
339 340


341 342
class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp):
    def config(self):
343
        super().config()
344 345 346
        self.bias_attr = False


347 348
class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
    def config(self):
349
        super().config()
350
        self.pre_layer_norm = True
351 352 353 354


class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
    def config(self):
355
        super().config()
356 357
        self.pre_layer_norm = True
        self.has_attn_mask = False
358 359


L
Li Min 已提交
360 361
class TestFusedAttentionOpFp16(TestFusedAttentionOp):
    def config(self):
362
        super().config()
L
Li Min 已提交
363 364 365
        self.x_type = np.float16

    def test_fused_attention_op(self):
366 367
        final_out_ref, x_grad_ref = self.GetBaselineOut()
        final_out, x_grad = self.GetFusedAttentionOut()
368 369 370 371 372 373
        np.testing.assert_allclose(
            final_out_ref, final_out.numpy(), rtol=self.rtol, atol=self.atol
        )
        np.testing.assert_allclose(
            x_grad_ref, x_grad.numpy(), rtol=self.rtol, atol=self.atol
        )
L
Li Min 已提交
374 375


376 377 378 379 380 381 382 383 384 385 386 387
class TestFusedAttentionOpCacheKV(TestFusedAttentionOp):
    def config(self):
        super().config()
        self.has_cache_kv = True
        self.training = False
        self.query_length = 1
        self.key_length, self.value_length = 1, 1

    def test_fused_attention_op(self):
        with paddle.no_grad():
            final_out_ref = self.GetBaselineOut()
            final_out, cache_kv_out = self.GetFusedAttentionOut()
388 389 390
            np.testing.assert_allclose(
                final_out_ref, final_out.numpy(), rtol=self.rtol, atol=self.atol
            )
391 392


393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709
class TestFusedAttentionOpParamStopGradient(OpTest):
    def setUp(self):
        self.config()
        self.generate_input_data()

        self.rtol = 1e-5
        # FIXME(limin29): Because there is a problem with the test precision
        #  on A100, atol is temporarily set to 1e-2, and it will be
        #  changed back after the precision problem is solved.
        self.atol = 1e-2
        # make sure local development precision
        if "V100" in paddle.device.cuda.get_device_name():
            self.atol = 1e-4
        if self.x_type is np.float16:
            self.atol = 1e-1

        paddle.set_default_dtype(self.x_type)
        self.__class__.op_type = "fused_attention"
        # use autograd to check grad in this unittest.
        self.__class__.no_need_check_grad = True
        self.q_proj = Linear(
            self.embed_dim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
        self.k_proj = Linear(
            self.kdim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
        self.v_proj = Linear(
            self.vdim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
        self.out_proj = Linear(
            self.embed_dim,
            self.embed_dim,
            self.weight_attr,
            bias_attr=self.bias_attr,
        )
        paddle.set_default_dtype(np.float32)
        self.norm1 = LayerNorm(self.embed_dim)
        self.norm2 = LayerNorm(self.embed_dim)
        paddle.set_default_dtype(self.x_type)
        self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train")

    def config(self):
        self.x_type = np.float32
        self.attn_mask_type = np.float64
        self.pre_layer_norm = False
        self.has_attn_mask = True
        self.has_cache_kv = False
        self.training = True

        self.batch_size = 8
        self.query_length = 128
        self.cache_length = 128
        self.head_dim = 64
        self.num_heads = 16
        self.embed_dim = self.head_dim * self.num_heads

        self.dropout_prob = 0.0
        self.attn_dropout_prob = 0.0
        self.weight_attr = None
        self.bias_attr = None
        self.kdim, self.vdim = self.embed_dim, self.embed_dim
        self.key_length, self.value_length = (
            self.query_length,
            self.query_length,
        )

    def generate_input_data(self):
        self.query = np.random.rand(
            self.batch_size, self.query_length, self.embed_dim
        ).astype(self.x_type)
        out_seq_len = self.key_length
        if self.has_cache_kv:
            assert self.training is False, ValueError(
                'cache_kv can only used in inference'
            )
            self.cache_kv = np.random.rand(
                2,
                self.batch_size,
                self.num_heads,
                self.cache_length,
                self.head_dim,
            ).astype(self.x_type)
            out_seq_len += self.cache_length
        else:
            self.cache_kv = None

        if self.has_attn_mask:
            # [B, n_head, seq_len, out_seq_len]
            self.attn_mask = np.ones(
                (
                    self.batch_size,
                    self.num_heads,
                    self.query_length,
                    out_seq_len,
                ),
                dtype=self.attn_mask_type,
            )
            if self.attn_mask_type == np.int64:
                self.attn_mask = np.tril(self.attn_mask)
            elif self.attn_mask_type == np.float64:
                self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
            else:
                raise ValueError(
                    "'attn_mask_type' should be 'int64' or 'float64'."
                )
        else:
            self.attn_mask = None
        self.key, self.value = self.query, self.query

        self.dout = np.random.random(
            (self.batch_size, self.query_length, self.embed_dim)
        ).astype(self.x_type)

    def GetBaselineOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
        tensor_query = paddle.to_tensor(self.query, stop_gradient=False)

        cache_kv = None
        if self.has_cache_kv:
            cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)

        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None
        residual = tensor_query

        ln1_out = tensor_query
        if self.pre_layer_norm:
            ln1_out = self.norm1(tensor_query)

        q = self.q_proj(ln1_out)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3])
        k = self.k_proj(ln1_out)
        v = self.v_proj(ln1_out)
        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        if self.has_cache_kv:
            # [1, B, n_head, cache_seq_len, head_dim]
            cache_k, cache_v = paddle.split(cache_kv, 2)
            cache_k = paddle.squeeze(cache_k, axis=0)
            cache_v = paddle.squeeze(cache_v, axis=0)
            # [B, n_head, cache_seq_len + seq_len, head_dim]
            # out_seq_len = cache_seq_len + seq_len
            k_out = paddle.concat([cache_k, k_out], axis=-2)
            v_out = paddle.concat([cache_v, v_out], axis=-2)

        # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
        # --> [B, n_head, seq_len, out_seq_len]
        qk_out = paddle.matmul(x=q_out, y=k_out, transpose_y=True)
        qk_out = paddle.scale(qk_out, scale=self.head_dim**-0.5)

        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype)
            attn_mask_out = qk_out + attn_mask
            softmax_out = F.softmax(attn_mask_out)
        else:
            softmax_out = F.softmax(qk_out)

        if self.dropout_prob:
            dropout_out = F.dropout(
                softmax_out,
                self.dropout_prob,
                training=self.training,
                mode="upscale_in_train",
            )
            # [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
            # --> [B, n_head, seq_len, head_dim]
            qktv_out = tensor.matmul(dropout_out, v_out)
        else:
            qktv_out = tensor.matmul(softmax_out, v_out)

        fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3])
        out_linear_in = tensor.reshape(
            x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]
        )
        out = self.out_proj(out_linear_in)

        residual_out = residual + self.dropout(out)
        if not self.pre_layer_norm:
            final_out = self.norm1(residual_out)
        else:
            final_out = residual_out

        if self.has_cache_kv:
            return final_out

        paddle.autograd.backward(
            [final_out], [paddle.to_tensor(self.dout)], retain_graph=True
        )
        return final_out, tensor_query.grad

    def GetFusedAttentionOut(self):
        paddle.disable_static(place=paddle.CUDAPlace(0))
        q_proj_weight = paddle.to_tensor(
            self.q_proj.weight, stop_gradient=False
        )
        k_proj_weight = paddle.to_tensor(
            self.k_proj.weight, stop_gradient=False
        )
        v_proj_weight = paddle.to_tensor(
            self.v_proj.weight, stop_gradient=False
        )
        out_linear_weight = paddle.to_tensor(
            self.out_proj.weight, stop_gradient=False
        )

        if self.bias_attr is False:
            qkv_bias_tensor = None
            out_linear_bias = None
        else:
            q_proj_bias = paddle.to_tensor(
                self.q_proj.bias, stop_gradient=False
            )
            k_proj_bias = paddle.to_tensor(
                self.k_proj.bias, stop_gradient=False
            )
            v_proj_bias = paddle.to_tensor(
                self.v_proj.bias, stop_gradient=False
            )
            qkv_bias = np.concatenate(
                (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())
            )
            qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
            qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
            out_linear_bias = paddle.to_tensor(
                self.out_proj.bias, stop_gradient=False
            )

        ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False)
        ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False)
        ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False)
        ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False)

        q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
        k_proj_weight = k_proj_weight.numpy().transpose((1, 0))
        v_proj_weight = v_proj_weight.numpy().transpose((1, 0))
        qkv_weight = np.concatenate(
            (q_proj_weight, k_proj_weight, v_proj_weight)
        )
        qkv_weight = qkv_weight.reshape(
            (3, self.num_heads, self.head_dim, self.embed_dim)
        )

        x = paddle.to_tensor(self.query, stop_gradient=False)
        cache_kv = None
        if self.has_cache_kv:
            cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False)
        if self.has_attn_mask:
            attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
        else:
            attn_mask = None
        qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
        epsilon = 1e-05
        ln2_epsilon = 1e-05

        if attn_mask is not None:
            attn_mask = _convert_attention_mask(attn_mask, x.dtype)
        qkv_weight_tensor.stop_gradient = True
        out_linear_weight.stop_gradient = True
        ln1_scale.stop_gradient = True
        ln1_bias.stop_gradient = True
        ln2_scale.stop_gradient = True
        ln2_bias.stop_gradient = True
        qkv_bias_tensor.stop_gradient = True
        out_linear_bias.stop_gradient = True
        final_out = incubate_f.fused_multi_head_attention(
            x,
            qkv_weight_tensor,
            out_linear_weight,
            self.pre_layer_norm,
            ln1_scale,
            ln1_bias,
            ln2_scale,
            ln2_bias,
            epsilon,
            qkv_bias_tensor,
            out_linear_bias,
            cache_kv,
            attn_mask,
            self.dropout_prob,
            self.attn_dropout_prob,
            ln2_epsilon,
        )

        if self.has_cache_kv:
            return final_out[0], final_out[1]

        paddle.autograd.backward(
            [final_out], [paddle.to_tensor(self.dout)], retain_graph=True
        )
        return final_out, x.grad

    def test_fused_attention_op(self):
        final_out_ref, x_grad_ref = self.GetBaselineOut()
        final_out, x_grad = self.GetFusedAttentionOut()
        np.testing.assert_allclose(
            final_out_ref, final_out.numpy(), rtol=self.rtol, atol=self.atol
        )
        np.testing.assert_allclose(
            x_grad_ref, x_grad.numpy(), rtol=self.rtol, atol=self.atol
        )


L
Li Min 已提交
710 711
if __name__ == "__main__":
    unittest.main()