test_fit_methods.py 11.4 KB
Newer Older
1
import string
2
import warnings
3 4

import numpy as np
5
import pytest
6
from torch import nn
7

8 9 10 11 12 13 14
from pytorch_widedeep.models import (
    Wide,
    TabMlp,
    TabNet,
    WideDeep,
    TabTransformer,
)
15
from pytorch_widedeep.metrics import R2Score
16
from pytorch_widedeep.training import Trainer
17
from pytorch_widedeep.dataloaders import DataLoaderImbalanced
18 19

# Wide array
20
X_wide = np.random.choice(50, (32, 10))
21 22

# Deep Array
J
jrzaurin 已提交
23
colnames = list(string.ascii_lowercase)[:10]
24
embed_cols = [np.random.choice(np.arange(5), 32) for _ in range(5)]
J
jrzaurin 已提交
25
embed_input = [(u, i, j) for u, i, j in zip(colnames[:5], [5] * 5, [16] * 5)]
26
embed_input_tt = [(u, i) for u, i in zip(colnames[:5], [5] * 5)]
27
cont_cols = [np.random.rand(32) for _ in range(5)]
28
column_idx = {k: v for v, k in enumerate(colnames)}
29
X_tab = np.vstack(embed_cols + cont_cols).transpose()
30 31

# Target
32 33
target_regres = np.random.random(32)
target_binary = np.random.choice(2, 32)
34
target_binary_imbalanced = np.random.choice(2, 32, p=[0.75, 0.25])
35
target_multic = np.random.choice(3, 32)
36 37

# Test dictionary
38
X_test = {"X_wide": X_wide, "X_tab": X_tab}
J
jrzaurin 已提交
39

40 41

##############################################################################
42
# Test that the three possible methods (regression, binary and mutliclass)
43 44 45
# work well
##############################################################################
@pytest.mark.parametrize(
P
Pavol Mulinka 已提交
46
    "X_wide, X_tab, target, objective, X_test, pred_dim, probs_dim, uncertainties_pred_dim",
47
    [
P
Pavol Mulinka 已提交
48 49 50 51 52 53
        (X_wide, X_tab, target_regres, "regression", None, 1, None, 4),
        (X_wide, X_tab, target_binary, "binary", None, 1, 2, 3),
        (X_wide, X_tab, target_multic, "multiclass", None, 3, 3, 4),
        (X_wide, X_tab, target_regres, "regression", X_test, 1, None, 4),
        (X_wide, X_tab, target_binary, "binary", X_test, 1, 2, 3),
        (X_wide, X_tab, target_multic, "multiclass", X_test, 3, 3, 4),
J
jrzaurin 已提交
54 55
    ],
)
56
def test_fit_objectives(
J
jrzaurin 已提交
57
    X_wide,
58
    X_tab,
J
jrzaurin 已提交
59
    target,
60
    objective,
J
jrzaurin 已提交
61
    X_test,
62
    pred_dim,
J
jrzaurin 已提交
63
    probs_dim,
P
Pavol Mulinka 已提交
64
    uncertainties_pred_dim,
J
jrzaurin 已提交
65
):
66
    wide = Wide(np.unique(X_wide).shape[0], pred_dim)
67
    deeptabular = TabMlp(
68
        column_idx=column_idx,
69
        cat_embed_input=embed_input,
J
jrzaurin 已提交
70
        continuous_cols=colnames[-5:],
71 72
        mlp_hidden_dims=[32, 16],
        mlp_dropout=[0.5, 0.5],
J
jrzaurin 已提交
73
    )
74
    model = WideDeep(wide=wide, deeptabular=deeptabular, pred_dim=pred_dim)
75 76 77
    trainer = Trainer(model, objective=objective, verbose=0)
    trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, batch_size=16)
    preds = trainer.predict(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
P
Pavol Mulinka 已提交
78 79 80 81 82 83 84 85 86 87
    probs = trainer.predict_proba(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
    unc_preds = trainer.predict_uncertainty(
        X_wide=X_wide, X_tab=X_tab, X_test=X_test, uncertainty_granularity=5
    )
    if objective == "regression":
        assert (preds.shape[0], probs, unc_preds.shape[1]) == (
            32,
            probs_dim,
            uncertainties_pred_dim,
        )
88
    else:
P
Pavol Mulinka 已提交
89 90 91 92 93
        assert (preds.shape[0], probs.shape[1], unc_preds.shape[1]) == (
            32,
            probs_dim,
            uncertainties_pred_dim,
        )
94 95 96 97 98 99 100


##############################################################################
# Simply Test that runs with the deephead parameter
##############################################################################
def test_fit_with_deephead():
    wide = Wide(np.unique(X_wide).shape[0], 1)
101
    deeptabular = TabMlp(
102
        column_idx=column_idx,
103
        cat_embed_input=embed_input,
104
        continuous_cols=colnames[-5:],
105
        mlp_hidden_dims=[32, 16],
106 107
    )
    deephead = nn.Sequential(nn.Linear(16, 8), nn.Linear(8, 4))
108
    model = WideDeep(wide=wide, deeptabular=deeptabular, pred_dim=1, deephead=deephead)
109 110 111 112
    trainer = Trainer(model, objective="binary", verbose=0)
    trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target_binary, batch_size=16)
    preds = trainer.predict(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
    probs = trainer.predict_proba(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
P
Pavol Mulinka 已提交
113 114 115 116
    unc_preds = trainer.predict_uncertainty(
        X_wide=X_wide, X_tab=X_tab, X_test=X_test, uncertainty_granularity=5
    )
    assert (preds.shape[0], probs.shape[1], unc_preds.shape[1]) == (32, 2, 3)
117 118 119 120 121 122 123 124


##############################################################################
# Repeat 1st set of tests with the TabTransformer
##############################################################################


@pytest.mark.parametrize(
P
Pavol Mulinka 已提交
125
    "X_wide, X_tab, target, objective, X_wide_test, X_tab_test, X_test, pred_dim, probs_dim, uncertainties_pred_dim",
126
    [
P
Pavol Mulinka 已提交
127 128 129 130 131 132
        (X_wide, X_tab, target_regres, "regression", X_wide, X_tab, None, 1, None, 4),
        (X_wide, X_tab, target_binary, "binary", X_wide, X_tab, None, 1, 2, 3),
        (X_wide, X_tab, target_multic, "multiclass", X_wide, X_tab, None, 3, 3, 4),
        (X_wide, X_tab, target_regres, "regression", None, None, X_test, 1, None, 4),
        (X_wide, X_tab, target_binary, "binary", None, None, X_test, 1, 2, 3),
        (X_wide, X_tab, target_multic, "multiclass", None, None, X_test, 3, 3, 4),
133 134
    ],
)
135
def test_fit_objectives_tab_transformer(
136 137 138
    X_wide,
    X_tab,
    target,
139
    objective,
140 141 142 143 144
    X_wide_test,
    X_tab_test,
    X_test,
    pred_dim,
    probs_dim,
P
Pavol Mulinka 已提交
145
    uncertainties_pred_dim,
146 147 148
):
    wide = Wide(np.unique(X_wide).shape[0], pred_dim)
    tab_transformer = TabTransformer(
149
        column_idx={k: v for v, k in enumerate(colnames)},
150
        cat_embed_input=embed_input_tt,
151 152 153
        continuous_cols=colnames[5:],
    )
    model = WideDeep(wide=wide, deeptabular=tab_transformer, pred_dim=pred_dim)
154 155 156
    trainer = Trainer(model, objective=objective, verbose=0)
    trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, batch_size=16)
    preds = trainer.predict(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
P
Pavol Mulinka 已提交
157 158 159 160 161 162 163 164 165 166
    probs = trainer.predict_proba(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
    unc_preds = trainer.predict_uncertainty(
        X_wide=X_wide, X_tab=X_tab, X_test=X_test, uncertainty_granularity=5
    )
    if objective == "regression":
        assert (preds.shape[0], probs, unc_preds.shape[1]) == (
            32,
            probs_dim,
            uncertainties_pred_dim,
        )
167
    else:
P
Pavol Mulinka 已提交
168 169 170 171 172
        assert (preds.shape[0], probs.shape[1], unc_preds.shape[1]) == (
            32,
            probs_dim,
            uncertainties_pred_dim,
        )
173 174 175 176 177 178 179 180


##############################################################################
# Repeat 1st set of tests with TabNet
##############################################################################


@pytest.mark.parametrize(
P
Pavol Mulinka 已提交
181
    "X_wide, X_tab, target, objective, X_wide_test, X_tab_test, X_test, pred_dim, probs_dim, uncertainties_pred_dim",
182
    [
P
Pavol Mulinka 已提交
183 184 185 186 187 188
        (X_wide, X_tab, target_regres, "regression", X_wide, X_tab, None, 1, None, 4),
        (X_wide, X_tab, target_binary, "binary", X_wide, X_tab, None, 1, 2, 3),
        (X_wide, X_tab, target_multic, "multiclass", X_wide, X_tab, None, 3, 3, 4),
        (X_wide, X_tab, target_regres, "regression", None, None, X_test, 1, None, 4),
        (X_wide, X_tab, target_binary, "binary", None, None, X_test, 1, 2, 3),
        (X_wide, X_tab, target_multic, "multiclass", None, None, X_test, 3, 3, 4),
189 190 191 192 193 194 195 196 197 198 199 200
    ],
)
def test_fit_objectives_tabnet(
    X_wide,
    X_tab,
    target,
    objective,
    X_wide_test,
    X_tab_test,
    X_test,
    pred_dim,
    probs_dim,
P
Pavol Mulinka 已提交
201
    uncertainties_pred_dim,
202 203 204 205 206
):
    warnings.filterwarnings("ignore")
    wide = Wide(np.unique(X_wide).shape[0], pred_dim)
    tab_transformer = TabNet(
        column_idx={k: v for v, k in enumerate(colnames)},
207
        cat_embed_input=embed_input,
208 209 210
        continuous_cols=colnames[5:],
    )
    model = WideDeep(wide=wide, deeptabular=tab_transformer, pred_dim=pred_dim)
211 212 213
    trainer = Trainer(model, objective=objective, verbose=0)
    trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, batch_size=16)
    preds = trainer.predict(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
P
Pavol Mulinka 已提交
214 215 216 217 218 219 220 221 222 223
    probs = trainer.predict_proba(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
    unc_preds = trainer.predict_uncertainty(
        X_wide=X_wide, X_tab=X_tab, X_test=X_test, uncertainty_granularity=5
    )
    if objective == "regression":
        assert (preds.shape[0], probs, unc_preds.shape[1]) == (
            32,
            probs_dim,
            uncertainties_pred_dim,
        )
224
    else:
P
Pavol Mulinka 已提交
225 226 227 228 229
        assert (preds.shape[0], probs.shape[1], unc_preds.shape[1]) == (
            32,
            probs_dim,
            uncertainties_pred_dim,
        )
230 231 232 233 234 235 236 237 238 239 240


##############################################################################
# Test fit with R2 for regression
##############################################################################


def test_fit_with_regression_and_metric():
    wide = Wide(np.unique(X_wide).shape[0], 1)
    deeptabular = TabMlp(
        column_idx=column_idx,
241
        cat_embed_input=embed_input,
242
        continuous_cols=colnames[-5:],
243 244
        mlp_hidden_dims=[32, 16],
        mlp_dropout=[0.5, 0.5],
245 246 247 248 249
    )
    model = WideDeep(wide=wide, deeptabular=deeptabular, pred_dim=1)
    trainer = Trainer(model, objective="regression", metrics=[R2Score], verbose=0)
    trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target_regres, batch_size=16)
    assert "train_r2" in trainer.history.keys()
250 251 252 253 254 255 256 257 258 259 260


##############################################################################
# Test aliases
##############################################################################


def test_aliases():
    wide = Wide(np.unique(X_wide).shape[0], 1)
    deeptabular = TabMlp(
        column_idx=column_idx,
261
        cat_embed_input=embed_input,
262
        continuous_cols=colnames[-5:],
263 264
        mlp_hidden_dims=[32, 16],
        mlp_dropout=[0.5, 0.5],
265 266 267 268 269 270 271 272 273 274 275
    )
    model = WideDeep(wide=wide, deeptabular=deeptabular, pred_dim=1)
    trainer = Trainer(model, loss="regression", verbose=0)
    trainer.fit(
        X_wide=X_wide, X_tab=X_tab, target=target_regres, batch_size=16, warmup=True
    )
    assert (
        "train_loss" in trainer.history.keys()
        and trainer.__wd_aliases_used["objective"] == "loss"
        and trainer.__wd_aliases_used["finetune"] == "warmup"
    )
276 277 278 279 280 281 282 283 284 285 286


##############################################################################
# Test custom dataloader
##############################################################################


def test_custom_dataloader():
    wide = Wide(np.unique(X_wide).shape[0], 1)
    deeptabular = TabMlp(
        column_idx=column_idx,
287
        cat_embed_input=embed_input,
288
        continuous_cols=colnames[-5:],
289 290
        mlp_hidden_dims=[32, 16],
        mlp_dropout=[0.5, 0.5],
291 292 293 294 295 296 297 298 299 300 301 302
    )
    model = WideDeep(wide=wide, deeptabular=deeptabular)
    trainer = Trainer(model, loss="binary", verbose=0)
    trainer.fit(
        X_wide=X_wide,
        X_tab=X_tab,
        target=target_binary_imbalanced,
        batch_size=16,
        custom_dataloader=DataLoaderImbalanced,
    )
    # simply checking that runs with DataLoaderImbalanced
    assert "train_loss" in trainer.history.keys()
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322


##############################################################################
# Test raise warning for multiclass classification
##############################################################################


def test_multiclass_warning():
    wide = Wide(np.unique(X_wide).shape[0], 1)
    deeptabular = TabMlp(
        column_idx=column_idx,
        cat_embed_input=embed_input,
        continuous_cols=colnames[-5:],
        mlp_hidden_dims=[32, 16],
        mlp_dropout=[0.5, 0.5],
    )
    model = WideDeep(wide=wide, deeptabular=deeptabular)

    with pytest.raises(ValueError):
        trainer = Trainer(model, loss="multiclass", verbose=0)  # noqa: F841