saint.py 11.4 KB
Newer Older
1 2 3
from torch import nn

from pytorch_widedeep.wdtypes import *  # noqa: F403
4
from pytorch_widedeep.models.tabular.mlp._layers import MLP
5 6
from pytorch_widedeep.models.tabular._base_tabular_model import (
    BaseTabularModelWithAttention,
7
)
8
from pytorch_widedeep.models.tabular.transformers._encoders import SaintEncoder
9 10


11
class SAINT(BaseTabularModelWithAttention):
J
jrzaurin 已提交
12 13 14
    r"""Defines a ``SAINT`` model
    (`arXiv:2106.01342 <https://arxiv.org/abs/2106.01342>`_) that can be used
    as the ``deeptabular`` component of a Wide & Deep model.
15 16 17 18 19

    Parameters
    ----------
    column_idx: Dict
        Dict containing the index of the columns that will be passed through
20 21 22 23 24
        the ``TabMlp`` model. Required to slice the tensors. e.g. {'education':
        0, 'relationship': 1, 'workclass': 2, ...}
    cat_embed_input: List, Optional, default = None
        List of Tuples with the column name, number of unique values and
        embedding dimension. e.g. [(education, 11, 32), ...]
25
    cat_embed_dropout: float, default = 0.1
26
        Categorical embeddings dropout
27 28 29 30
    use_cat_bias: bool, default = True,
        Boolean indicating in bias will be used for the categorical embeddings
    cat_embed_activation: Optional, str, default = None,
        Activation function for the categorical embeddings
31 32 33
    full_embed_dropout: bool, default = False
        Boolean indicating if an entire embedding (i.e. the representation of
        one column) will be dropped in the batch. See:
J
jrzaurin 已提交
34
        :obj:`pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`.
35
        If ``full_embed_dropout = True``, ``cat_embed_dropout`` is ignored.
36
    shared_embed: bool, default = False
37 38 39 40
        The idea behind ``shared_embed`` is described in the Appendix A in the
        `TabTransformer paper <https://arxiv.org/abs/2012.06678>`_: `'The
        goal of having column embedding is to enable the model to distinguish
        the classes in one column from those in the other columns'`. In other
J
jrzaurin 已提交
41
        words, the idea is to let the model learn which column is embedded
42
        at the time.
J
jrzaurin 已提交
43 44 45 46 47
    add_shared_embed: bool, default = False
        The two embedding sharing strategies are: 1) add the shared embeddings
        to the column embeddings or 2) to replace the first
        ``frac_shared_embed`` with the shared embeddings.
        See :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings`
48
    frac_shared_embed: float, default = 0.25
J
jrzaurin 已提交
49 50 51
        The fraction of embeddings that will be shared (if ``add_shared_embed
        = False``) by all the different categories for one particular
        column.
52 53
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
54 55 56 57 58 59 60 61
    cont_norm_layer: str, default =  "batchnorm"
        Type of normalization layer applied to the continuous features. Options
        are: 'layernorm', 'batchnorm' or None.
    cont_embed_dropout: float, default = 0.1,
        Continuous embeddings dropout
    use_cont_bias: bool, default = True,
        Boolean indicating in bias will be used for the continuous embeddings
    cont_embed_activation: str, default = None
62
        String indicating the activation function to be applied to the
J
jrzaurin 已提交
63 64
        continuous embeddings, if any. ``tanh``, ``relu``, ``leaky_relu`` and
        ``gelu`` are supported.
65
    input_dim: int, default = 32
J
jrzaurin 已提交
66 67
        The so-called *dimension of the model*. In general is the number of
        embeddings used to encode the categorical and/or continuous columns
68 69
    n_heads: int, default = 8
        Number of attention heads per Transformer block
70 71 72
    use_bias: bool, default = False
        Boolean indicating whether or not to use bias in the Q, K, and V
        projection layers
73
    n_blocks: int, default = 2
J
jrzaurin 已提交
74
        Number of SAINT-Transformer blocks. 1 in the paper.
75
    attn_dropout: float, default = 0.2
J
jrzaurin 已提交
76 77
        Dropout that will be applied to the Multi-Head Attention column and
        row layers
78 79
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
80
    transformer_activation: str, default = "gelu"
J
jrzaurin 已提交
81 82
        Transformer Encoder activation function. ``tanh``, ``relu``,
        ``leaky_relu``, ``gelu``, ``geglu`` and ``reglu`` are supported
83
    mlp_hidden_dims: List, Optional, default = None
J
jrzaurin 已提交
84 85
        MLP hidden dimensions. If not provided it will default to ``[l, 4*l,
        2*l]`` where ``l`` is the MLP input dimension
86
    mlp_activation: str, default = "relu"
J
jrzaurin 已提交
87 88
        MLP activation function. ``tanh``, ``relu``, ``leaky_relu`` and
        ``gelu`` are supported
89 90
    mlp_dropout: float, default = 0.1
        Dropout that will be applied to the final MLP
91 92 93 94 95 96 97 98 99 100 101 102 103
    mlp_batchnorm: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        dense layers
    mlp_batchnorm_last: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers
    mlp_linear_first: bool, default = False
        Boolean indicating whether the order of the operations in the dense
        layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
        LIN -> ACT]``

    Attributes
    ----------
104
    cat_and_cont_embed: ``nn.Module``
J
jrzaurin 已提交
105
        This is the module that processes the categorical and continuous columns
106
    saint_blks: ``nn.Sequential``
J
jrzaurin 已提交
107
        Sequence of SAINT-Transformer blocks
108
    saint_mlp: ``nn.Module``
109 110 111 112 113 114 115 116 117 118 119
        MLP component in the model
    output_dim: int
        The output dimension of the model. This is a required attribute
        neccesary to build the WideDeep class

    Example
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import SAINT
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
120
    >>> cat_embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
121 122
    >>> continuous_cols = ['e']
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
123
    >>> model = SAINT(column_idx=column_idx, cat_embed_input=cat_embed_input, continuous_cols=continuous_cols)
124 125 126 127 128 129
    >>> out = model(X_tab)
    """

    def __init__(
        self,
        column_idx: Dict[str, int],
130 131
        cat_embed_input: Optional[List[Tuple[str, int]]] = None,
        cat_embed_dropout: float = 0.1,
132 133
        use_cat_bias: bool = False,
        cat_embed_activation: Optional[str] = None,
134 135 136 137 138
        full_embed_dropout: bool = False,
        shared_embed: bool = False,
        add_shared_embed: bool = False,
        frac_shared_embed: float = 0.25,
        continuous_cols: Optional[List[str]] = None,
139
        cont_norm_layer: str = None,
140 141 142
        cont_embed_dropout: float = 0.1,
        use_cont_bias: bool = True,
        cont_embed_activation: Optional[str] = None,
143
        input_dim: int = 32,
144
        use_bias: bool = False,
145
        n_heads: int = 8,
146
        n_blocks: int = 2,
147 148 149
        attn_dropout: float = 0.1,
        ff_dropout: float = 0.2,
        transformer_activation: str = "gelu",
150 151
        mlp_hidden_dims: Optional[List[int]] = None,
        mlp_activation: str = "relu",
152
        mlp_dropout: float = 0.1,
153 154 155 156
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = True,
    ):
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
        super(SAINT, self).__init__(
            column_idx=column_idx,
            cat_embed_input=cat_embed_input,
            cat_embed_dropout=cat_embed_dropout,
            use_cat_bias=use_cat_bias,
            cat_embed_activation=cat_embed_activation,
            full_embed_dropout=full_embed_dropout,
            shared_embed=shared_embed,
            add_shared_embed=add_shared_embed,
            frac_shared_embed=frac_shared_embed,
            continuous_cols=continuous_cols,
            cont_norm_layer=cont_norm_layer,
            embed_continuous=True,
            cont_embed_dropout=cont_embed_dropout,
            use_cont_bias=use_cont_bias,
            cont_embed_activation=cont_embed_activation,
            input_dim=input_dim,
        )
175 176

        self.column_idx = column_idx
177 178
        self.cat_embed_input = cat_embed_input
        self.cat_embed_dropout = cat_embed_dropout
179 180 181 182
        self.full_embed_dropout = full_embed_dropout
        self.shared_embed = shared_embed
        self.add_shared_embed = add_shared_embed
        self.frac_shared_embed = frac_shared_embed
183

184
        self.continuous_cols = continuous_cols
185
        self.cont_embed_activation = cont_embed_activation
186
        self.cont_embed_dropout = cont_embed_dropout
187
        self.cont_norm_layer = cont_norm_layer
188

189 190 191 192 193 194 195
        self.input_dim = input_dim
        self.use_bias = use_bias
        self.n_heads = n_heads
        self.n_blocks = n_blocks
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
        self.transformer_activation = transformer_activation
196

197 198
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_activation = mlp_activation
199
        self.mlp_dropout = mlp_dropout
200 201 202 203 204
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
        self.mlp_linear_first = mlp_linear_first

        self.with_cls_token = "cls_token" in column_idx
205
        self.n_cat = len(cat_embed_input) if cat_embed_input is not None else 0
206 207 208
        self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
        self.n_feats = self.n_cat + self.n_cont

209
        # Embeddings are be instantiated at the base model
210
        # Transformer blocks
211
        self.saint_blks = nn.Sequential()
212
        for i in range(n_blocks):
213
            self.saint_blks.add_module(
214
                "saint_block" + str(i),
215 216 217
                SaintEncoder(
                    input_dim,
                    n_heads,
218
                    use_bias,
219 220
                    attn_dropout,
                    ff_dropout,
221
                    transformer_activation,
222
                    self.n_feats,
223 224
                ),
            )
225

226 227 228
        attn_output_dim = (
            self.input_dim if self.with_cls_token else self.n_feats * self.input_dim
        )
229 230

        # Mlp
231 232 233 234 235 236 237
        if not mlp_hidden_dims:
            mlp_hidden_dims = [
                attn_output_dim,
                attn_output_dim * 4,
                attn_output_dim * 2,
            ]
        else:
238 239
            mlp_hidden_dims = [attn_output_dim] + mlp_hidden_dims

240
        self.saint_mlp = MLP(
241 242 243 244 245 246 247 248 249 250 251 252
            mlp_hidden_dims,
            mlp_activation,
            mlp_dropout,
            mlp_batchnorm,
            mlp_batchnorm_last,
            mlp_linear_first,
        )

        # the output_dim attribute will be used as input_dim when "merging" the models
        self.output_dim = mlp_hidden_dims[-1]

    def forward(self, X: Tensor) -> Tensor:
253
        x = self._get_embeddings(X)
254
        x = self.saint_blks(x)
255 256 257 258
        if self.with_cls_token:
            x = x[:, 0, :]
        else:
            x = x.flatten(1)
259
        return self.saint_mlp(x)
260

261
    @property
262 263 264 265
    def attention_weights(self) -> List:
        r"""List with the attention weights. Each element of the list is a tuple
        where the first and the second elements are the column and row
        attention weights respectively
266 267 268 269 270 271 272 273 274

        The shape of the attention weights is:

            - column attention: :math:`(N, H, F, F)`

            - row attention: :math:`(1, H, N, N)`

        where *N* is the batch size, *H* is the number of heads and *F* is the
        number of features/columns in the dataset
275
        """
276
        attention_weights = []
277
        for blk in self.saint_blks:
278 279 280 281
            attention_weights.append(
                (blk.col_attn.attn_weights, blk.row_attn.attn_weights)
            )
        return attention_weights