From f9b1e8cb173228a9fd8c4025e4eb9f4fc9320d26 Mon Sep 17 00:00:00 2001 From: jrzaurin Date: Sat, 25 Dec 2021 21:12:05 +0100 Subject: [PATCH] re-arranged the code for the bayesian model and added docs --- docs/bayesian_models.rst | 15 + docs/index.rst | 1 + .../scripts/adult_census_bayesian_tabmlp.py | 127 ++++ pytorch_widedeep/bayesian_models/__init__.py | 1 + .../bayesian_models/_base_bayesian_model.py | 21 +- .../bayesian_models/_weight_sampler.py | 14 + .../bayesian_models/bayesian_nn/__init__.py | 1 + .../bayesian_nn/modules/__init__.py | 2 + .../bayesian_nn/modules/bayesian_embedding.py | 181 ++++++ .../modules}/bayesian_linear.py | 77 ++- .../bayesian_embeddings_layers.py | 183 +----- .../tabular/bayesian_linear/bayesian_wide.py | 79 ++- .../tabular/bayesian_mlp/_layers.py | 4 +- .../tabular/bayesian_mlp/bayesian_tab_mlp.py | 101 ++- pytorch_widedeep/losses.py | 12 +- pytorch_widedeep/training/_trainer_utils.py | 112 +++- pytorch_widedeep/training/bayesian_trainer.py | 601 ++++++++++++++++++ 17 files changed, 1324 insertions(+), 208 deletions(-) create mode 100644 docs/bayesian_models.rst create mode 100644 examples/scripts/adult_census_bayesian_tabmlp.py create mode 100644 pytorch_widedeep/bayesian_models/bayesian_nn/__init__.py create mode 100644 pytorch_widedeep/bayesian_models/bayesian_nn/modules/__init__.py create mode 100644 pytorch_widedeep/bayesian_models/bayesian_nn/modules/bayesian_embedding.py rename pytorch_widedeep/bayesian_models/{ => bayesian_nn/modules}/bayesian_linear.py (58%) rename pytorch_widedeep/bayesian_models/{ => tabular}/bayesian_embeddings_layers.py (56%) create mode 100644 pytorch_widedeep/training/bayesian_trainer.py diff --git a/docs/bayesian_models.rst b/docs/bayesian_models.rst new file mode 100644 index 0000000..a85dc36 --- /dev/null +++ b/docs/bayesian_models.rst @@ -0,0 +1,15 @@ +The ``models`` module +====================== + +This module contains the two Bayesian Models available in this library, namely +the bayesian version of the Wide and TabMlp models, referred as ``BayesianWide`` +and ``BayesianTabMlp`` + + +.. autoclass:: pytorch_widedeep.bayesian_models.tabular.bayesian_linear.bayesian_wide.BayesianWide + :exclude-members: forward + :members: + +.. autoclass:: pytorch_widedeep.bayesian_models.tabular.bayesian_mlp.bayesian_tab_mlp.BayesianTabMlp + :exclude-members: forward + :members: diff --git a/docs/index.rst b/docs/index.rst index 5153aaf..2e573f6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -18,6 +18,7 @@ Documentation Utilities Preprocessing Model Components + Bayesian Models Metrics Losses Dataloaders diff --git a/examples/scripts/adult_census_bayesian_tabmlp.py b/examples/scripts/adult_census_bayesian_tabmlp.py new file mode 100644 index 0000000..155fa1a --- /dev/null +++ b/examples/scripts/adult_census_bayesian_tabmlp.py @@ -0,0 +1,127 @@ +from pathlib import Path + +import numpy as np +import torch +import pandas as pd + +from pytorch_widedeep.metrics import Accuracy +from pytorch_widedeep.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_widedeep.preprocessing import TabPreprocessor, WidePreprocessor +from pytorch_widedeep.bayesian_models import BayesianWide, BayesianTabMlp +from pytorch_widedeep.training.bayesian_trainer import BayesianTrainer + +use_cuda = torch.cuda.is_available() + +if __name__ == "__main__": + + DATA_PATH = Path("../tmp_data") + + df = pd.read_csv(DATA_PATH / "adult/adult.csv.zip") + df.columns = [c.replace("-", "_") for c in df.columns] + df["age_buckets"] = pd.cut( + df.age, bins=[16, 25, 30, 35, 40, 45, 50, 55, 60, 91], labels=np.arange(9) + ) + df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int) + df.drop("income", axis=1, inplace=True) + df.head() + + for model_name in ["linear", "mlp"]: + for objective in ["binary", "multiclass", "regression"]: + + cat_cols = [ + "workclass", + "education", + "marital_status", + "occupation", + "relationship", + "native_country", + "race", + "gender", + ] + + if model_name == "linear": + crossed_cols = [ + ("education", "occupation"), + ("native_country", "occupation"), + ] + + if objective == "binary": + continuous_cols = ["age", "hours_per_week"] + target_name = "income_label" + target = df[target_name].values + elif objective == "multiclass": + continuous_cols = ["hours_per_week"] + target_name = "age_buckets" + target = np.array(df[target_name].tolist()) + elif objective == "regression": + continuous_cols = ["hours_per_week"] + target_name = "age" + target = df[target_name].values + + if model_name == "linear": + prepare_wide = WidePreprocessor( + wide_cols=cat_cols, crossed_cols=crossed_cols + ) + X_tab = prepare_wide.fit_transform(df) + + model = BayesianWide( + input_dim=np.unique(X_tab).shape[0], + pred_dim=df["age_buckets"].nunique() + if objective == "multiclass" + else 1, + prior_sigma_1=1.0, + prior_sigma_2=0.002, + prior_pi=0.8, + posterior_mu_init=0, + posterior_rho_init=-7.0, + ) + + if model_name == "mlp": + prepare_tab = TabPreprocessor( + embed_cols=cat_cols, continuous_cols=continuous_cols # type: ignore[arg-type] + ) + X_tab = prepare_tab.fit_transform(df) + + model = BayesianTabMlp( # type: ignore[assignment] + column_idx=prepare_tab.column_idx, + cat_embed_input=prepare_tab.embeddings_input, + continuous_cols=continuous_cols, + # embed_continuous=True, + mlp_hidden_dims=[128, 64], + prior_sigma_1=1.0, + prior_sigma_2=0.002, + prior_pi=0.8, + posterior_mu_init=0, + posterior_rho_init=-7.0, + pred_dim=df["age_buckets"].nunique() + if objective == "multiclass" + else 1, + ) + + model_checkpoint = ModelCheckpoint( + filepath="model_weights/wd_out", + save_best_only=True, + max_save=1, + ) + early_stopping = EarlyStopping(patience=2) + callbacks = [early_stopping, model_checkpoint] + metrics = [Accuracy] if objective != "regression" else None + + trainer = BayesianTrainer( + model, + objective=objective, + optimizer=torch.optim.Adam(model.parameters(), lr=0.01), + callbacks=callbacks, + metrics=metrics, + ) + + trainer.fit( + X_tab=X_tab, + target=target, + val_split=0.2, + n_epochs=1, + batch_size=256, + ) + + # simply to check predicts functions as expected + preds = trainer.predict(X_tab=X_tab) diff --git a/pytorch_widedeep/bayesian_models/__init__.py b/pytorch_widedeep/bayesian_models/__init__.py index 993996c..e9eb278 100644 --- a/pytorch_widedeep/bayesian_models/__init__.py +++ b/pytorch_widedeep/bayesian_models/__init__.py @@ -1,3 +1,4 @@ +from pytorch_widedeep.bayesian_models import bayesian_nn from pytorch_widedeep.bayesian_models.tabular import ( BayesianWide, BayesianTabMlp, diff --git a/pytorch_widedeep/bayesian_models/_base_bayesian_model.py b/pytorch_widedeep/bayesian_models/_base_bayesian_model.py index 48dd736..6a47148 100644 --- a/pytorch_widedeep/bayesian_models/_base_bayesian_model.py +++ b/pytorch_widedeep/bayesian_models/_base_bayesian_model.py @@ -1,14 +1,21 @@ +import torch from torch import nn from pytorch_widedeep.wdtypes import * # noqa: F403 class BayesianModule(nn.Module): + r"""Simply a 'hack' to facilitate the computation of the KL divergence for all + Bayesian models + """ + def init(self): super().__init__() class BaseBayesianModel(nn.Module): + r""" "Base model containing the two methods common to all Bayesian models""" + def _kl_divergence(self): kld = 0 for module in self.modules(): @@ -23,13 +30,15 @@ class BaseBayesianModel(nn.Module): loss_fn: nn.Module, n_samples: int, n_batches: int, - pred_dim: int, - ) -> Tensor: - outputs = torch.zeros(n_samples, target.shape[0], pred_dim) + ) -> Tuple[Tensor, Tensor]: + + outputs_l = [] kld = 0.0 - for i in range(n_samples): - outputs[i] = self(input) + for _ in range(n_samples): + outputs_l.append(self(input)) kld += self._kl_divergence() + outputs = torch.stack(outputs_l) + complexity_cost = kld / n_batches likelihood_cost = loss_fn(outputs.mean(0), target) - return complexity_cost + likelihood_cost + return outputs, complexity_cost + likelihood_cost diff --git a/pytorch_widedeep/bayesian_models/_weight_sampler.py b/pytorch_widedeep/bayesian_models/_weight_sampler.py index 6ae0e5a..4db4458 100644 --- a/pytorch_widedeep/bayesian_models/_weight_sampler.py +++ b/pytorch_widedeep/bayesian_models/_weight_sampler.py @@ -1,9 +1,19 @@ +""" +The code here is greatly insipired by the code at the Blitz package: + +https://github.com/piEsposito/blitz-bayesian-deep-learning +""" + import math from pytorch_widedeep.wdtypes import * # noqa: F403 class ScaleMixtureGaussianPrior(object): + r"""Defines the Scale Mixture Prior as proposed in Weight Uncertainty in + Neural Networks (Eq 7 in the original publication) + """ + def __init__(self, pi: float, sigma1: float, sigma2: float): super().__init__() self.pi = pi @@ -19,6 +29,10 @@ class ScaleMixtureGaussianPrior(object): class GaussianPosterior(object): + r"""Defines the Gaussian variational posterior as proposed in Weight + Uncertainty in Neural Networks + """ + def __init__(self, param_mu: Tensor, param_rho: Tensor): super().__init__() self.param_mu = param_mu diff --git a/pytorch_widedeep/bayesian_models/bayesian_nn/__init__.py b/pytorch_widedeep/bayesian_models/bayesian_nn/__init__.py new file mode 100644 index 0000000..5b69ab5 --- /dev/null +++ b/pytorch_widedeep/bayesian_models/bayesian_nn/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F401, F403 diff --git a/pytorch_widedeep/bayesian_models/bayesian_nn/modules/__init__.py b/pytorch_widedeep/bayesian_models/bayesian_nn/modules/__init__.py new file mode 100644 index 0000000..6e97be7 --- /dev/null +++ b/pytorch_widedeep/bayesian_models/bayesian_nn/modules/__init__.py @@ -0,0 +1,2 @@ +from .bayesian_linear import BayesianLinear +from .bayesian_embedding import BayesianEmbedding diff --git a/pytorch_widedeep/bayesian_models/bayesian_nn/modules/bayesian_embedding.py b/pytorch_widedeep/bayesian_models/bayesian_nn/modules/bayesian_embedding.py new file mode 100644 index 0000000..4701c51 --- /dev/null +++ b/pytorch_widedeep/bayesian_models/bayesian_nn/modules/bayesian_embedding.py @@ -0,0 +1,181 @@ +""" +The code here is greatly insipired by the code at the Blitz package: + +https://github.com/piEsposito/blitz-bayesian-deep-learning +""" + +import torch.nn.functional as F +from torch import nn + +from pytorch_widedeep.wdtypes import * # noqa: F403 +from pytorch_widedeep.bayesian_models._weight_sampler import ( + GaussianPosterior, + ScaleMixtureGaussianPrior, +) +from pytorch_widedeep.bayesian_models._base_bayesian_model import ( + BayesianModule, +) + + +class BayesianEmbedding(BayesianModule): + r"""A simple lookup table that looks up embeddings in a fixed dictionary and + size. + + Parameters + ---------- + n_embed: int + number of embeddings. Typically referred as size of the vocabulary + embed_dim: int + Dimension of the embeddings + padding_idx: int, optional, default = None + If specified, the entries at ``padding_idx`` do not contribute to the + gradient; therefore, the embedding vector at ``padding_idx`` is not + updated during training, i.e. it remains as a fixed “pad”. For a + newly constructed Embedding, the embedding vector at ``padding_idx`` + will default to all zeros, but can be updated to another value to be + used as the padding vector + max_norm: float, optional, default = None + If given, each embedding vector with norm larger than ``max_norm`` is + renormalized to have norm max_norm + norm_type: float, optional, default = 2. + The p of the p-norm to compute for the ``max_norm`` option. + scale_grad_by_freq: bool, optional, default = False + If given, this will scale gradients by the inverse of frequency of the + words in the mini-batch. + sparse: bool, optional, default = False + If True, gradient w.r.t. weight matrix will be a sparse tensor. See + Notes for more details regarding sparse gradients. + prior_sigma_1: float, default = 1.0 + Prior of the sigma parameter for the first of the two weight Gaussian + distributions that will be mixed to produce the prior weight + distribution + prior_sigma_2: float, default = 0.002 + Prior of the sigma parameter for the second of the two weight Gaussian + distributions that will be mixed to produce the prior weight + distribution + prior_pi: float, default = 0.8 + Scaling factor that will be used to mix the Gaussians to produce the + prior weight distribution + posterior_mu_init: float = 0.0, + The posterior sample of the weights is defined as: + + :math:`\mathbf{w} = \mu + log(1 + exp(\rho))` + + where :math:`\mu` and :math:`\rho` are both sampled from Gaussian + distributions. ``posterior_mu_init`` is the initial mean value for + the Gaussian distribution from which :math:`\mu` is sampled. + + posterior_rho_init: float = -7.0, + The initial mean value for the Gaussian distribution from which `\rho` + is sampled. + + Examples + -------- + >>> import torch + >>> from pytorch_widedeep.bayesian_models import bayesian_nn as bnn + >>> embedding = bnn.BayesianEmbedding(10, 3) + >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) + >>> out = embedding(input) + """ + + def __init__( + self, + n_embed: int, + embed_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: Optional[float] = 2.0, + scale_grad_by_freq: Optional[bool] = False, + sparse: Optional[bool] = False, + prior_sigma_1: float = 1.0, + prior_sigma_2: float = 0.002, + prior_pi: float = 0.25, + posterior_mu_init: float = 0.0, + posterior_rho_init: float = -3.0, + ): + super(BayesianEmbedding, self).__init__() + + self.n_embed = n_embed + self.embed_dim = embed_dim + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + + self.prior_sigma_1 = prior_sigma_1 + self.prior_sigma_2 = prior_sigma_2 + self.prior_pi = prior_pi + self.posterior_mu_init = posterior_mu_init + self.posterior_rho_init = posterior_rho_init + + # Variational weight parameters and sample + self.weight_mu = nn.Parameter( + torch.Tensor(n_embed, embed_dim).normal_(posterior_mu_init, 0.1) + ) + self.weight_rho = nn.Parameter( + torch.Tensor(n_embed, embed_dim).normal_(posterior_rho_init, 0.1) + ) + self.weight_sampler = GaussianPosterior(self.weight_mu, self.weight_rho) + + # Prior + self.weight_prior_dist = ScaleMixtureGaussianPrior( + self.prior_pi, + self.prior_sigma_1, + self.prior_sigma_2, + ) + + self.log_prior: Union[Tensor, float] = 0.0 + self.log_variational_posterior: Union[Tensor, float] = 0.0 + + def forward(self, X: Tensor) -> Tensor: + + if not self.training: + return F.embedding( + X, + self.weight_mu, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + weight = self.weight_sampler.sample() + + self.log_variational_posterior = self.weight_sampler.log_posterior(weight) + self.log_prior = self.weight_prior_dist.log_prior(weight) + + return F.embedding( + X, + weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + def extra_repr(self) -> str: # noqa: C901 + s = "{n_embed}, {embed_dim}" + if self.padding_idx is not None: + s += ", padding_idx={padding_idx}" + if self.max_norm is not None: + s += ", max_norm={max_norm}" + if self.norm_type != 2: + s += ", norm_type={norm_type}" + if self.scale_grad_by_freq is not False: + s += ", scale_grad_by_freq={scale_grad_by_freq}" + if self.sparse is not False: + s += ", sparse=True" + if self.prior_sigma_1 != 1.0: + s += ", prior_sigma_1={prior_sigma_1}" + if self.prior_sigma_2 != 0.002: + s += ", prior_sigma_2={prior_sigma_2}" + if self.prior_pi != 0.8: + s += ", prior_pi={prior_pi}" + if self.posterior_mu_init != 0.0: + s += ", posterior_mu_init={posterior_mu_init}" + if self.posterior_rho_init != -7.0: + s += ", posterior_rho_init={posterior_rho_init}" + return s.format(**self.__dict__) diff --git a/pytorch_widedeep/bayesian_models/bayesian_linear.py b/pytorch_widedeep/bayesian_models/bayesian_nn/modules/bayesian_linear.py similarity index 58% rename from pytorch_widedeep/bayesian_models/bayesian_linear.py rename to pytorch_widedeep/bayesian_models/bayesian_nn/modules/bayesian_linear.py index 1f15b5d..23412ca 100644 --- a/pytorch_widedeep/bayesian_models/bayesian_linear.py +++ b/pytorch_widedeep/bayesian_models/bayesian_nn/modules/bayesian_linear.py @@ -1,3 +1,13 @@ +""" +The code here is greatly insipired by a couple of sources: + +the Blitz package: https://github.com/piEsposito/blitz-bayesian-deep-learning and + +Weight Uncertainty in Neural Networks post by Nitarshan Rajkumar: https://www.nitarshan.com/bayes-by-backprop/ + +and references therein +""" + import torch.nn.functional as F from torch import nn @@ -12,16 +22,60 @@ from pytorch_widedeep.bayesian_models._base_bayesian_model import ( class BayesianLinear(BayesianModule): + r"""Applies a linear transformation to the incoming data as proposed in Weight + Uncertainity on Neural Networks + + Parameters + ---------- + in_features: int + size of each input sample + out_features: int + size of each output sample + use_bias: bool, default = True + Boolean indicating if an additive bias will be learnt + prior_sigma_1: float, default = 1.0 + Prior of the sigma parameter for the first of the two weight Gaussian + distributions that will be mixed to produce the prior weight + distribution + prior_sigma_2: float, default = 0.002 + Prior of the sigma parameter for the second of the two weight Gaussian + distributions that will be mixed to produce the prior weight + distribution + prior_pi: float, default = 0.8 + Scaling factor that will be used to mix the Gaussians to produce the + prior weight distribution + posterior_mu_init: float = 0.0, + The posterior sample of the weights is defined as: + + :math:`\mathbf{w} = \mu + log(1 + exp(\rho))` + + where :math:`\mu` and :math:`\rho` are both sampled from Gaussian + distributions. ``posterior_mu_init`` is the initial mean value for + the Gaussian distribution from which :math:`\mu` is sampled. + + posterior_rho_init: float = -7.0, + The initial mean value for the Gaussian distribution from which `\rho` + is sampled. + + Examples + -------- + >>> import torch + >>> from pytorch_widedeep.bayesian_models import bayesian_nn as bnn + >>> linear = bnn.BayesianLinear(10, 6) + >>> input = torch.rand(6, 10) + >>> out = linear(input) + """ + def __init__( self, in_features: int, out_features: int, use_bias: bool = True, - prior_sigma_1: float = 0.1, + prior_sigma_1: float = 1.0, prior_sigma_2: float = 0.002, - prior_pi: float = 1.0, + prior_pi: float = 0.8, posterior_mu_init: float = 0.0, - posterior_rho_init: float = -6.0, + posterior_rho_init: float = -7.0, ): super(BayesianLinear, self).__init__() @@ -37,8 +91,7 @@ class BayesianLinear(BayesianModule): self.prior_sigma_2 = prior_sigma_2 self.prior_pi = prior_pi - # Variational weight and bias parameters and sample for the posterior - # computation + # Variational Posterior self.weight_mu = nn.Parameter( torch.Tensor(out_features, in_features).normal_(posterior_mu_init, 0.1) ) @@ -103,13 +156,13 @@ class BayesianLinear(BayesianModule): if self.use_bias is not False: s += ", use_bias=True" if self.prior_sigma_1 != 0.1: - s + ", prior_sigma_1={prior_sigma_1}" + s += ", prior_sigma_1={prior_sigma_1}" if self.prior_sigma_2 != 0.002: - s + ", prior_sigma_2={prior_sigma_2}" - if self.prior_pi != 1.0: - s + ", prior_pi={prior_pi}" + s += ", prior_sigma_2={prior_sigma_2}" + if self.prior_pi != 0.8: + s += ", prior_pi={prior_pi}" if self.posterior_mu_init != 0.0: - s + ", posterior_mu_init={posterior_mu_init}" - if self.posterior_rho_init != -6.0: - s + ", posterior_rho_init={posterior_rho_init}" + s += ", posterior_mu_init={posterior_mu_init}" + if self.posterior_rho_init != -8.0: + s += ", posterior_rho_init={posterior_rho_init}" return s.format(**self.__dict__) diff --git a/pytorch_widedeep/bayesian_models/bayesian_embeddings_layers.py b/pytorch_widedeep/bayesian_models/tabular/bayesian_embeddings_layers.py similarity index 56% rename from pytorch_widedeep/bayesian_models/bayesian_embeddings_layers.py rename to pytorch_widedeep/bayesian_models/tabular/bayesian_embeddings_layers.py index 09f810f..5aac777 100644 --- a/pytorch_widedeep/bayesian_models/bayesian_embeddings_layers.py +++ b/pytorch_widedeep/bayesian_models/tabular/bayesian_embeddings_layers.py @@ -1,9 +1,9 @@ import numpy as np import einops -import torch.nn.functional as F from torch import nn from pytorch_widedeep.wdtypes import * # noqa: F403 +from pytorch_widedeep.bayesian_models import bayesian_nn as bnn from pytorch_widedeep.models._get_activation_fn import get_activation_fn from pytorch_widedeep.bayesian_models._weight_sampler import ( GaussianPosterior, @@ -14,150 +14,6 @@ from pytorch_widedeep.bayesian_models._base_bayesian_model import ( ) -class BayesianEmbedding(BayesianModule): - def __init__( - self, - n_embed: int, - embed_dim: int, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: Optional[float] = 2.0, - scale_grad_by_freq: Optional[bool] = False, - sparse: Optional[bool] = False, - use_bias: bool = False, - prior_sigma_1: float = 0.1, - prior_sigma_2: float = 0.002, - prior_pi: float = 1.0, - posterior_mu_init: float = 0.0, - posterior_rho_init: float = -6.0, - ): - super(BayesianEmbedding, self).__init__() - - self.n_embed = n_embed - self.embed_dim = embed_dim - self.padding_idx = padding_idx - self.max_norm = max_norm - self.norm_type = norm_type - self.scale_grad_by_freq = scale_grad_by_freq - self.sparse = sparse - - self.use_bias = use_bias - - self.prior_sigma_1 = prior_sigma_1 - self.prior_sigma_2 = prior_sigma_2 - self.prior_pi = prior_pi - self.posterior_mu_init = posterior_mu_init - self.posterior_rho_init = posterior_rho_init - - # Variational weight parameters and sample - self.weight_mu = nn.Parameter( - torch.Tensor(n_embed, embed_dim).normal_(posterior_mu_init, 0.1) - ) - self.weight_rho = nn.Parameter( - torch.Tensor(n_embed, embed_dim).normal_(posterior_rho_init, 0.1) - ) - self.weight_sampler = GaussianPosterior(self.weight_mu, self.weight_rho) - - if self.use_bias: - self.bias_mu: Union[nn.Parameter, float] = nn.Parameter( - torch.Tensor(n_embed).normal_(posterior_mu_init, 0.1) - ) - self.bias_rho: Union[nn.Parameter, float] = nn.Parameter( - torch.Tensor(n_embed).normal_(posterior_rho_init, 0.1) - ) - self.bias_sampler = GaussianPosterior(self.bias_mu, self.bias_rho) - else: - self.bias_mu, self.bias_rho = 0.0, 0.0 - - # Prior - self.weight_prior_dist = ScaleMixtureGaussianPrior( - self.prior_pi, - self.prior_sigma_1, - self.prior_sigma_2, - ) - if self.use_bias: - self.bias_prior_dist = ScaleMixtureGaussianPrior( - self.prior_pi, - self.prior_sigma_1, - self.prior_sigma_2, - ) - - self.log_prior: Union[Tensor, float] = 0.0 - self.log_variational_posterior: Union[Tensor, float] = 0.0 - - def forward(self, X: Tensor) -> Tensor: - - if not self.training: - return ( - F.embedding( - X, - self.weight_mu, - self.padding_idx, - self.max_norm, - self.norm_type, - self.scale_grad_by_freq, - self.sparse, - ) - + self.bias_mu - ) - - weight = self.weight_sampler.sample() - if self.use_bias: - bias = self.bias_sampler.sample() - bias_log_posterior: Union[Tensor, float] = self.bias_sampler.log_posterior( - bias - ) - bias_log_prior: Union[Tensor, float] = self.bias_prior_dist.log_prior(bias) - else: - bias = None - bias_log_posterior = 0.0 - bias_log_prior = 0.0 - - self.log_variational_posterior = ( - self.weight_sampler.log_posterior(weight) + bias_log_posterior - ) - self.log_prior = self.weight_prior_dist.log_prior(weight) + bias_log_prior - - return ( - F.embedding( - X, - weight, - self.padding_idx, - self.max_norm, - self.norm_type, - self.scale_grad_by_freq, - self.sparse, - ) - + bias - ) - - def extra_repr(self) -> str: # noqa: C901 - s = "{n_embed}, {embed_dim}" - if self.padding_idx is not None: - s += ", padding_idx={padding_idx}" - if self.max_norm is not None: - s += ", max_norm={max_norm}" - if self.norm_type != 2: - s += ", norm_type={norm_type}" - if self.scale_grad_by_freq is not False: - s += ", scale_grad_by_freq={scale_grad_by_freq}" - if self.sparse is not False: - s += ", sparse=True" - if self.use_bias: - s += ", use_bias=True" - if self.prior_sigma_1 != 0.1: - s + ", prior_sigma_1={prior_sigma_1}" - if self.prior_sigma_2 != 0.002: - s + ", prior_sigma_2={prior_sigma_2}" - if self.prior_pi != 1.0: - s + ", prior_pi={prior_pi}" - if self.posterior_mu_init != 0.0: - s + ", posterior_mu_init={posterior_mu_init}" - if self.posterior_rho_init != -6.0: - s + ", posterior_rho_init={posterior_rho_init}" - return s.format(**self.__dict__) - - class BayesianContEmbeddings(BayesianModule): def __init__( self, @@ -173,7 +29,10 @@ class BayesianContEmbeddings(BayesianModule): ): super(BayesianContEmbeddings, self).__init__() + self.n_cont_cols = n_cont_cols + self.embed_dim = embed_dim self.use_bias = use_bias + self.activation = activation self.weight_mu = nn.Parameter( torch.Tensor(n_cont_cols, embed_dim).normal_(posterior_mu_init, 0.1) @@ -246,7 +105,7 @@ class BayesianContEmbeddings(BayesianModule): return x def extra_repr(self) -> str: - s = "{n_cont_cols}, {embed_dim}, embed_dropout={embed_dropout}, use_bias={use_bias}" + s = "{n_cont_cols}, {embed_dim}, use_bias={use_bias}" if self.activation is not None: s += ", activation={activation}" return s.format(**self.__dict__) @@ -272,7 +131,7 @@ class BayesianDiffSizeCatEmbeddings(nn.Module): self.embed_layers = nn.ModuleDict( { "emb_layer_" - + col: BayesianEmbedding( + + col: bnn.BayesianEmbedding( val + 1, dim, padding_idx=0, @@ -303,6 +162,7 @@ class BayesianDiffSizeCatAndContEmbeddings(nn.Module): column_idx: Dict[str, int], cat_embed_input: List[Tuple[str, int, int]], continuous_cols: Optional[List[str]], + embed_continuous: bool, cont_embed_dim: int, cont_embed_activation: str, use_cont_bias: bool, @@ -317,6 +177,8 @@ class BayesianDiffSizeCatAndContEmbeddings(nn.Module): self.cat_embed_input = cat_embed_input self.continuous_cols = continuous_cols + self.embed_continuous = embed_continuous + self.cont_embed_dim = cont_embed_dim # Categorical if self.cat_embed_input is not None: @@ -342,18 +204,21 @@ class BayesianDiffSizeCatAndContEmbeddings(nn.Module): self.cont_norm = nn.BatchNorm1d(len(continuous_cols)) else: self.cont_norm = nn.Identity() - self.cont_embed = BayesianContEmbeddings( - len(continuous_cols), - cont_embed_dim, - prior_sigma_1, - prior_sigma_2, - prior_pi, - posterior_mu_init, - posterior_rho_init, - use_cont_bias, - cont_embed_activation, - ) - self.cont_out_dim = len(continuous_cols) * cont_embed_dim + if self.embed_continuous: + self.cont_embed = BayesianContEmbeddings( + len(continuous_cols), + cont_embed_dim, + prior_sigma_1, + prior_sigma_2, + prior_pi, + posterior_mu_init, + posterior_rho_init, + use_cont_bias, + cont_embed_activation, + ) + self.cont_out_dim = len(continuous_cols) * cont_embed_dim + else: + self.cont_out_dim = len(continuous_cols) else: self.cont_out_dim = 0 diff --git a/pytorch_widedeep/bayesian_models/tabular/bayesian_linear/bayesian_wide.py b/pytorch_widedeep/bayesian_models/tabular/bayesian_linear/bayesian_wide.py index 96e8381..d3d3741 100644 --- a/pytorch_widedeep/bayesian_models/tabular/bayesian_linear/bayesian_wide.py +++ b/pytorch_widedeep/bayesian_models/tabular/bayesian_linear/bayesian_wide.py @@ -1,36 +1,91 @@ +from torch import nn + from pytorch_widedeep.wdtypes import * # noqa: F403 +from pytorch_widedeep.bayesian_models import bayesian_nn as bnn from pytorch_widedeep.bayesian_models._base_bayesian_model import ( BaseBayesianModel, ) -from pytorch_widedeep.bayesian_models.bayesian_embeddings_layers import ( - BayesianEmbedding, -) class BayesianWide(BaseBayesianModel): + r"""Creates a so called Wide model. This is a linear model where the + non-linearlities are captured via crossed-columns + + The model implemented via a Bayesian Embedding layer connected to the + output neuron(s). + + Parameters + ---------- + input_dim: int + size of the Embedding layer. `input_dim` is the summation of all the + individual values for all the features that go through the wide + component. For example, if the wide component receives 2 features with + 5 individual values each, `input_dim = 10` + pred_dim: int + size of the ouput tensor containing the predictions + prior_sigma_1: float, default = 1.0 + Prior of the sigma parameter for the first of the two weight Gaussian + distributions that will be mixed to produce the prior weight + distribution + prior_sigma_2: float, default = 0.002 + Prior of the sigma parameter for the second of the two weight Gaussian + distributions that will be mixed to produce the prior weight + distribution + prior_pi: float, default = 0.8 + Scaling factor that will be used to mix the Gaussians to produce the + prior weight distribution + posterior_mu_init: float = 0.0, + The posterior sample of the weights of the Bayesian Embedding layer is + defined as: + + :math:`\mathbf{w} = \mu + log(1 + exp(\rho))` + + where :math:`\mu` and :math:`\rho` are both sampled from Gaussian + distributions. ``posterior_mu_init`` is the initial mean value for + the Gaussian distribution from which :math:`\mu` is sampled. + + posterior_rho_init: float = -7.0, + The initial mean value for the Gaussian distribution from + which :math:`\rho` is sampled. + + Attributes + ----------- + bayesian_wide_linear: ``nn.Module`` + the linear layer that comprises the wide branch of the model + + Examples + -------- + >>> import torch + >>> from pytorch_widedeep.bayesian_models import BayesianWide + >>> X = torch.empty(4, 4).random_(6) + >>> wide = BayesianWide(input_dim=X.unique().size(0), pred_dim=1) + >>> out = wide(X) + """ + def __init__( self, input_dim: int, pred_dim: int = 1, - prior_sigma_1: float = 0.75, - prior_sigma_2: float = 1, - prior_pi: float = 0.25, - posterior_mu_init: float = 0.1, - posterior_rho_init: float = -3.0, + prior_sigma_1: float = 1.0, + prior_sigma_2: float = 0.002, + prior_pi: float = 0.8, + posterior_mu_init: float = 0.0, + posterior_rho_init: float = -8.0, ): super(BayesianWide, self).__init__() - self.bayesian_wide_linear = BayesianEmbedding( - n_embed=input_dim, + # Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories. + self.bayesian_wide_linear = bnn.BayesianEmbedding( + n_embed=input_dim + 1, embed_dim=pred_dim, padding_idx=0, - use_bias=True, prior_sigma_1=prior_sigma_1, prior_sigma_2=prior_sigma_2, prior_pi=prior_pi, posterior_mu_init=posterior_mu_init, posterior_rho_init=posterior_rho_init, ) + self.bias = nn.Parameter(torch.zeros(pred_dim)) def forward(self, X: Tensor) -> Tensor: - out = self.bayesian_wide_linear(X.long()).sum(dim=1) + out = self.bayesian_wide_linear(X.long()).sum(dim=1) + self.bias return out diff --git a/pytorch_widedeep/bayesian_models/tabular/bayesian_mlp/_layers.py b/pytorch_widedeep/bayesian_models/tabular/bayesian_mlp/_layers.py index 4500402..c4b03b5 100644 --- a/pytorch_widedeep/bayesian_models/tabular/bayesian_mlp/_layers.py +++ b/pytorch_widedeep/bayesian_models/tabular/bayesian_mlp/_layers.py @@ -1,8 +1,8 @@ from torch import nn from pytorch_widedeep.wdtypes import * # noqa: F403 +from pytorch_widedeep.bayesian_models import bayesian_nn as bnn from pytorch_widedeep.models._get_activation_fn import get_activation_fn -from pytorch_widedeep.bayesian_models.bayesian_linear import BayesianLinear class BayesianMLP(nn.Module): @@ -27,7 +27,7 @@ class BayesianMLP(nn.Module): for i in range(1, len(d_hidden)): bayesian_dense_layer = nn.Sequential( *[ - BayesianLinear( + bnn.BayesianLinear( d_hidden[i - 1], d_hidden[i], use_bias, diff --git a/pytorch_widedeep/bayesian_models/tabular/bayesian_mlp/bayesian_tab_mlp.py b/pytorch_widedeep/bayesian_models/tabular/bayesian_mlp/bayesian_tab_mlp.py index a0464ef..977fe45 100644 --- a/pytorch_widedeep/bayesian_models/tabular/bayesian_mlp/bayesian_tab_mlp.py +++ b/pytorch_widedeep/bayesian_models/tabular/bayesian_mlp/bayesian_tab_mlp.py @@ -5,15 +5,101 @@ from pytorch_widedeep.models._get_activation_fn import allowed_activations from pytorch_widedeep.bayesian_models._base_bayesian_model import ( BaseBayesianModel, ) -from pytorch_widedeep.bayesian_models.bayesian_embeddings_layers import ( - BayesianDiffSizeCatAndContEmbeddings, -) from pytorch_widedeep.bayesian_models.tabular.bayesian_mlp._layers import ( BayesianMLP, ) +from pytorch_widedeep.bayesian_models.tabular.bayesian_embeddings_layers import ( + BayesianDiffSizeCatAndContEmbeddings, +) class BayesianTabMlp(BaseBayesianModel): + r"""Defines a ``TabMlp`` model that can be used as the ``deeptabular`` + component of a Wide & Deep model. + + This class combines embedding representations of the categorical features + with numerical (aka continuous) features. These are then passed through a + series of dense layers (i.e. a MLP). + + Parameters + ---------- + column_idx: Dict + Dict containing the index of the columns that will be passed through + the ``TabMlp`` model. Required to slice the tensors. e.g. {'education': + 0, 'relationship': 1, 'workclass': 2, ...} + cat_embed_input: List, Optional, default = None + List of Tuples with the column name, number of unique values and + embedding dimension. e.g. [(education, 11, 32), ...] + cat_embed_dropout: float, default = 0.1 + embeddings dropout + continuous_cols: List, Optional, default = None + List with the name of the numeric (aka continuous) columns + embed_continuous: bool, default = False, + Boolean indicating if the continuous columns will be embedded + (i.e. passed each through a linear layer with or without activation) + cont_embed_dim: int, default = 32, + Size of the continuous embeddings + cont_embed_dropout: float, default = 0.1, + Dropout for the continuous embeddings + cont_embed_activation: Optional, str, default = None, + Activation function for the continuous embeddings + use_cont_bias: bool, default = True, + Boolean indicating in bias will be used for the continuous embeddings + cont_norm_layer: str, default = "batchnorm" + Type of normalization layer applied to the continuous features. Options + are: 'layernorm', 'batchnorm' or None. + mlp_hidden_dims: List, default = [200, 100] + List with the number of neurons per dense layer in the mlp. + mlp_activation: str, default = "relu" + Activation function for the dense layers of the MLP. Currently + ``tanh``, ``relu``, ``leaky_relu`` and ``gelu`` are supported + prior_sigma_1: float, default = 1.0 + Prior of the sigma parameter for the first of the two weight Gaussian + distributions that will be mixed to produce the prior weight + distribution for each Bayesian linear and embedding layer + prior_sigma_2: float, default = 0.002 + Prior of the sigma parameter for the second of the two weight Gaussian + distributions that will be mixed to produce the prior weight + distribution for each Bayesian linear and embedding layer + prior_pi: float, default = 0.8 + Scaling factor that will be used to mix the Gaussians to produce the + prior weight distribution ffor each Bayesian linear and embedding + layer + posterior_mu_init: float = 0.0, + The posterior sample of the weights is defined as: + + :math:`\mathbf{w} = \mu + log(1 + exp(\rho))` + + where :math:`\mu` and :math:`\rho` are both sampled from Gaussian + distributions. ``posterior_mu_init`` is the initial mean value for + the Gaussian distribution from which :math:`\mu` is sampled for each + Bayesian linear and embedding layers. + posterior_rho_init: float = -7.0, + The initial mean value for the Gaussian distribution from + which :math:`\rho` is sampled for each Bayesian linear and embedding + layers. + + Attributes + ---------- + bayesian_cat_and_cont_embed: ``nn.Module`` + This is the module that processes the categorical and continuous columns + bayesian_tab_mlp: ``nn.Sequential`` + mlp model that will receive the concatenation of the embeddings and + the continuous columns + + Example + -------- + >>> import torch + >>> from pytorch_widedeep.bayesian_models import BayesianTabMlp + >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1) + >>> colnames = ['a', 'b', 'c', 'd', 'e'] + >>> cat_embed_input = [(u,i,j) for u,i,j in zip(colnames[:4], [4]*4, [8]*4)] + >>> column_idx = {k:v for v,k in enumerate(colnames)} + >>> model = BayesianTabMlp(mlp_hidden_dims=[8,4], column_idx=column_idx, cat_embed_input=cat_embed_input, + ... continuous_cols = ['e']) + >>> out = model(X_tab) + """ + def __init__( self, column_idx: Dict[str, int], @@ -27,12 +113,11 @@ class BayesianTabMlp(BaseBayesianModel): use_cont_bias: bool = True, cont_norm_layer: str = "batchnorm", mlp_hidden_dims: List[int] = [200, 100], - mlp_activation: str = "relu", - use_bias: bool = True, + mlp_activation: str = "leaky_relu", prior_sigma_1: float = 0.75, prior_sigma_2: float = 0.1, prior_pi: float = 0.25, - posterior_mu_init: float = 0.1, + posterior_mu_init: float = 0.0, posterior_rho_init: float = -3.0, pred_dim=1, # Bayesian models will require their own trainer and need the output layer ): @@ -52,7 +137,6 @@ class BayesianTabMlp(BaseBayesianModel): self.mlp_hidden_dims = mlp_hidden_dims self.mlp_activation = mlp_activation - self.use_bias = use_bias self.prior_sigma_1 = prior_sigma_1 self.prior_sigma_2 = prior_sigma_2 self.prior_pi = prior_pi @@ -73,6 +157,7 @@ class BayesianTabMlp(BaseBayesianModel): column_idx, cat_embed_input, continuous_cols, + embed_continuous, cont_embed_dim, cont_embed_activation, use_cont_bias, @@ -89,7 +174,7 @@ class BayesianTabMlp(BaseBayesianModel): self.bayesian_tab_mlp = BayesianMLP( mlp_hidden_dims, mlp_activation, - use_bias, + True, # use_bias prior_sigma_1, prior_sigma_2, prior_pi, diff --git a/pytorch_widedeep/losses.py b/pytorch_widedeep/losses.py index 0b599e1..7d5a5c2 100644 --- a/pytorch_widedeep/losses.py +++ b/pytorch_widedeep/losses.py @@ -304,13 +304,21 @@ class RMSLELoss(nn.Module): class BayesianRegressionLoss(nn.Module): - def __init__(self, noise_tolerance: float = 0.2): + def __init__(self, noise_tolerance: float): super().__init__() self.noise_tolerance = noise_tolerance def forward(self, input: Tensor, target: Tensor) -> Tensor: return ( - torch.distributions.Normal(input, self.noise_tolerance) + -torch.distributions.Normal(input, self.noise_tolerance) .log_prob(target) .sum() ) + + +class BayesianSELoss(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return (0.5 * (input - target) ** 2).sum() diff --git a/pytorch_widedeep/training/_trainer_utils.py b/pytorch_widedeep/training/_trainer_utils.py index 0ed10b8..49bceaf 100644 --- a/pytorch_widedeep/training/_trainer_utils.py +++ b/pytorch_widedeep/training/_trainer_utils.py @@ -1,6 +1,8 @@ import numpy as np +import torch from tqdm import tqdm from torch import nn +from torch.utils.data import TensorDataset from sklearn.model_selection import train_test_split from pytorch_widedeep.losses import ( @@ -11,7 +13,7 @@ from pytorch_widedeep.losses import ( RMSLELoss, TweedieLoss, QuantileLoss, - BayesianRegressionLoss, + BayesianSELoss, ) from pytorch_widedeep.wdtypes import Dict, List, Optional, Transforms from pytorch_widedeep.training._wd_dataset import WideDeepDataset @@ -21,6 +23,80 @@ from pytorch_widedeep.training._loss_and_obj_aliases import ( ) +def tabular_train_val_split( + seed: int, + method: str, + X: np.ndarray, + y: np.ndarray, + X_val: Optional[np.ndarray] = None, + y_val: Optional[np.ndarray] = None, + val_split: Optional[float] = None, +): + r""" + Function to create the train/val split for the BayesianTrainer where only + tabular data is present + + Parameters + ---------- + seed: int + random seed to be used during train/val split + method: str + 'regression', 'binary' or 'multiclass' + X: np.ndarray + tabular dataset (categorical and continuous features) + y: np.ndarray + X_val: np.ndarray, Optional, default = None + Dict with the validation set, where the keys are the component names + (e.g: 'wide') and the values the corresponding arrays + y_val: np.ndarray, Optional, default = None + + Returns + ------- + train_set: ``TensorDataset`` + eval_set: ``TensorDataset`` + """ + + if X_val is not None: + assert ( + y_val is not None + ), "if X_val is not None the validation target 'y_val' must also be specified" + + train_set = TensorDataset( + torch.from_numpy(X), + torch.from_numpy(y), + ) + eval_set = TensorDataset( + torch.from_numpy(X_val), + torch.from_numpy(y_val), + ) + elif val_split is not None: + y_tr, y_val, idx_tr, idx_val = train_test_split( + y, + np.arange(len(y)), + test_size=val_split, + random_state=seed, + stratify=y if method != "regression" else None, + ) + X_tr, X_val = X[idx_tr], X[idx_val] + + train_set = TensorDataset( + torch.from_numpy(X_tr), + torch.from_numpy(y_tr), + ) + eval_set = TensorDataset( + torch.from_numpy(X_val), + torch.from_numpy(y_val), + ) + else: + train_set = TensorDataset( + torch.from_numpy(X), + torch.from_numpy(y), + ) + eval_set = None + + return train_set, eval_set + + def wd_train_val_split( # noqa: C901 seed: int, method: str, @@ -185,6 +261,34 @@ def save_epoch_logs(epoch_logs: Dict, loss: float, score: Dict, stage: str): return epoch_logs +def bayesian_alias_to_loss(loss_fn: str, **kwargs): + r""" + Function that returns the corresponding loss function given an alias + + Parameters + ---------- + loss_fn: str + Loss name + + Returns + ------- + Object + loss function + + Examples + -------- + >>> from pytorch_widedeep.training._trainer_utils import bayesian_alias_to_loss + >>> loss_fn = bayesian_alias_to_loss(loss_fn="binary", weight=None) + """ + if loss_fn == "binary": + return nn.BCEWithLogitsLoss(pos_weight=kwargs["weight"], reduction="sum") + if loss_fn == "multiclass": + return nn.CrossEntropyLoss(weight=kwargs["weight"], reduction="sum") + if loss_fn == "regression": + return BayesianSELoss() + # return BayesianRegressionLoss(noise_tolerance=kwargs["noise_tolerance"]) + + def alias_to_loss(loss_fn: str, **kwargs): # noqa: C901 r""" Function that returns the corresponding loss function given an alias @@ -232,9 +336,3 @@ def alias_to_loss(loss_fn: str, **kwargs): # noqa: C901 return TweedieLoss() if "focal_loss" in loss_fn: return FocalLoss(**kwargs) - if "bayesian_binary" in loss_fn: - return nn.BCEWithLogitsLoss(pos_weight=kwargs["weight"], reduction="sum") - if "bayesian_multiclass" in loss_fn: - return nn.CrossEntropyLoss(weight=kwargs["weight"], reduction="sum") - if "bayesian_regression" in loss_fn: - return BayesianRegressionLoss(noise_tolerance=kwargs["noise_tolerance"]) diff --git a/pytorch_widedeep/training/bayesian_trainer.py b/pytorch_widedeep/training/bayesian_trainer.py new file mode 100644 index 0000000..5bfaffb --- /dev/null +++ b/pytorch_widedeep/training/bayesian_trainer.py @@ -0,0 +1,601 @@ +import os +import json +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +from tqdm import trange +from torchmetrics import Metric as TorchMetric +from torch.utils.data import DataLoader, TensorDataset +from torch.optim.lr_scheduler import ReduceLROnPlateau + +from pytorch_widedeep.metrics import Metric, MultipleMetrics +from pytorch_widedeep.wdtypes import * # noqa: F403 +from pytorch_widedeep.callbacks import ( + History, + Callback, + MetricCallback, + CallbackContainer, + LRShedulerCallback, +) +from pytorch_widedeep.utils.general_utils import Alias +from pytorch_widedeep.training._trainer_utils import ( + save_epoch_logs, + print_loss_and_metric, + bayesian_alias_to_loss, + tabular_train_val_split, +) +from pytorch_widedeep.bayesian_models._base_bayesian_model import ( + BaseBayesianModel, +) + +n_cpus = os.cpu_count() + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class BayesianTrainer: + r"""Class to set the of attributes that will be used during the + training process. + + Parameters + ---------- + model: ``BaseBayesianModel`` + An object of class ``BaseBayesianModel`` + objective: str + Defines the objective, loss or cost function. + + Param aliases: ``loss_function``, ``loss_fn``, ``loss``, + ``cost_function``, ``cost_fn``, ``cost`` + + Possible values are: 'binary', 'multiclass', 'regression' + + custom_loss_function: ``nn.Module``, optional, default = None + object of class ``nn.Module``. If none of the loss functions + available suits the user, it is possible to pass a custom loss + function. See for example + :class:`pytorch_widedeep.losses.FocalLoss` for the required + structure of the object or the `Examples + `__ + folder in the repo. + optimizer: ``Optimzer``, optional, default= None + An instance of Pytorch's ``Optimizer`` object + (e.g. :obj:`torch.optim.Adam()`). if no optimizer is passed it will + default to ``AdamW``. + lr_schedulers: ``LRScheduler``, optional, default=None + An instance of Pytorch's ``LRScheduler`` object (e.g + :obj:`torch.optim.lr_scheduler.StepLR(opt, step_size=5)`) + callbacks: List, optional, default=None + List with :obj:`Callback` objects. The three callbacks available in + ``pytorch-widedeep`` are: ``LRHistory``, ``ModelCheckpoint`` and + ``EarlyStopping``. The ``History`` and the ``LRShedulerCallback`` + callbacks are used by default. This can also be a custom callback as + long as the object of type ``Callback``. See + :obj:`pytorch_widedeep.callbacks.Callback` or the `Examples + `__ + folder in the repo + metrics: List, optional, default=None + - List of objects of type :obj:`Metric`. Metrics available are: + ``Accuracy``, ``Precision``, ``Recall``, ``FBetaScore``, + ``F1Score`` and ``R2Score``. This can also be a custom metric as + long as it is an object of type :obj:`Metric`. See + :obj:`pytorch_widedeep.metrics.Metric` or the `Examples + `__ + folder in the repo + - List of objects of type :obj:`torchmetrics.Metric`. This can be any + metric from torchmetrics library `Examples + `_. This can also be a custom metric as + long as it is an object of type :obj:`Metric`. See `the instructions + `_. + class_weight: float, List or Tuple. optional. default=None + - float indicating the weight of the minority class in binary classification + problems (e.g. 9.) + - a list or tuple with weights for the different classes in multiclass + classification problems (e.g. [1., 2., 3.]). The weights do not + need to be normalised. See `this discussion + `_. + verbose: int, default=1 + Setting it to 0 will print nothing during training. + seed: int, default=1 + Random seed to be used internally for train_test_split + + Attributes + ---------- + cyclic_lr: bool + Attribute that indicates if the lr_scheduler is cyclic_lr + (i.e. ``CyclicLR`` or ``OneCycleLR``). See `Pytorch schedulers + `_. + """ + + @Alias( # noqa: C901 + "objective", + ["loss_function", "loss_fn", "loss", "cost_function", "cost_fn", "cost"], + ) + def __init__( + self, + model: BaseBayesianModel, + objective: str, + custom_loss_function: Optional[Module] = None, + optimizer: Optimizer = None, + lr_scheduler: LRScheduler = None, + callbacks: Optional[List[Callback]] = None, + metrics: Optional[Union[List[Metric], List[TorchMetric]]] = None, + class_weight: Optional[Union[float, List[float], Tuple[float]]] = None, + verbose: int = 1, + seed: int = 1, + **kwargs, + ): + + if objective not in ["binary", "multiclass", "regression"]: + raise ValueError( + "If 'custom_loss_function' is not None, 'objective' must be 'binary' " + "'multiclass' or 'regression', consistent with the loss function" + ) + + self.model = model + self.early_stop = False + + self.verbose = verbose + self.seed = seed + self.objective = objective + + self.loss_fn = self._set_loss_fn(objective, class_weight, custom_loss_function) + self.optimizer = ( + optimizer + if optimizer is not None + else torch.optim.AdamW(self.model.parameters()) + ) + self.lr_scheduler = lr_scheduler + try: + self._set_lr_scheduler_running_params( + lr_scheduler, kwargs["reducelronplateau_criterion"] + ) + except KeyError: + self._set_lr_scheduler_running_params(lr_scheduler) + self._set_callbacks_and_metrics(callbacks, metrics) + self.model.to(device) + + def fit( # noqa: C901 + self, + X_tab: np.ndarray, + target: np.ndarray, + X_tab_val: Optional[np.ndarray] = None, + target_val: Optional[np.ndarray] = None, + val_split: Optional[float] = None, + n_epochs: int = 1, + val_freq: int = 1, + batch_size: int = 32, + n_train_samples: int = 2, + n_val_samples: int = 2, + ): + r"""Fit method. + + Parameters + ---------- + X_tab: np.ndarray, + tabular dataset + target: np.ndarray + target values + X_tab_val: np.ndarray, Optional, default = None + validation data + target_val: np.ndarray, Optional, default = None + validation target values + val_split: float, Optional. default=None + An alterative to passing the validation set is to use a train/val + split fraction via 'val_split' + n_epochs: int, default=1 + number of epochs + validation_freq: int, default=1 + epochs validation frequency + batch_size: int, default=32 + batch size + n_train_samples: int, default=2 + number of samples to average over during the training process. + n_val_samples: int, default=2 + number of samples to average over during the validation process. + """ + + self.batch_size = batch_size + + train_set, eval_set = tabular_train_val_split( + self.seed, self.objective, X_tab, target, X_tab_val, target_val, val_split + ) + train_loader = DataLoader( + dataset=train_set, batch_size=batch_size, num_workers=n_cpus + ) + train_steps = len(train_loader) + + if eval_set is not None: + eval_loader = DataLoader( + dataset=eval_set, + batch_size=batch_size, + num_workers=n_cpus, + shuffle=False, + ) + eval_steps = len(eval_loader) + + self.callback_container.on_train_begin( + { + "batch_size": batch_size, + "train_steps": train_steps, + "n_epochs": n_epochs, + } + ) + for epoch in range(n_epochs): + epoch_logs: Dict[str, float] = {} + self.callback_container.on_epoch_begin(epoch, logs=epoch_logs) + + self.train_running_loss = 0.0 + with trange(train_steps, disable=self.verbose != 1) as t: + for batch_idx, (X, y) in zip(t, train_loader): + t.set_description("epoch %i" % (epoch + 1)) + train_score, train_loss = self._train_step( + X, y, n_train_samples, train_steps, batch_idx + ) + print_loss_and_metric(t, train_loss, train_score) + self.callback_container.on_batch_end(batch=batch_idx) + epoch_logs = save_epoch_logs(epoch_logs, train_loss, train_score, "train") + + on_epoch_end_metric = None + if eval_set is not None and epoch % val_freq == (val_freq - 1): + self.callback_container.on_eval_begin() + self.valid_running_loss = 0.0 + with trange(eval_steps, disable=self.verbose != 1) as v: + for i, (X, y) in zip(v, eval_loader): + v.set_description("valid") + val_score, val_loss = self._eval_step( + X, y, n_val_samples, train_steps, i + ) + print_loss_and_metric(v, val_loss, val_score) + epoch_logs = save_epoch_logs(epoch_logs, val_loss, val_score, "val") + + if self.reducelronplateau: + if self.reducelronplateau_criterion == "loss": + on_epoch_end_metric = val_loss + else: + on_epoch_end_metric = val_score[ + self.reducelronplateau_criterion + ] + + self.callback_container.on_epoch_end(epoch, epoch_logs, on_epoch_end_metric) + + if self.early_stop: + self.callback_container.on_train_end(epoch_logs) + break + + self.callback_container.on_train_end(epoch_logs) + self._restore_best_weights() + self.model.train() + + def predict( # type: ignore[return] + self, + X_tab: np.ndarray, + n_samples: int = 5, + return_samples: bool = False, + batch_size: int = 256, + ) -> np.ndarray: + r"""Returns the predictions + + Parameters + ---------- + X_tab: np.ndarray, + tabular dataset + n_samples: int, default=5 + number of samples that will be either returned or averaged to + produce an overal prediction + return_samples: bool, default = False + Boolean indicating whether the n samples will be averaged or directly returned + batch_size: int, default = 256 + batch size + """ + + preds_l = self._predict(X_tab, n_samples, return_samples, batch_size) + preds = np.hstack(preds_l) if return_samples else np.vstack(preds_l) + axis = 2 if return_samples else 1 + + if self.objective == "regression": + return preds.squeeze(axis) + if self.objective == "binary": + return (preds.squeeze(axis) > 0.5).astype("int") + if self.objective == "multiclass": + return np.argmax(preds, axis) + + def predict_proba( # type: ignore[return] + self, + X_tab: np.ndarray, + n_samples: int = 5, + return_samples: bool = False, + batch_size: int = 256, + ) -> np.ndarray: + r"""Returns the predicted probabilities + + Parameters + ---------- + X_tab: np.ndarray, + tabular dataset + n_samples: int, default=5 + number of samples that will be either returned or averaged to + produce an overal prediction + return_samples: bool, default = False + Boolean indicating whether the n samples will be averaged or directly returned + batch_size: int, default = 256 + batch size + """ + preds_l = self._predict(X_tab, n_samples, return_samples, batch_size) + preds = np.hstack(preds_l) if return_samples else np.vstack(preds_l) + + if self.objective == "binary": + if return_samples: + preds = preds.squeeze(2) + probs = np.zeros([n_samples, preds.shape[0], 2]) + for i in range(n_samples): + probs[i, :, 0] = 1 - preds[i] + probs[i, :, 1] = preds[i] + else: + preds = preds.squeeze(1) + probs = np.zeros([preds.shape[0], 2]) + probs[:, 0] = 1 - preds + probs[:, 1] = preds + return probs + if self.objective == "multiclass": + return preds + + def save( + self, + path: str, + save_state_dict: bool = False, + model_filename: str = "wd_model.pt", + ): + r"""Saves the model, training and evaluation history, and the + ``feature_importance`` attribute (if the ``deeptabular`` component is a + Tabnet model) to disk + + The ``Trainer`` class is built so that it 'just' trains a model. With + that in mind, all the torch related parameters (such as optimizers or + learning rate schedulers) have to be defined externally and then + passed to the ``Trainer``. As a result, the ``Trainer`` does not + generate any attribute or additional data products that need to be + saved other than the ``model`` object itself, which can be saved as + any other torch model (e.g. ``torch.save(model, path)``). + + Parameters + ---------- + path: str + path to the directory where the model and the feature importance + attribute will be saved. + save_state_dict: bool, default = False + Boolean indicating whether to save directly the model or the + model's state dictionary + model_filename: str, Optional, default = "wd_model.pt" + filename where the model weights will be store + """ + + save_dir = Path(path) + history_dir = save_dir / "history" + history_dir.mkdir(exist_ok=True, parents=True) + + # the trainer is run with the History Callback by default + with open(history_dir / "train_eval_history.json", "w") as teh: + json.dump(self.history, teh) # type: ignore[attr-defined] + + has_lr_history = any( + [clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks] + ) + if self.lr_scheduler is not None and has_lr_history: + with open(history_dir / "lr_history.json", "w") as lrh: + json.dump(self.lr_history, lrh) # type: ignore[attr-defined] + + model_path = save_dir / model_filename + if save_state_dict: + torch.save(self.model.state_dict(), model_path) + else: + torch.save(self.model, model_path) + + def _restore_best_weights(self): + already_restored = any( + [ + ( + callback.__class__.__name__ == "EarlyStopping" + and callback.restore_best_weights + ) + for callback in self.callback_container.callbacks + ] + ) + if already_restored: + pass + else: + for callback in self.callback_container.callbacks: + if callback.__class__.__name__ == "ModelCheckpoint": + if callback.save_best_only: + if self.verbose: + print( + f"Model weights restored to best epoch: {callback.best_epoch + 1}" + ) + self.model.load_state_dict(callback.best_state_dict) + else: + if self.verbose: + print( + "Model weights after training corresponds to the those of the " + "final epoch which might not be the best performing weights. Use" + "the 'ModelCheckpoint' Callback to restore the best epoch weights." + ) + + def _train_step( + self, + X_tab: Tensor, + target: Tensor, + n_samples: int, + n_batches: int, + batch_idx: int, + ): + + self.model.train() + + X = X_tab.to(device) + y = target.view(-1, 1).float() if self.objective != "multiclass" else target + y = y.to(device) + + self.optimizer.zero_grad() + y_pred, loss = self.model.sample_elbo(X, y, self.loss_fn, n_samples, n_batches) # type: ignore[arg-type] + + y_pred = y_pred.mean(dim=0) + score = self._get_score(y_pred, y) + + loss.backward() + self.optimizer.step() + + self.train_running_loss += loss.item() + avg_loss = self.train_running_loss / (batch_idx + 1) + + return score, avg_loss + + def _eval_step( + self, + X_tab: Tensor, + target: Tensor, + n_samples: int, + n_batches: int, + batch_idx: int, + ): + + self.model.eval() + with torch.no_grad(): + X = X_tab.to(device) + y = target.view(-1, 1).float() if self.objective != "multiclass" else target + y = y.to(device) + + y_pred, loss = self.model.sample_elbo( + X, # type: ignore[arg-type] + y, + self.loss_fn, + n_samples, + n_batches, + ) + y_pred = y_pred.mean(dim=0) + score = self._get_score(y_pred, y) + + self.valid_running_loss += loss.item() + avg_loss = self.valid_running_loss / (batch_idx + 1) + + return score, avg_loss + + def _get_score(self, y_pred, y): + if self.metric is not None: + if self.objective == "regression": + score = self.metric(y_pred, y) + if self.objective == "binary": + score = self.metric(torch.sigmoid(y_pred), y) + if self.objective == "multiclass": + score = self.metric(F.softmax(y_pred, dim=1), y) + return score + else: + return None + + def _predict( # noqa: C901 + self, + X_tab: np.ndarray = None, + n_samples: int = 5, + return_samples: bool = False, + batch_size: int = 256, + ) -> List: + + self.batch_size = batch_size + + test_set = TensorDataset(torch.from_numpy(X_tab)) + test_loader = DataLoader( + dataset=test_set, + batch_size=self.batch_size, + num_workers=n_cpus, + shuffle=False, + ) + test_steps = (len(test_loader.dataset) // test_loader.batch_size) + 1 # type: ignore[arg-type] + + preds_l = [] + with torch.no_grad(): + with trange(test_steps, disable=self.verbose != 1) as tt: + for j, Xl in zip(tt, test_loader): + tt.set_description("predict") + + X = Xl[0].to(device) + + if return_samples: + preds = torch.stack([self.model(X) for _ in range(n_samples)]) + else: + self.model.eval() + preds = self.model(X) + + if self.objective == "binary": + preds = torch.sigmoid(preds) + if self.objective == "multiclass": + preds = ( + F.softmax(preds, dim=2) + if return_samples + else F.softmax(preds, dim=1) + ) + + preds = preds.cpu().data.numpy() + preds_l.append(preds) + + self.model.train() + + return preds_l + + def _set_loss_fn(self, objective, class_weight, custom_loss_function): + + if custom_loss_function is not None: + return custom_loss_function + + if class_weight is not None: + class_weight = torch.tensor(class_weight).to(device) + elif self.objective != "regression": + return bayesian_alias_to_loss(objective, weight=class_weight) + else: + return bayesian_alias_to_loss(objective) + + def _set_reduce_on_plateau_criterion( + self, lr_scheduler, reducelronplateau_criterion=None + ): + + self.reducelronplateau = False + + if isinstance(lr_scheduler, ReduceLROnPlateau): + self.reducelronplateau = True + + if self.reducelronplateau and not reducelronplateau_criterion: + UserWarning( + "The learning rate scheduler is of type ReduceLROnPlateau. The step method in this" + " scheduler requires a 'metrics' param that can be either the validation loss or the" + " validation metric. Please, when instantiating the Trainer, specify which quantity" + " will be tracked using reducelronplateau_criterion = 'loss' (default) or" + " reducelronplateau_criterion = 'metric'" + ) + else: + self.reducelronplateau_criterion = "loss" + + def _set_lr_scheduler_running_params( + self, lr_scheduler, reducelronplateau_criterion=None + ): + # ReduceLROnPlateau is special + self._set_reduce_on_plateau_criterion(lr_scheduler, reducelronplateau_criterion) + if lr_scheduler is not None: + self.cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower() + else: + self.cyclic_lr = False + + def _set_callbacks_and_metrics(self, callbacks, metrics): + self.callbacks: List = [History(), LRShedulerCallback()] + if callbacks is not None: + for callback in callbacks: + if isinstance(callback, type): + callback = callback() + self.callbacks.append(callback) + if metrics is not None: + self.metric = MultipleMetrics(metrics) + self.callbacks += [MetricCallback(self.metric)] + else: + self.metric = None + self.callback_container = CallbackContainer(self.callbacks) + self.callback_container.set_model(self.model) + self.callback_container.set_trainer(self) -- GitLab