diff --git a/labml_nn/transformers/__init__.py b/labml_nn/transformers/__init__.py index 9e18fb04f1b34abbc1216aa5c257ee6f7bb790f2..52ac31d6c63b083f59664a559dbc8297d547629f 100644 --- a/labml_nn/transformers/__init__.py +++ b/labml_nn/transformers/__init__.py @@ -1,3 +1,12 @@ +""" +# Transformers + +* [Multi-head attention](mha.html) +* [Relative multi-head attention](relative_mha.html) +* [Transformer models](models.html) +* [Fixed positional encoding](positional_encoding.html) +""" + from .configs import TransformerConfigs from .models import TransformerLayer, Encoder, Decoder, Generator, EncoderDecoder from .mha import MultiHeadAttention diff --git a/labml_nn/transformers/mha.py b/labml_nn/transformers/mha.py index cb6aecdadf3d9532d8a901bd74b57ad216861ec3..4f635e7967a22915a32f18c1cb6126df2bffce71 100644 --- a/labml_nn/transformers/mha.py +++ b/labml_nn/transformers/mha.py @@ -1,3 +1,9 @@ +""" +# Multi-Headed Attention + +The implementation is inspired from [Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html) +""" + import math from typing import Optional @@ -10,6 +16,10 @@ from labml_helpers.module import Module class PrepareForMultiHeadAttention(Module): + """ + This module does a linear transformation and splits the vector into given + number of heads for multi-head attention. + """ def __init__(self, d_model: int, heads: int, d_k: int, bias: bool): super().__init__() self.linear = nn.Linear(d_model, heads * d_k, bias=bias) @@ -17,22 +27,27 @@ class PrepareForMultiHeadAttention(Module): self.d_k = d_k def __call__(self, x: torch.Tensor): + # Input has shape `[seq_len, batch_size, d_model]` seq_len, batch_size, _ = x.shape x = self.linear(x) x = x.view(seq_len, batch_size, self.heads, self.d_k) + # Output has shape `[seq_len, batch_size, heads, d_k]` return x class MultiHeadAttention(Module): def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool=True): """ - ### Multi-Head Attention + ## Multi-Head Attention Module This computes multi-headed attention for given `query`, `key` and `value` vectors. `heads` is the number of heads. `d_model` is the number of features in the `query`, `key` and `value` vectors. + + $$Attention(Q, K, V) = softmax\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$ + """ super().__init__() @@ -54,8 +69,12 @@ class MultiHeadAttention(Module): def get_scores(self, query: torch.Tensor, key: torch.Tensor): """ - ### Calculate scores between queries and keys + ### Calculate scores between queries and keys. + + This method can be overriden for other variations like relative attention. """ + + # Calculate $Q K^T$ return torch.einsum('ibhd,jbhd->ijbh', query, key) def __call__(self, *, @@ -69,10 +88,10 @@ class MultiHeadAttention(Module): if mask is not None: # `mask` has shape `[seq_len, seq_len, batch_size]`, # where first dimension is the query dimension. - # If the query dimension is equal to $`$ it will be broadcasted to match + # If the query dimension is equal to $1$ it will be broadcasted assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1] - # Same mask applied to all `h` heads. + # Same mask applied to all heads. mask = mask.unsqueeze(-1) # Prepare `query`, `key` and `value` for attention computation @@ -81,17 +100,18 @@ class MultiHeadAttention(Module): key = self.key(key) value = self.value(value) - # Compute attention scores + # Compute attention scores $Q K^T$ + # Results in a tensor of shape `[seq_len, seq_len, batch_size, heads]` scores = self.get_scores(query, key) - # Scale scores + # Scale scores $\frac{Q K^T}{\sqrt{d_k}}$ scores *= self.scale # Apply mask if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) - # $softmax$ attention + # $softmax$ attention $softmax\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$ attn = F.softmax(scores, dim=1) # Save attentions if debugging @@ -100,7 +120,7 @@ class MultiHeadAttention(Module): # Apply dropout attn = self.dropout(attn) - # Calculate the attention results + # Multiply by values $softmax\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$ x = torch.einsum("ijbh,jbhd->ibhd", attn, value) # Save attentions for any other calculations diff --git a/labml_nn/transformers/relative_mha.py b/labml_nn/transformers/relative_mha.py index 042db24d192757e515dc9d8f09bf79bc3e897aba..1f6f8bb602d858b42c5be4a53dc37e7f8428cdc2 100644 --- a/labml_nn/transformers/relative_mha.py +++ b/labml_nn/transformers/relative_mha.py @@ -1,29 +1,51 @@ """ -Implementation of "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" -https://arxiv.org/abs/1901.02860 +# Relative Multi-head Attention + +This is an implementation of +[Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) """ import torch from torch import nn from labml.logger import inspect -from .mha import MultiHeadAttention +from labml_nn.transformers.mha import MultiHeadAttention + +def shift_right(x: torch.Tensor): + """ + This method shifts $i^{th}$ row of a matrix by $i$ columns. -def relative_shift(x: torch.Tensor): + If the input is `[[1, 2 ,3], [4, 5 ,6], [7, 8, 9]]`, the shifted + result would be `[[1, 2 ,3], [0, 4, 5], [9, 0, 7]]`. + *Ideally we should mask out the lower triangle but it's ok for our purpose*. + """ + + # Concatenate a column of zeros zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:]) x_padded = torch.cat([x, zero_pad], dim=1) + # Remove excess elements from the end x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:]) - x = x_padded[:-1].view_as(x) return x class RelativeMultiHeadAttention(MultiHeadAttention): + """ + ## Relative Multi-Head Attention Module + + We override [Multi-Head Attention](mha.html) module so we only need to + write the `get_scores` method. + """ + def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): + # The linear transformations doesn't need a bias since we take care of it when + # calculating scores. + # However having a bias for `value` might make sense. super().__init__(heads, d_model, dropout_prob, False) + self.P = 2 ** 12 self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True) @@ -31,27 +53,63 @@ class RelativeMultiHeadAttention(MultiHeadAttention): self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True) def get_scores(self, query: torch.Tensor, key: torch.Tensor): + """ + With absolute attention + + \begin{align} + A^{abs}_{i,j} &= lin_q(X^q_i + P_i)^T lin_k(X^k_j + P_j) \\ + &= Q_i^T K_j + Q_i^T U_j + V_i^T K_j + V_i^T U_j + \end{align} + + where $Q_i$, $K_j$, $V_i$, and $U_j$ are linear transformations of + orginal embeddings and positional encodings. + + They reason out that the attention to a given key should be the same regardless of + the position of query. Hence replace $V_i^T K_j$ with a constant $v^T K_j$. + 🤔 May be worthwhile testing without this assumption. + + For the second and third terms relative positional encodings are introduced. + So $Q_i^T U_j$ is replaced with $Q_i^T R_{i - j}$ and $V_i^T U_j$ with $S_{i-j}$. + + \begin{align} + A^{rel}_{i,j} &= Q_i^T K_j + Q_i^T R_{i - j} + v^T K_j + S_{i-j} + \end{align} + + """ + + # $R_{i-j}$ pre-shift key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]] + # $S_{i-j}$ pre-shift key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]] + # $v^T$ + query_pos_bias = self.query_pos_bias[None, None, :, :] - ac = torch.einsum('ibhd,jbhd->ijbh', query + self.query_pos_bias[None, None, :, :], key) + # $Q_i^T K_j + v^T K_j$ + ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key) + # $Q_i^T R_{i - j}$ pre-shift b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb) + # $S_{i-j}$ pre-shift d = key_pos_bias[None, :, None, :] - bd = relative_shift(b + d) + # $Q_i^T R_{i - j} + S_{i-j}$ + bd = shift_right(b + d) bd = bd[:, -key.shape[0]:] return ac + bd -def _test_relative_shift(): +def _test_shift_right(): + x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + inspect(x) + inspect(shift_right(x)) + x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1) inspect(x[:, :, 0, 0]) - inspect(relative_shift(x)[:, :, 0, 0]) + inspect(shift_right(x)[:, :, 0, 0]) x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1) inspect(x[:, :, 0, 0]) - inspect(relative_shift(x)[:, :, 0, 0]) + inspect(shift_right(x)[:, :, 0, 0]) if __name__ == '__main__': - _test_relative_shift() + _test_shift_right()