16import math
17from typing import Set
18
19import torch
20from torch import nn
21
22from labml.logger import inspect
We use rotary position embeddings in self-attention layers. We assume the positional information gets embedded in embeddings and therefore not use them in causal attention. Non-causal self-attention needs explicit positional information because it cannot infer it.
25class RotaryPositionalEmbeddings(nn.Module):
d
is the number of features base
is the constant used for calculating 36 def __init__(self, d: int, base: int = 10_000):
41 super().__init__()
43 self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
x
is the Tensor at the head of a key or a query with shape [ batch_size, seq_len, n_heads, d]
45 def forward(self, x: torch.Tensor):
Extract the shape
50 batch_size, seq_len, n_heads, d = x.shape
53 d_2 = d // 2
Create position indexes [0, 1, ..., seq_len - 1]
56 seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
Calculate the product of position index and
59 idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)
Concatenate so that for row we have
63 idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
Calculate
67 neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
79 rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])
82 return rx
85class SelfAttention(nn.Module):
d_model
is the number of features in transformer embeddings n_heads
is the number of attention heads d_k
is the number of features per head is_causal
indicates whether this is causal attention (masked)92 def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):
99 super().__init__()
100
101 self.is_causal = is_causal
102 self.n_heads = n_heads
103 self.d_k = d_k
To scale attentions before softmax by
106 self.scale = 1 / math.sqrt(self.d_k)
Linear layers for query, key and value heads.
109 self.query = nn.Linear(d_model, n_heads * d_k)
110 self.key = nn.Linear(d_model, n_heads * d_k)
111 self.value = nn.Linear(d_model, n_heads * d_k)
Pre-norm layer. The paper uses RMSNorm instead.
114 self.norm = nn.LayerNorm(d_model)
Softmax for attention probabilities
117 self.softmax = nn.Softmax(dim=-1)
Rotary positional embeddings
120 self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)
Final linear layer
123 self.output = nn.Linear(n_heads * d_k, d_model)
attn
is the attention matrix of shape [batch_size, n_heads, seq_len, seq_len]
125 def mask_attention(self, attn: torch.Tensor):
No masking for non-causal attention
133 if not self.is_causal:
134 return attn
Create a triangular mask
137 mask = torch.tril(attn.new_ones(attn.shape[-2:]))
Filter by the mask
139 return attn.masked_fill(mask == 0, float('-inf'))
h
is the transformer embeddings of shape [batch_size, seq_len, d_model]
141 def forward(self, h: torch.Tensor):
Residual connection
147 h_res = h
Pre-normalization
150 h = self.norm(h)
Get query, key, and values and split them in to heads. These will have shapes [batch_size, seq_len, n_heads, d_k]
154 mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
155 q = self.query(h).view(mh_shape)
156 k = self.key(h).view(mh_shape)
157 v = self.value(h).view(mh_shape)
Apply rotary positional embeddings
160 q = self.rotary_pe(q)
161 k = self.rotary_pe(k)
Calculate attentions
164 attn = torch.einsum('bihd,bjhd->bhij', q, k)
Scale it by
166 attn = attn * self.scale
Apply masks if it's causal attention
169 attn = self.mask_attention(attn)
Calculate attention probabilities
172 attn = self.softmax(attn)
Get values
175 h = torch.einsum("bhij,bjhd->bihd", attn, v)
Change from shape [batch_size, seq_len, n_heads, d_k]
to [batch_size, seq_len, n_heads * d_k]
179 h = h.reshape(*h.shape[:-2], -1)
Apply final linear layer. The result will have shape [batch_size, seq_len, d_model]
183 h = self.output(h)
Add the residual connection
186 return h + h_res
This is similar to the self-attention layer defined above, except that it gets keys and values from a different set of embeddings than the queries.
This is used in the encoder to encode the retrieved chunks based on the input chunks.
We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.
189class CrossAttention(nn.Module):
d_model
is the number of features in transformer embeddings n_heads
is the number of attention heads d_k
is the number of features per head203 def __init__(self, d_model: int, n_heads: int, d_k: int):
209 super().__init__()
210
211 self.n_heads = n_heads
212 self.d_k = d_k
To scale attentions before softmax by
215 self.scale = 1 / math.sqrt(self.d_k)
Linear layers for query, key and value heads.
218 self.query = nn.Linear(d_model, n_heads * d_k)
219 self.key = nn.Linear(d_model, n_heads * d_k)
220 self.value = nn.Linear(d_model, n_heads * d_k)
Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.
223 self.norm = nn.LayerNorm(d_model)
Softmax for attention probabilities
226 self.softmax = nn.Softmax(dim=-1)
Final linear layer
229 self.output = nn.Linear(n_heads * d_k, d_model)
e
are the retrieved nearest neighbor chunk embeddings with shape [batch_size, chunks, neighbors, neighbor_len, d_model]
h
are the input chunks from which the nearest neighbors were retrieved with shape [batch_size, chunks, chunk_len, d_model]
. This is already normalized.231 def forward(self, e: torch.Tensor, h: torch.Tensor):
Residual connection
240 e_res = e
Normalize retrieved chunks
243 e = self.norm(e)
Get query from the retrieved chunks
246 q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)
Get keys and values from the input chunks
248 k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
249 v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)
Calculate attention scores for all chunks. Each retrieved neighbor will pay attention to the original chunk that retrieved it. This will have shape [batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]
254 attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)
Scale attention scores
256 attn = attn * self.scale
Calculate softmax across the last dimension
259 attn = self.softmax(attn)
Gather values
262 e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)
Change from shape [batch_size, chunks, neighbors, neighbor_len, n_heads, d_k]
to [batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]
266 e = e.reshape(*e.shape[:-2], -1)
Apply final linear layer. The result will have shape [batch_size, chunks, neighbors, neighbor_len, d_model]
270 e = self.output(e)
Add residual connection
273 return e + e_res
This is similar to the cross-attention layer defined above.
This is used in the decoder to pay attention to the retrieved neighbor chunks.
We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.
276class ChunkedCrossAttention(nn.Module):
d_model
is the number of features in transformer embeddings n_heads
is the number of attention heads d_k
is the number of features per head chunk_len
is the length of a chunk288 def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):
296 super().__init__()
297
298 self.chunk_len = chunk_len
299 self.n_heads = n_heads
300 self.d_k = d_k
To scale attentions before softmax by
303 self.scale = 1 / math.sqrt(self.d_k)
Linear layers for query, key and value heads.
306 self.query = nn.Linear(d_model, n_heads * d_k)
307 self.key = nn.Linear(d_model, n_heads * d_k)
308 self.value = nn.Linear(d_model, n_heads * d_k)
Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.
311 self.norm = nn.LayerNorm(d_model)
Softmax for attention probabilities
314 self.softmax = nn.Softmax(dim=-1)
Final linear layer
317 self.output = nn.Linear(n_heads * d_k, d_model)
h
are the input embeddings of shape [batch_size, seq_len, d_model]
e
are the retrieved nearest neighbors of shape [batch_size, chunks, neighbors, neighbor_len, d_model]
319 def forward(self, h: torch.Tensor, e: torch.Tensor):
Get shape
326 batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
No attention if there are no chunks (for short inputs when sampling)
329 if chunks == 0:
330 return h
Residual connection
333 h_res = h
Remove the first chunk_len - 1
embeddings. The input pays attention to neighbors retrieved and encoded using the past tokens only; so that there is no information leakage. That is the retrieved neighbors from the first chunks will have information from the first chunk. So by shifting the sequence to the left by chunk_len - 1
we make sure that information only flows to the right.
341 h = h[:, self.chunk_len - 1:]
Pre-norm
343 h = self.norm(h)
Append empty embeddings to the end to be able to split the input into chunks
345 if h.shape[1] < chunks * self.chunk_len:
346 h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)
Reshape the input into chunks.
348 h = h.reshape(batch_size, chunks, self.chunk_len, d_model)
Get query from the input
351 q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)
Get keys and values from the retrieved neighbors
353 k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
354 v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)
Calculate attention scores for input chunks. Each chunk will pay attention to neighbors retrieved by the previous chunk. This will have shape [batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]
359 attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)
Scale attention scores
361 attn = attn * self.scale
Apply softmax over the last two dimensions neighbors, neighbor_len
364 attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)
Gather values
367 h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)
Change from shape [batch_size, chunks, chunk_len, n_heads, d_k]
to [batch_size, chunks * chunk_len, n_heads * d_k]
371 h = h.reshape(batch_size, chunks * self.chunk_len, -1)
Apply final linear layer. The result will have shape [batch_size, chunks * chunk_len, d_model]
375 h = self.output(h)
Append chunk_len - 1
zero embedding to the left; i.e. right shift it back
378 h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)
Truncate and add the residual connection
381 return h[:, :h_res.shape[1]] + h_res
This consists of two linear layers and an activation in the middle.
384class FeedForward(nn.Module):
d_model
is the number of features in transformer embeddings d_ff
is the number features in the hidden layer391 def __init__(self, d_model: int, d_ff: int):
397 super().__init__()
The two linear layers
400 self.lin1 = nn.Linear(d_model, d_ff)
401 self.lin2 = nn.Linear(d_ff, d_model)
ReLU Activation
404 self.act = nn.ReLU()
Pre-norm layer
407 self.norm = nn.LayerNorm(d_model)
h
are the embeddings of shape [batch_size, seq_len, d_model]
409 def forward(self, h: torch.Tensor):
Residual
415 h_res = h
Pre-norm
417 h = self.norm(h)
First linear layer
419 h = self.lin1(h)
Activation
421 h = self.act(h)
Second linear layer
423 h = self.lin2(h)
Add the residual connection
426 return h + h_res
429class NearestNeighborEncoder(nn.Module):
chunk_len
is the length of a chunk n_layer
is the number of layers in the encoder ca_layers
are the layers with cross attention d_model
is the number of features in embeddings n_heads
is the number of heads in attention layers d_k
is the size of attention heads d_ff
is the size of the feed-forward networks hidden layers436 def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
437 d_model: int, n_heads: int, d_k: int, d_ff: int):
448 super().__init__()
449 self.ca_layers = ca_layers
450 self.chunk_len = chunk_len
Cross-attention layers
452 self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])
Bi-directional self attention layers
454 self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])
Feed forward layers
456 self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
Pre-normalization layer for
459 self.norm_h = nn.LayerNorm(d_model)
e
are token embeddings of the retrieved nearest neighbors, of shape [batch_size, chunks, neighbors, neighbor_len, d_model]
h
is are the input token embeddings, of shape [batch_size, seq_len, d_model]
The chunks and neighbors are processed in parallel.
461 def forward(self, e: torch.Tensor, h: torch.Tensor):
Get shape
474 batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
477 h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)
Pre-norm
480 h_split = self.norm_h(h_split)
Keep the index of the cross attention layer
483 p_ca = 0
For all layers
485 for p in range(len(self.attn)):
Bi-directional self attention
488 e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)
Cross attention if
491 if p in self.ca_layers:
493 e = self.ca[p_ca](e, h_split)
Incremnt the cross attention index
495 p_ca += 1
Feed forward layer
498 e = self.ffw[p](e)
return
501 return e
504class RetroModel(nn.Module):
v_vocab
is the number of tokens in the vocabulary d_model
is the number of features in embeddings n_layers
is the number of layers in the decoder ca_layers
are the layers with cross attention chunk_len
is the length of a chunk n_heads
is the number of heads in attention layers d_k
is the size of attention heads d_ff
is the size of the feed-forward networks hidden layers encoder
is the nearest neighbor encoder511 def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
512 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):
524 super().__init__()
525
526 self.ca_layers = ca_layers
527 self.encoder = encoder
Token embedding layer
530 self.emb = nn.Embedding(n_vocab, d_model)
Chunked cross attention layers
532 self.cca = nn.ModuleList(
533 [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])
Attention layers
535 self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])
Feed forward layers
537 self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
Readout layer
539 self.read = nn.Linear(d_model, n_vocab)
Pre-normalization layer for nearest neighbor embeddings from
543 self.norm_e = nn.LayerNorm(d_model)
x
is the input sequence, of shape [batch_size, seq_len]
ret
are the retrieved neighbors of shape [batch_size, chunks, neighbors, neighbor_len]
545 def forward(self, x: torch.Tensor, ret: torch.Tensor):
Get input embeddings
554 h = self.emb(x)
560 ret_emb = self.emb(ret)
Keep index of the chunked cross attention layer
563 p_ca = 0
For all layers
565 for p in range(len(self.attn)):
Causal self attention
567 h = self.attn[p](h)
Get encoder embeddings before the first layer, when
571 if self.ca_layers and p == min(self.ca_layers):
575 e = self.encoder(ret_emb, h)
Normalize encoder embeddings
577 e = self.norm_e(e)
Chunked-cross attention if
580 if p in self.ca_layers:
582 h = self.cca[p_ca](h, e)
Increment chunked cross-attention index
584 p_ca += 1
587 h = self.ffw[p](h)
590 return self.read(h)
593def _test():
597 chunk_len = 4
598 d_model = 8
599 d_ff = 32
600 n_heads = 2
601 d_k = 4
602
603 device = torch.device('cuda:0')
604
605 m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
606 encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))
607
608 m.to(device)
609 x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
610 ret = [
611 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
612 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
613 ]
614 res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))
615
616 inspect(res)
620if __name__ == '__main__':
621 _test()