tab_perceiver.py 15.9 KB
Newer Older
1 2 3 4
import torch
import einops
from torch import nn

5
from pytorch_widedeep.wdtypes import Dict, List, Tuple, Tensor, Optional
6
from pytorch_widedeep.models.tabular.mlp._layers import MLP
7 8
from pytorch_widedeep.models.tabular._base_tabular_model import (
    BaseTabularModelWithAttention,
9
)
10 11
from pytorch_widedeep.models.tabular.transformers._encoders import (
    PerceiverEncoder,
12 13 14
)


15
class TabPerceiver(BaseTabularModelWithAttention):
16 17 18 19 20 21 22 23 24
    r"""Defines an adaptation of a [Perceiver](https://arxiv.org/abs/2103.03206)
     that can be used as the `deeptabular` component of a Wide & Deep model
     or independently by itself.

    :information_source: **NOTE**: while there are scientific publications for
     the `TabTransformer`, `SAINT` and `FTTransformer`, the `TabPerceiver`
     and the `TabFastFormer` are our own adaptations of the
     [Perceiver](https://arxiv.org/abs/2103.03206) and the
     [FastFormer](https://arxiv.org/abs/2108.09084) for tabular data.
25 26 27 28 29

    Parameters
    ----------
    column_idx: Dict
        Dict containing the index of the columns that will be passed through
J
jrzaurin 已提交
30
        the model. Required to slice the tensors. e.g.
31
        _{'education': 0, 'relationship': 1, 'workclass': 2, ...}_
32
    cat_embed_input: List, Optional, default = None
J
jrzaurin 已提交
33
        List of Tuples with the column name and number of unique values for
34
        each categorical component e.g. _[(education, 11), ...]_
35
    cat_embed_dropout: float, default = 0.1
36
        Categorical embeddings dropout
J
jrzaurin 已提交
37
    use_cat_bias: bool, default = False,
J
jrzaurin 已提交
38
        Boolean indicating if bias will be used for the categorical embeddings
39
    cat_embed_activation: Optional, str, default = None,
40 41
        Activation function for the categorical embeddings, if any. _'tanh'_,
        _'relu'_, _'leaky_relu'_ and _'gelu'_ are supported.
42 43 44
    full_embed_dropout: bool, default = False
        Boolean indicating if an entire embedding (i.e. the representation of
        one column) will be dropped in the batch. See:
45 46
        `pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`.
        If `full_embed_dropout = True`, `cat_embed_dropout` is ignored.
47
    shared_embed: bool, default = False
48 49
        The idea behind `shared_embed` is described in the Appendix A in the
        [TabTransformer paper](https://arxiv.org/abs/2012.06678): the
50
        goal of having column embedding is to enable the model to distinguish
51
        the classes in one column from those in the other columns. In other
J
jrzaurin 已提交
52
        words, the idea is to let the model learn which column is embedded
53
        at the time.
54
    add_shared_embed: bool, default = False,
55 56
        The two embedding sharing strategies are: 1) add the shared embeddings
        to the column embeddings or 2) to replace the first
57 58
        `frac_shared_embed` with the shared embeddings.
        See `pytorch_widedeep.models.transformers._layers.SharedEmbeddings`
59
    frac_shared_embed: float, default = 0.25
60 61
        The fraction of embeddings that will be shared (if `add_shared_embed
        = False`) by all the different categories for one particular
J
jrzaurin 已提交
62
        column.
63 64
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
65 66
    cont_norm_layer: str, default =  "batchnorm"
        Type of normalization layer applied to the continuous features. Options
67
        are: _'layernorm'_, _'batchnorm'_ or None.
68 69 70
    cont_embed_dropout: float, default = 0.1,
        Continuous embeddings dropout
    use_cont_bias: bool, default = True,
J
jrzaurin 已提交
71
        Boolean indicating if bias will be used for the continuous embeddings
72
    cont_embed_activation: str, default = None
J
jrzaurin 已提交
73
        Activation function to be applied to the continuous embeddings, if
74
        any. _'tanh'_, _'relu'_, _'leaky_relu'_ and _'gelu'_ are supported.
75
    input_dim: int, default = 32
76 77
        The so-called *dimension of the model*. Is the number of embeddings
        used to encode the categorical and/or continuous columns.
78 79
    n_cross_attns: int, default = 1
        Number of times each perceiver block will cross attend to the input
80 81 82 83 84
        data (i.e. number of cross attention components per perceiver block).
        This should normally be 1. However, in the paper they describe some
        architectures (normally computer vision-related problems) where the
        Perceiver attends multiple times to the input array. Therefore, maybe
        multiple cross attention to the input array is also useful in some
85
        cases for tabular data :shrug: .
86 87 88
    n_cross_attn_heads: int, default = 4
        Number of attention heads for the cross attention component
    n_latents: int, default = 16
89
        Number of latents. This is the $N$ parameter in the paper. As
J
jrzaurin 已提交
90
        indicated in the paper, this number should be significantly lower
91 92
        than $M$ (the number of columns in the dataset). Setting $N$ closer
        to $M$ defies the main purpose of the Perceiver, which is to overcome
J
jrzaurin 已提交
93
        the transformer quadratic bottleneck
94
    latent_dim: int, default = 128
95
        Latent dimension.
96 97 98
    n_latent_heads: int, default = 4
        Number of attention heads per Latent Transformer
    n_latent_blocks: int, default = 4
99 100
        Number of transformer encoder blocks (normalised MHA + normalised FF)
        per Latent Transformer
101
    n_perceiver_blocks: int, default = 4
J
jrzaurin 已提交
102
        Number of Perceiver blocks defined as [Cross Attention + Latent
103
        Transformer]
104
    share_weights: Boolean, default = False
105 106
        Boolean indicating if the weights will be shared between Perceiver
        blocks
107
    attn_dropout: float, default = 0.2
J
jrzaurin 已提交
108
        Dropout that will be applied to the Multi-Head Attention layers
109 110
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
111 112 113
    ff_factor: float, default = 4
        Multiplicative factor applied to the first layer of the FF network in
        each Transformer block, This is normally set to 4.
114
    transformer_activation: str, default = "gelu"
115 116
        Transformer Encoder activation function. _'tanh'_, _'relu'_,
        _'leaky_relu'_, _'gelu'_, _'geglu'_ and _'reglu'_ are supported
117
    mlp_hidden_dims: List, Optional, default = None
118 119
        MLP hidden dimensions. If not provided it will default to $[l, 4
        \times l, 2 \times l]$ where $l$ is the MLP's input dimension
120
    mlp_activation: str, default = "relu"
121 122
        MLP activation function. _'tanh'_, _'relu'_, _'leaky_relu'_ and
        _'gelu'_ are supported
123 124
    mlp_dropout: float, default = 0.1
        Dropout that will be applied to the final MLP
125 126 127 128 129 130 131 132
    mlp_batchnorm: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        dense layers
    mlp_batchnorm_last: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers
    mlp_linear_first: bool, default = False
        Boolean indicating whether the order of the operations in the dense
133 134
        layer. If `True: [LIN -> ACT -> BN -> DP]`. If `False: [BN -> DP ->
        LIN -> ACT]`
135 136 137

    Attributes
    ----------
138
    cat_and_cont_embed: nn.Module
J
jrzaurin 已提交
139
        This is the module that processes the categorical and continuous columns
140
    encoder: nn.ModuleDict
141
        ModuleDict with the Perceiver blocks
142
    latents: nn.Parameter
143
        Latents that will be used for prediction
144
    mlp: nn.Module
145 146
        MLP component in the model

147
    Examples
148 149
    --------
    >>> import torch
150
    >>> from pytorch_widedeep.models import TabPerceiver
151 152
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
153
    >>> cat_embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
154 155
    >>> continuous_cols = ['e']
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
156
    >>> model = TabPerceiver(column_idx=column_idx, cat_embed_input=cat_embed_input,
157 158 159 160 161
    ... continuous_cols=continuous_cols, n_latents=2, latent_dim=16,
    ... n_perceiver_blocks=2)
    >>> out = model(X_tab)
    """

162 163 164
    def __init__(
        self,
        column_idx: Dict[str, int],
165 166
        cat_embed_input: Optional[List[Tuple[str, int]]] = None,
        cat_embed_dropout: float = 0.1,
167 168
        use_cat_bias: bool = False,
        cat_embed_activation: Optional[str] = None,
169 170 171 172 173 174
        full_embed_dropout: bool = False,
        shared_embed: bool = False,
        add_shared_embed: bool = False,
        frac_shared_embed: float = 0.25,
        continuous_cols: Optional[List[str]] = None,
        cont_norm_layer: str = None,
175 176 177
        cont_embed_dropout: float = 0.1,
        use_cont_bias: bool = True,
        cont_embed_activation: Optional[str] = None,
178 179
        input_dim: int = 32,
        n_cross_attns: int = 1,
180
        n_cross_attn_heads: int = 4,
181 182
        n_latents: int = 16,
        latent_dim: int = 128,
183 184 185
        n_latent_heads: int = 4,
        n_latent_blocks: int = 4,
        n_perceiver_blocks: int = 4,
186
        share_weights: bool = False,
187 188
        attn_dropout: float = 0.1,
        ff_dropout: float = 0.1,
189
        ff_factor: int = 4,
190 191 192
        transformer_activation: str = "geglu",
        mlp_hidden_dims: Optional[List[int]] = None,
        mlp_activation: str = "relu",
193
        mlp_dropout: float = 0.1,
194 195 196 197
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = True,
    ):
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
        super(TabPerceiver, self).__init__(
            column_idx=column_idx,
            cat_embed_input=cat_embed_input,
            cat_embed_dropout=cat_embed_dropout,
            use_cat_bias=use_cat_bias,
            cat_embed_activation=cat_embed_activation,
            full_embed_dropout=full_embed_dropout,
            shared_embed=shared_embed,
            add_shared_embed=add_shared_embed,
            frac_shared_embed=frac_shared_embed,
            continuous_cols=continuous_cols,
            cont_norm_layer=cont_norm_layer,
            embed_continuous=True,
            cont_embed_dropout=cont_embed_dropout,
            use_cont_bias=use_cont_bias,
            cont_embed_activation=cont_embed_activation,
            input_dim=input_dim,
        )
216

217 218 219 220 221 222 223 224
        self.n_cross_attns = n_cross_attns
        self.n_cross_attn_heads = n_cross_attn_heads
        self.n_latents = n_latents
        self.latent_dim = latent_dim
        self.n_latent_heads = n_latent_heads
        self.n_latent_blocks = n_latent_blocks
        self.n_perceiver_blocks = n_perceiver_blocks
        self.share_weights = share_weights
225 226
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
227
        self.ff_factor = ff_factor
228
        self.transformer_activation = transformer_activation
229

230 231
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_activation = mlp_activation
232
        self.mlp_dropout = mlp_dropout
233 234 235 236
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
        self.mlp_linear_first = mlp_linear_first

J
jrzaurin 已提交
237
        # Embeddings are instantiated at the base model
238
        # Transformer blocks
239 240 241 242
        self.latents = nn.init.trunc_normal_(
            nn.Parameter(torch.empty(n_latents, latent_dim))
        )

243
        self.encoder = nn.ModuleDict()
244
        first_perceiver_block = self._build_perceiver_block()
245
        self.encoder["perceiver_block0"] = first_perceiver_block
246 247 248

        if share_weights:
            for n in range(1, n_perceiver_blocks):
249
                self.encoder["perceiver_block" + str(n)] = first_perceiver_block
250 251
        else:
            for n in range(1, n_perceiver_blocks):
252
                self.encoder["perceiver_block" + str(n)] = self._build_perceiver_block()
253

254 255
        self.mlp_first_hidden_dim = self.latent_dim

256
        # Mlp
257 258
        if mlp_hidden_dims is not None:
            self.mlp = MLP(
259
                [self.mlp_first_hidden_dim] + mlp_hidden_dims,
260 261 262 263 264 265
                mlp_activation,
                mlp_dropout,
                mlp_batchnorm,
                mlp_batchnorm_last,
                mlp_linear_first,
            )
266
        else:
267
            self.mlp = None
268 269

    def forward(self, X: Tensor) -> Tensor:
270
        x_emb = self._get_embeddings(X)
271

272 273 274
        x = einops.repeat(self.latents, "n d -> b n d", b=X.shape[0])

        for n in range(self.n_perceiver_blocks):
275 276
            cross_attns = self.encoder["perceiver_block" + str(n)]["cross_attns"]
            latent_transformer = self.encoder["perceiver_block" + str(n)][
277 278
                "latent_transformer"
            ]
279
            for cross_attn in cross_attns:
280
                x = cross_attn(x, x_emb)
281 282
            x = latent_transformer(x)

283 284 285
        # average along the latent index axis
        x = x.mean(dim=1)

286 287 288 289 290 291 292
        if self.mlp is not None:
            x = self.mlp(x)

        return x

    @property
    def output_dim(self) -> int:
293 294 295
        r"""The output dimension of the model. This is a required property
        neccesary to build the `WideDeep` class
        """
296 297 298
        return (
            self.mlp_hidden_dims[-1]
            if self.mlp_hidden_dims is not None
299
            else self.mlp_first_hidden_dim
300
        )
301 302

    @property
303 304 305 306 307
    def attention_weights(self) -> List:
        r"""List with the attention weights. If the weights are not shared
        between perceiver blocks each element of the list will be a list
        itself containing the Cross Attention and Latent Transformer
        attention weights respectively
308 309 310

        The shape of the attention weights is:

311 312 313
        - Cross Attention: $(N, C, L, F)$

        - Latent Attention: $(N, T, L, L)$
314

315 316 317
        WHere $N$ is the batch size, $C$ is the number of Cross Attention
        heads, $L$ is the number of Latents, $F$ is the number of
        features/columns in the dataset and $T$ is the number of Latent
318
        Attention heads
319
        """
320
        if self.share_weights:
321 322
            cross_attns = self.encoder["perceiver_block0"]["cross_attns"]
            latent_transformer = self.encoder["perceiver_block0"]["latent_transformer"]
323 324 325 326 327 328
            attention_weights = self._extract_attn_weights(
                cross_attns, latent_transformer
            )
        else:
            attention_weights = []
            for n in range(self.n_perceiver_blocks):
329 330
                cross_attns = self.encoder["perceiver_block" + str(n)]["cross_attns"]
                latent_transformer = self.encoder["perceiver_block" + str(n)][
331 332
                    "latent_transformer"
                ]
333 334 335 336 337
                attention_weights.append(
                    self._extract_attn_weights(cross_attns, latent_transformer)
                )
        return attention_weights

338 339
    def _build_perceiver_block(self) -> nn.ModuleDict:
        perceiver_block = nn.ModuleDict()
340 341 342 343 344

        # Cross Attention
        cross_attns = nn.ModuleList()
        for _ in range(self.n_cross_attns):
            cross_attns.append(
345
                PerceiverEncoder(
346 347 348
                    self.input_dim,
                    self.n_cross_attn_heads,
                    False,  # use_bias
349 350
                    self.attn_dropout,
                    self.ff_dropout,
351
                    self.ff_factor,
352
                    self.transformer_activation,
353
                    self.latent_dim,  # q_dim,
354 355
                ),
            )
356
        perceiver_block["cross_attns"] = cross_attns
357 358 359 360 361

        # Latent Transformer
        latent_transformer = nn.Sequential()
        for i in range(self.n_latent_blocks):
            latent_transformer.add_module(
362
                "latent_block" + str(i),
363
                PerceiverEncoder(
364
                    self.latent_dim,  # input_dim
365 366
                    self.n_latent_heads,
                    False,  # use_bias
367 368
                    self.attn_dropout,
                    self.ff_dropout,
369
                    self.ff_factor,
370 371 372
                    self.transformer_activation,
                ),
            )
373
        perceiver_block["latent_transformer"] = latent_transformer
374 375 376 377

        return perceiver_block

    @staticmethod
378
    def _extract_attn_weights(cross_attns, latent_transformer) -> List:
379 380 381 382 383 384
        attention_weights = []
        for cross_attn in cross_attns:
            attention_weights.append(cross_attn.attn.attn_weights)
        for latent_block in latent_transformer:
            attention_weights.append(latent_block.attn.attn_weights)
        return attention_weights