tab_perceiver.py 15.4 KB
Newer Older
1 2 3 4 5
import torch
import einops
from torch import nn

from pytorch_widedeep.wdtypes import *  # noqa: F403
6
from pytorch_widedeep.models.tabular.mlp._layers import MLP
7 8
from pytorch_widedeep.models.tabular._base_tabular_model import (
    BaseTabularModelWithAttention,
9
)
10 11
from pytorch_widedeep.models.tabular.transformers._encoders import (
    PerceiverEncoder,
12 13 14
)


15
class TabPerceiver(BaseTabularModelWithAttention):
J
jrzaurin 已提交
16 17 18 19
    r"""Defines an adaptation of a `Perceiver model
    <https://arxiv.org/abs/2103.03206>`_ that can be used as the
    ``deeptabular`` component of a Wide & Deep model or independently by
    itself.
20 21 22 23 24

    Parameters
    ----------
    column_idx: Dict
        Dict containing the index of the columns that will be passed through
J
jrzaurin 已提交
25 26
        the model. Required to slice the tensors. e.g.
        {'education': 0, 'relationship': 1, 'workclass': 2, ...}
27
    cat_embed_input: List, Optional, default = None
J
jrzaurin 已提交
28 29
        List of Tuples with the column name and number of unique values for
        each categorical component e.g. [(education, 11), ...]
30
    cat_embed_dropout: float, default = 0.1
31
        Categorical embeddings dropout
J
jrzaurin 已提交
32
    use_cat_bias: bool, default = False,
J
jrzaurin 已提交
33
        Boolean indicating if bias will be used for the categorical embeddings
34
    cat_embed_activation: Optional, str, default = None,
J
jrzaurin 已提交
35 36
        Activation function for the categorical embeddings, if any. `'tanh'`,
        `'relu'`, `'leaky_relu'` and `'gelu'` are supported.
37 38 39
    full_embed_dropout: bool, default = False
        Boolean indicating if an entire embedding (i.e. the representation of
        one column) will be dropped in the batch. See:
J
jrzaurin 已提交
40
        :obj:`pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`.
41
        If ``full_embed_dropout = True``, ``cat_embed_dropout`` is ignored.
42
    shared_embed: bool, default = False
43 44 45 46
        The idea behind ``shared_embed`` is described in the Appendix A in the
        `TabTransformer paper <https://arxiv.org/abs/2012.06678>`_: `'The
        goal of having column embedding is to enable the model to distinguish
        the classes in one column from those in the other columns'`. In other
J
jrzaurin 已提交
47
        words, the idea is to let the model learn which column is embedded
48
        at the time.
49
    add_shared_embed: bool, default = False,
50 51 52 53
        The two embedding sharing strategies are: 1) add the shared embeddings
        to the column embeddings or 2) to replace the first
        ``frac_shared_embed`` with the shared embeddings.
        See :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings`
54
    frac_shared_embed: float, default = 0.25
J
jrzaurin 已提交
55 56 57
        The fraction of embeddings that will be shared (if ``add_shared_embed
        = False``) by all the different categories for one particular
        column.
58 59
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
60 61 62 63 64 65
    cont_norm_layer: str, default =  "batchnorm"
        Type of normalization layer applied to the continuous features. Options
        are: 'layernorm', 'batchnorm' or None.
    cont_embed_dropout: float, default = 0.1,
        Continuous embeddings dropout
    use_cont_bias: bool, default = True,
J
jrzaurin 已提交
66
        Boolean indicating if bias will be used for the continuous embeddings
67
    cont_embed_activation: str, default = None
J
jrzaurin 已提交
68
        Activation function to be applied to the continuous embeddings, if
69
        any. `'tanh'`, `'relu'`, `'leaky_relu'` and `'gelu'` are supported.
70
    input_dim: int, default = 32
71 72
        The so-called *dimension of the model*. Is the number of embeddings
        used to encode the categorical and/or continuous columns.
73 74
    n_cross_attns: int, default = 1
        Number of times each perceiver block will cross attend to the input
75 76 77 78 79 80
        data (i.e. number of cross attention components per perceiver block).
        This should normally be 1. However, in the paper they describe some
        architectures (normally computer vision-related problems) where the
        Perceiver attends multiple times to the input array. Therefore, maybe
        multiple cross attention to the input array is also useful in some
        cases for tabular data
81 82 83
    n_cross_attn_heads: int, default = 4
        Number of attention heads for the cross attention component
    n_latents: int, default = 16
84
        Number of latents. This is the *N* parameter in the paper. As
J
jrzaurin 已提交
85
        indicated in the paper, this number should be significantly lower
86
        than *M* (the number of columns in the dataset). Setting *N* closer
J
jrzaurin 已提交
87 88
        to *M* defies the main purpose of the Perceiver, which is to overcome
        the transformer quadratic bottleneck
89
    latent_dim: int, default = 128
90
        Latent dimension.
91 92 93
    n_latent_heads: int, default = 4
        Number of attention heads per Latent Transformer
    n_latent_blocks: int, default = 4
94 95
        Number of transformer encoder blocks (normalised MHA + normalised FF)
        per Latent Transformer
96
    n_perceiver_blocks: int, default = 4
J
jrzaurin 已提交
97
        Number of Perceiver blocks defined as [Cross Attention + Latent
98
        Transformer]
99
    share_weights: Boolean, default = False
100 101
        Boolean indicating if the weights will be shared between Perceiver
        blocks
102
    attn_dropout: float, default = 0.2
J
jrzaurin 已提交
103
        Dropout that will be applied to the Multi-Head Attention layers
104 105
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
106
    transformer_activation: str, default = "gelu"
107
        Transformer Encoder activation function. `'tanh'`, `'relu'`,
J
jrzaurin 已提交
108
        `'leaky_relu'`, `'gelu'`, `'geglu'` and `'reglu'` are supported
109
    mlp_hidden_dims: List, Optional, default = None
J
jrzaurin 已提交
110
        MLP hidden dimensions. If not provided it will default to ``[l, 4*l,
J
jrzaurin 已提交
111
        2*l]`` where ``l`` is the MLP's input dimension
112
    mlp_activation: str, default = "relu"
113
        MLP activation function. `'tanh'`, `'relu'`, `'leaky_relu'` and
J
jrzaurin 已提交
114
        `'gelu'` are supported
115 116
    mlp_dropout: float, default = 0.1
        Dropout that will be applied to the final MLP
117 118 119 120 121 122 123 124 125 126 127 128 129
    mlp_batchnorm: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        dense layers
    mlp_batchnorm_last: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers
    mlp_linear_first: bool, default = False
        Boolean indicating whether the order of the operations in the dense
        layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
        LIN -> ACT]``

    Attributes
    ----------
130
    cat_and_cont_embed: ``nn.Module``
J
jrzaurin 已提交
131
        This is the module that processes the categorical and continuous columns
132 133 134 135 136 137 138 139
    perceiver_blks: ``nn.ModuleDict``
        ModuleDict with the Perceiver blocks
    latents: ``nn.Parameter``
        Latents that will be used for prediction
    perceiver_mlp: ``nn.Module``
        MLP component in the model
    output_dim: int
        The output dimension of the model. This is a required attribute
J
jrzaurin 已提交
140
        neccesary to build the ``WideDeep`` class
141 142 143 144

    Example
    --------
    >>> import torch
145
    >>> from pytorch_widedeep.models import TabPerceiver
146 147
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
148
    >>> cat_embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
149 150
    >>> continuous_cols = ['e']
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
151
    >>> model = TabPerceiver(column_idx=column_idx, cat_embed_input=cat_embed_input,
152 153 154 155 156
    ... continuous_cols=continuous_cols, n_latents=2, latent_dim=16,
    ... n_perceiver_blocks=2)
    >>> out = model(X_tab)
    """

157 158 159
    def __init__(
        self,
        column_idx: Dict[str, int],
160 161
        cat_embed_input: Optional[List[Tuple[str, int]]] = None,
        cat_embed_dropout: float = 0.1,
162 163
        use_cat_bias: bool = False,
        cat_embed_activation: Optional[str] = None,
164 165 166 167 168 169
        full_embed_dropout: bool = False,
        shared_embed: bool = False,
        add_shared_embed: bool = False,
        frac_shared_embed: float = 0.25,
        continuous_cols: Optional[List[str]] = None,
        cont_norm_layer: str = None,
170 171 172
        cont_embed_dropout: float = 0.1,
        use_cont_bias: bool = True,
        cont_embed_activation: Optional[str] = None,
173 174
        input_dim: int = 32,
        n_cross_attns: int = 1,
175
        n_cross_attn_heads: int = 4,
176 177
        n_latents: int = 16,
        latent_dim: int = 128,
178 179 180
        n_latent_heads: int = 4,
        n_latent_blocks: int = 4,
        n_perceiver_blocks: int = 4,
181
        share_weights: bool = False,
182 183
        attn_dropout: float = 0.1,
        ff_dropout: float = 0.1,
184 185 186
        transformer_activation: str = "geglu",
        mlp_hidden_dims: Optional[List[int]] = None,
        mlp_activation: str = "relu",
187
        mlp_dropout: float = 0.1,
188 189 190 191
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = True,
    ):
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
        super(TabPerceiver, self).__init__(
            column_idx=column_idx,
            cat_embed_input=cat_embed_input,
            cat_embed_dropout=cat_embed_dropout,
            use_cat_bias=use_cat_bias,
            cat_embed_activation=cat_embed_activation,
            full_embed_dropout=full_embed_dropout,
            shared_embed=shared_embed,
            add_shared_embed=add_shared_embed,
            frac_shared_embed=frac_shared_embed,
            continuous_cols=continuous_cols,
            cont_norm_layer=cont_norm_layer,
            embed_continuous=True,
            cont_embed_dropout=cont_embed_dropout,
            use_cont_bias=use_cont_bias,
            cont_embed_activation=cont_embed_activation,
            input_dim=input_dim,
        )
210

211 212 213 214 215 216 217 218
        self.n_cross_attns = n_cross_attns
        self.n_cross_attn_heads = n_cross_attn_heads
        self.n_latents = n_latents
        self.latent_dim = latent_dim
        self.n_latent_heads = n_latent_heads
        self.n_latent_blocks = n_latent_blocks
        self.n_perceiver_blocks = n_perceiver_blocks
        self.share_weights = share_weights
219 220
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
221
        self.transformer_activation = transformer_activation
222

223 224
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_activation = mlp_activation
225
        self.mlp_dropout = mlp_dropout
226 227 228 229
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
        self.mlp_linear_first = mlp_linear_first

J
jrzaurin 已提交
230
        # Embeddings are instantiated at the base model
231
        # Transformer blocks
232 233 234 235 236 237
        self.latents = nn.init.trunc_normal_(
            nn.Parameter(torch.empty(n_latents, latent_dim))
        )

        self.perceiver_blks = nn.ModuleDict()
        first_perceiver_block = self._build_perceiver_block()
238
        self.perceiver_blks["perceiver_block0"] = first_perceiver_block
239 240 241

        if share_weights:
            for n in range(1, n_perceiver_blocks):
242
                self.perceiver_blks["perceiver_block" + str(n)] = first_perceiver_block
243 244
        else:
            for n in range(1, n_perceiver_blocks):
245 246 247
                self.perceiver_blks[
                    "perceiver_block" + str(n)
                ] = self._build_perceiver_block()
248

249
        # Mlp
250
        if not mlp_hidden_dims:
251
            self.mlp_hidden_dims = [latent_dim, latent_dim * 4, latent_dim * 2]
252
        else:
253 254
            self.mlp_hidden_dims = [latent_dim] + mlp_hidden_dims

255
        self.perceiver_mlp = MLP(
256
            self.mlp_hidden_dims,
257
            mlp_activation,
258
            mlp_dropout,
259 260 261 262 263 264
            mlp_batchnorm,
            mlp_batchnorm_last,
            mlp_linear_first,
        )

        # the output_dim attribute will be used as input_dim when "merging" the models
J
jrzaurin 已提交
265
        self.output_dim: int = self.mlp_hidden_dims[-1]
266 267 268

    def forward(self, X: Tensor) -> Tensor:

269
        x_emb = self._get_embeddings(X)
270

271 272 273
        x = einops.repeat(self.latents, "n d -> b n d", b=X.shape[0])

        for n in range(self.n_perceiver_blocks):
274 275 276 277
            cross_attns = self.perceiver_blks["perceiver_block" + str(n)]["cross_attns"]
            latent_transformer = self.perceiver_blks["perceiver_block" + str(n)][
                "latent_transformer"
            ]
278
            for cross_attn in cross_attns:
279
                x = cross_attn(x, x_emb)
280 281
            x = latent_transformer(x)

282 283 284
        # average along the latent index axis
        x = x.mean(dim=1)

285 286 287
        return self.perceiver_mlp(x)

    @property
288 289 290 291 292
    def attention_weights(self) -> List:
        r"""List with the attention weights. If the weights are not shared
        between perceiver blocks each element of the list will be a list
        itself containing the Cross Attention and Latent Transformer
        attention weights respectively
293 294 295 296 297 298 299 300 301 302

        The shape of the attention weights is:

            - Cross Attention: :math:`(N, C, L, F)`
            - Latent Attention: :math:`(N, T, L, L)`

        WHere *N* is the batch size, *C* is the number of Cross Attention
        heads, *L* is the number of Latents, *F* is the number of
        features/columns in the dataset and *T* is the number of Latent
        Attention heads
303
        """
304
        if self.share_weights:
305 306 307 308
            cross_attns = self.perceiver_blks["perceiver_block0"]["cross_attns"]
            latent_transformer = self.perceiver_blks["perceiver_block0"][
                "latent_transformer"
            ]
309 310 311 312 313 314
            attention_weights = self._extract_attn_weights(
                cross_attns, latent_transformer
            )
        else:
            attention_weights = []
            for n in range(self.n_perceiver_blocks):
315 316 317 318 319 320
                cross_attns = self.perceiver_blks["perceiver_block" + str(n)][
                    "cross_attns"
                ]
                latent_transformer = self.perceiver_blks["perceiver_block" + str(n)][
                    "latent_transformer"
                ]
321 322 323 324 325
                attention_weights.append(
                    self._extract_attn_weights(cross_attns, latent_transformer)
                )
        return attention_weights

326
    def _build_perceiver_block(self) -> nn.ModuleDict:
327

328
        perceiver_block = nn.ModuleDict()
329 330 331 332 333

        # Cross Attention
        cross_attns = nn.ModuleList()
        for _ in range(self.n_cross_attns):
            cross_attns.append(
334
                PerceiverEncoder(
335 336 337
                    self.input_dim,
                    self.n_cross_attn_heads,
                    False,  # use_bias
338 339
                    self.attn_dropout,
                    self.ff_dropout,
340
                    self.transformer_activation,
341
                    self.latent_dim,  # q_dim,
342 343
                ),
            )
344
        perceiver_block["cross_attns"] = cross_attns
345 346 347 348 349

        # Latent Transformer
        latent_transformer = nn.Sequential()
        for i in range(self.n_latent_blocks):
            latent_transformer.add_module(
350
                "latent_block" + str(i),
351
                PerceiverEncoder(
352
                    self.latent_dim,  # input_dim
353 354
                    self.n_latent_heads,
                    False,  # use_bias
355 356
                    self.attn_dropout,
                    self.ff_dropout,
357 358 359
                    self.transformer_activation,
                ),
            )
360
        perceiver_block["latent_transformer"] = latent_transformer
361 362 363 364

        return perceiver_block

    @staticmethod
365
    def _extract_attn_weights(cross_attns, latent_transformer) -> List:
366 367 368 369 370 371
        attention_weights = []
        for cross_attn in cross_attns:
            attention_weights.append(cross_attn.attn.attn_weights)
        for latent_block in latent_transformer:
            attention_weights.append(latent_block.attn.attn_weights)
        return attention_weights