trainer.py 50.4 KB
Newer Older
1
import os
2
import json
3
from pathlib import Path
4 5 6 7 8 9

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import trange
10
from scipy.sparse import csc_matrix
11
from torch.utils.data import DataLoader
12
from torch.optim.lr_scheduler import ReduceLROnPlateau
13

14 15
from pytorch_widedeep.metrics import Metric, MetricCallback, MultipleMetrics
from pytorch_widedeep.wdtypes import *  # noqa: F403
16 17 18 19 20 21
from pytorch_widedeep.callbacks import (
    History,
    Callback,
    CallbackContainer,
    LRShedulerCallback,
)
22
from pytorch_widedeep.initializers import Initializer, MultipleInitializer
23
from pytorch_widedeep.training._finetune import FineTune
24
from pytorch_widedeep.training._wd_dataset import WideDeepDataset
25 26
from pytorch_widedeep.training.trainer_utils import (
    Alias,
27
    alias_to_loss,
28
    save_epoch_logs,
29
    wd_train_val_split,
30 31
    print_loss_and_metric,
)
32
from pytorch_widedeep.models.tabnet.tab_net_utils import create_explain_matrix
33 34
from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer
from pytorch_widedeep.training._multiple_transforms import MultipleTransforms
35
from pytorch_widedeep.training._loss_and_obj_aliases import _ObjectiveToMethod
36 37 38
from pytorch_widedeep.training._multiple_lr_scheduler import (
    MultipleLRScheduler,
)
39 40 41 42 43 44 45

n_cpus = os.cpu_count()

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")


46
class Trainer:
47 48 49 50
    @Alias(  # noqa: C901
        "objective",
        ["loss_function", "loss_fn", "loss", "cost_function", "cost_fn", "cost"],
    )
51
    @Alias("model", ["model_path"])
52 53
    def __init__(
        self,
54
        model: Union[str, WideDeep],
55 56 57 58
        objective: str,
        custom_loss_function: Optional[Module] = None,
        optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
        lr_schedulers: Optional[Union[LRScheduler, Dict[str, LRScheduler]]] = None,
59
        reducelronplateau_criterion: Optional[str] = "loss",
60
        initializers: Optional[Union[Initializer, Dict[str, Initializer]]] = None,
61 62 63 64
        transforms: Optional[List[Transforms]] = None,
        callbacks: Optional[List[Callback]] = None,
        metrics: Optional[List[Metric]] = None,
        class_weight: Optional[Union[float, List[float], Tuple[float]]] = None,
65
        lambda_sparse: float = 1e-3,
66 67 68 69 70 71 72 73 74 75
        alpha: float = 0.25,
        gamma: float = 2,
        verbose: int = 1,
        seed: int = 1,
    ):
        r"""Method to set the of attributes that will be used during the
        training process.

        Parameters
        ----------
76 77 78 79 80 81 82
        model: str or ``WideDeep``
            param alias: ``model_path``

            An object of class ``WideDeep`` or a string with the full path to
            the model, in which case you might want to use the alias
            ``model_path``.

83
        objective: str
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
            Defines the objective, loss or cost function.

            param aliases: ``loss_function``, ``loss_fn``, ``loss``,
            ``cost_function``, ``cost_fn``, ``cost``

            Possible values are:

            - ``binary``, aliases: ``logistic``, ``binary_logloss``, ``binary_cross_entropy``

            - ``binary_focal_loss``

            - ``multiclass``, aliases: ``multi_logloss``, ``cross_entropy``, ``categorical_cross_entropy``,

            - ``multiclass_focal_loss``

            - ``regression``, aliases: ``mse``, ``l2``, ``mean_squared_error``

            - ``mean_absolute_error``, aliases: ``mae``, ``l1``

            - ``mean_squared_log_error``, aliases: ``msle``

            - ``root_mean_squared_error``, aliases:  ``rmse``

            - ``root_mean_squared_log_error``, aliases: ``rmsle``
108
        custom_loss_function: ``nn.Module``, Optional, default = None
109 110 111 112
            object of class ``nn.Module``. If none of the loss functions
            available suits the user, it is possible to pass a custom loss
            function. See for example
            :class:`pytorch_widedeep.losses.FocalLoss` for the required
113 114 115
            structure of the object or the `Examples
            <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
            folder in the repo.
116 117 118 119 120

            .. note:: If ``custom_loss_function`` is not None, ``objective`` must be
                'binary', 'multiclass' or 'regression', consistent with the loss
                function

121
        optimizers: ``Optimzer`` or Dict, Optional, default= None
122
            - An instance of Pytorch's ``Optimizer`` object (e.g. :obj:`torch.optim.Adam()`) or
123 124 125
            - a dictionary where there keys are the model components (i.e.
              `'wide'`, `'deeptabular'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`)  and
              the values are the corresponding optimizers. If multiple optimizers are used
126
              the  dictionary **MUST** contain an optimizer per model component.
127 128 129

            if no optimizers are passed it will default to ``AdamW`` for all
            Wide and Deep components
130 131
        lr_schedulers: ``LRScheduler`` or Dict, Optional, default=None
            - An instance of Pytorch's ``LRScheduler`` object (e.g
132 133 134 135
              :obj:`torch.optim.lr_scheduler.StepLR(opt, step_size=5)`) or
            - a dictionary where there keys are the model componenst (i.e. `'wide'`,
              `'deeptabular'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and the
              values are the corresponding learning rate schedulers.
136 137 138 139 140
        initializers: ``Initializer`` or Dict, Optional. default=None
            - An instance of an `Initializer`` object see ``pytorch-widedeep.initializers`` or
            - a dictionary where there keys are the model components (i.e. `'wide'`,
              `'deeptabular'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`)
              and the values are the corresponding initializers.
141 142 143 144
        transforms: List, Optional, default=None
            List with :obj:`torchvision.transforms` to be applied to the image
            component of the model (i.e. ``deepimage``) See `torchvision
            transforms
145
            <https://pytorch.org/docs/stable/torchvision/transforms.html>`_.
146 147
        callbacks: List, Optional, default=None
            List with ``Callback`` objects. The four callbacks available in
148 149 150 151 152
            ``pytorch-widedeep`` are: ``History``, ``ModelCheckpoint``,
            ``EarlyStopping``, and ``LRHistory``. The ``History`` callback is
            used by default. This can also be a custom callback as long as the
            object of type ``Callback``. See
            ``pytorch_widedeep.callbacks.Callback`` or the `Examples
153 154
            <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
            folder in the repo
155 156 157 158 159
        metrics: List, Optional, default=None
            List of objects of type ``Metric``. Metrics available are:
            ``Accuracy``, ``Precision``, ``Recall``, ``FBetaScore``,
            ``F1Score`` and ``R2Score``. This can also be a custom metric as
            long as it is an object of type ``Metric``. See
160 161 162
            ``pytorch_widedeep.metrics.Metric`` or the `Examples
            <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
            folder in the repo
163
        class_weight: float, List or Tuple. Optional. default=None
164 165 166
            - float indicating the weight of the minority class in binary classification
              problems (e.g. 9.)
            - a list or tuple with weights for the different classes in multiclass
167 168
              classification problems  (e.g. [1., 2., 3.]). The weights do not
              need to be normalised. See `this discussion
169
              <https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/10>`_.
170 171
        lambda_sparse: float. default=1e-3
            Tabnet sparse regularization factor
172 173 174 175 176 177 178 179
        alpha: float. default=0.25
            if ``objective`` is ``binary_focal_loss`` or
            ``multiclass_focal_loss``, the Focal Loss alpha and gamma
            parameters can be set directly in the ``Trainer`` via the
            ``alpha`` and ``gamma`` parameters
        gamma: float. default=2
            Focal Loss alpha gamma parameter
        verbose: int, default=1
180
            Setting it to 0 will print nothing during training.
181
        seed: int, default=1
182
            Random seed to be used internally for train_test_split
183

184 185 186 187 188
        Attributes
        ----------
        cyclic_lr: bool
            Attribute that indicates if any of the lr_schedulers is cyclic_lr (i.e. ``CyclicLR`` or
            ``OneCycleLR``). See `Pytorch schedulers <https://pytorch.org/docs/stable/optim.html>`_.
189 190 191 192
        feature_importance: Dict
            Dict where the keys are the column names and the values are the
            corresponding feature importances. This attribute will only exist
            if the ``deeptabular`` component is a Tabnet model
193

194
        Examples
195 196 197 198
        --------
        >>> import torch
        >>> from torchvision.transforms import ToTensor
        >>>
199
        >>> # wide deep imports
200 201 202
        >>> from pytorch_widedeep.callbacks import EarlyStopping, LRHistory
        >>> from pytorch_widedeep.initializers import KaimingNormal, KaimingUniform, Normal, Uniform
        >>> from pytorch_widedeep.models import TabResnet, DeepImage, DeepText, Wide, WideDeep
203
        >>> from pytorch_widedeep import Trainer
204
        >>> from pytorch_widedeep.optim import RAdam
205
        >>>
206
        >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
207
        >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
208
        >>> wide = Wide(10, 1)
209 210
        >>>
        >>> # build the model
211
        >>> deeptabular = TabResnet(blocks_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
212 213 214 215
        >>> deeptext = DeepText(vocab_size=10, embed_dim=4, padding_idx=0)
        >>> deepimage = DeepImage(pretrained=False)
        >>> model = WideDeep(wide=wide, deeptabular=deeptabular, deeptext=deeptext, deepimage=deepimage)
        >>>
216
        >>> # set optimizers and schedulers
217 218 219 220 221 222 223 224 225
        >>> wide_opt = torch.optim.Adam(model.wide.parameters())
        >>> deep_opt = torch.optim.Adam(model.deeptabular.parameters())
        >>> text_opt = RAdam(model.deeptext.parameters())
        >>> img_opt = RAdam(model.deepimage.parameters())
        >>>
        >>> wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=5)
        >>> deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=3)
        >>> text_sch = torch.optim.lr_scheduler.StepLR(text_opt, step_size=5)
        >>> img_sch = torch.optim.lr_scheduler.StepLR(img_opt, step_size=3)
226
        >>>
227 228
        >>> optimizers = {"wide": wide_opt, "deeptabular": deep_opt, "deeptext": text_opt, "deepimage": img_opt}
        >>> schedulers = {"wide": wide_sch, "deeptabular": deep_sch, "deeptext": text_sch, "deepimage": img_sch}
229 230
        >>>
        >>> # set initializers and callbacks
231 232 233
        >>> initializers = {"wide": Uniform, "deeptabular": Normal, "deeptext": KaimingNormal, "deepimage": KaimingUniform}
        >>> transforms = [ToTensor]
        >>> callbacks = [LRHistory(n_epochs=4), EarlyStopping]
234 235
        >>>
        >>> # set the trainer
236
        >>> trainer = Trainer(model, objective="regression", initializers=initializers, optimizers=optimizers,
237 238 239
        ... lr_schedulers=schedulers, callbacks=callbacks, transforms=transforms)
        """

240 241 242 243 244 245 246
        if isinstance(optimizers, Dict):
            if lr_schedulers is not None and not isinstance(lr_schedulers, Dict):
                raise ValueError(
                    "''optimizers' and 'lr_schedulers' must have consistent type: "
                    "(Optimizer and LRScheduler) or (Dict[str, Optimizer] and Dict[str, LRScheduler]) "
                    "Please, read the documentation or see the examples for more details"
                )
247 248 249 250 251 252 253

        if custom_loss_function is not None and objective not in [
            "binary",
            "multiclass",
            "regression",
        ]:
            raise ValueError(
254
                "If 'custom_loss_function' is not None, 'objective' must be 'binary' "
255 256 257
                "'multiclass' or 'regression', consistent with the loss function"
            )

258
        self.reducelronplateau = False
259
        self.reducelronplateau_criterion = reducelronplateau_criterion
260 261 262 263 264 265 266
        if isinstance(lr_schedulers, Dict):
            for _, scheduler in lr_schedulers.items():
                if isinstance(scheduler, ReduceLROnPlateau):
                    self.reducelronplateau = True
        elif isinstance(lr_schedulers, ReduceLROnPlateau):
            self.reducelronplateau = True

267 268 269 270
        if isinstance(model, str):
            self.model = torch.load(model)
        else:
            self.model = model
271

272
        #  Tabnet related set ups
273 274
        if self.model.is_tabnet:
            self.lambda_sparse = lambda_sparse
275
            self.reducing_matrix = create_explain_matrix(self.model)
276

277 278 279
        self.verbose = verbose
        self.seed = seed
        self.objective = objective
280
        self.method = _ObjectiveToMethod.get(objective)
281

282 283 284 285
        # initialize early_stop. If EarlyStopping Callback is used it will
        # take care of it
        self.early_stop = False

286
        self.loss_fn = self._set_loss_fn(
287 288 289
            objective, class_weight, custom_loss_function, alpha, gamma
        )
        self._initialize(initializers)
290 291 292
        self.optimizer = self._set_optimizer(optimizers)
        self.lr_scheduler = self._set_lr_scheduler(lr_schedulers)
        self.transforms = self._set_transforms(transforms)
293 294 295 296
        self._set_callbacks_and_metrics(callbacks, metrics)

        self.model.to(device)

297 298 299 300 301 302 303 304 305 306 307 308 309
    @Alias("finetune", "warmup")  # noqa: C901
    @Alias("finetune_epochs", "warmup_epochs")
    @Alias("finetune_max_lr", "warmup_max_lr")
    @Alias("finetune_deeptabular_gradual", "warmup_deeptabular_gradual")
    @Alias("finetune_deeptabular_max_lr", "warmup_deeptabular_max_lr")
    @Alias("finetune_deeptabular_layers", "warmup_deeptabular_layers")
    @Alias("finetune_deeptext_gradual", "warmup_deeptext_gradual")
    @Alias("finetune_deeptext_max_lr", "warmup_deeptext_max_lr")
    @Alias("finetune_deeptext_layers", "warmup_deeptext_layers")
    @Alias("finetune_deepimage_gradual", "warmup_deepimage_gradual")
    @Alias("finetune_deepimage_max_lr", "warmup_deepimage_max_lr")
    @Alias("finetune_deepimage_layers", "warmup_deepimage_layers")
    @Alias("finetune_routine", "warmup_routine")
310 311 312 313 314 315 316 317 318 319 320
    def fit(  # noqa: C901
        self,
        X_wide: Optional[np.ndarray] = None,
        X_tab: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
        n_epochs: int = 1,
321
        validation_freq: int = 1,
322
        batch_size: int = 32,
323 324 325 326 327
        finetune: bool = False,
        finetune_epochs: int = 5,
        finetune_max_lr: float = 0.01,
        finetune_deeptabular_gradual: bool = False,
        finetune_deeptabular_max_lr: float = 0.01,
328
        finetune_deeptabular_layers: Optional[List[nn.Module]] = None,
329 330
        finetune_deeptext_gradual: bool = False,
        finetune_deeptext_max_lr: float = 0.01,
331
        finetune_deeptext_layers: Optional[List[nn.Module]] = None,
332 333
        finetune_deepimage_gradual: bool = False,
        finetune_deepimage_max_lr: float = 0.01,
334
        finetune_deepimage_layers: Optional[List[nn.Module]] = None,
335 336
        finetune_routine: str = "howard",
        stop_after_finetuning: bool = False,
337
    ):
338
        r"""Fit method.
339

340 341 342 343
        The input datasets can be passed either directly via numpy arrays
        (``X_wide``, ``X_tab``, ``X_text`` or ``X_img``) or alternatively, in
        dictionaries (``X_train`` or ``X_val``).

344 345
        Parameters
        ----------
346
        X_wide: np.ndarray, Optional. default=None
347 348
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
349
        X_tab: np.ndarray, Optional. default=None
350 351
            Input for the ``deeptabular`` model component.
            See :class:`pytorch_widedeep.preprocessing.TabPreprocessor`
352
        X_text: np.ndarray, Optional. default=None
353 354
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
355
        X_img : np.ndarray, Optional. default=None
356 357
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
358
        X_train: Dict, Optional. default=None
359 360 361
            The training dataset can also be passed in a dictionary. Keys are
            `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values
            are the corresponding matrices.
362
        X_val: Dict, Optional. default=None
363 364 365
            The validation dataset can also be passed in a dictionary. Keys
            are `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`.
            Values are the corresponding matrices.
366
        val_split: float, Optional. default=None
367
            train/val split fraction
368
        target: np.ndarray, Optional. default=None
369
            target values
370
        n_epochs: int, default=1
371
            number of epochs
372
        validation_freq: int, default=1
373
            epochs validation frequency
374
        batch_size: int, default=32
375
            batch size
376
        finetune: bool, default=False
377 378 379
            param alias: ``warmup``

            fine-tune individual model components.
380

381 382 383
            .. note:: This functionality can also be used to 'warm-up'
               individual components before the joined training starts, and hence
               its alias. See the Examples folder in the repo for more details
384

385 386 387
            ``pytorch_widedeep`` implements 3 fine-tune routines.

            - fine-tune all trainable layers at once. This routine is is
388 389 390 391 392 393 394 395 396
              inspired by the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_. Using a
              Slanted Triangular learing (see `Leslie N. Smith paper
              <https://arxiv.org/pdf/1506.01186.pdf>`_), the process is the
              following: `i`) the learning rate will gradually increase for
              10% of the training steps from max_lr/10 to max_lr. `ii`) It
              will then gradually decrease to max_lr/10 for the remaining 90%
              of the steps. The optimizer used in the process is ``AdamW``.

397 398
            and two gradual fine-tune routines, where only certain layers are
            trained at a time.
399

400 401
            - The so called `Felbo` gradual fine-tune rourine, based on the the
              Felbo et al., 2017 `DeepEmoji paper <https://arxiv.org/abs/1708.00524>`_.
402 403 404
            - The `Howard` routine based on the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_.

405 406 407 408
            For details on how these routines work, please see the Examples
            section in this documentation and the `Examples
            <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
            folder in the repo.
409
        finetune_epochs: int, default=4
410 411 412 413 414
            param alias: ``warmup_epochs``

            Number of fine-tune epochs for those model components that will
            *NOT* be gradually fine-tuned. Those components with gradual
            fine-tune follow their corresponding specific routine.
415
        finetune_max_lr: float, default=0.01
416 417 418 419 420 421 422 423 424
            param alias: ``warmup_max_lr``

            Maximum learning rate during the Triangular Learning rate cycle
            for those model componenst that will *NOT* be gradually fine-tuned
        finetune_deeptabular_gradual: bool, default=False
            param alias: ``warmup_deeptabular_gradual``

            Boolean indicating if the ``deeptabular`` component will be
            fine-tuned gradually
425
        finetune_deeptabular_max_lr: float, default=0.01
426 427
            param alias: ``warmup_deeptabular_max_lr``

428
            Maximum learning rate during the Triangular Learning rate cycle
429 430 431 432 433 434 435 436 437
            for the deeptabular component
        finetune_deeptabular_layers: List, Optional, default=None
            param alias: ``warmup_deeptabular_layers``

            List of :obj:`nn.Modules` that will be fine-tuned gradually.

            .. note:: These have to be in `fine-tune-order`: the layers or blocks
                close to the output neuron(s) first

438
        finetune_deeptext_gradual: bool, default=False
439 440 441 442
            param alias: ``warmup_deeptext_gradual``

            Boolean indicating if the ``deeptext`` component will be
            fine-tuned gradually
443
        finetune_deeptext_max_lr: float, default=0.01
444 445
            param alias: ``warmup_deeptext_max_lr``

446 447
            Maximum learning rate during the Triangular Learning rate cycle
            for the deeptext component
448 449
        finetune_deeptext_layers: List, Optional, default=None
            param alias: ``warmup_deeptext_layers``
450

451 452 453
            List of :obj:`nn.Modules` that will be fine-tuned gradually.

            .. note:: These have to be in `fine-tune-order`: the layers or blocks
454 455
                close to the output neuron(s) first

456
        finetune_deepimage_gradual: bool, default=False
457 458 459 460
            param alias: ``warmup_deepimage_gradual``

            Boolean indicating if the ``deepimage`` component will be
            fine-tuned gradually
461
        finetune_deepimage_max_lr: float, default=0.01
462 463
            param alias: ``warmup_deepimage_max_lr``

464
            Maximum learning rate during the Triangular Learning rate cycle
465
            for the ``deepimage`` component
466 467 468 469
        finetune_deepimage_layers: List, Optional, default=None
            param alias: ``warmup_deepimage_layers``

            List of :obj:`nn.Modules` that will be fine-tuned gradually.
470

471
            .. note:: These have to be in `fine-tune-order`: the layers or blocks
472 473
                close to the output neuron(s) first

474 475
        finetune_routine: str, default = "howard"
            param alias: ``warmup_routine``
476

477
            Warm up routine. On of "felbo" or "howard". See the examples
478
            section in this documentation and the corresponding repo for
479
            details on how to use fine-tune routines
480 481 482 483

        Examples
        --------

484 485
        For a series of comprehensive examples please, see the `Examples
        <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
486 487
        folder in the repo

488 489 490
        For completion, here we include some `"fabricated"` examples, i.e.
        these assume you have already built a model and instantiated a
        ``Trainer``, that is ready to fit
491

492
        .. code-block:: python
493

494 495
            # Ex 1. using train input arrays directly and no validation
            trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, n_epochs=10, batch_size=256)
496 497


498
        .. code-block:: python
499

500 501
            # Ex 2: using train input arrays directly and validation with val_split
            trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, n_epochs=10, batch_size=256, val_split=0.2)
502 503


504
        .. code-block:: python
505

506 507 508
            # Ex 3: using train dict and val_split
            X_train = {'X_wide': X_wide, 'X_tab': X_tab, 'target': y}
            trainer.fit(X_train, n_epochs=10, batch_size=256, val_split=0.2)
509

510 511 512 513 514 515 516

        .. code-block:: python

            # Ex 4: validation using training and validation dicts
            X_train = {'X_wide': X_wide_tr, 'X_tab': X_tab_tr, 'target': y_tr}
            X_val = {'X_wide': X_wide_val, 'X_tab': X_tab_val, 'target': y_val}
            trainer.fit(X_train=X_train, X_val=X_val n_epochs=10, batch_size=256)
517 518 519
        """

        self.batch_size = batch_size
520 521 522 523 524 525 526 527 528 529 530
        train_set, eval_set = wd_train_val_split(
            self.seed,
            self.method,
            X_wide,
            X_tab,
            X_text,
            X_img,
            X_train,
            X_val,
            val_split,
            target,
531 532 533 534
        )
        train_loader = DataLoader(
            dataset=train_set, batch_size=batch_size, num_workers=n_cpus
        )
535 536 537 538 539 540 541 542 543 544
        train_steps = len(train_loader)
        if eval_set is not None:
            eval_loader = DataLoader(
                dataset=eval_set,
                batch_size=batch_size,
                num_workers=n_cpus,
                shuffle=False,
            )
            eval_steps = len(eval_loader)

545 546
        if finetune:
            self._finetune(
547
                train_loader,
548 549 550 551 552 553 554 555 556 557 558 559
                finetune_epochs,
                finetune_max_lr,
                finetune_deeptabular_gradual,
                finetune_deeptabular_layers,
                finetune_deeptabular_max_lr,
                finetune_deeptext_gradual,
                finetune_deeptext_layers,
                finetune_deeptext_max_lr,
                finetune_deepimage_gradual,
                finetune_deepimage_layers,
                finetune_deepimage_max_lr,
                finetune_routine,
560
            )
561 562 563 564 565 566 567 568 569 570
            if stop_after_finetuning:
                print("Fine-tuning finished")
                return
            else:
                if self.verbose:
                    print(
                        "Fine-tuning of individual components completed. "
                        "Training the whole model for {} epochs".format(n_epochs)
                    )

571 572 573 574 575 576
        self.callback_container.on_train_begin(
            {"batch_size": batch_size, "train_steps": train_steps, "n_epochs": n_epochs}
        )
        for epoch in range(n_epochs):
            epoch_logs: Dict[str, float] = {}
            self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
577

578 579
            self.train_running_loss = 0.0
            with trange(train_steps, disable=self.verbose != 1) as t:
580
                for batch_idx, (data, targett) in zip(t, train_loader):
581
                    t.set_description("epoch %i" % (epoch + 1))
582 583
                    train_score, train_loss = self._train_step(data, targett, batch_idx)
                    print_loss_and_metric(t, train_loss, train_score)
584
                    self.callback_container.on_batch_end(batch=batch_idx)
585
            epoch_logs = save_epoch_logs(epoch_logs, train_loss, train_score, "train")
586

587
            on_epoch_end_metric = None
588 589 590
            if eval_set is not None and epoch % validation_freq == (
                validation_freq - 1
            ):
591
                self.callback_container.on_eval_begin()
592 593 594 595
                self.valid_running_loss = 0.0
                with trange(eval_steps, disable=self.verbose != 1) as v:
                    for i, (data, targett) in zip(v, eval_loader):
                        v.set_description("valid")
596 597 598 599 600
                        val_score, val_loss = self._eval_step(data, targett, i)
                        print_loss_and_metric(v, val_loss, val_score)
                epoch_logs = save_epoch_logs(epoch_logs, val_loss, val_score, "val")

                if self.reducelronplateau:
601
                    if self.reducelronplateau_criterion == "loss":
602 603
                        on_epoch_end_metric = val_loss
                    else:
604 605 606
                        on_epoch_end_metric = val_score[
                            self.reducelronplateau_criterion
                        ]
607 608

            self.callback_container.on_epoch_end(epoch, epoch_logs, on_epoch_end_metric)
609

610 611 612
            if self.early_stop:
                self.callback_container.on_train_end(epoch_logs)
                break
613

614
        self.callback_container.on_train_end(epoch_logs)
615
        if self.model.is_tabnet:
616
            self._compute_feature_importance(train_loader)
617
        self._restore_best_weights()
618 619
        self.model.train()

620
    def predict(  # type: ignore[return]
621 622 623 624 625 626 627 628 629
        self,
        X_wide: Optional[np.ndarray] = None,
        X_tab: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
        r"""Returns the predictions

630 631 632 633 634
        The input datasets can be passed either directly via numpy arrays
        (``X_wide``, ``X_tab``, ``X_text`` or ``X_img``) or alternatively, in
        a dictionary (``X_test``)


635 636
        Parameters
        ----------
637
        X_wide: np.ndarray, Optional. default=None
638 639
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
640
        X_tab: np.ndarray, Optional. default=None
641 642
            Input for the ``deeptabular`` model component.
            See :class:`pytorch_widedeep.preprocessing.TabPreprocessor`
643
        X_text: np.ndarray, Optional. default=None
644 645
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
646
        X_img : np.ndarray, Optional. default=None
647 648
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
649
        X_test: Dict, Optional. default=None
650 651 652
            The test dataset can also be passed in a dictionary. Keys are
            `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values
            are the corresponding matrices.
653 654 655 656 657 658 659 660 661 662
        """

        preds_l = self._predict(X_wide, X_tab, X_text, X_img, X_test)
        if self.method == "regression":
            return np.vstack(preds_l).squeeze(1)
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
            return (preds > 0.5).astype("int")
        if self.method == "multiclass":
            preds = np.vstack(preds_l)
663
            return np.argmax(preds, 1)  # type: ignore[return-value]
664

665
    def predict_proba(  # type: ignore[return]
666 667 668 669 670 671 672 673 674
        self,
        X_wide: Optional[np.ndarray] = None,
        X_tab: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
        r"""Returns the predicted probabilities for the test dataset for  binary
        and multiclass methods
675

676 677 678 679
        The input datasets can be passed either directly via numpy arrays
        (``X_wide``, ``X_tab``, ``X_text`` or ``X_img``) or alternatively, in
        a dictionary (``X_test``)

680 681 682 683 684 685 686 687 688 689 690 691 692 693 694
        Parameters
        ----------
        X_wide: np.ndarray, Optional. default=None
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
        X_tab: np.ndarray, Optional. default=None
            Input for the ``deeptabular`` model component.
            See :class:`pytorch_widedeep.preprocessing.TabPreprocessor`
        X_text: np.ndarray, Optional. default=None
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
        X_img : np.ndarray, Optional. default=None
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
        X_test: Dict, Optional. default=None
695 696 697
            The test dataset can also be passed in a dictionary. Keys are
            `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values
            are the corresponding matrices.
698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716
        """

        preds_l = self._predict(X_wide, X_tab, X_text, X_img, X_test)
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
            probs = np.zeros([preds.shape[0], 2])
            probs[:, 0] = 1 - preds
            probs[:, 1] = preds
            return probs
        if self.method == "multiclass":
            return np.vstack(preds_l)

    def get_embeddings(
        self, col_name: str, cat_encoding_dict: Dict[str, Dict[str, int]]
    ) -> Dict[str, np.ndarray]:  # pragma: no cover
        r"""Returns the learned embeddings for the categorical features passed through
        ``deeptabular``.

        This method is designed to take an encoding dictionary in the same
717
        format as that of the :obj:`LabelEncoder` Attribute in the class
718 719 720 721 722 723 724 725
        :obj:`TabPreprocessor`. See
        :class:`pytorch_widedeep.preprocessing.TabPreprocessor` and
        :class:`pytorch_widedeep.utils.dense_utils.LabelEncder`.

        Parameters
        ----------
        col_name: str,
            Column name of the feature we want to get the embeddings for
726 727 728 729 730 731 732
        cat_encoding_dict: Dict
            Dictionary where the keys are the name of the column for which we
            want to retrieve the embeddings and the values are also of type
            Dict. These Dict values have keys that are the categories for that
            column and the values are the corresponding numberical encodings

            e.g.: {'column': {'cat_0': 1, 'cat_1': 2, ...}}
733 734 735 736

        Examples
        --------

737 738
        For a series of comprehensive examples please, see the `Examples
        <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
739 740 741 742 743 744 745
        folder in the repo

        For completion, here we include a `"fabricated"` example, i.e.
        assuming we have already trained the model, that we have the
        categorical encodings in a dictionary name ``encoding_dict``, and that
        there is a column called `'education'`:

746 747 748
        .. code-block:: python

            trainer.get_embeddings(col_name="education", cat_encoding_dict=encoding_dict)
749 750 751 752 753 754 755 756 757 758 759
        """
        for n, p in self.model.named_parameters():
            if "embed_layers" in n and col_name in n:
                embed_mtx = p.cpu().data.numpy()
        encoding_dict = cat_encoding_dict[col_name]
        inv_encoding_dict = {v: k for k, v in encoding_dict.items()}
        cat_embed_dict = {}
        for idx, value in inv_encoding_dict.items():
            cat_embed_dict[value] = embed_mtx[idx]
        return cat_embed_dict

760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818
    def explain(self, X_tab: np.ndarray, save_step_masks: bool = False):
        """
        Returns the aggregated feature importance for each instance (or
        observation) in the ``X_tab`` array. If ``save_step_masks`` is set to
        ``True``, the masks per step will also be returned.

        Parameters
        ----------
        X_tab: np.ndarray
            Input array corresponding **only** to the deeptabular component
        save_step_masks: bool
            Boolean indicating if the masks per step will be returned

        Returns
        -------
        res: np.ndarray, Tuple
            Array or Tuple of two arrays with the corresponding aggregated
            feature importance and the masks per step if ``save_step_masks``
            is set to ``True``
        """
        loader = DataLoader(
            dataset=WideDeepDataset(**{"X_tab": X_tab}),
            batch_size=self.batch_size,
            num_workers=n_cpus,
            shuffle=False,
        )

        self.model.eval()
        tabnet_backbone = list(self.model.deeptabular.children())[0]

        m_explain_l = []
        for batch_nb, data in enumerate(loader):
            X = data["deeptabular"].to(device)
            M_explain, masks = tabnet_backbone.forward_masks(X)
            m_explain_l.append(
                csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix)
            )
            if save_step_masks:
                for key, value in masks.items():
                    masks[key] = csc_matrix.dot(
                        value.cpu().detach().numpy(), self.reducing_matrix
                    )
                if batch_nb == 0:
                    m_explain_step = masks
                else:
                    for key, value in masks.items():
                        m_explain_step[key] = np.vstack([m_explain_step[key], value])

        m_explain_agg = np.vstack(m_explain_l)
        m_explain_agg_norm = m_explain_agg / m_explain_agg.sum(axis=1)[:, np.newaxis]

        res = (
            (m_explain_agg_norm, m_explain_step)
            if save_step_masks
            else np.vstack(m_explain_agg_norm)
        )

        return res

819 820 821 822 823 824
    def save(
        self,
        path: str,
        save_state_dict: bool = False,
        model_filename: str = "wd_model.pt",
    ):
825 826 827
        """Saves the model, training and evaluation history, and the
        feature_importance attribute (if the ``deeptabular`` component is a
        Tabnet model) to disk
828 829 830 831 832 833 834 835 836 837 838 839 840 841 842

        The ``Trainer`` class is built so that it 'just' trains a model. With
        that in mind, all the torch related parameters (such as optimizers,
        learning rate schedulers, initializers, etc) have to be defined
        externally and then passed to the ``Trainer``. As a result, the
        ``Trainer`` does not generate any attribute or additional data
        products that need to be saved other than the ``model`` object itself,
        which can be saved as any other torch model (e.g. ``torch.save(model,
        path)``).

        The exception is Tabnet. If the ``deeptabular`` component is a Tabnet
        model, an attribute (a Dict) called ``feature_importance`` will be
        created at the end of the training process. Therefore, a ``save``
        method was created that will save both the feature importance
        dictionary to a json file and, since we are here, the model weights.
843 844 845 846

        Parameters
        ----------
        path: str
847 848 849 850 851 852 853
            path to the directory where the model and the feature importance
            attribute will be saved.
        save_state_dict: bool, default = False
            Boolean indicating whether to save directly the model or the
            model's state dictionary
        model_filename: str, Optional, default = "wd_model.pt"
            filename where the model weights will be store
854
        """
855

856 857 858 859 860 861 862 863 864 865 866 867 868 869
        save_dir = Path(path)
        history_dir = save_dir / "history"
        history_dir.mkdir(exist_ok=True, parents=True)

        # the trainer is run with the History Callback by default
        with open(history_dir / "train_eval_history.json", "w") as teh:
            json.dump(self.history, teh)  # type: ignore[attr-defined]

        has_lr_history = any(
            [clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks]
        )
        if has_lr_history:
            with open(history_dir / "lr_history.json", "w") as lrh:
                json.dump(self.lr_history, lrh)  # type: ignore[attr-defined]
870

871
        model_path = save_dir / model_filename
872 873
        if save_state_dict:
            torch.save(self.model.state_dict(), model_path)
874
        else:
875
            torch.save(self.model, model_path)
876

877
        if self.model.is_tabnet:
878
            with open(save_dir / "feature_importance.json", "w") as fi:
879
                json.dump(self.feature_importance, fi)
880

881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908
    def _restore_best_weights(self):
        already_restored = any(
            [
                (
                    callback.__class__.__name__ == "EarlyStopping"
                    and callback.restore_best_weights
                )
                for callback in self.callback_container.callbacks
            ]
        )
        if already_restored:
            pass
        else:
            for callback in self.callback_container.callbacks:
                if callback.__class__.__name__ == "ModelCheckpoint":
                    if callback.save_best_only:
                        filepath = "{}_{}.p".format(
                            callback.filepath, callback.best_epoch + 1
                        )
                        if self.verbose:
                            print(
                                f"Model weights restored to best epoch: {callback.best_epoch + 1}"
                            )
                        self.model.load_state_dict(torch.load(filepath))
                    else:
                        if self.verbose:
                            print(
                                "Model weights after training corresponds to the those of the "
909 910
                                "final epoch which might not be the best performing weights. Use"
                                "the 'ModelCheckpoint' Callback to restore the best epoch weights."
911 912
                            )

913
    def _finetune(
914 915 916 917
        self,
        loader: DataLoader,
        n_epochs: int,
        max_lr: float,
918 919 920
        deeptabular_gradual: bool,
        deeptabular_layers: List[nn.Module],
        deeptabular_max_lr: float,
921 922 923 924 925 926 927 928 929
        deeptext_gradual: bool,
        deeptext_layers: List[nn.Module],
        deeptext_max_lr: float,
        deepimage_gradual: bool,
        deepimage_layers: List[nn.Module],
        deepimage_max_lr: float,
        routine: str = "felbo",
    ):  # pragma: no cover
        r"""
930
        Simple wrap-up to individually fine-tune model components
931 932 933 934 935 936 937
        """
        if self.model.deephead is not None:
            raise ValueError(
                "Currently warming up is only supported without a fully connected 'DeepHead'"
            )
        # This is not the most elegant solution, but is a soluton "in-between"
        # a non elegant one and re-factoring the whole code
938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954
        finetuner = FineTune(self.loss_fn, self.metric, self.method, self.verbose)
        if self.model.wide:
            finetuner.finetune_all(self.model.wide, "wide", loader, n_epochs, max_lr)
        if self.model.deeptabular:
            if deeptabular_gradual:
                finetuner.finetune_gradual(
                    self.model.deeptabular,
                    "deeptabular",
                    loader,
                    deeptabular_max_lr,
                    deeptabular_layers,
                    routine,
                )
            else:
                finetuner.finetune_all(
                    self.model.deeptabular, "deeptabular", loader, n_epochs, max_lr
                )
955 956
        if self.model.deeptext:
            if deeptext_gradual:
957
                finetuner.finetune_gradual(
958 959 960 961 962 963 964 965
                    self.model.deeptext,
                    "deeptext",
                    loader,
                    deeptext_max_lr,
                    deeptext_layers,
                    routine,
                )
            else:
966
                finetuner.finetune_all(
967 968 969 970
                    self.model.deeptext, "deeptext", loader, n_epochs, max_lr
                )
        if self.model.deepimage:
            if deepimage_gradual:
971
                finetuner.finetune_gradual(
972 973 974 975 976 977 978 979
                    self.model.deepimage,
                    "deepimage",
                    loader,
                    deepimage_max_lr,
                    deepimage_layers,
                    routine,
                )
            else:
980
                finetuner.finetune_all(
981 982 983
                    self.model.deepimage, "deepimage", loader, n_epochs, max_lr
                )

984
    def _train_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
985 986 987 988 989 990 991
        self.model.train()
        X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
        y = target.view(-1, 1).float() if self.method != "multiclass" else target
        y = y.to(device)

        self.optimizer.zero_grad()
        y_pred = self.model(X)
992 993 994 995 996 997
        if self.model.is_tabnet:
            loss = self.loss_fn(y_pred[0], y) - self.lambda_sparse * y_pred[1]
            score = self._get_score(y_pred[0], y)
        else:
            loss = self.loss_fn(y_pred, y)
            score = self._get_score(y_pred, y)
998 999 1000 1001 1002 1003
        loss.backward()
        self.optimizer.step()

        self.train_running_loss += loss.item()
        avg_loss = self.train_running_loss / (batch_idx + 1)

1004
        return score, avg_loss
1005

1006
    def _eval_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
1007 1008 1009 1010 1011 1012 1013 1014

        self.model.eval()
        with torch.no_grad():
            X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
            y = target.view(-1, 1).float() if self.method != "multiclass" else target
            y = y.to(device)

            y_pred = self.model(X)
1015 1016 1017 1018 1019 1020 1021
            if self.model.is_tabnet:
                loss = self.loss_fn(y_pred[0], y) - self.lambda_sparse * y_pred[1]
                score = self._get_score(y_pred[0], y)
            else:
                score = self._get_score(y_pred, y)
                loss = self.loss_fn(y_pred, y)

1022 1023 1024
            self.valid_running_loss += loss.item()
            avg_loss = self.valid_running_loss / (batch_idx + 1)

1025
        return score, avg_loss
1026 1027

    def _get_score(self, y_pred, y):
1028
        if self.metric is not None:
1029 1030
            if self.method == "regression":
                score = self.metric(y_pred, y)
1031 1032 1033 1034
            if self.method == "binary":
                score = self.metric(torch.sigmoid(y_pred), y)
            if self.method == "multiclass":
                score = self.metric(F.softmax(y_pred, dim=1), y)
1035
            return score
1036
        else:
1037
            return None
1038

1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056
    def _compute_feature_importance(self, loader: DataLoader):
        self.model.eval()
        tabnet_backbone = list(self.model.deeptabular.children())[0]
        feat_imp = np.zeros((tabnet_backbone.embed_and_cont_dim))
        for data, target in loader:
            X = data["deeptabular"].to(device)
            y = target.view(-1, 1).float() if self.method != "multiclass" else target
            y = y.to(device)
            M_explain, masks = tabnet_backbone.forward_masks(X)
            feat_imp += M_explain.sum(dim=0).cpu().detach().numpy()

        feat_imp = csc_matrix.dot(feat_imp, self.reducing_matrix)
        feat_imp = feat_imp / np.sum(feat_imp)

        self.feature_importance = {
            k: v for k, v in zip(tabnet_backbone.column_idx.keys(), feat_imp)
        }

1057 1058 1059 1060 1061 1062 1063 1064
    def _predict(
        self,
        X_wide: Optional[np.ndarray] = None,
        X_tab: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> List:
1065
        r"""Private method to avoid code repetition in predict and
1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
        predict_proba. For parameter information, please, see the .predict()
        method documentation
        """
        if X_test is not None:
            test_set = WideDeepDataset(**X_test)
        else:
            load_dict = {}
            if X_wide is not None:
                load_dict = {"X_wide": X_wide}
            if X_tab is not None:
                load_dict.update({"X_tab": X_tab})
            if X_text is not None:
                load_dict.update({"X_text": X_text})
            if X_img is not None:
                load_dict.update({"X_img": X_img})
            test_set = WideDeepDataset(**load_dict)

        test_loader = DataLoader(
            dataset=test_set,
            batch_size=self.batch_size,
            num_workers=n_cpus,
            shuffle=False,
        )
        test_steps = (len(test_loader.dataset) // test_loader.batch_size) + 1  # type: ignore[arg-type]

        self.model.eval()
        preds_l = []
        with torch.no_grad():
            with trange(test_steps, disable=self.verbose != 1) as t:
                for i, data in zip(t, test_loader):
                    t.set_description("predict")
                    X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
1098 1099 1100
                    preds = (
                        self.model(X) if not self.model.is_tabnet else self.model(X)[0]
                    )
1101 1102 1103 1104 1105 1106 1107 1108 1109
                    if self.method == "binary":
                        preds = torch.sigmoid(preds)
                    if self.method == "multiclass":
                        preds = F.softmax(preds, dim=1)
                    preds = preds.cpu().data.numpy()
                    preds_l.append(preds)
        self.model.train()
        return preds_l

1110
    def _set_loss_fn(self, objective, class_weight, custom_loss_function, alpha, gamma):
1111
        if class_weight is not None:
1112
            class_weight = torch.tensor(class_weight).to(device)
1113 1114
        if custom_loss_function is not None:
            return custom_loss_function
1115
        elif self.method != "regression" and "focal_loss" not in objective:
1116
            return alias_to_loss(objective, weight=class_weight)
1117
        elif "focal_loss" in objective:
1118
            return alias_to_loss(objective, alpha=alpha, gamma=gamma)
1119
        else:
1120
            return alias_to_loss(objective)
1121 1122 1123

    def _initialize(self, initializers):
        if initializers is not None:
1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134
            if isinstance(initializers, Dict):
                self.initializer = MultipleInitializer(
                    initializers, verbose=self.verbose
                )
                self.initializer.apply(self.model)
            elif isinstance(initializers, type):
                self.initializer = initializers()
                self.initializer(self.model)
            elif isinstance(initializers, Initializer):
                self.initializer = initializers
                self.initializer(self.model)
1135

1136
    def _set_optimizer(self, optimizers):
1137 1138
        if optimizers is not None:
            if isinstance(optimizers, Optimizer):
1139
                optimizer: Union[Optimizer, MultipleOptimizer] = optimizers
1140 1141
            elif isinstance(optimizers, Dict):
                opt_names = list(optimizers.keys())
1142
                mod_names = [n for n, c in self.model.named_children()]
1143 1144
                for mn in mod_names:
                    assert mn in opt_names, "No optimizer found for {}".format(mn)
1145
                optimizer = MultipleOptimizer(optimizers)
1146
        else:
1147 1148
            optimizer = torch.optim.AdamW(self.model.parameters())  # type: ignore
        return optimizer
1149

1150
    def _set_lr_scheduler(self, lr_schedulers):
1151
        if lr_schedulers is not None:
1152 1153 1154 1155 1156 1157
            # ReduceLROnPlateau is special, only scheduler that is 'just' an
            # object rather than a LRScheduler
            if isinstance(lr_schedulers, LRScheduler) or isinstance(
                lr_schedulers, ReduceLROnPlateau
            ):
                lr_scheduler = lr_schedulers
1158 1159 1160 1161 1162 1163 1164 1165 1166 1167
                cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower()
            else:
                lr_scheduler = MultipleLRScheduler(lr_schedulers)
                scheduler_names = [
                    sc.__class__.__name__.lower()
                    for _, sc in lr_scheduler._schedulers.items()
                ]
                cyclic_lr = any(["cycl" in sn for sn in scheduler_names])
        else:
            lr_scheduler, cyclic_lr = None, False
1168 1169
        self.cyclic_lr = cyclic_lr
        return lr_scheduler
1170

1171
    @staticmethod
1172
    def _set_transforms(transforms):
1173 1174 1175 1176 1177 1178
        if transforms is not None:
            return MultipleTransforms(transforms)()
        else:
            return None

    def _set_callbacks_and_metrics(self, callbacks, metrics):
1179
        self.callbacks: List = [History(), LRShedulerCallback()]
1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192
        if callbacks is not None:
            for callback in callbacks:
                if isinstance(callback, type):
                    callback = callback()
                self.callbacks.append(callback)
        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
            self.callbacks += [MetricCallback(self.metric)]
        else:
            self.metric = None
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self.model)
        self.callback_container.set_trainer(self)