multiheadAttention.py 6.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
Topdu 已提交
15 16 17 18 19 20 21 22 23 24 25 26
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.nn import Linear
from paddle.nn.initializer import XavierUniform as xavier_uniform_
from paddle.nn.initializer import Constant as constant_
from paddle.nn.initializer import XavierNormal as xavier_normal_

zeros_ = constant_(value=0.)
ones_ = constant_(value=1.)


T
Topdu 已提交
27
class MultiheadAttention(nn.Layer):
28
    """Allows the model to jointly attend to information
T
Topdu 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41
    from different representation subspaces.
    See reference: Attention Is All You Need

    .. math::
        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
        \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

    Args:
        embed_dim: total dimension of the model
        num_heads: parallel attention layers, or heads

    """

42 43 44 45 46 47 48
    def __init__(self,
                 embed_dim,
                 num_heads,
                 dropout=0.,
                 bias=True,
                 add_bias_kv=False,
                 add_zero_attn=False):
T
Topdu 已提交
49
        super(MultiheadAttention, self).__init__()
T
Topdu 已提交
50 51 52 53 54
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
55
        self.scaling = self.head_dim**-0.5
T
Topdu 已提交
56 57
        self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
        self._reset_parameters()
58 59 60 61 62 63
        self.conv1 = paddle.nn.Conv2D(
            in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
        self.conv2 = paddle.nn.Conv2D(
            in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
        self.conv3 = paddle.nn.Conv2D(
            in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
T
Topdu 已提交
64 65 66 67

    def _reset_parameters(self):
        xavier_uniform_(self.out_proj.weight)

68 69 70 71 72 73 74
    def forward(self,
                query,
                key,
                value,
                key_padding_mask=None,
                incremental_state=None,
                attn_mask=None):
T
Topdu 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88
        """
        Inputs of forward function
            query: [target length, batch size, embed dim]
            key: [sequence length, batch size, embed dim]
            value: [sequence length, batch size, embed dim]
            key_padding_mask: if True, mask padding based on batch size
            incremental_state: if provided, previous time steps are cashed
            need_weights: output attn_output_weights
            static_kv: key and value are static

        Outputs of forward function
            attn_output: [target length, batch size, embed dim]
            attn_output_weights: [batch size, target length, sequence length]
        """
T
Topdu 已提交
89 90
        q_shape = paddle.shape(query)
        src_shape = paddle.shape(key)
T
Topdu 已提交
91 92 93 94
        q = self._in_proj_q(query)
        k = self._in_proj_k(key)
        v = self._in_proj_v(value)
        q *= self.scaling
T
Topdu 已提交
95 96 97 98 99 100 101 102 103 104 105 106
        q = paddle.transpose(
            paddle.reshape(
                q, [q_shape[0], q_shape[1], self.num_heads, self.head_dim]),
            [1, 2, 0, 3])
        k = paddle.transpose(
            paddle.reshape(
                k, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
            [1, 2, 0, 3])
        v = paddle.transpose(
            paddle.reshape(
                v, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
            [1, 2, 0, 3])
T
Topdu 已提交
107
        if key_padding_mask is not None:
T
Topdu 已提交
108 109 110 111
            assert key_padding_mask.shape[0] == q_shape[1]
            assert key_padding_mask.shape[1] == src_shape[0]
        attn_output_weights = paddle.matmul(q,
                                            paddle.transpose(k, [0, 1, 3, 2]))
T
Topdu 已提交
112
        if attn_mask is not None:
T
Topdu 已提交
113
            attn_mask = paddle.unsqueeze(paddle.unsqueeze(attn_mask, 0), 0)
T
Topdu 已提交
114 115
            attn_output_weights += attn_mask
        if key_padding_mask is not None:
T
Topdu 已提交
116 117 118 119 120 121 122
            attn_output_weights = paddle.reshape(
                attn_output_weights,
                [q_shape[1], self.num_heads, q_shape[0], src_shape[0]])
            key = paddle.unsqueeze(paddle.unsqueeze(key_padding_mask, 1), 2)
            key = paddle.cast(key, 'float32')
            y = paddle.full(
                shape=paddle.shape(key), dtype='float32', fill_value='-inf')
123
            y = paddle.where(key == 0., key, y)
T
Topdu 已提交
124 125
            attn_output_weights += y
        attn_output_weights = F.softmax(
126 127 128 129 130 131
            attn_output_weights.astype('float32'),
            axis=-1,
            dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16
            else attn_output_weights.dtype)
        attn_output_weights = F.dropout(
            attn_output_weights, p=self.dropout, training=self.training)
T
Topdu 已提交
132

T
Topdu 已提交
133 134 135 136
        attn_output = paddle.matmul(attn_output_weights, v)
        attn_output = paddle.reshape(
            paddle.transpose(attn_output, [2, 0, 1, 3]),
            [q_shape[0], q_shape[1], self.embed_dim])
T
Topdu 已提交
137 138
        attn_output = self.out_proj(attn_output)

T
Topdu 已提交
139
        return attn_output
T
Topdu 已提交
140 141

    def _in_proj_q(self, query):
T
Topdu 已提交
142
        query = paddle.transpose(query, [1, 2, 0])
T
Topdu 已提交
143 144 145
        query = paddle.unsqueeze(query, axis=2)
        res = self.conv1(query)
        res = paddle.squeeze(res, axis=2)
T
Topdu 已提交
146
        res = paddle.transpose(res, [2, 0, 1])
T
Topdu 已提交
147 148 149
        return res

    def _in_proj_k(self, key):
T
Topdu 已提交
150
        key = paddle.transpose(key, [1, 2, 0])
T
Topdu 已提交
151 152 153
        key = paddle.unsqueeze(key, axis=2)
        res = self.conv2(key)
        res = paddle.squeeze(res, axis=2)
T
Topdu 已提交
154
        res = paddle.transpose(res, [2, 0, 1])
T
Topdu 已提交
155 156 157
        return res

    def _in_proj_v(self, value):
T
Topdu 已提交
158
        value = paddle.transpose(value, [1, 2, 0])  #(1, 2, 0)
T
Topdu 已提交
159 160 161
        value = paddle.unsqueeze(value, axis=2)
        res = self.conv3(value)
        res = paddle.squeeze(res, axis=2)
T
Topdu 已提交
162
        res = paddle.transpose(res, [2, 0, 1])
T
topduke 已提交
163
        return res