multiheadAttention.py 6.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
Topdu 已提交
15 16 17 18 19 20 21 22 23 24 25 26
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.nn import Linear
from paddle.nn.initializer import XavierUniform as xavier_uniform_
from paddle.nn.initializer import Constant as constant_
from paddle.nn.initializer import XavierNormal as xavier_normal_

zeros_ = constant_(value=0.)
ones_ = constant_(value=1.)


T
Topdu 已提交
27
class MultiheadAttention(nn.Layer):
28
    """Allows the model to jointly attend to information
T
Topdu 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41
    from different representation subspaces.
    See reference: Attention Is All You Need

    .. math::
        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
        \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

    Args:
        embed_dim: total dimension of the model
        num_heads: parallel attention layers, or heads

    """

42 43 44 45 46 47 48
    def __init__(self,
                 embed_dim,
                 num_heads,
                 dropout=0.,
                 bias=True,
                 add_bias_kv=False,
                 add_zero_attn=False):
T
Topdu 已提交
49
        super(MultiheadAttention, self).__init__()
T
Topdu 已提交
50 51 52 53 54
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
55
        self.scaling = self.head_dim**-0.5
T
Topdu 已提交
56 57
        self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
        self._reset_parameters()
58 59 60 61 62 63
        self.conv1 = paddle.nn.Conv2D(
            in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
        self.conv2 = paddle.nn.Conv2D(
            in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
        self.conv3 = paddle.nn.Conv2D(
            in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
T
Topdu 已提交
64 65 66 67

    def _reset_parameters(self):
        xavier_uniform_(self.out_proj.weight)

68 69 70 71 72 73 74 75 76
    def forward(self,
                query,
                key,
                value,
                key_padding_mask=None,
                incremental_state=None,
                need_weights=True,
                static_kv=False,
                attn_mask=None):
T
Topdu 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
        """
        Inputs of forward function
            query: [target length, batch size, embed dim]
            key: [sequence length, batch size, embed dim]
            value: [sequence length, batch size, embed dim]
            key_padding_mask: if True, mask padding based on batch size
            incremental_state: if provided, previous time steps are cashed
            need_weights: output attn_output_weights
            static_kv: key and value are static

        Outputs of forward function
            attn_output: [target length, batch size, embed dim]
            attn_output_weights: [batch size, target length, sequence length]
        """
        tgt_len, bsz, embed_dim = query.shape
        assert embed_dim == self.embed_dim
        assert list(query.shape) == [tgt_len, bsz, embed_dim]
        assert key.shape == value.shape

        q = self._in_proj_q(query)
        k = self._in_proj_k(key)
        v = self._in_proj_v(value)
        q *= self.scaling

101 102 103 104 105 106
        q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose(
            [1, 0, 2])
        k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose(
            [1, 0, 2])
        v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose(
            [1, 0, 2])
T
Topdu 已提交
107 108 109 110 111 112 113

        src_len = k.shape[1]

        if key_padding_mask is not None:
            assert key_padding_mask.shape[0] == bsz
            assert key_padding_mask.shape[1] == src_len

114 115 116
        attn_output_weights = paddle.bmm(q, k.transpose([0, 2, 1]))
        assert list(attn_output_weights.
                    shape) == [bsz * self.num_heads, tgt_len, src_len]
T
Topdu 已提交
117 118 119 120 121

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            attn_output_weights += attn_mask
        if key_padding_mask is not None:
122 123
            attn_output_weights = attn_output_weights.reshape(
                [bsz, self.num_heads, tgt_len, src_len])
T
Topdu 已提交
124 125
            key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32')
            y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf')
126
            y = paddle.where(key == 0., key, y)
T
Topdu 已提交
127
            attn_output_weights += y
128 129
            attn_output_weights = attn_output_weights.reshape(
                [bsz * self.num_heads, tgt_len, src_len])
T
Topdu 已提交
130 131

        attn_output_weights = F.softmax(
132 133 134 135 136 137
            attn_output_weights.astype('float32'),
            axis=-1,
            dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16
            else attn_output_weights.dtype)
        attn_output_weights = F.dropout(
            attn_output_weights, p=self.dropout, training=self.training)
T
Topdu 已提交
138 139

        attn_output = paddle.bmm(attn_output_weights, v)
140 141 142 143
        assert list(attn_output.
                    shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
        attn_output = attn_output.transpose([1, 0, 2]).reshape(
            [tgt_len, bsz, embed_dim])
T
Topdu 已提交
144 145 146 147
        attn_output = self.out_proj(attn_output)

        if need_weights:
            # average attention weights over heads
148 149 150 151
            attn_output_weights = attn_output_weights.reshape(
                [bsz, self.num_heads, tgt_len, src_len])
            attn_output_weights = attn_output_weights.sum(
                axis=1) / self.num_heads
T
Topdu 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
        else:
            attn_output_weights = None
        return attn_output, attn_output_weights

    def _in_proj_q(self, query):
        query = query.transpose([1, 2, 0])
        query = paddle.unsqueeze(query, axis=2)
        res = self.conv1(query)
        res = paddle.squeeze(res, axis=2)
        res = res.transpose([2, 0, 1])
        return res

    def _in_proj_k(self, key):
        key = key.transpose([1, 2, 0])
        key = paddle.unsqueeze(key, axis=2)
        res = self.conv2(key)
        res = paddle.squeeze(res, axis=2)
        res = res.transpose([2, 0, 1])
        return res

    def _in_proj_v(self, value):
173
        value = value.transpose([1, 2, 0])  #(1, 2, 0)
T
Topdu 已提交
174 175 176 177
        value = paddle.unsqueeze(value, axis=2)
        res = self.conv3(value)
        res = paddle.squeeze(res, axis=2)
        res = res.transpose([2, 0, 1])
T
topduke 已提交
178
        return res