wide_deep.py 49.0 KB
Newer Older
1
import os
2 3

import numpy as np
4
import torch
5 6
import torch.nn as nn
import torch.nn.functional as F
7 8 9
from tqdm import trange
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
10

11
from ..losses import FocalLoss
12 13 14 15 16
from ._warmup import WarmUp
from ..metrics import Metric, MetricCallback, MultipleMetrics
from ..wdtypes import *
from ..callbacks import History, Callback, CallbackContainer
from .deep_dense import dense_layer
17
from ._wd_dataset import WideDeepDataset
18
from ..initializers import Initializer, MultipleInitializer
19 20
from ._multiple_optimizer import MultipleOptimizer
from ._multiple_transforms import MultipleTransforms
21
from ._multiple_lr_scheduler import MultipleLRScheduler
J
jrzaurin 已提交
22

23 24 25
# import warnings


26
n_cpus = os.cpu_count()
27 28 29 30
use_cuda = torch.cuda.is_available()


class WideDeep(nn.Module):
31 32 33
    r"""Main collector class that combines all ``Wide``, ``DeepDense``,
    ``DeepText`` and ``DeepImage`` models.

34 35
    There are two options to combine these models that correspond to the two
    architectures that ``pytorch-widedeep`` can build.
36 37 38 39 40 41

        - Directly connecting the output of the model components to an ouput neuron(s).

        - Adding a `Fully-Connected Head` (FC-Head) on top of the deep models.
          This FC-Head will combine the output form the ``DeepDense``, ``DeepText`` and
          ``DeepImage`` and will be then connected to the output neuron(s).
42 43 44

    Parameters
    ----------
45
    wide: nn.Module
46 47 48 49
        Wide model. We recommend using the ``Wide`` class in this package.
        However, it is possible to use a custom model as long as is consistent
        with the required architecture, see
        :class:`pytorch_widedeep.models.wide.Wide`
50
    deepdense: nn.Module
51 52 53
        `Deep dense` model comprised by the embeddings for the categorical
        features combined with numerical (also referred as continuous)
        features. We recommend using the ``DeepDense`` class in this package.
54
        However, a custom model as long as is  consistent with the required
55
        architecture. See :class:`pytorch_widedeep.models.deep_dense.DeepDense`.
56
    deeptext: nn.Module, Optional
57 58 59 60
        `Deep text` model for the text input. Must be an object of class
        ``DeepText`` or a custom model as long as is consistent with the
        required architecture. See
        :class:`pytorch_widedeep.models.deep_dense.DeepText`
61
    deepimage: nn.Module, Optional
62 63 64 65
        `Deep Image` model for the images input. Must be an object of class
        ``DeepImage`` or a custom model as long as is consistent with the
        required architecture. See
        :class:`pytorch_widedeep.models.deep_dense.DeepImage`
66
    deephead: nn.Module, Optional
67
        `Dense` model consisting in a stack of dense layers. The FC-Head.
68
    head_layers: List, Optional
69
        Alternatively, we can use ``head_layers`` to specify the sizes of the
70
        stacked dense layers in the fc-head e.g: ``[128, 64]``
71
    head_dropout: List, Optional
72
        Dropout between the layers in ``head_layers``. e.g: ``[0.5, 0.5]``
73
    head_batchnorm: bool, Optional
74
        Specifies if batch normalizatin should be included in the dense layers
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
    pred_dim: int
        Size of the final wide and deep output layer containing the
        predictions. `1` for regression and binary classification or `number
        of classes` for multiclass classification.


    .. note:: With the exception of ``cyclic``, all attributes are direct assignations of
        the corresponding parameters used when calling ``compile``.  Therefore,
        see the parameters at
        :class:`pytorch_widedeep.models.wide_deep.WideDeep.compile` for a full
        list of the attributes of an instance of
        :class:`pytorch_widedeep.models.wide_deep.Wide`


    Attributes
    ----------
    cyclic: :obj:`bool`
        Attribute that indicates if any of the lr_schedulers is cyclic (i.e. ``CyclicLR`` or
        ``OneCycleLR``). See `Pytorch schedulers <https://pytorch.org/docs/stable/optim.html>`_.
94 95 96 97 98 99 100 101 102 103 104


    .. note:: While I recommend using the ``Wide`` and ``DeepDense`` classes within
        this package when building the corresponding model components, it is very
        likely that the user will want to use custom text and image models. That
        is perfectly possible. Simply, build them and pass them as the
        corresponding parameters. Note that the custom models MUST return a last
        layer of activations (i.e. not the final prediction) so that  these
        activations are collected by ``WideDeep`` and combined accordingly. In
        addition, the models MUST also contain an attribute ``output_dim`` with
        the size of these last layers of activations. See for example
105 106 107
        :class:`pytorch_widedeep.models.deep_dense.DeepDense`

    """
J
jrzaurin 已提交
108

109
    def __init__(  # noqa: C901
J
jrzaurin 已提交
110
        self,
111 112
        wide: Optional[nn.Module] = None,
        deepdense: Optional[nn.Module] = None,
J
jrzaurin 已提交
113 114 115 116 117 118
        deeptext: Optional[nn.Module] = None,
        deepimage: Optional[nn.Module] = None,
        deephead: Optional[nn.Module] = None,
        head_layers: Optional[List[int]] = None,
        head_dropout: Optional[List] = None,
        head_batchnorm: Optional[bool] = None,
119
        pred_dim: int = 1,
J
jrzaurin 已提交
120
    ):
121

122
        super(WideDeep, self).__init__()
123

124 125 126
        self._check_params(
            deepdense, deeptext, deepimage, deephead, head_layers, head_dropout
        )
127

128 129 130
        # required as attribute just in case we pass a deephead
        self.pred_dim = pred_dim

131
        # The main 5 components of the wide and deep assemble
132 133
        self.wide = wide
        self.deepdense = deepdense
J
jrzaurin 已提交
134
        self.deeptext = deeptext
135
        self.deepimage = deepimage
136 137 138 139
        self.deephead = deephead

        if self.deephead is None:
            if head_layers is not None:
140 141 142
                input_dim = 0
                if self.deepdense is not None:
                    input_dim += self.deepdense.output_dim  # type:ignore
M
Minjin Choi 已提交
143
                if self.deeptext is not None:
144
                    input_dim += self.deeptext.output_dim  # type:ignore
M
Minjin Choi 已提交
145
                if self.deepimage is not None:
146
                    input_dim += self.deepimage.output_dim  # type:ignore
147
                head_layers = [input_dim] + head_layers
J
jrzaurin 已提交
148 149
                if not head_dropout:
                    head_dropout = [0.0] * (len(head_layers) - 1)
150 151 152
                self.deephead = nn.Sequential()
                for i in range(1, len(head_layers)):
                    self.deephead.add_module(
J
jrzaurin 已提交
153 154 155 156 157 158 159 160 161
                        "head_layer_{}".format(i - 1),
                        dense_layer(
                            head_layers[i - 1],
                            head_layers[i],
                            head_dropout[i - 1],
                            head_batchnorm,
                        ),
                    )
                self.deephead.add_module(
162
                    "head_out", nn.Linear(head_layers[-1], pred_dim)
J
jrzaurin 已提交
163
                )
164
            else:
165 166 167 168
                if self.deepdense is not None:
                    self.deepdense = nn.Sequential(
                        self.deepdense, nn.Linear(self.deepdense.output_dim, pred_dim)  # type: ignore
                    )
169 170
                if self.deeptext is not None:
                    self.deeptext = nn.Sequential(
171
                        self.deeptext, nn.Linear(self.deeptext.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
172
                    )
173 174
                if self.deepimage is not None:
                    self.deepimage = nn.Sequential(
175
                        self.deepimage, nn.Linear(self.deepimage.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
176
                    )
177 178
        # else:
        #     self.deephead
179

180
    def forward(self, X: Dict[str, Tensor]) -> Tensor:  # type: ignore  # noqa: C901
181

182
        # Wide output: direct connection to the output neuron(s)
183 184 185 186 187
        if self.wide is not None:
            out = self.wide(X["wide"])
        else:
            batch_size = X[list(X.keys())[0]].size(0)
            out = torch.zeros(batch_size, self.pred_dim)
188 189 190 191

        # Deep output: either connected directly to the output neuron(s) or
        # passed through a head first
        if self.deephead:
192 193 194 195
            if self.deepdense is not None:
                deepside = self.deepdense(X["deepdense"])
            else:
                deepside = torch.FloatTensor()
196
            if self.deeptext is not None:
J
jrzaurin 已提交
197
                deepside = torch.cat([deepside, self.deeptext(X["deeptext"])], axis=1)  # type: ignore
198
            if self.deepimage is not None:
J
jrzaurin 已提交
199
                deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1)  # type: ignore
200 201
            deephead_out = self.deephead(deepside)
            deepside_out = nn.Linear(deephead_out.size(1), self.pred_dim)(deephead_out)
202
            return out.add_(deepside_out)
203
        else:
204 205
            if self.deepdense is not None:
                out.add_(self.deepdense(X["deepdense"]))
206
            if self.deeptext is not None:
207
                out.add_(self.deeptext(X["deeptext"]))
208
            if self.deepimage is not None:
209
                out.add_(self.deepimage(X["deepimage"]))
210 211
            return out

212
    def compile(  # noqa: C901
J
jrzaurin 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
        self,
        method: str,
        optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
        lr_schedulers: Optional[Union[LRScheduler, Dict[str, LRScheduler]]] = None,
        initializers: Optional[Dict[str, Initializer]] = None,
        transforms: Optional[List[Transforms]] = None,
        callbacks: Optional[List[Callback]] = None,
        metrics: Optional[List[Metric]] = None,
        class_weight: Optional[Union[float, List[float], Tuple[float]]] = None,
        with_focal_loss: bool = False,
        alpha: float = 0.25,
        gamma: float = 2,
        verbose: int = 1,
        seed: int = 1,
    ):
228
        r"""Method to set the of attributes that will be used during the
229
        training process.
230 231 232

        Parameters
        ----------
233
        method: str
234 235 236 237 238 239 240 241 242
            One of `regression`, `binary` or `multiclass`. The default when
            performing a `regression`, a `binary` classification or a
            `multiclass` classification is the `mean squared error
            <https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.mse_loss>`_
            (MSE), `Binary Cross Entropy
            <https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.binary_cross_entropy>`_
            (BCE) and `Cross Entropy
            <https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.cross_entropy>`_
            (CE) respectively.
243
        optimizers: Union[Optimizer, Dict[str, Optimizer]], Optional, Default=AdamW
244 245 246 247 248
            - An instance of ``pytorch``'s ``Optimizer`` object (e.g. :obj:`torch.optim.Adam()`) or
            - a dictionary where there keys are the model components (i.e.
              `'wide'`, `'deepdense'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`)  and
              the values are the corresponding optimizers. If multiple optimizers are used
              the  dictionary MUST contain an optimizer per model component.
249 250 251

            See `Pytorch optimizers <https://pytorch.org/docs/stable/optim.html>`_.
        lr_schedulers: Union[LRScheduler, Dict[str, LRScheduler]], Optional, Default=None
252 253 254 255
            - An instance of ``pytorch``'s ``LRScheduler`` object (e.g
              :obj:`torch.optim.lr_scheduler.StepLR(opt, step_size=5)`) or
            - a dictionary where there keys are the model componenst (i.e. `'wide'`,
              `'deepdense'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and the
256 257 258 259
              values are the corresponding learning rate schedulers.

            See `Pytorch schedulers <https://pytorch.org/docs/stable/optim.html>`_.
        initializers: Dict[str, Initializer], Optional. Default=None
260 261
            Dict where there keys are the model components (i.e. `'wide'`,
            `'deepdense'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and the
262
            values are the corresponding initializers.
263 264
            See `Pytorch initializers <https://pytorch.org/docs/stable/nn.init.html>`_.
        transforms: List[Transforms], Optional. Default=None
265 266 267 268 269
            ``Transforms`` is a custom type. See
            :obj:`pytorch_widedeep.wdtypes`. List with
            :obj:`torchvision.transforms` to be applied to the image component
            of the model (i.e. ``deepimage``) See `torchvision transforms
            <https://pytorch.org/docs/stable/torchvision/transforms.html>`_.
270
        callbacks: List[Callback], Optional. Default=None
271 272 273 274
            Callbacks available are: ``ModelCheckpoint``, ``EarlyStopping``,
            and ``LRHistory``. The ``History`` callback is used by default.
            See the ``Callbacks`` section in this documentation or
            :obj:`pytorch_widedeep.callbacks`
275
        metrics: List[Metric], Optional. Default=None
276 277 278
            Metrics available are: ``Accuracy``, ``Precision``, ``Recall``,
            ``FBetaScore`` and ``F1Score``.  See the ``Metrics`` section in
            this documentation or :obj:`pytorch_widedeep.metrics`
279 280 281 282 283 284 285 286 287 288 289 290 291
        class_weight: Union[float, List[float], Tuple[float]]. Optional. Default=None
            - float indicating the weight of the minority class in binary classification
              problems (e.g. 9.)
            - a list or tuple with weights for the different classes in multiclass
              classification problems  (e.g. [1., 2., 3.]). The weights do
              not neccesarily need to be normalised. If your loss function
              uses reduction='mean', the loss will be normalized by the sum
              of the corresponding weights for each element. If you are
              using reduction='none', you would have to take care of the
              normalization yourself. See `this discussion
              <https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/10>`_.
        with_focal_loss: bool, Optional. Default=False
            Boolean indicating whether to use the Focal Loss for highly imbalanced problems.
292 293
            For details on the focal loss see the `original paper
            <https://arxiv.org/pdf/1708.02002.pdf>`_.
294 295 296 297 298
        alpha: float. Default=0.25
            Focal Loss alpha parameter.
        gamma: float. Default=2
            Focal Loss gamma parameter.
        verbose: int
299
            Setting it to 0 will print nothing during training.
300
        seed: int, Default=1
301
            Random seed to be used throughout all the methods
302 303 304

        Example
        --------
305 306 307 308 309 310
        >>> import torch
        >>> from torchvision.transforms import ToTensor
        >>>
        >>> from pytorch_widedeep.callbacks import EarlyStopping, LRHistory
        >>> from pytorch_widedeep.initializers import KaimingNormal, KaimingUniform, Normal, Uniform
        >>> from pytorch_widedeep.models import DeepDenseResnet, DeepImage, DeepText, Wide, WideDeep
311
        >>> from pytorch_widedeep.optim import RAdam
312 313 314 315 316 317
        >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
        >>> deep_column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
        >>> wide = Wide(10, 1)
        >>> deepdense = DeepDenseResnet(blocks=[8, 4], deep_column_idx=deep_column_idx, embed_input=embed_input)
        >>> deeptext = DeepText(vocab_size=10, embed_dim=4, padding_idx=0)
        >>> deepimage = DeepImage(pretrained=False)
318
        >>> model = WideDeep(wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage)
319
        >>>
320 321 322
        >>> wide_opt = torch.optim.Adam(model.wide.parameters())
        >>> deep_opt = torch.optim.Adam(model.deepdense.parameters())
        >>> text_opt = RAdam(model.deeptext.parameters())
323 324
        >>> img_opt = RAdam(model.deepimage.parameters())
        >>>
325 326 327
        >>> wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=5)
        >>> deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=3)
        >>> text_sch = torch.optim.lr_scheduler.StepLR(text_opt, step_size=5)
328 329 330 331 332 333 334 335
        >>> img_sch = torch.optim.lr_scheduler.StepLR(img_opt, step_size=3)
        >>> optimizers = {"wide": wide_opt, "deepdense": deep_opt, "deeptext": text_opt, "deepimage": img_opt}
        >>> schedulers = {"wide": wide_sch, "deepdense": deep_sch, "deeptext": text_sch, "deepimage": img_sch}
        >>> initializers = {"wide": Uniform, "deepdense": Normal, "deeptext": KaimingNormal, "deepimage": KaimingUniform}
        >>> transforms = [ToTensor]
        >>> callbacks = [LRHistory(n_epochs=4), EarlyStopping]
        >>> model.compile(method="regression", initializers=initializers, optimizers=optimizers,
        ... lr_schedulers=schedulers, callbacks=callbacks, transforms=transforms)
336
        """
337 338 339 340 341 342 343 344

        if isinstance(optimizers, Dict) and not isinstance(lr_schedulers, Dict):
            raise ValueError(
                "'parameters 'optimizers' and 'lr_schedulers' must have consistent type. "
                "(Optimizer, LRScheduler) or (Dict[str, Optimizer], Dict[str, LRScheduler]) "
                "Please, read the Documentation for more details"
            )

345
        self.verbose = verbose
346
        self.seed = seed
347
        self.early_stop = False
348
        self.method = method
349
        self.with_focal_loss = with_focal_loss
J
jrzaurin 已提交
350 351
        if self.with_focal_loss:
            self.alpha, self.gamma = alpha, gamma
352

353
        if isinstance(class_weight, float):
J
jrzaurin 已提交
354 355 356
            self.class_weight = torch.tensor([1.0 - class_weight, class_weight])
        elif isinstance(class_weight, (tuple, list)):
            self.class_weight = torch.tensor(class_weight)
357 358
        else:
            self.class_weight = None
359 360

        if initializers is not None:
361
            self.initializer = MultipleInitializer(initializers, verbose=self.verbose)
362 363
            self.initializer.apply(self)

364 365
        if optimizers is not None:
            if isinstance(optimizers, Optimizer):
J
jrzaurin 已提交
366
                self.optimizer: Union[Optimizer, MultipleOptimizer] = optimizers
367
            elif isinstance(optimizers, Dict):
368
                opt_names = list(optimizers.keys())
J
jrzaurin 已提交
369 370 371
                mod_names = [n for n, c in self.named_children()]
                for mn in mod_names:
                    assert mn in opt_names, "No optimizer found for {}".format(mn)
372
                self.optimizer = MultipleOptimizer(optimizers)
373
        else:
J
jrzaurin 已提交
374
            self.optimizer = torch.optim.AdamW(self.parameters())  # type: ignore
375

376 377
        if lr_schedulers is not None:
            if isinstance(lr_schedulers, LRScheduler):
J
jrzaurin 已提交
378
                self.lr_scheduler: Union[
379 380
                    LRScheduler,
                    MultipleLRScheduler,
J
jrzaurin 已提交
381 382
                ] = lr_schedulers
                self.cyclic = "cycl" in self.lr_scheduler.__class__.__name__.lower()
383
            else:
384
                self.lr_scheduler = MultipleLRScheduler(lr_schedulers)
J
jrzaurin 已提交
385 386 387 388 389
                scheduler_names = [
                    sc.__class__.__name__.lower()
                    for _, sc in self.lr_scheduler._schedulers.items()
                ]
                self.cyclic = any(["cycl" in sn for sn in scheduler_names])
390
        else:
391
            self.lr_scheduler, self.cyclic = None, False
392

393
        if transforms is not None:
J
jrzaurin 已提交
394
            self.transforms: MultipleTransforms = MultipleTransforms(transforms)()
395 396 397
        else:
            self.transforms = None

398
        self.history = History()
J
jrzaurin 已提交
399
        self.callbacks: List = [self.history]
400
        if callbacks is not None:
401
            for callback in callbacks:
J
jrzaurin 已提交
402 403
                if isinstance(callback, type):
                    callback = callback()
404
                self.callbacks.append(callback)
405 406 407

        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
408
            self.callbacks += [MetricCallback(self.metric)]
409 410
        else:
            self.metric = None
411

412 413
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self)
414

J
jrzaurin 已提交
415 416 417
        if use_cuda:
            self.cuda()

418
    def fit(  # noqa: C901
J
jrzaurin 已提交
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
        self,
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
        n_epochs: int = 1,
        validation_freq: int = 1,
        batch_size: int = 32,
        patience: int = 10,
        warm_up: bool = False,
        warm_epochs: int = 4,
        warm_max_lr: float = 0.01,
        warm_deeptext_gradual: bool = False,
        warm_deeptext_max_lr: float = 0.01,
        warm_deeptext_layers: Optional[List[nn.Module]] = None,
        warm_deepimage_gradual: bool = False,
        warm_deepimage_max_lr: float = 0.01,
        warm_deepimage_layers: Optional[List[nn.Module]] = None,
        warm_routine: str = "howard",
    ):
443
        r"""Fit method. Must run after calling ``compile``
444 445 446

        Parameters
        ----------
447
        X_wide: np.ndarray, Optional. Default=None
448 449
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
450
        X_deep: np.ndarray, Optional. Default=None
451 452
            Input for the ``deepdense`` model component.
            See :class:`pytorch_widedeep.preprocessing.DensePreprocessor`
453
        X_text: np.ndarray, Optional. Default=None
454 455
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
456
        X_img : np.ndarray, Optional. Default=None
457 458 459 460
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
        X_train: Dict[str, np.ndarray], Optional. Default=None
            Training dataset for the different model components. Keys are
461
            `X_wide`, `'X_deep'`, `'X_text'`, `'X_img'` and `'target'`. Values are
462
            the corresponding matrices.
463
        X_val: Dict, Optional. Default=None
464
            Validation dataset for the different model component. Keys are
465 466
            `'X_wide'`, `'X_deep'`, `'X_text'`, `'X_img'` and `'target'`. Values are
            the corresponding matrices.
467 468
        val_split: float, Optional. Default=None
            train/val split fraction
469 470
        target: np.ndarray, Optional. Default=None
            target values
471 472 473 474 475 476
        n_epochs: int, Default=1
            number of epochs
        validation_freq: int, Default=1
            epochs validation frequency
        batch_size: int, Default=32
        patience: int, Default=10
477 478
            Number of epochs without improving the target metric before
            the fit process stops
479
        warm_up: bool, Default=False
480 481 482
            warm up model components individually before the joined training
            starts.

483 484 485 486 487
            ``pytorch_widedeep`` implements 3 warm up routines.

            - Warm up all trainable layers at once. This routine is is
              inspired by the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_. Using a
488 489 490 491 492 493
              Slanted Triangular learing (see `Leslie N. Smith paper
              <https://arxiv.org/pdf/1506.01186.pdf>`_), the process is the
              following: `i`) the learning rate will gradually increase for
              10% of the training steps from max_lr/10 to max_lr. `ii`) It
              will then gradually decrease to max_lr/10 for the remaining 90%
              of the steps. The optimizer used in the process is ``AdamW``.
494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509

            and two gradual warm up routines, where only certain layers are
            warmed up at each warm up step.

            - The so called `Felbo` gradual warm up rourine, inpired by the Felbo et al., 2017
              `DeepEmoji paper <https://arxiv.org/abs/1708.00524>`_.
            - The `Howard` routine based on the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_.

            For details on how these routines work, please see the examples
            section in this documentation and the corresponding repo.
        warm_epochs: int, Default=4
            Number of warm up epochs for those model components that will NOT
            be gradually warmed up. Those components with gradual warm up
            follow their corresponding specific routine.
        warm_max_lr: float, Default=0.01
510
            Maximum learning rate during the Triangular Learning rate cycle
511 512
            for those model componenst that will NOT be gradually warmed up
        warm_deeptext_gradual: bool, Default=False
513 514
            Boolean indicating if the deeptext component will be warmed
            up gradually
515
        warm_deeptext_max_lr: float, Default=0.01
516 517
            Maximum learning rate during the Triangular Learning rate cycle
            for the deeptext component
518
        warm_deeptext_layers: List, Optional, Default=None
519
            List of :obj:`nn.Modules` that will be warmed up gradually.
520 521 522 523 524

            .. note:: These have to be in `warm-up-order`: the layers or blocks
                close to the output neuron(s) first

        warm_deepimage_gradual: bool, Default=False
525 526 527 528 529
            Boolean indicating if the deepimage component will be warmed
            up gradually
        warm_deepimage_max_lr: Float, Default=0.01
            Maximum learning rate during the Triangular Learning rate cycle
            for the deepimage component
530
        warm_deepimage_layers: List, Optional, Default=None
531
            List of :obj:`nn.Modules` that will be warmed up gradually.
532

533 534 535
            .. note:: These have to be in `warm-up-order`: the layers or blocks
                close to the output neuron(s) first

536
        warm_routine: str, Default=`felbo`
537 538 539 540 541
            Warm up routine. On of `felbo` or `howard`. See the examples
            section in this documentation and the corresponding repo for
            details on how to use warm up routines

        Examples
542
        --------
543 544 545 546 547 548 549

        For a series of comprehensive examples please, see the `example
        <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_.
        folder in the repo

        For completion, here we include some `"fabricated"` examples, i.e. these assume
        you have already built and compiled the model
550

551 552

        >>> # Ex 1. using train input arrays directly and no validation
553
        >>> # model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256)
554

555 556

        >>> # Ex 2: using train input arrays directly and validation with val_split
557
        >>> # model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256, val_split=0.2)
558

559 560

        >>> # Ex 3: using train dict and val_split
561 562
        >>> # X_train = {'X_wide': X_wide, 'X_deep': X_deep, 'target': y}
        >>> # model.fit(X_train, n_epochs=10, batch_size=256, val_split=0.2)
563

564 565

        >>> # Ex 4: validation using training and validation dicts
566 567 568
        >>> # X_train = {'X_wide': X_wide_tr, 'X_deep': X_deep_tr, 'target': y_tr}
        >>> # X_val = {'X_wide': X_wide_val, 'X_deep': X_deep_val, 'target': y_val}
        >>> # model.fit(X_train=X_train, X_val=X_val n_epochs=10, batch_size=256)
569

570
        .. note:: :obj:`WideDeep` assumes that `X_wide`, `X_deep` and `target` ALWAYS exist, while
571 572 573 574 575
            `X_text` and `X_img` are optional

        .. note:: Either `X_train` or the three `X_wide`, `X_deep` and `target` must be passed to the
            fit method

576
        """
577

578
        self.batch_size = batch_size
J
jrzaurin 已提交
579 580 581 582 583 584
        train_set, eval_set = self._train_val_split(
            X_wide, X_deep, X_text, X_img, X_train, X_val, val_split, target
        )
        train_loader = DataLoader(
            dataset=train_set, batch_size=batch_size, num_workers=n_cpus
        )
585 586
        if warm_up:
            # warm up...
J
jrzaurin 已提交
587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
            self._warm_up(
                train_loader,
                warm_epochs,
                warm_max_lr,
                warm_deeptext_gradual,
                warm_deeptext_layers,
                warm_deeptext_max_lr,
                warm_deepimage_gradual,
                warm_deepimage_layers,
                warm_deepimage_max_lr,
                warm_routine,
            )
        train_steps = len(train_loader)
        self.callback_container.on_train_begin(
            {"batch_size": batch_size, "train_steps": train_steps, "n_epochs": n_epochs}
        )
        if self.verbose:
            print("Training")
605
        for epoch in range(n_epochs):
606
            # train step...
J
jrzaurin 已提交
607
            epoch_logs: Dict[str, float] = {}
608
            self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
J
jrzaurin 已提交
609
            self.train_running_loss = 0.0
610
            with trange(train_steps, disable=self.verbose != 1) as t:
J
jrzaurin 已提交
611 612
                for batch_idx, (data, target) in zip(t, train_loader):
                    t.set_description("epoch %i" % (epoch + 1))
613 614 615 616 617 618
                    score, train_loss = self._training_step(data, target, batch_idx)
                    if score is not None:
                        t.set_postfix(
                            metrics={k: np.round(v, 4) for k, v in score.items()},
                            loss=train_loss,
                        )
619
                    else:
620
                        t.set_postfix(loss=train_loss)
J
jrzaurin 已提交
621 622
                    if self.lr_scheduler:
                        self._lr_scheduler_step(step_location="on_batch_end")
623
                    self.callback_container.on_batch_end(batch=batch_idx)
J
jrzaurin 已提交
624
            epoch_logs["train_loss"] = train_loss
625 626 627 628
            if score is not None:
                for k, v in score.items():
                    log_k = "_".join(["train", k])
                    epoch_logs[log_k] = v
629
            # eval step...
J
jrzaurin 已提交
630
            if epoch % validation_freq == (validation_freq - 1):
631
                if eval_set is not None:
J
jrzaurin 已提交
632 633 634 635 636 637 638 639
                    eval_loader = DataLoader(
                        dataset=eval_set,
                        batch_size=batch_size,
                        num_workers=n_cpus,
                        shuffle=False,
                    )
                    eval_steps = len(eval_loader)
                    self.valid_running_loss = 0.0
640
                    with trange(eval_steps, disable=self.verbose != 1) as v:
J
jrzaurin 已提交
641 642
                        for i, (data, target) in zip(v, eval_loader):
                            v.set_description("valid")
643 644 645 646 647 648 649 650
                            score, val_loss = self._validation_step(data, target, i)
                            if score is not None:
                                v.set_postfix(
                                    metrics={
                                        k: np.round(v, 4) for k, v in score.items()
                                    },
                                    loss=val_loss,
                                )
651
                            else:
652
                                v.set_postfix(loss=val_loss)
J
jrzaurin 已提交
653
                    epoch_logs["val_loss"] = val_loss
654 655 656 657
                    if score is not None:
                        for k, v in score.items():
                            log_k = "_".join(["val", k])
                            epoch_logs[log_k] = v
J
jrzaurin 已提交
658 659 660
            if self.lr_scheduler:
                self._lr_scheduler_step(step_location="on_epoch_end")
            #  log and check if early_stop...
661
            self.callback_container.on_epoch_end(epoch, epoch_logs)
662
            if self.early_stop:
J
jrzaurin 已提交
663
                self.callback_container.on_train_end(epoch_logs)
664
                break
J
jrzaurin 已提交
665
            self.callback_container.on_train_end(epoch_logs)
666 667
        self.train()

J
jrzaurin 已提交
668 669
    def predict(
        self,
J
jrzaurin 已提交
670 671
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
J
jrzaurin 已提交
672 673 674 675
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
676
        r"""Returns the predictions
677 678 679

        Parameters
        ----------
680
        X_wide: np.ndarray, Optional. Default=None
681 682
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
683
        X_deep: np.ndarray, Optional. Default=None
684 685
            Input for the ``deepdense`` model component.
            See :class:`pytorch_widedeep.preprocessing.DensePreprocessor`
686
        X_text: np.ndarray, Optional. Default=None
687 688
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
689
        X_img : np.ndarray, Optional. Default=None
690 691 692 693 694 695 696
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
        X_test: Dict[str, np.ndarray], Optional. Default=None
            Testing dataset for the different model components. Keys are
            `'X_wide'`, `'X_deep'`, `'X_text'`, `'X_img'` and `'target'` the values are
            the corresponding matrices.

697

698 699
        .. note:: WideDeep assumes that `X_wide`, `X_deep` and `target` ALWAYS exist,
            while `X_text` and `X_img` are optional.
700

701
        """
702
        preds_l = self._predict(X_wide, X_deep, X_text, X_img, X_test)
703 704 705 706
        if self.method == "regression":
            return np.vstack(preds_l).squeeze(1)
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
J
jrzaurin 已提交
707
            return (preds > 0.5).astype("int")
708 709 710
        if self.method == "multiclass":
            preds = np.vstack(preds_l)
            return np.argmax(preds, 1)
711

J
jrzaurin 已提交
712 713
    def predict_proba(
        self,
J
jrzaurin 已提交
714 715
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
J
jrzaurin 已提交
716 717 718 719
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
720
        r"""Returns the predicted probabilities for the test dataset for  binary
721
        and multiclass methods
722
        """
723
        preds_l = self._predict(X_wide, X_deep, X_text, X_img, X_test)
724 725
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
J
jrzaurin 已提交
726 727 728
            probs = np.zeros([preds.shape[0], 2])
            probs[:, 0] = 1 - preds
            probs[:, 1] = preds
729 730 731
            return probs
        if self.method == "multiclass":
            return np.vstack(preds_l)
732

J
jrzaurin 已提交
733 734
    def get_embeddings(
        self, col_name: str, cat_encoding_dict: Dict[str, Dict[str, int]]
735
    ) -> Dict[str, np.ndarray]:  # pragma: no cover
736 737 738
        r"""Returns the learned embeddings for the categorical features passed through
        ``deepdense``.

739 740 741 742
        This method is designed to take an encoding dictionary in the same
        format as that of the :obj:`LabelEncoder` Attribute of the class
        :obj:`DensePreprocessor`. See
        :class:`pytorch_widedeep.preprocessing.DensePreprocessor` and
743
        :class:`pytorch_widedeep.utils.dense_utils.LabelEncder`.
744 745 746

        Parameters
        ----------
747 748
        col_name: str,
            Column name of the feature we want to get the embeddings for
749 750 751 752 753 754
        cat_encoding_dict: Dict[str, Dict[str, int]]
            Dictionary containing the categorical encodings, e.g:

        Examples
        --------

755 756 757 758 759 760 761 762
        For a series of comprehensive examples please, see the `example
        <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_.
        folder in the repo

        For completion, here we include a `"fabricated"` example, i.e.
        assuming we have already trained the model, that we have the
        categorical encodings in a dictionary name ``encoding_dict``, and that
        there is a column called `'education'`:
763

764
        >>> # model.get_embeddings(col_name='education', cat_encoding_dict=encoding_dict)
765
        """
J
jrzaurin 已提交
766 767
        for n, p in self.named_parameters():
            if "embed_layers" in n and col_name in n:
768 769
                embed_mtx = p.cpu().data.numpy()
        encoding_dict = cat_encoding_dict[col_name]
J
jrzaurin 已提交
770
        inv_encoding_dict = {v: k for k, v in encoding_dict.items()}
771
        cat_embed_dict = {}
J
jrzaurin 已提交
772
        for idx, value in inv_encoding_dict.items():
773
            cat_embed_dict[value] = embed_mtx[idx]
774 775
        return cat_embed_dict

J
jrzaurin 已提交
776
    def _loss_fn(self, y_pred: Tensor, y_true: Tensor) -> Tensor:  # type: ignore
777 778
        if self.with_focal_loss:
            return FocalLoss(self.alpha, self.gamma)(y_pred, y_true)
J
jrzaurin 已提交
779
        if self.method == "regression":
780
            return F.mse_loss(y_pred, y_true.view(-1, 1))
J
jrzaurin 已提交
781
        if self.method == "binary":
782
            return F.binary_cross_entropy_with_logits(
J
jrzaurin 已提交
783 784 785
                y_pred, y_true.view(-1, 1), weight=self.class_weight
            )
        if self.method == "multiclass":
786 787
            return F.cross_entropy(y_pred, y_true, weight=self.class_weight)

788
    def _train_val_split(  # noqa: C901
J
jrzaurin 已提交
789 790 791 792 793 794 795 796 797 798
        self,
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
    ):
799 800 801 802 803 804 805 806 807
        r"""
        If a validation set (X_val) is passed to the fit method, or val_split
        is specified, the train/val split will happen internally. A number of
        options are allowed in terms of data inputs. For parameter
        information, please, see the .fit() method documentation

        Returns
        -------
        train_set: WideDeepDataset
808 809 810
            :obj:`WideDeepDataset` object that will be loaded through
            :obj:`torch.utils.data.DataLoader`. See
            :class:`pytorch_widedeep.models._wd_dataset`
811
        eval_set : WideDeepDataset
812 813 814
            :obj:`WideDeepDataset` object that will be loaded through
            :obj:`torch.utils.data.DataLoader`. See
            :class:`pytorch_widedeep.models._wd_dataset`
815
        """
816 817 818 819 820

        if X_val is not None:
            assert (
                X_train is not None
            ), "if the validation set is passed as a dictionary, the training set must also be a dictionary"
J
jrzaurin 已提交
821
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
822 823 824 825 826 827 828 829 830 831 832 833 834 835 836
            eval_set = WideDeepDataset(**X_val, transforms=self.transforms)  # type: ignore
        elif val_split is not None:
            if not X_train:
                X_train = self._build_train_dict(X_wide, X_deep, X_text, X_img, target)
            y_tr, y_val, idx_tr, idx_val = train_test_split(
                X_train["target"],
                np.arange(len(X_train["target"])),
                test_size=val_split,
                stratify=X_train["target"] if self.method != "regression" else None,
            )
            X_tr, X_val = {"target": y_tr}, {"target": y_val}
            if "X_wide" in X_train.keys():
                X_tr["X_wide"], X_val["X_wide"] = (
                    X_train["X_wide"][idx_tr],
                    X_train["X_wide"][idx_val],
J
jrzaurin 已提交
837
                )
838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853
            if "X_deep" in X_train.keys():
                X_tr["X_deep"], X_val["X_deep"] = (
                    X_train["X_deep"][idx_tr],
                    X_train["X_deep"][idx_val],
                )
            if "X_text" in X_train.keys():
                X_tr["X_text"], X_val["X_text"] = (
                    X_train["X_text"][idx_tr],
                    X_train["X_text"][idx_val],
                )
            if "X_img" in X_train.keys():
                X_tr["X_img"], X_val["X_img"] = (
                    X_train["X_img"][idx_tr],
                    X_train["X_img"][idx_val],
                )
            train_set = WideDeepDataset(**X_tr, transforms=self.transforms)  # type: ignore
J
jrzaurin 已提交
854
            eval_set = WideDeepDataset(**X_val, transforms=self.transforms)  # type: ignore
855 856 857 858 859 860
        else:
            if not X_train:
                X_train = self._build_train_dict(X_wide, X_deep, X_text, X_img, target)
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
            eval_set = None

861 862
        return train_set, eval_set

J
jrzaurin 已提交
863 864 865 866 867 868 869 870 871 872 873 874
    def _warm_up(
        self,
        loader: DataLoader,
        n_epochs: int,
        max_lr: float,
        deeptext_gradual: bool,
        deeptext_layers: List[nn.Module],
        deeptext_max_lr: float,
        deepimage_gradual: bool,
        deepimage_layers: List[nn.Module],
        deepimage_max_lr: float,
        routine: str = "felbo",
875
    ):  # pragma: no cover
876 877 878
        r"""
        Simple wrappup to individually warm up model components
        """
879 880
        if self.deephead is not None:
            raise ValueError(
J
jrzaurin 已提交
881 882
                "Currently warming up is only supported without a fully connected 'DeepHead'"
            )
883 884
        # This is not the most elegant solution, but is a soluton "in-between"
        # a non elegant one and re-factoring the whole code
885
        warmer = WarmUp(self._loss_fn, self.metric, self.method, self.verbose)
J
jrzaurin 已提交
886 887
        warmer.warm_all(self.wide, "wide", loader, n_epochs, max_lr)
        warmer.warm_all(self.deepdense, "deepdense", loader, n_epochs, max_lr)
888 889
        if self.deeptext:
            if deeptext_gradual:
J
jrzaurin 已提交
890 891 892 893 894 895 896 897 898 899
                warmer.warm_gradual(
                    self.deeptext,
                    "deeptext",
                    loader,
                    deeptext_max_lr,
                    deeptext_layers,
                    routine,
                )
            else:
                warmer.warm_all(self.deeptext, "deeptext", loader, n_epochs, max_lr)
900 901
        if self.deepimage:
            if deepimage_gradual:
J
jrzaurin 已提交
902 903 904 905 906 907 908 909 910 911
                warmer.warm_gradual(
                    self.deepimage,
                    "deepimage",
                    loader,
                    deepimage_max_lr,
                    deepimage_layers,
                    routine,
                )
            else:
                warmer.warm_all(self.deepimage, "deepimage", loader, n_epochs, max_lr)
912

913
    def _lr_scheduler_step(self, step_location: str):  # noqa: C901
914 915 916
        r"""
        Function to execute the learning rate schedulers steps.
        If the lr_scheduler is Cyclic (i.e. CyclicLR or OneCycleLR), the step
917
        must happen after training each bach durig training. On the other
918 919 920 921 922 923 924 925
        hand, if the  scheduler is not Cyclic, is expected to be called after
        validation.

        Parameters
        ----------
        step_location: Str
            Indicates where to run the lr_scheduler step
        """
J
jrzaurin 已提交
926 927 928 929 930 931 932 933 934 935 936 937
        if (
            self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"
            and self.cyclic
        ):
            if step_location == "on_batch_end":
                for model_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
            elif step_location == "on_epoch_end":
                for scheduler_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" not in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
938
        elif self.cyclic:
J
jrzaurin 已提交
939 940 941 942 943 944 945 946 947 948 949 950 951 952 953
            if step_location == "on_batch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler":
            if step_location == "on_epoch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif step_location == "on_epoch_end":
            self.lr_scheduler.step()  # type: ignore
        else:
            pass

    def _training_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
954
        self.train()
J
jrzaurin 已提交
955 956
        X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
        y = target.float() if self.method != "multiclass" else target
957 958 959
        y = y.cuda() if use_cuda else y

        self.optimizer.zero_grad()
960
        y_pred = self.forward(X)
961 962 963 964 965
        loss = self._loss_fn(y_pred, y)
        loss.backward()
        self.optimizer.step()

        self.train_running_loss += loss.item()
J
jrzaurin 已提交
966
        avg_loss = self.train_running_loss / (batch_idx + 1)
967 968

        if self.metric is not None:
969
            if self.method == "binary":
970
                score = self.metric(torch.sigmoid(y_pred), y)
971
            if self.method == "multiclass":
972 973
                score = self.metric(F.softmax(y_pred, dim=1), y)
            return score, avg_loss
974 975 976
        else:
            return None, avg_loss

J
jrzaurin 已提交
977
    def _validation_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
978 979 980

        self.eval()
        with torch.no_grad():
J
jrzaurin 已提交
981 982
            X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
            y = target.float() if self.method != "multiclass" else target
983 984
            y = y.cuda() if use_cuda else y

985
            y_pred = self.forward(X)
986 987
            loss = self._loss_fn(y_pred, y)
            self.valid_running_loss += loss.item()
J
jrzaurin 已提交
988
            avg_loss = self.valid_running_loss / (batch_idx + 1)
989 990

        if self.metric is not None:
991
            if self.method == "binary":
992
                score = self.metric(torch.sigmoid(y_pred), y)
993
            if self.method == "multiclass":
994 995
                score = self.metric(F.softmax(y_pred, dim=1), y)
            return score, avg_loss
996 997 998
        else:
            return None, avg_loss

J
jrzaurin 已提交
999 1000
    def _predict(
        self,
J
jrzaurin 已提交
1001 1002
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
J
jrzaurin 已提交
1003 1004 1005 1006
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> List:
1007 1008 1009
        r"""Hidden method to avoid code repetition in predict and
        predict_proba. For parameter information, please, see the .predict()
        method documentation
1010 1011 1012 1013
        """
        if X_test is not None:
            test_set = WideDeepDataset(**X_test)
        else:
J
jrzaurin 已提交
1014 1015 1016 1017 1018
            load_dict = {"X_wide": X_wide, "X_deep": X_deep}
            if X_text is not None:
                load_dict.update({"X_text": X_text})
            if X_img is not None:
                load_dict.update({"X_img": X_img})
1019 1020
            test_set = WideDeepDataset(**load_dict)

J
jrzaurin 已提交
1021 1022 1023 1024 1025 1026
        test_loader = DataLoader(
            dataset=test_set,
            batch_size=self.batch_size,
            num_workers=n_cpus,
            shuffle=False,
        )
1027
        test_steps = (len(test_loader.dataset) // test_loader.batch_size) + 1  # type: ignore[arg-type]
1028 1029 1030 1031 1032 1033

        self.eval()
        preds_l = []
        with torch.no_grad():
            with trange(test_steps, disable=self.verbose != 1) as t:
                for i, data in zip(t, test_loader):
J
jrzaurin 已提交
1034 1035
                    t.set_description("predict")
                    X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
1036 1037 1038
                    preds = self.forward(X)
                    if self.method == "binary":
                        preds = torch.sigmoid(preds)
J
jrzaurin 已提交
1039 1040
                    if self.method == "multiclass":
                        preds = F.softmax(preds, dim=1)
J
jrzaurin 已提交
1041
                    preds = preds.cpu().data.numpy()
1042 1043
                    preds_l.append(preds)
        self.train()
J
jrzaurin 已提交
1044
        return preds_l
1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058

    @staticmethod
    def _build_train_dict(X_wide, X_deep, X_text, X_img, target):
        X_train = {"target": target}
        if X_wide is not None:
            X_train["X_wide"] = X_wide
        if X_deep is not None:
            X_train["X_deep"] = X_deep
        if X_text is not None:
            X_train["X_text"] = X_text
        if X_img is not None:
            X_train["X_img"] = X_img
        return X_train

J
jrzaurin 已提交
1059
    @staticmethod  # noqa: C901
1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090
    def _check_params(
        deepdense, deeptext, deepimage, deephead, head_layers, head_dropout
    ):

        if deepdense is not None and not hasattr(deepdense, "output_dim"):
            raise AttributeError(
                "deepdense model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
        if deeptext is not None and not hasattr(deeptext, "output_dim"):
            raise AttributeError(
                "deeptext model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
        if deepimage is not None and not hasattr(deepimage, "output_dim"):
            raise AttributeError(
                "deepimage model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
        if deephead is not None and head_layers is not None:
            raise ValueError(
                "both 'deephead' and 'head_layers' are not None. Use one of the other, but not both"
            )
        if head_layers is not None and not deepdense and not deeptext and not deepimage:
            raise ValueError(
                "if 'head_layers' is not None, at least one deep component must be used"
            )
        if head_layers is not None and head_dropout is not None:
            assert len(head_layers) == len(
                head_dropout
            ), "'head_layers' and 'head_dropout' must have the same length"
J
jrzaurin 已提交
1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105
        if deephead is not None:
            deephead_inp_feat = next(deephead.parameters()).size(1)
            output_dim = 0
            if deepdense is not None:
                output_dim += deepdense.output_dim
            if deeptext is not None:
                output_dim += deeptext.output_dim
            if deepimage is not None:
                output_dim += deepimage.output_dim
            assert deephead_inp_feat == output_dim, (
                "if a custom 'deephead' is used its input features ({}) must be equal to "
                "the output features of the deep component ({})".format(
                    deephead_inp_feat, output_dim
                )
            )