tab_resnet.py 10.9 KB
Newer Older
1 2 3 4 5 6 7
from collections import OrderedDict

import numpy as np
import torch
from torch import nn
from torch.nn import Module

8 9
from pytorch_widedeep.wdtypes import *  # noqa: F403
from pytorch_widedeep.models.tab_mlp import MLP
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50


class BasicBlock(nn.Module):
    def __init__(self, inp: int, out: int, dropout: float = 0.0, resize: Module = None):
        super(BasicBlock, self).__init__()

        self.lin1 = nn.Linear(inp, out)
        self.bn1 = nn.BatchNorm1d(out)
        self.leaky_relu = nn.LeakyReLU(inplace=True)
        if dropout > 0.0:
            self.dropout = True
            self.dp = nn.Dropout(dropout)
        else:
            self.dropout = False
        self.lin2 = nn.Linear(out, out)
        self.bn2 = nn.BatchNorm1d(out)
        self.resize = resize

    def forward(self, x):

        identity = x

        out = self.lin1(x)
        out = self.bn1(out)
        out = self.leaky_relu(out)
        if self.dropout:
            out = self.dp(out)

        out = self.lin2(out)
        out = self.bn2(out)

        if self.resize is not None:
            identity = self.resize(x)

        out += identity
        out = self.leaky_relu(out)

        return out


class DenseResnet(nn.Module):
51
    def __init__(self, input_dim: int, blocks_dims: List[int], dropout: float):
52 53 54
        super(DenseResnet, self).__init__()

        self.input_dim = input_dim
55
        self.blocks_dims = blocks_dims
56 57
        self.dropout = dropout

58
        if input_dim != blocks_dims[0]:
59 60 61
            self.dense_resnet = nn.Sequential(
                OrderedDict(
                    [
62 63
                        ("lin1", nn.Linear(input_dim, blocks_dims[0])),
                        ("bn1", nn.BatchNorm1d(blocks_dims[0])),
64 65 66 67 68
                    ]
                )
            )
        else:
            self.dense_resnet = nn.Sequential()
69
        for i in range(1, len(blocks_dims)):
70
            resize = None
71
            if blocks_dims[i - 1] != blocks_dims[i]:
72
                resize = nn.Sequential(
73 74
                    nn.Linear(blocks_dims[i - 1], blocks_dims[i]),
                    nn.BatchNorm1d(blocks_dims[i]),
75 76 77
                )
            self.dense_resnet.add_module(
                "block_{}".format(i - 1),
78
                BasicBlock(blocks_dims[i - 1], blocks_dims[i], dropout, resize),
79 80 81 82 83 84 85 86 87 88
            )

    def forward(self, X: Tensor) -> Tensor:
        return self.dense_resnet(X)


class TabResnet(nn.Module):
    def __init__(
        self,
        embed_input: List[Tuple[str, int, int]],
89 90 91
        column_idx: Dict[str, int],
        blocks_dims: List[int],
        blocks_dropout: float = 0.0,
92
        mlp_hidden_dims: Optional[List[int]] = None,
93 94 95 96 97
        mlp_activation: str = "relu",
        mlp_dropout: float = 0.0,
        mlp_batchnorm: bool = False,
        mlp_batchnorm_last: bool = False,
        mlp_linear_first: bool = False,
98 99 100 101 102
        embed_dropout: Optional[float] = 0.0,
        continuous_cols: Optional[List[str]] = None,
        batchnorm_cont: Optional[bool] = False,
        concat_cont_first: Optional[bool] = True,
    ):
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
        r"""Defines a so-called ``TabResnet`` model that can be used as the
        ``deeptabular`` component of a Wide & Deep model.

        This class combines embedding representations of the categorical
        features with numerical (aka continuous) features. These are then
        passed through a series of Resnet blocks. See
        ``pytorch_widedeep.models.deep_dense_resnet.BasicBlock`` for details
        on the structure of each block.

        Parameters
        ----------
        embed_input: List
            List of Tuples with the column name, number of unique values and
            embedding dimension. e.g. [(education, 11, 32), ...].
        column_idx: Dict
            Dict containing the index of the columns that will be passed through
            the TabMlp model. Required to slice the tensors. e.g. {'education':
            0, 'relationship': 1, 'workclass': 2, ...}
        blocks_dims: List
            List of integers that define the input and output units of each block.
            For example: ``[128, 64, 32]`` will generate 2 blocks_dims. The first will
            receive a tensor of size 128 and output a tensor of size 64, and the
            second will receive a tensor of size 64 and output a tensor of size
            32. See ``pytorch_widedeep.models.deep_dense_resnet.BasicBlock`` for
            details on the structure of each block.
        blocks_dropout: float, default =  0.0
           Block's `"internal"` dropout. This dropout is applied to the first
           of the two dense layers that comprise each ``BasicBlock``.e
        mlp_hidden_dims: List
            List with the number of neurons per dense layer in the mlp. e.g: [64,32]
        mlp_activation: str, default = "relu"
            Activation function for the dense layers of the MLP
        mlp_dropout: float, default = 0.
            float with the dropout between the dense layers of the MLP.
        mlp_batchnorm: bool, default = False
            Boolean indicating whether or not batch normalization will be applied
            to the dense layers
        mlp_batchnorm_last: bool, default = False
            Boolean indicating whether or not batch normalization will be applied
            to the last of the dense layers
        mlp_linear_first: bool, default = False
            Boolean indicating the order of the operations in the dense
            layer. If ``True: [LIN -> ACT -> BN -> DP]``. If ``False: [BN -> DP ->
            LIN -> ACT]``
        embed_dropout: float, Optional, default = 0.0
            embeddings dropout
        continuous_cols: List, Optional
            List with the name of the numeric (aka continuous) columns
        batchnorm_cont: bool, default = False
            Boolean indicating whether or not to apply batch normalization to the
            continuous input
        concat_cont_first: bool, Optional, default = True
            Boolean indicating

            .. note:: Unlike ``TabMlp``, ``TabResnet`` assumes that there are categorical
                columns

        Attributes
        ----------
        dense_resnet: ``nn.Sequential``
            deep dense Resnet model that will receive the concatenation of the
            embeddings and the continuous columns
        embed_layers: ``nn.ModuleDict``
            ``ModuleDict`` with the embedding layers
        output_dim: `int`
            The output dimension of the model. This is a required attribute
            neccesary to build the WideDeep class
        tab_resnet_mlp: ``nn.Sequential``
            if ``mlp_hidden_dims`` is ``True``, this attribute will be an mlp model
            that will receive i) the results of concatenation of the embeddings
            and the continuous columns (if present) and then passed them through
            the ``dense_resnet``, or ii) the result of passing the embeddings
            through the ``dense_resnet`` and the concatenating the results with
            the continuous colnames (if present)

        Example
        --------
        >>> import torch
        >>> from pytorch_widedeep.models import TabResnet
        >>> X_deep = torch.cat((torch.empty(5, 4).random_(4), torch.rand(5, 1)), axis=1)
        >>> colnames = ['a', 'b', 'c', 'd', 'e']
        >>> embed_input = [(u,i,j) for u,i,j in zip(colnames[:4], [4]*4, [8]*4)]
        >>> column_idx = {k:v for v,k in enumerate(colnames)}
        >>> model = TabResnet(blocks_dims=[16,4], column_idx=column_idx, embed_input=embed_input)
        >>> out = model(X_deep)
        """
189 190 191
        super(TabResnet, self).__init__()

        self.embed_input = embed_input
192 193 194 195
        self.column_idx = column_idx
        self.blocks_dims = blocks_dims
        self.blocks_dropout = blocks_dropout
        self.mlp_activation = mlp_activation
196 197 198
        self.mlp_hidden_dims = mlp_hidden_dims
        self.mlp_batchnorm = mlp_batchnorm
        self.mlp_batchnorm_last = mlp_batchnorm_last
199
        self.mlp_linear_first = mlp_linear_first
200 201 202 203 204
        self.embed_dropout = embed_dropout
        self.continuous_cols = continuous_cols
        self.batchnorm_cont = batchnorm_cont
        self.concat_cont_first = concat_cont_first

205
        if len(self.blocks_dims) < 2:
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
            raise ValueError(
                "'blocks' must contain at least two elements, e.g. [256, 128]"
            )

        # Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories.
        self.embed_layers = nn.ModuleDict(
            {
                "emb_layer_" + col: nn.Embedding(val + 1, dim, padding_idx=0)
                for col, val, dim in self.embed_input
            }
        )
        self.embedding_dropout = nn.Dropout(embed_dropout)
        emb_inp_dim = np.sum([embed[2] for embed in self.embed_input])

        # Continuous
        if self.continuous_cols is not None:
            cont_inp_dim = len(self.continuous_cols)
            if self.batchnorm_cont:
                self.norm = nn.BatchNorm1d(cont_inp_dim)
        else:
            cont_inp_dim = 0

        # DenseResnet
        if self.concat_cont_first:
            dense_resnet_input_dim = emb_inp_dim + cont_inp_dim
231
            self.output_dim = blocks_dims[-1]
232 233
        else:
            dense_resnet_input_dim = emb_inp_dim
234 235 236 237
            self.output_dim = cont_inp_dim + blocks_dims[-1]
        self.tab_resnet = DenseResnet(
            dense_resnet_input_dim, blocks_dims, blocks_dropout
        )
238 239 240 241

        # MLP
        if self.mlp_hidden_dims is not None:
            if self.concat_cont_first:
242
                mlp_input_dim = blocks_dims[-1]
243
            else:
244
                mlp_input_dim = cont_inp_dim + blocks_dims[-1]
245 246 247
            mlp_hidden_dims = [mlp_input_dim] + mlp_hidden_dims
            self.tab_resnet_mlp = MLP(
                mlp_hidden_dims,
248 249
                mlp_activation,
                mlp_dropout,
250 251
                mlp_batchnorm,
                mlp_batchnorm_last,
252
                mlp_linear_first,
253 254 255 256 257 258 259 260
            )
            self.output_dim = mlp_hidden_dims[-1]

    def forward(self, X: Tensor) -> Tensor:  # type: ignore
        r"""Forward pass that concatenates the continuous features with the
        embeddings. The result is then passed through a series of dense Resnet
        blocks"""
        embed = [
261
            self.embed_layers["emb_layer_" + col](X[:, self.column_idx[col]].long())
262 263 264 265 266 267
            for col, _, _ in self.embed_input
        ]
        x = torch.cat(embed, 1)
        x = self.embedding_dropout(x)

        if self.continuous_cols is not None:
268
            cont_idx = [self.column_idx[col] for col in self.continuous_cols]
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
            x_cont = X[:, cont_idx].float()
            if self.batchnorm_cont:
                x_cont = self.norm(x_cont)
            if self.concat_cont_first:
                x = torch.cat([x, x_cont], 1)
                out = self.tab_resnet(x)
            else:
                out = torch.cat([self.tab_resnet(x), x_cont], 1)
        else:
            out = self.tab_resnet(x)

        if self.mlp_hidden_dims is not None:
            out = self.tab_resnet_mlp(out)

        return out