callbacks.py 21.7 KB
Newer Older
J
jrzaurin 已提交
1
"""
2 3 4
Code here is mostly based on the code from the torchsample and Keras packages

CREDIT TO THE TORCHSAMPLE AND KERAS TEAMS
J
jrzaurin 已提交
5
"""
6 7 8
import os
import datetime
import warnings
9 10

import numpy as np
11 12
import torch

13
from pytorch_widedeep.wdtypes import *  # noqa: F403
14

15 16 17 18

def _get_current_time():
    return datetime.datetime.now().strftime("%B %d, %Y - %I:%M%p")

19

20 21 22 23
class CallbackContainer(object):
    """
    Container holding a list of callbacks.
    """
J
jrzaurin 已提交
24 25

    def __init__(self, callbacks: Optional[List] = None, queue_length: int = 10):
26 27 28
        instantiated_callbacks = []
        if callbacks is not None:
            for callback in callbacks:
J
jrzaurin 已提交
29 30 31 32
                if isinstance(callback, type):
                    instantiated_callbacks.append(callback())
                else:
                    instantiated_callbacks.append(callback)
33 34 35 36 37 38 39
        self.callbacks = [c for c in instantiated_callbacks]
        self.queue_length = queue_length

    def set_params(self, params):
        for callback in self.callbacks:
            callback.set_params(params)

J
jrzaurin 已提交
40
    def set_model(self, model: Any):
41 42 43 44
        self.model = model
        for callback in self.callbacks:
            callback.set_model(model)

45 46 47 48 49
    def set_trainer(self, trainer: Any):
        self.trainer = trainer
        for callback in self.callbacks:
            callback.set_trainer(trainer)

J
jrzaurin 已提交
50
    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None):
51 52 53 54
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_epoch_begin(epoch, logs)

J
jrzaurin 已提交
55
    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
56 57 58 59
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_epoch_end(epoch, logs)

J
jrzaurin 已提交
60
    def on_batch_begin(self, batch: int, logs: Optional[Dict] = None):
61 62 63 64
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_batch_begin(batch, logs)

J
jrzaurin 已提交
65
    def on_batch_end(self, batch: int, logs: Optional[Dict] = None):
66 67 68 69
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_batch_end(batch, logs)

J
jrzaurin 已提交
70
    def on_train_begin(self, logs: Optional[Dict] = None):
71
        logs = logs or {}
J
jrzaurin 已提交
72
        logs["start_time"] = _get_current_time()
73 74 75
        for callback in self.callbacks:
            callback.on_train_begin(logs)

J
jrzaurin 已提交
76
    def on_train_end(self, logs: Optional[Dict] = None):
77
        logs = logs or {}
78 79 80
        # logs['final_loss'] = self.model.history.epoch_losses[-1],
        # logs['best_loss'] = min(self.model.history.epoch_losses),
        # logs['stop_time'] = _get_current_time()
81 82 83 84 85 86
        for callback in self.callbacks:
            callback.on_train_end(logs)


class Callback(object):
    """
87
    Base class used to build new callbacks.
88 89 90 91 92 93 94 95
    """

    def __init__(self):
        pass

    def set_params(self, params):
        self.params = params

J
jrzaurin 已提交
96
    def set_model(self, model: Any):
97 98
        self.model = model

99 100 101
    def set_trainer(self, trainer: Any):
        self.trainer = trainer

J
jrzaurin 已提交
102
    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None):
103 104
        pass

J
jrzaurin 已提交
105
    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
106 107
        pass

J
jrzaurin 已提交
108
    def on_batch_begin(self, batch: int, logs: Optional[Dict] = None):
109 110
        pass

J
jrzaurin 已提交
111
    def on_batch_end(self, batch: int, logs: Optional[Dict] = None):
112 113
        pass

J
jrzaurin 已提交
114
    def on_train_begin(self, logs: Optional[Dict] = None):
115 116
        pass

J
jrzaurin 已提交
117
    def on_train_end(self, logs: Optional[Dict] = None):
118 119 120 121
        pass


class History(Callback):
122
    r"""Callback that records metrics to a ``history`` attribute.
123

124 125
    This callback runs by default within :obj:`Trainer`, therefore, should not
    be passed to the ``Trainer``. Is included here just for completion.
126 127
    """

J
jrzaurin 已提交
128
    def on_train_begin(self, logs: Optional[Dict] = None):
129
        self.trainer.history = {}
130

J
jrzaurin 已提交
131
    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
132 133
        logs = logs or {}
        for k, v in logs.items():
134
            self.trainer.history.setdefault(k, []).append(v)
135 136


137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
class LRShedulerCallback(Callback):
    r"""Callback for the learning rate schedulers to take a step

    This callback runs by default within :obj:`Trainer`, therefore, should not
    be passed to the ``Trainer``. Is included here just for completion.
    """

    def on_batch_end(self, batch: int, logs: Optional[Dict] = None):
        if self.trainer.lr_scheduler is not None:
            if self._multiple_scheduler():
                for (
                    model_name,
                    scheduler,
                ) in self.trainer.lr_scheduler._schedulers.items():
                    if self._is_cyclic(model_name):
                        scheduler.step()
            elif self.trainer.cyclic_lr:
                self.trainer.lr_scheduler.step()

    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
        if self.trainer.lr_scheduler is not None:
            if self._multiple_scheduler():
                for (
                    model_name,
                    scheduler,
                ) in self.trainer.lr_scheduler._schedulers.items():
                    if not self._is_cyclic(model_name):
                        scheduler.step()
            elif not self.trainer.cyclic_lr:
                self.trainer.lr_scheduler.step()

    def _multiple_scheduler(self):
        return self.trainer.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"

    def _is_cyclic(self, model_name: str):
        return (
            self._has_scheduler(model_name)
            and "cycl"
            in self.trainer.lr_scheduler._schedulers[
                model_name
            ].__class__.__name__.lower()
        )

    def _has_scheduler(self, model_name: str):
        return model_name in self.trainer.lr_scheduler._schedulers


184
class LRHistory(Callback):
185
    def __init__(self, n_epochs: int):
186
        r"""Saves the learning rates during training to a ``lr_history`` attribute.
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208

        Callbacks are passed as input parameters to the ``Trainer`` class. See
        :class:`pytorch_widedeep.trainer.Trainer`

        Parameters
        ----------
        n_epochs: int
            number of epochs durint training

        Examples
        --------
        >>> from pytorch_widedeep.callbacks import LRHistory
        >>> from pytorch_widedeep.models import TabMlp, Wide, WideDeep
        >>> from pytorch_widedeep.training import Trainer
        >>>
        >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
        >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
        >>> wide = Wide(10, 1)
        >>> deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
        >>> model = WideDeep(wide, deep)
        >>> trainer = Trainer(model, objective="regression", callbacks=[LRHistory(n_epochs=10)])
        """
209 210
        super(LRHistory, self).__init__()
        self.n_epochs = n_epochs
211

J
jrzaurin 已提交
212
    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None):
213 214
        if epoch == 0 and self.trainer.lr_scheduler is not None:
            self.trainer.lr_history = {}
215 216
            if self._multiple_scheduler():
                self._save_group_lr_mulitple_scheduler(step_location="on_epoch_begin")
217
            else:
218
                self._save_group_lr(self.trainer.optimizer)
219

J
jrzaurin 已提交
220
    def on_batch_end(self, batch: int, logs: Optional[Dict] = None):
221
        if self.trainer.lr_scheduler is not None:
222 223
            if self._multiple_scheduler():
                self._save_group_lr_mulitple_scheduler(step_location="on_batch_end")
224 225
            elif self.trainer.cyclic_lr:
                self._save_group_lr(self.trainer.optimizer)
226

J
jrzaurin 已提交
227
    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
228
        if epoch != (self.n_epochs - 1) and self.trainer.lr_scheduler is not None:
229 230
            if self._multiple_scheduler():
                self._save_group_lr_mulitple_scheduler(step_location="on_epoch_end")
231 232
            elif not self.trainer.cyclic_lr:
                self._save_group_lr(self.trainer.optimizer)
233 234

    def _save_group_lr_mulitple_scheduler(self, step_location: str):
235
        for model_name, opt in self.trainer.optimizer._optimizers.items():
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
            if step_location == "on_epoch_begin":
                self._save_group_lr(opt, model_name)
            if step_location == "on_batch_end":
                if self._is_cyclic(model_name):
                    self._save_group_lr(opt, model_name)
            if step_location == "on_epoch_end":
                if not self._is_cyclic(model_name):
                    self._save_group_lr(opt, model_name)

    def _save_group_lr(self, opt: Optimizer, model_name: Optional[str] = None):
        for group_idx, group in enumerate(opt.param_groups):
            if model_name is not None:
                group_name = ("_").join(["lr", model_name, str(group_idx)])
            else:
                group_name = ("_").join(["lr", str(group_idx)])
251
            self.trainer.lr_history.setdefault(group_name, []).append(group["lr"])
252 253

    def _multiple_scheduler(self):
254
        return self.trainer.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"
255 256 257 258 259

    def _is_cyclic(self, model_name: str):
        return (
            self._has_scheduler(model_name)
            and "cycl"
260
            in self.trainer.lr_scheduler._schedulers[
261 262 263 264 265
                model_name
            ].__class__.__name__.lower()
        )

    def _has_scheduler(self, model_name: str):
266
        return model_name in self.trainer.lr_scheduler._schedulers
267 268


269
class ModelCheckpoint(Callback):
J
jrzaurin 已提交
270 271 272 273 274 275 276 277 278 279
    def __init__(
        self,
        filepath: str,
        monitor: str = "val_loss",
        verbose: int = 0,
        save_best_only: bool = False,
        mode: str = "auto",
        period: int = 1,
        max_save: int = -1,
    ):
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
        r"""Saves the model after every epoch.

        This class is almost identical to the corresponding keras class.
        Therefore, **credit** to the Keras Team.

        Callbacks are passed as input parameters to the ``Trainer`` class. See
        :class:`pytorch_widedeep.trainer.Trainer`

        Parameters
        ----------
        filepath: str
            Full path to save the output weights. It must contain only the root of
            the filenames. Epoch number and ``.pt`` extension (for pytorch) will
            be added. e.g. ``filepath="path/to/output_weights/weights_out"`` And
            the saved files in that directory will be named: ``weights_out_1.pt,
            weights_out_2.pt, ...``
        monitor: str, default="val_loss"
            quantity to monitor. :obj:`ModelCheckpoint` will infer if this is a
            loss (i.e. contains the str `'loss'`) or a metric (i.e. contains the
            str `'acc'` or starts with `'fmeasure'`).
        verbose:int, default=0,
            verbosity mode
        save_best_only: bool, default=False,
            the latest best model according to the quantity monitored will not be
            overwritten.
        mode: str, default="auto",
            If ``save_best_only=True``, the decision to overwrite the current save
            file is made based on either the maximization or the minimization of
            the monitored quantity. For `'val_acc'`, this should be `'max'`, for
            `'val_loss'` this should be `'min'`, etc. In `'auto'` mode, the
            direction is automatically inferred from the name of the monitored
            quantity.
        period: int, default=1,
            Interval (number of epochs) between checkpoints.
        max_save: int, default=-1
            Maximum number of outputs to save. If -1 will save all outputs

        Examples
        --------
        >>> from pytorch_widedeep.callbacks import ModelCheckpoint
        >>> from pytorch_widedeep.models import TabMlp, Wide, WideDeep
        >>> from pytorch_widedeep.training import Trainer
        >>>
        >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
        >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
        >>> wide = Wide(10, 1)
        >>> deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
        >>> model = WideDeep(wide, deep)
        >>> trainer = Trainer(model, objective="regression", callbacks=[ModelCheckpoint(filepath='checkpoints/weights_out')])
        """
330
        super(ModelCheckpoint, self).__init__()
331 332

        self.filepath = filepath
333 334 335
        self.monitor = monitor
        self.verbose = verbose
        self.save_best_only = save_best_only
336
        self.mode = mode
337 338 339
        self.period = period
        self.max_save = max_save

340 341 342
        self.epochs_since_last_save = 0

        if len(self.filepath.split("/")[:-1]) == 0:
343 344 345 346 347
            raise ValueError(
                "'filepath' must be the full path to save the output weights,"
                " including the root of the filenames. e.g. 'checkpoints/weights_out'"
            )

348
        root_dir = ("/").join(self.filepath.split("/")[:-1])
349 350 351
        if not os.path.exists(root_dir):
            os.makedirs(root_dir)

352
        if self.max_save > 0:
J
jrzaurin 已提交
353
            self.old_files: List[str] = []
354

355
        if self.mode not in ["auto", "min", "max"]:
J
jrzaurin 已提交
356 357
            warnings.warn(
                "ModelCheckpoint mode %s is unknown, "
358
                "fallback to auto mode." % (self.mode),
J
jrzaurin 已提交
359 360
                RuntimeWarning,
            )
361 362
            self.mode = "auto"
        if self.mode == "min":
363 364
            self.monitor_op = np.less
            self.best = np.Inf
365
        elif self.mode == "max":
366 367 368
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
J
jrzaurin 已提交
369
            if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
370 371 372 373 374 375
                self.monitor_op = np.greater
                self.best = -np.Inf
            else:
                self.monitor_op = np.less
                self.best = np.Inf

376
    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):  # noqa: C901
377 378 379 380
        logs = logs or {}
        self.epochs_since_last_save += 1
        if self.epochs_since_last_save >= self.period:
            self.epochs_since_last_save = 0
J
jrzaurin 已提交
381
            filepath = "{}_{}.p".format(self.filepath, epoch + 1)
382 383 384
            if self.save_best_only:
                current = logs.get(self.monitor)
                if current is None:
J
jrzaurin 已提交
385 386 387 388 389
                    warnings.warn(
                        "Can save best model only with %s available, "
                        "skipping." % (self.monitor),
                        RuntimeWarning,
                    )
390 391 392
                else:
                    if self.monitor_op(current, self.best):
                        if self.verbose > 0:
J
jrzaurin 已提交
393 394 395 396 397 398 399 400 401 402 403
                            print(
                                "\nEpoch %05d: %s improved from %0.5f to %0.5f,"
                                " saving model to %s"
                                % (
                                    epoch + 1,
                                    self.monitor,
                                    self.best,
                                    current,
                                    filepath,
                                )
                            )
404 405 406 407 408 409
                        self.best = current
                        torch.save(self.model.state_dict(), filepath)
                        if self.max_save > 0:
                            if len(self.old_files) == self.max_save:
                                try:
                                    os.remove(self.old_files[0])
410
                                except FileNotFoundError:
411 412
                                    pass
                                self.old_files = self.old_files[1:]
413
                            self.old_files.append(filepath)
414 415
                    else:
                        if self.verbose > 0:
J
jrzaurin 已提交
416 417 418 419
                            print(
                                "\nEpoch %05d: %s did not improve from %0.5f"
                                % (epoch + 1, self.monitor, self.best)
                            )
420 421
            else:
                if self.verbose > 0:
J
jrzaurin 已提交
422
                    print("\nEpoch %05d: saving model to %s" % (epoch + 1, filepath))
423 424 425 426 427
                torch.save(self.model.state_dict(), filepath)
                if self.max_save > 0:
                    if len(self.old_files) == self.max_save:
                        try:
                            os.remove(self.old_files[0])
428
                        except FileNotFoundError:
429 430
                            pass
                        self.old_files = self.old_files[1:]
431
                    self.old_files.append(filepath)
432 433 434


class EarlyStopping(Callback):
J
jrzaurin 已提交
435 436 437 438 439 440 441 442 443 444
    def __init__(
        self,
        monitor: str = "val_loss",
        min_delta: float = 0.0,
        patience: int = 10,
        verbose: int = 0,
        mode: str = "auto",
        baseline: Optional[float] = None,
        restore_best_weights: bool = False,
    ):
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492
        r"""Stop training when a monitored quantity has stopped improving.

        This class is almost identical to the corresponding keras class.
        Therefore, **credit** to the Keras Team.

        Callbacks are passed as input parameters to the ``Trainer`` class. See
        :class:`pytorch_widedeep.trainer.Trainer`

        Parameters
        -----------
        monitor: str, default='val_loss'.
            Quantity to be monitored.
        min_delta: float, default=0.
            minimum change in the monitored quantity to qualify as an
            improvement, i.e. an absolute change of less than min_delta, will
            count as no improvement.
        patience: int, default=10.
            Number of epochs that produced the monitored quantity with no
            improvement after which training will be stopped.
        verbose: int.
            verbosity mode.
        mode: str, default='auto'
            one of {'`auto`', '`min`', '`max`'}. In `'min'` mode, training will
            stop when the quantity monitored has stopped decreasing; in `'max'`
            mode it will stop when the quantity monitored has stopped increasing;
            in `'auto'` mode, the direction is automatically inferred from the
            name of the monitored quantity.
        baseline: float, Optional. default=None.
            Baseline value for the monitored quantity to reach. Training will
            stop if the model does not show improvement over the baseline.
        restore_best_weights: bool, default=None
            Whether to restore model weights from the epoch with the best
            value of the monitored quantity. If ``False``, the model weights
            obtained at the last step of training are used.

        Examples
        --------
        >>> from pytorch_widedeep.callbacks import EarlyStopping
        >>> from pytorch_widedeep.models import TabMlp, Wide, WideDeep
        >>> from pytorch_widedeep.training import Trainer
        >>>
        >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
        >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
        >>> wide = Wide(10, 1)
        >>> deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
        >>> model = WideDeep(wide, deep)
        >>> trainer = Trainer(model, objective="regression", callbacks=[EarlyStopping(patience=10)])
        """
J
jrzaurin 已提交
493
        super(EarlyStopping, self).__init__()
494 495

        self.monitor = monitor
496
        self.min_delta = min_delta
497 498
        self.patience = patience
        self.verbose = verbose
499 500 501 502
        self.mode = mode
        self.baseline = baseline
        self.restore_best_weights = restore_best_weights

503 504 505 506
        self.wait = 0
        self.stopped_epoch = 0
        self.state_dict = None

507
        if self.mode not in ["auto", "min", "max"]:
J
jrzaurin 已提交
508
            warnings.warn(
509 510
                "EarlyStopping mode %s is unknown, "
                "fallback to auto mode." % self.mode,
J
jrzaurin 已提交
511 512
                RuntimeWarning,
            )
513
            self.mode = "auto"
514

515
        if self.mode == "min":
516
            self.monitor_op = np.less
517
        elif self.mode == "max":
518 519
            self.monitor_op = np.greater
        else:
J
jrzaurin 已提交
520
            if "acc" in self.monitor:
521 522 523 524 525 526 527 528 529
                self.monitor_op = np.greater
            else:
                self.monitor_op = np.less

        if self.monitor_op == np.greater:
            self.min_delta *= 1
        else:
            self.min_delta *= -1

J
jrzaurin 已提交
530
    def on_train_begin(self, logs: Optional[Dict] = None):
531 532 533 534 535 536 537 538
        # Allow instances to be re-used
        self.wait = 0
        self.stopped_epoch = 0
        if self.baseline is not None:
            self.best = self.baseline
        else:
            self.best = np.Inf if self.monitor_op == np.less else -np.Inf

J
jrzaurin 已提交
539
    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
540 541 542 543 544 545 546 547 548 549 550 551 552
        current = self.get_monitor_value(logs)
        if current is None:
            return

        if self.monitor_op(current - self.min_delta, self.best):
            self.best = current
            self.wait = 0
            if self.restore_best_weights:
                self.state_dict = self.model.state_dict()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
553
                self.trainer.early_stop = True
554 555
                if self.restore_best_weights:
                    if self.verbose > 0:
J
jrzaurin 已提交
556 557 558
                        print(
                            "Restoring model weights from the end of " "the best epoch"
                        )
559 560
                    self.model.load_state_dict(self.state_dict)

J
jrzaurin 已提交
561
    def on_train_end(self, logs: Optional[Dict] = None):
562
        if self.stopped_epoch > 0 and self.verbose > 0:
J
jrzaurin 已提交
563
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
564 565 566 567

    def get_monitor_value(self, logs):
        monitor_value = logs.get(self.monitor)
        if monitor_value is None:
J
jrzaurin 已提交
568 569 570 571 572
            warnings.warn(
                "Early stopping conditioned on metric `%s` "
                "which is not available. Available metrics are: %s"
                % (self.monitor, ",".join(list(logs.keys()))),
                RuntimeWarning,
573
            )
J
jrzaurin 已提交
574
        return monitor_value