multihead_attention.py 5.2 KB
Newer Older
L
lifuchen 已提交
1 2 3 4
import math
import numpy as np
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
L
lifuchen 已提交
5
from parakeet.modules.layers import Linear
L
lifuchen 已提交
6 7 8 9 10 11 12 13

class ScaledDotProductAttention(dg.Layer):
    def __init__(self, d_key):
        super(ScaledDotProductAttention, self).__init__()

        self.d_key = d_key
    
    # please attention this mask is diff from pytorch
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
    def forward(self, key, value, query, mask=None, query_mask=None, dropout=0.1):
        """
        Scaled Dot Product Attention.
        
        Args:
            key (Variable): Shape(B, T, C), dtype: float32. The input key of attention.
            value (Variable): Shape(B, T, C), dtype: float32. The input value of attention.
            query (Variable): Shape(B, T, C), dtype: float32. The input query of attention.
            mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key.
            query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query.
            dropout (Constant): dtype: float32. The probability of dropout.
        Returns:
            result (Variable), Shape(B, T, C), the result of mutihead attention.
            attention (Variable), Shape(n_head * B, T, C), the attention of key.
        """
L
lifuchen 已提交
29 30 31 32 33 34
        # Compute attention score
        attention = layers.matmul(query, key, transpose_y=True) #transpose the last dim in y
        attention = attention / math.sqrt(self.d_key)

        # Mask key to ignore padding
        if mask is not None:
35 36
            attention = attention * mask
            mask = (mask == 0).astype(np.float32) * (-2 ** 32 + 1)
L
lifuchen 已提交
37
            attention = attention + mask
L
lifuchen 已提交
38
        
L
lifuchen 已提交
39
        attention = layers.softmax(attention)
40
        attention = layers.dropout(attention, dropout)
L
lifuchen 已提交
41
        
L
lifuchen 已提交
42 43 44 45 46 47 48 49
        # Mask query to ignore padding
        if query_mask is not None:
            attention = attention * query_mask
        
        result = layers.matmul(attention, value)
        return result, attention

class MultiheadAttention(dg.Layer):
L
lifuchen 已提交
50
    def __init__(self, num_hidden, d_k, d_q, num_head=4, is_bias=False, dropout=0.1, is_concat=True):
L
lifuchen 已提交
51 52 53 54 55 56
        super(MultiheadAttention, self).__init__()
        self.num_hidden = num_hidden
        self.num_head = num_head
        self.d_k = d_k
        self.d_q = d_q
        self.dropout = dropout
L
lifuchen 已提交
57
        self.is_concat = is_concat
L
lifuchen 已提交
58

L
lifuchen 已提交
59 60 61
        self.key = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
        self.value = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
        self.query = Linear(num_hidden, num_head * d_q, is_bias=is_bias)
L
lifuchen 已提交
62 63 64

        self.scal_attn = ScaledDotProductAttention(d_k)

L
lifuchen 已提交
65 66 67 68
        if self.is_concat:
            self.fc = Linear(num_head * d_q * 2, num_hidden)
        else:
            self.fc = Linear(num_head * d_q, num_hidden)
L
lifuchen 已提交
69 70 71 72

        self.layer_norm = dg.LayerNorm(num_hidden)

    def forward(self, key, value, query_input, mask=None, query_mask=None):
73 74 75 76 77 78 79 80 81 82 83 84 85
        """
        Multihead Attention.
        
        Args:
            key (Variable): Shape(B, T, C), dtype: float32. The input key of attention.
            value (Variable): Shape(B, T, C), dtype: float32. The input value of attention.
            query_input (Variable): Shape(B, T, C), dtype: float32. The input query of attention.
            mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key.
            query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query.
        Returns:
            result (Variable), Shape(B, T, C), the result of mutihead attention.
            attention (Variable), Shape(n_head * B, T, C), the attention of key.
        """
L
lifuchen 已提交
86 87 88 89 90 91 92 93 94 95
        batch_size = key.shape[0]
        seq_len_key = key.shape[1]
        seq_len_query = query_input.shape[1]

        # repeat masks h times
        if query_mask is not None:
            query_mask = layers.expand(query_mask, [self.num_head, 1, seq_len_key])
        if mask is not None:
            mask = layers.expand(mask, (self.num_head, 1, 1))
        
96
        
L
lifuchen 已提交
97 98 99 100 101 102 103 104 105
        # Make multihead attention
        # key & value.shape = (batch_size, seq_len, feature)(feature = num_head * num_hidden_per_attn)
        key = layers.reshape(self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k])
        value = layers.reshape(self.value(value), [batch_size, seq_len_key, self.num_head, self.d_k])
        query = layers.reshape(self.query(query_input), [batch_size, seq_len_query, self.num_head, self.d_q])

        key = layers.reshape(layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
        value = layers.reshape(layers.transpose(value, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
        query = layers.reshape(layers.transpose(query, [2, 0, 1, 3]), [-1, seq_len_query, self.d_q])
106
        
L
lifuchen 已提交
107 108 109 110 111
        result, attention = self.scal_attn(key, value, query, mask=mask, query_mask=query_mask)
        
        # concat all multihead result
        result = layers.reshape(result, [self.num_head, batch_size, seq_len_query, self.d_q])
        result = layers.reshape(layers.transpose(result, [1,2,0,3]),[batch_size, seq_len_query, -1])
L
lifuchen 已提交
112 113
        if self.is_concat:
            result = layers.concat([query_input,result], axis=-1)
L
lifuchen 已提交
114 115 116 117 118
        result = layers.dropout(self.fc(result), self.dropout)
        result = result + query_input
        
        result = self.layer_norm(result)
        return result, attention