test_finetuning_routines.py 5.9 KB
Newer Older
1
import string
2 3

import numpy as np
4
import torch
5
import pytest
6 7 8
import torch.nn.functional as F
from torch import nn
from sklearn.utils import Bunch
9
from torch.utils.data import Dataset, DataLoader
10

11
from pytorch_widedeep.models import Wide, TabMlp
12
from pytorch_widedeep.metrics import Accuracy, MultipleMetrics
13
from pytorch_widedeep.models.deep_image import conv_layer
14
from pytorch_widedeep.training._finetune import FineTune
15 16 17

use_cuda = torch.cuda.is_available()

J
jrzaurin 已提交
18

19
# Define a series of simple models to quickly test the FineTune class
20
class TestDeepText(nn.Module):
21
    def __init__(self):
22
        super(TestDeepText, self).__init__()
23 24
        self.word_embed = nn.Embedding(5, 16, padding_idx=0)
        self.rnn = nn.LSTM(16, 8, batch_first=True)
J
jrzaurin 已提交
25
        self.linear = nn.Linear(8, 1)
26 27 28 29

    def forward(self, X):
        embed = self.word_embed(X.long())
        o, (h, c) = self.rnn(embed)
J
jrzaurin 已提交
30
        return self.linear(h).view(-1, 1)
31 32


33
class TestDeepImage(nn.Module):
34
    def __init__(self):
35
        super(TestDeepImage, self).__init__()
36 37 38

        self.conv_block = nn.Sequential(
            conv_layer(3, 64, 3),
J
jrzaurin 已提交
39 40
            conv_layer(64, 128, 1, maxpool=False, adaptiveavgpool=True),
        )
41 42 43 44 45 46 47 48
        self.linear = nn.Linear(128, 1)

    def forward(self, X):
        x = self.conv_block(X)
        x = x.view(x.size(0), -1)
        return self.linear(x)


49
# Define a simple WideDeep Dataset
J
jrzaurin 已提交
50
class WDset(Dataset):
51
    def __init__(self, X_wide, X_tab, X_text, X_img, target):
52 53

        self.X_wide = X_wide
54
        self.X_tab = X_tab
55
        self.X_text = X_text
J
jrzaurin 已提交
56
        self.X_img = X_img
57 58
        self.Y = target

J
jrzaurin 已提交
59
    def __getitem__(self, idx: int):
60 61

        X = Bunch(wide=self.X_wide[idx])
62
        X.deeptabular = self.X_tab[idx]
J
jrzaurin 已提交
63
        X.deeptext = self.X_text[idx]
64
        X.deepimage = self.X_img[idx]
J
jrzaurin 已提交
65
        y = self.Y[idx]
66 67 68
        return X, y

    def __len__(self):
69
        return len(self.X_tab)
70

J
jrzaurin 已提交
71

72 73 74
# Remember that the FineTune class will be instantiated inside the WideDeep
# and will take, among others, the activation_fn and the loss_fn of that class
# as parameters. Therefore, we define equivalent classes to replicate the
75
# scenario
76 77
# def activ_fn(inp):
#     return torch.sigmoid(inp)
78

J
jrzaurin 已提交
79

80
def loss_fn(y_pred, y_true):
81
    return F.binary_cross_entropy_with_logits(y_pred, y_true.view(-1, 1))
J
jrzaurin 已提交
82

83

84
# Define the data components:
85 86

# target
J
jrzaurin 已提交
87
target = torch.empty(100, 1).random_(0, 2)
88 89

# wide
90
X_wide = torch.empty(100, 4).random_(1, 20)
91 92 93 94 95

# deep
colnames = list(string.ascii_lowercase)[:10]
embed_cols = [np.random.choice(np.arange(5), 100) for _ in range(5)]
cont_cols = [np.random.rand(100) for _ in range(5)]
J
jrzaurin 已提交
96
embed_input = [(u, i, j) for u, i, j in zip(colnames[:5], [5] * 5, [16] * 5)]
97
column_idx = {k: v for v, k in enumerate(colnames[:10])}
J
jrzaurin 已提交
98
continuous_cols = colnames[-5:]
99
X_tab = torch.from_numpy(np.vstack(embed_cols + cont_cols).transpose())
100 101

# text
102
X_text = torch.cat((torch.zeros([100, 1]), torch.empty(100, 4).random_(1, 4)), axis=1)  # type: ignore[call-overload]
103 104 105 106 107 108 109

# image
X_image = torch.rand(100, 3, 28, 28)

# Define the model components

# wide
110
wide = Wide(X_wide.unique().size(0), 1)
J
jrzaurin 已提交
111 112
if use_cuda:
    wide.cuda()
113 114

# deep
115
deeptabular = TabMlp(
116 117
    mlp_hidden_dims=[32, 16, 8],
    mlp_dropout=0.2,
118
    column_idx=column_idx,
119
    embed_input=embed_input,
J
jrzaurin 已提交
120 121
    continuous_cols=continuous_cols,
)
122
deeptabular = nn.Sequential(deeptabular, nn.Linear(8, 1))  # type: ignore[assignment]
J
jrzaurin 已提交
123
if use_cuda:
124
    deeptabular.cuda()
125 126

# text
127
deeptext = TestDeepText()
J
jrzaurin 已提交
128 129
if use_cuda:
    deeptext.cuda()
130 131

# image
132
deepimage = TestDeepImage()
J
jrzaurin 已提交
133 134
if use_cuda:
    deepimage.cuda()
135 136

# Define the loader
137
wdset = WDset(X_wide, X_tab, X_text, X_image, target)
138 139
wdloader = DataLoader(wdset, batch_size=10, shuffle=True)

140 141 142 143 144 145 146 147 148 149
# Instantiate the FineTune class
finetuner = FineTune(loss_fn, MultipleMetrics([Accuracy()]), "binary", False)

# List the layers for the finetune_gradual method
# deeptabular childrens -> TabMmlp and the final Linear layer
# TabMlp children -> Embeddings and MLP
# MLP children -> dense layers
# so here we go...
last_linear = list(deeptabular.children())[1]
inverted_mlp_layers = list(
150
    list(list(deeptabular.named_modules())[11][1].children())[0].children()
151 152
)[::-1]
tab_layers = [last_linear] + inverted_mlp_layers
J
jrzaurin 已提交
153
text_layers = [c for c in list(deeptext.children())[1:]][::-1]
154 155
image_layers = [c for c in list(deepimage.children())][::-1]

J
jrzaurin 已提交
156

157
###############################################################################
158
# Simply test that finetune_all runs
159 160
###############################################################################
@pytest.mark.parametrize(
J
jrzaurin 已提交
161
    "model, modelname, loader, n_epochs, max_lr",
162
    [
J
jrzaurin 已提交
163
        (wide, "wide", wdloader, 1, 0.01),
164
        (deeptabular, "deeptabular", wdloader, 1, 0.01),
J
jrzaurin 已提交
165 166 167 168
        (deeptext, "deeptext", wdloader, 1, 0.01),
        (deepimage, "deepimage", wdloader, 1, 0.01),
    ],
)
169
def test_finetune_all(model, modelname, loader, n_epochs, max_lr):
170
    has_run = True
J
jrzaurin 已提交
171
    try:
172
        finetuner.finetune_all(model, modelname, loader, n_epochs, max_lr)
J
jrzaurin 已提交
173
    except Exception:
J
jrzaurin 已提交
174
        has_run = False
175 176
    assert has_run

J
jrzaurin 已提交
177

178
###############################################################################
179
# Simply test that finetune_gradual runs
180 181
###############################################################################
@pytest.mark.parametrize(
J
jrzaurin 已提交
182
    "model, modelname, loader, max_lr, layers, routine",
183
    [
184 185
        (deeptabular, "deeptabular", wdloader, 0.01, tab_layers, "felbo"),
        (deeptabular, "deeptabular", wdloader, 0.01, tab_layers, "howard"),
J
jrzaurin 已提交
186 187 188 189 190 191
        (deeptext, "deeptext", wdloader, 0.01, text_layers, "felbo"),
        (deeptext, "deeptext", wdloader, 0.01, text_layers, "howard"),
        (deepimage, "deepimage", wdloader, 0.01, image_layers, "felbo"),
        (deepimage, "deepimage", wdloader, 0.01, image_layers, "howard"),
    ],
)
192
def test_finetune_gradual(model, modelname, loader, max_lr, layers, routine):
193
    has_run = True
J
jrzaurin 已提交
194
    try:
195
        finetuner.finetune_gradual(model, modelname, loader, max_lr, layers, routine)
J
jrzaurin 已提交
196
    except Exception:
J
jrzaurin 已提交
197
        has_run = False
198
    assert has_run