Recurrent Highway Networks

This is an implementation of Recurrent Highway Networks.

11from typing import Optional
12
13import torch
14from torch import nn
15
16from labml_helpers.module import Module

Recurrent Highway Network Cell

This implements equations $(6) - (9)$.

$s_d^t = h_d^t \odot g_d^t + s_{d - 1}^t \odot c_d^t$

where

and for $0 < d < D$

$\odot$ stands for element-wise multiplication.

Here we have made a couple of changes to notations from the paper. To avoid confusion with time, the gate is represented with $g$, which was $t$ in the paper. To avoid confusion with multiple layers we use $d$ for depth and $D$ for total depth instead of $l$ and $L$ from paper.

We have also replaced the weight matrices and bias vectors from the equations with linear transforms, because that’s how the implementation is going to look like.

We implement weight tying, as described in paper, $c_d^t = (1 - g_d^t$.

19class RHNCell(Module):

input_size is the feature length of the input and hidden_size is feature length of the cell. depth is $D$.

57    def __init__(self, input_size: int, hidden_size: int, depth: int):
63        super().__init__()
64
65        self.hidden_size = hidden_size
66        self.depth = depth

We combine $lin_{hs}$ and $lin_{gs}$, with a single linear layer. We can then split the results to get the $lin_{hs}$ and $lin_{gs}$ components. This is the $lin_{hs}^d$ and $lin_{gs}^d$ for $0 \leq d < D$.

70        self.hidden_lin = nn.ModuleList([nn.Linear(hidden_size, 2 * hidden_size) for _ in range(depth)])

Similarly we combine $lin_{hx}$ and $lin_{gx}$.

73        self.input_lin = nn.Linear(input_size, 2 * hidden_size, bias=False)

x has shape [batch_size, input_size] and s has shape [batch_size, hidden_size].

75    def __call__(self, x: torch.Tensor, s: torch.Tensor):

Iterate $0 \leq d < D$

82        for d in range(self.depth):

We calculate the concatenation of linear transforms for $h$ and $g$

84            if d == 0:

The input is used only when $d$ is $0$.

86                hg = self.input_lin(x) + self.hidden_lin[d](s)
87            else:
88                hg = self.hidden_lin[d](s)

Use the first half of hg to get $h_d^t$

95            h = torch.tanh(hg[:, :self.hidden_size])

Use the second half of hg to get $g_d^t$

101            g = torch.sigmoid(hg[:, self.hidden_size:])
102
103            s = h * g + s * (1 - g)
104
105        return s

Multilayer Recurrent Highway Network

108class RHN(Module):

Create a network of n_layers of recurrent highway network layers, each with depth depth, $D$.

113    def __init__(self, input_size: int, hidden_size: int, depth: int, n_layers: int):
118        super().__init__()
119        self.n_layers = n_layers
120        self.hidden_size = hidden_size

Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below

123        self.cells = nn.ModuleList([RHNCell(input_size, hidden_size, depth)] +
124                                   [RHNCell(hidden_size, hidden_size, depth) for _ in range(n_layers - 1)])

x has shape [seq_len, batch_size, input_size] and state has shape [batch_size, hidden_size].

126    def __call__(self, x: torch.Tensor, state: Optional[torch.Tensor] = None):
131        time_steps, batch_size = x.shape[:2]

Initialize the state if None

134        if state is None:
135            s = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
136        else:

Reverse stack the state to get the state of each layer
📝 You can just work with the tensor itself but this is easier to debug

139            s = torch.unbind(state)

Array to collect the outputs of the final layer at each time step.

142        out = []

Run through the network for each time step

145        for t in range(time_steps):

Input to the first layer is the input itself

147            inp = x[t]

Loop through the layers

149            for layer in range(self.n_layers):

Get the state of the layer

151                s[layer] = self.cells[layer](inp, s[layer])

Input to the next layer is the state of this layer

153                inp = s[layer]

Collect the output of the final layer

155            out.append(s[-1])

Stack the outputs and states

158        out = torch.stack(out)
159        s = torch.stack(s)
160
161        return out, s