test_mc_deep_text.py 3.1 KB
Newer Older
1 2 3 4
import numpy as np
import torch
import pytest

5
from pytorch_widedeep.models import DeepText
6

J
jrzaurin 已提交
7 8 9 10 11
padded_sequences = np.random.choice(np.arange(1, 100), (100, 48))
padded_sequences = np.hstack(
    (np.repeat(np.array([[0, 0]]), 100, axis=0), padded_sequences)
)
pretrained_embeddings = np.random.rand(1000, 64).astype("float32")
12 13
vocab_size = 1000

J
jrzaurin 已提交
14

15 16 17
###############################################################################
# Without Pretrained Embeddings
###############################################################################
J
jrzaurin 已提交
18 19 20
model1 = DeepText(vocab_size=vocab_size, embed_dim=32, padding_idx=0)


21
def test_deep_text():
J
jrzaurin 已提交
22 23 24
    out = model1(torch.from_numpy(padded_sequences))
    assert out.size(0) == 100 and out.size(1) == 64

25 26 27 28 29

###############################################################################
# With Pretrained Embeddings
###############################################################################
model2 = DeepText(
30
    vocab_size=vocab_size, embed_matrix=pretrained_embeddings, padding_idx=0
J
jrzaurin 已提交
31 32 33
)


34
def test_deep_text_pretrained():
J
jrzaurin 已提交
35 36 37
    out = model2(torch.from_numpy(padded_sequences))
    assert out.size(0) == 100 and out.size(1) == 64

38 39

###############################################################################
40 41
# Make sure it throws a UserWarning when the input embedding dimension and the
# dimension of the pretrained embeddings do not match.
42 43
###############################################################################
def test_catch_warning():
J
jrzaurin 已提交
44 45 46 47
    with pytest.warns(UserWarning):
        model3 = DeepText(
            vocab_size=vocab_size,
            embed_dim=32,
48
            embed_matrix=pretrained_embeddings,
J
jrzaurin 已提交
49 50 51 52
            padding_idx=0,
        )
    out = model3(torch.from_numpy(padded_sequences))
    assert out.size(0) == 100 and out.size(1) == 64
53 54 55 56 57 58 59


###############################################################################
# Without Pretrained Embeddings and head layers
###############################################################################

model4 = DeepText(
60
    vocab_size=vocab_size, embed_dim=32, padding_idx=0, head_layers_dim=[64, 16]
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
)


def test_deep_text_head_layers():
    out = model4(torch.from_numpy(padded_sequences))
    assert out.size(0) == 100 and out.size(1) == 16


###############################################################################
# Without Pretrained Embeddings, bidirectional
###############################################################################

model5 = DeepText(
    vocab_size=vocab_size, embed_dim=32, padding_idx=0, bidirectional=True
)


def test_deep_text_bidirectional():
    out = model1(torch.from_numpy(padded_sequences))
    assert out.size(0) == 100 and out.size(1) == 64
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97


###############################################################################
# Pretrained Embeddings made non-trainable
###############################################################################

model6 = DeepText(
    vocab_size=vocab_size,
    embed_matrix=pretrained_embeddings,
    embed_trainable=False,
    padding_idx=0,
)


def test_embed_non_trainable():
    out = model6(torch.from_numpy(padded_sequences))  # noqa: F841
    assert np.allclose(model6.word_embed.weight.numpy(), pretrained_embeddings)