diff --git a/examples/adult_census_tabnet.py b/examples/adult_census_tabnet.py index 57b44b5cf53e161870f4741d88f8e6096c7214d1..fef09458c92c641d41d0a3e61c84d462b8b32e6d 100644 --- a/examples/adult_census_tabnet.py +++ b/examples/adult_census_tabnet.py @@ -84,6 +84,7 @@ if __name__ == "__main__": optimizers=optimizers, callbacks=callbacks, metrics=metrics, + compute_feature_importance=False, ) trainer.fit( diff --git a/pytorch_widedeep/models/tabnet/tab_net.py b/pytorch_widedeep/models/tabnet/tab_net.py index 7ad23995acffd83018a27740c92c474bba4b8263..4456e528e54359ad8a118d5644cb5e2327d645cf 100644 --- a/pytorch_widedeep/models/tabnet/tab_net.py +++ b/pytorch_widedeep/models/tabnet/tab_net.py @@ -334,7 +334,7 @@ class TabNetEncoder(nn.Module): out = self.feat_transformers[step](masked_x) attn = out[:, self.step_dim :] # 'decision contribution' in the paper - d_out = ReLU()(out[:, : self.step_dim]) + d_out = nn.ReLU()(out[:, : self.step_dim]) # aggregate decision contribution agg_decision_contrib = torch.sum(d_out, dim=1) @@ -437,6 +437,7 @@ class TabNet(nn.Module): self.embed_and_cont = EmbeddingsAndContinuous( column_idx, embed_input, embed_dropout, continuous_cols, batchnorm_cont ) + self.embed_and_cont_dim = self.embed_and_cont.output_dim self.tabnet_encoder = TabNetEncoder( self.embed_and_cont.output_dim, step_dim, diff --git a/pytorch_widedeep/models/tabnet/tabnet_utils.py b/pytorch_widedeep/models/tabnet/tabnet_utils.py index 2099adbb39a61ad2580203d1596f0bbb3421e0bd..ac4b67c884e44ec125d6bec442f5b46fb242cfa9 100644 --- a/pytorch_widedeep/models/tabnet/tabnet_utils.py +++ b/pytorch_widedeep/models/tabnet/tabnet_utils.py @@ -1,51 +1,49 @@ import numpy as np import scipy +from pytorch_widedeep.wdtypes import WideDeep -def create_explain_matrix(n_feat, cat_emb_dim, cat_idxs, post_embed_dim): - """ - This is a computational trick. - In order to rapidly sum importances from same embeddings - to the initial index. - - Parameters - ---------- - n_feat : int - Initial input dim - cat_emb_dim : int or list of int - if int : size of embedding for all categorical feature - if list of int : size of embedding for each categorical feature - cat_idxs : list of int - Initial position of categorical features - post_embed_dim : int - Post embedding inputs dimension - - Returns - ------- - reducing_matrix : np.array - Matrix of dim (post_embed_dim, n_feat) to performe reduce - """ - - if isinstance(cat_emb_dim, int): - all_emb_impact = [cat_emb_dim - 1] * len(cat_idxs) - else: - all_emb_impact = [emb_dim - 1 for emb_dim in cat_emb_dim] - - acc_emb = 0 - nb_emb = 0 + +def create_explain_matrix(model: WideDeep): + + ( + embed_input, + column_idx, + embed_and_cont_dim, + ) = _extract_tabnet_params(model) + + n_feat = len(column_idx) + col_embeds = {e[0]: e[2] - 1 for e in embed_input} + embed_colname = [e[0] for e in embed_input] + cont_colname = [c for c in column_idx.keys() if c not in embed_colname] + + embed_cum_counter = 0 indices_trick = [] - for i in range(n_feat): - if i not in cat_idxs: - indices_trick.append([i + acc_emb]) - else: + for colname, idx in column_idx.items(): + if colname in cont_colname: + indices_trick.append([idx + embed_cum_counter]) + elif colname in embed_colname: indices_trick.append( - range(i + acc_emb, i + acc_emb + all_emb_impact[nb_emb] + 1) + range( # type: ignore[arg-type] + idx + embed_cum_counter, + idx + embed_cum_counter + col_embeds[colname] + 1, + ) ) - acc_emb += all_emb_impact[nb_emb] - nb_emb += 1 + embed_cum_counter += col_embeds[colname] - reducing_matrix = np.zeros((post_embed_dim, n_feat)) + reducing_matrix = np.zeros((embed_and_cont_dim, n_feat)) for i, cols in enumerate(indices_trick): reducing_matrix[cols, i] = 1 return scipy.sparse.csc_matrix(reducing_matrix) + + +def _extract_tabnet_params(model: WideDeep): + + tabnet_backbone = list(model.deeptabular.children())[0] + + column_idx = tabnet_backbone.column_idx + embed_input = tabnet_backbone.embed_input + embed_and_cont_dim = tabnet_backbone.embed_and_cont_dim + + return embed_input, column_idx, embed_and_cont_dim diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index d8fbf867878aa75e4ad6a10a5a496a527e524c77..8b0765dcf52155fa1e3e7777f9ef9faece9fe158 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -5,18 +5,28 @@ import torch import torch.nn as nn import torch.nn.functional as F from tqdm import trange +from scipy.sparse import csc_matrix from torch.utils.data import DataLoader from sklearn.model_selection import train_test_split from pytorch_widedeep.losses import MSLELoss, RMSELoss, FocalLoss, RMSLELoss -from pytorch_widedeep.models import WideDeep from pytorch_widedeep.metrics import Metric, MetricCallback, MultipleMetrics from pytorch_widedeep.wdtypes import * # noqa: F403 -from pytorch_widedeep.callbacks import History, LRShedulerCallback, Callback, CallbackContainer +from pytorch_widedeep.callbacks import ( + History, + Callback, + CallbackContainer, + LRShedulerCallback, +) from pytorch_widedeep.initializers import Initializer, MultipleInitializer from pytorch_widedeep.training._finetune import FineTune -from pytorch_widedeep.utils.general_utils import Alias from pytorch_widedeep.training._wd_dataset import WideDeepDataset +from pytorch_widedeep.training.trainer_utils import ( + Alias, + save_epoch_logs, + print_loss_and_metric, +) +from pytorch_widedeep.models.tabnet.tabnet_utils import create_explain_matrix from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer from pytorch_widedeep.training._multiple_transforms import MultipleTransforms from pytorch_widedeep.training._loss_and_obj_aliases import ( @@ -54,6 +64,7 @@ class Trainer: alpha: float = 0.25, gamma: float = 2, lambda_sparse: float = 1e-3, + compute_feature_importance: bool = True, verbose: int = 1, seed: int = 1, ): @@ -248,8 +259,12 @@ class Trainer: else: self.model = model + #  Tabnet related set ups + self.compute_feature_importance = compute_feature_importance if self.model.is_tabnet: self.lambda_sparse = lambda_sparse + if self.model.is_tabnet and self.compute_feature_importance: + self.reducing_matrix = create_explain_matrix(model) self.verbose = verbose self.seed = seed @@ -268,19 +283,19 @@ class Trainer: self.model.to(device) - @Alias("finetune", ["warmup"]) # noqa: C901 - @Alias("finetune_epochs", ["warmup_epochs"]) - @Alias("finetune_max_lr", ["warmup_max_lr"]) - @Alias("finetune_deeptabular_gradual", ["warmup_deeptabular_gradual"]) - @Alias("finetune_deeptabular_max_lr", ["warmup_deeptabular_max_lr"]) - @Alias("finetune_deeptabular_layers", ["warmup_deeptabular_layers"]) - @Alias("finetune_deeptext_gradual", ["warmup_deeptext_gradual"]) - @Alias("finetune_deeptext_max_lr", ["warmup_deeptext_max_lr"]) - @Alias("finetune_deeptext_layers", ["warmup_deeptext_layers"]) - @Alias("finetune_deepimage_gradual", ["warmup_deepimage_gradual"]) - @Alias("finetune_deepimage_max_lr", ["warmup_deepimage_max_lr"]) - @Alias("finetune_deepimage_layers", ["warmup_deepimage_layers"]) - @Alias("finetune_routine", ["warmup_routine"]) + @Alias("finetune", "warmup") # noqa: C901 + @Alias("finetune_epochs", "warmup_epochs") + @Alias("finetune_max_lr", "warmup_max_lr") + @Alias("finetune_deeptabular_gradual", "warmup_deeptabular_gradual") + @Alias("finetune_deeptabular_max_lr", "warmup_deeptabular_max_lr") + @Alias("finetune_deeptabular_layers", "warmup_deeptabular_layers") + @Alias("finetune_deeptext_gradual", "warmup_deeptext_gradual") + @Alias("finetune_deeptext_max_lr", "warmup_deeptext_max_lr") + @Alias("finetune_deeptext_layers", "warmup_deeptext_layers") + @Alias("finetune_deepimage_gradual", "warmup_deepimage_gradual") + @Alias("finetune_deepimage_max_lr", "warmup_deepimage_max_lr") + @Alias("finetune_deepimage_layers", "warmup_deepimage_layers") + @Alias("finetune_routine", "warmup_routine") def fit( # noqa: C901 self, X_wide: Optional[np.ndarray] = None, @@ -542,24 +557,16 @@ class Trainer: for epoch in range(n_epochs): epoch_logs: Dict[str, float] = {} self.callback_container.on_epoch_begin(epoch, logs=epoch_logs) + self.train_running_loss = 0.0 with trange(train_steps, disable=self.verbose != 1) as t: for batch_idx, (data, targett) in zip(t, train_loader): t.set_description("epoch %i" % (epoch + 1)) - score, train_loss = self._training_step(data, targett, batch_idx) - if score is not None: - t.set_postfix( - metrics={k: np.round(v, 4) for k, v in score.items()}, - loss=train_loss, - ) - else: - t.set_postfix(loss=train_loss) + score, train_loss = self._train_step(data, targett, batch_idx) + print_loss_and_metric(t, train_loss, score) self.callback_container.on_batch_end(batch=batch_idx) - epoch_logs["train_loss"] = train_loss - if score is not None: - for k, v in score.items(): - log_k = "_".join(["train", k]) - epoch_logs[log_k] = v + epoch_logs = save_epoch_logs(epoch_logs, train_loss, score, "train") + if eval_set is not None and epoch % validation_freq == ( validation_freq - 1 ): @@ -568,23 +575,17 @@ class Trainer: for i, (data, targett) in zip(v, eval_loader): v.set_description("valid") score, val_loss = self._validation_step(data, targett, i) - if score is not None: - v.set_postfix( - metrics={k: np.round(v, 4) for k, v in score.items()}, - loss=val_loss, - ) - else: - v.set_postfix(loss=val_loss) - epoch_logs["val_loss"] = val_loss - if score is not None: - for k, v in score.items(): - log_k = "_".join(["val", k]) - epoch_logs[log_k] = v + print_loss_and_metric(v, val_loss, score) + epoch_logs = save_epoch_logs(epoch_logs, val_loss, score, "val") + self.callback_container.on_epoch_end(epoch, epoch_logs) if self.early_stop: self.callback_container.on_train_end(epoch_logs) break + self.callback_container.on_train_end(epoch_logs) + if self.compute_feature_importance and self.model.is_tabnet: + self._compute_feature_importance(train_loader) self.model.train() def predict( # type: ignore[return] @@ -727,6 +728,65 @@ class Trainer: cat_embed_dict[value] = embed_mtx[idx] return cat_embed_dict + def explain(self, X_tab: np.ndarray, save_step_masks: bool = False): + """ + Returns the aggregated feature importance for each instance (or + observation) in the ``X_tab`` array. If ``save_step_masks`` is set to + ``True``, the masks per step will also be returned. + + Parameters + ---------- + X_tab: np.ndarray + Input array corresponding **only** to the deeptabular component + save_step_masks: bool + Boolean indicating if the masks per step will be returned + + Returns + ------- + res: np.ndarray, Tuple + Array or Tuple of two arrays with the corresponding aggregated + feature importance and the masks per step if ``save_step_masks`` + is set to ``True`` + """ + loader = DataLoader( + dataset=WideDeepDataset(**{"X_tab": X_tab}), + batch_size=self.batch_size, + num_workers=n_cpus, + shuffle=False, + ) + + self.model.eval() + tabnet_backbone = list(self.model.deeptabular.children())[0] + + m_explain_l = [] + for batch_nb, data in enumerate(loader): + X = data["deeptabular"].to(device) + M_explain, masks = tabnet_backbone.forward_masks(X) + m_explain_l.append( + csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix) + ) + if save_step_masks: + for key, value in masks.items(): + masks[key] = csc_matrix.dot( + value.cpu().detach().numpy(), self.reducing_matrix + ) + if batch_nb == 0: + m_explain_step = masks + else: + for key, value in masks.items(): + m_explain_step[key] = np.vstack([m_explain_step[key], value]) + + m_explain_agg = np.vstack(m_explain_l) + m_explain_agg_norm = m_explain_agg / m_explain_agg.sum(axis=1)[:, np.newaxis] + + res = ( + (m_explain_agg_norm, m_explain_step) + if save_step_masks + else np.vstack(m_explain_agg_norm) + ) + + return res + def save_model(self, path: str): """Saves the model to disk @@ -932,7 +992,7 @@ class Trainer: self.model.deepimage, "deepimage", loader, n_epochs, max_lr ) - def _training_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int): + def _train_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int): self.model.train() X = {k: v.cuda() for k, v in data.items()} if use_cuda else data y = target.view(-1, 1).float() if self.method != "multiclass" else target @@ -987,6 +1047,24 @@ class Trainer: else: return None + def _compute_feature_importance(self, loader: DataLoader): + self.model.eval() + tabnet_backbone = list(self.model.deeptabular.children())[0] + feat_imp = np.zeros((tabnet_backbone.embed_and_cont_dim)) + for data, target in loader: + X = data["deeptabular"].to(device) + y = target.view(-1, 1).float() if self.method != "multiclass" else target + y = y.to(device) + M_explain, masks = tabnet_backbone.forward_masks(X) + feat_imp += M_explain.sum(dim=0).cpu().detach().numpy() + + feat_imp = csc_matrix.dot(feat_imp, self.reducing_matrix) + feat_imp = feat_imp / np.sum(feat_imp) + + self.feature_importance = { + k: v for k, v in zip(tabnet_backbone.column_idx.keys(), feat_imp) + } + def _predict( self, X_wide: Optional[np.ndarray] = None, diff --git a/pytorch_widedeep/utils/general_utils.py b/pytorch_widedeep/training/trainer_utils.py similarity index 62% rename from pytorch_widedeep/utils/general_utils.py rename to pytorch_widedeep/training/trainer_utils.py index 5814b791bfb0baecbf9f3169ad785be2a5d02426..6ed470abf07309464014cb840e3f980f024e7b9d 100644 --- a/pytorch_widedeep/utils/general_utils.py +++ b/pytorch_widedeep/training/trainer_utils.py @@ -1,13 +1,18 @@ """ -Code here taken from the one and only Hunter McGushion and his fantastic -library: https://github.com/HunterMcGushion/hyperparameter_hunter +Code for 'Alias' and 'set_default_attr' taken from the one and only Hunter +McGushion and his library: +https://github.com/HunterMcGushion/hyperparameter_hunter """ +import numpy as np import wrapt +from tqdm import tqdm + +from pytorch_widedeep.wdtypes import Any, Dict, List, Union class Alias: - def __init__(self, primary_name, aliases): + def __init__(self, primary_name: str, aliases: Union[str, List[str]]): """Convert uses of `aliases` to `primary_name` upon calling the decorated function/method Parameters @@ -56,7 +61,7 @@ class Alias: return wrapped(*args, **kwargs) -def set_default_attr(obj, name, value): +def set_default_attr(obj: Any, name: str, value: Any): """Set the `name` attribute of `obj` to `value` if the attribute does not already exist Parameters @@ -88,3 +93,52 @@ def set_default_attr(obj, name, value): except AttributeError: setattr(obj, name, value) return value + + +def print_loss_and_metric(pb: tqdm, loss: float, score: Dict): + """ + Little function to improve readability and avoid code repetition in the + training/validation loop within the Trainer's fit method + + Parameters + ---------- + pb: tqdm + tqdm Object defined as trange(...) + loss: float + loss value + score: Dict + Dictionary where the keys are the metric names and the values are the + corresponding values + """ + if score is not None: + pb.set_postfix( + metrics={k: np.round(v, 4) for k, v in score.items()}, + loss=loss, + ) + else: + pb.set_postfix(loss=loss) + + +def save_epoch_logs(epoch_logs: Dict, loss: float, score: Dict, stage: str): + """ + Little function to improve readability and avoid code repetition in the + training/validation loop within the Trainer's fit method + + Parameters + ---------- + epoch_logs: Dict + Dict containing the epoch logs + loss: float + loss value + score: Dict + Dictionary where the keys are the metric names and the values are the + corresponding values + stage: str + one of 'train' or 'val' + """ + epoch_logs["_".join([stage, "loss"])] = loss + if score is not None: + for k, v in score.items(): + log_k = "_".join([stage, k]) + epoch_logs[log_k] = v + return epoch_logs diff --git a/pytorch_widedeep/wdtypes.py b/pytorch_widedeep/wdtypes.py index 328916fcc3393266e88951f33a04bf1bea443e01..9af0c281cd2ec8c476b6e91e0983f3d08739b0c3 100644 --- a/pytorch_widedeep/wdtypes.py +++ b/pytorch_widedeep/wdtypes.py @@ -49,6 +49,7 @@ from torchvision.transforms import ( from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data.dataloader import DataLoader +from pytorch_widedeep.models import WideDeep from pytorch_widedeep.models.tabnet.sparsemax import Entmax15, Sparsemax ListRules = Collection[Callable[[str], str]]