提交 3e240cfe 编写于 作者: V Varuna Jayasiri

annotations

上级 89ca5604
"""
# Transformers
* [Multi-head attention](mha.html)
* [Relative multi-head attention](relative_mha.html)
* [Transformer models](models.html)
* [Fixed positional encoding](positional_encoding.html)
"""
from .configs import TransformerConfigs
from .models import TransformerLayer, Encoder, Decoder, Generator, EncoderDecoder
from .mha import MultiHeadAttention
......
"""
# Multi-Headed Attention
The implementation is inspired from [Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html)
"""
import math
from typing import Optional
......@@ -10,6 +16,10 @@ from labml_helpers.module import Module
class PrepareForMultiHeadAttention(Module):
"""
This module does a linear transformation and splits the vector into given
number of heads for multi-head attention.
"""
def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
super().__init__()
self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
......@@ -17,22 +27,27 @@ class PrepareForMultiHeadAttention(Module):
self.d_k = d_k
def __call__(self, x: torch.Tensor):
# Input has shape `[seq_len, batch_size, d_model]`
seq_len, batch_size, _ = x.shape
x = self.linear(x)
x = x.view(seq_len, batch_size, self.heads, self.d_k)
# Output has shape `[seq_len, batch_size, heads, d_k]`
return x
class MultiHeadAttention(Module):
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool=True):
"""
### Multi-Head Attention
## Multi-Head Attention Module
This computes multi-headed attention for given `query`, `key` and `value` vectors.
`heads` is the number of heads.
`d_model` is the number of features in the `query`, `key` and `value` vectors.
$$Attention(Q, K, V) = softmax\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$$
"""
super().__init__()
......@@ -54,8 +69,12 @@ class MultiHeadAttention(Module):
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
"""
### Calculate scores between queries and keys
### Calculate scores between queries and keys.
This method can be overriden for other variations like relative attention.
"""
# Calculate $Q K^T$
return torch.einsum('ibhd,jbhd->ijbh', query, key)
def __call__(self, *,
......@@ -69,10 +88,10 @@ class MultiHeadAttention(Module):
if mask is not None:
# `mask` has shape `[seq_len, seq_len, batch_size]`,
# where first dimension is the query dimension.
# If the query dimension is equal to $`$ it will be broadcasted to match
# If the query dimension is equal to $1$ it will be broadcasted
assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
# Same mask applied to all `h` heads.
# Same mask applied to all heads.
mask = mask.unsqueeze(-1)
# Prepare `query`, `key` and `value` for attention computation
......@@ -81,17 +100,18 @@ class MultiHeadAttention(Module):
key = self.key(key)
value = self.value(value)
# Compute attention scores
# Compute attention scores $Q K^T$
# Results in a tensor of shape `[seq_len, seq_len, batch_size, heads]`
scores = self.get_scores(query, key)
# Scale scores
# Scale scores $\frac{Q K^T}{\sqrt{d_k}}$
scores *= self.scale
# Apply mask
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# $softmax$ attention
# $softmax$ attention $softmax\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)$
attn = F.softmax(scores, dim=1)
# Save attentions if debugging
......@@ -100,7 +120,7 @@ class MultiHeadAttention(Module):
# Apply dropout
attn = self.dropout(attn)
# Calculate the attention results
# Multiply by values $softmax\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# Save attentions for any other calculations
......
"""
Implementation of "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
https://arxiv.org/abs/1901.02860
# Relative Multi-head Attention
This is an implementation of
[Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860)
"""
import torch
from torch import nn
from labml.logger import inspect
from .mha import MultiHeadAttention
from labml_nn.transformers.mha import MultiHeadAttention
def shift_right(x: torch.Tensor):
"""
This method shifts $i^{th}$ row of a matrix by $i$ columns.
def relative_shift(x: torch.Tensor):
If the input is `[[1, 2 ,3], [4, 5 ,6], [7, 8, 9]]`, the shifted
result would be `[[1, 2 ,3], [0, 4, 5], [9, 0, 7]]`.
*Ideally we should mask out the lower triangle but it's ok for our purpose*.
"""
# Concatenate a column of zeros
zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
x_padded = torch.cat([x, zero_pad], dim=1)
# Remove excess elements from the end
x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
x = x_padded[:-1].view_as(x)
return x
class RelativeMultiHeadAttention(MultiHeadAttention):
"""
## Relative Multi-Head Attention Module
We override [Multi-Head Attention](mha.html) module so we only need to
write the `get_scores` method.
"""
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
# The linear transformations doesn't need a bias since we take care of it when
# calculating scores.
# However having a bias for `value` might make sense.
super().__init__(heads, d_model, dropout_prob, False)
self.P = 2 ** 12
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)
......@@ -31,27 +53,63 @@ class RelativeMultiHeadAttention(MultiHeadAttention):
self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
"""
With absolute attention
\begin{align}
A^{abs}_{i,j} &= lin_q(X^q_i + P_i)^T lin_k(X^k_j + P_j) \\
&= Q_i^T K_j + Q_i^T U_j + V_i^T K_j + V_i^T U_j
\end{align}
where $Q_i$, $K_j$, $V_i$, and $U_j$ are linear transformations of
orginal embeddings and positional encodings.
They reason out that the attention to a given key should be the same regardless of
the position of query. Hence replace $V_i^T K_j$ with a constant $v^T K_j$.
🤔 May be worthwhile testing without this assumption.
For the second and third terms relative positional encodings are introduced.
So $Q_i^T U_j$ is replaced with $Q_i^T R_{i - j}$ and $V_i^T U_j$ with $S_{i-j}$.
\begin{align}
A^{rel}_{i,j} &= Q_i^T K_j + Q_i^T R_{i - j} + v^T K_j + S_{i-j}
\end{align}
"""
# $R_{i-j}$ pre-shift
key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]
# $S_{i-j}$ pre-shift
key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]
# $v^T$
query_pos_bias = self.query_pos_bias[None, None, :, :]
ac = torch.einsum('ibhd,jbhd->ijbh', query + self.query_pos_bias[None, None, :, :], key)
# $Q_i^T K_j + v^T K_j$
ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)
# $Q_i^T R_{i - j}$ pre-shift
b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
# $S_{i-j}$ pre-shift
d = key_pos_bias[None, :, None, :]
bd = relative_shift(b + d)
# $Q_i^T R_{i - j} + S_{i-j}$
bd = shift_right(b + d)
bd = bd[:, -key.shape[0]:]
return ac + bd
def _test_relative_shift():
def _test_shift_right():
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
inspect(x)
inspect(shift_right(x))
x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1)
inspect(x[:, :, 0, 0])
inspect(relative_shift(x)[:, :, 0, 0])
inspect(shift_right(x)[:, :, 0, 0])
x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1)
inspect(x[:, :, 0, 0])
inspect(relative_shift(x)[:, :, 0, 0])
inspect(shift_right(x)[:, :, 0, 0])
if __name__ == '__main__':
_test_relative_shift()
_test_shift_right()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册