callbacks.py 20.0 KB
Newer Older
J
jrzaurin 已提交
1
"""
2 3 4
Code here is mostly based on the code from the torchsample and Keras packages

CREDIT TO THE TORCHSAMPLE AND KERAS TEAMS
J
jrzaurin 已提交
5
"""
6 7 8
import os
import datetime
import warnings
9 10

import numpy as np
11 12
import torch

13
from pytorch_widedeep.wdtypes import *  # noqa: F403
14

15 16 17 18

def _get_current_time():
    return datetime.datetime.now().strftime("%B %d, %Y - %I:%M%p")

19

20 21 22 23
class CallbackContainer(object):
    """
    Container holding a list of callbacks.
    """
J
jrzaurin 已提交
24 25

    def __init__(self, callbacks: Optional[List] = None, queue_length: int = 10):
26 27 28
        instantiated_callbacks = []
        if callbacks is not None:
            for callback in callbacks:
J
jrzaurin 已提交
29 30 31 32
                if isinstance(callback, type):
                    instantiated_callbacks.append(callback())
                else:
                    instantiated_callbacks.append(callback)
33 34 35 36 37 38 39
        self.callbacks = [c for c in instantiated_callbacks]
        self.queue_length = queue_length

    def set_params(self, params):
        for callback in self.callbacks:
            callback.set_params(params)

J
jrzaurin 已提交
40
    def set_model(self, model: Any):
41 42 43 44
        self.model = model
        for callback in self.callbacks:
            callback.set_model(model)

45 46 47 48 49
    def set_trainer(self, trainer: Any):
        self.trainer = trainer
        for callback in self.callbacks:
            callback.set_trainer(trainer)

J
jrzaurin 已提交
50
    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None):
51 52 53 54
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_epoch_begin(epoch, logs)

J
jrzaurin 已提交
55
    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
56 57 58 59
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_epoch_end(epoch, logs)

J
jrzaurin 已提交
60
    def on_batch_begin(self, batch: int, logs: Optional[Dict] = None):
61 62 63 64
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_batch_begin(batch, logs)

J
jrzaurin 已提交
65
    def on_batch_end(self, batch: int, logs: Optional[Dict] = None):
66 67 68 69
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_batch_end(batch, logs)

J
jrzaurin 已提交
70
    def on_train_begin(self, logs: Optional[Dict] = None):
71
        logs = logs or {}
J
jrzaurin 已提交
72
        logs["start_time"] = _get_current_time()
73 74 75
        for callback in self.callbacks:
            callback.on_train_begin(logs)

J
jrzaurin 已提交
76
    def on_train_end(self, logs: Optional[Dict] = None):
77
        logs = logs or {}
78 79 80
        # logs['final_loss'] = self.model.history.epoch_losses[-1],
        # logs['best_loss'] = min(self.model.history.epoch_losses),
        # logs['stop_time'] = _get_current_time()
81 82 83 84 85 86
        for callback in self.callbacks:
            callback.on_train_end(logs)


class Callback(object):
    """
87
    Base class used to build new callbacks.
88 89 90 91 92 93 94 95
    """

    def __init__(self):
        pass

    def set_params(self, params):
        self.params = params

J
jrzaurin 已提交
96
    def set_model(self, model: Any):
97 98
        self.model = model

99 100 101
    def set_trainer(self, trainer: Any):
        self.trainer = trainer

J
jrzaurin 已提交
102
    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None):
103 104
        pass

J
jrzaurin 已提交
105
    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
106 107
        pass

J
jrzaurin 已提交
108
    def on_batch_begin(self, batch: int, logs: Optional[Dict] = None):
109 110
        pass

J
jrzaurin 已提交
111
    def on_batch_end(self, batch: int, logs: Optional[Dict] = None):
112 113
        pass

J
jrzaurin 已提交
114
    def on_train_begin(self, logs: Optional[Dict] = None):
115 116
        pass

J
jrzaurin 已提交
117
    def on_train_end(self, logs: Optional[Dict] = None):
118 119 120 121
        pass


class History(Callback):
122
    r"""Callback that records metrics to a ``history`` attribute.
123

124 125 126 127
    This callback runs by default within :obj:`Trainer`. Callbacks are passed
    as input parameters to the ``Trainer`` class See
    :class:`pytorch_widedeep.trainer.Trainer`. Documentation is included here
    for completion.
128 129
    """

J
jrzaurin 已提交
130
    def on_train_begin(self, logs: Optional[Dict] = None):
131
        self.trainer.history = {}
132

J
jrzaurin 已提交
133
    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
134 135
        logs = logs or {}
        for k, v in logs.items():
136
            self.trainer.history.setdefault(k, []).append(v)
137 138


139
class LRHistory(Callback):
140
    def __init__(self, n_epochs: int):
141
        r"""Saves the learning rates during training to a ``lr_history`` attribute.
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163

        Callbacks are passed as input parameters to the ``Trainer`` class. See
        :class:`pytorch_widedeep.trainer.Trainer`

        Parameters
        ----------
        n_epochs: int
            number of epochs durint training

        Examples
        --------
        >>> from pytorch_widedeep.callbacks import LRHistory
        >>> from pytorch_widedeep.models import TabMlp, Wide, WideDeep
        >>> from pytorch_widedeep.training import Trainer
        >>>
        >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
        >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
        >>> wide = Wide(10, 1)
        >>> deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
        >>> model = WideDeep(wide, deep)
        >>> trainer = Trainer(model, objective="regression", callbacks=[LRHistory(n_epochs=10)])
        """
164 165
        super(LRHistory, self).__init__()
        self.n_epochs = n_epochs
166

J
jrzaurin 已提交
167
    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None):
168 169
        if epoch == 0 and self.trainer.lr_scheduler is not None:
            self.trainer.lr_history = {}
170 171
            if self._multiple_scheduler():
                self._save_group_lr_mulitple_scheduler(step_location="on_epoch_begin")
172
            else:
173
                self._save_group_lr(self.trainer.optimizer)
174

J
jrzaurin 已提交
175
    def on_batch_end(self, batch: int, logs: Optional[Dict] = None):
176
        if self.trainer.lr_scheduler is not None:
177 178
            if self._multiple_scheduler():
                self._save_group_lr_mulitple_scheduler(step_location="on_batch_end")
179 180
            elif self.trainer.cyclic_lr:
                self._save_group_lr(self.trainer.optimizer)
181

J
jrzaurin 已提交
182
    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
183
        if epoch != (self.n_epochs - 1) and self.trainer.lr_scheduler is not None:
184 185
            if self._multiple_scheduler():
                self._save_group_lr_mulitple_scheduler(step_location="on_epoch_end")
186 187
            elif not self.trainer.cyclic_lr:
                self._save_group_lr(self.trainer.optimizer)
188 189

    def _save_group_lr_mulitple_scheduler(self, step_location: str):
190
        for model_name, opt in self.trainer.optimizer._optimizers.items():
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
            if step_location == "on_epoch_begin":
                self._save_group_lr(opt, model_name)
            if step_location == "on_batch_end":
                if self._is_cyclic(model_name):
                    self._save_group_lr(opt, model_name)
            if step_location == "on_epoch_end":
                if not self._is_cyclic(model_name):
                    self._save_group_lr(opt, model_name)

    def _save_group_lr(self, opt: Optimizer, model_name: Optional[str] = None):
        for group_idx, group in enumerate(opt.param_groups):
            if model_name is not None:
                group_name = ("_").join(["lr", model_name, str(group_idx)])
            else:
                group_name = ("_").join(["lr", str(group_idx)])
206
            self.trainer.lr_history.setdefault(group_name, []).append(group["lr"])
207 208

    def _multiple_scheduler(self):
209
        return self.trainer.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"
210 211 212 213 214

    def _is_cyclic(self, model_name: str):
        return (
            self._has_scheduler(model_name)
            and "cycl"
215
            in self.trainer.lr_scheduler._schedulers[
216 217 218 219 220
                model_name
            ].__class__.__name__.lower()
        )

    def _has_scheduler(self, model_name: str):
221
        return model_name in self.trainer.lr_scheduler._schedulers
222 223


224
class ModelCheckpoint(Callback):
J
jrzaurin 已提交
225 226 227 228 229 230 231 232 233 234
    def __init__(
        self,
        filepath: str,
        monitor: str = "val_loss",
        verbose: int = 0,
        save_best_only: bool = False,
        mode: str = "auto",
        period: int = 1,
        max_save: int = -1,
    ):
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
        r"""Saves the model after every epoch.

        This class is almost identical to the corresponding keras class.
        Therefore, **credit** to the Keras Team.

        Callbacks are passed as input parameters to the ``Trainer`` class. See
        :class:`pytorch_widedeep.trainer.Trainer`

        Parameters
        ----------
        filepath: str
            Full path to save the output weights. It must contain only the root of
            the filenames. Epoch number and ``.pt`` extension (for pytorch) will
            be added. e.g. ``filepath="path/to/output_weights/weights_out"`` And
            the saved files in that directory will be named: ``weights_out_1.pt,
            weights_out_2.pt, ...``
        monitor: str, default="val_loss"
            quantity to monitor. :obj:`ModelCheckpoint` will infer if this is a
            loss (i.e. contains the str `'loss'`) or a metric (i.e. contains the
            str `'acc'` or starts with `'fmeasure'`).
        verbose:int, default=0,
            verbosity mode
        save_best_only: bool, default=False,
            the latest best model according to the quantity monitored will not be
            overwritten.
        mode: str, default="auto",
            If ``save_best_only=True``, the decision to overwrite the current save
            file is made based on either the maximization or the minimization of
            the monitored quantity. For `'val_acc'`, this should be `'max'`, for
            `'val_loss'` this should be `'min'`, etc. In `'auto'` mode, the
            direction is automatically inferred from the name of the monitored
            quantity.
        period: int, default=1,
            Interval (number of epochs) between checkpoints.
        max_save: int, default=-1
            Maximum number of outputs to save. If -1 will save all outputs

        Examples
        --------
        >>> from pytorch_widedeep.callbacks import ModelCheckpoint
        >>> from pytorch_widedeep.models import TabMlp, Wide, WideDeep
        >>> from pytorch_widedeep.training import Trainer
        >>>
        >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
        >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
        >>> wide = Wide(10, 1)
        >>> deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
        >>> model = WideDeep(wide, deep)
        >>> trainer = Trainer(model, objective="regression", callbacks=[ModelCheckpoint(filepath='checkpoints/weights_out')])
        """
285
        super(ModelCheckpoint, self).__init__()
286 287

        self.filepath = filepath
288 289 290
        self.monitor = monitor
        self.verbose = verbose
        self.save_best_only = save_best_only
291
        self.mode = mode
292 293 294
        self.period = period
        self.max_save = max_save

295 296 297
        self.epochs_since_last_save = 0

        if len(self.filepath.split("/")[:-1]) == 0:
298 299 300 301 302
            raise ValueError(
                "'filepath' must be the full path to save the output weights,"
                " including the root of the filenames. e.g. 'checkpoints/weights_out'"
            )

303
        root_dir = ("/").join(self.filepath.split("/")[:-1])
304 305 306
        if not os.path.exists(root_dir):
            os.makedirs(root_dir)

307
        if self.max_save > 0:
J
jrzaurin 已提交
308
            self.old_files: List[str] = []
309

310
        if self.mode not in ["auto", "min", "max"]:
J
jrzaurin 已提交
311 312
            warnings.warn(
                "ModelCheckpoint mode %s is unknown, "
313
                "fallback to auto mode." % (self.mode),
J
jrzaurin 已提交
314 315
                RuntimeWarning,
            )
316 317
            self.mode = "auto"
        if self.mode == "min":
318 319
            self.monitor_op = np.less
            self.best = np.Inf
320
        elif self.mode == "max":
321 322 323
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
J
jrzaurin 已提交
324
            if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
325 326 327 328 329 330
                self.monitor_op = np.greater
                self.best = -np.Inf
            else:
                self.monitor_op = np.less
                self.best = np.Inf

331
    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):  # noqa: C901
332 333 334 335
        logs = logs or {}
        self.epochs_since_last_save += 1
        if self.epochs_since_last_save >= self.period:
            self.epochs_since_last_save = 0
J
jrzaurin 已提交
336
            filepath = "{}_{}.p".format(self.filepath, epoch + 1)
337 338 339
            if self.save_best_only:
                current = logs.get(self.monitor)
                if current is None:
J
jrzaurin 已提交
340 341 342 343 344
                    warnings.warn(
                        "Can save best model only with %s available, "
                        "skipping." % (self.monitor),
                        RuntimeWarning,
                    )
345 346 347
                else:
                    if self.monitor_op(current, self.best):
                        if self.verbose > 0:
J
jrzaurin 已提交
348 349 350 351 352 353 354 355 356 357 358
                            print(
                                "\nEpoch %05d: %s improved from %0.5f to %0.5f,"
                                " saving model to %s"
                                % (
                                    epoch + 1,
                                    self.monitor,
                                    self.best,
                                    current,
                                    filepath,
                                )
                            )
359 360 361 362 363 364
                        self.best = current
                        torch.save(self.model.state_dict(), filepath)
                        if self.max_save > 0:
                            if len(self.old_files) == self.max_save:
                                try:
                                    os.remove(self.old_files[0])
365
                                except FileNotFoundError:
366 367
                                    pass
                                self.old_files = self.old_files[1:]
368
                            self.old_files.append(filepath)
369 370
                    else:
                        if self.verbose > 0:
J
jrzaurin 已提交
371 372 373 374
                            print(
                                "\nEpoch %05d: %s did not improve from %0.5f"
                                % (epoch + 1, self.monitor, self.best)
                            )
375 376
            else:
                if self.verbose > 0:
J
jrzaurin 已提交
377
                    print("\nEpoch %05d: saving model to %s" % (epoch + 1, filepath))
378 379 380 381 382
                torch.save(self.model.state_dict(), filepath)
                if self.max_save > 0:
                    if len(self.old_files) == self.max_save:
                        try:
                            os.remove(self.old_files[0])
383
                        except FileNotFoundError:
384 385
                            pass
                        self.old_files = self.old_files[1:]
386
                    self.old_files.append(filepath)
387 388 389


class EarlyStopping(Callback):
J
jrzaurin 已提交
390 391 392 393 394 395 396 397 398 399
    def __init__(
        self,
        monitor: str = "val_loss",
        min_delta: float = 0.0,
        patience: int = 10,
        verbose: int = 0,
        mode: str = "auto",
        baseline: Optional[float] = None,
        restore_best_weights: bool = False,
    ):
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
        r"""Stop training when a monitored quantity has stopped improving.

        This class is almost identical to the corresponding keras class.
        Therefore, **credit** to the Keras Team.

        Callbacks are passed as input parameters to the ``Trainer`` class. See
        :class:`pytorch_widedeep.trainer.Trainer`

        Parameters
        -----------
        monitor: str, default='val_loss'.
            Quantity to be monitored.
        min_delta: float, default=0.
            minimum change in the monitored quantity to qualify as an
            improvement, i.e. an absolute change of less than min_delta, will
            count as no improvement.
        patience: int, default=10.
            Number of epochs that produced the monitored quantity with no
            improvement after which training will be stopped.
        verbose: int.
            verbosity mode.
        mode: str, default='auto'
            one of {'`auto`', '`min`', '`max`'}. In `'min'` mode, training will
            stop when the quantity monitored has stopped decreasing; in `'max'`
            mode it will stop when the quantity monitored has stopped increasing;
            in `'auto'` mode, the direction is automatically inferred from the
            name of the monitored quantity.
        baseline: float, Optional. default=None.
            Baseline value for the monitored quantity to reach. Training will
            stop if the model does not show improvement over the baseline.
        restore_best_weights: bool, default=None
            Whether to restore model weights from the epoch with the best
            value of the monitored quantity. If ``False``, the model weights
            obtained at the last step of training are used.

        Examples
        --------
        >>> from pytorch_widedeep.callbacks import EarlyStopping
        >>> from pytorch_widedeep.models import TabMlp, Wide, WideDeep
        >>> from pytorch_widedeep.training import Trainer
        >>>
        >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
        >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
        >>> wide = Wide(10, 1)
        >>> deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
        >>> model = WideDeep(wide, deep)
        >>> trainer = Trainer(model, objective="regression", callbacks=[EarlyStopping(patience=10)])
        """
J
jrzaurin 已提交
448
        super(EarlyStopping, self).__init__()
449 450

        self.monitor = monitor
451
        self.min_delta = min_delta
452 453
        self.patience = patience
        self.verbose = verbose
454 455 456 457
        self.mode = mode
        self.baseline = baseline
        self.restore_best_weights = restore_best_weights

458 459 460 461
        self.wait = 0
        self.stopped_epoch = 0
        self.state_dict = None

462
        if self.mode not in ["auto", "min", "max"]:
J
jrzaurin 已提交
463
            warnings.warn(
464 465
                "EarlyStopping mode %s is unknown, "
                "fallback to auto mode." % self.mode,
J
jrzaurin 已提交
466 467
                RuntimeWarning,
            )
468
            self.mode = "auto"
469

470
        if self.mode == "min":
471
            self.monitor_op = np.less
472
        elif self.mode == "max":
473 474
            self.monitor_op = np.greater
        else:
J
jrzaurin 已提交
475
            if "acc" in self.monitor:
476 477 478 479 480 481 482 483 484
                self.monitor_op = np.greater
            else:
                self.monitor_op = np.less

        if self.monitor_op == np.greater:
            self.min_delta *= 1
        else:
            self.min_delta *= -1

J
jrzaurin 已提交
485
    def on_train_begin(self, logs: Optional[Dict] = None):
486 487 488 489 490 491 492 493
        # Allow instances to be re-used
        self.wait = 0
        self.stopped_epoch = 0
        if self.baseline is not None:
            self.best = self.baseline
        else:
            self.best = np.Inf if self.monitor_op == np.less else -np.Inf

J
jrzaurin 已提交
494
    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
495 496 497 498 499 500 501 502 503 504 505 506 507
        current = self.get_monitor_value(logs)
        if current is None:
            return

        if self.monitor_op(current - self.min_delta, self.best):
            self.best = current
            self.wait = 0
            if self.restore_best_weights:
                self.state_dict = self.model.state_dict()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
508
                self.trainer.early_stop = True
509 510
                if self.restore_best_weights:
                    if self.verbose > 0:
J
jrzaurin 已提交
511 512 513
                        print(
                            "Restoring model weights from the end of " "the best epoch"
                        )
514 515
                    self.model.load_state_dict(self.state_dict)

J
jrzaurin 已提交
516
    def on_train_end(self, logs: Optional[Dict] = None):
517
        if self.stopped_epoch > 0 and self.verbose > 0:
J
jrzaurin 已提交
518
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
519 520 521 522

    def get_monitor_value(self, logs):
        monitor_value = logs.get(self.monitor)
        if monitor_value is None:
J
jrzaurin 已提交
523 524 525 526 527
            warnings.warn(
                "Early stopping conditioned on metric `%s` "
                "which is not available. Available metrics are: %s"
                % (self.monitor, ",".join(list(logs.keys()))),
                RuntimeWarning,
528
            )
J
jrzaurin 已提交
529
        return monitor_value