test_mc_deep_text.py 1.5 KB
Newer Older
1 2 3 4
import numpy as np
import torch
import pytest

5
from pytorch_widedeep.models import DeepText
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21

padded_sequences = np.random.choice(np.arange(1,100), (100, 48))
padded_sequences = np.hstack((np.repeat(np.array([[0,0]]), 100, axis=0), padded_sequences))
pretrained_embeddings = np.random.rand(1000, 64)
vocab_size = 1000

###############################################################################
# Without Pretrained Embeddings
###############################################################################
model1 = DeepText(
    vocab_size=vocab_size,
    embed_dim=32,
    padding_idx=0
    )
def test_deep_test():
	out = model1(torch.from_numpy(padded_sequences))
22
	assert out.size(0)==100 and out.size(1)==64
23 24 25 26 27 28 29 30 31 32 33

###############################################################################
# With Pretrained Embeddings
###############################################################################
model2 = DeepText(
    vocab_size=vocab_size,
    embedding_matrix=pretrained_embeddings,
    padding_idx=0
    )
def test_deep_test_pretrained():
	out = model2(torch.from_numpy(padded_sequences))
34
	assert out.size(0)==100 and out.size(1)==64
35 36 37 38 39 40 41 42 43 44 45 46

###############################################################################
# Make sure it throws a warning
###############################################################################
def test_catch_warning():
	with pytest.warns(UserWarning):
		model3 = DeepText(
		    vocab_size=vocab_size,
		    embed_dim=32,
		    embedding_matrix=pretrained_embeddings,
		    padding_idx=0
		    )