wide_deep.py 46.9 KB
Newer Older
1
import os
2
import warnings
3 4

import numpy as np
5
import torch
6 7
import torch.nn as nn
import torch.nn.functional as F
8 9 10
from tqdm import trange
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
11

12
from ..losses import FocalLoss
13 14 15 16 17
from ._warmup import WarmUp
from ..metrics import Metric, MetricCallback, MultipleMetrics
from ..wdtypes import *
from ..callbacks import History, Callback, CallbackContainer
from .deep_dense import dense_layer
18
from ._wd_dataset import WideDeepDataset
19
from ..initializers import Initializer, MultipleInitializer
20 21
from ._multiple_optimizer import MultipleOptimizer
from ._multiple_transforms import MultipleTransforms
22
from ._multiple_lr_scheduler import MultipleLRScheduler
J
jrzaurin 已提交
23

24
n_cpus = os.cpu_count()
25 26 27 28
use_cuda = torch.cuda.is_available()


class WideDeep(nn.Module):
29 30 31
    r"""Main collector class that combines all ``Wide``, ``DeepDense``,
    ``DeepText`` and ``DeepImage`` models.

32 33
    There are two options to combine these models that correspond to the two
    architectures that ``pytorch-widedeep`` can build.
34 35 36 37 38 39

        - Directly connecting the output of the model components to an ouput neuron(s).

        - Adding a `Fully-Connected Head` (FC-Head) on top of the deep models.
          This FC-Head will combine the output form the ``DeepDense``, ``DeepText`` and
          ``DeepImage`` and will be then connected to the output neuron(s).
40 41 42

    Parameters
    ----------
43
    wide: nn.Module
44 45 46 47
        Wide model. We recommend using the ``Wide`` class in this package.
        However, it is possible to use a custom model as long as is consistent
        with the required architecture, see
        :class:`pytorch_widedeep.models.wide.Wide`
48
    deepdense: nn.Module
49 50 51
        `Deep dense` model comprised by the embeddings for the categorical
        features combined with numerical (also referred as continuous)
        features. We recommend using the ``DeepDense`` class in this package.
52
        However, a custom model as long as is  consistent with the required
53
        architecture. See :class:`pytorch_widedeep.models.deep_dense.DeepDense`.
54
    deeptext: nn.Module, Optional
55 56 57 58
        `Deep text` model for the text input. Must be an object of class
        ``DeepText`` or a custom model as long as is consistent with the
        required architecture. See
        :class:`pytorch_widedeep.models.deep_dense.DeepText`
59
    deepimage: nn.Module, Optional
60 61 62 63
        `Deep Image` model for the images input. Must be an object of class
        ``DeepImage`` or a custom model as long as is consistent with the
        required architecture. See
        :class:`pytorch_widedeep.models.deep_dense.DeepImage`
64
    deephead: nn.Module, Optional
65
        `Dense` model consisting in a stack of dense layers. The FC-Head.
66
    head_layers: List, Optional
67
        Alternatively, we can use ``head_layers`` to specify the sizes of the
68
        stacked dense layers in the fc-head e.g: ``[128, 64]``
69
    head_dropout: List, Optional
70
        Dropout between the layers in ``head_layers``. e.g: ``[0.5, 0.5]``
71
    head_batchnorm: bool, Optional
72
        Specifies if batch normalizatin should be included in the dense layers
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    pred_dim: int
        Size of the final wide and deep output layer containing the
        predictions. `1` for regression and binary classification or `number
        of classes` for multiclass classification.


    .. note:: With the exception of ``cyclic``, all attributes are direct assignations of
        the corresponding parameters used when calling ``compile``.  Therefore,
        see the parameters at
        :class:`pytorch_widedeep.models.wide_deep.WideDeep.compile` for a full
        list of the attributes of an instance of
        :class:`pytorch_widedeep.models.wide_deep.Wide`


    Attributes
    ----------
    cyclic: :obj:`bool`
        Attribute that indicates if any of the lr_schedulers is cyclic (i.e. ``CyclicLR`` or
        ``OneCycleLR``). See `Pytorch schedulers <https://pytorch.org/docs/stable/optim.html>`_.
92 93 94 95 96 97 98 99 100 101 102


    .. note:: While I recommend using the ``Wide`` and ``DeepDense`` classes within
        this package when building the corresponding model components, it is very
        likely that the user will want to use custom text and image models. That
        is perfectly possible. Simply, build them and pass them as the
        corresponding parameters. Note that the custom models MUST return a last
        layer of activations (i.e. not the final prediction) so that  these
        activations are collected by ``WideDeep`` and combined accordingly. In
        addition, the models MUST also contain an attribute ``output_dim`` with
        the size of these last layers of activations. See for example
103 104 105
        :class:`pytorch_widedeep.models.deep_dense.DeepDense`

    """
J
jrzaurin 已提交
106 107 108 109 110

    def __init__(
        self,
        wide: nn.Module,
        deepdense: nn.Module,
111
        pred_dim: int = 1,
J
jrzaurin 已提交
112 113 114 115 116 117 118
        deeptext: Optional[nn.Module] = None,
        deepimage: Optional[nn.Module] = None,
        deephead: Optional[nn.Module] = None,
        head_layers: Optional[List[int]] = None,
        head_dropout: Optional[List] = None,
        head_batchnorm: Optional[bool] = None,
    ):
119

120
        super(WideDeep, self).__init__()
121

122
        # check that model components have the required output_dim attribute
J
jrzaurin 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
        if not hasattr(deepdense, "output_dim"):
            raise AttributeError(
                "deepdense model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepDense"
            )
        if deeptext is not None and not hasattr(deeptext, "output_dim"):
            raise AttributeError(
                "deeptext model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
        if deepimage is not None and not hasattr(deepimage, "output_dim"):
            raise AttributeError(
                "deepimage model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
138

139
        # The main 5 components of the wide and deep assemble
140 141
        self.wide = wide
        self.deepdense = deepdense
J
jrzaurin 已提交
142
        self.deeptext = deeptext
143
        self.deepimage = deepimage
144 145
        self.deephead = deephead

146 147 148 149 150 151 152
        if deephead is not None and head_layers is not None:
            warnings.warn(
                "both 'deephead' and 'head_layers' are not None."
                "'deephead' takes priority and will be used",
                UserWarning,
            )

153 154
        if self.deephead is None:
            if head_layers is not None:
J
jrzaurin 已提交
155
                input_dim: int = self.deepdense.output_dim  # type:ignore
M
Minjin Choi 已提交
156
                if self.deeptext is not None:
157
                    input_dim += self.deeptext.output_dim  # type:ignore
M
Minjin Choi 已提交
158
                if self.deepimage is not None:
159
                    input_dim += self.deepimage.output_dim  # type:ignore
160
                head_layers = [input_dim] + head_layers
J
jrzaurin 已提交
161 162
                if not head_dropout:
                    head_dropout = [0.0] * (len(head_layers) - 1)
163 164 165
                self.deephead = nn.Sequential()
                for i in range(1, len(head_layers)):
                    self.deephead.add_module(
J
jrzaurin 已提交
166 167 168 169 170 171 172 173 174
                        "head_layer_{}".format(i - 1),
                        dense_layer(
                            head_layers[i - 1],
                            head_layers[i],
                            head_dropout[i - 1],
                            head_batchnorm,
                        ),
                    )
                self.deephead.add_module(
175
                    "head_out", nn.Linear(head_layers[-1], pred_dim)
J
jrzaurin 已提交
176
                )
177 178
            else:
                self.deepdense = nn.Sequential(
179
                    self.deepdense, nn.Linear(self.deepdense.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
180
                )
181 182
                if self.deeptext is not None:
                    self.deeptext = nn.Sequential(
183
                        self.deeptext, nn.Linear(self.deeptext.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
184
                    )
185 186
                if self.deepimage is not None:
                    self.deepimage = nn.Sequential(
187
                        self.deepimage, nn.Linear(self.deepimage.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
188
                    )
189

J
jrzaurin 已提交
190
    def forward(self, X: Dict[str, Tensor]) -> Tensor:  # type: ignore
191

192
        # Wide output: direct connection to the output neuron(s)
J
jrzaurin 已提交
193
        out = self.wide(X["wide"])
194 195 196 197

        # Deep output: either connected directly to the output neuron(s) or
        # passed through a head first
        if self.deephead:
J
jrzaurin 已提交
198
            deepside = self.deepdense(X["deepdense"])
199
            if self.deeptext is not None:
J
jrzaurin 已提交
200
                deepside = torch.cat([deepside, self.deeptext(X["deeptext"])], axis=1)  # type: ignore
201
            if self.deepimage is not None:
J
jrzaurin 已提交
202
                deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1)  # type: ignore
203
            deepside_out = self.deephead(deepside)
204
            return out.add(deepside_out)
205
        else:
206
            out.add(self.deepdense(X["deepdense"]))
207
            if self.deeptext is not None:
208
                out.add(self.deeptext(X["deeptext"]))
209
            if self.deepimage is not None:
210
                out.add(self.deepimage(X["deepimage"]))
211 212
            return out

J
jrzaurin 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
    def compile(
        self,
        method: str,
        optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
        lr_schedulers: Optional[Union[LRScheduler, Dict[str, LRScheduler]]] = None,
        initializers: Optional[Dict[str, Initializer]] = None,
        transforms: Optional[List[Transforms]] = None,
        callbacks: Optional[List[Callback]] = None,
        metrics: Optional[List[Metric]] = None,
        class_weight: Optional[Union[float, List[float], Tuple[float]]] = None,
        with_focal_loss: bool = False,
        alpha: float = 0.25,
        gamma: float = 2,
        verbose: int = 1,
        seed: int = 1,
    ):
229
        r"""Method to set the of attributes that will be used during the
230
        training process.
231 232 233

        Parameters
        ----------
234 235 236
        method: str
            One of `regression`, `binary` or `multiclass`
        optimizers: Union[Optimizer, Dict[str, Optimizer]], Optional, Default=AdamW
237 238 239 240 241
            - An instance of ``pytorch``'s ``Optimizer`` object (e.g. :obj:`torch.optim.Adam()`) or
            - a dictionary where there keys are the model components (i.e.
              `'wide'`, `'deepdense'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`)  and
              the values are the corresponding optimizers. If multiple optimizers are used
              the  dictionary MUST contain an optimizer per model component.
242 243 244

            See `Pytorch optimizers <https://pytorch.org/docs/stable/optim.html>`_.
        lr_schedulers: Union[LRScheduler, Dict[str, LRScheduler]], Optional, Default=None
245 246 247 248
            - An instance of ``pytorch``'s ``LRScheduler`` object (e.g
              :obj:`torch.optim.lr_scheduler.StepLR(opt, step_size=5)`) or
            - a dictionary where there keys are the model componenst (i.e. `'wide'`,
              `'deepdense'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and the
249 250 251 252
              values are the corresponding learning rate schedulers.

            See `Pytorch schedulers <https://pytorch.org/docs/stable/optim.html>`_.
        initializers: Dict[str, Initializer], Optional. Default=None
253 254
            Dict where there keys are the model components (i.e. `'wide'`,
            `'deepdense'`, `'deeptext'`, `'deepimage'` and/or `'deephead'`) and the
255
            values are the corresponding initializers.
256 257
            See `Pytorch initializers <https://pytorch.org/docs/stable/nn.init.html>`_.
        transforms: List[Transforms], Optional. Default=None
258 259 260 261 262
            ``Transforms`` is a custom type. See
            :obj:`pytorch_widedeep.wdtypes`. List with
            :obj:`torchvision.transforms` to be applied to the image component
            of the model (i.e. ``deepimage``) See `torchvision transforms
            <https://pytorch.org/docs/stable/torchvision/transforms.html>`_.
263
        callbacks: List[Callback], Optional. Default=None
264 265 266 267
            Callbacks available are: ``ModelCheckpoint``, ``EarlyStopping``,
            and ``LRHistory``. The ``History`` callback is used by default.
            See the ``Callbacks`` section in this documentation or
            :obj:`pytorch_widedeep.callbacks`
268
        metrics: List[Metric], Optional. Default=None
269 270 271
            Metrics available are: ``BinaryAccuracy`` and
            ``CategoricalAccuracy`` See the ``Metrics`` section in this
            documentation or :obj:`pytorch_widedeep.metrics`
272 273 274 275 276 277 278 279 280 281 282 283 284
        class_weight: Union[float, List[float], Tuple[float]]. Optional. Default=None
            - float indicating the weight of the minority class in binary classification
              problems (e.g. 9.)
            - a list or tuple with weights for the different classes in multiclass
              classification problems  (e.g. [1., 2., 3.]). The weights do
              not neccesarily need to be normalised. If your loss function
              uses reduction='mean', the loss will be normalized by the sum
              of the corresponding weights for each element. If you are
              using reduction='none', you would have to take care of the
              normalization yourself. See `this discussion
              <https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/10>`_.
        with_focal_loss: bool, Optional. Default=False
            Boolean indicating whether to use the Focal Loss for highly imbalanced problems.
285 286
            For details on the focal loss see the `original paper
            <https://arxiv.org/pdf/1708.02002.pdf>`_.
287 288 289 290 291
        alpha: float. Default=0.25
            Focal Loss alpha parameter.
        gamma: float. Default=2
            Focal Loss gamma parameter.
        verbose: int
292
            Setting it to 0 will print nothing during training.
293
        seed: int, Default=1
294
            Random seed to be used throughout all the methods
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321

        Example
        --------
        Assuming you have already built the model components (wide, deepdense, etc...)

        >>> from pytorch_widedeep.models import WideDeep
        >>> from pytorch_widedeep.initializers import *
        >>> from pytorch_widedeep.callbacks import *
        >>> from pytorch_widedeep.optim import RAdam
        >>> model = WideDeep(wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage)
        >>> wide_opt = torch.optim.Adam(model.wide.parameters())
        >>> deep_opt = torch.optim.Adam(model.deepdense.parameters())
        >>> text_opt = RAdam(model.deeptext.parameters())
        >>> img_opt  = RAdam(model.deepimage.parameters())
        >>> wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=5)
        >>> deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=3)
        >>> text_sch = torch.optim.lr_scheduler.StepLR(text_opt, step_size=5)
        >>> img_sch  = torch.optim.lr_scheduler.StepLR(img_opt, step_size=3)
        >>> optimizers = {'wide': wide_opt, 'deepdense':deep_opt, 'deeptext':text_opt, 'deepimage': img_opt}
        >>> schedulers = {'wide': wide_sch, 'deepdense':deep_sch, 'deeptext':text_sch, 'deepimage': img_sch}
        >>> initializers = {'wide': Uniform, 'deepdense':Normal, 'deeptext':KaimingNormal,
        >>> ... 'deepimage':KaimingUniform}
        >>> transforms = [ToTensor, Normalize(mean=mean, std=std)]
        >>> callbacks = [LRHistory, EarlyStopping, ModelCheckpoint(filepath='model_weights/wd_out.pt')]
        >>> model.compile(method='regression', initializers=initializers, optimizers=optimizers,
        >>> ... lr_schedulers=schedulers, callbacks=callbacks, transforms=transforms)
        """
322
        self.verbose = verbose
323
        self.seed = seed
324
        self.early_stop = False
325
        self.method = method
326
        self.with_focal_loss = with_focal_loss
J
jrzaurin 已提交
327 328
        if self.with_focal_loss:
            self.alpha, self.gamma = alpha, gamma
329

330
        if isinstance(class_weight, float):
J
jrzaurin 已提交
331 332 333
            self.class_weight = torch.tensor([1.0 - class_weight, class_weight])
        elif isinstance(class_weight, (tuple, list)):
            self.class_weight = torch.tensor(class_weight)
334 335
        else:
            self.class_weight = None
336 337

        if initializers is not None:
338
            self.initializer = MultipleInitializer(initializers, verbose=self.verbose)
339 340
            self.initializer.apply(self)

341 342
        if optimizers is not None:
            if isinstance(optimizers, Optimizer):
J
jrzaurin 已提交
343 344
                self.optimizer: Union[Optimizer, MultipleOptimizer] = optimizers
            elif len(optimizers) > 1:
345
                opt_names = list(optimizers.keys())
J
jrzaurin 已提交
346 347 348
                mod_names = [n for n, c in self.named_children()]
                for mn in mod_names:
                    assert mn in opt_names, "No optimizer found for {}".format(mn)
349
                self.optimizer = MultipleOptimizer(optimizers)
350
        else:
J
jrzaurin 已提交
351
            self.optimizer = torch.optim.AdamW(self.parameters())  # type: ignore
352

353 354
        if lr_schedulers is not None:
            if isinstance(lr_schedulers, LRScheduler):
J
jrzaurin 已提交
355 356 357 358
                self.lr_scheduler: Union[
                    LRScheduler, MultipleLRScheduler
                ] = lr_schedulers
                self.cyclic = "cycl" in self.lr_scheduler.__class__.__name__.lower()
359 360
            elif len(lr_schedulers) > 1:
                self.lr_scheduler = MultipleLRScheduler(lr_schedulers)
J
jrzaurin 已提交
361 362 363 364 365
                scheduler_names = [
                    sc.__class__.__name__.lower()
                    for _, sc in self.lr_scheduler._schedulers.items()
                ]
                self.cyclic = any(["cycl" in sn for sn in scheduler_names])
366
        else:
367
            self.lr_scheduler, self.cyclic = None, False
368

369
        if transforms is not None:
J
jrzaurin 已提交
370
            self.transforms: MultipleTransforms = MultipleTransforms(transforms)()
371 372 373
        else:
            self.transforms = None

374
        self.history = History()
J
jrzaurin 已提交
375
        self.callbacks: List = [self.history]
376
        if callbacks is not None:
377
            for callback in callbacks:
J
jrzaurin 已提交
378 379
                if isinstance(callback, type):
                    callback = callback()
380
                self.callbacks.append(callback)
381 382 383

        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
384
            self.callbacks += [MetricCallback(self.metric)]
385 386
        else:
            self.metric = None
387

388 389
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self)
390

J
jrzaurin 已提交
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
        if use_cuda:
            self.cuda()

    def fit(
        self,
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
        n_epochs: int = 1,
        validation_freq: int = 1,
        batch_size: int = 32,
        patience: int = 10,
        warm_up: bool = False,
        warm_epochs: int = 4,
        warm_max_lr: float = 0.01,
        warm_deeptext_gradual: bool = False,
        warm_deeptext_max_lr: float = 0.01,
        warm_deeptext_layers: Optional[List[nn.Module]] = None,
        warm_deepimage_gradual: bool = False,
        warm_deepimage_max_lr: float = 0.01,
        warm_deepimage_layers: Optional[List[nn.Module]] = None,
        warm_routine: str = "howard",
    ):
419
        r"""Fit method. Must run after calling ``compile``
420 421 422

        Parameters
        ----------
423
        X_wide: np.ndarray, Optional. Default=None
424 425
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
426
        X_deep: np.ndarray, Optional. Default=None
427 428
            Input for the ``deepdense`` model component.
            See :class:`pytorch_widedeep.preprocessing.DensePreprocessor`
429
        X_text: np.ndarray, Optional. Default=None
430 431
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
432
        X_img : np.ndarray, Optional. Default=None
433 434 435 436
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
        X_train: Dict[str, np.ndarray], Optional. Default=None
            Training dataset for the different model components. Keys are
437
            `X_wide`, `'X_deep'`, `'X_text'`, `'X_img'` and `'target'`. Values are
438
            the corresponding matrices.
439
        X_val: Dict, Optional. Default=None
440
            Validation dataset for the different model component. Keys are
441 442
            `'X_wide'`, `'X_deep'`, `'X_text'`, `'X_img'` and `'target'`. Values are
            the corresponding matrices.
443 444
        val_split: float, Optional. Default=None
            train/val split fraction
445 446
        target: np.ndarray, Optional. Default=None
            target values
447 448 449 450 451 452
        n_epochs: int, Default=1
            number of epochs
        validation_freq: int, Default=1
            epochs validation frequency
        batch_size: int, Default=32
        patience: int, Default=10
453 454
            Number of epochs without improving the target metric before
            the fit process stops
455
        warm_up: bool, Default=False
456 457 458
            warm up model components individually before the joined training
            starts.

459 460 461 462 463
            ``pytorch_widedeep`` implements 3 warm up routines.

            - Warm up all trainable layers at once. This routine is is
              inspired by the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_. Using a
464 465 466 467 468 469
              Slanted Triangular learing (see `Leslie N. Smith paper
              <https://arxiv.org/pdf/1506.01186.pdf>`_), the process is the
              following: `i`) the learning rate will gradually increase for
              10% of the training steps from max_lr/10 to max_lr. `ii`) It
              will then gradually decrease to max_lr/10 for the remaining 90%
              of the steps. The optimizer used in the process is ``AdamW``.
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485

            and two gradual warm up routines, where only certain layers are
            warmed up at each warm up step.

            - The so called `Felbo` gradual warm up rourine, inpired by the Felbo et al., 2017
              `DeepEmoji paper <https://arxiv.org/abs/1708.00524>`_.
            - The `Howard` routine based on the work of Howard & Sebastian Ruder 2018 in their
              `ULMfit paper <https://arxiv.org/abs/1801.06146>`_.

            For details on how these routines work, please see the examples
            section in this documentation and the corresponding repo.
        warm_epochs: int, Default=4
            Number of warm up epochs for those model components that will NOT
            be gradually warmed up. Those components with gradual warm up
            follow their corresponding specific routine.
        warm_max_lr: float, Default=0.01
486
            Maximum learning rate during the Triangular Learning rate cycle
487 488
            for those model componenst that will NOT be gradually warmed up
        warm_deeptext_gradual: bool, Default=False
489 490
            Boolean indicating if the deeptext component will be warmed
            up gradually
491
        warm_deeptext_max_lr: float, Default=0.01
492 493
            Maximum learning rate during the Triangular Learning rate cycle
            for the deeptext component
494
        warm_deeptext_layers: List, Optional, Default=None
495
            List of :obj:`nn.Modules` that will be warmed up gradually.
496 497 498 499 500

            .. note:: These have to be in `warm-up-order`: the layers or blocks
                close to the output neuron(s) first

        warm_deepimage_gradual: bool, Default=False
501 502 503 504 505
            Boolean indicating if the deepimage component will be warmed
            up gradually
        warm_deepimage_max_lr: Float, Default=0.01
            Maximum learning rate during the Triangular Learning rate cycle
            for the deepimage component
506
        warm_deepimage_layers: List, Optional, Default=None
507
            List of :obj:`nn.Modules` that will be warmed up gradually.
508

509 510 511
            .. note:: These have to be in `warm-up-order`: the layers or blocks
                close to the output neuron(s) first

512
        warm_routine: str, Default=`felbo`
513 514 515 516 517
            Warm up routine. On of `felbo` or `howard`. See the examples
            section in this documentation and the corresponding repo for
            details on how to use warm up routines

        Examples
518 519 520
        --------
        Assuming you have already built and compiled the model

521 522

        >>> # Ex 1. using train input arrays directly and no validation
523 524
        >>> model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256)

525 526

        >>> # Ex 2: using train input arrays directly and validation with val_split
527 528
        >>> model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256, val_split=0.2)

529 530

        >>> # Ex 3: using train dict and val_split
531 532 533
        >>> X_train = {'X_wide': X_wide, 'X_deep': X_deep, 'target': y}
        >>> model.fit(X_train, n_epochs=10, batch_size=256, val_split=0.2)

534 535

        >>> # Ex 4: validation using training and validation dicts
536 537 538
        >>> X_train = {'X_wide': X_wide_tr, 'X_deep': X_deep_tr, 'target': y_tr}
        >>> X_val = {'X_wide': X_wide_val, 'X_deep': X_deep_val, 'target': y_val}
        >>> model.fit(X_train=X_train, X_val=X_val n_epochs=10, batch_size=256)
539

540
        .. note:: :obj:`WideDeep` assumes that `X_wide`, `X_deep` and `target` ALWAYS exist, while
541 542 543 544 545
            `X_text` and `X_img` are optional

        .. note:: Either `X_train` or the three `X_wide`, `X_deep` and `target` must be passed to the
            fit method

546
        """
547 548 549

        if X_train is None and (X_wide is None or X_deep is None or target is None):
            raise ValueError(
550 551
                "Training data is missing. Either a dictionary (X_train) with "
                "the training dataset or at least 3 arrays (X_wide, X_deep, "
J
jrzaurin 已提交
552 553
                "target) must be passed to the fit method"
            )
554 555

        self.batch_size = batch_size
J
jrzaurin 已提交
556 557 558 559 560 561
        train_set, eval_set = self._train_val_split(
            X_wide, X_deep, X_text, X_img, X_train, X_val, val_split, target
        )
        train_loader = DataLoader(
            dataset=train_set, batch_size=batch_size, num_workers=n_cpus
        )
562 563
        if warm_up:
            # warm up...
J
jrzaurin 已提交
564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581
            self._warm_up(
                train_loader,
                warm_epochs,
                warm_max_lr,
                warm_deeptext_gradual,
                warm_deeptext_layers,
                warm_deeptext_max_lr,
                warm_deepimage_gradual,
                warm_deepimage_layers,
                warm_deepimage_max_lr,
                warm_routine,
            )
        train_steps = len(train_loader)
        self.callback_container.on_train_begin(
            {"batch_size": batch_size, "train_steps": train_steps, "n_epochs": n_epochs}
        )
        if self.verbose:
            print("Training")
582
        for epoch in range(n_epochs):
583
            # train step...
J
jrzaurin 已提交
584
            epoch_logs: Dict[str, float] = {}
585
            self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
J
jrzaurin 已提交
586
            self.train_running_loss = 0.0
587
            with trange(train_steps, disable=self.verbose != 1) as t:
J
jrzaurin 已提交
588 589
                for batch_idx, (data, target) in zip(t, train_loader):
                    t.set_description("epoch %i" % (epoch + 1))
590
                    acc, train_loss = self._training_step(data, target, batch_idx)
591 592
                    if acc is not None:
                        t.set_postfix(metrics=acc, loss=train_loss)
593
                    else:
594
                        t.set_postfix(loss=np.sqrt(train_loss))
J
jrzaurin 已提交
595 596
                    if self.lr_scheduler:
                        self._lr_scheduler_step(step_location="on_batch_end")
597
                    self.callback_container.on_batch_end(batch=batch_idx)
J
jrzaurin 已提交
598 599 600
            epoch_logs["train_loss"] = train_loss
            if acc is not None:
                epoch_logs["train_acc"] = acc["acc"]
601
            # eval step...
J
jrzaurin 已提交
602
            if epoch % validation_freq == (validation_freq - 1):
603
                if eval_set is not None:
J
jrzaurin 已提交
604 605 606 607 608 609 610 611
                    eval_loader = DataLoader(
                        dataset=eval_set,
                        batch_size=batch_size,
                        num_workers=n_cpus,
                        shuffle=False,
                    )
                    eval_steps = len(eval_loader)
                    self.valid_running_loss = 0.0
612
                    with trange(eval_steps, disable=self.verbose != 1) as v:
J
jrzaurin 已提交
613 614
                        for i, (data, target) in zip(v, eval_loader):
                            v.set_description("valid")
615 616 617 618 619
                            acc, val_loss = self._validation_step(data, target, i)
                            if acc is not None:
                                v.set_postfix(metrics=acc, loss=val_loss)
                            else:
                                v.set_postfix(loss=np.sqrt(val_loss))
J
jrzaurin 已提交
620 621 622 623 624 625
                    epoch_logs["val_loss"] = val_loss
                    if acc is not None:
                        epoch_logs["val_acc"] = acc["acc"]
            if self.lr_scheduler:
                self._lr_scheduler_step(step_location="on_epoch_end")
            #  log and check if early_stop...
626
            self.callback_container.on_epoch_end(epoch, epoch_logs)
627
            if self.early_stop:
J
jrzaurin 已提交
628
                self.callback_container.on_train_end(epoch_logs)
629
                break
J
jrzaurin 已提交
630
            self.callback_container.on_train_end(epoch_logs)
631 632
        self.train()

J
jrzaurin 已提交
633 634 635 636 637 638 639 640
    def predict(
        self,
        X_wide: np.ndarray,
        X_deep: np.ndarray,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
641
        r"""Returns the predictions
642 643 644

        Parameters
        ----------
645
        X_wide: np.ndarray, Optional. Default=None
646 647
            Input for the ``wide`` model component.
            See :class:`pytorch_widedeep.preprocessing.WidePreprocessor`
648
        X_deep: np.ndarray, Optional. Default=None
649 650
            Input for the ``deepdense`` model component.
            See :class:`pytorch_widedeep.preprocessing.DensePreprocessor`
651
        X_text: np.ndarray, Optional. Default=None
652 653
            Input for the ``deeptext`` model component.
            See :class:`pytorch_widedeep.preprocessing.TextPreprocessor`
654
        X_img : np.ndarray, Optional. Default=None
655 656 657 658 659 660 661 662 663
            Input for the ``deepimage`` model component.
            See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor`
        X_test: Dict[str, np.ndarray], Optional. Default=None
            Testing dataset for the different model components. Keys are
            `'X_wide'`, `'X_deep'`, `'X_text'`, `'X_img'` and `'target'` the values are
            the corresponding matrices.

        .. note:: WideDeep assumes that `X_wide`, `X_deep` and `target` ALWAYS exist,
            while `X_text` and `X_img` are optional.
664
        """
665
        preds_l = self._predict(X_wide, X_deep, X_text, X_img, X_test)
666 667 668 669
        if self.method == "regression":
            return np.vstack(preds_l).squeeze(1)
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
J
jrzaurin 已提交
670
            return (preds > 0.5).astype("int")
671 672 673
        if self.method == "multiclass":
            preds = np.vstack(preds_l)
            return np.argmax(preds, 1)
674

J
jrzaurin 已提交
675 676 677 678 679 680 681 682
    def predict_proba(
        self,
        X_wide: np.ndarray,
        X_deep: np.ndarray,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
683
        r"""Returns the predicted probabilities for the test dataset for  binary
684
            and multiclass methods
685
        """
686
        preds_l = self._predict(X_wide, X_deep, X_text, X_img, X_test)
687 688
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
J
jrzaurin 已提交
689 690 691
            probs = np.zeros([preds.shape[0], 2])
            probs[:, 0] = 1 - preds
            probs[:, 1] = preds
692 693 694
            return probs
        if self.method == "multiclass":
            return np.vstack(preds_l)
695

J
jrzaurin 已提交
696 697 698
    def get_embeddings(
        self, col_name: str, cat_encoding_dict: Dict[str, Dict[str, int]]
    ) -> Dict[str, np.ndarray]:
699 700 701
        r"""Returns the learned embeddings for the categorical features passed through
        ``deepdense``.

702 703 704 705
        This method is designed to take an encoding dictionary in the same
        format as that of the :obj:`LabelEncoder` Attribute of the class
        :obj:`DensePreprocessor`. See
        :class:`pytorch_widedeep.preprocessing.DensePreprocessor` and
706
        :class:`pytorch_widedeep.utils.dense_utils.LabelEncder`.
707 708 709

        Parameters
        ----------
710 711
        col_name: str,
            Column name of the feature we want to get the embeddings for
712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728
        cat_encoding_dict: Dict[str, Dict[str, int]]
            Dictionary containing the categorical encodings, e.g:

            Examples
            --------
            >>> cat_encoding_dict['education']
            {'11th': 0, 'HS-grad': 1, 'Assoc-acdm': 2, 'Some-college': 3, '10th': 4, 'Prof-school': 5,
            '7th-8th': 6, 'Bachelors': 7, 'Masters': 8, 'Doctorate': 9, '5th-6th': 10, 'Assoc-voc': 11,
            '9th': 12, '12th': 13, '1st-4th': 14, 'Preschool': 15}

        Examples
        --------

        Assuming we have already train the model and that we have the
        categorical encodings in a dictionary name ``encoding_dict``:

        >>> model.get_embeddings(col_name='education', cat_encoding_dict=encoding_dict)
729 730 731 732 733 734 735
        {'11th': array([-0.42739448, -0.22282735,  0.36969638,  0.4445322 ,  0.2562272 ,
        0.11572784, -0.01648579,  0.09027119,  0.0457597 , -0.28337458], dtype=float32),
         'HS-grad': array([-0.10600474, -0.48775527,  0.3444158 ,  0.13818645, -0.16547225,
        0.27409762, -0.05006042, -0.0668492 , -0.11047247,  0.3280354 ], dtype=float32),
        ...
        }
        """
J
jrzaurin 已提交
736 737
        for n, p in self.named_parameters():
            if "embed_layers" in n and col_name in n:
738 739
                embed_mtx = p.cpu().data.numpy()
        encoding_dict = cat_encoding_dict[col_name]
J
jrzaurin 已提交
740
        inv_encoding_dict = {v: k for k, v in encoding_dict.items()}
741
        cat_embed_dict = {}
J
jrzaurin 已提交
742
        for idx, value in inv_encoding_dict.items():
743
            cat_embed_dict[value] = embed_mtx[idx]
744 745
        return cat_embed_dict

J
jrzaurin 已提交
746
    def _loss_fn(self, y_pred: Tensor, y_true: Tensor) -> Tensor:  # type: ignore
747 748
        if self.with_focal_loss:
            return FocalLoss(self.alpha, self.gamma)(y_pred, y_true)
J
jrzaurin 已提交
749
        if self.method == "regression":
750
            return F.mse_loss(y_pred, y_true.view(-1, 1))
J
jrzaurin 已提交
751
        if self.method == "binary":
752
            return F.binary_cross_entropy_with_logits(
J
jrzaurin 已提交
753 754 755
                y_pred, y_true.view(-1, 1), weight=self.class_weight
            )
        if self.method == "multiclass":
756 757
            return F.cross_entropy(y_pred, y_true, weight=self.class_weight)

J
jrzaurin 已提交
758 759 760 761 762 763 764 765 766 767 768
    def _train_val_split(
        self,
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
    ):
769 770 771 772 773 774 775 776 777
        r"""
        If a validation set (X_val) is passed to the fit method, or val_split
        is specified, the train/val split will happen internally. A number of
        options are allowed in terms of data inputs. For parameter
        information, please, see the .fit() method documentation

        Returns
        -------
        train_set: WideDeepDataset
778 779 780
            :obj:`WideDeepDataset` object that will be loaded through
            :obj:`torch.utils.data.DataLoader`. See
            :class:`pytorch_widedeep.models._wd_dataset`
781
        eval_set : WideDeepDataset
782 783 784
            :obj:`WideDeepDataset` object that will be loaded through
            :obj:`torch.utils.data.DataLoader`. See
            :class:`pytorch_widedeep.models._wd_dataset`
785
        """
J
jrzaurin 已提交
786
        #  Without validation
787 788 789 790
        if X_val is None and val_split is None:
            # if a train dictionary is passed, check if text and image datasets
            # are present and instantiate the WideDeepDataset class
            if X_train is not None:
J
jrzaurin 已提交
791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809
                X_wide, X_deep, target = (
                    X_train["X_wide"],
                    X_train["X_deep"],
                    X_train["target"],
                )
                if "X_text" in X_train.keys():
                    X_text = X_train["X_text"]
                if "X_img" in X_train.keys():
                    X_img = X_train["X_img"]
            X_train = {"X_wide": X_wide, "X_deep": X_deep, "target": target}
            try:
                X_train.update({"X_text": X_text})
            except:
                pass
            try:
                X_train.update({"X_img": X_img})
            except:
                pass
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
810
            eval_set = None
J
jrzaurin 已提交
811
        #  With validation
812 813 814 815 816 817
        else:
            if X_val is not None:
                # if a validation dictionary is passed, then if not train
                # dictionary is passed we build it with the input arrays
                # (either the dictionary or the arrays must be passed)
                if X_train is None:
J
jrzaurin 已提交
818 819 820 821 822
                    X_train = {"X_wide": X_wide, "X_deep": X_deep, "target": target}
                    if X_text is not None:
                        X_train.update({"X_text": X_text})
                    if X_img is not None:
                        X_train.update({"X_img": X_img})
823 824 825 826
            else:
                # if a train dictionary is passed, check if text and image
                # datasets are present. The train/val split using val_split
                if X_train is not None:
J
jrzaurin 已提交
827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852
                    X_wide, X_deep, target = (
                        X_train["X_wide"],
                        X_train["X_deep"],
                        X_train["target"],
                    )
                    if "X_text" in X_train.keys():
                        X_text = X_train["X_text"]
                    if "X_img" in X_train.keys():
                        X_img = X_train["X_img"]
                (
                    X_tr_wide,
                    X_val_wide,
                    X_tr_deep,
                    X_val_deep,
                    y_tr,
                    y_val,
                ) = train_test_split(
                    X_wide,
                    X_deep,
                    target,
                    test_size=val_split,
                    random_state=self.seed,
                    stratify=target if self.method != "regression" else None,
                )
                X_train = {"X_wide": X_tr_wide, "X_deep": X_tr_deep, "target": y_tr}
                X_val = {"X_wide": X_val_wide, "X_deep": X_val_deep, "target": y_val}
853
                try:
J
jrzaurin 已提交
854
                    X_tr_text, X_val_text = train_test_split(
J
jrzaurin 已提交
855 856 857 858 859 860 861 862 863 864
                        X_text,
                        test_size=val_split,
                        random_state=self.seed,
                        stratify=target if self.method != "regression" else None,
                    )
                    X_train.update({"X_text": X_tr_text}), X_val.update(
                        {"X_text": X_val_text}
                    )
                except:
                    pass
865
                try:
J
jrzaurin 已提交
866
                    X_tr_img, X_val_img = train_test_split(
J
jrzaurin 已提交
867 868 869 870 871 872 873 874 875 876
                        X_img,
                        test_size=val_split,
                        random_state=self.seed,
                        stratify=target if self.method != "regression" else None,
                    )
                    X_train.update({"X_img": X_tr_img}), X_val.update(
                        {"X_img": X_val_img}
                    )
                except:
                    pass
877
            # At this point the X_train and X_val dictionaries have been built
J
jrzaurin 已提交
878 879
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
            eval_set = WideDeepDataset(**X_val, transforms=self.transforms)  # type: ignore
880 881
        return train_set, eval_set

J
jrzaurin 已提交
882 883 884 885 886 887 888 889 890 891 892 893 894
    def _warm_up(
        self,
        loader: DataLoader,
        n_epochs: int,
        max_lr: float,
        deeptext_gradual: bool,
        deeptext_layers: List[nn.Module],
        deeptext_max_lr: float,
        deepimage_gradual: bool,
        deepimage_layers: List[nn.Module],
        deepimage_max_lr: float,
        routine: str = "felbo",
    ):
895 896 897
        r"""
        Simple wrappup to individually warm up model components
        """
898 899
        if self.deephead is not None:
            raise ValueError(
J
jrzaurin 已提交
900 901
                "Currently warming up is only supported without a fully connected 'DeepHead'"
            )
902 903
        # This is not the most elegant solution, but is a soluton "in-between"
        # a non elegant one and re-factoring the whole code
904
        warmer = WarmUp(self._loss_fn, self.metric, self.method, self.verbose)
J
jrzaurin 已提交
905 906
        warmer.warm_all(self.wide, "wide", loader, n_epochs, max_lr)
        warmer.warm_all(self.deepdense, "deepdense", loader, n_epochs, max_lr)
907 908
        if self.deeptext:
            if deeptext_gradual:
J
jrzaurin 已提交
909 910 911 912 913 914 915 916 917 918
                warmer.warm_gradual(
                    self.deeptext,
                    "deeptext",
                    loader,
                    deeptext_max_lr,
                    deeptext_layers,
                    routine,
                )
            else:
                warmer.warm_all(self.deeptext, "deeptext", loader, n_epochs, max_lr)
919 920
        if self.deepimage:
            if deepimage_gradual:
J
jrzaurin 已提交
921 922 923 924 925 926 927 928 929 930
                warmer.warm_gradual(
                    self.deepimage,
                    "deepimage",
                    loader,
                    deepimage_max_lr,
                    deepimage_layers,
                    routine,
                )
            else:
                warmer.warm_all(self.deepimage, "deepimage", loader, n_epochs, max_lr)
931

J
jrzaurin 已提交
932
    def _lr_scheduler_step(self, step_location: str):
933 934 935
        r"""
        Function to execute the learning rate schedulers steps.
        If the lr_scheduler is Cyclic (i.e. CyclicLR or OneCycleLR), the step
936
        must happen after training each bach durig training. On the other
937 938 939 940 941 942 943 944
        hand, if the  scheduler is not Cyclic, is expected to be called after
        validation.

        Parameters
        ----------
        step_location: Str
            Indicates where to run the lr_scheduler step
        """
J
jrzaurin 已提交
945 946 947 948 949 950 951 952 953 954 955 956
        if (
            self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"
            and self.cyclic
        ):
            if step_location == "on_batch_end":
                for model_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
            elif step_location == "on_epoch_end":
                for scheduler_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" not in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
957
        elif self.cyclic:
J
jrzaurin 已提交
958 959 960 961 962 963 964 965 966 967 968 969 970 971 972
            if step_location == "on_batch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler":
            if step_location == "on_epoch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif step_location == "on_epoch_end":
            self.lr_scheduler.step()  # type: ignore
        else:
            pass

    def _training_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
973
        self.train()
J
jrzaurin 已提交
974 975
        X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
        y = target.float() if self.method != "multiclass" else target
976 977 978
        y = y.cuda() if use_cuda else y

        self.optimizer.zero_grad()
979
        y_pred = self.forward(X)
980 981 982 983 984
        loss = self._loss_fn(y_pred, y)
        loss.backward()
        self.optimizer.step()

        self.train_running_loss += loss.item()
J
jrzaurin 已提交
985
        avg_loss = self.train_running_loss / (batch_idx + 1)
986 987

        if self.metric is not None:
988 989 990 991
            if self.method == "binary":
                acc = self.metric(torch.sigmoid(y_pred), y)
            if self.method == "multiclass":
                acc = self.metric(F.softmax(y_pred, dim=1), y)
992 993 994 995
            return acc, avg_loss
        else:
            return None, avg_loss

J
jrzaurin 已提交
996
    def _validation_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
997 998 999

        self.eval()
        with torch.no_grad():
J
jrzaurin 已提交
1000 1001
            X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
            y = target.float() if self.method != "multiclass" else target
1002 1003
            y = y.cuda() if use_cuda else y

1004
            y_pred = self.forward(X)
1005 1006
            loss = self._loss_fn(y_pred, y)
            self.valid_running_loss += loss.item()
J
jrzaurin 已提交
1007
            avg_loss = self.valid_running_loss / (batch_idx + 1)
1008 1009

        if self.metric is not None:
1010 1011 1012 1013
            if self.method == "binary":
                acc = self.metric(torch.sigmoid(y_pred), y)
            if self.method == "multiclass":
                acc = self.metric(F.softmax(y_pred, dim=1), y)
1014 1015 1016 1017
            return acc, avg_loss
        else:
            return None, avg_loss

J
jrzaurin 已提交
1018 1019 1020 1021 1022 1023 1024 1025
    def _predict(
        self,
        X_wide: np.ndarray,
        X_deep: np.ndarray,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> List:
1026 1027 1028
        r"""Hidden method to avoid code repetition in predict and
        predict_proba. For parameter information, please, see the .predict()
        method documentation
1029 1030 1031 1032
        """
        if X_test is not None:
            test_set = WideDeepDataset(**X_test)
        else:
J
jrzaurin 已提交
1033 1034 1035 1036 1037
            load_dict = {"X_wide": X_wide, "X_deep": X_deep}
            if X_text is not None:
                load_dict.update({"X_text": X_text})
            if X_img is not None:
                load_dict.update({"X_img": X_img})
1038 1039
            test_set = WideDeepDataset(**load_dict)

J
jrzaurin 已提交
1040 1041 1042 1043 1044 1045 1046
        test_loader = DataLoader(
            dataset=test_set,
            batch_size=self.batch_size,
            num_workers=n_cpus,
            shuffle=False,
        )
        test_steps = (len(test_loader.dataset) // test_loader.batch_size) + 1
1047 1048 1049 1050 1051 1052

        self.eval()
        preds_l = []
        with torch.no_grad():
            with trange(test_steps, disable=self.verbose != 1) as t:
                for i, data in zip(t, test_loader):
J
jrzaurin 已提交
1053 1054
                    t.set_description("predict")
                    X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
1055 1056 1057
                    preds = self.forward(X)
                    if self.method == "binary":
                        preds = torch.sigmoid(preds)
J
jrzaurin 已提交
1058 1059
                    if self.method == "multiclass":
                        preds = F.softmax(preds, dim=1)
J
jrzaurin 已提交
1060
                    preds = preds.cpu().data.numpy()
1061 1062
                    preds_l.append(preds)
        self.train()
J
jrzaurin 已提交
1063
        return preds_l