From 2ca85492e18829aef2cc9b05a2701035ed55024e Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Mon, 7 Sep 2020 18:27:18 +0530 Subject: [PATCH] rhn docs --- labml_nn/__init__.py | 9 +- .../recurrent_highway_networks/__init__.py | 119 ++++++++++++++++-- readme.rst | 15 ++- 3 files changed, 120 insertions(+), 23 deletions(-) diff --git a/labml_nn/__init__.py b/labml_nn/__init__.py index 5819e8da..d6e6e3f3 100644 --- a/labml_nn/__init__.py +++ b/labml_nn/__init__.py @@ -2,10 +2,9 @@ # LabML Models * [Transformers](transformers/index.html) +* [Recurrent Highway Networks](recurrent_highway_networks/index.html) +* [LSTM](lstm/index.html) -TODO: - -* LSTM -* Highway Networks -* 🤔 +If you have any suggestions for other new implementations, +please create a [Github Issue](https://github.com/lab-ml/labml_nn/issues). """ diff --git a/labml_nn/recurrent_highway_networks/__init__.py b/labml_nn/recurrent_highway_networks/__init__.py index ddb32ac4..f52c312c 100644 --- a/labml_nn/recurrent_highway_networks/__init__.py +++ b/labml_nn/recurrent_highway_networks/__init__.py @@ -1,3 +1,8 @@ +""" +This is an implementation of [Recurrent Highway Networks](https://arxiv.org/abs/1607.03474). +""" +from typing import Optional + import torch from torch import nn @@ -5,54 +10,142 @@ from labml_helpers.module import Module class RHNCell(Module): + """ + ## Recurrent Highway Network Cell + + This implements equations $(6) - (9)$. + + $s_d^t = h_d^t . g_d^t + s_{d - 1}^t . c_d^t$ + + where + + \begin{align} + h_0^t &= tanh(lin_{hx}(x) + lin_{hs}(s_D^{t-1})) \\ + g_0^t &= \sigma(lin_{gx}(x) + lin_{gs}^1(s_D^{t-1})) \\ + c_0^t &= \sigma(lin_{cx}(x) + lin_{cs}^1(s_D^{t-1})) + \end{align} + + and for $0 < d < D$ + + \begin{align} + h_d^t &= tanh(lin_{hs}^d(s_d^t)) \\ + g_d^t &= \sigma(lin_{gs}^d(s_d^t)) \\ + c_d^t &= \sigma(lin_{cs}^d(s_d^t)) + \end{align} + + Here we have made a couple of changes to notations from the paper. + To avoid confusion with time, the gate is represented with $g$, + which was $t$ in the paper. + To avoid confusion with multiple layers we use $d$ for depth and $D$ for + total depth instead of $l$ and $L$ from paper. + + We have also replaced the weight matrices and bias vectors from the equations with + linear transforms, because that's how the implementation is going to look like. + + We implement weight tying, as described in paper, $c_d^t = (1 - g_d^t$. + """ + def __init__(self, input_size: int, hidden_size: int, depth: int): + """ + `input_size` is the feature length of the input and `hidden_size` is + feature length of the cell. + `depth` is $D$. + """ super().__init__() self.hidden_size = hidden_size self.depth = depth + # We combine $lin_{hs}$ and $lin_{gs}$, with a single linear layer. + # We can then split the results to get the $lin_{hs}$ and $lin_{gs}$ components. + # This is the $lin_{hs}^d$ and $lin_{gs}^d$ for $0 \leq d < D$. self.hidden_lin = nn.ModuleList([nn.Linear(hidden_size, 2 * hidden_size) for _ in range(depth)]) + + # Similarly we combine $lin_{hx}$ and $lin_{gx}$. self.input_lin = nn.Linear(input_size, 2 * hidden_size, bias=False) - def __call__(self, x, s): - for i in range(self.depth): - if i == 0: - ht = self.input_lin(x) + self.hidden_lin[i](s) + def __call__(self, x: torch.Tensor, s: torch.Tensor): + """ + `x` has shape `[batch_size, input_size]` and + `s` has shape `[batch_size, hidden_size]`. + """ + + # Iterate $0 \leq d < D$ + for d in range(self.depth): + # We calculate the concatenation of linear transforms for $h$ and $g$ + if d == 0: + # The input is used only when $d$ is $0$. + hg = self.input_lin(x) + self.hidden_lin[d](s) else: - ht = self.hidden_lin[i](s) + hg = self.hidden_lin[d](s) - h = torch.tanh(ht[:, :self.hidden_size]) - t = torch.sigmoid(ht[:, self.hidden_size:]) + # Use the first half of `hg` to get $h_d^t$ + # \begin{align} + # h_0^t &= tanh(lin_{hx}(x) + lin_{hs}(s_D^{t-1})) \\ + # h_d^t &= tanh(lin_{hs}^d(s_d^t)) + # \end{align} + h = torch.tanh(hg[:, :self.hidden_size]) + # Use the second half of `hg` to get $g_d^t$ + # \begin{align} + # g_0^t &= \sigma(lin_{gx}(x) + lin_{gs}^1(s_D^{t-1})) \\ + # g_d^t &= \sigma(lin_{gs}^d(s_d^t)) + # \end{align} + g = torch.sigmoid(hg[:, self.hidden_size:]) - s = s + (h - s) * t + s = h * g + s * (1 - g) return s class RHN(Module): + """ + ### Multilayer Recurrent Highway Network + """ + def __init__(self, input_size: int, hidden_size: int, depth: int, n_layers: int): + """ + Create a network of `n_layers` of recurrent highway network layers, each with depth `depth`, $D$. + """ + super().__init__() self.n_layers = n_layers self.hidden_size = hidden_size + # Create cells for each layer. Note that only the first layer gets the input directly. + # Rest of the layers get the input from the layer below self.cells = nn.ModuleList([RHNCell(input_size, hidden_size, depth)] + [RHNCell(hidden_size, hidden_size, depth) for _ in range(n_layers - 1)]) - def __call__(self, x: torch.Tensor, state=None): - # x [seq_len, batch, d_model] + def __call__(self, x: torch.Tensor, state: Optional[torch.Tensor] = None): + """ + `x` has shape `[seq_len, batch_size, input_size]` and + `s` has shape `[batch_size, hidden_size]`. + """ time_steps, batch_size = x.shape[:2] + # Initialize the state if `None` if state is None: s = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] else: + # Reverse stack the state to get the state of each layer
+ # 📝 You can just work with the tensor itself but this is easier to debug s = torch.unbind(state) + # Array to collect the outputs of the final layer at each time step. out = [] + + # Run through the network for each time step for t in range(time_steps): + # Input to the first layer is the input itself inp = x[t] - for i in range(self.n_layers): - s[i] = self.cells[i](inp, s[i]) - inp = s[i] + # Loop through the layers + for layer in range(self.n_layers): + # Get the state of the first layer + s[layer] = self.cells[layer](inp, s[layer]) + # Input to the next layer is the state of this layer + inp = s[layer] + # Collect the output of the final layer out.append(s[-1]) + # Stack the outputs and states out = torch.stack(out) s = torch.stack(s) return out, s diff --git a/readme.rst b/readme.rst index 531ec1d2..79655d1f 100644 --- a/readme.rst +++ b/readme.rst @@ -16,13 +16,18 @@ Transformers and `relative multi-headed attention `_. -✅ TODO -------- +Recurrent Highway Networks +-------------------------- -* Recurrent Highway Networks -* LSTMs +This is the implementation for `Recurrent Highway Networks `_. -Please create a Github issue if there's something you'ld like to see implemented here. + +LSTM +---- + +This is the implementation for `LSTMs `_. + +✅ Please create a Github issue if there's something you'ld like to see implemented here. Installation ------------ -- GitLab