_denoise_mlps.py 5.0 KB
Newer Older
1 2 3 4 5 6 7 8
import torch
from torch import Tensor, nn

from pytorch_widedeep.wdtypes import *  # noqa: F403
from pytorch_widedeep.models.tabular.mlp._layers import MLP


class CatSingleMlp(nn.Module):
9
    def __init__(self, input_dim, cat_embed_input, column_idx, activation):
10 11
        super(CatSingleMlp, self).__init__()

12 13 14
        self.input_dim = input_dim
        self.column_idx = column_idx
        self.cat_embed_input = cat_embed_input
15 16
        self.activation = activation

17
        self.num_class = sum([ei[1] for ei in cat_embed_input if e[0] != "cls_token"])
18 19

        self.mlp = MLP(
20
            d_hidden=[input_dim, self.num_class * 4, self.num_class],
21 22 23 24 25 26 27 28 29
            activation=activation,
            dropout=0.0,
            batchnorm=False,
            batchnorm_last=False,
            linear_first=False,
        )

    def forward(self, X: Tensor, r_: Tensor) -> Tuple[Tensor, Tensor]:

30
        x = torch.cat(
31 32 33 34 35
            [
                X[:, self.column_idx[col]].long()
                for col, _ in self.cat_embed_input
                if col != "cls_token"
            ]
36 37
        )

38
        cat_r_ = torch.cat(
39 40 41 42 43
            [
                r_[:, self.column_idx[col], :]
                for col, _ in self.cat_embed_input
                if col != "cls_token"
            ]
44 45 46
        )

        x_ = self.mlp(cat_r_)
47 48 49 50 51

        return x, x_


class CatFeaturesMlp(nn.Module):
52
    def __init__(self, input_dim, cat_embed_input, column_idx, activation):
53 54
        super(CatFeaturesMlp, self).__init__()

55 56 57
        self.input_dim = input_dim
        self.column_idx = column_idx
        self.cat_embed_input = cat_embed_input
58 59 60 61 62 63 64
        self.activation = activation

        self.mlp = nn.ModuleDict(
            {
                "mlp_"
                + col: MLP(
                    d_hidden=[
65
                        input_dim,
66 67 68 69 70 71 72 73 74
                        val * 4,
                        val,
                    ],
                    activation=activation,
                    dropout=0.0,
                    batchnorm=False,
                    batchnorm_last=False,
                    linear_first=False,
                )
75
                for col, val in self.cat_embed_input
76
                if col != "cls_token"
77 78 79 80 81
            }
        )

    def forward(self, X: Tensor, r_: Tensor) -> List[Tuple[Tensor, Tensor]]:

82 83 84 85 86
        x = [
            X[:, self.column_idx[col]].long()
            for col, _ in self.cat_embed_input
            if col != "cls_token"
        ]
87 88 89

        x_ = [
            self.mlp["mlp_" + col](r_[:, self.column_idx[col], :])
90
            for col, _ in self.cat_embed_input
91
            if col != "cls_token"
92 93 94 95 96 97
        ]

        return list(zip(x, x_))


class ContSingleMlp(nn.Module):
98
    def __init__(self, input_dim, continuous_cols, column_idx, activation):
99 100
        super(ContSingleMlp, self).__init__()

101 102 103
        self.input_dim = input_dim
        self.column_idx = column_idx
        self.continuous_cols = continuous_cols
104 105 106
        self.activation = activation

        self.mlp = MLP(
107
            d_hidden=[input_dim, input_dim * 2, 1],
108 109 110 111 112 113 114 115 116
            activation=activation,
            dropout=0.0,
            batchnorm=False,
            batchnorm_last=False,
            linear_first=False,
        )

    def forward(self, X: Tensor, r_: Tensor) -> Tuple[Tensor, Tensor]:

117
        x = torch.cat(
118 119 120 121 122
            [
                X[:, self.column_idx[col]].float()
                for col in self.continuous_cols
                if col != "cls_token"
            ]
123 124 125
        ).unsqueeze(1)

        cont_r_ = torch.cat(
126 127 128 129 130
            [
                r_[:, self.column_idx[col], :]
                for col in self.continuous_cols
                if col != "cls_token"
            ]
131 132
        )

133
        x_ = self.mlp(cont_r_)
134 135 136 137 138

        return x, x_


class ContFeaturesMlp(nn.Module):
139
    def __init__(self, input_dim, continuous_cols, column_idx, activation):
140 141
        super(ContFeaturesMlp, self).__init__()

142 143 144
        self.input_dim = input_dim
        self.column_idx = column_idx
        self.continuous_cols = continuous_cols
145 146 147 148 149 150 151
        self.activation = activation

        self.mlp = nn.ModuleDict(
            {
                "mlp_"
                + col: MLP(
                    d_hidden=[
152 153
                        input_dim,
                        input_dim * 2,
154 155 156 157 158 159 160 161
                        1,
                    ],
                    activation=activation,
                    dropout=0.0,
                    batchnorm=False,
                    batchnorm_last=False,
                    linear_first=False,
                )
162
                for col in self.continuous_cols
163
                if col != "cls_token"
164 165 166 167 168 169 170 171
            }
        )

    def forward(self, X: Tensor, r_: Tensor) -> List[Tuple[Tensor, Tensor]]:

        x = [
            X[:, self.column_idx[col]].unsqueeze(1).float()
            for col in self.continuous_cols
172
            if col != "cls_token"
173 174 175 176 177
        ]

        x_ = [
            self.mlp["mlp_" + col](r_[:, self.column_idx[col]])
            for col in self.continuous_cols
178
            if col != "cls_token"
179 180 181
        ]

        return list(zip(x, x_))