wide_deep.py 15.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
"""
During the development of the package I realised that there is a typing
inconsistency. The input components of a Wide and Deep model are of type
nn.Module. These change type internally to nn.Sequential. While nn.Sequential
is an instance of nn.Module the oppossite is, of course, not true. This does
not affect any funcionality of the package, but it is something that needs
fixing. However, while fixing is simple (simply define new attributes that
are the nn.Sequential objects), its implications are quite wide within the
package (involves changing a number of tests and tutorials). Therefore, I
will introduce that fix when I do a major release. For now, we live with it.
"""

13
import warnings
14

15
import torch
16 17
import torch.nn as nn

18
from pytorch_widedeep.wdtypes import *  # noqa: F403
19
from pytorch_widedeep.models._get_activation_fn import get_activation_fn
20 21
from pytorch_widedeep.models.tabular.mlp._layers import MLP
from pytorch_widedeep.models.tabular.tabnet.tab_net import TabNetPredLayer
J
jrzaurin 已提交
22

23
warnings.filterwarnings("default", category=UserWarning)
24

25
use_cuda = torch.cuda.is_available()
26
device = torch.device("cuda" if use_cuda else "cpu")
27 28 29


class WideDeep(nn.Module):
30 31 32 33 34
    r"""Main collector class that combines all ``wide``, ``deeptabular``
    (which can be a number of architectures), ``deeptext`` and
    ``deepimage`` models.

    There are two options to combine these models that correspond to the
35
    two main architectures that ``pytorch-widedeep`` can build.
36 37 38 39 40 41 42 43 44 45

        - Directly connecting the output of the model components to an ouput neuron(s).

        - Adding a `Fully-Connected Head` (FC-Head) on top of the deep models.
          This FC-Head will combine the output form the ``deeptabular``, ``deeptext`` and
          ``deepimage`` and will be then connected to the output neuron(s).

    Parameters
    ----------
    wide: ``nn.Module``, Optional, default = None
J
jrzaurin 已提交
46
        ``Wide`` model. I recommend using the ``Wide`` class in this
47 48 49 50
        package. However, it is possible to use a custom model as long as
        is consistent with the required architecture, see
        :class:`pytorch_widedeep.models.wide.Wide`
    deeptabular: ``nn.Module``, Optional, default = None
J
jrzaurin 已提交
51 52 53 54 55
        currently ``pytorch-widedeep`` implements a number of possible
        architectures for the ``deeptabular`` component. See the documenation
        of the package. I recommend using the ``deeptabular`` components in
        this package. However, it is possible to use a custom model as long
        as is  consistent with the required architecture.
56 57 58 59
    deeptext: ``nn.Module``, Optional, default = None
        Model for the text input. Must be an object of class ``DeepText``
        or a custom model as long as is consistent with the required
        architecture. See
60
        :class:`pytorch_widedeep.models.deep_text.DeepText`
61 62 63 64
    deepimage: ``nn.Module``, Optional, default = None
        Model for the images input. Must be an object of class
        ``DeepImage`` or a custom model as long as is consistent with the
        required architecture. See
65
        :class:`pytorch_widedeep.models.deep_image.DeepImage`
66 67 68 69 70 71 72 73 74 75 76 77
    deephead: ``nn.Module``, Optional, default = None
        Custom model by the user that will receive the outtput of the deep
        component. Typically a FC-Head (MLP)
    head_hidden_dims: List, Optional, default = None
        Alternatively, the ``head_hidden_dims`` param can be used to
        specify the sizes of the stacked dense layers in the fc-head e.g:
        ``[128, 64]``. Use ``deephead`` or ``head_hidden_dims``, but not
        both.
    head_dropout: float, default = 0.1
        If ``head_hidden_dims`` is not None, dropout between the layers in
        ``head_hidden_dims``
    head_activation: str, default = "relu"
J
jrzaurin 已提交
78 79
        If ``head_hidden_dims`` is not None, activation function of the head
        layers. One of ``tanh``, ``relu``, ``gelu`` or ``leaky_relu``
80 81 82 83 84 85 86 87 88 89 90
    head_batchnorm: bool, default = False
        If ``head_hidden_dims`` is not None, specifies if batch
        normalizatin should be included in the head layers
    head_batchnorm_last: bool, default = False
        If ``head_hidden_dims`` is not None, boolean indicating whether or
        not to apply batch normalization to the last of the dense layers
    head_linear_first: bool, default = False
        If ``head_hidden_dims`` is not None, boolean indicating whether
        the order of the operations in the dense layer. If ``True``:
        ``[LIN -> ACT -> BN -> DP]``. If ``False``: ``[BN -> DP -> LIN ->
        ACT]``
P
Pavol Mulinka 已提交
91
    enforce_positive: bool, default = False
P
Pavol Mulinka 已提交
92 93 94
        If final layer has activation function or not. Important if you are using
        loss functions non-negative input restrictions, e.g. RMSLE, or if you know
        your predictions are limited only to <0, inf)
P
Pavol Mulinka 已提交
95 96 97
    enforce_positive_activation: str, default = "softplus"
        Activation function to enforce positive output from final layer. Use
        "softplus" or "relu".
98 99 100 101 102 103 104 105
    pred_dim: int, default = 1
        Size of the final wide and deep output layer containing the
        predictions. `1` for regression and binary classification or number
        of classes for multiclass classification.

    Examples
    --------

106
    >>> from pytorch_widedeep.models import TabResnet, Vision, AttentiveRNN, Wide, WideDeep
107 108 109
    >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
    >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
    >>> wide = Wide(10, 1)
110 111 112
    >>> deeptabular = TabResnet(blocks_dims=[8, 4], column_idx=column_idx, cat_embed_input=embed_input)
    >>> deeptext = AttentiveRNN(vocab_size=10, embed_dim=4, padding_idx=0)
    >>> deepimage = Vision()
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
    >>> model = WideDeep(wide=wide, deeptabular=deeptabular, deeptext=deeptext, deepimage=deepimage)


    .. note:: While I recommend using the ``wide`` and ``deeptabular`` components
        within this package when building the corresponding model components,
        it is very likely that the user will want to use custom text and image
        models. That is perfectly possible. Simply, build them and pass them
        as the corresponding parameters. Note that the custom models MUST
        return a last layer of activations (i.e. not the final prediction) so
        that  these activations are collected by ``WideDeep`` and combined
        accordingly. In addition, the models MUST also contain an attribute
        ``output_dim`` with the size of these last layers of activations. See
        for example :class:`pytorch_widedeep.models.tab_mlp.TabMlp`

    """

129
    def __init__(
J
jrzaurin 已提交
130
        self,
131
        wide: Optional[nn.Module] = None,
132
        deeptabular: Optional[nn.Module] = None,
J
jrzaurin 已提交
133 134 135
        deeptext: Optional[nn.Module] = None,
        deepimage: Optional[nn.Module] = None,
        deephead: Optional[nn.Module] = None,
136
        head_hidden_dims: Optional[List[int]] = None,
137 138 139 140 141
        head_activation: str = "relu",
        head_dropout: float = 0.1,
        head_batchnorm: bool = False,
        head_batchnorm_last: bool = False,
        head_linear_first: bool = False,
P
Pavol Mulinka 已提交
142 143
        enforce_positive: bool = False,
        enforce_positive_activation: str = "softplus",
144
        pred_dim: int = 1,
J
jrzaurin 已提交
145
    ):
146
        super(WideDeep, self).__init__()
147

J
jrzaurin 已提交
148 149
        self._check_model_components(
            wide,
150
            deeptabular,
J
jrzaurin 已提交
151 152 153
            deeptext,
            deepimage,
            deephead,
154
            head_hidden_dims,
J
jrzaurin 已提交
155
            pred_dim,
156
        )
157

158 159 160
        # required as attribute just in case we pass a deephead
        self.pred_dim = pred_dim

161
        # The main 5 components of the wide and deep assemble
162
        self.wide = wide
163
        self.deeptabular = deeptabular
J
jrzaurin 已提交
164
        self.deeptext = deeptext
165
        self.deepimage = deepimage
166
        self.deephead = deephead
P
Pavol Mulinka 已提交
167
        self.enforce_positive = enforce_positive
168

169 170
        if self.deeptabular is not None:
            self.is_tabnet = deeptabular.__class__.__name__ == "TabNet"
171 172
        else:
            self.is_tabnet = False
173

174
        if self.deephead is None:
175
            if head_hidden_dims is not None:
176 177 178 179 180 181 182 183
                self._build_deephead(
                    head_hidden_dims,
                    head_activation,
                    head_dropout,
                    head_batchnorm,
                    head_batchnorm_last,
                    head_linear_first,
                )
184 185 186
            else:
                self._add_pred_layer()

P
Pavol Mulinka 已提交
187 188 189
        if self.enforce_positive:
            self.enf_pos = get_activation_fn(enforce_positive_activation)

190 191 192
    def forward(self, X: Dict[str, Tensor]):
        wide_out = self._forward_wide(X)
        if self.deephead:
P
Pavol Mulinka 已提交
193 194 195 196 197
            deep = self._forward_deephead(X, wide_out)
        else:
            deep = self._forward_deep(X, wide_out)
        if self.enforce_positive:
            return self.enf_pos(deep)
198
        else:
P
Pavol Mulinka 已提交
199
            return deep
200

201 202 203 204 205 206 207 208 209
    def _build_deephead(
        self,
        head_hidden_dims,
        head_activation,
        head_dropout,
        head_batchnorm,
        head_batchnorm_last,
        head_linear_first,
    ):
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
        deep_dim = 0
        if self.deeptabular is not None:
            deep_dim += self.deeptabular.output_dim
        if self.deeptext is not None:
            deep_dim += self.deeptext.output_dim
        if self.deepimage is not None:
            deep_dim += self.deepimage.output_dim

        head_hidden_dims = [deep_dim] + head_hidden_dims
        self.deephead = MLP(
            head_hidden_dims,
            head_activation,
            head_dropout,
            head_batchnorm,
            head_batchnorm_last,
            head_linear_first,
        )
        self.deephead.add_module(
            "head_out", nn.Linear(head_hidden_dims[-1], self.pred_dim)
        )

    def _add_pred_layer(self):
        if self.deeptabular is not None:
            if self.is_tabnet:
                self.deeptabular = nn.Sequential(
                    self.deeptabular,
                    TabNetPredLayer(self.deeptabular.output_dim, self.pred_dim),
J
jrzaurin 已提交
237
                )
238
            else:
239 240 241 242 243 244 245 246 247 248 249 250 251 252
                self.deeptabular = nn.Sequential(
                    self.deeptabular,
                    nn.Linear(self.deeptabular.output_dim, self.pred_dim),
                )
        if self.deeptext is not None:
            self.deeptext = nn.Sequential(
                self.deeptext, nn.Linear(self.deeptext.output_dim, self.pred_dim)
            )
        if self.deepimage is not None:
            self.deepimage = nn.Sequential(
                self.deepimage, nn.Linear(self.deepimage.output_dim, self.pred_dim)
            )

    def _forward_wide(self, X):
253 254 255 256
        if self.wide is not None:
            out = self.wide(X["wide"])
        else:
            batch_size = X[list(X.keys())[0]].size(0)
257
            out = torch.zeros(batch_size, self.pred_dim).to(device)
258

259 260 261 262 263 264 265 266
        return out

    def _forward_deephead(self, X, wide_out):
        if self.deeptabular is not None:
            if self.is_tabnet:
                tab_out = self.deeptabular(X["deeptabular"])
                deepside, M_loss = tab_out[0], tab_out[1]
            else:
267
                deepside = self.deeptabular(X["deeptabular"])
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
        else:
            deepside = torch.FloatTensor().to(device)
        if self.deeptext is not None:
            deepside = torch.cat([deepside, self.deeptext(X["deeptext"])], axis=1)
        if self.deepimage is not None:
            deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1)

        deephead_out = self.deephead(deepside)
        deepside_out = nn.Linear(deephead_out.size(1), self.pred_dim).to(device)

        if self.is_tabnet:
            res = (wide_out.add_(deepside_out(deephead_out)), M_loss)
        else:
            res = wide_out.add_(deepside_out(deephead_out))

        return res

    def _forward_deep(self, X, wide_out):
        if self.deeptabular is not None:
            if self.is_tabnet:
                tab_out, M_loss = self.deeptabular(X["deeptabular"])
                wide_out.add_(tab_out)
290
            else:
291 292 293 294 295 296 297 298
                wide_out.add_(self.deeptabular(X["deeptabular"]))
        if self.deeptext is not None:
            wide_out.add_(self.deeptext(X["deeptext"]))
        if self.deepimage is not None:
            wide_out.add_(self.deepimage(X["deepimage"]))

        if self.is_tabnet:
            res = (wide_out, M_loss)
299
        else:
300 301 302
            res = wide_out

        return res
303

J
jrzaurin 已提交
304
    @staticmethod  # noqa: C901
305
    def _check_model_components(  # noqa: C901
J
jrzaurin 已提交
306
        wide,
307
        deeptabular,
J
jrzaurin 已提交
308 309 310
        deeptext,
        deepimage,
        deephead,
311
        head_hidden_dims,
J
jrzaurin 已提交
312
        pred_dim,
313 314
    ):

J
jrzaurin 已提交
315 316 317 318 319 320 321
        if wide is not None:
            assert wide.wide_linear.weight.size(1) == pred_dim, (
                "the 'pred_dim' of the wide component ({}) must be equal to the 'pred_dim' "
                "of the deep component and the overall model itself ({})".format(
                    wide.wide_linear.weight.size(1), pred_dim
                )
            )
322
        if deeptabular is not None and not hasattr(deeptabular, "output_dim"):
323
            raise AttributeError(
324
                "deeptabular model must have an 'output_dim' attribute. "
325
                "See pytorch-widedeep.models.deep_text.DeepText"
326
            )
327 328 329 330 331 332 333 334 335
        if deeptabular is not None:
            is_tabnet = deeptabular.__class__.__name__ == "TabNet"
            has_wide_text_or_image = (
                wide is not None or deeptext is not None or deepimage is not None
            )
            if is_tabnet and has_wide_text_or_image:
                warnings.warn(
                    "'WideDeep' is a model comprised by multiple components and the 'deeptabular'"
                    " component is 'TabNet'. We recommend using 'TabNet' in isolation."
336
                    " The reasons are: i)'TabNet' uses sparse regularization which partially losses"
337 338
                    " its purpose when used in combination with other components."
                    " If you still want to use a multiple component model with 'TabNet',"
339 340 341
                    " consider setting 'lambda_sparse' to 0 during training. ii) The feature"
                    " importances will be computed only for TabNet but the model will comprise multiple"
                    " components. Therefore, such importances will partially lose their 'meaning'.",
342 343
                    UserWarning,
                )
344 345 346
        if deeptext is not None and not hasattr(deeptext, "output_dim"):
            raise AttributeError(
                "deeptext model must have an 'output_dim' attribute. "
347
                "See pytorch-widedeep.models.deep_text.DeepText"
348 349 350 351
            )
        if deepimage is not None and not hasattr(deepimage, "output_dim"):
            raise AttributeError(
                "deepimage model must have an 'output_dim' attribute. "
352
                "See pytorch-widedeep.models.deep_text.DeepText"
353
            )
354
        if deephead is not None and head_hidden_dims is not None:
355
            raise ValueError(
356
                "both 'deephead' and 'head_hidden_dims' are not None. Use one of the other, but not both"
357
            )
358
        if (
359
            head_hidden_dims is not None
360 361 362 363
            and not deeptabular
            and not deeptext
            and not deepimage
        ):
364
            raise ValueError(
365
                "if 'head_hidden_dims' is not None, at least one deep component must be used"
366
            )
J
jrzaurin 已提交
367 368 369
        if deephead is not None:
            deephead_inp_feat = next(deephead.parameters()).size(1)
            output_dim = 0
370 371
            if deeptabular is not None:
                output_dim += deeptabular.output_dim
J
jrzaurin 已提交
372 373 374 375 376 377 378 379 380 381
            if deeptext is not None:
                output_dim += deeptext.output_dim
            if deepimage is not None:
                output_dim += deepimage.output_dim
            assert deephead_inp_feat == output_dim, (
                "if a custom 'deephead' is used its input features ({}) must be equal to "
                "the output features of the deep component ({})".format(
                    deephead_inp_feat, output_dim
                )
            )