test_mc_deep_text.py 1.7 KB
Newer Older
1 2 3 4
import numpy as np
import torch
import pytest

5
from pytorch_widedeep.models import DeepText
6

J
jrzaurin 已提交
7 8 9 10 11
padded_sequences = np.random.choice(np.arange(1, 100), (100, 48))
padded_sequences = np.hstack(
    (np.repeat(np.array([[0, 0]]), 100, axis=0), padded_sequences)
)
pretrained_embeddings = np.random.rand(1000, 64).astype("float32")
12 13
vocab_size = 1000

J
jrzaurin 已提交
14

15 16 17
###############################################################################
# Without Pretrained Embeddings
###############################################################################
J
jrzaurin 已提交
18 19 20
model1 = DeepText(vocab_size=vocab_size, embed_dim=32, padding_idx=0)


21
def test_deep_test():
J
jrzaurin 已提交
22 23 24
    out = model1(torch.from_numpy(padded_sequences))
    assert out.size(0) == 100 and out.size(1) == 64

25 26 27 28 29

###############################################################################
# With Pretrained Embeddings
###############################################################################
model2 = DeepText(
J
jrzaurin 已提交
30 31 32 33
    vocab_size=vocab_size, embedding_matrix=pretrained_embeddings, padding_idx=0
)


34
def test_deep_test_pretrained():
J
jrzaurin 已提交
35 36 37
    out = model2(torch.from_numpy(padded_sequences))
    assert out.size(0) == 100 and out.size(1) == 64

38 39 40 41 42

###############################################################################
# Make sure it throws a warning
###############################################################################
def test_catch_warning():
J
jrzaurin 已提交
43 44 45 46 47 48 49 50 51
    with pytest.warns(UserWarning):
        model3 = DeepText(
            vocab_size=vocab_size,
            embed_dim=32,
            embedding_matrix=pretrained_embeddings,
            padding_idx=0,
        )
    out = model3(torch.from_numpy(padded_sequences))
    assert out.size(0) == 100 and out.size(1) == 64