wide_deep.py 46.8 KB
Newer Older
1
import os
2
import warnings
3 4

import numpy as np
5
import torch
6 7
import torch.nn as nn
import torch.nn.functional as F
8 9 10
from tqdm import trange
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
11

12
from ..losses import FocalLoss
13 14 15 16 17
from ._warmup import WarmUp
from ..metrics import Metric, MetricCallback, MultipleMetrics
from ..wdtypes import *
from ..callbacks import History, Callback, CallbackContainer
from .deep_dense import dense_layer
18
from ._wd_dataset import WideDeepDataset
19
from ..initializers import Initializer, MultipleInitializer
20 21
from ._multiple_optimizer import MultipleOptimizer
from ._multiple_transforms import MultipleTransforms
22
from ._multiple_lr_scheduler import MultipleLRScheduler
J
jrzaurin 已提交
23

24
n_cpus = os.cpu_count()
25 26 27 28
use_cuda = torch.cuda.is_available()


class WideDeep(nn.Module):
29 30 31
    r"""Main collector class that combines all ``Wide``, ``DeepDense``,
    ``DeepText`` and ``DeepImage`` models.

32 33
    There are two options to combine these models that correspond to the two
    architectures that ``pytorch-widedeep`` can build.
34 35 36 37 38 39

        - Directly connecting the output of the model components to an ouput neuron(s).

        - Adding a `Fully-Connected Head` (FC-Head) on top of the deep models.
          This FC-Head will combine the output form the ``DeepDense``, ``DeepText`` and
          ``DeepImage`` and will be then connected to the output neuron(s).
40 41 42

    Parameters
    ----------
43
    wide: nn.Module
44 45 46 47
        Wide model. We recommend using the ``Wide`` class in this package.
        However, it is possible to use a custom model as long as is consistent
        with the required architecture, see
        :class:`pytorch_widedeep.models.wide.Wide`
48
    deepdense: nn.Module
49 50 51
        `Deep dense` model comprised by the embeddings for the categorical
        features combined with numerical (also referred as continuous)
        features. We recommend using the ``DeepDense`` class in this package.
52
        However, a custom model as long as is  consistent with the required
53
        architecture. See :class:`pytorch_widedeep.models.deep_dense.DeepDense`.
54
    deeptext: nn.Module, Optional
55 56 57 58
        `Deep text` model for the text input. Must be an object of class
        ``DeepText`` or a custom model as long as is consistent with the
        required architecture. See
        :class:`pytorch_widedeep.models.deep_dense.DeepText`
59
    deepimage: nn.Module, Optional
60 61 62 63
        `Deep Image` model for the images input. Must be an object of class
        ``DeepImage`` or a custom model as long as is consistent with the
        required architecture. See
        :class:`pytorch_widedeep.models.deep_dense.DeepImage`
64
    deephead: nn.Module, Optional
65
        `Dense` model consisting in a stack of dense layers. The FC-Head.
66
    head_layers: List, Optional
67
        Alternatively, we can use ``head_layers`` to specify the sizes of the
68
        stacked dense layers in the fc-head e.g: ``[128, 64]``
69
    head_dropout: List, Optional
70
        Dropout between the layers in ``head_layers``. e.g: ``[0.5, 0.5]``
71
    head_batchnorm: bool, Optional
72
        Specifies if batch normalizatin should be included in the dense layers
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    pred_dim: int
        Size of the final wide and deep output layer containing the
        predictions. `1` for regression and binary classification or `number
        of classes` for multiclass classification.


    .. note:: With the exception of ``cyclic``, all attributes are direct assignations of
        the corresponding parameters used when calling ``compile``.  Therefore,
        see the parameters at
        :class:`pytorch_widedeep.models.wide_deep.WideDeep.compile` for a full
        list of the attributes of an instance of
        :class:`pytorch_widedeep.models.wide_deep.Wide`


    Attributes
    ----------
    cyclic: :obj:`bool`
        Attribute that indicates if any of the lr_schedulers is cyclic (i.e. ``CyclicLR`` or
        ``OneCycleLR``). See `Pytorch schedulers <https://pytorch.org/docs/stable/optim.html>`_.
92 93 94 95 96 97 98 99 100 101 102


    .. note:: While I recommend using the ``Wide`` and ``DeepDense`` classes within
        this package when building the corresponding model components, it is very
        likely that the user will want to use custom text and image models. That
        is perfectly possible. Simply, build them and pass them as the
        corresponding parameters. Note that the custom models MUST return a last
        layer of activations (i.e. not the final prediction) so that  these
        activations are collected by ``WideDeep`` and combined accordingly. In
        addition, the models MUST also contain an attribute ``output_dim`` with
        the size of these last layers of activations. See for example
103 104 105
        :class:`pytorch_widedeep.models.deep_dense.DeepDense`

    """
J
jrzaurin 已提交
106 107 108 109 110

    def __init__(
        self,
        wide: nn.Module,
        deepdense: nn.Module,
111
        pred_dim: int = 1,
J
jrzaurin 已提交
112 113 114 115 116 117 118
        deeptext: Optional[nn.Module] = None,
        deepimage: Optional[nn.Module] = None,
        deephead: Optional[nn.Module] = None,
        head_layers: Optional[List[int]] = None,
        head_dropout: Optional[List] = None,
        head_batchnorm: Optional[bool] = None,
    ):
119

120
        super(WideDeep, self).__init__()
121

122 123 124 125 126 127 128 129 130 131 132
        # check that model components have the required output_dim attribute
        if not hasattr(deepdense, 'output_dim'):
            raise AttributeError("deepdense model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepDense")
        if deeptext is not None and not hasattr(deeptext, 'output_dim'):
            raise AttributeError("deeptext model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText")
        if deepimage is not None and not hasattr(deepimage, 'output_dim'):
            raise AttributeError("deepimage model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText")

133
        # The main 5 components of the wide and deep assemble
134 135
        self.wide = wide
        self.deepdense = deepdense
J
jrzaurin 已提交
136
        self.deeptext = deeptext
137
        self.deepimage = deepimage
138 139
        self.deephead = deephead

140 141 142 143 144 145 146
        if deephead is not None and head_layers is not None:
            warnings.warn(
                "both 'deephead' and 'head_layers' are not None."
                "'deephead' takes priority and will be used",
                UserWarning,
            )

147 148
        if self.deephead is None:
            if head_layers is not None:
J
jrzaurin 已提交
149
                input_dim: int = self.deepdense.output_dim  # type:ignore
M
Minjin Choi 已提交
150
                if self.deeptext is not None:
151
                    input_dim += self.deeptext.output_dim  # type:ignore
M
Minjin Choi 已提交
152
                if self.deepimage is not None:
153
                    input_dim += self.deepimage.output_dim  # type:ignore
154
                head_layers = [input_dim] + head_layers
J
jrzaurin 已提交
155 156
                if not head_dropout:
                    head_dropout = [0.0] * (len(head_layers) - 1)
157 158 159
                self.deephead = nn.Sequential()
                for i in range(1, len(head_layers)):
                    self.deephead.add_module(
J
jrzaurin 已提交
160 161 162 163 164 165 166 167 168
                        "head_layer_{}".format(i - 1),
                        dense_layer(
                            head_layers[i - 1],
                            head_layers[i],
                            head_dropout[i - 1],
                            head_batchnorm,
                        ),
                    )
                self.deephead.add_module(
169
                    "head_out", nn.Linear(head_layers[-1], pred_dim)
J
jrzaurin 已提交
170
                )
171 172
            else:
                self.deepdense = nn.Sequential(
173
                    self.deepdense, nn.Linear(self.deepdense.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
174
                )
175 176
                if self.deeptext is not None:
                    self.deeptext = nn.Sequential(
177
                        self.deeptext, nn.Linear(self.deeptext.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
178
                    )
179 180
                if self.deepimage is not None:
                    self.deepimage = nn.Sequential(
181
                        self.deepimage, nn.Linear(self.deepimage.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
182
                    )
183

J
jrzaurin 已提交
184
    def forward(self, X: Dict[str, Tensor]) -> Tensor:  # type: ignore
185

186
        # Wide output: direct connection to the output neuron(s)
J
jrzaurin 已提交
187
        out = self.wide(X["wide"])
188 189 190 191

        # Deep output: either connected directly to the output neuron(s) or
        # passed through a head first
        if self.deephead:
J
jrzaurin 已提交
192
            deepside = self.deepdense(X["deepdense"])
193
            if self.deeptext is not None:
J
jrzaurin 已提交
194
                deepside = torch.cat([deepside, self.deeptext(X["deeptext"])], axis=1)  # type: ignore
195
            if self.deepimage is not None:
J
jrzaurin 已提交
196
                deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1)  # type: ignore
197
            deepside_out = self.deephead(deepside)
198 199
            return out.add_(deepside_out)
        else:
J
jrzaurin 已提交
200
            out.add_(self.deepdense(X["deepdense"]))
201
            if self.deeptext is not None:
J
jrzaurin 已提交
202
                out.add_(self.deeptext(X["deeptext"]))
203
            if self.deepimage is not None:
J
jrzaurin 已提交
204
                out.add_(self.deepimage(X["deepimage"]))
205 206
            return out

J
jrzaurin 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
    def compile(
        self,
        method: str,
        optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
        lr_schedulers: Optional[Union[LRScheduler, Dict[str, LRScheduler]]] = None,
        initializers: Optional[Dict[str, Initializer]] = None,
        transforms: Optional[List[Transforms]] = None,
        callbacks: Optional[List[Callback]] = None,
        metrics: Optional[List[Metric]] = None,
        class_weight: Optional[Union[float, List[float], Tuple[float]]] = None,
        with_focal_loss: bool = False,
        alpha: float = 0.25,
        gamma: float = 2,
        verbose: int = 1,
        seed: int = 1,
    ):
223
        r"""Method to set the of attributes that will be used during the
224
        training process.
225 226 227

        Parameters
        ----------
228 229 230
        method: str
            One of `regression`, `binary` or `multiclass`
        optimizers: Union[Optimizer, Dict[str, Optimizer]], Optional, Default=AdamW
231 232 233 234 235
            - An instance of ``pytorch``'s ``Optimizer`` object (e.g. :obj:`torch.optim.Adam()`) or
            - a dictionary where there keys are the model components (i.e.
              `'wide'`, `'deepdense'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`)  and
              the values are the corresponding optimizers. If multiple optimizers are used
              the  dictionary MUST contain an optimizer per model component.
236 237 238

            See `Pytorch optimizers <https://pytorch.org/docs/stable/optim.html>`_.
        lr_schedulers: Union[LRScheduler, Dict[str, LRScheduler]], Optional, Default=None
239 240 241 242
            - An instance of ``pytorch``'s ``LRScheduler`` object (e.g
              :obj:`torch.optim.lr_scheduler.StepLR(opt, step_size=5)`) or
            - a dictionary where there keys are the model componenst (i.e. `'wide'`,
              `'deepdense'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and the
243 244 245 246
              values are the corresponding learning rate schedulers.

            See `Pytorch schedulers <https://pytorch.org/docs/stable/optim.html>`_.
        initializers: Dict[str, Initializer], Optional. Default=None
247 248
            Dict where there keys are the model components (i.e. `'wide'`,
            `'deepdense'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and the
249
            values are the corresponding initializers.
250 251
            See `Pytorch initializers <https://pytorch.org/docs/stable/nn.init.html>`_.
        transforms: List[Transforms], Optional. Default=None
252 253 254 255 256
            ``Transforms`` is a custom type. See
            :obj:`pytorch_widedeep.wdtypes`. List with
            :obj:`torchvision.transforms` to be applied to the image component
            of the model (i.e. ``deepimage``) See `torchvision transforms
            <https://pytorch.org/docs/stable/torchvision/transforms.html>`_.
257
        callbacks: List[Callback], Optional. Default=None
258 259 260 261
            Callbacks available are: ``ModelCheckpoint``, ``EarlyStopping``,
            and ``LRHistory``. The ``History`` callback is used by default.
            See the ``Callbacks`` section in this documentation or
            :obj:`pytorch_widedeep.callbacks`
262
        metrics: List[Metric], Optional. Default=None
263 264 265
            Metrics available are: ``BinaryAccuracy`` and
            ``CategoricalAccuracy`` See the ``Metrics`` section in this
            documentation or :obj:`pytorch_widedeep.metrics`
266 267 268 269 270 271 272 273 274 275 276 277 278
        class_weight: Union[float, List[float], Tuple[float]]. Optional. Default=None
            - float indicating the weight of the minority class in binary classification
              problems (e.g. 9.)
            - a list or tuple with weights for the different classes in multiclass
              classification problems  (e.g. [1., 2., 3.]). The weights do
              not neccesarily need to be normalised. If your loss function
              uses reduction='mean', the loss will be normalized by the sum
              of the corresponding weights for each element. If you are
              using reduction='none', you would have to take care of the
              normalization yourself. See `this discussion
              <https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/10>`_.
        with_focal_loss: bool, Optional. Default=False
            Boolean indicating whether to use the Focal Loss for highly imbalanced problems.
279 280
            For details on the focal loss see the `original paper
            <https://arxiv.org/pdf/1708.02002.pdf>`_.
281 282 283 284 285
        alpha: float. Default=0.25
            Focal Loss alpha parameter.
        gamma: float. Default=2
            Focal Loss gamma parameter.
        verbose: int
286
            Setting it to 0 will print nothing during training.
287
        seed: int, Default=1
288
            Random seed to be used throughout all the methods
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315

        Example
        --------
        Assuming you have already built the model components (wide, deepdense, etc...)

        >>> from pytorch_widedeep.models import WideDeep
        >>> from pytorch_widedeep.initializers import *
        >>> from pytorch_widedeep.callbacks import *
        >>> from pytorch_widedeep.optim import RAdam
        >>> model = WideDeep(wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage)
        >>> wide_opt = torch.optim.Adam(model.wide.parameters())
        >>> deep_opt = torch.optim.Adam(model.deepdense.parameters())
        >>> text_opt = RAdam(model.deeptext.parameters())
        >>> img_opt  = RAdam(model.deepimage.parameters())
        >>> wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=5)
        >>> deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=3)
        >>> text_sch = torch.optim.lr_scheduler.StepLR(text_opt, step_size=5)
        >>> img_sch  = torch.optim.lr_scheduler.StepLR(img_opt, step_size=3)
        >>> optimizers = {'wide': wide_opt, 'deepdense':deep_opt, 'deeptext':text_opt, 'deepimage': img_opt}
        >>> schedulers = {'wide': wide_sch, 'deepdense':deep_sch, 'deeptext':text_sch, 'deepimage': img_sch}
        >>> initializers = {'wide': Uniform, 'deepdense':Normal, 'deeptext':KaimingNormal,
        >>> ... 'deepimage':KaimingUniform}
        >>> transforms = [ToTensor, Normalize(mean=mean, std=std)]
        >>> callbacks = [LRHistory, EarlyStopping, ModelCheckpoint(filepath='model_weights/wd_out.pt')]
        >>> model.compile(method='regression', initializers=initializers, optimizers=optimizers,
        >>> ... lr_schedulers=schedulers, callbacks=callbacks, transforms=transforms)
        """
316
        self.verbose = verbose
317
        self.seed = seed
318
        self.early_stop = False
319
        self.method = method
320
        self.with_focal_loss = with_focal_loss
J
jrzaurin 已提交
321 322
        if self.with_focal_loss:
            self.alpha, self.gamma = alpha, gamma
323

324
        if isinstance(class_weight, float):
J
jrzaurin 已提交
325 326 327
            self.class_weight = torch.tensor([1.0 - class_weight, class_weight])
        elif isinstance(class_weight, (tuple, list)):
            self.class_weight = torch.tensor(class_weight)
328 329
        else:
            self.class_weight = None
330 331

        if initializers is not None:
332
            self.initializer = MultipleInitializer(initializers, verbose=self.verbose)
333 334
            self.initializer.apply(self)

335 336
        if optimizers is not None:
            if isinstance(optimizers, Optimizer):
J
jrzaurin 已提交
337 338
                self.optimizer: Union[Optimizer, MultipleOptimizer] = optimizers
            elif len(optimizers) > 1:
339
                opt_names = list(optimizers.keys())
J
jrzaurin 已提交
340 341 342
                mod_names = [n for n, c in self.named_children()]
                for mn in mod_names:
                    assert mn in opt_names, "No optimizer found for {}".format(mn)
343
                self.optimizer = MultipleOptimizer(optimizers)
344
        else:
J
jrzaurin 已提交
345
            self.optimizer = torch.optim.AdamW(self.parameters())  # type: ignore
346

347 348
        if lr_schedulers is not None:
            if isinstance(lr_schedulers, LRScheduler):
J
jrzaurin 已提交
349 350 351 352
                self.lr_scheduler: Union[
                    LRScheduler, MultipleLRScheduler
                ] = lr_schedulers
                self.cyclic = "cycl" in self.lr_scheduler.__class__.__name__.lower()
353 354
            elif len(lr_schedulers) > 1:
                self.lr_scheduler = MultipleLRScheduler(lr_schedulers)
J
jrzaurin 已提交
355 356 357 358 359
                scheduler_names = [
                    sc.__class__.__name__.lower()
                    for _, sc in self.lr_scheduler._schedulers.items()
                ]
                self.cyclic = any(["cycl" in sn for sn in scheduler_names])
360
        else:
361
            self.lr_scheduler, self.cyclic = None, False
362

363
        if transforms is not None:
J
jrzaurin 已提交
364
            self.transforms: MultipleTransforms = MultipleTransforms(transforms)()
365 366 367
        else:
            self.transforms = None

368
        self.history = History()
J
jrzaurin 已提交
369
        self.callbacks: List = [self.history]
370
        if callbacks is not None:
371
            for callback in callbacks:
J
jrzaurin 已提交
372 373
                if isinstance(callback, type):
                    callback = callback()
374
                self.callbacks.append(callback)
375 376 377

        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
378
            self.callbacks += [MetricCallback(self.metric)]
379 380
        else:
            self.metric = None
381

382 383
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self)
384

J
jrzaurin 已提交
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
        if use_cuda:
            self.cuda()

    def fit(
        self,
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
        n_epochs: int = 1,
        validation_freq: int = 1,
        batch_size: int = 32,
        patience: int = 10,
        warm_up: bool = False,
        warm_epochs: int = 4,
        warm_max_lr: float = 0.01,
        warm_deeptext_gradual: bool = False,
        warm_deeptext_max_lr: float = 0.01,
        warm_deeptext_layers: Optional[List[nn.Module]] = None,
        warm_deepimage_gradual: bool = False,
        warm_deepimage_max_lr: float = 0.01,
        warm_deepimage_layers: Optional[List[nn.Module]] = None,
        warm_routine: str = "howard",
    ):
413
        r"""Fit method. Must run after calling ``compile``
414 415 416

        Parameters
        ----------
417
        X_wide: np.ndarray, Optional. Default=None
418 419
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
420
        X_deep: np.ndarray, Optional. Default=None
421 422
            Input for the ``deepdense`` model component.
            See :class:`pytorch_widedeep.preprocessing.DensePreprocessor`
423
        X_text: np.ndarray, Optional. Default=None
424 425
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
426
        X_img : np.ndarray, Optional. Default=None
427 428 429 430
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
        X_train: Dict[str, np.ndarray], Optional. Default=None
            Training dataset for the different model components. Keys are
431
            `X_wide`, `'X_deep'`, `'X_text'`, `'X_img'` and `'target'`. Values are
432
            the corresponding matrices.
433
        X_val: Dict, Optional. Default=None
434
            Validation dataset for the different model component. Keys are
435 436
            `'X_wide'`, `'X_deep'`, `'X_text'`, `'X_img'` and `'target'`. Values are
            the corresponding matrices.
437 438
        val_split: float, Optional. Default=None
            train/val split fraction
439 440
        target: np.ndarray, Optional. Default=None
            target values
441 442 443 444 445 446
        n_epochs: int, Default=1
            number of epochs
        validation_freq: int, Default=1
            epochs validation frequency
        batch_size: int, Default=32
        patience: int, Default=10
447 448
            Number of epochs without improving the target metric before
            the fit process stops
449
        warm_up: bool, Default=False
450 451 452
            warm up model components individually before the joined training
            starts.

453 454 455 456 457
            ``pytorch_widedeep`` implements 3 warm up routines.

            - Warm up all trainable layers at once. This routine is is
              inspired by the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_. Using a
458 459 460 461 462 463
              Slanted Triangular learing (see `Leslie N. Smith paper
              <https://arxiv.org/pdf/1506.01186.pdf>`_), the process is the
              following: `i`) the learning rate will gradually increase for
              10% of the training steps from max_lr/10 to max_lr. `ii`) It
              will then gradually decrease to max_lr/10 for the remaining 90%
              of the steps. The optimizer used in the process is ``AdamW``.
464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479

            and two gradual warm up routines, where only certain layers are
            warmed up at each warm up step.

            - The so called `Felbo` gradual warm up rourine, inpired by the Felbo et al., 2017
              `DeepEmoji paper <https://arxiv.org/abs/1708.00524>`_.
            - The `Howard` routine based on the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_.

            For details on how these routines work, please see the examples
            section in this documentation and the corresponding repo.
        warm_epochs: int, Default=4
            Number of warm up epochs for those model components that will NOT
            be gradually warmed up. Those components with gradual warm up
            follow their corresponding specific routine.
        warm_max_lr: float, Default=0.01
480
            Maximum learning rate during the Triangular Learning rate cycle
481 482
            for those model componenst that will NOT be gradually warmed up
        warm_deeptext_gradual: bool, Default=False
483 484
            Boolean indicating if the deeptext component will be warmed
            up gradually
485
        warm_deeptext_max_lr: float, Default=0.01
486 487
            Maximum learning rate during the Triangular Learning rate cycle
            for the deeptext component
488
        warm_deeptext_layers: List, Optional, Default=None
489
            List of :obj:`nn.Modules` that will be warmed up gradually.
490 491 492 493 494

            .. note:: These have to be in `warm-up-order`: the layers or blocks
                close to the output neuron(s) first

        warm_deepimage_gradual: bool, Default=False
495 496 497 498 499
            Boolean indicating if the deepimage component will be warmed
            up gradually
        warm_deepimage_max_lr: Float, Default=0.01
            Maximum learning rate during the Triangular Learning rate cycle
            for the deepimage component
500
        warm_deepimage_layers: List, Optional, Default=None
501
            List of :obj:`nn.Modules` that will be warmed up gradually.
502

503 504 505
            .. note:: These have to be in `warm-up-order`: the layers or blocks
                close to the output neuron(s) first

506
        warm_routine: str, Default=`felbo`
507 508 509 510 511
            Warm up routine. On of `felbo` or `howard`. See the examples
            section in this documentation and the corresponding repo for
            details on how to use warm up routines

        Examples
512 513 514
        --------
        Assuming you have already built and compiled the model

515 516

        >>> # Ex 1. using train input arrays directly and no validation
517 518
        >>> model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256)

519 520

        >>> # Ex 2: using train input arrays directly and validation with val_split
521 522
        >>> model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256, val_split=0.2)

523 524

        >>> # Ex 3: using train dict and val_split
525 526 527
        >>> X_train = {'X_wide': X_wide, 'X_deep': X_deep, 'target': y}
        >>> model.fit(X_train, n_epochs=10, batch_size=256, val_split=0.2)

528 529

        >>> # Ex 4: validation using training and validation dicts
530 531 532
        >>> X_train = {'X_wide': X_wide_tr, 'X_deep': X_deep_tr, 'target': y_tr}
        >>> X_val = {'X_wide': X_wide_val, 'X_deep': X_deep_val, 'target': y_val}
        >>> model.fit(X_train=X_train, X_val=X_val n_epochs=10, batch_size=256)
533

534
        .. note:: :obj:`WideDeep` assumes that `X_wide`, `X_deep` and `target` ALWAYS exist, while
535 536 537 538 539
            `X_text` and `X_img` are optional

        .. note:: Either `X_train` or the three `X_wide`, `X_deep` and `target` must be passed to the
            fit method

540
        """
541 542 543

        if X_train is None and (X_wide is None or X_deep is None or target is None):
            raise ValueError(
544 545
                "Training data is missing. Either a dictionary (X_train) with "
                "the training dataset or at least 3 arrays (X_wide, X_deep, "
J
jrzaurin 已提交
546 547
                "target) must be passed to the fit method"
            )
548 549

        self.batch_size = batch_size
J
jrzaurin 已提交
550 551 552 553 554 555
        train_set, eval_set = self._train_val_split(
            X_wide, X_deep, X_text, X_img, X_train, X_val, val_split, target
        )
        train_loader = DataLoader(
            dataset=train_set, batch_size=batch_size, num_workers=n_cpus
        )
556 557
        if warm_up:
            # warm up...
J
jrzaurin 已提交
558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575
            self._warm_up(
                train_loader,
                warm_epochs,
                warm_max_lr,
                warm_deeptext_gradual,
                warm_deeptext_layers,
                warm_deeptext_max_lr,
                warm_deepimage_gradual,
                warm_deepimage_layers,
                warm_deepimage_max_lr,
                warm_routine,
            )
        train_steps = len(train_loader)
        self.callback_container.on_train_begin(
            {"batch_size": batch_size, "train_steps": train_steps, "n_epochs": n_epochs}
        )
        if self.verbose:
            print("Training")
576
        for epoch in range(n_epochs):
577
            # train step...
J
jrzaurin 已提交
578
            epoch_logs: Dict[str, float] = {}
579
            self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
J
jrzaurin 已提交
580
            self.train_running_loss = 0.0
581
            with trange(train_steps, disable=self.verbose != 1) as t:
J
jrzaurin 已提交
582 583
                for batch_idx, (data, target) in zip(t, train_loader):
                    t.set_description("epoch %i" % (epoch + 1))
584
                    acc, train_loss = self._training_step(data, target, batch_idx)
585 586
                    if acc is not None:
                        t.set_postfix(metrics=acc, loss=train_loss)
587
                    else:
588
                        t.set_postfix(loss=np.sqrt(train_loss))
J
jrzaurin 已提交
589 590
                    if self.lr_scheduler:
                        self._lr_scheduler_step(step_location="on_batch_end")
591
                    self.callback_container.on_batch_end(batch=batch_idx)
J
jrzaurin 已提交
592 593 594
            epoch_logs["train_loss"] = train_loss
            if acc is not None:
                epoch_logs["train_acc"] = acc["acc"]
595
            # eval step...
J
jrzaurin 已提交
596
            if epoch % validation_freq == (validation_freq - 1):
597
                if eval_set is not None:
J
jrzaurin 已提交
598 599 600 601 602 603 604 605
                    eval_loader = DataLoader(
                        dataset=eval_set,
                        batch_size=batch_size,
                        num_workers=n_cpus,
                        shuffle=False,
                    )
                    eval_steps = len(eval_loader)
                    self.valid_running_loss = 0.0
606
                    with trange(eval_steps, disable=self.verbose != 1) as v:
J
jrzaurin 已提交
607 608
                        for i, (data, target) in zip(v, eval_loader):
                            v.set_description("valid")
609 610 611 612 613
                            acc, val_loss = self._validation_step(data, target, i)
                            if acc is not None:
                                v.set_postfix(metrics=acc, loss=val_loss)
                            else:
                                v.set_postfix(loss=np.sqrt(val_loss))
J
jrzaurin 已提交
614 615 616 617 618 619
                    epoch_logs["val_loss"] = val_loss
                    if acc is not None:
                        epoch_logs["val_acc"] = acc["acc"]
            if self.lr_scheduler:
                self._lr_scheduler_step(step_location="on_epoch_end")
            #  log and check if early_stop...
620
            self.callback_container.on_epoch_end(epoch, epoch_logs)
621
            if self.early_stop:
J
jrzaurin 已提交
622
                self.callback_container.on_train_end(epoch_logs)
623
                break
J
jrzaurin 已提交
624
            self.callback_container.on_train_end(epoch_logs)
625 626
        self.train()

J
jrzaurin 已提交
627 628 629 630 631 632 633 634
    def predict(
        self,
        X_wide: np.ndarray,
        X_deep: np.ndarray,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
635
        r"""Returns the predictions
636 637 638

        Parameters
        ----------
639
        X_wide: np.ndarray, Optional. Default=None
640 641
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
642
        X_deep: np.ndarray, Optional. Default=None
643 644
            Input for the ``deepdense`` model component.
            See :class:`pytorch_widedeep.preprocessing.DensePreprocessor`
645
        X_text: np.ndarray, Optional. Default=None
646 647
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
648
        X_img : np.ndarray, Optional. Default=None
649 650 651 652 653 654 655 656 657
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
        X_test: Dict[str, np.ndarray], Optional. Default=None
            Testing dataset for the different model components. Keys are
            `'X_wide'`, `'X_deep'`, `'X_text'`, `'X_img'` and `'target'` the values are
            the corresponding matrices.

        .. note:: WideDeep assumes that `X_wide`, `X_deep` and `target` ALWAYS exist,
            while `X_text` and `X_img` are optional.
658
        """
659
        preds_l = self._predict(X_wide, X_deep, X_text, X_img, X_test)
660 661 662 663
        if self.method == "regression":
            return np.vstack(preds_l).squeeze(1)
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
J
jrzaurin 已提交
664
            return (preds > 0.5).astype("int")
665 666 667
        if self.method == "multiclass":
            preds = np.vstack(preds_l)
            return np.argmax(preds, 1)
668

J
jrzaurin 已提交
669 670 671 672 673 674 675 676
    def predict_proba(
        self,
        X_wide: np.ndarray,
        X_deep: np.ndarray,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
677
        r"""Returns the predicted probabilities for the test dataset for  binary
678
            and multiclass methods
679
        """
680
        preds_l = self._predict(X_wide, X_deep, X_text, X_img, X_test)
681 682
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
J
jrzaurin 已提交
683 684 685
            probs = np.zeros([preds.shape[0], 2])
            probs[:, 0] = 1 - preds
            probs[:, 1] = preds
686 687 688
            return probs
        if self.method == "multiclass":
            return np.vstack(preds_l)
689

J
jrzaurin 已提交
690 691 692
    def get_embeddings(
        self, col_name: str, cat_encoding_dict: Dict[str, Dict[str, int]]
    ) -> Dict[str, np.ndarray]:
693 694 695
        r"""Returns the learned embeddings for the categorical features passed through
        ``deepdense``.

696 697 698 699
        This method is designed to take an encoding dictionary in the same
        format as that of the :obj:`LabelEncoder` Attribute of the class
        :obj:`DensePreprocessor`. See
        :class:`pytorch_widedeep.preprocessing.DensePreprocessor` and
700
        :class:`pytorch_widedeep.utils.dense_utils.LabelEncder`.
701 702 703

        Parameters
        ----------
704 705
        col_name: str,
            Column name of the feature we want to get the embeddings for
706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722
        cat_encoding_dict: Dict[str, Dict[str, int]]
            Dictionary containing the categorical encodings, e.g:

            Examples
            --------
            >>> cat_encoding_dict['education']
            {'11th': 0, 'HS-grad': 1, 'Assoc-acdm': 2, 'Some-college': 3, '10th': 4, 'Prof-school': 5,
            '7th-8th': 6, 'Bachelors': 7, 'Masters': 8, 'Doctorate': 9, '5th-6th': 10, 'Assoc-voc': 11,
            '9th': 12, '12th': 13, '1st-4th': 14, 'Preschool': 15}

        Examples
        --------

        Assuming we have already train the model and that we have the
        categorical encodings in a dictionary name ``encoding_dict``:

        >>> model.get_embeddings(col_name='education', cat_encoding_dict=encoding_dict)
723 724 725 726 727 728 729
        {'11th': array([-0.42739448, -0.22282735,  0.36969638,  0.4445322 ,  0.2562272 ,
        0.11572784, -0.01648579,  0.09027119,  0.0457597 , -0.28337458], dtype=float32),
         'HS-grad': array([-0.10600474, -0.48775527,  0.3444158 ,  0.13818645, -0.16547225,
        0.27409762, -0.05006042, -0.0668492 , -0.11047247,  0.3280354 ], dtype=float32),
        ...
        }
        """
J
jrzaurin 已提交
730 731
        for n, p in self.named_parameters():
            if "embed_layers" in n and col_name in n:
732 733
                embed_mtx = p.cpu().data.numpy()
        encoding_dict = cat_encoding_dict[col_name]
J
jrzaurin 已提交
734
        inv_encoding_dict = {v: k for k, v in encoding_dict.items()}
735
        cat_embed_dict = {}
J
jrzaurin 已提交
736
        for idx, value in inv_encoding_dict.items():
737
            cat_embed_dict[value] = embed_mtx[idx]
738 739
        return cat_embed_dict

J
jrzaurin 已提交
740 741
    def _activation_fn(self, inp: Tensor) -> Tensor:
        if self.method == "binary":
742
            return torch.sigmoid(inp)
743 744 745 746
        else:
            # F.cross_entropy will apply logSoftmax to the preds in the case
            # of 'multiclass'
            return inp
747

J
jrzaurin 已提交
748
    def _loss_fn(self, y_pred: Tensor, y_true: Tensor) -> Tensor:  # type: ignore
749 750
        if self.with_focal_loss:
            return FocalLoss(self.alpha, self.gamma)(y_pred, y_true)
J
jrzaurin 已提交
751
        if self.method == "regression":
752
            return F.mse_loss(y_pred, y_true.view(-1, 1))
J
jrzaurin 已提交
753 754 755 756 757
        if self.method == "binary":
            return F.binary_cross_entropy(
                y_pred, y_true.view(-1, 1), weight=self.class_weight
            )
        if self.method == "multiclass":
758 759
            return F.cross_entropy(y_pred, y_true, weight=self.class_weight)

J
jrzaurin 已提交
760 761 762 763 764 765 766 767 768 769 770
    def _train_val_split(
        self,
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
    ):
771 772 773 774 775 776 777 778 779
        r"""
        If a validation set (X_val) is passed to the fit method, or val_split
        is specified, the train/val split will happen internally. A number of
        options are allowed in terms of data inputs. For parameter
        information, please, see the .fit() method documentation

        Returns
        -------
        train_set: WideDeepDataset
780 781 782
            :obj:`WideDeepDataset` object that will be loaded through
            :obj:`torch.utils.data.DataLoader`. See
            :class:`pytorch_widedeep.models._wd_dataset`
783
        eval_set : WideDeepDataset
784 785 786
            :obj:`WideDeepDataset` object that will be loaded through
            :obj:`torch.utils.data.DataLoader`. See
            :class:`pytorch_widedeep.models._wd_dataset`
787
        """
J
jrzaurin 已提交
788
        #  Without validation
789 790 791 792
        if X_val is None and val_split is None:
            # if a train dictionary is passed, check if text and image datasets
            # are present and instantiate the WideDeepDataset class
            if X_train is not None:
J
jrzaurin 已提交
793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811
                X_wide, X_deep, target = (
                    X_train["X_wide"],
                    X_train["X_deep"],
                    X_train["target"],
                )
                if "X_text" in X_train.keys():
                    X_text = X_train["X_text"]
                if "X_img" in X_train.keys():
                    X_img = X_train["X_img"]
            X_train = {"X_wide": X_wide, "X_deep": X_deep, "target": target}
            try:
                X_train.update({"X_text": X_text})
            except:
                pass
            try:
                X_train.update({"X_img": X_img})
            except:
                pass
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
812
            eval_set = None
J
jrzaurin 已提交
813
        #  With validation
814 815 816 817 818 819
        else:
            if X_val is not None:
                # if a validation dictionary is passed, then if not train
                # dictionary is passed we build it with the input arrays
                # (either the dictionary or the arrays must be passed)
                if X_train is None:
J
jrzaurin 已提交
820 821 822 823 824
                    X_train = {"X_wide": X_wide, "X_deep": X_deep, "target": target}
                    if X_text is not None:
                        X_train.update({"X_text": X_text})
                    if X_img is not None:
                        X_train.update({"X_img": X_img})
825 826 827 828
            else:
                # if a train dictionary is passed, check if text and image
                # datasets are present. The train/val split using val_split
                if X_train is not None:
J
jrzaurin 已提交
829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854
                    X_wide, X_deep, target = (
                        X_train["X_wide"],
                        X_train["X_deep"],
                        X_train["target"],
                    )
                    if "X_text" in X_train.keys():
                        X_text = X_train["X_text"]
                    if "X_img" in X_train.keys():
                        X_img = X_train["X_img"]
                (
                    X_tr_wide,
                    X_val_wide,
                    X_tr_deep,
                    X_val_deep,
                    y_tr,
                    y_val,
                ) = train_test_split(
                    X_wide,
                    X_deep,
                    target,
                    test_size=val_split,
                    random_state=self.seed,
                    stratify=target if self.method != "regression" else None,
                )
                X_train = {"X_wide": X_tr_wide, "X_deep": X_tr_deep, "target": y_tr}
                X_val = {"X_wide": X_val_wide, "X_deep": X_val_deep, "target": y_val}
855
                try:
J
jrzaurin 已提交
856
                    X_tr_text, X_val_text = train_test_split(
J
jrzaurin 已提交
857 858 859 860 861 862 863 864 865 866
                        X_text,
                        test_size=val_split,
                        random_state=self.seed,
                        stratify=target if self.method != "regression" else None,
                    )
                    X_train.update({"X_text": X_tr_text}), X_val.update(
                        {"X_text": X_val_text}
                    )
                except:
                    pass
867
                try:
J
jrzaurin 已提交
868
                    X_tr_img, X_val_img = train_test_split(
J
jrzaurin 已提交
869 870 871 872 873 874 875 876 877 878
                        X_img,
                        test_size=val_split,
                        random_state=self.seed,
                        stratify=target if self.method != "regression" else None,
                    )
                    X_train.update({"X_img": X_tr_img}), X_val.update(
                        {"X_img": X_val_img}
                    )
                except:
                    pass
879
            # At this point the X_train and X_val dictionaries have been built
J
jrzaurin 已提交
880 881
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
            eval_set = WideDeepDataset(**X_val, transforms=self.transforms)  # type: ignore
882 883
        return train_set, eval_set

J
jrzaurin 已提交
884 885 886 887 888 889 890 891 892 893 894 895 896
    def _warm_up(
        self,
        loader: DataLoader,
        n_epochs: int,
        max_lr: float,
        deeptext_gradual: bool,
        deeptext_layers: List[nn.Module],
        deeptext_max_lr: float,
        deepimage_gradual: bool,
        deepimage_layers: List[nn.Module],
        deepimage_max_lr: float,
        routine: str = "felbo",
    ):
897 898 899
        r"""
        Simple wrappup to individually warm up model components
        """
900 901
        if self.deephead is not None:
            raise ValueError(
J
jrzaurin 已提交
902 903
                "Currently warming up is only supported without a fully connected 'DeepHead'"
            )
904 905
        # This is not the most elegant solution, but is a soluton "in-between"
        # a non elegant one and re-factoring the whole code
J
jrzaurin 已提交
906 907 908 909 910
        warmer = WarmUp(
            self._activation_fn, self._loss_fn, self.metric, self.method, self.verbose
        )
        warmer.warm_all(self.wide, "wide", loader, n_epochs, max_lr)
        warmer.warm_all(self.deepdense, "deepdense", loader, n_epochs, max_lr)
911 912
        if self.deeptext:
            if deeptext_gradual:
J
jrzaurin 已提交
913 914 915 916 917 918 919 920 921 922
                warmer.warm_gradual(
                    self.deeptext,
                    "deeptext",
                    loader,
                    deeptext_max_lr,
                    deeptext_layers,
                    routine,
                )
            else:
                warmer.warm_all(self.deeptext, "deeptext", loader, n_epochs, max_lr)
923 924
        if self.deepimage:
            if deepimage_gradual:
J
jrzaurin 已提交
925 926 927 928 929 930 931 932 933 934
                warmer.warm_gradual(
                    self.deepimage,
                    "deepimage",
                    loader,
                    deepimage_max_lr,
                    deepimage_layers,
                    routine,
                )
            else:
                warmer.warm_all(self.deepimage, "deepimage", loader, n_epochs, max_lr)
935

J
jrzaurin 已提交
936
    def _lr_scheduler_step(self, step_location: str):
937 938 939
        r"""
        Function to execute the learning rate schedulers steps.
        If the lr_scheduler is Cyclic (i.e. CyclicLR or OneCycleLR), the step
940
        must happen after training each bach durig training. On the other
941 942 943 944 945 946 947 948
        hand, if the  scheduler is not Cyclic, is expected to be called after
        validation.

        Parameters
        ----------
        step_location: Str
            Indicates where to run the lr_scheduler step
        """
J
jrzaurin 已提交
949 950 951 952 953 954 955 956 957 958 959 960
        if (
            self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"
            and self.cyclic
        ):
            if step_location == "on_batch_end":
                for model_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
            elif step_location == "on_epoch_end":
                for scheduler_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" not in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
961
        elif self.cyclic:
J
jrzaurin 已提交
962 963 964 965 966 967 968 969 970 971 972 973 974 975 976
            if step_location == "on_batch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler":
            if step_location == "on_epoch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif step_location == "on_epoch_end":
            self.lr_scheduler.step()  # type: ignore
        else:
            pass

    def _training_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
977
        self.train()
J
jrzaurin 已提交
978 979
        X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
        y = target.float() if self.method != "multiclass" else target
980 981 982
        y = y.cuda() if use_cuda else y

        self.optimizer.zero_grad()
J
jrzaurin 已提交
983
        y_pred = self._activation_fn(self.forward(X))
984 985 986 987 988
        loss = self._loss_fn(y_pred, y)
        loss.backward()
        self.optimizer.step()

        self.train_running_loss += loss.item()
J
jrzaurin 已提交
989
        avg_loss = self.train_running_loss / (batch_idx + 1)
990 991 992 993 994 995 996

        if self.metric is not None:
            acc = self.metric(y_pred, y)
            return acc, avg_loss
        else:
            return None, avg_loss

J
jrzaurin 已提交
997
    def _validation_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
998 999 1000

        self.eval()
        with torch.no_grad():
J
jrzaurin 已提交
1001 1002
            X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
            y = target.float() if self.method != "multiclass" else target
1003 1004 1005 1006 1007
            y = y.cuda() if use_cuda else y

            y_pred = self._activation_fn(self.forward(X))
            loss = self._loss_fn(y_pred, y)
            self.valid_running_loss += loss.item()
J
jrzaurin 已提交
1008
            avg_loss = self.valid_running_loss / (batch_idx + 1)
1009 1010 1011 1012 1013 1014 1015

        if self.metric is not None:
            acc = self.metric(y_pred, y)
            return acc, avg_loss
        else:
            return None, avg_loss

J
jrzaurin 已提交
1016 1017 1018 1019 1020 1021 1022 1023
    def _predict(
        self,
        X_wide: np.ndarray,
        X_deep: np.ndarray,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> List:
1024 1025 1026
        r"""Hidden method to avoid code repetition in predict and
        predict_proba. For parameter information, please, see the .predict()
        method documentation
1027 1028 1029 1030
        """
        if X_test is not None:
            test_set = WideDeepDataset(**X_test)
        else:
J
jrzaurin 已提交
1031 1032 1033 1034 1035
            load_dict = {"X_wide": X_wide, "X_deep": X_deep}
            if X_text is not None:
                load_dict.update({"X_text": X_text})
            if X_img is not None:
                load_dict.update({"X_img": X_img})
1036 1037
            test_set = WideDeepDataset(**load_dict)

J
jrzaurin 已提交
1038 1039 1040 1041 1042 1043 1044
        test_loader = DataLoader(
            dataset=test_set,
            batch_size=self.batch_size,
            num_workers=n_cpus,
            shuffle=False,
        )
        test_steps = (len(test_loader.dataset) // test_loader.batch_size) + 1
1045 1046 1047 1048 1049 1050

        self.eval()
        preds_l = []
        with torch.no_grad():
            with trange(test_steps, disable=self.verbose != 1) as t:
                for i, data in zip(t, test_loader):
J
jrzaurin 已提交
1051 1052
                    t.set_description("predict")
                    X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
J
jrzaurin 已提交
1053
                    preds = self._activation_fn(self.forward(X))
J
jrzaurin 已提交
1054 1055
                    if self.method == "multiclass":
                        preds = F.softmax(preds, dim=1)
J
jrzaurin 已提交
1056
                    preds = preds.cpu().data.numpy()
1057 1058
                    preds_l.append(preds)
        self.train()
J
jrzaurin 已提交
1059
        return preds_l