wide_deep.py 15.7 KB
Newer Older
1
import warnings
2

3
import torch
4 5
import torch.nn as nn

6 7
from pytorch_widedeep.wdtypes import *  # noqa: F403
from pytorch_widedeep.models.tab_mlp import MLP
8
from pytorch_widedeep.models.tabnet.tab_net import TabNetPredLayer
J
jrzaurin 已提交
9

10
warnings.filterwarnings("default", category=UserWarning)
11

12
use_cuda = torch.cuda.is_available()
13
device = torch.device("cuda" if use_cuda else "cpu")
14 15 16


class WideDeep(nn.Module):
17
    def __init__(
J
jrzaurin 已提交
18
        self,
19
        wide: Optional[nn.Module] = None,
20
        deeptabular: Optional[nn.Module] = None,
J
jrzaurin 已提交
21 22 23
        deeptext: Optional[nn.Module] = None,
        deepimage: Optional[nn.Module] = None,
        deephead: Optional[nn.Module] = None,
24
        head_hidden_dims: Optional[List[int]] = None,
25 26 27 28 29
        head_activation: str = "relu",
        head_dropout: float = 0.1,
        head_batchnorm: bool = False,
        head_batchnorm_last: bool = False,
        head_linear_first: bool = False,
30
        pred_dim: int = 1,
J
jrzaurin 已提交
31
    ):
32
        r"""Main collector class that combines all ``wide``, ``deeptabular``
33 34
        (which can be a number of architectures), ``deeptext`` and
        ``deepimage`` models.
35

36 37 38
        There are two options to combine these models that correspond to the
        two main architectures (there is a higher number of
        "sub-architectures") that ``pytorch-widedeep`` can build.
39 40 41 42 43 44 45 46 47

            - Directly connecting the output of the model components to an ouput neuron(s).

            - Adding a `Fully-Connected Head` (FC-Head) on top of the deep models.
              This FC-Head will combine the output form the ``deeptabular``, ``deeptext`` and
              ``deepimage`` and will be then connected to the output neuron(s).

        Parameters
        ----------
48
        wide: ``nn.Module``, Optional, default = None
49 50 51 52
            ``Wide`` model. I recommend using the ``Wide`` class in this
            package. However, it is possible to use a custom model as long as
            is consistent with the required architecture, see
            :class:`pytorch_widedeep.models.wide.Wide`
53
        deeptabular: ``nn.Module``, Optional, default = None
54

55
            currently ``pytorch-widedeep`` implements four possible
56
            architectures for the `deeptabular` component. These are:
57
            ``TabMlp``, ``TabResnet``, ``TabNet`` and ``TabTransformer``.
58 59 60 61 62 63 64 65 66 67 68

            1. ``TabMlp`` is simply an embedding layer encoding the categorical
            features that are then concatenated and passed through a series of
            dense (hidden) layers (i.e. and MLP).
            See: ``pytorch_widedeep.models.deep_dense.TabMlp``

            2. ``TabResnet`` is an embedding layer encoding the categorical
            features that are then concatenated and passed through a series of
            ResNet blocks formed by dense layers.
            See ``pytorch_widedeep.models.deep_dense_resnet.TabResnet``

69 70 71 72
            3. ``TabNet`` is detailed in `TabNet: Attentive Interpretable Tabular
            Learning <https://arxiv.org/abs/1908.07442>`_. See
            ``pytorch_widedeep.models.tabnet.tab_net.TabNet``

73 74
            3. ``TabTransformer`` is detailed in `TabTransformer: Tabular Data
            Modeling Using Contextual Embeddings
75
            <https://arxiv.org/abs/2012.06678>`_. See
76 77
            ``pytorch_widedeep.models.tab_transformer.TabTransformer``

78 79 80
            I recommend using on of these as ``deeptabular``. However, it is
            possible to use a custom model as long as is  consistent with the
            required architecture. See
81 82
            :class:`pytorch_widedeep.models.deep_dense.TabTransformer`.

83
        deeptext: ``nn.Module``, Optional, default = None
84 85 86 87
            Model for the text input. Must be an object of class ``DeepText``
            or a custom model as long as is consistent with the required
            architecture. See
            :class:`pytorch_widedeep.models.deep_dense.DeepText`
88
        deepimage: ``nn.Module``, Optional, default = None
89 90 91 92
            Model for the images input. Must be an object of class
            ``DeepImage`` or a custom model as long as is consistent with the
            required architecture. See
            :class:`pytorch_widedeep.models.deep_dense.DeepImage`
93
        deephead: ``nn.Module``, Optional, default = None
94 95
            Custom model by the user that will receive the outtput of the deep
            component. Typically a FC-Head (MLP)
96 97
        head_hidden_dims: List, Optional, default = None
            Alternatively, the ``head_hidden_dims`` param can be used to
98
            specify the sizes of the stacked dense layers in the fc-head e.g:
99 100 101 102 103
            ``[128, 64]``. Use ``deephead`` or ``head_hidden_dims``, but not
            both.
        head_dropout: float, default = 0.1
            If ``head_hidden_dims`` is not None, dropout between the layers in
            ``head_hidden_dims``
104
        head_activation: str, default = "relu"
105 106 107 108 109 110 111 112 113 114 115 116 117
            If ``head_hidden_dims`` is not None, activation function of the
            head layers. One of "relu", gelu" or "leaky_relu"
        head_batchnorm: bool, default = False
            If ``head_hidden_dims`` is not None, specifies if batch
            normalizatin should be included in the head layers
        head_batchnorm_last: bool, default = False
            If ``head_hidden_dims`` is not None, boolean indicating whether or
            not to apply batch normalization to the last of the dense layers
        head_linear_first: bool, default = False
            If ``head_hidden_dims`` is not None, boolean indicating whether
            the order of the operations in the dense layer. If ``True``:
            ``[LIN -> ACT -> BN -> DP]``. If ``False``: ``[BN -> DP -> LIN ->
            ACT]``
118
        pred_dim: int, default = 1
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
            Size of the final wide and deep output layer containing the
            predictions. `1` for regression and binary classification or number
            of classes for multiclass classification.

        Examples
        --------

        >>> from pytorch_widedeep.models import TabResnet, DeepImage, DeepText, Wide, WideDeep
        >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
        >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
        >>> wide = Wide(10, 1)
        >>> deeptabular = TabResnet(blocks_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
        >>> deeptext = DeepText(vocab_size=10, embed_dim=4, padding_idx=0)
        >>> deepimage = DeepImage(pretrained=False)
        >>> model = WideDeep(wide=wide, deeptabular=deeptabular, deeptext=deeptext, deepimage=deepimage)


        .. note:: While I recommend using the ``wide`` and ``deeptabular`` components
            within this package when building the corresponding model components,
            it is very likely that the user will want to use custom text and image
            models. That is perfectly possible. Simply, build them and pass them
            as the corresponding parameters. Note that the custom models MUST
            return a last layer of activations (i.e. not the final prediction) so
            that  these activations are collected by ``WideDeep`` and combined
            accordingly. In addition, the models MUST also contain an attribute
            ``output_dim`` with the size of these last layers of activations. See
            for example :class:`pytorch_widedeep.models.tab_mlp.TabMlp`

        """
148
        super(WideDeep, self).__init__()
149

J
jrzaurin 已提交
150 151
        self._check_model_components(
            wide,
152
            deeptabular,
J
jrzaurin 已提交
153 154 155
            deeptext,
            deepimage,
            deephead,
156
            head_hidden_dims,
J
jrzaurin 已提交
157
            pred_dim,
158
        )
159

160 161 162
        # required as attribute just in case we pass a deephead
        self.pred_dim = pred_dim

163
        # The main 5 components of the wide and deep assemble
164
        self.wide = wide
165
        self.deeptabular = deeptabular
J
jrzaurin 已提交
166
        self.deeptext = deeptext
167
        self.deepimage = deepimage
168 169
        self.deephead = deephead

170 171
        if self.deeptabular is not None:
            self.is_tabnet = deeptabular.__class__.__name__ == "TabNet"
172 173
        else:
            self.is_tabnet = False
174

175
        if self.deephead is None:
176
            if head_hidden_dims is not None:
177 178 179 180 181 182 183 184
                self._build_deephead(
                    head_hidden_dims,
                    head_activation,
                    head_dropout,
                    head_batchnorm,
                    head_batchnorm_last,
                    head_linear_first,
                )
185 186 187 188 189 190 191 192 193 194
            else:
                self._add_pred_layer()

    def forward(self, X: Dict[str, Tensor]):
        wide_out = self._forward_wide(X)
        if self.deephead:
            return self._forward_deephead(X, wide_out)
        else:
            return self._forward_deep(X, wide_out)

195 196 197 198 199 200 201 202 203
    def _build_deephead(
        self,
        head_hidden_dims,
        head_activation,
        head_dropout,
        head_batchnorm,
        head_batchnorm_last,
        head_linear_first,
    ):
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
        deep_dim = 0
        if self.deeptabular is not None:
            deep_dim += self.deeptabular.output_dim
        if self.deeptext is not None:
            deep_dim += self.deeptext.output_dim
        if self.deepimage is not None:
            deep_dim += self.deepimage.output_dim

        head_hidden_dims = [deep_dim] + head_hidden_dims
        self.deephead = MLP(
            head_hidden_dims,
            head_activation,
            head_dropout,
            head_batchnorm,
            head_batchnorm_last,
            head_linear_first,
        )

        self.deephead.add_module(
            "head_out", nn.Linear(head_hidden_dims[-1], self.pred_dim)
        )

    def _add_pred_layer(self):
        if self.deeptabular is not None:
            if self.is_tabnet:
                self.deeptabular = nn.Sequential(
                    self.deeptabular,
                    TabNetPredLayer(self.deeptabular.output_dim, self.pred_dim),
J
jrzaurin 已提交
232
                )
233
            else:
234 235 236 237 238 239 240 241 242 243 244 245 246 247
                self.deeptabular = nn.Sequential(
                    self.deeptabular,
                    nn.Linear(self.deeptabular.output_dim, self.pred_dim),
                )
        if self.deeptext is not None:
            self.deeptext = nn.Sequential(
                self.deeptext, nn.Linear(self.deeptext.output_dim, self.pred_dim)
            )
        if self.deepimage is not None:
            self.deepimage = nn.Sequential(
                self.deepimage, nn.Linear(self.deepimage.output_dim, self.pred_dim)
            )

    def _forward_wide(self, X):
248 249 250 251
        if self.wide is not None:
            out = self.wide(X["wide"])
        else:
            batch_size = X[list(X.keys())[0]].size(0)
252
            out = torch.zeros(batch_size, self.pred_dim).to(device)
253

254 255 256 257 258 259 260 261
        return out

    def _forward_deephead(self, X, wide_out):
        if self.deeptabular is not None:
            if self.is_tabnet:
                tab_out = self.deeptabular(X["deeptabular"])
                deepside, M_loss = tab_out[0], tab_out[1]
            else:
262
                deepside = self.deeptabular(X["deeptabular"])
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
        else:
            deepside = torch.FloatTensor().to(device)
        if self.deeptext is not None:
            deepside = torch.cat([deepside, self.deeptext(X["deeptext"])], axis=1)
        if self.deepimage is not None:
            deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1)

        deephead_out = self.deephead(deepside)
        deepside_out = nn.Linear(deephead_out.size(1), self.pred_dim).to(device)

        if self.is_tabnet:
            res = (wide_out.add_(deepside_out(deephead_out)), M_loss)
        else:
            res = wide_out.add_(deepside_out(deephead_out))

        return res

    def _forward_deep(self, X, wide_out):
        if self.deeptabular is not None:
            if self.is_tabnet:
                tab_out, M_loss = self.deeptabular(X["deeptabular"])
                wide_out.add_(tab_out)
285
            else:
286 287 288 289 290 291 292 293
                wide_out.add_(self.deeptabular(X["deeptabular"]))
        if self.deeptext is not None:
            wide_out.add_(self.deeptext(X["deeptext"]))
        if self.deepimage is not None:
            wide_out.add_(self.deepimage(X["deepimage"]))

        if self.is_tabnet:
            res = (wide_out, M_loss)
294
        else:
295 296 297
            res = wide_out

        return res
298

J
jrzaurin 已提交
299
    @staticmethod  # noqa: C901
J
jrzaurin 已提交
300 301
    def _check_model_components(
        wide,
302
        deeptabular,
J
jrzaurin 已提交
303 304 305
        deeptext,
        deepimage,
        deephead,
306
        head_hidden_dims,
J
jrzaurin 已提交
307
        pred_dim,
308 309
    ):

J
jrzaurin 已提交
310 311 312 313 314 315 316
        if wide is not None:
            assert wide.wide_linear.weight.size(1) == pred_dim, (
                "the 'pred_dim' of the wide component ({}) must be equal to the 'pred_dim' "
                "of the deep component and the overall model itself ({})".format(
                    wide.wide_linear.weight.size(1), pred_dim
                )
            )
317
        if deeptabular is not None and not hasattr(deeptabular, "output_dim"):
318
            raise AttributeError(
319
                "deeptabular model must have an 'output_dim' attribute. "
320 321
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
322 323 324 325 326 327 328 329 330
        if deeptabular is not None:
            is_tabnet = deeptabular.__class__.__name__ == "TabNet"
            has_wide_text_or_image = (
                wide is not None or deeptext is not None or deepimage is not None
            )
            if is_tabnet and has_wide_text_or_image:
                warnings.warn(
                    "'WideDeep' is a model comprised by multiple components and the 'deeptabular'"
                    " component is 'TabNet'. We recommend using 'TabNet' in isolation."
331
                    " The reasons are: i)'TabNet' uses sparse regularization which partially losses"
332 333
                    " its purpose when used in combination with other components."
                    " If you still want to use a multiple component model with 'TabNet',"
334 335 336
                    " consider setting 'lambda_sparse' to 0 during training. ii) The feature"
                    " importances will be computed only for TabNet but the model will comprise multiple"
                    " components. Therefore, such importances will partially lose their 'meaning'.",
337 338
                    UserWarning,
                )
339 340 341 342 343 344 345 346 347 348
        if deeptext is not None and not hasattr(deeptext, "output_dim"):
            raise AttributeError(
                "deeptext model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
        if deepimage is not None and not hasattr(deepimage, "output_dim"):
            raise AttributeError(
                "deepimage model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
349
        if deephead is not None and head_hidden_dims is not None:
350
            raise ValueError(
351
                "both 'deephead' and 'head_hidden_dims' are not None. Use one of the other, but not both"
352
            )
353
        if (
354
            head_hidden_dims is not None
355 356 357 358
            and not deeptabular
            and not deeptext
            and not deepimage
        ):
359
            raise ValueError(
360
                "if 'head_hidden_dims' is not None, at least one deep component must be used"
361
            )
J
jrzaurin 已提交
362 363 364
        if deephead is not None:
            deephead_inp_feat = next(deephead.parameters()).size(1)
            output_dim = 0
365 366
            if deeptabular is not None:
                output_dim += deeptabular.output_dim
J
jrzaurin 已提交
367 368 369 370 371 372 373 374 375 376
            if deeptext is not None:
                output_dim += deeptext.output_dim
            if deepimage is not None:
                output_dim += deepimage.output_dim
            assert deephead_inp_feat == output_dim, (
                "if a custom 'deephead' is used its input features ({}) must be equal to "
                "the output features of the deep component ({})".format(
                    deephead_inp_feat, output_dim
                )
            )