wide_deep.py 49.3 KB
Newer Older
1
import os
2 3

import numpy as np
4
import torch
5 6
import torch.nn as nn
import torch.nn.functional as F
7 8 9
from tqdm import trange
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
10

11
from ..losses import FocalLoss
12 13 14 15 16
from ._warmup import WarmUp
from ..metrics import Metric, MetricCallback, MultipleMetrics
from ..wdtypes import *
from ..callbacks import History, Callback, CallbackContainer
from .deep_dense import dense_layer
17
from ._wd_dataset import WideDeepDataset
18
from ..initializers import Initializer, MultipleInitializer
19 20
from ._multiple_optimizer import MultipleOptimizer
from ._multiple_transforms import MultipleTransforms
21
from ._multiple_lr_scheduler import MultipleLRScheduler
J
jrzaurin 已提交
22

23
n_cpus = os.cpu_count()
24

25
use_cuda = torch.cuda.is_available()
26
device = torch.device("cuda" if use_cuda else "cpu")
27 28 29


class WideDeep(nn.Module):
30 31 32
    r"""Main collector class that combines all ``Wide``, ``DeepDense``,
    ``DeepText`` and ``DeepImage`` models.

33 34
    There are two options to combine these models that correspond to the two
    architectures that ``pytorch-widedeep`` can build.
35 36 37 38 39 40

        - Directly connecting the output of the model components to an ouput neuron(s).

        - Adding a `Fully-Connected Head` (FC-Head) on top of the deep models.
          This FC-Head will combine the output form the ``DeepDense``, ``DeepText`` and
          ``DeepImage`` and will be then connected to the output neuron(s).
41 42 43

    Parameters
    ----------
44
    wide: nn.Module
45 46 47 48
        Wide model. We recommend using the ``Wide`` class in this package.
        However, it is possible to use a custom model as long as is consistent
        with the required architecture, see
        :class:`pytorch_widedeep.models.wide.Wide`
49
    deepdense: nn.Module
50 51 52
        `Deep dense` model comprised by the embeddings for the categorical
        features combined with numerical (also referred as continuous)
        features. We recommend using the ``DeepDense`` class in this package.
53
        However, a custom model as long as is  consistent with the required
54
        architecture. See :class:`pytorch_widedeep.models.deep_dense.DeepDense`.
55
    deeptext: nn.Module, Optional
56 57 58 59
        `Deep text` model for the text input. Must be an object of class
        ``DeepText`` or a custom model as long as is consistent with the
        required architecture. See
        :class:`pytorch_widedeep.models.deep_dense.DeepText`
60
    deepimage: nn.Module, Optional
61 62 63 64
        `Deep Image` model for the images input. Must be an object of class
        ``DeepImage`` or a custom model as long as is consistent with the
        required architecture. See
        :class:`pytorch_widedeep.models.deep_dense.DeepImage`
65
    deephead: nn.Module, Optional
66
        `Dense` model consisting in a stack of dense layers. The FC-Head.
67
    head_layers: List, Optional
68
        Alternatively, we can use ``head_layers`` to specify the sizes of the
69
        stacked dense layers in the fc-head e.g: ``[128, 64]``
70
    head_dropout: List, Optional
71
        Dropout between the layers in ``head_layers``. e.g: ``[0.5, 0.5]``
72
    head_batchnorm: bool, Optional
73
        Specifies if batch normalizatin should be included in the dense layers
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
    pred_dim: int
        Size of the final wide and deep output layer containing the
        predictions. `1` for regression and binary classification or `number
        of classes` for multiclass classification.


    .. note:: With the exception of ``cyclic``, all attributes are direct assignations of
        the corresponding parameters used when calling ``compile``.  Therefore,
        see the parameters at
        :class:`pytorch_widedeep.models.wide_deep.WideDeep.compile` for a full
        list of the attributes of an instance of
        :class:`pytorch_widedeep.models.wide_deep.Wide`


    Attributes
    ----------
    cyclic: :obj:`bool`
        Attribute that indicates if any of the lr_schedulers is cyclic (i.e. ``CyclicLR`` or
        ``OneCycleLR``). See `Pytorch schedulers <https://pytorch.org/docs/stable/optim.html>`_.
93 94 95 96 97 98 99 100 101 102 103


    .. note:: While I recommend using the ``Wide`` and ``DeepDense`` classes within
        this package when building the corresponding model components, it is very
        likely that the user will want to use custom text and image models. That
        is perfectly possible. Simply, build them and pass them as the
        corresponding parameters. Note that the custom models MUST return a last
        layer of activations (i.e. not the final prediction) so that  these
        activations are collected by ``WideDeep`` and combined accordingly. In
        addition, the models MUST also contain an attribute ``output_dim`` with
        the size of these last layers of activations. See for example
104 105 106
        :class:`pytorch_widedeep.models.deep_dense.DeepDense`

    """
J
jrzaurin 已提交
107

108
    def __init__(  # noqa: C901
J
jrzaurin 已提交
109
        self,
110 111
        wide: Optional[nn.Module] = None,
        deepdense: Optional[nn.Module] = None,
J
jrzaurin 已提交
112 113 114 115 116 117
        deeptext: Optional[nn.Module] = None,
        deepimage: Optional[nn.Module] = None,
        deephead: Optional[nn.Module] = None,
        head_layers: Optional[List[int]] = None,
        head_dropout: Optional[List] = None,
        head_batchnorm: Optional[bool] = None,
118
        pred_dim: int = 1,
J
jrzaurin 已提交
119
    ):
120

121
        super(WideDeep, self).__init__()
122

J
jrzaurin 已提交
123 124 125 126 127 128 129 130 131
        self._check_model_components(
            wide,
            deepdense,
            deeptext,
            deepimage,
            deephead,
            head_layers,
            head_dropout,
            pred_dim,
132
        )
133

134 135 136
        # required as attribute just in case we pass a deephead
        self.pred_dim = pred_dim

137
        # The main 5 components of the wide and deep assemble
138 139
        self.wide = wide
        self.deepdense = deepdense
J
jrzaurin 已提交
140
        self.deeptext = deeptext
141
        self.deepimage = deepimage
142 143 144 145
        self.deephead = deephead

        if self.deephead is None:
            if head_layers is not None:
146 147 148
                input_dim = 0
                if self.deepdense is not None:
                    input_dim += self.deepdense.output_dim  # type:ignore
M
Minjin Choi 已提交
149
                if self.deeptext is not None:
150
                    input_dim += self.deeptext.output_dim  # type:ignore
M
Minjin Choi 已提交
151
                if self.deepimage is not None:
152
                    input_dim += self.deepimage.output_dim  # type:ignore
153
                head_layers = [input_dim] + head_layers
J
jrzaurin 已提交
154 155
                if not head_dropout:
                    head_dropout = [0.0] * (len(head_layers) - 1)
156 157 158
                self.deephead = nn.Sequential()
                for i in range(1, len(head_layers)):
                    self.deephead.add_module(
J
jrzaurin 已提交
159 160 161 162 163 164 165 166 167
                        "head_layer_{}".format(i - 1),
                        dense_layer(
                            head_layers[i - 1],
                            head_layers[i],
                            head_dropout[i - 1],
                            head_batchnorm,
                        ),
                    )
                self.deephead.add_module(
168
                    "head_out", nn.Linear(head_layers[-1], pred_dim)
J
jrzaurin 已提交
169
                )
170
            else:
171 172 173 174
                if self.deepdense is not None:
                    self.deepdense = nn.Sequential(
                        self.deepdense, nn.Linear(self.deepdense.output_dim, pred_dim)  # type: ignore
                    )
175 176
                if self.deeptext is not None:
                    self.deeptext = nn.Sequential(
177
                        self.deeptext, nn.Linear(self.deeptext.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
178
                    )
179 180
                if self.deepimage is not None:
                    self.deepimage = nn.Sequential(
181
                        self.deepimage, nn.Linear(self.deepimage.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
182
                    )
183 184
        # else:
        #     self.deephead
185

186
    def forward(self, X: Dict[str, Tensor]) -> Tensor:  # type: ignore  # noqa: C901
187

188
        # Wide output: direct connection to the output neuron(s)
189 190 191 192
        if self.wide is not None:
            out = self.wide(X["wide"])
        else:
            batch_size = X[list(X.keys())[0]].size(0)
193
            out = torch.zeros(batch_size, self.pred_dim).to(device)
194 195 196 197

        # Deep output: either connected directly to the output neuron(s) or
        # passed through a head first
        if self.deephead:
198 199 200
            if self.deepdense is not None:
                deepside = self.deepdense(X["deepdense"])
            else:
201
                deepside = torch.FloatTensor().to(device)
202
            if self.deeptext is not None:
J
jrzaurin 已提交
203
                deepside = torch.cat([deepside, self.deeptext(X["deeptext"])], axis=1)  # type: ignore
204
            if self.deepimage is not None:
J
jrzaurin 已提交
205
                deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1)  # type: ignore
206
            deephead_out = self.deephead(deepside)
207 208
            deepside_linear = nn.Linear(deephead_out.size(1), self.pred_dim).to(device)
            return out.add_(deepside_linear(deephead_out))
209
        else:
210 211
            if self.deepdense is not None:
                out.add_(self.deepdense(X["deepdense"]))
212
            if self.deeptext is not None:
213
                out.add_(self.deeptext(X["deeptext"]))
214
            if self.deepimage is not None:
215
                out.add_(self.deepimage(X["deepimage"]))
216 217
            return out

218
    def compile(  # noqa: C901
J
jrzaurin 已提交
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
        self,
        method: str,
        optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
        lr_schedulers: Optional[Union[LRScheduler, Dict[str, LRScheduler]]] = None,
        initializers: Optional[Dict[str, Initializer]] = None,
        transforms: Optional[List[Transforms]] = None,
        callbacks: Optional[List[Callback]] = None,
        metrics: Optional[List[Metric]] = None,
        class_weight: Optional[Union[float, List[float], Tuple[float]]] = None,
        with_focal_loss: bool = False,
        alpha: float = 0.25,
        gamma: float = 2,
        verbose: int = 1,
        seed: int = 1,
    ):
234
        r"""Method to set the of attributes that will be used during the
235
        training process.
236 237 238

        Parameters
        ----------
239
        method: str
240 241 242 243 244 245 246 247 248
            One of `regression`, `binary` or `multiclass`. The default when
            performing a `regression`, a `binary` classification or a
            `multiclass` classification is the `mean squared error
            <https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.mse_loss>`_
            (MSE), `Binary Cross Entropy
            <https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.binary_cross_entropy>`_
            (BCE) and `Cross Entropy
            <https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.cross_entropy>`_
            (CE) respectively.
249
        optimizers: Union[Optimizer, Dict[str, Optimizer]], Optional, Default=AdamW
250 251 252 253 254
            - An instance of ``pytorch``'s ``Optimizer`` object (e.g. :obj:`torch.optim.Adam()`) or
            - a dictionary where there keys are the model components (i.e.
              `'wide'`, `'deepdense'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`)  and
              the values are the corresponding optimizers. If multiple optimizers are used
              the  dictionary MUST contain an optimizer per model component.
255 256 257

            See `Pytorch optimizers <https://pytorch.org/docs/stable/optim.html>`_.
        lr_schedulers: Union[LRScheduler, Dict[str, LRScheduler]], Optional, Default=None
258 259 260 261
            - An instance of ``pytorch``'s ``LRScheduler`` object (e.g
              :obj:`torch.optim.lr_scheduler.StepLR(opt, step_size=5)`) or
            - a dictionary where there keys are the model componenst (i.e. `'wide'`,
              `'deepdense'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and the
262 263 264 265
              values are the corresponding learning rate schedulers.

            See `Pytorch schedulers <https://pytorch.org/docs/stable/optim.html>`_.
        initializers: Dict[str, Initializer], Optional. Default=None
266 267
            Dict where there keys are the model components (i.e. `'wide'`,
            `'deepdense'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and the
268
            values are the corresponding initializers.
269 270
            See `Pytorch initializers <https://pytorch.org/docs/stable/nn.init.html>`_.
        transforms: List[Transforms], Optional. Default=None
271 272 273 274 275
            ``Transforms`` is a custom type. See
            :obj:`pytorch_widedeep.wdtypes`. List with
            :obj:`torchvision.transforms` to be applied to the image component
            of the model (i.e. ``deepimage``) See `torchvision transforms
            <https://pytorch.org/docs/stable/torchvision/transforms.html>`_.
276
        callbacks: List[Callback], Optional. Default=None
277 278 279 280
            Callbacks available are: ``ModelCheckpoint``, ``EarlyStopping``,
            and ``LRHistory``. The ``History`` callback is used by default.
            See the ``Callbacks`` section in this documentation or
            :obj:`pytorch_widedeep.callbacks`
281
        metrics: List[Metric], Optional. Default=None
282 283 284
            Metrics available are: ``Accuracy``, ``Precision``, ``Recall``,
            ``FBetaScore`` and ``F1Score``.  See the ``Metrics`` section in
            this documentation or :obj:`pytorch_widedeep.metrics`
285 286 287 288 289 290 291 292 293 294 295 296 297
        class_weight: Union[float, List[float], Tuple[float]]. Optional. Default=None
            - float indicating the weight of the minority class in binary classification
              problems (e.g. 9.)
            - a list or tuple with weights for the different classes in multiclass
              classification problems  (e.g. [1., 2., 3.]). The weights do
              not neccesarily need to be normalised. If your loss function
              uses reduction='mean', the loss will be normalized by the sum
              of the corresponding weights for each element. If you are
              using reduction='none', you would have to take care of the
              normalization yourself. See `this discussion
              <https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/10>`_.
        with_focal_loss: bool, Optional. Default=False
            Boolean indicating whether to use the Focal Loss for highly imbalanced problems.
298 299
            For details on the focal loss see the `original paper
            <https://arxiv.org/pdf/1708.02002.pdf>`_.
300 301 302 303 304
        alpha: float. Default=0.25
            Focal Loss alpha parameter.
        gamma: float. Default=2
            Focal Loss gamma parameter.
        verbose: int
305
            Setting it to 0 will print nothing during training.
306
        seed: int, Default=1
307
            Random seed to be used throughout all the methods
308 309 310

        Example
        --------
311 312 313 314 315 316
        >>> import torch
        >>> from torchvision.transforms import ToTensor
        >>>
        >>> from pytorch_widedeep.callbacks import EarlyStopping, LRHistory
        >>> from pytorch_widedeep.initializers import KaimingNormal, KaimingUniform, Normal, Uniform
        >>> from pytorch_widedeep.models import DeepDenseResnet, DeepImage, DeepText, Wide, WideDeep
317
        >>> from pytorch_widedeep.optim import RAdam
318 319 320 321 322 323
        >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
        >>> deep_column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
        >>> wide = Wide(10, 1)
        >>> deepdense = DeepDenseResnet(blocks=[8, 4], deep_column_idx=deep_column_idx, embed_input=embed_input)
        >>> deeptext = DeepText(vocab_size=10, embed_dim=4, padding_idx=0)
        >>> deepimage = DeepImage(pretrained=False)
324
        >>> model = WideDeep(wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage)
325
        >>>
326 327 328
        >>> wide_opt = torch.optim.Adam(model.wide.parameters())
        >>> deep_opt = torch.optim.Adam(model.deepdense.parameters())
        >>> text_opt = RAdam(model.deeptext.parameters())
329 330
        >>> img_opt = RAdam(model.deepimage.parameters())
        >>>
331 332 333
        >>> wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=5)
        >>> deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=3)
        >>> text_sch = torch.optim.lr_scheduler.StepLR(text_opt, step_size=5)
334 335 336 337 338 339 340 341
        >>> img_sch = torch.optim.lr_scheduler.StepLR(img_opt, step_size=3)
        >>> optimizers = {"wide": wide_opt, "deepdense": deep_opt, "deeptext": text_opt, "deepimage": img_opt}
        >>> schedulers = {"wide": wide_sch, "deepdense": deep_sch, "deeptext": text_sch, "deepimage": img_sch}
        >>> initializers = {"wide": Uniform, "deepdense": Normal, "deeptext": KaimingNormal, "deepimage": KaimingUniform}
        >>> transforms = [ToTensor]
        >>> callbacks = [LRHistory(n_epochs=4), EarlyStopping]
        >>> model.compile(method="regression", initializers=initializers, optimizers=optimizers,
        ... lr_schedulers=schedulers, callbacks=callbacks, transforms=transforms)
342
        """
343 344 345

        if isinstance(optimizers, Dict) and not isinstance(lr_schedulers, Dict):
            raise ValueError(
J
jrzaurin 已提交
346 347 348
                "''optimizers' and 'lr_schedulers' must have consistent type: "
                "(Optimizer and LRScheduler) or (Dict[str, Optimizer] and Dict[str, LRScheduler]) "
                "Please, read the documentation or see the examples for more details"
349 350
            )

351
        self.verbose = verbose
352
        self.seed = seed
353
        self.early_stop = False
354
        self.method = method
355
        self.with_focal_loss = with_focal_loss
J
jrzaurin 已提交
356 357
        if self.with_focal_loss:
            self.alpha, self.gamma = alpha, gamma
358

359
        if isinstance(class_weight, float):
J
jrzaurin 已提交
360 361 362
            self.class_weight = torch.tensor([1.0 - class_weight, class_weight])
        elif isinstance(class_weight, (tuple, list)):
            self.class_weight = torch.tensor(class_weight)
363 364
        else:
            self.class_weight = None
365 366

        if initializers is not None:
367
            self.initializer = MultipleInitializer(initializers, verbose=self.verbose)
368 369
            self.initializer.apply(self)

370 371
        if optimizers is not None:
            if isinstance(optimizers, Optimizer):
J
jrzaurin 已提交
372
                self.optimizer: Union[Optimizer, MultipleOptimizer] = optimizers
373
            elif isinstance(optimizers, Dict):
374
                opt_names = list(optimizers.keys())
J
jrzaurin 已提交
375 376 377
                mod_names = [n for n, c in self.named_children()]
                for mn in mod_names:
                    assert mn in opt_names, "No optimizer found for {}".format(mn)
378
                self.optimizer = MultipleOptimizer(optimizers)
379
        else:
J
jrzaurin 已提交
380
            self.optimizer = torch.optim.AdamW(self.parameters())  # type: ignore
381

382 383
        if lr_schedulers is not None:
            if isinstance(lr_schedulers, LRScheduler):
J
jrzaurin 已提交
384
                self.lr_scheduler: Union[
385 386
                    LRScheduler,
                    MultipleLRScheduler,
J
jrzaurin 已提交
387 388
                ] = lr_schedulers
                self.cyclic = "cycl" in self.lr_scheduler.__class__.__name__.lower()
389
            else:
390
                self.lr_scheduler = MultipleLRScheduler(lr_schedulers)
J
jrzaurin 已提交
391 392 393 394 395
                scheduler_names = [
                    sc.__class__.__name__.lower()
                    for _, sc in self.lr_scheduler._schedulers.items()
                ]
                self.cyclic = any(["cycl" in sn for sn in scheduler_names])
396
        else:
397
            self.lr_scheduler, self.cyclic = None, False
398

399
        if transforms is not None:
J
jrzaurin 已提交
400
            self.transforms: MultipleTransforms = MultipleTransforms(transforms)()
401 402 403
        else:
            self.transforms = None

404
        self.history = History()
J
jrzaurin 已提交
405
        self.callbacks: List = [self.history]
406
        if callbacks is not None:
407
            for callback in callbacks:
J
jrzaurin 已提交
408 409
                if isinstance(callback, type):
                    callback = callback()
410
                self.callbacks.append(callback)
411 412 413

        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
414
            self.callbacks += [MetricCallback(self.metric)]
415 416
        else:
            self.metric = None
417

418 419
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self)
420

421
        self.to(device)
J
jrzaurin 已提交
422

423
    def fit(  # noqa: C901
J
jrzaurin 已提交
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
        self,
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
        n_epochs: int = 1,
        validation_freq: int = 1,
        batch_size: int = 32,
        patience: int = 10,
        warm_up: bool = False,
        warm_epochs: int = 4,
        warm_max_lr: float = 0.01,
        warm_deeptext_gradual: bool = False,
        warm_deeptext_max_lr: float = 0.01,
        warm_deeptext_layers: Optional[List[nn.Module]] = None,
        warm_deepimage_gradual: bool = False,
        warm_deepimage_max_lr: float = 0.01,
        warm_deepimage_layers: Optional[List[nn.Module]] = None,
        warm_routine: str = "howard",
    ):
448
        r"""Fit method. Must run after calling ``compile``
449 450 451

        Parameters
        ----------
452
        X_wide: np.ndarray, Optional. Default=None
453 454
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
455
        X_deep: np.ndarray, Optional. Default=None
456 457
            Input for the ``deepdense`` model component.
            See :class:`pytorch_widedeep.preprocessing.DensePreprocessor`
458
        X_text: np.ndarray, Optional. Default=None
459 460
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
461
        X_img : np.ndarray, Optional. Default=None
462 463 464 465
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
        X_train: Dict[str, np.ndarray], Optional. Default=None
            Training dataset for the different model components. Keys are
466
            `X_wide`, `'X_deep'`, `'X_text'`, `'X_img'` and `'target'`. Values are
467
            the corresponding matrices.
468
        X_val: Dict, Optional. Default=None
469
            Validation dataset for the different model component. Keys are
470 471
            `'X_wide'`, `'X_deep'`, `'X_text'`, `'X_img'` and `'target'`. Values are
            the corresponding matrices.
472 473
        val_split: float, Optional. Default=None
            train/val split fraction
474 475
        target: np.ndarray, Optional. Default=None
            target values
476 477 478 479 480 481
        n_epochs: int, Default=1
            number of epochs
        validation_freq: int, Default=1
            epochs validation frequency
        batch_size: int, Default=32
        patience: int, Default=10
482 483
            Number of epochs without improving the target metric before
            the fit process stops
484
        warm_up: bool, Default=False
485 486 487
            warm up model components individually before the joined training
            starts.

488 489 490 491 492
            ``pytorch_widedeep`` implements 3 warm up routines.

            - Warm up all trainable layers at once. This routine is is
              inspired by the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_. Using a
493 494 495 496 497 498
              Slanted Triangular learing (see `Leslie N. Smith paper
              <https://arxiv.org/pdf/1506.01186.pdf>`_), the process is the
              following: `i`) the learning rate will gradually increase for
              10% of the training steps from max_lr/10 to max_lr. `ii`) It
              will then gradually decrease to max_lr/10 for the remaining 90%
              of the steps. The optimizer used in the process is ``AdamW``.
499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514

            and two gradual warm up routines, where only certain layers are
            warmed up at each warm up step.

            - The so called `Felbo` gradual warm up rourine, inpired by the Felbo et al., 2017
              `DeepEmoji paper <https://arxiv.org/abs/1708.00524>`_.
            - The `Howard` routine based on the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_.

            For details on how these routines work, please see the examples
            section in this documentation and the corresponding repo.
        warm_epochs: int, Default=4
            Number of warm up epochs for those model components that will NOT
            be gradually warmed up. Those components with gradual warm up
            follow their corresponding specific routine.
        warm_max_lr: float, Default=0.01
515
            Maximum learning rate during the Triangular Learning rate cycle
516 517
            for those model componenst that will NOT be gradually warmed up
        warm_deeptext_gradual: bool, Default=False
518 519
            Boolean indicating if the deeptext component will be warmed
            up gradually
520
        warm_deeptext_max_lr: float, Default=0.01
521 522
            Maximum learning rate during the Triangular Learning rate cycle
            for the deeptext component
523
        warm_deeptext_layers: List, Optional, Default=None
524
            List of :obj:`nn.Modules` that will be warmed up gradually.
525 526 527 528 529

            .. note:: These have to be in `warm-up-order`: the layers or blocks
                close to the output neuron(s) first

        warm_deepimage_gradual: bool, Default=False
530 531 532 533 534
            Boolean indicating if the deepimage component will be warmed
            up gradually
        warm_deepimage_max_lr: Float, Default=0.01
            Maximum learning rate during the Triangular Learning rate cycle
            for the deepimage component
535
        warm_deepimage_layers: List, Optional, Default=None
536
            List of :obj:`nn.Modules` that will be warmed up gradually.
537

538 539 540
            .. note:: These have to be in `warm-up-order`: the layers or blocks
                close to the output neuron(s) first

541
        warm_routine: str, Default=`felbo`
542 543 544 545 546
            Warm up routine. On of `felbo` or `howard`. See the examples
            section in this documentation and the corresponding repo for
            details on how to use warm up routines

        Examples
547
        --------
548 549 550 551 552 553 554

        For a series of comprehensive examples please, see the `example
        <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_.
        folder in the repo

        For completion, here we include some `"fabricated"` examples, i.e. these assume
        you have already built and compiled the model
555

556 557

        >>> # Ex 1. using train input arrays directly and no validation
558
        >>> # model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256)
559

560 561

        >>> # Ex 2: using train input arrays directly and validation with val_split
562
        >>> # model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256, val_split=0.2)
563

564 565

        >>> # Ex 3: using train dict and val_split
566 567
        >>> # X_train = {'X_wide': X_wide, 'X_deep': X_deep, 'target': y}
        >>> # model.fit(X_train, n_epochs=10, batch_size=256, val_split=0.2)
568

569 570

        >>> # Ex 4: validation using training and validation dicts
571 572 573
        >>> # X_train = {'X_wide': X_wide_tr, 'X_deep': X_deep_tr, 'target': y_tr}
        >>> # X_val = {'X_wide': X_wide_val, 'X_deep': X_deep_val, 'target': y_val}
        >>> # model.fit(X_train=X_train, X_val=X_val n_epochs=10, batch_size=256)
574

575
        """
576

577
        self.batch_size = batch_size
J
jrzaurin 已提交
578 579 580 581 582 583
        train_set, eval_set = self._train_val_split(
            X_wide, X_deep, X_text, X_img, X_train, X_val, val_split, target
        )
        train_loader = DataLoader(
            dataset=train_set, batch_size=batch_size, num_workers=n_cpus
        )
584 585
        if warm_up:
            # warm up...
J
jrzaurin 已提交
586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603
            self._warm_up(
                train_loader,
                warm_epochs,
                warm_max_lr,
                warm_deeptext_gradual,
                warm_deeptext_layers,
                warm_deeptext_max_lr,
                warm_deepimage_gradual,
                warm_deepimage_layers,
                warm_deepimage_max_lr,
                warm_routine,
            )
        train_steps = len(train_loader)
        self.callback_container.on_train_begin(
            {"batch_size": batch_size, "train_steps": train_steps, "n_epochs": n_epochs}
        )
        if self.verbose:
            print("Training")
604
        for epoch in range(n_epochs):
605
            # train step...
J
jrzaurin 已提交
606
            epoch_logs: Dict[str, float] = {}
607
            self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
J
jrzaurin 已提交
608
            self.train_running_loss = 0.0
609
            with trange(train_steps, disable=self.verbose != 1) as t:
J
jrzaurin 已提交
610 611
                for batch_idx, (data, target) in zip(t, train_loader):
                    t.set_description("epoch %i" % (epoch + 1))
612 613 614 615 616 617
                    score, train_loss = self._training_step(data, target, batch_idx)
                    if score is not None:
                        t.set_postfix(
                            metrics={k: np.round(v, 4) for k, v in score.items()},
                            loss=train_loss,
                        )
618
                    else:
619
                        t.set_postfix(loss=train_loss)
J
jrzaurin 已提交
620 621
                    if self.lr_scheduler:
                        self._lr_scheduler_step(step_location="on_batch_end")
622
                    self.callback_container.on_batch_end(batch=batch_idx)
J
jrzaurin 已提交
623
            epoch_logs["train_loss"] = train_loss
624 625 626 627
            if score is not None:
                for k, v in score.items():
                    log_k = "_".join(["train", k])
                    epoch_logs[log_k] = v
628
            # eval step...
J
jrzaurin 已提交
629
            if epoch % validation_freq == (validation_freq - 1):
630
                if eval_set is not None:
J
jrzaurin 已提交
631 632 633 634 635 636 637 638
                    eval_loader = DataLoader(
                        dataset=eval_set,
                        batch_size=batch_size,
                        num_workers=n_cpus,
                        shuffle=False,
                    )
                    eval_steps = len(eval_loader)
                    self.valid_running_loss = 0.0
639
                    with trange(eval_steps, disable=self.verbose != 1) as v:
J
jrzaurin 已提交
640 641
                        for i, (data, target) in zip(v, eval_loader):
                            v.set_description("valid")
642 643 644 645 646 647 648 649
                            score, val_loss = self._validation_step(data, target, i)
                            if score is not None:
                                v.set_postfix(
                                    metrics={
                                        k: np.round(v, 4) for k, v in score.items()
                                    },
                                    loss=val_loss,
                                )
650
                            else:
651
                                v.set_postfix(loss=val_loss)
J
jrzaurin 已提交
652
                    epoch_logs["val_loss"] = val_loss
653 654 655 656
                    if score is not None:
                        for k, v in score.items():
                            log_k = "_".join(["val", k])
                            epoch_logs[log_k] = v
J
jrzaurin 已提交
657 658 659
            if self.lr_scheduler:
                self._lr_scheduler_step(step_location="on_epoch_end")
            #  log and check if early_stop...
660
            self.callback_container.on_epoch_end(epoch, epoch_logs)
661
            if self.early_stop:
J
jrzaurin 已提交
662
                self.callback_container.on_train_end(epoch_logs)
663
                break
J
jrzaurin 已提交
664
            self.callback_container.on_train_end(epoch_logs)
665 666
        self.train()

J
jrzaurin 已提交
667 668
    def predict(
        self,
J
jrzaurin 已提交
669 670
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
J
jrzaurin 已提交
671 672 673 674
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
675
        r"""Returns the predictions
676 677 678

        Parameters
        ----------
679
        X_wide: np.ndarray, Optional. Default=None
680 681
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
682
        X_deep: np.ndarray, Optional. Default=None
683 684
            Input for the ``deepdense`` model component.
            See :class:`pytorch_widedeep.preprocessing.DensePreprocessor`
685
        X_text: np.ndarray, Optional. Default=None
686 687
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
688
        X_img : np.ndarray, Optional. Default=None
689 690 691 692 693 694 695
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
        X_test: Dict[str, np.ndarray], Optional. Default=None
            Testing dataset for the different model components. Keys are
            `'X_wide'`, `'X_deep'`, `'X_text'`, `'X_img'` and `'target'` the values are
            the corresponding matrices.

696
        """
697
        preds_l = self._predict(X_wide, X_deep, X_text, X_img, X_test)
698 699 700 701
        if self.method == "regression":
            return np.vstack(preds_l).squeeze(1)
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
J
jrzaurin 已提交
702
            return (preds > 0.5).astype("int")
703 704 705
        if self.method == "multiclass":
            preds = np.vstack(preds_l)
            return np.argmax(preds, 1)
706

J
jrzaurin 已提交
707 708
    def predict_proba(
        self,
J
jrzaurin 已提交
709 710
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
J
jrzaurin 已提交
711 712 713 714
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
715
        r"""Returns the predicted probabilities for the test dataset for  binary
716
        and multiclass methods
717
        """
718
        preds_l = self._predict(X_wide, X_deep, X_text, X_img, X_test)
719 720
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
J
jrzaurin 已提交
721 722 723
            probs = np.zeros([preds.shape[0], 2])
            probs[:, 0] = 1 - preds
            probs[:, 1] = preds
724 725 726
            return probs
        if self.method == "multiclass":
            return np.vstack(preds_l)
727

J
jrzaurin 已提交
728 729
    def get_embeddings(
        self, col_name: str, cat_encoding_dict: Dict[str, Dict[str, int]]
730
    ) -> Dict[str, np.ndarray]:  # pragma: no cover
731 732 733
        r"""Returns the learned embeddings for the categorical features passed through
        ``deepdense``.

734 735 736 737
        This method is designed to take an encoding dictionary in the same
        format as that of the :obj:`LabelEncoder` Attribute of the class
        :obj:`DensePreprocessor`. See
        :class:`pytorch_widedeep.preprocessing.DensePreprocessor` and
738
        :class:`pytorch_widedeep.utils.dense_utils.LabelEncder`.
739 740 741

        Parameters
        ----------
742 743
        col_name: str,
            Column name of the feature we want to get the embeddings for
744 745 746 747 748 749
        cat_encoding_dict: Dict[str, Dict[str, int]]
            Dictionary containing the categorical encodings, e.g:

        Examples
        --------

750 751 752 753 754 755 756 757
        For a series of comprehensive examples please, see the `example
        <https://github.com/jrzaurin/pytorch-widedeep/tree/master/examples>`_.
        folder in the repo

        For completion, here we include a `"fabricated"` example, i.e.
        assuming we have already trained the model, that we have the
        categorical encodings in a dictionary name ``encoding_dict``, and that
        there is a column called `'education'`:
758

759
        >>> # model.get_embeddings(col_name='education', cat_encoding_dict=encoding_dict)
760
        """
J
jrzaurin 已提交
761 762
        for n, p in self.named_parameters():
            if "embed_layers" in n and col_name in n:
763 764
                embed_mtx = p.cpu().data.numpy()
        encoding_dict = cat_encoding_dict[col_name]
J
jrzaurin 已提交
765
        inv_encoding_dict = {v: k for k, v in encoding_dict.items()}
766
        cat_embed_dict = {}
J
jrzaurin 已提交
767
        for idx, value in inv_encoding_dict.items():
768
            cat_embed_dict[value] = embed_mtx[idx]
769 770
        return cat_embed_dict

J
jrzaurin 已提交
771
    def _loss_fn(self, y_pred: Tensor, y_true: Tensor) -> Tensor:  # type: ignore
772 773
        if self.with_focal_loss:
            return FocalLoss(self.alpha, self.gamma)(y_pred, y_true)
J
jrzaurin 已提交
774
        if self.method == "regression":
775
            return F.mse_loss(y_pred, y_true.view(-1, 1))
J
jrzaurin 已提交
776
        if self.method == "binary":
777
            return F.binary_cross_entropy_with_logits(
J
jrzaurin 已提交
778 779 780
                y_pred, y_true.view(-1, 1), weight=self.class_weight
            )
        if self.method == "multiclass":
781 782
            return F.cross_entropy(y_pred, y_true, weight=self.class_weight)

783
    def _train_val_split(  # noqa: C901
J
jrzaurin 已提交
784 785 786 787 788 789 790 791 792 793
        self,
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
    ):
794 795 796 797 798 799 800 801 802
        r"""
        If a validation set (X_val) is passed to the fit method, or val_split
        is specified, the train/val split will happen internally. A number of
        options are allowed in terms of data inputs. For parameter
        information, please, see the .fit() method documentation

        Returns
        -------
        train_set: WideDeepDataset
803 804 805
            :obj:`WideDeepDataset` object that will be loaded through
            :obj:`torch.utils.data.DataLoader`. See
            :class:`pytorch_widedeep.models._wd_dataset`
806
        eval_set : WideDeepDataset
807 808 809
            :obj:`WideDeepDataset` object that will be loaded through
            :obj:`torch.utils.data.DataLoader`. See
            :class:`pytorch_widedeep.models._wd_dataset`
810
        """
811 812 813 814 815

        if X_val is not None:
            assert (
                X_train is not None
            ), "if the validation set is passed as a dictionary, the training set must also be a dictionary"
J
jrzaurin 已提交
816
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
817 818 819 820 821 822 823 824 825 826 827 828 829 830 831
            eval_set = WideDeepDataset(**X_val, transforms=self.transforms)  # type: ignore
        elif val_split is not None:
            if not X_train:
                X_train = self._build_train_dict(X_wide, X_deep, X_text, X_img, target)
            y_tr, y_val, idx_tr, idx_val = train_test_split(
                X_train["target"],
                np.arange(len(X_train["target"])),
                test_size=val_split,
                stratify=X_train["target"] if self.method != "regression" else None,
            )
            X_tr, X_val = {"target": y_tr}, {"target": y_val}
            if "X_wide" in X_train.keys():
                X_tr["X_wide"], X_val["X_wide"] = (
                    X_train["X_wide"][idx_tr],
                    X_train["X_wide"][idx_val],
J
jrzaurin 已提交
832
                )
833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848
            if "X_deep" in X_train.keys():
                X_tr["X_deep"], X_val["X_deep"] = (
                    X_train["X_deep"][idx_tr],
                    X_train["X_deep"][idx_val],
                )
            if "X_text" in X_train.keys():
                X_tr["X_text"], X_val["X_text"] = (
                    X_train["X_text"][idx_tr],
                    X_train["X_text"][idx_val],
                )
            if "X_img" in X_train.keys():
                X_tr["X_img"], X_val["X_img"] = (
                    X_train["X_img"][idx_tr],
                    X_train["X_img"][idx_val],
                )
            train_set = WideDeepDataset(**X_tr, transforms=self.transforms)  # type: ignore
J
jrzaurin 已提交
849
            eval_set = WideDeepDataset(**X_val, transforms=self.transforms)  # type: ignore
850 851 852 853 854 855
        else:
            if not X_train:
                X_train = self._build_train_dict(X_wide, X_deep, X_text, X_img, target)
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
            eval_set = None

856 857
        return train_set, eval_set

J
jrzaurin 已提交
858 859 860 861 862 863 864 865 866 867 868 869
    def _warm_up(
        self,
        loader: DataLoader,
        n_epochs: int,
        max_lr: float,
        deeptext_gradual: bool,
        deeptext_layers: List[nn.Module],
        deeptext_max_lr: float,
        deepimage_gradual: bool,
        deepimage_layers: List[nn.Module],
        deepimage_max_lr: float,
        routine: str = "felbo",
870
    ):  # pragma: no cover
871 872 873
        r"""
        Simple wrappup to individually warm up model components
        """
874 875
        if self.deephead is not None:
            raise ValueError(
J
jrzaurin 已提交
876 877
                "Currently warming up is only supported without a fully connected 'DeepHead'"
            )
878 879
        # This is not the most elegant solution, but is a soluton "in-between"
        # a non elegant one and re-factoring the whole code
880
        warmer = WarmUp(self._loss_fn, self.metric, self.method, self.verbose)
J
jrzaurin 已提交
881 882
        warmer.warm_all(self.wide, "wide", loader, n_epochs, max_lr)
        warmer.warm_all(self.deepdense, "deepdense", loader, n_epochs, max_lr)
883 884
        if self.deeptext:
            if deeptext_gradual:
J
jrzaurin 已提交
885 886 887 888 889 890 891 892 893 894
                warmer.warm_gradual(
                    self.deeptext,
                    "deeptext",
                    loader,
                    deeptext_max_lr,
                    deeptext_layers,
                    routine,
                )
            else:
                warmer.warm_all(self.deeptext, "deeptext", loader, n_epochs, max_lr)
895 896
        if self.deepimage:
            if deepimage_gradual:
J
jrzaurin 已提交
897 898 899 900 901 902 903 904 905 906
                warmer.warm_gradual(
                    self.deepimage,
                    "deepimage",
                    loader,
                    deepimage_max_lr,
                    deepimage_layers,
                    routine,
                )
            else:
                warmer.warm_all(self.deepimage, "deepimage", loader, n_epochs, max_lr)
907

908
    def _lr_scheduler_step(self, step_location: str):  # noqa: C901
909 910 911
        r"""
        Function to execute the learning rate schedulers steps.
        If the lr_scheduler is Cyclic (i.e. CyclicLR or OneCycleLR), the step
912
        must happen after training each bach durig training. On the other
913 914 915 916 917 918 919 920
        hand, if the  scheduler is not Cyclic, is expected to be called after
        validation.

        Parameters
        ----------
        step_location: Str
            Indicates where to run the lr_scheduler step
        """
J
jrzaurin 已提交
921 922 923 924 925 926 927 928 929 930 931 932
        if (
            self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"
            and self.cyclic
        ):
            if step_location == "on_batch_end":
                for model_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
            elif step_location == "on_epoch_end":
                for scheduler_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" not in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
933
        elif self.cyclic:
J
jrzaurin 已提交
934 935 936 937 938 939 940 941 942 943 944 945 946 947 948
            if step_location == "on_batch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler":
            if step_location == "on_epoch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif step_location == "on_epoch_end":
            self.lr_scheduler.step()  # type: ignore
        else:
            pass

    def _training_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
949
        self.train()
J
jrzaurin 已提交
950 951
        X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
        y = target.float() if self.method != "multiclass" else target
952
        y = y.to(device)
953 954

        self.optimizer.zero_grad()
955
        y_pred = self.forward(X)
956 957 958 959 960
        loss = self._loss_fn(y_pred, y)
        loss.backward()
        self.optimizer.step()

        self.train_running_loss += loss.item()
J
jrzaurin 已提交
961
        avg_loss = self.train_running_loss / (batch_idx + 1)
962 963

        if self.metric is not None:
964
            if self.method == "binary":
965
                score = self.metric(torch.sigmoid(y_pred), y)
966
            if self.method == "multiclass":
967 968
                score = self.metric(F.softmax(y_pred, dim=1), y)
            return score, avg_loss
969 970 971
        else:
            return None, avg_loss

J
jrzaurin 已提交
972
    def _validation_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
973 974 975

        self.eval()
        with torch.no_grad():
J
jrzaurin 已提交
976 977
            X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
            y = target.float() if self.method != "multiclass" else target
978
            y = y.to(device)
979

980
            y_pred = self.forward(X)
981 982
            loss = self._loss_fn(y_pred, y)
            self.valid_running_loss += loss.item()
J
jrzaurin 已提交
983
            avg_loss = self.valid_running_loss / (batch_idx + 1)
984 985

        if self.metric is not None:
986
            if self.method == "binary":
987
                score = self.metric(torch.sigmoid(y_pred), y)
988
            if self.method == "multiclass":
989 990
                score = self.metric(F.softmax(y_pred, dim=1), y)
            return score, avg_loss
991 992 993
        else:
            return None, avg_loss

J
jrzaurin 已提交
994 995
    def _predict(
        self,
J
jrzaurin 已提交
996 997
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
J
jrzaurin 已提交
998 999 1000 1001
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> List:
1002 1003 1004
        r"""Hidden method to avoid code repetition in predict and
        predict_proba. For parameter information, please, see the .predict()
        method documentation
1005 1006 1007 1008
        """
        if X_test is not None:
            test_set = WideDeepDataset(**X_test)
        else:
J
jrzaurin 已提交
1009 1010 1011 1012 1013
            load_dict = {}
            if X_wide is not None:
                load_dict = {"X_wide": X_wide}
            if X_deep is not None:
                load_dict.update({"X_deep": X_deep})
J
jrzaurin 已提交
1014 1015 1016 1017
            if X_text is not None:
                load_dict.update({"X_text": X_text})
            if X_img is not None:
                load_dict.update({"X_img": X_img})
1018 1019
            test_set = WideDeepDataset(**load_dict)

J
jrzaurin 已提交
1020 1021 1022 1023 1024 1025
        test_loader = DataLoader(
            dataset=test_set,
            batch_size=self.batch_size,
            num_workers=n_cpus,
            shuffle=False,
        )
1026
        test_steps = (len(test_loader.dataset) // test_loader.batch_size) + 1  # type: ignore[arg-type]
1027 1028 1029 1030 1031 1032

        self.eval()
        preds_l = []
        with torch.no_grad():
            with trange(test_steps, disable=self.verbose != 1) as t:
                for i, data in zip(t, test_loader):
J
jrzaurin 已提交
1033 1034
                    t.set_description("predict")
                    X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
1035 1036 1037
                    preds = self.forward(X)
                    if self.method == "binary":
                        preds = torch.sigmoid(preds)
J
jrzaurin 已提交
1038 1039
                    if self.method == "multiclass":
                        preds = F.softmax(preds, dim=1)
J
jrzaurin 已提交
1040
                    preds = preds.cpu().data.numpy()
1041 1042
                    preds_l.append(preds)
        self.train()
J
jrzaurin 已提交
1043
        return preds_l
1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057

    @staticmethod
    def _build_train_dict(X_wide, X_deep, X_text, X_img, target):
        X_train = {"target": target}
        if X_wide is not None:
            X_train["X_wide"] = X_wide
        if X_deep is not None:
            X_train["X_deep"] = X_deep
        if X_text is not None:
            X_train["X_text"] = X_text
        if X_img is not None:
            X_train["X_img"] = X_img
        return X_train

J
jrzaurin 已提交
1058
    @staticmethod  # noqa: C901
J
jrzaurin 已提交
1059 1060 1061 1062 1063 1064 1065 1066 1067
    def _check_model_components(
        wide,
        deepdense,
        deeptext,
        deepimage,
        deephead,
        head_layers,
        head_dropout,
        pred_dim,
1068 1069
    ):

J
jrzaurin 已提交
1070 1071 1072 1073 1074 1075 1076
        if wide is not None:
            assert wide.wide_linear.weight.size(1) == pred_dim, (
                "the 'pred_dim' of the wide component ({}) must be equal to the 'pred_dim' "
                "of the deep component and the overall model itself ({})".format(
                    wide.wide_linear.weight.size(1), pred_dim
                )
            )
1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103
        if deepdense is not None and not hasattr(deepdense, "output_dim"):
            raise AttributeError(
                "deepdense model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
        if deeptext is not None and not hasattr(deeptext, "output_dim"):
            raise AttributeError(
                "deeptext model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
        if deepimage is not None and not hasattr(deepimage, "output_dim"):
            raise AttributeError(
                "deepimage model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
        if deephead is not None and head_layers is not None:
            raise ValueError(
                "both 'deephead' and 'head_layers' are not None. Use one of the other, but not both"
            )
        if head_layers is not None and not deepdense and not deeptext and not deepimage:
            raise ValueError(
                "if 'head_layers' is not None, at least one deep component must be used"
            )
        if head_layers is not None and head_dropout is not None:
            assert len(head_layers) == len(
                head_dropout
            ), "'head_layers' and 'head_dropout' must have the same length"
J
jrzaurin 已提交
1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118
        if deephead is not None:
            deephead_inp_feat = next(deephead.parameters()).size(1)
            output_dim = 0
            if deepdense is not None:
                output_dim += deepdense.output_dim
            if deeptext is not None:
                output_dim += deeptext.output_dim
            if deepimage is not None:
                output_dim += deepimage.output_dim
            assert deephead_inp_feat == output_dim, (
                "if a custom 'deephead' is used its input features ({}) must be equal to "
                "the output features of the deep component ({})".format(
                    deephead_inp_feat, output_dim
                )
            )