callbacks.py 26.6 KB
Newer Older
J
jrzaurin 已提交
1
"""
2 3 4
Code here is mostly based on the code from the torchsample and Keras packages

CREDIT TO THE TORCHSAMPLE AND KERAS TEAMS
J
jrzaurin 已提交
5
"""
6 7 8
import os
import datetime
import warnings
9 10

import numpy as np
11
import torch
12
from ray import tune
13
from torch.optim.lr_scheduler import ReduceLROnPlateau
14

15
from pytorch_widedeep.metrics import MultipleMetrics
16
from pytorch_widedeep.wdtypes import *  # noqa: F403
17

18 19 20 21

def _get_current_time():
    return datetime.datetime.now().strftime("%B %d, %Y - %I:%M%p")

22

23 24 25 26 27 28 29 30
def _is_metric(monitor: str):
    # We assume no one will use f3 or more
    if any([s in monitor for s in ["acc", "prec", "rec", "fscore", "f1", "f2"]]):
        return True
    else:
        return False


31 32 33 34
class CallbackContainer(object):
    """
    Container holding a list of callbacks.
    """
J
jrzaurin 已提交
35 36

    def __init__(self, callbacks: Optional[List] = None, queue_length: int = 10):
37 38 39
        instantiated_callbacks = []
        if callbacks is not None:
            for callback in callbacks:
J
jrzaurin 已提交
40 41 42 43
                if isinstance(callback, type):
                    instantiated_callbacks.append(callback())
                else:
                    instantiated_callbacks.append(callback)
44 45 46 47 48 49 50
        self.callbacks = [c for c in instantiated_callbacks]
        self.queue_length = queue_length

    def set_params(self, params):
        for callback in self.callbacks:
            callback.set_params(params)

J
jrzaurin 已提交
51
    def set_model(self, model: Any):
52 53 54 55
        self.model = model
        for callback in self.callbacks:
            callback.set_model(model)

56 57 58 59 60
    def set_trainer(self, trainer: Any):
        self.trainer = trainer
        for callback in self.callbacks:
            callback.set_trainer(trainer)

J
jrzaurin 已提交
61
    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None):
62 63 64 65
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_epoch_begin(epoch, logs)

66 67 68
    def on_epoch_end(
        self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
    ):
69 70
        logs = logs or {}
        for callback in self.callbacks:
71
            callback.on_epoch_end(epoch, logs, metric)
72

J
jrzaurin 已提交
73
    def on_batch_begin(self, batch: int, logs: Optional[Dict] = None):
74 75 76 77
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_batch_begin(batch, logs)

J
jrzaurin 已提交
78
    def on_batch_end(self, batch: int, logs: Optional[Dict] = None):
79 80 81 82
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_batch_end(batch, logs)

J
jrzaurin 已提交
83
    def on_train_begin(self, logs: Optional[Dict] = None):
84
        logs = logs or {}
J
jrzaurin 已提交
85
        logs["start_time"] = _get_current_time()
86 87 88
        for callback in self.callbacks:
            callback.on_train_begin(logs)

J
jrzaurin 已提交
89
    def on_train_end(self, logs: Optional[Dict] = None):
90
        logs = logs or {}
91 92 93
        # logs['final_loss'] = self.model.history.epoch_losses[-1],
        # logs['best_loss'] = min(self.model.history.epoch_losses),
        # logs['stop_time'] = _get_current_time()
94 95 96
        for callback in self.callbacks:
            callback.on_train_end(logs)

97
    def on_eval_begin(self, logs: Optional[Dict] = None):
98
        # at the moment only used to reset metrics before eval
99 100 101 102
        logs = logs or {}
        for callback in self.callbacks:
            callback.on_eval_begin(logs)

103 104 105

class Callback(object):
    """
106
    Base class used to build new callbacks.
107 108 109 110 111 112 113 114
    """

    def __init__(self):
        pass

    def set_params(self, params):
        self.params = params

J
jrzaurin 已提交
115
    def set_model(self, model: Any):
116 117
        self.model = model

118 119 120
    def set_trainer(self, trainer: Any):
        self.trainer = trainer

J
jrzaurin 已提交
121
    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None):
122 123
        pass

124 125 126
    def on_epoch_end(
        self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
    ):
127 128
        pass

J
jrzaurin 已提交
129
    def on_batch_begin(self, batch: int, logs: Optional[Dict] = None):
130 131
        pass

J
jrzaurin 已提交
132
    def on_batch_end(self, batch: int, logs: Optional[Dict] = None):
133 134
        pass

J
jrzaurin 已提交
135
    def on_train_begin(self, logs: Optional[Dict] = None):
136 137
        pass

J
jrzaurin 已提交
138
    def on_train_end(self, logs: Optional[Dict] = None):
139 140
        pass

141
    def on_eval_begin(self, logs: Optional[Dict] = None):
142
        # at the moment only used to reset metrics before eval
143 144
        pass

145 146

class History(Callback):
147
    r"""Callback that records metrics to a ``history`` attribute.
148

149
    This callback runs by default within :obj:`Trainer`, therefore, should not
150
    be passed to the :obj:`Trainer`. Is included here just for completion.
151
    """
152

J
jrzaurin 已提交
153
    def on_train_begin(self, logs: Optional[Dict] = None):
154
        self.trainer.history = {}
155

156 157 158
    def on_epoch_end(
        self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
    ):
159 160
        logs = logs or {}
        for k, v in logs.items():
161
            if isinstance(v, np.ndarray):
162
                v = v.tolist()
163
            if isinstance(v, list) and len(v) > 1:
P
Pavol Mulinka 已提交
164
                for i in range(len(v)):
165
                    self.trainer.history.setdefault(k + "_" + str(i), []).append(v[i])
P
Pavol Mulinka 已提交
166 167 168
            else:
                self.trainer.history.setdefault(k, []).append(v)

169

170 171 172 173
class LRShedulerCallback(Callback):
    r"""Callback for the learning rate schedulers to take a step

    This callback runs by default within :obj:`Trainer`, therefore, should not
174
    be passed to the :obj:`Trainer`. Is included here just for completion.
175 176 177 178 179 180 181 182 183 184 185 186 187 188
    """

    def on_batch_end(self, batch: int, logs: Optional[Dict] = None):
        if self.trainer.lr_scheduler is not None:
            if self._multiple_scheduler():
                for (
                    model_name,
                    scheduler,
                ) in self.trainer.lr_scheduler._schedulers.items():
                    if self._is_cyclic(model_name):
                        scheduler.step()
            elif self.trainer.cyclic_lr:
                self.trainer.lr_scheduler.step()

189 190 191
    def on_epoch_end(
        self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
    ):
192 193 194 195 196 197 198
        if self.trainer.lr_scheduler is not None:
            if self._multiple_scheduler():
                for (
                    model_name,
                    scheduler,
                ) in self.trainer.lr_scheduler._schedulers.items():
                    if not self._is_cyclic(model_name):
199 200 201 202
                        if isinstance(scheduler, ReduceLROnPlateau):
                            scheduler.step(metric)
                        else:
                            scheduler.step()
203
            elif not self.trainer.cyclic_lr:
204 205 206 207
                if isinstance(self.trainer.lr_scheduler, ReduceLROnPlateau):
                    self.trainer.lr_scheduler.step(metric)
                else:
                    self.trainer.lr_scheduler.step()
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224

    def _multiple_scheduler(self):
        return self.trainer.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"

    def _is_cyclic(self, model_name: str):
        return (
            self._has_scheduler(model_name)
            and "cycl"
            in self.trainer.lr_scheduler._schedulers[
                model_name
            ].__class__.__name__.lower()
        )

    def _has_scheduler(self, model_name: str):
        return model_name in self.trainer.lr_scheduler._schedulers


225
class MetricCallback(Callback):
226 227 228 229 230 231
    r"""Callback that resets the metrics (if any metric is used)

    This callback runs by default within :obj:`Trainer`, therefore, should not
    be passed to the :obj:`Trainer`. Is included here just for completion.
    """

232 233 234 235 236 237 238 239 240 241
    def __init__(self, container: MultipleMetrics):
        self.container = container

    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None):
        self.container.reset()

    def on_eval_begin(self, logs: Optional[Dict] = None):
        self.container.reset()


242
class LRHistory(Callback):
243 244
    r"""Saves the learning rates during training to a ``lr_history`` attribute.

245
    Callbacks are passed as input parameters to the :obj:`Trainer` class. See
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
    :class:`pytorch_widedeep.trainer.Trainer`

    Parameters
    ----------
    n_epochs: int
        number of epochs durint training

    Examples
    --------
    >>> from pytorch_widedeep.callbacks import LRHistory
    >>> from pytorch_widedeep.models import TabMlp, Wide, WideDeep
    >>> from pytorch_widedeep.training import Trainer
    >>>
    >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
    >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
    >>> wide = Wide(10, 1)
    >>> deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
    >>> model = WideDeep(wide, deep)
    >>> trainer = Trainer(model, objective="regression", callbacks=[LRHistory(n_epochs=10)])
    """

267
    def __init__(self, n_epochs: int):
268 269
        super(LRHistory, self).__init__()
        self.n_epochs = n_epochs
270

J
jrzaurin 已提交
271
    def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None):
272 273
        if epoch == 0 and self.trainer.lr_scheduler is not None:
            self.trainer.lr_history = {}
274 275
            if self._multiple_scheduler():
                self._save_group_lr_mulitple_scheduler(step_location="on_epoch_begin")
276
            else:
277
                self._save_group_lr(self.trainer.optimizer)
278

J
jrzaurin 已提交
279
    def on_batch_end(self, batch: int, logs: Optional[Dict] = None):
280
        if self.trainer.lr_scheduler is not None:
281 282
            if self._multiple_scheduler():
                self._save_group_lr_mulitple_scheduler(step_location="on_batch_end")
283 284
            elif self.trainer.cyclic_lr:
                self._save_group_lr(self.trainer.optimizer)
285

286 287 288
    def on_epoch_end(
        self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
    ):
289
        if epoch != (self.n_epochs - 1) and self.trainer.lr_scheduler is not None:
290 291
            if self._multiple_scheduler():
                self._save_group_lr_mulitple_scheduler(step_location="on_epoch_end")
292 293
            elif not self.trainer.cyclic_lr:
                self._save_group_lr(self.trainer.optimizer)
294 295

    def _save_group_lr_mulitple_scheduler(self, step_location: str):
296
        for model_name, opt in self.trainer.optimizer._optimizers.items():
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
            if step_location == "on_epoch_begin":
                self._save_group_lr(opt, model_name)
            if step_location == "on_batch_end":
                if self._is_cyclic(model_name):
                    self._save_group_lr(opt, model_name)
            if step_location == "on_epoch_end":
                if not self._is_cyclic(model_name):
                    self._save_group_lr(opt, model_name)

    def _save_group_lr(self, opt: Optimizer, model_name: Optional[str] = None):
        for group_idx, group in enumerate(opt.param_groups):
            if model_name is not None:
                group_name = ("_").join(["lr", model_name, str(group_idx)])
            else:
                group_name = ("_").join(["lr", str(group_idx)])
312
            self.trainer.lr_history.setdefault(group_name, []).append(group["lr"])
313 314

    def _multiple_scheduler(self):
315
        return self.trainer.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"
316 317 318 319 320

    def _is_cyclic(self, model_name: str):
        return (
            self._has_scheduler(model_name)
            and "cycl"
321
            in self.trainer.lr_scheduler._schedulers[
322 323 324 325 326
                model_name
            ].__class__.__name__.lower()
        )

    def _has_scheduler(self, model_name: str):
327
        return model_name in self.trainer.lr_scheduler._schedulers
328 329


330
class ModelCheckpoint(Callback):
331 332 333 334 335
    r"""Saves the model after every epoch.

    This class is almost identical to the corresponding keras class.
    Therefore, **credit** to the Keras Team.

336
    Callbacks are passed as input parameters to the :obj:`Trainer` class. See
337 338 339 340
    :class:`pytorch_widedeep.trainer.Trainer`

    Parameters
    ----------
341
    filepath: str, default=None
342 343 344 345 346
        Full path to save the output weights. It must contain only the root of
        the filenames. Epoch number and ``.pt`` extension (for pytorch) will
        be added. e.g. ``filepath="path/to/output_weights/weights_out"`` And
        the saved files in that directory will be named: ``weights_out_1.pt,
        weights_out_2.pt, ...``
347
        If set to None the class just report best metric and best_epoch.
348 349
    monitor: str, default="loss"
        quantity to monitor. Typically 'val_loss' or metric name (e.g. 'val_acc')
J
jrzaurin 已提交
350
    verbose:int, default=0
351 352 353 354
        verbosity mode
    save_best_only: bool, default=False,
        the latest best model according to the quantity monitored will not be
        overwritten.
J
jrzaurin 已提交
355
    mode: str, default="auto"
356 357 358 359 360 361
        If ``save_best_only=True``, the decision to overwrite the current save
        file is made based on either the maximization or the minimization of
        the monitored quantity. For `'acc'`, this should be `'max'`, for
        `'loss'` this should be `'min'`, etc. In `'auto'` mode, the
        direction is automatically inferred from the name of the monitored
        quantity.
J
jrzaurin 已提交
362
    period: int, default=1
363 364 365
        Interval (number of epochs) between checkpoints.
    max_save: int, default=-1
        Maximum number of outputs to save. If -1 will save all outputs
366 367 368 369 370 371 372 373 374
    wb: obj
        Weights&Biases API interface to report single best result usable for comparisson of multiple 
        paramater combinations by e.g. parallel coordinates: 
        https://docs.wandb.ai/ref/app/features/panels/parallel-coordinates. 
        E.g W&B summary report `wandb.run.summary["best"]`:
        If external EarlyStopping scheduler is used from e.g. RayTune in combination with W&B, 
        the RayTune EarlyStopping stops training function and the summary log is not sent if defined 
        after training by e.g.:
        `wandb.run.summary["best"]=model_checkpoint.best`.
375

376 377 378 379 380 381
    Attributes
    ----------
    best: float
        best metric
    best_epoch: int
        best epoch
382 383
    best_state_dict: dict
        best model state dictionary to restore model to its best state using trainer.model.load_state_dict(ModelCheckpoint.best_state_dict)
384

385 386 387 388 389 390 391 392 393 394 395 396 397 398
    Examples
    --------
    >>> from pytorch_widedeep.callbacks import ModelCheckpoint
    >>> from pytorch_widedeep.models import TabMlp, Wide, WideDeep
    >>> from pytorch_widedeep.training import Trainer
    >>>
    >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
    >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
    >>> wide = Wide(10, 1)
    >>> deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
    >>> model = WideDeep(wide, deep)
    >>> trainer = Trainer(model, objective="regression", callbacks=[ModelCheckpoint(filepath='checkpoints/weights_out')])
    """

J
jrzaurin 已提交
399 400
    def __init__(
        self,
401
        filepath: Optional[str] = None,
J
jrzaurin 已提交
402 403 404 405 406 407
        monitor: str = "val_loss",
        verbose: int = 0,
        save_best_only: bool = False,
        mode: str = "auto",
        period: int = 1,
        max_save: int = -1,
408
        wb: Optional[object] = None,
J
jrzaurin 已提交
409
    ):
410
        super(ModelCheckpoint, self).__init__()
411 412

        self.filepath = filepath
413 414 415
        self.monitor = monitor
        self.verbose = verbose
        self.save_best_only = save_best_only
416
        self.mode = mode
417 418
        self.period = period
        self.max_save = max_save
419
        self.wb = wb
420

421 422
        self.epochs_since_last_save = 0

423 424 425 426 427 428
        if self.filepath is not None:
            if len(self.filepath.split("/")[:-1]) == 0:
                raise ValueError(
                    "'filepath' must be the full path to save the output weights,"
                    " including the root of the filenames. e.g. 'checkpoints/weights_out'"
                )
429

430 431 432
            root_dir = ("/").join(self.filepath.split("/")[:-1])
            if not os.path.exists(root_dir):
                os.makedirs(root_dir)
433

434
        if self.max_save > 0:
J
jrzaurin 已提交
435
            self.old_files: List[str] = []
436

437
        if self.mode not in ["auto", "min", "max"]:
J
jrzaurin 已提交
438 439
            warnings.warn(
                "ModelCheckpoint mode %s is unknown, "
440
                "fallback to auto mode." % (self.mode),
J
jrzaurin 已提交
441 442
                RuntimeWarning,
            )
443 444
            self.mode = "auto"
        if self.mode == "min":
445 446
            self.monitor_op = np.less
            self.best = np.Inf
447
        elif self.mode == "max":
448
            self.monitor_op = np.greater  # type: ignore[assignment]
449 450
            self.best = -np.Inf
        else:
451
            if _is_metric(self.monitor):
452
                self.monitor_op = np.greater  # type: ignore[assignment]
453 454 455 456 457
                self.best = -np.Inf
            else:
                self.monitor_op = np.less
                self.best = np.Inf

458 459 460
    def on_epoch_end(  # noqa: C901
        self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
    ):
461 462 463 464
        logs = logs or {}
        self.epochs_since_last_save += 1
        if self.epochs_since_last_save >= self.period:
            self.epochs_since_last_save = 0
J
jrzaurin 已提交
465
            filepath = "{}_{}.p".format(self.filepath, epoch + 1)
466 467 468
            if self.save_best_only:
                current = logs.get(self.monitor)
                if current is None:
J
jrzaurin 已提交
469 470 471 472 473
                    warnings.warn(
                        "Can save best model only with %s available, "
                        "skipping." % (self.monitor),
                        RuntimeWarning,
                    )
474 475
                else:
                    if self.monitor_op(current, self.best):
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497
                        if self.verbose > 0 :
                            if self.filepath is not None:
                                print(
                                    "\nEpoch %05d: %s improved from %0.5f to %0.5f,"
                                    " saving model to %s"
                                    % (
                                        epoch + 1,
                                        self.monitor,
                                        self.best,
                                        current,
                                        filepath,
                                    )
                                )
                            else:
                                print(
                                    "\nEpoch %05d: %s improved from %0.5f to %0.5f"
                                    % (
                                        epoch + 1,
                                        self.monitor,
                                        self.best,
                                        current,
                                    )
J
jrzaurin 已提交
498
                                )
499 500
                        if self.wb is not None:
                            self.wb.run.summary["best"] = current
501
                        self.best = current
502
                        self.best_epoch = epoch
503 504 505 506 507 508 509 510 511 512 513
                        self.best_state_dict = self.model.state_dict()
                        if self.filepath is not None:
                            torch.save(self.best_state_dict, filepath)
                            if self.max_save > 0:
                                if len(self.old_files) == self.max_save:
                                    try:
                                        os.remove(self.old_files[0])
                                    except FileNotFoundError:
                                        pass
                                    self.old_files = self.old_files[1:]
                                self.old_files.append(filepath)
514 515
                    else:
                        if self.verbose > 0:
J
jrzaurin 已提交
516 517 518 519
                            print(
                                "\nEpoch %05d: %s did not improve from %0.5f"
                                % (epoch + 1, self.monitor, self.best)
                            )
520
            if not self.save_best_only and self.filepath is not None:
521
                if self.verbose > 0:
J
jrzaurin 已提交
522
                    print("\nEpoch %05d: saving model to %s" % (epoch + 1, filepath))
523 524 525 526 527
                torch.save(self.model.state_dict(), filepath)
                if self.max_save > 0:
                    if len(self.old_files) == self.max_save:
                        try:
                            os.remove(self.old_files[0])
528
                        except FileNotFoundError:
529 530
                            pass
                        self.old_files = self.old_files[1:]
531
                    self.old_files.append(filepath)
532

533 534 535 536 537 538 539 540
    def __getstate__(self):
        d = self.__dict__
        self_dict = {k: d[k] for k in d if k not in ["trainer", "model"]}
        return self_dict

    def __setstate__(self, state):
        self.__dict__ = state

541 542

class EarlyStopping(Callback):
543 544 545 546 547
    r"""Stop training when a monitored quantity has stopped improving.

    This class is almost identical to the corresponding keras class.
    Therefore, **credit** to the Keras Team.

548
    Callbacks are passed as input parameters to the :obj:`Trainer` class. See
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577
    :class:`pytorch_widedeep.trainer.Trainer`

    Parameters
    -----------
    monitor: str, default='val_loss'.
        Quantity to monitor. Typically 'val_loss' or metric name (e.g. 'val_acc')
    min_delta: float, default=0.
        minimum change in the monitored quantity to qualify as an
        improvement, i.e. an absolute change of less than min_delta, will
        count as no improvement.
    patience: int, default=10.
        Number of epochs that produced the monitored quantity with no
        improvement after which training will be stopped.
    verbose: int.
        verbosity mode.
    mode: str, default='auto'
        one of {'`auto`', '`min`', '`max`'}. In `'min'` mode, training will
        stop when the quantity monitored has stopped decreasing; in `'max'`
        mode it will stop when the quantity monitored has stopped increasing;
        in `'auto'` mode, the direction is automatically inferred from the
        name of the monitored quantity.
    baseline: float, Optional. default=None.
        Baseline value for the monitored quantity to reach. Training will
        stop if the model does not show improvement over the baseline.
    restore_best_weights: bool, default=None
        Whether to restore model weights from the epoch with the best
        value of the monitored quantity. If ``False``, the model weights
        obtained at the last step of training are used.

578 579 580 581 582 583 584
    Attributes
    ----------
    best: float
        best metric
    stopped_epoch: int
        epoch when the training stopped

585 586 587 588 589 590 591 592 593 594 595 596 597 598
    Examples
    --------
    >>> from pytorch_widedeep.callbacks import EarlyStopping
    >>> from pytorch_widedeep.models import TabMlp, Wide, WideDeep
    >>> from pytorch_widedeep.training import Trainer
    >>>
    >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
    >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
    >>> wide = Wide(10, 1)
    >>> deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
    >>> model = WideDeep(wide, deep)
    >>> trainer = Trainer(model, objective="regression", callbacks=[EarlyStopping(patience=10)])
    """

J
jrzaurin 已提交
599 600 601 602 603 604 605 606 607 608 609
    def __init__(
        self,
        monitor: str = "val_loss",
        min_delta: float = 0.0,
        patience: int = 10,
        verbose: int = 0,
        mode: str = "auto",
        baseline: Optional[float] = None,
        restore_best_weights: bool = False,
    ):
        super(EarlyStopping, self).__init__()
610 611

        self.monitor = monitor
612
        self.min_delta = min_delta
613 614
        self.patience = patience
        self.verbose = verbose
615 616 617 618
        self.mode = mode
        self.baseline = baseline
        self.restore_best_weights = restore_best_weights

619 620 621 622
        self.wait = 0
        self.stopped_epoch = 0
        self.state_dict = None

623
        if self.mode not in ["auto", "min", "max"]:
J
jrzaurin 已提交
624
            warnings.warn(
625 626
                "EarlyStopping mode %s is unknown, "
                "fallback to auto mode." % self.mode,
J
jrzaurin 已提交
627 628
                RuntimeWarning,
            )
629
            self.mode = "auto"
630

631
        if self.mode == "min":
632
            self.monitor_op = np.less
633
        elif self.mode == "max":
634
            self.monitor_op = np.greater  # type: ignore[assignment]
635
        else:
636
            if _is_metric(self.monitor):
637
                self.monitor_op = np.greater  # type: ignore[assignment]
638 639 640 641 642 643 644 645
            else:
                self.monitor_op = np.less

        if self.monitor_op == np.greater:
            self.min_delta *= 1
        else:
            self.min_delta *= -1

J
jrzaurin 已提交
646
    def on_train_begin(self, logs: Optional[Dict] = None):
647 648 649 650 651 652 653 654
        # Allow instances to be re-used
        self.wait = 0
        self.stopped_epoch = 0
        if self.baseline is not None:
            self.best = self.baseline
        else:
            self.best = np.Inf if self.monitor_op == np.less else -np.Inf

655 656 657
    def on_epoch_end(
        self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
    ):
658 659 660 661 662 663 664 665 666 667 668 669 670
        current = self.get_monitor_value(logs)
        if current is None:
            return

        if self.monitor_op(current - self.min_delta, self.best):
            self.best = current
            self.wait = 0
            if self.restore_best_weights:
                self.state_dict = self.model.state_dict()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
671
                self.trainer.early_stop = True
672 673
                if self.restore_best_weights:
                    if self.verbose > 0:
674
                        print("Restoring model weights from the end of the best epoch")
675 676
                    self.model.load_state_dict(self.state_dict)

J
jrzaurin 已提交
677
    def on_train_end(self, logs: Optional[Dict] = None):
678
        if self.stopped_epoch > 0 and self.verbose > 0:
J
jrzaurin 已提交
679
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
680 681 682 683

    def get_monitor_value(self, logs):
        monitor_value = logs.get(self.monitor)
        if monitor_value is None:
J
jrzaurin 已提交
684 685 686 687 688
            warnings.warn(
                "Early stopping conditioned on metric `%s` "
                "which is not available. Available metrics are: %s"
                % (self.monitor, ",".join(list(logs.keys()))),
                RuntimeWarning,
689
            )
J
jrzaurin 已提交
690
        return monitor_value
691 692 693 694 695 696 697 698

    def __getstate__(self):
        d = self.__dict__
        self_dict = {k: d[k] for k in d if k not in ["trainer", "model"]}
        return self_dict

    def __setstate__(self, state):
        self.__dict__ = state
P
Pavol Mulinka 已提交
699 700 701


class RayTuneReporter(Callback):
702
    r"""Callback that allows reporting history and lr_history values to RayTune
703
    during Hyperparameter tuning
P
Pavol Mulinka 已提交
704

705 706 707 708 709 710 711
    Callbacks are passed as input parameters to the :obj:`Trainer` class. See
    :class:`pytorch_widedeep.trainer.Trainer`

    Examples
    --------
    see /examples/12_HyperParameter_tuning_w_RayTune.ipynb
    """
712

P
Pavol Mulinka 已提交
713 714 715 716 717 718
    def on_epoch_end(
        self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
    ):
        report_dict = {}
        for k, v in self.trainer.history.items():
            report_dict.update({k: v[-1]})
719
        if hasattr(self.trainer, "lr_history"):
P
Pavol Mulinka 已提交
720 721
            for k, v in self.trainer.lr_history.items():
                report_dict.update({k: v[-1]})
722
        tune.report(report_dict)