test_models.py 3.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
# coding: utf-8
import os
import sys
import logging
from copy import deepcopy
import pytest
import mxnet as mx
from mxnet import nd

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))

from cnocr.hyperparams.cn_hyperparams import CnHyperparams
from cnocr.symbols.densenet import _make_dense_layer, DenseNet, cal_num_params
from cnocr.symbols.crnn import (
    CRnn,
    pipline,
    gen_network,
    get_infer_shape,
    crnn_lstm,
    crnn_lstm_lite,
)

head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
logger = logging.getLogger(__name__)

HP = CnHyperparams()


def test_dense_layer():
    x = nd.random.randn(128, 64, 32, 280)
    net = _make_dense_layer(64, 2, 0.1)
    net.initialize()
    y = net(x)
    logger.info(net)
    logger.info(y.shape)


def test_densenet():
    x = nd.random.randn(128, 64, 32, 280)
    layer_channels = (64, 128, 256, 512)
    net = DenseNet(layer_channels)
    net.initialize()
    y = net(x)
    logger.info(net)
    logger.info(y.shape)  # (128, 512, 1, 69)
    assert y.shape[2] == 1
    logger.info('number of parameters: %d', cal_num_params(net))  # 1748224


def test_crnn():
    _hp = deepcopy(HP)
    _hp.set_seq_length(_hp.img_width // 4 - 1)
    x = nd.random.randn(128, 64, 32, 280)
    layer_channels_list = [(64, 128, 256, 512), (32, 64, 128, 256)]
    for layer_channels in layer_channels_list:
        densenet = DenseNet(layer_channels)
        crnn = CRnn(_hp, densenet)
        crnn.initialize()
        y = crnn(x)
        logger.info(
            'output shape: %s', y.shape
        )  # res: `(sequence_length, batch_size, 2*num_hidden)`
        assert y.shape == (_hp.seq_length, _hp.batch_size, 2 * _hp.num_hidden)
        logger.info('number of parameters: %d', cal_num_params(crnn))


def test_crnn_lstm():
    hp = deepcopy(HP)
    hp.set_seq_length(hp.img_width // 8)
    data = mx.sym.Variable('data', shape=(128, 1, 32, 280))
    pred = crnn_lstm(HP, data)
    pred_shape = pred.infer_shape()[1][0]
    logger.info('shape of pred: %s', pred_shape)
    assert pred_shape == (hp.seq_length, hp.batch_size, 2 * hp.num_hidden)


def test_crnn_lstm_lite():
    hp = deepcopy(HP)
    hp.set_seq_length(hp.img_width // 4 - 1)
    data = mx.sym.Variable('data', shape=(128, 1, 32, 280))
    pred = crnn_lstm_lite(HP, data)
    pred_shape = pred.infer_shape()[1][0]
    logger.info('shape of pred: %s', pred_shape)
    assert pred_shape == (hp.seq_length, hp.batch_size, 2 * hp.num_hidden)


def test_pipline():
    hp = deepcopy(HP)
    hp.set_seq_length(hp.img_width // 4 - 1)
    layer_channels_list = [(64, 128, 256, 512), (32, 64, 128, 256)]
    for layer_channels in layer_channels_list:
        densenet = DenseNet(layer_channels)
        crnn = CRnn(hp, densenet)
        data = mx.sym.Variable('data', shape=(128, 1, 32, 280))
        loss, pred = pipline(crnn, hp, data, need_pred=True)
        pred_shape = pred.infer_shape()[1][0]
        logger.info('shape of pred: %s', pred_shape)
        assert pred_shape == (hp.batch_size * hp.seq_length, hp.num_classes)


@pytest.mark.parametrize(
    'model_name', ['conv-rnn', 'conv-rnn-lite', 'densenet-rnn', 'densenet-rnn-lite']
)
def test_gen_networks(model_name):
    logger.info('model_name: %s', model_name)
    network, hp = gen_network(model_name, HP)
    shape_dict = get_infer_shape(network, HP)
    logger.info('shape_dict: %s', shape_dict)
    assert shape_dict['pred_fc_output'] == (
        hp.batch_size * hp.seq_length,
        hp.num_classes,
    )