diff --git a/docs/index.html b/docs/index.html index c69dad4a22737751fd22ac7b484758bdbc0f48cb..fd59953224d07921a334924f6a192ecbf37a676e 100644 --- a/docs/index.html +++ b/docs/index.html @@ -154,15 +154,19 @@ implementations.
+pip install labml-nn
If you use LabML for academic research, please cite the library using the following BibTeX entry.
+If you use this for academic research, please cite it using the following BibTeX entry.
@misc{labml,
author = {Varuna Jayasiri, Nipun Wijerathne},
- title = {LabML: A library to organize machine learning experiments},
+ title = {labml.ai Annotated Paper Implementations},
year = {2020},
url = {https://nn.labml.ai/},
}
diff --git a/docs/normalization/batch_norm/mnist.html b/docs/normalization/batch_norm/mnist.html
index c283324e275bb1611e0b2a08323b7ea0b0c64215..af6cd3185f27edb3f320d82e65cae55547e6fa73 100644
--- a/docs/normalization/batch_norm/mnist.html
+++ b/docs/normalization/batch_norm/mnist.html
@@ -268,7 +268,10 @@ and set a new function to calculate the model.
Load configurations
- 75 experiment.configs(conf, {'optimizer.optimizer': 'Adam'})
+ 75 experiment.configs(conf, {
+76 'optimizer.optimizer': 'Adam',
+77 'optimizer.learning_rate': 0.001,
+78 })
@@ -279,8 +282,8 @@ and set a new function to calculate the model.
Start the experiment and run the training loop
- 77 with experiment.start():
-78 conf.run()
+ 80 with experiment.start():
+81 conf.run()
@@ -291,8 +294,8 @@ and set a new function to calculate the model.
- 82if __name__ == '__main__':
-83 main()
+ 85if __name__ == '__main__':
+86 main()
diff --git a/docs/papers.json b/docs/papers.json
index d22adf24b934acf7b7742710b287cb2cea94920c..0f45ac1cd13d20e14e9517d77770d90f7d06217b 100644
--- a/docs/papers.json
+++ b/docs/papers.json
@@ -114,6 +114,9 @@
"1704.03477": [
"https://nn.labml.ai/sketch_rnn/index.html"
],
+ "1806.01768": [
+ "https://nn.labml.ai/uncertainty/evidence/index.html"
+ ],
"1509.06461": [
"https://nn.labml.ai/rl/dqn/index.html"
],
diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index 4f0461377f053512a0b81c293a75f315c10b328c..7579d06e90ae1821d8949f9a760cff3dbbad29dc 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -204,7 +204,7 @@
https://nn.labml.ai/normalization/batch_norm/mnist.html
- 2021-08-19T16:30:00+00:00
+ 2021-08-20T16:30:00+00:00
1.00
@@ -281,7 +281,7 @@
https://nn.labml.ai/index.html
- 2021-08-12T16:30:00+00:00
+ 2021-08-21T16:30:00+00:00
1.00
@@ -797,6 +797,27 @@
+
+ https://nn.labml.ai/uncertainty/evidence/index.html
+ 2021-08-21T16:30:00+00:00
+ 1.00
+
+
+
+
+ https://nn.labml.ai/uncertainty/evidence/experiment.html
+ 2021-08-21T16:30:00+00:00
+ 1.00
+
+
+
+
+ https://nn.labml.ai/uncertainty/index.html
+ 2021-08-21T16:30:00+00:00
+ 1.00
+
+
+
https://nn.labml.ai/rl/game.html
2020-12-10T16:30:00+00:00
diff --git a/docs/uncertainty/evidence/experiment.html b/docs/uncertainty/evidence/experiment.html
new file mode 100644
index 0000000000000000000000000000000000000000..1a977b2408391e338a009106bf185a61803a0de2
--- /dev/null
+++ b/docs/uncertainty/evidence/experiment.html
@@ -0,0 +1,867 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Evidential Deep Learning to Quantify Classification Uncertainty Experiment
+
+
+
+
+
+
+
+
+
+
+
+
+
+ home
+ uncertainty
+ evidence
+
+
+
+
+
+
+
+ #
+
+ Evidential Deep Learning to Quantify Classification Uncertainty Experiment
+This trains a model based on Evidential Deep Learning to Quantify Classification Uncertainty
+ on MNIST dataset.
+
+
+ 14from typing import Any
+15
+16import torch.nn as nn
+17import torch.utils.data
+18
+19from labml import tracker, experiment
+20from labml.configs import option, calculate
+21from labml_helpers.module import Module
+22from labml_helpers.schedule import Schedule, RelativePiecewise
+23from labml_helpers.train_valid import BatchIndex
+24from labml_nn.experiments.mnist import MNISTConfigs
+25from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
+26 CrossEntropyBayesRisk, SquaredErrorBayesRisk
+
+
+
+
+
+ #
+
+ LeNet based model fro MNIST classification
+
+
+ 29class Model(Module):
+
+
+
+
+
+ #
+
+
+
+
+ 34 def __init__(self, dropout: float):
+35 super().__init__()
+
+
+
+
+
+ #
+
+ First $5x5$ convolution layer
+
+
+ 37 self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
+
+
+
+
+
+ #
+
+ ReLU activation
+
+
+ 39 self.act1 = nn.ReLU()
+
+
+
+
+
+ #
+
+ $2x2$ max-pooling
+
+
+ 41 self.max_pool1 = nn.MaxPool2d(2, 2)
+
+
+
+
+
+ #
+
+ Second $5x5$ convolution layer
+
+
+ 43 self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
+
+
+
+
+
+ #
+
+ ReLU activation
+
+
+ 45 self.act2 = nn.ReLU()
+
+
+
+
+
+ #
+
+ $2x2$ max-pooling
+
+
+ 47 self.max_pool2 = nn.MaxPool2d(2, 2)
+
+
+
+
+
+ #
+
+ First fully-connected layer that maps to $500$ features
+
+
+ 49 self.fc1 = nn.Linear(50 * 4 * 4, 500)
+
+
+
+
+
+ #
+
+ ReLU activation
+
+
+ 51 self.act3 = nn.ReLU()
+
+
+
+
+
+ #
+
+ Final fully connected layer to output evidence for $10$ classes.
+The ReLU or Softplus activation is applied to this outside the model to get the
+non-negative evidence
+
+
+ 55 self.fc2 = nn.Linear(500, 10)
+
+
+
+
+
+ #
+
+ Dropout for the hidden layer
+
+
+ 57 self.dropout = nn.Dropout(p=dropout)
+
+
+
+
+
+ #
+
+
+x
is the batch of MNIST images of shape [batch_size, 1, 28, 28]
+
+
+
+ 59 def __call__(self, x: torch.Tensor):
+
+
+
+
+
+ #
+
+ Apply first convolution and max pooling.
+The result has shape [batch_size, 20, 12, 12]
+
+
+ 65 x = self.max_pool1(self.act1(self.conv1(x)))
+
+
+
+
+
+ #
+
+ Apply second convolution and max pooling.
+The result has shape [batch_size, 50, 4, 4]
+
+
+ 68 x = self.max_pool2(self.act2(self.conv2(x)))
+
+
+
+
+
+ #
+
+ Flatten the tensor to shape [batch_size, 50 * 4 * 4]
+
+
+ 70 x = x.view(x.shape[0], -1)
+
+
+
+
+
+ #
+
+ Apply hidden layer
+
+
+ 72 x = self.act3(self.fc1(x))
+
+
+
+
+
+ #
+
+ Apply dropout
+
+
+ 74 x = self.dropout(x)
+
+
+
+
+
+ #
+
+ Apply final layer and return
+
+
+ 76 return self.fc2(x)
+
+
+
+
+
+ 79class Configs(MNISTConfigs):
+
+
+
+
+
+ #
+
+
+
+
+ 87 kl_div_loss = KLDivergenceLoss()
+
+
+
+
+
+ #
+
+ KL Divergence regularization coefficient schedule
+
+
+ 89 kl_div_coef: Schedule
+
+
+
+
+
+ #
+
+ KL Divergence regularization coefficient schedule
+
+
+ 91 kl_div_coef_schedule = [(0, 0.), (0.2, 0.01), (1, 1.)]
+
+
+
+
+
+ #
+
+ Stats module for tracking
+
+
+ 93 stats = TrackStatistics()
+
+
+
+
+
+ #
+
+ Dropout
+
+
+ 95 dropout: float = 0.5
+
+
+
+
+
+ #
+
+ Module to convert the model output to non-zero evidences
+
+
+ 97 outputs_to_evidence: Module
+
+
+
+
+
+ #
+
+ Initialization
+
+
+ 99 def init(self):
+
+
+
+
+
+ #
+
+ Set tracker configurations
+
+
+ 104 tracker.set_scalar("loss.*", True)
+105 tracker.set_scalar("accuracy.*", True)
+106 tracker.set_histogram('u.*', True)
+107 tracker.set_histogram('prob.*', False)
+108 tracker.set_scalar('annealing_coef.*', False)
+109 tracker.set_scalar('kl_div_loss.*', False)
+
+
+
+
+
+ #
+
+
+
+
+ 112 self.state_modules = []
+
+
+
+
+
+ #
+
+ Training or validation step
+
+
+ 114 def step(self, batch: Any, batch_idx: BatchIndex):
+
+
+
+
+
+ #
+
+ Training/Evaluation mode
+
+
+ 120 self.model.train(self.mode.is_train)
+
+
+
+
+
+ #
+
+ Move data to the device
+
+
+ 123 data, target = batch[0].to(self.device), batch[1].to(self.device)
+
+
+
+
+
+ #
+
+ One-hot coded targets
+
+
+ 126 eye = torch.eye(10).to(torch.float).to(self.device)
+127 target = eye[target]
+
+
+
+
+
+ #
+
+ Update global step (number of samples processed) when in training mode
+
+
+ 130 if self.mode.is_train:
+131 tracker.add_global_step(len(data))
+
+
+
+
+
+ #
+
+ Get model outputs
+
+
+ 134 outputs = self.model(data)
+
+
+
+
+
+ #
+
+ Get evidences $e_k \ge 0$
+
+
+ 136 evidence = self.outputs_to_evidence(outputs)
+
+
+
+
+
+ #
+
+ Calculate loss
+
+
+ 139 loss = self.loss_func(evidence, target)
+
+
+
+
+
+ #
+
+ Calculate KL Divergence regularization loss
+
+
+ 141 kl_div_loss = self.kl_div_loss(evidence, target)
+142 tracker.add("loss.", loss)
+143 tracker.add("kl_div_loss.", kl_div_loss)
+
+
+
+
+
+ #
+
+ KL Divergence loss coefficient $\lambda_t$
+
+
+ 146 annealing_coef = min(1., self.kl_div_coef(tracker.get_global_step()))
+147 tracker.add("annealing_coef.", annealing_coef)
+
+
+
+
+
+ #
+
+ Total loss
+
+
+ 150 loss = loss + annealing_coef * kl_div_loss
+
+
+
+
+
+ #
+
+ Track statistics
+
+
+ 153 self.stats(evidence, target)
+
+
+
+
+
+ #
+
+ Train the model
+
+
+ 156 if self.mode.is_train:
+
+
+
+
+
+ #
+
+ Calculate gradients
+
+
+ 158 loss.backward()
+
+
+
+
+
+ #
+
+ Take optimizer step
+
+
+ 160 self.optimizer.step()
+
+
+
+
+
+ #
+
+ Clear the gradients
+
+
+ 162 self.optimizer.zero_grad()
+
+
+
+
+
+ #
+
+ Save the tracked metrics
+
+
+ 165 tracker.save()
+
+
+
+
+
+ #
+
+ Create model
+
+
+ 168@option(Configs.model)
+169def mnist_model(c: Configs):
+
+
+
+
+
+ #
+
+
+
+
+ 173 return Model(c.dropout).to(c.device)
+
+
+
+
+
+ #
+
+ KL Divergence Loss Coefficient Schedule
+
+
+ 176@option(Configs.kl_div_coef)
+177def kl_div_coef(c: Configs):
+
+
+
+
+
+ #
+
+ Create a relative piecewise schedule
+
+
+ 183 return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))
+
+
+
+
+
+ #
+
+
+
+
+ 187calculate(Configs.loss_func, 'max_likelihood_loss', lambda: MaximumLikelihoodLoss())
+
+
+
+
+
+ #
+
+
+
+
+ 189calculate(Configs.loss_func, 'cross_entropy_bayes_risk', lambda: CrossEntropyBayesRisk())
+
+
+
+
+
+ #
+
+
+
+
+ 191calculate(Configs.loss_func, 'squared_error_bayes_risk', lambda: SquaredErrorBayesRisk())
+
+
+
+
+
+ #
+
+ ReLU to calculate evidence
+
+
+ 194calculate(Configs.outputs_to_evidence, 'relu', lambda: nn.ReLU())
+
+
+
+
+
+ #
+
+ Softplus to calculate evidence
+
+
+ 196calculate(Configs.outputs_to_evidence, 'softplus', lambda: nn.Softplus())
+
+
+
+
+
+ #
+
+
+
+
+ 199def main():
+
+
+
+
+
+ #
+
+ Create experiment
+
+
+ 201 experiment.create(name='evidence_mnist')
+
+
+
+
+
+ #
+
+ Create configurations
+
+
+ 203 conf = Configs()
+
+
+
+
+
+ #
+
+ Load configurations
+
+
+ 205 experiment.configs(conf, {
+206 'optimizer.optimizer': 'Adam',
+207 'optimizer.learning_rate': 0.001,
+208 'optimizer.weight_decay': 0.005,
+
+
+
+
+
+ #
+
+ ‘loss_func’: ‘max_likelihood_loss’,
+‘loss_func’: ‘cross_entropy_bayes_risk’,
+
+
+ 212 'loss_func': 'squared_error_bayes_risk',
+213
+214 'outputs_to_evidence': 'softplus',
+215
+216 'dropout': 0.5,
+217 })
+
+
+
+
+
+ #
+
+ Start the experiment and run the training loop
+
+
+ 219 with experiment.start():
+220 conf.run()
+
+
+
+
+
+ #
+
+
+
+
+ 224if __name__ == '__main__':
+225 main()
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/uncertainty/evidence/index.html b/docs/uncertainty/evidence/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..9cce99edeef49732ec94ff6613a129198d13a84a
--- /dev/null
+++ b/docs/uncertainty/evidence/index.html
@@ -0,0 +1,798 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Evidential Deep Learning to Quantify Classification Uncertainty
+
+
+
+
+
+
+
+
+
+
+
+
+
+ home
+ uncertainty
+ evidence
+
+
+
+
+
+
+
+ #
+
+ Evidential Deep Learning to Quantify Classification Uncertainty
+This is a PyTorch implementation of the paper
+Evidential Deep Learning to Quantify Classification Uncertainty.
+Dampster-Shafer Theory of Evidence
+assigns belief masses a set of classes (unlike assigning a probability to a single class).
+Sum of the masses of all subsets is $1$.
+Individual class probabilities (plausibilities) can be derived from these masses.
+Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying “I don’t know”.
+If there are $K$ classes, we assign masses $b_k \ge 0$ to each of the classes and
+ an overall uncertainty mass $u \ge 0$ to all classes.
+
+
+
+Belief masses $b_k$ and $u$ can be computed from evidence $e_k \ge 0$, as $b_k = \frac{e_k}{S}$
+and $u = \frac{K}{S}$ where $S = \sum_{k=1}^K (e_k + 1)$.
+Paper uses term evidence as a measure of the amount of support
+collected from data in favor of a sample to be classified into a certain class.
+This corresponds to a Dirichlet distribution
+with parameters $\color{cyan}{\alpha_k} = e_k + 1$, and
+ $\color{cyan}{\alpha_0} = S = \sum_{k=1}^K \color{cyan}{\alpha_k}$ is known as the Dirichlet strength.
+Dirichlet distribution $D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}})$
+ is a distribution over categorical distribution; i.e. you can sample class probabilities
+from a Dirichlet distribution.
+The expected probability for class $k$ is $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$.
+We get the model to output evidences
+
+ for a given input $\mathbf{x}$.
+We use a function such as
+ ReLU or a
+ Softplus
+ at the final layer to get $f(\mathbf{x} | \Theta) \ge 0$.
+The paper proposes a few loss functions to train the model, which we have implemented below.
+Here is the training code experiment.py
to train a model on MNIST dataset.
+
+
+
+ 54import torch
+55
+56from labml import tracker
+57from labml_helpers.module import Module
+
+
+
+
+
+ #
+
+
+Type II Maximum Likelihood Loss
+The distribution D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}}) is a prior on the likelihood
+$Multi(\mathbf{y} \vert p)$,
+ and the negative log marginal likelihood is calculated by integrating over class probabilities
+ $\mathbf{p}$.
+If target probabilities (one-hot targets) are $y_k$ for a given sample the loss is,
+
+
+
+
+
+ 60class MaximumLikelihoodLoss(Module):
+
+
+
+
+
+ #
+
+
+evidence
is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]
+target
is $\mathbf{y}$ with shape [batch_size, n_classes]
+
+
+
+ 84 def forward(self, evidence: torch.Tensor, target: torch.Tensor):
+
+
+
+
+
+ #
+
+ $\color{cyan}{\alpha_k} = e_k + 1$
+
+
+ 90 alpha = evidence + 1.
+
+
+
+
+
+ #
+
+ $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
+
+
+ 92 strength = alpha.sum(dim=-1)
+
+
+
+
+
+ #
+
+ Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \log S - \log \color{cyan}{\alpha_k} \bigg)$
+
+
+ 95 loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)
+
+
+
+
+
+ #
+
+ Mean loss over the batch
+
+
+ 98 return loss.mean()
+
+
+
+
+
+ #
+
+
+Bayes Risk with Cross Entropy Loss
+Bayes risk is the overall maximum cost of making incorrect estimates.
+It takes a cost function that gives the cost of making an incorrect estimate
+and sums it over all possible outcomes based on probability distribution.
+Here the cost function is cross-entropy loss, for one-hot coded $\mathbf{y}$
+
+
+We integrate this cost over all $\mathbf{p}$
+
+
+
+where $\psi(\cdot)$ is the $digamma$ function.
+
+
+ 101class CrossEntropyBayesRisk(Module):
+
+
+
+
+
+ #
+
+
+evidence
is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]
+target
is $\mathbf{y}$ with shape [batch_size, n_classes]
+
+
+
+ 130 def forward(self, evidence: torch.Tensor, target: torch.Tensor):
+
+
+
+
+
+ #
+
+ $\color{cyan}{\alpha_k} = e_k + 1$
+
+
+ 136 alpha = evidence + 1.
+
+
+
+
+
+ #
+
+ $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
+
+
+ 138 strength = alpha.sum(dim=-1)
+
+
+
+
+
+ #
+
+ Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{cyan}{\alpha_k} ) \bigg)$
+
+
+ 141 loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)
+
+
+
+
+
+ #
+
+ Mean loss over the batch
+
+
+ 144 return loss.mean()
+
+
+
+
+
+ #
+
+
+Bayes Risk with Squared Error Loss
+Here the cost function is squared error,
+
+
+We integrate this cost over all $\mathbf{p}$
+
+
+
+Where
+is the expected probability when sampled from the Dirichlet distribution
+and
+ where
+
+ is the variance.
+This gives,
+
+
+This first part of the equation $\big(y_k -\mathbb{E}[p_k]\big)^2$ is the error term and
+the second part is the variance.
+
+
+ 147class SquaredErrorBayesRisk(Module):
+
+
+
+
+
+ #
+
+
+evidence
is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]
+target
is $\mathbf{y}$ with shape [batch_size, n_classes]
+
+
+
+ 191 def forward(self, evidence: torch.Tensor, target: torch.Tensor):
+
+
+
+
+
+ #
+
+ $\color{cyan}{\alpha_k} = e_k + 1$
+
+
+ 197 alpha = evidence + 1.
+
+
+
+
+
+ #
+
+ $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
+
+
+ 199 strength = alpha.sum(dim=-1)
+
+
+
+
+
+ #
+
+ $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$
+
+
+ 201 p = alpha / strength[:, None]
+
+
+
+
+
+ #
+
+ Error $(y_k -\hat{p}_k)^2$
+
+
+ 204 err = (target - p) ** 2
+
+
+
+
+
+ #
+
+ Variance $\text{Var}(p_k) = \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$
+
+
+ 206 var = p * (1 - p) / (strength[:, None] + 1)
+
+
+
+
+
+ #
+
+ Sum of them
+
+
+ 209 loss = (err + var).sum(dim=-1)
+
+
+
+
+
+ #
+
+ Mean loss over the batch
+
+
+ 212 return loss.mean()
+
+
+
+
+
+ #
+
+
+KL Divergence Regularization Loss
+This tries to shrink the total evidence to zero if the sample cannot be correctly classified.
+First we calculate $\tilde{\alpha}_k = y_k + (1 - y_k) \color{cyan}{\alpha_k}$ the
+Dirichlet parameters after remove the correct evidence.
+
+
+
+where $\Gamma(\cdot)$ is the gamma function,
+$\psi(\cdot)$ is the $digamma$ function and
+$\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$
+
+
+ 215class KLDivergenceLoss(Module):
+
+
+
+
+
+ #
+
+
+evidence
is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]
+target
is $\mathbf{y}$ with shape [batch_size, n_classes]
+
+
+
+ 238 def forward(self, evidence: torch.Tensor, target: torch.Tensor):
+
+
+
+
+
+ #
+
+ $\color{cyan}{\alpha_k} = e_k + 1$
+
+
+ 244 alpha = evidence + 1.
+
+
+
+
+
+ #
+
+ Number of classes
+
+
+ 246 n_classes = evidence.shape[-1]
+
+
+
+
+
+ #
+
+ Remove non-misleading evidence
+
+
+
+
+ 249 alpha_tilde = target + (1 - target) * alpha
+
+
+
+
+
+ #
+
+ $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$
+
+
+ 251 strength_tilde = alpha_tilde.sum(dim=-1)
+
+
+
+
+
+ #
+
+ The first term
+
+
+
+
+ 261 first = (torch.lgamma(alpha_tilde.sum(dim=-1))
+262 - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
+263 - (torch.lgamma(alpha_tilde)).sum(dim=-1))
+
+
+
+
+
+ #
+
+ The second term
+
+
+
+
+ 268 second = (
+269 (alpha_tilde - 1) *
+270 (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
+271 ).sum(dim=-1)
+
+
+
+
+
+ #
+
+ Sum of the terms
+
+
+ 274 loss = first + second
+
+
+
+
+
+ #
+
+ Mean loss over the batch
+
+
+ 277 return loss.mean()
+
+
+
+
+
+ 280class TrackStatistics(Module):
+
+
+
+
+
+ #
+
+
+
+
+ 287 def forward(self, evidence: torch.Tensor, target: torch.Tensor):
+
+
+
+
+
+ #
+
+ Number of classes
+
+
+ 289 n_classes = evidence.shape[-1]
+
+
+
+
+
+ #
+
+ Predictions that correctly match with the target (greedy sampling based on highest probability)
+
+
+ 291 match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))
+
+
+
+
+
+ #
+
+ Track accuracy
+
+
+ 293 tracker.add('accuracy.', match.sum() / match.shape[0])
+
+
+
+
+
+ #
+
+ $\color{cyan}{\alpha_k} = e_k + 1$
+
+
+ 296 alpha = evidence + 1.
+
+
+
+
+
+ #
+
+ $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
+
+
+ 298 strength = alpha.sum(dim=-1)
+
+
+
+
+
+ #
+
+ $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$
+
+
+ 301 expected_probability = alpha / strength[:, None]
+
+
+
+
+
+ #
+
+ Expected probability of the selected (greedy highset probability) class
+
+
+ 303 expected_probability, _ = expected_probability.max(dim=-1)
+
+
+
+
+
+ #
+
+ Uncertainty mass $u = \frac{K}{S}$
+
+
+ 306 uncertainty_mass = n_classes / strength
+
+
+
+
+
+ #
+
+ Track $u$ for correctly predictions
+
+
+ 309 tracker.add('u.succ.', uncertainty_mass.masked_select(match))
+
+
+
+
+
+ #
+
+ Track $u$ for incorrect predictions
+
+
+ 311 tracker.add('u.fail.', uncertainty_mass.masked_select(~match))
+
+
+
+
+
+ #
+
+ Track $\hat{p}_k$ for correctly predictions
+
+
+ 313 tracker.add('prob.succ.', expected_probability.masked_select(match))
+
+
+
+
+
+ #
+
+ Track $\hat{p}_k$ for incorrect predictions
+
+
+ 315 tracker.add('prob.fail.', expected_probability.masked_select(~match))
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/uncertainty/evidence/readme.html b/docs/uncertainty/evidence/readme.html
new file mode 100644
index 0000000000000000000000000000000000000000..25a1693b03933a5087ff82c005ec828c0fac5a1a
--- /dev/null
+++ b/docs/uncertainty/evidence/readme.html
@@ -0,0 +1,144 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Evidential Deep Learning to Quantify Classification Uncertainty
+
+
+
+
+
+
+
+
+
+
+
+
+
+ home
+ uncertainty
+ evidence
+
+
+
+
+
+
+
+ #
+
+ Evidential Deep Learning to Quantify Classification Uncertainty
+This is a PyTorch implementation of the paper
+Evidential Deep Learning to Quantify Classification Uncertainty.
+Here is the training code experiment.py
to train a model on MNIST dataset.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/uncertainty/index.html b/docs/uncertainty/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..a4ff24eeddc9549bddf5d6c4ed7f49cb9c6d2dfa
--- /dev/null
+++ b/docs/uncertainty/index.html
@@ -0,0 +1,143 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Neural Networks with Uncertainty Estimation
+
+
+
+
+
+
+
+
+
+
+
+
+
+ home
+ uncertainty
+
+
+
+
+
+
+
+ #
+
+ Neural Networks with Uncertainty Estimation
+These are neural network architectures that estimate the uncertainty of the predictions.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/uncertainty/readme.html b/docs/uncertainty/readme.html
new file mode 100644
index 0000000000000000000000000000000000000000..97a403fb194153dddb1c907b11b6242d7811b2d7
--- /dev/null
+++ b/docs/uncertainty/readme.html
@@ -0,0 +1,143 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Neural Networks with Uncertainty Estimation
+
+
+
+
+
+
+
+
+
+
+
+
+
+ home
+ uncertainty
+
+
+
+
+
+
+
+ #
+
+ Neural Networks with Uncertainty Estimation
+These are neural network architectures that estimate the uncertainty of the predictions.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/labml_nn/__init__.py b/labml_nn/__init__.py
index 0293af96aa83e8fd3be2f15094e9693f14f60701..5022f4634c6e826a8af4c08a3355ebaccbb0ae64 100644
--- a/labml_nn/__init__.py
+++ b/labml_nn/__init__.py
@@ -94,6 +94,10 @@ Solving games with incomplete information such as poker with CFR.
* [PonderNet](adaptive_computation/ponder_net/index.html)
+#### ✨ [Uncertainty](uncertainty/index.html)
+
+* [Evidential Deep Learning to Quantify Classification Uncertainty](uncertainty/evidence/index.html)
+
### Installation
```bash
@@ -102,12 +106,12 @@ pip install labml-nn
### Citing LabML
-If you use LabML for academic research, please cite the library using the following BibTeX entry.
+If you use this for academic research, please cite it using the following BibTeX entry.
```bibtex
@misc{labml,
author = {Varuna Jayasiri, Nipun Wijerathne},
- title = {LabML: A library to organize machine learning experiments},
+ title = {labml.ai Annotated Paper Implementations},
year = {2020},
url = {https://nn.labml.ai/},
}
diff --git a/labml_nn/normalization/batch_norm/mnist.py b/labml_nn/normalization/batch_norm/mnist.py
index e06890843c6f6b5506d5283835e75f4f1d83dd14..e194628dbf4940785e587e5c91775e3c1c282081 100644
--- a/labml_nn/normalization/batch_norm/mnist.py
+++ b/labml_nn/normalization/batch_norm/mnist.py
@@ -72,7 +72,10 @@ def main():
# Create configurations
conf = MNISTConfigs()
# Load configurations
- experiment.configs(conf, {'optimizer.optimizer': 'Adam'})
+ experiment.configs(conf, {
+ 'optimizer.optimizer': 'Adam',
+ 'optimizer.learning_rate': 0.001,
+ })
# Start the experiment and run the training loop
with experiment.start():
conf.run()
diff --git a/labml_nn/uncertainty/__init__.py b/labml_nn/uncertainty/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..54077fe1676e53388a89e8df8a09f3adac043b70
--- /dev/null
+++ b/labml_nn/uncertainty/__init__.py
@@ -0,0 +1,13 @@
+"""
+---
+title: Neural Networks with Uncertainty Estimation
+summary: >
+ A set of PyTorch implementations/tutorials related to uncertainty estimation
+---
+
+# Neural Networks with Uncertainty Estimation
+
+These are neural network architectures that estimate the uncertainty of the predictions.
+
+* [Evidential Deep Learning to Quantify Classification Uncertainty](evidence/index.html)
+"""
diff --git a/labml_nn/uncertainty/evidence/__init__.py b/labml_nn/uncertainty/evidence/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0a911d45bae6fc84c33c3c16224f3955a9c50f8
--- /dev/null
+++ b/labml_nn/uncertainty/evidence/__init__.py
@@ -0,0 +1,315 @@
+"""
+---
+title: "Evidential Deep Learning to Quantify Classification Uncertainty"
+summary: >
+ A PyTorch implementation/tutorial of the paper Evidential Deep Learning to Quantify Classification
+ Uncertainty.
+---
+
+# Evidential Deep Learning to Quantify Classification Uncertainty
+
+This is a [PyTorch](https://pytorch.org) implementation of the paper
+[Evidential Deep Learning to Quantify Classification Uncertainty](https://papers.labml.ai/paper/1806.01768).
+
+[Dampster-Shafer Theory of Evidence](https://en.wikipedia.org/wiki/Dempster%E2%80%93Shafer_theory)
+assigns belief masses a set of classes (unlike assigning a probability to a single class).
+Sum of the masses of all subsets is $1$.
+Individual class probabilities (plausibilities) can be derived from these masses.
+
+Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying "I don't know".
+
+If there are $K$ classes, we assign masses $b_k \ge 0$ to each of the classes and
+ an overall uncertainty mass $u \ge 0$ to all classes.
+
+$$u + \sum_{k=1}^K b_k = 1$$
+
+Belief masses $b_k$ and $u$ can be computed from evidence $e_k \ge 0$, as $b_k = \frac{e_k}{S}$
+and $u = \frac{K}{S}$ where $S = \sum_{k=1}^K (e_k + 1)$.
+Paper uses term evidence as a measure of the amount of support
+collected from data in favor of a sample to be classified into a certain class.
+
+This corresponds to a [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution)
+with parameters $\color{cyan}{\alpha_k} = e_k + 1$, and
+ $\color{cyan}{\alpha_0} = S = \sum_{k=1}^K \color{cyan}{\alpha_k}$ is known as the Dirichlet strength.
+Dirichlet distribution $D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}})$
+ is a distribution over categorical distribution; i.e. you can sample class probabilities
+from a Dirichlet distribution.
+The expected probability for class $k$ is $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$.
+
+We get the model to output evidences
+$$\mathbf{e} = \color{cyan}{\mathbf{\alpha}} - 1 = f(\mathbf{x} | \Theta)$$
+ for a given input $\mathbf{x}$.
+We use a function such as
+ [ReLU](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html) or a
+ [Softplus](https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html)
+ at the final layer to get $f(\mathbf{x} | \Theta) \ge 0$.
+
+The paper proposes a few loss functions to train the model, which we have implemented below.
+
+Here is the [training code `experiment.py`](experiment.html) to train a model on MNIST dataset.
+
+[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/f82b2bfc01ba11ecbb2aa16a33570106)
+"""
+
+import torch
+
+from labml import tracker
+from labml_helpers.module import Module
+
+
+class MaximumLikelihoodLoss(Module):
+ """
+
+ ## Type II Maximum Likelihood Loss
+
+ The distribution D(\mathbf{p} \vert \color{cyan}{\mathbf{\alpha}}) is a prior on the likelihood
+ $Multi(\mathbf{y} \vert p)$,
+ and the negative log marginal likelihood is calculated by integrating over class probabilities
+ $\mathbf{p}$.
+
+ If target probabilities (one-hot targets) are $y_k$ for a given sample the loss is,
+
+ \begin{align}
+ \mathcal{L}(\Theta)
+ &= -\log \Bigg(
+ \int
+ \prod_{k=1}^K p_k^{y_k}
+ \frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
+ \prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
+ d\mathbf{p}
+ \Bigg ) \\
+ &= \sum_{k=1}^K y_k \bigg( \log S - \log \color{cyan}{\alpha_k} \bigg)
+ \end{align}
+ """
+ def forward(self, evidence: torch.Tensor, target: torch.Tensor):
+ """
+ * `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
+ * `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
+ """
+ # $\color{cyan}{\alpha_k} = e_k + 1$
+ alpha = evidence + 1.
+ # $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
+ strength = alpha.sum(dim=-1)
+
+ # Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \log S - \log \color{cyan}{\alpha_k} \bigg)$
+ loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)
+
+ # Mean loss over the batch
+ return loss.mean()
+
+
+class CrossEntropyBayesRisk(Module):
+ """
+
+ ## Bayes Risk with Cross Entropy Loss
+
+ Bayes risk is the overall maximum cost of making incorrect estimates.
+ It takes a cost function that gives the cost of making an incorrect estimate
+ and sums it over all possible outcomes based on probability distribution.
+
+ Here the cost function is cross-entropy loss, for one-hot coded $\mathbf{y}$
+ $$\sum_{k=1}^K -y_k \log p_k$$
+
+ We integrate this cost over all $\mathbf{p}$
+
+ \begin{align}
+ \mathcal{L}(\Theta)
+ &= -\log \Bigg(
+ \int
+ \Big[ \sum_{k=1}^K -y_k \log p_k \Big]
+ \frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
+ \prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
+ d\mathbf{p}
+ \Bigg ) \\
+ &= \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{cyan}{\alpha_k} ) \bigg)
+ \end{align}
+
+ where $\psi(\cdot)$ is the $digamma$ function.
+ """
+
+ def forward(self, evidence: torch.Tensor, target: torch.Tensor):
+ """
+ * `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
+ * `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
+ """
+ # $\color{cyan}{\alpha_k} = e_k + 1$
+ alpha = evidence + 1.
+ # $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
+ strength = alpha.sum(dim=-1)
+
+ # Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{cyan}{\alpha_k} ) \bigg)$
+ loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)
+
+ # Mean loss over the batch
+ return loss.mean()
+
+
+class SquaredErrorBayesRisk(Module):
+ """
+
+ ## Bayes Risk with Squared Error Loss
+
+ Here the cost function is squared error,
+ $$\sum_{k=1}^K (y_k - p_k)^2 = \Vert \mathbf{y} - \mathbf{p} \Vert_2^2$$
+
+ We integrate this cost over all $\mathbf{p}$
+
+ \begin{align}
+ \mathcal{L}(\Theta)
+ &= -\log \Bigg(
+ \int
+ \Big[ \sum_{k=1}^K (y_k - p_k)^2 \Big]
+ \frac{1}{B(\color{cyan}{\mathbf{\alpha}})}
+ \prod_{k=1}^K p_k^{\color{cyan}{\alpha_k} - 1}
+ d\mathbf{p}
+ \Bigg ) \\
+ &= \sum_{k=1}^K \mathbb{E} \Big[ y_k^2 -2 y_k p_k + p_k^2 \Big] \\
+ &= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k^2] \Big)
+ \end{align}
+
+ Where $$\mathbb{E}[p_k] = \hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$$
+ is the expected probability when sampled from the Dirichlet distribution
+ and $$\mathbb{E}[p_k^2] = \mathbb{E}[p_k]^2 + \text{Var}(p_k)$$
+ where
+ $$\text{Var}(p_k) = \frac{\color{cyan}{\alpha_k}(S - \color{cyan}{\alpha_k})}{S^2 (S + 1)}
+ = \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$$
+ is the variance.
+
+ This gives,
+ \begin{align}
+ \mathcal{L}(\Theta)
+ &= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k^2] \Big) \\
+ &= \sum_{k=1}^K \Big( y_k^2 -2 y_k \mathbb{E}[p_k] + \mathbb{E}[p_k]^2 + \text{Var}(p_k) \Big) \\
+ &= \sum_{k=1}^K \Big( \big( y_k -\mathbb{E}[p_k] \big)^2 + \text{Var}(p_k) \Big) \\
+ &= \sum_{k=1}^K \Big( ( y_k -\hat{p}_k)^2 + \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1} \Big)
+ \end{align}
+
+ This first part of the equation $\big(y_k -\mathbb{E}[p_k]\big)^2$ is the error term and
+ the second part is the variance.
+ """
+
+ def forward(self, evidence: torch.Tensor, target: torch.Tensor):
+ """
+ * `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
+ * `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
+ """
+ # $\color{cyan}{\alpha_k} = e_k + 1$
+ alpha = evidence + 1.
+ # $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
+ strength = alpha.sum(dim=-1)
+ # $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$
+ p = alpha / strength[:, None]
+
+ # Error $(y_k -\hat{p}_k)^2$
+ err = (target - p) ** 2
+ # Variance $\text{Var}(p_k) = \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$
+ var = p * (1 - p) / (strength[:, None] + 1)
+
+ # Sum of them
+ loss = (err + var).sum(dim=-1)
+
+ # Mean loss over the batch
+ return loss.mean()
+
+
+class KLDivergenceLoss(Module):
+ """
+
+ ## KL Divergence Regularization Loss
+
+ This tries to shrink the total evidence to zero if the sample cannot be correctly classified.
+
+ First we calculate $\tilde{\alpha}_k = y_k + (1 - y_k) \color{cyan}{\alpha_k}$ the
+ Dirichlet parameters after remove the correct evidence.
+
+ \begin{align}
+ &KL \Big[ D(\mathbf{p} \vert \mathbf{\tilde{\alpha}}) \Big \Vert
+ D(\mathbf{p} \vert <1, \dots, 1>\Big] \\
+ &= \log \Bigg( \frac{\Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)}
+ {\Gamma(K) \prod_{k=1}^K \Gamma(\tilde{\alpha}_k)} \Bigg)
+ + \sum_{k=1}^K (\tilde{\alpha}_k - 1)
+ \Big[ \psi(\tilde{\alpha}_k) - \psi(\tilde{S}) \Big]
+ \end{align}
+
+ where $\Gamma(\cdot)$ is the gamma function,
+ $\psi(\cdot)$ is the $digamma$ function and
+ $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$
+ """
+ def forward(self, evidence: torch.Tensor, target: torch.Tensor):
+ """
+ * `evidence` is $\mathbf{e} \ge 0$ with shape `[batch_size, n_classes]`
+ * `target` is $\mathbf{y}$ with shape `[batch_size, n_classes]`
+ """
+ # $\color{cyan}{\alpha_k} = e_k + 1$
+ alpha = evidence + 1.
+ # Number of classes
+ n_classes = evidence.shape[-1]
+ # Remove non-misleading evidence
+ # $$\tilde{\alpha}_k = y_k + (1 - y_k) \color{cyan}{\alpha_k}$$
+ alpha_tilde = target + (1 - target) * alpha
+ # $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$
+ strength_tilde = alpha_tilde.sum(dim=-1)
+
+ # The first term
+ # \begin{align}
+ # &\log \Bigg( \frac{\Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)}
+ # {\Gamma(K) \prod_{k=1}^K \Gamma(\tilde{\alpha}_k)} \Bigg) \\
+ # &= \log \Gamma \Big( \sum_{k=1}^K \tilde{\alpha}_k \Big)
+ # - \log \Gamma(K)
+ # - \sum_{k=1}^K \log \Gamma(\tilde{\alpha}_k)
+ # \end{align}
+ first = (torch.lgamma(alpha_tilde.sum(dim=-1))
+ - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
+ - (torch.lgamma(alpha_tilde)).sum(dim=-1))
+
+ # The second term
+ # $$\sum_{k=1}^K (\tilde{\alpha}_k - 1)
+ # \Big[ \psi(\tilde{\alpha}_k) - \psi(\tilde{S}) \Big]$$
+ second = (
+ (alpha_tilde - 1) *
+ (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
+ ).sum(dim=-1)
+
+ # Sum of the terms
+ loss = first + second
+
+ # Mean loss over the batch
+ return loss.mean()
+
+
+class TrackStatistics(Module):
+ """
+
+ ### Track statistics
+
+ This module computes statistics and tracks them with [labml `tracker`](https://docs.labml.ai/api/tracker.html).
+ """
+ def forward(self, evidence: torch.Tensor, target: torch.Tensor):
+ # Number of classes
+ n_classes = evidence.shape[-1]
+ # Predictions that correctly match with the target (greedy sampling based on highest probability)
+ match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))
+ # Track accuracy
+ tracker.add('accuracy.', match.sum() / match.shape[0])
+
+ # $\color{cyan}{\alpha_k} = e_k + 1$
+ alpha = evidence + 1.
+ # $S = \sum_{k=1}^K \color{cyan}{\alpha_k}$
+ strength = alpha.sum(dim=-1)
+
+ # $\hat{p}_k = \frac{\color{cyan}{\alpha_k}}{S}$
+ expected_probability = alpha / strength[:, None]
+ # Expected probability of the selected (greedy highset probability) class
+ expected_probability, _ = expected_probability.max(dim=-1)
+
+ # Uncertainty mass $u = \frac{K}{S}$
+ uncertainty_mass = n_classes / strength
+
+ # Track $u$ for correctly predictions
+ tracker.add('u.succ.', uncertainty_mass.masked_select(match))
+ # Track $u$ for incorrect predictions
+ tracker.add('u.fail.', uncertainty_mass.masked_select(~match))
+ # Track $\hat{p}_k$ for correctly predictions
+ tracker.add('prob.succ.', expected_probability.masked_select(match))
+ # Track $\hat{p}_k$ for incorrect predictions
+ tracker.add('prob.fail.', expected_probability.masked_select(~match))
diff --git a/labml_nn/uncertainty/evidence/experiment.py b/labml_nn/uncertainty/evidence/experiment.py
new file mode 100644
index 0000000000000000000000000000000000000000..96285810b119f39bfd116fc1aafdc23004aa4b1e
--- /dev/null
+++ b/labml_nn/uncertainty/evidence/experiment.py
@@ -0,0 +1,225 @@
+"""
+---
+title: "Evidential Deep Learning to Quantify Classification Uncertainty Experiment"
+summary: >
+ This trains is EDL model on MNIST
+---
+
+# [Evidential Deep Learning to Quantify Classification Uncertainty](index.html) Experiment
+
+This trains a model based on [Evidential Deep Learning to Quantify Classification Uncertainty](index.html)
+ on MNIST dataset.
+"""
+
+from typing import Any
+
+import torch.nn as nn
+import torch.utils.data
+
+from labml import tracker, experiment
+from labml.configs import option, calculate
+from labml_helpers.module import Module
+from labml_helpers.schedule import Schedule, RelativePiecewise
+from labml_helpers.train_valid import BatchIndex
+from labml_nn.experiments.mnist import MNISTConfigs
+from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
+ CrossEntropyBayesRisk, SquaredErrorBayesRisk
+
+
+class Model(Module):
+ """
+ ## LeNet based model fro MNIST classification
+ """
+
+ def __init__(self, dropout: float):
+ super().__init__()
+ # First $5x5$ convolution layer
+ self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
+ # ReLU activation
+ self.act1 = nn.ReLU()
+ # $2x2$ max-pooling
+ self.max_pool1 = nn.MaxPool2d(2, 2)
+ # Second $5x5$ convolution layer
+ self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
+ # ReLU activation
+ self.act2 = nn.ReLU()
+ # $2x2$ max-pooling
+ self.max_pool2 = nn.MaxPool2d(2, 2)
+ # First fully-connected layer that maps to $500$ features
+ self.fc1 = nn.Linear(50 * 4 * 4, 500)
+ # ReLU activation
+ self.act3 = nn.ReLU()
+ # Final fully connected layer to output evidence for $10$ classes.
+ # The ReLU or Softplus activation is applied to this outside the model to get the
+ # non-negative evidence
+ self.fc2 = nn.Linear(500, 10)
+ # Dropout for the hidden layer
+ self.dropout = nn.Dropout(p=dropout)
+
+ def __call__(self, x: torch.Tensor):
+ """
+ * `x` is the batch of MNIST images of shape `[batch_size, 1, 28, 28]`
+ """
+ # Apply first convolution and max pooling.
+ # The result has shape `[batch_size, 20, 12, 12]`
+ x = self.max_pool1(self.act1(self.conv1(x)))
+ # Apply second convolution and max pooling.
+ # The result has shape `[batch_size, 50, 4, 4]`
+ x = self.max_pool2(self.act2(self.conv2(x)))
+ # Flatten the tensor to shape `[batch_size, 50 * 4 * 4]`
+ x = x.view(x.shape[0], -1)
+ # Apply hidden layer
+ x = self.act3(self.fc1(x))
+ # Apply dropout
+ x = self.dropout(x)
+ # Apply final layer and return
+ return self.fc2(x)
+
+
+class Configs(MNISTConfigs):
+ """
+ ## Configurations
+
+ We use [`MNISTConfigs`](../../experiments/mnist.html#MNISTConfigs) configurations.
+ """
+
+ # [KL Divergence regularization](index.html#KLDivergenceLoss)
+ kl_div_loss = KLDivergenceLoss()
+ # KL Divergence regularization coefficient schedule
+ kl_div_coef: Schedule
+ # KL Divergence regularization coefficient schedule
+ kl_div_coef_schedule = [(0, 0.), (0.2, 0.01), (1, 1.)]
+ # [Stats module](index.html#TrackStatistics) for tracking
+ stats = TrackStatistics()
+ # Dropout
+ dropout: float = 0.5
+ # Module to convert the model output to non-zero evidences
+ outputs_to_evidence: Module
+
+ def init(self):
+ """
+ ### Initialization
+ """
+ # Set tracker configurations
+ tracker.set_scalar("loss.*", True)
+ tracker.set_scalar("accuracy.*", True)
+ tracker.set_histogram('u.*', True)
+ tracker.set_histogram('prob.*', False)
+ tracker.set_scalar('annealing_coef.*', False)
+ tracker.set_scalar('kl_div_loss.*', False)
+
+ #
+ self.state_modules = []
+
+ def step(self, batch: Any, batch_idx: BatchIndex):
+ """
+ ### Training or validation step
+ """
+
+ # Training/Evaluation mode
+ self.model.train(self.mode.is_train)
+
+ # Move data to the device
+ data, target = batch[0].to(self.device), batch[1].to(self.device)
+
+ # One-hot coded targets
+ eye = torch.eye(10).to(torch.float).to(self.device)
+ target = eye[target]
+
+ # Update global step (number of samples processed) when in training mode
+ if self.mode.is_train:
+ tracker.add_global_step(len(data))
+
+ # Get model outputs
+ outputs = self.model(data)
+ # Get evidences $e_k \ge 0$
+ evidence = self.outputs_to_evidence(outputs)
+
+ # Calculate loss
+ loss = self.loss_func(evidence, target)
+ # Calculate KL Divergence regularization loss
+ kl_div_loss = self.kl_div_loss(evidence, target)
+ tracker.add("loss.", loss)
+ tracker.add("kl_div_loss.", kl_div_loss)
+
+ # KL Divergence loss coefficient $\lambda_t$
+ annealing_coef = min(1., self.kl_div_coef(tracker.get_global_step()))
+ tracker.add("annealing_coef.", annealing_coef)
+
+ # Total loss
+ loss = loss + annealing_coef * kl_div_loss
+
+ # Track statistics
+ self.stats(evidence, target)
+
+ # Train the model
+ if self.mode.is_train:
+ # Calculate gradients
+ loss.backward()
+ # Take optimizer step
+ self.optimizer.step()
+ # Clear the gradients
+ self.optimizer.zero_grad()
+
+ # Save the tracked metrics
+ tracker.save()
+
+
+@option(Configs.model)
+def mnist_model(c: Configs):
+ """
+ ### Create model
+ """
+ return Model(c.dropout).to(c.device)
+
+
+@option(Configs.kl_div_coef)
+def kl_div_coef(c: Configs):
+ """
+ ### KL Divergence Loss Coefficient Schedule
+ """
+
+ # Create a [relative piecewise schedule](https://docs.labml.ai/api/helpers.html#labml_helpers.schedule.Piecewise)
+ return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))
+
+
+# [Maximum Likelihood Loss](index.html#MaximumLikelihoodLoss)
+calculate(Configs.loss_func, 'max_likelihood_loss', lambda: MaximumLikelihoodLoss())
+# [Cross Entropy Bayes Risk](index.html#CrossEntropyBayesRisk)
+calculate(Configs.loss_func, 'cross_entropy_bayes_risk', lambda: CrossEntropyBayesRisk())
+# [Squared Error Bayes Risk](index.html#SquaredErrorBayesRisk)
+calculate(Configs.loss_func, 'squared_error_bayes_risk', lambda: SquaredErrorBayesRisk())
+
+# ReLU to calculate evidence
+calculate(Configs.outputs_to_evidence, 'relu', lambda: nn.ReLU())
+# Softplus to calculate evidence
+calculate(Configs.outputs_to_evidence, 'softplus', lambda: nn.Softplus())
+
+
+def main():
+ # Create experiment
+ experiment.create(name='evidence_mnist')
+ # Create configurations
+ conf = Configs()
+ # Load configurations
+ experiment.configs(conf, {
+ 'optimizer.optimizer': 'Adam',
+ 'optimizer.learning_rate': 0.001,
+ 'optimizer.weight_decay': 0.005,
+
+ # 'loss_func': 'max_likelihood_loss',
+ # 'loss_func': 'cross_entropy_bayes_risk',
+ 'loss_func': 'squared_error_bayes_risk',
+
+ 'outputs_to_evidence': 'softplus',
+
+ 'dropout': 0.5,
+ })
+ # Start the experiment and run the training loop
+ with experiment.start():
+ conf.run()
+
+
+#
+if __name__ == '__main__':
+ main()
diff --git a/labml_nn/uncertainty/evidence/readme.md b/labml_nn/uncertainty/evidence/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..da54b56690591284203b3c32c90c7ae40526dd46
--- /dev/null
+++ b/labml_nn/uncertainty/evidence/readme.md
@@ -0,0 +1,8 @@
+# [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
+
+This is a [PyTorch](https://pytorch.org) implementation of the paper
+[Evidential Deep Learning to Quantify Classification Uncertainty](https://papers.labml.ai/paper/1806.01768).
+
+Here is the [training code `experiment.py`](https://nn.labml.ai/uncertainty/evidence/experiment.html) to train a model on MNIST dataset.
+
+[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/f82b2bfc01ba11ecbb2aa16a33570106)
diff --git a/labml_nn/uncertainty/readme.md b/labml_nn/uncertainty/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..798271f62a0da599e1282f5677e03bf7d97da39c
--- /dev/null
+++ b/labml_nn/uncertainty/readme.md
@@ -0,0 +1,5 @@
+# [Neural Networks with Uncertainty Estimation](https://nn.labml.ai/uncertainty/index.html)
+
+These are neural network architectures that estimate the uncertainty of the predictions.
+
+* [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
diff --git a/readme.md b/readme.md
index d070372b08c14ee29c1678deaccbad6bf76a14de..13deca16798f4a5d805118de29e123b6565275bf 100644
--- a/readme.md
+++ b/readme.md
@@ -99,6 +99,10 @@ Solving games with incomplete information such as poker with CFR.
* [PonderNet](https://nn.labml.ai/adaptive_computation/ponder_net/index.html)
+#### ✨ [Uncertainty](https://nn.labml.ai/uncertainty/index.html)
+
+* [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
+
### Installation
```bash
diff --git a/requirements.txt b/requirements.txt
index 69cd5c8d24a5ebb1abf367824fec893c95f38343..2771391e74a41ec99f187be133fddb1ee0e7bd4b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,6 @@
torch>=1.7
-labml>=0.4.94
-labml-helpers>=0.4.77
+labml>=0.4.132
+labml-helpers>=0.4.81
torchvision
numpy>=1.16.3
matplotlib>=3.0.3
diff --git a/setup.py b/setup.py
index 90557d02616468d8ad0657cdf61b8241897a80f9..da43851f4aa1af9986ba068d62f2289c5b46f22b 100644
--- a/setup.py
+++ b/setup.py
@@ -5,10 +5,10 @@ with open("readme.md", "r") as f:
setuptools.setup(
name='labml-nn',
- version='0.4.109',
+ version='0.4.110',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
- description="🧠 Implementations/tutorials of deep learning papers with side-by-side notes; including transformers (original, xl, switch, feedback, vit), optimizers(adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), reinforcement learning (ppo, dqn), capsnet, distillation, etc.",
+ description="🧑🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, etc. 🧠",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/labmlai/annotated_deep_learning_paper_implementations",
@@ -20,7 +20,7 @@ setuptools.setup(
'labml_helpers', 'labml_helpers.*',
'test',
'test.*')),
- install_requires=['labml>=0.4.129',
+ install_requires=['labml>=0.4.132',
'labml-helpers>=0.4.81',
'torch',
'einops',