提交 1690395b 编写于 作者: J Javier Rodriguez Zaurin

Re-distributed the code. Added encoder and decoder self supervision for non-attention based models

上级 cdd674ed
from itertools import product
import numpy as np
import torch
import pandas as pd
from pytorch_widedeep.models import (
SAINT,
TabPerceiver,
FTTransformer,
TabFastFormer,
TabTransformer,
SelfAttentionMLP,
ContextAttentionMLP,
)
from pytorch_widedeep.datasets import load_adult
from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.self_supervised_training.self_supervised_trainer import (
SelfSupervisedTrainer,
from pytorch_widedeep.self_supervised_training.contrastive_denoising_trainer import (
ContrastiveDenoisingTrainer,
)
use_cuda = torch.cuda.is_available()
......@@ -55,53 +58,100 @@ if __name__ == "__main__":
target = "income_label"
target = df[target].values
tab_preprocessor = TabPreprocessor(
cat_embed_cols=cat_embed_cols,
continuous_cols=continuous_cols,
with_attention=True,
with_cls_token=True,
)
X_tab = tab_preprocessor.fit_transform(df)
tab_transformer = TabTransformer(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=continuous_cols,
embed_continuous=True,
n_blocks=4,
)
transformer_models = [
"tab_transformer",
"saint",
"tab_fastformer",
"ft_transformer",
]
with_cls_token = [True, False]
saint = SAINT(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=continuous_cols,
cont_norm_layer="batchnorm",
n_blocks=4,
)
for w_cls_tok, transf_model in product(with_cls_token, transformer_models):
tab_fastformer = TabFastFormer(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=continuous_cols,
n_blocks=4,
n_heads=4,
share_qv_weights=False,
share_weights=False,
)
processor = TabPreprocessor(
cat_embed_cols=cat_embed_cols,
continuous_cols=continuous_cols,
with_attention=True,
with_cls_token=w_cls_tok,
)
X_tab = processor.fit_transform(df)
ft_transformer = FTTransformer(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=continuous_cols,
input_dim=32,
kv_compression_factor=0.5,
n_blocks=3,
n_heads=4,
)
if transf_model == "tab_transformer":
model = TabTransformer(
column_idx=processor.column_idx,
cat_embed_input=processor.cat_embed_input,
continuous_cols=continuous_cols,
embed_continuous=True,
n_blocks=4,
)
if transf_model == "saint":
model = SAINT(
column_idx=processor.column_idx,
cat_embed_input=processor.cat_embed_input,
continuous_cols=continuous_cols,
cont_norm_layer="batchnorm",
n_blocks=4,
)
if transf_model == "tab_fastformer":
model = TabFastFormer(
column_idx=processor.column_idx,
cat_embed_input=processor.cat_embed_input,
continuous_cols=continuous_cols,
n_blocks=4,
n_heads=4,
share_qv_weights=False,
share_weights=False,
)
if transf_model == "ft_transformer":
model = FTTransformer(
column_idx=processor.column_idx,
cat_embed_input=processor.cat_embed_input,
continuous_cols=continuous_cols,
input_dim=32,
kv_compression_factor=0.5,
n_blocks=3,
n_heads=4,
)
ss_trainer = ContrastiveDenoisingTrainer(
base_model=model,
preprocessor=processor,
)
ss_trainer.pretrain(X_tab, n_epochs=1, batch_size=256)
mlp_attn_model = ["context_attention", "self_attention"]
for w_cls_tok, attn_model in product(with_cls_token, mlp_attn_model):
processor = TabPreprocessor(
cat_embed_cols=cat_embed_cols,
continuous_cols=continuous_cols,
with_attention=True,
with_cls_token=w_cls_tok,
)
X_tab = processor.fit_transform(df)
if attn_model == "context_attention":
model = ContextAttentionMLP(
column_idx=processor.column_idx,
cat_embed_input=processor.cat_embed_input,
continuous_cols=continuous_cols,
input_dim=16,
attn_dropout=0.2,
n_blocks=3,
)
if attn_model == "self_attention":
model = SelfAttentionMLP(
column_idx=processor.column_idx,
cat_embed_input=processor.cat_embed_input,
continuous_cols=continuous_cols,
input_dim=16,
attn_dropout=0.2,
n_blocks=3,
)
for transformer_model in [tab_transformer, saint, tab_fastformer, ft_transformer]:
ss_trainer = SelfSupervisedTrainer(
model=transformer_model,
preprocessor=tab_preprocessor,
ss_trainer = ContrastiveDenoisingTrainer(
base_model=model,
preprocessor=processor,
)
ss_trainer.pretrain(X_tab, n_epochs=1, batch_size=256)
......@@ -901,3 +901,29 @@ class DenoisingLoss(nn.Module):
loss_cont += F.mse_loss(x_, x, reduction=self.reduction)
return loss_cont
class EncoderDecoderLoss(object):
def __init__(self, eps=1e-9):
super(EncoderDecoderLoss, self).__init__()
self.eps = eps
def forward(x_true, x_pred, mask):
errors = x_pred - x_true
reconstruction_errors = torch.mul(errors, mask) ** 2
x_true_means = torch.mean(x_true, dim=0)
x_true_means[x_true_means == 0] = 1
x_true_stds = torch.std(x_true, dim=0) ** 2
x_true_stds[x_true_stds == 0] = x_true_means[x_true_stds == 0]
features_loss = torch.matmul(reconstruction_errors, 1 / x_true_stds)
nb_reconstructed_variables = torch.sum(mask, dim=1)
features_loss_norm = features_loss / (nb_reconstructed_variables + eps)
loss = torch.mean(features_loss_norm)
return loss
from torch import nn
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.models.tabular.mlp._layers import MLP
from pytorch_widedeep.models.tabular._base_tabular_model import (
......@@ -152,3 +154,45 @@ class TabMlp(BaseTabularModelWithoutAttention):
@property
def output_dim(self):
return self.mlp_hidden_dims[-1]
# This is a companion Decoder for the TabMlp. We prefer not to refer to the
# 'TabMlp' model as 'TabMlpEncoder' (despite the fact that is indeed an
# encoder) for two reasons: 1. For convenience accross the library and 2.
# Because decoders are only going to be used in one of our implementations
# of Self Supervised pretraining, and we prefer to keep the names of
# the 'general' DL models as they are (e.g. TabMlp) as opposed as carry
# the 'Encoder' description (e.g. TabMlpEncoder) throughout the library
class TabMlpDecoder(nn.Module):
def __init__(
self,
embed_dim: int,
mlp_hidden_dims: List[int] = [100, 200],
mlp_activation: str = "relu",
mlp_dropout: Union[float, List[float]] = 0.1,
mlp_batchnorm: bool = False,
mlp_batchnorm_last: bool = False,
mlp_linear_first: bool = False,
):
super(TabMlpDecoder, self).__init__()
self.embed_dim = embed_dim
self.mlp_hidden_dims = mlp_hidden_dims
self.mlp_activation = mlp_activation
self.mlp_dropout = mlp_dropout
self.mlp_batchnorm = mlp_batchnorm
self.mlp_batchnorm_last = mlp_batchnorm_last
self.mlp_linear_first = mlp_linear_first
self.decoder = MLP(
mlp_hidden_dims + [self.embed_dim],
mlp_activation,
mlp_dropout,
mlp_batchnorm,
mlp_batchnorm_last,
mlp_linear_first,
)
def forward(self, X: Tensor) -> Tensor:
return self.decoder(X)
from torch import nn
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.models.tabular.mlp._layers import MLP
from pytorch_widedeep.models.tabular.resnet._layers import DenseResnet
......@@ -204,3 +206,65 @@ class TabResnet(BaseTabularModelWithoutAttention):
if self.mlp_hidden_dims is not None
else self.blocks_dims[-1]
)
class TabResnetDecoder(nn.Module):
def __init__(
self,
embed_dim: int,
blocks_dims: List[int] = [100, 100, 200],
blocks_dropout: float = 0.1,
simplify_blocks: bool = False,
mlp_hidden_dims: Optional[List[int]] = None,
mlp_activation: str = "relu",
mlp_dropout: float = 0.1,
mlp_batchnorm: bool = False,
mlp_batchnorm_last: bool = False,
mlp_linear_first: bool = False,
):
super(TabResnetDecoder, self).__init__()
if len(blocks_dims) < 2:
raise ValueError(
"'blocks' must contain at least two elements, e.g. [256, 128]"
)
self.embed_dim = embed_dim
self.blocks_dims = blocks_dims
self.blocks_dropout = blocks_dropout
self.simplify_blocks = simplify_blocks
self.mlp_hidden_dims = mlp_hidden_dims
self.mlp_activation = mlp_activation
self.mlp_dropout = mlp_dropout
self.mlp_batchnorm = mlp_batchnorm
self.mlp_batchnorm_last = mlp_batchnorm_last
self.mlp_linear_first = mlp_linear_first
if self.mlp_hidden_dims is not None:
self.mlp = MLP(
mlp_hidden_dims,
mlp_activation,
mlp_dropout,
mlp_batchnorm,
mlp_batchnorm_last,
mlp_linear_first,
)
else:
self.mlp = None
if self.mlp is not None:
self.decoder = DenseResnet(
mlp_hidden_dims[-1], blocks_dims, blocks_dropout, self.simplify_blocks
)
else:
self.decoder = DenseResnet(
blocks_dims[0], blocks_dims, blocks_dropout, self.simplify_blocks
)
self.reconstruction_layer = nn.Linear(blocks_dims[-1], embed_dim, bias=False)
def forward(self, X: Tensor) -> Tensor:
x = self.mlp(X) if self.mlp is not None else X
return self.reconstruction_layer(self.decoder(x))
import torch
from torch import nn
class RandomObfuscator(nn.Module):
def __init__(self, p):
super(RandomObfuscator, self).__init__()
self.p = p
def forward(self, x):
mask = torch.bernoulli(self.p * torch.ones(x.shape)).to(x.device)
masked_input = torch.mul(1 - mask, x)
return masked_input, mask
......@@ -2,22 +2,22 @@ from torch import Tensor, nn
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.models.tabular.mlp._layers import MLP
from pytorch_widedeep.self_supervised_training._denoise_mlps import (
from pytorch_widedeep.models.tabular.self_supervised._denoise_mlps import (
CatSingleMlp,
ContSingleMlp,
CatFeaturesMlp,
ContFeaturesMlp,
)
from pytorch_widedeep.self_supervised_training._augmentations import (
from pytorch_widedeep.models.tabular.self_supervised._augmentations import (
mix_up,
cut_mix,
)
class SelfSupervisedModel(nn.Module):
class ContrastiveDenoisingModel(nn.Module):
def __init__(
self,
model: nn.Module,
model: ModelWithAttention,
encoding_dict: Dict[str, Dict[str, int]],
loss_type: Literal["contrastive", "denoising", "both"],
projection_head1_dims: Optional[List],
......@@ -27,7 +27,7 @@ class SelfSupervisedModel(nn.Module):
cont_mlp_type: Literal["single", "multiple"],
denoise_mlps_activation: str,
):
super(SelfSupervisedModel, self).__init__()
super(ContrastiveDenoisingModel, self).__init__()
self.model = model
self.loss_type = loss_type
......@@ -55,19 +55,22 @@ class SelfSupervisedModel(nn.Module):
Optional[Tuple[Tensor, Tensor]],
]:
# "uncorrupted branch"
# "uncorrupted" branch
embed = self.model._get_embeddings(X)
if self.model.with_cls_token:
embed[:, 0, :] = 0.0
encoded = self.model.encoder(embed)
# cut_mix and mix_up branch
cut_mixed = cut_mix(X)
cut_mixed_embed = self.model._get_embeddings(cut_mixed)
if self.model.with_cls_token:
cut_mixed_embed[:, 0, :] = 0.0
cut_mixed_embed_mixed_up = mix_up(cut_mixed_embed)
encoded_ = self.model.encoder(cut_mixed_embed_mixed_up)
# cut_mixed and mixed_up branch
if self.training:
cut_mixed = cut_mix(X)
cut_mixed_embed = self.model._get_embeddings(cut_mixed)
if self.model.with_cls_token:
cut_mixed_embed[:, 0, :] = 0.0
cut_mixed_embed_mixed_up = mix_up(cut_mixed_embed)
encoded_ = self.model.encoder(cut_mixed_embed_mixed_up)
else:
encoded_ = encoded.clone()
# projections for constrastive loss
if self.loss_type in ["contrastive", "both"]:
......
from torch import Tensor, nn
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.models.tabular.self_supervised._random_obfuscator import (
RandomObfuscator,
)
class EncoderDecoderModel(nn.Module):
def __init__(
self,
encoder: ModelWithoutAttention,
decoder: nn.Module,
masked_prob: float,
):
super(EncoderDecoderModel, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.masker = RandomObfuscator(p=masked_prob)
def forward(self, X: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
if self.encoder.is_tabnet:
return self._forward_tabnet(X)
else:
return self._forward(X)
def _forward(self, X: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
x_embed = self.encoder._get_embeddings(X)
if self.training:
masked_x, mask = self.masker(x_embed)
x_enc = self.encoder(X)
x_embed_rec = self.decoder(x_enc)
else:
x_embed_rec = self.decoder(self.encoder(X))
mask = torch.ones(x_embed.shape).to(X.device)
return x_embed, x_embed_rec, mask
def _forward_tabnet(self, X: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
x_embed = self.encoder._get_embeddings(X)
if self.training:
masked_x, mask = self.masker(x_embed)
prior = 1 - mask
steps_out, _ = self.encoder(masked_x, prior=prior)
x_embed_rec = self.decoder(steps_out)
else:
steps_out, _ = self.encoder(x_embed)
x_embed_rec = self.decoder(steps_out)
mask = torch.ones(x_embed.shape).to(X.device)
return x_embed_rec, x_embed, mask
......@@ -294,11 +294,14 @@ class TabNetEncoder(nn.Module):
self.feat_transformers.append(feat_transformer)
self.attn_transformers.append(attn_transformer)
def forward(self, X: Tensor) -> Tuple[List[Tensor], Tensor]:
def forward(
self, X: Tensor, prior: Optional[Tensor] = None
) -> Tuple[List[Tensor], Tensor]:
x = self.initial_bn(X)
# P[n_step = 0] is initialized as all ones, 1^(B×D)
prior = torch.ones(x.shape).to(x.device)
if prior is None:
# P[n_step = 0] is initialized as all ones, 1^(B×D)
prior = torch.ones(x.shape).to(x.device)
# sparsity regularization
M_loss = torch.FloatTensor([0.0]).to(x.device)
......
......@@ -4,6 +4,7 @@ from torch import nn
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.models.tabular.tabnet._layers import (
TabNetEncoder,
FeatTransformer,
initialize_non_glu,
)
from pytorch_widedeep.models.tabular._base_tabular_model import (
......@@ -189,9 +190,11 @@ class TabNet(BaseTabularModelWithoutAttention):
mask_type,
)
def forward(self, X: Tensor) -> Tuple[Tensor, Tensor]:
def forward(
self, X: Tensor, prior: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor]:
x = self._get_embeddings(X)
steps_output, M_loss = self.encoder(x)
steps_output, M_loss = self.encoder(x, prior)
res = torch.sum(torch.stack(steps_output, dim=0), dim=0)
return (res, M_loss)
......@@ -223,3 +226,78 @@ class TabNetPredLayer(nn.Module):
def forward(self, tabnet_tuple: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
res, M_loss = tabnet_tuple[0], tabnet_tuple[1]
return self.pred_layer(res), M_loss
class TabNetDecoder(nn.Module):
def __init__(
self,
embed_dim: int,
n_steps: int = 3,
step_dim: int = 8,
attn_dim: int = 8,
dropout: float = 0.0,
n_glu_step_dependent: int = 2,
n_glu_shared: int = 2,
ghost_bn: bool = True,
virtual_batch_size: int = 128,
momentum: float = 0.02,
gamma: float = 1.3,
epsilon: float = 1e-15,
mask_type: str = "sparsemax",
):
super(TabNetDecoder, self).__init__()
self.n_steps = n_steps
self.step_dim = step_dim
self.attn_dim = attn_dim
self.dropout = dropout
self.n_glu_step_dependent = n_glu_step_dependent
self.n_glu_shared = n_glu_shared
self.ghost_bn = ghost_bn
self.virtual_batch_size = virtual_batch_size
self.momentum = momentum
self.gamma = gamma
self.epsilon = epsilon
self.mask_type = mask_type
shared_layers = nn.ModuleList()
for i in range(n_glu_shared):
if i == 0:
shared_layers.append(
nn.Linear(embed_dim, 2 * (step_dim + attn_dim), bias=False)
)
else:
shared_layers.append(
nn.Linear(
step_dim + attn_dim, 2 * (step_dim + attn_dim), bias=False
)
)
self.feat_transformers = nn.ModuleList()
for step in range(n_steps):
transformer = FeatTransformer(
embed_dim,
embed_dim,
dropout,
shared_layers,
n_glu_step_dependent,
ghost_bn,
virtual_batch_size,
momentum=momentum,
)
self.feat_transformers.append(transformer)
self.reconstruction_layer = nn.Linear(step_dim, embed_dim, bias=False)
initialize_non_glu(self.reconstruction_layer, step_dim, embed_dim)
def forward(self, X):
out = 0.0
for i, x in enumerate(X):
x = self.feat_transformers[step_nb](x)
out = torch.add(out, x)
out = self.reconstruction_layer(out)
return out
def forward_masks(self, X: Tensor) -> Tuple[Tensor, Dict[int, Tensor]]:
x = self._get_embeddings(X)
return self.encoder.forward_masks(x)
......@@ -2,10 +2,12 @@ import os
import sys
from abc import ABC, abstractmethod
import numpy as np
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_widedeep.losses import InfoNCELoss, DenoisingLoss
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.wdtypes import * # noqa: F403; noqa: F403
from pytorch_widedeep.callbacks import (
History,
Callback,
......@@ -13,15 +15,15 @@ from pytorch_widedeep.callbacks import (
LRShedulerCallback,
)
from pytorch_widedeep.preprocessing.tab_preprocessor import TabPreprocessor
from pytorch_widedeep.self_supervised_training.self_supervised_model import (
SelfSupervisedModel,
from pytorch_widedeep.models.tabular.self_supervised.contrastive_denoising_model import (
ContrastiveDenoisingModel,
)
class BaseSelfSupervisedTrainer(ABC):
class BaseContrastiveDenoisingTrainer(ABC):
def __init__(
self,
model,
base_model: ModelWithAttention,
preprocessor: TabPreprocessor,
optimizer: Optional[Optimizer],
lr_scheduler: Optional[LRScheduler],
......@@ -38,8 +40,15 @@ class BaseSelfSupervisedTrainer(ABC):
**kwargs,
):
self.ss_model = SelfSupervisedModel(
model,
self._check_model_is_supported(base_model)
self.device, self.num_workers = self._set_device_and_num_workers(**kwargs)
self.early_stop = False
self.verbose = verbose
self.seed = seed
self.model = ContrastiveDenoisingModel(
base_model,
preprocessor.label_encoder.encoding_dict,
loss_type,
projection_head1_dims,
......@@ -49,31 +58,38 @@ class BaseSelfSupervisedTrainer(ABC):
cont_mlp_type,
denoise_mlps_activation,
)
self.device, self.num_workers = self._set_device_and_num_workers(**kwargs)
self.early_stop = False
self.ss_model.to(self.device)
self.model.to(self.device)
self.loss_type = loss_type
self._set_loss_fn(**kwargs)
self.verbose = verbose
self.seed = seed
self.optimizer = (
optimizer
if optimizer is not None
else torch.optim.AdamW(self.ss_model.parameters())
)
self.lr_scheduler = self._set_lr_scheduler_running_params(
lr_scheduler, **kwargs
else torch.optim.AdamW(self.model.parameters())
)
self.lr_scheduler = self._set_lr_scheduler_running_params(lr_scheduler)
self._set_callbacks(callbacks)
@abstractmethod
def pretrain(self):
pass
def pretrain(
self,
X_tab: np.ndarray,
X_val: Optional[np.ndarray],
val_split: Optional[float],
validation_freq: int,
n_epochs: int,
batch_size: int,
):
raise NotImplementedError("Trainer.pretrain method not implemented")
@abstractmethod
def save(
self,
path: str,
save_state_dict: bool,
model_filename: str,
):
raise NotImplementedError("Trainer.save method not implemented")
def _set_loss_fn(self, **kwargs):
......@@ -102,7 +118,29 @@ class BaseSelfSupervisedTrainer(ABC):
return contrastive_loss + denoising_loss
def _set_reduce_on_plateau_criterion(
self, lr_scheduler, reducelronplateau_criterion
):
self.reducelronplateau = False
if isinstance(lr_scheduler, ReduceLROnPlateau):
self.reducelronplateau = True
if self.reducelronplateau and not reducelronplateau_criterion:
UserWarning(
"The learning rate scheduler is of type ReduceLROnPlateau. The step method in this"
" scheduler requires a 'metrics' param that can be either the validation loss or the"
" validation metric. Please, when instantiating the Trainer, specify which quantity"
" will be tracked using reducelronplateau_criterion = 'loss' (default) or"
" reducelronplateau_criterion = 'metric'"
)
else:
self.reducelronplateau_criterion = "loss"
def _set_lr_scheduler_running_params(self, lr_scheduler, **kwargs):
reducelronplateau_criterion = kwargs.get("reducelronplateau_criterion", None)
self._set_reduce_on_plateau_criterion(lr_scheduler, reducelronplateau_criterion)
if lr_scheduler is not None:
self.cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower()
else:
......@@ -116,7 +154,7 @@ class BaseSelfSupervisedTrainer(ABC):
callback = callback()
self.callbacks.append(callback)
self.callback_container = CallbackContainer(self.callbacks)
self.callback_container.set_model(self.ss_model)
self.callback_container.set_model(self.model)
self.callback_container.set_trainer(self)
def _restore_best_weights(self):
......@@ -139,7 +177,7 @@ class BaseSelfSupervisedTrainer(ABC):
print(
f"Model weights restored to best epoch: {callback.best_epoch + 1}"
)
self.ss_model.load_state_dict(callback.best_state_dict)
self.model.load_state_dict(callback.best_state_dict)
else:
if self.verbose:
print(
......@@ -160,3 +198,16 @@ class BaseSelfSupervisedTrainer(ABC):
device = kwargs.get("device", default_device)
num_workers = kwargs.get("num_workers", default_num_workers)
return device, num_workers
@staticmethod
def _check_model_is_supported(model: ModelWithAttention):
if model.__class__.__name__ == "TabPerceiver":
raise ValueError(
"Self-Supervised pretraining is not supported for the 'TabPerceiver'"
)
if model.__class__.__name__ == "TabTransformer" and not model.embed_continuous:
raise ValueError(
"Self-Supervised pretraining is only supported if both categorical and "
"continuum columns are embedded. Please set 'embed_continuous = True'"
)
import os
import sys
from abc import ABC, abstractmethod
import numpy as np
import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_widedeep.losses import EncoderDecoderLoss
from pytorch_widedeep.wdtypes import * # noqa: F403; noqa: F403
from pytorch_widedeep.callbacks import (
History,
Callback,
CallbackContainer,
LRShedulerCallback,
)
from pytorch_widedeep.self_supervised_training.self_supervised_models import (
EncoderDecoderModel,
)
class BaseEncoderDecoderTrainer(ABC):
def __init__(
self,
encoder: ModelWithoutAttention,
decoder: nn.Module,
masked_prob: float,
optimizer: Optional[Optimizer],
lr_scheduler: Optional[LRScheduler],
callbacks: Optional[List[Callback]],
verbose: int,
seed: int,
**kwargs,
):
# self._check_model_is_supported(encoder)
self.device, self.num_workers = self._set_device_and_num_workers(**kwargs)
self.early_stop = False
self.verbose = verbose
self.seed = seed
self.model = EncoderDecoderModel(
encoder,
decoder,
masked_prob,
)
self.model.to(self.device)
self.loss_fn = EncoderDecoderLoss()
self.optimizer = (
optimizer
if optimizer is not None
else torch.optim.AdamW(self.model.parameters())
)
self.lr_scheduler = self._set_lr_scheduler_running_params(lr_scheduler)
self._set_callbacks(callbacks)
@abstractmethod
def pretrain(
self,
X_tab: np.ndarray,
X_val: Optional[np.ndarray],
val_split: Optional[float],
validation_freq: int,
n_epochs: int,
batch_size: int,
):
raise NotImplementedError("Trainer.pretrain method not implemented")
@abstractmethod
def save(
self,
path: str,
save_state_dict: bool,
model_filename: str,
):
raise NotImplementedError("Trainer.save method not implemented")
def _set_reduce_on_plateau_criterion(
self, lr_scheduler, reducelronplateau_criterion
):
self.reducelronplateau = False
if isinstance(lr_scheduler, ReduceLROnPlateau):
self.reducelronplateau = True
if self.reducelronplateau and not reducelronplateau_criterion:
UserWarning(
"The learning rate scheduler is of type ReduceLROnPlateau. The step method in this"
" scheduler requires a 'metrics' param that can be either the validation loss or the"
" validation metric. Please, when instantiating the Trainer, specify which quantity"
" will be tracked using reducelronplateau_criterion = 'loss' (default) or"
" reducelronplateau_criterion = 'metric'"
)
else:
self.reducelronplateau_criterion = "loss"
def _set_lr_scheduler_running_params(self, lr_scheduler, **kwargs):
reducelronplateau_criterion = kwargs.get("reducelronplateau_criterion", None)
self._set_reduce_on_plateau_criterion(lr_scheduler, reducelronplateau_criterion)
if lr_scheduler is not None:
self.cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower()
else:
self.cyclic_lr = False
def _set_callbacks(self, callbacks):
self.callbacks: List = [History(), LRShedulerCallback()]
if callbacks is not None:
for callback in callbacks:
if isinstance(callback, type):
callback = callback()
self.callbacks.append(callback)
self.callback_container = CallbackContainer(self.callbacks)
self.callback_container.set_model(self.model)
self.callback_container.set_trainer(self)
def _restore_best_weights(self):
already_restored = any(
[
(
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
)
for callback in self.callback_container.callbacks
]
)
if already_restored:
pass
else:
for callback in self.callback_container.callbacks:
if callback.__class__.__name__ == "ModelCheckpoint":
if callback.save_best_only:
if self.verbose:
print(
f"Model weights restored to best epoch: {callback.best_epoch + 1}"
)
self.model.load_state_dict(callback.best_state_dict)
else:
if self.verbose:
print(
"Model weights after training corresponds to the those of the "
"final epoch which might not be the best performing weights. Use "
"the 'ModelCheckpoint' Callback to restore the best epoch weights."
)
@staticmethod
def _set_device_and_num_workers(**kwargs):
default_num_workers = (
0
if sys.platform == "darwin" and sys.version_info.minor > 7
else os.cpu_count()
)
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = kwargs.get("device", default_device)
num_workers = kwargs.get("num_workers", default_num_workers)
return device, num_workers
import json
from pathlib import Path
import numpy as np
import torch
from tqdm import trange
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.callbacks import Callback
from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.training._trainer_utils import (
save_epoch_logs,
print_loss_and_metric,
)
from pytorch_widedeep.self_supervised_training._base_self_supervised_trainer import (
BaseSelfSupervisedTrainer,
from pytorch_widedeep.self_supervised_training._base_contrastive_denoising_trainer import (
BaseContrastiveDenoisingTrainer,
)
class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
class ContrastiveDenoisingTrainer(BaseContrastiveDenoisingTrainer):
def __init__(
self,
model,
preprocessor,
base_model: ModelWithAttention,
preprocessor: TabPreprocessor,
optimizer: Optional[Optimizer] = None,
lr_scheduler: Optional[LRScheduler] = None,
callbacks: Optional[List[Callback]] = None,
......@@ -34,7 +39,7 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
**kwargs,
):
super().__init__(
model=model,
base_model=base_model,
preprocessor=preprocessor,
loss_type=loss_type,
optimizer=optimizer,
......@@ -54,18 +59,28 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
def pretrain(
self,
X_tab: np.ndarray,
X_val: Optional[np.ndarray] = None,
val_split: Optional[float] = None,
validation_freq: int = 1,
n_epochs: int = 1,
batch_size: int = 32,
):
self.batch_size = batch_size
pretrain_loader = DataLoader(
dataset=TensorDataset(torch.from_numpy(X_tab)),
batch_size=batch_size,
num_workers=self.num_workers,
train_set, eval_set = self._train_eval_split(X_tab, X_val, val_split)
train_loader = DataLoader(
dataset=train_set, batch_size=batch_size, num_workers=self.num_workers
)
train_steps = len(pretrain_loader)
train_steps = len(train_loader)
if eval_set is not None:
eval_loader = DataLoader(
dataset=eval_set,
batch_size=batch_size,
num_workers=self.num_workers,
shuffle=False,
)
eval_steps = len(eval_loader)
self.callback_container.on_train_begin(
{
......@@ -80,13 +95,31 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
self.train_running_loss = 0.0
with trange(train_steps, disable=self.verbose != 1) as t:
for batch_idx, X in zip(t, pretrain_loader):
for batch_idx, X in zip(t, train_loader):
t.set_description("epoch %i" % (epoch + 1))
train_loss = self._pretrain_step(X[0], batch_idx)
train_loss = self._train_step(X[0], batch_idx)
self.callback_container.on_batch_end(batch=batch_idx)
print_loss_and_metric(t, train_loss)
epoch_logs = save_epoch_logs(epoch_logs, train_loss, None, "train")
if eval_set is not None and epoch % validation_freq == (
validation_freq - 1
):
self.callback_container.on_eval_begin()
self.valid_running_loss = 0.0
with trange(eval_steps, disable=self.verbose != 1) as v:
for batch_idx, X in zip(v, eval_loader):
v.set_description("valid")
val_loss = self._eval_step(X[0], batch_idx)
print_loss_and_metric(v, val_loss)
epoch_logs = save_epoch_logs(epoch_logs, val_loss, None, "val")
else:
if self.reducelronplateau:
raise NotImplementedError(
"ReduceLROnPlateau scheduler can be used only with validation data."
)
self.callback_container.on_epoch_end(epoch, epoch_logs)
if self.early_stop:
......@@ -95,14 +128,41 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
self.callback_container.on_train_end(epoch_logs)
self._restore_best_weights()
self.ss_model.train()
self.model.train()
def save(
self,
path: str,
save_state_dict: bool = False,
model_filename: str = "cd_model.pt",
):
save_dir = Path(path)
history_dir = save_dir / "history"
history_dir.mkdir(exist_ok=True, parents=True)
def _pretrain_step(self, X_tab: Tensor, batch_idx: int):
# the trainer is run with the History Callback by default
with open(history_dir / "train_eval_history.json", "w") as teh:
json.dump(self.history, teh) # type: ignore[attr-defined]
has_lr_history = any(
[clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks]
)
if self.lr_scheduler is not None and has_lr_history:
with open(history_dir / "lr_history.json", "w") as lrh:
json.dump(self.lr_history, lrh) # type: ignore[attr-defined]
model_path = save_dir / model_filename
if save_state_dict:
torch.save(self.model.state_dict(), model_path)
else:
torch.save(self.model, model_path)
def _train_step(self, X_tab: Tensor, batch_idx: int):
X = X_tab.to(self.device)
self.optimizer.zero_grad()
g_projs, cat_x_and_x_, cont_x_and_x_ = self.ss_model(X)
g_projs, cat_x_and_x_, cont_x_and_x_ = self.model(X)
loss = self._compute_loss(g_projs, cat_x_and_x_, cont_x_and_x_)
loss.backward()
self.optimizer.step()
......@@ -111,3 +171,40 @@ class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
avg_loss = self.train_running_loss / (batch_idx + 1)
return avg_loss
def _eval_step(self, X_tab: Tensor, batch_idx: int):
self.model.eval()
with torch.no_grad():
X = X_tab.to(self.device)
g_projs, cat_x_and_x_, cont_x_and_x_ = self.model(X)
loss = self._compute_loss(g_projs, cat_x_and_x_, cont_x_and_x_)
self.valid_running_loss += loss.item()
avg_loss = self.valid_running_loss / (batch_idx + 1)
return avg_loss
def _train_eval_split(
self,
X: np.ndarray,
X_val: Optional[np.ndarray] = None,
val_split: Optional[float] = None,
):
if X_val is not None:
train_set = TensorDataset(torch.from_numpy(X))
eval_set = TensorDataset(torch.from_numpy(X_val))
elif val_split is not None:
X_tr, X_val = train_test_split(
X, test_size=val_split, random_state=self.seed
)
train_set = TensorDataset(torch.from_numpy(X_tr))
eval_set = TensorDataset(torch.from_numpy(X_val))
else:
train_set = TensorDataset(torch.from_numpy(X))
eval_set = None
return train_set, eval_set
import json
from pathlib import Path
import numpy as np
import torch
from tqdm import trange
from torch import nn
from scipy.sparse import csc_matrix
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.callbacks import Callback
from pytorch_widedeep.training._trainer_utils import (
save_epoch_logs,
print_loss_and_metric,
)
from pytorch_widedeep.self_supervised_training._base_encoder_decoder_trainer import (
BaseEncoderDecoderTrainer,
)
class EncoderDecoderTrainer(BaseEncoderDecoderTrainer):
def __init__(
self,
encoder: ModelWithoutAttention,
decoder: nn.Module,
masked_prob: float,
optimizer: Optional[Optimizer],
lr_scheduler: Optional[LRScheduler],
callbacks: Optional[List[Callback]],
verbose: int,
seed: int,
**kwargs,
):
super().__init__(
encoder=encoder,
decoder=decoder,
masked_prob=masked_prob,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
callbacks=callbacks,
verbose=verbose,
seed=seed,
**kwargs,
)
def pretrain(
self,
X_tab: np.ndarray,
X_val: Optional[np.ndarray] = None,
val_split: Optional[float] = None,
validation_freq: int = 1,
n_epochs: int = 1,
batch_size: int = 32,
):
self.batch_size = batch_size
train_set, eval_set = self._train_eval_split(X_tab, X_val, val_split)
train_loader = DataLoader(
dataset=train_set, batch_size=batch_size, num_workers=self.num_workers
)
train_steps = len(train_loader)
if eval_set is not None:
eval_loader = DataLoader(
dataset=eval_set,
batch_size=batch_size,
num_workers=self.num_workers,
shuffle=False,
)
eval_steps = len(eval_loader)
self.callback_container.on_train_begin(
{
"batch_size": batch_size,
"train_steps": train_steps,
"n_epochs": n_epochs,
}
)
for epoch in range(n_epochs):
epoch_logs: Dict[str, float] = {}
self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
self.train_running_loss = 0.0
with trange(train_steps, disable=self.verbose != 1) as t:
for batch_idx, X in zip(t, train_loader):
t.set_description("epoch %i" % (epoch + 1))
train_loss = self._train_step(X[0], batch_idx)
self.callback_container.on_batch_end(batch=batch_idx)
print_loss_and_metric(t, train_loss)
epoch_logs = save_epoch_logs(epoch_logs, train_loss, None, "train")
if eval_set is not None and epoch % validation_freq == (
validation_freq - 1
):
self.callback_container.on_eval_begin()
self.valid_running_loss = 0.0
with trange(eval_steps, disable=self.verbose != 1) as v:
for batch_idx, X in zip(v, eval_loader):
v.set_description("valid")
val_loss = self._eval_step(X[0], batch_idx)
print_loss_and_metric(v, val_loss)
epoch_logs = save_epoch_logs(epoch_logs, val_loss, None, "val")
else:
if self.reducelronplateau:
raise NotImplementedError(
"ReduceLROnPlateau scheduler can be used only with validation data."
)
self.callback_container.on_epoch_end(epoch, epoch_logs)
if self.early_stop:
self.callback_container.on_train_end(epoch_logs)
break
self.callback_container.on_train_end(epoch_logs)
self._restore_best_weights()
self.model.train()
def save(
self,
path: str,
save_state_dict: bool = False,
model_filename: str = "ed_model.pt",
):
save_dir = Path(path)
history_dir = save_dir / "history"
history_dir.mkdir(exist_ok=True, parents=True)
# the trainer is run with the History Callback by default
with open(history_dir / "train_eval_history.json", "w") as teh:
json.dump(self.history, teh) # type: ignore[attr-defined]
has_lr_history = any(
[clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks]
)
if self.lr_scheduler is not None and has_lr_history:
with open(history_dir / "lr_history.json", "w") as lrh:
json.dump(self.lr_history, lrh) # type: ignore[attr-defined]
model_path = save_dir / model_filename
if save_state_dict:
torch.save(self.model.state_dict(), model_path)
else:
torch.save(self.model, model_path)
def explain(self, X_tab: np.ndarray, save_step_masks: bool = False):
# TO DO: Adjust this to Self Supervised (e.g. no need of data["deeptabular"])
loader = DataLoader(
dataset=self._train_eval_split(X_tab),
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
self.model.eval()
tabnet_backbone = list(self.encoder.children())[0]
m_explain_l = []
for batch_nb, data in enumerate(loader):
X = data["deeptabular"].to(self.device)
M_explain, masks = tabnet_backbone.forward_masks(X)
m_explain_l.append(
csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix)
)
if save_step_masks:
for key, value in masks.items():
masks[key] = csc_matrix.dot(
value.cpu().detach().numpy(), self.reducing_matrix
)
if batch_nb == 0:
m_explain_step = masks
else:
for key, value in masks.items():
m_explain_step[key] = np.vstack([m_explain_step[key], value])
m_explain_agg = np.vstack(m_explain_l)
m_explain_agg_norm = m_explain_agg / m_explain_agg.sum(axis=1)[:, np.newaxis]
res = (
(m_explain_agg_norm, m_explain_step)
if save_step_masks
else np.vstack(m_explain_agg_norm)
)
return res
def _train_step(self, X_tab: Tensor, batch_idx: int):
X = X_tab.to(self.device)
self.optimizer.zero_grad()
x_embed, x_embed_rec, mask = self.model(X)
loss = self.loss_fn(x_embed, x_embed_rec, mask)
loss.backward()
self.optimizer.step()
self.train_running_loss += loss.item()
avg_loss = self.train_running_loss / (batch_idx + 1)
return avg_loss
def _eval_step(self, X_tab: Tensor, batch_idx: int):
self.model.eval()
with torch.no_grad():
X = X_tab.to(self.device)
x_embed, x_embed_rec, mask = self.model(X)
loss = self.loss_fn(x_embed, x_embed_rec, mask)
self.valid_running_loss += loss.item()
avg_loss = self.valid_running_loss / (batch_idx + 1)
return avg_loss
def _train_eval_split(
self,
X: np.ndarray,
X_val: Optional[np.ndarray] = None,
val_split: Optional[float] = None,
):
if X_val is not None:
train_set = TensorDataset(torch.from_numpy(X))
eval_set = TensorDataset(torch.from_numpy(X_val))
elif val_split is not None:
X_tr, X_val = train_test_split(
X, test_size=val_split, random_state=self.seed
)
train_set = TensorDataset(torch.from_numpy(X_tr))
eval_set = TensorDataset(torch.from_numpy(X_val))
else:
train_set = TensorDataset(torch.from_numpy(X))
eval_set = None
return train_set, eval_set
import os
import sys
from abc import ABC, abstractmethod
import numpy as np
import torch
from torchmetrics import Metric as TorchMetric
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_widedeep.metrics import Metric, MultipleMetrics
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.callbacks import (
History,
Callback,
MetricCallback,
CallbackContainer,
LRShedulerCallback,
)
from pytorch_widedeep.training._trainer_utils import bayesian_alias_to_loss
from pytorch_widedeep.bayesian_models._base_bayesian_model import (
BaseBayesianModel,
)
# There are some nuances in the Bayesian Trainer that make it hard to build an
# overall BaseTrainer. We could still perhaps code a very basic Trainer and
# then pass it to a BaseTrainer and BaseBayesianTrainer. However in this
# particular case we prefer code repetition as we believe is a simpler
# solution
class BaseBayesianTrainer(ABC):
def __init__(
self,
model: BaseBayesianModel,
objective: str,
custom_loss_function: Optional[Module],
optimizer: Optimizer,
lr_scheduler: LRScheduler,
callbacks: Optional[List[Callback]],
metrics: Optional[Union[List[Metric], List[TorchMetric]]],
verbose: int,
seed: int,
**kwargs,
):
if objective not in ["binary", "multiclass", "regression"]:
raise ValueError(
"If 'custom_loss_function' is not None, 'objective' must be 'binary' "
"'multiclass' or 'regression', consistent with the loss function"
)
self.device, self.num_workers = self._set_device_and_num_workers(**kwargs)
self.early_stop = False
self.model = model
self.model.to(self.device)
self.verbose = verbose
self.seed = seed
self.objective = objective
self.loss_fn = self._set_loss_fn(objective, custom_loss_function, **kwargs)
self.optimizer = (
optimizer
if optimizer is not None
else torch.optim.AdamW(self.model.parameters())
)
self.lr_scheduler = lr_scheduler
self._set_lr_scheduler_running_params(lr_scheduler, **kwargs)
self._set_callbacks_and_metrics(callbacks, metrics)
@abstractmethod
def fit(
self,
X_tab: np.ndarray,
target: np.ndarray,
X_tab_val: Optional[np.ndarray],
target_val: Optional[np.ndarray],
val_split: Optional[float],
n_epochs: int,
val_freq: int,
batch_size: int,
n_train_samples: int,
n_val_samples: int,
):
raise NotImplementedError("Trainer.fit method not implemented")
@abstractmethod
def predict(
self,
X_tab: np.ndarray,
n_samples: int,
return_samples: bool,
batch_size: int,
) -> np.ndarray:
raise NotImplementedError("Trainer.predict method not implemented")
@abstractmethod
def predict_proba(
self,
X_tab: np.ndarray,
n_samples: int,
return_samples: bool,
batch_size: int,
) -> np.ndarray:
raise NotImplementedError("Trainer.predict_proba method not implemented")
@abstractmethod
def save(
self,
path: str,
save_state_dict: bool,
model_filename: str,
):
raise NotImplementedError("Trainer.save method not implemented")
def _restore_best_weights(self):
already_restored = any(
[
(
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
)
for callback in self.callback_container.callbacks
]
)
if already_restored:
pass
else:
for callback in self.callback_container.callbacks:
if callback.__class__.__name__ == "ModelCheckpoint":
if callback.save_best_only:
if self.verbose:
print(
f"Model weights restored to best epoch: {callback.best_epoch + 1}"
)
self.model.load_state_dict(callback.best_state_dict)
else:
if self.verbose:
print(
"Model weights after training corresponds to the those of the "
"final epoch which might not be the best performing weights. Use "
"the 'ModelCheckpoint' Callback to restore the best epoch weights."
)
def _set_loss_fn(self, objective, custom_loss_function, **kwargs):
if custom_loss_function is not None:
return custom_loss_function
class_weight = (
torch.tensor(kwargs["class_weight"]).to(self.device)
if "class_weight" in kwargs
else None
)
if self.objective != "regression":
return bayesian_alias_to_loss(objective, weight=class_weight)
else:
return bayesian_alias_to_loss(objective)
def _set_reduce_on_plateau_criterion(
self, lr_scheduler, reducelronplateau_criterion
):
self.reducelronplateau = False
if isinstance(lr_scheduler, ReduceLROnPlateau):
self.reducelronplateau = True
if self.reducelronplateau and not reducelronplateau_criterion:
UserWarning(
"The learning rate scheduler is of type ReduceLROnPlateau. The step method in this"
" scheduler requires a 'metrics' param that can be either the validation loss or the"
" validation metric. Please, when instantiating the Trainer, specify which quantity"
" will be tracked using reducelronplateau_criterion = 'loss' (default) or"
" reducelronplateau_criterion = 'metric'"
)
else:
self.reducelronplateau_criterion = "loss"
def _set_lr_scheduler_running_params(self, lr_scheduler, **kwargs):
reducelronplateau_criterion = kwargs.get("reducelronplateau_criterion", None)
self._set_reduce_on_plateau_criterion(lr_scheduler, reducelronplateau_criterion)
if lr_scheduler is not None:
self.cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower()
else:
self.cyclic_lr = False
def _set_callbacks_and_metrics(self, callbacks, metrics):
self.callbacks: List = [History(), LRShedulerCallback()]
if callbacks is not None:
for callback in callbacks:
if isinstance(callback, type):
callback = callback()
self.callbacks.append(callback)
if metrics is not None:
self.metric = MultipleMetrics(metrics)
self.callbacks += [MetricCallback(self.metric)]
else:
self.metric = None
self.callback_container = CallbackContainer(self.callbacks)
self.callback_container.set_model(self.model)
self.callback_container.set_trainer(self)
@staticmethod
def _set_device_and_num_workers(**kwargs):
default_num_workers = (
0
if sys.platform == "darwin" and sys.version_info.minor > 7
else os.cpu_count()
)
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = kwargs.get("device", default_device)
num_workers = kwargs.get("num_workers", default_num_workers)
return device, num_workers
......@@ -17,10 +17,7 @@ from pytorch_widedeep.callbacks import (
LRShedulerCallback,
)
from pytorch_widedeep.initializers import Initializer, MultipleInitializer
from pytorch_widedeep.training._trainer_utils import (
alias_to_loss,
bayesian_alias_to_loss,
)
from pytorch_widedeep.training._trainer_utils import alias_to_loss
from pytorch_widedeep.models.tabular.tabnet._utils import create_explain_matrix
from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer
from pytorch_widedeep.training._multiple_transforms import MultipleTransforms
......@@ -28,9 +25,6 @@ from pytorch_widedeep.training._loss_and_obj_aliases import _ObjectiveToMethod
from pytorch_widedeep.training._multiple_lr_scheduler import (
MultipleLRScheduler,
)
from pytorch_widedeep.bayesian_models._base_bayesian_model import (
BaseBayesianModel,
)
class BaseTrainer(ABC):
......@@ -346,197 +340,3 @@ class BaseTrainer(ABC):
device = kwargs.get("device", default_device)
num_workers = kwargs.get("num_workers", default_num_workers)
return device, num_workers
# There are some nuances in the Bayesian Trainer that make it hard to build an
# overall BaseTrainer. We could still perhaps code a very basic Trainer and
# then pass it to a BaseTrainer and BaseBayesianTrainer. However in this
# particular case we prefer code repetition as we believe is a simpler
# solution
class BaseBayesianTrainer(ABC):
def __init__(
self,
model: BaseBayesianModel,
objective: str,
custom_loss_function: Optional[Module],
optimizer: Optimizer,
lr_scheduler: LRScheduler,
callbacks: Optional[List[Callback]],
metrics: Optional[Union[List[Metric], List[TorchMetric]]],
verbose: int,
seed: int,
**kwargs,
):
if objective not in ["binary", "multiclass", "regression"]:
raise ValueError(
"If 'custom_loss_function' is not None, 'objective' must be 'binary' "
"'multiclass' or 'regression', consistent with the loss function"
)
self.device, self.num_workers = self._set_device_and_num_workers(**kwargs)
self.early_stop = False
self.model = model
self.model.to(self.device)
self.verbose = verbose
self.seed = seed
self.objective = objective
self.loss_fn = self._set_loss_fn(objective, custom_loss_function, **kwargs)
self.optimizer = (
optimizer
if optimizer is not None
else torch.optim.AdamW(self.model.parameters())
)
self.lr_scheduler = lr_scheduler
self._set_lr_scheduler_running_params(lr_scheduler, **kwargs)
self._set_callbacks_and_metrics(callbacks, metrics)
@abstractmethod
def fit(
self,
X_tab: np.ndarray,
target: np.ndarray,
X_tab_val: Optional[np.ndarray],
target_val: Optional[np.ndarray],
val_split: Optional[float],
n_epochs: int,
val_freq: int,
batch_size: int,
n_train_samples: int,
n_val_samples: int,
):
raise NotImplementedError("Trainer.fit method not implemented")
@abstractmethod
def predict(
self,
X_tab: np.ndarray,
n_samples: int,
return_samples: bool,
batch_size: int,
) -> np.ndarray:
raise NotImplementedError("Trainer.predict method not implemented")
@abstractmethod
def predict_proba(
self,
X_tab: np.ndarray,
n_samples: int,
return_samples: bool,
batch_size: int,
) -> np.ndarray:
raise NotImplementedError("Trainer.predict_proba method not implemented")
@abstractmethod
def save(
self,
path: str,
save_state_dict: bool,
model_filename: str,
):
raise NotImplementedError("Trainer.save method not implemented")
def _restore_best_weights(self):
already_restored = any(
[
(
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
)
for callback in self.callback_container.callbacks
]
)
if already_restored:
pass
else:
for callback in self.callback_container.callbacks:
if callback.__class__.__name__ == "ModelCheckpoint":
if callback.save_best_only:
if self.verbose:
print(
f"Model weights restored to best epoch: {callback.best_epoch + 1}"
)
self.model.load_state_dict(callback.best_state_dict)
else:
if self.verbose:
print(
"Model weights after training corresponds to the those of the "
"final epoch which might not be the best performing weights. Use "
"the 'ModelCheckpoint' Callback to restore the best epoch weights."
)
def _set_loss_fn(self, objective, custom_loss_function, **kwargs):
if custom_loss_function is not None:
return custom_loss_function
class_weight = (
torch.tensor(kwargs["class_weight"]).to(self.device)
if "class_weight" in kwargs
else None
)
if self.objective != "regression":
return bayesian_alias_to_loss(objective, weight=class_weight)
else:
return bayesian_alias_to_loss(objective)
def _set_reduce_on_plateau_criterion(
self, lr_scheduler, reducelronplateau_criterion
):
self.reducelronplateau = False
if isinstance(lr_scheduler, ReduceLROnPlateau):
self.reducelronplateau = True
if self.reducelronplateau and not reducelronplateau_criterion:
UserWarning(
"The learning rate scheduler is of type ReduceLROnPlateau. The step method in this"
" scheduler requires a 'metrics' param that can be either the validation loss or the"
" validation metric. Please, when instantiating the Trainer, specify which quantity"
" will be tracked using reducelronplateau_criterion = 'loss' (default) or"
" reducelronplateau_criterion = 'metric'"
)
else:
self.reducelronplateau_criterion = "loss"
def _set_lr_scheduler_running_params(self, lr_scheduler, **kwargs):
reducelronplateau_criterion = kwargs.get("reducelronplateau_criterion", None)
self._set_reduce_on_plateau_criterion(lr_scheduler, reducelronplateau_criterion)
if lr_scheduler is not None:
self.cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower()
else:
self.cyclic_lr = False
def _set_callbacks_and_metrics(self, callbacks, metrics):
self.callbacks: List = [History(), LRShedulerCallback()]
if callbacks is not None:
for callback in callbacks:
if isinstance(callback, type):
callback = callback()
self.callbacks.append(callback)
if metrics is not None:
self.metric = MultipleMetrics(metrics)
self.callbacks += [MetricCallback(self.metric)]
else:
self.metric = None
self.callback_container = CallbackContainer(self.callbacks)
self.callback_container.set_model(self.model)
self.callback_container.set_trainer(self)
@staticmethod
def _set_device_and_num_workers(**kwargs):
default_num_workers = (
0
if sys.platform == "darwin" and sys.version_info.minor > 7
else os.cpu_count()
)
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = kwargs.get("device", default_device)
num_workers = kwargs.get("num_workers", default_num_workers)
return device, num_workers
......@@ -12,12 +12,14 @@ from pytorch_widedeep.metrics import Metric
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.callbacks import Callback
from pytorch_widedeep.utils.general_utils import Alias
from pytorch_widedeep.training._base_trainers import BaseBayesianTrainer
from pytorch_widedeep.training._trainer_utils import (
save_epoch_logs,
print_loss_and_metric,
tabular_train_val_split,
)
from pytorch_widedeep.training._base_bayesian_trainer import (
BaseBayesianTrainer,
)
from pytorch_widedeep.bayesian_models._base_bayesian_model import (
BaseBayesianModel,
)
......@@ -128,7 +130,7 @@ class BayesianTrainer(BaseBayesianTrainer):
target_val: Optional[np.ndarray] = None,
val_split: Optional[float] = None,
n_epochs: int = 1,
val_freq: int = 1,
validation_freq: int = 1,
batch_size: int = 32,
n_train_samples: int = 2,
n_val_samples: int = 2,
......@@ -202,7 +204,9 @@ class BayesianTrainer(BaseBayesianTrainer):
epoch_logs = save_epoch_logs(epoch_logs, train_loss, train_score, "train")
on_epoch_end_metric = None
if eval_set is not None and epoch % val_freq == (val_freq - 1):
if eval_set is not None and epoch % validation_freq == (
validation_freq - 1
):
self.callback_container.on_eval_begin()
self.valid_running_loss = 0.0
with trange(eval_steps, disable=self.verbose != 1) as v:
......
......@@ -21,7 +21,7 @@ from pytorch_widedeep.initializers import Initializer
from pytorch_widedeep.training._finetune import FineTune
from pytorch_widedeep.utils.general_utils import Alias
from pytorch_widedeep.training._wd_dataset import WideDeepDataset
from pytorch_widedeep.training._base_trainers import BaseTrainer
from pytorch_widedeep.training._base_trainer import BaseTrainer
from pytorch_widedeep.training._trainer_utils import (
save_epoch_logs,
wd_train_val_split,
......
......@@ -71,7 +71,7 @@ from torchvision.transforms import (
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data.dataloader import DataLoader
from pytorch_widedeep.models import WideDeep
from pytorch_widedeep.models import *
from pytorch_widedeep.models.tabular.tabnet.sparsemax import (
Entmax15,
Sparsemax,
......@@ -122,3 +122,15 @@ Transforms = Union[
LRScheduler = _LRScheduler
ModelParams = Generator[Tensor, Tensor, Tensor]
NormLayers = Union[torch.nn.Identity, torch.nn.LayerNorm, torch.nn.BatchNorm1d]
ModelWithAttention = Union[
TabTransformer,
SAINT,
FTTransformer,
TabFastFormer,
TabPerceiver,
ContextAttentionMLP,
SelfAttentionMLP,
]
ModelWithoutAttention = Union[TabMlp, TabResnet, TabNet]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册