test_fused_attention_pass.py 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
import paddle.fluid.core as core
import paddle.nn.functional as F
from paddle.distributed.passes import PassManager, new_pass

paddle.enable_static()


class MultiHeadAttention(paddle.nn.Layer):
    def __init__(
        self,
        embed_dim,
        num_heads,
        add_residual=True,
        pre_ln=True,
        post_ln=False,
        attn_dropout=True,
    ):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.kdim = embed_dim
        self.vdim = embed_dim
        self.num_heads = num_heads

        self.add_residual = add_residual
        self.pre_ln = pre_ln
        self.post_ln = post_ln
        self.attn_dropout = attn_dropout

        self.head_dim = embed_dim // num_heads
        assert (
            self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"

        self.norm1 = paddle.nn.LayerNorm(embed_dim, epsilon=1e-5)
        self.norm2 = paddle.nn.LayerNorm(embed_dim, epsilon=1e-5)

        self.qkv_proj = paddle.nn.Linear(embed_dim, 3 * embed_dim)
        self.out_proj = paddle.nn.Linear(embed_dim, embed_dim)
        self.dropout = paddle.nn.Dropout(0.1, mode="upscale_in_train")

    def forward(self, x, attn_mask=None):
        residual = x

        if self.pre_ln:
            # pre layer norm
            x = self.norm1(x)

        # compute qkv
        qkv = self.qkv_proj(x)
        qkv = paddle.reshape(qkv, [0, 0, self.num_heads, 3 * self.head_dim])
        qkv = paddle.transpose(qkv, [0, 2, 1, 3])
        q, k, v = paddle.split(qkv, num_or_sections=3, axis=-1)

        # compute core attention
        product = paddle.matmul(x=q, y=k, transpose_y=True)
        product = paddle.scale(product, scale=self.head_dim**-0.5)
        if attn_mask is not None:
            product = product + attn_mask
        weights = F.softmax(product)
        if self.attn_dropout:
            weights = F.dropout(
                weights, 0.1, training=self.training, mode="upscale_in_train"
            )
        out = paddle.matmul(weights, v)
        out = paddle.transpose(out, perm=[0, 2, 1, 3])
        out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)
        out = self.dropout(out)
        if self.add_residual:
            out = residual + out

        if self.post_ln:
            # post layer norm
            out = self.norm2(out)

        return out


@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestFusedAttentionPass(unittest.TestCase):
    def setUp(self):
        self.add_residual = True
        self.pre_ln = True
        self.post_ln = True
        self.attn_dropout = True
        self.add_mask = True

    def test_pass(self):
        batch_size = 2
        seq_len = 1024
        hidden_size = 768
        num_heads = 12

117
        x_data = np.random.rand(batch_size, seq_len, seq_len).astype('float32')
118 119 120 121 122 123 124 125 126 127
        mask_data = np.random.rand(
            batch_size, num_heads, seq_len, seq_len
        ).astype('float32')

        main_prog = paddle.static.Program()
        startup_prog = paddle.static.Program()

        with paddle.static.program_guard(main_prog, startup_prog):
            data = paddle.static.data(
                name="x",
128
                shape=[-1, seq_len, seq_len],
129 130 131 132 133 134 135 136 137 138
                dtype='float32',
            )
            if self.add_mask:
                attn_mask = paddle.static.data(
                    name="attn_mask",
                    shape=[-1, num_heads, seq_len, seq_len],
                    dtype='float32',
                )
            else:
                attn_mask = None
139
            data_linear = paddle.nn.Linear(seq_len, hidden_size)
140 141 142 143 144 145 146 147
            multi_head_attn = MultiHeadAttention(
                hidden_size,
                num_heads,
                add_residual=self.add_residual,
                pre_ln=self.pre_ln,
                post_ln=self.post_ln,
                attn_dropout=self.attn_dropout,
            )
148 149 150

            attn_input = data_linear(data)
            out = multi_head_attn(attn_input, attn_mask)
151 152 153 154 155 156 157 158 159
            loss = paddle.mean(out)

            sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001)
            sgd_optimizer.minimize(loss)

        pass_manager = PassManager([new_pass("fused_attention")])
        pass_manager.apply([main_prog], [startup_prog])

        ops = main_prog.global_block().ops
160 161 162 163 164 165 166
        assert ops[2].type == 'reduce_mean'
        assert ops[4].type == 'reduce_mean_grad'
        # two ops for linear, one op for reduce mean
        # one fill constant
        # one op for reduce mean grad, two ops for linear bwd
        # the eighth op should be the optimizer
        assert ops[7].type == 'sgd'
167 168 169 170 171


if __name__ == "__main__":
    np.random.seed(0)
    unittest.main()