saint.py 11.2 KB
Newer Older
1 2
from torch import nn

3
from pytorch_widedeep.wdtypes import Dict, List, Tuple, Tensor, Optional
4
from pytorch_widedeep.models.tabular.mlp._layers import MLP
5 6
from pytorch_widedeep.models.tabular._base_tabular_model import (
    BaseTabularModelWithAttention,
7
)
8
from pytorch_widedeep.models.tabular.transformers._encoders import SaintEncoder
9 10


11
class SAINT(BaseTabularModelWithAttention):
12 13
    r"""Defines a [SAINT model](https://arxiv.org/abs/2106.01342) that
    can be used as the `deeptabular` component of a Wide & Deep model or
J
jrzaurin 已提交
14 15
    independently by itself.

16 17
    :information_source: **NOTE**: This is an slightly modified and enhanced
     version of the model described in the paper,
18 19 20 21 22

    Parameters
    ----------
    column_idx: Dict
        Dict containing the index of the columns that will be passed through
J
jrzaurin 已提交
23
        the model. Required to slice the tensors. e.g.
24
        _{'education': 0, 'relationship': 1, 'workclass': 2, ...}_
25
    cat_embed_input: List, Optional, default = None
J
jrzaurin 已提交
26
        List of Tuples with the column name and number of unique values and
27
        embedding dimension. e.g. _[(education, 11), ...]_
28
    cat_embed_dropout: float, default = 0.1
29
        Categorical embeddings dropout
J
jrzaurin 已提交
30
    use_cat_bias: bool, default = False,
J
jrzaurin 已提交
31
        Boolean indicating if bias will be used for the categorical embeddings
32
    cat_embed_activation: Optional, str, default = None,
33 34
        Activation function for the categorical embeddings, if any. _'tanh'_,
        _'relu'_, _'leaky_relu'_ and _'gelu'_ are supported.
35 36 37
    full_embed_dropout: bool, default = False
        Boolean indicating if an entire embedding (i.e. the representation of
        one column) will be dropped in the batch. See:
38 39
        `pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`.
        If `full_embed_dropout = True`, `cat_embed_dropout` is ignored.
40
    shared_embed: bool, default = False
41 42
        The idea behind `shared_embed` is described in the Appendix A in the
        [TabTransformer paper](https://arxiv.org/abs/2012.06678): the
43
        goal of having column embedding is to enable the model to distinguish
44
        the classes in one column from those in the other columns. In other
J
jrzaurin 已提交
45
        words, the idea is to let the model learn which column is embedded
46
        at the time.
J
jrzaurin 已提交
47 48 49
    add_shared_embed: bool, default = False
        The two embedding sharing strategies are: 1) add the shared embeddings
        to the column embeddings or 2) to replace the first
50 51
        `frac_shared_embed` with the shared embeddings.
        See `pytorch_widedeep.models.transformers._layers.SharedEmbeddings`
52
    frac_shared_embed: float, default = 0.25
53 54
        The fraction of embeddings that will be shared (if `add_shared_embed
        = False`) by all the different categories for one particular
J
jrzaurin 已提交
55
        column.
56 57
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
58 59
    cont_norm_layer: str, default =  "batchnorm"
        Type of normalization layer applied to the continuous features. Options
60
        are: _'layernorm'_, _'batchnorm'_ or None.
61 62 63
    cont_embed_dropout: float, default = 0.1,
        Continuous embeddings dropout
    use_cont_bias: bool, default = True,
J
jrzaurin 已提交
64
        Boolean indicating if bias will be used for the continuous embeddings
65
    cont_embed_activation: str, default = None
J
jrzaurin 已提交
66
        Activation function to be applied to the continuous embeddings, if
67
        any. _'tanh'_, _'relu'_, _'leaky_relu'_ and _'gelu'_ are supported.
68
    input_dim: int, default = 32
69
        The so-called *dimension of the model*. Is the number of
J
jrzaurin 已提交
70
        embeddings used to encode the categorical and/or continuous columns
71 72
    n_heads: int, default = 8
        Number of attention heads per Transformer block
J
jrzaurin 已提交
73
    use_qkv_bias: bool, default = False
74 75
        Boolean indicating whether or not to use bias in the Q, K, and V
        projection layers
76
    n_blocks: int, default = 2
77
        Number of SAINT-Transformer blocks.
78
    attn_dropout: float, default = 0.2
J
jrzaurin 已提交
79 80
        Dropout that will be applied to the Multi-Head Attention column and
        row layers
81 82
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
83 84 85
    ff_factor: float, default = 4
        Multiplicative factor applied to the first layer of the FF network in
        each Transformer block, This is normally set to 4.
86
    transformer_activation: str, default = "gelu"
87 88
        Transformer Encoder activation function. _'tanh'_, _'relu'_,
        _'leaky_relu'_, _'gelu'_, _'geglu'_ and _'reglu'_ are supported
89
    mlp_hidden_dims: List, Optional, default = None
90 91
        MLP hidden dimensions. If not provided it will default to $[l, 4
        \times l, 2 \times l]$ where $l$ is the MLP's input dimension
92
    mlp_activation: str, default = "relu"
93 94
        MLP activation function. _'tanh'_, _'relu'_, _'leaky_relu'_ and
        _'gelu'_ are supported
95 96
    mlp_dropout: float, default = 0.1
        Dropout that will be applied to the final MLP
97 98 99 100 101 102 103 104
    mlp_batchnorm: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        dense layers
    mlp_batchnorm_last: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers
    mlp_linear_first: bool, default = False
        Boolean indicating whether the order of the operations in the dense
105 106
        layer. If `True: [LIN -> ACT -> BN -> DP]`. If `False: [BN -> DP ->
        LIN -> ACT]`
107 108 109

    Attributes
    ----------
110
    cat_and_cont_embed: nn.Module
J
jrzaurin 已提交
111
        This is the module that processes the categorical and continuous columns
112
    encoder: nn.Module
J
jrzaurin 已提交
113
        Sequence of SAINT-Transformer blocks
114
    mlp: nn.Module
115 116
        MLP component in the model

117
    Examples
118 119 120 121 122
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import SAINT
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
123
    >>> cat_embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
124 125
    >>> continuous_cols = ['e']
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
126
    >>> model = SAINT(column_idx=column_idx, cat_embed_input=cat_embed_input, continuous_cols=continuous_cols)
127 128 129 130 131 132
    >>> out = model(X_tab)
    """

    def __init__(
        self,
        column_idx: Dict[str, int],
133 134
        cat_embed_input: Optional[List[Tuple[str, int]]] = None,
        cat_embed_dropout: float = 0.1,
135 136
        use_cat_bias: bool = False,
        cat_embed_activation: Optional[str] = None,
137 138 139 140 141
        full_embed_dropout: bool = False,
        shared_embed: bool = False,
        add_shared_embed: bool = False,
        frac_shared_embed: float = 0.25,
        continuous_cols: Optional[List[str]] = None,
142
        cont_norm_layer: str = None,
143 144 145
        cont_embed_dropout: float = 0.1,
        use_cont_bias: bool = True,
        cont_embed_activation: Optional[str] = None,
146
        input_dim: int = 32,
J
jrzaurin 已提交
147
        use_qkv_bias: bool = False,
148
        n_heads: int = 8,
149
        n_blocks: int = 2,
150 151
        attn_dropout: float = 0.1,
        ff_dropout: float = 0.2,
152
        ff_factor: int = 4,
153
        transformer_activation: str = "gelu",
154 155
        mlp_hidden_dims: Optional[List[int]] = None,
        mlp_activation: str = "relu",
156
        mlp_dropout: float = 0.1,
157 158 159 160
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = True,
    ):
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
        super(SAINT, self).__init__(
            column_idx=column_idx,
            cat_embed_input=cat_embed_input,
            cat_embed_dropout=cat_embed_dropout,
            use_cat_bias=use_cat_bias,
            cat_embed_activation=cat_embed_activation,
            full_embed_dropout=full_embed_dropout,
            shared_embed=shared_embed,
            add_shared_embed=add_shared_embed,
            frac_shared_embed=frac_shared_embed,
            continuous_cols=continuous_cols,
            cont_norm_layer=cont_norm_layer,
            embed_continuous=True,
            cont_embed_dropout=cont_embed_dropout,
            use_cont_bias=use_cont_bias,
            cont_embed_activation=cont_embed_activation,
            input_dim=input_dim,
        )
179

J
jrzaurin 已提交
180
        self.use_qkv_bias = use_qkv_bias
181 182 183 184
        self.n_heads = n_heads
        self.n_blocks = n_blocks
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
185
        self.ff_factor = ff_factor
186
        self.transformer_activation = transformer_activation
187

188 189
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_activation = mlp_activation
190
        self.mlp_dropout = mlp_dropout
191 192 193 194 195
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
        self.mlp_linear_first = mlp_linear_first

        self.with_cls_token = "cls_token" in column_idx
196
        self.n_cat = len(cat_embed_input) if cat_embed_input is not None else 0
197 198 199
        self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
        self.n_feats = self.n_cat + self.n_cont

J
jrzaurin 已提交
200
        # Embeddings are instantiated at the base model
201
        # Transformer blocks
202
        self.encoder = nn.Sequential()
203
        for i in range(n_blocks):
204
            self.encoder.add_module(
205
                "saint_block" + str(i),
206 207 208
                SaintEncoder(
                    input_dim,
                    n_heads,
J
jrzaurin 已提交
209
                    use_qkv_bias,
210 211
                    attn_dropout,
                    ff_dropout,
212
                    ff_factor,
213
                    transformer_activation,
214
                    self.n_feats,
215 216
                ),
            )
217

218 219 220 221
        self.mlp_first_hidden_dim = (
            self.input_dim if self.with_cls_token else (self.n_feats * self.input_dim)
        )

222 223
        if mlp_hidden_dims is not None:
            self.mlp = MLP(
224
                [self.mlp_first_hidden_dim] + mlp_hidden_dims,
225 226 227 228 229 230
                mlp_activation,
                mlp_dropout,
                mlp_batchnorm,
                mlp_batchnorm_last,
                mlp_linear_first,
            )
231
        else:
232
            self.mlp = None
233 234

    def forward(self, X: Tensor) -> Tensor:
235
        x = self._get_embeddings(X)
236
        x = self.encoder(x)
237 238 239 240
        if self.with_cls_token:
            x = x[:, 0, :]
        else:
            x = x.flatten(1)
241 242 243 244 245 246
        if self.mlp is not None:
            x = self.mlp(x)
        return x

    @property
    def output_dim(self) -> int:
247 248 249
        r"""The output dimension of the model. This is a required property
        neccesary to build the `WideDeep` class
        """
250 251 252
        return (
            self.mlp_hidden_dims[-1]
            if self.mlp_hidden_dims is not None
253
            else self.mlp_first_hidden_dim
254
        )
255

256
    @property
J
Javier 已提交
257
    def attention_weights(self) -> List[Tuple[Tensor, Tensor]]:
258 259 260
        r"""List with the attention weights. Each element of the list is a tuple
        where the first and the second elements are the column and row
        attention weights respectively
261 262 263

        The shape of the attention weights is:

264
        - column attention: $(N, H, F, F)$
265

266
        - row attention: $(1, H, N, N)$
267

268
        where $N$ is the batch size, $H$ is the number of heads and $F$ is the
269
        number of features/columns in the dataset
270
        """
271
        attention_weights = []
272
        for blk in self.encoder:
273 274 275 276
            attention_weights.append(
                (blk.col_attn.attn_weights, blk.row_attn.attn_weights)
            )
        return attention_weights