提交 b2dcae59 编写于 作者: B breezedeus

replace last maxpooling with conv to keep the image length same

上级 a2ab13e2
......@@ -35,7 +35,7 @@ def gen_network(model_name, hp):
model_name = model_name.lower()
if model_name.startswith('densenet'):
hp.seq_len_cmpr_ratio = 4
hp.set_seq_length(hp.img_width // 4 - 1)
hp.set_seq_length(hp.img_width // 4)
layer_channels = (
(32, 64, 128, 256)
if model_name.startswith('densenet-lite')
......@@ -288,7 +288,7 @@ def crnn_lstm_lite(hp, data):
# print('4', net.infer_shape()[1])
net = bottle_conv(4, net, kernel_size[4], layer_size[4], padding_size[4])
net = bottle_conv(5, net, kernel_size[5], layer_size[5], padding_size[5], True) + x
# res: bz x 512 x 1 x 35,高度变成1的原因是pooling后没用padding
# res: bz x 512 x 4 x 69,长度从70变成69的原因是pooling后没用padding
net = mx.symbol.Pooling(
data=net, name='pool-2', pool_type='max', kernel=(2, 2), stride=(2, 1)
)
......
......@@ -29,11 +29,12 @@ logger = logging.getLogger(__name__)
def cal_num_params(net):
import numpy as np
params = [p for p in net.collect_params().values()]
for p in params:
logger.info(p)
total = sum([np.prod(p.shape) for p in params])
logger.info('total params: %d' % total)
logger.info('total params: %d', total)
return total
......@@ -70,15 +71,6 @@ def _make_residual(cell_net):
return out
def _make_transition(num_output_features, strides=2):
out = nn.HybridSequential(prefix='')
out.add(nn.BatchNorm())
out.add(nn.Activation('relu'))
out.add(nn.Conv2D(num_output_features, kernel_size=1, use_bias=False))
out.add(nn.MaxPool2D(pool_size=2, strides=strides))
return out
class DenseNet(HybridBlock):
r"""Densenet-BC model from the
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ paper.
......@@ -99,13 +91,16 @@ class DenseNet(HybridBlock):
classes : int, default 1000
Number of classification classes.
"""
def __init__(self, layer_channels, **kwargs):
assert len(layer_channels) == 4
super(DenseNet, self).__init__(**kwargs)
with self.name_scope():
# Stage 0
self.features = nn.HybridSequential(prefix='')
self.features.add(_make_first_stage_net((layer_channels[0], layer_channels[1])))
self.features.add(
_make_first_stage_net((layer_channels[0], layer_channels[1]))
)
self.features.add(_make_transition(layer_channels[1]))
# self.features.add(nn.Conv2D(num_init_features, kernel_size=3,
# strides=1, padding=1, use_bias=False))
......@@ -115,14 +110,17 @@ class DenseNet(HybridBlock):
# Add dense blocks
# Stage 1
self.features.add(_make_inter_stage_net(1, num_layers=2, growth_rate=layer_channels[0]))
self.features.add(
_make_inter_stage_net(1, num_layers=2, growth_rate=layer_channels[0])
)
self.features.add(_make_transition(layer_channels[2]))
# Stage 2
self.features.add(_make_inter_stage_net(2, num_layers=2, growth_rate=layer_channels[1]))
self.features.add(_make_transition(layer_channels[3], strides=(2, 1)))
# self.features.add(nn.MaxPool2D(pool_size=2, strides=(2, 1)))
# self.features.add(_make_transition(512))
self.features.add(
_make_inter_stage_net(2, num_layers=2, growth_rate=layer_channels[1])
)
# self.features.add(_make_transition(layer_channels[3], strides=(2, 1)))
self.features.add(_make_last_transition(layer_channels[3]))
# Stage 3
self.features.add(_make_final_stage_net(3, out_channels=layer_channels[3]))
......@@ -150,19 +148,61 @@ class DenseNet(HybridBlock):
def _make_first_stage_net(out_channels):
features = nn.HybridSequential(prefix='stage%d_' % 0)
with features.name_scope():
features.add(nn.Conv2D(out_channels[0], kernel_size=3,
strides=1, padding=1, use_bias=False))
features.add(
nn.Conv2D(
out_channels[0], kernel_size=3, strides=1, padding=1, use_bias=False
)
)
features.add(nn.BatchNorm())
features.add(nn.Activation('relu'))
features.add(nn.Conv2D(out_channels[1], kernel_size=3,
strides=1, padding=1, use_bias=False))
features.add(
nn.Conv2D(
out_channels[1], kernel_size=3, strides=1, padding=1, use_bias=False
)
)
# features.add(nn.BatchNorm())
# features.add(nn.Activation('relu'))
return _make_residual(features)
def _make_inter_stage_net(stage_index, num_layers=2, growth_rate=128):
return _make_dense_block(num_layers, bn_size=2, growth_rate=growth_rate, dropout=0.0, stage_index=stage_index)
return _make_dense_block(
num_layers,
bn_size=2,
growth_rate=growth_rate,
dropout=0.0,
stage_index=stage_index,
)
def _make_transition(num_output_features, strides=2):
out = nn.HybridSequential(prefix='')
out.add(nn.BatchNorm())
out.add(nn.Activation('relu'))
out.add(nn.Conv2D(num_output_features, kernel_size=1, use_bias=False))
out.add(nn.MaxPool2D(pool_size=2, strides=strides))
return out
def _make_last_transition(num_output_features):
out = nn.HybridSequential(prefix='last_trans_')
with out.name_scope():
out.add(nn.BatchNorm())
out.add(nn.Activation('relu'))
out.add(nn.Conv2D(num_output_features, kernel_size=1, use_bias=False))
out.add(nn.Activation('relu'))
out.add(
nn.Conv2D(
num_output_features,
groups=num_output_features,
kernel_size=(2, 3),
strides=(2, 1),
padding=(0, 1),
use_bias=False,
) # input shape: (8, 70), output shape: (4, 70)
)
# out.add(nn.MaxPool2D(pool_size=2, strides=strides))
return out
def _make_final_stage_net(stage_index, out_channels):
......
......@@ -11,6 +11,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))
from cnocr.consts import EMB_MODEL_TYPES, SEQ_MODEL_TYPES
from cnocr.utils import set_logger
from cnocr.hyperparams.cn_hyperparams import CnHyperparams
from cnocr.symbols.densenet import _make_dense_layer, DenseNet, cal_num_params
from cnocr.symbols.crnn import (
......@@ -22,9 +23,7 @@ from cnocr.symbols.crnn import (
crnn_lstm_lite,
)
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
logger = logging.getLogger(__name__)
logger = set_logger('info')
HP = CnHyperparams()
......@@ -52,7 +51,7 @@ def test_densenet():
def test_crnn():
_hp = deepcopy(HP)
_hp.set_seq_length(_hp.img_width // 4 - 1)
_hp.set_seq_length(_hp.img_width // 4)
x = nd.random.randn(128, 64, 32, 280)
layer_channels_list = [(64, 128, 256, 512), (32, 64, 128, 256)]
for layer_channels in layer_channels_list:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册