提交 302785a4 编写于 作者: V Varuna Jayasiri

capsnet annotations

上级 010f0c56
......@@ -8,3 +8,5 @@ build/
.idea/*
!.idea/dictionaries
html/
labml
labml_helpers
web_api: https://api.lab-ml.com/api/v1/track?labml_token=903c84fba8ca49ca9f215922833e08cf&channel=app-updates-test
......@@ -22,7 +22,7 @@ uninstall: ## Uninstall
pip uninstall labml_nn
docs: ## Render annotated HTML
python ../../pylit/pylit.py -t ../../pylit/template_docs.html -d html -w labml_nn
python ../../pylit/pylit.py --remove_empty_sections -s ../../pylit/pylit_docs.css -t ../../pylit/template_docs.html -d html -w labml_nn
pages: ## Copy to lab-ml site
@cd ../lab-ml.github.io; git pull
......
......@@ -4,6 +4,7 @@
* [Transformers](transformers/index.html)
* [Recurrent Highway Networks](recurrent_highway_networks/index.html)
* [LSTM](lstm/index.html)
* [Capsule Networks](capsule_networks/index.html)
If you have any suggestions for other new implementations,
please create a [Github Issue](https://github.com/lab-ml/labml_nn/issues).
......
"""
This is an implementation of paper
This is an implementation of [Dynamic Routing Between Capsules](https://arxiv.org/abs/1710.09829).
Unlike in other implementations of models, we've included a sample, because
it is difficult to understand some of the concepts with just the modules.
[This is the annotated code for a model that use capsules to classify MNIST dataset](mnist.html)
This file holds the implementations of the core modules of Capsule Networks.
"""
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from labml import experiment, tracker
from labml.configs import option
from labml.utils.pytorch import get_device
from labml_helpers.datasets.mnist import MNISTConfigs
from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module
from labml_helpers.train_valid import TrainValidConfigs, BatchStep
class Squash(Module):
"""
This is **squashing** function from paper.
## Squash
This is **squashing** function from paper, given by equation $(1)$.
$$\mathbf{v}_j = \frac{{\lVert \mathbf{s}_j \rVert}^2}{1 + {\lVert \mathbf{s}_j \rVert}^2}
\frac{\mathbf{s}_j}{\lVert \mathbf{s}_j \rVert}$$
$\frac{\mathbf{s}_j}{\lVert \mathbf{s}_j \rVert}$
normalizes the length of all the capsules, whilst
$\frac{{\lVert \mathbf{s}_j \rVert}^2}{1 + {\lVert \mathbf{s}_j \rVert}^2}$
shrinks the capsules that have a length smaller than one .
"""
def __init__(self, epsilon=1e-8):
......@@ -25,42 +35,103 @@ class Squash(Module):
self.epsilon = epsilon
def __call__(self, s: torch.Tensor):
# shape: batch, caps, features
"""
The shape of `s` is `[batch_size, n_capsules, n_features]`
"""
# ${\lVert \mathbf{s}_j \rVert}^2$
s2 = (s ** 2).sum(dim=-1, keepdims=True)
# We add an epsilon when calculating $\lVert \mathbf{s}_j \rVert$ to make sure it doesn't become zero.
# If this becomes zero it starts giving out `nan` values and training fails.
# $$\mathbf{v}_j = \frac{{\lVert \mathbf{s}_j \rVert}^2}{1 + {\lVert \mathbf{s}_j \rVert}^2}
# \frac{\mathbf{s}_j}{\sqrt{{\lVert \mathbf{s}_j \rVert}^2 + \epsilon}}$$
return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon))
class Router(Module):
"""
The routing mechanism
## Routing Algorithm
This is the routing mechanism described in the paper.
You can use multiple routing layers in your models.
This combines calculating $\mathbf{s}_j$ for this layer and
the routing algorithm described in *Procedure 1*.
"""
def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int,
iterations: int):
def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, iterations: int):
"""
`in_caps` is the number of capsules, and `in_d` is the number of features per capsule from the layer below.
`out_caps` and `out_d` are the same for this layer.
`iterations` is the number of routing iterations, symbolized by $r$ in the paper.
"""
super().__init__()
self.in_caps = in_caps
self.out_caps = out_caps
self.iterations = iterations
self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d))
self.softmax = nn.Softmax(dim=1)
self.squash = Squash()
# This is the weight matrix $\mathbf{W}_{ij}$. It maps each capsule in the
# lower layer to each capsule in this layer
self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True)
def __call__(self, u: torch.Tensor):
# batch, in_caps, in_d
"""
The shape of `u` is `[batch_size, n_capsules, n_features]`.
These are the capsules from the lower layer.
"""
# $$\hat{\mathbf{u}}_{j|i} = \mathbf{W}_{ij} \mathbf{u}_i$$
# Here $j$ is used to index capsules in this layer, whilst $i$ is
# used to index capsules in the layer below (previous).
u_hat = torch.einsum('ijnm,bin->bijm', self.weight, u)
# Initial logits $b_{ij}$ are the log prior probabilities that capsule $i$
# should be coupled with $j$.
# We initialize these at zero
b = u.new_zeros(u.shape[0], self.in_caps, self.out_caps)
v = None
# Iterate
for i in range(self.iterations):
# routing softmax $$c_{ij} = \frac{\exp({b_{ij}})}{\sum_k\exp({b_{ik}})}$$
c = self.softmax(b)
# $$\mathbf{s}_j = \sum_i{c_{ij} \hat{\mathbf{u}}_{j|i}}$$
s = torch.einsum('bij,bijm->bjm', c, u_hat)
# $$\mathbf{v}_j = squash(\mathbf{s}_j)$$
v = self.squash(s)
# $$a_{ij} = \mathbf{v}_j \cdot \hat{\mathbf{u}}_{j|i}$$
a = torch.einsum('bjm,bijm->bij', v, u_hat)
# $$b_{ij} \gets b_{ij} + \mathbf{v}_j \cdot \hat{\mathbf{u}}_{j|i}$$
b = b + a
return v
class MarginLoss(Module):
"""
## Margin loss for class existence
A separate margin loss is used for each output capsule and the total loss is the sum of them.
The length of each output capsule is the probability that class is present in the input.
Loss for each output capsule or class $k$ is,
$$L_k = T_k \max(0, m^{+} - \lVert\mathbf{v}_k\rVert)^2 +
\lambda (1 - T_k) \max(0, \lVert\mathbf{v}_k\rVert - m^{-})^2$$
$T_k$ is $1$ if the class $k$ is present and $0$ otherwise.
The first component of the loss is $0$ when if the class is not present,
and the second component is $0$ is the class is present.
The $\max(0, x)$ is used to avoid predictions going to extremes.
$m^{+}$ is set to be $0.9$ and $m^{-}$ to be $0.1$ in the paper.
The $\lambda$ down-weighting is used to stop the length of all capsules from
fallind during the initial phase of training.
"""
def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1):
super().__init__()
......@@ -70,104 +141,25 @@ class MarginLoss(Module):
self.n_labels = n_labels
def __call__(self, v: torch.Tensor, labels: torch.Tensor):
"""
`v`, $\mathbf{v}_j$ are the squashed output capsules.
This has shape `[batch_size, n_labels, n_features]`; that is, there is a capsule for each label.
`labels` are the labels, and has shape `[batch_size]`.
"""
# $$\lVert \mathbf{v}_j \rVert$$
v_norm = torch.sqrt((v ** 2).sum(dim=-1))
# $$L$$
# `labels` is one-hot encoded labels of shape `[batch_size, n_labels]`
labels = torch.eye(self.n_labels, device=labels.device)[labels]
# $$L_k = T_k \max(0, m^{+} - \lVert\mathbf{v}_k\rVert)^2 +
# \lambda (1 - T_k) \max(0, \lVert\mathbf{v}_k\rVert - m^{-})^2$$
# `loss` has shape `[batch_size, n_labels]`. We have parallelized the computation
# of $L_k$ for for all $k$.
loss = labels * F.relu(self.m_positive - v_norm) + \
self.lambda_ * (1.0 - labels) * F.relu(v_norm - self.m_negative)
loss = loss.sum(dim=-1).mean()
return loss
class MNISTCapsuleNetworkModel(Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
self.squash = Squash()
# self.digit_capsules = DigitCaps()
self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)
self.reconstruct = nn.Sequential(
nn.Linear(16 * 10, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 784),
nn.Sigmoid()
)
self.mse_loss = nn.MSELoss()
def forward(self, data):
x = F.relu(self.conv1(data))
caps = self.conv2(x).view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)
caps = self.squash(caps)
caps = self.digit_capsules(caps)
with torch.no_grad():
pred = (caps ** 2).sum(-1).argmax(-1)
masked = torch.eye(10, device=x.device)[pred]
reconstructions = self.reconstruct((caps * masked[:, :, None]).view(x.shape[0], -1))
reconstructions = reconstructions.view(-1, 1, 28, 28)
return caps, reconstructions, pred
class CapsuleNetworkBatchStep(BatchStep):
def __init__(self, *, model, optimizer):
super().__init__(model=model, optimizer=optimizer, loss_func=None, accuracy_func=None)
self.reconstruction_loss = nn.MSELoss()
self.margin_loss = MarginLoss(n_labels=10)
def calculate_loss(self, batch: any, state: any):
device = get_device(self.model)
data, target = batch
data, target = data.to(device), target.to(device)
stats = {'samples': len(data)}
caps, reconstructions, pred = self.model(data)
loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
stats['correct'] = pred.eq(target).sum().item()
stats['loss'] = loss.detach().item() * stats['samples']
tracker.add("loss.", loss)
return loss, stats, None
class Configs(MNISTConfigs, TrainValidConfigs):
batch_step = 'capsule_network_batch_step'
device: torch.device = DeviceConfigs()
epochs: int = 10
loss_func = None
accuracy_func = None
@option(Configs.model)
def model(c: Configs):
return MNISTCapsuleNetworkModel().to(c.device)
@option(Configs.batch_step)
def capsule_network_batch_step(c: TrainValidConfigs):
return CapsuleNetworkBatchStep(model=c.model, optimizer=c.optimizer)
def main():
conf = Configs()
experiment.create(name='mnist_latest', writers={})
experiment.configs(conf, {'optimizer.optimizer': 'Adam',
'device.cuda_device': 1},
'run')
experiment.add_pytorch_models(dict(model=conf.model))
with experiment.start():
conf.run()
if __name__ == '__main__':
main()
# $$\sum_k L_k$$
return loss.sum(dim=-1).mean()
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from labml import experiment, tracker
from labml.configs import option
from labml.utils.pytorch import get_device
from labml_helpers.datasets.mnist import MNISTConfigs
from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module
from labml_helpers.train_valid import TrainValidConfigs, BatchStep
from labml_nn.capsule_networks import Squash, Router, MarginLoss
class MNISTCapsuleNetworkModel(Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0)
self.squash = Squash()
# self.digit_capsules = DigitCaps()
self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3)
self.reconstruct = nn.Sequential(
nn.Linear(16 * 10, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 784),
nn.Sigmoid()
)
self.mse_loss = nn.MSELoss()
def forward(self, data):
x = F.relu(self.conv1(data))
caps = self.conv2(x).view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1)
caps = self.squash(caps)
caps = self.digit_capsules(caps)
with torch.no_grad():
pred = (caps ** 2).sum(-1).argmax(-1)
masked = torch.eye(10, device=x.device)[pred]
reconstructions = self.reconstruct((caps * masked[:, :, None]).view(x.shape[0], -1))
reconstructions = reconstructions.view(-1, 1, 28, 28)
return caps, reconstructions, pred
class CapsuleNetworkBatchStep(BatchStep):
def __init__(self, *, model, optimizer):
super().__init__(model=model, optimizer=optimizer, loss_func=None, accuracy_func=None)
self.reconstruction_loss = nn.MSELoss()
self.margin_loss = MarginLoss(n_labels=10)
def calculate_loss(self, batch: any, state: any):
device = get_device(self.model)
data, target = batch
data, target = data.to(device), target.to(device)
stats = {'samples': len(data)}
caps, reconstructions, pred = self.model(data)
loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
stats['correct'] = pred.eq(target).sum().item()
stats['loss'] = loss.detach().item() * stats['samples']
tracker.add("loss.", loss)
return loss, stats, None
class Configs(MNISTConfigs, TrainValidConfigs):
batch_step = 'capsule_network_batch_step'
device: torch.device = DeviceConfigs()
epochs: int = 10
model = 'capsule_network_model'
loss_func = None
accuracy_func = None
@option(Configs.model)
def capsule_network_model(c: Configs):
return MNISTCapsuleNetworkModel().to(c.device)
@option(Configs.batch_step)
def capsule_network_batch_step(c: TrainValidConfigs):
return CapsuleNetworkBatchStep(model=c.model, optimizer=c.optimizer)
def main():
conf = Configs()
experiment.create(name='mnist_latest')
experiment.configs(conf, {'optimizer.optimizer': 'Adam',
'device.cuda_device': 1},
'run')
with experiment.start():
conf.run()
if __name__ == '__main__':
main()
......@@ -5,7 +5,7 @@ with open("readme.rst", "r") as f:
setuptools.setup(
name='labml_nn',
version='0.4.1',
version='0.4.2',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="A collection of PyTorch implementations of neural network architectures and layers.",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册