提交 04544924 编写于 作者: J Javier Rodriguez Zaurin

Added base trainers for the Trainer and the BayesianTrainer classes. The base...

Added base trainers for the Trainer and the BayesianTrainer classes. The base trainers will change as I add self supervised options
上级 487fa1a6
......@@ -2,12 +2,12 @@ import numpy as np
import torch
import pandas as pd
from pytorch_widedeep import BayesianTrainer
from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.datasets import load_adult
from pytorch_widedeep.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_widedeep.preprocessing import TabPreprocessor, WidePreprocessor
from pytorch_widedeep.bayesian_models import BayesianWide, BayesianTabMlp
from pytorch_widedeep.training.bayesian_trainer import BayesianTrainer
use_cuda = torch.cuda.is_available()
......
......@@ -14,4 +14,4 @@ from pytorch_widedeep.utils import (
)
from pytorch_widedeep.tab2vec import Tab2Vec
from pytorch_widedeep.version import __version__
from pytorch_widedeep.training import Trainer
from pytorch_widedeep.training import Trainer, BayesianTrainer
此差异已折叠。
import os
import json
from pathlib import Path
......@@ -8,22 +7,15 @@ import torch.nn.functional as F
from tqdm import trange
from torchmetrics import Metric as TorchMetric
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_widedeep.metrics import Metric, MultipleMetrics
from pytorch_widedeep.metrics import Metric
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.callbacks import (
History,
Callback,
MetricCallback,
CallbackContainer,
LRShedulerCallback,
)
from pytorch_widedeep.callbacks import Callback
from pytorch_widedeep.utils.general_utils import Alias
from pytorch_widedeep.training._base_trainers import BaseBayesianTrainer
from pytorch_widedeep.training._trainer_utils import (
save_epoch_logs,
print_loss_and_metric,
bayesian_alias_to_loss,
tabular_train_val_split,
)
from pytorch_widedeep.bayesian_models._base_bayesian_model import (
......@@ -31,7 +23,7 @@ from pytorch_widedeep.bayesian_models._base_bayesian_model import (
)
class BayesianTrainer:
class BayesianTrainer(BaseBayesianTrainer):
r"""Class to set the of attributes that will be used during the
training process.
......@@ -115,32 +107,18 @@ class BayesianTrainer:
seed: int = 1,
**kwargs,
):
if objective not in ["binary", "multiclass", "regression"]:
raise ValueError(
"If 'custom_loss_function' is not None, 'objective' must be 'binary' "
"'multiclass' or 'regression', consistent with the loss function"
)
self.device, self.num_workers = self._set_device_and_num_workers(**kwargs)
self.model = model
self.early_stop = False
self.verbose = verbose
self.seed = seed
self.objective = objective
self.loss_fn = self._set_loss_fn(objective, custom_loss_function, **kwargs)
self.optimizer = (
optimizer
if optimizer is not None
else torch.optim.AdamW(self.model.parameters())
super().__init__(
model=model,
objective=objective,
custom_loss_function=custom_loss_function,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
callbacks=callbacks,
metrics=metrics,
verbose=verbose,
seed=seed,
**kwargs,
)
self.lr_scheduler = lr_scheduler
self._set_lr_scheduler_running_params(lr_scheduler, **kwargs)
self._set_callbacks_and_metrics(callbacks, metrics)
self.model.to(self.device)
def fit( # noqa: C901
self,
......@@ -378,35 +356,6 @@ class BayesianTrainer:
else:
torch.save(self.model, model_path)
def _restore_best_weights(self):
already_restored = any(
[
(
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
)
for callback in self.callback_container.callbacks
]
)
if already_restored:
pass
else:
for callback in self.callback_container.callbacks:
if callback.__class__.__name__ == "ModelCheckpoint":
if callback.save_best_only:
if self.verbose:
print(
f"Model weights restored to best epoch: {callback.best_epoch + 1}"
)
self.model.load_state_dict(callback.best_state_dict)
else:
if self.verbose:
print(
"Model weights after training corresponds to the those of the "
"final epoch which might not be the best performing weights. Use"
"the 'ModelCheckpoint' Callback to restore the best epoch weights."
)
def _train_step(
self,
X_tab: Tensor,
......@@ -526,81 +475,3 @@ class BayesianTrainer:
self.model.train()
return preds_l
def _set_loss_fn(self, objective, custom_loss_function, **kwargs):
if custom_loss_function is not None:
return custom_loss_function
class_weight = (
torch.tensor(kwargs["class_weight"]).to(self.device)
if "class_weight" in kwargs
else None
)
if self.objective != "regression":
return bayesian_alias_to_loss(objective, weight=class_weight)
else:
return bayesian_alias_to_loss(objective)
def _set_reduce_on_plateau_criterion(
self, lr_scheduler, reducelronplateau_criterion
):
self.reducelronplateau = False
if isinstance(lr_scheduler, ReduceLROnPlateau):
self.reducelronplateau = True
if self.reducelronplateau and not reducelronplateau_criterion:
UserWarning(
"The learning rate scheduler is of type ReduceLROnPlateau. The step method in this"
" scheduler requires a 'metrics' param that can be either the validation loss or the"
" validation metric. Please, when instantiating the Trainer, specify which quantity"
" will be tracked using reducelronplateau_criterion = 'loss' (default) or"
" reducelronplateau_criterion = 'metric'"
)
else:
self.reducelronplateau_criterion = "loss"
def _set_lr_scheduler_running_params(self, lr_scheduler, **kwargs):
# ReduceLROnPlateau is special
reducelronplateau_criterion = kwargs.get("reducelronplateau_criterion", None)
self._set_reduce_on_plateau_criterion(lr_scheduler, reducelronplateau_criterion)
if lr_scheduler is not None:
self.cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower()
else:
self.cyclic_lr = False
def _set_callbacks_and_metrics(self, callbacks, metrics):
self.callbacks: List = [History(), LRShedulerCallback()]
if callbacks is not None:
for callback in callbacks:
if isinstance(callback, type):
callback = callback()
self.callbacks.append(callback)
if metrics is not None:
self.metric = MultipleMetrics(metrics)
self.callbacks += [MetricCallback(self.metric)]
else:
self.metric = None
self.callback_container = CallbackContainer(self.callbacks)
self.callback_container.set_model(self.model)
self.callback_container.set_trainer(self)
@staticmethod
def _set_device_and_num_workers(**kwargs):
# Important note for Mac users: Since python 3.8, the multiprocessing
# library start method changed from 'fork' to 'spawn'. This affects the
# data-loaders, which will not run in parallel.
default_num_workers = (
0
if sys.platform == "darwin" and sys.version_info.minor > 7
else os.cpu_count()
)
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = kwargs.get("device", default_device)
num_workers = kwargs.get("num_workers", default_num_workers)
return device, num_workers
import os
import sys
import json
import warnings
from inspect import signature
......@@ -13,39 +11,25 @@ from torch import nn
from scipy.sparse import csc_matrix
from torchmetrics import Metric as TorchMetric
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_widedeep.losses import ZILNLoss
from pytorch_widedeep.metrics import Metric, MultipleMetrics
from pytorch_widedeep.metrics import Metric
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.callbacks import (
History,
Callback,
MetricCallback,
CallbackContainer,
LRShedulerCallback,
)
from pytorch_widedeep.callbacks import Callback
from pytorch_widedeep.dataloaders import DataLoaderDefault
from pytorch_widedeep.initializers import Initializer, MultipleInitializer
from pytorch_widedeep.initializers import Initializer
from pytorch_widedeep.training._finetune import FineTune
from pytorch_widedeep.utils.general_utils import Alias
from pytorch_widedeep.training._wd_dataset import WideDeepDataset
from pytorch_widedeep.training._base_trainers import BaseTrainer
from pytorch_widedeep.training._trainer_utils import (
alias_to_loss,
save_epoch_logs,
wd_train_val_split,
print_loss_and_metric,
)
from pytorch_widedeep.models.tabular.tabnet._utils import create_explain_matrix
from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer
from pytorch_widedeep.training._multiple_transforms import MultipleTransforms
from pytorch_widedeep.training._loss_and_obj_aliases import _ObjectiveToMethod
from pytorch_widedeep.training._multiple_lr_scheduler import (
MultipleLRScheduler,
)
class Trainer:
class Trainer(BaseTrainer):
r"""Class to set the of attributes that will be used during the
training process.
......@@ -221,34 +205,20 @@ class Trainer:
seed: int = 1,
**kwargs,
):
self._check_inputs(
model, objective, optimizers, lr_schedulers, custom_loss_function
super().__init__(
model=model,
objective=objective,
custom_loss_function=custom_loss_function,
optimizers=optimizers,
lr_schedulers=lr_schedulers,
initializers=initializers,
transforms=transforms,
callbacks=callbacks,
metrics=metrics,
verbose=verbose,
seed=seed,
**kwargs,
)
self.device, self.num_workers = self._set_device_and_num_workers(**kwargs)
# initialize early_stop. If EarlyStopping Callback is used it will
# take care of it
self.early_stop = False
self.verbose = verbose
self.seed = seed
self.model = model
if self.model.is_tabnet:
self.lambda_sparse = kwargs.get("lambda_sparse", 1e-3)
self.reducing_matrix = create_explain_matrix(self.model)
self.model.to(self.device)
self.model.wd_device = self.device
self.objective = objective
self.method = _ObjectiveToMethod.get(objective)
self._initialize(initializers)
self.loss_fn = self._set_loss_fn(objective, custom_loss_function, **kwargs)
self.optimizer = self._set_optimizer(optimizers)
self.lr_scheduler = self._set_lr_scheduler(lr_schedulers, **kwargs)
self.transforms = self._set_transforms(transforms)
self._set_callbacks_and_metrics(callbacks, metrics)
@Alias("finetune", "warmup")
def fit( # noqa: C901
......@@ -695,64 +665,6 @@ class Trainer:
if self.method == "multiclass":
return np.vstack(preds_l)
def get_embeddings(
self, col_name: str, cat_encoding_dict: Dict[str, Dict[str, int]]
) -> Dict[str, np.ndarray]: # pragma: no cover
r"""Returns the learned embeddings for the categorical features passed through
``deeptabular``.
.. note:: This function will be deprecated in the next relase. Please consider
using ``Tab2Vec`` instead.
This method is designed to take an encoding dictionary in the same
format as that of the :obj:`LabelEncoder` Attribute in the class
:obj:`TabPreprocessor`. See
:class:`pytorch_widedeep.preprocessing.TabPreprocessor` and
:class:`pytorch_widedeep.utils.dense_utils.LabelEncder`.
Parameters
----------
col_name: str,
Column name of the feature we want to get the embeddings for
cat_encoding_dict: Dict
Dictionary where the keys are the name of the column for which we
want to retrieve the embeddings and the values are also of type
Dict. These Dict values have keys that are the categories for that
column and the values are the corresponding numberical encodings
e.g.: {'column': {'cat_0': 1, 'cat_1': 2, ...}}
Examples
--------
For a series of comprehensive examples please, see the `Examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
folder in the repo
For completion, here we include a `"fabricated"` example, i.e.
assuming we have already trained the model, that we have the
categorical encodings in a dictionary name ``encoding_dict``, and that
there is a column called `'education'`:
.. code-block:: python
trainer.get_embeddings(col_name="education", cat_encoding_dict=encoding_dict)
"""
warnings.warn(
"'get_embeddings' will be deprecated in the next release. "
"Please consider using 'Tab2vec' instead",
DeprecationWarning,
)
for n, p in self.model.named_parameters():
if "embed_layers" in n and col_name in n:
embed_mtx = p.cpu().data.numpy()
encoding_dict = cat_encoding_dict[col_name]
inv_encoding_dict = {v: k for k, v in encoding_dict.items()}
cat_embed_dict = {}
for idx, value in inv_encoding_dict.items():
cat_embed_dict[value] = embed_mtx[idx]
return cat_embed_dict
def explain(self, X_tab: np.ndarray, save_step_masks: bool = False):
"""
if the ``deeptabular`` component is a :obj:`Tabnet` model, returns the
......@@ -876,35 +788,6 @@ class Trainer:
with open(save_dir / "feature_importance.json", "w") as fi:
json.dump(self.feature_importance, fi)
def _restore_best_weights(self):
already_restored = any(
[
(
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
)
for callback in self.callback_container.callbacks
]
)
if already_restored:
pass
else:
for callback in self.callback_container.callbacks:
if callback.__class__.__name__ == "ModelCheckpoint":
if callback.save_best_only:
if self.verbose:
print(
f"Model weights restored to best epoch: {callback.best_epoch + 1}"
)
self.model.load_state_dict(callback.best_state_dict)
else:
if self.verbose:
print(
"Model weights after training corresponds to the those of the "
"final epoch which might not be the best performing weights. Use"
"the 'ModelCheckpoint' Callback to restore the best epoch weights."
)
@Alias("n_epochs", ["finetune_epochs", "warmup_epochs"])
@Alias("max_lr", ["finetune_max_lr", "warmup_max_lr"])
def _finetune(
......@@ -1287,195 +1170,3 @@ class Trainer:
finetune_args[k] = v
return lds_args, dataloader_args, finetune_args
def _initialize(self, initializers):
if initializers is not None:
if isinstance(initializers, Dict):
self.initializer = MultipleInitializer(
initializers, verbose=self.verbose
)
self.initializer.apply(self.model)
elif isinstance(initializers, type):
self.initializer = initializers()
self.initializer(self.model)
elif isinstance(initializers, Initializer):
self.initializer = initializers
self.initializer(self.model)
def _set_loss_fn(self, objective, custom_loss_function, **kwargs):
class_weight = (
torch.tensor(kwargs["class_weight"]).to(self.device)
if "class_weight" in kwargs
else None
)
if custom_loss_function is not None:
return custom_loss_function
elif (
self.method not in ["regression", "qregression"]
and "focal_loss" not in objective
):
return alias_to_loss(objective, weight=class_weight)
elif "focal_loss" in objective:
alpha = kwargs.get("alpha", 0.25)
gamma = kwargs.get("gamma", 2.0)
return alias_to_loss(objective, alpha=alpha, gamma=gamma)
else:
return alias_to_loss(objective)
def _set_optimizer(self, optimizers):
if optimizers is not None:
if isinstance(optimizers, Optimizer):
optimizer: Union[Optimizer, MultipleOptimizer] = optimizers
elif isinstance(optimizers, Dict):
opt_names = list(optimizers.keys())
mod_names = [n for n, c in self.model.named_children()]
# if with_fds - the prediction layer is part of the model and
# should be optimized with the rest of deeptabular
# component/model
if self.model.with_fds:
mod_names.remove("enf_pos")
mod_names.remove("fds_layer")
optimizers["deeptabular"].add_param_group(
{"params": self.model.fds_layer.pred_layer.parameters()}
)
for mn in mod_names:
assert mn in opt_names, "No optimizer found for {}".format(mn)
optimizer = MultipleOptimizer(optimizers)
else:
optimizer = torch.optim.Adam(self.model.parameters()) # type: ignore
return optimizer
def _set_reduce_on_plateau_criterion(
self, lr_schedulers, reducelronplateau_criterion
):
self.reducelronplateau = False
if isinstance(lr_schedulers, Dict):
for _, scheduler in lr_schedulers.items():
if isinstance(scheduler, ReduceLROnPlateau):
self.reducelronplateau = True
elif isinstance(lr_schedulers, ReduceLROnPlateau):
self.reducelronplateau = True
if self.reducelronplateau and not reducelronplateau_criterion:
UserWarning(
"The learning rate scheduler of at least one of the model components is of type "
"ReduceLROnPlateau. The step method in this scheduler requires a 'metrics' param "
"that can be either the validation loss or the validation metric. Please, when "
"instantiating the Trainer, specify which quantity will be tracked using "
"reducelronplateau_criterion = 'loss' (default) or reducelronplateau_criterion = 'metric'"
)
self.reducelronplateau_criterion = "loss"
else:
self.reducelronplateau_criterion = reducelronplateau_criterion
def _set_lr_scheduler(self, lr_schedulers, **kwargs):
# ReduceLROnPlateau is special
reducelronplateau_criterion = kwargs.get("reducelronplateau_criterion", None)
self._set_reduce_on_plateau_criterion(
lr_schedulers, reducelronplateau_criterion
)
if lr_schedulers is not None:
if isinstance(lr_schedulers, LRScheduler) or isinstance(
lr_schedulers, ReduceLROnPlateau
):
lr_scheduler = lr_schedulers
cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower()
else:
lr_scheduler = MultipleLRScheduler(lr_schedulers)
scheduler_names = [
sc.__class__.__name__.lower()
for _, sc in lr_scheduler._schedulers.items()
]
cyclic_lr = any(["cycl" in sn for sn in scheduler_names])
else:
lr_scheduler, cyclic_lr = None, False
self.cyclic_lr = cyclic_lr
return lr_scheduler
@staticmethod
def _set_transforms(transforms):
if transforms is not None:
return MultipleTransforms(transforms)()
else:
return None
def _set_callbacks_and_metrics(self, callbacks, metrics):
self.callbacks: List = [History(), LRShedulerCallback()]
if callbacks is not None:
for callback in callbacks:
if isinstance(callback, type):
callback = callback()
self.callbacks.append(callback)
if metrics is not None:
self.metric = MultipleMetrics(metrics)
self.callbacks += [MetricCallback(self.metric)]
else:
self.metric = None
self.callback_container = CallbackContainer(self.callbacks)
self.callback_container.set_model(self.model)
self.callback_container.set_trainer(self)
@staticmethod
def _check_inputs(
model,
objective,
optimizers,
lr_schedulers,
custom_loss_function,
):
if model.with_fds and _ObjectiveToMethod.get(objective) != "regression":
raise ValueError(
"Feature Distribution Smooting can be used only for regression"
)
if _ObjectiveToMethod.get(objective) == "multiclass" and model.pred_dim == 1:
raise ValueError(
"This is a multiclass classification problem but the size of the output layer"
" is set to 1. Please, set the 'pred_dim' param equal to the number of classes "
" when instantiating the 'WideDeep' class"
)
if isinstance(optimizers, Dict):
if lr_schedulers is not None and not isinstance(lr_schedulers, Dict):
raise ValueError(
"''optimizers' and 'lr_schedulers' must have consistent type: "
"(Optimizer and LRScheduler) or (Dict[str, Optimizer] and Dict[str, LRScheduler]) "
"Please, read the documentation or see the examples for more details"
)
if custom_loss_function is not None and objective not in [
"binary",
"multiclass",
"regression",
]:
raise ValueError(
"If 'custom_loss_function' is not None, 'objective' must be 'binary' "
"'multiclass' or 'regression', consistent with the loss function"
)
@staticmethod
def _set_device_and_num_workers(**kwargs):
# Important note for Mac users: Since python 3.8, the multiprocessing
# library start method changed from 'fork' to 'spawn'. This affects the
# data-loaders, which will not run in parallel.
default_num_workers = (
0
if sys.platform == "darwin" and sys.version_info.minor > 7
else os.cpu_count()
)
default_device = "cuda" if torch.cuda.is_available() else "cpu"
device = kwargs.get("device", default_device)
num_workers = kwargs.get("num_workers", default_num_workers)
return device, num_workers
......@@ -347,11 +347,6 @@ def test_save_load_and_predict():
assert preds.shape[0] == X_tab.shape[0]
###############################################################################
# test get_embeddings DeprecationWarning
###############################################################################
def create_test_dataset(input_type, input_type_2=None):
df = pd.DataFrame()
col1 = list(np.random.choice(input_type, 32))
......@@ -371,40 +366,6 @@ df["col4"] = np.round(np.random.rand(32), 3)
df["target"] = np.random.choice(2, 32)
def test_get_embeddings_deprecation_warning():
embed_cols = [("col1", 5), ("col2", 5)]
continuous_cols = ["col3", "col4"]
tab_preprocessor = TabPreprocessor(
cat_embed_cols=embed_cols, continuous_cols=continuous_cols
)
X_tab = tab_preprocessor.fit_transform(df)
target = df.target.values
tabmlp = TabMlp(
mlp_hidden_dims=[32, 16],
mlp_dropout=[0.5, 0.5],
column_idx={k: v for v, k in enumerate(df.columns)},
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
)
model = WideDeep(deeptabular=tabmlp)
trainer = Trainer(model, objective="binary", verbose=0)
trainer.fit(
X_tab=X_tab,
target=target,
batch_size=16,
)
with pytest.warns(DeprecationWarning):
trainer.get_embeddings(
col_name="col1",
cat_encoding_dict=tab_preprocessor.label_encoder.encoding_dict,
)
###############################################################################
# test test_handle_columns_with_dots
###############################################################################
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册