test_mc_tab_mlp.py 2.6 KB
Newer Older
1
import string
2

3 4
import numpy as np
import torch
5
import pytest
6

7
from pytorch_widedeep.models import TabMlp
8 9 10 11 12

colnames = list(string.ascii_lowercase)[:10]
embed_cols = [np.random.choice(np.arange(5), 10) for _ in range(5)]
cont_cols = [np.random.rand(10) for _ in range(5)]

J
jrzaurin 已提交
13
X_deep = torch.from_numpy(np.vstack(embed_cols + cont_cols).transpose())
14 15 16
X_deep_emb = X_deep[:, :5]
X_deep_cont = X_deep[:, 5:]

J
jrzaurin 已提交
17

18 19 20
###############################################################################
# Embeddings and NO continuous_cols
###############################################################################
J
jrzaurin 已提交
21
embed_input = [(u, i, j) for u, i, j in zip(colnames[:5], [5] * 5, [16] * 5)]
22 23
model1 = TabMlp(
    mlp_hidden_dims=[32, 16],
24 25
    mlp_dropout=[0.5, 0.2],
    column_idx={k: v for v, k in enumerate(colnames[:5])},
J
jrzaurin 已提交
26 27 28
    embed_input=embed_input,
)

29 30

def test_deep_dense_embed():
J
jrzaurin 已提交
31 32 33
    out = model1(X_deep_emb)
    assert out.size(0) == 10 and out.size(1) == 16

34 35 36 37

###############################################################################
# Continous cols but NO embeddings
###############################################################################
J
jrzaurin 已提交
38
continuous_cols = colnames[-5:]
39 40
model2 = TabMlp(
    mlp_hidden_dims=[32, 16],
41 42
    mlp_dropout=[0.5, 0.2],
    column_idx={k: v for v, k in enumerate(colnames[5:])},
J
jrzaurin 已提交
43 44 45
    continuous_cols=continuous_cols,
)

46 47

def test_deep_dense_cont():
J
jrzaurin 已提交
48 49 50
    out = model2(X_deep_cont)
    assert out.size(0) == 10 and out.size(1) == 16

51 52

###############################################################################
53
# All parameters
54
###############################################################################
55
model3 = TabMlp(
56
    column_idx={k: v for v, k in enumerate(colnames)},
57
    mlp_hidden_dims=[32, 16, 8],
58
    mlp_dropout=0.1,
59 60
    mlp_batchnorm=True,
    mlp_batchnorm_last=False,
61
    mlp_linear_first=False,
J
jrzaurin 已提交
62
    embed_input=embed_input,
63
    embed_dropout=0.1,
J
jrzaurin 已提交
64
    continuous_cols=continuous_cols,
65
    batchnorm_cont=True,
J
jrzaurin 已提交
66 67
)

68 69

def test_deep_dense():
J
jrzaurin 已提交
70
    out = model3(X_deep)
71 72 73 74 75 76 77 78 79 80 81 82
    assert out.size(0) == 10 and out.size(1) == 8


###############################################################################
# Test raise ValueError
###############################################################################


def test_act_fn_ValueError():
    with pytest.raises(ValueError):
        model4 = TabMlp(  # noqa: F841
            mlp_hidden_dims=[32, 16],
83 84 85
            mlp_dropout=[0.5, 0.2],
            mlp_activation="javier",
            column_idx={k: v for v, k in enumerate(colnames)},
86 87 88
            embed_input=embed_input,
            continuous_cols=continuous_cols,
        )