test_fit_methods.py 5.0 KB
Newer Older
1
import string
2 3

import numpy as np
4
import pytest
5
from torch import nn
6

7
from pytorch_widedeep.models import Wide, WideDeep, DeepDense, TabTransformer
8 9

# Wide array
10
X_wide = np.random.choice(50, (32, 10))
11 12

# Deep Array
J
jrzaurin 已提交
13
colnames = list(string.ascii_lowercase)[:10]
14
embed_cols = [np.random.choice(np.arange(5), 32) for _ in range(5)]
J
jrzaurin 已提交
15
embed_input = [(u, i, j) for u, i, j in zip(colnames[:5], [5] * 5, [16] * 5)]
16
embed_input_tt = [(u, i) for u, i in zip(colnames[:5], [5] * 5)]
17
cont_cols = [np.random.rand(32) for _ in range(5)]
J
jrzaurin 已提交
18
deep_column_idx = {k: v for v, k in enumerate(colnames)}
19
X_tab = np.vstack(embed_cols + cont_cols).transpose()
20 21

# Target
22 23 24
target_regres = np.random.random(32)
target_binary = np.random.choice(2, 32)
target_multic = np.random.choice(3, 32)
25 26

# Test dictionary
27
X_test = {"X_wide": X_wide, "X_tab": X_tab}
J
jrzaurin 已提交
28

29 30

##############################################################################
31
# Test that the three possible methods (regression, binary and mutliclass)
32 33 34
# work well
##############################################################################
@pytest.mark.parametrize(
35
    "X_wide, X_tab, target, method, X_wide_test, X_tab_test, X_test, pred_dim, probs_dim",
36
    [
37 38 39 40 41 42
        (X_wide, X_tab, target_regres, "regression", X_wide, X_tab, None, 1, None),
        (X_wide, X_tab, target_binary, "binary", X_wide, X_tab, None, 1, 2),
        (X_wide, X_tab, target_multic, "multiclass", X_wide, X_tab, None, 3, 3),
        (X_wide, X_tab, target_regres, "regression", None, None, X_test, 1, None),
        (X_wide, X_tab, target_binary, "binary", None, None, X_test, 1, 2),
        (X_wide, X_tab, target_multic, "multiclass", None, None, X_test, 3, 3),
J
jrzaurin 已提交
43 44 45 46
    ],
)
def test_fit_methods(
    X_wide,
47
    X_tab,
J
jrzaurin 已提交
48 49 50
    target,
    method,
    X_wide_test,
51
    X_tab_test,
J
jrzaurin 已提交
52
    X_test,
53
    pred_dim,
J
jrzaurin 已提交
54 55
    probs_dim,
):
56
    wide = Wide(np.unique(X_wide).shape[0], pred_dim)
57
    deeptabular = DeepDense(
J
jrzaurin 已提交
58
        hidden_layers=[32, 16],
59
        dropout=[0.5, 0.5],
60 61
        deep_column_idx=deep_column_idx,
        embed_input=embed_input,
J
jrzaurin 已提交
62 63
        continuous_cols=colnames[-5:],
    )
64
    model = WideDeep(wide=wide, deeptabular=deeptabular, pred_dim=pred_dim)
65
    model.compile(method=method, verbose=0)
66 67
    model.fit(X_wide=X_wide, X_tab=X_tab, target=target, batch_size=16)
    preds = model.predict(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
J
jrzaurin 已提交
68
    if method == "binary":
69 70
        pass
    else:
71
        probs = model.predict_proba(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
72 73 74 75 76 77 78 79
    assert preds.shape[0] == 32, probs.shape[1] == probs_dim


##############################################################################
# Simply Test that runs with the deephead parameter
##############################################################################
def test_fit_with_deephead():
    wide = Wide(np.unique(X_wide).shape[0], 1)
80
    deeptabular = DeepDense(
81 82 83 84 85 86
        hidden_layers=[32, 16],
        deep_column_idx=deep_column_idx,
        embed_input=embed_input,
        continuous_cols=colnames[-5:],
    )
    deephead = nn.Sequential(nn.Linear(16, 8), nn.Linear(8, 4))
87
    model = WideDeep(wide=wide, deeptabular=deeptabular, pred_dim=1, deephead=deephead)
88
    model.compile(method="binary", verbose=0)
89 90 91
    model.fit(X_wide=X_wide, X_tab=X_tab, target=target_binary, batch_size=16)
    preds = model.predict(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
    probs = model.predict_proba(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
92
    assert preds.shape[0] == 32, probs.shape[1] == 2
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136


##############################################################################
# Repeat 1st set of tests with the TabTransformer
##############################################################################


@pytest.mark.parametrize(
    "X_wide, X_tab, target, method, X_wide_test, X_tab_test, X_test, pred_dim, probs_dim",
    [
        (X_wide, X_tab, target_regres, "regression", X_wide, X_tab, None, 1, None),
        (X_wide, X_tab, target_binary, "binary", X_wide, X_tab, None, 1, 2),
        (X_wide, X_tab, target_multic, "multiclass", X_wide, X_tab, None, 3, 3),
        (X_wide, X_tab, target_regres, "regression", None, None, X_test, 1, None),
        (X_wide, X_tab, target_binary, "binary", None, None, X_test, 1, 2),
        (X_wide, X_tab, target_multic, "multiclass", None, None, X_test, 3, 3),
    ],
)
def test_fit_methods_tab_transformer(
    X_wide,
    X_tab,
    target,
    method,
    X_wide_test,
    X_tab_test,
    X_test,
    pred_dim,
    probs_dim,
):
    wide = Wide(np.unique(X_wide).shape[0], pred_dim)
    tab_transformer = TabTransformer(
        deep_column_idx={k: v for v, k in enumerate(colnames)},
        embed_input=embed_input_tt,
        continuous_cols=colnames[5:],
    )
    model = WideDeep(wide=wide, deeptabular=tab_transformer, pred_dim=pred_dim)
    model.compile(method=method, verbose=0)
    model.fit(X_wide=X_wide, X_tab=X_tab, target=target, batch_size=16)
    preds = model.predict(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
    if method == "binary":
        pass
    else:
        probs = model.predict_proba(X_wide=X_wide, X_tab=X_tab, X_test=X_test)
    assert preds.shape[0] == 32, probs.shape[1] == probs_dim