trainer.py 49.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import trange
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

11 12 13 14 15 16
from pytorch_widedeep.losses import MSLELoss, RMSELoss, FocalLoss, RMSLELoss
from pytorch_widedeep.models import WideDeep
from pytorch_widedeep.metrics import Metric, MetricCallback, MultipleMetrics
from pytorch_widedeep.wdtypes import *  # noqa: F403
from pytorch_widedeep.callbacks import History, Callback, CallbackContainer
from pytorch_widedeep.initializers import Initializer, MultipleInitializer
17
from pytorch_widedeep.training._finetune import FineTune
18 19 20 21
from pytorch_widedeep.utils.general_utils import Alias
from pytorch_widedeep.training._wd_dataset import WideDeepDataset
from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer
from pytorch_widedeep.training._multiple_transforms import MultipleTransforms
22 23 24 25
from pytorch_widedeep.training._loss_and_obj_aliases import (
    _LossAliases,
    _ObjectiveToMethod,
)
26 27 28
from pytorch_widedeep.training._multiple_lr_scheduler import (
    MultipleLRScheduler,
)
29 30 31 32 33 34 35

n_cpus = os.cpu_count()

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")


36
class Trainer:
37 38 39 40
    @Alias(  # noqa: C901
        "objective",
        ["loss_function", "loss_fn", "loss", "cost_function", "cost_fn", "cost"],
    )
41
    @Alias("model", ["model_path"])
42 43
    def __init__(
        self,
44
        model: Union[str, WideDeep],
45 46 47 48
        objective: str,
        custom_loss_function: Optional[Module] = None,
        optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
        lr_schedulers: Optional[Union[LRScheduler, Dict[str, LRScheduler]]] = None,
49
        initializers: Optional[Union[Initializer, Dict[str, Initializer]]] = None,
50 51 52 53 54 55 56 57 58 59 60 61 62 63
        transforms: Optional[List[Transforms]] = None,
        callbacks: Optional[List[Callback]] = None,
        metrics: Optional[List[Metric]] = None,
        class_weight: Optional[Union[float, List[float], Tuple[float]]] = None,
        alpha: float = 0.25,
        gamma: float = 2,
        verbose: int = 1,
        seed: int = 1,
    ):
        r"""Method to set the of attributes that will be used during the
        training process.

        Parameters
        ----------
64 65 66 67 68 69 70
        model: str or ``WideDeep``
            param alias: ``model_path``

            An object of class ``WideDeep`` or a string with the full path to
            the model, in which case you might want to use the alias
            ``model_path``.

71
        objective: str
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
            Defines the objective, loss or cost function.

            param aliases: ``loss_function``, ``loss_fn``, ``loss``,
            ``cost_function``, ``cost_fn``, ``cost``

            Possible values are:

            - ``binary``, aliases: ``logistic``, ``binary_logloss``, ``binary_cross_entropy``

            - ``binary_focal_loss``

            - ``multiclass``, aliases: ``multi_logloss``, ``cross_entropy``, ``categorical_cross_entropy``,

            - ``multiclass_focal_loss``

            - ``regression``, aliases: ``mse``, ``l2``, ``mean_squared_error``

            - ``mean_absolute_error``, aliases: ``mae``, ``l1``

            - ``mean_squared_log_error``, aliases: ``msle``

            - ``root_mean_squared_error``, aliases:  ``rmse``

            - ``root_mean_squared_log_error``, aliases: ``rmsle``
        custom_loss: ``nn.Module``, Optional, default = None
            object of class ``nn.Module``. If none of the loss functions
            available suits the user, it is possible to pass a custom loss
            function. See for example
            :class:`pytorch_widedeep.losses.FocalLoss` for the required
101 102 103
            structure of the object or the `Examples
            <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
            folder in the repo.
104 105
        optimizers: ``Optimzer`` or Dict, Optional, default= ``AdamW``
            - An instance of Pytorch's ``Optimizer`` object (e.g. :obj:`torch.optim.Adam()`) or
106 107 108
            - a dictionary where there keys are the model components (i.e.
              `'wide'`, `'deeptabular'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`)  and
              the values are the corresponding optimizers. If multiple optimizers are used
109 110 111
              the  dictionary **MUST** contain an optimizer per model component.
        lr_schedulers: ``LRScheduler`` or Dict, Optional, default=None
            - An instance of Pytorch's ``LRScheduler`` object (e.g
112 113 114 115
              :obj:`torch.optim.lr_scheduler.StepLR(opt, step_size=5)`) or
            - a dictionary where there keys are the model componenst (i.e. `'wide'`,
              `'deeptabular'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and the
              values are the corresponding learning rate schedulers.
116 117 118 119 120
        initializers: ``Initializer`` or Dict, Optional. default=None
            - An instance of an `Initializer`` object see ``pytorch-widedeep.initializers`` or
            - a dictionary where there keys are the model components (i.e. `'wide'`,
              `'deeptabular'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`)
              and the values are the corresponding initializers.
121 122 123 124
        transforms: List, Optional, default=None
            List with :obj:`torchvision.transforms` to be applied to the image
            component of the model (i.e. ``deepimage``) See `torchvision
            transforms
125
            <https://pytorch.org/docs/stable/torchvision/transforms.html>`_.
126 127 128
        callbacks: List, Optional, default=None
            List with ``Callback`` objects. The four callbacks available in
            ``pytorch-widedeep`` are: ``ModelCheckpoint``, ``EarlyStopping``,
129
            and ``LRHistory``. The ``History`` callback is used by default.
130
            This can also be a custom callback as long as the object of type
131 132 133 134
            ``Callback``. See ``pytorch_widedeep.callbacks.Callback`` or the
            `Examples
            <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
            folder in the repo
135 136 137 138 139
        metrics: List, Optional, default=None
            List of objects of type ``Metric``. Metrics available are:
            ``Accuracy``, ``Precision``, ``Recall``, ``FBetaScore``,
            ``F1Score`` and ``R2Score``. This can also be a custom metric as
            long as it is an object of type ``Metric``. See
140 141 142
            ``pytorch_widedeep.metrics.Metric`` or the `Examples
            <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
            folder in the repo
143
        class_weight: float, List or Tuple. Optional. default=None
144 145 146 147 148 149 150 151 152 153
            - float indicating the weight of the minority class in binary classification
              problems (e.g. 9.)
            - a list or tuple with weights for the different classes in multiclass
              classification problems  (e.g. [1., 2., 3.]). The weights do
              not neccesarily need to be normalised. If your loss function
              uses reduction='mean', the loss will be normalized by the sum
              of the corresponding weights for each element. If you are
              using reduction='none', you would have to take care of the
              normalization yourself. See `this discussion
              <https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/10>`_.
154 155 156 157 158 159 160 161
        alpha: float. default=0.25
            if ``objective`` is ``binary_focal_loss`` or
            ``multiclass_focal_loss``, the Focal Loss alpha and gamma
            parameters can be set directly in the ``Trainer`` via the
            ``alpha`` and ``gamma`` parameters
        gamma: float. default=2
            Focal Loss alpha gamma parameter
        verbose: int, default=1
162
            Setting it to 0 will print nothing during training.
163
        seed: int, default=1
164 165 166 167 168 169 170
            Random seed to be used throughout all the methods

        Example
        --------
        >>> import torch
        >>> from torchvision.transforms import ToTensor
        >>>
171
        >>> # wide deep imports
172 173 174
        >>> from pytorch_widedeep.callbacks import EarlyStopping, LRHistory
        >>> from pytorch_widedeep.initializers import KaimingNormal, KaimingUniform, Normal, Uniform
        >>> from pytorch_widedeep.models import TabResnet, DeepImage, DeepText, Wide, WideDeep
175
        >>> from pytorch_widedeep import Trainer
176
        >>> from pytorch_widedeep.optim import RAdam
177
        >>>
178
        >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
179
        >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
180
        >>> wide = Wide(10, 1)
181 182
        >>>
        >>> # build the model
183
        >>> deeptabular = TabResnet(blocks_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
184 185 186 187
        >>> deeptext = DeepText(vocab_size=10, embed_dim=4, padding_idx=0)
        >>> deepimage = DeepImage(pretrained=False)
        >>> model = WideDeep(wide=wide, deeptabular=deeptabular, deeptext=deeptext, deepimage=deepimage)
        >>>
188
        >>> # set optimizers and schedulers
189 190 191 192 193 194 195 196 197
        >>> wide_opt = torch.optim.Adam(model.wide.parameters())
        >>> deep_opt = torch.optim.Adam(model.deeptabular.parameters())
        >>> text_opt = RAdam(model.deeptext.parameters())
        >>> img_opt = RAdam(model.deepimage.parameters())
        >>>
        >>> wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=5)
        >>> deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=3)
        >>> text_sch = torch.optim.lr_scheduler.StepLR(text_opt, step_size=5)
        >>> img_sch = torch.optim.lr_scheduler.StepLR(img_opt, step_size=3)
198
        >>>
199 200
        >>> optimizers = {"wide": wide_opt, "deeptabular": deep_opt, "deeptext": text_opt, "deepimage": img_opt}
        >>> schedulers = {"wide": wide_sch, "deeptabular": deep_sch, "deeptext": text_sch, "deepimage": img_sch}
201 202
        >>>
        >>> # set initializers and callbacks
203 204 205
        >>> initializers = {"wide": Uniform, "deeptabular": Normal, "deeptext": KaimingNormal, "deepimage": KaimingUniform}
        >>> transforms = [ToTensor]
        >>> callbacks = [LRHistory(n_epochs=4), EarlyStopping]
206 207
        >>>
        >>> # set the trainer
208
        >>> trainer = Trainer(model, objective="regression", initializers=initializers, optimizers=optimizers,
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
        ... lr_schedulers=schedulers, callbacks=callbacks, transforms=transforms)
        """

        if isinstance(optimizers, Dict) and not isinstance(lr_schedulers, Dict):
            raise ValueError(
                "''optimizers' and 'lr_schedulers' must have consistent type: "
                "(Optimizer and LRScheduler) or (Dict[str, Optimizer] and Dict[str, LRScheduler]) "
                "Please, read the documentation or see the examples for more details"
            )

        if custom_loss_function is not None and objective not in [
            "binary",
            "multiclass",
            "regression",
        ]:
            raise ValueError(
                "If 'custom_loss_function' is not None, 'objective' might be 'binary' "
                "'multiclass' or 'regression', consistent with the loss function"
            )

229 230 231 232
        if isinstance(model, str):
            self.model = torch.load(model)
        else:
            self.model = model
233 234 235 236
        self.verbose = verbose
        self.seed = seed
        self.early_stop = False
        self.objective = objective
237
        self.method = _ObjectiveToMethod.get(objective)
238 239 240 241 242 243 244 245 246 247 248 249

        self.loss_fn = self._get_loss_fn(
            objective, class_weight, custom_loss_function, alpha, gamma
        )
        self._initialize(initializers)
        self.optimizer = self._get_optimizer(optimizers)
        self.lr_scheduler, self.cyclic_lr = self._get_lr_scheduler(lr_schedulers)
        self.transforms = self._get_transforms(transforms)
        self._set_callbacks_and_metrics(callbacks, metrics)

        self.model.to(device)

250 251 252 253 254 255 256 257 258 259 260 261 262
    @Alias("finetune", ["warmup"])  # noqa: C901
    @Alias("finetune_epochs", ["warmup_epochs"])
    @Alias("finetune_max_lr", ["warmup_max_lr"])
    @Alias("finetune_deeptabular_gradual", ["warmup_deeptabular_gradual"])
    @Alias("finetune_deeptabular_max_lr", ["warmup_deeptabular_max_lr"])
    @Alias("finetune_deeptabular_layers", ["warmup_deeptabular_layers"])
    @Alias("finetune_deeptext_gradual", ["warmup_deeptext_gradual"])
    @Alias("finetune_deeptext_max_lr", ["warmup_deeptext_max_lr"])
    @Alias("finetune_deeptext_layers", ["warmup_deeptext_layers"])
    @Alias("finetune_deepimage_gradual", ["warmup_deepimage_gradual"])
    @Alias("finetune_deepimage_max_lr", ["warmup_deepimage_max_lr"])
    @Alias("finetune_deepimage_layers", ["warmup_deepimage_layers"])
    @Alias("finetune_routine", ["warmup_routine"])
263 264 265 266 267 268 269 270 271 272 273 274 275 276
    def fit(  # noqa: C901
        self,
        X_wide: Optional[np.ndarray] = None,
        X_tab: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
        n_epochs: int = 1,
        validation_freq: int = 1,
        batch_size: int = 32,
        patience: int = 10,
277 278 279 280 281 282 283 284 285 286 287 288 289 290
        finetune: bool = False,
        finetune_epochs: int = 5,
        finetune_max_lr: float = 0.01,
        finetune_deeptabular_gradual: bool = False,
        finetune_deeptabular_max_lr: float = 0.01,
        finetune_deeptabular_layers: Optional[List[nn.Module]] = None,
        finetune_deeptext_gradual: bool = False,
        finetune_deeptext_max_lr: float = 0.01,
        finetune_deeptext_layers: Optional[List[nn.Module]] = None,
        finetune_deepimage_gradual: bool = False,
        finetune_deepimage_max_lr: float = 0.01,
        finetune_deepimage_layers: Optional[List[nn.Module]] = None,
        finetune_routine: str = "howard",
        stop_after_finetuning: bool = False,
291
    ):
292
        r"""Fit method.
293 294 295

        Parameters
        ----------
296
        X_wide: np.ndarray, Optional. default=None
297 298
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
299
        X_tab: np.ndarray, Optional. default=None
300 301
            Input for the ``deeptabular`` model component.
            See :class:`pytorch_widedeep.preprocessing.TabPreprocessor`
302
        X_text: np.ndarray, Optional. default=None
303 304
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
305
        X_img : np.ndarray, Optional. default=None
306 307
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
308
        X_train: Dict, Optional. default=None
309 310 311
            Training dataset for the different model components. Keys are
            `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values are
            the corresponding matrices.
312
        X_val: Dict, Optional. default=None
313 314 315
            Validation dataset for the different model component. Keys are
            `'X_wide'`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values are
            the corresponding matrices.
316
        val_split: float, Optional. default=None
317
            train/val split fraction
318
        target: np.ndarray, Optional. default=None
319
            target values
320
        n_epochs: int, default=1
321
            number of epochs
322
        validation_freq: int, default=1
323
            epochs validation frequency
324 325 326 327
        batch_size: int, default=32
        patience: int, default=10
            Number of epochs without improving the target metric or loss
            before the fit process stops
328 329 330 331
        finetune: bool, default=False
            param alias: ``warmup``

            fine-tune individual model components.
332

333 334 335
            .. note:: This functionality can also be used to 'warm-up'
               individual components before the joined training starts, and hence
               its alias. See the Examples folder in the repo for more details
336

337 338 339
            ``pytorch_widedeep`` implements 3 fine-tune routines.

            - fine-tune all trainable layers at once. This routine is is
340 341 342 343 344 345 346 347 348
              inspired by the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_. Using a
              Slanted Triangular learing (see `Leslie N. Smith paper
              <https://arxiv.org/pdf/1506.01186.pdf>`_), the process is the
              following: `i`) the learning rate will gradually increase for
              10% of the training steps from max_lr/10 to max_lr. `ii`) It
              will then gradually decrease to max_lr/10 for the remaining 90%
              of the steps. The optimizer used in the process is ``AdamW``.

349 350
            and two gradual fine-tune routines, where only certain layers are
            trained at a time.
351

352 353
            - The so called `Felbo` gradual fine-tune rourine, based on the the
              Felbo et al., 2017 `DeepEmoji paper <https://arxiv.org/abs/1708.00524>`_.
354 355 356
            - The `Howard` routine based on the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_.

357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
            For details on how these routines work, please see the Examples
            section in this documentation and the `Examples
            <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
            folder in the repo.
        finetune_epochs: int, default=4
            param alias: ``warmup_epochs``

            Number of fine-tune epochs for those model components that will
            *NOT* be gradually fine-tuned. Those components with gradual
            fine-tune follow their corresponding specific routine.
        finetune_max_lr: float, default=0.01
            param alias: ``warmup_max_lr``

            Maximum learning rate during the Triangular Learning rate cycle
            for those model componenst that will *NOT* be gradually fine-tuned
        finetune_deeptabular_gradual: bool, default=False
            param alias: ``warmup_deeptabular_gradual``

            Boolean indicating if the ``deeptabular`` component will be
            fine-tuned gradually
        finetune_deeptabular_max_lr: float, default=0.01
            param alias: ``warmup_deeptabular_max_lr``

380
            Maximum learning rate during the Triangular Learning rate cycle
381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
            for the deeptabular component
        finetune_deeptabular_layers: List, Optional, default=None
            param alias: ``warmup_deeptabular_layers``

            List of :obj:`nn.Modules` that will be fine-tuned gradually.

            .. note:: These have to be in `fine-tune-order`: the layers or blocks
                close to the output neuron(s) first

        finetune_deeptext_gradual: bool, default=False
            param alias: ``warmup_deeptext_gradual``

            Boolean indicating if the ``deeptext`` component will be
            fine-tuned gradually
        finetune_deeptext_max_lr: float, default=0.01
            param alias: ``warmup_deeptext_max_lr``

398 399
            Maximum learning rate during the Triangular Learning rate cycle
            for the deeptext component
400 401
        finetune_deeptext_layers: List, Optional, default=None
            param alias: ``warmup_deeptext_layers``
402

403 404 405
            List of :obj:`nn.Modules` that will be fine-tuned gradually.

            .. note:: These have to be in `fine-tune-order`: the layers or blocks
406 407
                close to the output neuron(s) first

408 409 410 411 412 413 414 415
        finetune_deepimage_gradual: bool, default=False
            param alias: ``warmup_deepimage_gradual``

            Boolean indicating if the ``deepimage`` component will be
            fine-tuned gradually
        finetune_deepimage_max_lr: float, default=0.01
            param alias: ``warmup_deepimage_max_lr``

416
            Maximum learning rate during the Triangular Learning rate cycle
417
            for the ``deepimage`` component
418 419 420 421
        finetune_deepimage_layers: List, Optional, default=None
            param alias: ``warmup_deepimage_layers``

            List of :obj:`nn.Modules` that will be fine-tuned gradually.
422

423
            .. note:: These have to be in `fine-tune-order`: the layers or blocks
424 425
                close to the output neuron(s) first

426 427 428
        finetune_routine: str, default=`felbo`
            param alias: ``warmup_deepimage_layers``

429 430
            Warm up routine. On of `felbo` or `howard`. See the examples
            section in this documentation and the corresponding repo for
431
            details on how to use fine-tune routines
432 433 434 435

        Examples
        --------

436 437
        For a series of comprehensive examples please, see the `Examples
        <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
438 439
        folder in the repo

440 441 442
        For completion, here we include some `"fabricated"` examples, i.e.
        these assume you have already built a model and instantiated a
        ``Trainer``, that is ready to fit
443

444
        .. code-block:: python
445

446 447
            # Ex 1. using train input arrays directly and no validation
            trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, n_epochs=10, batch_size=256)
448 449


450
        .. code-block:: python
451

452 453
            # Ex 2: using train input arrays directly and validation with val_split
            trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, n_epochs=10, batch_size=256, val_split=0.2)
454 455


456
        .. code-block:: python
457

458 459 460
            # Ex 3: using train dict and val_split
            X_train = {'X_wide': X_wide, 'X_tab': X_tab, 'target': y}
            trainer.fit(X_train, n_epochs=10, batch_size=256, val_split=0.2)
461

462 463 464 465 466 467 468

        .. code-block:: python

            # Ex 4: validation using training and validation dicts
            X_train = {'X_wide': X_wide_tr, 'X_tab': X_tab_tr, 'target': y_tr}
            X_val = {'X_wide': X_wide_val, 'X_tab': X_tab_val, 'target': y_val}
            trainer.fit(X_train=X_train, X_val=X_val n_epochs=10, batch_size=256)
469 470 471 472 473 474 475 476 477
        """

        self.batch_size = batch_size
        train_set, eval_set = self._train_val_split(
            X_wide, X_tab, X_text, X_img, X_train, X_val, val_split, target
        )
        train_loader = DataLoader(
            dataset=train_set, batch_size=batch_size, num_workers=n_cpus
        )
478 479 480 481 482 483 484 485 486 487
        train_steps = len(train_loader)
        if eval_set is not None:
            eval_loader = DataLoader(
                dataset=eval_set,
                batch_size=batch_size,
                num_workers=n_cpus,
                shuffle=False,
            )
            eval_steps = len(eval_loader)

488 489
        if finetune:
            self._finetune(
490
                train_loader,
491 492 493 494 495 496 497 498 499 500 501 502
                finetune_epochs,
                finetune_max_lr,
                finetune_deeptabular_gradual,
                finetune_deeptabular_layers,
                finetune_deeptabular_max_lr,
                finetune_deeptext_gradual,
                finetune_deeptext_layers,
                finetune_deeptext_max_lr,
                finetune_deepimage_gradual,
                finetune_deepimage_layers,
                finetune_deepimage_max_lr,
                finetune_routine,
503
            )
504 505 506 507 508 509 510 511 512 513
            if stop_after_finetuning:
                print("Fine-tuning finished")
                return
            else:
                if self.verbose:
                    print(
                        "Fine-tuning of individual components completed. "
                        "Training the whole model for {} epochs".format(n_epochs)
                    )

514 515 516 517 518 519 520 521
        self.callback_container.on_train_begin(
            {"batch_size": batch_size, "train_steps": train_steps, "n_epochs": n_epochs}
        )
        for epoch in range(n_epochs):
            epoch_logs: Dict[str, float] = {}
            self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
            self.train_running_loss = 0.0
            with trange(train_steps, disable=self.verbose != 1) as t:
522
                for batch_idx, (data, targett) in zip(t, train_loader):
523
                    t.set_description("epoch %i" % (epoch + 1))
524
                    score, train_loss = self._training_step(data, targett, batch_idx)
525 526 527 528 529 530 531 532 533 534 535 536 537 538 539
                    if score is not None:
                        t.set_postfix(
                            metrics={k: np.round(v, 4) for k, v in score.items()},
                            loss=train_loss,
                        )
                    else:
                        t.set_postfix(loss=train_loss)
                    if self.lr_scheduler:
                        self._lr_scheduler_step(step_location="on_batch_end")
                    self.callback_container.on_batch_end(batch=batch_idx)
            epoch_logs["train_loss"] = train_loss
            if score is not None:
                for k, v in score.items():
                    log_k = "_".join(["train", k])
                    epoch_logs[log_k] = v
540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559
            if eval_set is not None and epoch % validation_freq == (
                validation_freq - 1
            ):
                self.valid_running_loss = 0.0
                with trange(eval_steps, disable=self.verbose != 1) as v:
                    for i, (data, targett) in zip(v, eval_loader):
                        v.set_description("valid")
                        score, val_loss = self._validation_step(data, targett, i)
                        if score is not None:
                            v.set_postfix(
                                metrics={k: np.round(v, 4) for k, v in score.items()},
                                loss=val_loss,
                            )
                        else:
                            v.set_postfix(loss=val_loss)
                epoch_logs["val_loss"] = val_loss
                if score is not None:
                    for k, v in score.items():
                        log_k = "_".join(["val", k])
                        epoch_logs[log_k] = v
560 561 562 563 564 565
            if self.lr_scheduler:
                self._lr_scheduler_step(step_location="on_epoch_end")
            self.callback_container.on_epoch_end(epoch, epoch_logs)
            if self.early_stop:
                self.callback_container.on_train_end(epoch_logs)
                break
566
        self.callback_container.on_train_end(epoch_logs)
567 568
        self.model.train()

569
    def predict(  # type: ignore[return]
570 571 572 573 574 575 576 577 578 579 580
        self,
        X_wide: Optional[np.ndarray] = None,
        X_tab: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
        r"""Returns the predictions

        Parameters
        ----------
581
        X_wide: np.ndarray, Optional. default=None
582 583
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
584
        X_tab: np.ndarray, Optional. default=None
585 586
            Input for the ``deeptabular`` model component.
            See :class:`pytorch_widedeep.preprocessing.TabPreprocessor`
587
        X_text: np.ndarray, Optional. default=None
588 589
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
590
        X_img : np.ndarray, Optional. default=None
591 592
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
593 594 595 596
        X_test: Dict, Optional. default=None
            Dictionary with the resting dataset for the different model
            components. Keys are `'X_wide'`, `'X_tab'`, `'X_text'` and
            `'X_img'` and the values are the corresponding matrices.
597 598 599 600 601 602 603 604 605 606
        """

        preds_l = self._predict(X_wide, X_tab, X_text, X_img, X_test)
        if self.method == "regression":
            return np.vstack(preds_l).squeeze(1)
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
            return (preds > 0.5).astype("int")
        if self.method == "multiclass":
            preds = np.vstack(preds_l)
607
            return np.argmax(preds, 1)  # type: ignore[return-value]
608

609
    def predict_proba(  # type: ignore[return]
610 611 612 613 614 615 616 617 618
        self,
        X_wide: Optional[np.ndarray] = None,
        X_tab: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
        r"""Returns the predicted probabilities for the test dataset for  binary
        and multiclass methods
619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637

        Parameters
        ----------
        X_wide: np.ndarray, Optional. default=None
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
        X_tab: np.ndarray, Optional. default=None
            Input for the ``deeptabular`` model component.
            See :class:`pytorch_widedeep.preprocessing.TabPreprocessor`
        X_text: np.ndarray, Optional. default=None
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
        X_img : np.ndarray, Optional. default=None
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
        X_test: Dict, Optional. default=None
            Dictionary with the resting dataset for the different model
            components. Keys are `'X_wide'`, `'X_tab'`, `'X_text'` and
            `'X_img'` and the values are the corresponding matrices.
638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656
        """

        preds_l = self._predict(X_wide, X_tab, X_text, X_img, X_test)
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
            probs = np.zeros([preds.shape[0], 2])
            probs[:, 0] = 1 - preds
            probs[:, 1] = preds
            return probs
        if self.method == "multiclass":
            return np.vstack(preds_l)

    def get_embeddings(
        self, col_name: str, cat_encoding_dict: Dict[str, Dict[str, int]]
    ) -> Dict[str, np.ndarray]:  # pragma: no cover
        r"""Returns the learned embeddings for the categorical features passed through
        ``deeptabular``.

        This method is designed to take an encoding dictionary in the same
657
        format as that of the :obj:`LabelEncoder` Attribute in the class
658 659 660 661 662 663 664 665
        :obj:`TabPreprocessor`. See
        :class:`pytorch_widedeep.preprocessing.TabPreprocessor` and
        :class:`pytorch_widedeep.utils.dense_utils.LabelEncder`.

        Parameters
        ----------
        col_name: str,
            Column name of the feature we want to get the embeddings for
666 667 668 669 670 671 672
        cat_encoding_dict: Dict
            Dictionary where the keys are the name of the column for which we
            want to retrieve the embeddings and the values are also of type
            Dict. These Dict values have keys that are the categories for that
            column and the values are the corresponding numberical encodings

            e.g.: {'column': {'cat_0': 1, 'cat_1': 2, ...}}
673 674 675 676

        Examples
        --------

677 678
        For a series of comprehensive examples please, see the `Examples
        <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_
679 680 681 682 683 684 685
        folder in the repo

        For completion, here we include a `"fabricated"` example, i.e.
        assuming we have already trained the model, that we have the
        categorical encodings in a dictionary name ``encoding_dict``, and that
        there is a column called `'education'`:

686 687 688
        .. code-block:: python

            trainer.get_embeddings(col_name="education", cat_encoding_dict=encoding_dict)
689 690 691 692 693 694 695 696 697 698 699
        """
        for n, p in self.model.named_parameters():
            if "embed_layers" in n and col_name in n:
                embed_mtx = p.cpu().data.numpy()
        encoding_dict = cat_encoding_dict[col_name]
        inv_encoding_dict = {v: k for k, v in encoding_dict.items()}
        cat_embed_dict = {}
        for idx, value in inv_encoding_dict.items():
            cat_embed_dict[value] = embed_mtx[idx]
        return cat_embed_dict

700 701 702 703 704 705 706 707
    def save_model(self, path: str):
        """Saves the model to disk

        Parameters
        ----------
        path: str
            full path to the directory where the model will be saved
        """
708
        self._makedir_if_not_exist(path)
709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730
        torch.save(self.model, path)

    @staticmethod
    def load_model(path: str) -> nn.Module:
        """Loads the model from disk

        Parameters
        ----------
        path: str
            full path to the directory from where the model will be read
        """
        return torch.load(path)

    def save_model_state_dict(self, path: str):
        """Saves the state dictionary to disk

        Parameters
        ----------
        path: str
            full path to the directory where the model's state dictionary will
            be saved
        """
731
        self._makedir_if_not_exist(path)
732 733 734 735 736 737 738 739 740 741 742 743 744
        torch.save(self.model.state_dict(), path)

    def load_model_state_dict(self, path: str):
        """Saves the state dictionary to disk

        Parameters
        ----------
        path: str
            full path to the directory from where the model's state dictionary
            will be loaded
        """
        self.model.load_state_dict(torch.load(path))

745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819
    def _train_val_split(  # noqa: C901
        self,
        X_wide: Optional[np.ndarray] = None,
        X_tab: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
    ):
        r"""
        If a validation set (X_val) is passed to the fit method, or val_split
        is specified, the train/val split will happen internally. A number of
        options are allowed in terms of data inputs. For parameter
        information, please, see the .fit() method documentation

        Returns
        -------
        train_set: WideDeepDataset
            :obj:`WideDeepDataset` object that will be loaded through
            :obj:`torch.utils.data.DataLoader`. See
            :class:`pytorch_widedeep.models._wd_dataset`
        eval_set : WideDeepDataset
            :obj:`WideDeepDataset` object that will be loaded through
            :obj:`torch.utils.data.DataLoader`. See
            :class:`pytorch_widedeep.models._wd_dataset`
        """

        if X_val is not None:
            assert (
                X_train is not None
            ), "if the validation set is passed as a dictionary, the training set must also be a dictionary"
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
            eval_set = WideDeepDataset(**X_val, transforms=self.transforms)  # type: ignore
        elif val_split is not None:
            if not X_train:
                X_train = self._build_train_dict(X_wide, X_tab, X_text, X_img, target)
            y_tr, y_val, idx_tr, idx_val = train_test_split(
                X_train["target"],
                np.arange(len(X_train["target"])),
                test_size=val_split,
                stratify=X_train["target"] if self.method != "regression" else None,
            )
            X_tr, X_val = {"target": y_tr}, {"target": y_val}
            if "X_wide" in X_train.keys():
                X_tr["X_wide"], X_val["X_wide"] = (
                    X_train["X_wide"][idx_tr],
                    X_train["X_wide"][idx_val],
                )
            if "X_tab" in X_train.keys():
                X_tr["X_tab"], X_val["X_tab"] = (
                    X_train["X_tab"][idx_tr],
                    X_train["X_tab"][idx_val],
                )
            if "X_text" in X_train.keys():
                X_tr["X_text"], X_val["X_text"] = (
                    X_train["X_text"][idx_tr],
                    X_train["X_text"][idx_val],
                )
            if "X_img" in X_train.keys():
                X_tr["X_img"], X_val["X_img"] = (
                    X_train["X_img"][idx_tr],
                    X_train["X_img"][idx_val],
                )
            train_set = WideDeepDataset(**X_tr, transforms=self.transforms)  # type: ignore
            eval_set = WideDeepDataset(**X_val, transforms=self.transforms)  # type: ignore
        else:
            if not X_train:
                X_train = self._build_train_dict(X_wide, X_tab, X_text, X_img, target)
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
            eval_set = None

        return train_set, eval_set

820 821 822 823 824 825 826 827 828 829 830 831 832
    @staticmethod
    def _build_train_dict(X_wide, X_tab, X_text, X_img, target):
        X_train = {"target": target}
        if X_wide is not None:
            X_train["X_wide"] = X_wide
        if X_tab is not None:
            X_train["X_tab"] = X_tab
        if X_text is not None:
            X_train["X_text"] = X_text
        if X_img is not None:
            X_train["X_img"] = X_img
        return X_train

833
    def _finetune(
834 835 836 837
        self,
        loader: DataLoader,
        n_epochs: int,
        max_lr: float,
838 839 840
        deeptabular_gradual: bool,
        deeptabular_layers: List[nn.Module],
        deeptabular_max_lr: float,
841 842 843 844 845 846 847 848 849
        deeptext_gradual: bool,
        deeptext_layers: List[nn.Module],
        deeptext_max_lr: float,
        deepimage_gradual: bool,
        deepimage_layers: List[nn.Module],
        deepimage_max_lr: float,
        routine: str = "felbo",
    ):  # pragma: no cover
        r"""
850
        Simple wrap-up to individually fine-tune model components
851 852 853 854 855 856 857
        """
        if self.model.deephead is not None:
            raise ValueError(
                "Currently warming up is only supported without a fully connected 'DeepHead'"
            )
        # This is not the most elegant solution, but is a soluton "in-between"
        # a non elegant one and re-factoring the whole code
858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874
        finetuner = FineTune(self.loss_fn, self.metric, self.method, self.verbose)
        if self.model.wide:
            finetuner.finetune_all(self.model.wide, "wide", loader, n_epochs, max_lr)
        if self.model.deeptabular:
            if deeptabular_gradual:
                finetuner.finetune_gradual(
                    self.model.deeptabular,
                    "deeptabular",
                    loader,
                    deeptabular_max_lr,
                    deeptabular_layers,
                    routine,
                )
            else:
                finetuner.finetune_all(
                    self.model.deeptabular, "deeptabular", loader, n_epochs, max_lr
                )
875 876
        if self.model.deeptext:
            if deeptext_gradual:
877
                finetuner.finetune_gradual(
878 879 880 881 882 883 884 885
                    self.model.deeptext,
                    "deeptext",
                    loader,
                    deeptext_max_lr,
                    deeptext_layers,
                    routine,
                )
            else:
886
                finetuner.finetune_all(
887 888 889 890
                    self.model.deeptext, "deeptext", loader, n_epochs, max_lr
                )
        if self.model.deepimage:
            if deepimage_gradual:
891
                finetuner.finetune_gradual(
892 893 894 895 896 897 898 899
                    self.model.deepimage,
                    "deepimage",
                    loader,
                    deepimage_max_lr,
                    deepimage_layers,
                    routine,
                )
            else:
900
                finetuner.finetune_all(
901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958
                    self.model.deepimage, "deepimage", loader, n_epochs, max_lr
                )

    def _lr_scheduler_step(self, step_location: str):  # noqa: C901
        r"""
        Function to execute the learning rate schedulers steps.
        If the lr_scheduler is Cyclic (i.e. CyclicLR or OneCycleLR), the step
        must happen after training each bach durig training. On the other
        hand, if the  scheduler is not Cyclic, is expected to be called after
        validation.

        Parameters
        ----------
        step_location: Str
            Indicates where to run the lr_scheduler step
        """
        if (
            self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"
            and self.cyclic_lr
        ):
            if step_location == "on_batch_end":
                for model_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
            elif step_location == "on_epoch_end":
                for scheduler_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" not in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
        elif self.cyclic_lr:
            if step_location == "on_batch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler":
            if step_location == "on_epoch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif step_location == "on_epoch_end":
            self.lr_scheduler.step()  # type: ignore
        else:
            pass

    def _training_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
        self.model.train()
        X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
        y = target.view(-1, 1).float() if self.method != "multiclass" else target
        y = y.to(device)

        self.optimizer.zero_grad()
        y_pred = self.model(X)
        loss = self.loss_fn(y_pred, y)
        loss.backward()
        self.optimizer.step()

        self.train_running_loss += loss.item()
        avg_loss = self.train_running_loss / (batch_idx + 1)

959
        return self._get_score(y_pred, y), avg_loss
960 961 962 963 964 965 966 967 968 969 970 971 972 973

    def _validation_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):

        self.model.eval()
        with torch.no_grad():
            X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
            y = target.view(-1, 1).float() if self.method != "multiclass" else target
            y = y.to(device)

            y_pred = self.model(X)
            loss = self.loss_fn(y_pred, y)
            self.valid_running_loss += loss.item()
            avg_loss = self.valid_running_loss / (batch_idx + 1)

974 975 976
        return self._get_score(y_pred, y), avg_loss

    def _get_score(self, y_pred, y):
977
        if self.metric is not None:
978 979
            if self.method == "regression":
                score = self.metric(y_pred, y)
980 981 982 983
            if self.method == "binary":
                score = self.metric(torch.sigmoid(y_pred), y)
            if self.method == "multiclass":
                score = self.metric(F.softmax(y_pred, dim=1), y)
984
            return score
985
        else:
986
            return None
987 988 989 990 991 992 993 994 995

    def _predict(
        self,
        X_wide: Optional[np.ndarray] = None,
        X_tab: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> List:
996
        r"""Private method to avoid code repetition in predict and
997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039
        predict_proba. For parameter information, please, see the .predict()
        method documentation
        """
        if X_test is not None:
            test_set = WideDeepDataset(**X_test)
        else:
            load_dict = {}
            if X_wide is not None:
                load_dict = {"X_wide": X_wide}
            if X_tab is not None:
                load_dict.update({"X_tab": X_tab})
            if X_text is not None:
                load_dict.update({"X_text": X_text})
            if X_img is not None:
                load_dict.update({"X_img": X_img})
            test_set = WideDeepDataset(**load_dict)

        test_loader = DataLoader(
            dataset=test_set,
            batch_size=self.batch_size,
            num_workers=n_cpus,
            shuffle=False,
        )
        test_steps = (len(test_loader.dataset) // test_loader.batch_size) + 1  # type: ignore[arg-type]

        self.model.eval()
        preds_l = []
        with torch.no_grad():
            with trange(test_steps, disable=self.verbose != 1) as t:
                for i, data in zip(t, test_loader):
                    t.set_description("predict")
                    X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
                    preds = self.model(X)
                    if self.method == "binary":
                        preds = torch.sigmoid(preds)
                    if self.method == "multiclass":
                        preds = F.softmax(preds, dim=1)
                    preds = preds.cpu().data.numpy()
                    preds_l.append(preds)
        self.model.train()
        return preds_l

    @staticmethod
1040 1041 1042 1043 1044 1045 1046 1047 1048
    def _makedir_if_not_exist(path):
        if len(path.split("/")[:-1]) == 0:
            raise ValueError(
                "'path' must be the full path to save the model, including"
                " the root of the filenames. e.g. 'model/model.t'"
            )
        root_dir = ("/").join(path.split("/")[:-1])
        if not os.path.exists(root_dir):
            os.makedirs(root_dir)
1049

1050 1051
    @staticmethod
    def _alias_to_loss(loss_fn: str, **kwargs):
1052
        if loss_fn not in _ObjectiveToMethod.keys():
1053 1054 1055
            raise ValueError(
                "objective or loss function is not supported. Please consider passing a callable "
                "directly to the compile method (see docs) or use one of the supported objectives "
1056
                "or loss functions: {}".format(", ".join(_ObjectiveToMethod.keys()))
1057
            )
1058
        if loss_fn in _LossAliases.get("binary"):
1059
            return nn.BCEWithLogitsLoss(weight=kwargs["weight"])
1060
        if loss_fn in _LossAliases.get("multiclass"):
1061
            return nn.CrossEntropyLoss(weight=kwargs["weight"])
1062
        if loss_fn in _LossAliases.get("regression"):
1063
            return nn.MSELoss()
1064
        if loss_fn in _LossAliases.get("mean_absolute_error"):
1065
            return nn.L1Loss()
1066
        if loss_fn in _LossAliases.get("mean_squared_log_error"):
1067
            return MSLELoss()
1068
        if loss_fn in _LossAliases.get("root_mean_squared_error"):
1069
            return RMSELoss()
1070
        if loss_fn in _LossAliases.get("root_mean_squared_log_error"):
1071 1072 1073 1074
            return RMSLELoss()
        if "focal_loss" in loss_fn:
            return FocalLoss(**kwargs)

1075 1076 1077 1078 1079 1080 1081 1082 1083
    def _get_loss_fn(self, objective, class_weight, custom_loss_function, alpha, gamma):
        if isinstance(class_weight, float):
            class_weight = torch.tensor([1.0 - class_weight, class_weight])
        elif isinstance(class_weight, (tuple, list)):
            class_weight = torch.tensor(class_weight)
        else:
            class_weight = None
        if custom_loss_function is not None:
            return custom_loss_function
1084 1085
        elif self.method != "regression" and "focal_loss" not in objective:
            return self._alias_to_loss(objective, weight=class_weight)
1086
        elif "focal_loss" in objective:
1087
            return self._alias_to_loss(objective, alpha=alpha, gamma=gamma)
1088
        else:
1089
            return self._alias_to_loss(objective)
1090 1091 1092

    def _initialize(self, initializers):
        if initializers is not None:
1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103
            if isinstance(initializers, Dict):
                self.initializer = MultipleInitializer(
                    initializers, verbose=self.verbose
                )
                self.initializer.apply(self.model)
            elif isinstance(initializers, type):
                self.initializer = initializers()
                self.initializer(self.model)
            elif isinstance(initializers, Initializer):
                self.initializer = initializers
                self.initializer(self.model)
1104 1105 1106 1107

    def _get_optimizer(self, optimizers):
        if optimizers is not None:
            if isinstance(optimizers, Optimizer):
1108
                optimizer: Union[Optimizer, MultipleOptimizer] = optimizers
1109 1110
            elif isinstance(optimizers, Dict):
                opt_names = list(optimizers.keys())
1111
                mod_names = [n for n, c in self.model.named_children()]
1112 1113
                for mn in mod_names:
                    assert mn in opt_names, "No optimizer found for {}".format(mn)
1114
                optimizer = MultipleOptimizer(optimizers)
1115
        else:
1116 1117
            optimizer = torch.optim.AdamW(self.model.parameters())  # type: ignore
        return optimizer
1118

1119 1120
    @staticmethod
    def _get_lr_scheduler(lr_schedulers):
1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138
        if lr_schedulers is not None:
            if isinstance(lr_schedulers, LRScheduler):
                lr_scheduler: Union[
                    LRScheduler,
                    MultipleLRScheduler,
                ] = lr_schedulers
                cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower()
            else:
                lr_scheduler = MultipleLRScheduler(lr_schedulers)
                scheduler_names = [
                    sc.__class__.__name__.lower()
                    for _, sc in lr_scheduler._schedulers.items()
                ]
                cyclic_lr = any(["cycl" in sn for sn in scheduler_names])
        else:
            lr_scheduler, cyclic_lr = None, False
        return lr_scheduler, cyclic_lr

1139 1140
    @staticmethod
    def _get_transforms(transforms):
1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160
        if transforms is not None:
            return MultipleTransforms(transforms)()
        else:
            return None

    def _set_callbacks_and_metrics(self, callbacks, metrics):
        self.callbacks: List = [History()]
        if callbacks is not None:
            for callback in callbacks:
                if isinstance(callback, type):
                    callback = callback()
                self.callbacks.append(callback)
        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
            self.callbacks += [MetricCallback(self.metric)]
        else:
            self.metric = None
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self.model)
        self.callback_container.set_trainer(self)