trainer.py 52.2 KB
Newer Older
1
import os
2
import json
3
from pathlib import Path
4 5 6 7 8 9

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import trange
10
from scipy.sparse import csc_matrix
11
from torchmetrics import Metric as TorchMetric
12
from torch.utils.data import DataLoader
13
from torch.optim.lr_scheduler import ReduceLROnPlateau
14

15
from pytorch_widedeep.metrics import Metric, MultipleMetrics
16
from pytorch_widedeep.wdtypes import *  # noqa: F403
17 18 19
from pytorch_widedeep.callbacks import (
    History,
    Callback,
20
    MetricCallback,
21 22 23
    CallbackContainer,
    LRShedulerCallback,
)
24
from pytorch_widedeep.dataloaders import DataLoaderDefault
25
from pytorch_widedeep.initializers import Initializer, MultipleInitializer
26
from pytorch_widedeep.training._finetune import FineTune
27
from pytorch_widedeep.training._wd_dataset import WideDeepDataset
28 29
from pytorch_widedeep.training.trainer_utils import (
    Alias,
30
    alias_to_loss,
31
    save_epoch_logs,
32
    wd_train_val_split,
33 34
    print_loss_and_metric,
)
35
from pytorch_widedeep.models.tabnet.tab_net_utils import create_explain_matrix
36 37
from pytorch_widedeep.training._multiple_optimizer import MultipleOptimizer
from pytorch_widedeep.training._multiple_transforms import MultipleTransforms
38
from pytorch_widedeep.training._loss_and_obj_aliases import _ObjectiveToMethod
39 40 41
from pytorch_widedeep.training._multiple_lr_scheduler import (
    MultipleLRScheduler,
)
42 43 44 45 46 47 48

n_cpus = os.cpu_count()

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")


49
class Trainer:
50 51 52 53 54 55 56 57 58 59
    r"""Method to set the of attributes that will be used during the
    training process.

    Parameters
    ----------
    model: ``WideDeep``
        An object of class ``WideDeep``
    objective: str
        Defines the objective, loss or cost function.

60
        Param aliases: ``loss_function``, ``loss_fn``, ``loss``,
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        ``cost_function``, ``cost_fn``, ``cost``

        Possible values are:

        - ``binary``, aliases: ``logistic``, ``binary_logloss``, ``binary_cross_entropy``

        - ``binary_focal_loss``

        - ``multiclass``, aliases: ``multi_logloss``, ``cross_entropy``, ``categorical_cross_entropy``,

        - ``multiclass_focal_loss``

        - ``regression``, aliases: ``mse``, ``l2``, ``mean_squared_error``

        - ``mean_absolute_error``, aliases: ``mae``, ``l1``

        - ``mean_squared_log_error``, aliases: ``msle``

        - ``root_mean_squared_error``, aliases:  ``rmse``

        - ``root_mean_squared_log_error``, aliases: ``rmsle``
82
    custom_loss_function: ``nn.Module``, optional, default = None
83 84 85 86 87
        object of class ``nn.Module``. If none of the loss functions
        available suits the user, it is possible to pass a custom loss
        function. See for example
        :class:`pytorch_widedeep.losses.FocalLoss` for the required
        structure of the object or the `Examples
88
        <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
89 90 91 92 93 94
        folder in the repo.

        .. note:: If ``custom_loss_function`` is not None, ``objective`` must be
            'binary', 'multiclass' or 'regression', consistent with the loss
            function

95
    optimizers: ``Optimzer`` or dict, optional, default= None
96 97 98 99 100 101
        - An instance of Pytorch's ``Optimizer`` object (e.g. :obj:`torch.optim.Adam()`) or
        - a dictionary where there keys are the model components (i.e.
          `'wide'`, `'deeptabular'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`)  and
          the values are the corresponding optimizers. If multiple optimizers are used
          the  dictionary **MUST** contain an optimizer per model component.

102
        if no optimizers are passed it will default to ``Adam`` for all
103
        Wide and Deep components
104
    lr_schedulers: ``LRScheduler`` or dict, optional, default=None
105 106 107 108 109
        - An instance of Pytorch's ``LRScheduler`` object (e.g
          :obj:`torch.optim.lr_scheduler.StepLR(opt, step_size=5)`) or
        - a dictionary where there keys are the model componenst (i.e. `'wide'`,
          `'deeptabular'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and the
          values are the corresponding learning rate schedulers.
110
    reducelronplateau_criterion: str, optional. default="loss"
111
        Quantity to be monitored during training if using the
112
        :obj:`ReduceLROnPlateau` learning rate scheduler. Possible value
113
        are: 'loss' or 'metric'.
114 115
    initializers: ``Initializer`` or dict, optional, default=None
        - An instance of an `Initializer`` object see :obj:`pytorch-widedeep.initializers` or
116 117 118
        - a dictionary where there keys are the model components (i.e. `'wide'`,
          `'deeptabular'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`)
          and the values are the corresponding initializers.
119
    transforms: List, optional, default=None
120 121 122 123
        List with :obj:`torchvision.transforms` to be applied to the image
        component of the model (i.e. ``deepimage``) See `torchvision
        transforms
        <https://pytorch.org/docs/stable/torchvision/transforms.html>`_.
124 125 126 127 128 129 130
    callbacks: List, optional, default=None
        List with :obj:`Callback` objects. The three callbacks available in
        ``pytorch-widedeep`` are: ``LRHistory``, ``ModelCheckpoint`` and
        ``EarlyStopping``. The ``History`` and the ``LRShedulerCallback``
        callbacks are used by default. This can also be a custom callback as
        long as the object of type ``Callback``. See
        :obj:`pytorch_widedeep.callbacks.Callback` or the `Examples
131
        <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
132
        folder in the repo
133
    metrics: List, optional, default=None
134 135 136 137 138
        - List of objects of type :obj:`Metric`. Metrics available are:
          ``Accuracy``, ``Precision``, ``Recall``, ``FBetaScore``,
          ``F1Score`` and ``R2Score``. This can also be a custom metric as
          long as it is an object of type :obj:`Metric`. See
          :obj:`pytorch_widedeep.metrics.Metric` or the `Examples
139
          <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
140 141 142 143
          folder in the repo
        - List of objects of type :obj:`torchmetrics.Metric`. This can be any
          metric from torchmetrics library `Examples
          <https://torchmetrics.readthedocs.io/en/latest/references/modules.html#
144
          classification-metrics>`_. This can also be a custom metric as
145 146
          long as it is an object of type :obj:`Metric`. See `the instructions
          <https://torchmetrics.readthedocs.io/en/latest/>`_.
147
    class_weight: float, List or Tuple. optional. default=None
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
        - float indicating the weight of the minority class in binary classification
          problems (e.g. 9.)
        - a list or tuple with weights for the different classes in multiclass
          classification problems  (e.g. [1., 2., 3.]). The weights do not
          need to be normalised. See `this discussion
          <https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/10>`_.
    lambda_sparse: float. default=1e-3
        Tabnet sparse regularization factor
    alpha: float. default=0.25
        if ``objective`` is ``binary_focal_loss`` or
        ``multiclass_focal_loss``, the Focal Loss alpha and gamma
        parameters can be set directly in the ``Trainer`` via the
        ``alpha`` and ``gamma`` parameters
    gamma: float. default=2
        Focal Loss alpha gamma parameter
    verbose: int, default=1
        Setting it to 0 will print nothing during training.
    seed: int, default=1
        Random seed to be used internally for train_test_split

    Attributes
    ----------
    cyclic_lr: bool
        Attribute that indicates if any of the lr_schedulers is cyclic_lr (i.e. ``CyclicLR`` or
        ``OneCycleLR``). See `Pytorch schedulers <https://pytorch.org/docs/stable/optim.html>`_.
173 174
    feature_importance: dict
        dict where the keys are the column names and the values are the
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
        corresponding feature importances. This attribute will only exist
        if the ``deeptabular`` component is a Tabnet model

    Examples
    --------
    >>> import torch
    >>> from torchvision.transforms import ToTensor
    >>>
    >>> # wide deep imports
    >>> from pytorch_widedeep.callbacks import EarlyStopping, LRHistory
    >>> from pytorch_widedeep.initializers import KaimingNormal, KaimingUniform, Normal, Uniform
    >>> from pytorch_widedeep.models import TabResnet, DeepImage, DeepText, Wide, WideDeep
    >>> from pytorch_widedeep import Trainer
    >>> from pytorch_widedeep.optim import RAdam
    >>>
    >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
    >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
    >>> wide = Wide(10, 1)
    >>>
    >>> # build the model
    >>> deeptabular = TabResnet(blocks_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
    >>> deeptext = DeepText(vocab_size=10, embed_dim=4, padding_idx=0)
    >>> deepimage = DeepImage(pretrained=False)
    >>> model = WideDeep(wide=wide, deeptabular=deeptabular, deeptext=deeptext, deepimage=deepimage)
    >>>
    >>> # set optimizers and schedulers
    >>> wide_opt = torch.optim.Adam(model.wide.parameters())
    >>> deep_opt = torch.optim.Adam(model.deeptabular.parameters())
    >>> text_opt = RAdam(model.deeptext.parameters())
    >>> img_opt = RAdam(model.deepimage.parameters())
    >>>
    >>> wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=5)
    >>> deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=3)
    >>> text_sch = torch.optim.lr_scheduler.StepLR(text_opt, step_size=5)
    >>> img_sch = torch.optim.lr_scheduler.StepLR(img_opt, step_size=3)
    >>>
    >>> optimizers = {"wide": wide_opt, "deeptabular": deep_opt, "deeptext": text_opt, "deepimage": img_opt}
    >>> schedulers = {"wide": wide_sch, "deeptabular": deep_sch, "deeptext": text_sch, "deepimage": img_sch}
    >>>
    >>> # set initializers and callbacks
    >>> initializers = {"wide": Uniform, "deeptabular": Normal, "deeptext": KaimingNormal, "deepimage": KaimingUniform}
    >>> transforms = [ToTensor]
    >>> callbacks = [LRHistory(n_epochs=4), EarlyStopping]
    >>>
    >>> # set the trainer
    >>> trainer = Trainer(model, objective="regression", initializers=initializers, optimizers=optimizers,
    ... lr_schedulers=schedulers, callbacks=callbacks, transforms=transforms)
    """

224 225 226 227 228 229
    @Alias(  # noqa: C901
        "objective",
        ["loss_function", "loss_fn", "loss", "cost_function", "cost_fn", "cost"],
    )
    def __init__(
        self,
230
        model: WideDeep,
231 232 233 234
        objective: str,
        custom_loss_function: Optional[Module] = None,
        optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
        lr_schedulers: Optional[Union[LRScheduler, Dict[str, LRScheduler]]] = None,
235
        reducelronplateau_criterion: Optional[str] = "loss",
236
        initializers: Optional[Union[Initializer, Dict[str, Initializer]]] = None,
237 238
        transforms: Optional[List[Transforms]] = None,
        callbacks: Optional[List[Callback]] = None,
239
        metrics: Optional[Union[List[Metric], List[TorchMetric]]] = None,
240
        class_weight: Optional[Union[float, List[float], Tuple[float]]] = None,
241
        lambda_sparse: float = 1e-3,
242 243 244 245 246
        alpha: float = 0.25,
        gamma: float = 2,
        verbose: int = 1,
        seed: int = 1,
    ):
247 248 249 250 251 252 253
        if isinstance(optimizers, Dict):
            if lr_schedulers is not None and not isinstance(lr_schedulers, Dict):
                raise ValueError(
                    "''optimizers' and 'lr_schedulers' must have consistent type: "
                    "(Optimizer and LRScheduler) or (Dict[str, Optimizer] and Dict[str, LRScheduler]) "
                    "Please, read the documentation or see the examples for more details"
                )
254 255 256 257 258 259 260

        if custom_loss_function is not None and objective not in [
            "binary",
            "multiclass",
            "regression",
        ]:
            raise ValueError(
261
                "If 'custom_loss_function' is not None, 'objective' must be 'binary' "
262 263 264
                "'multiclass' or 'regression', consistent with the loss function"
            )

265
        self.reducelronplateau = False
266
        self.reducelronplateau_criterion = reducelronplateau_criterion
267 268 269 270 271 272 273
        if isinstance(lr_schedulers, Dict):
            for _, scheduler in lr_schedulers.items():
                if isinstance(scheduler, ReduceLROnPlateau):
                    self.reducelronplateau = True
        elif isinstance(lr_schedulers, ReduceLROnPlateau):
            self.reducelronplateau = True

274
        self.model = model
275

276
        # Tabnet related set ups
277 278
        if self.model.is_tabnet:
            self.lambda_sparse = lambda_sparse
279
            self.reducing_matrix = create_explain_matrix(self.model)
280

281 282 283
        self.verbose = verbose
        self.seed = seed
        self.objective = objective
284
        self.method = _ObjectiveToMethod.get(objective)
285

286 287 288 289
        # initialize early_stop. If EarlyStopping Callback is used it will
        # take care of it
        self.early_stop = False

290
        self.loss_fn = self._set_loss_fn(
291 292 293
            objective, class_weight, custom_loss_function, alpha, gamma
        )
        self._initialize(initializers)
294 295 296
        self.optimizer = self._set_optimizer(optimizers)
        self.lr_scheduler = self._set_lr_scheduler(lr_schedulers)
        self.transforms = self._set_transforms(transforms)
297 298 299 300
        self._set_callbacks_and_metrics(callbacks, metrics)

        self.model.to(device)

301 302 303 304 305 306 307 308 309 310 311 312 313
    @Alias("finetune", "warmup")  # noqa: C901
    @Alias("finetune_epochs", "warmup_epochs")
    @Alias("finetune_max_lr", "warmup_max_lr")
    @Alias("finetune_deeptabular_gradual", "warmup_deeptabular_gradual")
    @Alias("finetune_deeptabular_max_lr", "warmup_deeptabular_max_lr")
    @Alias("finetune_deeptabular_layers", "warmup_deeptabular_layers")
    @Alias("finetune_deeptext_gradual", "warmup_deeptext_gradual")
    @Alias("finetune_deeptext_max_lr", "warmup_deeptext_max_lr")
    @Alias("finetune_deeptext_layers", "warmup_deeptext_layers")
    @Alias("finetune_deepimage_gradual", "warmup_deepimage_gradual")
    @Alias("finetune_deepimage_max_lr", "warmup_deepimage_max_lr")
    @Alias("finetune_deepimage_layers", "warmup_deepimage_layers")
    @Alias("finetune_routine", "warmup_routine")
314 315 316 317 318 319 320 321 322 323 324
    def fit(  # noqa: C901
        self,
        X_wide: Optional[np.ndarray] = None,
        X_tab: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
        n_epochs: int = 1,
325
        validation_freq: int = 1,
326
        batch_size: int = 32,
327
        custom_dataloader: Union[DataLoader, None] = None,
328 329 330 331 332
        finetune: bool = False,
        finetune_epochs: int = 5,
        finetune_max_lr: float = 0.01,
        finetune_deeptabular_gradual: bool = False,
        finetune_deeptabular_max_lr: float = 0.01,
333
        finetune_deeptabular_layers: Optional[List[nn.Module]] = None,
334 335
        finetune_deeptext_gradual: bool = False,
        finetune_deeptext_max_lr: float = 0.01,
336
        finetune_deeptext_layers: Optional[List[nn.Module]] = None,
337 338
        finetune_deepimage_gradual: bool = False,
        finetune_deepimage_max_lr: float = 0.01,
339
        finetune_deepimage_layers: Optional[List[nn.Module]] = None,
340 341
        finetune_routine: str = "howard",
        stop_after_finetuning: bool = False,
342
        **kwargs,
343
    ):
344
        r"""Fit method.
345

346 347 348 349
        The input datasets can be passed either directly via numpy arrays
        (``X_wide``, ``X_tab``, ``X_text`` or ``X_img``) or alternatively, in
        dictionaries (``X_train`` or ``X_val``).

350 351
        Parameters
        ----------
352
        X_wide: np.ndarray, Optional. default=None
353 354
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
355
        X_tab: np.ndarray, Optional. default=None
356 357
            Input for the ``deeptabular`` model component.
            See :class:`pytorch_widedeep.preprocessing.TabPreprocessor`
358
        X_text: np.ndarray, Optional. default=None
359 360
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
361
        X_img : np.ndarray, Optional. default=None
362 363
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
364
        X_train: Dict, Optional. default=None
365 366 367
            The training dataset can also be passed in a dictionary. Keys are
            `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values
            are the corresponding matrices.
368
        X_val: Dict, Optional. default=None
369 370 371
            The validation dataset can also be passed in a dictionary. Keys
            are `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`.
            Values are the corresponding matrices.
372
        val_split: float, Optional. default=None
373
            train/val split fraction
374
        target: np.ndarray, Optional. default=None
375
            target values
376
        n_epochs: int, default=1
377
            number of epochs
378
        validation_freq: int, default=1
379
            epochs validation frequency
380
        batch_size: int, default=32
381
            batch size
382
        custom_dataloader: ``DataLoader``, Optional, default=None
383
            object of class ``torch.utils.data.DataLoader``. Available
384 385
            predefined dataloaders are in ``pytorch-widedeep.dataloaders``.If
            ``None``, a standard torch ``DataLoader`` is used.
386
        finetune: bool, default=False
387 388 389
            param alias: ``warmup``

            fine-tune individual model components.
390

391 392 393
            .. note:: This functionality can also be used to 'warm-up'
               individual components before the joined training starts, and hence
               its alias. See the Examples folder in the repo for more details
394

395 396 397
            ``pytorch_widedeep`` implements 3 fine-tune routines.

            - fine-tune all trainable layers at once. This routine is is
398 399 400 401 402 403 404
              inspired by the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_. Using a
              Slanted Triangular learing (see `Leslie N. Smith paper
              <https://arxiv.org/pdf/1506.01186.pdf>`_), the process is the
              following: `i`) the learning rate will gradually increase for
              10% of the training steps from max_lr/10 to max_lr. `ii`) It
              will then gradually decrease to max_lr/10 for the remaining 90%
405
              of the steps. The optimizer used in the process is ``Adam``.
406

407 408
            and two gradual fine-tune routines, where only certain layers are
            trained at a time.
409

410 411
            - The so called `Felbo` gradual fine-tune rourine, based on the the
              Felbo et al., 2017 `DeepEmoji paper <https://arxiv.org/abs/1708.00524>`_.
412 413 414
            - The `Howard` routine based on the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_.

415 416
            For details on how these routines work, please see the Examples
            section in this documentation and the `Examples
417
            <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
418
            folder in the repo.
419
        finetune_epochs: int, default=4
420 421 422 423 424
            param alias: ``warmup_epochs``

            Number of fine-tune epochs for those model components that will
            *NOT* be gradually fine-tuned. Those components with gradual
            fine-tune follow their corresponding specific routine.
425
        finetune_max_lr: float, default=0.01
426 427 428 429 430 431 432 433 434
            param alias: ``warmup_max_lr``

            Maximum learning rate during the Triangular Learning rate cycle
            for those model componenst that will *NOT* be gradually fine-tuned
        finetune_deeptabular_gradual: bool, default=False
            param alias: ``warmup_deeptabular_gradual``

            Boolean indicating if the ``deeptabular`` component will be
            fine-tuned gradually
435
        finetune_deeptabular_max_lr: float, default=0.01
436 437
            param alias: ``warmup_deeptabular_max_lr``

438
            Maximum learning rate during the Triangular Learning rate cycle
439 440 441 442 443 444 445 446 447
            for the deeptabular component
        finetune_deeptabular_layers: List, Optional, default=None
            param alias: ``warmup_deeptabular_layers``

            List of :obj:`nn.Modules` that will be fine-tuned gradually.

            .. note:: These have to be in `fine-tune-order`: the layers or blocks
                close to the output neuron(s) first

448
        finetune_deeptext_gradual: bool, default=False
449 450 451 452
            param alias: ``warmup_deeptext_gradual``

            Boolean indicating if the ``deeptext`` component will be
            fine-tuned gradually
453
        finetune_deeptext_max_lr: float, default=0.01
454 455
            param alias: ``warmup_deeptext_max_lr``

456 457
            Maximum learning rate during the Triangular Learning rate cycle
            for the deeptext component
458 459
        finetune_deeptext_layers: List, Optional, default=None
            param alias: ``warmup_deeptext_layers``
460

461 462 463
            List of :obj:`nn.Modules` that will be fine-tuned gradually.

            .. note:: These have to be in `fine-tune-order`: the layers or blocks
464 465
                close to the output neuron(s) first

466
        finetune_deepimage_gradual: bool, default=False
467 468 469 470
            param alias: ``warmup_deepimage_gradual``

            Boolean indicating if the ``deepimage`` component will be
            fine-tuned gradually
471
        finetune_deepimage_max_lr: float, default=0.01
472 473
            param alias: ``warmup_deepimage_max_lr``

474
            Maximum learning rate during the Triangular Learning rate cycle
475
            for the ``deepimage`` component
476 477 478 479
        finetune_deepimage_layers: List, Optional, default=None
            param alias: ``warmup_deepimage_layers``

            List of :obj:`nn.Modules` that will be fine-tuned gradually.
480

481
            .. note:: These have to be in `fine-tune-order`: the layers or blocks
482 483
                close to the output neuron(s) first

484 485
        finetune_routine: str, default = "howard"
            param alias: ``warmup_routine``
486

487
            Warm up routine. On of "felbo" or "howard". See the examples
488
            section in this documentation and the corresponding repo for
489
            details on how to use fine-tune routines
490 491 492 493

        Examples
        --------

494
        For a series of comprehensive examples please, see the `Examples
495
        <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
496 497
        folder in the repo

498 499 500
        For completion, here we include some `"fabricated"` examples, i.e.
        these assume you have already built a model and instantiated a
        ``Trainer``, that is ready to fit
501

502
        .. code-block:: python
503

504 505
            # Ex 1. using train input arrays directly and no validation
            trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, n_epochs=10, batch_size=256)
506 507


508
        .. code-block:: python
509

510 511
            # Ex 2: using train input arrays directly and validation with val_split
            trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, n_epochs=10, batch_size=256, val_split=0.2)
512 513


514
        .. code-block:: python
515

516 517 518
            # Ex 3: using train dict and val_split
            X_train = {'X_wide': X_wide, 'X_tab': X_tab, 'target': y}
            trainer.fit(X_train, n_epochs=10, batch_size=256, val_split=0.2)
519

520 521 522 523 524 525 526

        .. code-block:: python

            # Ex 4: validation using training and validation dicts
            X_train = {'X_wide': X_wide_tr, 'X_tab': X_tab_tr, 'target': y_tr}
            X_val = {'X_wide': X_wide_val, 'X_tab': X_tab_val, 'target': y_val}
            trainer.fit(X_train=X_train, X_val=X_val n_epochs=10, batch_size=256)
527 528 529
        """

        self.batch_size = batch_size
530 531 532 533 534 535 536 537 538 539 540
        train_set, eval_set = wd_train_val_split(
            self.seed,
            self.method,
            X_wide,
            X_tab,
            X_text,
            X_img,
            X_train,
            X_val,
            val_split,
            target,
541
        )
542 543 544
        if isinstance(custom_dataloader, type):
            if issubclass(custom_dataloader, DataLoader):
                train_loader = custom_dataloader(
545 546 547 548
                    dataset=train_set,
                    batch_size=batch_size,
                    num_workers=n_cpus,
                    **kwargs,
549 550
                )
            else:
551 552 553 554 555 556
                NotImplementedError(
                    "Custom DataLoader must be a subclass of "
                    "torch.utils.data.DataLoader, please see the "
                    "pytorch documentation or examples in "
                    "pytorch_widedeep.dataloaders"
                )
557
        else:
558
            train_loader = DataLoaderDefault(
559 560
                dataset=train_set, batch_size=batch_size, num_workers=n_cpus
            )
561 562 563 564 565 566 567 568 569 570
        train_steps = len(train_loader)
        if eval_set is not None:
            eval_loader = DataLoader(
                dataset=eval_set,
                batch_size=batch_size,
                num_workers=n_cpus,
                shuffle=False,
            )
            eval_steps = len(eval_loader)

571 572
        if finetune:
            self._finetune(
573
                train_loader,
574 575 576 577 578 579 580 581 582 583 584 585
                finetune_epochs,
                finetune_max_lr,
                finetune_deeptabular_gradual,
                finetune_deeptabular_layers,
                finetune_deeptabular_max_lr,
                finetune_deeptext_gradual,
                finetune_deeptext_layers,
                finetune_deeptext_max_lr,
                finetune_deepimage_gradual,
                finetune_deepimage_layers,
                finetune_deepimage_max_lr,
                finetune_routine,
586
            )
587 588 589 590 591 592 593 594 595 596
            if stop_after_finetuning:
                print("Fine-tuning finished")
                return
            else:
                if self.verbose:
                    print(
                        "Fine-tuning of individual components completed. "
                        "Training the whole model for {} epochs".format(n_epochs)
                    )

597 598 599 600 601 602
        self.callback_container.on_train_begin(
            {"batch_size": batch_size, "train_steps": train_steps, "n_epochs": n_epochs}
        )
        for epoch in range(n_epochs):
            epoch_logs: Dict[str, float] = {}
            self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
603

604 605
            self.train_running_loss = 0.0
            with trange(train_steps, disable=self.verbose != 1) as t:
606
                for batch_idx, (data, targett) in zip(t, train_loader):
607
                    t.set_description("epoch %i" % (epoch + 1))
608 609
                    train_score, train_loss = self._train_step(data, targett, batch_idx)
                    print_loss_and_metric(t, train_loss, train_score)
610
                    self.callback_container.on_batch_end(batch=batch_idx)
611
            epoch_logs = save_epoch_logs(epoch_logs, train_loss, train_score, "train")
612

613
            on_epoch_end_metric = None
614 615 616
            if eval_set is not None and epoch % validation_freq == (
                validation_freq - 1
            ):
617
                self.callback_container.on_eval_begin()
618 619 620 621
                self.valid_running_loss = 0.0
                with trange(eval_steps, disable=self.verbose != 1) as v:
                    for i, (data, targett) in zip(v, eval_loader):
                        v.set_description("valid")
622 623 624 625 626
                        val_score, val_loss = self._eval_step(data, targett, i)
                        print_loss_and_metric(v, val_loss, val_score)
                epoch_logs = save_epoch_logs(epoch_logs, val_loss, val_score, "val")

                if self.reducelronplateau:
627
                    if self.reducelronplateau_criterion == "loss":
628 629
                        on_epoch_end_metric = val_loss
                    else:
630 631 632
                        on_epoch_end_metric = val_score[
                            self.reducelronplateau_criterion
                        ]
633 634

            self.callback_container.on_epoch_end(epoch, epoch_logs, on_epoch_end_metric)
635

636 637 638
            if self.early_stop:
                self.callback_container.on_train_end(epoch_logs)
                break
639

640
        self.callback_container.on_train_end(epoch_logs)
641
        if self.model.is_tabnet:
642
            self._compute_feature_importance(train_loader)
643
        self._restore_best_weights()
644 645
        self.model.train()

646
    def predict(  # type: ignore[return]
647 648 649 650 651 652
        self,
        X_wide: Optional[np.ndarray] = None,
        X_tab: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
653
        batch_size: int = 256,
654 655 656
    ) -> np.ndarray:
        r"""Returns the predictions

657 658 659 660 661
        The input datasets can be passed either directly via numpy arrays
        (``X_wide``, ``X_tab``, ``X_text`` or ``X_img``) or alternatively, in
        a dictionary (``X_test``)


662 663
        Parameters
        ----------
664
        X_wide: np.ndarray, Optional. default=None
665 666
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
667
        X_tab: np.ndarray, Optional. default=None
668 669
            Input for the ``deeptabular`` model component.
            See :class:`pytorch_widedeep.preprocessing.TabPreprocessor`
670
        X_text: np.ndarray, Optional. default=None
671 672
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
673
        X_img : np.ndarray, Optional. default=None
674 675
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
676
        X_test: Dict, Optional. default=None
677 678 679
            The test dataset can also be passed in a dictionary. Keys are
            `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values
            are the corresponding matrices.
680 681 682 683
        batch_size: int, default = 256
            If a trainer is used to predict after having trained a model, the
            ``batch_size`` needs to be defined as it will not be defined as
            the :obj:`Trainer` is instantiated
684 685
        """

686
        preds_l = self._predict(X_wide, X_tab, X_text, X_img, X_test, batch_size)
687 688 689 690 691 692 693
        if self.method == "regression":
            return np.vstack(preds_l).squeeze(1)
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
            return (preds > 0.5).astype("int")
        if self.method == "multiclass":
            preds = np.vstack(preds_l)
694
            return np.argmax(preds, 1)  # type: ignore[return-value]
695

696
    def predict_proba(  # type: ignore[return]
697 698 699 700 701 702
        self,
        X_wide: Optional[np.ndarray] = None,
        X_tab: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
703
        batch_size: int = 256,
704 705 706
    ) -> np.ndarray:
        r"""Returns the predicted probabilities for the test dataset for  binary
        and multiclass methods
707

708 709 710 711
        The input datasets can be passed either directly via numpy arrays
        (``X_wide``, ``X_tab``, ``X_text`` or ``X_img``) or alternatively, in
        a dictionary (``X_test``)

712 713 714 715 716 717 718 719 720 721 722 723 724 725 726
        Parameters
        ----------
        X_wide: np.ndarray, Optional. default=None
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
        X_tab: np.ndarray, Optional. default=None
            Input for the ``deeptabular`` model component.
            See :class:`pytorch_widedeep.preprocessing.TabPreprocessor`
        X_text: np.ndarray, Optional. default=None
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
        X_img : np.ndarray, Optional. default=None
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
        X_test: Dict, Optional. default=None
727 728 729
            The test dataset can also be passed in a dictionary. Keys are
            `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values
            are the corresponding matrices.
730 731 732 733
        batch_size: int, default = 256
            If a trainer is used to predict after having trained a model, the
            ``batch_size`` needs to be defined as it will not be defined as
            the :obj:`Trainer` is instantiated
734 735
        """

736
        preds_l = self._predict(X_wide, X_tab, X_text, X_img, X_test, batch_size)
737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
            probs = np.zeros([preds.shape[0], 2])
            probs[:, 0] = 1 - preds
            probs[:, 1] = preds
            return probs
        if self.method == "multiclass":
            return np.vstack(preds_l)

    def get_embeddings(
        self, col_name: str, cat_encoding_dict: Dict[str, Dict[str, int]]
    ) -> Dict[str, np.ndarray]:  # pragma: no cover
        r"""Returns the learned embeddings for the categorical features passed through
        ``deeptabular``.

        This method is designed to take an encoding dictionary in the same
753
        format as that of the :obj:`LabelEncoder` Attribute in the class
754 755 756 757 758 759 760 761
        :obj:`TabPreprocessor`. See
        :class:`pytorch_widedeep.preprocessing.TabPreprocessor` and
        :class:`pytorch_widedeep.utils.dense_utils.LabelEncder`.

        Parameters
        ----------
        col_name: str,
            Column name of the feature we want to get the embeddings for
762 763 764 765 766 767 768
        cat_encoding_dict: Dict
            Dictionary where the keys are the name of the column for which we
            want to retrieve the embeddings and the values are also of type
            Dict. These Dict values have keys that are the categories for that
            column and the values are the corresponding numberical encodings

            e.g.: {'column': {'cat_0': 1, 'cat_1': 2, ...}}
769 770 771 772

        Examples
        --------

773
        For a series of comprehensive examples please, see the `Examples
774
        <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`__
775 776 777 778 779 780 781
        folder in the repo

        For completion, here we include a `"fabricated"` example, i.e.
        assuming we have already trained the model, that we have the
        categorical encodings in a dictionary name ``encoding_dict``, and that
        there is a column called `'education'`:

782 783 784
        .. code-block:: python

            trainer.get_embeddings(col_name="education", cat_encoding_dict=encoding_dict)
785 786 787 788 789 790 791 792 793 794 795
        """
        for n, p in self.model.named_parameters():
            if "embed_layers" in n and col_name in n:
                embed_mtx = p.cpu().data.numpy()
        encoding_dict = cat_encoding_dict[col_name]
        inv_encoding_dict = {v: k for k, v in encoding_dict.items()}
        cat_embed_dict = {}
        for idx, value in inv_encoding_dict.items():
            cat_embed_dict[value] = embed_mtx[idx]
        return cat_embed_dict

796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828
    def explain(self, X_tab: np.ndarray, save_step_masks: bool = False):
        """
        Returns the aggregated feature importance for each instance (or
        observation) in the ``X_tab`` array. If ``save_step_masks`` is set to
        ``True``, the masks per step will also be returned.

        Parameters
        ----------
        X_tab: np.ndarray
            Input array corresponding **only** to the deeptabular component
        save_step_masks: bool
            Boolean indicating if the masks per step will be returned

        Returns
        -------
        res: np.ndarray, Tuple
            Array or Tuple of two arrays with the corresponding aggregated
            feature importance and the masks per step if ``save_step_masks``
            is set to ``True``
        """
        loader = DataLoader(
            dataset=WideDeepDataset(**{"X_tab": X_tab}),
            batch_size=self.batch_size,
            num_workers=n_cpus,
            shuffle=False,
        )

        self.model.eval()
        tabnet_backbone = list(self.model.deeptabular.children())[0]

        m_explain_l = []
        for batch_nb, data in enumerate(loader):
            X = data["deeptabular"].to(device)
829
            M_explain, masks = tabnet_backbone.forward_masks(X)  # type: ignore[operator]
830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854
            m_explain_l.append(
                csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix)
            )
            if save_step_masks:
                for key, value in masks.items():
                    masks[key] = csc_matrix.dot(
                        value.cpu().detach().numpy(), self.reducing_matrix
                    )
                if batch_nb == 0:
                    m_explain_step = masks
                else:
                    for key, value in masks.items():
                        m_explain_step[key] = np.vstack([m_explain_step[key], value])

        m_explain_agg = np.vstack(m_explain_l)
        m_explain_agg_norm = m_explain_agg / m_explain_agg.sum(axis=1)[:, np.newaxis]

        res = (
            (m_explain_agg_norm, m_explain_step)
            if save_step_masks
            else np.vstack(m_explain_agg_norm)
        )

        return res

855 856 857 858 859 860
    def save(
        self,
        path: str,
        save_state_dict: bool = False,
        model_filename: str = "wd_model.pt",
    ):
861
        r"""Saves the model, training and evaluation history, and the
862
        ``feature_importance`` attribute (if the ``deeptabular`` component is a
863
        Tabnet model) to disk
864 865 866 867 868 869 870 871 872 873 874

        The ``Trainer`` class is built so that it 'just' trains a model. With
        that in mind, all the torch related parameters (such as optimizers,
        learning rate schedulers, initializers, etc) have to be defined
        externally and then passed to the ``Trainer``. As a result, the
        ``Trainer`` does not generate any attribute or additional data
        products that need to be saved other than the ``model`` object itself,
        which can be saved as any other torch model (e.g. ``torch.save(model,
        path)``).

        The exception is Tabnet. If the ``deeptabular`` component is a Tabnet
875
        model, an attribute (a dict) called ``feature_importance`` will be
876 877
        created at the end of the training process. Therefore, a ``save``
        method was created that will save both the feature importance
878 879
        dictionary to a json file and, since we are here, the model weights,
        training history and learning rate history.
880 881 882 883

        Parameters
        ----------
        path: str
884 885 886 887 888 889 890
            path to the directory where the model and the feature importance
            attribute will be saved.
        save_state_dict: bool, default = False
            Boolean indicating whether to save directly the model or the
            model's state dictionary
        model_filename: str, Optional, default = "wd_model.pt"
            filename where the model weights will be store
891
        """
892

893 894 895 896 897 898 899 900 901 902 903
        save_dir = Path(path)
        history_dir = save_dir / "history"
        history_dir.mkdir(exist_ok=True, parents=True)

        # the trainer is run with the History Callback by default
        with open(history_dir / "train_eval_history.json", "w") as teh:
            json.dump(self.history, teh)  # type: ignore[attr-defined]

        has_lr_history = any(
            [clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks]
        )
904
        if self.lr_scheduler is not None and has_lr_history:
905 906
            with open(history_dir / "lr_history.json", "w") as lrh:
                json.dump(self.lr_history, lrh)  # type: ignore[attr-defined]
907

908
        model_path = save_dir / model_filename
909 910
        if save_state_dict:
            torch.save(self.model.state_dict(), model_path)
911
        else:
912
            torch.save(self.model, model_path)
913

914
        if self.model.is_tabnet:
915
            with open(save_dir / "feature_importance.json", "w") as fi:
916
                json.dump(self.feature_importance, fi)
917

918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945
    def _restore_best_weights(self):
        already_restored = any(
            [
                (
                    callback.__class__.__name__ == "EarlyStopping"
                    and callback.restore_best_weights
                )
                for callback in self.callback_container.callbacks
            ]
        )
        if already_restored:
            pass
        else:
            for callback in self.callback_container.callbacks:
                if callback.__class__.__name__ == "ModelCheckpoint":
                    if callback.save_best_only:
                        filepath = "{}_{}.p".format(
                            callback.filepath, callback.best_epoch + 1
                        )
                        if self.verbose:
                            print(
                                f"Model weights restored to best epoch: {callback.best_epoch + 1}"
                            )
                        self.model.load_state_dict(torch.load(filepath))
                    else:
                        if self.verbose:
                            print(
                                "Model weights after training corresponds to the those of the "
946 947
                                "final epoch which might not be the best performing weights. Use"
                                "the 'ModelCheckpoint' Callback to restore the best epoch weights."
948 949
                            )

950
    def _finetune(
951 952 953 954
        self,
        loader: DataLoader,
        n_epochs: int,
        max_lr: float,
955 956 957
        deeptabular_gradual: bool,
        deeptabular_layers: List[nn.Module],
        deeptabular_max_lr: float,
958 959 960 961 962 963 964 965 966
        deeptext_gradual: bool,
        deeptext_layers: List[nn.Module],
        deeptext_max_lr: float,
        deepimage_gradual: bool,
        deepimage_layers: List[nn.Module],
        deepimage_max_lr: float,
        routine: str = "felbo",
    ):  # pragma: no cover
        r"""
967
        Simple wrap-up to individually fine-tune model components
968 969 970 971 972 973 974
        """
        if self.model.deephead is not None:
            raise ValueError(
                "Currently warming up is only supported without a fully connected 'DeepHead'"
            )
        # This is not the most elegant solution, but is a soluton "in-between"
        # a non elegant one and re-factoring the whole code
975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991
        finetuner = FineTune(self.loss_fn, self.metric, self.method, self.verbose)
        if self.model.wide:
            finetuner.finetune_all(self.model.wide, "wide", loader, n_epochs, max_lr)
        if self.model.deeptabular:
            if deeptabular_gradual:
                finetuner.finetune_gradual(
                    self.model.deeptabular,
                    "deeptabular",
                    loader,
                    deeptabular_max_lr,
                    deeptabular_layers,
                    routine,
                )
            else:
                finetuner.finetune_all(
                    self.model.deeptabular, "deeptabular", loader, n_epochs, max_lr
                )
992 993
        if self.model.deeptext:
            if deeptext_gradual:
994
                finetuner.finetune_gradual(
995 996 997 998 999 1000 1001 1002
                    self.model.deeptext,
                    "deeptext",
                    loader,
                    deeptext_max_lr,
                    deeptext_layers,
                    routine,
                )
            else:
1003
                finetuner.finetune_all(
1004 1005 1006 1007
                    self.model.deeptext, "deeptext", loader, n_epochs, max_lr
                )
        if self.model.deepimage:
            if deepimage_gradual:
1008
                finetuner.finetune_gradual(
1009 1010 1011 1012 1013 1014 1015 1016
                    self.model.deepimage,
                    "deepimage",
                    loader,
                    deepimage_max_lr,
                    deepimage_layers,
                    routine,
                )
            else:
1017
                finetuner.finetune_all(
1018 1019 1020
                    self.model.deepimage, "deepimage", loader, n_epochs, max_lr
                )

1021
    def _train_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
1022 1023 1024 1025 1026 1027 1028
        self.model.train()
        X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
        y = target.view(-1, 1).float() if self.method != "multiclass" else target
        y = y.to(device)

        self.optimizer.zero_grad()
        y_pred = self.model(X)
1029 1030 1031 1032 1033 1034
        if self.model.is_tabnet:
            loss = self.loss_fn(y_pred[0], y) - self.lambda_sparse * y_pred[1]
            score = self._get_score(y_pred[0], y)
        else:
            loss = self.loss_fn(y_pred, y)
            score = self._get_score(y_pred, y)
1035 1036 1037 1038 1039 1040
        loss.backward()
        self.optimizer.step()

        self.train_running_loss += loss.item()
        avg_loss = self.train_running_loss / (batch_idx + 1)

1041
        return score, avg_loss
1042

1043
    def _eval_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
1044 1045 1046 1047 1048 1049 1050 1051

        self.model.eval()
        with torch.no_grad():
            X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
            y = target.view(-1, 1).float() if self.method != "multiclass" else target
            y = y.to(device)

            y_pred = self.model(X)
1052 1053 1054 1055 1056 1057 1058
            if self.model.is_tabnet:
                loss = self.loss_fn(y_pred[0], y) - self.lambda_sparse * y_pred[1]
                score = self._get_score(y_pred[0], y)
            else:
                score = self._get_score(y_pred, y)
                loss = self.loss_fn(y_pred, y)

1059 1060 1061
            self.valid_running_loss += loss.item()
            avg_loss = self.valid_running_loss / (batch_idx + 1)

1062
        return score, avg_loss
1063 1064

    def _get_score(self, y_pred, y):
1065
        if self.metric is not None:
1066 1067
            if self.method == "regression":
                score = self.metric(y_pred, y)
1068 1069 1070 1071
            if self.method == "binary":
                score = self.metric(torch.sigmoid(y_pred), y)
            if self.method == "multiclass":
                score = self.metric(F.softmax(y_pred, dim=1), y)
1072
            return score
1073
        else:
1074
            return None
1075

1076 1077 1078
    def _compute_feature_importance(self, loader: DataLoader):
        self.model.eval()
        tabnet_backbone = list(self.model.deeptabular.children())[0]
1079
        feat_imp = np.zeros((tabnet_backbone.embed_and_cont_dim))  # type: ignore[arg-type]
1080 1081 1082 1083
        for data, target in loader:
            X = data["deeptabular"].to(device)
            y = target.view(-1, 1).float() if self.method != "multiclass" else target
            y = y.to(device)
1084
            M_explain, masks = tabnet_backbone.forward_masks(X)  # type: ignore[operator]
1085 1086 1087 1088 1089 1090
            feat_imp += M_explain.sum(dim=0).cpu().detach().numpy()

        feat_imp = csc_matrix.dot(feat_imp, self.reducing_matrix)
        feat_imp = feat_imp / np.sum(feat_imp)

        self.feature_importance = {
1091
            k: v for k, v in zip(tabnet_backbone.column_idx.keys(), feat_imp)  # type: ignore[operator, union-attr]
1092 1093
        }

1094 1095 1096 1097 1098 1099 1100
    def _predict(
        self,
        X_wide: Optional[np.ndarray] = None,
        X_tab: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
1101
        batch_size: int = 256,
1102
    ) -> List:
1103
        r"""Private method to avoid code repetition in predict and
1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120
        predict_proba. For parameter information, please, see the .predict()
        method documentation
        """
        if X_test is not None:
            test_set = WideDeepDataset(**X_test)
        else:
            load_dict = {}
            if X_wide is not None:
                load_dict = {"X_wide": X_wide}
            if X_tab is not None:
                load_dict.update({"X_tab": X_tab})
            if X_text is not None:
                load_dict.update({"X_text": X_text})
            if X_img is not None:
                load_dict.update({"X_img": X_img})
            test_set = WideDeepDataset(**load_dict)

1121 1122 1123
        if not hasattr(self, "batch_size"):
            self.batch_size = batch_size

1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138
        test_loader = DataLoader(
            dataset=test_set,
            batch_size=self.batch_size,
            num_workers=n_cpus,
            shuffle=False,
        )
        test_steps = (len(test_loader.dataset) // test_loader.batch_size) + 1  # type: ignore[arg-type]

        self.model.eval()
        preds_l = []
        with torch.no_grad():
            with trange(test_steps, disable=self.verbose != 1) as t:
                for i, data in zip(t, test_loader):
                    t.set_description("predict")
                    X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
1139 1140 1141
                    preds = (
                        self.model(X) if not self.model.is_tabnet else self.model(X)[0]
                    )
1142 1143 1144 1145 1146 1147 1148 1149 1150
                    if self.method == "binary":
                        preds = torch.sigmoid(preds)
                    if self.method == "multiclass":
                        preds = F.softmax(preds, dim=1)
                    preds = preds.cpu().data.numpy()
                    preds_l.append(preds)
        self.model.train()
        return preds_l

1151
    def _set_loss_fn(self, objective, class_weight, custom_loss_function, alpha, gamma):
1152
        if class_weight is not None:
1153
            class_weight = torch.tensor(class_weight).to(device)
1154 1155
        if custom_loss_function is not None:
            return custom_loss_function
1156
        elif self.method != "regression" and "focal_loss" not in objective:
1157
            return alias_to_loss(objective, weight=class_weight)
1158
        elif "focal_loss" in objective:
1159
            return alias_to_loss(objective, alpha=alpha, gamma=gamma)
1160
        else:
1161
            return alias_to_loss(objective)
1162 1163 1164

    def _initialize(self, initializers):
        if initializers is not None:
1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175
            if isinstance(initializers, Dict):
                self.initializer = MultipleInitializer(
                    initializers, verbose=self.verbose
                )
                self.initializer.apply(self.model)
            elif isinstance(initializers, type):
                self.initializer = initializers()
                self.initializer(self.model)
            elif isinstance(initializers, Initializer):
                self.initializer = initializers
                self.initializer(self.model)
1176

1177
    def _set_optimizer(self, optimizers):
1178 1179
        if optimizers is not None:
            if isinstance(optimizers, Optimizer):
1180
                optimizer: Union[Optimizer, MultipleOptimizer] = optimizers
1181 1182
            elif isinstance(optimizers, Dict):
                opt_names = list(optimizers.keys())
1183
                mod_names = [n for n, c in self.model.named_children()]
1184 1185
                for mn in mod_names:
                    assert mn in opt_names, "No optimizer found for {}".format(mn)
1186
                optimizer = MultipleOptimizer(optimizers)
1187
        else:
1188
            optimizer = torch.optim.Adam(self.model.parameters())  # type: ignore
1189
        return optimizer
1190

1191
    def _set_lr_scheduler(self, lr_schedulers):
1192
        if lr_schedulers is not None:
1193 1194 1195 1196 1197 1198
            # ReduceLROnPlateau is special, only scheduler that is 'just' an
            # object rather than a LRScheduler
            if isinstance(lr_schedulers, LRScheduler) or isinstance(
                lr_schedulers, ReduceLROnPlateau
            ):
                lr_scheduler = lr_schedulers
1199 1200 1201 1202 1203 1204 1205 1206 1207 1208
                cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower()
            else:
                lr_scheduler = MultipleLRScheduler(lr_schedulers)
                scheduler_names = [
                    sc.__class__.__name__.lower()
                    for _, sc in lr_scheduler._schedulers.items()
                ]
                cyclic_lr = any(["cycl" in sn for sn in scheduler_names])
        else:
            lr_scheduler, cyclic_lr = None, False
1209 1210
        self.cyclic_lr = cyclic_lr
        return lr_scheduler
1211

1212
    @staticmethod
1213
    def _set_transforms(transforms):
1214 1215 1216 1217 1218 1219
        if transforms is not None:
            return MultipleTransforms(transforms)()
        else:
            return None

    def _set_callbacks_and_metrics(self, callbacks, metrics):
1220
        self.callbacks: List = [History(), LRShedulerCallback()]
1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233
        if callbacks is not None:
            for callback in callbacks:
                if isinstance(callback, type):
                    callback = callback()
                self.callbacks.append(callback)
        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
            self.callbacks += [MetricCallback(self.metric)]
        else:
            self.metric = None
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self.model)
        self.callback_container.set_trainer(self)