tab_perceiver.py 15.6 KB
Newer Older
1 2 3 4 5
import torch
import einops
from torch import nn

from pytorch_widedeep.wdtypes import *  # noqa: F403
6
from pytorch_widedeep.models.tabular.mlp._layers import MLP
7 8
from pytorch_widedeep.models.tabular._base_tabular_model import (
    BaseTabularModelWithAttention,
9
)
10 11
from pytorch_widedeep.models.tabular.transformers._encoders import (
    PerceiverEncoder,
12 13 14
)


15
class TabPerceiver(BaseTabularModelWithAttention):
16 17 18 19 20 21 22 23 24
    r"""Defines an adaptation of a [Perceiver](https://arxiv.org/abs/2103.03206)
     that can be used as the `deeptabular` component of a Wide & Deep model
     or independently by itself.

    :information_source: **NOTE**: while there are scientific publications for
     the `TabTransformer`, `SAINT` and `FTTransformer`, the `TabPerceiver`
     and the `TabFastFormer` are our own adaptations of the
     [Perceiver](https://arxiv.org/abs/2103.03206) and the
     [FastFormer](https://arxiv.org/abs/2108.09084) for tabular data.
25 26 27 28 29

    Parameters
    ----------
    column_idx: Dict
        Dict containing the index of the columns that will be passed through
J
jrzaurin 已提交
30
        the model. Required to slice the tensors. e.g.
31
        _{'education': 0, 'relationship': 1, 'workclass': 2, ...}_
32
    cat_embed_input: List, Optional, default = None
J
jrzaurin 已提交
33
        List of Tuples with the column name and number of unique values for
34
        each categorical component e.g. _[(education, 11), ...]_
35
    cat_embed_dropout: float, default = 0.1
36
        Categorical embeddings dropout
J
jrzaurin 已提交
37
    use_cat_bias: bool, default = False,
J
jrzaurin 已提交
38
        Boolean indicating if bias will be used for the categorical embeddings
39
    cat_embed_activation: Optional, str, default = None,
40 41
        Activation function for the categorical embeddings, if any. _'tanh'_,
        _'relu'_, _'leaky_relu'_ and _'gelu'_ are supported.
42 43 44
    full_embed_dropout: bool, default = False
        Boolean indicating if an entire embedding (i.e. the representation of
        one column) will be dropped in the batch. See:
45 46
        `pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`.
        If `full_embed_dropout = True`, `cat_embed_dropout` is ignored.
47
    shared_embed: bool, default = False
48 49
        The idea behind `shared_embed` is described in the Appendix A in the
        [TabTransformer paper](https://arxiv.org/abs/2012.06678): the
50
        goal of having column embedding is to enable the model to distinguish
51
        the classes in one column from those in the other columns. In other
J
jrzaurin 已提交
52
        words, the idea is to let the model learn which column is embedded
53
        at the time.
54
    add_shared_embed: bool, default = False,
55 56
        The two embedding sharing strategies are: 1) add the shared embeddings
        to the column embeddings or 2) to replace the first
57 58
        `frac_shared_embed` with the shared embeddings.
        See `pytorch_widedeep.models.transformers._layers.SharedEmbeddings`
59
    frac_shared_embed: float, default = 0.25
60 61
        The fraction of embeddings that will be shared (if `add_shared_embed
        = False`) by all the different categories for one particular
J
jrzaurin 已提交
62
        column.
63 64
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
65 66
    cont_norm_layer: str, default =  "batchnorm"
        Type of normalization layer applied to the continuous features. Options
67
        are: _'layernorm'_, _'batchnorm'_ or None.
68 69 70
    cont_embed_dropout: float, default = 0.1,
        Continuous embeddings dropout
    use_cont_bias: bool, default = True,
J
jrzaurin 已提交
71
        Boolean indicating if bias will be used for the continuous embeddings
72
    cont_embed_activation: str, default = None
J
jrzaurin 已提交
73
        Activation function to be applied to the continuous embeddings, if
74
        any. _'tanh'_, _'relu'_, _'leaky_relu'_ and _'gelu'_ are supported.
75
    input_dim: int, default = 32
76 77
        The so-called *dimension of the model*. Is the number of embeddings
        used to encode the categorical and/or continuous columns.
78 79
    n_cross_attns: int, default = 1
        Number of times each perceiver block will cross attend to the input
80 81 82 83 84
        data (i.e. number of cross attention components per perceiver block).
        This should normally be 1. However, in the paper they describe some
        architectures (normally computer vision-related problems) where the
        Perceiver attends multiple times to the input array. Therefore, maybe
        multiple cross attention to the input array is also useful in some
85
        cases for tabular data :shrug: .
86 87 88
    n_cross_attn_heads: int, default = 4
        Number of attention heads for the cross attention component
    n_latents: int, default = 16
89
        Number of latents. This is the $N$ parameter in the paper. As
J
jrzaurin 已提交
90
        indicated in the paper, this number should be significantly lower
91 92
        than $M$ (the number of columns in the dataset). Setting $N$ closer
        to $M$ defies the main purpose of the Perceiver, which is to overcome
J
jrzaurin 已提交
93
        the transformer quadratic bottleneck
94
    latent_dim: int, default = 128
95
        Latent dimension.
96 97 98
    n_latent_heads: int, default = 4
        Number of attention heads per Latent Transformer
    n_latent_blocks: int, default = 4
99 100
        Number of transformer encoder blocks (normalised MHA + normalised FF)
        per Latent Transformer
101
    n_perceiver_blocks: int, default = 4
J
jrzaurin 已提交
102
        Number of Perceiver blocks defined as [Cross Attention + Latent
103
        Transformer]
104
    share_weights: Boolean, default = False
105 106
        Boolean indicating if the weights will be shared between Perceiver
        blocks
107
    attn_dropout: float, default = 0.2
J
jrzaurin 已提交
108
        Dropout that will be applied to the Multi-Head Attention layers
109 110
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
111
    transformer_activation: str, default = "gelu"
112 113
        Transformer Encoder activation function. _'tanh'_, _'relu'_,
        _'leaky_relu'_, _'gelu'_, _'geglu'_ and _'reglu'_ are supported
114
    mlp_hidden_dims: List, Optional, default = None
115 116
        MLP hidden dimensions. If not provided it will default to $[l, 4
        \times l, 2 \times l]$ where $l$ is the MLP's input dimension
117
    mlp_activation: str, default = "relu"
118 119
        MLP activation function. _'tanh'_, _'relu'_, _'leaky_relu'_ and
        _'gelu'_ are supported
120 121
    mlp_dropout: float, default = 0.1
        Dropout that will be applied to the final MLP
122 123 124 125 126 127 128 129
    mlp_batchnorm: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        dense layers
    mlp_batchnorm_last: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers
    mlp_linear_first: bool, default = False
        Boolean indicating whether the order of the operations in the dense
130 131
        layer. If `True: [LIN -> ACT -> BN -> DP]`. If `False: [BN -> DP ->
        LIN -> ACT]`
132 133 134

    Attributes
    ----------
135
    cat_and_cont_embed: nn.Module
J
jrzaurin 已提交
136
        This is the module that processes the categorical and continuous columns
137
    perceiver_blks: nn.ModuleDict
138
        ModuleDict with the Perceiver blocks
139
    latents: nn.Parameter
140
        Latents that will be used for prediction
141
    perceiver_mlp: nn.Module
142 143 144
        MLP component in the model
    output_dim: int
        The output dimension of the model. This is a required attribute
145
        neccesary to build the `WideDeep` class
146

147
    Examples
148 149
    --------
    >>> import torch
150
    >>> from pytorch_widedeep.models import TabPerceiver
151 152
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
153
    >>> cat_embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
154 155
    >>> continuous_cols = ['e']
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
156
    >>> model = TabPerceiver(column_idx=column_idx, cat_embed_input=cat_embed_input,
157 158 159 160 161
    ... continuous_cols=continuous_cols, n_latents=2, latent_dim=16,
    ... n_perceiver_blocks=2)
    >>> out = model(X_tab)
    """

162 163 164
    def __init__(
        self,
        column_idx: Dict[str, int],
165 166
        cat_embed_input: Optional[List[Tuple[str, int]]] = None,
        cat_embed_dropout: float = 0.1,
167 168
        use_cat_bias: bool = False,
        cat_embed_activation: Optional[str] = None,
169 170 171 172 173 174
        full_embed_dropout: bool = False,
        shared_embed: bool = False,
        add_shared_embed: bool = False,
        frac_shared_embed: float = 0.25,
        continuous_cols: Optional[List[str]] = None,
        cont_norm_layer: str = None,
175 176 177
        cont_embed_dropout: float = 0.1,
        use_cont_bias: bool = True,
        cont_embed_activation: Optional[str] = None,
178 179
        input_dim: int = 32,
        n_cross_attns: int = 1,
180
        n_cross_attn_heads: int = 4,
181 182
        n_latents: int = 16,
        latent_dim: int = 128,
183 184 185
        n_latent_heads: int = 4,
        n_latent_blocks: int = 4,
        n_perceiver_blocks: int = 4,
186
        share_weights: bool = False,
187 188
        attn_dropout: float = 0.1,
        ff_dropout: float = 0.1,
189 190 191
        transformer_activation: str = "geglu",
        mlp_hidden_dims: Optional[List[int]] = None,
        mlp_activation: str = "relu",
192
        mlp_dropout: float = 0.1,
193 194 195 196
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = True,
    ):
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
        super(TabPerceiver, self).__init__(
            column_idx=column_idx,
            cat_embed_input=cat_embed_input,
            cat_embed_dropout=cat_embed_dropout,
            use_cat_bias=use_cat_bias,
            cat_embed_activation=cat_embed_activation,
            full_embed_dropout=full_embed_dropout,
            shared_embed=shared_embed,
            add_shared_embed=add_shared_embed,
            frac_shared_embed=frac_shared_embed,
            continuous_cols=continuous_cols,
            cont_norm_layer=cont_norm_layer,
            embed_continuous=True,
            cont_embed_dropout=cont_embed_dropout,
            use_cont_bias=use_cont_bias,
            cont_embed_activation=cont_embed_activation,
            input_dim=input_dim,
        )
215

216 217 218 219 220 221 222 223
        self.n_cross_attns = n_cross_attns
        self.n_cross_attn_heads = n_cross_attn_heads
        self.n_latents = n_latents
        self.latent_dim = latent_dim
        self.n_latent_heads = n_latent_heads
        self.n_latent_blocks = n_latent_blocks
        self.n_perceiver_blocks = n_perceiver_blocks
        self.share_weights = share_weights
224 225
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
226
        self.transformer_activation = transformer_activation
227

228 229
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_activation = mlp_activation
230
        self.mlp_dropout = mlp_dropout
231 232 233 234
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
        self.mlp_linear_first = mlp_linear_first

J
jrzaurin 已提交
235
        # Embeddings are instantiated at the base model
236
        # Transformer blocks
237 238 239 240
        self.latents = nn.init.trunc_normal_(
            nn.Parameter(torch.empty(n_latents, latent_dim))
        )

241
        self.encoder = nn.ModuleDict()
242
        first_perceiver_block = self._build_perceiver_block()
243
        self.encoder["perceiver_block0"] = first_perceiver_block
244 245 246

        if share_weights:
            for n in range(1, n_perceiver_blocks):
247
                self.encoder["perceiver_block" + str(n)] = first_perceiver_block
248 249
        else:
            for n in range(1, n_perceiver_blocks):
250
                self.encoder["perceiver_block" + str(n)] = self._build_perceiver_block()
251

252 253
        self.mlp_first_hidden_dim = self.latent_dim

254
        # Mlp
255 256
        if mlp_hidden_dims is not None:
            self.mlp = MLP(
257
                [self.mlp_first_hidden_dim] + mlp_hidden_dims,
258 259 260 261 262 263
                mlp_activation,
                mlp_dropout,
                mlp_batchnorm,
                mlp_batchnorm_last,
                mlp_linear_first,
            )
264
        else:
265
            self.mlp = None
266 267 268

    def forward(self, X: Tensor) -> Tensor:

269
        x_emb = self._get_embeddings(X)
270

271 272 273
        x = einops.repeat(self.latents, "n d -> b n d", b=X.shape[0])

        for n in range(self.n_perceiver_blocks):
274 275
            cross_attns = self.encoder["perceiver_block" + str(n)]["cross_attns"]
            latent_transformer = self.encoder["perceiver_block" + str(n)][
276 277
                "latent_transformer"
            ]
278
            for cross_attn in cross_attns:
279
                x = cross_attn(x, x_emb)
280 281
            x = latent_transformer(x)

282 283 284
        # average along the latent index axis
        x = x.mean(dim=1)

285 286 287 288 289 290 291 292 293 294
        if self.mlp is not None:
            x = self.mlp(x)

        return x

    @property
    def output_dim(self) -> int:
        return (
            self.mlp_hidden_dims[-1]
            if self.mlp_hidden_dims is not None
295
            else self.mlp_first_hidden_dim
296
        )
297 298

    @property
299 300 301 302 303
    def attention_weights(self) -> List:
        r"""List with the attention weights. If the weights are not shared
        between perceiver blocks each element of the list will be a list
        itself containing the Cross Attention and Latent Transformer
        attention weights respectively
304 305 306

        The shape of the attention weights is:

307 308 309
        - Cross Attention: $(N, C, L, F)$

        - Latent Attention: $(N, T, L, L)$
310

311 312 313
        WHere $N$ is the batch size, $C$ is the number of Cross Attention
        heads, $L$ is the number of Latents, $F$ is the number of
        features/columns in the dataset and $T$ is the number of Latent
314
        Attention heads
315
        """
316
        if self.share_weights:
317 318
            cross_attns = self.encoder["perceiver_block0"]["cross_attns"]
            latent_transformer = self.encoder["perceiver_block0"]["latent_transformer"]
319 320 321 322 323 324
            attention_weights = self._extract_attn_weights(
                cross_attns, latent_transformer
            )
        else:
            attention_weights = []
            for n in range(self.n_perceiver_blocks):
325 326
                cross_attns = self.encoder["perceiver_block" + str(n)]["cross_attns"]
                latent_transformer = self.encoder["perceiver_block" + str(n)][
327 328
                    "latent_transformer"
                ]
329 330 331 332 333
                attention_weights.append(
                    self._extract_attn_weights(cross_attns, latent_transformer)
                )
        return attention_weights

334
    def _build_perceiver_block(self) -> nn.ModuleDict:
335

336
        perceiver_block = nn.ModuleDict()
337 338 339 340 341

        # Cross Attention
        cross_attns = nn.ModuleList()
        for _ in range(self.n_cross_attns):
            cross_attns.append(
342
                PerceiverEncoder(
343 344 345
                    self.input_dim,
                    self.n_cross_attn_heads,
                    False,  # use_bias
346 347
                    self.attn_dropout,
                    self.ff_dropout,
348
                    self.transformer_activation,
349
                    self.latent_dim,  # q_dim,
350 351
                ),
            )
352
        perceiver_block["cross_attns"] = cross_attns
353 354 355 356 357

        # Latent Transformer
        latent_transformer = nn.Sequential()
        for i in range(self.n_latent_blocks):
            latent_transformer.add_module(
358
                "latent_block" + str(i),
359
                PerceiverEncoder(
360
                    self.latent_dim,  # input_dim
361 362
                    self.n_latent_heads,
                    False,  # use_bias
363 364
                    self.attn_dropout,
                    self.ff_dropout,
365 366 367
                    self.transformer_activation,
                ),
            )
368
        perceiver_block["latent_transformer"] = latent_transformer
369 370 371 372

        return perceiver_block

    @staticmethod
373
    def _extract_attn_weights(cross_attns, latent_transformer) -> List:
374 375 376 377 378 379
        attention_weights = []
        for cross_attn in cross_attns:
            attention_weights.append(cross_attn.attn.attn_weights)
        for latent_block in latent_transformer:
            attention_weights.append(latent_block.attn.attn_weights)
        return attention_weights