diff --git a/.gitignore b/.gitignore index 35634031b724ae5ff9c4618e4fb722141db96d89..28172990e2556a59ec5a85bd78328054e3ba8b93 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,5 @@ build/ .idea/* !.idea/dictionaries html/ +labml +labml_helpers diff --git a/.labml.yaml b/.labml.yaml new file mode 100644 index 0000000000000000000000000000000000000000..700ebb73bc5acefa21102467c4127926b2df9130 --- /dev/null +++ b/.labml.yaml @@ -0,0 +1 @@ +web_api: https://api.lab-ml.com/api/v1/track?labml_token=903c84fba8ca49ca9f215922833e08cf&channel=app-updates-test diff --git a/Makefile b/Makefile index 9910f1b6bebe727957e706f089a607b4166644a9..b88884dcec3c4649e5edc989a2847befd934dd86 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ uninstall: ## Uninstall pip uninstall labml_nn docs: ## Render annotated HTML - python ../../pylit/pylit.py -t ../../pylit/template_docs.html -d html -w labml_nn + python ../../pylit/pylit.py --remove_empty_sections -s ../../pylit/pylit_docs.css -t ../../pylit/template_docs.html -d html -w labml_nn pages: ## Copy to lab-ml site @cd ../lab-ml.github.io; git pull diff --git a/labml_nn/__init__.py b/labml_nn/__init__.py index d6e6e3f3e96ab6dbc26c2150723275314ce8e97c..1ecedc490d9f316bbdafa2542d457fed4806de82 100644 --- a/labml_nn/__init__.py +++ b/labml_nn/__init__.py @@ -4,6 +4,7 @@ * [Transformers](transformers/index.html) * [Recurrent Highway Networks](recurrent_highway_networks/index.html) * [LSTM](lstm/index.html) +* [Capsule Networks](capsule_networks/index.html) If you have any suggestions for other new implementations, please create a [Github Issue](https://github.com/lab-ml/labml_nn/issues). diff --git a/labml_nn/capsule_networks/__init__.py b/labml_nn/capsule_networks/__init__.py index f3b394d6cf4209106d7e9d7da634897a74bd60cc..dbc7306ebd13b6fcdf6e58cecbe3e5ce3db683da 100644 --- a/labml_nn/capsule_networks/__init__.py +++ b/labml_nn/capsule_networks/__init__.py @@ -1,23 +1,33 @@ """ -This is an implementation of paper +This is an implementation of [Dynamic Routing Between Capsules](https://arxiv.org/abs/1710.09829). + +Unlike in other implementations of models, we've included a sample, because +it is difficult to understand some of the concepts with just the modules. +[This is the annotated code for a model that use capsules to classify MNIST dataset](mnist.html) + +This file holds the implementations of the core modules of Capsule Networks. """ import torch.nn as nn import torch.nn.functional as F import torch.utils.data -from labml import experiment, tracker -from labml.configs import option -from labml.utils.pytorch import get_device -from labml_helpers.datasets.mnist import MNISTConfigs -from labml_helpers.device import DeviceConfigs from labml_helpers.module import Module -from labml_helpers.train_valid import TrainValidConfigs, BatchStep class Squash(Module): """ - This is **squashing** function from paper. + ## Squash + + This is **squashing** function from paper, given by equation $(1)$. + + $$\mathbf{v}_j = \frac{{\lVert \mathbf{s}_j \rVert}^2}{1 + {\lVert \mathbf{s}_j \rVert}^2} + \frac{\mathbf{s}_j}{\lVert \mathbf{s}_j \rVert}$$ + + $\frac{\mathbf{s}_j}{\lVert \mathbf{s}_j \rVert}$ + normalizes the length of all the capsules, whilst + $\frac{{\lVert \mathbf{s}_j \rVert}^2}{1 + {\lVert \mathbf{s}_j \rVert}^2}$ + shrinks the capsules that have a length smaller than one . """ def __init__(self, epsilon=1e-8): @@ -25,42 +35,103 @@ class Squash(Module): self.epsilon = epsilon def __call__(self, s: torch.Tensor): - # shape: batch, caps, features + """ + The shape of `s` is `[batch_size, n_capsules, n_features]` + """ + + # ${\lVert \mathbf{s}_j \rVert}^2$ s2 = (s ** 2).sum(dim=-1, keepdims=True) + + # We add an epsilon when calculating $\lVert \mathbf{s}_j \rVert$ to make sure it doesn't become zero. + # If this becomes zero it starts giving out `nan` values and training fails. + # $$\mathbf{v}_j = \frac{{\lVert \mathbf{s}_j \rVert}^2}{1 + {\lVert \mathbf{s}_j \rVert}^2} + # \frac{\mathbf{s}_j}{\sqrt{{\lVert \mathbf{s}_j \rVert}^2 + \epsilon}}$$ return (s2 / (1 + s2)) * (s / torch.sqrt(s2 + self.epsilon)) class Router(Module): """ - The routing mechanism + ## Routing Algorithm + + This is the routing mechanism described in the paper. + You can use multiple routing layers in your models. + + This combines calculating $\mathbf{s}_j$ for this layer and + the routing algorithm described in *Procedure 1*. """ - def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, - iterations: int): + def __init__(self, in_caps: int, out_caps: int, in_d: int, out_d: int, iterations: int): + """ + `in_caps` is the number of capsules, and `in_d` is the number of features per capsule from the layer below. + `out_caps` and `out_d` are the same for this layer. + + `iterations` is the number of routing iterations, symbolized by $r$ in the paper. + """ super().__init__() self.in_caps = in_caps self.out_caps = out_caps self.iterations = iterations - self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d)) self.softmax = nn.Softmax(dim=1) self.squash = Squash() + # This is the weight matrix $\mathbf{W}_{ij}$. It maps each capsule in the + # lower layer to each capsule in this layer + self.weight = nn.Parameter(torch.randn(in_caps, out_caps, in_d, out_d), requires_grad=True) + def __call__(self, u: torch.Tensor): - # batch, in_caps, in_d + """ + The shape of `u` is `[batch_size, n_capsules, n_features]`. + These are the capsules from the lower layer. + """ + + # $$\hat{\mathbf{u}}_{j|i} = \mathbf{W}_{ij} \mathbf{u}_i$$ + # Here $j$ is used to index capsules in this layer, whilst $i$ is + # used to index capsules in the layer below (previous). u_hat = torch.einsum('ijnm,bin->bijm', self.weight, u) + + # Initial logits $b_{ij}$ are the log prior probabilities that capsule $i$ + # should be coupled with $j$. + # We initialize these at zero b = u.new_zeros(u.shape[0], self.in_caps, self.out_caps) + v = None + + # Iterate for i in range(self.iterations): + # routing softmax $$c_{ij} = \frac{\exp({b_{ij}})}{\sum_k\exp({b_{ik}})}$$ c = self.softmax(b) + # $$\mathbf{s}_j = \sum_i{c_{ij} \hat{\mathbf{u}}_{j|i}}$$ s = torch.einsum('bij,bijm->bjm', c, u_hat) + # $$\mathbf{v}_j = squash(\mathbf{s}_j)$$ v = self.squash(s) + # $$a_{ij} = \mathbf{v}_j \cdot \hat{\mathbf{u}}_{j|i}$$ a = torch.einsum('bjm,bijm->bij', v, u_hat) + # $$b_{ij} \gets b_{ij} + \mathbf{v}_j \cdot \hat{\mathbf{u}}_{j|i}$$ b = b + a return v class MarginLoss(Module): + """ + ## Margin loss for class existence + + A separate margin loss is used for each output capsule and the total loss is the sum of them. + The length of each output capsule is the probability that class is present in the input. + + Loss for each output capsule or class $k$ is, + $$L_k = T_k \max(0, m^{+} - \lVert\mathbf{v}_k\rVert)^2 + + \lambda (1 - T_k) \max(0, \lVert\mathbf{v}_k\rVert - m^{-})^2$$ + + $T_k$ is $1$ if the class $k$ is present and $0$ otherwise. + The first component of the loss is $0$ when if the class is not present, + and the second component is $0$ is the class is present. + The $\max(0, x)$ is used to avoid predictions going to extremes. + $m^{+}$ is set to be $0.9$ and $m^{-}$ to be $0.1$ in the paper. + + The $\lambda$ down-weighting is used to stop the length of all capsules from + fallind during the initial phase of training. + """ def __init__(self, *, n_labels: int, lambda_: float = 0.5, m_positive: float = 0.9, m_negative: float = 0.1): super().__init__() @@ -70,104 +141,25 @@ class MarginLoss(Module): self.n_labels = n_labels def __call__(self, v: torch.Tensor, labels: torch.Tensor): + """ + `v`, $\mathbf{v}_j$ are the squashed output capsules. + This has shape `[batch_size, n_labels, n_features]`; that is, there is a capsule for each label. + + `labels` are the labels, and has shape `[batch_size]`. + """ + # $$\lVert \mathbf{v}_j \rVert$$ v_norm = torch.sqrt((v ** 2).sum(dim=-1)) + # $$L$$ + # `labels` is one-hot encoded labels of shape `[batch_size, n_labels]` labels = torch.eye(self.n_labels, device=labels.device)[labels] + + # $$L_k = T_k \max(0, m^{+} - \lVert\mathbf{v}_k\rVert)^2 + + # \lambda (1 - T_k) \max(0, \lVert\mathbf{v}_k\rVert - m^{-})^2$$ + # `loss` has shape `[batch_size, n_labels]`. We have parallelized the computation + # of $L_k$ for for all $k$. loss = labels * F.relu(self.m_positive - v_norm) + \ self.lambda_ * (1.0 - labels) * F.relu(v_norm - self.m_negative) - loss = loss.sum(dim=-1).mean() - - return loss - - -class MNISTCapsuleNetworkModel(Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1) - self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0) - self.squash = Squash() - - # self.digit_capsules = DigitCaps() - self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3) - self.reconstruct = nn.Sequential( - nn.Linear(16 * 10, 512), - nn.ReLU(), - nn.Linear(512, 1024), - nn.ReLU(), - nn.Linear(1024, 784), - nn.Sigmoid() - ) - - self.mse_loss = nn.MSELoss() - - def forward(self, data): - x = F.relu(self.conv1(data)) - caps = self.conv2(x).view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1) - caps = self.squash(caps) - caps = self.digit_capsules(caps) - - with torch.no_grad(): - pred = (caps ** 2).sum(-1).argmax(-1) - masked = torch.eye(10, device=x.device)[pred] - - reconstructions = self.reconstruct((caps * masked[:, :, None]).view(x.shape[0], -1)) - reconstructions = reconstructions.view(-1, 1, 28, 28) - - return caps, reconstructions, pred - - -class CapsuleNetworkBatchStep(BatchStep): - def __init__(self, *, model, optimizer): - super().__init__(model=model, optimizer=optimizer, loss_func=None, accuracy_func=None) - self.reconstruction_loss = nn.MSELoss() - self.margin_loss = MarginLoss(n_labels=10) - - def calculate_loss(self, batch: any, state: any): - device = get_device(self.model) - data, target = batch - data, target = data.to(device), target.to(device) - stats = {'samples': len(data)} - - caps, reconstructions, pred = self.model(data) - - loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data) - - stats['correct'] = pred.eq(target).sum().item() - stats['loss'] = loss.detach().item() * stats['samples'] - tracker.add("loss.", loss) - - return loss, stats, None - - -class Configs(MNISTConfigs, TrainValidConfigs): - batch_step = 'capsule_network_batch_step' - device: torch.device = DeviceConfigs() - epochs: int = 10 - - loss_func = None - accuracy_func = None - - -@option(Configs.model) -def model(c: Configs): - return MNISTCapsuleNetworkModel().to(c.device) - - -@option(Configs.batch_step) -def capsule_network_batch_step(c: TrainValidConfigs): - return CapsuleNetworkBatchStep(model=c.model, optimizer=c.optimizer) - - -def main(): - conf = Configs() - experiment.create(name='mnist_latest', writers={}) - experiment.configs(conf, {'optimizer.optimizer': 'Adam', - 'device.cuda_device': 1}, - 'run') - experiment.add_pytorch_models(dict(model=conf.model)) - with experiment.start(): - conf.run() - -if __name__ == '__main__': - main() + # $$\sum_k L_k$$ + return loss.sum(dim=-1).mean() diff --git a/labml_nn/capsule_networks/mnist.py b/labml_nn/capsule_networks/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..d2a85c6e09378705ea88b105b9c728fa90cfcb63 --- /dev/null +++ b/labml_nn/capsule_networks/mnist.py @@ -0,0 +1,105 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data + +from labml import experiment, tracker +from labml.configs import option +from labml.utils.pytorch import get_device +from labml_helpers.datasets.mnist import MNISTConfigs +from labml_helpers.device import DeviceConfigs +from labml_helpers.module import Module +from labml_helpers.train_valid import TrainValidConfigs, BatchStep +from labml_nn.capsule_networks import Squash, Router, MarginLoss + + +class MNISTCapsuleNetworkModel(Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1) + self.conv2 = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2, padding=0) + self.squash = Squash() + + # self.digit_capsules = DigitCaps() + self.digit_capsules = Router(32 * 6 * 6, 10, 8, 16, 3) + self.reconstruct = nn.Sequential( + nn.Linear(16 * 10, 512), + nn.ReLU(), + nn.Linear(512, 1024), + nn.ReLU(), + nn.Linear(1024, 784), + nn.Sigmoid() + ) + + self.mse_loss = nn.MSELoss() + + def forward(self, data): + x = F.relu(self.conv1(data)) + caps = self.conv2(x).view(x.shape[0], 8, 32 * 6 * 6).permute(0, 2, 1) + caps = self.squash(caps) + caps = self.digit_capsules(caps) + + with torch.no_grad(): + pred = (caps ** 2).sum(-1).argmax(-1) + masked = torch.eye(10, device=x.device)[pred] + + reconstructions = self.reconstruct((caps * masked[:, :, None]).view(x.shape[0], -1)) + reconstructions = reconstructions.view(-1, 1, 28, 28) + + return caps, reconstructions, pred + + +class CapsuleNetworkBatchStep(BatchStep): + def __init__(self, *, model, optimizer): + super().__init__(model=model, optimizer=optimizer, loss_func=None, accuracy_func=None) + self.reconstruction_loss = nn.MSELoss() + self.margin_loss = MarginLoss(n_labels=10) + + def calculate_loss(self, batch: any, state: any): + device = get_device(self.model) + data, target = batch + data, target = data.to(device), target.to(device) + stats = {'samples': len(data)} + + caps, reconstructions, pred = self.model(data) + + loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data) + + stats['correct'] = pred.eq(target).sum().item() + stats['loss'] = loss.detach().item() * stats['samples'] + tracker.add("loss.", loss) + + return loss, stats, None + + +class Configs(MNISTConfigs, TrainValidConfigs): + batch_step = 'capsule_network_batch_step' + device: torch.device = DeviceConfigs() + epochs: int = 10 + model = 'capsule_network_model' + + loss_func = None + accuracy_func = None + + +@option(Configs.model) +def capsule_network_model(c: Configs): + return MNISTCapsuleNetworkModel().to(c.device) + + +@option(Configs.batch_step) +def capsule_network_batch_step(c: TrainValidConfigs): + return CapsuleNetworkBatchStep(model=c.model, optimizer=c.optimizer) + + +def main(): + conf = Configs() + experiment.create(name='mnist_latest') + experiment.configs(conf, {'optimizer.optimizer': 'Adam', + 'device.cuda_device': 1}, + 'run') + with experiment.start(): + conf.run() + + +if __name__ == '__main__': + main() diff --git a/setup.py b/setup.py index adc42691240d96cbe0ffe3b8705623c4310d0a14..2bd4755d3bacd31aaae9fc423bffdf1984780132 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ with open("readme.rst", "r") as f: setuptools.setup( name='labml_nn', - version='0.4.1', + version='0.4.2', author="Varuna Jayasiri, Nipun Wijerathne", author_email="vpjayasiri@gmail.com, hnipun@gmail.com", description="A collection of PyTorch implementations of neural network architectures and layers.",