tab_transformer.py 12.3 KB
Newer Older
1 2 3 4
import torch
from torch import nn

from pytorch_widedeep.wdtypes import *  # noqa: F403
5
from pytorch_widedeep.models.tabular.mlp._layers import MLP
6 7
from pytorch_widedeep.models.tabular._base_tabular_model import (
    BaseTabularModelWithAttention,
8
)
9 10
from pytorch_widedeep.models.tabular.transformers._encoders import (
    TransformerEncoder,
11 12 13
)


14
class TabTransformer(BaseTabularModelWithAttention):
J
jrzaurin 已提交
15 16 17 18 19 20
    r"""Defines a ``TabTransformer`` model
    (`arXiv:2012.06678 <https://arxiv.org/abs/2012.06678>`_) that can be used
    as the ``deeptabular`` component of a Wide & Deep model.

    Note that this is an enhanced adaptation of the model described in the
    original publication, containing a series of additional features.
21 22 23 24 25

    Parameters
    ----------
    column_idx: Dict
        Dict containing the index of the columns that will be passed through
26 27 28 29 30
        the ``TabMlp`` model. Required to slice the tensors. e.g. {'education':
        0, 'relationship': 1, 'workclass': 2, ...}
    cat_embed_input: List, Optional, default = None
        List of Tuples with the column name, number of unique values and
        embedding dimension. e.g. [(education, 11, 32), ...]
31
    cat_embed_dropout: float, default = 0.1
32
        Categorical embeddings dropout
J
jrzaurin 已提交
33
    use_cat_bias: bool, default = False,
34 35 36
        Boolean indicating in bias will be used for the categorical embeddings
    cat_embed_activation: Optional, str, default = None,
        Activation function for the categorical embeddings
37
    full_embed_dropout: bool, default = False
38 39
        Boolean indicating if an entire embedding (i.e. the representation of
        one column) will be dropped in the batch. See:
J
jrzaurin 已提交
40
        :obj:`pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`.
41
        If ``full_embed_dropout = True``, ``cat_embed_dropout`` is ignored.
42
    shared_embed: bool, default = False
43 44 45 46 47 48 49 50 51 52 53
        The idea behind ``shared_embed`` is described in the Appendix A in the
        `TabTransformer paper <https://arxiv.org/abs/2012.06678>`_: `'The
        goal of having column embedding is to enable the model to distinguish
        the classes in one column from those in the other columns'`. In other
        words, the idea is to let the model learn which column is embedded
        at the time.
    add_shared_embed: bool, default = False,
        The two embedding sharing strategies are: 1) add the shared embeddings
        to the column embeddings or 2) to replace the first
        ``frac_shared_embed`` with the shared embeddings.
        See :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings`
54
    frac_shared_embed: float, default = 0.25
J
jrzaurin 已提交
55 56 57
        The fraction of embeddings that will be shared (if ``add_shared_embed
        = False``) by all the different categories for one particular
        column.
58 59
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
60 61 62
    cont_norm_layer: str, default =  "batchnorm"
        Type of normalization layer applied to the continuous features. Options
        are: 'layernorm', 'batchnorm' or None.
J
jrzaurin 已提交
63
    embed_continuous: bool, default = False
64
        Boolean indicating if the continuous features will be "embedded". See
J
jrzaurin 已提交
65 66 67 68 69 70 71
        ``pytorch_widedeep.models.transformers._layers.ContinuousEmbeddings``
        Note that setting this to ``True`` is similar (but not identical) to the
        so called `FT-Transformer <https://arxiv.org/abs/2106.11959>`_
        (Feature Tokenizer + Transformer).
        See :obj:`pytorch_widedeep.models.transformers.ft_transformer.FTTransformer`
        for details on the dedicated implementation available in this
        library
72 73 74 75
    cont_embed_dropout: float, default = 0.1,
        Continuous embeddings dropout
    use_cont_bias: bool, default = True,
        Boolean indicating in bias will be used for the continuous embeddings
76
    cont_embed_activation: str, default = None
77
        String indicating the activation function to be applied to the
J
jrzaurin 已提交
78 79
        continuous embeddings, if any. ``tanh``, ``relu``, ``leaky_relu`` and
        ``gelu`` are supported.
80
    input_dim: int, default = 32
J
jrzaurin 已提交
81 82
        The so-called *dimension of the model*. In general is the number of
        embeddings used to encode the categorical and/or continuous columns
83 84
    n_heads: int, default = 8
        Number of attention heads per Transformer block
85 86
    use_bias: bool, default = False
        Boolean indicating whether or not to use bias in the Q, K, and V
J
jrzaurin 已提交
87
        projection layers.
88
    n_blocks: int, default = 4
89
        Number of Transformer blocks
90
    attn_dropout: float, default = 0.2
J
jrzaurin 已提交
91
        Dropout that will be applied to the Multi-Head Attention layers
92 93
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
94
    transformer_activation: str, default = "gelu"
J
jrzaurin 已提交
95 96
        Transformer Encoder activation function. ``tanh``, ``relu``,
        ``leaky_relu``, ``gelu``, ``geglu`` and ``reglu`` are supported
97
    mlp_hidden_dims: List, Optional, default = None
J
jrzaurin 已提交
98 99
        MLP hidden dimensions. If not provided it will default to ``[l, 4*l,
        2*l]`` where ``l`` is the MLP input dimension
100
    mlp_activation: str, default = "relu"
J
jrzaurin 已提交
101 102
        MLP activation function. ``tanh``, ``relu``, ``leaky_relu`` and
        ``gelu`` are supported
103 104
    mlp_dropout: float, default = 0.1
        Dropout that will be applied to the final MLP
105 106 107 108 109 110 111 112 113 114 115 116 117
    mlp_batchnorm: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        dense layers
    mlp_batchnorm_last: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers
    mlp_linear_first: bool, default = False
        Boolean indicating whether the order of the operations in the dense
        layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
        LIN -> ACT]``

    Attributes
    ----------
118
    cat_and_cont_embed: ``nn.Module``
J
jrzaurin 已提交
119
        This is the module that processes the categorical and continuous columns
120 121 122 123 124 125 126 127 128 129 130 131 132 133
    transformer_blks: ``nn.Sequential``
        Sequence of Transformer blocks
    transformer_mlp: ``nn.Module``
        MLP component in the model
    output_dim: int
        The output dimension of the model. This is a required attribute
        neccesary to build the WideDeep class

    Example
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import TabTransformer
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
134
    >>> cat_embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
135 136
    >>> continuous_cols = ['e']
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
137
    >>> model = TabTransformer(column_idx=column_idx, cat_embed_input=cat_embed_input, continuous_cols=continuous_cols)
138 139 140
    >>> out = model(X_tab)
    """

141 142 143
    def __init__(
        self,
        column_idx: Dict[str, int],
144 145
        cat_embed_input: Optional[List[Tuple[str, int]]] = None,
        cat_embed_dropout: float = 0.1,
146 147
        use_cat_bias: bool = False,
        cat_embed_activation: Optional[str] = None,
148 149 150
        full_embed_dropout: bool = False,
        shared_embed: bool = False,
        add_shared_embed: bool = False,
151 152
        frac_shared_embed: float = 0.25,
        continuous_cols: Optional[List[str]] = None,
153
        cont_norm_layer: str = None,
154 155 156 157
        embed_continuous: bool = False,
        cont_embed_dropout: float = 0.1,
        use_cont_bias: bool = True,
        cont_embed_activation: Optional[str] = None,
158 159
        input_dim: int = 32,
        n_heads: int = 8,
160
        use_bias: bool = False,
161
        n_blocks: int = 4,
162 163
        attn_dropout: float = 0.2,
        ff_dropout: float = 0.1,
164 165 166
        transformer_activation: str = "gelu",
        mlp_hidden_dims: Optional[List[int]] = None,
        mlp_activation: str = "relu",
167
        mlp_dropout: float = 0.1,
168 169 170 171
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = True,
    ):
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
        super(TabTransformer, self).__init__(
            column_idx=column_idx,
            cat_embed_input=cat_embed_input,
            cat_embed_dropout=cat_embed_dropout,
            use_cat_bias=use_cat_bias,
            cat_embed_activation=cat_embed_activation,
            full_embed_dropout=full_embed_dropout,
            shared_embed=shared_embed,
            add_shared_embed=add_shared_embed,
            frac_shared_embed=frac_shared_embed,
            continuous_cols=continuous_cols,
            cont_norm_layer=cont_norm_layer,
            embed_continuous=embed_continuous,
            cont_embed_dropout=cont_embed_dropout,
            use_cont_bias=use_cont_bias,
            cont_embed_activation=cont_embed_activation,
            input_dim=input_dim,
        )
190

191
        self.n_heads = n_heads
192
        self.use_bias = use_bias
193
        self.n_blocks = n_blocks
194 195
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
196
        self.transformer_activation = transformer_activation
197

198 199
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_activation = mlp_activation
200
        self.mlp_dropout = mlp_dropout
201 202 203 204
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
        self.mlp_linear_first = mlp_linear_first

205
        self.with_cls_token = "cls_token" in column_idx
206
        self.n_cat = len(cat_embed_input) if cat_embed_input is not None else 0
207
        self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
208

209 210 211 212
        if self.n_cont and not self.n_cat and not self.embed_continuous:
            raise ValueError(
                "If only continuous features are used 'embed_continuous' must be set to 'True'"
            )
213

J
jrzaurin 已提交
214
        # Embeddings are instantiated at the base model
215
        # Transformer blocks
216 217 218
        self.transformer_blks = nn.Sequential()
        for i in range(n_blocks):
            self.transformer_blks.add_module(
219
                "transformer_block" + str(i),
220 221 222
                TransformerEncoder(
                    input_dim,
                    n_heads,
223
                    use_bias,
224 225
                    attn_dropout,
                    ff_dropout,
226 227 228 229
                    transformer_activation,
                ),
            )

230
        # Mlp
231
        attn_output_dim = self._compute_attn_output_dim()
232
        if not mlp_hidden_dims:
233 234 235 236 237 238
            mlp_hidden_dims = [
                attn_output_dim,
                attn_output_dim * 4,
                attn_output_dim * 2,
            ]
        else:
239 240
            mlp_hidden_dims = [attn_output_dim] + mlp_hidden_dims

241 242 243
        self.transformer_mlp = MLP(
            mlp_hidden_dims,
            mlp_activation,
244
            mlp_dropout,
245 246 247 248 249 250 251 252 253 254
            mlp_batchnorm,
            mlp_batchnorm_last,
            mlp_linear_first,
        )

        # the output_dim attribute will be used as input_dim when "merging" the models
        self.output_dim = mlp_hidden_dims[-1]

    def forward(self, X: Tensor) -> Tensor:

255 256 257 258 259 260 261 262 263 264 265
        if not self.embed_continuous:
            x_cat, x_cont = self.cat_and_cont_embed(X)
            if x_cat is not None:
                x = (
                    self.cat_embed_act_fn(x_cat)
                    if self.cat_embed_act_fn is not None
                    else x_cat
                )
        else:
            x = self._get_embeddings(X)
            x_cont = None
266

267
        x = self.transformer_blks(x)
268
        if self.with_cls_token:
269 270 271 272
            x = x[:, 0, :]
        else:
            x = x.flatten(1)

273
        if x_cont is not None and not self.embed_continuous:
274 275 276 277
            x = torch.cat([x, x_cont], 1)

        return self.transformer_mlp(x)

278
    @property
279
    def attention_weights(self) -> List:
280 281 282 283 284 285 286 287 288
        r"""List with the attention weights

        The shape of the attention weights is:

        :math:`(N, H, F, F)`

        Where *N* is the batch size, *H* is the number of attention heads
        and *F* is the number of features/columns in the dataset
        """
289
        return [blk.attn.attn_weights for blk in self.transformer_blks]
290

291
    def _compute_attn_output_dim(self) -> int:
292

293
        if self.with_cls_token:
294 295 296 297
            if self.embed_continuous:
                attn_output_dim = self.input_dim
            else:
                attn_output_dim = self.input_dim + self.n_cont
298 299
        elif self.embed_continuous:
            attn_output_dim = (self.n_cat + self.n_cont) * self.input_dim
300
        else:
301
            attn_output_dim = self.n_cat * self.input_dim + self.n_cont
302

303
        return attn_output_dim