diff --git a/tests/test_trainer.py b/tests/test_trainer.py index f1f5c876314337a845721f352e3ec3767b0fa4ec..8c83008c5718f7f5eb79941c2a2a2a27563a4043 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -13,10 +13,9 @@ EXAMPLE_DIR = Path(__file__).parent.parent / 'examples' INDEX_DIR = Path(__file__).parent.parent / 'data/test' IMAGE_DIR = Path(__file__).parent.parent / 'data/images' +from cnocr import gen_model from cnocr.data_utils.aug import NormalizeAug from cnocr.dataset import OcrDataModule -from cnocr.models.densenet import DenseNet -from cnocr.models.crnn import CRNN from cnocr.trainer import PlTrainer train_transform = transforms.Compose( @@ -31,12 +30,6 @@ train_transform = transforms.Compose( val_transform = NormalizeAug() -def gen_model(vocab): - net = DenseNet(32, [2, 2, 2, 2], 64) - crnn = CRNN(net, vocab=vocab, lstm_features=512, rnn_units=128) - return crnn - - def test_trainer(): data_mod = OcrDataModule( index_dir=INDEX_DIR, @@ -48,7 +41,6 @@ def test_trainer(): num_workers=0, pin_memory=False, ) - # data_mod.setup() config = { 'epochs': 2, @@ -64,5 +56,5 @@ def test_trainer(): "pl_checkpoint_mode": "max", } trainer = PlTrainer(config) - model = gen_model(data_mod.vocab) + model = gen_model('densenet-s-lstm', data_mod.vocab) trainer.fit(model, datamodule=data_mod)