提交 d52e296d 编写于 作者: B breezedeus

feat: support mobilenet

上级 bec885b4
# 可取值:['densenet-s']
ENCODER_NAME = densenet-lite-136
ENCODER_NAME = mobilenetv3_tiny
# 可取值:['fc', 'gru', 'lstm']
DECODER_NAME = fclite
DECODER_NAME = fc
MODEL_NAME = $(ENCODER_NAME)-$(DECODER_NAME)
EPOCH = 41
......
......@@ -57,7 +57,7 @@ class CnOcr(object):
def __init__(
self,
model_name: str = 'densenet-s-fc',
model_name: str = 'densenet_lite_124-fc',
*,
cand_alphabet: Optional[Union[Collection, str]] = None,
context: str = 'cpu', # ['cpu', 'gpu', 'cuda']
......@@ -69,7 +69,7 @@ class CnOcr(object):
识别模型初始化函数。
Args:
model_name (str): 模型名称。默认为 `densenet-s-fc`
model_name (str): 模型名称。默认为 `densenet_lite_124-fc`
cand_alphabet (Optional[Union[Collection, str]]): 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
context (str): 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为 `cpu`
model_fp (Optional[str]): 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件)
......@@ -83,10 +83,10 @@ class CnOcr(object):
>>> ocr = CnOcr()
使用指定模型:
>>> ocr = CnOcr(model_name='densenet-s-gru')
>>> ocr = CnOcr(model_name='densenet_lite_124-fc')
识别时只考虑数字:
>>> ocr = CnOcr(model_name='densenet-s-gru', cand_alphabet='0123456789')
>>> ocr = CnOcr(model_name='densenet_lite_124-fc', cand_alphabet='0123456789')
"""
if 'name' in kwargs:
......
......@@ -36,61 +36,70 @@ ENCODER_CONFIGS = {
'num_init_features': 64,
'out_length': 512, # 输出的向量长度为 4*128 = 512
},
'densenet-1112': { # 长度压缩至 1/8(seq_len == 35)
'densenet_1112': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 1, 1, 2],
'num_init_features': 64,
'out_length': 400,
},
'densenet-1114': { # 长度压缩至 1/8(seq_len == 35)
'densenet_1114': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 1, 1, 4],
'num_init_features': 64,
'out_length': 656,
},
'densenet-1122': { # 长度压缩至 1/8(seq_len == 35)
'densenet_1122': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 1, 2, 2],
'num_init_features': 64,
'out_length': 464,
},
'densenet-1124': { # 长度压缩至 1/8(seq_len == 35)
'densenet_1124': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 1, 2, 4],
'num_init_features': 64,
'out_length': 720,
},
'densenet-lite-113': { # 长度压缩至 1/8(seq_len == 35),输出的向量长度为 2*136 = 272
'densenet_lite_113': { # 长度压缩至 1/8(seq_len == 35),输出的向量长度为 2*136 = 272
'growth_rate': 32,
'block_config': [1, 1, 3],
'num_init_features': 64,
'out_length': 272, # 输出的向量长度为 2*80 = 160
},
'densenet-lite-114': { # 长度压缩至 1/8(seq_len == 35)
'densenet_lite_114': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 1, 4],
'num_init_features': 64,
'out_length': 336,
},
'densenet-lite-124': { # 长度压缩至 1/8(seq_len == 35)
'densenet_lite_124': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 2, 4],
'num_init_features': 64,
'out_length': 368,
},
'densenet-lite-134': { # 长度压缩至 1/8(seq_len == 35)
'densenet_lite_134': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 3, 4],
'num_init_features': 64,
'out_length': 400,
},
'densenet-lite-136': { # 长度压缩至 1/8(seq_len == 35)
'densenet_lite_136': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 3, 6],
'num_init_features': 64,
'out_length': 528,
},
'mobilenetv3_tiny': {
'arch': 'tiny',
'out_length': 384,
},
'mobilenetv3_small': {
'arch': 'small',
'out_length': 384,
}
}
DECODER_CONFIGS = {
......@@ -102,12 +111,12 @@ DECODER_CONFIGS = {
# 'input_size': 512, # 对应 encoder 的输出向量长度
'rnn_units': 128,
},
'fc': {
'fcfull': {
# 'input_size': 512, # 对应 encoder 的输出向量长度
'hidden_size': 256,
'dropout': 0.3,
},
'fclite': {
'fc': {
# 'input_size': 512, # 对应 encoder 的输出向量长度
'hidden_size': 128,
'dropout': 0.1,
......@@ -120,13 +129,14 @@ root_url = (
)
# name: (epochs, url)
AVAILABLE_MODELS = {
'densenet-s-fc': (8, root_url + 'densenet-s-fc-v2.0.1.zip'),
'densenet-s-gru': (14, root_url + 'densenet-s-gru-v2.0.1.zip'),
# 'densenet-s-fc': (8, root_url + 'densenet-s-fc-v2.0.1.zip'),
# 'densenet-s-gru': (14, root_url + 'densenet-s-gru-v2.0.1.zip'),
# 'densenet-lite-113-fclite': (33, root_url + '.zip'),
'densenet-lite-114-fclite': (31, root_url + '.zip'),
'densenet-lite-124-fclite': (36, root_url + '.zip'),
'densenet-lite-134-fclite': (38, root_url + '.zip'),
'densenet-lite-136-fclite': (38, root_url + '.zip'),
'densenet_lite_114-fc': (31, root_url + '.zip'),
'densenet_lite_124-fc': (36, root_url + '.zip'),
'densenet_lite_134-fc': (38, root_url + '.zip'),
'densenet_lite_136-fc': (17, root_url + '.zip'),
'densenet_lite_136-fc-scene': (17, root_url + '.zip'),
}
# 候选字符集合
......
# coding: utf-8
# Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus).
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# adapted from: torchvision/models/mobilenetv3.py
from functools import partial
from typing import Any, List, Optional, Callable
from torch import nn, Tensor
from torchvision.models.mobilenetv2 import ConvBNActivation
from torchvision.models import mobilenetv3
from torchvision.models.mobilenetv3 import InvertedResidualConfig
class MobileNetV3(mobilenetv3.MobileNetV3):
def __init__(
self,
inverted_residual_setting: List[InvertedResidualConfig],
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any
) -> None:
super().__init__(inverted_residual_setting, 1, 2, block, norm_layer)
delattr(self, 'classifier')
firstconv_input_channels = self.features[0][0].out_channels
self.features[0] = ConvBNActivation(
1,
firstconv_input_channels,
kernel_size=3,
stride=2,
norm_layer=norm_layer,
activation_layer=nn.Hardswish,
)
lastconv_input_channels = self.features[-1][0].in_channels
lastconv_output_channels = 2 * lastconv_input_channels
self.features[-1] = ConvBNActivation(
lastconv_input_channels,
lastconv_output_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=nn.Hardswish,
)
self.avgpool = nn.AvgPool2d(kernel_size=(2, 1), stride=(2, 1))
self._post_init_weights()
@property
def compress_ratio(self):
return 8
def _post_init_weights(self):
# Official init from torch repo.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def forward(self, x: Tensor) -> Tensor:
features = self.features(x)
features = self.avgpool(features)
return features
def _mobilenet_v3_conf(
arch: str,
width_mult: float = 1.0,
reduced_tail: bool = False,
dilated: bool = False,
**kwargs: Any
):
reduce_divider = 2 if reduced_tail else 1
dilation = 2 if dilated else 1
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(
InvertedResidualConfig.adjust_channels, width_mult=width_mult
)
if arch == "mobilenet_v3_tiny":
inverted_residual_setting = [
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 96, 40, False, "HS", 2, 1), # C3
bneck_conf(40, 5, 120, 48, False, "HS", 1, 1),
# bneck_conf(48, 5, 144, 48, False, "HS", 1, 1),
bneck_conf(
48, 5, 288, 96 // reduce_divider, False, "HS", 2, dilation
), # C4
bneck_conf(
96 // reduce_divider,
5,
128 // reduce_divider,
96 // reduce_divider,
True,
"HS",
1,
dilation,
),
bneck_conf(
96 // reduce_divider,
5,
128 // reduce_divider,
96 // reduce_divider,
True,
"HS",
1,
dilation,
),
]
elif arch == "mobilenet_v3_small":
inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1), # C1
bneck_conf(16, 3, 72, 24, False, "RE", 1, 1), # C2
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 96, 40, False, "HS", 2, 1), # C3
bneck_conf(40, 5, 240, 40, False, "HS", 1, 1),
bneck_conf(40, 5, 240, 40, False, "HS", 1, 1),
bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
bneck_conf(
96 // reduce_divider,
5,
576 // reduce_divider,
96 // reduce_divider,
True,
"HS",
1,
dilation,
),
bneck_conf(
96 // reduce_divider,
5,
576 // reduce_divider,
96 // reduce_divider,
True,
"HS",
1,
dilation,
),
]
else:
raise ValueError("Unsupported model type {}".format(arch))
return inverted_residual_setting
def _mobilenet_v3_model(
inverted_residual_setting: List[InvertedResidualConfig], **kwargs: Any
):
model = MobileNetV3(inverted_residual_setting, **kwargs)
return model
def gen_mobilenet_v3(arch: str = 'tiny', **kwargs: Any) -> MobileNetV3:
"""
Constructs a small MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
arch (str): arch name; values: 'tiny' or 'small'
"""
arch = 'mobilenet_v3_%s' % arch
inverted_residual_setting = _mobilenet_v3_conf(arch, **kwargs)
return _mobilenet_v3_model(inverted_residual_setting, **kwargs)
......@@ -31,6 +31,7 @@ from .ctc import CTCPostProcessor
from ..consts import ENCODER_CONFIGS, DECODER_CONFIGS
from ..data_utils.utils import encode_sequences
from .densenet import DenseNet, DenseNetLite
from .mobilenet import gen_mobilenet_v3
class EncoderManager(object):
......@@ -45,12 +46,16 @@ class EncoderManager(object):
assert config is not None and 'name' in config
name = config.pop('name')
if name.lower().startswith('densenet-lite'):
if name.lower().startswith('densenet_lite'):
out_length = config.pop('out_length')
encoder = DenseNetLite(**config)
elif name.lower().startswith('densenet'):
out_length = config.pop('out_length')
encoder = DenseNet(**config)
elif name.lower().startswith('mobilenet'):
arch = config['arch']
out_length = config.pop('out_length')
encoder = gen_mobilenet_v3(arch)
else:
raise ValueError('not supported encoder name: %s' % name)
return encoder, out_length
......@@ -89,7 +94,7 @@ class DecoderManager(object):
bidirectional=True,
)
out_length = config['rnn_units'] * 2
elif name.lower() in ('fc', 'fclite'):
elif name.lower() in ('fc', 'fcfull'):
decoder = nn.Sequential(
nn.Dropout(p=config['dropout']),
# nn.Tanh(),
......
......@@ -102,7 +102,7 @@ def data_dir():
def check_model_name(model_name):
encoder_type, decoder_type = model_name.rsplit('-', maxsplit=1)
encoder_type, decoder_type = model_name.split('-')[:2]
assert encoder_type in ENCODER_CONFIGS
assert decoder_type in DECODER_CONFIGS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册