diff --git a/labml_nn/transformers/__init__.py b/labml_nn/transformers/__init__.py
index 7bf6c9cfa0e9828b0bd91d2a0b080e17b7574559..52ac31d6c63b083f59664a559dbc8297d547629f 100644
--- a/labml_nn/transformers/__init__.py
+++ b/labml_nn/transformers/__init__.py
@@ -1,6 +1,4 @@
"""
-Star
-
# Transformers
* [Multi-head attention](mha.html)
diff --git a/labml_nn/transformers/mha.py b/labml_nn/transformers/mha.py
index 863ba3515131a9f633123c4a9c69ecda5e235eaf..4f635e7967a22915a32f18c1cb6126df2bffce71 100644
--- a/labml_nn/transformers/mha.py
+++ b/labml_nn/transformers/mha.py
@@ -1,6 +1,4 @@
"""
-Star
-
# Multi-Headed Attention
The implementation is inspired from [Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html)
diff --git a/labml_nn/transformers/models.py b/labml_nn/transformers/models.py
index e9aaa81a109820f111e520e9f9dac9534fa448ac..361477b4bf6179b0eb2704e8190fb18d896984f7 100644
--- a/labml_nn/transformers/models.py
+++ b/labml_nn/transformers/models.py
@@ -11,6 +11,9 @@ from .positional_encoding import get_positional_encoding
class EmbeddingsWithPositionalEncoding(Module):
+ """
+ ## Embed tokenas and add [fixed positional encoding](positional_encoding.html)
+ """
def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
super().__init__()
self.linear = nn.Embedding(n_vocab, d_model)
@@ -23,6 +26,9 @@ class EmbeddingsWithPositionalEncoding(Module):
class EmbeddingsWithLearnedPositionalEncoding(Module):
+ """
+ ## Embed tokenas and add parameterized positional encodings
+ """
def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
super().__init__()
self.linear = nn.Embedding(n_vocab, d_model)
@@ -35,6 +41,9 @@ class EmbeddingsWithLearnedPositionalEncoding(Module):
class FeedForward(Module):
+ """
+ ## Position-wise feed-forward network with hidden layer
+ """
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.layer1 = nn.Linear(d_model, d_ff)
@@ -49,6 +58,20 @@ class FeedForward(Module):
class TransformerLayer(Module):
+ """
+ ## Transformer Layer
+
+ This can act as a encoder layer or a decoder layer.
+
+ 🗒 Some implementations, including the paper seem to have differences
+ in where the layer-normalization is done.
+ Here we do a layer normalization before attention and feed-forward networks,
+ and add the original residual vectors.
+ Alternative is to do a layer normalzation after adding the residuals.
+ But we found this to be less stable when training.
+ We found a detailed discussion about this in paper
+ [On Layer Normalization in the Transformer Architecture](https://arxiv.org/abs/2002.04745).
+ """
def __init__(self, *,
d_model: int,
self_attn: MultiHeadAttention,
@@ -71,47 +94,77 @@ class TransformerLayer(Module):
mask: torch.Tensor,
src: torch.Tensor = None,
src_mask: torch.Tensor = None):
+ # Normalize the vectors before doing self attention
z = self.norm_self_attn(x)
- attn_self = self.self_attn(query=z, key=z, value=z, mask=mask)
- x = x + self.dropout(attn_self)
-
+ # Run through self attention, i.e. keys and values are from self
+ self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
+ # Add the self attention results
+ x = x + self.dropout(self_attn)
+
+ # If a source is provided, get results from attention to source.
+ # This is when you have a decoder layer that pays attention to
+ # encoder outputs
if src is not None:
+ # Normalize vectors
z = self.norm_src_attn(x)
+ # Attention to source. i.e. keys and values are from source
attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
+ # Add the source attention results
x = x + self.dropout(attn_src)
+ # Normalize for feed-forward
z = self.norm_ff(x)
+ # Pass through the feed-forward network
ff = self.feed_forward(z)
+ # Add the feed-forward results back
x = x + self.dropout(ff)
return x
class Encoder(Module):
+ """
+ ## Transformer Encoder
+ """
def __init__(self, layer: TransformerLayer, n_layers: int):
super().__init__()
+ # Make copies of the transformer layer
self.layers = clone_module_list(layer, n_layers)
self.norm = nn.LayerNorm([layer.size])
def __call__(self, x: torch.Tensor, mask: torch.Tensor):
+ # Run through each transformer layer
for layer in self.layers:
x = layer(x=x, mask=mask)
+ # Finally, normalize the vectors
return self.norm(x)
class Decoder(Module):
+ """
+ ## Transformer Decoder
+ """
def __init__(self, layer: TransformerLayer, n_layers: int):
super().__init__()
+ # Make copies of the transformer layer
self.layers = clone_module_list(layer, n_layers)
self.norm = nn.LayerNorm([layer.size])
- def __call__(self, x, memory, src_mask, tgt_mask):
+ def __call__(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
+ # Run through each transformer layer
for layer in self.layers:
x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
+ # Finally, normalize the vectors
return self.norm(x)
class Generator(Module):
+ """
+ ## Generator
+
+ This predicts the tokens and gives the lof softmaxes of those.
+ You don't need this if you are using `nn.CrossEntropyLoss`.
+ """
def __init__(self, n_vocab: int, d_model: int):
super().__init__()
self.projection = nn.Linear(d_model, n_vocab)
@@ -121,6 +174,9 @@ class Generator(Module):
class EncoderDecoder(Module):
+ """
+ ## Combined Encoder-Decoder
+ """
def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
super().__init__()
self.encoder = encoder
@@ -135,10 +191,11 @@ class EncoderDecoder(Module):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
- def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor,
- tgt_mask: torch.Tensor):
- return self.decode(self.encode(src, src_mask), src_mask,
- tgt, tgt_mask)
+ def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
+ # Runs the source through encoder
+ enc = self.encode(src, src_mask)
+ # Run encodings and targets through decoder
+ return self.decode(enc, src_mask, tgt, tgt_mask)
def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
return self.encoder(self.src_embed(src), src_mask)
diff --git a/labml_nn/transformers/positional_encoding.py b/labml_nn/transformers/positional_encoding.py
index 731fc0084e76c05bcfe7cabc736695e8e2ddb497..82957b46b45da3ba689a66e5cea836acd71ede64 100644
--- a/labml_nn/transformers/positional_encoding.py
+++ b/labml_nn/transformers/positional_encoding.py
@@ -1,3 +1,18 @@
+"""
+# Fixed Positional Encodings
+
+The positional encoding encodes the position along the sequence into
+ a vector of size `d_model`.
+
+\begin{align}
+PE_{p,2i} &= sin\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg) \\
+PE_{p,2i + 1} &= cos\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)
+\end{align}
+
+Where $1 \leq 2i, 2i + 1 \leq d_{model}$ are the feature indexes in the encoding,
+and $p$ is the position.
+"""
+
import math
import matplotlib.pyplot as plt
@@ -23,12 +38,20 @@ class PositionalEncoding(Module):
def get_positional_encoding(d_model: int, max_len: int = 5000):
+ # Empty encodings vectors
encodings = torch.zeros(max_len, d_model)
+ # Position indexes
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
+ # $2 * i$
two_i = torch.arange(0, d_model, 2, dtype=torch.float32)
+ # $10000^{\frac{2i}{d_{model}}$
div_term = torch.exp(two_i * -(math.log(10000.0) / d_model))
+ # $PE_{p,2i} = sin\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)$
encodings[:, 0::2] = torch.sin(position * div_term)
+ # $PE_{p,2i + 1} = cos\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)$
encodings[:, 1::2] = torch.cos(position * div_term)
+
+ # Add batch dimension
encodings = encodings.unsqueeze(1).requires_grad_(False)
return encodings
diff --git a/labml_nn/transformers/relative_mha.py b/labml_nn/transformers/relative_mha.py
index 43d61aeff593f3067bcfb69dd841567ed422e362..1f6f8bb602d858b42c5be4a53dc37e7f8428cdc2 100644
--- a/labml_nn/transformers/relative_mha.py
+++ b/labml_nn/transformers/relative_mha.py
@@ -1,6 +1,4 @@
"""
-Star
-
# Relative Multi-head Attention
This is an implementation of