saint.py 11.5 KB
Newer Older
1 2 3
from torch import nn

from pytorch_widedeep.wdtypes import *  # noqa: F403
4
from pytorch_widedeep.models.tabular.mlp._layers import MLP
5 6
from pytorch_widedeep.models.tabular._base_tabular_model import (
    BaseTabularModelWithAttention,
7
)
8
from pytorch_widedeep.models.tabular.transformers._encoders import SaintEncoder
9 10


11
class SAINT(BaseTabularModelWithAttention):
J
jrzaurin 已提交
12 13 14 15
    r"""Defines a `SAINT model <https://arxiv.org/abs/2106.01342>`_ that
    can be used as the ``deeptabular`` component of a Wide & Deep model or
    independently by itself.

16 17 18 19 20

    Parameters
    ----------
    column_idx: Dict
        Dict containing the index of the columns that will be passed through
J
jrzaurin 已提交
21 22
        the model. Required to slice the tensors. e.g.
        {'education': 0, 'relationship': 1, 'workclass': 2, ...}
23
    cat_embed_input: List, Optional, default = None
J
jrzaurin 已提交
24 25
        List of Tuples with the column name and number of unique values and
        embedding dimension. e.g. [(education, 11), ...]
26
    cat_embed_dropout: float, default = 0.1
27
        Categorical embeddings dropout
J
jrzaurin 已提交
28
    use_cat_bias: bool, default = False,
J
jrzaurin 已提交
29
        Boolean indicating if bias will be used for the categorical embeddings
30
    cat_embed_activation: Optional, str, default = None,
J
jrzaurin 已提交
31 32
        Activation function for the categorical embeddings, if any. `'tanh'`,
        `'relu'`, `'leaky_relu'` and `'gelu'` are supported.
33 34 35
    full_embed_dropout: bool, default = False
        Boolean indicating if an entire embedding (i.e. the representation of
        one column) will be dropped in the batch. See:
J
jrzaurin 已提交
36
        :obj:`pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`.
37
        If ``full_embed_dropout = True``, ``cat_embed_dropout`` is ignored.
38
    shared_embed: bool, default = False
39 40 41 42
        The idea behind ``shared_embed`` is described in the Appendix A in the
        `TabTransformer paper <https://arxiv.org/abs/2012.06678>`_: `'The
        goal of having column embedding is to enable the model to distinguish
        the classes in one column from those in the other columns'`. In other
J
jrzaurin 已提交
43
        words, the idea is to let the model learn which column is embedded
44
        at the time.
J
jrzaurin 已提交
45 46 47 48 49
    add_shared_embed: bool, default = False
        The two embedding sharing strategies are: 1) add the shared embeddings
        to the column embeddings or 2) to replace the first
        ``frac_shared_embed`` with the shared embeddings.
        See :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings`
50
    frac_shared_embed: float, default = 0.25
J
jrzaurin 已提交
51 52 53
        The fraction of embeddings that will be shared (if ``add_shared_embed
        = False``) by all the different categories for one particular
        column.
54 55
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
56 57 58 59 60 61
    cont_norm_layer: str, default =  "batchnorm"
        Type of normalization layer applied to the continuous features. Options
        are: 'layernorm', 'batchnorm' or None.
    cont_embed_dropout: float, default = 0.1,
        Continuous embeddings dropout
    use_cont_bias: bool, default = True,
J
jrzaurin 已提交
62
        Boolean indicating if bias will be used for the continuous embeddings
63
    cont_embed_activation: str, default = None
J
jrzaurin 已提交
64
        Activation function to be applied to the continuous embeddings, if
65
        any. `'tanh'`, `'relu'`, `'leaky_relu'` and `'gelu'` are supported.
66
    input_dim: int, default = 32
J
jrzaurin 已提交
67 68
        The so-called *dimension of the model*. In general is the number of
        embeddings used to encode the categorical and/or continuous columns
69 70
    n_heads: int, default = 8
        Number of attention heads per Transformer block
J
jrzaurin 已提交
71
    use_qkv_bias: bool, default = False
72 73
        Boolean indicating whether or not to use bias in the Q, K, and V
        projection layers
74
    n_blocks: int, default = 2
J
jrzaurin 已提交
75
        Number of SAINT-Transformer blocks. 1 in the paper.
76
    attn_dropout: float, default = 0.2
J
jrzaurin 已提交
77 78
        Dropout that will be applied to the Multi-Head Attention column and
        row layers
79 80
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
81
    transformer_activation: str, default = "gelu"
82
        Transformer Encoder activation function. `'tanh'`, `'relu'`,
J
jrzaurin 已提交
83
        `'leaky_relu'`, `'gelu'`, `'geglu'` and `'reglu'` are supported
84
    mlp_hidden_dims: List, Optional, default = None
J
jrzaurin 已提交
85
        MLP hidden dimensions. If not provided it will default to ``[l, 4*l,
J
jrzaurin 已提交
86
        2*l]`` where ``l`` is the MLP's input dimension
87
    mlp_activation: str, default = "relu"
88
        MLP activation function. `'tanh'`, `'relu'`, `'leaky_relu'` and
J
jrzaurin 已提交
89
        `'gelu'` are supported
90 91
    mlp_dropout: float, default = 0.1
        Dropout that will be applied to the final MLP
92 93 94 95 96 97 98 99 100 101 102 103 104
    mlp_batchnorm: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        dense layers
    mlp_batchnorm_last: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers
    mlp_linear_first: bool, default = False
        Boolean indicating whether the order of the operations in the dense
        layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
        LIN -> ACT]``

    Attributes
    ----------
105
    cat_and_cont_embed: ``nn.Module``
J
jrzaurin 已提交
106
        This is the module that processes the categorical and continuous columns
107
    saint_blks: ``nn.Sequential``
J
jrzaurin 已提交
108
        Sequence of SAINT-Transformer blocks
109
    saint_mlp: ``nn.Module``
110 111 112
        MLP component in the model
    output_dim: int
        The output dimension of the model. This is a required attribute
J
jrzaurin 已提交
113
        neccesary to build the ``WideDeep`` class
114 115 116 117 118 119 120

    Example
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import SAINT
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
121
    >>> cat_embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
122 123
    >>> continuous_cols = ['e']
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
124
    >>> model = SAINT(column_idx=column_idx, cat_embed_input=cat_embed_input, continuous_cols=continuous_cols)
125 126 127 128 129 130
    >>> out = model(X_tab)
    """

    def __init__(
        self,
        column_idx: Dict[str, int],
131 132
        cat_embed_input: Optional[List[Tuple[str, int]]] = None,
        cat_embed_dropout: float = 0.1,
133 134
        use_cat_bias: bool = False,
        cat_embed_activation: Optional[str] = None,
135 136 137 138 139
        full_embed_dropout: bool = False,
        shared_embed: bool = False,
        add_shared_embed: bool = False,
        frac_shared_embed: float = 0.25,
        continuous_cols: Optional[List[str]] = None,
140
        cont_norm_layer: str = None,
141 142 143
        cont_embed_dropout: float = 0.1,
        use_cont_bias: bool = True,
        cont_embed_activation: Optional[str] = None,
144
        input_dim: int = 32,
J
jrzaurin 已提交
145
        use_qkv_bias: bool = False,
146
        n_heads: int = 8,
147
        n_blocks: int = 2,
148 149 150
        attn_dropout: float = 0.1,
        ff_dropout: float = 0.2,
        transformer_activation: str = "gelu",
151 152
        mlp_hidden_dims: Optional[List[int]] = None,
        mlp_activation: str = "relu",
153
        mlp_dropout: float = 0.1,
154 155 156 157
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = True,
    ):
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
        super(SAINT, self).__init__(
            column_idx=column_idx,
            cat_embed_input=cat_embed_input,
            cat_embed_dropout=cat_embed_dropout,
            use_cat_bias=use_cat_bias,
            cat_embed_activation=cat_embed_activation,
            full_embed_dropout=full_embed_dropout,
            shared_embed=shared_embed,
            add_shared_embed=add_shared_embed,
            frac_shared_embed=frac_shared_embed,
            continuous_cols=continuous_cols,
            cont_norm_layer=cont_norm_layer,
            embed_continuous=True,
            cont_embed_dropout=cont_embed_dropout,
            use_cont_bias=use_cont_bias,
            cont_embed_activation=cont_embed_activation,
            input_dim=input_dim,
        )
176 177

        self.column_idx = column_idx
178 179
        self.cat_embed_input = cat_embed_input
        self.cat_embed_dropout = cat_embed_dropout
180 181 182 183
        self.full_embed_dropout = full_embed_dropout
        self.shared_embed = shared_embed
        self.add_shared_embed = add_shared_embed
        self.frac_shared_embed = frac_shared_embed
184

185
        self.continuous_cols = continuous_cols
186
        self.cont_embed_activation = cont_embed_activation
187
        self.cont_embed_dropout = cont_embed_dropout
188
        self.cont_norm_layer = cont_norm_layer
189

190
        self.input_dim = input_dim
J
jrzaurin 已提交
191
        self.use_qkv_bias = use_qkv_bias
192 193 194 195 196
        self.n_heads = n_heads
        self.n_blocks = n_blocks
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
        self.transformer_activation = transformer_activation
197

198 199
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_activation = mlp_activation
200
        self.mlp_dropout = mlp_dropout
201 202 203 204 205
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
        self.mlp_linear_first = mlp_linear_first

        self.with_cls_token = "cls_token" in column_idx
206
        self.n_cat = len(cat_embed_input) if cat_embed_input is not None else 0
207 208 209
        self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
        self.n_feats = self.n_cat + self.n_cont

J
jrzaurin 已提交
210
        # Embeddings are instantiated at the base model
211
        # Transformer blocks
212
        self.saint_blks = nn.Sequential()
213
        for i in range(n_blocks):
214
            self.saint_blks.add_module(
215
                "saint_block" + str(i),
216 217 218
                SaintEncoder(
                    input_dim,
                    n_heads,
J
jrzaurin 已提交
219
                    use_qkv_bias,
220 221
                    attn_dropout,
                    ff_dropout,
222
                    transformer_activation,
223
                    self.n_feats,
224 225
                ),
            )
226

227 228 229
        attn_output_dim = (
            self.input_dim if self.with_cls_token else self.n_feats * self.input_dim
        )
230 231

        # Mlp
232 233 234 235 236 237 238
        if not mlp_hidden_dims:
            mlp_hidden_dims = [
                attn_output_dim,
                attn_output_dim * 4,
                attn_output_dim * 2,
            ]
        else:
239 240
            mlp_hidden_dims = [attn_output_dim] + mlp_hidden_dims

241
        self.saint_mlp = MLP(
242 243 244 245 246 247 248 249 250
            mlp_hidden_dims,
            mlp_activation,
            mlp_dropout,
            mlp_batchnorm,
            mlp_batchnorm_last,
            mlp_linear_first,
        )

        # the output_dim attribute will be used as input_dim when "merging" the models
J
jrzaurin 已提交
251
        self.output_dim: int = mlp_hidden_dims[-1]
252 253

    def forward(self, X: Tensor) -> Tensor:
254
        x = self._get_embeddings(X)
255
        x = self.saint_blks(x)
256 257 258 259
        if self.with_cls_token:
            x = x[:, 0, :]
        else:
            x = x.flatten(1)
260
        return self.saint_mlp(x)
261

262
    @property
263 264 265 266
    def attention_weights(self) -> List:
        r"""List with the attention weights. Each element of the list is a tuple
        where the first and the second elements are the column and row
        attention weights respectively
267 268 269 270 271 272 273 274 275

        The shape of the attention weights is:

            - column attention: :math:`(N, H, F, F)`

            - row attention: :math:`(1, H, N, N)`

        where *N* is the batch size, *H* is the number of heads and *F* is the
        number of features/columns in the dataset
276
        """
277
        attention_weights = []
278
        for blk in self.saint_blks:
279 280 281 282
            attention_weights.append(
                (blk.col_attn.attn_weights, blk.row_attn.attn_weights)
            )
        return attention_weights