tab_transformer.py 12.0 KB
Newer Older
1 2 3 4
import torch
from torch import nn

from pytorch_widedeep.wdtypes import *  # noqa: F403
5
from pytorch_widedeep.models.tabular.mlp._layers import MLP
6 7
from pytorch_widedeep.models.tabular._base_tabular_model import (
    BaseTabularModelWithAttention,
8
)
9 10
from pytorch_widedeep.models.tabular.transformers._encoders import (
    TransformerEncoder,
11 12 13
)


14
class TabTransformer(BaseTabularModelWithAttention):
J
jrzaurin 已提交
15 16 17
    r"""Defines a `TabTransformer model <https://arxiv.org/abs/2012.06678>`_ that
    can be used as the ``deeptabular`` component of a Wide & Deep model or
    independently by itself.
J
jrzaurin 已提交
18 19

    Note that this is an enhanced adaptation of the model described in the
J
jrzaurin 已提交
20
    paper, containing a series of additional features.
21 22 23 24 25

    Parameters
    ----------
    column_idx: Dict
        Dict containing the index of the columns that will be passed through
J
jrzaurin 已提交
26 27
        the model. Required to slice the tensors. e.g.
        {'education': 0, 'relationship': 1, 'workclass': 2, ...}
28
    cat_embed_input: List, Optional, default = None
J
jrzaurin 已提交
29 30
        List of Tuples with the column name and number of unique values for
        each categorical component e.g. [(education, 11), ...]
31
    cat_embed_dropout: float, default = 0.1
32
        Categorical embeddings dropout
J
jrzaurin 已提交
33
    use_cat_bias: bool, default = False,
J
jrzaurin 已提交
34
        Boolean indicating if bias will be used for the categorical embeddings
35
    cat_embed_activation: Optional, str, default = None,
J
jrzaurin 已提交
36 37
        Activation function for the categorical embeddings, if any. `'tanh'`,
        `'relu'`, `'leaky_relu'` and `'gelu'` are supported.
38
    full_embed_dropout: bool, default = False
39 40
        Boolean indicating if an entire embedding (i.e. the representation of
        one column) will be dropped in the batch. See:
J
jrzaurin 已提交
41
        :obj:`pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`.
42
        If ``full_embed_dropout = True``, ``cat_embed_dropout`` is ignored.
43
    shared_embed: bool, default = False
44 45 46 47 48 49 50 51 52 53 54
        The idea behind ``shared_embed`` is described in the Appendix A in the
        `TabTransformer paper <https://arxiv.org/abs/2012.06678>`_: `'The
        goal of having column embedding is to enable the model to distinguish
        the classes in one column from those in the other columns'`. In other
        words, the idea is to let the model learn which column is embedded
        at the time.
    add_shared_embed: bool, default = False,
        The two embedding sharing strategies are: 1) add the shared embeddings
        to the column embeddings or 2) to replace the first
        ``frac_shared_embed`` with the shared embeddings.
        See :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings`
55
    frac_shared_embed: float, default = 0.25
J
jrzaurin 已提交
56 57 58
        The fraction of embeddings that will be shared (if ``add_shared_embed
        = False``) by all the different categories for one particular
        column.
59 60
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
61 62 63
    cont_norm_layer: str, default =  "batchnorm"
        Type of normalization layer applied to the continuous features. Options
        are: 'layernorm', 'batchnorm' or None.
J
jrzaurin 已提交
64
    embed_continuous: bool, default = False
J
jrzaurin 已提交
65 66
        Boolean indicating if the continuous columns will be embedded
        (i.e. passed each through a linear layer with or without activation)
67 68 69
    cont_embed_dropout: float, default = 0.1,
        Continuous embeddings dropout
    use_cont_bias: bool, default = True,
J
jrzaurin 已提交
70
        Boolean indicating if bias will be used for the continuous embeddings
71
    cont_embed_activation: str, default = None
J
jrzaurin 已提交
72
        Activation function to be applied to the continuous embeddings, if
73
        any. `'tanh'`, `'relu'`, `'leaky_relu'` and `'gelu'` are supported.
74
    input_dim: int, default = 32
J
jrzaurin 已提交
75 76
        The so-called *dimension of the model*. In general is the number of
        embeddings used to encode the categorical and/or continuous columns
77 78
    n_heads: int, default = 8
        Number of attention heads per Transformer block
J
jrzaurin 已提交
79
    use_qkv_bias: bool, default = False
80
        Boolean indicating whether or not to use bias in the Q, K, and V
J
jrzaurin 已提交
81
        projection layers.
82
    n_blocks: int, default = 4
83
        Number of Transformer blocks
84
    attn_dropout: float, default = 0.2
J
jrzaurin 已提交
85
        Dropout that will be applied to the Multi-Head Attention layers
86 87
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
88
    transformer_activation: str, default = "gelu"
89
        Transformer Encoder activation function. `'tanh'`, `'relu'`,
J
jrzaurin 已提交
90
        `'leaky_relu'`, `'gelu'`, `'geglu'` and `'reglu'` are supported
91
    mlp_hidden_dims: List, Optional, default = None
J
jrzaurin 已提交
92
        MLP hidden dimensions. If not provided it will default to ``[l, 4*l,
J
jrzaurin 已提交
93
        2*l]`` where ``l`` is the MLP's input dimension
94
    mlp_activation: str, default = "relu"
95
        MLP activation function. `'tanh'`, `'relu'`, `'leaky_relu'` and
J
jrzaurin 已提交
96
        `'gelu'` are supported
97 98
    mlp_dropout: float, default = 0.1
        Dropout that will be applied to the final MLP
99 100 101 102 103 104 105 106 107 108 109 110 111
    mlp_batchnorm: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        dense layers
    mlp_batchnorm_last: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers
    mlp_linear_first: bool, default = False
        Boolean indicating whether the order of the operations in the dense
        layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
        LIN -> ACT]``

    Attributes
    ----------
112
    cat_and_cont_embed: ``nn.Module``
J
jrzaurin 已提交
113
        This is the module that processes the categorical and continuous columns
114 115 116 117 118 119
    transformer_blks: ``nn.Sequential``
        Sequence of Transformer blocks
    transformer_mlp: ``nn.Module``
        MLP component in the model
    output_dim: int
        The output dimension of the model. This is a required attribute
J
jrzaurin 已提交
120
        neccesary to build the ``WideDeep`` class
121 122 123 124 125 126 127

    Example
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import TabTransformer
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
128
    >>> cat_embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
129 130
    >>> continuous_cols = ['e']
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
131
    >>> model = TabTransformer(column_idx=column_idx, cat_embed_input=cat_embed_input, continuous_cols=continuous_cols)
132 133 134
    >>> out = model(X_tab)
    """

135 136 137
    def __init__(
        self,
        column_idx: Dict[str, int],
138 139
        cat_embed_input: Optional[List[Tuple[str, int]]] = None,
        cat_embed_dropout: float = 0.1,
140 141
        use_cat_bias: bool = False,
        cat_embed_activation: Optional[str] = None,
142 143 144
        full_embed_dropout: bool = False,
        shared_embed: bool = False,
        add_shared_embed: bool = False,
145 146
        frac_shared_embed: float = 0.25,
        continuous_cols: Optional[List[str]] = None,
147
        cont_norm_layer: str = None,
148 149 150 151
        embed_continuous: bool = False,
        cont_embed_dropout: float = 0.1,
        use_cont_bias: bool = True,
        cont_embed_activation: Optional[str] = None,
152 153
        input_dim: int = 32,
        n_heads: int = 8,
J
jrzaurin 已提交
154
        use_qkv_bias: bool = False,
155
        n_blocks: int = 4,
156 157
        attn_dropout: float = 0.2,
        ff_dropout: float = 0.1,
158 159 160
        transformer_activation: str = "gelu",
        mlp_hidden_dims: Optional[List[int]] = None,
        mlp_activation: str = "relu",
161
        mlp_dropout: float = 0.1,
162 163 164 165
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = True,
    ):
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
        super(TabTransformer, self).__init__(
            column_idx=column_idx,
            cat_embed_input=cat_embed_input,
            cat_embed_dropout=cat_embed_dropout,
            use_cat_bias=use_cat_bias,
            cat_embed_activation=cat_embed_activation,
            full_embed_dropout=full_embed_dropout,
            shared_embed=shared_embed,
            add_shared_embed=add_shared_embed,
            frac_shared_embed=frac_shared_embed,
            continuous_cols=continuous_cols,
            cont_norm_layer=cont_norm_layer,
            embed_continuous=embed_continuous,
            cont_embed_dropout=cont_embed_dropout,
            use_cont_bias=use_cont_bias,
            cont_embed_activation=cont_embed_activation,
            input_dim=input_dim,
        )
184

185
        self.n_heads = n_heads
J
jrzaurin 已提交
186
        self.use_qkv_bias = use_qkv_bias
187
        self.n_blocks = n_blocks
188 189
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
190
        self.transformer_activation = transformer_activation
191

192 193
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_activation = mlp_activation
194
        self.mlp_dropout = mlp_dropout
195 196 197 198
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
        self.mlp_linear_first = mlp_linear_first

199
        self.with_cls_token = "cls_token" in column_idx
200
        self.n_cat = len(cat_embed_input) if cat_embed_input is not None else 0
201
        self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
202

203 204 205 206
        if self.n_cont and not self.n_cat and not self.embed_continuous:
            raise ValueError(
                "If only continuous features are used 'embed_continuous' must be set to 'True'"
            )
207

J
jrzaurin 已提交
208
        # Embeddings are instantiated at the base model
209
        # Transformer blocks
210 211 212
        self.transformer_blks = nn.Sequential()
        for i in range(n_blocks):
            self.transformer_blks.add_module(
213
                "transformer_block" + str(i),
214 215 216
                TransformerEncoder(
                    input_dim,
                    n_heads,
J
jrzaurin 已提交
217
                    use_qkv_bias,
218 219
                    attn_dropout,
                    ff_dropout,
220 221 222 223
                    transformer_activation,
                ),
            )

224
        # Mlp
225
        attn_output_dim = self._compute_attn_output_dim()
226
        if not mlp_hidden_dims:
227 228 229 230 231 232
            mlp_hidden_dims = [
                attn_output_dim,
                attn_output_dim * 4,
                attn_output_dim * 2,
            ]
        else:
233 234
            mlp_hidden_dims = [attn_output_dim] + mlp_hidden_dims

235 236 237
        self.transformer_mlp = MLP(
            mlp_hidden_dims,
            mlp_activation,
238
            mlp_dropout,
239 240 241 242 243 244
            mlp_batchnorm,
            mlp_batchnorm_last,
            mlp_linear_first,
        )

        # the output_dim attribute will be used as input_dim when "merging" the models
J
jrzaurin 已提交
245
        self.output_dim: int = mlp_hidden_dims[-1]
246 247 248

    def forward(self, X: Tensor) -> Tensor:

249 250 251 252 253 254 255 256 257 258 259
        if not self.embed_continuous:
            x_cat, x_cont = self.cat_and_cont_embed(X)
            if x_cat is not None:
                x = (
                    self.cat_embed_act_fn(x_cat)
                    if self.cat_embed_act_fn is not None
                    else x_cat
                )
        else:
            x = self._get_embeddings(X)
            x_cont = None
260

261
        x = self.transformer_blks(x)
262
        if self.with_cls_token:
263 264 265 266
            x = x[:, 0, :]
        else:
            x = x.flatten(1)

267
        if x_cont is not None and not self.embed_continuous:
268 269 270 271
            x = torch.cat([x, x_cont], 1)

        return self.transformer_mlp(x)

272
    @property
273
    def attention_weights(self) -> List:
J
jrzaurin 已提交
274
        r"""List with the attention weights per block
275 276 277 278 279 280 281 282

        The shape of the attention weights is:

        :math:`(N, H, F, F)`

        Where *N* is the batch size, *H* is the number of attention heads
        and *F* is the number of features/columns in the dataset
        """
283
        return [blk.attn.attn_weights for blk in self.transformer_blks]
284

285
    def _compute_attn_output_dim(self) -> int:
286

287
        if self.with_cls_token:
288 289 290 291
            if self.embed_continuous:
                attn_output_dim = self.input_dim
            else:
                attn_output_dim = self.input_dim + self.n_cont
292 293
        elif self.embed_continuous:
            attn_output_dim = (self.n_cat + self.n_cont) * self.input_dim
294
        else:
295
            attn_output_dim = self.n_cat * self.input_dim + self.n_cont
296

297
        return attn_output_dim