wide_deep.py 48.0 KB
Newer Older
1
import os
2
import warnings
3 4

import numpy as np
5
import torch
6 7
import torch.nn as nn
import torch.nn.functional as F
8 9 10
from tqdm import trange
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
11

12
from ..losses import FocalLoss
13 14 15 16 17
from ._warmup import WarmUp
from ..metrics import Metric, MetricCallback, MultipleMetrics
from ..wdtypes import *
from ..callbacks import History, Callback, CallbackContainer
from .deep_dense import dense_layer
18
from ._wd_dataset import WideDeepDataset
19
from ..initializers import Initializer, MultipleInitializer
20 21
from ._multiple_optimizer import MultipleOptimizer
from ._multiple_transforms import MultipleTransforms
22
from ._multiple_lr_scheduler import MultipleLRScheduler
J
jrzaurin 已提交
23

24
n_cpus = os.cpu_count()
25 26 27 28
use_cuda = torch.cuda.is_available()


class WideDeep(nn.Module):
29 30 31
    r"""Main collector class that combines all ``Wide``, ``DeepDense``,
    ``DeepText`` and ``DeepImage`` models.

32 33
    There are two options to combine these models that correspond to the two
    architectures that ``pytorch-widedeep`` can build.
34 35 36 37 38 39

        - Directly connecting the output of the model components to an ouput neuron(s).

        - Adding a `Fully-Connected Head` (FC-Head) on top of the deep models.
          This FC-Head will combine the output form the ``DeepDense``, ``DeepText`` and
          ``DeepImage`` and will be then connected to the output neuron(s).
40 41 42

    Parameters
    ----------
43
    wide: nn.Module
44 45 46 47
        Wide model. We recommend using the ``Wide`` class in this package.
        However, it is possible to use a custom model as long as is consistent
        with the required architecture, see
        :class:`pytorch_widedeep.models.wide.Wide`
48
    deepdense: nn.Module
49 50 51
        `Deep dense` model comprised by the embeddings for the categorical
        features combined with numerical (also referred as continuous)
        features. We recommend using the ``DeepDense`` class in this package.
52
        However, a custom model as long as is  consistent with the required
53
        architecture. See :class:`pytorch_widedeep.models.deep_dense.DeepDense`.
54
    deeptext: nn.Module, Optional
55 56 57 58
        `Deep text` model for the text input. Must be an object of class
        ``DeepText`` or a custom model as long as is consistent with the
        required architecture. See
        :class:`pytorch_widedeep.models.deep_dense.DeepText`
59
    deepimage: nn.Module, Optional
60 61 62 63
        `Deep Image` model for the images input. Must be an object of class
        ``DeepImage`` or a custom model as long as is consistent with the
        required architecture. See
        :class:`pytorch_widedeep.models.deep_dense.DeepImage`
64
    deephead: nn.Module, Optional
65
        `Dense` model consisting in a stack of dense layers. The FC-Head.
66
    head_layers: List, Optional
67
        Alternatively, we can use ``head_layers`` to specify the sizes of the
68
        stacked dense layers in the fc-head e.g: ``[128, 64]``
69
    head_dropout: List, Optional
70
        Dropout between the layers in ``head_layers``. e.g: ``[0.5, 0.5]``
71
    head_batchnorm: bool, Optional
72
        Specifies if batch normalizatin should be included in the dense layers
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    pred_dim: int
        Size of the final wide and deep output layer containing the
        predictions. `1` for regression and binary classification or `number
        of classes` for multiclass classification.


    .. note:: With the exception of ``cyclic``, all attributes are direct assignations of
        the corresponding parameters used when calling ``compile``.  Therefore,
        see the parameters at
        :class:`pytorch_widedeep.models.wide_deep.WideDeep.compile` for a full
        list of the attributes of an instance of
        :class:`pytorch_widedeep.models.wide_deep.Wide`


    Attributes
    ----------
    cyclic: :obj:`bool`
        Attribute that indicates if any of the lr_schedulers is cyclic (i.e. ``CyclicLR`` or
        ``OneCycleLR``). See `Pytorch schedulers <https://pytorch.org/docs/stable/optim.html>`_.
92 93 94 95 96 97 98 99 100 101 102


    .. note:: While I recommend using the ``Wide`` and ``DeepDense`` classes within
        this package when building the corresponding model components, it is very
        likely that the user will want to use custom text and image models. That
        is perfectly possible. Simply, build them and pass them as the
        corresponding parameters. Note that the custom models MUST return a last
        layer of activations (i.e. not the final prediction) so that  these
        activations are collected by ``WideDeep`` and combined accordingly. In
        addition, the models MUST also contain an attribute ``output_dim`` with
        the size of these last layers of activations. See for example
103 104 105
        :class:`pytorch_widedeep.models.deep_dense.DeepDense`

    """
J
jrzaurin 已提交
106 107 108 109 110

    def __init__(
        self,
        wide: nn.Module,
        deepdense: nn.Module,
111
        pred_dim: int = 1,
J
jrzaurin 已提交
112 113 114 115 116 117 118
        deeptext: Optional[nn.Module] = None,
        deepimage: Optional[nn.Module] = None,
        deephead: Optional[nn.Module] = None,
        head_layers: Optional[List[int]] = None,
        head_dropout: Optional[List] = None,
        head_batchnorm: Optional[bool] = None,
    ):
119

120
        super(WideDeep, self).__init__()
121

122
        # check that model components have the required output_dim attribute
J
jrzaurin 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
        if not hasattr(deepdense, "output_dim"):
            raise AttributeError(
                "deepdense model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepDense"
            )
        if deeptext is not None and not hasattr(deeptext, "output_dim"):
            raise AttributeError(
                "deeptext model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
        if deepimage is not None and not hasattr(deepimage, "output_dim"):
            raise AttributeError(
                "deepimage model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
138

139
        # The main 5 components of the wide and deep assemble
140 141
        self.wide = wide
        self.deepdense = deepdense
J
jrzaurin 已提交
142
        self.deeptext = deeptext
143
        self.deepimage = deepimage
144 145
        self.deephead = deephead

146 147 148 149 150 151 152
        if deephead is not None and head_layers is not None:
            warnings.warn(
                "both 'deephead' and 'head_layers' are not None."
                "'deephead' takes priority and will be used",
                UserWarning,
            )

153 154
        if self.deephead is None:
            if head_layers is not None:
J
jrzaurin 已提交
155
                input_dim: int = self.deepdense.output_dim  # type:ignore
M
Minjin Choi 已提交
156
                if self.deeptext is not None:
157
                    input_dim += self.deeptext.output_dim  # type:ignore
M
Minjin Choi 已提交
158
                if self.deepimage is not None:
159
                    input_dim += self.deepimage.output_dim  # type:ignore
160
                head_layers = [input_dim] + head_layers
J
jrzaurin 已提交
161 162
                if not head_dropout:
                    head_dropout = [0.0] * (len(head_layers) - 1)
163 164 165
                self.deephead = nn.Sequential()
                for i in range(1, len(head_layers)):
                    self.deephead.add_module(
J
jrzaurin 已提交
166 167 168 169 170 171 172 173 174
                        "head_layer_{}".format(i - 1),
                        dense_layer(
                            head_layers[i - 1],
                            head_layers[i],
                            head_dropout[i - 1],
                            head_batchnorm,
                        ),
                    )
                self.deephead.add_module(
175
                    "head_out", nn.Linear(head_layers[-1], pred_dim)
J
jrzaurin 已提交
176
                )
177 178
            else:
                self.deepdense = nn.Sequential(
179
                    self.deepdense, nn.Linear(self.deepdense.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
180
                )
181 182
                if self.deeptext is not None:
                    self.deeptext = nn.Sequential(
183
                        self.deeptext, nn.Linear(self.deeptext.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
184
                    )
185 186
                if self.deepimage is not None:
                    self.deepimage = nn.Sequential(
187
                        self.deepimage, nn.Linear(self.deepimage.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
188
                    )
189

J
jrzaurin 已提交
190
    def forward(self, X: Dict[str, Tensor]) -> Tensor:  # type: ignore
191

192
        # Wide output: direct connection to the output neuron(s)
J
jrzaurin 已提交
193
        out = self.wide(X["wide"])
194 195 196 197

        # Deep output: either connected directly to the output neuron(s) or
        # passed through a head first
        if self.deephead:
J
jrzaurin 已提交
198
            deepside = self.deepdense(X["deepdense"])
199
            if self.deeptext is not None:
J
jrzaurin 已提交
200
                deepside = torch.cat([deepside, self.deeptext(X["deeptext"])], axis=1)  # type: ignore
201
            if self.deepimage is not None:
J
jrzaurin 已提交
202
                deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1)  # type: ignore
203
            deepside_out = self.deephead(deepside)
204
            return out.add(deepside_out)
205
        else:
206
            out.add(self.deepdense(X["deepdense"]))
207
            if self.deeptext is not None:
208
                out.add(self.deeptext(X["deeptext"]))
209
            if self.deepimage is not None:
210
                out.add(self.deepimage(X["deepimage"]))
211 212
            return out

J
jrzaurin 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
    def compile(
        self,
        method: str,
        optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
        lr_schedulers: Optional[Union[LRScheduler, Dict[str, LRScheduler]]] = None,
        initializers: Optional[Dict[str, Initializer]] = None,
        transforms: Optional[List[Transforms]] = None,
        callbacks: Optional[List[Callback]] = None,
        metrics: Optional[List[Metric]] = None,
        class_weight: Optional[Union[float, List[float], Tuple[float]]] = None,
        with_focal_loss: bool = False,
        alpha: float = 0.25,
        gamma: float = 2,
        verbose: int = 1,
        seed: int = 1,
    ):
229
        r"""Method to set the of attributes that will be used during the
230
        training process.
231 232 233

        Parameters
        ----------
234
        method: str
235 236 237 238 239 240 241 242 243
            One of `regression`, `binary` or `multiclass`. The default when
            performing a `regression`, a `binary` classification or a
            `multiclass` classification is the `mean squared error
            <https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.mse_loss>`_
            (MSE), `Binary Cross Entropy
            <https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.binary_cross_entropy>`_
            (BCE) and `Cross Entropy
            <https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.cross_entropy>`_
            (CE) respectively.
244
        optimizers: Union[Optimizer, Dict[str, Optimizer]], Optional, Default=AdamW
245 246 247 248 249
            - An instance of ``pytorch``'s ``Optimizer`` object (e.g. :obj:`torch.optim.Adam()`) or
            - a dictionary where there keys are the model components (i.e.
              `'wide'`, `'deepdense'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`)  and
              the values are the corresponding optimizers. If multiple optimizers are used
              the  dictionary MUST contain an optimizer per model component.
250 251 252

            See `Pytorch optimizers <https://pytorch.org/docs/stable/optim.html>`_.
        lr_schedulers: Union[LRScheduler, Dict[str, LRScheduler]], Optional, Default=None
253 254 255 256
            - An instance of ``pytorch``'s ``LRScheduler`` object (e.g
              :obj:`torch.optim.lr_scheduler.StepLR(opt, step_size=5)`) or
            - a dictionary where there keys are the model componenst (i.e. `'wide'`,
              `'deepdense'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and the
257 258 259 260
              values are the corresponding learning rate schedulers.

            See `Pytorch schedulers <https://pytorch.org/docs/stable/optim.html>`_.
        initializers: Dict[str, Initializer], Optional. Default=None
261 262
            Dict where there keys are the model components (i.e. `'wide'`,
            `'deepdense'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and the
263
            values are the corresponding initializers.
264 265
            See `Pytorch initializers <https://pytorch.org/docs/stable/nn.init.html>`_.
        transforms: List[Transforms], Optional. Default=None
266 267 268 269 270
            ``Transforms`` is a custom type. See
            :obj:`pytorch_widedeep.wdtypes`. List with
            :obj:`torchvision.transforms` to be applied to the image component
            of the model (i.e. ``deepimage``) See `torchvision transforms
            <https://pytorch.org/docs/stable/torchvision/transforms.html>`_.
271
        callbacks: List[Callback], Optional. Default=None
272 273 274 275
            Callbacks available are: ``ModelCheckpoint``, ``EarlyStopping``,
            and ``LRHistory``. The ``History`` callback is used by default.
            See the ``Callbacks`` section in this documentation or
            :obj:`pytorch_widedeep.callbacks`
276
        metrics: List[Metric], Optional. Default=None
277 278 279
            Metrics available are: ``Accuracy``, ``Precision``, ``Recall``,
            ``FBetaScore`` and ``F1Score``.  See the ``Metrics`` section in
            this documentation or :obj:`pytorch_widedeep.metrics`
280 281 282 283 284 285 286 287 288 289 290 291 292
        class_weight: Union[float, List[float], Tuple[float]]. Optional. Default=None
            - float indicating the weight of the minority class in binary classification
              problems (e.g. 9.)
            - a list or tuple with weights for the different classes in multiclass
              classification problems  (e.g. [1., 2., 3.]). The weights do
              not neccesarily need to be normalised. If your loss function
              uses reduction='mean', the loss will be normalized by the sum
              of the corresponding weights for each element. If you are
              using reduction='none', you would have to take care of the
              normalization yourself. See `this discussion
              <https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/10>`_.
        with_focal_loss: bool, Optional. Default=False
            Boolean indicating whether to use the Focal Loss for highly imbalanced problems.
293 294
            For details on the focal loss see the `original paper
            <https://arxiv.org/pdf/1708.02002.pdf>`_.
295 296 297 298 299
        alpha: float. Default=0.25
            Focal Loss alpha parameter.
        gamma: float. Default=2
            Focal Loss gamma parameter.
        verbose: int
300
            Setting it to 0 will print nothing during training.
301
        seed: int, Default=1
302
            Random seed to be used throughout all the methods
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329

        Example
        --------
        Assuming you have already built the model components (wide, deepdense, etc...)

        >>> from pytorch_widedeep.models import WideDeep
        >>> from pytorch_widedeep.initializers import *
        >>> from pytorch_widedeep.callbacks import *
        >>> from pytorch_widedeep.optim import RAdam
        >>> model = WideDeep(wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage)
        >>> wide_opt = torch.optim.Adam(model.wide.parameters())
        >>> deep_opt = torch.optim.Adam(model.deepdense.parameters())
        >>> text_opt = RAdam(model.deeptext.parameters())
        >>> img_opt  = RAdam(model.deepimage.parameters())
        >>> wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=5)
        >>> deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=3)
        >>> text_sch = torch.optim.lr_scheduler.StepLR(text_opt, step_size=5)
        >>> img_sch  = torch.optim.lr_scheduler.StepLR(img_opt, step_size=3)
        >>> optimizers = {'wide': wide_opt, 'deepdense':deep_opt, 'deeptext':text_opt, 'deepimage': img_opt}
        >>> schedulers = {'wide': wide_sch, 'deepdense':deep_sch, 'deeptext':text_sch, 'deepimage': img_sch}
        >>> initializers = {'wide': Uniform, 'deepdense':Normal, 'deeptext':KaimingNormal,
        >>> ... 'deepimage':KaimingUniform}
        >>> transforms = [ToTensor, Normalize(mean=mean, std=std)]
        >>> callbacks = [LRHistory, EarlyStopping, ModelCheckpoint(filepath='model_weights/wd_out.pt')]
        >>> model.compile(method='regression', initializers=initializers, optimizers=optimizers,
        >>> ... lr_schedulers=schedulers, callbacks=callbacks, transforms=transforms)
        """
330
        self.verbose = verbose
331
        self.seed = seed
332
        self.early_stop = False
333
        self.method = method
334
        self.with_focal_loss = with_focal_loss
J
jrzaurin 已提交
335 336
        if self.with_focal_loss:
            self.alpha, self.gamma = alpha, gamma
337

338
        if isinstance(class_weight, float):
J
jrzaurin 已提交
339 340 341
            self.class_weight = torch.tensor([1.0 - class_weight, class_weight])
        elif isinstance(class_weight, (tuple, list)):
            self.class_weight = torch.tensor(class_weight)
342 343
        else:
            self.class_weight = None
344 345

        if initializers is not None:
346
            self.initializer = MultipleInitializer(initializers, verbose=self.verbose)
347 348
            self.initializer.apply(self)

349 350
        if optimizers is not None:
            if isinstance(optimizers, Optimizer):
J
jrzaurin 已提交
351 352
                self.optimizer: Union[Optimizer, MultipleOptimizer] = optimizers
            elif len(optimizers) > 1:
353
                opt_names = list(optimizers.keys())
J
jrzaurin 已提交
354 355 356
                mod_names = [n for n, c in self.named_children()]
                for mn in mod_names:
                    assert mn in opt_names, "No optimizer found for {}".format(mn)
357
                self.optimizer = MultipleOptimizer(optimizers)
358
        else:
J
jrzaurin 已提交
359
            self.optimizer = torch.optim.AdamW(self.parameters())  # type: ignore
360

361 362
        if lr_schedulers is not None:
            if isinstance(lr_schedulers, LRScheduler):
J
jrzaurin 已提交
363 364 365 366
                self.lr_scheduler: Union[
                    LRScheduler, MultipleLRScheduler
                ] = lr_schedulers
                self.cyclic = "cycl" in self.lr_scheduler.__class__.__name__.lower()
367 368
            elif len(lr_schedulers) > 1:
                self.lr_scheduler = MultipleLRScheduler(lr_schedulers)
J
jrzaurin 已提交
369 370 371 372 373
                scheduler_names = [
                    sc.__class__.__name__.lower()
                    for _, sc in self.lr_scheduler._schedulers.items()
                ]
                self.cyclic = any(["cycl" in sn for sn in scheduler_names])
374
        else:
375
            self.lr_scheduler, self.cyclic = None, False
376

377
        if transforms is not None:
J
jrzaurin 已提交
378
            self.transforms: MultipleTransforms = MultipleTransforms(transforms)()
379 380 381
        else:
            self.transforms = None

382
        self.history = History()
J
jrzaurin 已提交
383
        self.callbacks: List = [self.history]
384
        if callbacks is not None:
385
            for callback in callbacks:
J
jrzaurin 已提交
386 387
                if isinstance(callback, type):
                    callback = callback()
388
                self.callbacks.append(callback)
389 390 391

        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
392
            self.callbacks += [MetricCallback(self.metric)]
393 394
        else:
            self.metric = None
395

396 397
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self)
398

J
jrzaurin 已提交
399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426
        if use_cuda:
            self.cuda()

    def fit(
        self,
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
        n_epochs: int = 1,
        validation_freq: int = 1,
        batch_size: int = 32,
        patience: int = 10,
        warm_up: bool = False,
        warm_epochs: int = 4,
        warm_max_lr: float = 0.01,
        warm_deeptext_gradual: bool = False,
        warm_deeptext_max_lr: float = 0.01,
        warm_deeptext_layers: Optional[List[nn.Module]] = None,
        warm_deepimage_gradual: bool = False,
        warm_deepimage_max_lr: float = 0.01,
        warm_deepimage_layers: Optional[List[nn.Module]] = None,
        warm_routine: str = "howard",
    ):
427
        r"""Fit method. Must run after calling ``compile``
428 429 430

        Parameters
        ----------
431
        X_wide: np.ndarray, Optional. Default=None
432 433
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
434
        X_deep: np.ndarray, Optional. Default=None
435 436
            Input for the ``deepdense`` model component.
            See :class:`pytorch_widedeep.preprocessing.DensePreprocessor`
437
        X_text: np.ndarray, Optional. Default=None
438 439
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
440
        X_img : np.ndarray, Optional. Default=None
441 442 443 444
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
        X_train: Dict[str, np.ndarray], Optional. Default=None
            Training dataset for the different model components. Keys are
445
            `X_wide`, `'X_deep'`, `'X_text'`, `'X_img'` and `'target'`. Values are
446
            the corresponding matrices.
447
        X_val: Dict, Optional. Default=None
448
            Validation dataset for the different model component. Keys are
449 450
            `'X_wide'`, `'X_deep'`, `'X_text'`, `'X_img'` and `'target'`. Values are
            the corresponding matrices.
451 452
        val_split: float, Optional. Default=None
            train/val split fraction
453 454
        target: np.ndarray, Optional. Default=None
            target values
455 456 457 458 459 460
        n_epochs: int, Default=1
            number of epochs
        validation_freq: int, Default=1
            epochs validation frequency
        batch_size: int, Default=32
        patience: int, Default=10
461 462
            Number of epochs without improving the target metric before
            the fit process stops
463
        warm_up: bool, Default=False
464 465 466
            warm up model components individually before the joined training
            starts.

467 468 469 470 471
            ``pytorch_widedeep`` implements 3 warm up routines.

            - Warm up all trainable layers at once. This routine is is
              inspired by the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_. Using a
472 473 474 475 476 477
              Slanted Triangular learing (see `Leslie N. Smith paper
              <https://arxiv.org/pdf/1506.01186.pdf>`_), the process is the
              following: `i`) the learning rate will gradually increase for
              10% of the training steps from max_lr/10 to max_lr. `ii`) It
              will then gradually decrease to max_lr/10 for the remaining 90%
              of the steps. The optimizer used in the process is ``AdamW``.
478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493

            and two gradual warm up routines, where only certain layers are
            warmed up at each warm up step.

            - The so called `Felbo` gradual warm up rourine, inpired by the Felbo et al., 2017
              `DeepEmoji paper <https://arxiv.org/abs/1708.00524>`_.
            - The `Howard` routine based on the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_.

            For details on how these routines work, please see the examples
            section in this documentation and the corresponding repo.
        warm_epochs: int, Default=4
            Number of warm up epochs for those model components that will NOT
            be gradually warmed up. Those components with gradual warm up
            follow their corresponding specific routine.
        warm_max_lr: float, Default=0.01
494
            Maximum learning rate during the Triangular Learning rate cycle
495 496
            for those model componenst that will NOT be gradually warmed up
        warm_deeptext_gradual: bool, Default=False
497 498
            Boolean indicating if the deeptext component will be warmed
            up gradually
499
        warm_deeptext_max_lr: float, Default=0.01
500 501
            Maximum learning rate during the Triangular Learning rate cycle
            for the deeptext component
502
        warm_deeptext_layers: List, Optional, Default=None
503
            List of :obj:`nn.Modules` that will be warmed up gradually.
504 505 506 507 508

            .. note:: These have to be in `warm-up-order`: the layers or blocks
                close to the output neuron(s) first

        warm_deepimage_gradual: bool, Default=False
509 510 511 512 513
            Boolean indicating if the deepimage component will be warmed
            up gradually
        warm_deepimage_max_lr: Float, Default=0.01
            Maximum learning rate during the Triangular Learning rate cycle
            for the deepimage component
514
        warm_deepimage_layers: List, Optional, Default=None
515
            List of :obj:`nn.Modules` that will be warmed up gradually.
516

517 518 519
            .. note:: These have to be in `warm-up-order`: the layers or blocks
                close to the output neuron(s) first

520
        warm_routine: str, Default=`felbo`
521 522 523 524 525
            Warm up routine. On of `felbo` or `howard`. See the examples
            section in this documentation and the corresponding repo for
            details on how to use warm up routines

        Examples
526 527 528
        --------
        Assuming you have already built and compiled the model

529 530

        >>> # Ex 1. using train input arrays directly and no validation
531 532
        >>> model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256)

533 534

        >>> # Ex 2: using train input arrays directly and validation with val_split
535 536
        >>> model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256, val_split=0.2)

537 538

        >>> # Ex 3: using train dict and val_split
539 540 541
        >>> X_train = {'X_wide': X_wide, 'X_deep': X_deep, 'target': y}
        >>> model.fit(X_train, n_epochs=10, batch_size=256, val_split=0.2)

542 543

        >>> # Ex 4: validation using training and validation dicts
544 545 546
        >>> X_train = {'X_wide': X_wide_tr, 'X_deep': X_deep_tr, 'target': y_tr}
        >>> X_val = {'X_wide': X_wide_val, 'X_deep': X_deep_val, 'target': y_val}
        >>> model.fit(X_train=X_train, X_val=X_val n_epochs=10, batch_size=256)
547

548
        .. note:: :obj:`WideDeep` assumes that `X_wide`, `X_deep` and `target` ALWAYS exist, while
549 550 551 552 553
            `X_text` and `X_img` are optional

        .. note:: Either `X_train` or the three `X_wide`, `X_deep` and `target` must be passed to the
            fit method

554
        """
555 556 557

        if X_train is None and (X_wide is None or X_deep is None or target is None):
            raise ValueError(
558 559
                "Training data is missing. Either a dictionary (X_train) with "
                "the training dataset or at least 3 arrays (X_wide, X_deep, "
J
jrzaurin 已提交
560 561
                "target) must be passed to the fit method"
            )
562 563

        self.batch_size = batch_size
J
jrzaurin 已提交
564 565 566 567 568 569
        train_set, eval_set = self._train_val_split(
            X_wide, X_deep, X_text, X_img, X_train, X_val, val_split, target
        )
        train_loader = DataLoader(
            dataset=train_set, batch_size=batch_size, num_workers=n_cpus
        )
570 571
        if warm_up:
            # warm up...
J
jrzaurin 已提交
572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589
            self._warm_up(
                train_loader,
                warm_epochs,
                warm_max_lr,
                warm_deeptext_gradual,
                warm_deeptext_layers,
                warm_deeptext_max_lr,
                warm_deepimage_gradual,
                warm_deepimage_layers,
                warm_deepimage_max_lr,
                warm_routine,
            )
        train_steps = len(train_loader)
        self.callback_container.on_train_begin(
            {"batch_size": batch_size, "train_steps": train_steps, "n_epochs": n_epochs}
        )
        if self.verbose:
            print("Training")
590
        for epoch in range(n_epochs):
591
            # train step...
J
jrzaurin 已提交
592
            epoch_logs: Dict[str, float] = {}
593
            self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
J
jrzaurin 已提交
594
            self.train_running_loss = 0.0
595
            with trange(train_steps, disable=self.verbose != 1) as t:
J
jrzaurin 已提交
596 597
                for batch_idx, (data, target) in zip(t, train_loader):
                    t.set_description("epoch %i" % (epoch + 1))
598 599 600 601 602 603
                    score, train_loss = self._training_step(data, target, batch_idx)
                    if score is not None:
                        t.set_postfix(
                            metrics={k: np.round(v, 4) for k, v in score.items()},
                            loss=train_loss,
                        )
604
                    else:
605
                        t.set_postfix(loss=train_loss)
J
jrzaurin 已提交
606 607
                    if self.lr_scheduler:
                        self._lr_scheduler_step(step_location="on_batch_end")
608
                    self.callback_container.on_batch_end(batch=batch_idx)
J
jrzaurin 已提交
609
            epoch_logs["train_loss"] = train_loss
610 611 612 613
            if score is not None:
                for k, v in score.items():
                    log_k = "_".join(["train", k])
                    epoch_logs[log_k] = v
614
            # eval step...
J
jrzaurin 已提交
615
            if epoch % validation_freq == (validation_freq - 1):
616
                if eval_set is not None:
J
jrzaurin 已提交
617 618 619 620 621 622 623 624
                    eval_loader = DataLoader(
                        dataset=eval_set,
                        batch_size=batch_size,
                        num_workers=n_cpus,
                        shuffle=False,
                    )
                    eval_steps = len(eval_loader)
                    self.valid_running_loss = 0.0
625
                    with trange(eval_steps, disable=self.verbose != 1) as v:
J
jrzaurin 已提交
626 627
                        for i, (data, target) in zip(v, eval_loader):
                            v.set_description("valid")
628 629 630 631 632 633 634 635
                            score, val_loss = self._validation_step(data, target, i)
                            if score is not None:
                                v.set_postfix(
                                    metrics={
                                        k: np.round(v, 4) for k, v in score.items()
                                    },
                                    loss=val_loss,
                                )
636
                            else:
637
                                v.set_postfix(loss=val_loss)
J
jrzaurin 已提交
638
                    epoch_logs["val_loss"] = val_loss
639 640 641 642
                    if score is not None:
                        for k, v in score.items():
                            log_k = "_".join(["val", k])
                            epoch_logs[log_k] = v
J
jrzaurin 已提交
643 644 645
            if self.lr_scheduler:
                self._lr_scheduler_step(step_location="on_epoch_end")
            #  log and check if early_stop...
646
            self.callback_container.on_epoch_end(epoch, epoch_logs)
647
            if self.early_stop:
J
jrzaurin 已提交
648
                self.callback_container.on_train_end(epoch_logs)
649
                break
J
jrzaurin 已提交
650
            self.callback_container.on_train_end(epoch_logs)
651 652
        self.train()

J
jrzaurin 已提交
653 654 655 656 657 658 659 660
    def predict(
        self,
        X_wide: np.ndarray,
        X_deep: np.ndarray,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
661
        r"""Returns the predictions
662 663 664

        Parameters
        ----------
665
        X_wide: np.ndarray, Optional. Default=None
666 667
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
668
        X_deep: np.ndarray, Optional. Default=None
669 670
            Input for the ``deepdense`` model component.
            See :class:`pytorch_widedeep.preprocessing.DensePreprocessor`
671
        X_text: np.ndarray, Optional. Default=None
672 673
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
674
        X_img : np.ndarray, Optional. Default=None
675 676 677 678 679 680 681 682 683
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
        X_test: Dict[str, np.ndarray], Optional. Default=None
            Testing dataset for the different model components. Keys are
            `'X_wide'`, `'X_deep'`, `'X_text'`, `'X_img'` and `'target'` the values are
            the corresponding matrices.

        .. note:: WideDeep assumes that `X_wide`, `X_deep` and `target` ALWAYS exist,
            while `X_text` and `X_img` are optional.
684
        """
685
        preds_l = self._predict(X_wide, X_deep, X_text, X_img, X_test)
686 687 688 689
        if self.method == "regression":
            return np.vstack(preds_l).squeeze(1)
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
J
jrzaurin 已提交
690
            return (preds > 0.5).astype("int")
691 692 693
        if self.method == "multiclass":
            preds = np.vstack(preds_l)
            return np.argmax(preds, 1)
694

J
jrzaurin 已提交
695 696 697 698 699 700 701 702
    def predict_proba(
        self,
        X_wide: np.ndarray,
        X_deep: np.ndarray,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
703
        r"""Returns the predicted probabilities for the test dataset for  binary
704
        and multiclass methods
705
        """
706
        preds_l = self._predict(X_wide, X_deep, X_text, X_img, X_test)
707 708
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
J
jrzaurin 已提交
709 710 711
            probs = np.zeros([preds.shape[0], 2])
            probs[:, 0] = 1 - preds
            probs[:, 1] = preds
712 713 714
            return probs
        if self.method == "multiclass":
            return np.vstack(preds_l)
715

J
jrzaurin 已提交
716 717 718
    def get_embeddings(
        self, col_name: str, cat_encoding_dict: Dict[str, Dict[str, int]]
    ) -> Dict[str, np.ndarray]:
719 720 721
        r"""Returns the learned embeddings for the categorical features passed through
        ``deepdense``.

722 723 724 725
        This method is designed to take an encoding dictionary in the same
        format as that of the :obj:`LabelEncoder` Attribute of the class
        :obj:`DensePreprocessor`. See
        :class:`pytorch_widedeep.preprocessing.DensePreprocessor` and
726
        :class:`pytorch_widedeep.utils.dense_utils.LabelEncder`.
727 728 729

        Parameters
        ----------
730 731
        col_name: str,
            Column name of the feature we want to get the embeddings for
732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748
        cat_encoding_dict: Dict[str, Dict[str, int]]
            Dictionary containing the categorical encodings, e.g:

            Examples
            --------
            >>> cat_encoding_dict['education']
            {'11th': 0, 'HS-grad': 1, 'Assoc-acdm': 2, 'Some-college': 3, '10th': 4, 'Prof-school': 5,
            '7th-8th': 6, 'Bachelors': 7, 'Masters': 8, 'Doctorate': 9, '5th-6th': 10, 'Assoc-voc': 11,
            '9th': 12, '12th': 13, '1st-4th': 14, 'Preschool': 15}

        Examples
        --------

        Assuming we have already train the model and that we have the
        categorical encodings in a dictionary name ``encoding_dict``:

        >>> model.get_embeddings(col_name='education', cat_encoding_dict=encoding_dict)
749 750 751 752 753 754 755
        {'11th': array([-0.42739448, -0.22282735,  0.36969638,  0.4445322 ,  0.2562272 ,
        0.11572784, -0.01648579,  0.09027119,  0.0457597 , -0.28337458], dtype=float32),
         'HS-grad': array([-0.10600474, -0.48775527,  0.3444158 ,  0.13818645, -0.16547225,
        0.27409762, -0.05006042, -0.0668492 , -0.11047247,  0.3280354 ], dtype=float32),
        ...
        }
        """
J
jrzaurin 已提交
756 757
        for n, p in self.named_parameters():
            if "embed_layers" in n and col_name in n:
758 759
                embed_mtx = p.cpu().data.numpy()
        encoding_dict = cat_encoding_dict[col_name]
J
jrzaurin 已提交
760
        inv_encoding_dict = {v: k for k, v in encoding_dict.items()}
761
        cat_embed_dict = {}
J
jrzaurin 已提交
762
        for idx, value in inv_encoding_dict.items():
763
            cat_embed_dict[value] = embed_mtx[idx]
764 765
        return cat_embed_dict

J
jrzaurin 已提交
766
    def _loss_fn(self, y_pred: Tensor, y_true: Tensor) -> Tensor:  # type: ignore
767 768
        if self.with_focal_loss:
            return FocalLoss(self.alpha, self.gamma)(y_pred, y_true)
J
jrzaurin 已提交
769
        if self.method == "regression":
770
            return F.mse_loss(y_pred, y_true.view(-1, 1))
J
jrzaurin 已提交
771
        if self.method == "binary":
772
            return F.binary_cross_entropy_with_logits(
J
jrzaurin 已提交
773 774 775
                y_pred, y_true.view(-1, 1), weight=self.class_weight
            )
        if self.method == "multiclass":
776 777
            return F.cross_entropy(y_pred, y_true, weight=self.class_weight)

J
jrzaurin 已提交
778 779 780 781 782 783 784 785 786 787 788
    def _train_val_split(
        self,
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
    ):
789 790 791 792 793 794 795 796 797
        r"""
        If a validation set (X_val) is passed to the fit method, or val_split
        is specified, the train/val split will happen internally. A number of
        options are allowed in terms of data inputs. For parameter
        information, please, see the .fit() method documentation

        Returns
        -------
        train_set: WideDeepDataset
798 799 800
            :obj:`WideDeepDataset` object that will be loaded through
            :obj:`torch.utils.data.DataLoader`. See
            :class:`pytorch_widedeep.models._wd_dataset`
801
        eval_set : WideDeepDataset
802 803 804
            :obj:`WideDeepDataset` object that will be loaded through
            :obj:`torch.utils.data.DataLoader`. See
            :class:`pytorch_widedeep.models._wd_dataset`
805
        """
J
jrzaurin 已提交
806
        #  Without validation
807 808 809 810
        if X_val is None and val_split is None:
            # if a train dictionary is passed, check if text and image datasets
            # are present and instantiate the WideDeepDataset class
            if X_train is not None:
J
jrzaurin 已提交
811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829
                X_wide, X_deep, target = (
                    X_train["X_wide"],
                    X_train["X_deep"],
                    X_train["target"],
                )
                if "X_text" in X_train.keys():
                    X_text = X_train["X_text"]
                if "X_img" in X_train.keys():
                    X_img = X_train["X_img"]
            X_train = {"X_wide": X_wide, "X_deep": X_deep, "target": target}
            try:
                X_train.update({"X_text": X_text})
            except:
                pass
            try:
                X_train.update({"X_img": X_img})
            except:
                pass
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
830
            eval_set = None
J
jrzaurin 已提交
831
        #  With validation
832 833 834 835 836 837
        else:
            if X_val is not None:
                # if a validation dictionary is passed, then if not train
                # dictionary is passed we build it with the input arrays
                # (either the dictionary or the arrays must be passed)
                if X_train is None:
J
jrzaurin 已提交
838 839 840 841 842
                    X_train = {"X_wide": X_wide, "X_deep": X_deep, "target": target}
                    if X_text is not None:
                        X_train.update({"X_text": X_text})
                    if X_img is not None:
                        X_train.update({"X_img": X_img})
843 844 845 846
            else:
                # if a train dictionary is passed, check if text and image
                # datasets are present. The train/val split using val_split
                if X_train is not None:
J
jrzaurin 已提交
847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872
                    X_wide, X_deep, target = (
                        X_train["X_wide"],
                        X_train["X_deep"],
                        X_train["target"],
                    )
                    if "X_text" in X_train.keys():
                        X_text = X_train["X_text"]
                    if "X_img" in X_train.keys():
                        X_img = X_train["X_img"]
                (
                    X_tr_wide,
                    X_val_wide,
                    X_tr_deep,
                    X_val_deep,
                    y_tr,
                    y_val,
                ) = train_test_split(
                    X_wide,
                    X_deep,
                    target,
                    test_size=val_split,
                    random_state=self.seed,
                    stratify=target if self.method != "regression" else None,
                )
                X_train = {"X_wide": X_tr_wide, "X_deep": X_tr_deep, "target": y_tr}
                X_val = {"X_wide": X_val_wide, "X_deep": X_val_deep, "target": y_val}
873
                try:
J
jrzaurin 已提交
874
                    X_tr_text, X_val_text = train_test_split(
J
jrzaurin 已提交
875 876 877 878 879 880 881 882 883 884
                        X_text,
                        test_size=val_split,
                        random_state=self.seed,
                        stratify=target if self.method != "regression" else None,
                    )
                    X_train.update({"X_text": X_tr_text}), X_val.update(
                        {"X_text": X_val_text}
                    )
                except:
                    pass
885
                try:
J
jrzaurin 已提交
886
                    X_tr_img, X_val_img = train_test_split(
J
jrzaurin 已提交
887 888 889 890 891 892 893 894 895 896
                        X_img,
                        test_size=val_split,
                        random_state=self.seed,
                        stratify=target if self.method != "regression" else None,
                    )
                    X_train.update({"X_img": X_tr_img}), X_val.update(
                        {"X_img": X_val_img}
                    )
                except:
                    pass
897
            # At this point the X_train and X_val dictionaries have been built
J
jrzaurin 已提交
898 899
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
            eval_set = WideDeepDataset(**X_val, transforms=self.transforms)  # type: ignore
900 901
        return train_set, eval_set

J
jrzaurin 已提交
902 903 904 905 906 907 908 909 910 911 912 913 914
    def _warm_up(
        self,
        loader: DataLoader,
        n_epochs: int,
        max_lr: float,
        deeptext_gradual: bool,
        deeptext_layers: List[nn.Module],
        deeptext_max_lr: float,
        deepimage_gradual: bool,
        deepimage_layers: List[nn.Module],
        deepimage_max_lr: float,
        routine: str = "felbo",
    ):
915 916 917
        r"""
        Simple wrappup to individually warm up model components
        """
918 919
        if self.deephead is not None:
            raise ValueError(
J
jrzaurin 已提交
920 921
                "Currently warming up is only supported without a fully connected 'DeepHead'"
            )
922 923
        # This is not the most elegant solution, but is a soluton "in-between"
        # a non elegant one and re-factoring the whole code
924
        warmer = WarmUp(self._loss_fn, self.metric, self.method, self.verbose)
J
jrzaurin 已提交
925 926
        warmer.warm_all(self.wide, "wide", loader, n_epochs, max_lr)
        warmer.warm_all(self.deepdense, "deepdense", loader, n_epochs, max_lr)
927 928
        if self.deeptext:
            if deeptext_gradual:
J
jrzaurin 已提交
929 930 931 932 933 934 935 936 937 938
                warmer.warm_gradual(
                    self.deeptext,
                    "deeptext",
                    loader,
                    deeptext_max_lr,
                    deeptext_layers,
                    routine,
                )
            else:
                warmer.warm_all(self.deeptext, "deeptext", loader, n_epochs, max_lr)
939 940
        if self.deepimage:
            if deepimage_gradual:
J
jrzaurin 已提交
941 942 943 944 945 946 947 948 949 950
                warmer.warm_gradual(
                    self.deepimage,
                    "deepimage",
                    loader,
                    deepimage_max_lr,
                    deepimage_layers,
                    routine,
                )
            else:
                warmer.warm_all(self.deepimage, "deepimage", loader, n_epochs, max_lr)
951

J
jrzaurin 已提交
952
    def _lr_scheduler_step(self, step_location: str):
953 954 955
        r"""
        Function to execute the learning rate schedulers steps.
        If the lr_scheduler is Cyclic (i.e. CyclicLR or OneCycleLR), the step
956
        must happen after training each bach durig training. On the other
957 958 959 960 961 962 963 964
        hand, if the  scheduler is not Cyclic, is expected to be called after
        validation.

        Parameters
        ----------
        step_location: Str
            Indicates where to run the lr_scheduler step
        """
J
jrzaurin 已提交
965 966 967 968 969 970 971 972 973 974 975 976
        if (
            self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"
            and self.cyclic
        ):
            if step_location == "on_batch_end":
                for model_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
            elif step_location == "on_epoch_end":
                for scheduler_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" not in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
977
        elif self.cyclic:
J
jrzaurin 已提交
978 979 980 981 982 983 984 985 986 987 988 989 990 991 992
            if step_location == "on_batch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler":
            if step_location == "on_epoch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif step_location == "on_epoch_end":
            self.lr_scheduler.step()  # type: ignore
        else:
            pass

    def _training_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
993
        self.train()
J
jrzaurin 已提交
994 995
        X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
        y = target.float() if self.method != "multiclass" else target
996 997 998
        y = y.cuda() if use_cuda else y

        self.optimizer.zero_grad()
999
        y_pred = self.forward(X)
1000 1001 1002 1003 1004
        loss = self._loss_fn(y_pred, y)
        loss.backward()
        self.optimizer.step()

        self.train_running_loss += loss.item()
J
jrzaurin 已提交
1005
        avg_loss = self.train_running_loss / (batch_idx + 1)
1006 1007

        if self.metric is not None:
1008
            if self.method == "binary":
1009
                score = self.metric(torch.sigmoid(y_pred), y)
1010
            if self.method == "multiclass":
1011 1012
                score = self.metric(F.softmax(y_pred, dim=1), y)
            return score, avg_loss
1013 1014 1015
        else:
            return None, avg_loss

J
jrzaurin 已提交
1016
    def _validation_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
1017 1018 1019

        self.eval()
        with torch.no_grad():
J
jrzaurin 已提交
1020 1021
            X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
            y = target.float() if self.method != "multiclass" else target
1022 1023
            y = y.cuda() if use_cuda else y

1024
            y_pred = self.forward(X)
1025 1026
            loss = self._loss_fn(y_pred, y)
            self.valid_running_loss += loss.item()
J
jrzaurin 已提交
1027
            avg_loss = self.valid_running_loss / (batch_idx + 1)
1028 1029

        if self.metric is not None:
1030
            if self.method == "binary":
1031
                score = self.metric(torch.sigmoid(y_pred), y)
1032
            if self.method == "multiclass":
1033 1034
                score = self.metric(F.softmax(y_pred, dim=1), y)
            return score, avg_loss
1035 1036 1037
        else:
            return None, avg_loss

J
jrzaurin 已提交
1038 1039 1040 1041 1042 1043 1044 1045
    def _predict(
        self,
        X_wide: np.ndarray,
        X_deep: np.ndarray,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> List:
1046 1047 1048
        r"""Hidden method to avoid code repetition in predict and
        predict_proba. For parameter information, please, see the .predict()
        method documentation
1049 1050 1051 1052
        """
        if X_test is not None:
            test_set = WideDeepDataset(**X_test)
        else:
J
jrzaurin 已提交
1053 1054 1055 1056 1057
            load_dict = {"X_wide": X_wide, "X_deep": X_deep}
            if X_text is not None:
                load_dict.update({"X_text": X_text})
            if X_img is not None:
                load_dict.update({"X_img": X_img})
1058 1059
            test_set = WideDeepDataset(**load_dict)

J
jrzaurin 已提交
1060 1061 1062 1063 1064 1065 1066
        test_loader = DataLoader(
            dataset=test_set,
            batch_size=self.batch_size,
            num_workers=n_cpus,
            shuffle=False,
        )
        test_steps = (len(test_loader.dataset) // test_loader.batch_size) + 1
1067 1068 1069 1070 1071 1072

        self.eval()
        preds_l = []
        with torch.no_grad():
            with trange(test_steps, disable=self.verbose != 1) as t:
                for i, data in zip(t, test_loader):
J
jrzaurin 已提交
1073 1074
                    t.set_description("predict")
                    X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
1075 1076 1077
                    preds = self.forward(X)
                    if self.method == "binary":
                        preds = torch.sigmoid(preds)
J
jrzaurin 已提交
1078 1079
                    if self.method == "multiclass":
                        preds = F.softmax(preds, dim=1)
J
jrzaurin 已提交
1080
                    preds = preds.cpu().data.numpy()
1081 1082
                    preds_l.append(preds)
        self.train()
J
jrzaurin 已提交
1083
        return preds_l