tab_perceiver.py 15.5 KB
Newer Older
1 2 3 4 5
import torch
import einops
from torch import nn

from pytorch_widedeep.wdtypes import *  # noqa: F403
6 7 8 9
from pytorch_widedeep.models._embeddings_layers import CatAndContEmbeddings
from pytorch_widedeep.models.tabular.mlp._layers import MLP
from pytorch_widedeep.models.tabular.transformers._encoders import (
    PerceiverEncoder,
10 11 12
)


13
class TabPerceiver(nn.Module):
J
jrzaurin 已提交
14 15 16
    r"""Defines an adaptation of a ``Perceiver`` model
    (`arXiv:2103.03206 <https://arxiv.org/abs/2103.03206>`_) that can be used
    as the ``deeptabular`` component of a Wide & Deep model.
17 18 19 20 21

    Parameters
    ----------
    column_idx: Dict
        Dict containing the index of the columns that will be passed through
J
jrzaurin 已提交
22 23
        the model. Required to slice the tensors. e.g.
        {'education': 0, 'relationship': 1, 'workclass': 2, ...}
24 25
    embed_input: List
        List of Tuples with the column name and number of unique values
J
jrzaurin 已提交
26
        e.g. [('education', 11), ...]
27 28 29 30 31
    embed_dropout: float, default = 0.1
        Dropout to be applied to the embeddings matrix
    full_embed_dropout: bool, default = False
        Boolean indicating if an entire embedding (i.e. the representation of
        one column) will be dropped in the batch. See:
J
jrzaurin 已提交
32
        :obj:`pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`.
33 34
        If ``full_embed_dropout = True``, ``embed_dropout`` is ignored.
    shared_embed: bool, default = False
35 36 37 38
        The idea behind ``shared_embed`` is described in the Appendix A in the
        `TabTransformer paper <https://arxiv.org/abs/2012.06678>`_: `'The
        goal of having column embedding is to enable the model to distinguish
        the classes in one column from those in the other columns'`. In other
J
jrzaurin 已提交
39
        words, the idea is to let the model learn which column is embedded
40
        at the time.
41 42 43
    add_shared_embed: bool, default = False,
        The two embedding sharing strategies are: 1) add the shared embeddings to the column
        embeddings or 2) to replace the first ``frac_shared_embed`` with the shared
J
jrzaurin 已提交
44
        embeddings. See :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings`
45
    frac_shared_embed: float, default = 0.25
J
jrzaurin 已提交
46 47 48
        The fraction of embeddings that will be shared (if ``add_shared_embed
        = False``) by all the different categories for one particular
        column.
49 50 51 52
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
    embed_continuous_activation: str, default = None
        String indicating the activation function to be applied to the
J
jrzaurin 已提交
53 54
        continuous embeddings, if any. ``tanh``, ``relu``, ``leaky_relu`` and
        ``gelu`` are supported.
55
    cont_norm_layer: str, default =  None,
J
jrzaurin 已提交
56 57 58
        Type of normalization layer applied to the continuous features before
        they are embedded. Options are: ``layernorm``, ``batchnorm`` or
        ``None``.
59 60 61 62 63
    input_dim: int, default = 32
        The so-called *dimension of the model*. In general, is the number of
        embeddings used to encode the categorical and/or continuous columns.
    n_cross_attns: int, default = 1
        Number of times each perceiver block will cross attend to the input
64 65 66 67 68 69
        data (i.e. number of cross attention components per perceiver block).
        This should normally be 1. However, in the paper they describe some
        architectures (normally computer vision-related problems) where the
        Perceiver attends multiple times to the input array. Therefore, maybe
        multiple cross attention to the input array is also useful in some
        cases for tabular data
70 71 72
    n_cross_attn_heads: int, default = 4
        Number of attention heads for the cross attention component
    n_latents: int, default = 16
73
        Number of latents. This is the *N* parameter in the paper. As
J
jrzaurin 已提交
74
        indicated in the paper, this number should be significantly lower
75
        than *M* (the number of columns in the dataset). Setting *N* closer
J
jrzaurin 已提交
76 77
        to *M* defies the main purpose of the Perceiver, which is to overcome
        the transformer quadratic bottleneck
78
    latent_dim: int, default = 128
79
        Latent dimension.
80 81 82
    n_latent_heads: int, default = 4
        Number of attention heads per Latent Transformer
    n_latent_blocks: int, default = 4
83 84
        Number of transformer encoder blocks (normalised MHA + normalised FF)
        per Latent Transformer
85
    n_perceiver_blocks: int, default = 4
J
jrzaurin 已提交
86
        Number of Perceiver blocks defined as [Cross Attention + Latent
87
        Transformer]
88
    share_weights: Boolean, default = False
89 90
        Boolean indicating if the weights will be shared between Perceiver
        blocks
91
    attn_dropout: float, default = 0.2
J
jrzaurin 已提交
92
        Dropout that will be applied to the Multi-Head Attention layers
93 94
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
95
    transformer_activation: str, default = "gelu"
J
jrzaurin 已提交
96 97
        Transformer Encoder activation function. ``tanh``, ``relu``,
        ``leaky_relu``, ``gelu``, ``geglu`` and ``reglu`` are supported
98
    mlp_hidden_dims: List, Optional, default = None
J
jrzaurin 已提交
99 100
        MLP hidden dimensions. If not provided it will default to ``[l, 4*l,
        2*l]`` where ``l`` is the MLP input dimension
101
    mlp_activation: str, default = "relu"
J
jrzaurin 已提交
102 103
        MLP activation function. ``tanh``, ``relu``, ``leaky_relu`` and
        ``gelu`` are supported
104 105
    mlp_dropout: float, default = 0.1
        Dropout that will be applied to the final MLP
106 107 108 109 110 111 112 113 114 115 116 117 118
    mlp_batchnorm: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        dense layers
    mlp_batchnorm_last: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers
    mlp_linear_first: bool, default = False
        Boolean indicating whether the order of the operations in the dense
        layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
        LIN -> ACT]``

    Attributes
    ----------
119
    cat_and_cont_embed: ``nn.Module``
J
jrzaurin 已提交
120
        This is the module that processes the categorical and continuous columns
121 122 123 124 125 126 127 128 129 130 131 132 133
    perceiver_blks: ``nn.ModuleDict``
        ModuleDict with the Perceiver blocks
    latents: ``nn.Parameter``
        Latents that will be used for prediction
    perceiver_mlp: ``nn.Module``
        MLP component in the model
    output_dim: int
        The output dimension of the model. This is a required attribute
        neccesary to build the WideDeep class

    Example
    --------
    >>> import torch
134
    >>> from pytorch_widedeep.models import TabPerceiver
135 136 137 138 139
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
    >>> embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
    >>> continuous_cols = ['e']
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
140
    >>> model = TabPerceiver(column_idx=column_idx, embed_input=embed_input,
141 142 143 144 145
    ... continuous_cols=continuous_cols, n_latents=2, latent_dim=16,
    ... n_perceiver_blocks=2)
    >>> out = model(X_tab)
    """

146 147 148 149 150 151 152 153 154 155 156 157 158 159
    def __init__(
        self,
        column_idx: Dict[str, int],
        embed_input: Optional[List[Tuple[str, int]]] = None,
        embed_dropout: float = 0.1,
        full_embed_dropout: bool = False,
        shared_embed: bool = False,
        add_shared_embed: bool = False,
        frac_shared_embed: float = 0.25,
        continuous_cols: Optional[List[str]] = None,
        embed_continuous_activation: str = None,
        cont_norm_layer: str = None,
        input_dim: int = 32,
        n_cross_attns: int = 1,
160
        n_cross_attn_heads: int = 4,
161 162
        n_latents: int = 16,
        latent_dim: int = 128,
163 164 165
        n_latent_heads: int = 4,
        n_latent_blocks: int = 4,
        n_perceiver_blocks: int = 4,
166
        share_weights: bool = False,
167 168
        attn_dropout: float = 0.1,
        ff_dropout: float = 0.1,
169 170 171
        transformer_activation: str = "geglu",
        mlp_hidden_dims: Optional[List[int]] = None,
        mlp_activation: str = "relu",
172
        mlp_dropout: float = 0.1,
173 174 175 176
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = True,
    ):
177
        super(TabPerceiver, self).__init__()
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197

        self.column_idx = column_idx
        self.embed_input = embed_input
        self.embed_dropout = embed_dropout
        self.full_embed_dropout = full_embed_dropout
        self.shared_embed = shared_embed
        self.add_shared_embed = add_shared_embed
        self.frac_shared_embed = frac_shared_embed
        self.continuous_cols = continuous_cols
        self.embed_continuous_activation = embed_continuous_activation
        self.cont_norm_layer = cont_norm_layer
        self.input_dim = input_dim
        self.n_cross_attns = n_cross_attns
        self.n_cross_attn_heads = n_cross_attn_heads
        self.n_latents = n_latents
        self.latent_dim = latent_dim
        self.n_latent_heads = n_latent_heads
        self.n_latent_blocks = n_latent_blocks
        self.n_perceiver_blocks = n_perceiver_blocks
        self.share_weights = share_weights
198 199
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
200 201 202 203 204 205 206 207 208 209 210 211
        self.transformer_activation = transformer_activation
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_activation = mlp_activation
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
        self.mlp_linear_first = mlp_linear_first

        if mlp_hidden_dims is not None:
            assert (
                mlp_hidden_dims[0] == latent_dim
            ), "The first mlp input dim must be equal to 'latent_dim'"

212
        # This should be named 'cat_and_cont_embed' since the continuous cols
213
        # will always be embedded for the TabPerceiver. However is very
214
        # convenient for other funcionalities to name
215 216
        # it 'cat_and_cont_embed'
        self.cat_and_cont_embed = CatAndContEmbeddings(
217 218 219 220 221 222 223 224
            input_dim,
            column_idx,
            embed_input,
            embed_dropout,
            full_embed_dropout,
            shared_embed,
            add_shared_embed,
            frac_shared_embed,
225
            False,  # use_embed_bias
226 227 228
            continuous_cols,
            True,  # embed_continuous,
            embed_continuous_activation,
229
            True,  # use_cont_bias
230 231 232 233 234 235 236 237 238
            cont_norm_layer,
        )

        self.latents = nn.init.trunc_normal_(
            nn.Parameter(torch.empty(n_latents, latent_dim))
        )

        self.perceiver_blks = nn.ModuleDict()
        first_perceiver_block = self._build_perceiver_block()
239
        self.perceiver_blks["perceiver_block0"] = first_perceiver_block
240 241 242

        if share_weights:
            for n in range(1, n_perceiver_blocks):
243
                self.perceiver_blks["perceiver_block" + str(n)] = first_perceiver_block
244 245
        else:
            for n in range(1, n_perceiver_blocks):
246 247 248
                self.perceiver_blks[
                    "perceiver_block" + str(n)
                ] = self._build_perceiver_block()
249 250

        if not mlp_hidden_dims:
251
            self.mlp_hidden_dims = [latent_dim, latent_dim * 4, latent_dim * 2]
252 253 254
        else:
            assert mlp_hidden_dims[0] == latent_dim, (
                f"The input dim of the MLP must be {latent_dim}. "
J
jrzaurin 已提交
255
                f"Got {mlp_hidden_dims[0]} instead"
256
            )
257
        self.perceiver_mlp = MLP(
258
            self.mlp_hidden_dims,
259
            mlp_activation,
260
            mlp_dropout,
261 262 263 264 265 266
            mlp_batchnorm,
            mlp_batchnorm_last,
            mlp_linear_first,
        )

        # the output_dim attribute will be used as input_dim when "merging" the models
267
        self.output_dim = self.mlp_hidden_dims[-1]
268 269 270

    def forward(self, X: Tensor) -> Tensor:

271 272 273 274 275
        x_cat, x_cont = self.cat_and_cont_embed(X)
        if x_cat is not None:
            x_emb = x_cat
        if x_cont is not None:
            x_emb = torch.cat([x_emb, x_cont], 1) if x_cat is not None else x_cont
276

277 278 279
        x = einops.repeat(self.latents, "n d -> b n d", b=X.shape[0])

        for n in range(self.n_perceiver_blocks):
280 281 282 283
            cross_attns = self.perceiver_blks["perceiver_block" + str(n)]["cross_attns"]
            latent_transformer = self.perceiver_blks["perceiver_block" + str(n)][
                "latent_transformer"
            ]
284
            for cross_attn in cross_attns:
285
                x = cross_attn(x, x_emb)
286 287
            x = latent_transformer(x)

288 289 290
        # average along the latent index axis
        x = x.mean(dim=1)

291 292 293
        return self.perceiver_mlp(x)

    @property
294 295 296 297 298
    def attention_weights(self) -> List:
        r"""List with the attention weights. If the weights are not shared
        between perceiver blocks each element of the list will be a list
        itself containing the Cross Attention and Latent Transformer
        attention weights respectively
299 300 301 302 303 304 305 306 307 308

        The shape of the attention weights is:

            - Cross Attention: :math:`(N, C, L, F)`
            - Latent Attention: :math:`(N, T, L, L)`

        WHere *N* is the batch size, *C* is the number of Cross Attention
        heads, *L* is the number of Latents, *F* is the number of
        features/columns in the dataset and *T* is the number of Latent
        Attention heads
309
        """
310
        if self.share_weights:
311 312 313 314
            cross_attns = self.perceiver_blks["perceiver_block0"]["cross_attns"]
            latent_transformer = self.perceiver_blks["perceiver_block0"][
                "latent_transformer"
            ]
315 316 317 318 319 320
            attention_weights = self._extract_attn_weights(
                cross_attns, latent_transformer
            )
        else:
            attention_weights = []
            for n in range(self.n_perceiver_blocks):
321 322 323 324 325 326
                cross_attns = self.perceiver_blks["perceiver_block" + str(n)][
                    "cross_attns"
                ]
                latent_transformer = self.perceiver_blks["perceiver_block" + str(n)][
                    "latent_transformer"
                ]
327 328 329 330 331
                attention_weights.append(
                    self._extract_attn_weights(cross_attns, latent_transformer)
                )
        return attention_weights

332
    def _build_perceiver_block(self) -> nn.ModuleDict:
333

334
        perceiver_block = nn.ModuleDict()
335 336 337 338 339

        # Cross Attention
        cross_attns = nn.ModuleList()
        for _ in range(self.n_cross_attns):
            cross_attns.append(
340
                PerceiverEncoder(
341 342 343
                    self.input_dim,
                    self.n_cross_attn_heads,
                    False,  # use_bias
344 345
                    self.attn_dropout,
                    self.ff_dropout,
346
                    self.transformer_activation,
347
                    self.latent_dim,  # q_dim,
348 349
                ),
            )
350
        perceiver_block["cross_attns"] = cross_attns
351 352 353 354 355

        # Latent Transformer
        latent_transformer = nn.Sequential()
        for i in range(self.n_latent_blocks):
            latent_transformer.add_module(
356
                "latent_block" + str(i),
357
                PerceiverEncoder(
358
                    self.latent_dim,  # input_dim
359 360
                    self.n_latent_heads,
                    False,  # use_bias
361 362
                    self.attn_dropout,
                    self.ff_dropout,
363 364 365
                    self.transformer_activation,
                ),
            )
366
        perceiver_block["latent_transformer"] = latent_transformer
367 368 369 370

        return perceiver_block

    @staticmethod
371
    def _extract_attn_weights(cross_attns, latent_transformer) -> List:
372 373 374 375 376 377
        attention_weights = []
        for cross_attn in cross_attns:
            attention_weights.append(cross_attn.attn.attn_weights)
        for latent_block in latent_transformer:
            attention_weights.append(latent_block.attn.attn_weights)
        return attention_weights