提交 041490b3 编写于 作者: B breezedeus

add `hybridize()`

上级 e806e453
......@@ -42,6 +42,7 @@ def gen_network(model_name, hp):
else (64, 128, 256, 512)
)
densenet = DenseNet(layer_channels)
densenet.hybridize()
model = CRnn(hp, densenet)
elif model_name.startswith('conv-lite'):
hp.seq_len_cmpr_ratio = 4
......
# coding: utf-8
import os
import sys
import logging
from copy import deepcopy
import pytest
import mxnet as mx
......@@ -44,7 +43,7 @@ def test_densenet():
net.initialize()
y = net(x)
logger.info(net)
logger.info(y.shape) # (128, 512, 1, 69)
logger.info(y.shape) # (128, 512, 1, 70)
assert y.shape[2] == 1
logger.info('number of parameters: %d', cal_num_params(net)) # 1748224
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册