saint.py 10.9 KB
Newer Older
1 2 3
from torch import nn

from pytorch_widedeep.wdtypes import *  # noqa: F403
4
from pytorch_widedeep.models.tabular.mlp._layers import MLP
5 6
from pytorch_widedeep.models.tabular._base_tabular_model import (
    BaseTabularModelWithAttention,
7
)
8
from pytorch_widedeep.models.tabular.transformers._encoders import SaintEncoder
9 10


11
class SAINT(BaseTabularModelWithAttention):
J
jrzaurin 已提交
12 13 14 15
    r"""Defines a `SAINT model <https://arxiv.org/abs/2106.01342>`_ that
    can be used as the ``deeptabular`` component of a Wide & Deep model or
    independently by itself.

16 17 18 19 20

    Parameters
    ----------
    column_idx: Dict
        Dict containing the index of the columns that will be passed through
J
jrzaurin 已提交
21 22
        the model. Required to slice the tensors. e.g.
        {'education': 0, 'relationship': 1, 'workclass': 2, ...}
23
    cat_embed_input: List, Optional, default = None
J
jrzaurin 已提交
24 25
        List of Tuples with the column name and number of unique values and
        embedding dimension. e.g. [(education, 11), ...]
26
    cat_embed_dropout: float, default = 0.1
27
        Categorical embeddings dropout
J
jrzaurin 已提交
28
    use_cat_bias: bool, default = False,
J
jrzaurin 已提交
29
        Boolean indicating if bias will be used for the categorical embeddings
30
    cat_embed_activation: Optional, str, default = None,
J
jrzaurin 已提交
31 32
        Activation function for the categorical embeddings, if any. `'tanh'`,
        `'relu'`, `'leaky_relu'` and `'gelu'` are supported.
33 34 35
    full_embed_dropout: bool, default = False
        Boolean indicating if an entire embedding (i.e. the representation of
        one column) will be dropped in the batch. See:
J
jrzaurin 已提交
36
        :obj:`pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`.
37
        If ``full_embed_dropout = True``, ``cat_embed_dropout`` is ignored.
38
    shared_embed: bool, default = False
39 40 41 42
        The idea behind ``shared_embed`` is described in the Appendix A in the
        `TabTransformer paper <https://arxiv.org/abs/2012.06678>`_: `'The
        goal of having column embedding is to enable the model to distinguish
        the classes in one column from those in the other columns'`. In other
J
jrzaurin 已提交
43
        words, the idea is to let the model learn which column is embedded
44
        at the time.
J
jrzaurin 已提交
45 46 47 48 49
    add_shared_embed: bool, default = False
        The two embedding sharing strategies are: 1) add the shared embeddings
        to the column embeddings or 2) to replace the first
        ``frac_shared_embed`` with the shared embeddings.
        See :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings`
50
    frac_shared_embed: float, default = 0.25
J
jrzaurin 已提交
51 52 53
        The fraction of embeddings that will be shared (if ``add_shared_embed
        = False``) by all the different categories for one particular
        column.
54 55
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
56 57 58 59 60 61
    cont_norm_layer: str, default =  "batchnorm"
        Type of normalization layer applied to the continuous features. Options
        are: 'layernorm', 'batchnorm' or None.
    cont_embed_dropout: float, default = 0.1,
        Continuous embeddings dropout
    use_cont_bias: bool, default = True,
J
jrzaurin 已提交
62
        Boolean indicating if bias will be used for the continuous embeddings
63
    cont_embed_activation: str, default = None
J
jrzaurin 已提交
64
        Activation function to be applied to the continuous embeddings, if
65
        any. `'tanh'`, `'relu'`, `'leaky_relu'` and `'gelu'` are supported.
66
    input_dim: int, default = 32
67
        The so-called *dimension of the model*. Is the number of
J
jrzaurin 已提交
68
        embeddings used to encode the categorical and/or continuous columns
69 70
    n_heads: int, default = 8
        Number of attention heads per Transformer block
J
jrzaurin 已提交
71
    use_qkv_bias: bool, default = False
72 73
        Boolean indicating whether or not to use bias in the Q, K, and V
        projection layers
74
    n_blocks: int, default = 2
J
jrzaurin 已提交
75
        Number of SAINT-Transformer blocks. 1 in the paper.
76
    attn_dropout: float, default = 0.2
J
jrzaurin 已提交
77 78
        Dropout that will be applied to the Multi-Head Attention column and
        row layers
79 80
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
81
    transformer_activation: str, default = "gelu"
82
        Transformer Encoder activation function. `'tanh'`, `'relu'`,
J
jrzaurin 已提交
83
        `'leaky_relu'`, `'gelu'`, `'geglu'` and `'reglu'` are supported
84
    mlp_hidden_dims: List, Optional, default = None
J
jrzaurin 已提交
85
        MLP hidden dimensions. If not provided it will default to ``[l, 4*l,
J
jrzaurin 已提交
86
        2*l]`` where ``l`` is the MLP's input dimension
87
    mlp_activation: str, default = "relu"
88
        MLP activation function. `'tanh'`, `'relu'`, `'leaky_relu'` and
J
jrzaurin 已提交
89
        `'gelu'` are supported
90 91
    mlp_dropout: float, default = 0.1
        Dropout that will be applied to the final MLP
92 93 94 95 96 97 98 99 100 101 102 103 104
    mlp_batchnorm: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        dense layers
    mlp_batchnorm_last: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers
    mlp_linear_first: bool, default = False
        Boolean indicating whether the order of the operations in the dense
        layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
        LIN -> ACT]``

    Attributes
    ----------
105
    cat_and_cont_embed: ``nn.Module``
J
jrzaurin 已提交
106
        This is the module that processes the categorical and continuous columns
107
    encoder: ``nn.Sequential``
J
jrzaurin 已提交
108
        Sequence of SAINT-Transformer blocks
109
    mlp: ``nn.Module``
110 111 112
        MLP component in the model
    output_dim: int
        The output dimension of the model. This is a required attribute
J
jrzaurin 已提交
113
        neccesary to build the ``WideDeep`` class
114 115 116 117 118 119 120

    Example
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import SAINT
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
121
    >>> cat_embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
122 123
    >>> continuous_cols = ['e']
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
124
    >>> model = SAINT(column_idx=column_idx, cat_embed_input=cat_embed_input, continuous_cols=continuous_cols)
125 126 127 128 129 130
    >>> out = model(X_tab)
    """

    def __init__(
        self,
        column_idx: Dict[str, int],
131 132
        cat_embed_input: Optional[List[Tuple[str, int]]] = None,
        cat_embed_dropout: float = 0.1,
133 134
        use_cat_bias: bool = False,
        cat_embed_activation: Optional[str] = None,
135 136 137 138 139
        full_embed_dropout: bool = False,
        shared_embed: bool = False,
        add_shared_embed: bool = False,
        frac_shared_embed: float = 0.25,
        continuous_cols: Optional[List[str]] = None,
140
        cont_norm_layer: str = None,
141 142 143
        cont_embed_dropout: float = 0.1,
        use_cont_bias: bool = True,
        cont_embed_activation: Optional[str] = None,
144
        input_dim: int = 32,
J
jrzaurin 已提交
145
        use_qkv_bias: bool = False,
146
        n_heads: int = 8,
147
        n_blocks: int = 2,
148 149 150
        attn_dropout: float = 0.1,
        ff_dropout: float = 0.2,
        transformer_activation: str = "gelu",
151 152
        mlp_hidden_dims: Optional[List[int]] = None,
        mlp_activation: str = "relu",
153
        mlp_dropout: float = 0.1,
154 155 156 157
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = True,
    ):
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
        super(SAINT, self).__init__(
            column_idx=column_idx,
            cat_embed_input=cat_embed_input,
            cat_embed_dropout=cat_embed_dropout,
            use_cat_bias=use_cat_bias,
            cat_embed_activation=cat_embed_activation,
            full_embed_dropout=full_embed_dropout,
            shared_embed=shared_embed,
            add_shared_embed=add_shared_embed,
            frac_shared_embed=frac_shared_embed,
            continuous_cols=continuous_cols,
            cont_norm_layer=cont_norm_layer,
            embed_continuous=True,
            cont_embed_dropout=cont_embed_dropout,
            use_cont_bias=use_cont_bias,
            cont_embed_activation=cont_embed_activation,
            input_dim=input_dim,
        )
176

J
jrzaurin 已提交
177
        self.use_qkv_bias = use_qkv_bias
178 179 180 181 182
        self.n_heads = n_heads
        self.n_blocks = n_blocks
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
        self.transformer_activation = transformer_activation
183

184 185
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_activation = mlp_activation
186
        self.mlp_dropout = mlp_dropout
187 188 189 190 191
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
        self.mlp_linear_first = mlp_linear_first

        self.with_cls_token = "cls_token" in column_idx
192
        self.n_cat = len(cat_embed_input) if cat_embed_input is not None else 0
193 194 195
        self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
        self.n_feats = self.n_cat + self.n_cont

J
jrzaurin 已提交
196
        # Embeddings are instantiated at the base model
197
        # Transformer blocks
198
        self.encoder = nn.Sequential()
199
        for i in range(n_blocks):
200
            self.encoder.add_module(
201
                "saint_block" + str(i),
202 203 204
                SaintEncoder(
                    input_dim,
                    n_heads,
J
jrzaurin 已提交
205
                    use_qkv_bias,
206 207
                    attn_dropout,
                    ff_dropout,
208
                    transformer_activation,
209
                    self.n_feats,
210 211
                ),
            )
212

213 214 215 216 217 218 219 220 221
        if mlp_hidden_dims is not None:
            self.mlp = MLP(
                [self.encoder_output_dim] + mlp_hidden_dims,
                mlp_activation,
                mlp_dropout,
                mlp_batchnorm,
                mlp_batchnorm_last,
                mlp_linear_first,
            )
222
        else:
223
            self.mlp = None
224 225

    def forward(self, X: Tensor) -> Tensor:
226
        x = self._get_embeddings(X)
227
        x = self.encoder(x)
228 229 230 231
        if self.with_cls_token:
            x = x[:, 0, :]
        else:
            x = x.flatten(1)
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
        if self.mlp is not None:
            x = self.mlp(x)
        return x

    @property
    def encoder_output_dim(self) -> int:
        return (
            self.input_dim if self.with_cls_token else (self.n_feats * self.input_dim)
        )

    @property
    def output_dim(self) -> int:
        return (
            self.mlp_hidden_dims[-1]
            if self.mlp_hidden_dims is not None
            else self.encoder_output_dim
        )
249

250
    @property
251 252 253 254
    def attention_weights(self) -> List:
        r"""List with the attention weights. Each element of the list is a tuple
        where the first and the second elements are the column and row
        attention weights respectively
255 256 257 258 259 260 261 262 263

        The shape of the attention weights is:

            - column attention: :math:`(N, H, F, F)`

            - row attention: :math:`(1, H, N, N)`

        where *N* is the batch size, *H* is the number of heads and *F* is the
        number of features/columns in the dataset
264
        """
265
        attention_weights = []
266
        for blk in self.encoder:
267 268 269 270
            attention_weights.append(
                (blk.col_attn.attn_weights, blk.row_attn.attn_weights)
            )
        return attention_weights