attention.py 4.0 KB
Newer Older
C
chenfeiyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
import numpy as np
from collections import namedtuple
from paddle import fluid
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as F

from parakeet.modules.weight_norm import Linear
WindowRange = namedtuple("WindowRange", ["backward", "ahead"])


class Attention(dg.Layer):
    def __init__(self,
                 query_dim,
                 embed_dim,
                 dropout=0.0,
                 window_range=WindowRange(-1, 3),
                 key_projection=True,
                 value_projection=True):
        super(Attention, self).__init__()
        self.query_proj = Linear(query_dim, embed_dim)
        if key_projection:
            self.key_proj = Linear(embed_dim, embed_dim)
        if value_projection:
            self.value_proj = Linear(embed_dim, embed_dim)
        self.out_proj = Linear(embed_dim, query_dim)

        self.key_projection = key_projection
        self.value_projection = value_projection
        self.dropout = dropout
        self.window_range = window_range

    def forward(self, query, encoder_out, mask=None, last_attended=None):
        """
        Compute pooled context representation and alignment scores.
        
        Args:
            query (Variable): shape(B, T_dec, C_q), the query tensor,
                where C_q means the channel of query.
            encoder_out (Tuple(Variable, Variable)): 
                keys (Variable): shape(B, T_enc, C_emb), the key
                    representation from an encoder, where C_emb means
                    text embedding size.
                values (Variable): shape(B, T_enc, C_emb), the value
                    representation from an encoder, where C_emb means
                    text embedding size.
            mask (Variable, optional): Shape(B, T_enc), mask generated with 
                valid text lengths.
            last_attended (int, optional): The position that received most
                attention at last timestep. This is only used at decoding.

        Outpus:
            x (Variable): Shape(B, T_dec, C_q), the context representation
                pooled from attention mechanism.
            attn_scores (Variable): shape(B, T_dec, T_enc), the alignment
                tensor, where T_dec means the number of decoder time steps and 
                T_enc means number the number of decoder time steps.
        """
        keys, values = encoder_out
        residual = query
        if self.value_projection:
            values = self.value_proj(values)
        if self.key_projection:
            keys = self.key_proj(keys)
        x = self.query_proj(query)
        # TODO: check the code

        x = F.matmul(x, keys, transpose_y=True)

        # mask generated by sentence length
        neg_inf = -1.e30
        if mask is not None:
            neg_inf_mask = F.scale(F.unsqueeze(mask, [1]), neg_inf)
            x += neg_inf_mask

        # if last_attended is provided, focus only on a window range around it
        # to enforce monotonic attention.
        # TODO: if last attended is a shape(B,) array
        if last_attended is not None:
            locality_mask = np.ones(shape=x.shape, dtype=np.float32)
            backward, ahead = self.window_range
            backward = last_attended + backward
            ahead = last_attended + ahead
            backward = max(backward, 0)
            ahead = min(ahead, x.shape[-1])
            locality_mask[:, :, backward:ahead] = 0.
            locality_mask = dg.to_variable(locality_mask)
            neg_inf_mask = F.scale(locality_mask, neg_inf)
            x += neg_inf_mask

        x = F.softmax(x)
        attn_scores = x
        x = F.dropout(x,
                      self.dropout,
                      dropout_implementation="upscale_in_train")
        x = F.matmul(x, values)
        encoder_length = keys.shape[1]
        # CAUTION: is it wrong? let it be now
        x = F.scale(x, encoder_length * np.sqrt(1.0 / encoder_length))
        x = self.out_proj(x)
        x = F.scale((x + residual), np.sqrt(0.5))
        return x, attn_scores