diff --git a/pytorch_widedeep/callbacks.py b/pytorch_widedeep/callbacks.py index ee5b81ddd615ffa23f75657575eea09b0bc8d1ac..af895c5a66936e525cce7e768f86d5ecd8338c0a 100644 --- a/pytorch_widedeep/callbacks.py +++ b/pytorch_widedeep/callbacks.py @@ -9,6 +9,7 @@ import warnings import numpy as np import torch +from torch.optim.lr_scheduler import ReduceLROnPlateau from pytorch_widedeep.wdtypes import * # noqa: F403 @@ -17,6 +18,14 @@ def _get_current_time(): return datetime.datetime.now().strftime("%B %d, %Y - %I:%M%p") +def _is_metric(monitor: str): + # We assume no one will use f3 or more + if any([s in monitor for s in ["acc", "prec", "rec", "fscore", "f1", "f2"]]): + return True + else: + return False + + class CallbackContainer(object): """ Container holding a list of callbacks. @@ -52,10 +61,12 @@ class CallbackContainer(object): for callback in self.callbacks: callback.on_epoch_begin(epoch, logs) - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None): + def on_epoch_end( + self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None + ): logs = logs or {} for callback in self.callbacks: - callback.on_epoch_end(epoch, logs) + callback.on_epoch_end(epoch, logs, metric) def on_batch_begin(self, batch: int, logs: Optional[Dict] = None): logs = logs or {} @@ -73,6 +84,12 @@ class CallbackContainer(object): for callback in self.callbacks: callback.on_train_begin(logs) + def on_eval_begin(self, logs: Optional[Dict] = None): + # at the moment only used to reset metrics before eval + logs = logs or {} + for callback in self.callbacks: + callback.on_eval_begin(logs) + def on_train_end(self, logs: Optional[Dict] = None): logs = logs or {} # logs['final_loss'] = self.model.history.epoch_losses[-1], @@ -102,7 +119,9 @@ class Callback(object): def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None): pass - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None): + def on_epoch_end( + self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None + ): pass def on_batch_begin(self, batch: int, logs: Optional[Dict] = None): @@ -114,6 +133,10 @@ class Callback(object): def on_train_begin(self, logs: Optional[Dict] = None): pass + def on_eval_begin(self, logs: Optional[Dict] = None): + # at the moment only used to reset metrics before eval + pass + def on_train_end(self, logs: Optional[Dict] = None): pass @@ -128,7 +151,9 @@ class History(Callback): def on_train_begin(self, logs: Optional[Dict] = None): self.trainer.history = {} - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None): + def on_epoch_end( + self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None + ): logs = logs or {} for k, v in logs.items(): self.trainer.history.setdefault(k, []).append(v) @@ -153,7 +178,9 @@ class LRShedulerCallback(Callback): elif self.trainer.cyclic_lr: self.trainer.lr_scheduler.step() - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None): + def on_epoch_end( + self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None + ): if self.trainer.lr_scheduler is not None: if self._multiple_scheduler(): for ( @@ -161,9 +188,15 @@ class LRShedulerCallback(Callback): scheduler, ) in self.trainer.lr_scheduler._schedulers.items(): if not self._is_cyclic(model_name): - scheduler.step() + if isinstance(scheduler, ReduceLROnPlateau): + scheduler.step(metric) + else: + scheduler.step() elif not self.trainer.cyclic_lr: - self.trainer.lr_scheduler.step() + if isinstance(self.trainer.lr_scheduler, ReduceLROnPlateau): + self.trainer.lr_scheduler.step(metric) + else: + self.trainer.lr_scheduler.step() def _multiple_scheduler(self): return self.trainer.lr_scheduler.__class__.__name__ == "MultipleLRScheduler" @@ -224,7 +257,9 @@ class LRHistory(Callback): elif self.trainer.cyclic_lr: self._save_group_lr(self.trainer.optimizer) - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None): + def on_epoch_end( + self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None + ): if epoch != (self.n_epochs - 1) and self.trainer.lr_scheduler is not None: if self._multiple_scheduler(): self._save_group_lr_mulitple_scheduler(step_location="on_epoch_end") @@ -293,10 +328,8 @@ class ModelCheckpoint(Callback): be added. e.g. ``filepath="path/to/output_weights/weights_out"`` And the saved files in that directory will be named: ``weights_out_1.pt, weights_out_2.pt, ...`` - monitor: str, default="val_loss" - quantity to monitor. :obj:`ModelCheckpoint` will infer if this is a - loss (i.e. contains the str `'loss'`) or a metric (i.e. contains the - str `'acc'` or starts with `'fmeasure'`). + monitor: str, default="loss" + quantity to monitor. Typically 'val_loss' or metric name (e.g. 'val_acc') verbose:int, default=0, verbosity mode save_best_only: bool, default=False, @@ -305,8 +338,8 @@ class ModelCheckpoint(Callback): mode: str, default="auto", If ``save_best_only=True``, the decision to overwrite the current save file is made based on either the maximization or the minimization of - the monitored quantity. For `'val_acc'`, this should be `'max'`, for - `'val_loss'` this should be `'min'`, etc. In `'auto'` mode, the + the monitored quantity. For `'acc'`, this should be `'max'`, for + `'loss'` this should be `'min'`, etc. In `'auto'` mode, the direction is automatically inferred from the name of the monitored quantity. period: int, default=1, @@ -366,14 +399,16 @@ class ModelCheckpoint(Callback): self.monitor_op = np.greater self.best = -np.Inf else: - if "acc" in self.monitor or self.monitor.startswith("fmeasure"): + if _is_metric(self.monitor): self.monitor_op = np.greater self.best = -np.Inf else: self.monitor_op = np.less self.best = np.Inf - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None): # noqa: C901 + def on_epoch_end( # noqa: C901 + self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None + ): logs = logs or {} self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: @@ -453,7 +488,7 @@ class EarlyStopping(Callback): Parameters ----------- monitor: str, default='val_loss'. - Quantity to be monitored. + Quantity to monitor. Typically 'val_loss' or metric name (e.g. 'val_acc') min_delta: float, default=0. minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will @@ -517,7 +552,7 @@ class EarlyStopping(Callback): elif self.mode == "max": self.monitor_op = np.greater else: - if "acc" in self.monitor: + if _is_metric(self.monitor): self.monitor_op = np.greater else: self.monitor_op = np.less @@ -536,7 +571,9 @@ class EarlyStopping(Callback): else: self.best = np.Inf if self.monitor_op == np.less else -np.Inf - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None): + def on_epoch_end( + self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None + ): current = self.get_monitor_value(logs) if current is None: return diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py index e98fe957a1aa3f9ee5542eaf9406eec5eb0c0385..282a61d4e484c245620dda320a8fed780d269469 100644 --- a/pytorch_widedeep/metrics.py +++ b/pytorch_widedeep/metrics.py @@ -46,6 +46,9 @@ class MetricCallback(Callback): def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None): self.container.reset() + def on_eval_begin(self, logs: Optional[Dict] = None): + self.container.reset() + class Accuracy(Metric): def __init__(self, top_k: int = 1): diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index c913675fc90460b0b3fe9f7ecc8bca17893d6dbc..9a18c15ac2818aef50626e0ee325fd80f7ed5e18 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from tqdm import trange from scipy.sparse import csc_matrix from torch.utils.data import DataLoader +from torch.optim.lr_scheduler import ReduceLROnPlateau from pytorch_widedeep.metrics import Metric, MetricCallback, MultipleMetrics from pytorch_widedeep.wdtypes import * # noqa: F403 @@ -54,6 +55,7 @@ class Trainer: custom_loss_function: Optional[Module] = None, optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None, lr_schedulers: Optional[Union[LRScheduler, Dict[str, LRScheduler]]] = None, + reduce_on: Optional[str] = "loss", initializers: Optional[Union[Initializer, Dict[str, Initializer]]] = None, transforms: Optional[List[Transforms]] = None, callbacks: Optional[List[Callback]] = None, @@ -256,6 +258,15 @@ class Trainer: "'multiclass' or 'regression', consistent with the loss function" ) + self.reducelronplateau = False + self.reduce_on = reduce_on + if isinstance(lr_schedulers, Dict): + for _, scheduler in lr_schedulers.items(): + if isinstance(scheduler, ReduceLROnPlateau): + self.reducelronplateau = True + elif isinstance(lr_schedulers, ReduceLROnPlateau): + self.reducelronplateau = True + if isinstance(model, str): self.model = torch.load(model) else: @@ -309,7 +320,6 @@ class Trainer: n_epochs: int = 1, validation_freq: int = 1, batch_size: int = 32, - patience: int = 10, finetune: bool = False, finetune_epochs: int = 5, finetune_max_lr: float = 0.01, @@ -362,9 +372,7 @@ class Trainer: validation_freq: int, default=1 epochs validation frequency batch_size: int, default=32 - patience: int, default=10 - Number of epochs without improving the target metric or loss - before the fit process stops + batch size finetune: bool, default=False param alias: ``warmup`` @@ -571,23 +579,32 @@ class Trainer: with trange(train_steps, disable=self.verbose != 1) as t: for batch_idx, (data, targett) in zip(t, train_loader): t.set_description("epoch %i" % (epoch + 1)) - score, train_loss = self._train_step(data, targett, batch_idx) - print_loss_and_metric(t, train_loss, score) + train_score, train_loss = self._train_step(data, targett, batch_idx) + print_loss_and_metric(t, train_loss, train_score) self.callback_container.on_batch_end(batch=batch_idx) - epoch_logs = save_epoch_logs(epoch_logs, train_loss, score, "train") + epoch_logs = save_epoch_logs(epoch_logs, train_loss, train_score, "train") + on_epoch_end_metric = None if eval_set is not None and epoch % validation_freq == ( validation_freq - 1 ): + self.callback_container.on_eval_begin() self.valid_running_loss = 0.0 with trange(eval_steps, disable=self.verbose != 1) as v: for i, (data, targett) in zip(v, eval_loader): v.set_description("valid") - score, val_loss = self._eval_step(data, targett, i) - print_loss_and_metric(v, val_loss, score) - epoch_logs = save_epoch_logs(epoch_logs, val_loss, score, "val") + val_score, val_loss = self._eval_step(data, targett, i) + print_loss_and_metric(v, val_loss, val_score) + epoch_logs = save_epoch_logs(epoch_logs, val_loss, val_score, "val") + + if self.reducelronplateau: + if self.reduce_on == "loss": + on_epoch_end_metric = val_loss + else: + on_epoch_end_metric = val_score[self.reduce_on] + + self.callback_container.on_epoch_end(epoch, epoch_logs, on_epoch_end_metric) - self.callback_container.on_epoch_end(epoch, epoch_logs) if self.early_stop: self.callback_container.on_train_end(epoch_logs) break @@ -1047,12 +1064,8 @@ class Trainer: return preds_l def _set_loss_fn(self, objective, class_weight, custom_loss_function, alpha, gamma): - if isinstance(class_weight, float): - class_weight = torch.tensor([1.0 - class_weight, class_weight]) - elif isinstance(class_weight, (tuple, list)): + if class_weight is not None: class_weight = torch.tensor(class_weight) - else: - class_weight = None if custom_loss_function is not None: return custom_loss_function elif self.method != "regression" and "focal_loss" not in objective: @@ -1092,11 +1105,12 @@ class Trainer: def _set_lr_scheduler(self, lr_schedulers): if lr_schedulers is not None: - if isinstance(lr_schedulers, LRScheduler): - lr_scheduler: Union[ - LRScheduler, - MultipleLRScheduler, - ] = lr_schedulers + # ReduceLROnPlateau is special, only scheduler that is 'just' an + # object rather than a LRScheduler + if isinstance(lr_schedulers, LRScheduler) or isinstance( + lr_schedulers, ReduceLROnPlateau + ): + lr_scheduler = lr_schedulers cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower() else: lr_scheduler = MultipleLRScheduler(lr_schedulers) diff --git a/pytorch_widedeep/training/trainer_utils.py b/pytorch_widedeep/training/trainer_utils.py index 5308cad498370756091e4a64778f184319be9910..372217edc7b825792b18e8b54ecd4ce47dd13695 100644 --- a/pytorch_widedeep/training/trainer_utils.py +++ b/pytorch_widedeep/training/trainer_utils.py @@ -214,7 +214,7 @@ def alias_to_loss(loss_fn: str, **kwargs): "or loss functions: {}".format(", ".join(_ObjectiveToMethod.keys())) ) if loss_fn in _LossAliases.get("binary"): - return nn.BCEWithLogitsLoss(weight=kwargs["weight"]) + return nn.BCEWithLogitsLoss(pos_weight=kwargs["weight"]) if loss_fn in _LossAliases.get("multiclass"): return nn.CrossEntropyLoss(weight=kwargs["weight"]) if loss_fn in _LossAliases.get("regression"):