tab_net.py 8.8 KB
Newer Older
J
jrzaurin 已提交
1 2 3 4
import torch
from torch import nn

from pytorch_widedeep.wdtypes import *  # noqa: F403
5
from pytorch_widedeep.models.tabular.tabnet._layers import (
6 7 8
    TabNetEncoder,
    initialize_non_glu,
)
9 10
from pytorch_widedeep.models.tabular._base_tabular_model import (
    BaseTabularModelWithoutAttention,
11
)
12 13


14
class TabNet(BaseTabularModelWithoutAttention):
15
    r"""Defines a ``TabNet`` model (https://arxiv.org/abs/1908.07442)
J
jrzaurin 已提交
16 17
    that can be used as the ``deeptabular`` component of a Wide & Deep
    model.
18 19 20 21 22 23 24 25 26

    The implementation in this library is fully based on that here:
    https://github.com/dreamquark-ai/tabnet, simply adapted so that it can
    work within the ``WideDeep`` frame. Therefore, **all credit to the
    dreamquark-ai team**

    Parameters
    ----------
    column_idx: Dict
27
        Dict containing the index of the columns that will be passed through
28
        the ``TabMlp`` model. Required to slice the tensors. e.g. {'education':
29
        0, 'relationship': 1, 'workclass': 2, ...}
30
    cat_embed_input: List, Optional, default = None
31 32
        List of Tuples with the column name, number of unique values and
        embedding dimension. e.g. [(education, 11, 32), ...]
33
    cat_embed_dropout: float, default = 0.1
34
        Categorical embeddings dropout
35 36 37 38
    use_cat_bias: bool, default = True,
        Boolean indicating in bias will be used for the categorical embeddings
    cat_embed_activation: Optional, str, default = None,
        Activation function for the categorical embeddings
39 40
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
41 42 43
    cont_norm_layer: str, default =  "batchnorm"
        Type of normalization layer applied to the continuous features. Options
        are: 'layernorm', 'batchnorm' or None.
44 45 46 47 48 49
    embed_continuous: bool, default = False,
        Boolean indicating if the continuous columns will be embedded
        (i.e. passed each through a linear layer with or without activation)
    cont_embed_dim: int, default = 32,
        Size of the continuous embeddings
    cont_embed_dropout: float, default = 0.1,
50
        Dropout for the continuous embeddings
51 52
    use_cont_bias: bool, default = True,
        Boolean indicating in bias will be used for the continuous embeddings
53 54
    cont_embed_activation: Optional, str, default = None,
        Activation function for the continuous embeddings
55 56 57 58 59 60 61 62 63 64
    n_steps: int, default = 3
        number of decision steps
    step_dim: int, default = 8
        Step's output dimension. This is the output dimension that
        ``WideDeep`` will collect and connect to the output neuron(s). For
        a better understanding of the function of this and the upcoming
        parameters, please see the `paper
        <https://arxiv.org/abs/1908.07442>`_.
    attn_dim: int, default = 8
        Attention dimension
65
    dropout: float, default = 0.0
J
jrzaurin 已提交
66
        GLU block's internal dropout
67 68 69 70 71 72 73
    n_glu_step_dependent: int, default = 2
        number of GLU Blocks [FC -> BN -> GLU] that are step dependent
    n_glu_shared: int, default = 2
        number of GLU Blocks [FC -> BN -> GLU] that will be shared
        across decision steps
    ghost_bn: bool, default=True
        Boolean indicating if `Ghost Batch Normalization
74
        <https://arxiv.org/abs/1705.08741>`_ will be used.
75 76 77
    virtual_batch_size: int, default = 128
        Batch size when using Ghost Batch Normalization
    momentum: float, default = 0.02
78
        Ghost Batch Normalization's momentum. The dreamquark-ai advises for
J
jrzaurin 已提交
79 80
        very low values. However high values are used in the original
        publication. During our tests higher values lead to better results
81 82
    gamma: float, default = 1.3
        Relaxation parameter in the paper. When gamma = 1, a feature is
J
jrzaurin 已提交
83 84 85
        enforced to be used only at one decision step. As gamma increases,
        more flexibility is provided to use a feature at multiple decision
        steps
86 87 88 89 90 91 92
    epsilon: float, default = 1e-15
        Float to avoid log(0). Always keep low
    mask_type: str, default = "sparsemax"
        Mask function to use. Either "sparsemax" or "entmax"

    Attributes
    ----------
93
    cat_and_cont_embed: ``nn.Module``
J
jrzaurin 已提交
94
        This is the module that processes the categorical and continuous columns
95 96 97
    tabnet_encoder: ``nn.Module``
        ``Module`` containing the TabNet encoder. See the `paper
        <https://arxiv.org/abs/1908.07442>`_.
J
jrzaurin 已提交
98 99 100
    output_dim: int
        The output dimension of the model. This is a required attribute
        neccesary to build the WideDeep class
101 102 103 104 105 106 107

    Example
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import TabNet
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
108
    >>> cat_embed_input = [(u,i,j) for u,i,j in zip(colnames[:4], [4]*4, [8]*4)]
109
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
110
    >>> model = TabNet(column_idx=column_idx, cat_embed_input=cat_embed_input, continuous_cols = ['e'])
111 112
    """

113 114 115
    def __init__(
        self,
        column_idx: Dict[str, int],
116 117
        cat_embed_input: Optional[List[Tuple[str, int, int]]] = None,
        cat_embed_dropout: float = 0.1,
118 119
        use_cat_bias: bool = False,
        cat_embed_activation: Optional[str] = None,
120
        continuous_cols: Optional[List[str]] = None,
121
        cont_norm_layer: str = None,
122 123 124 125
        embed_continuous: bool = False,
        cont_embed_dim: int = 32,
        cont_embed_dropout: float = 0.1,
        use_cont_bias: bool = True,
126
        cont_embed_activation: Optional[str] = None,
127 128 129
        n_steps: int = 3,
        step_dim: int = 8,
        attn_dim: int = 8,
130
        dropout: float = 0.0,
131 132 133 134 135 136 137 138
        n_glu_step_dependent: int = 2,
        n_glu_shared: int = 2,
        ghost_bn: bool = True,
        virtual_batch_size: int = 128,
        momentum: float = 0.02,
        gamma: float = 1.3,
        epsilon: float = 1e-15,
        mask_type: str = "sparsemax",
139
    ):
140 141 142 143 144 145 146 147 148 149 150 151 152 153
        super(TabNet, self).__init__(
            column_idx=column_idx,
            cat_embed_input=cat_embed_input,
            cat_embed_dropout=cat_embed_dropout,
            use_cat_bias=use_cat_bias,
            cat_embed_activation=cat_embed_activation,
            continuous_cols=continuous_cols,
            cont_norm_layer=cont_norm_layer,
            embed_continuous=embed_continuous,
            cont_embed_dim=cont_embed_dim,
            cont_embed_dropout=cont_embed_dropout,
            use_cont_bias=use_cont_bias,
            cont_embed_activation=cont_embed_activation,
        )
154

155
        self.n_steps = n_steps
156 157
        self.step_dim = step_dim
        self.attn_dim = attn_dim
158
        self.dropout = dropout
159 160 161 162 163 164 165 166
        self.n_glu_step_dependent = n_glu_step_dependent
        self.n_glu_shared = n_glu_shared
        self.ghost_bn = ghost_bn
        self.virtual_batch_size = virtual_batch_size
        self.momentum = momentum
        self.gamma = gamma
        self.epsilon = epsilon
        self.mask_type = mask_type
167

168
        # Embeddings are be instantiated at the base model
169
        self.embed_out_dim = self.cat_and_cont_embed.output_dim
170 171

        # TabNet
172
        self.tabnet_encoder = TabNetEncoder(
173
            self.embed_out_dim,
174
            n_steps,
175 176
            step_dim,
            attn_dim,
177
            dropout,
178 179 180 181 182 183 184 185 186 187 188
            n_glu_step_dependent,
            n_glu_shared,
            ghost_bn,
            virtual_batch_size,
            momentum,
            gamma,
            epsilon,
            mask_type,
        )
        self.output_dim = step_dim

189
    def forward(self, X: Tensor) -> Tuple[Tensor, Tensor]:
190
        x = self._get_embeddings(X)
191 192 193 194
        steps_output, M_loss = self.tabnet_encoder(x)
        res = torch.sum(torch.stack(steps_output, dim=0), dim=0)
        return (res, M_loss)

195
    def forward_masks(self, X: Tensor) -> Tuple[Tensor, Dict[int, Tensor]]:
196
        x = self._get_embeddings(X)
197
        return self.tabnet_encoder.forward_masks(x)
198 199 200 201


class TabNetPredLayer(nn.Module):
    def __init__(self, inp, out):
202 203 204 205 206 207 208 209 210 211
        r"""This class is a 'hack' required because TabNet is a very particular
        model within ``WideDeep``.

        TabNet's forward method within ``WideDeep`` outputs two tensors, one
        with the last layer's activations and the sparse regularization
        factor. Since the output needs to be collected by ``WideDeep`` to then
        Sequentially build the output layer (connection to the output
        neuron(s)) I need to code a custom TabNetPredLayer that accepts two
        inputs. This will be used by the ``WideDeep`` class.
        """
212 213 214 215
        super(TabNetPredLayer, self).__init__()
        self.pred_layer = nn.Linear(inp, out, bias=False)
        initialize_non_glu(self.pred_layer, inp, out)

216
    def forward(self, tabnet_tuple: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
217 218
        res, M_loss = tabnet_tuple[0], tabnet_tuple[1]
        return self.pred_layer(res), M_loss