RETRO model

This is the model definition for RETRO.

View Run

16import math
17from typing import Set
18
19import torch
20from torch import nn
21
22from labml.logger import inspect

RoPE embeddings

We use rotary position embeddings in self-attention layers. We assume the positional information gets embedded in embeddings and therefore not use them in causal attention. Non-causal self-attention needs explicit positional information because it cannot infer it.

25class RotaryPositionalEmbeddings(nn.Module):
  • d is the number of features
  • base is the constant used for calculating
36    def __init__(self, d: int, base: int = 10_000):
41        super().__init__()

43        self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
  • x is the Tensor at the head of a key or a query with shape [ batch_size, seq_len, n_heads, d]
45    def forward(self, x: torch.Tensor):

Extract the shape

50        batch_size, seq_len, n_heads, d = x.shape

53        d_2 = d // 2

Create position indexes [0, 1, ..., seq_len - 1]

56        seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)

Calculate the product of position index and

59        idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)

Concatenate so that for row we have

63        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)

Calculate

67        neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)

Calculate

for

79        rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])

82        return rx

Self-Attention Layer

This applies causal and non-causal multi-headed self-attention.

85class SelfAttention(nn.Module):
  • d_model is the number of features in transformer embeddings
  • n_heads is the number of attention heads
  • d_k is the number of features per head
  • is_causal indicates whether this is causal attention (masked)
92    def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):
99        super().__init__()
100
101        self.is_causal = is_causal
102        self.n_heads = n_heads
103        self.d_k = d_k

To scale attentions before softmax by

106        self.scale = 1 / math.sqrt(self.d_k)

Linear layers for query, key and value heads.

109        self.query = nn.Linear(d_model, n_heads * d_k)
110        self.key = nn.Linear(d_model, n_heads * d_k)
111        self.value = nn.Linear(d_model, n_heads * d_k)

Pre-norm layer. The paper uses RMSNorm instead.

114        self.norm = nn.LayerNorm(d_model)

Softmax for attention probabilities

117        self.softmax = nn.Softmax(dim=-1)

Rotary positional embeddings

120        self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)

Final linear layer

123        self.output = nn.Linear(n_heads * d_k, d_model)

Mask the attention layer for causal attention

  • attn is the attention matrix of shape [batch_size, n_heads, seq_len, seq_len]
125    def mask_attention(self, attn: torch.Tensor):

No masking for non-causal attention

133        if not self.is_causal:
134            return attn

Create a triangular mask

137        mask = torch.tril(attn.new_ones(attn.shape[-2:]))

Filter by the mask

139        return attn.masked_fill(mask == 0, float('-inf'))
  • h is the transformer embeddings of shape [batch_size, seq_len, d_model]
141    def forward(self, h: torch.Tensor):

Residual connection

147        h_res = h

Pre-normalization

150        h = self.norm(h)

Get query, key, and values and split them in to heads. These will have shapes [batch_size, seq_len, n_heads, d_k]

154        mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
155        q = self.query(h).view(mh_shape)
156        k = self.key(h).view(mh_shape)
157        v = self.value(h).view(mh_shape)

Apply rotary positional embeddings

160        q = self.rotary_pe(q)
161        k = self.rotary_pe(k)

Calculate attentions

164        attn = torch.einsum('bihd,bjhd->bhij', q, k)

Scale it by

166        attn = attn * self.scale

Apply masks if it's causal attention

169        attn = self.mask_attention(attn)

Calculate attention probabilities

172        attn = self.softmax(attn)

Get values

175        h = torch.einsum("bhij,bjhd->bihd", attn, v)

Change from shape [batch_size, seq_len, n_heads, d_k] to [batch_size, seq_len, n_heads * d_k]

179        h = h.reshape(*h.shape[:-2], -1)

Apply final linear layer. The result will have shape [batch_size, seq_len, d_model]

183        h = self.output(h)

Add the residual connection

186        return h + h_res

Cross-Attention Layer

This is similar to the self-attention layer defined above, except that it gets keys and values from a different set of embeddings than the queries.

This is used in the encoder to encode the retrieved chunks based on the input chunks.

We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.

189class CrossAttention(nn.Module):
  • d_model is the number of features in transformer embeddings
  • n_heads is the number of attention heads
  • d_k is the number of features per head
203    def __init__(self, d_model: int, n_heads: int, d_k: int):
209        super().__init__()
210
211        self.n_heads = n_heads
212        self.d_k = d_k

To scale attentions before softmax by

215        self.scale = 1 / math.sqrt(self.d_k)

Linear layers for query, key and value heads.

218        self.query = nn.Linear(d_model, n_heads * d_k)
219        self.key = nn.Linear(d_model, n_heads * d_k)
220        self.value = nn.Linear(d_model, n_heads * d_k)

Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.

223        self.norm = nn.LayerNorm(d_model)

Softmax for attention probabilities

226        self.softmax = nn.Softmax(dim=-1)

Final linear layer

229        self.output = nn.Linear(n_heads * d_k, d_model)
  • e are the retrieved nearest neighbor chunk embeddings with shape [batch_size, chunks, neighbors, neighbor_len, d_model]
  • h are the input chunks from which the nearest neighbors were retrieved with shape [batch_size, chunks, chunk_len, d_model] . This is already normalized.
231    def forward(self, e: torch.Tensor, h: torch.Tensor):

Residual connection

240        e_res = e

Normalize retrieved chunks

243        e = self.norm(e)

Get query from the retrieved chunks

246        q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)

Get keys and values from the input chunks

248        k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
249        v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)

Calculate attention scores for all chunks. Each retrieved neighbor will pay attention to the original chunk that retrieved it. This will have shape [batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]

254        attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)

Scale attention scores

256        attn = attn * self.scale

Calculate softmax across the last dimension

259        attn = self.softmax(attn)

Gather values

262        e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)

Change from shape [batch_size, chunks, neighbors, neighbor_len, n_heads, d_k] to [batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]

266        e = e.reshape(*e.shape[:-2], -1)

Apply final linear layer. The result will have shape [batch_size, chunks, neighbors, neighbor_len, d_model]

270        e = self.output(e)

Add residual connection

273        return e + e_res

Chunked Cross-Attention Layer

This is similar to the cross-attention layer defined above.

This is used in the decoder to pay attention to the retrieved neighbor chunks.

We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.

276class ChunkedCrossAttention(nn.Module):
  • d_model is the number of features in transformer embeddings
  • n_heads is the number of attention heads
  • d_k is the number of features per head
  • chunk_len is the length of a chunk
288    def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):
296        super().__init__()
297
298        self.chunk_len = chunk_len
299        self.n_heads = n_heads
300        self.d_k = d_k

To scale attentions before softmax by

303        self.scale = 1 / math.sqrt(self.d_k)

Linear layers for query, key and value heads.

306        self.query = nn.Linear(d_model, n_heads * d_k)
307        self.key = nn.Linear(d_model, n_heads * d_k)
308        self.value = nn.Linear(d_model, n_heads * d_k)

Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.

311        self.norm = nn.LayerNorm(d_model)

Softmax for attention probabilities

314        self.softmax = nn.Softmax(dim=-1)

Final linear layer

317        self.output = nn.Linear(n_heads * d_k, d_model)

h are the input embeddings of shape [batch_size, seq_len, d_model] e are the retrieved nearest neighbors of shape [batch_size, chunks, neighbors, neighbor_len, d_model]

319    def forward(self, h: torch.Tensor, e: torch.Tensor):

Get shape

326        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape

No attention if there are no chunks (for short inputs when sampling)

329        if chunks == 0:
330            return h

Residual connection

333        h_res = h

Remove the first chunk_len - 1 embeddings. The input pays attention to neighbors retrieved and encoded using the past tokens only; so that there is no information leakage. That is the retrieved neighbors from the first chunks will have information from the first chunk. So by shifting the sequence to the left by chunk_len - 1 we make sure that information only flows to the right.

341        h = h[:, self.chunk_len - 1:]

Pre-norm

343        h = self.norm(h)

Append empty embeddings to the end to be able to split the input into chunks

345        if h.shape[1] < chunks * self.chunk_len:
346            h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)

Reshape the input into chunks.

348        h = h.reshape(batch_size, chunks, self.chunk_len, d_model)

Get query from the input

351        q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)

Get keys and values from the retrieved neighbors

353        k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
354        v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)

Calculate attention scores for input chunks. Each chunk will pay attention to neighbors retrieved by the previous chunk. This will have shape [batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]

359        attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)

Scale attention scores

361        attn = attn * self.scale

Apply softmax over the last two dimensions neighbors, neighbor_len

364        attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)

Gather values

367        h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)

Change from shape [batch_size, chunks, chunk_len, n_heads, d_k] to [batch_size, chunks * chunk_len, n_heads * d_k]

371        h = h.reshape(batch_size, chunks * self.chunk_len, -1)

Apply final linear layer. The result will have shape [batch_size, chunks * chunk_len, d_model]

375        h = self.output(h)

Append chunk_len - 1 zero embedding to the left; i.e. right shift it back

378        h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)

Truncate and add the residual connection

381        return h[:, :h_res.shape[1]] + h_res

Position-wise Feed Forward Layer

This consists of two linear layers and an activation in the middle.

384class FeedForward(nn.Module):
  • d_model is the number of features in transformer embeddings
  • d_ff is the number features in the hidden layer
391    def __init__(self, d_model: int, d_ff: int):
397        super().__init__()

The two linear layers

400        self.lin1 = nn.Linear(d_model, d_ff)
401        self.lin2 = nn.Linear(d_ff, d_model)

ReLU Activation

404        self.act = nn.ReLU()

Pre-norm layer

407        self.norm = nn.LayerNorm(d_model)

h are the embeddings of shape [batch_size, seq_len, d_model]

409    def forward(self, h: torch.Tensor):

Residual

415        h_res = h

Pre-norm

417        h = self.norm(h)

First linear layer

419        h = self.lin1(h)

Activation

421        h = self.act(h)

Second linear layer

423        h = self.lin2(h)

Add the residual connection

426        return h + h_res

Nearest Neighbor Encoder

This module encodes the retrieved nearest neighbors

429class NearestNeighborEncoder(nn.Module):
  • chunk_len is the length of a chunk
  • n_layer is the number of layers in the encoder
  • ca_layers are the layers with cross attention
  • d_model is the number of features in embeddings
  • n_heads is the number of heads in attention layers
  • d_k is the size of attention heads
  • d_ff is the size of the feed-forward networks hidden layers
436    def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
437                 d_model: int, n_heads: int, d_k: int, d_ff: int):
448        super().__init__()
449        self.ca_layers = ca_layers
450        self.chunk_len = chunk_len

Cross-attention layers

452        self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])

Bi-directional self attention layers

454        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])

Feed forward layers

456        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])

Pre-normalization layer for

459        self.norm_h = nn.LayerNorm(d_model)
  • e are token embeddings of the retrieved nearest neighbors, of shape [batch_size, chunks, neighbors, neighbor_len, d_model]
  • h is are the input token embeddings, of shape [batch_size, seq_len, d_model]

The chunks and neighbors are processed in parallel.

461    def forward(self, e: torch.Tensor, h: torch.Tensor):

Get shape

474        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape

477        h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)

Pre-norm

480        h_split = self.norm_h(h_split)

Keep the index of the cross attention layer

483        p_ca = 0

For all layers

485        for p in range(len(self.attn)):

Bi-directional self attention

488            e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)

Cross attention if

491            if p in self.ca_layers:

493                e = self.ca[p_ca](e, h_split)

Incremnt the cross attention index

495                p_ca += 1

Feed forward layer

498            e = self.ffw[p](e)

return

501        return e

Retro Model

This is the Retro decoder

504class RetroModel(nn.Module):
  • v_vocab is the number of tokens in the vocabulary
  • d_model is the number of features in embeddings
  • n_layers is the number of layers in the decoder
  • ca_layers are the layers with cross attention
  • chunk_len is the length of a chunk
  • n_heads is the number of heads in attention layers
  • d_k is the size of attention heads
  • d_ff is the size of the feed-forward networks hidden layers
  • encoder is the nearest neighbor encoder
511    def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
512                 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):
524        super().__init__()
525
526        self.ca_layers = ca_layers
527        self.encoder = encoder

Token embedding layer

530        self.emb = nn.Embedding(n_vocab, d_model)

Chunked cross attention layers

532        self.cca = nn.ModuleList(
533            [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])

Attention layers

535        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])

Feed forward layers

537        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])

Readout layer

539        self.read = nn.Linear(d_model, n_vocab)

Pre-normalization layer for nearest neighbor embeddings from

543        self.norm_e = nn.LayerNorm(d_model)
  • x is the input sequence, of shape [batch_size, seq_len]
  • ret are the retrieved neighbors of shape [batch_size, chunks, neighbors, neighbor_len]
545    def forward(self, x: torch.Tensor, ret: torch.Tensor):

Get input embeddings

554        h = self.emb(x)

Embeddings of the retrieved neighbors .

We use same embeddings for both input and neighbors

560        ret_emb = self.emb(ret)

Keep index of the chunked cross attention layer

563        p_ca = 0

For all layers

565        for p in range(len(self.attn)):

Causal self attention

567            h = self.attn[p](h)

Get encoder embeddings before the first layer, when

571            if self.ca_layers and p == min(self.ca_layers):

We passed the embeddings of to encoder.

575                e = self.encoder(ret_emb, h)

Normalize encoder embeddings

577                e = self.norm_e(e)

Chunked-cross attention if

580            if p in self.ca_layers:

582                h = self.cca[p_ca](h, e)

Increment chunked cross-attention index

584                p_ca += 1

587            h = self.ffw[p](h)

590        return self.read(h)

Test the model with fake data

593def _test():
597    chunk_len = 4
598    d_model = 8
599    d_ff = 32
600    n_heads = 2
601    d_k = 4
602
603    device = torch.device('cuda:0')
604
605    m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
606                   encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))
607
608    m.to(device)
609    x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
610    ret = [
611        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
612        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
613    ]
614    res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))
615
616    inspect(res)

620if __name__ == '__main__':
621    _test()