提交 7626bcc9 编写于 作者: J jrzaurin

added explain and compute feature importance methods to the trainer. Need to...

added explain and compute feature importance methods to the trainer. Need to add error handling and messaging since these funcionalities are intended only for tabnet. Move the general_utils module to the training module and rename it as trainer_utils. Adapted the create_explain_matrix function to WideDeep
上级 c36b38c7
...@@ -84,6 +84,7 @@ if __name__ == "__main__": ...@@ -84,6 +84,7 @@ if __name__ == "__main__":
optimizers=optimizers, optimizers=optimizers,
callbacks=callbacks, callbacks=callbacks,
metrics=metrics, metrics=metrics,
compute_feature_importance=False,
) )
trainer.fit( trainer.fit(
......
...@@ -334,7 +334,7 @@ class TabNetEncoder(nn.Module): ...@@ -334,7 +334,7 @@ class TabNetEncoder(nn.Module):
out = self.feat_transformers[step](masked_x) out = self.feat_transformers[step](masked_x)
attn = out[:, self.step_dim :] attn = out[:, self.step_dim :]
# 'decision contribution' in the paper # 'decision contribution' in the paper
d_out = ReLU()(out[:, : self.step_dim]) d_out = nn.ReLU()(out[:, : self.step_dim])
# aggregate decision contribution # aggregate decision contribution
agg_decision_contrib = torch.sum(d_out, dim=1) agg_decision_contrib = torch.sum(d_out, dim=1)
...@@ -437,6 +437,7 @@ class TabNet(nn.Module): ...@@ -437,6 +437,7 @@ class TabNet(nn.Module):
self.embed_and_cont = EmbeddingsAndContinuous( self.embed_and_cont = EmbeddingsAndContinuous(
column_idx, embed_input, embed_dropout, continuous_cols, batchnorm_cont column_idx, embed_input, embed_dropout, continuous_cols, batchnorm_cont
) )
self.embed_and_cont_dim = self.embed_and_cont.output_dim
self.tabnet_encoder = TabNetEncoder( self.tabnet_encoder = TabNetEncoder(
self.embed_and_cont.output_dim, self.embed_and_cont.output_dim,
step_dim, step_dim,
......
import numpy as np import numpy as np
import scipy import scipy
from pytorch_widedeep.wdtypes import WideDeep
def create_explain_matrix(n_feat, cat_emb_dim, cat_idxs, post_embed_dim):
""" def create_explain_matrix(model: WideDeep):
This is a computational trick.
In order to rapidly sum importances from same embeddings (
to the initial index. embed_input,
column_idx,
Parameters embed_and_cont_dim,
---------- ) = _extract_tabnet_params(model)
n_feat : int
Initial input dim n_feat = len(column_idx)
cat_emb_dim : int or list of int col_embeds = {e[0]: e[2] - 1 for e in embed_input}
if int : size of embedding for all categorical feature embed_colname = [e[0] for e in embed_input]
if list of int : size of embedding for each categorical feature cont_colname = [c for c in column_idx.keys() if c not in embed_colname]
cat_idxs : list of int
Initial position of categorical features embed_cum_counter = 0
post_embed_dim : int
Post embedding inputs dimension
Returns
-------
reducing_matrix : np.array
Matrix of dim (post_embed_dim, n_feat) to performe reduce
"""
if isinstance(cat_emb_dim, int):
all_emb_impact = [cat_emb_dim - 1] * len(cat_idxs)
else:
all_emb_impact = [emb_dim - 1 for emb_dim in cat_emb_dim]
acc_emb = 0
nb_emb = 0
indices_trick = [] indices_trick = []
for i in range(n_feat): for colname, idx in column_idx.items():
if i not in cat_idxs: if colname in cont_colname:
indices_trick.append([i + acc_emb]) indices_trick.append([idx + embed_cum_counter])
else: elif colname in embed_colname:
indices_trick.append( indices_trick.append(
range(i + acc_emb, i + acc_emb + all_emb_impact[nb_emb] + 1) range( # type: ignore[arg-type]
idx + embed_cum_counter,
idx + embed_cum_counter + col_embeds[colname] + 1,
)
) )
acc_emb += all_emb_impact[nb_emb] embed_cum_counter += col_embeds[colname]
nb_emb += 1
reducing_matrix = np.zeros((post_embed_dim, n_feat)) reducing_matrix = np.zeros((embed_and_cont_dim, n_feat))
for i, cols in enumerate(indices_trick): for i, cols in enumerate(indices_trick):
reducing_matrix[cols, i] = 1 reducing_matrix[cols, i] = 1
return scipy.sparse.csc_matrix(reducing_matrix) return scipy.sparse.csc_matrix(reducing_matrix)
def _extract_tabnet_params(model: WideDeep):
tabnet_backbone = list(model.deeptabular.children())[0]
column_idx = tabnet_backbone.column_idx
embed_input = tabnet_backbone.embed_input
embed_and_cont_dim = tabnet_backbone.embed_and_cont_dim
return embed_input, column_idx, embed_and_cont_dim
...@@ -5,18 +5,28 @@ import torch ...@@ -5,18 +5,28 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from tqdm import trange from tqdm import trange
from scipy.sparse import csc_matrix
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from pytorch_widedeep.losses import MSLELoss, RMSELoss, FocalLoss, RMSLELoss from pytorch_widedeep.losses import MSLELoss, RMSELoss, FocalLoss, RMSLELoss
from pytorch_widedeep.models import WideDeep
from pytorch_widedeep.metrics import Metric, MetricCallback, MultipleMetrics from pytorch_widedeep.metrics import Metric, MetricCallback, MultipleMetrics
from pytorch_widedeep.wdtypes import * # noqa: F403 from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.callbacks import History, LRShedulerCallback, Callback, CallbackContainer from pytorch_widedeep.callbacks import (
History,
Callback,
CallbackContainer,
LRShedulerCallback,
)
from pytorch_widedeep.initializers import Initializer, MultipleInitializer from pytorch_widedeep.initializers import Initializer, MultipleInitializer
from pytorch_widedeep.training._finetune import FineTune from pytorch_widedeep.training._finetune import FineTune
from pytorch_widedeep.utils.general_utils import Alias
from pytorch_widedeep.training._wd_dataset import WideDeepDataset from pytorch_widedeep.training._wd_dataset import WideDeepDataset
from pytorch_widedeep.training.trainer_utils import (
Alias,
save_epoch_logs,
print_loss_and_metric,
)
from pytorch_widedeep.models.tabnet.tabnet_utils import create_explain_matrix
from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer
from pytorch_widedeep.training._multiple_transforms import MultipleTransforms from pytorch_widedeep.training._multiple_transforms import MultipleTransforms
from pytorch_widedeep.training._loss_and_obj_aliases import ( from pytorch_widedeep.training._loss_and_obj_aliases import (
...@@ -54,6 +64,7 @@ class Trainer: ...@@ -54,6 +64,7 @@ class Trainer:
alpha: float = 0.25, alpha: float = 0.25,
gamma: float = 2, gamma: float = 2,
lambda_sparse: float = 1e-3, lambda_sparse: float = 1e-3,
compute_feature_importance: bool = True,
verbose: int = 1, verbose: int = 1,
seed: int = 1, seed: int = 1,
): ):
...@@ -248,8 +259,12 @@ class Trainer: ...@@ -248,8 +259,12 @@ class Trainer:
else: else:
self.model = model self.model = model
#  Tabnet related set ups
self.compute_feature_importance = compute_feature_importance
if self.model.is_tabnet: if self.model.is_tabnet:
self.lambda_sparse = lambda_sparse self.lambda_sparse = lambda_sparse
if self.model.is_tabnet and self.compute_feature_importance:
self.reducing_matrix = create_explain_matrix(model)
self.verbose = verbose self.verbose = verbose
self.seed = seed self.seed = seed
...@@ -268,19 +283,19 @@ class Trainer: ...@@ -268,19 +283,19 @@ class Trainer:
self.model.to(device) self.model.to(device)
@Alias("finetune", ["warmup"]) # noqa: C901 @Alias("finetune", "warmup") # noqa: C901
@Alias("finetune_epochs", ["warmup_epochs"]) @Alias("finetune_epochs", "warmup_epochs")
@Alias("finetune_max_lr", ["warmup_max_lr"]) @Alias("finetune_max_lr", "warmup_max_lr")
@Alias("finetune_deeptabular_gradual", ["warmup_deeptabular_gradual"]) @Alias("finetune_deeptabular_gradual", "warmup_deeptabular_gradual")
@Alias("finetune_deeptabular_max_lr", ["warmup_deeptabular_max_lr"]) @Alias("finetune_deeptabular_max_lr", "warmup_deeptabular_max_lr")
@Alias("finetune_deeptabular_layers", ["warmup_deeptabular_layers"]) @Alias("finetune_deeptabular_layers", "warmup_deeptabular_layers")
@Alias("finetune_deeptext_gradual", ["warmup_deeptext_gradual"]) @Alias("finetune_deeptext_gradual", "warmup_deeptext_gradual")
@Alias("finetune_deeptext_max_lr", ["warmup_deeptext_max_lr"]) @Alias("finetune_deeptext_max_lr", "warmup_deeptext_max_lr")
@Alias("finetune_deeptext_layers", ["warmup_deeptext_layers"]) @Alias("finetune_deeptext_layers", "warmup_deeptext_layers")
@Alias("finetune_deepimage_gradual", ["warmup_deepimage_gradual"]) @Alias("finetune_deepimage_gradual", "warmup_deepimage_gradual")
@Alias("finetune_deepimage_max_lr", ["warmup_deepimage_max_lr"]) @Alias("finetune_deepimage_max_lr", "warmup_deepimage_max_lr")
@Alias("finetune_deepimage_layers", ["warmup_deepimage_layers"]) @Alias("finetune_deepimage_layers", "warmup_deepimage_layers")
@Alias("finetune_routine", ["warmup_routine"]) @Alias("finetune_routine", "warmup_routine")
def fit( # noqa: C901 def fit( # noqa: C901
self, self,
X_wide: Optional[np.ndarray] = None, X_wide: Optional[np.ndarray] = None,
...@@ -542,24 +557,16 @@ class Trainer: ...@@ -542,24 +557,16 @@ class Trainer:
for epoch in range(n_epochs): for epoch in range(n_epochs):
epoch_logs: Dict[str, float] = {} epoch_logs: Dict[str, float] = {}
self.callback_container.on_epoch_begin(epoch, logs=epoch_logs) self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
self.train_running_loss = 0.0 self.train_running_loss = 0.0
with trange(train_steps, disable=self.verbose != 1) as t: with trange(train_steps, disable=self.verbose != 1) as t:
for batch_idx, (data, targett) in zip(t, train_loader): for batch_idx, (data, targett) in zip(t, train_loader):
t.set_description("epoch %i" % (epoch + 1)) t.set_description("epoch %i" % (epoch + 1))
score, train_loss = self._training_step(data, targett, batch_idx) score, train_loss = self._train_step(data, targett, batch_idx)
if score is not None: print_loss_and_metric(t, train_loss, score)
t.set_postfix(
metrics={k: np.round(v, 4) for k, v in score.items()},
loss=train_loss,
)
else:
t.set_postfix(loss=train_loss)
self.callback_container.on_batch_end(batch=batch_idx) self.callback_container.on_batch_end(batch=batch_idx)
epoch_logs["train_loss"] = train_loss epoch_logs = save_epoch_logs(epoch_logs, train_loss, score, "train")
if score is not None:
for k, v in score.items():
log_k = "_".join(["train", k])
epoch_logs[log_k] = v
if eval_set is not None and epoch % validation_freq == ( if eval_set is not None and epoch % validation_freq == (
validation_freq - 1 validation_freq - 1
): ):
...@@ -568,23 +575,17 @@ class Trainer: ...@@ -568,23 +575,17 @@ class Trainer:
for i, (data, targett) in zip(v, eval_loader): for i, (data, targett) in zip(v, eval_loader):
v.set_description("valid") v.set_description("valid")
score, val_loss = self._validation_step(data, targett, i) score, val_loss = self._validation_step(data, targett, i)
if score is not None: print_loss_and_metric(v, val_loss, score)
v.set_postfix( epoch_logs = save_epoch_logs(epoch_logs, val_loss, score, "val")
metrics={k: np.round(v, 4) for k, v in score.items()},
loss=val_loss,
)
else:
v.set_postfix(loss=val_loss)
epoch_logs["val_loss"] = val_loss
if score is not None:
for k, v in score.items():
log_k = "_".join(["val", k])
epoch_logs[log_k] = v
self.callback_container.on_epoch_end(epoch, epoch_logs) self.callback_container.on_epoch_end(epoch, epoch_logs)
if self.early_stop: if self.early_stop:
self.callback_container.on_train_end(epoch_logs) self.callback_container.on_train_end(epoch_logs)
break break
self.callback_container.on_train_end(epoch_logs) self.callback_container.on_train_end(epoch_logs)
if self.compute_feature_importance and self.model.is_tabnet:
self._compute_feature_importance(train_loader)
self.model.train() self.model.train()
def predict( # type: ignore[return] def predict( # type: ignore[return]
...@@ -727,6 +728,65 @@ class Trainer: ...@@ -727,6 +728,65 @@ class Trainer:
cat_embed_dict[value] = embed_mtx[idx] cat_embed_dict[value] = embed_mtx[idx]
return cat_embed_dict return cat_embed_dict
def explain(self, X_tab: np.ndarray, save_step_masks: bool = False):
"""
Returns the aggregated feature importance for each instance (or
observation) in the ``X_tab`` array. If ``save_step_masks`` is set to
``True``, the masks per step will also be returned.
Parameters
----------
X_tab: np.ndarray
Input array corresponding **only** to the deeptabular component
save_step_masks: bool
Boolean indicating if the masks per step will be returned
Returns
-------
res: np.ndarray, Tuple
Array or Tuple of two arrays with the corresponding aggregated
feature importance and the masks per step if ``save_step_masks``
is set to ``True``
"""
loader = DataLoader(
dataset=WideDeepDataset(**{"X_tab": X_tab}),
batch_size=self.batch_size,
num_workers=n_cpus,
shuffle=False,
)
self.model.eval()
tabnet_backbone = list(self.model.deeptabular.children())[0]
m_explain_l = []
for batch_nb, data in enumerate(loader):
X = data["deeptabular"].to(device)
M_explain, masks = tabnet_backbone.forward_masks(X)
m_explain_l.append(
csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix)
)
if save_step_masks:
for key, value in masks.items():
masks[key] = csc_matrix.dot(
value.cpu().detach().numpy(), self.reducing_matrix
)
if batch_nb == 0:
m_explain_step = masks
else:
for key, value in masks.items():
m_explain_step[key] = np.vstack([m_explain_step[key], value])
m_explain_agg = np.vstack(m_explain_l)
m_explain_agg_norm = m_explain_agg / m_explain_agg.sum(axis=1)[:, np.newaxis]
res = (
(m_explain_agg_norm, m_explain_step)
if save_step_masks
else np.vstack(m_explain_agg_norm)
)
return res
def save_model(self, path: str): def save_model(self, path: str):
"""Saves the model to disk """Saves the model to disk
...@@ -932,7 +992,7 @@ class Trainer: ...@@ -932,7 +992,7 @@ class Trainer:
self.model.deepimage, "deepimage", loader, n_epochs, max_lr self.model.deepimage, "deepimage", loader, n_epochs, max_lr
) )
def _training_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int): def _train_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
self.model.train() self.model.train()
X = {k: v.cuda() for k, v in data.items()} if use_cuda else data X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
y = target.view(-1, 1).float() if self.method != "multiclass" else target y = target.view(-1, 1).float() if self.method != "multiclass" else target
...@@ -987,6 +1047,24 @@ class Trainer: ...@@ -987,6 +1047,24 @@ class Trainer:
else: else:
return None return None
def _compute_feature_importance(self, loader: DataLoader):
self.model.eval()
tabnet_backbone = list(self.model.deeptabular.children())[0]
feat_imp = np.zeros((tabnet_backbone.embed_and_cont_dim))
for data, target in loader:
X = data["deeptabular"].to(device)
y = target.view(-1, 1).float() if self.method != "multiclass" else target
y = y.to(device)
M_explain, masks = tabnet_backbone.forward_masks(X)
feat_imp += M_explain.sum(dim=0).cpu().detach().numpy()
feat_imp = csc_matrix.dot(feat_imp, self.reducing_matrix)
feat_imp = feat_imp / np.sum(feat_imp)
self.feature_importance = {
k: v for k, v in zip(tabnet_backbone.column_idx.keys(), feat_imp)
}
def _predict( def _predict(
self, self,
X_wide: Optional[np.ndarray] = None, X_wide: Optional[np.ndarray] = None,
......
""" """
Code here taken from the one and only Hunter McGushion and his fantastic Code for 'Alias' and 'set_default_attr' taken from the one and only Hunter
library: https://github.com/HunterMcGushion/hyperparameter_hunter McGushion and his library:
https://github.com/HunterMcGushion/hyperparameter_hunter
""" """
import numpy as np
import wrapt import wrapt
from tqdm import tqdm
from pytorch_widedeep.wdtypes import Any, Dict, List, Union
class Alias: class Alias:
def __init__(self, primary_name, aliases): def __init__(self, primary_name: str, aliases: Union[str, List[str]]):
"""Convert uses of `aliases` to `primary_name` upon calling the decorated function/method """Convert uses of `aliases` to `primary_name` upon calling the decorated function/method
Parameters Parameters
...@@ -56,7 +61,7 @@ class Alias: ...@@ -56,7 +61,7 @@ class Alias:
return wrapped(*args, **kwargs) return wrapped(*args, **kwargs)
def set_default_attr(obj, name, value): def set_default_attr(obj: Any, name: str, value: Any):
"""Set the `name` attribute of `obj` to `value` if the attribute does not already exist """Set the `name` attribute of `obj` to `value` if the attribute does not already exist
Parameters Parameters
...@@ -88,3 +93,52 @@ def set_default_attr(obj, name, value): ...@@ -88,3 +93,52 @@ def set_default_attr(obj, name, value):
except AttributeError: except AttributeError:
setattr(obj, name, value) setattr(obj, name, value)
return value return value
def print_loss_and_metric(pb: tqdm, loss: float, score: Dict):
"""
Little function to improve readability and avoid code repetition in the
training/validation loop within the Trainer's fit method
Parameters
----------
pb: tqdm
tqdm Object defined as trange(...)
loss: float
loss value
score: Dict
Dictionary where the keys are the metric names and the values are the
corresponding values
"""
if score is not None:
pb.set_postfix(
metrics={k: np.round(v, 4) for k, v in score.items()},
loss=loss,
)
else:
pb.set_postfix(loss=loss)
def save_epoch_logs(epoch_logs: Dict, loss: float, score: Dict, stage: str):
"""
Little function to improve readability and avoid code repetition in the
training/validation loop within the Trainer's fit method
Parameters
----------
epoch_logs: Dict
Dict containing the epoch logs
loss: float
loss value
score: Dict
Dictionary where the keys are the metric names and the values are the
corresponding values
stage: str
one of 'train' or 'val'
"""
epoch_logs["_".join([stage, "loss"])] = loss
if score is not None:
for k, v in score.items():
log_k = "_".join([stage, k])
epoch_logs[log_k] = v
return epoch_logs
...@@ -49,6 +49,7 @@ from torchvision.transforms import ( ...@@ -49,6 +49,7 @@ from torchvision.transforms import (
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from pytorch_widedeep.models import WideDeep
from pytorch_widedeep.models.tabnet.sparsemax import Entmax15, Sparsemax from pytorch_widedeep.models.tabnet.sparsemax import Entmax15, Sparsemax
ListRules = Collection[Callable[[str], str]] ListRules = Collection[Callable[[str], str]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册