wide_deep.py 16.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
"""
During the development of the package I realised that there is a typing
inconsistency. The input components of a Wide and Deep model are of type
nn.Module. These change type internally to nn.Sequential. While nn.Sequential
is an instance of nn.Module the oppossite is, of course, not true. This does
not affect any funcionality of the package, but it is something that needs
fixing. However, while fixing is simple (simply define new attributes that
are the nn.Sequential objects), its implications are quite wide within the
package (involves changing a number of tests and tutorials). Therefore, I
will introduce that fix when I do a major release. For now, we live with it.
"""

13
import warnings
14

15
import torch
16 17
import torch.nn as nn

18
from pytorch_widedeep.wdtypes import *  # noqa: F403
19
from pytorch_widedeep.models.fds_layer import FDSLayer
P
Pavol Mulinka 已提交
20 21 22
from pytorch_widedeep.models._get_activation_fn import get_activation_fn
from pytorch_widedeep.models.tabular.mlp._layers import MLP
from pytorch_widedeep.models.tabular.tabnet.tab_net import TabNetPredLayer
J
jrzaurin 已提交
23

24
warnings.filterwarnings("default", category=UserWarning)
25

26
use_cuda = torch.cuda.is_available()
27
device = torch.device("cuda" if use_cuda else "cpu")
28 29 30


class WideDeep(nn.Module):
31
    r"""Main collector class that combines all ``wide``, ``deeptabular``
J
jrzaurin 已提交
32
    ``deeptext`` and ``deepimage`` models.
33 34

    There are two options to combine these models that correspond to the
35
    two main architectures that ``pytorch-widedeep`` can build.
36 37 38 39 40 41 42 43 44 45

        - Directly connecting the output of the model components to an ouput neuron(s).

        - Adding a `Fully-Connected Head` (FC-Head) on top of the deep models.
          This FC-Head will combine the output form the ``deeptabular``, ``deeptext`` and
          ``deepimage`` and will be then connected to the output neuron(s).

    Parameters
    ----------
    wide: ``nn.Module``, Optional, default = None
J
jrzaurin 已提交
46 47
        ``Wide`` model. This is a linear model where the non-linearities are
        captured via crossed-columns.
48
    deeptabular: ``nn.Module``, Optional, default = None
J
jrzaurin 已提交
49 50 51
        Currently this library implements a number of possible architectures
        for the ``deeptabular`` component. See the documenation of the
        package.
52
    deeptext: ``nn.Module``, Optional, default = None
J
jrzaurin 已提交
53 54 55
        Currently this library implements a number of possible architectures
        for the ``deeptext`` component. See the documenation of the
        package.
56
    deepimage: ``nn.Module``, Optional, default = None
J
jrzaurin 已提交
57 58 59
        Currently this library uses ``torchvision`` and implements a number of
        possible architectures for the ``deepimage`` component. See the
        documenation of the package.
60
    head_hidden_dims: List, Optional, default = None
J
jrzaurin 已提交
61
        List with the sizes of the dense layers in the head e.g: [128, 64]
62
    head_activation: str, default = "relu"
J
jrzaurin 已提交
63
        Activation function for the dense layers in the head. Currently
64
        `'tanh'`, `'relu'`, `'leaky_relu'` and `'gelu'` are supported
J
jrzaurin 已提交
65 66
    head_dropout: float, Optional, default = None
        Dropout of the dense layers in the head
67
    head_batchnorm: bool, default = False
J
jrzaurin 已提交
68 69
        Boolean indicating whether or not to include batch normalization in
        the dense layers that form the `'rnn_mlp'`
70
    head_batchnorm_last: bool, default = False
J
jrzaurin 已提交
71 72
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers in the head
73
    head_linear_first: bool, default = False
J
jrzaurin 已提交
74 75 76 77 78 79 80
        Boolean indicating whether the order of the operations in the dense
        layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
        LIN -> ACT]``
    deephead: ``nn.Module``, Optional, default = None
        Alternatively, the user can pass a custom model that will receive the
        output of the deep component. If ``deephead`` is not None all the
        previous fc-head parameters will be ignored
P
Pavol Mulinka 已提交
81
    enforce_positive: bool, default = False
82 83 84 85
        Boolean indicating if the output from the final layer must be
        positive. This is important if you are using loss functions with
        non-negative input restrictions, e.g. RMSLE, or if you know your
        predictions are bounded in between 0 and inf
P
Pavol Mulinka 已提交
86
    enforce_positive_activation: str, default = "softplus"
87 88
        Activation function to enforce that the final layer has a positive
        output. `'softplus'` or `'relu'` are supported.
89 90 91 92
    pred_dim: int, default = 1
        Size of the final wide and deep output layer containing the
        predictions. `1` for regression and binary classification or number
        of classes for multiclass classification.
93 94 95 96 97 98 99 100
    with_fds: bool, default = False
        If feature distribution smoothing (FDS) should be applied before the
        final prediction layer. Only available for regression problems. See
        `Delving into Deep Imbalanced Regression <https://arxiv.org/abs/2102.09554>`_
        for details.
    fds_config: dict, default = None
        Dictionary defining specific values for
        ``FeatureDistributionSmoothing`` layer
101 102 103 104

    Examples
    --------

105
    >>> from pytorch_widedeep.models import TabResnet, Vision, BasicRNN, Wide, WideDeep
106 107 108
    >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
    >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
    >>> wide = Wide(10, 1)
109
    >>> deeptabular = TabResnet(blocks_dims=[8, 4], column_idx=column_idx, cat_embed_input=embed_input)
110
    >>> deeptext = BasicRNN(vocab_size=10, embed_dim=4, padding_idx=0)
111
    >>> deepimage = Vision()
112 113 114
    >>> model = WideDeep(wide=wide, deeptabular=deeptabular, deeptext=deeptext, deepimage=deepimage)


J
jrzaurin 已提交
115 116 117 118 119 120 121 122
    .. note:: It is possible to use custom components to build Wide & Deep models.
        Simply, build them and pass them as the corresponding parameters. Note
        that the custom models MUST return a last layer of activations
        (i.e. not the final prediction) so that  these activations are
        collected by ``WideDeep`` and combined accordingly. In addition, the
        models MUST also contain an attribute ``output_dim`` with the size of
        these last layers of activations. See for
        example :class:`pytorch_widedeep.models.tab_mlp.TabMlp`
123 124
    """

125
    def __init__(
J
jrzaurin 已提交
126
        self,
127
        wide: Optional[nn.Module] = None,
128
        deeptabular: Optional[nn.Module] = None,
J
jrzaurin 已提交
129 130 131
        deeptext: Optional[nn.Module] = None,
        deepimage: Optional[nn.Module] = None,
        deephead: Optional[nn.Module] = None,
132
        head_hidden_dims: Optional[List[int]] = None,
133 134 135 136 137
        head_activation: str = "relu",
        head_dropout: float = 0.1,
        head_batchnorm: bool = False,
        head_batchnorm_last: bool = False,
        head_linear_first: bool = False,
P
Pavol Mulinka 已提交
138 139
        enforce_positive: bool = False,
        enforce_positive_activation: str = "softplus",
140
        pred_dim: int = 1,
141 142
        with_fds: bool = False,
        **fds_config,
J
jrzaurin 已提交
143
    ):
144
        super(WideDeep, self).__init__()
145

J
jrzaurin 已提交
146
        self._check_inputs(
J
jrzaurin 已提交
147
            wide,
148
            deeptabular,
J
jrzaurin 已提交
149 150 151
            deeptext,
            deepimage,
            deephead,
152
            head_hidden_dims,
J
jrzaurin 已提交
153
            pred_dim,
154
            with_fds,
155
        )
156

157 158 159
        # required as attribute just in case we pass a deephead
        self.pred_dim = pred_dim

160
        # The main 5 components of the wide and deep assemble
161
        self.wide = wide
162
        self.deeptabular = deeptabular
J
jrzaurin 已提交
163
        self.deeptext = deeptext
164
        self.deepimage = deepimage
165
        self.deephead = deephead
P
Pavol Mulinka 已提交
166
        self.enforce_positive = enforce_positive
167
        self.with_fds = with_fds
P
Pavol Mulinka 已提交
168

169 170
        if self.deeptabular is not None:
            self.is_tabnet = deeptabular.__class__.__name__ == "TabNet"
171 172
        else:
            self.is_tabnet = False
173

P
Pavol Mulinka 已提交
174 175 176 177 178 179 180 181 182 183 184
        if self.deephead is None and head_hidden_dims is not None:
            self._build_deephead(
                head_hidden_dims,
                head_activation,
                head_dropout,
                head_batchnorm,
                head_batchnorm_last,
                head_linear_first,
            )
        elif self.deephead is not None:
            pass
185 186
        else:
            self._add_pred_layer()
P
Pavol Mulinka 已提交
187

188 189 190
        if self.with_fds:
            self.fds_layer = FDSLayer(feature_dim=self.deeptabular.output_dim, **fds_config)  # type: ignore[arg-type]

P
Pavol Mulinka 已提交
191 192 193
        if self.enforce_positive:
            self.enf_pos = get_activation_fn(enforce_positive_activation)

P
Pavol Mulinka 已提交
194 195 196 197 198 199
    def forward(
        self,
        X: Dict[str, Tensor],
        y: Optional[Tensor] = None,
        epoch: Optional[int] = None,
    ):
200 201 202 203 204

        if self.with_fds:
            return self._forward_deep_with_fds(X, y, epoch)

        wide_out = self._forward_wide(X)
205
        if self.deephead:
206
            deep = self._forward_deephead(X, wide_out)
P
Pavol Mulinka 已提交
207
        else:
208 209
            deep = self._forward_deep(X, wide_out)

P
Pavol Mulinka 已提交
210
        if self.enforce_positive:
211
            return self.enf_pos(deep)
212
        else:
213
            return deep
214

215 216 217 218 219 220 221 222 223
    def _build_deephead(
        self,
        head_hidden_dims,
        head_activation,
        head_dropout,
        head_batchnorm,
        head_batchnorm_last,
        head_linear_first,
    ):
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
        deep_dim = 0
        if self.deeptabular is not None:
            deep_dim += self.deeptabular.output_dim
        if self.deeptext is not None:
            deep_dim += self.deeptext.output_dim
        if self.deepimage is not None:
            deep_dim += self.deepimage.output_dim

        head_hidden_dims = [deep_dim] + head_hidden_dims
        self.deephead = MLP(
            head_hidden_dims,
            head_activation,
            head_dropout,
            head_batchnorm,
            head_batchnorm_last,
            head_linear_first,
        )
P
Pavol Mulinka 已提交
241

242 243 244 245 246
        self.deephead.add_module(
            "head_out", nn.Linear(head_hidden_dims[-1], self.pred_dim)
        )

    def _add_pred_layer(self):
247 248
        # if FDS the FDS Layer already includes the pred layer
        if self.deeptabular is not None and not self.with_fds:
249 250 251 252
            if self.is_tabnet:
                self.deeptabular = nn.Sequential(
                    self.deeptabular,
                    TabNetPredLayer(self.deeptabular.output_dim, self.pred_dim),
J
jrzaurin 已提交
253
                )
254
            else:
255 256 257 258 259 260 261 262 263 264 265 266 267 268
                self.deeptabular = nn.Sequential(
                    self.deeptabular,
                    nn.Linear(self.deeptabular.output_dim, self.pred_dim),
                )
        if self.deeptext is not None:
            self.deeptext = nn.Sequential(
                self.deeptext, nn.Linear(self.deeptext.output_dim, self.pred_dim)
            )
        if self.deepimage is not None:
            self.deepimage = nn.Sequential(
                self.deepimage, nn.Linear(self.deepimage.output_dim, self.pred_dim)
            )

    def _forward_wide(self, X):
269 270 271 272
        if self.wide is not None:
            out = self.wide(X["wide"])
        else:
            batch_size = X[list(X.keys())[0]].size(0)
273
            out = torch.zeros(batch_size, self.pred_dim).to(device)
274

275 276 277 278 279 280 281 282
        return out

    def _forward_deephead(self, X, wide_out):
        if self.deeptabular is not None:
            if self.is_tabnet:
                tab_out = self.deeptabular(X["deeptabular"])
                deepside, M_loss = tab_out[0], tab_out[1]
            else:
283
                deepside = self.deeptabular(X["deeptabular"])
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
        else:
            deepside = torch.FloatTensor().to(device)
        if self.deeptext is not None:
            deepside = torch.cat([deepside, self.deeptext(X["deeptext"])], axis=1)
        if self.deepimage is not None:
            deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1)

        deephead_out = self.deephead(deepside)
        deepside_out = nn.Linear(deephead_out.size(1), self.pred_dim).to(device)

        if self.is_tabnet:
            res = (wide_out.add_(deepside_out(deephead_out)), M_loss)
        else:
            res = wide_out.add_(deepside_out(deephead_out))

        return res

301
    def _forward_deep(self, X, wide_out):
302 303 304 305
        if self.deeptabular is not None:
            if self.is_tabnet:
                tab_out, M_loss = self.deeptabular(X["deeptabular"])
                wide_out.add_(tab_out)
306
            else:
307
                wide_out.add_(self.deeptabular(X["deeptabular"]))
308 309 310 311
        if self.deeptext is not None:
            wide_out.add_(self.deeptext(X["deeptext"]))
        if self.deepimage is not None:
            wide_out.add_(self.deepimage(X["deepimage"]))
312

313 314
        if self.is_tabnet:
            res = (wide_out, M_loss)
315
        else:
316 317 318
            res = wide_out

        return res
319

320 321 322 323 324 325 326 327 328 329 330 331 332 333
    def _forward_deep_with_fds(
        self,
        X: Dict[str, Tensor],
        y: Optional[Tensor] = None,
        epoch: Optional[int] = None,
    ):
        res = self.fds_layer(self.deeptabular(X["deeptabular"]), y, epoch)
        if self.enforce_positive:
            if isinstance(res, Tuple):  # type: ignore[arg-type]
                res[1] = self.enf_pos(res[1])
            else:
                res = self.enf_pos(res)
        return res

J
jrzaurin 已提交
334
    @staticmethod  # noqa: C901
J
jrzaurin 已提交
335
    def _check_inputs(  # noqa: C901
J
jrzaurin 已提交
336
        wide,
337
        deeptabular,
J
jrzaurin 已提交
338 339 340
        deeptext,
        deepimage,
        deephead,
341
        head_hidden_dims,
J
jrzaurin 已提交
342
        pred_dim,
343
        with_fds,
344 345
    ):

J
jrzaurin 已提交
346 347 348 349 350 351 352
        if wide is not None:
            assert wide.wide_linear.weight.size(1) == pred_dim, (
                "the 'pred_dim' of the wide component ({}) must be equal to the 'pred_dim' "
                "of the deep component and the overall model itself ({})".format(
                    wide.wide_linear.weight.size(1), pred_dim
                )
            )
353
        if deeptabular is not None and not hasattr(deeptabular, "output_dim"):
354
            raise AttributeError(
355
                "deeptabular model must have an 'output_dim' attribute. "
356
                "See pytorch-widedeep.models.deep_text.DeepText"
357
            )
358 359 360 361 362 363 364 365 366
        if deeptabular is not None:
            is_tabnet = deeptabular.__class__.__name__ == "TabNet"
            has_wide_text_or_image = (
                wide is not None or deeptext is not None or deepimage is not None
            )
            if is_tabnet and has_wide_text_or_image:
                warnings.warn(
                    "'WideDeep' is a model comprised by multiple components and the 'deeptabular'"
                    " component is 'TabNet'. We recommend using 'TabNet' in isolation."
367
                    " The reasons are: i)'TabNet' uses sparse regularization which partially losses"
368 369
                    " its purpose when used in combination with other components."
                    " If you still want to use a multiple component model with 'TabNet',"
370 371 372
                    " consider setting 'lambda_sparse' to 0 during training. ii) The feature"
                    " importances will be computed only for TabNet but the model will comprise multiple"
                    " components. Therefore, such importances will partially lose their 'meaning'.",
373 374
                    UserWarning,
                )
375 376 377
        if deeptext is not None and not hasattr(deeptext, "output_dim"):
            raise AttributeError(
                "deeptext model must have an 'output_dim' attribute. "
378
                "See pytorch-widedeep.models.deep_text.DeepText"
379 380 381 382
            )
        if deepimage is not None and not hasattr(deepimage, "output_dim"):
            raise AttributeError(
                "deepimage model must have an 'output_dim' attribute. "
383
                "See pytorch-widedeep.models.deep_text.DeepText"
384
            )
385
        if deephead is not None and head_hidden_dims is not None:
386
            raise ValueError(
387
                "both 'deephead' and 'head_hidden_dims' are not None. Use one of the other, but not both"
388
            )
389
        if (
390
            head_hidden_dims is not None
391 392 393 394
            and not deeptabular
            and not deeptext
            and not deepimage
        ):
395
            raise ValueError(
396
                "if 'head_hidden_dims' is not None, at least one deep component must be used"
397
            )
J
jrzaurin 已提交
398 399 400
        if deephead is not None:
            deephead_inp_feat = next(deephead.parameters()).size(1)
            output_dim = 0
401 402
            if deeptabular is not None:
                output_dim += deeptabular.output_dim
J
jrzaurin 已提交
403 404 405 406 407 408 409 410 411 412
            if deeptext is not None:
                output_dim += deeptext.output_dim
            if deepimage is not None:
                output_dim += deepimage.output_dim
            assert deephead_inp_feat == output_dim, (
                "if a custom 'deephead' is used its input features ({}) must be equal to "
                "the output features of the deep component ({})".format(
                    deephead_inp_feat, output_dim
                )
            )
413

414 415 416 417 418 419 420
        if with_fds and (
            (
                wide is not None
                or deeptext is not None
                or deepimage is not None
                or deephead is not None
            )
421 422 423 424 425 426
            or pred_dim != 1
        ):
            raise ValueError(
                """Feature Distribution Smoothing (FDS) is supported when using only a deeptabular component"
                " and for regression problems."""
            )