tab_perceiver.py 15.6 KB
Newer Older
1 2 3 4 5
import torch
import einops
from torch import nn

from pytorch_widedeep.wdtypes import *  # noqa: F403
6
from pytorch_widedeep.models.embeddings_layers import (
7 8
    SameSizeCatAndContEmbeddings,
)
9 10 11
from pytorch_widedeep.models.tabular.mlp._layers import MLP
from pytorch_widedeep.models.tabular.transformers._encoders import (
    PerceiverEncoder,
12 13 14
)


15
class TabPerceiver(nn.Module):
J
jrzaurin 已提交
16 17 18
    r"""Defines an adaptation of a ``Perceiver`` model
    (`arXiv:2103.03206 <https://arxiv.org/abs/2103.03206>`_) that can be used
    as the ``deeptabular`` component of a Wide & Deep model.
19 20 21 22 23

    Parameters
    ----------
    column_idx: Dict
        Dict containing the index of the columns that will be passed through
J
jrzaurin 已提交
24 25
        the model. Required to slice the tensors. e.g.
        {'education': 0, 'relationship': 1, 'workclass': 2, ...}
26
    cat_embed_input: List
27
        List of Tuples with the column name and number of unique values
J
jrzaurin 已提交
28
        e.g. [('education', 11), ...]
29
    cat_embed_dropout: float, default = 0.1
30 31 32 33
        Dropout to be applied to the embeddings matrix
    full_embed_dropout: bool, default = False
        Boolean indicating if an entire embedding (i.e. the representation of
        one column) will be dropped in the batch. See:
J
jrzaurin 已提交
34
        :obj:`pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`.
35
        If ``full_embed_dropout = True``, ``cat_embed_dropout`` is ignored.
36
    shared_embed: bool, default = False
37 38 39 40
        The idea behind ``shared_embed`` is described in the Appendix A in the
        `TabTransformer paper <https://arxiv.org/abs/2012.06678>`_: `'The
        goal of having column embedding is to enable the model to distinguish
        the classes in one column from those in the other columns'`. In other
J
jrzaurin 已提交
41
        words, the idea is to let the model learn which column is embedded
42
        at the time.
43 44 45
    add_shared_embed: bool, default = False,
        The two embedding sharing strategies are: 1) add the shared embeddings to the column
        embeddings or 2) to replace the first ``frac_shared_embed`` with the shared
J
jrzaurin 已提交
46
        embeddings. See :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings`
47
    frac_shared_embed: float, default = 0.25
J
jrzaurin 已提交
48 49 50
        The fraction of embeddings that will be shared (if ``add_shared_embed
        = False``) by all the different categories for one particular
        column.
51 52 53 54
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
    embed_continuous_activation: str, default = None
        String indicating the activation function to be applied to the
J
jrzaurin 已提交
55 56
        continuous embeddings, if any. ``tanh``, ``relu``, ``leaky_relu`` and
        ``gelu`` are supported.
57 58 59 60
    cont_embed_dropout: float, default = 0.0,
        Dropout for the continuous embeddings
    cont_embed_activation: str,  default = None,
        Activation function for the continuous embeddings
61
    cont_norm_layer: str, default =  None,
J
jrzaurin 已提交
62 63 64
        Type of normalization layer applied to the continuous features before
        they are embedded. Options are: ``layernorm``, ``batchnorm`` or
        ``None``.
65 66 67 68 69
    input_dim: int, default = 32
        The so-called *dimension of the model*. In general, is the number of
        embeddings used to encode the categorical and/or continuous columns.
    n_cross_attns: int, default = 1
        Number of times each perceiver block will cross attend to the input
70 71 72 73 74 75
        data (i.e. number of cross attention components per perceiver block).
        This should normally be 1. However, in the paper they describe some
        architectures (normally computer vision-related problems) where the
        Perceiver attends multiple times to the input array. Therefore, maybe
        multiple cross attention to the input array is also useful in some
        cases for tabular data
76 77 78
    n_cross_attn_heads: int, default = 4
        Number of attention heads for the cross attention component
    n_latents: int, default = 16
79
        Number of latents. This is the *N* parameter in the paper. As
J
jrzaurin 已提交
80
        indicated in the paper, this number should be significantly lower
81
        than *M* (the number of columns in the dataset). Setting *N* closer
J
jrzaurin 已提交
82 83
        to *M* defies the main purpose of the Perceiver, which is to overcome
        the transformer quadratic bottleneck
84
    latent_dim: int, default = 128
85
        Latent dimension.
86 87 88
    n_latent_heads: int, default = 4
        Number of attention heads per Latent Transformer
    n_latent_blocks: int, default = 4
89 90
        Number of transformer encoder blocks (normalised MHA + normalised FF)
        per Latent Transformer
91
    n_perceiver_blocks: int, default = 4
J
jrzaurin 已提交
92
        Number of Perceiver blocks defined as [Cross Attention + Latent
93
        Transformer]
94
    share_weights: Boolean, default = False
95 96
        Boolean indicating if the weights will be shared between Perceiver
        blocks
97
    attn_dropout: float, default = 0.2
J
jrzaurin 已提交
98
        Dropout that will be applied to the Multi-Head Attention layers
99 100
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
101
    transformer_activation: str, default = "gelu"
J
jrzaurin 已提交
102 103
        Transformer Encoder activation function. ``tanh``, ``relu``,
        ``leaky_relu``, ``gelu``, ``geglu`` and ``reglu`` are supported
104
    mlp_hidden_dims: List, Optional, default = None
J
jrzaurin 已提交
105 106
        MLP hidden dimensions. If not provided it will default to ``[l, 4*l,
        2*l]`` where ``l`` is the MLP input dimension
107
    mlp_activation: str, default = "relu"
J
jrzaurin 已提交
108 109
        MLP activation function. ``tanh``, ``relu``, ``leaky_relu`` and
        ``gelu`` are supported
110 111
    mlp_dropout: float, default = 0.1
        Dropout that will be applied to the final MLP
112 113 114 115 116 117 118 119 120 121 122 123 124
    mlp_batchnorm: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        dense layers
    mlp_batchnorm_last: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers
    mlp_linear_first: bool, default = False
        Boolean indicating whether the order of the operations in the dense
        layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
        LIN -> ACT]``

    Attributes
    ----------
125
    cat_and_cont_embed: ``nn.Module``
J
jrzaurin 已提交
126
        This is the module that processes the categorical and continuous columns
127 128 129 130 131 132 133 134 135 136 137 138 139
    perceiver_blks: ``nn.ModuleDict``
        ModuleDict with the Perceiver blocks
    latents: ``nn.Parameter``
        Latents that will be used for prediction
    perceiver_mlp: ``nn.Module``
        MLP component in the model
    output_dim: int
        The output dimension of the model. This is a required attribute
        neccesary to build the WideDeep class

    Example
    --------
    >>> import torch
140
    >>> from pytorch_widedeep.models import TabPerceiver
141 142
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
143
    >>> cat_embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
144 145
    >>> continuous_cols = ['e']
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
146
    >>> model = TabPerceiver(column_idx=column_idx, cat_embed_input=cat_embed_input,
147 148 149 150 151
    ... continuous_cols=continuous_cols, n_latents=2, latent_dim=16,
    ... n_perceiver_blocks=2)
    >>> out = model(X_tab)
    """

152 153 154
    def __init__(
        self,
        column_idx: Dict[str, int],
155 156
        cat_embed_input: Optional[List[Tuple[str, int]]] = None,
        cat_embed_dropout: float = 0.1,
157 158 159 160 161 162
        full_embed_dropout: bool = False,
        shared_embed: bool = False,
        add_shared_embed: bool = False,
        frac_shared_embed: float = 0.25,
        continuous_cols: Optional[List[str]] = None,
        embed_continuous_activation: str = None,
163 164
        cont_embed_dropout: float = 0.0,
        cont_embed_activation: str = None,
165 166 167
        cont_norm_layer: str = None,
        input_dim: int = 32,
        n_cross_attns: int = 1,
168
        n_cross_attn_heads: int = 4,
169 170
        n_latents: int = 16,
        latent_dim: int = 128,
171 172 173
        n_latent_heads: int = 4,
        n_latent_blocks: int = 4,
        n_perceiver_blocks: int = 4,
174
        share_weights: bool = False,
175 176
        attn_dropout: float = 0.1,
        ff_dropout: float = 0.1,
177 178 179
        transformer_activation: str = "geglu",
        mlp_hidden_dims: Optional[List[int]] = None,
        mlp_activation: str = "relu",
180
        mlp_dropout: float = 0.1,
181 182 183 184
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = True,
    ):
185
        super(TabPerceiver, self).__init__()
186 187

        self.column_idx = column_idx
188 189
        self.cat_embed_input = cat_embed_input
        self.cat_embed_dropout = cat_embed_dropout
190 191 192 193
        self.full_embed_dropout = full_embed_dropout
        self.shared_embed = shared_embed
        self.add_shared_embed = add_shared_embed
        self.frac_shared_embed = frac_shared_embed
194

195 196
        self.continuous_cols = continuous_cols
        self.embed_continuous_activation = embed_continuous_activation
197 198
        self.cont_embed_dropout = cont_embed_dropout
        self.cont_embed_activation = cont_embed_activation
199
        self.cont_norm_layer = cont_norm_layer
200

201 202 203 204 205 206 207 208 209
        self.input_dim = input_dim
        self.n_cross_attns = n_cross_attns
        self.n_cross_attn_heads = n_cross_attn_heads
        self.n_latents = n_latents
        self.latent_dim = latent_dim
        self.n_latent_heads = n_latent_heads
        self.n_latent_blocks = n_latent_blocks
        self.n_perceiver_blocks = n_perceiver_blocks
        self.share_weights = share_weights
210 211
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
212
        self.transformer_activation = transformer_activation
213

214 215
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_activation = mlp_activation
216
        self.mlp_dropout = mlp_dropout
217 218 219 220
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
        self.mlp_linear_first = mlp_linear_first

221
        self.cat_and_cont_embed = SameSizeCatAndContEmbeddings(
222 223
            input_dim,
            column_idx,
224 225
            cat_embed_input,
            cat_embed_dropout,
226 227 228 229
            full_embed_dropout,
            shared_embed,
            add_shared_embed,
            frac_shared_embed,
230
            False,  # use_embed_bias
231 232
            continuous_cols,
            True,  # embed_continuous,
233
            cont_embed_dropout,
234
            embed_continuous_activation,
235
            True,  # use_cont_bias
236 237 238 239 240 241 242 243 244
            cont_norm_layer,
        )

        self.latents = nn.init.trunc_normal_(
            nn.Parameter(torch.empty(n_latents, latent_dim))
        )

        self.perceiver_blks = nn.ModuleDict()
        first_perceiver_block = self._build_perceiver_block()
245
        self.perceiver_blks["perceiver_block0"] = first_perceiver_block
246 247 248

        if share_weights:
            for n in range(1, n_perceiver_blocks):
249
                self.perceiver_blks["perceiver_block" + str(n)] = first_perceiver_block
250 251
        else:
            for n in range(1, n_perceiver_blocks):
252 253 254
                self.perceiver_blks[
                    "perceiver_block" + str(n)
                ] = self._build_perceiver_block()
255 256

        if not mlp_hidden_dims:
257
            self.mlp_hidden_dims = [latent_dim, latent_dim * 4, latent_dim * 2]
258
        else:
259 260
            self.mlp_hidden_dims = [latent_dim] + mlp_hidden_dims

261
        self.perceiver_mlp = MLP(
262
            self.mlp_hidden_dims,
263
            mlp_activation,
264
            mlp_dropout,
265 266 267 268 269 270
            mlp_batchnorm,
            mlp_batchnorm_last,
            mlp_linear_first,
        )

        # the output_dim attribute will be used as input_dim when "merging" the models
271
        self.output_dim = self.mlp_hidden_dims[-1]
272 273 274

    def forward(self, X: Tensor) -> Tensor:

275 276 277 278 279
        x_cat, x_cont = self.cat_and_cont_embed(X)
        if x_cat is not None:
            x_emb = x_cat
        if x_cont is not None:
            x_emb = torch.cat([x_emb, x_cont], 1) if x_cat is not None else x_cont
280

281 282 283
        x = einops.repeat(self.latents, "n d -> b n d", b=X.shape[0])

        for n in range(self.n_perceiver_blocks):
284 285 286 287
            cross_attns = self.perceiver_blks["perceiver_block" + str(n)]["cross_attns"]
            latent_transformer = self.perceiver_blks["perceiver_block" + str(n)][
                "latent_transformer"
            ]
288
            for cross_attn in cross_attns:
289
                x = cross_attn(x, x_emb)
290 291
            x = latent_transformer(x)

292 293 294
        # average along the latent index axis
        x = x.mean(dim=1)

295 296 297
        return self.perceiver_mlp(x)

    @property
298 299 300 301 302
    def attention_weights(self) -> List:
        r"""List with the attention weights. If the weights are not shared
        between perceiver blocks each element of the list will be a list
        itself containing the Cross Attention and Latent Transformer
        attention weights respectively
303 304 305 306 307 308 309 310 311 312

        The shape of the attention weights is:

            - Cross Attention: :math:`(N, C, L, F)`
            - Latent Attention: :math:`(N, T, L, L)`

        WHere *N* is the batch size, *C* is the number of Cross Attention
        heads, *L* is the number of Latents, *F* is the number of
        features/columns in the dataset and *T* is the number of Latent
        Attention heads
313
        """
314
        if self.share_weights:
315 316 317 318
            cross_attns = self.perceiver_blks["perceiver_block0"]["cross_attns"]
            latent_transformer = self.perceiver_blks["perceiver_block0"][
                "latent_transformer"
            ]
319 320 321 322 323 324
            attention_weights = self._extract_attn_weights(
                cross_attns, latent_transformer
            )
        else:
            attention_weights = []
            for n in range(self.n_perceiver_blocks):
325 326 327 328 329 330
                cross_attns = self.perceiver_blks["perceiver_block" + str(n)][
                    "cross_attns"
                ]
                latent_transformer = self.perceiver_blks["perceiver_block" + str(n)][
                    "latent_transformer"
                ]
331 332 333 334 335
                attention_weights.append(
                    self._extract_attn_weights(cross_attns, latent_transformer)
                )
        return attention_weights

336
    def _build_perceiver_block(self) -> nn.ModuleDict:
337

338
        perceiver_block = nn.ModuleDict()
339 340 341 342 343

        # Cross Attention
        cross_attns = nn.ModuleList()
        for _ in range(self.n_cross_attns):
            cross_attns.append(
344
                PerceiverEncoder(
345 346 347
                    self.input_dim,
                    self.n_cross_attn_heads,
                    False,  # use_bias
348 349
                    self.attn_dropout,
                    self.ff_dropout,
350
                    self.transformer_activation,
351
                    self.latent_dim,  # q_dim,
352 353
                ),
            )
354
        perceiver_block["cross_attns"] = cross_attns
355 356 357 358 359

        # Latent Transformer
        latent_transformer = nn.Sequential()
        for i in range(self.n_latent_blocks):
            latent_transformer.add_module(
360
                "latent_block" + str(i),
361
                PerceiverEncoder(
362
                    self.latent_dim,  # input_dim
363 364
                    self.n_latent_heads,
                    False,  # use_bias
365 366
                    self.attn_dropout,
                    self.ff_dropout,
367 368 369
                    self.transformer_activation,
                ),
            )
370
        perceiver_block["latent_transformer"] = latent_transformer
371 372 373 374

        return perceiver_block

    @staticmethod
375
    def _extract_attn_weights(cross_attns, latent_transformer) -> List:
376 377 378 379 380 381
        attention_weights = []
        for cross_attn in cross_attns:
            attention_weights.append(cross_attn.attn.attn_weights)
        for latent_block in latent_transformer:
            attention_weights.append(latent_block.attn.attn_weights)
        return attention_weights