wide_deep.py 14.8 KB
Newer Older
1
import warnings
2

3
import torch
4 5
import torch.nn as nn

6 7
from pytorch_widedeep.wdtypes import *  # noqa: F403
from pytorch_widedeep.models.tab_mlp import MLP
8
from pytorch_widedeep.models.tabnet.tab_net import TabNetPredLayer
J
jrzaurin 已提交
9

10
warnings.filterwarnings("default", category=UserWarning)
11

12
use_cuda = torch.cuda.is_available()
13
device = torch.device("cuda" if use_cuda else "cpu")
14 15 16


class WideDeep(nn.Module):
17
    def __init__(
J
jrzaurin 已提交
18
        self,
19
        wide: Optional[nn.Module] = None,
20
        deeptabular: Optional[nn.Module] = None,
J
jrzaurin 已提交
21 22 23
        deeptext: Optional[nn.Module] = None,
        deepimage: Optional[nn.Module] = None,
        deephead: Optional[nn.Module] = None,
24
        head_hidden_dims: Optional[List[int]] = None,
25 26 27 28 29
        head_activation: str = "relu",
        head_dropout: float = 0.1,
        head_batchnorm: bool = False,
        head_batchnorm_last: bool = False,
        head_linear_first: bool = False,
30
        pred_dim: int = 1,
J
jrzaurin 已提交
31
    ):
32 33 34
        r"""Main collector class that combines all ``wide``, ``deeptabular``
        (which can be a number of architectures), ``deeptext`` and ``deepimage`` models.

35 36 37
        There are two options to combine these models that correspond to the
        two main architectures (there is a higher number of
        "sub-architectures") that ``pytorch-widedeep`` can build.
38 39 40 41 42 43 44 45 46

            - Directly connecting the output of the model components to an ouput neuron(s).

            - Adding a `Fully-Connected Head` (FC-Head) on top of the deep models.
              This FC-Head will combine the output form the ``deeptabular``, ``deeptext`` and
              ``deepimage`` and will be then connected to the output neuron(s).

        Parameters
        ----------
47
        wide: ``nn.Module``, Optional, default = None
48 49 50 51
            ``Wide`` model. I recommend using the ``Wide`` class in this
            package. However, it is possible to use a custom model as long as
            is consistent with the required architecture, see
            :class:`pytorch_widedeep.models.wide.Wide`
52
        deeptabular: ``nn.Module``, Optional, default = None
53 54 55

            currently ``pytorch-widedeep`` implements three possible
            architectures for the `deeptabular` component. These are:
56
            ``TabMlp``, ``TabResnet`` and ``TabTransformer``.
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

            1. ``TabMlp`` is simply an embedding layer encoding the categorical
            features that are then concatenated and passed through a series of
            dense (hidden) layers (i.e. and MLP).
            See: ``pytorch_widedeep.models.deep_dense.TabMlp``

            2. ``TabResnet`` is an embedding layer encoding the categorical
            features that are then concatenated and passed through a series of
            ResNet blocks formed by dense layers.
            See ``pytorch_widedeep.models.deep_dense_resnet.TabResnet``

            3. ``TabTransformer`` is detailed in `TabTransformer: Tabular Data
            Modeling Using Contextual Embeddings
            <https://arxiv.org/pdf/2012.06678.pdf>`_. See
            ``pytorch_widedeep.models.tab_transformer.TabTransformer``

            I recommend using on of these as ``deeptabular``. However, a
            custom model as long as is  consistent with the required
            architecture. See
            :class:`pytorch_widedeep.models.deep_dense.TabTransformer`.

78
        deeptext: ``nn.Module``, Optional, default = None
79 80 81 82
            Model for the text input. Must be an object of class ``DeepText``
            or a custom model as long as is consistent with the required
            architecture. See
            :class:`pytorch_widedeep.models.deep_dense.DeepText`
83
        deepimage: ``nn.Module``, Optional, default = None
84 85 86 87
            Model for the images input. Must be an object of class
            ``DeepImage`` or a custom model as long as is consistent with the
            required architecture. See
            :class:`pytorch_widedeep.models.deep_dense.DeepImage`
88
        deephead: ``nn.Module``, Optional, default = None
89 90
            Custom model by the user that will receive the outtput of the deep
            component. Typically a FC-Head (MLP)
91 92
        head_hidden_dims: List, Optional, default = None
            Alternatively, the ``head_hidden_dims`` param can be used to
93
            specify the sizes of the stacked dense layers in the fc-head e.g:
94 95 96 97 98
            ``[128, 64]``. Use ``deephead`` or ``head_hidden_dims``, but not
            both.
        head_dropout: float, default = 0.1
            If ``head_hidden_dims`` is not None, dropout between the layers in
            ``head_hidden_dims``
99
        head_activation: str, default = "relu"
100 101 102 103 104 105 106 107 108 109 110 111 112
            If ``head_hidden_dims`` is not None, activation function of the
            head layers. One of "relu", gelu" or "leaky_relu"
        head_batchnorm: bool, default = False
            If ``head_hidden_dims`` is not None, specifies if batch
            normalizatin should be included in the head layers
        head_batchnorm_last: bool, default = False
            If ``head_hidden_dims`` is not None, boolean indicating whether or
            not to apply batch normalization to the last of the dense layers
        head_linear_first: bool, default = False
            If ``head_hidden_dims`` is not None, boolean indicating whether
            the order of the operations in the dense layer. If ``True``:
            ``[LIN -> ACT -> BN -> DP]``. If ``False``: ``[BN -> DP -> LIN ->
            ACT]``
113
        pred_dim: int, default = 1
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
            Size of the final wide and deep output layer containing the
            predictions. `1` for regression and binary classification or number
            of classes for multiclass classification.

        Examples
        --------

        >>> from pytorch_widedeep.models import TabResnet, DeepImage, DeepText, Wide, WideDeep
        >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
        >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
        >>> wide = Wide(10, 1)
        >>> deeptabular = TabResnet(blocks_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
        >>> deeptext = DeepText(vocab_size=10, embed_dim=4, padding_idx=0)
        >>> deepimage = DeepImage(pretrained=False)
        >>> model = WideDeep(wide=wide, deeptabular=deeptabular, deeptext=deeptext, deepimage=deepimage)


        .. note:: While I recommend using the ``wide`` and ``deeptabular`` components
            within this package when building the corresponding model components,
            it is very likely that the user will want to use custom text and image
            models. That is perfectly possible. Simply, build them and pass them
            as the corresponding parameters. Note that the custom models MUST
            return a last layer of activations (i.e. not the final prediction) so
            that  these activations are collected by ``WideDeep`` and combined
            accordingly. In addition, the models MUST also contain an attribute
            ``output_dim`` with the size of these last layers of activations. See
            for example :class:`pytorch_widedeep.models.tab_mlp.TabMlp`

        """
143
        super(WideDeep, self).__init__()
144

J
jrzaurin 已提交
145 146
        self._check_model_components(
            wide,
147
            deeptabular,
J
jrzaurin 已提交
148 149 150
            deeptext,
            deepimage,
            deephead,
151
            head_hidden_dims,
J
jrzaurin 已提交
152
            pred_dim,
153
        )
154

155 156 157
        # required as attribute just in case we pass a deephead
        self.pred_dim = pred_dim

158
        # The main 5 components of the wide and deep assemble
159
        self.wide = wide
160
        self.deeptabular = deeptabular
J
jrzaurin 已提交
161
        self.deeptext = deeptext
162
        self.deepimage = deepimage
163 164
        self.deephead = deephead

165 166 167
        if self.deeptabular is not None:
            self.is_tabnet = deeptabular.__class__.__name__ == "TabNet"

168
        if self.deephead is None:
169
            if head_hidden_dims is not None:
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
                self._build_deephead()
            else:
                self._add_pred_layer()

    def forward(self, X: Dict[str, Tensor]):
        wide_out = self._forward_wide(X)
        if self.deephead:
            return self._forward_deephead(X, wide_out)
        else:
            return self._forward_deep(X, wide_out)

    def _build_deephead(self):
        deep_dim = 0
        if self.deeptabular is not None:
            deep_dim += self.deeptabular.output_dim
        if self.deeptext is not None:
            deep_dim += self.deeptext.output_dim
        if self.deepimage is not None:
            deep_dim += self.deepimage.output_dim

        head_hidden_dims = [deep_dim] + head_hidden_dims
        self.deephead = MLP(
            head_hidden_dims,
            head_activation,
            head_dropout,
            head_batchnorm,
            head_batchnorm_last,
            head_linear_first,
        )

        self.deephead.add_module(
            "head_out", nn.Linear(head_hidden_dims[-1], self.pred_dim)
        )

    def _add_pred_layer(self):
        if self.deeptabular is not None:
            if self.is_tabnet:
                self.deeptabular = nn.Sequential(
                    self.deeptabular,
                    TabNetPredLayer(self.deeptabular.output_dim, self.pred_dim),
J
jrzaurin 已提交
210
                )
211
            else:
212 213 214 215 216 217 218 219 220 221 222 223 224 225
                self.deeptabular = nn.Sequential(
                    self.deeptabular,
                    nn.Linear(self.deeptabular.output_dim, self.pred_dim),
                )
        if self.deeptext is not None:
            self.deeptext = nn.Sequential(
                self.deeptext, nn.Linear(self.deeptext.output_dim, self.pred_dim)
            )
        if self.deepimage is not None:
            self.deepimage = nn.Sequential(
                self.deepimage, nn.Linear(self.deepimage.output_dim, self.pred_dim)
            )

    def _forward_wide(self, X):
226 227 228 229
        if self.wide is not None:
            out = self.wide(X["wide"])
        else:
            batch_size = X[list(X.keys())[0]].size(0)
230
            out = torch.zeros(batch_size, self.pred_dim).to(device)
231

232 233 234 235 236 237 238 239
        return out

    def _forward_deephead(self, X, wide_out):
        if self.deeptabular is not None:
            if self.is_tabnet:
                tab_out = self.deeptabular(X["deeptabular"])
                deepside, M_loss = tab_out[0], tab_out[1]
            else:
240
                deepside = self.deeptabular(X["deeptabular"])
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
        else:
            deepside = torch.FloatTensor().to(device)
        if self.deeptext is not None:
            deepside = torch.cat([deepside, self.deeptext(X["deeptext"])], axis=1)
        if self.deepimage is not None:
            deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1)

        deephead_out = self.deephead(deepside)
        deepside_out = nn.Linear(deephead_out.size(1), self.pred_dim).to(device)

        if self.is_tabnet:
            res = (wide_out.add_(deepside_out(deephead_out)), M_loss)
        else:
            res = wide_out.add_(deepside_out(deephead_out))

        return res

    def _forward_deep(self, X, wide_out):
        if self.deeptabular is not None:
            if self.is_tabnet:
                tab_out, M_loss = self.deeptabular(X["deeptabular"])
                wide_out.add_(tab_out)
263
            else:
264 265 266 267 268 269 270 271
                wide_out.add_(self.deeptabular(X["deeptabular"]))
        if self.deeptext is not None:
            wide_out.add_(self.deeptext(X["deeptext"]))
        if self.deepimage is not None:
            wide_out.add_(self.deepimage(X["deepimage"]))

        if self.is_tabnet:
            res = (wide_out, M_loss)
272
        else:
273 274 275
            res = wide_out

        return res
276

J
jrzaurin 已提交
277
    @staticmethod  # noqa: C901
J
jrzaurin 已提交
278 279
    def _check_model_components(
        wide,
280
        deeptabular,
J
jrzaurin 已提交
281 282 283
        deeptext,
        deepimage,
        deephead,
284
        head_hidden_dims,
J
jrzaurin 已提交
285
        pred_dim,
286 287
    ):

J
jrzaurin 已提交
288 289 290 291 292 293 294
        if wide is not None:
            assert wide.wide_linear.weight.size(1) == pred_dim, (
                "the 'pred_dim' of the wide component ({}) must be equal to the 'pred_dim' "
                "of the deep component and the overall model itself ({})".format(
                    wide.wide_linear.weight.size(1), pred_dim
                )
            )
295
        if deeptabular is not None and not hasattr(deeptabular, "output_dim"):
296
            raise AttributeError(
297
                "deeptabular model must have an 'output_dim' attribute. "
298 299
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
        if deeptabular is not None:
            is_tabnet = deeptabular.__class__.__name__ == "TabNet"
            has_wide_text_or_image = (
                wide is not None or deeptext is not None or deepimage is not None
            )
            if is_tabnet and has_wide_text_or_image:
                warnings.warn(
                    "'WideDeep' is a model comprised by multiple components and the 'deeptabular'"
                    " component is 'TabNet'. We recommend using 'TabNet' in isolation."
                    " This is because 'TabNet' uses sparse regularization which partially losses"
                    " its purpose when used in combination with other components."
                    " If you still want to use a multiple component model with 'TabNet',"
                    " consider setting 'lambda_sparse' to 0 during training",
                    UserWarning,
                )
315 316 317 318 319 320 321 322 323 324
        if deeptext is not None and not hasattr(deeptext, "output_dim"):
            raise AttributeError(
                "deeptext model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
        if deepimage is not None and not hasattr(deepimage, "output_dim"):
            raise AttributeError(
                "deepimage model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
325
        if deephead is not None and head_hidden_dims is not None:
326
            raise ValueError(
327
                "both 'deephead' and 'head_hidden_dims' are not None. Use one of the other, but not both"
328
            )
329
        if (
330
            head_hidden_dims is not None
331 332 333 334
            and not deeptabular
            and not deeptext
            and not deepimage
        ):
335
            raise ValueError(
336
                "if 'head_hidden_dims' is not None, at least one deep component must be used"
337
            )
J
jrzaurin 已提交
338 339 340
        if deephead is not None:
            deephead_inp_feat = next(deephead.parameters()).size(1)
            output_dim = 0
341 342
            if deeptabular is not None:
                output_dim += deeptabular.output_dim
J
jrzaurin 已提交
343 344 345 346 347 348 349 350 351 352
            if deeptext is not None:
                output_dim += deeptext.output_dim
            if deepimage is not None:
                output_dim += deepimage.output_dim
            assert deephead_inp_feat == output_dim, (
                "if a custom 'deephead' is used its input features ({}) must be equal to "
                "the output features of the deep component ({})".format(
                    deephead_inp_feat, output_dim
                )
            )