wide_deep.py 12.9 KB
Newer Older
1
import warnings
2

3
import torch
4 5
import torch.nn as nn

6 7
from pytorch_widedeep.wdtypes import *  # noqa: F403
from pytorch_widedeep.models.tab_mlp import MLP
J
jrzaurin 已提交
8

9 10
warnings.filterwarnings("default", category=DeprecationWarning)

11
use_cuda = torch.cuda.is_available()
12
device = torch.device("cuda" if use_cuda else "cpu")
13 14 15


class WideDeep(nn.Module):
16
    def __init__(
J
jrzaurin 已提交
17
        self,
18
        wide: Optional[nn.Module] = None,
19
        deeptabular: Optional[nn.Module] = None,
J
jrzaurin 已提交
20 21 22
        deeptext: Optional[nn.Module] = None,
        deepimage: Optional[nn.Module] = None,
        deephead: Optional[nn.Module] = None,
23 24
        head_layers_dim: Optional[List[int]] = None,
        head_activation: Optional[str] = "relu",
25
        head_dropout: float = None,
26 27 28
        head_batchnorm: Optional[bool] = False,
        head_batchnorm_last: Optional[bool] = False,
        head_linear_first: Optional[bool] = False,
29
        pred_dim: int = 1,
J
jrzaurin 已提交
30
    ):
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
        r"""Main collector class that combines all ``wide``, ``deeptabular``
        (which can be a number of architectures), ``deeptext`` and ``deepimage`` models.

        There are two options to combine these models that correspond to the two
        architectures that ``pytorch-widedeep`` can build.

            - Directly connecting the output of the model components to an ouput neuron(s).

            - Adding a `Fully-Connected Head` (FC-Head) on top of the deep models.
              This FC-Head will combine the output form the ``deeptabular``, ``deeptext`` and
              ``deepimage`` and will be then connected to the output neuron(s).

        Parameters
        ----------
        wide: ``nn.Module``, Optional
            ``Wide`` model. I recommend using the ``Wide`` class in this
            package. However, it is possible to use a custom model as long as
            is consistent with the required architecture, see
            :class:`pytorch_widedeep.models.wide.Wide`
        deeptabular: ``nn.Module``, Optional

            currently ``pytorch-widedeep`` implements three possible
            architectures for the `deeptabular` component. These are:
            ``TabMlp``, ``TabResnet`` and ` ``TabTransformer``.

            1. ``TabMlp`` is simply an embedding layer encoding the categorical
            features that are then concatenated and passed through a series of
            dense (hidden) layers (i.e. and MLP).
            See: ``pytorch_widedeep.models.deep_dense.TabMlp``

            2. ``TabResnet`` is an embedding layer encoding the categorical
            features that are then concatenated and passed through a series of
            ResNet blocks formed by dense layers.
            See ``pytorch_widedeep.models.deep_dense_resnet.TabResnet``

            3. ``TabTransformer`` is detailed in `TabTransformer: Tabular Data
            Modeling Using Contextual Embeddings
            <https://arxiv.org/pdf/2012.06678.pdf>`_. See
            ``pytorch_widedeep.models.tab_transformer.TabTransformer``

            I recommend using on of these as ``deeptabular``. However, a
            custom model as long as is  consistent with the required
            architecture. See
            :class:`pytorch_widedeep.models.deep_dense.TabTransformer`.

        deeptext: ``nn.Module``, Optional
            Model for the text input. Must be an object of class ``DeepText``
            or a custom model as long as is consistent with the required
            architecture. See
            :class:`pytorch_widedeep.models.deep_dense.DeepText`
        deepimage: ``nn.Module``, Optional
            Model for the images input. Must be an object of class
            ``DeepImage`` or a custom model as long as is consistent with the
            required architecture. See
            :class:`pytorch_widedeep.models.deep_dense.DeepImage`
        deephead: ``nn.Module``, Optional
            Custom model by the user that will receive the outtput of the deep
            component. Typically a FC-Head (MLP)
        head_layers_dim: List, Optional
            Alternatively, the ``head_layers_dim`` param can be used to
            specify the sizes of the stacked dense layers in the fc-head e.g:
            ``[128, 64]``
        head_dropout: float, Optional
            Dropout between the layers in ``head_layers_dim``
        head_activation: str, default = "relu"
            activation function of the head layers. One of "relu", gelu" or
            "leaky_relu"
        head_batchnorm: bool, Optional
            Specifies if batch normalizatin should be included in the head layers
        head_batchnorm_last: bool, default = False
            Boolean indicating whether or not to apply batch normalization to the
            last of the dense layers
        head_linear_first: bool, default = False
            Boolean indicating whether the order of the operations in the dense
            layer. If 'True': [LIN -> ACT -> BN -> DP]. If 'False': [BN -> DP ->
            LIN -> ACT]
        pred_dim: int
            Size of the final wide and deep output layer containing the
            predictions. `1` for regression and binary classification or number
            of classes for multiclass classification.

        Attributes
        ----------
        cyclic_lr: bool
            Attribute that indicates if any of the lr_schedulers is cyclic_lr (i.e. ``CyclicLR`` or
            ``OneCycleLR``). See `Pytorch schedulers <https://pytorch.org/docs/stable/optim.html>`_.

        Examples
        --------

        >>> from pytorch_widedeep.models import TabResnet, DeepImage, DeepText, Wide, WideDeep
        >>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
        >>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
        >>> wide = Wide(10, 1)
        >>> deeptabular = TabResnet(blocks_dims=[8, 4], column_idx=column_idx, embed_input=embed_input)
        >>> deeptext = DeepText(vocab_size=10, embed_dim=4, padding_idx=0)
        >>> deepimage = DeepImage(pretrained=False)
        >>> model = WideDeep(wide=wide, deeptabular=deeptabular, deeptext=deeptext, deepimage=deepimage)


        .. note:: While I recommend using the ``wide`` and ``deeptabular`` components
            within this package when building the corresponding model components,
            it is very likely that the user will want to use custom text and image
            models. That is perfectly possible. Simply, build them and pass them
            as the corresponding parameters. Note that the custom models MUST
            return a last layer of activations (i.e. not the final prediction) so
            that  these activations are collected by ``WideDeep`` and combined
            accordingly. In addition, the models MUST also contain an attribute
            ``output_dim`` with the size of these last layers of activations. See
            for example :class:`pytorch_widedeep.models.tab_mlp.TabMlp`

        """
143
        super(WideDeep, self).__init__()
144

J
jrzaurin 已提交
145 146
        self._check_model_components(
            wide,
147
            deeptabular,
J
jrzaurin 已提交
148 149 150
            deeptext,
            deepimage,
            deephead,
151
            head_layers_dim,
J
jrzaurin 已提交
152
            pred_dim,
153
        )
154

155 156 157
        # required as attribute just in case we pass a deephead
        self.pred_dim = pred_dim

158
        # The main 5 components of the wide and deep assemble
159
        self.wide = wide
160
        self.deeptabular = deeptabular
J
jrzaurin 已提交
161
        self.deeptext = deeptext
162
        self.deepimage = deepimage
163 164 165
        self.deephead = deephead

        if self.deephead is None:
166 167
            if head_layers_dim is not None:
                deep_dim = 0
168
                if self.deeptabular is not None:
169
                    deep_dim += self.deeptabular.output_dim  # type:ignore
M
Minjin Choi 已提交
170
                if self.deeptext is not None:
171
                    deep_dim += self.deeptext.output_dim  # type:ignore
M
Minjin Choi 已提交
172
                if self.deepimage is not None:
173 174 175 176 177 178 179 180 181 182
                    deep_dim += self.deepimage.output_dim  # type:ignore
                head_layers_dim = [deep_dim] + head_layers_dim
                self.deephead = MLP(
                    head_layers_dim,
                    head_activation,
                    head_dropout,
                    head_batchnorm,
                    head_batchnorm_last,
                    head_linear_first,
                )
J
jrzaurin 已提交
183
                self.deephead.add_module(
184
                    "head_out", nn.Linear(head_layers_dim[-1], pred_dim)
J
jrzaurin 已提交
185
                )
186
            else:
187 188 189
                if self.deeptabular is not None:
                    self.deeptabular = nn.Sequential(
                        self.deeptabular, nn.Linear(self.deeptabular.output_dim, pred_dim)  # type: ignore
190
                    )
191 192
                if self.deeptext is not None:
                    self.deeptext = nn.Sequential(
193
                        self.deeptext, nn.Linear(self.deeptext.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
194
                    )
195 196
                if self.deepimage is not None:
                    self.deepimage = nn.Sequential(
197
                        self.deepimage, nn.Linear(self.deepimage.output_dim, pred_dim)  # type: ignore
J
jrzaurin 已提交
198
                    )
199

200
    def forward(self, X: Dict[str, Tensor]) -> Tensor:  # type: ignore  # noqa: C901
201

202
        # Wide output: direct connection to the output neuron(s)
203 204 205 206
        if self.wide is not None:
            out = self.wide(X["wide"])
        else:
            batch_size = X[list(X.keys())[0]].size(0)
207
            out = torch.zeros(batch_size, self.pred_dim).to(device)
208 209 210 211

        # Deep output: either connected directly to the output neuron(s) or
        # passed through a head first
        if self.deephead:
212 213
            if self.deeptabular is not None:
                deepside = self.deeptabular(X["deeptabular"])
214
            else:
215
                deepside = torch.FloatTensor().to(device)
216
            if self.deeptext is not None:
J
jrzaurin 已提交
217
                deepside = torch.cat([deepside, self.deeptext(X["deeptext"])], axis=1)  # type: ignore
218
            if self.deepimage is not None:
J
jrzaurin 已提交
219
                deepside = torch.cat([deepside, self.deepimage(X["deepimage"])], axis=1)  # type: ignore
220
            deephead_out = self.deephead(deepside)
221 222
            deepside_linear = nn.Linear(deephead_out.size(1), self.pred_dim).to(device)
            return out.add_(deepside_linear(deephead_out))
223
        else:
224 225
            if self.deeptabular is not None:
                out.add_(self.deeptabular(X["deeptabular"]))
226
            if self.deeptext is not None:
227
                out.add_(self.deeptext(X["deeptext"]))
228
            if self.deepimage is not None:
229
                out.add_(self.deepimage(X["deepimage"]))
230 231
            return out

J
jrzaurin 已提交
232
    @staticmethod  # noqa: C901
J
jrzaurin 已提交
233 234
    def _check_model_components(
        wide,
235
        deeptabular,
J
jrzaurin 已提交
236 237 238
        deeptext,
        deepimage,
        deephead,
239
        head_layers_dim,
J
jrzaurin 已提交
240
        pred_dim,
241 242
    ):

J
jrzaurin 已提交
243 244 245 246 247 248 249
        if wide is not None:
            assert wide.wide_linear.weight.size(1) == pred_dim, (
                "the 'pred_dim' of the wide component ({}) must be equal to the 'pred_dim' "
                "of the deep component and the overall model itself ({})".format(
                    wide.wide_linear.weight.size(1), pred_dim
                )
            )
250
        if deeptabular is not None and not hasattr(deeptabular, "output_dim"):
251
            raise AttributeError(
252
                "deeptabular model must have an 'output_dim' attribute. "
253 254 255 256 257 258 259 260 261 262 263 264
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
        if deeptext is not None and not hasattr(deeptext, "output_dim"):
            raise AttributeError(
                "deeptext model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
        if deepimage is not None and not hasattr(deepimage, "output_dim"):
            raise AttributeError(
                "deepimage model must have an 'output_dim' attribute. "
                "See pytorch-widedeep.models.deep_dense.DeepText"
            )
265
        if deephead is not None and head_layers_dim is not None:
266
            raise ValueError(
267
                "both 'deephead' and 'head_layers_dim' are not None. Use one of the other, but not both"
268
            )
269
        if (
270
            head_layers_dim is not None
271 272 273 274
            and not deeptabular
            and not deeptext
            and not deepimage
        ):
275
            raise ValueError(
276
                "if 'head_layers_dim' is not None, at least one deep component must be used"
277
            )
J
jrzaurin 已提交
278 279 280
        if deephead is not None:
            deephead_inp_feat = next(deephead.parameters()).size(1)
            output_dim = 0
281 282
            if deeptabular is not None:
                output_dim += deeptabular.output_dim
J
jrzaurin 已提交
283 284 285 286 287 288 289 290 291 292
            if deeptext is not None:
                output_dim += deeptext.output_dim
            if deepimage is not None:
                output_dim += deepimage.output_dim
            assert deephead_inp_feat == output_dim, (
                "if a custom 'deephead' is used its input features ({}) must be equal to "
                "the output features of the deep component ({})".format(
                    deephead_inp_feat, output_dim
                )
            )