saint.py 11.1 KB
Newer Older
1 2 3
from torch import nn

from pytorch_widedeep.wdtypes import *  # noqa: F403
4 5
from pytorch_widedeep.models.tabular.mlp._layers import MLP
from pytorch_widedeep.models.tabular.embeddings_layers import (
6 7
    SameSizeCatAndContEmbeddings,
)
8
from pytorch_widedeep.models.tabular.transformers._encoders import SaintEncoder
9 10


11
class SAINT(nn.Module):
J
jrzaurin 已提交
12 13 14
    r"""Defines a ``SAINT`` model
    (`arXiv:2106.01342 <https://arxiv.org/abs/2106.01342>`_) that can be used
    as the ``deeptabular`` component of a Wide & Deep model.
15 16 17 18 19

    Parameters
    ----------
    column_idx: Dict
        Dict containing the index of the columns that will be passed through
J
jrzaurin 已提交
20 21
        the model. Required to slice the tensors. e.g.
        {'education': 0, 'relationship': 1, 'workclass': 2, ...}
22
    cat_embed_input: List
23
        List of Tuples with the column name and number of unique values
J
jrzaurin 已提交
24
        e.g. [('education', 11), ...]
25
    cat_embed_dropout: float, default = 0.1
26 27 28 29
        Dropout to be applied to the embeddings matrix
    full_embed_dropout: bool, default = False
        Boolean indicating if an entire embedding (i.e. the representation of
        one column) will be dropped in the batch. See:
J
jrzaurin 已提交
30
        :obj:`pytorch_widedeep.models.transformers._layers.FullEmbeddingDropout`.
31
        If ``full_embed_dropout = True``, ``cat_embed_dropout`` is ignored.
32
    shared_embed: bool, default = False
33 34 35 36
        The idea behind ``shared_embed`` is described in the Appendix A in the
        `TabTransformer paper <https://arxiv.org/abs/2012.06678>`_: `'The
        goal of having column embedding is to enable the model to distinguish
        the classes in one column from those in the other columns'`. In other
J
jrzaurin 已提交
37
        words, the idea is to let the model learn which column is embedded
38
        at the time.
J
jrzaurin 已提交
39 40 41 42 43
    add_shared_embed: bool, default = False
        The two embedding sharing strategies are: 1) add the shared embeddings
        to the column embeddings or 2) to replace the first
        ``frac_shared_embed`` with the shared embeddings.
        See :obj:`pytorch_widedeep.models.transformers._layers.SharedEmbeddings`
44
    frac_shared_embed: float, default = 0.25
J
jrzaurin 已提交
45 46 47
        The fraction of embeddings that will be shared (if ``add_shared_embed
        = False``) by all the different categories for one particular
        column.
48 49
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
J
jrzaurin 已提交
50
    embed_continuous_activation: str, default = None
51
        String indicating the activation function to be applied to the
J
jrzaurin 已提交
52 53
        continuous embeddings, if any. ``tanh``, ``relu``, ``leaky_relu`` and
        ``gelu`` are supported.
54 55 56 57
    cont_embed_dropout: float, default = 0.0,
        Dropout for the continuous embeddings
    cont_embed_activation: str,  default = None,
        Activation function for the continuous embeddings
J
jrzaurin 已提交
58 59 60 61
    cont_norm_layer: str, default =  None,
        Type of normalization layer applied to the continuous features before
        they are embedded. Options are: ``layernorm``, ``batchnorm`` or
        ``None``.
62
    input_dim: int, default = 32
J
jrzaurin 已提交
63 64
        The so-called *dimension of the model*. In general is the number of
        embeddings used to encode the categorical and/or continuous columns
65 66
    n_heads: int, default = 8
        Number of attention heads per Transformer block
67 68 69
    use_bias: bool, default = False
        Boolean indicating whether or not to use bias in the Q, K, and V
        projection layers
70
    n_blocks: int, default = 2
J
jrzaurin 已提交
71
        Number of SAINT-Transformer blocks. 1 in the paper.
72
    attn_dropout: float, default = 0.2
J
jrzaurin 已提交
73 74
        Dropout that will be applied to the Multi-Head Attention column and
        row layers
75 76
    ff_dropout: float, default = 0.1
        Dropout that will be applied to the FeedForward network
77
    transformer_activation: str, default = "gelu"
J
jrzaurin 已提交
78 79
        Transformer Encoder activation function. ``tanh``, ``relu``,
        ``leaky_relu``, ``gelu``, ``geglu`` and ``reglu`` are supported
80
    mlp_hidden_dims: List, Optional, default = None
J
jrzaurin 已提交
81 82
        MLP hidden dimensions. If not provided it will default to ``[l, 4*l,
        2*l]`` where ``l`` is the MLP input dimension
83
    mlp_activation: str, default = "relu"
J
jrzaurin 已提交
84 85
        MLP activation function. ``tanh``, ``relu``, ``leaky_relu`` and
        ``gelu`` are supported
86 87
    mlp_dropout: float, default = 0.1
        Dropout that will be applied to the final MLP
88 89 90 91 92 93 94 95 96 97 98 99 100
    mlp_batchnorm: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        dense layers
    mlp_batchnorm_last: bool, default = False
        Boolean indicating whether or not to apply batch normalization to the
        last of the dense layers
    mlp_linear_first: bool, default = False
        Boolean indicating whether the order of the operations in the dense
        layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
        LIN -> ACT]``

    Attributes
    ----------
101
    cat_and_cont_embed: ``nn.Module``
J
jrzaurin 已提交
102
        This is the module that processes the categorical and continuous columns
103
    saint_blks: ``nn.Sequential``
J
jrzaurin 已提交
104
        Sequence of SAINT-Transformer blocks
105
    saint_mlp: ``nn.Module``
106 107 108 109 110 111 112 113 114 115 116
        MLP component in the model
    output_dim: int
        The output dimension of the model. This is a required attribute
        neccesary to build the WideDeep class

    Example
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import SAINT
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
117
    >>> cat_embed_input = [(u,i) for u,i in zip(colnames[:4], [4]*4)]
118 119
    >>> continuous_cols = ['e']
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
120
    >>> model = SAINT(column_idx=column_idx, cat_embed_input=cat_embed_input, continuous_cols=continuous_cols)
121 122 123 124 125 126
    >>> out = model(X_tab)
    """

    def __init__(
        self,
        column_idx: Dict[str, int],
127 128
        cat_embed_input: Optional[List[Tuple[str, int]]] = None,
        cat_embed_dropout: float = 0.1,
129 130 131 132 133
        full_embed_dropout: bool = False,
        shared_embed: bool = False,
        add_shared_embed: bool = False,
        frac_shared_embed: float = 0.25,
        continuous_cols: Optional[List[str]] = None,
134
        embed_continuous_activation: str = None,
135 136
        cont_embed_dropout: float = 0.0,
        cont_embed_activation: str = None,
137
        cont_norm_layer: str = None,
138
        input_dim: int = 32,
139
        use_bias: bool = False,
140
        n_heads: int = 8,
141
        n_blocks: int = 2,
142 143 144
        attn_dropout: float = 0.1,
        ff_dropout: float = 0.2,
        transformer_activation: str = "gelu",
145 146
        mlp_hidden_dims: Optional[List[int]] = None,
        mlp_activation: str = "relu",
147
        mlp_dropout: float = 0.1,
148 149 150 151
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = True,
    ):
152 153 154
        super(SAINT, self).__init__()

        self.column_idx = column_idx
155 156
        self.cat_embed_input = cat_embed_input
        self.cat_embed_dropout = cat_embed_dropout
157 158 159 160
        self.full_embed_dropout = full_embed_dropout
        self.shared_embed = shared_embed
        self.add_shared_embed = add_shared_embed
        self.frac_shared_embed = frac_shared_embed
161

162 163
        self.continuous_cols = continuous_cols
        self.embed_continuous_activation = embed_continuous_activation
164 165
        self.cont_embed_dropout = cont_embed_dropout
        self.cont_embed_activation = cont_embed_activation
166
        self.cont_norm_layer = cont_norm_layer
167

168 169 170 171 172 173 174
        self.input_dim = input_dim
        self.use_bias = use_bias
        self.n_heads = n_heads
        self.n_blocks = n_blocks
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
        self.transformer_activation = transformer_activation
175

176 177
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_activation = mlp_activation
178
        self.mlp_dropout = mlp_dropout
179 180 181 182 183
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
        self.mlp_linear_first = mlp_linear_first

        self.with_cls_token = "cls_token" in column_idx
184
        self.n_cat = len(cat_embed_input) if cat_embed_input is not None else 0
185 186 187
        self.n_cont = len(continuous_cols) if continuous_cols is not None else 0
        self.n_feats = self.n_cat + self.n_cont

188
        self.cat_and_cont_embed = SameSizeCatAndContEmbeddings(
189
            input_dim,
190
            column_idx,
191 192
            cat_embed_input,
            cat_embed_dropout,
193 194 195 196
            full_embed_dropout,
            shared_embed,
            add_shared_embed,
            frac_shared_embed,
197
            False,  # use_embed_bias
198
            continuous_cols,
199
            True,  # embed_continuous,
200
            cont_embed_dropout,
201
            embed_continuous_activation,
202
            True,  # use_cont_bias
203 204 205
            cont_norm_layer,
        )

206 207
        # Transformer bocks
        self.saint_blks = nn.Sequential()
208
        for i in range(n_blocks):
209
            self.saint_blks.add_module(
210
                "saint_block" + str(i),
211 212 213
                SaintEncoder(
                    input_dim,
                    n_heads,
214
                    use_bias,
215 216
                    attn_dropout,
                    ff_dropout,
217
                    transformer_activation,
218
                    self.n_feats,
219 220
                ),
            )
221

222 223 224
        attn_output_dim = (
            self.input_dim if self.with_cls_token else self.n_feats * self.input_dim
        )
225 226

        # Mlp
227 228 229 230 231 232 233
        if not mlp_hidden_dims:
            mlp_hidden_dims = [
                attn_output_dim,
                attn_output_dim * 4,
                attn_output_dim * 2,
            ]
        else:
234 235
            mlp_hidden_dims = [attn_output_dim] + mlp_hidden_dims

236
        self.saint_mlp = MLP(
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
            mlp_hidden_dims,
            mlp_activation,
            mlp_dropout,
            mlp_batchnorm,
            mlp_batchnorm_last,
            mlp_linear_first,
        )

        # the output_dim attribute will be used as input_dim when "merging" the models
        self.output_dim = mlp_hidden_dims[-1]

    def forward(self, X: Tensor) -> Tensor:

        x_cat, x_cont = self.cat_and_cont_embed(X)

        if x_cat is not None:
            x = x_cat
        if x_cont is not None:
            x = torch.cat([x, x_cont], 1) if x_cat is not None else x_cont

257
        x = self.saint_blks(x)
258 259 260 261 262 263

        if self.with_cls_token:
            x = x[:, 0, :]
        else:
            x = x.flatten(1)

264
        return self.saint_mlp(x)
265

266
    @property
267 268 269 270
    def attention_weights(self) -> List:
        r"""List with the attention weights. Each element of the list is a tuple
        where the first and the second elements are the column and row
        attention weights respectively
271 272 273 274 275 276 277 278 279

        The shape of the attention weights is:

            - column attention: :math:`(N, H, F, F)`

            - row attention: :math:`(1, H, N, N)`

        where *N* is the batch size, *H* is the number of heads and *F* is the
        number of features/columns in the dataset
280
        """
281
        attention_weights = []
282
        for blk in self.saint_blks:
283 284 285 286
            attention_weights.append(
                (blk.col_attn.attn_weights, blk.row_attn.attn_weights)
            )
        return attention_weights