test_fit_methods.py 8.5 KB
Newer Older
1
import string
2
import warnings
3 4

import numpy as np
5
import pytest
6
from torch import nn
7

8 9 10 11 12 13 14
from pytorch_widedeep.models import (
    Wide,
    TabMlp,
    TabNet,
    WideDeep,
    TabTransformer,
)
15
from pytorch_widedeep.metrics import R2Score
16
from pytorch_widedeep.training import Trainer
17 18

# Wide array
19
X_wide = np.random.choice(50, (32, 10))
20 21

# Deep Array
J
jrzaurin 已提交
22
colnames = list(string.ascii_lowercase)[:10]
23
embed_cols = [np.random.choice(np.arange(5), 32) for _ in range(5)]
J
jrzaurin 已提交
24
embed_input = [(u, i, j) for u, i, j in zip(colnames[:5], [5] * 5, [16] * 5)]
25
embed_input_tt = [(u, i) for u, i in zip(colnames[:5], [5] * 5)]
26
cont_cols = [np.random.rand(32) for _ in range(5)]
27
column_idx = {k: v for v, k in enumerate(colnames)}
28
X_tab = np.vstack(embed_cols + cont_cols).transpose()
29 30

# Target
31 32 33
target_regres = np.random.random(32)
target_binary = np.random.choice(2, 32)
target_multic = np.random.choice(3, 32)
34 35

# Test dictionary
36
X_test = {"X_wide": X_wide, "X_tab": X_tab}
J
jrzaurin 已提交
37

38 39

##############################################################################
40
# Test that the three possible methods (regression, binary and mutliclass)
41 42 43
# work well
##############################################################################
@pytest.mark.parametrize(
44
    "X_wide, X_tab, target, objective, X_wide_test, X_tab_test, X_test, pred_dim, probs_dim",
45
    [
46 47 48 49 50 51
        (X_wide, X_tab, target_regres, "regression", X_wide, X_tab, None, 1, None),
        (X_wide, X_tab, target_binary, "binary", X_wide, X_tab, None, 1, 2),
        (X_wide, X_tab, target_multic, "multiclass", X_wide, X_tab, None, 3, 3),
        (X_wide, X_tab, target_regres, "regression", None, None, X_test, 1, None),
        (X_wide, X_tab, target_binary, "binary", None, None, X_test, 1, 2),
        (X_wide, X_tab, target_multic, "multiclass", None, None, X_test, 3, 3),
J
jrzaurin 已提交
52 53
    ],
)
54
def test_fit_objectives(
J
jrzaurin 已提交
55
    X_wide,
56
    X_tab,
J
jrzaurin 已提交
57
    target,
58
    objective,
J
jrzaurin 已提交
59
    X_wide_test,
60
    X_tab_test,
J
jrzaurin 已提交
61
    X_test,
62
    pred_dim,
J
jrzaurin 已提交
63 64
    probs_dim,
):
65
    wide = Wide(np.unique(X_wide).shape[0], pred_dim)
66 67
    deeptabular = TabMlp(
        mlp_hidden_dims=[32, 16],
68 69
        mlp_dropout=[0.5, 0.5],
        column_idx=column_idx,
70
        embed_input=embed_input,
J
jrzaurin 已提交
71 72
        continuous_cols=colnames[-5:],
    )
73
    model = WideDeep(wide=wide, deeptabular=deeptabular, pred_dim=pred_dim)
74 75 76 77
    trainer = Trainer(model, objective=objective, verbose=0)
    trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, batch_size=16)
    preds = trainer.predict(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
    if objective == "binary":
78 79
        pass
    else:
80
        probs = trainer.predict_proba(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
81 82 83 84 85 86 87 88
    assert preds.shape[0] == 32, probs.shape[1] == probs_dim


##############################################################################
# Simply Test that runs with the deephead parameter
##############################################################################
def test_fit_with_deephead():
    wide = Wide(np.unique(X_wide).shape[0], 1)
89 90
    deeptabular = TabMlp(
        mlp_hidden_dims=[32, 16],
91
        column_idx=column_idx,
92 93 94 95
        embed_input=embed_input,
        continuous_cols=colnames[-5:],
    )
    deephead = nn.Sequential(nn.Linear(16, 8), nn.Linear(8, 4))
96
    model = WideDeep(wide=wide, deeptabular=deeptabular, pred_dim=1, deephead=deephead)
97 98 99 100
    trainer = Trainer(model, objective="binary", verbose=0)
    trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target_binary, batch_size=16)
    preds = trainer.predict(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
    probs = trainer.predict_proba(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
101
    assert preds.shape[0] == 32, probs.shape[1] == 2
102 103 104 105 106 107 108 109


##############################################################################
# Repeat 1st set of tests with the TabTransformer
##############################################################################


@pytest.mark.parametrize(
110
    "X_wide, X_tab, target, objective, X_wide_test, X_tab_test, X_test, pred_dim, probs_dim",
111 112 113 114 115 116 117 118 119
    [
        (X_wide, X_tab, target_regres, "regression", X_wide, X_tab, None, 1, None),
        (X_wide, X_tab, target_binary, "binary", X_wide, X_tab, None, 1, 2),
        (X_wide, X_tab, target_multic, "multiclass", X_wide, X_tab, None, 3, 3),
        (X_wide, X_tab, target_regres, "regression", None, None, X_test, 1, None),
        (X_wide, X_tab, target_binary, "binary", None, None, X_test, 1, 2),
        (X_wide, X_tab, target_multic, "multiclass", None, None, X_test, 3, 3),
    ],
)
120
def test_fit_objectives_tab_transformer(
121 122 123
    X_wide,
    X_tab,
    target,
124
    objective,
125 126 127 128 129 130 131 132
    X_wide_test,
    X_tab_test,
    X_test,
    pred_dim,
    probs_dim,
):
    wide = Wide(np.unique(X_wide).shape[0], pred_dim)
    tab_transformer = TabTransformer(
133
        column_idx={k: v for v, k in enumerate(colnames)},
134 135 136 137
        embed_input=embed_input_tt,
        continuous_cols=colnames[5:],
    )
    model = WideDeep(wide=wide, deeptabular=tab_transformer, pred_dim=pred_dim)
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
    trainer = Trainer(model, objective=objective, verbose=0)
    trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, batch_size=16)
    preds = trainer.predict(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
    if objective == "binary":
        pass
    else:
        probs = trainer.predict_proba(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
    assert preds.shape[0] == 32, probs.shape[1] == probs_dim


##############################################################################
# Repeat 1st set of tests with TabNet
##############################################################################


@pytest.mark.parametrize(
    "X_wide, X_tab, target, objective, X_wide_test, X_tab_test, X_test, pred_dim, probs_dim",
    [
        (X_wide, X_tab, target_regres, "regression", X_wide, X_tab, None, 1, None),
        (X_wide, X_tab, target_binary, "binary", X_wide, X_tab, None, 1, 2),
        (X_wide, X_tab, target_multic, "multiclass", X_wide, X_tab, None, 3, 3),
        (X_wide, X_tab, target_regres, "regression", None, None, X_test, 1, None),
        (X_wide, X_tab, target_binary, "binary", None, None, X_test, 1, 2),
        (X_wide, X_tab, target_multic, "multiclass", None, None, X_test, 3, 3),
    ],
)
def test_fit_objectives_tabnet(
    X_wide,
    X_tab,
    target,
    objective,
    X_wide_test,
    X_tab_test,
    X_test,
    pred_dim,
    probs_dim,
):
    warnings.filterwarnings("ignore")
    wide = Wide(np.unique(X_wide).shape[0], pred_dim)
    tab_transformer = TabNet(
        column_idx={k: v for v, k in enumerate(colnames)},
        embed_input=embed_input,
        continuous_cols=colnames[5:],
    )
    model = WideDeep(wide=wide, deeptabular=tab_transformer, pred_dim=pred_dim)
183 184 185 186
    trainer = Trainer(model, objective=objective, verbose=0)
    trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, batch_size=16)
    preds = trainer.predict(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
    if objective == "binary":
187 188
        pass
    else:
189
        probs = trainer.predict_proba(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
190
    assert preds.shape[0] == 32, probs.shape[1] == probs_dim
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210


##############################################################################
# Test fit with R2 for regression
##############################################################################


def test_fit_with_regression_and_metric():
    wide = Wide(np.unique(X_wide).shape[0], 1)
    deeptabular = TabMlp(
        mlp_hidden_dims=[32, 16],
        mlp_dropout=[0.5, 0.5],
        column_idx=column_idx,
        embed_input=embed_input,
        continuous_cols=colnames[-5:],
    )
    model = WideDeep(wide=wide, deeptabular=deeptabular, pred_dim=1)
    trainer = Trainer(model, objective="regression", metrics=[R2Score], verbose=0)
    trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target_regres, batch_size=16)
    assert "train_r2" in trainer.history.keys()
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236


##############################################################################
# Test aliases
##############################################################################


def test_aliases():
    wide = Wide(np.unique(X_wide).shape[0], 1)
    deeptabular = TabMlp(
        mlp_hidden_dims=[32, 16],
        mlp_dropout=[0.5, 0.5],
        column_idx=column_idx,
        embed_input=embed_input,
        continuous_cols=colnames[-5:],
    )
    model = WideDeep(wide=wide, deeptabular=deeptabular, pred_dim=1)
    trainer = Trainer(model, loss="regression", verbose=0)
    trainer.fit(
        X_wide=X_wide, X_tab=X_tab, target=target_regres, batch_size=16, warmup=True
    )
    assert (
        "train_loss" in trainer.history.keys()
        and trainer.__wd_aliases_used["objective"] == "loss"
        and trainer.__wd_aliases_used["finetune"] == "warmup"
    )