wide_deep.py 16.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
"""
During the development of the package I realised that there is a typing
inconsistency. The input components of a Wide and Deep model are of type
nn.Module. These change type internally to nn.Sequential. While nn.Sequential
is an instance of nn.Module the oppossite is, of course, not true. This does
not affect any funcionality of the package, but it is something that needs
fixing. However, while fixing is simple (simply define new attributes that
are the nn.Sequential objects), its implications are quite wide within the
package (involves changing a number of tests and tutorials). Therefore, I
will introduce that fix when I do a major release. For now, we live with it.
"""

13
import warnings
14

15
import torch
16 17
import torch.nn as nn

18 19
from pytorch_widedeep.wdtypes import *  # noqa: F403
from pytorch_widedeep.models.tab_mlp import MLP
20
from pytorch_widedeep.models.tabnet.tab_net import TabNetPredLayer
J
jrzaurin 已提交
21

22
warnings.filterwarnings("default", category=UserWarning)
23

24
use_cuda = torch.cuda.is_available()
25
device = torch.device("cuda" if use_cuda else "cpu")
26 27 28


class WideDeep(nn.Module):
29 30 31 32 33
    r"""Main collector class that combines all ``wide``, ``deeptabular``
    (which can be a number of architectures), ``deeptext`` and
    ``deepimage`` models.

    There are two options to combine these models that correspond to the
34
    two main architectures that ``pytorch-widedeep`` can build.
35 36 37 38 39 40 41 42 43 44

        - Directly connecting the output of the model components to an ouput neuron(s).

        - Adding a `Fully-Connected Head` (FC-Head) on top of the deep models.
          This FC-Head will combine the output form the ``deeptabular``, ``deeptext`` and
          ``deepimage`` and will be then connected to the output neuron(s).

    Parameters
    ----------
    wide: ``nn.Module``, Optional, default = None
45
        ``Wide`` model. I recommend using the :obj:`Wide` class in this
46 47 48 49 50 51 52
        package. However, it is possible to use a custom model as long as
        is consistent with the required architecture, see
        :class:`pytorch_widedeep.models.wide.Wide`
    deeptabular: ``nn.Module``, Optional, default = None

        currently ``pytorch-widedeep`` implements four possible
        architectures for the `deeptabular` component. These are:
53
        TabMlp, TabResnet, TabNet, TabTransformer and SAINT.
54

55
        1. TabMlp is simply an embedding layer encoding the categorical
56 57
        features that are then concatenated and passed through a series of
        dense (hidden) layers (i.e. and MLP).
58
        See: :obj:`pytorch_widedeep.models.tab_mlp.TabMlp`
59

60
        2. TabResnet is an embedding layer encoding the categorical
61 62
        features that are then concatenated and passed through a series of
        ResNet blocks formed by dense layers.
63
        See :obj:`pytorch_widedeep.models.tab_resnet.TabResnet`
64

65
        3. TabNet is detailed in `TabNet: Attentive Interpretable Tabular
66 67 68 69
        Learning <https://arxiv.org/abs/1908.07442>`_. The TabNet
        implementation in ``pytorch_widedeep`` is an adaptation of the
        `dreamquark-ai <https://github.com/dreamquark-ai/tabnet>`_
        implementation. See
70
        :obj:`pytorch_widedeep.models.tabnet.tab_net.TabNet`
71

72
        3. TabTransformer is detailed in `TabTransformer: Tabular Data
73
        Modeling Using Contextual Embeddings
74 75 76 77 78 79 80 81 82 83 84
        <https://arxiv.org/abs/2012.06678>`_. The TabTransformer
        implementation in ``pytorch-widedeep`` is an adaptation of the
        original implementation. See
        :obj:`pytorch_widedeep.models.transformers.tab_transformer.TabTransformer`.

        3. SAINT is detailed in `SAINT: Improved Neural Networks for Tabular
        Data via Row Attention and Contrastive Pre-Training
        <https://arxiv.org/abs/2106.01342>`_. The SAINT implementation in
        ``pytorch-widedeep`` is an adaptation of the original implementation.
        See
        :obj:`pytorch_widedeep.models.transformers.saint.SAINT`.
85 86 87 88 89 90 91 92 93

        I recommend using on of these as ``deeptabular``. However, it is
        possible to use a custom model as long as is  consistent with the
        required architecture.

    deeptext: ``nn.Module``, Optional, default = None
        Model for the text input. Must be an object of class ``DeepText``
        or a custom model as long as is consistent with the required
        architecture. See
94
        :class:`pytorch_widedeep.models.deep_text.DeepText`
95 96 97 98
    deepimage: ``nn.Module``, Optional, default = None
        Model for the images input. Must be an object of class
        ``DeepImage`` or a custom model as long as is consistent with the
        required architecture. See
99
        :class:`pytorch_widedeep.models.deep_image.DeepImage`
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
    deephead: ``nn.Module``, Optional, default = None
        Custom model by the user that will receive the outtput of the deep
        component. Typically a FC-Head (MLP)
    head_hidden_dims: List, Optional, default = None
        Alternatively, the ``head_hidden_dims`` param can be used to
        specify the sizes of the stacked dense layers in the fc-head e.g:
        ``[128, 64]``. Use ``deephead`` or ``head_hidden_dims``, but not
        both.
    head_dropout: float, default = 0.1
        If ``head_hidden_dims`` is not None, dropout between the layers in
        ``head_hidden_dims``
    head_activation: str, default = "relu"
        If ``head_hidden_dims`` is not None, activation function of the
        head layers. One of "relu", gelu" or "leaky_relu"
    head_batchnorm: bool, default = False
        If ``head_hidden_dims`` is not None, specifies if batch
        normalizatin should be included in the head layers
    head_batchnorm_last: bool, default = False
        If ``head_hidden_dims`` is not None, boolean indicating whether or
        not to apply batch normalization to the last of the dense layers
    head_linear_first: bool, default = False
        If ``head_hidden_dims`` is not None, boolean indicating whether
        the order of the operations in the dense layer. If ``True``:
        ``[LIN -> ACT -> BN -> DP]``. If ``False``: ``[BN -> DP -> LIN ->
        ACT]``
    pred_dim: int, default = 1
        Size of the final wide and deep output layer containing the
        predictions. `1` for regression and binary classification or number
        of classes for multiclass classification.

    Examples
    --------

    >>> from pytorch_widedeep.models import TabResnet, DeepImage, DeepText, Wide, WideDeep
    >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
    >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
    >>> wide = Wide(10, 1)
    >>> deeptabular = TabResnet(blocks_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
    >>> deeptext = DeepText(vocab_size=10, embed_dim=4, padding_idx=0)
    >>> deepimage = DeepImage(pretrained=False)
    >>> model = WideDeep(wide=wide, deeptabular=deeptabular, deeptext=deeptext, deepimage=deepimage)


    .. note:: While I recommend using the ``wide`` and ``deeptabular`` components
        within this package when building the corresponding model components,
        it is very likely that the user will want to use custom text and image
        models. That is perfectly possible. Simply, build them and pass them
        as the corresponding parameters. Note that the custom models MUST
        return a last layer of activations (i.e. not the final prediction) so
        that  these activations are collected by ``WideDeep`` and combined
        accordingly. In addition, the models MUST also contain an attribute
        ``output_dim`` with the size of these last layers of activations. See
        for example :class:`pytorch_widedeep.models.tab_mlp.TabMlp`

    """

156
    def __init__(
J
jrzaurin 已提交
157
        self,
158
        wide: Optional[nn.Module] = None,
159
        deeptabular: Optional[nn.Module] = None,
J
jrzaurin 已提交
160 161 162
        deeptext: Optional[nn.Module] = None,
        deepimage: Optional[nn.Module] = None,
        deephead: Optional[nn.Module] = None,
163
        head_hidden_dims: Optional[List[int]] = None,
164 165 166 167 168
        head_activation: str = "relu",
        head_dropout: float = 0.1,
        head_batchnorm: bool = False,
        head_batchnorm_last: bool = False,
        head_linear_first: bool = False,
169
        pred_dim: int = 1,
J
jrzaurin 已提交
170
    ):
171
        super(WideDeep, self).__init__()
172

J
jrzaurin 已提交
173 174
        self._check_model_components(
            wide,
175
            deeptabular,
J
jrzaurin 已提交
176 177 178
            deeptext,
            deepimage,
            deephead,
179
            head_hidden_dims,
J
jrzaurin 已提交
180
            pred_dim,
181
        )
182

183 184 185
        # required as attribute just in case we pass a deephead
        self.pred_dim = pred_dim

186
        # The main 5 components of the wide and deep assemble
187
        self.wide = wide
188
        self.deeptabular = deeptabular
J
jrzaurin 已提交
189
        self.deeptext = deeptext
190
        self.deepimage = deepimage
191 192
        self.deephead = deephead

193 194
        if self.deeptabular is not None:
            self.is_tabnet = deeptabular.__class__.__name__ == "TabNet"
195 196
        else:
            self.is_tabnet = False
197

198
        if self.deephead is None:
199
            if head_hidden_dims is not None:
200 201 202 203 204 205 206 207
                self._build_deephead(
                    head_hidden_dims,
                    head_activation,
                    head_dropout,
                    head_batchnorm,
                    head_batchnorm_last,
                    head_linear_first,
                )
208 209 210 211 212 213 214 215 216 217
            else:
                self._add_pred_layer()

    def forward(self, X: Dict[str, Tensor]):
        wide_out = self._forward_wide(X)
        if self.deephead:
            return self._forward_deephead(X, wide_out)
        else:
            return self._forward_deep(X, wide_out)

218 219 220 221 222 223 224 225 226
    def _build_deephead(
        self,
        head_hidden_dims,
        head_activation,
        head_dropout,
        head_batchnorm,
        head_batchnorm_last,
        head_linear_first,
    ):
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
        deep_dim = 0
        if self.deeptabular is not None:
            deep_dim += self.deeptabular.output_dim
        if self.deeptext is not None:
            deep_dim += self.deeptext.output_dim
        if self.deepimage is not None:
            deep_dim += self.deepimage.output_dim

        head_hidden_dims = [deep_dim] + head_hidden_dims
        self.deephead = MLP(
            head_hidden_dims,
            head_activation,
            head_dropout,
            head_batchnorm,
            head_batchnorm_last,
            head_linear_first,
        )

        self.deephead.add_module(
            "head_out", nn.Linear(head_hidden_dims[-1], self.pred_dim)
        )

    def _add_pred_layer(self):
        if self.deeptabular is not None:
            if self.is_tabnet:
                self.deeptabular = nn.Sequential(
                    self.deeptabular,
                    TabNetPredLayer(self.deeptabular.output_dim, self.pred_dim),
J
jrzaurin 已提交
255
                )
256
            else:
257 258 259 260 261 262 263 264 265 266 267 268 269 270
                self.deeptabular = nn.Sequential(
                    self.deeptabular,
                    nn.Linear(self.deeptabular.output_dim, self.pred_dim),
                )
        if self.deeptext is not None:
            self.deeptext = nn.Sequential(
                self.deeptext, nn.Linear(self.deeptext.output_dim, self.pred_dim)
            )
        if self.deepimage is not None:
            self.deepimage = nn.Sequential(
                self.deepimage, nn.Linear(self.deepimage.output_dim, self.pred_dim)
            )

    def _forward_wide(self, X):
271 272 273 274
        if self.wide is not None:
            out = self.wide(X["wide"])
        else:
            batch_size = X[list(X.keys())[0]].size(0)
275
            out = torch.zeros(batch_size, self.pred_dim).to(device)
276

277 278 279 280 281 282 283 284
        return out

    def _forward_deephead(self, X, wide_out):
        if self.deeptabular is not None:
            if self.is_tabnet:
                tab_out = self.deeptabular(X["deeptabular"])
                deepside, M_loss = tab_out[0], tab_out[1]
            else:
285
                deepside = self.deeptabular(X["deeptabular"])
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
        else:
            deepside = torch.FloatTensor().to(device)
        if self.deeptext is not None:
            deepside = torch.cat([deepside, self.deeptext(X["deeptext"])], axis=1)
        if self.deepimage is not None:
            deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1)

        deephead_out = self.deephead(deepside)
        deepside_out = nn.Linear(deephead_out.size(1), self.pred_dim).to(device)

        if self.is_tabnet:
            res = (wide_out.add_(deepside_out(deephead_out)), M_loss)
        else:
            res = wide_out.add_(deepside_out(deephead_out))

        return res

    def _forward_deep(self, X, wide_out):
        if self.deeptabular is not None:
            if self.is_tabnet:
                tab_out, M_loss = self.deeptabular(X["deeptabular"])
                wide_out.add_(tab_out)
308
            else:
309 310 311 312 313 314 315 316
                wide_out.add_(self.deeptabular(X["deeptabular"]))
        if self.deeptext is not None:
            wide_out.add_(self.deeptext(X["deeptext"]))
        if self.deepimage is not None:
            wide_out.add_(self.deepimage(X["deepimage"]))

        if self.is_tabnet:
            res = (wide_out, M_loss)
317
        else:
318 319 320
            res = wide_out

        return res
321

J
jrzaurin 已提交
322
    @staticmethod  # noqa: C901
J
jrzaurin 已提交
323 324
    def _check_model_components(
        wide,
325
        deeptabular,
J
jrzaurin 已提交
326 327 328
        deeptext,
        deepimage,
        deephead,
329
        head_hidden_dims,
J
jrzaurin 已提交
330
        pred_dim,
331 332
    ):

J
jrzaurin 已提交
333 334 335 336 337 338 339
        if wide is not None:
            assert wide.wide_linear.weight.size(1) == pred_dim, (
                "the 'pred_dim' of the wide component ({}) must be equal to the 'pred_dim' "
                "of the deep component and the overall model itself ({})".format(
                    wide.wide_linear.weight.size(1), pred_dim
                )
            )
340
        if deeptabular is not None and not hasattr(deeptabular, "output_dim"):
341
            raise AttributeError(
342
                "deeptabular model must have an 'output_dim' attribute. "
343
                "See pytorch-widedeep.models.deep_text.DeepText"
344
            )
345 346 347 348 349 350 351 352 353
        if deeptabular is not None:
            is_tabnet = deeptabular.__class__.__name__ == "TabNet"
            has_wide_text_or_image = (
                wide is not None or deeptext is not None or deepimage is not None
            )
            if is_tabnet and has_wide_text_or_image:
                warnings.warn(
                    "'WideDeep' is a model comprised by multiple components and the 'deeptabular'"
                    " component is 'TabNet'. We recommend using 'TabNet' in isolation."
354
                    " The reasons are: i)'TabNet' uses sparse regularization which partially losses"
355 356
                    " its purpose when used in combination with other components."
                    " If you still want to use a multiple component model with 'TabNet',"
357 358 359
                    " consider setting 'lambda_sparse' to 0 during training. ii) The feature"
                    " importances will be computed only for TabNet but the model will comprise multiple"
                    " components. Therefore, such importances will partially lose their 'meaning'.",
360 361
                    UserWarning,
                )
362 363 364
        if deeptext is not None and not hasattr(deeptext, "output_dim"):
            raise AttributeError(
                "deeptext model must have an 'output_dim' attribute. "
365
                "See pytorch-widedeep.models.deep_text.DeepText"
366 367 368 369
            )
        if deepimage is not None and not hasattr(deepimage, "output_dim"):
            raise AttributeError(
                "deepimage model must have an 'output_dim' attribute. "
370
                "See pytorch-widedeep.models.deep_text.DeepText"
371
            )
372
        if deephead is not None and head_hidden_dims is not None:
373
            raise ValueError(
374
                "both 'deephead' and 'head_hidden_dims' are not None. Use one of the other, but not both"
375
            )
376
        if (
377
            head_hidden_dims is not None
378 379 380 381
            and not deeptabular
            and not deeptext
            and not deepimage
        ):
382
            raise ValueError(
383
                "if 'head_hidden_dims' is not None, at least one deep component must be used"
384
            )
J
jrzaurin 已提交
385 386 387
        if deephead is not None:
            deephead_inp_feat = next(deephead.parameters()).size(1)
            output_dim = 0
388 389
            if deeptabular is not None:
                output_dim += deeptabular.output_dim
J
jrzaurin 已提交
390 391 392 393 394 395 396 397 398 399
            if deeptext is not None:
                output_dim += deeptext.output_dim
            if deepimage is not None:
                output_dim += deepimage.output_dim
            assert deephead_inp_feat == output_dim, (
                "if a custom 'deephead' is used its input features ({}) must be equal to "
                "the output features of the deep component ({})".format(
                    deephead_inp_feat, output_dim
                )
            )