提交 2ca85492 编写于 作者: V Varuna Jayasiri

rhn docs

上级 0eb20f85
......@@ -2,10 +2,9 @@
# LabML Models
* [Transformers](transformers/index.html)
* [Recurrent Highway Networks](recurrent_highway_networks/index.html)
* [LSTM](lstm/index.html)
TODO:
* LSTM
* Highway Networks
* 🤔
If you have any suggestions for other new implementations,
please create a [Github Issue](https://github.com/lab-ml/labml_nn/issues).
"""
"""
This is an implementation of [Recurrent Highway Networks](https://arxiv.org/abs/1607.03474).
"""
from typing import Optional
import torch
from torch import nn
......@@ -5,54 +10,142 @@ from labml_helpers.module import Module
class RHNCell(Module):
"""
## Recurrent Highway Network Cell
This implements equations $(6) - (9)$.
$s_d^t = h_d^t . g_d^t + s_{d - 1}^t . c_d^t$
where
\begin{align}
h_0^t &= tanh(lin_{hx}(x) + lin_{hs}(s_D^{t-1})) \\
g_0^t &= \sigma(lin_{gx}(x) + lin_{gs}^1(s_D^{t-1})) \\
c_0^t &= \sigma(lin_{cx}(x) + lin_{cs}^1(s_D^{t-1}))
\end{align}
and for $0 < d < D$
\begin{align}
h_d^t &= tanh(lin_{hs}^d(s_d^t)) \\
g_d^t &= \sigma(lin_{gs}^d(s_d^t)) \\
c_d^t &= \sigma(lin_{cs}^d(s_d^t))
\end{align}
Here we have made a couple of changes to notations from the paper.
To avoid confusion with time, the gate is represented with $g$,
which was $t$ in the paper.
To avoid confusion with multiple layers we use $d$ for depth and $D$ for
total depth instead of $l$ and $L$ from paper.
We have also replaced the weight matrices and bias vectors from the equations with
linear transforms, because that's how the implementation is going to look like.
We implement weight tying, as described in paper, $c_d^t = (1 - g_d^t$.
"""
def __init__(self, input_size: int, hidden_size: int, depth: int):
"""
`input_size` is the feature length of the input and `hidden_size` is
feature length of the cell.
`depth` is $D$.
"""
super().__init__()
self.hidden_size = hidden_size
self.depth = depth
# We combine $lin_{hs}$ and $lin_{gs}$, with a single linear layer.
# We can then split the results to get the $lin_{hs}$ and $lin_{gs}$ components.
# This is the $lin_{hs}^d$ and $lin_{gs}^d$ for $0 \leq d < D$.
self.hidden_lin = nn.ModuleList([nn.Linear(hidden_size, 2 * hidden_size) for _ in range(depth)])
# Similarly we combine $lin_{hx}$ and $lin_{gx}$.
self.input_lin = nn.Linear(input_size, 2 * hidden_size, bias=False)
def __call__(self, x, s):
for i in range(self.depth):
if i == 0:
ht = self.input_lin(x) + self.hidden_lin[i](s)
def __call__(self, x: torch.Tensor, s: torch.Tensor):
"""
`x` has shape `[batch_size, input_size]` and
`s` has shape `[batch_size, hidden_size]`.
"""
# Iterate $0 \leq d < D$
for d in range(self.depth):
# We calculate the concatenation of linear transforms for $h$ and $g$
if d == 0:
# The input is used only when $d$ is $0$.
hg = self.input_lin(x) + self.hidden_lin[d](s)
else:
ht = self.hidden_lin[i](s)
hg = self.hidden_lin[d](s)
h = torch.tanh(ht[:, :self.hidden_size])
t = torch.sigmoid(ht[:, self.hidden_size:])
# Use the first half of `hg` to get $h_d^t$
# \begin{align}
# h_0^t &= tanh(lin_{hx}(x) + lin_{hs}(s_D^{t-1})) \\
# h_d^t &= tanh(lin_{hs}^d(s_d^t))
# \end{align}
h = torch.tanh(hg[:, :self.hidden_size])
# Use the second half of `hg` to get $g_d^t$
# \begin{align}
# g_0^t &= \sigma(lin_{gx}(x) + lin_{gs}^1(s_D^{t-1})) \\
# g_d^t &= \sigma(lin_{gs}^d(s_d^t))
# \end{align}
g = torch.sigmoid(hg[:, self.hidden_size:])
s = s + (h - s) * t
s = h * g + s * (1 - g)
return s
class RHN(Module):
"""
### Multilayer Recurrent Highway Network
"""
def __init__(self, input_size: int, hidden_size: int, depth: int, n_layers: int):
"""
Create a network of `n_layers` of recurrent highway network layers, each with depth `depth`, $D$.
"""
super().__init__()
self.n_layers = n_layers
self.hidden_size = hidden_size
# Create cells for each layer. Note that only the first layer gets the input directly.
# Rest of the layers get the input from the layer below
self.cells = nn.ModuleList([RHNCell(input_size, hidden_size, depth)] +
[RHNCell(hidden_size, hidden_size, depth) for _ in range(n_layers - 1)])
def __call__(self, x: torch.Tensor, state=None):
# x [seq_len, batch, d_model]
def __call__(self, x: torch.Tensor, state: Optional[torch.Tensor] = None):
"""
`x` has shape `[seq_len, batch_size, input_size]` and
`s` has shape `[batch_size, hidden_size]`.
"""
time_steps, batch_size = x.shape[:2]
# Initialize the state if `None`
if state is None:
s = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
else:
# Reverse stack the state to get the state of each layer <br />
# 📝 You can just work with the tensor itself but this is easier to debug
s = torch.unbind(state)
# Array to collect the outputs of the final layer at each time step.
out = []
# Run through the network for each time step
for t in range(time_steps):
# Input to the first layer is the input itself
inp = x[t]
for i in range(self.n_layers):
s[i] = self.cells[i](inp, s[i])
inp = s[i]
# Loop through the layers
for layer in range(self.n_layers):
# Get the state of the first layer
s[layer] = self.cells[layer](inp, s[layer])
# Input to the next layer is the state of this layer
inp = s[layer]
# Collect the output of the final layer
out.append(s[-1])
# Stack the outputs and states
out = torch.stack(out)
s = torch.stack(s)
return out, s
......@@ -16,13 +16,18 @@ Transformers
and
`relative multi-headed attention <http://lab-ml.com/labml_nn/transformers/relative_mha.html>`_.
✅ TODO
-------
Recurrent Highway Networks
--------------------------
* Recurrent Highway Networks
* LSTMs
This is the implementation for `Recurrent Highway Networks <http://lab-ml.com/labml_nn/recurrent_highway_networks>`_.
Please create a Github issue if there's something you'ld like to see implemented here.
LSTM
----
This is the implementation for `LSTMs <http://lab-ml.com/labml_nn/lstm>`_.
✅ Please create a Github issue if there's something you'ld like to see implemented here.
Installation
------------
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册