tab_net.py 11.4 KB
Newer Older
J
jrzaurin 已提交
1 2 3 4
import torch
from torch import nn

from pytorch_widedeep.wdtypes import *  # noqa: F403
5
from pytorch_widedeep.models.tabular.tabnet._layers import (
6
    TabNetEncoder,
7
    FeatTransformer,
8 9
    initialize_non_glu,
)
10 11
from pytorch_widedeep.models.tabular._base_tabular_model import (
    BaseTabularModelWithoutAttention,
12
)
13 14


15
class TabNet(BaseTabularModelWithoutAttention):
16

J
jrzaurin 已提交
17 18 19 20 21 22 23
    r"""Defines a `TabNet model <https://arxiv.org/abs/1908.07442>`_ that
    can be used as the ``deeptabular`` component of a Wide & Deep model or
    independently by itself.

    The implementation in this library is fully based on that `here
    <https://github.com/dreamquark-ai/tabnet>`_, simply adapted so that it
    can work within the ``WideDeep`` frame. Therefore, **all credit to the
24 25 26 27 28
    dreamquark-ai team**

    Parameters
    ----------
    column_idx: Dict
29
        Dict containing the index of the columns that will be passed through
30
        the ``TabMlp`` model. Required to slice the tensors. e.g. {'education':
31
        0, 'relationship': 1, 'workclass': 2, ...}
32
    cat_embed_input: List, Optional, default = None
33 34
        List of Tuples with the column name, number of unique values and
        embedding dimension. e.g. [(education, 11, 32), ...]
35
    cat_embed_dropout: float, default = 0.1
36
        Categorical embeddings dropout
J
jrzaurin 已提交
37
    use_cat_bias: bool, default = False,
J
jrzaurin 已提交
38
        Boolean indicating if bias will be used for the categorical embeddings
39
    cat_embed_activation: Optional, str, default = None,
40
        Activation function for the categorical embeddings, if any. `'tanh'`,
J
jrzaurin 已提交
41
        `'relu'`, `'leaky_relu'` and `'gelu'` are supported.
42 43
    continuous_cols: List, Optional, default = None
        List with the name of the numeric (aka continuous) columns
44 45 46
    cont_norm_layer: str, default =  "batchnorm"
        Type of normalization layer applied to the continuous features. Options
        are: 'layernorm', 'batchnorm' or None.
47 48 49 50 51 52
    embed_continuous: bool, default = False,
        Boolean indicating if the continuous columns will be embedded
        (i.e. passed each through a linear layer with or without activation)
    cont_embed_dim: int, default = 32,
        Size of the continuous embeddings
    cont_embed_dropout: float, default = 0.1,
53
        Dropout for the continuous embeddings
54
    use_cont_bias: bool, default = True,
J
jrzaurin 已提交
55
        Boolean indicating if bias will be used for the continuous embeddings
56
    cont_embed_activation: Optional, str, default = None,
57
        Activation function for the continuous embeddings, if any. `'tanh'`,
J
jrzaurin 已提交
58
        `'relu'`, `'leaky_relu'` and `'gelu'` are supported.
59 60 61 62 63 64 65 66 67 68
    n_steps: int, default = 3
        number of decision steps
    step_dim: int, default = 8
        Step's output dimension. This is the output dimension that
        ``WideDeep`` will collect and connect to the output neuron(s). For
        a better understanding of the function of this and the upcoming
        parameters, please see the `paper
        <https://arxiv.org/abs/1908.07442>`_.
    attn_dim: int, default = 8
        Attention dimension
69
    dropout: float, default = 0.0
J
jrzaurin 已提交
70
        GLU block's internal dropout
71 72 73 74 75 76 77
    n_glu_step_dependent: int, default = 2
        number of GLU Blocks [FC -> BN -> GLU] that are step dependent
    n_glu_shared: int, default = 2
        number of GLU Blocks [FC -> BN -> GLU] that will be shared
        across decision steps
    ghost_bn: bool, default=True
        Boolean indicating if `Ghost Batch Normalization
78
        <https://arxiv.org/abs/1705.08741>`_ will be used.
79 80 81
    virtual_batch_size: int, default = 128
        Batch size when using Ghost Batch Normalization
    momentum: float, default = 0.02
82
        Ghost Batch Normalization's momentum. The dreamquark-ai advises for
J
jrzaurin 已提交
83 84
        very low values. However high values are used in the original
        publication. During our tests higher values lead to better results
85 86
    gamma: float, default = 1.3
        Relaxation parameter in the paper. When gamma = 1, a feature is
J
jrzaurin 已提交
87 88 89
        enforced to be used only at one decision step. As gamma increases,
        more flexibility is provided to use a feature at multiple decision
        steps
90 91 92 93 94 95 96
    epsilon: float, default = 1e-15
        Float to avoid log(0). Always keep low
    mask_type: str, default = "sparsemax"
        Mask function to use. Either "sparsemax" or "entmax"

    Attributes
    ----------
97
    cat_and_cont_embed: ``nn.Module``
J
jrzaurin 已提交
98
        This is the module that processes the categorical and continuous columns
99
    encoder: ``nn.Module``
100 101
        ``Module`` containing the TabNet encoder. See the `paper
        <https://arxiv.org/abs/1908.07442>`_.
J
jrzaurin 已提交
102 103
    output_dim: int
        The output dimension of the model. This is a required attribute
J
jrzaurin 已提交
104
        neccesary to build the ``WideDeep`` class
105 106 107 108 109 110 111

    Example
    --------
    >>> import torch
    >>> from pytorch_widedeep.models import TabNet
    >>> X_tab = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
    >>> colnames = ['a', 'b', 'c', 'd', 'e']
112
    >>> cat_embed_input = [(u,i,j) for u,i,j in zip(colnames[:4], [4]*4, [8]*4)]
113
    >>> column_idx = {k:v for v,k in enumerate(colnames)}
114
    >>> model = TabNet(column_idx=column_idx, cat_embed_input=cat_embed_input, continuous_cols = ['e'])
J
jrzaurin 已提交
115
    >>> out = model(X_tab)
116 117
    """

118 119 120
    def __init__(
        self,
        column_idx: Dict[str, int],
121 122
        cat_embed_input: Optional[List[Tuple[str, int, int]]] = None,
        cat_embed_dropout: float = 0.1,
123 124
        use_cat_bias: bool = False,
        cat_embed_activation: Optional[str] = None,
125
        continuous_cols: Optional[List[str]] = None,
126
        cont_norm_layer: str = None,
127 128 129 130
        embed_continuous: bool = False,
        cont_embed_dim: int = 32,
        cont_embed_dropout: float = 0.1,
        use_cont_bias: bool = True,
131
        cont_embed_activation: Optional[str] = None,
132 133 134
        n_steps: int = 3,
        step_dim: int = 8,
        attn_dim: int = 8,
135
        dropout: float = 0.0,
136 137 138 139 140 141 142 143
        n_glu_step_dependent: int = 2,
        n_glu_shared: int = 2,
        ghost_bn: bool = True,
        virtual_batch_size: int = 128,
        momentum: float = 0.02,
        gamma: float = 1.3,
        epsilon: float = 1e-15,
        mask_type: str = "sparsemax",
144
    ):
145 146 147 148 149 150 151 152 153 154 155 156 157 158
        super(TabNet, self).__init__(
            column_idx=column_idx,
            cat_embed_input=cat_embed_input,
            cat_embed_dropout=cat_embed_dropout,
            use_cat_bias=use_cat_bias,
            cat_embed_activation=cat_embed_activation,
            continuous_cols=continuous_cols,
            cont_norm_layer=cont_norm_layer,
            embed_continuous=embed_continuous,
            cont_embed_dim=cont_embed_dim,
            cont_embed_dropout=cont_embed_dropout,
            use_cont_bias=use_cont_bias,
            cont_embed_activation=cont_embed_activation,
        )
159

160
        self.n_steps = n_steps
161 162
        self.step_dim = step_dim
        self.attn_dim = attn_dim
163
        self.dropout = dropout
164 165 166 167 168 169 170 171
        self.n_glu_step_dependent = n_glu_step_dependent
        self.n_glu_shared = n_glu_shared
        self.ghost_bn = ghost_bn
        self.virtual_batch_size = virtual_batch_size
        self.momentum = momentum
        self.gamma = gamma
        self.epsilon = epsilon
        self.mask_type = mask_type
172

J
jrzaurin 已提交
173
        # Embeddings are instantiated at the base model
174
        self.embed_out_dim = self.cat_and_cont_embed.output_dim
175 176

        # TabNet
177
        self.encoder = TabNetEncoder(
178
            self.embed_out_dim,
179
            n_steps,
180 181
            step_dim,
            attn_dim,
182
            dropout,
183 184 185 186 187 188 189 190 191 192
            n_glu_step_dependent,
            n_glu_shared,
            ghost_bn,
            virtual_batch_size,
            momentum,
            gamma,
            epsilon,
            mask_type,
        )

193 194 195
    def forward(
        self, X: Tensor, prior: Optional[Tensor] = None
    ) -> Tuple[Tensor, Tensor]:
196
        x = self._get_embeddings(X)
197
        steps_output, M_loss = self.encoder(x, prior)
198 199 200
        res = torch.sum(torch.stack(steps_output, dim=0), dim=0)
        return (res, M_loss)

201
    def forward_masks(self, X: Tensor) -> Tuple[Tensor, Dict[int, Tensor]]:
202
        x = self._get_embeddings(X)
203 204 205 206
        return self.encoder.forward_masks(x)

    @property
    def output_dim(self) -> int:
207
        return self.step_dim
208 209 210 211


class TabNetPredLayer(nn.Module):
    def __init__(self, inp, out):
212 213 214 215 216 217 218 219 220 221
        r"""This class is a 'hack' required because TabNet is a very particular
        model within ``WideDeep``.

        TabNet's forward method within ``WideDeep`` outputs two tensors, one
        with the last layer's activations and the sparse regularization
        factor. Since the output needs to be collected by ``WideDeep`` to then
        Sequentially build the output layer (connection to the output
        neuron(s)) I need to code a custom TabNetPredLayer that accepts two
        inputs. This will be used by the ``WideDeep`` class.
        """
222 223 224 225
        super(TabNetPredLayer, self).__init__()
        self.pred_layer = nn.Linear(inp, out, bias=False)
        initialize_non_glu(self.pred_layer, inp, out)

226
    def forward(self, tabnet_tuple: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
227 228
        res, M_loss = tabnet_tuple[0], tabnet_tuple[1]
        return self.pred_layer(res), M_loss
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303


class TabNetDecoder(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        n_steps: int = 3,
        step_dim: int = 8,
        attn_dim: int = 8,
        dropout: float = 0.0,
        n_glu_step_dependent: int = 2,
        n_glu_shared: int = 2,
        ghost_bn: bool = True,
        virtual_batch_size: int = 128,
        momentum: float = 0.02,
        gamma: float = 1.3,
        epsilon: float = 1e-15,
        mask_type: str = "sparsemax",
    ):
        super(TabNetDecoder, self).__init__()

        self.n_steps = n_steps
        self.step_dim = step_dim
        self.attn_dim = attn_dim
        self.dropout = dropout
        self.n_glu_step_dependent = n_glu_step_dependent
        self.n_glu_shared = n_glu_shared
        self.ghost_bn = ghost_bn
        self.virtual_batch_size = virtual_batch_size
        self.momentum = momentum
        self.gamma = gamma
        self.epsilon = epsilon
        self.mask_type = mask_type

        shared_layers = nn.ModuleList()
        for i in range(n_glu_shared):
            if i == 0:
                shared_layers.append(
                    nn.Linear(embed_dim, 2 * (step_dim + attn_dim), bias=False)
                )
            else:
                shared_layers.append(
                    nn.Linear(
                        step_dim + attn_dim, 2 * (step_dim + attn_dim), bias=False
                    )
                )

        self.feat_transformers = nn.ModuleList()
        for step in range(n_steps):
            transformer = FeatTransformer(
                embed_dim,
                embed_dim,
                dropout,
                shared_layers,
                n_glu_step_dependent,
                ghost_bn,
                virtual_batch_size,
                momentum=momentum,
            )
            self.feat_transformers.append(transformer)

        self.reconstruction_layer = nn.Linear(step_dim, embed_dim, bias=False)
        initialize_non_glu(self.reconstruction_layer, step_dim, embed_dim)

    def forward(self, X):
        out = 0.0
        for i, x in enumerate(X):
            x = self.feat_transformers[step_nb](x)
            out = torch.add(out, x)
        out = self.reconstruction_layer(out)
        return out

    def forward_masks(self, X: Tensor) -> Tuple[Tensor, Dict[int, Tensor]]:
        x = self._get_embeddings(X)
        return self.encoder.forward_masks(x)