提交 dadb7960 编写于 作者: B breezedeus

refactor to OcrModel

上级 2adae2d3
......@@ -13,10 +13,9 @@ EXAMPLE_DIR = Path(__file__).parent.parent / 'examples'
INDEX_DIR = Path(__file__).parent.parent / 'data/test'
IMAGE_DIR = Path(__file__).parent.parent / 'data/images'
from cnocr import gen_model
from cnocr.data_utils.aug import NormalizeAug
from cnocr.dataset import OcrDataModule
from cnocr.models.densenet import DenseNet
from cnocr.models.crnn import CRNN
from cnocr.trainer import PlTrainer
train_transform = transforms.Compose(
......@@ -31,12 +30,6 @@ train_transform = transforms.Compose(
val_transform = NormalizeAug()
def gen_model(vocab):
net = DenseNet(32, [2, 2, 2, 2], 64)
crnn = CRNN(net, vocab=vocab, lstm_features=512, rnn_units=128)
return crnn
def test_trainer():
data_mod = OcrDataModule(
index_dir=INDEX_DIR,
......@@ -48,7 +41,6 @@ def test_trainer():
num_workers=0,
pin_memory=False,
)
# data_mod.setup()
config = {
'epochs': 2,
......@@ -64,5 +56,5 @@ def test_trainer():
"pl_checkpoint_mode": "max",
}
trainer = PlTrainer(config)
model = gen_model(data_mod.vocab)
model = gen_model('densenet-s-lstm', data_mod.vocab)
trainer.fit(model, datamodule=data_mod)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册