bayesian_linear.py 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
"""
The code here is greatly insipired by a couple of sources:

the Blitz package: https://github.com/piEsposito/blitz-bayesian-deep-learning and

Weight Uncertainty in Neural Networks post by Nitarshan Rajkumar: https://www.nitarshan.com/bayes-by-backprop/

and references therein
"""

11 12 13 14 15 16 17 18 19 20 21 22 23 24
import torch.nn.functional as F
from torch import nn

from pytorch_widedeep.wdtypes import *  # noqa: F403
from pytorch_widedeep.bayesian_models._weight_sampler import (
    GaussianPosterior,
    ScaleMixtureGaussianPrior,
)
from pytorch_widedeep.bayesian_models._base_bayesian_model import (
    BayesianModule,
)


class BayesianLinear(BayesianModule):
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
    r"""Applies a linear transformation to the incoming data as proposed in Weight
    Uncertainity on Neural Networks

    Parameters
    ----------
    in_features: int
        size of each input sample
    out_features: int
         size of each output sample
    use_bias: bool, default = True
        Boolean indicating if an additive bias will be learnt
    prior_sigma_1: float, default = 1.0
        Prior of the sigma parameter for the first of the two weight Gaussian
        distributions that will be mixed to produce the prior weight
        distribution
    prior_sigma_2: float, default = 0.002
        Prior of the sigma parameter for the second of the two weight Gaussian
        distributions that will be mixed to produce the prior weight
        distribution
    prior_pi: float, default = 0.8
        Scaling factor that will be used to mix the Gaussians to produce the
        prior weight distribution
    posterior_mu_init: float = 0.0,
        The posterior sample of the weights is defined as:

            :math:`\mathbf{w} = \mu + log(1 + exp(\rho))`

        where :math:`\mu` and :math:`\rho` are both sampled from Gaussian
        distributions. ``posterior_mu_init`` is the initial mean value for
        the Gaussian distribution from which :math:`\mu` is sampled.

    posterior_rho_init: float = -7.0,
        The initial mean value for the Gaussian distribution from which `\rho`
        is sampled.

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.bayesian_models import bayesian_nn as bnn
    >>> linear = bnn.BayesianLinear(10, 6)
    >>> input = torch.rand(6, 10)
    >>> out = linear(input)
    """

69 70 71 72 73
    def __init__(
        self,
        in_features: int,
        out_features: int,
        use_bias: bool = True,
74
        prior_sigma_1: float = 1.0,
75
        prior_sigma_2: float = 0.002,
76
        prior_pi: float = 0.8,
77
        posterior_mu_init: float = 0.0,
78
        posterior_rho_init: float = -7.0,
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
    ):
        super(BayesianLinear, self).__init__()

        # main parameters of the layer
        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias

        # posterior params
        self.posterior_mu_init = posterior_mu_init
        self.posterior_rho_init = posterior_rho_init
        self.prior_sigma_1 = prior_sigma_1
        self.prior_sigma_2 = prior_sigma_2
        self.prior_pi = prior_pi

94
        # Variational Posterior
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
        self.weight_mu = nn.Parameter(
            torch.Tensor(out_features, in_features).normal_(posterior_mu_init, 0.1)
        )
        self.weight_rho = nn.Parameter(
            torch.Tensor(out_features, in_features).normal_(posterior_rho_init, 0.1)
        )
        self.weight_sampler = GaussianPosterior(self.weight_mu, self.weight_rho)

        if self.use_bias:
            self.bias_mu = nn.Parameter(
                torch.Tensor(out_features).normal_(posterior_mu_init, 0.1)
            )
            self.bias_rho = nn.Parameter(
                torch.Tensor(out_features).normal_(posterior_rho_init, 0.1)
            )
            self.bias_sampler = GaussianPosterior(self.bias_mu, self.bias_rho)
        else:
            self.bias_mu, self.bias_rho = None, None

        # Prior
        self.weight_prior_dist = ScaleMixtureGaussianPrior(
            self.prior_pi,
            self.prior_sigma_1,
            self.prior_sigma_2,
        )
        if self.use_bias:
            self.bias_prior_dist = ScaleMixtureGaussianPrior(
                self.prior_pi,
                self.prior_sigma_1,
                self.prior_sigma_2,
            )

        self.log_prior: Union[Tensor, float] = 0.0
        self.log_variational_posterior: Union[Tensor, float] = 0.0

    def forward(self, X: Tensor) -> Tensor:

        if not self.training:
            return F.linear(X, self.weight_mu, self.bias_mu)

        weight = self.weight_sampler.sample()
        if self.use_bias:
            bias = self.bias_sampler.sample()
            bias_log_posterior: Union[Tensor, float] = self.bias_sampler.log_posterior(
                bias
            )
            bias_log_prior: Union[Tensor, float] = self.bias_prior_dist.log_prior(bias)
        else:
            bias = None
            bias_log_posterior = 0.0
            bias_log_prior = 0.0

        self.log_variational_posterior = (
            self.weight_sampler.log_posterior(weight) + bias_log_posterior
        )
        self.log_prior = self.weight_prior_dist.log_prior(weight) + bias_log_prior

        return F.linear(X, weight, bias)

    def extra_repr(self) -> str:  # noqa: C901
        s = "{in_features}, {out_features}"
        if self.use_bias is not False:
            s += ", use_bias=True"
        if self.prior_sigma_1 != 0.1:
159
            s += ", prior_sigma_1={prior_sigma_1}"
160
        if self.prior_sigma_2 != 0.002:
161 162 163
            s += ", prior_sigma_2={prior_sigma_2}"
        if self.prior_pi != 0.8:
            s += ", prior_pi={prior_pi}"
164
        if self.posterior_mu_init != 0.0:
165 166 167
            s += ", posterior_mu_init={posterior_mu_init}"
        if self.posterior_rho_init != -8.0:
            s += ", posterior_rho_init={posterior_rho_init}"
168
        return s.format(**self.__dict__)