tab_resnet.py 11.5 KB
Newer Older
1 2 3 4 5 6 7
from collections import OrderedDict

import numpy as np
import torch
from torch import nn
from torch.nn import Module

8 9
from pytorch_widedeep.wdtypes import *  # noqa: F403
from pytorch_widedeep.models.tab_mlp import MLP
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50


class BasicBlock(nn.Module):
    def __init__(self, inp: int, out: int, dropout: float = 0.0, resize: Module = None):
        super(BasicBlock, self).__init__()

        self.lin1 = nn.Linear(inp, out)
        self.bn1 = nn.BatchNorm1d(out)
        self.leaky_relu = nn.LeakyReLU(inplace=True)
        if dropout > 0.0:
            self.dropout = True
            self.dp = nn.Dropout(dropout)
        else:
            self.dropout = False
        self.lin2 = nn.Linear(out, out)
        self.bn2 = nn.BatchNorm1d(out)
        self.resize = resize

    def forward(self, x):

        identity = x

        out = self.lin1(x)
        out = self.bn1(out)
        out = self.leaky_relu(out)
        if self.dropout:
            out = self.dp(out)

        out = self.lin2(out)
        out = self.bn2(out)

        if self.resize is not None:
            identity = self.resize(x)

        out += identity
        out = self.leaky_relu(out)

        return out


class DenseResnet(nn.Module):
51
    def __init__(self, input_dim: int, blocks_dims: List[int], dropout: float):
52 53 54
        super(DenseResnet, self).__init__()

        self.input_dim = input_dim
55
        self.blocks_dims = blocks_dims
56 57
        self.dropout = dropout

58
        if input_dim != blocks_dims[0]:
59 60 61
            self.dense_resnet = nn.Sequential(
                OrderedDict(
                    [
62 63
                        ("lin1", nn.Linear(input_dim, blocks_dims[0])),
                        ("bn1", nn.BatchNorm1d(blocks_dims[0])),
64 65 66 67 68
                    ]
                )
            )
        else:
            self.dense_resnet = nn.Sequential()
69
        for i in range(1, len(blocks_dims)):
70
            resize = None
71
            if blocks_dims[i - 1] != blocks_dims[i]:
72
                resize = nn.Sequential(
73 74
                    nn.Linear(blocks_dims[i - 1], blocks_dims[i]),
                    nn.BatchNorm1d(blocks_dims[i]),
75 76 77
                )
            self.dense_resnet.add_module(
                "block_{}".format(i - 1),
78
                BasicBlock(blocks_dims[i - 1], blocks_dims[i], dropout, resize),
79 80 81 82 83 84 85 86 87 88
            )

    def forward(self, X: Tensor) -> Tensor:
        return self.dense_resnet(X)


class TabResnet(nn.Module):
    def __init__(
        self,
        embed_input: List[Tuple[str, int, int]],
89
        column_idx: Dict[str, int],
90 91
        blocks_dims: List[int] = [200, 100, 100],
        blocks_dropout: float = 0.1,
92
        mlp_hidden_dims: Optional[List[int]] = None,
93 94 95 96 97 98
        mlp_activation: str = "relu",
        mlp_dropout: float = 0.1,
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = False,
        embed_dropout: float = 0.1,
99
        continuous_cols: Optional[List[str]] = None,
100 101
        batchnorm_cont: bool = False,
        concat_cont_first: bool = True,
102
    ):
103 104 105 106 107 108 109 110 111
        r"""Defines a so-called ``TabResnet`` model that can be used as the
        ``deeptabular`` component of a Wide & Deep model.

        This class combines embedding representations of the categorical
        features with numerical (aka continuous) features. These are then
        passed through a series of Resnet blocks. See
        ``pytorch_widedeep.models.deep_dense_resnet.BasicBlock`` for details
        on the structure of each block.

112 113 114
        .. note:: Unlike ``TabMlp``, ``TabResnet`` assumes that there are always
            categorical columns

115 116 117 118 119 120 121 122 123
        Parameters
        ----------
        embed_input: List
            List of Tuples with the column name, number of unique values and
            embedding dimension. e.g. [(education, 11, 32), ...].
        column_idx: Dict
            Dict containing the index of the columns that will be passed through
            the TabMlp model. Required to slice the tensors. e.g. {'education':
            0, 'relationship': 1, 'workclass': 2, ...}
124
        blocks_dims: List, default = [200, 100, 100]
125
            List of integers that define the input and output units of each block.
126 127 128 129
            For example: [200, 100, 100] will generate 2 blocks_dims. The first will
            receive a tensor of size 200 and output a tensor of size 100, and the
            second will receive a tensor of size 100 and output a tensor of size
            100. See ``pytorch_widedeep.models.deep_dense_resnet.BasicBlock`` for
130
            details on the structure of each block.
131
        blocks_dropout: float, default =  0.1
132 133
           Block's `"internal"` dropout. This dropout is applied to the first
           of the two dense layers that comprise each ``BasicBlock``.e
134 135 136 137 138
        mlp_hidden_dims: List, Optional, default = None
            List with the number of neurons per dense layer in the mlp. e.g:
            [64, 32]. If ``None`` the  output of the Resnet Blocks will be
            connected directly to the output neuron(s), i.e. using a MLP is
            optional.
139
        mlp_activation: str, default = "relu"
140
            Activation function for the dense layers of the MLP
141
        mlp_dropout: float, default = 0.1
142
            float with the dropout between the dense layers of the MLP.
143
        mlp_batchnorm: bool, default = False
144 145
            Boolean indicating whether or not batch normalization will be applied
            to the dense layers
146
        mlp_batchnorm_last: bool, default = False
147 148
            Boolean indicating whether or not batch normalization will be applied
            to the last of the dense layers
149
        mlp_linear_first: bool, default = False
150 151 152
            Boolean indicating the order of the operations in the dense
            layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
            LIN -> ACT]``
153
        embed_dropout: float, default = 0.1
154
            embeddings dropout
155
        continuous_cols: List, Optional, default = None
156 157 158 159
            List with the name of the numeric (aka continuous) columns
        batchnorm_cont: bool, default = False
            Boolean indicating whether or not to apply batch normalization to the
            continuous input
160
        concat_cont_first: bool, default = True
161 162 163 164 165
            Boolean indicating whether the continuum columns will be
            concatenated with the Embeddings and then passed through the
            Resnet blocks (``True``) or, alternatively, will be concatenated
            with the result of passing the embeddings through the Resnet
            Blocks (``False``)
166 167 168 169 170 171 172 173

        Attributes
        ----------
        dense_resnet: ``nn.Sequential``
            deep dense Resnet model that will receive the concatenation of the
            embeddings and the continuous columns
        embed_layers: ``nn.ModuleDict``
            ``ModuleDict`` with the embedding layers
174 175 176 177
        tab_resnet_mlp: ``nn.Sequential``
            if ``mlp_hidden_dims`` is ``True``, this attribute will be an mlp
            model that will receive:

178 179 180
            - the results of the concatenation of the embeddings and the
              continuous columns -- if present -- and then passed it through
              the ``dense_resnet`` (``concat_cont_first = True``), or
181

182 183 184
            - the result of passing the embeddings through the ``dense_resnet``
              and then concatenating the results with the continuous columns --
              if present -- (``concat_cont_first = False``)
185

186 187 188 189 190 191 192 193 194 195 196 197
        output_dim: `int`
            The output dimension of the model. This is a required attribute
            neccesary to build the WideDeep class

        Example
        --------
        >>> import torch
        >>> from pytorch_widedeep.models import TabResnet
        >>> X_deep = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
        >>> colnames = ['a', 'b', 'c', 'd', 'e']
        >>> embed_input = [(u,i,j) for u,i,j in zip(colnames[:4], [4]*4, [8]*4)]
        >>> column_idx = {k:v for v,k in enumerate(colnames)}
198 199
        >>> model = TabResnet(blocks_dims=[16,4], column_idx=column_idx, embed_input=embed_input,
        ... continuous_cols = ['e'])
200 201
        >>> out = model(X_deep)
        """
202 203 204
        super(TabResnet, self).__init__()

        self.embed_input = embed_input
205 206 207 208
        self.column_idx = column_idx
        self.blocks_dims = blocks_dims
        self.blocks_dropout = blocks_dropout
        self.mlp_activation = mlp_activation
209 210 211
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
212
        self.mlp_linear_first = mlp_linear_first
213 214 215 216 217
        self.embed_dropout = embed_dropout
        self.continuous_cols = continuous_cols
        self.batchnorm_cont = batchnorm_cont
        self.concat_cont_first = concat_cont_first

218
        if len(self.blocks_dims) < 2:
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
            raise ValueError(
                "'blocks' must contain at least two elements, e.g. [256, 128]"
            )

        # Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories.
        self.embed_layers = nn.ModuleDict(
            {
                "emb_layer_" + col: nn.Embedding(val + 1, dim, padding_idx=0)
                for col, val, dim in self.embed_input
            }
        )
        self.embedding_dropout = nn.Dropout(embed_dropout)
        emb_inp_dim = np.sum([embed[2] for embed in self.embed_input])

        # Continuous
        if self.continuous_cols is not None:
            cont_inp_dim = len(self.continuous_cols)
            if self.batchnorm_cont:
                self.norm = nn.BatchNorm1d(cont_inp_dim)
        else:
            cont_inp_dim = 0

        # DenseResnet
        if self.concat_cont_first:
            dense_resnet_input_dim = emb_inp_dim + cont_inp_dim
244
            self.output_dim = blocks_dims[-1]
245 246
        else:
            dense_resnet_input_dim = emb_inp_dim
247 248
            self.output_dim = cont_inp_dim + blocks_dims[-1]
        self.tab_resnet = DenseResnet(
249
            dense_resnet_input_dim, blocks_dims, blocks_dropout  # type: ignore[arg-type]
250
        )
251 252 253 254

        # MLP
        if self.mlp_hidden_dims is not None:
            if self.concat_cont_first:
255
                mlp_input_dim = blocks_dims[-1]
256
            else:
257
                mlp_input_dim = cont_inp_dim + blocks_dims[-1]
258 259 260
            mlp_hidden_dims = [mlp_input_dim] + mlp_hidden_dims
            self.tab_resnet_mlp = MLP(
                mlp_hidden_dims,
261 262
                mlp_activation,
                mlp_dropout,
263 264
                mlp_batchnorm,
                mlp_batchnorm_last,
265
                mlp_linear_first,
266 267 268 269 270 271 272 273
            )
            self.output_dim = mlp_hidden_dims[-1]

    def forward(self, X: Tensor) -> Tensor:  # type: ignore
        r"""Forward pass that concatenates the continuous features with the
        embeddings. The result is then passed through a series of dense Resnet
        blocks"""
        embed = [
274
            self.embed_layers["emb_layer_" + col](X[:, self.column_idx[col]].long())
275 276 277 278 279 280
            for col, _, _ in self.embed_input
        ]
        x = torch.cat(embed, 1)
        x = self.embedding_dropout(x)

        if self.continuous_cols is not None:
281
            cont_idx = [self.column_idx[col] for col in self.continuous_cols]
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
            x_cont = X[:, cont_idx].float()
            if self.batchnorm_cont:
                x_cont = self.norm(x_cont)
            if self.concat_cont_first:
                x = torch.cat([x, x_cont], 1)
                out = self.tab_resnet(x)
            else:
                out = torch.cat([self.tab_resnet(x), x_cont], 1)
        else:
            out = self.tab_resnet(x)

        if self.mlp_hidden_dims is not None:
            out = self.tab_resnet_mlp(out)

        return out