wide_deep.py 51.2 KB
Newer Older
1
import os
2 3
import warnings
import functools
4 5

import numpy as np
6
import torch
7 8
import torch.nn as nn
import torch.nn.functional as F
9 10 11
from tqdm import trange
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
12

13
from ..losses import FocalLoss
14 15
from ._warmup import WarmUp
from ..metrics import Metric, MetricCallback, MultipleMetrics
16
from ..wdtypes import *  # noqa: F403
17 18
from ..callbacks import History, Callback, CallbackContainer
from .deep_dense import dense_layer
19
from ._wd_dataset import WideDeepDataset
20
from ..initializers import Initializer, MultipleInitializer
21 22
from ._multiple_optimizer import MultipleOptimizer
from ._multiple_transforms import MultipleTransforms
23
from ._multiple_lr_scheduler import MultipleLRScheduler
J
jrzaurin 已提交
24

25 26
warnings.filterwarnings("default", category=DeprecationWarning)

27
n_cpus = os.cpu_count()
28

29
use_cuda = torch.cuda.is_available()
30
device = torch.device("cuda" if use_cuda else "cpu")
31 32


33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
def deprecated_alias(**aliases):
    def deco(f):
        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            rename_kwargs(f.__name__, kwargs, aliases)
            return f(*args, **kwargs)

        return wrapper

    return deco


def rename_kwargs(func_name, kwargs, aliases):
    for alias, new in aliases.items():
        if alias in kwargs:
            if new in kwargs:
                raise TypeError(
                    "{} received both {} and {}".format(func_name, alias, new)
                )
            warnings.warn(
                "'{}' is deprecated; use '{}' instead".format(alias, new),
                DeprecationWarning,
            )
            kwargs[new] = kwargs.pop(alias)


59
class WideDeep(nn.Module):
60 61
    r"""Main collector class that combines all ``wide``, ``deeptabular``
    (which can be a number of architectures), ``deeptext`` and ``deepimage`` models.
62

63 64
    There are two options to combine these models that correspond to the two
    architectures that ``pytorch-widedeep`` can build.
65 66 67 68

        - Directly connecting the output of the model components to an ouput neuron(s).

        - Adding a `Fully-Connected Head` (FC-Head) on top of the deep models.
69 70
          This FC-Head will combine the output form the ``deeptabular``, ``deeptext`` and
          ``deepimage`` and will be then connected to the output neuron(s).
71 72 73

    Parameters
    ----------
74
    wide: nn.Module
75 76 77 78
        Wide model. We recommend using the ``Wide`` class in this package.
        However, it is possible to use a custom model as long as is consistent
        with the required architecture, see
        :class:`pytorch_widedeep.models.wide.Wide`
79 80 81
    deeptabular: nn.Module
        currently we offer three possible architectures for the `deeptabular` component
        implemented in this package. These are: ``DeepDense``, ``DeepDenseResnet`` and `
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        ``TabTransformer``.

        1. ``DeepDense`` is simply an embedding layer encoding the categorical
        features that are then concatenated and passed through a series of
        dense layers.
        See: ``pytorch_widedeep.models.deep_dense.DeepDense``

        2. ``DeepDenseResnet`` is an embedding layer encoding the categorical
        features that are then concatenated and passed through a series of
        `"dense"` ResNet blocks.
        See ``pytorch_widedeep.models.deep_dense_resnet.DeepDenseResnet``

        3. ``TabTransformer`` is detailed in `TabTransformer: Tabular Data Modeling
        Using Contextual Embeddings <https://arxiv.org/pdf/2012.06678.pdf>`_.
        See ``pytorch_widedeep.models.tab_transformer.TabTransformer``
97 98
        We recommend using on of these as `deeptabular`. However, a custom
        model as long as is  consistent with the required architecture. See
99
        :class:`pytorch_widedeep.models.deep_dense.TabTransformer`.
100
    deeptext: nn.Module, Optional
101 102 103 104
        `Deep text` model for the text input. Must be an object of class
        ``DeepText`` or a custom model as long as is consistent with the
        required architecture. See
        :class:`pytorch_widedeep.models.deep_dense.DeepText`
105
    deepimage: nn.Module, Optional
106 107 108 109
        `Deep Image` model for the images input. Must be an object of class
        ``DeepImage`` or a custom model as long as is consistent with the
        required architecture. See
        :class:`pytorch_widedeep.models.deep_dense.DeepImage`
110
    deephead: nn.Module, Optional
111
        `Dense` model consisting in a stack of dense layers. The FC-Head.
112
    head_layers: List, Optional
113
        Alternatively, we can use ``head_layers`` to specify the sizes of the
114
        stacked dense layers in the fc-head e.g: ``[128, 64]``
115
    head_dropout: List, Optional
116
        Dropout between the layers in ``head_layers``. e.g: ``[0.5, 0.5]``
117
    head_batchnorm: bool, Optional
118
        Specifies if batch normalizatin should be included in the dense layers
119 120 121 122 123 124
    pred_dim: int
        Size of the final wide and deep output layer containing the
        predictions. `1` for regression and binary classification or `number
        of classes` for multiclass classification.


125
    .. note:: With the exception of ``cyclic_lr``, all attributes are direct assignations of
126 127 128 129 130 131 132 133 134
        the corresponding parameters used when calling ``compile``.  Therefore,
        see the parameters at
        :class:`pytorch_widedeep.models.wide_deep.WideDeep.compile` for a full
        list of the attributes of an instance of
        :class:`pytorch_widedeep.models.wide_deep.Wide`


    Attributes
    ----------
135 136
    cyclic_lr: :obj:`bool`
        Attribute that indicates if any of the lr_schedulers is cyclic_lr (i.e. ``CyclicLR`` or
137
        ``OneCycleLR``). See `Pytorch schedulers <https://pytorch.org/docs/stable/optim.html>`_.
138 139


140 141 142 143 144 145 146 147 148 149
    .. note:: While I recommend using the ``wide`` and ``deeptabular`` components
        within this package when building the corresponding model components,
        it is very likely that the user will want to use custom text and image
        models. That is perfectly possible. Simply, build them and pass them
        as the corresponding parameters. Note that the custom models MUST
        return a last layer of activations (i.e. not the final prediction) so
        that  these activations are collected by ``WideDeep`` and combined
        accordingly. In addition, the models MUST also contain an attribute
        ``output_dim`` with the size of these last layers of activations. See
        for example :class:`pytorch_widedeep.models.deep_dense.DeepDense`
150 151

    """
J
jrzaurin 已提交
152

153 154
    @deprecated_alias(deepdense="deeptabular")  # noqa: C901
    def __init__(
J
jrzaurin 已提交
155
        self,
156
        wide: Optional[nn.Module] = None,
157
        deeptabular: Optional[nn.Module] = None,
J
jrzaurin 已提交
158 159 160 161 162 163
        deeptext: Optional[nn.Module] = None,
        deepimage: Optional[nn.Module] = None,
        deephead: Optional[nn.Module] = None,
        head_layers: Optional[List[int]] = None,
        head_dropout: Optional[List] = None,
        head_batchnorm: Optional[bool] = None,
164
        pred_dim: int = 1,
J
jrzaurin 已提交
165
    ):
166

167
        super(WideDeep, self).__init__()
168

J
jrzaurin 已提交
169 170
        self._check_model_components(
            wide,
171
            deeptabular,
J
jrzaurin 已提交
172 173 174 175 176 177
            deeptext,
            deepimage,
            deephead,
            head_layers,
            head_dropout,
            pred_dim,
178
        )
179

180 181 182
        # required as attribute just in case we pass a deephead
        self.pred_dim = pred_dim

183
        # The main 5 components of the wide and deep assemble
184
        self.wide = wide
185
        self.deeptabular = deeptabular
J
jrzaurin 已提交
186
        self.deeptext = deeptext
187
        self.deepimage = deepimage
188 189 190 191
        self.deephead = deephead

        if self.deephead is None:
            if head_layers is not None:
192
                input_dim = 0
193 194
                if self.deeptabular is not None:
                    input_dim += self.deeptabular.output_dim  # type:ignore
M
Minjin Choi 已提交
195
                if self.deeptext is not None:
196
                    input_dim += self.deeptext.output_dim  # type:ignore
M
Minjin Choi 已提交
197
                if self.deepimage is not None:
198
                    input_dim += self.deepimage.output_dim  # type:ignore
199
                head_layers = [input_dim] + head_layers
J
jrzaurin 已提交
200 201
                if not head_dropout:
                    head_dropout = [0.0] * (len(head_layers) - 1)
202 203 204
                self.deephead = nn.Sequential()
                for i in range(1, len(head_layers)):
                    self.deephead.add_module(
J
jrzaurin 已提交
205 206 207 208 209 210 211 212 213
                        "head_layer_{}".format(i - 1),
                        dense_layer(
                            head_layers[i - 1],
                            head_layers[i],
                            head_dropout[i - 1],
                            head_batchnorm,
                        ),
                    )
                self.deephead.add_module(
214
                    "head_out", nn.Linear(head_layers[-1], pred_dim)
J
jrzaurin 已提交
215
                )
216
            else:
217 218 219
                if self.deeptabular is not None:
                    self.deeptabular = nn.Sequential(
                        self.deeptabular, nn.Linear(self.deeptabular.output_dim, pred_dim)  # type: ignore
220
                    )
221 222
                if self.deeptext is not None:
                    self.deeptext = nn.Sequential(
223
                        self.deeptext, nn.Linear(self.deeptext.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
224
                    )
225 226
                if self.deepimage is not None:
                    self.deepimage = nn.Sequential(
227
                        self.deepimage, nn.Linear(self.deepimage.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
228
                    )
229 230
        # else:
        #     self.deephead
231

232
    def forward(self, X: Dict[str, Tensor]) -> Tensor:  # type: ignore  # noqa: C901
233

234
        # Wide output: direct connection to the output neuron(s)
235 236 237 238
        if self.wide is not None:
            out = self.wide(X["wide"])
        else:
            batch_size = X[list(X.keys())[0]].size(0)
239
            out = torch.zeros(batch_size, self.pred_dim).to(device)
240 241 242 243

        # Deep output: either connected directly to the output neuron(s) or
        # passed through a head first
        if self.deephead:
244 245
            if self.deeptabular is not None:
                deepside = self.deeptabular(X["deeptabular"])
246
            else:
247
                deepside = torch.FloatTensor().to(device)
248
            if self.deeptext is not None:
J
jrzaurin 已提交
249
                deepside = torch.cat([deepside, self.deeptext(X["deeptext"])], axis=1)  # type: ignore
250
            if self.deepimage is not None:
J
jrzaurin 已提交
251
                deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1)  # type: ignore
252
            deephead_out = self.deephead(deepside)
253 254
            deepside_linear = nn.Linear(deephead_out.size(1), self.pred_dim).to(device)
            return out.add_(deepside_linear(deephead_out))
255
        else:
256 257
            if self.deeptabular is not None:
                out.add_(self.deeptabular(X["deeptabular"]))
258
            if self.deeptext is not None:
259
                out.add_(self.deeptext(X["deeptext"]))
260
            if self.deepimage is not None:
261
                out.add_(self.deepimage(X["deepimage"]))
262 263
            return out

264
    def compile(  # noqa: C901
J
jrzaurin 已提交
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
        self,
        method: str,
        optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
        lr_schedulers: Optional[Union[LRScheduler, Dict[str, LRScheduler]]] = None,
        initializers: Optional[Dict[str, Initializer]] = None,
        transforms: Optional[List[Transforms]] = None,
        callbacks: Optional[List[Callback]] = None,
        metrics: Optional[List[Metric]] = None,
        class_weight: Optional[Union[float, List[float], Tuple[float]]] = None,
        with_focal_loss: bool = False,
        alpha: float = 0.25,
        gamma: float = 2,
        verbose: int = 1,
        seed: int = 1,
    ):
280
        r"""Method to set the of attributes that will be used during the
281
        training process.
282 283 284

        Parameters
        ----------
285
        method: str
286 287 288 289 290 291 292 293 294
            One of `regression`, `binary` or `multiclass`. The default when
            performing a `regression`, a `binary` classification or a
            `multiclass` classification is the `mean squared error
            <https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.mse_loss>`_
            (MSE), `Binary Cross Entropy
            <https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.binary_cross_entropy>`_
            (BCE) and `Cross Entropy
            <https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.cross_entropy>`_
            (CE) respectively.
295
        optimizers: Union[Optimizer, Dict[str, Optimizer]], Optional, Default=AdamW
296 297
            - An instance of ``pytorch``'s ``Optimizer`` object (e.g. :obj:`torch.optim.Adam()`) or
            - a dictionary where there keys are the model components (i.e.
298
              `'wide'`, `'deeptabular'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`)  and
299 300
              the values are the corresponding optimizers. If multiple optimizers are used
              the  dictionary MUST contain an optimizer per model component.
301 302 303

            See `Pytorch optimizers <https://pytorch.org/docs/stable/optim.html>`_.
        lr_schedulers: Union[LRScheduler, Dict[str, LRScheduler]], Optional, Default=None
304 305 306
            - An instance of ``pytorch``'s ``LRScheduler`` object (e.g
              :obj:`torch.optim.lr_scheduler.StepLR(opt, step_size=5)`) or
            - a dictionary where there keys are the model componenst (i.e. `'wide'`,
307
              `'deeptabular'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and the
308 309 310 311
              values are the corresponding learning rate schedulers.

            See `Pytorch schedulers <https://pytorch.org/docs/stable/optim.html>`_.
        initializers: Dict[str, Initializer], Optional. Default=None
312
            Dict where there keys are the model components (i.e. `'wide'`,
313
            `'deeptabular'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and the
314
            values are the corresponding initializers.
315 316
            See `Pytorch initializers <https://pytorch.org/docs/stable/nn.init.html>`_.
        transforms: List[Transforms], Optional. Default=None
317 318 319 320 321
            ``Transforms`` is a custom type. See
            :obj:`pytorch_widedeep.wdtypes`. List with
            :obj:`torchvision.transforms` to be applied to the image component
            of the model (i.e. ``deepimage``) See `torchvision transforms
            <https://pytorch.org/docs/stable/torchvision/transforms.html>`_.
322
        callbacks: List[Callback], Optional. Default=None
323 324 325 326
            Callbacks available are: ``ModelCheckpoint``, ``EarlyStopping``,
            and ``LRHistory``. The ``History`` callback is used by default.
            See the ``Callbacks`` section in this documentation or
            :obj:`pytorch_widedeep.callbacks`
327
        metrics: List[Metric], Optional. Default=None
328 329 330
            Metrics available are: ``Accuracy``, ``Precision``, ``Recall``,
            ``FBetaScore`` and ``F1Score``.  See the ``Metrics`` section in
            this documentation or :obj:`pytorch_widedeep.metrics`
331 332 333 334 335 336 337 338 339 340 341 342 343
        class_weight: Union[float, List[float], Tuple[float]]. Optional. Default=None
            - float indicating the weight of the minority class in binary classification
              problems (e.g. 9.)
            - a list or tuple with weights for the different classes in multiclass
              classification problems  (e.g. [1., 2., 3.]). The weights do
              not neccesarily need to be normalised. If your loss function
              uses reduction='mean', the loss will be normalized by the sum
              of the corresponding weights for each element. If you are
              using reduction='none', you would have to take care of the
              normalization yourself. See `this discussion
              <https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/10>`_.
        with_focal_loss: bool, Optional. Default=False
            Boolean indicating whether to use the Focal Loss for highly imbalanced problems.
344 345
            For details on the focal loss see the `original paper
            <https://arxiv.org/pdf/1708.02002.pdf>`_.
346 347 348 349 350
        alpha: float. Default=0.25
            Focal Loss alpha parameter.
        gamma: float. Default=2
            Focal Loss gamma parameter.
        verbose: int
351
            Setting it to 0 will print nothing during training.
352
        seed: int, Default=1
353
            Random seed to be used throughout all the methods
354 355 356

        Example
        --------
357 358 359 360 361 362
        >>> import torch
        >>> from torchvision.transforms import ToTensor
        >>>
        >>> from pytorch_widedeep.callbacks import EarlyStopping, LRHistory
        >>> from pytorch_widedeep.initializers import KaimingNormal, KaimingUniform, Normal, Uniform
        >>> from pytorch_widedeep.models import DeepDenseResnet, DeepImage, DeepText, Wide, WideDeep
363
        >>> from pytorch_widedeep.optim import RAdam
364 365 366
        >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
        >>> deep_column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
        >>> wide = Wide(10, 1)
367
        >>> deeptabular = DeepDenseResnet(blocks=[8, 4], deep_column_idx=deep_column_idx, embed_input=embed_input)
368 369
        >>> deeptext = DeepText(vocab_size=10, embed_dim=4, padding_idx=0)
        >>> deepimage = DeepImage(pretrained=False)
370
        >>> model = WideDeep(wide=wide, deeptabular=deeptabular, deeptext=deeptext, deepimage=deepimage)
371
        >>>
372
        >>> wide_opt = torch.optim.Adam(model.wide.parameters())
373
        >>> deep_opt = torch.optim.Adam(model.deeptabular.parameters())
374
        >>> text_opt = RAdam(model.deeptext.parameters())
375 376
        >>> img_opt = RAdam(model.deepimage.parameters())
        >>>
377 378 379
        >>> wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=5)
        >>> deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=3)
        >>> text_sch = torch.optim.lr_scheduler.StepLR(text_opt, step_size=5)
380
        >>> img_sch = torch.optim.lr_scheduler.StepLR(img_opt, step_size=3)
381 382 383
        >>> optimizers = {"wide": wide_opt, "deeptabular": deep_opt, "deeptext": text_opt, "deepimage": img_opt}
        >>> schedulers = {"wide": wide_sch, "deeptabular": deep_sch, "deeptext": text_sch, "deepimage": img_sch}
        >>> initializers = {"wide": Uniform, "deeptabular": Normal, "deeptext": KaimingNormal, "deepimage": KaimingUniform}
384 385 386 387
        >>> transforms = [ToTensor]
        >>> callbacks = [LRHistory(n_epochs=4), EarlyStopping]
        >>> model.compile(method="regression", initializers=initializers, optimizers=optimizers,
        ... lr_schedulers=schedulers, callbacks=callbacks, transforms=transforms)
388
        """
389 390 391

        if isinstance(optimizers, Dict) and not isinstance(lr_schedulers, Dict):
            raise ValueError(
J
jrzaurin 已提交
392 393 394
                "''optimizers' and 'lr_schedulers' must have consistent type: "
                "(Optimizer and LRScheduler) or (Dict[str, Optimizer] and Dict[str, LRScheduler]) "
                "Please, read the documentation or see the examples for more details"
395 396
            )

397
        self.verbose = verbose
398
        self.seed = seed
399
        self.early_stop = False
400
        self.method = method
401
        self.with_focal_loss = with_focal_loss
J
jrzaurin 已提交
402 403
        if self.with_focal_loss:
            self.alpha, self.gamma = alpha, gamma
404

405
        if isinstance(class_weight, float):
J
jrzaurin 已提交
406 407 408
            self.class_weight = torch.tensor([1.0 - class_weight, class_weight])
        elif isinstance(class_weight, (tuple, list)):
            self.class_weight = torch.tensor(class_weight)
409 410
        else:
            self.class_weight = None
411 412

        if initializers is not None:
413
            self.initializer = MultipleInitializer(initializers, verbose=self.verbose)
414 415
            self.initializer.apply(self)

416 417
        if optimizers is not None:
            if isinstance(optimizers, Optimizer):
J
jrzaurin 已提交
418
                self.optimizer: Union[Optimizer, MultipleOptimizer] = optimizers
419
            elif isinstance(optimizers, Dict):
420
                opt_names = list(optimizers.keys())
J
jrzaurin 已提交
421 422 423
                mod_names = [n for n, c in self.named_children()]
                for mn in mod_names:
                    assert mn in opt_names, "No optimizer found for {}".format(mn)
424
                self.optimizer = MultipleOptimizer(optimizers)
425
        else:
J
jrzaurin 已提交
426
            self.optimizer = torch.optim.AdamW(self.parameters())  # type: ignore
427

428 429
        if lr_schedulers is not None:
            if isinstance(lr_schedulers, LRScheduler):
J
jrzaurin 已提交
430
                self.lr_scheduler: Union[
431 432
                    LRScheduler,
                    MultipleLRScheduler,
J
jrzaurin 已提交
433
                ] = lr_schedulers
434
                self.cyclic_lr = "cycl" in self.lr_scheduler.__class__.__name__.lower()
435
            else:
436
                self.lr_scheduler = MultipleLRScheduler(lr_schedulers)
J
jrzaurin 已提交
437 438 439 440
                scheduler_names = [
                    sc.__class__.__name__.lower()
                    for _, sc in self.lr_scheduler._schedulers.items()
                ]
441
                self.cyclic_lr = any(["cycl" in sn for sn in scheduler_names])
442
        else:
443
            self.lr_scheduler, self.cyclic_lr = None, False
444

445
        if transforms is not None:
J
jrzaurin 已提交
446
            self.transforms: MultipleTransforms = MultipleTransforms(transforms)()
447 448 449
        else:
            self.transforms = None

450
        self.callbacks: List = [History()]
451
        if callbacks is not None:
452
            for callback in callbacks:
J
jrzaurin 已提交
453 454
                if isinstance(callback, type):
                    callback = callback()
455
                self.callbacks.append(callback)
456 457 458

        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
459
            self.callbacks += [MetricCallback(self.metric)]
460 461
        else:
            self.metric = None
462

463 464
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self)
465

466
        self.to(device)
J
jrzaurin 已提交
467

468 469
    @deprecated_alias(X_deep="X_tab")  # noqa: C901
    def fit(
J
jrzaurin 已提交
470 471
        self,
        X_wide: Optional[np.ndarray] = None,
472
        X_tab: Optional[np.ndarray] = None,
J
jrzaurin 已提交
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
        n_epochs: int = 1,
        validation_freq: int = 1,
        batch_size: int = 32,
        patience: int = 10,
        warm_up: bool = False,
        warm_epochs: int = 4,
        warm_max_lr: float = 0.01,
        warm_deeptext_gradual: bool = False,
        warm_deeptext_max_lr: float = 0.01,
        warm_deeptext_layers: Optional[List[nn.Module]] = None,
        warm_deepimage_gradual: bool = False,
        warm_deepimage_max_lr: float = 0.01,
        warm_deepimage_layers: Optional[List[nn.Module]] = None,
        warm_routine: str = "howard",
    ):
494
        r"""Fit method. Must run after calling ``compile``
495 496 497

        Parameters
        ----------
498
        X_wide: np.ndarray, Optional. Default=None
499 500
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
501 502 503
        X_tab: np.ndarray, Optional. Default=None
            Input for the ``deeptabular`` model component.
            See :class:`pytorch_widedeep.preprocessing.TabPreprocessor`
504
        X_text: np.ndarray, Optional. Default=None
505 506
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
507
        X_img : np.ndarray, Optional. Default=None
508 509 510 511
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
        X_train: Dict[str, np.ndarray], Optional. Default=None
            Training dataset for the different model components. Keys are
512
            `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values are
513
            the corresponding matrices.
514
        X_val: Dict, Optional. Default=None
515
            Validation dataset for the different model component. Keys are
516
            `'X_wide'`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values are
517
            the corresponding matrices.
518 519
        val_split: float, Optional. Default=None
            train/val split fraction
520 521
        target: np.ndarray, Optional. Default=None
            target values
522 523 524 525 526 527
        n_epochs: int, Default=1
            number of epochs
        validation_freq: int, Default=1
            epochs validation frequency
        batch_size: int, Default=32
        patience: int, Default=10
528 529
            Number of epochs without improving the target metric before
            the fit process stops
530
        warm_up: bool, Default=False
531 532 533
            warm up model components individually before the joined training
            starts.

534 535 536 537 538
            ``pytorch_widedeep`` implements 3 warm up routines.

            - Warm up all trainable layers at once. This routine is is
              inspired by the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_. Using a
539 540 541 542 543 544
              Slanted Triangular learing (see `Leslie N. Smith paper
              <https://arxiv.org/pdf/1506.01186.pdf>`_), the process is the
              following: `i`) the learning rate will gradually increase for
              10% of the training steps from max_lr/10 to max_lr. `ii`) It
              will then gradually decrease to max_lr/10 for the remaining 90%
              of the steps. The optimizer used in the process is ``AdamW``.
545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560

            and two gradual warm up routines, where only certain layers are
            warmed up at each warm up step.

            - The so called `Felbo` gradual warm up rourine, inpired by the Felbo et al., 2017
              `DeepEmoji paper <https://arxiv.org/abs/1708.00524>`_.
            - The `Howard` routine based on the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_.

            For details on how these routines work, please see the examples
            section in this documentation and the corresponding repo.
        warm_epochs: int, Default=4
            Number of warm up epochs for those model components that will NOT
            be gradually warmed up. Those components with gradual warm up
            follow their corresponding specific routine.
        warm_max_lr: float, Default=0.01
561
            Maximum learning rate during the Triangular Learning rate cycle
562 563
            for those model componenst that will NOT be gradually warmed up
        warm_deeptext_gradual: bool, Default=False
564 565
            Boolean indicating if the deeptext component will be warmed
            up gradually
566
        warm_deeptext_max_lr: float, Default=0.01
567 568
            Maximum learning rate during the Triangular Learning rate cycle
            for the deeptext component
569
        warm_deeptext_layers: List, Optional, Default=None
570
            List of :obj:`nn.Modules` that will be warmed up gradually.
571 572 573 574 575

            .. note:: These have to be in `warm-up-order`: the layers or blocks
                close to the output neuron(s) first

        warm_deepimage_gradual: bool, Default=False
576 577 578 579 580
            Boolean indicating if the deepimage component will be warmed
            up gradually
        warm_deepimage_max_lr: Float, Default=0.01
            Maximum learning rate during the Triangular Learning rate cycle
            for the deepimage component
581
        warm_deepimage_layers: List, Optional, Default=None
582
            List of :obj:`nn.Modules` that will be warmed up gradually.
583

584 585 586
            .. note:: These have to be in `warm-up-order`: the layers or blocks
                close to the output neuron(s) first

587
        warm_routine: str, Default=`felbo`
588 589 590 591 592
            Warm up routine. On of `felbo` or `howard`. See the examples
            section in this documentation and the corresponding repo for
            details on how to use warm up routines

        Examples
593
        --------
594 595 596 597 598 599 600

        For a series of comprehensive examples please, see the `example
        <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_.
        folder in the repo

        For completion, here we include some `"fabricated"` examples, i.e. these assume
        you have already built and compiled the model
601

602 603

        >>> # Ex 1. using train input arrays directly and no validation
604
        >>> # model.fit(X_wide=X_wide, X_tab=X_tab, target=target, n_epochs=10, batch_size=256)
605

606 607

        >>> # Ex 2: using train input arrays directly and validation with val_split
608
        >>> # model.fit(X_wide=X_wide, X_tab=X_tab, target=target, n_epochs=10, batch_size=256, val_split=0.2)
609

610 611

        >>> # Ex 3: using train dict and val_split
612
        >>> # X_train = {'X_wide': X_wide, 'X_tab': X_tab, 'target': y}
613
        >>> # model.fit(X_train, n_epochs=10, batch_size=256, val_split=0.2)
614

615 616

        >>> # Ex 4: validation using training and validation dicts
617 618
        >>> # X_train = {'X_wide': X_wide_tr, 'X_tab': X_tab_tr, 'target': y_tr}
        >>> # X_val = {'X_wide': X_wide_val, 'X_tab': X_tab_val, 'target': y_val}
619
        >>> # model.fit(X_train=X_train, X_val=X_val n_epochs=10, batch_size=256)
620

621
        """
622

623
        self.batch_size = batch_size
J
jrzaurin 已提交
624
        train_set, eval_set = self._train_val_split(
625
            X_wide, X_tab, X_text, X_img, X_train, X_val, val_split, target
J
jrzaurin 已提交
626 627 628 629
        )
        train_loader = DataLoader(
            dataset=train_set, batch_size=batch_size, num_workers=n_cpus
        )
630 631
        if warm_up:
            # warm up...
J
jrzaurin 已提交
632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649
            self._warm_up(
                train_loader,
                warm_epochs,
                warm_max_lr,
                warm_deeptext_gradual,
                warm_deeptext_layers,
                warm_deeptext_max_lr,
                warm_deepimage_gradual,
                warm_deepimage_layers,
                warm_deepimage_max_lr,
                warm_routine,
            )
        train_steps = len(train_loader)
        self.callback_container.on_train_begin(
            {"batch_size": batch_size, "train_steps": train_steps, "n_epochs": n_epochs}
        )
        if self.verbose:
            print("Training")
650
        for epoch in range(n_epochs):
651
            # train step...
J
jrzaurin 已提交
652
            epoch_logs: Dict[str, float] = {}
653
            self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
J
jrzaurin 已提交
654
            self.train_running_loss = 0.0
655
            with trange(train_steps, disable=self.verbose != 1) as t:
J
jrzaurin 已提交
656 657
                for batch_idx, (data, target) in zip(t, train_loader):
                    t.set_description("epoch %i" % (epoch + 1))
658 659 660 661 662 663
                    score, train_loss = self._training_step(data, target, batch_idx)
                    if score is not None:
                        t.set_postfix(
                            metrics={k: np.round(v, 4) for k, v in score.items()},
                            loss=train_loss,
                        )
664
                    else:
665
                        t.set_postfix(loss=train_loss)
J
jrzaurin 已提交
666 667
                    if self.lr_scheduler:
                        self._lr_scheduler_step(step_location="on_batch_end")
668
                    self.callback_container.on_batch_end(batch=batch_idx)
J
jrzaurin 已提交
669
            epoch_logs["train_loss"] = train_loss
670 671 672 673
            if score is not None:
                for k, v in score.items():
                    log_k = "_".join(["train", k])
                    epoch_logs[log_k] = v
674
            # eval step...
J
jrzaurin 已提交
675
            if epoch % validation_freq == (validation_freq - 1):
676
                if eval_set is not None:
J
jrzaurin 已提交
677 678 679 680 681 682 683 684
                    eval_loader = DataLoader(
                        dataset=eval_set,
                        batch_size=batch_size,
                        num_workers=n_cpus,
                        shuffle=False,
                    )
                    eval_steps = len(eval_loader)
                    self.valid_running_loss = 0.0
685
                    with trange(eval_steps, disable=self.verbose != 1) as v:
J
jrzaurin 已提交
686 687
                        for i, (data, target) in zip(v, eval_loader):
                            v.set_description("valid")
688 689 690 691 692 693 694 695
                            score, val_loss = self._validation_step(data, target, i)
                            if score is not None:
                                v.set_postfix(
                                    metrics={
                                        k: np.round(v, 4) for k, v in score.items()
                                    },
                                    loss=val_loss,
                                )
696
                            else:
697
                                v.set_postfix(loss=val_loss)
J
jrzaurin 已提交
698
                    epoch_logs["val_loss"] = val_loss
699 700 701 702
                    if score is not None:
                        for k, v in score.items():
                            log_k = "_".join(["val", k])
                            epoch_logs[log_k] = v
J
jrzaurin 已提交
703 704 705
            if self.lr_scheduler:
                self._lr_scheduler_step(step_location="on_epoch_end")
            #  log and check if early_stop...
706
            self.callback_container.on_epoch_end(epoch, epoch_logs)
707
            if self.early_stop:
J
jrzaurin 已提交
708
                self.callback_container.on_train_end(epoch_logs)
709
                break
J
jrzaurin 已提交
710
            self.callback_container.on_train_end(epoch_logs)
711 712
        self.train()

713
    @deprecated_alias(X_deep="X_tab")  # noqa: C901
J
jrzaurin 已提交
714 715
    def predict(
        self,
J
jrzaurin 已提交
716
        X_wide: Optional[np.ndarray] = None,
717
        X_tab: Optional[np.ndarray] = None,
J
jrzaurin 已提交
718 719 720 721
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
722
        r"""Returns the predictions
723 724 725

        Parameters
        ----------
726
        X_wide: np.ndarray, Optional. Default=None
727 728
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
729 730 731
        X_tab: np.ndarray, Optional. Default=None
            Input for the ``deeptabular`` model component.
            See :class:`pytorch_widedeep.preprocessing.TabPreprocessor`
732
        X_text: np.ndarray, Optional. Default=None
733 734
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
735
        X_img : np.ndarray, Optional. Default=None
736 737 738 739
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
        X_test: Dict[str, np.ndarray], Optional. Default=None
            Testing dataset for the different model components. Keys are
740
            `'X_wide'`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'` the values are
741 742
            the corresponding matrices.

743
        """
744 745

        preds_l = self._predict(X_wide, X_tab, X_text, X_img, X_test)
746 747 748 749
        if self.method == "regression":
            return np.vstack(preds_l).squeeze(1)
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
J
jrzaurin 已提交
750
            return (preds > 0.5).astype("int")
751 752 753
        if self.method == "multiclass":
            preds = np.vstack(preds_l)
            return np.argmax(preds, 1)
754

755
    @deprecated_alias(X_deep="X_tab")
J
jrzaurin 已提交
756 757
    def predict_proba(
        self,
J
jrzaurin 已提交
758
        X_wide: Optional[np.ndarray] = None,
759
        X_tab: Optional[np.ndarray] = None,
J
jrzaurin 已提交
760 761 762 763
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
764
        r"""Returns the predicted probabilities for the test dataset for  binary
765
        and multiclass methods
766
        """
767 768

        preds_l = self._predict(X_wide, X_tab, X_text, X_img, X_test)
769 770
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
J
jrzaurin 已提交
771 772 773
            probs = np.zeros([preds.shape[0], 2])
            probs[:, 0] = 1 - preds
            probs[:, 1] = preds
774 775 776
            return probs
        if self.method == "multiclass":
            return np.vstack(preds_l)
777

J
jrzaurin 已提交
778 779
    def get_embeddings(
        self, col_name: str, cat_encoding_dict: Dict[str, Dict[str, int]]
780
    ) -> Dict[str, np.ndarray]:  # pragma: no cover
781
        r"""Returns the learned embeddings for the categorical features passed through
782
        ``deeptabular``.
783

784 785
        This method is designed to take an encoding dictionary in the same
        format as that of the :obj:`LabelEncoder` Attribute of the class
786 787
        :obj:`TabPreprocessor`. See
        :class:`pytorch_widedeep.preprocessing.TabPreprocessor` and
788
        :class:`pytorch_widedeep.utils.dense_utils.LabelEncder`.
789 790 791

        Parameters
        ----------
792 793
        col_name: str,
            Column name of the feature we want to get the embeddings for
794 795 796 797 798 799
        cat_encoding_dict: Dict[str, Dict[str, int]]
            Dictionary containing the categorical encodings, e.g:

        Examples
        --------

800 801 802 803 804 805 806 807
        For a series of comprehensive examples please, see the `example
        <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_.
        folder in the repo

        For completion, here we include a `"fabricated"` example, i.e.
        assuming we have already trained the model, that we have the
        categorical encodings in a dictionary name ``encoding_dict``, and that
        there is a column called `'education'`:
808

809
        >>> # model.get_embeddings(col_name='education', cat_encoding_dict=encoding_dict)
810
        """
J
jrzaurin 已提交
811 812
        for n, p in self.named_parameters():
            if "embed_layers" in n and col_name in n:
813 814
                embed_mtx = p.cpu().data.numpy()
        encoding_dict = cat_encoding_dict[col_name]
J
jrzaurin 已提交
815
        inv_encoding_dict = {v: k for k, v in encoding_dict.items()}
816
        cat_embed_dict = {}
J
jrzaurin 已提交
817
        for idx, value in inv_encoding_dict.items():
818
            cat_embed_dict[value] = embed_mtx[idx]
819 820
        return cat_embed_dict

J
jrzaurin 已提交
821
    def _loss_fn(self, y_pred: Tensor, y_true: Tensor) -> Tensor:  # type: ignore
822 823
        if self.with_focal_loss:
            return FocalLoss(self.alpha, self.gamma)(y_pred, y_true)
J
jrzaurin 已提交
824
        if self.method == "regression":
825
            return F.mse_loss(y_pred, y_true.view(-1, 1))
J
jrzaurin 已提交
826
        if self.method == "binary":
827
            return F.binary_cross_entropy_with_logits(
J
jrzaurin 已提交
828 829 830
                y_pred, y_true.view(-1, 1), weight=self.class_weight
            )
        if self.method == "multiclass":
831 832
            return F.cross_entropy(y_pred, y_true, weight=self.class_weight)

833
    def _train_val_split(  # noqa: C901
J
jrzaurin 已提交
834 835
        self,
        X_wide: Optional[np.ndarray] = None,
836
        X_tab: Optional[np.ndarray] = None,
J
jrzaurin 已提交
837 838 839 840 841 842 843
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
    ):
844 845 846 847 848 849 850 851 852
        r"""
        If a validation set (X_val) is passed to the fit method, or val_split
        is specified, the train/val split will happen internally. A number of
        options are allowed in terms of data inputs. For parameter
        information, please, see the .fit() method documentation

        Returns
        -------
        train_set: WideDeepDataset
853 854 855
            :obj:`WideDeepDataset` object that will be loaded through
            :obj:`torch.utils.data.DataLoader`. See
            :class:`pytorch_widedeep.models._wd_dataset`
856
        eval_set : WideDeepDataset
857 858 859
            :obj:`WideDeepDataset` object that will be loaded through
            :obj:`torch.utils.data.DataLoader`. See
            :class:`pytorch_widedeep.models._wd_dataset`
860
        """
861 862 863 864 865

        if X_val is not None:
            assert (
                X_train is not None
            ), "if the validation set is passed as a dictionary, the training set must also be a dictionary"
J
jrzaurin 已提交
866
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
867 868 869
            eval_set = WideDeepDataset(**X_val, transforms=self.transforms)  # type: ignore
        elif val_split is not None:
            if not X_train:
870
                X_train = self._build_train_dict(X_wide, X_tab, X_text, X_img, target)
871 872 873 874 875 876 877 878 879 880 881
            y_tr, y_val, idx_tr, idx_val = train_test_split(
                X_train["target"],
                np.arange(len(X_train["target"])),
                test_size=val_split,
                stratify=X_train["target"] if self.method != "regression" else None,
            )
            X_tr, X_val = {"target": y_tr}, {"target": y_val}
            if "X_wide" in X_train.keys():
                X_tr["X_wide"], X_val["X_wide"] = (
                    X_train["X_wide"][idx_tr],
                    X_train["X_wide"][idx_val],
J
jrzaurin 已提交
882
                )
883 884 885 886
            if "X_tab" in X_train.keys():
                X_tr["X_tab"], X_val["X_tab"] = (
                    X_train["X_tab"][idx_tr],
                    X_train["X_tab"][idx_val],
887 888 889 890 891 892 893 894 895 896 897 898
                )
            if "X_text" in X_train.keys():
                X_tr["X_text"], X_val["X_text"] = (
                    X_train["X_text"][idx_tr],
                    X_train["X_text"][idx_val],
                )
            if "X_img" in X_train.keys():
                X_tr["X_img"], X_val["X_img"] = (
                    X_train["X_img"][idx_tr],
                    X_train["X_img"][idx_val],
                )
            train_set = WideDeepDataset(**X_tr, transforms=self.transforms)  # type: ignore
J
jrzaurin 已提交
899
            eval_set = WideDeepDataset(**X_val, transforms=self.transforms)  # type: ignore
900 901
        else:
            if not X_train:
902
                X_train = self._build_train_dict(X_wide, X_tab, X_text, X_img, target)
903 904 905
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
            eval_set = None

906 907
        return train_set, eval_set

J
jrzaurin 已提交
908 909 910 911 912 913 914 915 916 917 918 919
    def _warm_up(
        self,
        loader: DataLoader,
        n_epochs: int,
        max_lr: float,
        deeptext_gradual: bool,
        deeptext_layers: List[nn.Module],
        deeptext_max_lr: float,
        deepimage_gradual: bool,
        deepimage_layers: List[nn.Module],
        deepimage_max_lr: float,
        routine: str = "felbo",
920
    ):  # pragma: no cover
921 922 923
        r"""
        Simple wrappup to individually warm up model components
        """
924 925
        if self.deephead is not None:
            raise ValueError(
J
jrzaurin 已提交
926 927
                "Currently warming up is only supported without a fully connected 'DeepHead'"
            )
928 929
        # This is not the most elegant solution, but is a soluton "in-between"
        # a non elegant one and re-factoring the whole code
930
        warmer = WarmUp(self._loss_fn, self.metric, self.method, self.verbose)
J
jrzaurin 已提交
931
        warmer.warm_all(self.wide, "wide", loader, n_epochs, max_lr)
932
        warmer.warm_all(self.deeptabular, "deeptabular", loader, n_epochs, max_lr)
933 934
        if self.deeptext:
            if deeptext_gradual:
J
jrzaurin 已提交
935 936 937 938 939 940 941 942 943 944
                warmer.warm_gradual(
                    self.deeptext,
                    "deeptext",
                    loader,
                    deeptext_max_lr,
                    deeptext_layers,
                    routine,
                )
            else:
                warmer.warm_all(self.deeptext, "deeptext", loader, n_epochs, max_lr)
945 946
        if self.deepimage:
            if deepimage_gradual:
J
jrzaurin 已提交
947 948 949 950 951 952 953 954 955 956
                warmer.warm_gradual(
                    self.deepimage,
                    "deepimage",
                    loader,
                    deepimage_max_lr,
                    deepimage_layers,
                    routine,
                )
            else:
                warmer.warm_all(self.deepimage, "deepimage", loader, n_epochs, max_lr)
957

958
    def _lr_scheduler_step(self, step_location: str):  # noqa: C901
959 960 961
        r"""
        Function to execute the learning rate schedulers steps.
        If the lr_scheduler is Cyclic (i.e. CyclicLR or OneCycleLR), the step
962
        must happen after training each bach durig training. On the other
963 964 965 966 967 968 969 970
        hand, if the  scheduler is not Cyclic, is expected to be called after
        validation.

        Parameters
        ----------
        step_location: Str
            Indicates where to run the lr_scheduler step
        """
J
jrzaurin 已提交
971 972
        if (
            self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"
973
            and self.cyclic_lr
J
jrzaurin 已提交
974 975 976 977 978 979 980 981 982
        ):
            if step_location == "on_batch_end":
                for model_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
            elif step_location == "on_epoch_end":
                for scheduler_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" not in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
983
        elif self.cyclic_lr:
J
jrzaurin 已提交
984 985 986 987 988 989 990 991 992 993 994 995 996 997 998
            if step_location == "on_batch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler":
            if step_location == "on_epoch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif step_location == "on_epoch_end":
            self.lr_scheduler.step()  # type: ignore
        else:
            pass

    def _training_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
999
        self.train()
J
jrzaurin 已提交
1000 1001
        X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
        y = target.float() if self.method != "multiclass" else target
1002
        y = y.to(device)
1003 1004

        self.optimizer.zero_grad()
1005
        y_pred = self.forward(X)
1006 1007 1008 1009 1010
        loss = self._loss_fn(y_pred, y)
        loss.backward()
        self.optimizer.step()

        self.train_running_loss += loss.item()
J
jrzaurin 已提交
1011
        avg_loss = self.train_running_loss / (batch_idx + 1)
1012 1013

        if self.metric is not None:
1014
            if self.method == "binary":
1015
                score = self.metric(torch.sigmoid(y_pred), y)
1016
            if self.method == "multiclass":
1017 1018
                score = self.metric(F.softmax(y_pred, dim=1), y)
            return score, avg_loss
1019 1020 1021
        else:
            return None, avg_loss

J
jrzaurin 已提交
1022
    def _validation_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
1023 1024 1025

        self.eval()
        with torch.no_grad():
J
jrzaurin 已提交
1026 1027
            X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
            y = target.float() if self.method != "multiclass" else target
1028
            y = y.to(device)
1029

1030
            y_pred = self.forward(X)
1031 1032
            loss = self._loss_fn(y_pred, y)
            self.valid_running_loss += loss.item()
J
jrzaurin 已提交
1033
            avg_loss = self.valid_running_loss / (batch_idx + 1)
1034 1035

        if self.metric is not None:
1036
            if self.method == "binary":
1037
                score = self.metric(torch.sigmoid(y_pred), y)
1038
            if self.method == "multiclass":
1039 1040
                score = self.metric(F.softmax(y_pred, dim=1), y)
            return score, avg_loss
1041 1042 1043
        else:
            return None, avg_loss

J
jrzaurin 已提交
1044 1045
    def _predict(
        self,
J
jrzaurin 已提交
1046
        X_wide: Optional[np.ndarray] = None,
1047
        X_tab: Optional[np.ndarray] = None,
J
jrzaurin 已提交
1048 1049 1050 1051
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> List:
1052 1053 1054
        r"""Hidden method to avoid code repetition in predict and
        predict_proba. For parameter information, please, see the .predict()
        method documentation
1055 1056 1057 1058
        """
        if X_test is not None:
            test_set = WideDeepDataset(**X_test)
        else:
J
jrzaurin 已提交
1059 1060 1061
            load_dict = {}
            if X_wide is not None:
                load_dict = {"X_wide": X_wide}
1062 1063
            if X_tab is not None:
                load_dict.update({"X_tab": X_tab})
J
jrzaurin 已提交
1064 1065 1066 1067
            if X_text is not None:
                load_dict.update({"X_text": X_text})
            if X_img is not None:
                load_dict.update({"X_img": X_img})
1068 1069
            test_set = WideDeepDataset(**load_dict)

J
jrzaurin 已提交
1070 1071 1072 1073 1074 1075
        test_loader = DataLoader(
            dataset=test_set,
            batch_size=self.batch_size,
            num_workers=n_cpus,
            shuffle=False,
        )
1076
        test_steps = (len(test_loader.dataset) // test_loader.batch_size) + 1  # type: ignore[arg-type]
1077 1078 1079 1080 1081 1082

        self.eval()
        preds_l = []
        with torch.no_grad():
            with trange(test_steps, disable=self.verbose != 1) as t:
                for i, data in zip(t, test_loader):
J
jrzaurin 已提交
1083 1084
                    t.set_description("predict")
                    X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
1085 1086 1087
                    preds = self.forward(X)
                    if self.method == "binary":
                        preds = torch.sigmoid(preds)
J
jrzaurin 已提交
1088 1089
                    if self.method == "multiclass":
                        preds = F.softmax(preds, dim=1)
J
jrzaurin 已提交
1090
                    preds = preds.cpu().data.numpy()
1091 1092
                    preds_l.append(preds)
        self.train()
J
jrzaurin 已提交
1093
        return preds_l
1094 1095

    @staticmethod
1096
    def _build_train_dict(X_wide, X_tab, X_text, X_img, target):
1097 1098 1099
        X_train = {"target": target}
        if X_wide is not None:
            X_train["X_wide"] = X_wide
1100 1101
        if X_tab is not None:
            X_train["X_tab"] = X_tab
1102 1103 1104 1105 1106 1107
        if X_text is not None:
            X_train["X_text"] = X_text
        if X_img is not None:
            X_train["X_img"] = X_img
        return X_train

J
jrzaurin 已提交
1108
    @staticmethod  # noqa: C901
J
jrzaurin 已提交
1109 1110
    def _check_model_components(
        wide,
1111
        deeptabular,
J
jrzaurin 已提交
1112 1113 1114 1115 1116 1117
        deeptext,
        deepimage,
        deephead,
        head_layers,
        head_dropout,
        pred_dim,
1118 1119
    ):

J
jrzaurin 已提交
1120 1121 1122 1123 1124 1125 1126
        if wide is not None:
            assert wide.wide_linear.weight.size(1) == pred_dim, (
                "the 'pred_dim' of the wide component ({}) must be equal to the 'pred_dim' "
                "of the deep component and the overall model itself ({})".format(
                    wide.wide_linear.weight.size(1), pred_dim
                )
            )
1127
        if deeptabular is not None and not hasattr(deeptabular, "output_dim"):
1128
            raise AttributeError(
1129
                "deeptabular model must have an 'output_dim' attribute. "
1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
        if deeptext is not None and not hasattr(deeptext, "output_dim"):
            raise AttributeError(
                "deeptext model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
        if deepimage is not None and not hasattr(deepimage, "output_dim"):
            raise AttributeError(
                "deepimage model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
        if deephead is not None and head_layers is not None:
            raise ValueError(
                "both 'deephead' and 'head_layers' are not None. Use one of the other, but not both"
            )
1146 1147 1148 1149 1150 1151
        if (
            head_layers is not None
            and not deeptabular
            and not deeptext
            and not deepimage
        ):
1152 1153 1154 1155 1156 1157 1158
            raise ValueError(
                "if 'head_layers' is not None, at least one deep component must be used"
            )
        if head_layers is not None and head_dropout is not None:
            assert len(head_layers) == len(
                head_dropout
            ), "'head_layers' and 'head_dropout' must have the same length"
J
jrzaurin 已提交
1159 1160 1161
        if deephead is not None:
            deephead_inp_feat = next(deephead.parameters()).size(1)
            output_dim = 0
1162 1163
            if deeptabular is not None:
                output_dim += deeptabular.output_dim
J
jrzaurin 已提交
1164 1165 1166 1167 1168 1169 1170 1171 1172 1173
            if deeptext is not None:
                output_dim += deeptext.output_dim
            if deepimage is not None:
                output_dim += deepimage.output_dim
            assert deephead_inp_feat == output_dim, (
                "if a custom 'deephead' is used its input features ({}) must be equal to "
                "the output features of the deep component ({})".format(
                    deephead_inp_feat, output_dim
                )
            )