wide_deep.py 34.4 KB
Newer Older
1
import numpy as np
2
import warnings
3
import torch
4 5 6
import torch.nn as nn
import torch.nn.functional as F

7 8 9 10 11 12 13 14 15 16 17
from ..wdtypes import *

from ..initializers import Initializer, MultipleInitializer
from ..callbacks import Callback, History, CallbackContainer
from ..metrics import Metric, MultipleMetrics, MetricCallback
from ..losses import FocalLoss

from ._wd_dataset import WideDeepDataset
from ._multiple_optimizer import MultipleOptimizer
from ._multiple_lr_scheduler import MultipleLRScheduler
from ._multiple_transforms import MultipleTransforms
18
from .deep_dense import dense_layer
19

20 21
from tqdm import tqdm,trange
from sklearn.model_selection import train_test_split
22
from torch.utils.data import DataLoader
23

24 25
import pdb

26

27 28 29 30
use_cuda = torch.cuda.is_available()


class WideDeep(nn.Module):
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    r""" Main collector class to combine all Wide, DeepDense, DeepText and
    DeepImage models. There are two options to combine these models.
    1) Directly connecting the output of the models to an ouput neuron(s).
    2) Adding a FC-Head on top of the deep models. This FC-Head will combine
    the output form the DeepDense, DeepText and DeepImage and will be then
    connected to the output neuron(s)

    Parameters
    ----------
    wide: wide model. I recommend using the Wide class in this package.
        However, can a custom model as long as is  consistent with the
        required architecture.
    deepdense: deep dense model consisting in a series of categorical features
        represented by embeddings combined with numerical (aka continuous)
        features. I recommend using the DeepDense class in this package.
        However, a custom model as long as is  consistent with the required
        architecture.
    deeptext: optional model for the text input. Must be an object of class
        DeepText or a custom model as long as is consistent with the required
        architecture.
    deepimage: optional model for the images input. Must be an object of class
        DeepImage or a custom model as long as is consistent with the required
        architecture.
    deephead: optional dense model consisting in a stack of dense layers.
    head_layers: optional list with the sizes of the stacked dense layers in
        the fc-head e.g: [128, 64]
    head_dropout: optional list with the dropout between the dense layers.
        e.g: [0.5, 0.5]
    head_batchnorm: Optional Boolean indicating whether or not to include batch
        normalizatin in the dense layers that form the texthead
    output_dim: int size of the final layer. 1 for regression and binary
        classification or 'n_class' for multiclass classification

    ** While I recommend using the Wide and DeepDense classes within this
    package when building the corresponding model components, it is very likely
    that the user will want to use custom text and image models. Simply, build
    them and pass them as the corresponding parameters. Note that the custom
    models must return a last layer of activations (i.e. not the final prediction)
    so that  these activations are collected by WideDeep and combined accordingly.
    In  addition, the classes/models must also contain an attribute 'output_dim'
    with the size of these last layers of activations.

    Attributes
    ----------
    deephead: Sequential stack of dense layers comprising the FC-Head (aka imagehead)
        can be custom designed

    ** The remaining attributes that will be set as we compile and run the model are
        discussed within the corresponding methods.

    Example
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import Wide, DeepDense, DeepText, DeepImage, WideDeep
    >>>
    >>> X_wide = torch.empty(5, 5).random_(2)
    >>> wide = Wide(wide_dim=X_wide.size(0), output_dim=1)
    >>>
    >>> X_deep = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
    >>> embed_input = [(u,i,j) for u,i,j in zip(colnames[:4], [4]*4, [8]*4)]
    >>> deep_column_idx = {k:v for v,k in enumerate(colnames)}
    >>> deepdense = DeepDense(hidden_layers=[8,4], deep_column_idx=deep_column_idx, embed_input=embed_input)
    >>>
    >>> X_text = torch.cat((torch.zeros([5,1]), torch.empty(5, 4).random_(1,4)), axis=1)
    >>> deeptext = DeepText(vocab_size=4, hidden_dim=4, n_layers=1, padding_idx=0, embed_dim=4)
    >>>
    >>> X_img = torch.rand((5,3,224,224))
    >>> deepimage = DeepImage(head_layers=[512, 64, 8])
    >>>
    >>> model = WideDeep(wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage, output_dim=1)
    >>> input_dict = {'wide':X_wide, 'deepdense':X_deep, 'deeptext':X_text, 'deepimage':X_img}
    >>> model(X=input_dict)
    tensor([[-0.3779],
            [-0.5247],
            [-0.2773],
            [-0.2888],
            [-0.2010]], grad_fn=<AddBackward0>)
    """
110
    def __init__(self,
111 112
        wide:nn.Module,
        deepdense:nn.Module,
113
        output_dim:int=1,
114
        deeptext:Optional[nn.Module]=None,
115 116 117 118 119
        deepimage:Optional[nn.Module]=None,
        deephead:Optional[nn.Module]=None,
        head_layers:Optional[List]=None,
        head_dropout:Optional[List]=None,
        head_batchnorm:Optional[bool]=None):
120

121
        super(WideDeep, self).__init__()
122 123

        # The main 5 components of the wide and deep assemble
124 125 126 127
        self.wide = wide
        self.deepdense = deepdense
        self.deeptext  = deeptext
        self.deepimage = deepimage
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
        self.deephead = deephead

        if self.deephead is None:
            if head_layers is not None:
                input_dim = self.deepdense.output_dim + self.deeptext.output_dim + self.deepimage.output_dim
                head_layers = [input_dim] + head_layers
                if not head_dropout: head_dropout = [0.] * (len(head_layers)-1)
                self.deephead = nn.Sequential()
                for i in range(1, len(head_layers)):
                    self.deephead.add_module(
                        'head_layer_{}'.format(i-1),
                        dense_layer( head_layers[i-1], head_layers[i], head_dropout[i-1], head_batchnorm))
                self.deephead.add_module('head_out', nn.Linear(head_layers[-1], output_dim))
            else:
                self.deepdense = nn.Sequential(
                    self.deepdense,
                    nn.Linear(self.deepdense.output_dim, output_dim))
                if self.deeptext is not None:
                    self.deeptext = nn.Sequential(
                        self.deeptext,
                        nn.Linear(self.deeptext.output_dim, output_dim))
                if self.deepimage is not None:
                    self.deepimage = nn.Sequential(
                        self.deepimage,
                        nn.Linear(self.deepimage.output_dim, output_dim))
153

154
    def forward(self, X:List[Dict[str,Tensor]])->Tensor:
155 156 157 158 159 160
        r"""
        Parameters
        ----------
        X: List of Dict where the keys are the model names (wide, deepdense, deeptext
            and deepimage) and the values are the corresponding Tensors
        """
161
        # Wide output: direct connection to the output neuron(s)
162
        out = self.wide(X['wide'])
163 164 165 166 167 168 169 170 171

        # Deep output: either connected directly to the output neuron(s) or
        # passed through a head first
        if self.deephead:
            deepside = self.deepdense(X['deepdense'])
            if self.deeptext is not None:
                deepside = torch.cat( [deepside, self.deeptext(X['deeptext'])], axis=1 )
            if self.deepimage is not None:
                deepside = torch.cat( [deepside, self.deepimage(X['deepimage'])], axis=1 )
172
            deepside_out = self.deephead(deepside)
173 174 175 176 177 178 179 180 181 182 183
            return out.add_(deepside_out)
        else:
            out.add_(self.deepdense(X['deepdense']))
            if self.deeptext is not None:
                out.add_(self.deeptext(X['deeptext']))
            if self.deepimage is not None:
                out.add_(self.deepimage(X['deepimage']))
            return out

    def compile(self,
        method:str,
184 185
        optimizers:Union[Optimizer,Dict[str,Optimizer]],
        lr_schedulers:Optional[Union[LRScheduler,Dict[str,LRScheduler]]]=None,
186 187 188 189 190
        initializers:Optional[Dict[str,Initializer]]=None,
        transforms:Optional[List[Transforms]]=None,
        callbacks:Optional[List[Callback]]=None,
        metrics:Optional[List[Metric]]=None,
        class_weight:Optional[Union[float,List[float],Tuple[float]]]=None,
191 192 193
        with_focal_loss:bool=False,
        alpha:float=0.25,
        gamma:float=1,
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
        verbose:int=1):
        r"""
        Function to set a number of attributes that are used during the training process.

        Parameters
        ----------
        method: required parameter. One of ('regression', 'binary' or 'multiclass')
        optimizers: either an Optimizer object (e.g. torch.optim.Adam()) or a
            dictionary where there keys are the model's children (i.e. wide,
            deepdense, deeptext, deepimage and/or deephead)  and the values
            are the corresponding optimizers. If multiple optimizers are used
            the  dictionary **must** contain an optimizer per child.
            Defaults to Adam.
        lr_schedulers: optional parameter with a LRScheduler object (e.g
            torch.optim.lr_scheduler.StepLR(opt, step_size=5)) or  dictionary
            where there keys are the model's children (i.e. wide, deepdense,
            deeptext, deepimage and/or deephead)  and the values are the
            corresponding learning rate schedulers.
        initializers: optional dictionary where there keys are the model's
            children (i.e. wide, deepdense, deeptext, deepimage and/or deephead)
            and the values are the corresponding initializers.
        transforms: optional List with torchvision.transforms to be applied to
            the image component of the model
        callbacks: optional List with callbacks. Callbacks available are:
            ModelCheckpoint, EarlyStopping, and LRHistory. The History callback is
            used by default.
        metrics: optional List of metrics. Metrics available are: BinaryAccuracy
            and CategoricalAccuracy
        class_weight: optional parameter than can be one of: a float
            indicating the weight of the minority class in binary classification
            problems (e.g. 9.) or a list or tuple with weights for the different
            classes in multiclass classification problems  (e.g. [1., 2., 3.]).
            The weights do not neccesarily need to be normalised. If your loss
            function uses reduction='mean', the loss will be normalized by the sum
            of the corresponding weights for each element. If you are using
            reduction='none', you would have to take care of the normalization
            yourself. See here:
            https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/10
        with_focal_loss: optional boolean indicating whether or not to use the
            Focal Loss. https://arxiv.org/pdf/1708.02002.pdf
        alpha, gamma: Focal Loss parameters. See:
            https://arxiv.org/pdf/1708.02002.pdf
        verbose: int indicating the level of verbosity. Setting it to 0 will
            print nothing during training.

        Attributes
        ----------
        Attributes that are not direct assignations of parameters

        self.cyclic: boolean indicating if any of the lr_schedulers is
            CyclicLR or OneCycleLR

        Example
        --------
        Assuming you have already built the model components (wide, deepdense, etc...)

        >>> from pytorch_widedeep.models import WideDeep
        >>> from pytorch_widedeep.initializers import *
        >>> from pytorch_widedeep.callbacks import *
        >>> from pytorch_widedeep.optim import RAdam
        >>> model = WideDeep(wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage)
        >>> wide_opt = torch.optim.Adam(model.wide.parameters())
        >>> deep_opt = torch.optim.Adam(model.deepdense.parameters())
        >>> text_opt = RAdam(model.deeptext.parameters())
        >>> img_opt  = RAdam(model.deepimage.parameters())
        >>> wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=5)
        >>> deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=3)
        >>> text_sch = torch.optim.lr_scheduler.StepLR(text_opt, step_size=5)
        >>> img_sch  = torch.optim.lr_scheduler.StepLR(img_opt, step_size=3)
        >>> optimizers = {'wide': wide_opt, 'deepdense':deep_opt, 'deeptext':text_opt, 'deepimage': img_opt}
        >>> schedulers = {'wide': wide_sch, 'deepdense':deep_sch, 'deeptext':text_sch, 'deepimage': img_sch}
        >>> initializers = {'wide': Uniform, 'deepdense':Normal, 'deeptext':KaimingNormal,
        >>> ... 'deepimage':KaimingUniform}
        >>> transforms = [ToTensor, Normalize(mean=mean, std=std)]
        >>> callbacks = [LRHistory, EarlyStopping, ModelCheckpoint(filepath='model_weights/wd_out.pt')]
        >>> model.compile(method='regression', initializers=initializers, optimizers=optimizers,
        >>> ... lr_schedulers=schedulers, callbacks=callbacks, transforms=transforms)
        """
272
        self.verbose = verbose
273
        self.early_stop = False
274
        self.method = method
275 276
        self.with_focal_loss = with_focal_loss
        if self.with_focal_loss:
277
            self.alpha, self.gamma = alpha, gamma
278

279
        if isinstance(class_weight, float):
280
            self.class_weight = torch.tensor([1.-class_weight, class_weight])
281 282 283 284
        elif isinstance(class_weight, (List, Tuple)):
            self.class_weight =  torch.tensor(class_weight)
        else:
            self.class_weight = None
285 286

        if initializers is not None:
287
            self.initializer = MultipleInitializer(initializers, verbose=self.verbose)
288 289
            self.initializer.apply(self)

290 291 292 293 294 295
        if isinstance(optimizers, Optimizer):
            self.optimizer = optimizers
        elif len(optimizers)>1:
            opt_names = list(optimizers.keys())
            mod_names = [n  for n, c in self.named_children()]
            for mn in mod_names: assert mn in opt_names, "No optimizer found for {}".format(mn)
296
            self.optimizer = MultipleOptimizer(optimizers)
297 298 299
        else:
            self.optimizer = torch.optim.Adam(self.parameters())

300 301 302 303
        if isinstance(lr_schedulers, LRScheduler):
            self.lr_scheduler = lr_schedulers
            self.cyclic = 'cycl' in self.lr_scheduler.__class__.__name__.lower()
        elif len(lr_schedulers) > 1:
304
            self.lr_scheduler = MultipleLRScheduler(lr_schedulers)
305 306
            scheduler_names = [sc.__class__.__name__.lower() for _,sc in self.lr_scheduler._schedulers.items()]
            self.cyclic = any(['cycl' in sn for sn in scheduler_names])
307
        else:
308
            self.lr_scheduler, self.cyclic = None, False
309

310 311 312 313 314
        if transforms is not None:
            self.transforms = MultipleTransforms(transforms)()
        else:
            self.transforms = None

315
        self.history = History()
316 317
        self.callbacks = [self.history]
        if callbacks is not None:
318 319 320
            for callback in callbacks:
                if isinstance(callback, type): callback = callback()
                self.callbacks.append(callback)
321 322 323

        if metrics is not None:
            self.metric = MultipleMetrics(metrics)
324
            self.callbacks += [MetricCallback(self.metric)]
325 326
        else:
            self.metric = None
327

328 329
        self.callback_container = CallbackContainer(self.callbacks)
        self.callback_container.set_model(self)
330

331 332 333
    def _activation_fn(self, inp:Tensor) -> Tensor:
        if self.method == 'regression':
            return inp
334
        if self.method == 'binary':
335 336 337 338 339
            return torch.sigmoid(inp)
        if self.method == 'multiclass':
            return F.softmax(inp, dim=1)

    def _loss_fn(self, y_pred:Tensor, y_true:Tensor) -> Tensor:
340 341
        if self.with_focal_loss:
            return FocalLoss(self.alpha, self.gamma)(y_pred, y_true)
342 343
        if self.method == 'regression':
            return F.mse_loss(y_pred, y_true.view(-1, 1))
344
        if self.method == 'binary':
345 346 347 348
            return F.binary_cross_entropy(y_pred, y_true.view(-1, 1), weight=self.class_weight)
        if self.method == 'multiclass':
            return F.cross_entropy(y_pred, y_true, weight=self.class_weight)

349
    def _training_step(self, data:Dict[str, Tensor], target:Tensor, batch_idx:int):
350
        self.train()
351
        X = {k:v.cuda() for k,v in data.items()} if use_cuda else data
352 353 354 355
        y = target.float() if self.method != 'multiclass' else target
        y = y.cuda() if use_cuda else y

        self.optimizer.zero_grad()
356 357
        y_pred =  self._activation_fn(self.forward(X))
        loss = self._loss_fn(y_pred, y)
358 359 360
        loss.backward()
        self.optimizer.step()

361 362
        self.train_running_loss += loss.item()
        avg_loss = self.train_running_loss/(batch_idx+1)
363

364 365
        if self.metric is not None:
            acc = self.metric(y_pred, y)
366
            return acc, avg_loss
367
        else:
368
            return None, avg_loss
369

370
    def _validation_step(self, data:Dict[str, Tensor], target:Tensor, batch_idx:int):
371

372
        self.eval()
373
        with torch.no_grad():
374
            X = {k:v.cuda() for k,v in data.item()} if use_cuda else data
375 376 377
            y = target.float() if self.method != 'multiclass' else target
            y = y.cuda() if use_cuda else y

378 379
            y_pred = self._activation_fn(self.forward(X))
            loss = self._loss_fn(y_pred, y)
380 381 382 383 384
            self.valid_running_loss += loss.item()
            avg_loss = self.valid_running_loss/(batch_idx+1)

        if self.metric is not None:
            acc = self.metric(y_pred, y)
385
            return acc, avg_loss
386
        else:
387 388
            return None, avg_loss

389
    def _lr_scheduler_step(self, step_location:str):
390 391 392 393 394 395 396 397 398 399 400
        r"""
        Function to execute the learning rate schedulers steps.
        If the lr_scheduler is Cyclic (i.e. CyclicLR or OneCycleLR), the step
        must  happen after training each bach durig training. On the other
        hand, if the  scheduler is not Cyclic, is expected to be called after
        validation.

        Parameters
        ----------
        step_location: string indicating where to run the lr_scheduler step
        """
401 402
        if self.lr_scheduler.__class__.__name__ == 'MultipleLRScheduler' and self.cyclic:
            if step_location == 'on_batch_end':
403 404
                for model_name, scheduler in self.lr_scheduler._schedulers.items():
                    if 'cycl' in scheduler.__class__.__name__.lower(): scheduler.step()
405 406
            elif step_location == 'on_epoch_end':
                for scheduler_name, scheduler in self.lr_scheduler._schedulers.items():
407
                    if 'cycl' not in scheduler.__class__.__name__.lower(): scheduler.step()
408 409 410 411 412 413 414 415 416
        elif self.cyclic:
            if step_location == 'on_batch_end': self.lr_scheduler.step()
            else: pass
        elif self.lr_scheduler.__class__.__name__ == 'MultipleLRScheduler':
            if step_location == 'on_epoch_end': self.lr_scheduler.step()
            else: pass
        elif step_location == 'on_epoch_end': self.lr_scheduler.step()
        else: pass

417 418 419 420
    def _train_val_split(self,
        X_wide:Optional[np.ndarray]=None,
        X_deep:Optional[np.ndarray]=None,
        X_text:Optional[np.ndarray]=None,
421 422 423
        X_img:Optional[np.ndarray]=None,
        X_train:Optional[Dict[str,np.ndarray]]=None,
        X_val:Optional[Dict[str,np.ndarray]]=None,
424 425
        val_split:Optional[float]=None,
        target:Optional[np.ndarray]=None,
426
        seed:int=1):
427 428 429 430 431 432 433 434 435 436 437 438 439 440
        r"""
        If a validation set (X_val) is passed to the fit method, or val_split
        is specified, the train/val split will happen internally. A number of
        options are allowed in terms of data inputs. For parameter
        information, please, see the .fit() method documentation

        Returns
        -------
        train_set: WideDeepDataset object that will be loaded through
            torch.utils.data.DataLoader
        eval_set : WideDeepDataset object that will be loaded through
            torch.utils.data.DataLoader
        """
        # Without validation
441
        if X_val is None and val_split is None:
442 443
            # if a train dictionary is passed, check if text and image datasets
            # are present and instantiate the WideDeepDataset class
444
            if X_train is not None:
445 446 447 448 449 450 451 452
                X_wide, X_deep, target = X_train['X_wide'], X_train['X_deep'], X_train['target']
                if 'X_text' in X_train.keys(): X_text = X_train['X_text']
                if 'X_img' in X_train.keys(): X_img = X_train['X_img']
            X_train={'X_wide': X_wide, 'X_deep': X_deep, 'target': target}
            try: X_train.update({'X_text': X_text})
            except: pass
            try: X_train.update({'X_img': X_img})
            except: pass
453
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)
454
            eval_set = None
455
        # With validation
456 457
        else:
            if X_val is not None:
458 459 460
                # if a validation dictionary is passed, then if not train
                # dictionary is passed we build it with the input arrays
                # (either the dictionary or the arrays must be passed)
461 462 463 464 465
                if X_train is None:
                    X_train = {'X_wide':X_wide, 'X_deep': X_deep, 'target': target}
                    if X_text is not None: X_train.update({'X_text': X_text})
                    if X_img is not None:  X_train.update({'X_img': X_img})
            else:
466 467
                # if a train dictionary is passed, check if text and image
                # datasets are present. The train/val split using val_split
468
                if X_train is not None:
469
                    X_wide, X_deep, target = X_train['X_wide'], X_train['X_deep'], X_train['target']
470 471 472 473
                    if 'X_text' in X_train.keys(): X_text = X_train['X_text']
                    if 'X_img' in X_train.keys(): X_img = X_train['X_img']
                X_tr_wide, X_val_wide, X_tr_deep, X_val_deep, y_tr, y_val = train_test_split(X_wide,
                    X_deep, target, test_size=val_split, random_state=seed)
474 475
                X_train = {'X_wide':X_tr_wide, 'X_deep': X_tr_deep, 'target': y_tr}
                X_val = {'X_wide':X_val_wide, 'X_deep': X_val_deep, 'target': y_val}
476
                try:
477 478
                    X_tr_text, X_val_text = train_test_split(X_text, test_size=val_split,
                        random_state=seed)
479
                    X_train.update({'X_text': X_tr_text}), X_val.update({'X_text': X_val_text})
480 481
                except: pass
                try:
482 483
                    X_tr_img, X_val_img = train_test_split(X_img, test_size=val_split,
                        random_state=seed)
484 485
                    X_train.update({'X_img': X_tr_img}), X_val.update({'X_img': X_val_img})
                except: pass
486
            # At this point the X_train and X_val dictionaries have been built
487 488
            train_set = WideDeepDataset(**X_train, transforms=self.transforms)
            eval_set = WideDeepDataset(**X_val, transforms=self.transforms)
489 490
        return train_set, eval_set

491 492 493 494 495
    def fit(self,
        X_wide:Optional[np.ndarray]=None,
        X_deep:Optional[np.ndarray]=None,
        X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None,
496 497
        X_train:Optional[Dict[str,np.ndarray]]=None,
        X_val:Optional[Dict[str,np.ndarray]]=None,
498 499 500 501 502 503
        val_split:Optional[float]=None,
        target:Optional[np.ndarray]=None,
        n_epochs:int=1,
        batch_size:int=32,
        patience:int=10,
        seed:int=1):
504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552
        r"""
        fit method that must run after calling 'compile'

        Parameters
        ----------
        X_wide: optional np.array with the one hot encoded wide input.
        X_deep: optional np.array with the input for the deepdense model
        X_text: optional np.array with the input for the deeptext model
        X_img : optional np.array with the input for the deepimage model
        X_train: optional Dict with the training dataset for the different
            model branches.  Keys are 'X_wide', 'X_deep', 'X_text', 'X_img' and 'target'
            the values are the corresponding matrices
            e.g X_train = {'X_wide': X_wide, 'X_wide': X_wide, 'X_text': X_text, 'X_img': X_img}
        X_val: optional Dict with the validation dataset for the different
            model branches.  Keys are 'X_wide', 'X_deep', 'X_text', 'X_img' and 'target'
            the values are the corresponding matrices
            e.g X_val = {'X_wide': X_wide, 'X_wide': X_wide, 'X_text': X_text, 'X_img': X_img}
        val_split: optional float specifying the train/val split
        target: optional np.array with the target values
        n_epochs: number of epochs
        batch_size: batch size
        patience: number of epochs without improving the target metric before
            we stop the fit
        seed: random seed for the train/val split

        **WideDeep assumes that X_wide, X_deep and target ALWAYS exist, while
        X_text and X_img are optional
        **Either X_train or X_wide, X_deep and target must be passed to the
        fit method

        Example
        --------
        Assuming you have already built and compiled the model

        Ex 1. using train input arrays directly and no validation
        >>> model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256)

        Ex 2: using train input arrays directly and validation with val_split
        >>> model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256, val_split=0.2)

        Ex 3: using train dict and val_split
        >>> X_train = {'X_wide': X_wide, 'X_deep': X_deep, 'target': y}
        >>> model.fit(X_train, n_epochs=10, batch_size=256, val_split=0.2)

        Ex 4: validation using training and validation dicts
        >>> X_train = {'X_wide': X_wide_tr, 'X_deep': X_deep_tr, 'target': y_tr}
        >>> X_val = {'X_wide': X_wide_val, 'X_deep': X_deep_val, 'target': y_val}
        >>> model.fit(X_train=X_train, X_val=X_val n_epochs=10, batch_size=256)
        """
553 554 555

        if X_train is None and (X_wide is None or X_deep is None or target is None):
            raise ValueError(
556 557
                "Training data is missing. Either a dictionary (X_train) with "
                "the training dataset or at least 3 arrays (X_wide, X_deep, "
558
                "target) must be passed to the fit method")
559 560

        self.batch_size = batch_size
561 562
        train_set, eval_set = self._train_val_split(X_wide, X_deep, X_text, X_img,
            X_train, X_val, val_split, target, seed)
563
        train_loader = DataLoader(dataset=train_set, batch_size=batch_size, num_workers=8)
564 565 566
        train_steps =  (len(train_loader.dataset) // batch_size) + 1
        self.callback_container.on_train_begin({'batch_size': batch_size,
            'train_steps': train_steps, 'n_epochs': n_epochs})
567

568
        for epoch in range(n_epochs):
569
            # train step...
570 571
            epoch_logs={}
            self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
572
            self.train_running_loss = 0.
573
            with trange(train_steps, disable=self.verbose != 1) as t:
574
                for batch_idx, (data,target) in zip(t, train_loader):
575
                    t.set_description('epoch %i' % (epoch+1))
576
                    acc, train_loss = self._training_step(data, target, batch_idx)
577 578
                    if acc is not None:
                        t.set_postfix(metrics=acc, loss=train_loss)
579
                    else:
580
                        t.set_postfix(loss=np.sqrt(train_loss))
581 582
                    if self.lr_scheduler: self._lr_scheduler_step(step_location='on_batch_end')
                    self.callback_container.on_batch_end(batch=batch_idx)
583
            epoch_logs['train_loss'] = train_loss
584
            if acc is not None: epoch_logs['train_acc'] = acc['acc']
585 586 587 588 589
            # eval step...
            if eval_set is not None:
                eval_loader = DataLoader(dataset=eval_set, batch_size=batch_size, num_workers=8,
                    shuffle=False)
                eval_steps =  (len(eval_loader.dataset) // batch_size) + 1
590
                self.valid_running_loss = 0.
591
                with trange(eval_steps, disable=self.verbose != 1) as v:
592 593
                    for i, (data,target) in zip(v, eval_loader):
                        v.set_description('valid')
594
                        acc, val_loss = self._validation_step(data, target, i)
595 596
                        if acc is not None:
                            v.set_postfix(metrics=acc, loss=val_loss)
597
                        else:
598 599
                            v.set_postfix(loss=np.sqrt(val_loss))
                epoch_logs['val_loss'] = val_loss
600 601 602 603
                if acc is not None: epoch_logs['val_acc'] = acc['acc']
            if self.lr_scheduler: self._lr_scheduler_step(step_location='on_epoch_end')
            # log and check if early_stop...
            self.callback_container.on_epoch_end(epoch, epoch_logs)
604
            if self.early_stop:
605
                self.callback_container.on_train_end(epoch)
606
                break
607
            self.callback_container.on_train_end(epoch)
608 609 610 611 612 613 614 615 616
        self.train()

    def _predict(self, X_wide:np.ndarray, X_deep:np.ndarray, X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None, X_test:Optional[Dict[str, np.ndarray]]=None)->List:
        r"""
        Hidden method to avoid code repetition in predict and predict_proba.
        For parameter information, please, see the .predict() method
        documentation
        """
617
        if X_test is not None:
618
            test_set = WideDeepDataset(**X_test)
619
        else:
620
            load_dict = {'X_wide': X_wide, 'X_deep': X_deep}
621 622
            if X_text is not None: load_dict.update({'X_text': X_text})
            if X_img is not None:  load_dict.update({'X_img': X_img})
623
            test_set = WideDeepDataset(**load_dict)
624 625 626 627

        test_loader = torch.utils.data.DataLoader(dataset=test_set,
            batch_size=self.batch_size,shuffle=False)
        test_steps =  (len(test_loader.dataset) // test_loader.batch_size) + 1
628

629
        self.eval()
630 631
        preds_l = []
        with torch.no_grad():
632
            with trange(test_steps, disable=self.verbose != 1) as t:
633
                for i, data in zip(t, test_loader):
634
                    t.set_description('predict')
635
                    X = {k:v.cuda() for k,v in data.items()} if use_cuda else data
636
                    preds = self._activation_fn(self.forward(X)).cpu().data.numpy()
637
                    preds_l.append(preds)
638 639
        self.train()
        return preds_l
640

641
    def predict(self, X_wide:np.ndarray, X_deep:np.ndarray, X_text:Optional[np.ndarray]=None,
642
        X_img:Optional[np.ndarray]=None, X_test:Optional[Dict[str, np.ndarray]]=None)->np.ndarray:
643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672
        r"""
        fit method that must run after calling 'compile'

        Parameters
        ----------
        X_wide: optional np.array with the one hot encoded wide input.
        X_deep: optional np.array with the input for the deepdense model
        X_text: optional np.array with the input for the deeptext model
        X_img : optional np.array with the input for the deepimage model
        X_test: optional Dict with the testing dataset for the different
            model branches.  Keys are 'X_wide', 'X_deep', 'X_text', 'X_img' and 'target'
            the values are the corresponding matrices
            e.g X_test = {'X_wide': X_wide_te, 'X_wide': X_wide_te, 'X_text': X_text_te, 'X_img': X_img_te}

        **WideDeep assumes that X_wide, X_deep and target ALWAYS exist, while
        X_text and X_img are optional

        Returns
        -------
        preds: np.array with the predicted target for the test dataset.
        """
        preds_l = _predict(X_wide, X_deep, X_text, X_img, X_test)
        if self.method == "regression":
            return np.vstack(preds_l).squeeze(1)
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
            return (preds > 0.5).astype('int')
        if self.method == "multiclass":
            preds = np.vstack(preds_l)
            return np.argmax(preds, 1)
673

674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690
    def predict_proba(self, X_wide:np.ndarray, X_deep:np.ndarray, X_text:Optional[np.ndarray]=None,
        X_img:Optional[np.ndarray]=None, X_test:Optional[Dict[str, np.ndarray]]=None)->np.ndarray:
        """
        Returns
        -------
        preds: np.array with the predicted probabilities of target for the
        test dataset for  binary and multiclass methods
        """
        preds_l = _predict(X_wide, X_deep, X_text, X_img, X_test)
        if self.method == "binary":
            preds = np.vstack(preds_l).squeeze(1)
            probs = np.zeros([preds.shape[0],2])
            probs[:,0] = 1-preds
            probs[:,1] = preds
            return probs
        if self.method == "multiclass":
            return np.vstack(preds_l)
691

692
    def get_embeddings(self, col_name:str,
693
        cat_encoding_dict:Dict[str,Dict[str,int]]) -> Dict[str,np.ndarray]:
694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727
        """
        Get the learned embeddings for the categorical features passed through deepdense.

        Parameters
        ----------
        col_name: column name of the feature we want to get the embeddings for
        cat_encoding_dict: Dict with the categorical encodings. The function
            is designed to take the 'encoding_dict' attribute from the
            DeepPreprocessor class. Any Dict with the same structure can be used

        Returns
        -------
        cat_embed_dict: Dict with the categorical levels of the col_name
            feature and the corresponding embeddings

        Example:
        -------
        Assuming we have already train the model:

        >>> model.get_embeddings(col_name='education', cat_encoding_dict=deep_preprocessor.encoding_dict)
        {'11th': array([-0.42739448, -0.22282735,  0.36969638,  0.4445322 ,  0.2562272 ,
        0.11572784, -0.01648579,  0.09027119,  0.0457597 , -0.28337458], dtype=float32),
         'HS-grad': array([-0.10600474, -0.48775527,  0.3444158 ,  0.13818645, -0.16547225,
        0.27409762, -0.05006042, -0.0668492 , -0.11047247,  0.3280354 ], dtype=float32),
        ...
        }

        where:

        >>> deep_preprocessor.encoding_dict['education']
        {'11th': 0, 'HS-grad': 1, 'Assoc-acdm': 2, 'Some-college': 3, '10th': 4, 'Prof-school': 5,
        '7th-8th': 6, 'Bachelors': 7, 'Masters': 8, 'Doctorate': 9, '5th-6th': 10, 'Assoc-voc': 11,
        '9th': 12, '12th': 13, '1st-4th': 14, 'Preschool': 15}
        """
728 729 730 731 732 733 734 735 736
        for n,p in self.named_parameters():
            if 'embed_layers' in n and col_name in n:
                embed_mtx = p.cpu().data.numpy()
        encoding_dict = cat_encoding_dict[col_name]
        inv_encoding_dict = {v:k for k,v in encoding_dict.items()}
        cat_embed_dict = {}
        for idx,value in inv_encoding_dict.items():
            cat_embed_dict[value] = embed_mtx[idx]
        return cat_embed_dict