wide_deep.py 38.8 KB
Newer Older
1
import numpy as np
2
import os
3
import warnings
4
import torch
5 6 7
import torch.nn as nn
import torch.nn.functional as F

8 9 10 11 12 13 14 15 16 17 18
from ..wdtypes import *

from ..initializers import Initializer, MultipleInitializer
from ..callbacks import Callback, History, CallbackContainer
from ..metrics import Metric, MultipleMetrics, MetricCallback
from ..losses import FocalLoss

from ._wd_dataset import WideDeepDataset
from ._multiple_optimizer import MultipleOptimizer
from ._multiple_lr_scheduler import MultipleLRScheduler
from ._multiple_transforms import MultipleTransforms
19
from ._wdmodel_type import WDModel
20
from .deep_dense import dense_layer
21

22 23
from tqdm import tqdm,trange
from sklearn.model_selection import train_test_split
24
from torch.utils.data import DataLoader
25

26
n_cpus = os.cpu_count()
27 28 29 30
use_cuda = torch.cuda.is_available()


class WideDeep(nn.Module):
31 32 33 34 35 36 37 38 39
    r""" Main collector class to combine all Wide, DeepDense, DeepText and
    DeepImage models. There are two options to combine these models.
    1) Directly connecting the output of the models to an ouput neuron(s).
    2) Adding a FC-Head on top of the deep models. This FC-Head will combine
    the output form the DeepDense, DeepText and DeepImage and will be then
    connected to the output neuron(s)

    Parameters
    ----------
40 41 42 43 44 45
    wide: nn.Module
        Wide model. I recommend using the Wide class in this package. However,
        can a custom model as long as is  consistent with the required
        architecture.
    deepdense: nn.Module
        'Deep dense' model consisting in a series of categorical features
46 47 48 49
        represented by embeddings combined with numerical (aka continuous)
        features. I recommend using the DeepDense class in this package.
        However, a custom model as long as is  consistent with the required
        architecture.
50 51
    deeptext: nn.Module, Optional
        'Deep text' model for the text input. Must be an object of class
52 53
        DeepText or a custom model as long as is consistent with the required
        architecture.
54 55
    deepimage: nn.Module, Optional
        'Deep Image' model for the images input. Must be an object of class
56 57
        DeepImage or a custom model as long as is consistent with the required
        architecture.
58 59 60 61 62 63 64 65 66 67 68 69
    deephead: nn.Module, Optional
        Dense model consisting in a stack of dense layers. The FC-Head
    head_layers: List, Optional
        Sizes of the stacked dense layers in the fc-head e.g: [128, 64]
    head_dropout: List, Optional
        Dropout between the dense layers. e.g: [0.5, 0.5]
    head_batchnorm: Boolean, Optional
        Whether or not to include batch normalizatin in the dense layers that
        form the texthead
    output_dim: Int
        Size of the final layer. 1 for regression and binary classification or
        'n_class' for multiclass classification
70 71 72

    ** While I recommend using the Wide and DeepDense classes within this
    package when building the corresponding model components, it is very likely
73 74 75 76 77 78
    that the user will want to use custom text and image models. That is perfectly
    possible. Simply, build them and pass them as the corresponding parameters.
    Note that the custom models MUST return a last layer of activations (i.e. not
    the final prediction) so that  these activations are collected by WideDeep and
    combined accordingly. In  addition, the models MUST also contain an attribute
    'output_dim' with the size of these last layers of activations.
79 80 81

    Attributes
    ----------
82 83 84
    deephead: nn.Sequential
        stack of dense layers comprising the FC-Head (aka imagehead) can be
        custom designed
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117

    ** The remaining attributes that will be set as we compile and run the model are
        discussed within the corresponding methods.

    Example
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import Wide, DeepDense, DeepText, DeepImage, WideDeep
    >>>
    >>> X_wide = torch.empty(5, 5).random_(2)
    >>> wide = Wide(wide_dim=X_wide.size(0), output_dim=1)
    >>>
    >>> X_deep = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
    >>> embed_input = [(u,i,j) for u,i,j in zip(colnames[:4], [4]*4, [8]*4)]
    >>> deep_column_idx = {k:v for v,k in enumerate(colnames)}
    >>> deepdense = DeepDense(hidden_layers=[8,4], deep_column_idx=deep_column_idx, embed_input=embed_input)
    >>>
    >>> X_text = torch.cat((torch.zeros([5,1]), torch.empty(5, 4).random_(1,4)), axis=1)
    >>> deeptext = DeepText(vocab_size=4, hidden_dim=4, n_layers=1, padding_idx=0, embed_dim=4)
    >>>
    >>> X_img = torch.rand((5,3,224,224))
    >>> deepimage = DeepImage(head_layers=[512, 64, 8])
    >>>
    >>> model = WideDeep(wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage, output_dim=1)
    >>> input_dict = {'wide':X_wide, 'deepdense':X_deep, 'deeptext':X_text, 'deepimage':X_img}
    >>> model(X=input_dict)
    tensor([[-0.3779],
            [-0.5247],
            [-0.2773],
            [-0.2888],
            [-0.2010]], grad_fn=<AddBackward0>)
    """
118
    def __init__(self,
119 120
        wide:nn.Module,
        deepdense:nn.Module,
121
        output_dim:int=1,
122
        deeptext:Optional[nn.Module]=None,
123 124 125 126 127
        deepimage:Optional[nn.Module]=None,
        deephead:Optional[nn.Module]=None,
        head_layers:Optional[List]=None,
        head_dropout:Optional[List]=None,
        head_batchnorm:Optional[bool]=None):
128

129
        super(WideDeep, self).__init__()
130 131

        # The main 5 components of the wide and deep assemble
132 133 134 135
        self.wide = wide
        self.deepdense = deepdense
        self.deeptext  = deeptext
        self.deepimage = deepimage
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
        self.deephead = deephead

        if self.deephead is None:
            if head_layers is not None:
                input_dim = self.deepdense.output_dim + self.deeptext.output_dim + self.deepimage.output_dim
                head_layers = [input_dim] + head_layers
                if not head_dropout: head_dropout = [0.] * (len(head_layers)-1)
                self.deephead = nn.Sequential()
                for i in range(1, len(head_layers)):
                    self.deephead.add_module(
                        'head_layer_{}'.format(i-1),
                        dense_layer( head_layers[i-1], head_layers[i], head_dropout[i-1], head_batchnorm))
                self.deephead.add_module('head_out', nn.Linear(head_layers[-1], output_dim))
            else:
                self.deepdense = nn.Sequential(
                    self.deepdense,
                    nn.Linear(self.deepdense.output_dim, output_dim))
                if self.deeptext is not None:
                    self.deeptext = nn.Sequential(
                        self.deeptext,
                        nn.Linear(self.deeptext.output_dim, output_dim))
                if self.deepimage is not None:
                    self.deepimage = nn.Sequential(
                        self.deepimage,
                        nn.Linear(self.deepimage.output_dim, output_dim))
161

162
    def forward(self, X:List[Dict[str,Tensor]])->Tensor:
163 164 165
        r"""
        Parameters
        ----------
166 167 168 169
        X: List
            List of Dict where the keys are the model names (wide, deepdense,
            deeptext and deepimage) and the values are the corresponding
            Tensors
170
        """
171
        # Wide output: direct connection to the output neuron(s)
172
        out = self.wide(X['wide'])
173 174 175 176 177 178 179 180 181

        # Deep output: either connected directly to the output neuron(s) or
        # passed through a head first
        if self.deephead:
            deepside = self.deepdense(X['deepdense'])
            if self.deeptext is not None:
                deepside = torch.cat( [deepside, self.deeptext(X['deeptext'])], axis=1 )
            if self.deepimage is not None:
                deepside = torch.cat( [deepside, self.deepimage(X['deepimage'])], axis=1 )
182
            deepside_out = self.deephead(deepside)
183 184 185 186 187 188 189 190 191 192 193
            return out.add_(deepside_out)
        else:
            out.add_(self.deepdense(X['deepdense']))
            if self.deeptext is not None:
                out.add_(self.deeptext(X['deeptext']))
            if self.deepimage is not None:
                out.add_(self.deepimage(X['deepimage']))
            return out

    def compile(self,
        method:str,
194
        optimizers:Optional[Union[Optimizer,Dict[str,Optimizer]]]=None,
195
        lr_schedulers:Optional[Union[LRScheduler,Dict[str,LRScheduler]]]=None,
196 197 198 199 200
        initializers:Optional[Dict[str,Initializer]]=None,
        transforms:Optional[List[Transforms]]=None,
        callbacks:Optional[List[Callback]]=None,
        metrics:Optional[List[Metric]]=None,
        class_weight:Optional[Union[float,List[float],Tuple[float]]]=None,
201 202
        with_focal_loss:bool=False,
        alpha:float=0.25,
203
        gamma:float=2,
204 205
        verbose:int=1,
        seed:int=1):
206 207 208 209 210
        r"""
        Function to set a number of attributes that are used during the training process.

        Parameters
        ----------
211 212
        method: Str
             One of ('regression', 'binary' or 'multiclass')
213
        optimizers: Optimizer, Dict. Optional, Default=AdamW
214
            Either an optimizers object (e.g. torch.optim.Adam()) or a
215 216 217
            dictionary where there keys are the model's children (i.e. wide,
            deepdense, deeptext, deepimage and/or deephead)  and the values
            are the corresponding optimizers. If multiple optimizers are used
218 219 220
            the  dictionary MUST contain an optimizer per child.
        lr_schedulers: LRScheduler, Dict. Optional. Default=None
            Either a LRScheduler object (e.g
221 222 223 224
            torch.optim.lr_scheduler.StepLR(opt, step_size=5)) or  dictionary
            where there keys are the model's children (i.e. wide, deepdense,
            deeptext, deepimage and/or deephead)  and the values are the
            corresponding learning rate schedulers.
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
        initializers: Dict, Optional. Default=None
            Dict where there keys are the model's children (i.e. wide,
            deepdense, deeptext, deepimage and/or deephead) and the values are
            the corresponding initializers.
        transforms: List, Optional. Default=None
            List with torchvision.transforms to be applied to the image
            component of the model
        callbacks: List, Optional. Default=None
            Callbacks available are: ModelCheckpoint, EarlyStopping, and
            LRHistory. The History callback is used by default.
        metrics: List, Optional. Default=None
            Metrics available are: BinaryAccuracy and CategoricalAccuracy
        class_weight: List, Tuple, Float. Optional. Default=None
            Can be one of: float indicating the weight of the minority class
            in binary classification problems (e.g. 9.) or a list or tuple
            with weights for the different classes in multiclass
            classification problems  (e.g. [1., 2., 3.]). The weights do not
            neccesarily need to be normalised. If your loss function uses
            reduction='mean', the loss will be normalized by the sum of the
            corresponding weights for each element. If you are using
245 246 247
            reduction='none', you would have to take care of the normalization
            yourself. See here:
            https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/10
248 249 250 251 252 253
        with_focal_loss: Boolean, Optional. Default=False
            Whether or not to use the Focal Loss. https://arxiv.org/pdf/1708.02002.pdf
        alpha, gamma: Float. Default=0.25, 2
            Focal Loss parameters. See: https://arxiv.org/pdf/1708.02002.pdf
        verbose: Int
            Setting it to 0 will print nothing during training.
254 255
        seed: Int, Default=1
            Random seed to be used throughout all the methods
256 257 258 259 260

        Attributes
        ----------
        Attributes that are not direct assignations of parameters

261 262
        self.cyclic: Boolean
            Indicates if any of the lr_schedulers is CyclicLR or OneCycleLR
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289

        Example
        --------
        Assuming you have already built the model components (wide, deepdense, etc...)

        >>> from pytorch_widedeep.models import WideDeep
        >>> from pytorch_widedeep.initializers import *
        >>> from pytorch_widedeep.callbacks import *
        >>> from pytorch_widedeep.optim import RAdam
        >>> model = WideDeep(wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage)
        >>> wide_opt = torch.optim.Adam(model.wide.parameters())
        >>> deep_opt = torch.optim.Adam(model.deepdense.parameters())
        >>> text_opt = RAdam(model.deeptext.parameters())
        >>> img_opt  = RAdam(model.deepimage.parameters())
        >>> wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=5)
        >>> deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=3)
        >>> text_sch = torch.optim.lr_scheduler.StepLR(text_opt, step_size=5)
        >>> img_sch  = torch.optim.lr_scheduler.StepLR(img_opt, step_size=3)
        >>> optimizers = {'wide': wide_opt, 'deepdense':deep_opt, 'deeptext':text_opt, 'deepimage': img_opt}
        >>> schedulers = {'wide': wide_sch, 'deepdense':deep_sch, 'deeptext':text_sch, 'deepimage': img_sch}
        >>> initializers = {'wide': Uniform, 'deepdense':Normal, 'deeptext':KaimingNormal,
        >>> ... 'deepimage':KaimingUniform}
        >>> transforms = [ToTensor, Normalize(mean=mean, std=std)]
        >>> callbacks = [LRHistory, EarlyStopping, ModelCheckpoint(filepath='model_weights/wd_out.pt')]
        >>> model.compile(method='regression', initializers=initializers, optimizers=optimizers,
        >>> ... lr_schedulers=schedulers, callbacks=callbacks, transforms=transforms)
        """
290
        self.verbose = verbose
291
        self.seed = seed
292
        self.early_stop = False
293
        self.method = method
294 295
        self.with_focal_loss = with_focal_loss
        if self.with_focal_loss:
296
            self.alpha, self.gamma = alpha, gamma
297

298
        if isinstance(class_weight, float):
299
            self.class_weight = torch.tensor([1.-class_weight, class_weight])
300 301 302 303
        elif isinstance(class_weight, (List, Tuple)):
            self.class_weight =  torch.tensor(class_weight)
        else:
            self.class_weight = None
304 305

        if initializers is not None:
306
            self.initializer = MultipleInitializer(initializers, verbose=self.verbose)
307 308
            self.initializer.apply(self)

309 310 311 312 313 314 315 316
        if optimizers is not None:
            if isinstance(optimizers, Optimizer):
                self.optimizer = optimizers
            elif len(optimizers)>1:
                opt_names = list(optimizers.keys())
                mod_names = [n  for n, c in self.named_children()]
                for mn in mod_names: assert mn in opt_names, "No optimizer found for {}".format(mn)
                self.optimizer = MultipleOptimizer(optimizers)
317
        else:
318
            self.optimizer = torch.optim.AdamW(self.parameters())
319

320 321 322 323 324 325 326 327
        if lr_schedulers is not None:
            if isinstance(lr_schedulers, LRScheduler):
                self.lr_scheduler = lr_schedulers
                self.cyclic = 'cycl' in self.lr_scheduler.__class__.__name__.lower()
            elif len(lr_schedulers) > 1:
                self.lr_scheduler = MultipleLRScheduler(lr_schedulers)
                scheduler_names = [sc.__class__.__name__.lower() for _,sc in self.lr_scheduler._schedulers.items()]
                self.cyclic = any(['cycl' in sn for sn in scheduler_names])
328
        else:
329
            self.lr_scheduler, self.cyclic = None, False
330

331 332 333 334 335
        if transforms is not None:
            self.transforms = MultipleTransforms(transforms)()
        else:
            self.transforms = None

336
        self.history = History()
337 338
        self.callbacks = [self.history]
        if callbacks is not None:
339 340 341
            for callback in callbacks:
                if isinstance(callback, type): callback = callback()
                self.callbacks.append(callback)
342 343 344

        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
345
            self.callbacks += [MetricCallback(self.metric)]
346 347
        else:
            self.metric = None
348

349 350
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self)
351

J
jrzaurin 已提交
352 353 354
        if use_cuda:
            self.cuda()

355 356 357 358 359
    def fit(self,
        X_wide:Optional[np.ndarray]=None,
        X_deep:Optional[np.ndarray]=None,
        X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None,
360 361
        X_train:Optional[Dict[str,np.ndarray]]=None,
        X_val:Optional[Dict[str,np.ndarray]]=None,
362 363 364
        val_split:Optional[float]=None,
        target:Optional[np.ndarray]=None,
        n_epochs:int=1,
365
        validation_freq:int=1,
366 367
        batch_size:int=32,
        patience:int=10,
368 369 370
        warm_up:bool=False,
        warm_epochs:int=4,
        warm_max_lr:float=0.01):
371 372 373 374 375
        r"""
        fit method that must run after calling 'compile'

        Parameters
        ----------
376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
        X_wide: np.ndarray, Optional. Default=None
            One hot encoded wide input.
        X_deep: np.ndarray, Optional. Default=None
            Input for the deepdense model
        X_text: np.ndarray, Optional. Default=None
            Input for the deeptext model
        X_img : np.ndarray, Optional. Default=None
            Input for the deepimage model
        X_train: Dict, Optional. Default=None
            Training dataset for the different model branches.  Keys are
            'X_wide', 'X_deep', 'X_text', 'X_img' and 'target' the values are
            the corresponding matrices e.g X_train = {'X_wide': X_wide,
            'X_wide': X_wide, 'X_text': X_text, 'X_img': X_img}
        X_val: Dict, Optional. Default=None
            Validation dataset for the different model branches.  Keys are
            'X_wide', 'X_deep', 'X_text', 'X_img' and 'target' the values are
            the corresponding matrices e.g X_val = {'X_wide': X_wide,
            'X_wide': X_wide, 'X_text': X_text, 'X_img': X_img}
        val_split: Float, Optional. Default=None
            train/val split
        target: np.ndarray, Optional. Default=None
            target values
        n_epochs: Int, Default=1
        validation_freq: Int, Default=1
        batch_size: Int, Default=32
        patience: Int, Default=10
            Number of epochs without improving the target metric before we
            stop the fit
404 405 406 407 408 409 410 411
        warm_up: Boolean, Default=False
            Warm up the models individually
        warm_epochs: Int, Default=4
            Number of warm up epochs
        warm_max_lr: Float, Default=0.01
            Warming up will happen using a slanted triangular learning rates
            (https://arxiv.org/pdf/1801.06146.pdf). warm_max_lr indicates the
            maximum learning rate that will be used during the cycle
412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436

        **WideDeep assumes that X_wide, X_deep and target ALWAYS exist, while
        X_text and X_img are optional
        **Either X_train or X_wide, X_deep and target must be passed to the
        fit method

        Example
        --------
        Assuming you have already built and compiled the model

        Ex 1. using train input arrays directly and no validation
        >>> model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256)

        Ex 2: using train input arrays directly and validation with val_split
        >>> model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256, val_split=0.2)

        Ex 3: using train dict and val_split
        >>> X_train = {'X_wide': X_wide, 'X_deep': X_deep, 'target': y}
        >>> model.fit(X_train, n_epochs=10, batch_size=256, val_split=0.2)

        Ex 4: validation using training and validation dicts
        >>> X_train = {'X_wide': X_wide_tr, 'X_deep': X_deep_tr, 'target': y_tr}
        >>> X_val = {'X_wide': X_wide_val, 'X_deep': X_deep_val, 'target': y_val}
        >>> model.fit(X_train=X_train, X_val=X_val n_epochs=10, batch_size=256)
        """
437 438 439

        if X_train is None and (X_wide is None or X_deep is None or target is None):
            raise ValueError(
440 441
                "Training data is missing. Either a dictionary (X_train) with "
                "the training dataset or at least 3 arrays (X_wide, X_deep, "
442
                "target) must be passed to the fit method")
443 444

        self.batch_size = batch_size
445
        train_set, eval_set = self._train_val_split(X_wide, X_deep, X_text, X_img,
446
            X_train, X_val, val_split, target)
447
        train_loader = DataLoader(dataset=train_set, batch_size=batch_size, num_workers=n_cpus)
448
        train_steps =  (len(train_loader.dataset) // batch_size) + 1
449
        if warm_up: self._warm_up(train_loader, warm_epochs, warm_max_lr)
450 451
        self.callback_container.on_train_begin({'batch_size': batch_size,
            'train_steps': train_steps, 'n_epochs': n_epochs})
452

453
        if self.verbose: print('Training')
454
        for epoch in range(n_epochs):
455
            # train step...
456 457
            epoch_logs={}
            self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
458
            self.train_running_loss = 0.
459
            with trange(train_steps, disable=self.verbose != 1) as t:
460
                for batch_idx, (data,target) in zip(t, train_loader):
461
                    t.set_description('epoch %i' % (epoch+1))
462
                    acc, train_loss = self._training_step(data, target, batch_idx)
463 464
                    if acc is not None:
                        t.set_postfix(metrics=acc, loss=train_loss)
465
                    else:
466
                        t.set_postfix(loss=np.sqrt(train_loss))
467 468
                    if self.lr_scheduler: self._lr_scheduler_step(step_location='on_batch_end')
                    self.callback_container.on_batch_end(batch=batch_idx)
469
            epoch_logs['train_loss'] = train_loss
470
            if acc is not None: epoch_logs['train_acc'] = acc['acc']
471
            # eval step...
472 473
            if epoch % validation_freq  == (validation_freq - 1):
                if eval_set is not None:
474
                    eval_loader = DataLoader(dataset=eval_set, batch_size=batch_size, num_workers=n_cpus,
475 476 477 478 479 480 481 482 483 484 485 486 487
                        shuffle=False)
                    eval_steps =  (len(eval_loader.dataset) // batch_size) + 1
                    self.valid_running_loss = 0.
                    with trange(eval_steps, disable=self.verbose != 1) as v:
                        for i, (data,target) in zip(v, eval_loader):
                            v.set_description('valid')
                            acc, val_loss = self._validation_step(data, target, i)
                            if acc is not None:
                                v.set_postfix(metrics=acc, loss=val_loss)
                            else:
                                v.set_postfix(loss=np.sqrt(val_loss))
                    epoch_logs['val_loss'] = val_loss
                    if acc is not None: epoch_logs['val_acc'] = acc['acc']
488 489 490
            if self.lr_scheduler: self._lr_scheduler_step(step_location='on_epoch_end')
            # log and check if early_stop...
            self.callback_container.on_epoch_end(epoch, epoch_logs)
491
            if self.early_stop:
492
                self.callback_container.on_train_end(epoch)
493
                break
494
            self.callback_container.on_train_end(epoch)
495 496 497
        self.train()

    def predict(self, X_wide:np.ndarray, X_deep:np.ndarray, X_text:Optional[np.ndarray]=None,
498
        X_img:Optional[np.ndarray]=None, X_test:Optional[Dict[str, np.ndarray]]=None)->np.ndarray:
499 500 501 502 503
        r"""
        fit method that must run after calling 'compile'

        Parameters
        ----------
504 505 506 507 508 509 510 511 512 513 514 515 516
        X_wide: np.ndarray, Optional. Default=None
            One hot encoded wide input.
        X_deep: np.ndarray, Optional. Default=None
            Input for the deepdense model
        X_text: np.ndarray, Optional. Default=None
            Input for the deeptext model
        X_img : np.ndarray, Optional. Default=None
            Input for the deepimage model
        X_test: Dict, Optional. Default=None
            Testing dataset for the different model branches.  Keys are
            'X_wide', 'X_deep', 'X_text', 'X_img' and 'target' the values are
            the corresponding matrices e.g X_train = {'X_wide': X_wide,
            'X_wide': X_wide, 'X_text': X_text, 'X_img': X_img}
517 518 519 520 521 522 523 524

        **WideDeep assumes that X_wide, X_deep and target ALWAYS exist, while
        X_text and X_img are optional

        Returns
        -------
        preds: np.array with the predicted target for the test dataset.
        """
525
        preds_l = self._predict(X_wide, X_deep, X_text, X_img, X_test)
526 527 528 529 530 531 532 533
        if self.method == "regression":
            return np.vstack(preds_l).squeeze(1)
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
            return (preds > 0.5).astype('int')
        if self.method == "multiclass":
            preds = np.vstack(preds_l)
            return np.argmax(preds, 1)
534

535 536 537 538 539
    def predict_proba(self, X_wide:np.ndarray, X_deep:np.ndarray, X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None, X_test:Optional[Dict[str, np.ndarray]]=None)->np.ndarray:
        """
        Returns
        -------
540 541 542
        preds: np.ndarray
            Predicted probabilities of target for the test dataset for  binary
            and multiclass methods
543
        """
544
        preds_l = self._predict(X_wide, X_deep, X_text, X_img, X_test)
545 546 547 548 549 550 551 552
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
            probs = np.zeros([preds.shape[0],2])
            probs[:,0] = 1-preds
            probs[:,1] = preds
            return probs
        if self.method == "multiclass":
            return np.vstack(preds_l)
553

554
    def get_embeddings(self, col_name:str,
555
        cat_encoding_dict:Dict[str,Dict[str,int]]) -> Dict[str,np.ndarray]:
556 557 558 559 560
        """
        Get the learned embeddings for the categorical features passed through deepdense.

        Parameters
        ----------
561 562 563 564 565 566
        col_name: str,
            Column name of the feature we want to get the embeddings for
        cat_encoding_dict: Dict
            Categorical encodings. The function is designed to take the
            'encoding_dict' attribute from the DeepPreprocessor class. Any
            Dict with the same structure can be used
567 568 569

        Returns
        -------
570 571 572
        cat_embed_dict: Dict
            Categorical levels of the col_name feature and the corresponding
            embeddings
573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592

        Example:
        -------
        Assuming we have already train the model:

        >>> model.get_embeddings(col_name='education', cat_encoding_dict=deep_preprocessor.encoding_dict)
        {'11th': array([-0.42739448, -0.22282735,  0.36969638,  0.4445322 ,  0.2562272 ,
        0.11572784, -0.01648579,  0.09027119,  0.0457597 , -0.28337458], dtype=float32),
         'HS-grad': array([-0.10600474, -0.48775527,  0.3444158 ,  0.13818645, -0.16547225,
        0.27409762, -0.05006042, -0.0668492 , -0.11047247,  0.3280354 ], dtype=float32),
        ...
        }

        where:

        >>> deep_preprocessor.encoding_dict['education']
        {'11th': 0, 'HS-grad': 1, 'Assoc-acdm': 2, 'Some-college': 3, '10th': 4, 'Prof-school': 5,
        '7th-8th': 6, 'Bachelors': 7, 'Masters': 8, 'Doctorate': 9, '5th-6th': 10, 'Assoc-voc': 11,
        '9th': 12, '12th': 13, '1st-4th': 14, 'Preschool': 15}
        """
593 594 595 596 597 598 599 600
        for n,p in self.named_parameters():
            if 'embed_layers' in n and col_name in n:
                embed_mtx = p.cpu().data.numpy()
        encoding_dict = cat_encoding_dict[col_name]
        inv_encoding_dict = {v:k for k,v in encoding_dict.items()}
        cat_embed_dict = {}
        for idx,value in inv_encoding_dict.items():
            cat_embed_dict[value] = embed_mtx[idx]
601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851
        return cat_embed_dict

    def _activation_fn(self, inp:Tensor) -> Tensor:
        if self.method == 'regression':
            return inp
        if self.method == 'binary':
            return torch.sigmoid(inp)
        if self.method == 'multiclass':
            return F.softmax(inp, dim=1)

    def _loss_fn(self, y_pred:Tensor, y_true:Tensor) -> Tensor:
        if self.with_focal_loss:
            return FocalLoss(self.alpha, self.gamma)(y_pred, y_true)
        if self.method == 'regression':
            return F.mse_loss(y_pred, y_true.view(-1, 1))
        if self.method == 'binary':
            return F.binary_cross_entropy(y_pred, y_true.view(-1, 1), weight=self.class_weight)
        if self.method == 'multiclass':
            return F.cross_entropy(y_pred, y_true, weight=self.class_weight)

    def _train_val_split(self,
        X_wide:Optional[np.ndarray]=None,
        X_deep:Optional[np.ndarray]=None,
        X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None,
        X_train:Optional[Dict[str,np.ndarray]]=None,
        X_val:Optional[Dict[str,np.ndarray]]=None,
        val_split:Optional[float]=None,
        target:Optional[np.ndarray]=None):
        r"""
        If a validation set (X_val) is passed to the fit method, or val_split
        is specified, the train/val split will happen internally. A number of
        options are allowed in terms of data inputs. For parameter
        information, please, see the .fit() method documentation

        Returns
        -------
        train_set: WideDeepDataset
            WideDeepDataset object that will be loaded through
            torch.utils.data.DataLoader
        eval_set : WideDeepDataset
            WideDeepDataset object that will be loaded through
            torch.utils.data.DataLoader
        """
        # Without validation
        if X_val is None and val_split is None:
            # if a train dictionary is passed, check if text and image datasets
            # are present and instantiate the WideDeepDataset class
            if X_train is not None:
                X_wide, X_deep, target = X_train['X_wide'], X_train['X_deep'], X_train['target']
                if 'X_text' in X_train.keys(): X_text = X_train['X_text']
                if 'X_img' in X_train.keys(): X_img = X_train['X_img']
            X_train={'X_wide': X_wide, 'X_deep': X_deep, 'target': target}
            try: X_train.update({'X_text': X_text})
            except: pass
            try: X_train.update({'X_img': X_img})
            except: pass
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)
            eval_set = None
        # With validation
        else:
            if X_val is not None:
                # if a validation dictionary is passed, then if not train
                # dictionary is passed we build it with the input arrays
                # (either the dictionary or the arrays must be passed)
                if X_train is None:
                    X_train = {'X_wide':X_wide, 'X_deep': X_deep, 'target': target}
                    if X_text is not None: X_train.update({'X_text': X_text})
                    if X_img is not None:  X_train.update({'X_img': X_img})
            else:
                # if a train dictionary is passed, check if text and image
                # datasets are present. The train/val split using val_split
                if X_train is not None:
                    X_wide, X_deep, target = X_train['X_wide'], X_train['X_deep'], X_train['target']
                    if 'X_text' in X_train.keys(): X_text = X_train['X_text']
                    if 'X_img' in X_train.keys(): X_img = X_train['X_img']
                X_tr_wide, X_val_wide, X_tr_deep, X_val_deep, y_tr, y_val = train_test_split(X_wide,
                    X_deep, target, test_size=val_split, random_state=self.seed)
                X_train = {'X_wide':X_tr_wide, 'X_deep': X_tr_deep, 'target': y_tr}
                X_val = {'X_wide':X_val_wide, 'X_deep': X_val_deep, 'target': y_val}
                try:
                    X_tr_text, X_val_text = train_test_split(X_text, test_size=val_split,
                        random_state=self.seed)
                    X_train.update({'X_text': X_tr_text}), X_val.update({'X_text': X_val_text})
                except: pass
                try:
                    X_tr_img, X_val_img = train_test_split(X_img, test_size=val_split,
                        random_state=self.seed)
                    X_train.update({'X_img': X_tr_img}), X_val.update({'X_img': X_val_img})
                except: pass
            # At this point the X_train and X_val dictionaries have been built
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)
            eval_set = WideDeepDataset(**X_val, transforms=self.transforms)
        return train_set, eval_set

    def _warm_model(self, model:WDModel, model_name:str, loader:DataLoader, n_epochs:int,
        max_lr:float):
        r"""
        To Warm up the different models that comprise WideDeep we will use a
        triangular learning rate schedule and one single cycle over
        The cycle will go from max_lr/10. to max_lr.
        """
        if self.verbose: print('Warming up {} for {} epochs'.format(model_name, n_epochs))

        model.train()

        optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr/10.)
        steps = len(loader)
        step_size_up = round((steps*n_epochs) * 0.1)
        step_size_down = (steps*n_epochs) - step_size_up
        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=max_lr/10.,
            max_lr=max_lr, step_size_up=step_size_up, step_size_down=step_size_down,
            cycle_momentum=False)

        for epoch in range(n_epochs):
            running_loss=0.
            with trange(steps, disable=self.verbose != 1) as t:
                for batch_idx, (data, target) in zip(t, loader):
                    t.set_description('epoch %i' % (epoch+1))
                    X = data[model_name].cuda() if use_cuda else data[model_name]
                    y = target.float() if self.method != 'multiclass' else target
                    y = y.cuda() if use_cuda else y

                    optimizer.zero_grad()
                    y_pred = self._activation_fn(model(X))
                    loss   = self._loss_fn(y_pred, y)
                    loss.backward()
                    optimizer.step()
                    scheduler.step()

                    running_loss += loss.item()
                    avg_loss = running_loss/(batch_idx+1)

                    if self.metric is not None:
                        acc = self.metric(y_pred, y)
                        t.set_postfix(metrics=acc, loss=avg_loss)
                    else:
                        t.set_postfix(loss=np.sqrt(avg_loss))

    def _warm_up(self, loader:DataLoader, n_epochs:int, max_lr:float):

        if self.deephead is not None:
            raise ValueError(
                "Currently warming up is only supported without a fully connected 'DeepHead'")

        self._warm_model(self.wide, 'wide', loader, n_epochs, max_lr)
        self._warm_model(self.deepdense, 'deepdense', loader, n_epochs, max_lr)
        if self.deeptext is not None:
            self._warm_model(self.deeptext, 'deeptext', loader, n_epochs, max_lr)
        if self.deepimage is not None:
            self._warm_model(self.deepimage, 'deepimage', loader, n_epochs, max_lr)

    def _lr_scheduler_step(self, step_location:str):
        r"""
        Function to execute the learning rate schedulers steps.
        If the lr_scheduler is Cyclic (i.e. CyclicLR or OneCycleLR), the step
        must  happen after training each bach durig training. On the other
        hand, if the  scheduler is not Cyclic, is expected to be called after
        validation.

        Parameters
        ----------
        step_location: Str
            Indicates where to run the lr_scheduler step
        """
        if self.lr_scheduler.__class__.__name__ == 'MultipleLRScheduler' and self.cyclic:
            if step_location == 'on_batch_end':
                for model_name, scheduler in self.lr_scheduler._schedulers.items():
                    if 'cycl' in scheduler.__class__.__name__.lower(): scheduler.step()
            elif step_location == 'on_epoch_end':
                for scheduler_name, scheduler in self.lr_scheduler._schedulers.items():
                    if 'cycl' not in scheduler.__class__.__name__.lower(): scheduler.step()
        elif self.cyclic:
            if step_location == 'on_batch_end': self.lr_scheduler.step()
            else: pass
        elif self.lr_scheduler.__class__.__name__ == 'MultipleLRScheduler':
            if step_location == 'on_epoch_end': self.lr_scheduler.step()
            else: pass
        elif step_location == 'on_epoch_end': self.lr_scheduler.step()
        else: pass

    def _training_step(self, data:Dict[str, Tensor], target:Tensor, batch_idx:int):
        self.train()
        X = {k:v.cuda() for k,v in data.items()} if use_cuda else data
        y = target.float() if self.method != 'multiclass' else target
        y = y.cuda() if use_cuda else y

        self.optimizer.zero_grad()
        y_pred =  self._activation_fn(self.forward(X))
        loss = self._loss_fn(y_pred, y)
        loss.backward()
        self.optimizer.step()

        self.train_running_loss += loss.item()
        avg_loss = self.train_running_loss/(batch_idx+1)

        if self.metric is not None:
            acc = self.metric(y_pred, y)
            return acc, avg_loss
        else:
            return None, avg_loss

    def _validation_step(self, data:Dict[str, Tensor], target:Tensor, batch_idx:int):

        self.eval()
        with torch.no_grad():
            X = {k:v.cuda() for k,v in data.items()} if use_cuda else data
            y = target.float() if self.method != 'multiclass' else target
            y = y.cuda() if use_cuda else y

            y_pred = self._activation_fn(self.forward(X))
            loss = self._loss_fn(y_pred, y)
            self.valid_running_loss += loss.item()
            avg_loss = self.valid_running_loss/(batch_idx+1)

        if self.metric is not None:
            acc = self.metric(y_pred, y)
            return acc, avg_loss
        else:
            return None, avg_loss

    def _predict(self, X_wide:np.ndarray, X_deep:np.ndarray, X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None, X_test:Optional[Dict[str, np.ndarray]]=None)->List:
        r"""
        Hidden method to avoid code repetition in predict and predict_proba.
        For parameter information, please, see the .predict() method
        documentation
        """
        if X_test is not None:
            test_set = WideDeepDataset(**X_test)
        else:
            load_dict = {'X_wide': X_wide, 'X_deep': X_deep}
            if X_text is not None: load_dict.update({'X_text': X_text})
            if X_img is not None:  load_dict.update({'X_img': X_img})
            test_set = WideDeepDataset(**load_dict)

        test_loader = DataLoader(dataset=test_set, batch_size=self.batch_size, num_workers=n_cpus,
            shuffle=False)
        test_steps =  (len(test_loader.dataset) // test_loader.batch_size) + 1

        self.eval()
        preds_l = []
        with torch.no_grad():
            with trange(test_steps, disable=self.verbose != 1) as t:
                for i, data in zip(t, test_loader):
                    t.set_description('predict')
                    X = {k:v.cuda() for k,v in data.items()} if use_cuda else data
                    preds = self._activation_fn(self.forward(X)).cpu().data.numpy()
                    preds_l.append(preds)
        self.train()
        return preds_l