diff --git a/cnocr/symbols/crnn.py b/cnocr/symbols/crnn.py index 394189101d301a6c957fe10071eebf67a83c2f53..f5914e891b82578a59e1047a808a93dc72eb585c 100644 --- a/cnocr/symbols/crnn.py +++ b/cnocr/symbols/crnn.py @@ -42,6 +42,7 @@ def gen_network(model_name, hp): else (64, 128, 256, 512) ) densenet = DenseNet(layer_channels) + densenet.hybridize() model = CRnn(hp, densenet) elif model_name.startswith('conv-lite'): hp.seq_len_cmpr_ratio = 4 diff --git a/tests/test_models.py b/tests/test_models.py index 75f303990da9ba64eb7a2ce80fa2a9b4e184d7a2..8c6d834b3d80028ce4e0dd01b35ac1dba7d95e80 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,7 +1,6 @@ # coding: utf-8 import os import sys -import logging from copy import deepcopy import pytest import mxnet as mx @@ -44,7 +43,7 @@ def test_densenet(): net.initialize() y = net(x) logger.info(net) - logger.info(y.shape) # (128, 512, 1, 69) + logger.info(y.shape) # (128, 512, 1, 70) assert y.shape[2] == 1 logger.info('number of parameters: %d', cal_num_params(net)) # 1748224