multihead_attention.py 7.6 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
lifuchen 已提交
14 15
import math
import numpy as np
L
lifuchen 已提交
16
import paddle.fluid as fluid
L
lifuchen 已提交
17 18
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
L
lifuchen 已提交
19

L
lifuchen 已提交
20

L
lifuchen 已提交
21
class Linear(dg.Layer):
L
lifuchen 已提交
22 23 24 25 26
    def __init__(self,
                 in_features,
                 out_features,
                 is_bias=True,
                 dtype="float32"):
L
lifuchen 已提交
27 28 29 30
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dtype = dtype
L
lifuchen 已提交
31 32 33
        self.weight = fluid.ParamAttr(
            initializer=fluid.initializer.XavierInitializer())
        self.bias = is_bias
L
lifuchen 已提交
34 35 36

        if is_bias is not False:
            k = math.sqrt(1 / in_features)
L
lifuchen 已提交
37 38 39 40 41 42 43 44
            self.bias = fluid.ParamAttr(initializer=fluid.initializer.Uniform(
                low=-k, high=k))

        self.linear = dg.Linear(
            in_features,
            out_features,
            param_attr=self.weight,
            bias_attr=self.bias, )
L
lifuchen 已提交
45 46 47 48

    def forward(self, x):
        x = self.linear(x)
        return x
L
lifuchen 已提交
49

L
lifuchen 已提交
50

L
lifuchen 已提交
51 52
class ScaledDotProductAttention(dg.Layer):
    def __init__(self, d_key):
53 54 55 56 57
        """Scaled dot product attention module.

        Args:
            d_key (int): the dim of key in multihead attention.
        """
L
lifuchen 已提交
58 59 60
        super(ScaledDotProductAttention, self).__init__()

        self.d_key = d_key
L
lifuchen 已提交
61

L
lifuchen 已提交
62
    # please attention this mask is diff from pytorch
L
lifuchen 已提交
63 64 65 66 67 68 69
    def forward(self,
                key,
                value,
                query,
                mask=None,
                query_mask=None,
                dropout=0.1):
70
        """
71
        Compute scaled dot product attention.
72 73
        
        Args:
74 75 76 77 78 79
            key (Variable): shape(B, T, C), dtype float32, the input key of scaled dot product attention.
            value (Variable): shape(B, T, C), dtype float32, the input value of scaled dot product attention.
            query (Variable): shape(B, T, C), dtype float32, the input query of scaled dot product attention.
            mask (Variable, optional): shape(B, T_q, T_k), dtype float32, the mask of key.  Defaults to None.
            query_mask (Variable, optional): shape(B, T_q, T_q), dtype float32, the mask of query.  Defaults to None.
            dropout (float32, optional): the probability of dropout. Defaults to 0.1.
80
        Returns:
81 82
            result (Variable): shape(B, T, C), the result of mutihead attention.
            attention (Variable): shape(n_head * B, T, C), the attention of key.
83
        """
L
lifuchen 已提交
84
        # Compute attention score
L
lifuchen 已提交
85
        attention = layers.matmul(
86 87
            query, key, transpose_y=True, alpha=self.d_key
            **-0.5)  #transpose the last dim in y
L
lifuchen 已提交
88 89 90 91 92

        # Mask key to ignore padding
        if mask is not None:
            attention = attention + mask
        attention = layers.softmax(attention)
93 94
        attention = layers.dropout(
            attention, dropout, dropout_implementation='upscale_in_train')
L
lifuchen 已提交
95

L
lifuchen 已提交
96 97 98
        # Mask query to ignore padding
        if query_mask is not None:
            attention = attention * query_mask
L
lifuchen 已提交
99

L
lifuchen 已提交
100 101 102
        result = layers.matmul(attention, value)
        return result, attention

L
lifuchen 已提交
103

L
lifuchen 已提交
104
class MultiheadAttention(dg.Layer):
L
lifuchen 已提交
105 106 107 108 109 110 111 112
    def __init__(self,
                 num_hidden,
                 d_k,
                 d_q,
                 num_head=4,
                 is_bias=False,
                 dropout=0.1,
                 is_concat=True):
113 114 115 116 117 118 119 120 121 122 123
        """Multihead Attention.

        Args:
            num_hidden (int): the number of hidden layer in network.
            d_k (int): the dim of key in multihead attention.
            d_q (int): the dim of query in multihead attention.
            num_head (int, optional): the head number of multihead attention. Defaults to 4.
            is_bias (bool, optional): whether have bias in linear layers. Default to False.
            dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
            is_concat (bool, optional): whether concat query and result. Default to True.
        """
L
lifuchen 已提交
124 125 126 127 128 129
        super(MultiheadAttention, self).__init__()
        self.num_hidden = num_hidden
        self.num_head = num_head
        self.d_k = d_k
        self.d_q = d_q
        self.dropout = dropout
L
lifuchen 已提交
130
        self.is_concat = is_concat
L
lifuchen 已提交
131

L
lifuchen 已提交
132 133 134
        self.key = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
        self.value = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
        self.query = Linear(num_hidden, num_head * d_q, is_bias=is_bias)
L
lifuchen 已提交
135 136 137

        self.scal_attn = ScaledDotProductAttention(d_k)

L
lifuchen 已提交
138 139 140 141
        if self.is_concat:
            self.fc = Linear(num_head * d_q * 2, num_hidden)
        else:
            self.fc = Linear(num_head * d_q, num_hidden)
L
lifuchen 已提交
142 143 144 145

        self.layer_norm = dg.LayerNorm(num_hidden)

    def forward(self, key, value, query_input, mask=None, query_mask=None):
146
        """
147
        Compute attention.
148 149
        
        Args:
150 151 152 153 154 155
            key (Variable): shape(B, T, C), dtype float32, the input key of attention.
            value (Variable): shape(B, T, C), dtype float32, the input value of attention.
            query_input (Variable): shape(B, T, C), dtype float32, the input query of attention.
            mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of key. Defaults to None.
            query_mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of query. Defaults to None.
                
156
        Returns:
157 158
            result (Variable): shape(B, T, C), the result of mutihead attention. 
            attention (Variable): shape(num_head * B, T, C), the attention of key and query. 
159
        """
160

L
lifuchen 已提交
161 162 163 164 165
        batch_size = key.shape[0]
        seq_len_key = key.shape[1]
        seq_len_query = query_input.shape[1]

        # Make multihead attention
L
lifuchen 已提交
166 167 168 169 170 171 172 173 174 175 176
        key = layers.reshape(
            self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k])
        value = layers.reshape(
            self.value(value),
            [batch_size, seq_len_key, self.num_head, self.d_k])
        query = layers.reshape(
            self.query(query_input),
            [batch_size, seq_len_query, self.num_head, self.d_q])

        key = layers.reshape(
            layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
177 178 179 180 181 182 183 184 185 186
        value = layers.reshape(
            layers.transpose(value, [2, 0, 1, 3]),
            [-1, seq_len_key, self.d_k])
        query = layers.reshape(
            layers.transpose(query, [2, 0, 1, 3]),
            [-1, seq_len_query, self.d_q])

        result, attention = self.scal_attn(
            key, value, query, mask=mask, query_mask=query_mask)

L
lifuchen 已提交
187
        # concat all multihead result
L
lifuchen 已提交
188 189 190 191 192
        result = layers.reshape(
            result, [self.num_head, batch_size, seq_len_query, self.d_q])
        result = layers.reshape(
            layers.transpose(result, [1, 2, 0, 3]),
            [batch_size, seq_len_query, -1])
L
lifuchen 已提交
193
        if self.is_concat:
L
lifuchen 已提交
194
            result = layers.concat([query_input, result], axis=-1)
195 196 197 198
        result = layers.dropout(
            self.fc(result),
            self.dropout,
            dropout_implementation='upscale_in_train')
L
lifuchen 已提交
199
        result = result + query_input
L
lifuchen 已提交
200

L
lifuchen 已提交
201
        result = self.layer_norm(result)
L
lifuchen 已提交
202
        return result, attention