saint.py 11.0 KB
Newer Older
1 2
from torch import nn

3
from pytorch_widedeep.wdtypes import Dict, List, Tuple, Tensor, Optional
4
from pytorch_widedeep.models.tabular.mlp._layers import MLP
5 6
from pytorch_widedeep.models.tabular._base_tabular_model import (
    BaseTabularModelWithAttention,
7
)
8
from pytorch_widedeep.models.tabular.transformers._encoders import SaintEncoder
9 10


11
class SAINT(BaseTabularModelWithAttention):
12 13
    r"""Defines a [SAINT model](https://arxiv.org/abs/2106.01342) that
    can be used as the `deeptabular` component of a Wide & Deep model or
J
jrzaurin 已提交
14 15
    independently by itself.

16 17
    :information_source: **NOTE**: This is an slightly modified and enhanced
     version of the model described in the paper,
18 19 20 21 22

    Parameters
    ----------
    column_idx: Dict
        Dict containing the index of the columns that will be passed through
J
jrzaurin 已提交
23
        the model. Required to slice the tensors. e.g.
24
        _{'education': 0, 'relationship': 1, 'workclass': 2, ...}_
25
    cat_embed_input: List, Optional, default = None
J
jrzaurin 已提交
26
        List of Tuples with the column name and number of unique values and
27
        embedding dimension. e.g. _[(education, 11), ...]_
28
    cat_embed_dropout: float, default = 0.1
29
        Categorical embeddings dropout
J
jrzaurin 已提交
30
    use_cat_bias: bool, default = False,
J
jrzaurin 已提交
31
        Boolean indicating if bias will be used for the categorical embeddings
32
    cat_embed_activation: Optional, str, default = None,
33 34
        Activation function for the categorical embeddings, if any. _'tanh'_,
        _'relu'_, _'leaky_relu'_ and _'gelu'_ are supported.
35 36 37
    full_embed_dropout: bool, default = False
        Boolean indicating if an entire embedding (i.e. the representation of
        one column) will be dropped in the batch. See:
38 39
        `pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`.
        If `full_embed_dropout = True`, `cat_embed_dropout` is ignored.
40
    shared_embed: bool, default = False
41 42
        The idea behind `shared_embed` is described in the Appendix A in the
        [TabTransformer paper](https://arxiv.org/abs/2012.06678): the
43
        goal of having column embedding is to enable the model to distinguish
44
        the classes in one column from those in the other columns. In other
J
jrzaurin 已提交
45
        words, the idea is to let the model learn which column is embedded
46
        at the time.
J
jrzaurin 已提交
47 48 49
    add_shared_embed: bool, default = False
        The two embedding sharing strategies are: 1) add the shared embeddings
        to the column embeddings or 2) to replace the first
50 51
        `frac_shared_embed` with the shared embeddings.
        See `pytorch_widedeep.models.transformers._layers.SharedEmbeddings`
52
    frac_shared_embed: float, default = 0.25
53 54
        The fraction of embeddings that will be shared (if `add_shared_embed
        = False`) by all the different categories for one particular
J
jrzaurin 已提交
55
        column.
56 57
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
58 59
    cont_norm_layer: str, default =  "batchnorm"
        Type of normalization layer applied to the continuous features. Options
60
        are: _'layernorm'_, _'batchnorm'_ or None.
61 62 63
    cont_embed_dropout: float, default = 0.1,
        Continuous embeddings dropout
    use_cont_bias: bool, default = True,
J
jrzaurin 已提交
64
        Boolean indicating if bias will be used for the continuous embeddings
65
    cont_embed_activation: str, default = None
J
jrzaurin 已提交
66
        Activation function to be applied to the continuous embeddings, if
67
        any. _'tanh'_, _'relu'_, _'leaky_relu'_ and _'gelu'_ are supported.
68
    input_dim: int, default = 32
69
        The so-called *dimension of the model*. Is the number of
J
jrzaurin 已提交
70
        embeddings used to encode the categorical and/or continuous columns
71 72
    n_heads: int, default = 8
        Number of attention heads per Transformer block
J
jrzaurin 已提交
73
    use_qkv_bias: bool, default = False
74 75
        Boolean indicating whether or not to use bias in the Q, K, and V
        projection layers
76
    n_blocks: int, default = 2
77
        Number of SAINT-Transformer blocks.
78
    attn_dropout: float, default = 0.2
J
jrzaurin 已提交
79 80
        Dropout that will be applied to the Multi-Head Attention column and
        row layers
81 82
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
83
    transformer_activation: str, default = "gelu"
84 85
        Transformer Encoder activation function. _'tanh'_, _'relu'_,
        _'leaky_relu'_, _'gelu'_, _'geglu'_ and _'reglu'_ are supported
86
    mlp_hidden_dims: List, Optional, default = None
87 88
        MLP hidden dimensions. If not provided it will default to $[l, 4
        \times l, 2 \times l]$ where $l$ is the MLP's input dimension
89
    mlp_activation: str, default = "relu"
90 91
        MLP activation function. _'tanh'_, _'relu'_, _'leaky_relu'_ and
        _'gelu'_ are supported
92 93
    mlp_dropout: float, default = 0.1
        Dropout that will be applied to the final MLP
94 95 96 97 98 99 100 101
    mlp_batchnorm: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        dense layers
    mlp_batchnorm_last: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers
    mlp_linear_first: bool, default = False
        Boolean indicating whether the order of the operations in the dense
102 103
        layer. If `True: [LIN -> ACT -> BN -> DP]`. If `False: [BN -> DP ->
        LIN -> ACT]`
104 105 106

    Attributes
    ----------
107
    cat_and_cont_embed: nn.Module
J
jrzaurin 已提交
108
        This is the module that processes the categorical and continuous columns
109
    saint_blks: nn.Sequential
J
jrzaurin 已提交
110
        Sequence of SAINT-Transformer blocks
111
    saint_mlp: nn.Module
112 113 114
        MLP component in the model
    output_dim: int
        The output dimension of the model. This is a required attribute
115
        neccesary to build the `WideDeep` class
116

117
    Examples
118 119 120 121 122
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import SAINT
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
123
    >>> cat_embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
124 125
    >>> continuous_cols = ['e']
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
126
    >>> model = SAINT(column_idx=column_idx, cat_embed_input=cat_embed_input, continuous_cols=continuous_cols)
127 128 129 130 131 132
    >>> out = model(X_tab)
    """

    def __init__(
        self,
        column_idx: Dict[str, int],
133 134
        cat_embed_input: Optional[List[Tuple[str, int]]] = None,
        cat_embed_dropout: float = 0.1,
135 136
        use_cat_bias: bool = False,
        cat_embed_activation: Optional[str] = None,
137 138 139 140 141
        full_embed_dropout: bool = False,
        shared_embed: bool = False,
        add_shared_embed: bool = False,
        frac_shared_embed: float = 0.25,
        continuous_cols: Optional[List[str]] = None,
142
        cont_norm_layer: str = None,
143 144 145
        cont_embed_dropout: float = 0.1,
        use_cont_bias: bool = True,
        cont_embed_activation: Optional[str] = None,
146
        input_dim: int = 32,
J
jrzaurin 已提交
147
        use_qkv_bias: bool = False,
148
        n_heads: int = 8,
149
        n_blocks: int = 2,
150 151 152
        attn_dropout: float = 0.1,
        ff_dropout: float = 0.2,
        transformer_activation: str = "gelu",
153 154
        mlp_hidden_dims: Optional[List[int]] = None,
        mlp_activation: str = "relu",
155
        mlp_dropout: float = 0.1,
156 157 158 159
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = True,
    ):
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
        super(SAINT, self).__init__(
            column_idx=column_idx,
            cat_embed_input=cat_embed_input,
            cat_embed_dropout=cat_embed_dropout,
            use_cat_bias=use_cat_bias,
            cat_embed_activation=cat_embed_activation,
            full_embed_dropout=full_embed_dropout,
            shared_embed=shared_embed,
            add_shared_embed=add_shared_embed,
            frac_shared_embed=frac_shared_embed,
            continuous_cols=continuous_cols,
            cont_norm_layer=cont_norm_layer,
            embed_continuous=True,
            cont_embed_dropout=cont_embed_dropout,
            use_cont_bias=use_cont_bias,
            cont_embed_activation=cont_embed_activation,
            input_dim=input_dim,
        )
178

J
jrzaurin 已提交
179
        self.use_qkv_bias = use_qkv_bias
180 181 182 183 184
        self.n_heads = n_heads
        self.n_blocks = n_blocks
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
        self.transformer_activation = transformer_activation
185

186 187
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_activation = mlp_activation
188
        self.mlp_dropout = mlp_dropout
189 190 191 192 193
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
        self.mlp_linear_first = mlp_linear_first

        self.with_cls_token = "cls_token" in column_idx
194
        self.n_cat = len(cat_embed_input) if cat_embed_input is not None else 0
195 196 197
        self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
        self.n_feats = self.n_cat + self.n_cont

J
jrzaurin 已提交
198
        # Embeddings are instantiated at the base model
199
        # Transformer blocks
200
        self.encoder = nn.Sequential()
201
        for i in range(n_blocks):
202
            self.encoder.add_module(
203
                "saint_block" + str(i),
204 205 206
                SaintEncoder(
                    input_dim,
                    n_heads,
J
jrzaurin 已提交
207
                    use_qkv_bias,
208 209
                    attn_dropout,
                    ff_dropout,
210
                    transformer_activation,
211
                    self.n_feats,
212 213
                ),
            )
214

215 216 217 218
        self.mlp_first_hidden_dim = (
            self.input_dim if self.with_cls_token else (self.n_feats * self.input_dim)
        )

219 220
        if mlp_hidden_dims is not None:
            self.mlp = MLP(
221
                [self.mlp_first_hidden_dim] + mlp_hidden_dims,
222 223 224 225 226 227
                mlp_activation,
                mlp_dropout,
                mlp_batchnorm,
                mlp_batchnorm_last,
                mlp_linear_first,
            )
228
        else:
229
            self.mlp = None
230 231

    def forward(self, X: Tensor) -> Tensor:
232
        x = self._get_embeddings(X)
233
        x = self.encoder(x)
234 235 236 237
        if self.with_cls_token:
            x = x[:, 0, :]
        else:
            x = x.flatten(1)
238 239 240 241 242 243 244 245 246
        if self.mlp is not None:
            x = self.mlp(x)
        return x

    @property
    def output_dim(self) -> int:
        return (
            self.mlp_hidden_dims[-1]
            if self.mlp_hidden_dims is not None
247
            else self.mlp_first_hidden_dim
248
        )
249

250
    @property
251 252 253 254
    def attention_weights(self) -> List:
        r"""List with the attention weights. Each element of the list is a tuple
        where the first and the second elements are the column and row
        attention weights respectively
255 256 257

        The shape of the attention weights is:

258
        - column attention: $(N, H, F, F)$
259

260
        - row attention: $(1, H, N, N)$
261

262
        where $N$ is the batch size, $H$ is the number of heads and $F$ is the
263
        number of features/columns in the dataset
264
        """
265
        attention_weights = []
266
        for blk in self.encoder:
267 268 269 270
            attention_weights.append(
                (blk.col_attn.attn_weights, blk.row_attn.attn_weights)
            )
        return attention_weights