From 0454492408678db232da018eda0df1a4e27a84d6 Mon Sep 17 00:00:00 2001 From: Javier Rodriguez Zaurin Date: Sun, 24 Apr 2022 20:34:05 +0200 Subject: [PATCH] Added base trainers for the Trainer and the BayesianTrainer classes. The base trainers will change as I add self supervised options --- .../scripts/adult_census_bayesian_tabmlp.py | 2 +- pytorch_widedeep/__init__.py | 2 +- pytorch_widedeep/training/_base_trainers.py | 542 ++++++++++++++++++ pytorch_widedeep/training/bayesian_trainer.py | 159 +---- pytorch_widedeep/training/trainer.py | 345 +---------- .../test_miscellaneous.py | 39 -- 6 files changed, 577 insertions(+), 512 deletions(-) create mode 100644 pytorch_widedeep/training/_base_trainers.py diff --git a/examples/scripts/adult_census_bayesian_tabmlp.py b/examples/scripts/adult_census_bayesian_tabmlp.py index 979cbf6..5fc834e 100644 --- a/examples/scripts/adult_census_bayesian_tabmlp.py +++ b/examples/scripts/adult_census_bayesian_tabmlp.py @@ -2,12 +2,12 @@ import numpy as np import torch import pandas as pd +from pytorch_widedeep import BayesianTrainer from pytorch_widedeep.metrics import Accuracy from pytorch_widedeep.datasets import load_adult from pytorch_widedeep.callbacks import EarlyStopping, ModelCheckpoint from pytorch_widedeep.preprocessing import TabPreprocessor, WidePreprocessor from pytorch_widedeep.bayesian_models import BayesianWide, BayesianTabMlp -from pytorch_widedeep.training.bayesian_trainer import BayesianTrainer use_cuda = torch.cuda.is_available() diff --git a/pytorch_widedeep/__init__.py b/pytorch_widedeep/__init__.py index cdcad2e..6ac8c11 100644 --- a/pytorch_widedeep/__init__.py +++ b/pytorch_widedeep/__init__.py @@ -14,4 +14,4 @@ from pytorch_widedeep.utils import ( ) from pytorch_widedeep.tab2vec import Tab2Vec from pytorch_widedeep.version import __version__ -from pytorch_widedeep.training import Trainer +from pytorch_widedeep.training import Trainer, BayesianTrainer diff --git a/pytorch_widedeep/training/_base_trainers.py b/pytorch_widedeep/training/_base_trainers.py new file mode 100644 index 0000000..42e0b61 --- /dev/null +++ b/pytorch_widedeep/training/_base_trainers.py @@ -0,0 +1,542 @@ +import os +import sys +from abc import ABC, abstractmethod + +import numpy as np +import torch +from torchmetrics import Metric as TorchMetric +from torch.optim.lr_scheduler import ReduceLROnPlateau + +from pytorch_widedeep.metrics import Metric, MultipleMetrics +from pytorch_widedeep.wdtypes import * # noqa: F403 +from pytorch_widedeep.callbacks import ( + History, + Callback, + MetricCallback, + CallbackContainer, + LRShedulerCallback, +) +from pytorch_widedeep.initializers import Initializer, MultipleInitializer +from pytorch_widedeep.training._trainer_utils import ( + alias_to_loss, + bayesian_alias_to_loss, +) +from pytorch_widedeep.models.tabular.tabnet._utils import create_explain_matrix +from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer +from pytorch_widedeep.training._multiple_transforms import MultipleTransforms +from pytorch_widedeep.training._loss_and_obj_aliases import _ObjectiveToMethod +from pytorch_widedeep.training._multiple_lr_scheduler import ( + MultipleLRScheduler, +) +from pytorch_widedeep.bayesian_models._base_bayesian_model import ( + BaseBayesianModel, +) + + +class BaseTrainer(ABC): + def __init__( + self, + model: WideDeep, + objective: str, + custom_loss_function: Optional[Module], + optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]], + lr_schedulers: Optional[Union[LRScheduler, Dict[str, LRScheduler]]], + initializers: Optional[Union[Initializer, Dict[str, Initializer]]], + transforms: Optional[List[Transforms]], + callbacks: Optional[List[Callback]], + metrics: Optional[Union[List[Metric], List[TorchMetric]]], + verbose: int, + seed: int, + **kwargs, + ): + + self._check_inputs( + model, objective, optimizers, lr_schedulers, custom_loss_function + ) + self.device, self.num_workers = self._set_device_and_num_workers(**kwargs) + + self.early_stop = False + self.verbose = verbose + self.seed = seed + + self.model = model + if self.model.is_tabnet: + self.lambda_sparse = kwargs.get("lambda_sparse", 1e-3) + self.reducing_matrix = create_explain_matrix(self.model) + self.model.to(self.device) + self.model.wd_device = self.device + + self.objective = objective + self.method = _ObjectiveToMethod.get(objective) + + self._initialize(initializers) + self.loss_fn = self._set_loss_fn(objective, custom_loss_function, **kwargs) + self.optimizer = self._set_optimizer(optimizers) + self.lr_scheduler = self._set_lr_scheduler(lr_schedulers, **kwargs) + self.transforms = self._set_transforms(transforms) + self._set_callbacks_and_metrics(callbacks, metrics) + + @abstractmethod + def fit( + self, + X_wide: Optional[np.ndarray], + X_tab: Optional[np.ndarray], + X_text: Optional[np.ndarray], + X_img: Optional[np.ndarray], + X_train: Optional[Dict[str, np.ndarray]], + X_val: Optional[Dict[str, np.ndarray]], + val_split: Optional[float], + target: Optional[np.ndarray], + n_epochs: int, + validation_freq: int, + batch_size: int, + ): + raise NotImplementedError("Trainer.fit method not implemented") + + @abstractmethod + def predict( + self, + X_wide: Optional[np.ndarray], + X_tab: Optional[np.ndarray], + X_text: Optional[np.ndarray], + X_img: Optional[np.ndarray], + X_test: Optional[Dict[str, np.ndarray]], + batch_size: int, + ) -> np.ndarray: + raise NotImplementedError("Trainer.predict method not implemented") + + @abstractmethod + def predict_proba( + self, + X_wide: Optional[np.ndarray], + X_tab: Optional[np.ndarray], + X_text: Optional[np.ndarray], + X_img: Optional[np.ndarray], + X_test: Optional[Dict[str, np.ndarray]], + batch_size: int, + ) -> np.ndarray: + raise NotImplementedError("Trainer.predict_proba method not implemented") + + @abstractmethod + def save( + self, + path: str, + save_state_dict: bool, + model_filename: str, + ): + raise NotImplementedError("Trainer.save method not implemented") + + def _restore_best_weights(self): + already_restored = any( + [ + ( + callback.__class__.__name__ == "EarlyStopping" + and callback.restore_best_weights + ) + for callback in self.callback_container.callbacks + ] + ) + if already_restored: + pass + else: + for callback in self.callback_container.callbacks: + if callback.__class__.__name__ == "ModelCheckpoint": + if callback.save_best_only: + if self.verbose: + print( + f"Model weights restored to best epoch: {callback.best_epoch + 1}" + ) + self.model.load_state_dict(callback.best_state_dict) + else: + if self.verbose: + print( + "Model weights after training corresponds to the those of the " + "final epoch which might not be the best performing weights. Use" + "the 'ModelCheckpoint' Callback to restore the best epoch weights." + ) + + def _initialize(self, initializers): + if initializers is not None: + if isinstance(initializers, Dict): + self.initializer = MultipleInitializer( + initializers, verbose=self.verbose + ) + self.initializer.apply(self.model) + elif isinstance(initializers, type): + self.initializer = initializers() + self.initializer(self.model) + elif isinstance(initializers, Initializer): + self.initializer = initializers + self.initializer(self.model) + + def _set_loss_fn(self, objective, custom_loss_function, **kwargs): + + class_weight = ( + torch.tensor(kwargs["class_weight"]).to(self.device) + if "class_weight" in kwargs + else None + ) + + if custom_loss_function is not None: + return custom_loss_function + elif ( + self.method not in ["regression", "qregression"] + and "focal_loss" not in objective + ): + return alias_to_loss(objective, weight=class_weight) + elif "focal_loss" in objective: + alpha = kwargs.get("alpha", 0.25) + gamma = kwargs.get("gamma", 2.0) + return alias_to_loss(objective, alpha=alpha, gamma=gamma) + else: + return alias_to_loss(objective) + + def _set_optimizer(self, optimizers): + if optimizers is not None: + if isinstance(optimizers, Optimizer): + optimizer: Union[Optimizer, MultipleOptimizer] = optimizers + elif isinstance(optimizers, Dict): + opt_names = list(optimizers.keys()) + mod_names = [n for n, c in self.model.named_children()] + # if with_fds - the prediction layer is part of the model and + # should be optimized with the rest of deeptabular + # component/model + if self.model.with_fds: + mod_names.remove("enf_pos") + mod_names.remove("fds_layer") + optimizers["deeptabular"].add_param_group( + {"params": self.model.fds_layer.pred_layer.parameters()} + ) + for mn in mod_names: + assert mn in opt_names, "No optimizer found for {}".format(mn) + optimizer = MultipleOptimizer(optimizers) + else: + optimizer = torch.optim.Adam(self.model.parameters()) # type: ignore + return optimizer + + def _set_lr_scheduler(self, lr_schedulers, **kwargs): + + # ReduceLROnPlateau is special + reducelronplateau_criterion = kwargs.get("reducelronplateau_criterion", None) + + self._set_reduce_on_plateau_criterion( + lr_schedulers, reducelronplateau_criterion + ) + + if lr_schedulers is not None: + + if isinstance(lr_schedulers, LRScheduler) or isinstance( + lr_schedulers, ReduceLROnPlateau + ): + lr_scheduler = lr_schedulers + cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower() + else: + lr_scheduler = MultipleLRScheduler(lr_schedulers) + scheduler_names = [ + sc.__class__.__name__.lower() + for _, sc in lr_scheduler._schedulers.items() + ] + cyclic_lr = any(["cycl" in sn for sn in scheduler_names]) + else: + lr_scheduler, cyclic_lr = None, False + + self.cyclic_lr = cyclic_lr + + return lr_scheduler + + def _set_reduce_on_plateau_criterion( + self, lr_schedulers, reducelronplateau_criterion + ): + + self.reducelronplateau = False + + if isinstance(lr_schedulers, Dict): + for _, scheduler in lr_schedulers.items(): + if isinstance(scheduler, ReduceLROnPlateau): + self.reducelronplateau = True + elif isinstance(lr_schedulers, ReduceLROnPlateau): + self.reducelronplateau = True + + if self.reducelronplateau and not reducelronplateau_criterion: + UserWarning( + "The learning rate scheduler of at least one of the model components is of type " + "ReduceLROnPlateau. The step method in this scheduler requires a 'metrics' param " + "that can be either the validation loss or the validation metric. Please, when " + "instantiating the Trainer, specify which quantity will be tracked using " + "reducelronplateau_criterion = 'loss' (default) or reducelronplateau_criterion = 'metric'" + ) + self.reducelronplateau_criterion = "loss" + else: + self.reducelronplateau_criterion = reducelronplateau_criterion + + @staticmethod + def _set_transforms(transforms): + if transforms is not None: + return MultipleTransforms(transforms)() + else: + return None + + def _set_callbacks_and_metrics(self, callbacks, metrics): + self.callbacks: List = [History(), LRShedulerCallback()] + if callbacks is not None: + for callback in callbacks: + if isinstance(callback, type): + callback = callback() + self.callbacks.append(callback) + if metrics is not None: + self.metric = MultipleMetrics(metrics) + self.callbacks += [MetricCallback(self.metric)] + else: + self.metric = None + self.callback_container = CallbackContainer(self.callbacks) + self.callback_container.set_model(self.model) + self.callback_container.set_trainer(self) + + @staticmethod + def _check_inputs( + model, + objective, + optimizers, + lr_schedulers, + custom_loss_function, + ): + + if model.with_fds and _ObjectiveToMethod.get(objective) != "regression": + raise ValueError( + "Feature Distribution Smooting can be used only for regression" + ) + + if _ObjectiveToMethod.get(objective) == "multiclass" and model.pred_dim == 1: + raise ValueError( + "This is a multiclass classification problem but the size of the output layer" + " is set to 1. Please, set the 'pred_dim' param equal to the number of classes " + " when instantiating the 'WideDeep' class" + ) + + if isinstance(optimizers, Dict): + if lr_schedulers is not None and not isinstance(lr_schedulers, Dict): + raise ValueError( + "''optimizers' and 'lr_schedulers' must have consistent type: " + "(Optimizer and LRScheduler) or (Dict[str, Optimizer] and Dict[str, LRScheduler]) " + "Please, read the documentation or see the examples for more details" + ) + + if custom_loss_function is not None and objective not in [ + "binary", + "multiclass", + "regression", + ]: + raise ValueError( + "If 'custom_loss_function' is not None, 'objective' must be 'binary' " + "'multiclass' or 'regression', consistent with the loss function" + ) + + @staticmethod + def _set_device_and_num_workers(**kwargs): + + # Important note for Mac users: Since python 3.8, the multiprocessing + # library start method changed from 'fork' to 'spawn'. This affects the + # data-loaders, which will not run in parallel. + default_num_workers = ( + 0 + if sys.platform == "darwin" and sys.version_info.minor > 7 + else os.cpu_count() + ) + default_device = "cuda" if torch.cuda.is_available() else "cpu" + device = kwargs.get("device", default_device) + num_workers = kwargs.get("num_workers", default_num_workers) + return device, num_workers + + +# There are some nuances in the Bayesian Trainer that make it hard to build an +# overall BaseTrainer. We could still perhaps code a very basic Trainer and +# then pass it to a BaseTrainer and BaseBayesianTrainer. However in this +# particular case we prefer code repetition as we believe is a simpler +# solution +class BaseBayesianTrainer(ABC): + def __init__( + self, + model: BaseBayesianModel, + objective: str, + custom_loss_function: Optional[Module], + optimizer: Optimizer, + lr_scheduler: LRScheduler, + callbacks: Optional[List[Callback]], + metrics: Optional[Union[List[Metric], List[TorchMetric]]], + verbose: int, + seed: int, + **kwargs, + ): + + if objective not in ["binary", "multiclass", "regression"]: + raise ValueError( + "If 'custom_loss_function' is not None, 'objective' must be 'binary' " + "'multiclass' or 'regression', consistent with the loss function" + ) + + self.device, self.num_workers = self._set_device_and_num_workers(**kwargs) + + self.model = model + self.early_stop = False + + self.verbose = verbose + self.seed = seed + self.objective = objective + + self.loss_fn = self._set_loss_fn(objective, custom_loss_function, **kwargs) + self.optimizer = ( + optimizer + if optimizer is not None + else torch.optim.AdamW(self.model.parameters()) + ) + self.lr_scheduler = lr_scheduler + self._set_lr_scheduler_running_params(lr_scheduler, **kwargs) + self._set_callbacks_and_metrics(callbacks, metrics) + self.model.to(self.device) + + @abstractmethod + def fit( + self, + X_tab: np.ndarray, + target: np.ndarray, + X_tab_val: Optional[np.ndarray], + target_val: Optional[np.ndarray], + val_split: Optional[float], + n_epochs: int, + val_freq: int, + batch_size: int, + n_train_samples: int, + n_val_samples: int, + ): + raise NotImplementedError("Trainer.fit method not implemented") + + @abstractmethod + def predict( + self, + X_tab: np.ndarray, + n_samples: int, + return_samples: bool, + batch_size: int, + ) -> np.ndarray: + raise NotImplementedError("Trainer.predict method not implemented") + + @abstractmethod + def predict_proba( + self, + X_tab: np.ndarray, + n_samples: int, + return_samples: bool, + batch_size: int, + ) -> np.ndarray: + raise NotImplementedError("Trainer.predict_proba method not implemented") + + @abstractmethod + def save( + self, + path: str, + save_state_dict: bool, + model_filename: str, + ): + raise NotImplementedError("Trainer.save method not implemented") + + def _restore_best_weights(self): + already_restored = any( + [ + ( + callback.__class__.__name__ == "EarlyStopping" + and callback.restore_best_weights + ) + for callback in self.callback_container.callbacks + ] + ) + if already_restored: + pass + else: + for callback in self.callback_container.callbacks: + if callback.__class__.__name__ == "ModelCheckpoint": + if callback.save_best_only: + if self.verbose: + print( + f"Model weights restored to best epoch: {callback.best_epoch + 1}" + ) + self.model.load_state_dict(callback.best_state_dict) + else: + if self.verbose: + print( + "Model weights after training corresponds to the those of the " + "final epoch which might not be the best performing weights. Use" + "the 'ModelCheckpoint' Callback to restore the best epoch weights." + ) + + def _set_loss_fn(self, objective, custom_loss_function, **kwargs): + + if custom_loss_function is not None: + return custom_loss_function + + class_weight = ( + torch.tensor(kwargs["class_weight"]).to(self.device) + if "class_weight" in kwargs + else None + ) + + if self.objective != "regression": + return bayesian_alias_to_loss(objective, weight=class_weight) + else: + return bayesian_alias_to_loss(objective) + + def _set_reduce_on_plateau_criterion( + self, lr_scheduler, reducelronplateau_criterion + ): + + self.reducelronplateau = False + + if isinstance(lr_scheduler, ReduceLROnPlateau): + self.reducelronplateau = True + + if self.reducelronplateau and not reducelronplateau_criterion: + UserWarning( + "The learning rate scheduler is of type ReduceLROnPlateau. The step method in this" + " scheduler requires a 'metrics' param that can be either the validation loss or the" + " validation metric. Please, when instantiating the Trainer, specify which quantity" + " will be tracked using reducelronplateau_criterion = 'loss' (default) or" + " reducelronplateau_criterion = 'metric'" + ) + else: + self.reducelronplateau_criterion = "loss" + + def _set_lr_scheduler_running_params(self, lr_scheduler, **kwargs): + reducelronplateau_criterion = kwargs.get("reducelronplateau_criterion", None) + self._set_reduce_on_plateau_criterion(lr_scheduler, reducelronplateau_criterion) + if lr_scheduler is not None: + self.cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower() + else: + self.cyclic_lr = False + + def _set_callbacks_and_metrics(self, callbacks, metrics): + self.callbacks: List = [History(), LRShedulerCallback()] + if callbacks is not None: + for callback in callbacks: + if isinstance(callback, type): + callback = callback() + self.callbacks.append(callback) + if metrics is not None: + self.metric = MultipleMetrics(metrics) + self.callbacks += [MetricCallback(self.metric)] + else: + self.metric = None + self.callback_container = CallbackContainer(self.callbacks) + self.callback_container.set_model(self.model) + self.callback_container.set_trainer(self) + + @staticmethod + def _set_device_and_num_workers(**kwargs): + + default_num_workers = ( + 0 + if sys.platform == "darwin" and sys.version_info.minor > 7 + else os.cpu_count() + ) + default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = kwargs.get("device", default_device) + num_workers = kwargs.get("num_workers", default_num_workers) + return device, num_workers diff --git a/pytorch_widedeep/training/bayesian_trainer.py b/pytorch_widedeep/training/bayesian_trainer.py index a1935ef..42a26e5 100644 --- a/pytorch_widedeep/training/bayesian_trainer.py +++ b/pytorch_widedeep/training/bayesian_trainer.py @@ -1,4 +1,3 @@ -import os import json from pathlib import Path @@ -8,22 +7,15 @@ import torch.nn.functional as F from tqdm import trange from torchmetrics import Metric as TorchMetric from torch.utils.data import DataLoader, TensorDataset -from torch.optim.lr_scheduler import ReduceLROnPlateau -from pytorch_widedeep.metrics import Metric, MultipleMetrics +from pytorch_widedeep.metrics import Metric from pytorch_widedeep.wdtypes import * # noqa: F403 -from pytorch_widedeep.callbacks import ( - History, - Callback, - MetricCallback, - CallbackContainer, - LRShedulerCallback, -) +from pytorch_widedeep.callbacks import Callback from pytorch_widedeep.utils.general_utils import Alias +from pytorch_widedeep.training._base_trainers import BaseBayesianTrainer from pytorch_widedeep.training._trainer_utils import ( save_epoch_logs, print_loss_and_metric, - bayesian_alias_to_loss, tabular_train_val_split, ) from pytorch_widedeep.bayesian_models._base_bayesian_model import ( @@ -31,7 +23,7 @@ from pytorch_widedeep.bayesian_models._base_bayesian_model import ( ) -class BayesianTrainer: +class BayesianTrainer(BaseBayesianTrainer): r"""Class to set the of attributes that will be used during the training process. @@ -115,32 +107,18 @@ class BayesianTrainer: seed: int = 1, **kwargs, ): - - if objective not in ["binary", "multiclass", "regression"]: - raise ValueError( - "If 'custom_loss_function' is not None, 'objective' must be 'binary' " - "'multiclass' or 'regression', consistent with the loss function" - ) - - self.device, self.num_workers = self._set_device_and_num_workers(**kwargs) - - self.model = model - self.early_stop = False - - self.verbose = verbose - self.seed = seed - self.objective = objective - - self.loss_fn = self._set_loss_fn(objective, custom_loss_function, **kwargs) - self.optimizer = ( - optimizer - if optimizer is not None - else torch.optim.AdamW(self.model.parameters()) + super().__init__( + model=model, + objective=objective, + custom_loss_function=custom_loss_function, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + callbacks=callbacks, + metrics=metrics, + verbose=verbose, + seed=seed, + **kwargs, ) - self.lr_scheduler = lr_scheduler - self._set_lr_scheduler_running_params(lr_scheduler, **kwargs) - self._set_callbacks_and_metrics(callbacks, metrics) - self.model.to(self.device) def fit( # noqa: C901 self, @@ -378,35 +356,6 @@ class BayesianTrainer: else: torch.save(self.model, model_path) - def _restore_best_weights(self): - already_restored = any( - [ - ( - callback.__class__.__name__ == "EarlyStopping" - and callback.restore_best_weights - ) - for callback in self.callback_container.callbacks - ] - ) - if already_restored: - pass - else: - for callback in self.callback_container.callbacks: - if callback.__class__.__name__ == "ModelCheckpoint": - if callback.save_best_only: - if self.verbose: - print( - f"Model weights restored to best epoch: {callback.best_epoch + 1}" - ) - self.model.load_state_dict(callback.best_state_dict) - else: - if self.verbose: - print( - "Model weights after training corresponds to the those of the " - "final epoch which might not be the best performing weights. Use" - "the 'ModelCheckpoint' Callback to restore the best epoch weights." - ) - def _train_step( self, X_tab: Tensor, @@ -526,81 +475,3 @@ class BayesianTrainer: self.model.train() return preds_l - - def _set_loss_fn(self, objective, custom_loss_function, **kwargs): - - if custom_loss_function is not None: - return custom_loss_function - - class_weight = ( - torch.tensor(kwargs["class_weight"]).to(self.device) - if "class_weight" in kwargs - else None - ) - - if self.objective != "regression": - return bayesian_alias_to_loss(objective, weight=class_weight) - else: - return bayesian_alias_to_loss(objective) - - def _set_reduce_on_plateau_criterion( - self, lr_scheduler, reducelronplateau_criterion - ): - - self.reducelronplateau = False - - if isinstance(lr_scheduler, ReduceLROnPlateau): - self.reducelronplateau = True - - if self.reducelronplateau and not reducelronplateau_criterion: - UserWarning( - "The learning rate scheduler is of type ReduceLROnPlateau. The step method in this" - " scheduler requires a 'metrics' param that can be either the validation loss or the" - " validation metric. Please, when instantiating the Trainer, specify which quantity" - " will be tracked using reducelronplateau_criterion = 'loss' (default) or" - " reducelronplateau_criterion = 'metric'" - ) - else: - self.reducelronplateau_criterion = "loss" - - def _set_lr_scheduler_running_params(self, lr_scheduler, **kwargs): - # ReduceLROnPlateau is special - - reducelronplateau_criterion = kwargs.get("reducelronplateau_criterion", None) - self._set_reduce_on_plateau_criterion(lr_scheduler, reducelronplateau_criterion) - if lr_scheduler is not None: - self.cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower() - else: - self.cyclic_lr = False - - def _set_callbacks_and_metrics(self, callbacks, metrics): - self.callbacks: List = [History(), LRShedulerCallback()] - if callbacks is not None: - for callback in callbacks: - if isinstance(callback, type): - callback = callback() - self.callbacks.append(callback) - if metrics is not None: - self.metric = MultipleMetrics(metrics) - self.callbacks += [MetricCallback(self.metric)] - else: - self.metric = None - self.callback_container = CallbackContainer(self.callbacks) - self.callback_container.set_model(self.model) - self.callback_container.set_trainer(self) - - @staticmethod - def _set_device_and_num_workers(**kwargs): - - # Important note for Mac users: Since python 3.8, the multiprocessing - # library start method changed from 'fork' to 'spawn'. This affects the - # data-loaders, which will not run in parallel. - default_num_workers = ( - 0 - if sys.platform == "darwin" and sys.version_info.minor > 7 - else os.cpu_count() - ) - default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - device = kwargs.get("device", default_device) - num_workers = kwargs.get("num_workers", default_num_workers) - return device, num_workers diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index 6eb0a12..01c89c6 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -1,5 +1,3 @@ -import os -import sys import json import warnings from inspect import signature @@ -13,39 +11,25 @@ from torch import nn from scipy.sparse import csc_matrix from torchmetrics import Metric as TorchMetric from torch.utils.data import DataLoader -from torch.optim.lr_scheduler import ReduceLROnPlateau from pytorch_widedeep.losses import ZILNLoss -from pytorch_widedeep.metrics import Metric, MultipleMetrics +from pytorch_widedeep.metrics import Metric from pytorch_widedeep.wdtypes import * # noqa: F403 -from pytorch_widedeep.callbacks import ( - History, - Callback, - MetricCallback, - CallbackContainer, - LRShedulerCallback, -) +from pytorch_widedeep.callbacks import Callback from pytorch_widedeep.dataloaders import DataLoaderDefault -from pytorch_widedeep.initializers import Initializer, MultipleInitializer +from pytorch_widedeep.initializers import Initializer from pytorch_widedeep.training._finetune import FineTune from pytorch_widedeep.utils.general_utils import Alias from pytorch_widedeep.training._wd_dataset import WideDeepDataset +from pytorch_widedeep.training._base_trainers import BaseTrainer from pytorch_widedeep.training._trainer_utils import ( - alias_to_loss, save_epoch_logs, wd_train_val_split, print_loss_and_metric, ) -from pytorch_widedeep.models.tabular.tabnet._utils import create_explain_matrix -from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer -from pytorch_widedeep.training._multiple_transforms import MultipleTransforms -from pytorch_widedeep.training._loss_and_obj_aliases import _ObjectiveToMethod -from pytorch_widedeep.training._multiple_lr_scheduler import ( - MultipleLRScheduler, -) -class Trainer: +class Trainer(BaseTrainer): r"""Class to set the of attributes that will be used during the training process. @@ -221,34 +205,20 @@ class Trainer: seed: int = 1, **kwargs, ): - - self._check_inputs( - model, objective, optimizers, lr_schedulers, custom_loss_function + super().__init__( + model=model, + objective=objective, + custom_loss_function=custom_loss_function, + optimizers=optimizers, + lr_schedulers=lr_schedulers, + initializers=initializers, + transforms=transforms, + callbacks=callbacks, + metrics=metrics, + verbose=verbose, + seed=seed, + **kwargs, ) - self.device, self.num_workers = self._set_device_and_num_workers(**kwargs) - - # initialize early_stop. If EarlyStopping Callback is used it will - # take care of it - self.early_stop = False - self.verbose = verbose - self.seed = seed - - self.model = model - if self.model.is_tabnet: - self.lambda_sparse = kwargs.get("lambda_sparse", 1e-3) - self.reducing_matrix = create_explain_matrix(self.model) - self.model.to(self.device) - self.model.wd_device = self.device - - self.objective = objective - self.method = _ObjectiveToMethod.get(objective) - - self._initialize(initializers) - self.loss_fn = self._set_loss_fn(objective, custom_loss_function, **kwargs) - self.optimizer = self._set_optimizer(optimizers) - self.lr_scheduler = self._set_lr_scheduler(lr_schedulers, **kwargs) - self.transforms = self._set_transforms(transforms) - self._set_callbacks_and_metrics(callbacks, metrics) @Alias("finetune", "warmup") def fit( # noqa: C901 @@ -695,64 +665,6 @@ class Trainer: if self.method == "multiclass": return np.vstack(preds_l) - def get_embeddings( - self, col_name: str, cat_encoding_dict: Dict[str, Dict[str, int]] - ) -> Dict[str, np.ndarray]: # pragma: no cover - r"""Returns the learned embeddings for the categorical features passed through - ``deeptabular``. - - .. note:: This function will be deprecated in the next relase. Please consider - using ``Tab2Vec`` instead. - - This method is designed to take an encoding dictionary in the same - format as that of the :obj:`LabelEncoder` Attribute in the class - :obj:`TabPreprocessor`. See - :class:`pytorch_widedeep.preprocessing.TabPreprocessor` and - :class:`pytorch_widedeep.utils.dense_utils.LabelEncder`. - - Parameters - ---------- - col_name: str, - Column name of the feature we want to get the embeddings for - cat_encoding_dict: Dict - Dictionary where the keys are the name of the column for which we - want to retrieve the embeddings and the values are also of type - Dict. These Dict values have keys that are the categories for that - column and the values are the corresponding numberical encodings - - e.g.: {'column': {'cat_0': 1, 'cat_1': 2, ...}} - - Examples - -------- - - For a series of comprehensive examples please, see the `Examples - `__ - folder in the repo - - For completion, here we include a `"fabricated"` example, i.e. - assuming we have already trained the model, that we have the - categorical encodings in a dictionary name ``encoding_dict``, and that - there is a column called `'education'`: - - .. code-block:: python - - trainer.get_embeddings(col_name="education", cat_encoding_dict=encoding_dict) - """ - warnings.warn( - "'get_embeddings' will be deprecated in the next release. " - "Please consider using 'Tab2vec' instead", - DeprecationWarning, - ) - for n, p in self.model.named_parameters(): - if "embed_layers" in n and col_name in n: - embed_mtx = p.cpu().data.numpy() - encoding_dict = cat_encoding_dict[col_name] - inv_encoding_dict = {v: k for k, v in encoding_dict.items()} - cat_embed_dict = {} - for idx, value in inv_encoding_dict.items(): - cat_embed_dict[value] = embed_mtx[idx] - return cat_embed_dict - def explain(self, X_tab: np.ndarray, save_step_masks: bool = False): """ if the ``deeptabular`` component is a :obj:`Tabnet` model, returns the @@ -876,35 +788,6 @@ class Trainer: with open(save_dir / "feature_importance.json", "w") as fi: json.dump(self.feature_importance, fi) - def _restore_best_weights(self): - already_restored = any( - [ - ( - callback.__class__.__name__ == "EarlyStopping" - and callback.restore_best_weights - ) - for callback in self.callback_container.callbacks - ] - ) - if already_restored: - pass - else: - for callback in self.callback_container.callbacks: - if callback.__class__.__name__ == "ModelCheckpoint": - if callback.save_best_only: - if self.verbose: - print( - f"Model weights restored to best epoch: {callback.best_epoch + 1}" - ) - self.model.load_state_dict(callback.best_state_dict) - else: - if self.verbose: - print( - "Model weights after training corresponds to the those of the " - "final epoch which might not be the best performing weights. Use" - "the 'ModelCheckpoint' Callback to restore the best epoch weights." - ) - @Alias("n_epochs", ["finetune_epochs", "warmup_epochs"]) @Alias("max_lr", ["finetune_max_lr", "warmup_max_lr"]) def _finetune( @@ -1287,195 +1170,3 @@ class Trainer: finetune_args[k] = v return lds_args, dataloader_args, finetune_args - - def _initialize(self, initializers): - if initializers is not None: - if isinstance(initializers, Dict): - self.initializer = MultipleInitializer( - initializers, verbose=self.verbose - ) - self.initializer.apply(self.model) - elif isinstance(initializers, type): - self.initializer = initializers() - self.initializer(self.model) - elif isinstance(initializers, Initializer): - self.initializer = initializers - self.initializer(self.model) - - def _set_loss_fn(self, objective, custom_loss_function, **kwargs): - - class_weight = ( - torch.tensor(kwargs["class_weight"]).to(self.device) - if "class_weight" in kwargs - else None - ) - - if custom_loss_function is not None: - return custom_loss_function - elif ( - self.method not in ["regression", "qregression"] - and "focal_loss" not in objective - ): - return alias_to_loss(objective, weight=class_weight) - elif "focal_loss" in objective: - alpha = kwargs.get("alpha", 0.25) - gamma = kwargs.get("gamma", 2.0) - return alias_to_loss(objective, alpha=alpha, gamma=gamma) - else: - return alias_to_loss(objective) - - def _set_optimizer(self, optimizers): - if optimizers is not None: - if isinstance(optimizers, Optimizer): - optimizer: Union[Optimizer, MultipleOptimizer] = optimizers - elif isinstance(optimizers, Dict): - opt_names = list(optimizers.keys()) - mod_names = [n for n, c in self.model.named_children()] - # if with_fds - the prediction layer is part of the model and - # should be optimized with the rest of deeptabular - # component/model - if self.model.with_fds: - mod_names.remove("enf_pos") - mod_names.remove("fds_layer") - optimizers["deeptabular"].add_param_group( - {"params": self.model.fds_layer.pred_layer.parameters()} - ) - for mn in mod_names: - assert mn in opt_names, "No optimizer found for {}".format(mn) - optimizer = MultipleOptimizer(optimizers) - else: - optimizer = torch.optim.Adam(self.model.parameters()) # type: ignore - return optimizer - - def _set_reduce_on_plateau_criterion( - self, lr_schedulers, reducelronplateau_criterion - ): - - self.reducelronplateau = False - - if isinstance(lr_schedulers, Dict): - for _, scheduler in lr_schedulers.items(): - if isinstance(scheduler, ReduceLROnPlateau): - self.reducelronplateau = True - elif isinstance(lr_schedulers, ReduceLROnPlateau): - self.reducelronplateau = True - - if self.reducelronplateau and not reducelronplateau_criterion: - UserWarning( - "The learning rate scheduler of at least one of the model components is of type " - "ReduceLROnPlateau. The step method in this scheduler requires a 'metrics' param " - "that can be either the validation loss or the validation metric. Please, when " - "instantiating the Trainer, specify which quantity will be tracked using " - "reducelronplateau_criterion = 'loss' (default) or reducelronplateau_criterion = 'metric'" - ) - self.reducelronplateau_criterion = "loss" - else: - self.reducelronplateau_criterion = reducelronplateau_criterion - - def _set_lr_scheduler(self, lr_schedulers, **kwargs): - - # ReduceLROnPlateau is special - reducelronplateau_criterion = kwargs.get("reducelronplateau_criterion", None) - - self._set_reduce_on_plateau_criterion( - lr_schedulers, reducelronplateau_criterion - ) - - if lr_schedulers is not None: - - if isinstance(lr_schedulers, LRScheduler) or isinstance( - lr_schedulers, ReduceLROnPlateau - ): - lr_scheduler = lr_schedulers - cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower() - else: - lr_scheduler = MultipleLRScheduler(lr_schedulers) - scheduler_names = [ - sc.__class__.__name__.lower() - for _, sc in lr_scheduler._schedulers.items() - ] - cyclic_lr = any(["cycl" in sn for sn in scheduler_names]) - else: - lr_scheduler, cyclic_lr = None, False - - self.cyclic_lr = cyclic_lr - - return lr_scheduler - - @staticmethod - def _set_transforms(transforms): - if transforms is not None: - return MultipleTransforms(transforms)() - else: - return None - - def _set_callbacks_and_metrics(self, callbacks, metrics): - self.callbacks: List = [History(), LRShedulerCallback()] - if callbacks is not None: - for callback in callbacks: - if isinstance(callback, type): - callback = callback() - self.callbacks.append(callback) - if metrics is not None: - self.metric = MultipleMetrics(metrics) - self.callbacks += [MetricCallback(self.metric)] - else: - self.metric = None - self.callback_container = CallbackContainer(self.callbacks) - self.callback_container.set_model(self.model) - self.callback_container.set_trainer(self) - - @staticmethod - def _check_inputs( - model, - objective, - optimizers, - lr_schedulers, - custom_loss_function, - ): - - if model.with_fds and _ObjectiveToMethod.get(objective) != "regression": - raise ValueError( - "Feature Distribution Smooting can be used only for regression" - ) - - if _ObjectiveToMethod.get(objective) == "multiclass" and model.pred_dim == 1: - raise ValueError( - "This is a multiclass classification problem but the size of the output layer" - " is set to 1. Please, set the 'pred_dim' param equal to the number of classes " - " when instantiating the 'WideDeep' class" - ) - - if isinstance(optimizers, Dict): - if lr_schedulers is not None and not isinstance(lr_schedulers, Dict): - raise ValueError( - "''optimizers' and 'lr_schedulers' must have consistent type: " - "(Optimizer and LRScheduler) or (Dict[str, Optimizer] and Dict[str, LRScheduler]) " - "Please, read the documentation or see the examples for more details" - ) - - if custom_loss_function is not None and objective not in [ - "binary", - "multiclass", - "regression", - ]: - raise ValueError( - "If 'custom_loss_function' is not None, 'objective' must be 'binary' " - "'multiclass' or 'regression', consistent with the loss function" - ) - - @staticmethod - def _set_device_and_num_workers(**kwargs): - - # Important note for Mac users: Since python 3.8, the multiprocessing - # library start method changed from 'fork' to 'spawn'. This affects the - # data-loaders, which will not run in parallel. - default_num_workers = ( - 0 - if sys.platform == "darwin" and sys.version_info.minor > 7 - else os.cpu_count() - ) - default_device = "cuda" if torch.cuda.is_available() else "cpu" - device = kwargs.get("device", default_device) - num_workers = kwargs.get("num_workers", default_num_workers) - return device, num_workers diff --git a/tests/test_model_functioning/test_miscellaneous.py b/tests/test_model_functioning/test_miscellaneous.py index 64ef44b..4a8cc7d 100644 --- a/tests/test_model_functioning/test_miscellaneous.py +++ b/tests/test_model_functioning/test_miscellaneous.py @@ -347,11 +347,6 @@ def test_save_load_and_predict(): assert preds.shape[0] == X_tab.shape[0] -############################################################################### -# test get_embeddings DeprecationWarning -############################################################################### - - def create_test_dataset(input_type, input_type_2=None): df = pd.DataFrame() col1 = list(np.random.choice(input_type, 32)) @@ -371,40 +366,6 @@ df["col4"] = np.round(np.random.rand(32), 3) df["target"] = np.random.choice(2, 32) -def test_get_embeddings_deprecation_warning(): - - embed_cols = [("col1", 5), ("col2", 5)] - continuous_cols = ["col3", "col4"] - - tab_preprocessor = TabPreprocessor( - cat_embed_cols=embed_cols, continuous_cols=continuous_cols - ) - X_tab = tab_preprocessor.fit_transform(df) - target = df.target.values - - tabmlp = TabMlp( - mlp_hidden_dims=[32, 16], - mlp_dropout=[0.5, 0.5], - column_idx={k: v for v, k in enumerate(df.columns)}, - cat_embed_input=tab_preprocessor.cat_embed_input, - continuous_cols=tab_preprocessor.continuous_cols, - ) - - model = WideDeep(deeptabular=tabmlp) - trainer = Trainer(model, objective="binary", verbose=0) - trainer.fit( - X_tab=X_tab, - target=target, - batch_size=16, - ) - - with pytest.warns(DeprecationWarning): - trainer.get_embeddings( - col_name="col1", - cat_encoding_dict=tab_preprocessor.label_encoder.encoding_dict, - ) - - ############################################################################### # test test_handle_columns_with_dots ############################################################################### -- GitLab