提交 7626bcc9 编写于 作者: J jrzaurin

added explain and compute feature importance methods to the trainer. Need to...

added explain and compute feature importance methods to the trainer. Need to add error handling and messaging since these funcionalities are intended only for tabnet. Move the general_utils module to the training module and rename it as trainer_utils. Adapted the create_explain_matrix function to WideDeep
上级 c36b38c7
......@@ -84,6 +84,7 @@ if __name__ == "__main__":
optimizers=optimizers,
callbacks=callbacks,
metrics=metrics,
compute_feature_importance=False,
)
trainer.fit(
......
......@@ -334,7 +334,7 @@ class TabNetEncoder(nn.Module):
out = self.feat_transformers[step](masked_x)
attn = out[:, self.step_dim :]
# 'decision contribution' in the paper
d_out = ReLU()(out[:, : self.step_dim])
d_out = nn.ReLU()(out[:, : self.step_dim])
# aggregate decision contribution
agg_decision_contrib = torch.sum(d_out, dim=1)
......@@ -437,6 +437,7 @@ class TabNet(nn.Module):
self.embed_and_cont = EmbeddingsAndContinuous(
column_idx, embed_input, embed_dropout, continuous_cols, batchnorm_cont
)
self.embed_and_cont_dim = self.embed_and_cont.output_dim
self.tabnet_encoder = TabNetEncoder(
self.embed_and_cont.output_dim,
step_dim,
......
import numpy as np
import scipy
from pytorch_widedeep.wdtypes import WideDeep
def create_explain_matrix(n_feat, cat_emb_dim, cat_idxs, post_embed_dim):
"""
This is a computational trick.
In order to rapidly sum importances from same embeddings
to the initial index.
Parameters
----------
n_feat : int
Initial input dim
cat_emb_dim : int or list of int
if int : size of embedding for all categorical feature
if list of int : size of embedding for each categorical feature
cat_idxs : list of int
Initial position of categorical features
post_embed_dim : int
Post embedding inputs dimension
Returns
-------
reducing_matrix : np.array
Matrix of dim (post_embed_dim, n_feat) to performe reduce
"""
if isinstance(cat_emb_dim, int):
all_emb_impact = [cat_emb_dim - 1] * len(cat_idxs)
else:
all_emb_impact = [emb_dim - 1 for emb_dim in cat_emb_dim]
acc_emb = 0
nb_emb = 0
def create_explain_matrix(model: WideDeep):
(
embed_input,
column_idx,
embed_and_cont_dim,
) = _extract_tabnet_params(model)
n_feat = len(column_idx)
col_embeds = {e[0]: e[2] - 1 for e in embed_input}
embed_colname = [e[0] for e in embed_input]
cont_colname = [c for c in column_idx.keys() if c not in embed_colname]
embed_cum_counter = 0
indices_trick = []
for i in range(n_feat):
if i not in cat_idxs:
indices_trick.append([i + acc_emb])
else:
for colname, idx in column_idx.items():
if colname in cont_colname:
indices_trick.append([idx + embed_cum_counter])
elif colname in embed_colname:
indices_trick.append(
range(i + acc_emb, i + acc_emb + all_emb_impact[nb_emb] + 1)
range( # type: ignore[arg-type]
idx + embed_cum_counter,
idx + embed_cum_counter + col_embeds[colname] + 1,
)
)
acc_emb += all_emb_impact[nb_emb]
nb_emb += 1
embed_cum_counter += col_embeds[colname]
reducing_matrix = np.zeros((post_embed_dim, n_feat))
reducing_matrix = np.zeros((embed_and_cont_dim, n_feat))
for i, cols in enumerate(indices_trick):
reducing_matrix[cols, i] = 1
return scipy.sparse.csc_matrix(reducing_matrix)
def _extract_tabnet_params(model: WideDeep):
tabnet_backbone = list(model.deeptabular.children())[0]
column_idx = tabnet_backbone.column_idx
embed_input = tabnet_backbone.embed_input
embed_and_cont_dim = tabnet_backbone.embed_and_cont_dim
return embed_input, column_idx, embed_and_cont_dim
......@@ -5,18 +5,28 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import trange
from scipy.sparse import csc_matrix
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from pytorch_widedeep.losses import MSLELoss, RMSELoss, FocalLoss, RMSLELoss
from pytorch_widedeep.models import WideDeep
from pytorch_widedeep.metrics import Metric, MetricCallback, MultipleMetrics
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.callbacks import History, LRShedulerCallback, Callback, CallbackContainer
from pytorch_widedeep.callbacks import (
History,
Callback,
CallbackContainer,
LRShedulerCallback,
)
from pytorch_widedeep.initializers import Initializer, MultipleInitializer
from pytorch_widedeep.training._finetune import FineTune
from pytorch_widedeep.utils.general_utils import Alias
from pytorch_widedeep.training._wd_dataset import WideDeepDataset
from pytorch_widedeep.training.trainer_utils import (
Alias,
save_epoch_logs,
print_loss_and_metric,
)
from pytorch_widedeep.models.tabnet.tabnet_utils import create_explain_matrix
from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer
from pytorch_widedeep.training._multiple_transforms import MultipleTransforms
from pytorch_widedeep.training._loss_and_obj_aliases import (
......@@ -54,6 +64,7 @@ class Trainer:
alpha: float = 0.25,
gamma: float = 2,
lambda_sparse: float = 1e-3,
compute_feature_importance: bool = True,
verbose: int = 1,
seed: int = 1,
):
......@@ -248,8 +259,12 @@ class Trainer:
else:
self.model = model
#  Tabnet related set ups
self.compute_feature_importance = compute_feature_importance
if self.model.is_tabnet:
self.lambda_sparse = lambda_sparse
if self.model.is_tabnet and self.compute_feature_importance:
self.reducing_matrix = create_explain_matrix(model)
self.verbose = verbose
self.seed = seed
......@@ -268,19 +283,19 @@ class Trainer:
self.model.to(device)
@Alias("finetune", ["warmup"]) # noqa: C901
@Alias("finetune_epochs", ["warmup_epochs"])
@Alias("finetune_max_lr", ["warmup_max_lr"])
@Alias("finetune_deeptabular_gradual", ["warmup_deeptabular_gradual"])
@Alias("finetune_deeptabular_max_lr", ["warmup_deeptabular_max_lr"])
@Alias("finetune_deeptabular_layers", ["warmup_deeptabular_layers"])
@Alias("finetune_deeptext_gradual", ["warmup_deeptext_gradual"])
@Alias("finetune_deeptext_max_lr", ["warmup_deeptext_max_lr"])
@Alias("finetune_deeptext_layers", ["warmup_deeptext_layers"])
@Alias("finetune_deepimage_gradual", ["warmup_deepimage_gradual"])
@Alias("finetune_deepimage_max_lr", ["warmup_deepimage_max_lr"])
@Alias("finetune_deepimage_layers", ["warmup_deepimage_layers"])
@Alias("finetune_routine", ["warmup_routine"])
@Alias("finetune", "warmup") # noqa: C901
@Alias("finetune_epochs", "warmup_epochs")
@Alias("finetune_max_lr", "warmup_max_lr")
@Alias("finetune_deeptabular_gradual", "warmup_deeptabular_gradual")
@Alias("finetune_deeptabular_max_lr", "warmup_deeptabular_max_lr")
@Alias("finetune_deeptabular_layers", "warmup_deeptabular_layers")
@Alias("finetune_deeptext_gradual", "warmup_deeptext_gradual")
@Alias("finetune_deeptext_max_lr", "warmup_deeptext_max_lr")
@Alias("finetune_deeptext_layers", "warmup_deeptext_layers")
@Alias("finetune_deepimage_gradual", "warmup_deepimage_gradual")
@Alias("finetune_deepimage_max_lr", "warmup_deepimage_max_lr")
@Alias("finetune_deepimage_layers", "warmup_deepimage_layers")
@Alias("finetune_routine", "warmup_routine")
def fit( # noqa: C901
self,
X_wide: Optional[np.ndarray] = None,
......@@ -542,24 +557,16 @@ class Trainer:
for epoch in range(n_epochs):
epoch_logs: Dict[str, float] = {}
self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
self.train_running_loss = 0.0
with trange(train_steps, disable=self.verbose != 1) as t:
for batch_idx, (data, targett) in zip(t, train_loader):
t.set_description("epoch %i" % (epoch + 1))
score, train_loss = self._training_step(data, targett, batch_idx)
if score is not None:
t.set_postfix(
metrics={k: np.round(v, 4) for k, v in score.items()},
loss=train_loss,
)
else:
t.set_postfix(loss=train_loss)
score, train_loss = self._train_step(data, targett, batch_idx)
print_loss_and_metric(t, train_loss, score)
self.callback_container.on_batch_end(batch=batch_idx)
epoch_logs["train_loss"] = train_loss
if score is not None:
for k, v in score.items():
log_k = "_".join(["train", k])
epoch_logs[log_k] = v
epoch_logs = save_epoch_logs(epoch_logs, train_loss, score, "train")
if eval_set is not None and epoch % validation_freq == (
validation_freq - 1
):
......@@ -568,23 +575,17 @@ class Trainer:
for i, (data, targett) in zip(v, eval_loader):
v.set_description("valid")
score, val_loss = self._validation_step(data, targett, i)
if score is not None:
v.set_postfix(
metrics={k: np.round(v, 4) for k, v in score.items()},
loss=val_loss,
)
else:
v.set_postfix(loss=val_loss)
epoch_logs["val_loss"] = val_loss
if score is not None:
for k, v in score.items():
log_k = "_".join(["val", k])
epoch_logs[log_k] = v
print_loss_and_metric(v, val_loss, score)
epoch_logs = save_epoch_logs(epoch_logs, val_loss, score, "val")
self.callback_container.on_epoch_end(epoch, epoch_logs)
if self.early_stop:
self.callback_container.on_train_end(epoch_logs)
break
self.callback_container.on_train_end(epoch_logs)
if self.compute_feature_importance and self.model.is_tabnet:
self._compute_feature_importance(train_loader)
self.model.train()
def predict( # type: ignore[return]
......@@ -727,6 +728,65 @@ class Trainer:
cat_embed_dict[value] = embed_mtx[idx]
return cat_embed_dict
def explain(self, X_tab: np.ndarray, save_step_masks: bool = False):
"""
Returns the aggregated feature importance for each instance (or
observation) in the ``X_tab`` array. If ``save_step_masks`` is set to
``True``, the masks per step will also be returned.
Parameters
----------
X_tab: np.ndarray
Input array corresponding **only** to the deeptabular component
save_step_masks: bool
Boolean indicating if the masks per step will be returned
Returns
-------
res: np.ndarray, Tuple
Array or Tuple of two arrays with the corresponding aggregated
feature importance and the masks per step if ``save_step_masks``
is set to ``True``
"""
loader = DataLoader(
dataset=WideDeepDataset(**{"X_tab": X_tab}),
batch_size=self.batch_size,
num_workers=n_cpus,
shuffle=False,
)
self.model.eval()
tabnet_backbone = list(self.model.deeptabular.children())[0]
m_explain_l = []
for batch_nb, data in enumerate(loader):
X = data["deeptabular"].to(device)
M_explain, masks = tabnet_backbone.forward_masks(X)
m_explain_l.append(
csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix)
)
if save_step_masks:
for key, value in masks.items():
masks[key] = csc_matrix.dot(
value.cpu().detach().numpy(), self.reducing_matrix
)
if batch_nb == 0:
m_explain_step = masks
else:
for key, value in masks.items():
m_explain_step[key] = np.vstack([m_explain_step[key], value])
m_explain_agg = np.vstack(m_explain_l)
m_explain_agg_norm = m_explain_agg / m_explain_agg.sum(axis=1)[:, np.newaxis]
res = (
(m_explain_agg_norm, m_explain_step)
if save_step_masks
else np.vstack(m_explain_agg_norm)
)
return res
def save_model(self, path: str):
"""Saves the model to disk
......@@ -932,7 +992,7 @@ class Trainer:
self.model.deepimage, "deepimage", loader, n_epochs, max_lr
)
def _training_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
def _train_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
self.model.train()
X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
y = target.view(-1, 1).float() if self.method != "multiclass" else target
......@@ -987,6 +1047,24 @@ class Trainer:
else:
return None
def _compute_feature_importance(self, loader: DataLoader):
self.model.eval()
tabnet_backbone = list(self.model.deeptabular.children())[0]
feat_imp = np.zeros((tabnet_backbone.embed_and_cont_dim))
for data, target in loader:
X = data["deeptabular"].to(device)
y = target.view(-1, 1).float() if self.method != "multiclass" else target
y = y.to(device)
M_explain, masks = tabnet_backbone.forward_masks(X)
feat_imp += M_explain.sum(dim=0).cpu().detach().numpy()
feat_imp = csc_matrix.dot(feat_imp, self.reducing_matrix)
feat_imp = feat_imp / np.sum(feat_imp)
self.feature_importance = {
k: v for k, v in zip(tabnet_backbone.column_idx.keys(), feat_imp)
}
def _predict(
self,
X_wide: Optional[np.ndarray] = None,
......
"""
Code here taken from the one and only Hunter McGushion and his fantastic
library: https://github.com/HunterMcGushion/hyperparameter_hunter
Code for 'Alias' and 'set_default_attr' taken from the one and only Hunter
McGushion and his library:
https://github.com/HunterMcGushion/hyperparameter_hunter
"""
import numpy as np
import wrapt
from tqdm import tqdm
from pytorch_widedeep.wdtypes import Any, Dict, List, Union
class Alias:
def __init__(self, primary_name, aliases):
def __init__(self, primary_name: str, aliases: Union[str, List[str]]):
"""Convert uses of `aliases` to `primary_name` upon calling the decorated function/method
Parameters
......@@ -56,7 +61,7 @@ class Alias:
return wrapped(*args, **kwargs)
def set_default_attr(obj, name, value):
def set_default_attr(obj: Any, name: str, value: Any):
"""Set the `name` attribute of `obj` to `value` if the attribute does not already exist
Parameters
......@@ -88,3 +93,52 @@ def set_default_attr(obj, name, value):
except AttributeError:
setattr(obj, name, value)
return value
def print_loss_and_metric(pb: tqdm, loss: float, score: Dict):
"""
Little function to improve readability and avoid code repetition in the
training/validation loop within the Trainer's fit method
Parameters
----------
pb: tqdm
tqdm Object defined as trange(...)
loss: float
loss value
score: Dict
Dictionary where the keys are the metric names and the values are the
corresponding values
"""
if score is not None:
pb.set_postfix(
metrics={k: np.round(v, 4) for k, v in score.items()},
loss=loss,
)
else:
pb.set_postfix(loss=loss)
def save_epoch_logs(epoch_logs: Dict, loss: float, score: Dict, stage: str):
"""
Little function to improve readability and avoid code repetition in the
training/validation loop within the Trainer's fit method
Parameters
----------
epoch_logs: Dict
Dict containing the epoch logs
loss: float
loss value
score: Dict
Dictionary where the keys are the metric names and the values are the
corresponding values
stage: str
one of 'train' or 'val'
"""
epoch_logs["_".join([stage, "loss"])] = loss
if score is not None:
for k, v in score.items():
log_k = "_".join([stage, k])
epoch_logs[log_k] = v
return epoch_logs
......@@ -49,6 +49,7 @@ from torchvision.transforms import (
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data.dataloader import DataLoader
from pytorch_widedeep.models import WideDeep
from pytorch_widedeep.models.tabnet.sparsemax import Entmax15, Sparsemax
ListRules = Collection[Callable[[str], str]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册