wide_deep.py 42.9 KB
Newer Older
1
import numpy as np
2
import os
3
import torch
4 5 6
import torch.nn as nn
import torch.nn.functional as F

7 8 9 10 11 12 13 14 15 16 17
from ..wdtypes import *

from ..initializers import Initializer, MultipleInitializer
from ..callbacks import Callback, History, CallbackContainer
from ..metrics import Metric, MultipleMetrics, MetricCallback
from ..losses import FocalLoss

from ._wd_dataset import WideDeepDataset
from ._multiple_optimizer import MultipleOptimizer
from ._multiple_lr_scheduler import MultipleLRScheduler
from ._multiple_transforms import MultipleTransforms
18
from ._warmup import WarmUp
19
from .deep_dense import dense_layer
20

J
jrzaurin 已提交
21
from tqdm import trange
22
from sklearn.model_selection import train_test_split
23
from torch.utils.data import DataLoader
24

J
jrzaurin 已提交
25

26
n_cpus = os.cpu_count()
27 28 29 30
use_cuda = torch.cuda.is_available()


class WideDeep(nn.Module):
31 32 33 34 35 36 37 38 39
    r""" Main collector class to combine all Wide, DeepDense, DeepText and
    DeepImage models. There are two options to combine these models.
    1) Directly connecting the output of the models to an ouput neuron(s).
    2) Adding a FC-Head on top of the deep models. This FC-Head will combine
    the output form the DeepDense, DeepText and DeepImage and will be then
    connected to the output neuron(s)

    Parameters
    ----------
40 41 42 43 44 45
    wide: nn.Module
        Wide model. I recommend using the Wide class in this package. However,
        can a custom model as long as is  consistent with the required
        architecture.
    deepdense: nn.Module
        'Deep dense' model consisting in a series of categorical features
46 47 48 49
        represented by embeddings combined with numerical (aka continuous)
        features. I recommend using the DeepDense class in this package.
        However, a custom model as long as is  consistent with the required
        architecture.
50 51
    deeptext: nn.Module, Optional
        'Deep text' model for the text input. Must be an object of class
52 53
        DeepText or a custom model as long as is consistent with the required
        architecture.
54 55
    deepimage: nn.Module, Optional
        'Deep Image' model for the images input. Must be an object of class
56 57
        DeepImage or a custom model as long as is consistent with the required
        architecture.
58 59 60 61 62 63 64
    deephead: nn.Module, Optional
        Dense model consisting in a stack of dense layers. The FC-Head
    head_layers: List, Optional
        Sizes of the stacked dense layers in the fc-head e.g: [128, 64]
    head_dropout: List, Optional
        Dropout between the dense layers. e.g: [0.5, 0.5]
    head_batchnorm: Boolean, Optional
65 66
        Specifies if batch normalizatin should be included in the dense layers
        that form the texthead
67 68 69
    output_dim: Int
        Size of the final layer. 1 for regression and binary classification or
        'n_class' for multiclass classification
70 71 72

    ** While I recommend using the Wide and DeepDense classes within this
    package when building the corresponding model components, it is very likely
73 74 75 76 77 78
    that the user will want to use custom text and image models. That is perfectly
    possible. Simply, build them and pass them as the corresponding parameters.
    Note that the custom models MUST return a last layer of activations (i.e. not
    the final prediction) so that  these activations are collected by WideDeep and
    combined accordingly. In  addition, the models MUST also contain an attribute
    'output_dim' with the size of these last layers of activations.
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108

    Example
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import Wide, DeepDense, DeepText, DeepImage, WideDeep
    >>>
    >>> X_wide = torch.empty(5, 5).random_(2)
    >>> wide = Wide(wide_dim=X_wide.size(0), output_dim=1)
    >>>
    >>> X_deep = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
    >>> embed_input = [(u,i,j) for u,i,j in zip(colnames[:4], [4]*4, [8]*4)]
    >>> deep_column_idx = {k:v for v,k in enumerate(colnames)}
    >>> deepdense = DeepDense(hidden_layers=[8,4], deep_column_idx=deep_column_idx, embed_input=embed_input)
    >>>
    >>> X_text = torch.cat((torch.zeros([5,1]), torch.empty(5, 4).random_(1,4)), axis=1)
    >>> deeptext = DeepText(vocab_size=4, hidden_dim=4, n_layers=1, padding_idx=0, embed_dim=4)
    >>>
    >>> X_img = torch.rand((5,3,224,224))
    >>> deepimage = DeepImage(head_layers=[512, 64, 8])
    >>>
    >>> model = WideDeep(wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage, output_dim=1)
    >>> input_dict = {'wide':X_wide, 'deepdense':X_deep, 'deeptext':X_text, 'deepimage':X_img}
    >>> model(X=input_dict)
    tensor([[-0.3779],
            [-0.5247],
            [-0.2773],
            [-0.2888],
            [-0.2010]], grad_fn=<AddBackward0>)
    """
J
jrzaurin 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121

    def __init__(
        self,
        wide: nn.Module,
        deepdense: nn.Module,
        output_dim: int = 1,
        deeptext: Optional[nn.Module] = None,
        deepimage: Optional[nn.Module] = None,
        deephead: Optional[nn.Module] = None,
        head_layers: Optional[List[int]] = None,
        head_dropout: Optional[List] = None,
        head_batchnorm: Optional[bool] = None,
    ):
122

123
        super(WideDeep, self).__init__()
124 125

        # The main 5 components of the wide and deep assemble
126 127
        self.wide = wide
        self.deepdense = deepdense
J
jrzaurin 已提交
128
        self.deeptext = deeptext
129
        self.deepimage = deepimage
130 131 132 133
        self.deephead = deephead

        if self.deephead is None:
            if head_layers is not None:
M
Minjin Choi 已提交
134 135 136 137 138
                input_dim: int = self.deepdense.output_dim # type:ignore
                if self.deeptext is not None:
                    input_dim += self.deeptext.output_dim 
                if self.deepimage is not None:
                    input_dim += self.deepimage.output_dim  
139
                head_layers = [input_dim] + head_layers
J
jrzaurin 已提交
140 141
                if not head_dropout:
                    head_dropout = [0.0] * (len(head_layers) - 1)
142 143 144
                self.deephead = nn.Sequential()
                for i in range(1, len(head_layers)):
                    self.deephead.add_module(
J
jrzaurin 已提交
145 146 147 148 149 150 151 152 153 154 155
                        "head_layer_{}".format(i - 1),
                        dense_layer(
                            head_layers[i - 1],
                            head_layers[i],
                            head_dropout[i - 1],
                            head_batchnorm,
                        ),
                    )
                self.deephead.add_module(
                    "head_out", nn.Linear(head_layers[-1], output_dim)
                )
156 157
            else:
                self.deepdense = nn.Sequential(
J
jrzaurin 已提交
158 159
                    self.deepdense, nn.Linear(self.deepdense.output_dim, output_dim)  # type: ignore
                )
160 161
                if self.deeptext is not None:
                    self.deeptext = nn.Sequential(
J
jrzaurin 已提交
162 163
                        self.deeptext, nn.Linear(self.deeptext.output_dim, output_dim)  # type: ignore
                    )
164 165
                if self.deepimage is not None:
                    self.deepimage = nn.Sequential(
J
jrzaurin 已提交
166 167
                        self.deepimage, nn.Linear(self.deepimage.output_dim, output_dim)  # type: ignore
                    )
168

J
jrzaurin 已提交
169
    def forward(self, X: Dict[str, Tensor]) -> Tensor:  # type: ignore
170 171 172
        r"""
        Parameters
        ----------
173
        X: List
174 175 176
            List of Dict where the keys are the model names ('wide',
            'deepdense', 'deeptext' and 'deepimage') and the values are the
            corresponding Tensors
177
        """
178
        # Wide output: direct connection to the output neuron(s)
J
jrzaurin 已提交
179
        out = self.wide(X["wide"])
180 181 182 183

        # Deep output: either connected directly to the output neuron(s) or
        # passed through a head first
        if self.deephead:
J
jrzaurin 已提交
184
            deepside = self.deepdense(X["deepdense"])
185
            if self.deeptext is not None:
J
jrzaurin 已提交
186
                deepside = torch.cat([deepside, self.deeptext(X["deeptext"])], axis=1)  # type: ignore
187
            if self.deepimage is not None:
J
jrzaurin 已提交
188
                deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1)  # type: ignore
189
            deepside_out = self.deephead(deepside)
190 191
            return out.add_(deepside_out)
        else:
J
jrzaurin 已提交
192
            out.add_(self.deepdense(X["deepdense"]))
193
            if self.deeptext is not None:
J
jrzaurin 已提交
194
                out.add_(self.deeptext(X["deeptext"]))
195
            if self.deepimage is not None:
J
jrzaurin 已提交
196
                out.add_(self.deepimage(X["deepimage"]))
197 198
            return out

J
jrzaurin 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
    def compile(
        self,
        method: str,
        optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
        lr_schedulers: Optional[Union[LRScheduler, Dict[str, LRScheduler]]] = None,
        initializers: Optional[Dict[str, Initializer]] = None,
        transforms: Optional[List[Transforms]] = None,
        callbacks: Optional[List[Callback]] = None,
        metrics: Optional[List[Metric]] = None,
        class_weight: Optional[Union[float, List[float], Tuple[float]]] = None,
        with_focal_loss: bool = False,
        alpha: float = 0.25,
        gamma: float = 2,
        verbose: int = 1,
        seed: int = 1,
    ):
215
        r"""
216 217
        Function to set a number of attributes that will be used during the
        training process.
218 219 220

        Parameters
        ----------
221
        method: Str
222
            One of ('regression', 'binary' or 'multiclass')
223
        optimizers: Optimizer, Dict. Optional, Default=AdamW
224
            Either an optimizers object (e.g. torch.optim.Adam()) or a
225 226 227 228
            dictionary where there keys are the model's children (i.e. 'wide',
            'deepdense', 'deeptext', 'deepimage' and/or 'deephead')  and the
            values are the corresponding optimizers. If multiple optimizers
            are used the  dictionary MUST contain an optimizer per child.
229 230
        lr_schedulers: LRScheduler, Dict. Optional. Default=None
            Either a LRScheduler object (e.g
231 232 233
            torch.optim.lr_scheduler.StepLR(opt, step_size=5)) or dictionary
            where there keys are the model's children (i.e. 'wide', 'deepdense',
            'deeptext', 'deepimage' and/or 'deephead') and the values are the
234
            corresponding learning rate schedulers.
235
        initializers: Dict, Optional. Default=None
236 237 238
            Dict where there keys are the model's children (i.e. 'wide',
            'deepdense', 'deeptext', 'deepimage' and/or 'deephead') and the
            values are the corresponding initializers.
239 240
        transforms: List, Optional. Default=None
            List with torchvision.transforms to be applied to the image
241
            component of the model (i.e. 'deepimage')
242 243 244 245 246 247 248 249 250 251 252 253 254
        callbacks: List, Optional. Default=None
            Callbacks available are: ModelCheckpoint, EarlyStopping, and
            LRHistory. The History callback is used by default.
        metrics: List, Optional. Default=None
            Metrics available are: BinaryAccuracy and CategoricalAccuracy
        class_weight: List, Tuple, Float. Optional. Default=None
            Can be one of: float indicating the weight of the minority class
            in binary classification problems (e.g. 9.) or a list or tuple
            with weights for the different classes in multiclass
            classification problems  (e.g. [1., 2., 3.]). The weights do not
            neccesarily need to be normalised. If your loss function uses
            reduction='mean', the loss will be normalized by the sum of the
            corresponding weights for each element. If you are using
255 256 257
            reduction='none', you would have to take care of the normalization
            yourself. See here:
            https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/10
258
        with_focal_loss: Boolean, Optional. Default=False
259
            Use the Focal Loss. https://arxiv.org/pdf/1708.02002.pdf
260 261 262 263
        alpha, gamma: Float. Default=0.25, 2
            Focal Loss parameters. See: https://arxiv.org/pdf/1708.02002.pdf
        verbose: Int
            Setting it to 0 will print nothing during training.
264 265
        seed: Int, Default=1
            Random seed to be used throughout all the methods
266 267 268 269 270

        Attributes
        ----------
        Attributes that are not direct assignations of parameters

271
        self.cyclic: Boolean
272 273
            Indicates if any of the lr_schedulers is cyclic (i.e. CyclicLR or
            OneCycleLR)
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300

        Example
        --------
        Assuming you have already built the model components (wide, deepdense, etc...)

        >>> from pytorch_widedeep.models import WideDeep
        >>> from pytorch_widedeep.initializers import *
        >>> from pytorch_widedeep.callbacks import *
        >>> from pytorch_widedeep.optim import RAdam
        >>> model = WideDeep(wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage)
        >>> wide_opt = torch.optim.Adam(model.wide.parameters())
        >>> deep_opt = torch.optim.Adam(model.deepdense.parameters())
        >>> text_opt = RAdam(model.deeptext.parameters())
        >>> img_opt  = RAdam(model.deepimage.parameters())
        >>> wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=5)
        >>> deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=3)
        >>> text_sch = torch.optim.lr_scheduler.StepLR(text_opt, step_size=5)
        >>> img_sch  = torch.optim.lr_scheduler.StepLR(img_opt, step_size=3)
        >>> optimizers = {'wide': wide_opt, 'deepdense':deep_opt, 'deeptext':text_opt, 'deepimage': img_opt}
        >>> schedulers = {'wide': wide_sch, 'deepdense':deep_sch, 'deeptext':text_sch, 'deepimage': img_sch}
        >>> initializers = {'wide': Uniform, 'deepdense':Normal, 'deeptext':KaimingNormal,
        >>> ... 'deepimage':KaimingUniform}
        >>> transforms = [ToTensor, Normalize(mean=mean, std=std)]
        >>> callbacks = [LRHistory, EarlyStopping, ModelCheckpoint(filepath='model_weights/wd_out.pt')]
        >>> model.compile(method='regression', initializers=initializers, optimizers=optimizers,
        >>> ... lr_schedulers=schedulers, callbacks=callbacks, transforms=transforms)
        """
301
        self.verbose = verbose
302
        self.seed = seed
303
        self.early_stop = False
304
        self.method = method
305
        self.with_focal_loss = with_focal_loss
J
jrzaurin 已提交
306 307
        if self.with_focal_loss:
            self.alpha, self.gamma = alpha, gamma
308

309
        if isinstance(class_weight, float):
J
jrzaurin 已提交
310 311 312
            self.class_weight = torch.tensor([1.0 - class_weight, class_weight])
        elif isinstance(class_weight, (tuple, list)):
            self.class_weight = torch.tensor(class_weight)
313 314
        else:
            self.class_weight = None
315 316

        if initializers is not None:
317
            self.initializer = MultipleInitializer(initializers, verbose=self.verbose)
318 319
            self.initializer.apply(self)

320 321
        if optimizers is not None:
            if isinstance(optimizers, Optimizer):
J
jrzaurin 已提交
322 323
                self.optimizer: Union[Optimizer, MultipleOptimizer] = optimizers
            elif len(optimizers) > 1:
324
                opt_names = list(optimizers.keys())
J
jrzaurin 已提交
325 326 327
                mod_names = [n for n, c in self.named_children()]
                for mn in mod_names:
                    assert mn in opt_names, "No optimizer found for {}".format(mn)
328
                self.optimizer = MultipleOptimizer(optimizers)
329
        else:
J
jrzaurin 已提交
330
            self.optimizer = torch.optim.AdamW(self.parameters())  # type: ignore
331

332 333
        if lr_schedulers is not None:
            if isinstance(lr_schedulers, LRScheduler):
J
jrzaurin 已提交
334 335 336 337
                self.lr_scheduler: Union[
                    LRScheduler, MultipleLRScheduler
                ] = lr_schedulers
                self.cyclic = "cycl" in self.lr_scheduler.__class__.__name__.lower()
338 339
            elif len(lr_schedulers) > 1:
                self.lr_scheduler = MultipleLRScheduler(lr_schedulers)
J
jrzaurin 已提交
340 341 342 343 344
                scheduler_names = [
                    sc.__class__.__name__.lower()
                    for _, sc in self.lr_scheduler._schedulers.items()
                ]
                self.cyclic = any(["cycl" in sn for sn in scheduler_names])
345
        else:
346
            self.lr_scheduler, self.cyclic = None, False
347

348
        if transforms is not None:
J
jrzaurin 已提交
349
            self.transforms: MultipleTransforms = MultipleTransforms(transforms)()
350 351 352
        else:
            self.transforms = None

353
        self.history = History()
J
jrzaurin 已提交
354
        self.callbacks: List = [self.history]
355
        if callbacks is not None:
356
            for callback in callbacks:
J
jrzaurin 已提交
357 358
                if isinstance(callback, type):
                    callback = callback()
359
                self.callbacks.append(callback)
360 361 362

        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
363
            self.callbacks += [MetricCallback(self.metric)]
364 365
        else:
            self.metric = None
366

367 368
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self)
369

J
jrzaurin 已提交
370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
        if use_cuda:
            self.cuda()

    def fit(
        self,
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
        n_epochs: int = 1,
        validation_freq: int = 1,
        batch_size: int = 32,
        patience: int = 10,
        warm_up: bool = False,
        warm_epochs: int = 4,
        warm_max_lr: float = 0.01,
        warm_deeptext_gradual: bool = False,
        warm_deeptext_max_lr: float = 0.01,
        warm_deeptext_layers: Optional[List[nn.Module]] = None,
        warm_deepimage_gradual: bool = False,
        warm_deepimage_max_lr: float = 0.01,
        warm_deepimage_layers: Optional[List[nn.Module]] = None,
        warm_routine: str = "howard",
    ):
398 399 400 401 402
        r"""
        fit method that must run after calling 'compile'

        Parameters
        ----------
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
        X_wide: np.ndarray, Optional. Default=None
            One hot encoded wide input.
        X_deep: np.ndarray, Optional. Default=None
            Input for the deepdense model
        X_text: np.ndarray, Optional. Default=None
            Input for the deeptext model
        X_img : np.ndarray, Optional. Default=None
            Input for the deepimage model
        X_train: Dict, Optional. Default=None
            Training dataset for the different model branches.  Keys are
            'X_wide', 'X_deep', 'X_text', 'X_img' and 'target' the values are
            the corresponding matrices e.g X_train = {'X_wide': X_wide,
            'X_wide': X_wide, 'X_text': X_text, 'X_img': X_img}
        X_val: Dict, Optional. Default=None
            Validation dataset for the different model branches.  Keys are
            'X_wide', 'X_deep', 'X_text', 'X_img' and 'target' the values are
            the corresponding matrices e.g X_val = {'X_wide': X_wide,
            'X_wide': X_wide, 'X_text': X_text, 'X_img': X_img}
        val_split: Float, Optional. Default=None
            train/val split
        target: np.ndarray, Optional. Default=None
            target values
        n_epochs: Int, Default=1
        validation_freq: Int, Default=1
        batch_size: Int, Default=32
        patience: Int, Default=10
            Number of epochs without improving the target metric before we
            stop the fit
431
        warm_up: Boolean, Default=False
432
            warm_up model components individually before the joined traininga
433
        warm_epochs: Int, Default=4
434 435
            Number of warm up epochs for those model componenst that will not
            be gradually warmed up
436
        warm_max_lr: Float, Default=0.01
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461
            Maximum learning rate during the Triangular Learning rate cycle
            for those model componenst that will not be gradually warmed up
        warm_deeptext_gradual: Boolean, Default=False
            Boolean indicating if the deeptext component will be warmed
            up gradually
        warm_deeptext_max_lr: Float, Default=0.01
            Maximum learning rate during the Triangular Learning rate cycle
            for the deeptext component
        warm_deeptext_layers: Optional, List, Default=None
            List of nn.Modules that will be warmed up gradually. These have to
            be in 'warm-up-order': the layers or blocks close to the output
            neuron(s) first
        warm_deepimage_gradual: Boolean, Default=False
            Boolean indicating if the deepimage component will be warmed
            up gradually
        warm_deepimage_max_lr: Float, Default=0.01
            Maximum learning rate during the Triangular Learning rate cycle
            for the deepimage component
        warm_deepimage_layers: Optional, List, Default=None
            List of nn.Modules that will be warmed up gradually. These have to
            be in 'warm-up-order': the layers or blocks close to the output
            neuron(s) first
        warm_routine: Str, Default='felbo'
            Warm up routine. On of 'felbo' or 'howard'. See the WarmUp class
            documentation for details
462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486

        **WideDeep assumes that X_wide, X_deep and target ALWAYS exist, while
        X_text and X_img are optional
        **Either X_train or X_wide, X_deep and target must be passed to the
        fit method

        Example
        --------
        Assuming you have already built and compiled the model

        Ex 1. using train input arrays directly and no validation
        >>> model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256)

        Ex 2: using train input arrays directly and validation with val_split
        >>> model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256, val_split=0.2)

        Ex 3: using train dict and val_split
        >>> X_train = {'X_wide': X_wide, 'X_deep': X_deep, 'target': y}
        >>> model.fit(X_train, n_epochs=10, batch_size=256, val_split=0.2)

        Ex 4: validation using training and validation dicts
        >>> X_train = {'X_wide': X_wide_tr, 'X_deep': X_deep_tr, 'target': y_tr}
        >>> X_val = {'X_wide': X_wide_val, 'X_deep': X_deep_val, 'target': y_val}
        >>> model.fit(X_train=X_train, X_val=X_val n_epochs=10, batch_size=256)
        """
487 488 489

        if X_train is None and (X_wide is None or X_deep is None or target is None):
            raise ValueError(
490 491
                "Training data is missing. Either a dictionary (X_train) with "
                "the training dataset or at least 3 arrays (X_wide, X_deep, "
J
jrzaurin 已提交
492 493
                "target) must be passed to the fit method"
            )
494 495

        self.batch_size = batch_size
J
jrzaurin 已提交
496 497 498 499 500 501
        train_set, eval_set = self._train_val_split(
            X_wide, X_deep, X_text, X_img, X_train, X_val, val_split, target
        )
        train_loader = DataLoader(
            dataset=train_set, batch_size=batch_size, num_workers=n_cpus
        )
502 503
        if warm_up:
            # warm up...
J
jrzaurin 已提交
504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521
            self._warm_up(
                train_loader,
                warm_epochs,
                warm_max_lr,
                warm_deeptext_gradual,
                warm_deeptext_layers,
                warm_deeptext_max_lr,
                warm_deepimage_gradual,
                warm_deepimage_layers,
                warm_deepimage_max_lr,
                warm_routine,
            )
        train_steps = len(train_loader)
        self.callback_container.on_train_begin(
            {"batch_size": batch_size, "train_steps": train_steps, "n_epochs": n_epochs}
        )
        if self.verbose:
            print("Training")
522
        for epoch in range(n_epochs):
523
            # train step...
J
jrzaurin 已提交
524
            epoch_logs: Dict[str, float] = {}
525
            self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
J
jrzaurin 已提交
526
            self.train_running_loss = 0.0
527
            with trange(train_steps, disable=self.verbose != 1) as t:
J
jrzaurin 已提交
528 529
                for batch_idx, (data, target) in zip(t, train_loader):
                    t.set_description("epoch %i" % (epoch + 1))
530
                    acc, train_loss = self._training_step(data, target, batch_idx)
531 532
                    if acc is not None:
                        t.set_postfix(metrics=acc, loss=train_loss)
533
                    else:
534
                        t.set_postfix(loss=np.sqrt(train_loss))
J
jrzaurin 已提交
535 536
                    if self.lr_scheduler:
                        self._lr_scheduler_step(step_location="on_batch_end")
537
                    self.callback_container.on_batch_end(batch=batch_idx)
J
jrzaurin 已提交
538 539 540
            epoch_logs["train_loss"] = train_loss
            if acc is not None:
                epoch_logs["train_acc"] = acc["acc"]
541
            # eval step...
J
jrzaurin 已提交
542
            if epoch % validation_freq == (validation_freq - 1):
543
                if eval_set is not None:
J
jrzaurin 已提交
544 545 546 547 548 549 550 551
                    eval_loader = DataLoader(
                        dataset=eval_set,
                        batch_size=batch_size,
                        num_workers=n_cpus,
                        shuffle=False,
                    )
                    eval_steps = len(eval_loader)
                    self.valid_running_loss = 0.0
552
                    with trange(eval_steps, disable=self.verbose != 1) as v:
J
jrzaurin 已提交
553 554
                        for i, (data, target) in zip(v, eval_loader):
                            v.set_description("valid")
555 556 557 558 559
                            acc, val_loss = self._validation_step(data, target, i)
                            if acc is not None:
                                v.set_postfix(metrics=acc, loss=val_loss)
                            else:
                                v.set_postfix(loss=np.sqrt(val_loss))
J
jrzaurin 已提交
560 561 562 563 564 565
                    epoch_logs["val_loss"] = val_loss
                    if acc is not None:
                        epoch_logs["val_acc"] = acc["acc"]
            if self.lr_scheduler:
                self._lr_scheduler_step(step_location="on_epoch_end")
            #  log and check if early_stop...
566
            self.callback_container.on_epoch_end(epoch, epoch_logs)
567
            if self.early_stop:
J
jrzaurin 已提交
568
                self.callback_container.on_train_end(epoch_logs)
569
                break
J
jrzaurin 已提交
570
            self.callback_container.on_train_end(epoch_logs)
571 572
        self.train()

J
jrzaurin 已提交
573 574 575 576 577 578 579 580
    def predict(
        self,
        X_wide: np.ndarray,
        X_deep: np.ndarray,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
581 582 583 584 585
        r"""
        fit method that must run after calling 'compile'

        Parameters
        ----------
586 587 588 589 590 591 592 593 594 595 596 597 598
        X_wide: np.ndarray, Optional. Default=None
            One hot encoded wide input.
        X_deep: np.ndarray, Optional. Default=None
            Input for the deepdense model
        X_text: np.ndarray, Optional. Default=None
            Input for the deeptext model
        X_img : np.ndarray, Optional. Default=None
            Input for the deepimage model
        X_test: Dict, Optional. Default=None
            Testing dataset for the different model branches.  Keys are
            'X_wide', 'X_deep', 'X_text', 'X_img' and 'target' the values are
            the corresponding matrices e.g X_train = {'X_wide': X_wide,
            'X_wide': X_wide, 'X_text': X_text, 'X_img': X_img}
599 600 601 602 603 604 605 606

        **WideDeep assumes that X_wide, X_deep and target ALWAYS exist, while
        X_text and X_img are optional

        Returns
        -------
        preds: np.array with the predicted target for the test dataset.
        """
607
        preds_l = self._predict(X_wide, X_deep, X_text, X_img, X_test)
608 609 610 611
        if self.method == "regression":
            return np.vstack(preds_l).squeeze(1)
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
J
jrzaurin 已提交
612
            return (preds > 0.5).astype("int")
613 614 615
        if self.method == "multiclass":
            preds = np.vstack(preds_l)
            return np.argmax(preds, 1)
616

J
jrzaurin 已提交
617 618 619 620 621 622 623 624
    def predict_proba(
        self,
        X_wide: np.ndarray,
        X_deep: np.ndarray,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> np.ndarray:
625
        r"""
626 627
        Returns
        -------
628 629 630
        preds: np.ndarray
            Predicted probabilities of target for the test dataset for  binary
            and multiclass methods
631
        """
632
        preds_l = self._predict(X_wide, X_deep, X_text, X_img, X_test)
633 634
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
J
jrzaurin 已提交
635 636 637
            probs = np.zeros([preds.shape[0], 2])
            probs[:, 0] = 1 - preds
            probs[:, 1] = preds
638 639 640
            return probs
        if self.method == "multiclass":
            return np.vstack(preds_l)
641

J
jrzaurin 已提交
642 643 644
    def get_embeddings(
        self, col_name: str, cat_encoding_dict: Dict[str, Dict[str, int]]
    ) -> Dict[str, np.ndarray]:
645
        r"""
646 647 648 649
        Get the learned embeddings for the categorical features passed through deepdense.

        Parameters
        ----------
650 651 652 653 654 655
        col_name: str,
            Column name of the feature we want to get the embeddings for
        cat_encoding_dict: Dict
            Categorical encodings. The function is designed to take the
            'encoding_dict' attribute from the DeepPreprocessor class. Any
            Dict with the same structure can be used
656 657 658

        Returns
        -------
659 660 661
        cat_embed_dict: Dict
            Categorical levels of the col_name feature and the corresponding
            embeddings
662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681

        Example:
        -------
        Assuming we have already train the model:

        >>> model.get_embeddings(col_name='education', cat_encoding_dict=deep_preprocessor.encoding_dict)
        {'11th': array([-0.42739448, -0.22282735,  0.36969638,  0.4445322 ,  0.2562272 ,
        0.11572784, -0.01648579,  0.09027119,  0.0457597 , -0.28337458], dtype=float32),
         'HS-grad': array([-0.10600474, -0.48775527,  0.3444158 ,  0.13818645, -0.16547225,
        0.27409762, -0.05006042, -0.0668492 , -0.11047247,  0.3280354 ], dtype=float32),
        ...
        }

        where:

        >>> deep_preprocessor.encoding_dict['education']
        {'11th': 0, 'HS-grad': 1, 'Assoc-acdm': 2, 'Some-college': 3, '10th': 4, 'Prof-school': 5,
        '7th-8th': 6, 'Bachelors': 7, 'Masters': 8, 'Doctorate': 9, '5th-6th': 10, 'Assoc-voc': 11,
        '9th': 12, '12th': 13, '1st-4th': 14, 'Preschool': 15}
        """
J
jrzaurin 已提交
682 683
        for n, p in self.named_parameters():
            if "embed_layers" in n and col_name in n:
684 685
                embed_mtx = p.cpu().data.numpy()
        encoding_dict = cat_encoding_dict[col_name]
J
jrzaurin 已提交
686
        inv_encoding_dict = {v: k for k, v in encoding_dict.items()}
687
        cat_embed_dict = {}
J
jrzaurin 已提交
688
        for idx, value in inv_encoding_dict.items():
689
            cat_embed_dict[value] = embed_mtx[idx]
690 691
        return cat_embed_dict

J
jrzaurin 已提交
692 693
    def _activation_fn(self, inp: Tensor) -> Tensor:
        if self.method == "binary":
694
            return torch.sigmoid(inp)
695 696 697 698
        else:
            # F.cross_entropy will apply logSoftmax to the preds in the case
            # of 'multiclass'
            return inp
699

J
jrzaurin 已提交
700
    def _loss_fn(self, y_pred: Tensor, y_true: Tensor) -> Tensor:  # type: ignore
701 702
        if self.with_focal_loss:
            return FocalLoss(self.alpha, self.gamma)(y_pred, y_true)
J
jrzaurin 已提交
703
        if self.method == "regression":
704
            return F.mse_loss(y_pred, y_true.view(-1, 1))
J
jrzaurin 已提交
705 706 707 708 709
        if self.method == "binary":
            return F.binary_cross_entropy(
                y_pred, y_true.view(-1, 1), weight=self.class_weight
            )
        if self.method == "multiclass":
710 711
            return F.cross_entropy(y_pred, y_true, weight=self.class_weight)

J
jrzaurin 已提交
712 713 714 715 716 717 718 719 720 721 722
    def _train_val_split(
        self,
        X_wide: Optional[np.ndarray] = None,
        X_deep: Optional[np.ndarray] = None,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_train: Optional[Dict[str, np.ndarray]] = None,
        X_val: Optional[Dict[str, np.ndarray]] = None,
        val_split: Optional[float] = None,
        target: Optional[np.ndarray] = None,
    ):
723 724 725 726 727 728 729 730 731 732 733 734 735 736 737
        r"""
        If a validation set (X_val) is passed to the fit method, or val_split
        is specified, the train/val split will happen internally. A number of
        options are allowed in terms of data inputs. For parameter
        information, please, see the .fit() method documentation

        Returns
        -------
        train_set: WideDeepDataset
            WideDeepDataset object that will be loaded through
            torch.utils.data.DataLoader
        eval_set : WideDeepDataset
            WideDeepDataset object that will be loaded through
            torch.utils.data.DataLoader
        """
J
jrzaurin 已提交
738
        #  Without validation
739 740 741 742
        if X_val is None and val_split is None:
            # if a train dictionary is passed, check if text and image datasets
            # are present and instantiate the WideDeepDataset class
            if X_train is not None:
J
jrzaurin 已提交
743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761
                X_wide, X_deep, target = (
                    X_train["X_wide"],
                    X_train["X_deep"],
                    X_train["target"],
                )
                if "X_text" in X_train.keys():
                    X_text = X_train["X_text"]
                if "X_img" in X_train.keys():
                    X_img = X_train["X_img"]
            X_train = {"X_wide": X_wide, "X_deep": X_deep, "target": target}
            try:
                X_train.update({"X_text": X_text})
            except:
                pass
            try:
                X_train.update({"X_img": X_img})
            except:
                pass
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
762
            eval_set = None
J
jrzaurin 已提交
763
        #  With validation
764 765 766 767 768 769
        else:
            if X_val is not None:
                # if a validation dictionary is passed, then if not train
                # dictionary is passed we build it with the input arrays
                # (either the dictionary or the arrays must be passed)
                if X_train is None:
J
jrzaurin 已提交
770 771 772 773 774
                    X_train = {"X_wide": X_wide, "X_deep": X_deep, "target": target}
                    if X_text is not None:
                        X_train.update({"X_text": X_text})
                    if X_img is not None:
                        X_train.update({"X_img": X_img})
775 776 777 778
            else:
                # if a train dictionary is passed, check if text and image
                # datasets are present. The train/val split using val_split
                if X_train is not None:
J
jrzaurin 已提交
779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804
                    X_wide, X_deep, target = (
                        X_train["X_wide"],
                        X_train["X_deep"],
                        X_train["target"],
                    )
                    if "X_text" in X_train.keys():
                        X_text = X_train["X_text"]
                    if "X_img" in X_train.keys():
                        X_img = X_train["X_img"]
                (
                    X_tr_wide,
                    X_val_wide,
                    X_tr_deep,
                    X_val_deep,
                    y_tr,
                    y_val,
                ) = train_test_split(
                    X_wide,
                    X_deep,
                    target,
                    test_size=val_split,
                    random_state=self.seed,
                    stratify=target if self.method != "regression" else None,
                )
                X_train = {"X_wide": X_tr_wide, "X_deep": X_tr_deep, "target": y_tr}
                X_val = {"X_wide": X_val_wide, "X_deep": X_val_deep, "target": y_val}
805
                try:
J
jrzaurin 已提交
806
                    X_tr_text, X_val_text = train_test_split(
J
jrzaurin 已提交
807 808 809 810 811 812 813 814 815 816
                        X_text,
                        test_size=val_split,
                        random_state=self.seed,
                        stratify=target if self.method != "regression" else None,
                    )
                    X_train.update({"X_text": X_tr_text}), X_val.update(
                        {"X_text": X_val_text}
                    )
                except:
                    pass
817
                try:
J
jrzaurin 已提交
818
                    X_tr_img, X_val_img = train_test_split(
J
jrzaurin 已提交
819 820 821 822 823 824 825 826 827 828
                        X_img,
                        test_size=val_split,
                        random_state=self.seed,
                        stratify=target if self.method != "regression" else None,
                    )
                    X_train.update({"X_img": X_tr_img}), X_val.update(
                        {"X_img": X_val_img}
                    )
                except:
                    pass
829
            # At this point the X_train and X_val dictionaries have been built
J
jrzaurin 已提交
830 831
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)  # type: ignore
            eval_set = WideDeepDataset(**X_val, transforms=self.transforms)  # type: ignore
832 833
        return train_set, eval_set

J
jrzaurin 已提交
834 835 836 837 838 839 840 841 842 843 844 845 846
    def _warm_up(
        self,
        loader: DataLoader,
        n_epochs: int,
        max_lr: float,
        deeptext_gradual: bool,
        deeptext_layers: List[nn.Module],
        deeptext_max_lr: float,
        deepimage_gradual: bool,
        deepimage_layers: List[nn.Module],
        deepimage_max_lr: float,
        routine: str = "felbo",
    ):
847 848 849
        r"""
        Simple wrappup to individually warm up model components
        """
850 851
        if self.deephead is not None:
            raise ValueError(
J
jrzaurin 已提交
852 853
                "Currently warming up is only supported without a fully connected 'DeepHead'"
            )
854 855
        # This is not the most elegant solution, but is a soluton "in-between"
        # a non elegant one and re-factoring the whole code
J
jrzaurin 已提交
856 857 858 859 860
        warmer = WarmUp(
            self._activation_fn, self._loss_fn, self.metric, self.method, self.verbose
        )
        warmer.warm_all(self.wide, "wide", loader, n_epochs, max_lr)
        warmer.warm_all(self.deepdense, "deepdense", loader, n_epochs, max_lr)
861 862
        if self.deeptext:
            if deeptext_gradual:
J
jrzaurin 已提交
863 864 865 866 867 868 869 870 871 872
                warmer.warm_gradual(
                    self.deeptext,
                    "deeptext",
                    loader,
                    deeptext_max_lr,
                    deeptext_layers,
                    routine,
                )
            else:
                warmer.warm_all(self.deeptext, "deeptext", loader, n_epochs, max_lr)
873 874
        if self.deepimage:
            if deepimage_gradual:
J
jrzaurin 已提交
875 876 877 878 879 880 881 882 883 884
                warmer.warm_gradual(
                    self.deepimage,
                    "deepimage",
                    loader,
                    deepimage_max_lr,
                    deepimage_layers,
                    routine,
                )
            else:
                warmer.warm_all(self.deepimage, "deepimage", loader, n_epochs, max_lr)
885

J
jrzaurin 已提交
886
    def _lr_scheduler_step(self, step_location: str):
887 888 889
        r"""
        Function to execute the learning rate schedulers steps.
        If the lr_scheduler is Cyclic (i.e. CyclicLR or OneCycleLR), the step
890
        must happen after training each bach durig training. On the other
891 892 893 894 895 896 897 898
        hand, if the  scheduler is not Cyclic, is expected to be called after
        validation.

        Parameters
        ----------
        step_location: Str
            Indicates where to run the lr_scheduler step
        """
J
jrzaurin 已提交
899 900 901 902 903 904 905 906 907 908 909 910
        if (
            self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"
            and self.cyclic
        ):
            if step_location == "on_batch_end":
                for model_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
            elif step_location == "on_epoch_end":
                for scheduler_name, scheduler in self.lr_scheduler._schedulers.items():  # type: ignore
                    if "cycl" not in scheduler.__class__.__name__.lower():
                        scheduler.step()  # type: ignore
911
        elif self.cyclic:
J
jrzaurin 已提交
912 913 914 915 916 917 918 919 920 921 922 923 924 925 926
            if step_location == "on_batch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler":
            if step_location == "on_epoch_end":
                self.lr_scheduler.step()  # type: ignore
            else:
                pass
        elif step_location == "on_epoch_end":
            self.lr_scheduler.step()  # type: ignore
        else:
            pass

    def _training_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
927
        self.train()
J
jrzaurin 已提交
928 929
        X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
        y = target.float() if self.method != "multiclass" else target
930 931 932
        y = y.cuda() if use_cuda else y

        self.optimizer.zero_grad()
J
jrzaurin 已提交
933
        y_pred = self._activation_fn(self.forward(X))
934 935 936 937 938
        loss = self._loss_fn(y_pred, y)
        loss.backward()
        self.optimizer.step()

        self.train_running_loss += loss.item()
J
jrzaurin 已提交
939
        avg_loss = self.train_running_loss / (batch_idx + 1)
940 941 942 943 944 945 946

        if self.metric is not None:
            acc = self.metric(y_pred, y)
            return acc, avg_loss
        else:
            return None, avg_loss

J
jrzaurin 已提交
947
    def _validation_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
948 949 950

        self.eval()
        with torch.no_grad():
J
jrzaurin 已提交
951 952
            X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
            y = target.float() if self.method != "multiclass" else target
953 954 955 956 957
            y = y.cuda() if use_cuda else y

            y_pred = self._activation_fn(self.forward(X))
            loss = self._loss_fn(y_pred, y)
            self.valid_running_loss += loss.item()
J
jrzaurin 已提交
958
            avg_loss = self.valid_running_loss / (batch_idx + 1)
959 960 961 962 963 964 965

        if self.metric is not None:
            acc = self.metric(y_pred, y)
            return acc, avg_loss
        else:
            return None, avg_loss

J
jrzaurin 已提交
966 967 968 969 970 971 972 973
    def _predict(
        self,
        X_wide: np.ndarray,
        X_deep: np.ndarray,
        X_text: Optional[np.ndarray] = None,
        X_img: Optional[np.ndarray] = None,
        X_test: Optional[Dict[str, np.ndarray]] = None,
    ) -> List:
974 975 976 977 978 979 980 981
        r"""
        Hidden method to avoid code repetition in predict and predict_proba.
        For parameter information, please, see the .predict() method
        documentation
        """
        if X_test is not None:
            test_set = WideDeepDataset(**X_test)
        else:
J
jrzaurin 已提交
982 983 984 985 986
            load_dict = {"X_wide": X_wide, "X_deep": X_deep}
            if X_text is not None:
                load_dict.update({"X_text": X_text})
            if X_img is not None:
                load_dict.update({"X_img": X_img})
987 988
            test_set = WideDeepDataset(**load_dict)

J
jrzaurin 已提交
989 990 991 992 993 994 995
        test_loader = DataLoader(
            dataset=test_set,
            batch_size=self.batch_size,
            num_workers=n_cpus,
            shuffle=False,
        )
        test_steps = (len(test_loader.dataset) // test_loader.batch_size) + 1
996 997 998 999 1000 1001

        self.eval()
        preds_l = []
        with torch.no_grad():
            with trange(test_steps, disable=self.verbose != 1) as t:
                for i, data in zip(t, test_loader):
J
jrzaurin 已提交
1002 1003
                    t.set_description("predict")
                    X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
J
jrzaurin 已提交
1004
                    preds = self._activation_fn(self.forward(X))
J
jrzaurin 已提交
1005 1006
                    if self.method == "multiclass":
                        preds = F.softmax(preds, dim=1)
J
jrzaurin 已提交
1007
                    preds = preds.cpu().data.numpy()
1008 1009
                    preds_l.append(preds)
        self.train()
J
jrzaurin 已提交
1010
        return preds_l