tab_transformer.py 11.9 KB
Newer Older
1 2 3 4
import torch
from torch import nn

from pytorch_widedeep.wdtypes import *  # noqa: F403
5
from pytorch_widedeep.models.embeddings_layers import (
6 7
    SameSizeCatAndContEmbeddings,
)
8 9 10
from pytorch_widedeep.models.tabular.mlp._layers import MLP
from pytorch_widedeep.models.tabular.transformers._encoders import (
    TransformerEncoder,
11 12 13 14
)


class TabTransformer(nn.Module):
J
jrzaurin 已提交
15 16 17 18 19 20
    r"""Defines a ``TabTransformer`` model
    (`arXiv:2012.06678 <https://arxiv.org/abs/2012.06678>`_) that can be used
    as the ``deeptabular`` component of a Wide & Deep model.

    Note that this is an enhanced adaptation of the model described in the
    original publication, containing a series of additional features.
21 22 23 24 25

    Parameters
    ----------
    column_idx: Dict
        Dict containing the index of the columns that will be passed through
J
jrzaurin 已提交
26 27
        the model. Required to slice the tensors. e.g.
        {'education': 0, 'relationship': 1, 'workclass': 2, ...}
28 29
    embed_input: List
        List of Tuples with the column name and number of unique values
J
jrzaurin 已提交
30
        e.g. [('education', 11), ...]
31 32 33
    embed_dropout: float, default = 0.1
        Dropout to be applied to the embeddings matrix
    full_embed_dropout: bool, default = False
34 35
        Boolean indicating if an entire embedding (i.e. the representation of
        one column) will be dropped in the batch. See:
J
jrzaurin 已提交
36
        :obj:`pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`.
37 38 39 40 41
        If ``full_embed_dropout = True``, ``embed_dropout`` is ignored.
    shared_embed: bool, default = False
        The idea behind ``shared_embed`` is described in the Appendix A in the paper:
        `'The goal of having column embedding is to enable the model to distinguish the
        classes in one column from those in the other columns'`. In other words, the idea
J
jrzaurin 已提交
42 43
        is to let the model learn which column is embedded at the time.
    add_shared_embed: bool, default = False
44 45
        The two embedding sharing strategies are: 1) add the shared embeddings to the column
        embeddings or 2) to replace the first ``frac_shared_embed`` with the shared
J
jrzaurin 已提交
46
        embeddings. See :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings`
47
    frac_shared_embed: float, default = 0.25
J
jrzaurin 已提交
48 49 50
        The fraction of embeddings that will be shared (if ``add_shared_embed
        = False``) by all the different categories for one particular
        column.
51 52
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
J
jrzaurin 已提交
53
    embed_continuous: bool, default = False
54
        Boolean indicating if the continuous features will be "embedded". See
J
jrzaurin 已提交
55 56 57 58 59 60 61
        ``pytorch_widedeep.models.transformers._layers.ContinuousEmbeddings``
        Note that setting this to ``True`` is similar (but not identical) to the
        so called `FT-Transformer <https://arxiv.org/abs/2106.11959>`_
        (Feature Tokenizer + Transformer).
        See :obj:`pytorch_widedeep.models.transformers.ft_transformer.FTTransformer`
        for details on the dedicated implementation available in this
        library
62 63
    embed_continuous_activation: str, default = None
        String indicating the activation function to be applied to the
J
jrzaurin 已提交
64 65 66 67 68 69
        continuous embeddings, if any. ``tanh``, ``relu``, ``leaky_relu`` and
        ``gelu`` are supported.
    cont_norm_layer: str, default =  "layernorm",
        Type of normalization layer applied to the continuous features before
        they are passed to the network. Options are: ``layernorm``,
        ``batchnorm`` or ``None``.
70
    input_dim: int, default = 32
J
jrzaurin 已提交
71 72
        The so-called *dimension of the model*. In general is the number of
        embeddings used to encode the categorical and/or continuous columns
73 74
    n_heads: int, default = 8
        Number of attention heads per Transformer block
75 76
    use_bias: bool, default = False
        Boolean indicating whether or not to use bias in the Q, K, and V
J
jrzaurin 已提交
77
        projection layers.
78
    n_blocks: int, default = 4
79
        Number of Transformer blocks
80
    attn_dropout: float, default = 0.2
J
jrzaurin 已提交
81
        Dropout that will be applied to the Multi-Head Attention layers
82 83
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
84
    transformer_activation: str, default = "gelu"
J
jrzaurin 已提交
85 86
        Transformer Encoder activation function. ``tanh``, ``relu``,
        ``leaky_relu``, ``gelu``, ``geglu`` and ``reglu`` are supported
87
    mlp_hidden_dims: List, Optional, default = None
J
jrzaurin 已提交
88 89
        MLP hidden dimensions. If not provided it will default to ``[l, 4*l,
        2*l]`` where ``l`` is the MLP input dimension
90
    mlp_activation: str, default = "relu"
J
jrzaurin 已提交
91 92
        MLP activation function. ``tanh``, ``relu``, ``leaky_relu`` and
        ``gelu`` are supported
93 94
    mlp_dropout: float, default = 0.1
        Dropout that will be applied to the final MLP
95 96 97 98 99 100 101 102 103 104 105 106 107
    mlp_batchnorm: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        dense layers
    mlp_batchnorm_last: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers
    mlp_linear_first: bool, default = False
        Boolean indicating whether the order of the operations in the dense
        layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
        LIN -> ACT]``

    Attributes
    ----------
108
    cat_and_cont_embed: ``nn.Module``
J
jrzaurin 已提交
109
        This is the module that processes the categorical and continuous columns
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
    transformer_blks: ``nn.Sequential``
        Sequence of Transformer blocks
    transformer_mlp: ``nn.Module``
        MLP component in the model
    output_dim: int
        The output dimension of the model. This is a required attribute
        neccesary to build the WideDeep class

    Example
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import TabTransformer
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
    >>> embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
    >>> continuous_cols = ['e']
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
    >>> model = TabTransformer(column_idx=column_idx, embed_input=embed_input, continuous_cols=continuous_cols)
    >>> out = model(X_tab)
    """

131 132 133
    def __init__(
        self,
        column_idx: Dict[str, int],
134 135
        cat_embed_input: Optional[List[Tuple[str, int]]] = None,
        cat_embed_dropout: float = 0.1,
136 137 138
        full_embed_dropout: bool = False,
        shared_embed: bool = False,
        add_shared_embed: bool = False,
139 140 141
        frac_shared_embed: float = 0.25,
        continuous_cols: Optional[List[str]] = None,
        embed_continuous: bool = False,
142
        embed_continuous_activation: str = None,
143 144
        cont_embed_dropout: float = 0.0,
        cont_embed_activation: str = None,
145
        cont_norm_layer: str = None,
146 147
        input_dim: int = 32,
        n_heads: int = 8,
148
        use_bias: bool = False,
149
        n_blocks: int = 4,
150 151
        attn_dropout: float = 0.2,
        ff_dropout: float = 0.1,
152 153 154
        transformer_activation: str = "gelu",
        mlp_hidden_dims: Optional[List[int]] = None,
        mlp_activation: str = "relu",
155
        mlp_dropout: float = 0.1,
156 157 158 159 160 161 162
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = True,
    ):
        super(TabTransformer, self).__init__()

        self.column_idx = column_idx
163 164
        self.cat_embed_input = cat_embed_input
        self.cat_embed_dropout = cat_embed_dropout
165 166 167 168
        self.full_embed_dropout = full_embed_dropout
        self.shared_embed = shared_embed
        self.add_shared_embed = add_shared_embed
        self.frac_shared_embed = frac_shared_embed
169

170 171
        self.continuous_cols = continuous_cols
        self.embed_continuous = embed_continuous
172
        self.embed_continuous_activation = embed_continuous_activation
173 174
        self.cont_embed_dropout = cont_embed_dropout
        self.cont_embed_activation = cont_embed_activation
175
        self.cont_norm_layer = cont_norm_layer
176

177 178
        self.input_dim = input_dim
        self.n_heads = n_heads
179
        self.use_bias = use_bias
180
        self.n_blocks = n_blocks
181 182
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
183
        self.transformer_activation = transformer_activation
184

185 186
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_activation = mlp_activation
187
        self.mlp_dropout = mlp_dropout
188 189 190 191
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
        self.mlp_linear_first = mlp_linear_first

192
        self.with_cls_token = "cls_token" in column_idx
193
        self.n_cat = len(cat_embed_input) if cat_embed_input is not None else 0
194
        self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
195

196 197 198 199
        if self.n_cont and not self.n_cat and not self.embed_continuous:
            raise ValueError(
                "If only continuous features are used 'embed_continuous' must be set to 'True'"
            )
200

201
        self.cat_and_cont_embed = SameSizeCatAndContEmbeddings(
202 203
            input_dim,
            column_idx,
204 205
            cat_embed_input,
            cat_embed_dropout,
206 207 208 209
            full_embed_dropout,
            shared_embed,
            add_shared_embed,
            frac_shared_embed,
210
            False,  # use_embed_bias
211 212
            continuous_cols,
            embed_continuous,
213
            cont_embed_dropout,
214
            embed_continuous_activation,
215
            True,  # use_cont_bias
216 217
            cont_norm_layer,
        )
218 219 220 221

        self.transformer_blks = nn.Sequential()
        for i in range(n_blocks):
            self.transformer_blks.add_module(
222
                "transformer_block" + str(i),
223 224 225
                TransformerEncoder(
                    input_dim,
                    n_heads,
226
                    use_bias,
227 228
                    attn_dropout,
                    ff_dropout,
229 230 231 232
                    transformer_activation,
                ),
            )

233
        attn_output_dim = self._compute_attn_output_dim()
234
        if not mlp_hidden_dims:
235 236 237 238 239 240
            mlp_hidden_dims = [
                attn_output_dim,
                attn_output_dim * 4,
                attn_output_dim * 2,
            ]
        else:
241 242
            mlp_hidden_dims = [attn_output_dim] + mlp_hidden_dims

243 244 245
        self.transformer_mlp = MLP(
            mlp_hidden_dims,
            mlp_activation,
246
            mlp_dropout,
247 248 249 250 251 252 253 254 255 256
            mlp_batchnorm,
            mlp_batchnorm_last,
            mlp_linear_first,
        )

        # the output_dim attribute will be used as input_dim when "merging" the models
        self.output_dim = mlp_hidden_dims[-1]

    def forward(self, X: Tensor) -> Tensor:

257
        x_cat, x_cont = self.cat_and_cont_embed(X)
258

259 260 261 262
        if x_cat is not None:
            x = x_cat
        if x_cont is not None and self.embed_continuous:
            x = torch.cat([x, x_cont], 1) if x_cat is not None else x_cont
263

264
        x = self.transformer_blks(x)
265

266
        if self.with_cls_token:
267 268 269 270
            x = x[:, 0, :]
        else:
            x = x.flatten(1)

271
        if x_cont is not None and not self.embed_continuous:
272 273 274 275
            x = torch.cat([x, x_cont], 1)

        return self.transformer_mlp(x)

276
    @property
277
    def attention_weights(self) -> List:
278 279 280 281 282 283 284 285 286
        r"""List with the attention weights

        The shape of the attention weights is:

        :math:`(N, H, F, F)`

        Where *N* is the batch size, *H* is the number of attention heads
        and *F* is the number of features/columns in the dataset
        """
287
        return [blk.attn.attn_weights for blk in self.transformer_blks]
288

289
    def _compute_attn_output_dim(self) -> int:
290

291
        if self.with_cls_token:
292 293 294 295
            if self.embed_continuous:
                attn_output_dim = self.input_dim
            else:
                attn_output_dim = self.input_dim + self.n_cont
296 297
        elif self.embed_continuous:
            attn_output_dim = (self.n_cat + self.n_cont) * self.input_dim
298
        else:
299
            attn_output_dim = self.n_cat * self.input_dim + self.n_cont
300

301
        return attn_output_dim