未验证 提交 742efd3b 编写于 作者: B BreezeDeus 提交者: GitHub

Merge pull request #103 from breezedeus/dev-v1.2

fix: allow to initialize multiple instances
......@@ -10,7 +10,7 @@ English [README](./README_en.md).
# 最近更新 【2020.05.25】:V1.2.1
# 最近更新 【2020.05.29】:V1.2.2
主要变更:
......@@ -21,6 +21,7 @@ English [README](./README_en.md).
* 默认模型由之前的`conv-lite-fc`改为`densenet-lite-fc`
* 预测支持使用GPU。
* bugfixs:
* 修复同时初始化多个实例时会报错的问题;
* Web 调用时的内存泄露。感谢 [@myuanz](https://github.com/myuanz)
* 输入图片宽度很小时导致异常;
* 去掉 `f-print`
......@@ -150,6 +151,7 @@ class CnOcr(object):
cand_alphabet=None,
root=data_dir(),
context='cpu',
name=None,
):
```
......@@ -159,10 +161,12 @@ class CnOcr(object):
* `model_epoch`: 模型迭代次数。默认为 `None`,表示使用默认的迭代次数值。对于模型名称 `densenet-lite-fc`就是 `40`
* `cand_alphabet`: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围。`cnocr.consts`中内置了两个候选集合:(1) 数字和标点 `NUMBERS`;(2) 英文字母、数字和标点 `ENG_LETTERS`
* 例如对于图片 ![examples/hybrid.png](./examples/hybrid.png) ,不做约束时识别结果为 `o12345678`;如果加入数字约束时(`ocr = CnOcr(cand_alphabet=NUMBERS)`),识别结果为 `012345678`
* `cand_alphabet`也可以初始化后通过类函数 `CnOcr.set_cand_alphabet(cand_alphabet)` 进行设置。这样同一个实例也可以指定不同的`cand_alphabet`进行识别。
* `root`: 模型文件所在的根目录。
* Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.2.0/densenet-lite-fc`
* Windows下默认值为 `C:\Users\<username>\AppData\Roaming\cnocr`
* `context`:预测使用的机器资源,可取值为字符串`cpu``gpu`,或者 `mx.Context`实例。
* `name`:正在初始化的这个实例的名称。如果需要同时初始化多个实例,需要为不同的实例指定不同的名称。
......
# Release Notes
### Update 2020.05.29: 发布 cnocr V1.2.2
主要变更:
* `CnOcr`加入类函数 `CnOcr.set_cand_alphabet(cand_alphabet) `。可通过此类函数设置`cand_alphabet`。这样同一个实例也可以指定不同的`cand_alphabet`进行识别。
* bugfix:
* 修复同时初始化多个实例时会报错的问题。
### Update 2020.05.25: 发布 cnocr V1.2.1
bugfix:
主要变更:
* 修复了zip文件名的typo。
* bugfix:
* 修复了zip文件名的typo。
......
__version__ = '1.2.1'
__version__ = '1.2.2'
......@@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
import os
import re
import logging
import mxnet as mx
import numpy as np
......@@ -25,7 +26,6 @@ from cnocr.consts import MODEL_VERSION, AVAILABLE_MODELS
from cnocr.hyperparams.cn_hyperparams import CnHyperparams as Hyperparams
from cnocr.fit.lstm import init_states
from cnocr.fit.ctc_metrics import CtcMetrics
from cnocr.data_utils.data_iter import SimpleBatch
from cnocr.symbols.crnn import gen_network
from cnocr.utils import (
data_dir,
......@@ -83,7 +83,16 @@ def lstm_init_states(batch_size, hp):
return init_names, init_arrays
def load_module(prefix, epoch, data_names, data_shapes, network=None, context='cpu'):
def load_module(
prefix,
epoch,
data_names,
data_shapes,
*,
network=None,
net_prefix=None,
context='cpu'
):
"""
Loads the model from checkpoint specified by prefix and epoch, binds it
to an executor, and sets its parameters and returns a mx.mod.Module
......@@ -92,9 +101,13 @@ def load_module(prefix, epoch, data_names, data_shapes, network=None, context='c
if network is not None:
sym = network
net_prefix = net_prefix or ''
if net_prefix:
arg_params = {rename_params(k, net_prefix): v for k, v in arg_params.items()}
aux_params = {rename_params(k, net_prefix): v for k, v in aux_params.items()}
# We don't need CTC loss for prediction, just a simple softmax will suffice.
# We get the output of the layer just before the loss layer ('pred_fc') and add softmax on top
pred_fc = sym.get_internals()['pred_fc_output']
pred_fc = sym.get_internals()[net_prefix + 'pred_fc_output']
sym = mx.sym.softmax(data=pred_fc)
if not check_context(context):
......@@ -110,6 +123,12 @@ def load_module(prefix, epoch, data_names, data_shapes, network=None, context='c
return mod
def rename_params(k, net_prefix):
pat = re.compile(r'^(densenet|crnn|gru|lstm)\d*_')
k = pat.sub('', k, 1)
return net_prefix + k
class CnOcr(object):
MODEL_FILE_PREFIX = 'cnocr-v{}'.format(MODEL_VERSION)
......@@ -120,6 +139,7 @@ class CnOcr(object):
cand_alphabet=None,
root=data_dir(),
context='cpu',
name=None,
):
"""
......@@ -130,6 +150,7 @@ class CnOcr(object):
Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.1.0/conv-lite-fc-0027`。
Windows下默认值为 ``。
:param context: 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为CPU。
:param name: 正在初始化的这个实例名称。如果需要同时初始化多个实例,需要为不同的实例指定不同的名称。
"""
check_model_name(model_name)
self._model_name = model_name
......@@ -139,17 +160,17 @@ class CnOcr(object):
root = os.path.join(root, MODEL_VERSION)
self._model_dir = os.path.join(root, self._model_name)
self._assert_and_prepare_model_files()
self._alphabet, inv_alph_dict = read_charset(
self._alphabet, self._inv_alph_dict = read_charset(
os.path.join(self._model_dir, 'label_cn.txt')
)
self._cand_alph_idx = None
if cand_alphabet is not None:
self._cand_alph_idx = [0] + [inv_alph_dict[word] for word in cand_alphabet]
self._cand_alph_idx.sort()
self.set_cand_alphabet(cand_alphabet)
self._hp = Hyperparams()
self._hp._loss_type = None # infer mode
# 传入''的话,也改成传入None
self._net_prefix = None if name == '' else name
self._mod = self._get_module(context)
......@@ -174,7 +195,7 @@ class CnOcr(object):
get_model_file(model_dir)
def _get_module(self, context):
network, self._hp = gen_network(self._model_name, self._hp)
network, self._hp = gen_network(self._model_name, self._hp, self._net_prefix)
hp = self._hp
prefix = os.path.join(self._model_dir, self._model_file_prefix)
data_names = ['data']
......@@ -186,10 +207,23 @@ class CnOcr(object):
data_names,
data_shapes,
network=network,
net_prefix=self._net_prefix,
context=context,
)
return mod
def set_cand_alphabet(self, cand_alphabet):
"""
设置待识别字符的候选集合。
:param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
:return: None
"""
if cand_alphabet is None:
self._cand_alph_idx = None
else:
self._cand_alph_idx = [0] + [self._inv_alph_dict[word] for word in cand_alphabet]
self._cand_alph_idx.sort()
def ocr(self, img_fp):
"""
:param img_fp: image file path; or color image mx.nd.NDArray or np.ndarray,
......
......@@ -29,7 +29,7 @@ from .densenet import DenseNet
from ..fit.ctc_loss import add_ctc_loss
def gen_network(model_name, hp):
def gen_network(model_name, hp, net_prefix=None):
hp = deepcopy(hp)
hp.seq_model_type = model_name.rsplit('-', maxsplit=1)[-1]
model_name = model_name.lower()
......@@ -45,23 +45,30 @@ def gen_network(model_name, hp):
)
seq_len = hp.img_width // 8 if shorter else hp.img_width // 4
hp.set_seq_length(seq_len)
densenet = DenseNet(layer_channels, shorter=shorter)
densenet = DenseNet(layer_channels, shorter=shorter, prefix=net_prefix)
densenet.hybridize()
model = CRnn(hp, densenet)
model = CRnn(hp, densenet, prefix=net_prefix)
elif model_name.startswith('conv-lite'):
hp.seq_len_cmpr_ratio = 4
shorter = model_name.startswith('conv-lite-s-')
seq_len = hp.img_width // 8 if shorter else hp.img_width // 4 - 1
hp.set_seq_length(seq_len)
model = lambda data: crnn_lstm_lite(hp, data, shorter=shorter)
def model(data):
with mx.name.Prefix(net_prefix or ''):
return crnn_lstm_lite(hp, data, shorter=shorter)
elif model_name.startswith('conv'):
hp.seq_len_cmpr_ratio = 8
hp.set_seq_length(hp.img_width // 8)
model = lambda data: crnn_lstm(hp, data)
def model(data):
with mx.name.Prefix(net_prefix or ''):
return crnn_lstm(hp, data)
else:
raise NotImplementedError('bad model_name: %s', model_name)
return pipline(model, hp), hp
return pipline(model, hp, net_prefix=net_prefix), hp
def get_infer_shape(sym_model, hp):
......@@ -75,18 +82,25 @@ def get_infer_shape(sym_model, hp):
return shape_dict
def gen_seq_model(hp):
def gen_seq_model(hp, **kw):
if hp.seq_model_type.lower() == 'lstm':
seq_model = LSTM(hp.num_hidden, hp.num_lstm_layer, bidirectional=True)
seq_model = LSTM(hp.num_hidden, hp.num_lstm_layer, bidirectional=True, **kw)
elif hp.seq_model_type.lower() == 'gru':
seq_model = GRU(hp.num_hidden, hp.num_lstm_layer, bidirectional=True)
seq_model = GRU(hp.num_hidden, hp.num_lstm_layer, bidirectional=True, **kw)
else:
def fc_seq_model(data):
fc = mx.sym.FullyConnected(
data, num_hidden=hp.num_hidden, flatten=False, name='seq-fc'
)
net = mx.sym.Activation(data=fc, act_type='relu', name='seq-relu')
if kw.get('prefix', None):
with mx.name.Prefix(kw['prefix']):
fc = mx.sym.FullyConnected(
data, num_hidden=hp.num_hidden, flatten=False, name='seq-fc'
)
net = mx.sym.Activation(data=fc, act_type='relu', name='seq-relu')
else:
fc = mx.sym.FullyConnected(
data, num_hidden=hp.num_hidden, flatten=False, name='seq-fc'
)
net = mx.sym.Activation(data=fc, act_type='relu', name='seq-relu')
return net
seq_model = fc_seq_model
......@@ -100,7 +114,7 @@ class CRnn(nn.HybridBlock):
self.emb_model = emb_model
self.dropout = nn.Dropout(hp.dropout)
self.seq_model = gen_seq_model(hp)
self.seq_model = gen_seq_model(hp, **kw)
def hybrid_forward(self, F, X):
embs = self.emb_model(X) # res: bz x emb_size x 1 x seq_len
......@@ -114,15 +128,22 @@ class CRnn(nn.HybridBlock):
return self.seq_model(embs)
def pipline(model, hp, data=None):
def pipline(model, hp, data=None, *, net_prefix=''):
# 构建用于训练的整个计算图
data = data if data is not None else mx.sym.Variable('data')
output = model(data)
output = mx.symbol.reshape(output, shape=(-3, -2)) # res: (seq_len * bz, c)
pred = mx.sym.FullyConnected(
data=output, num_hidden=hp.num_classes, name='pred_fc'
) # (bz x 35) x num_classes
if net_prefix:
with mx.name.Prefix(net_prefix):
output = mx.symbol.reshape(output, shape=(-3, -2)) # res: (seq_len * bz, c)
pred = mx.sym.FullyConnected(
data=output, num_hidden=hp.num_classes, name='pred_fc'
) # (bz x 35) x num_classes
else:
output = mx.symbol.reshape(output, shape=(-3, -2)) # res: (seq_len * bz, c)
pred = mx.sym.FullyConnected(
data=output, num_hidden=hp.num_classes, name='pred_fc'
) # (bz x 35) x num_classes
# print('pred', pred.infer_shape()[1])
# import pdb; pdb.set_trace()
......
......@@ -12,6 +12,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))
from cnocr import CnOcr
from cnocr.consts import NUMBERS, AVAILABLE_MODELS
from cnocr.line_split import line_split
from cnocr.data_utils.aug import GrayAug
......@@ -176,19 +177,51 @@ def test_gray_aug(img_fp, expected):
print(res_img.shape, res_img.dtype)
def test_cand_alphabet():
from cnocr import NUMBERS
def test_cand_alphabet1():
img_fp = os.path.join(example_dir, 'hybrid.png')
ocr = CnOcr(name='instance1')
pred = ocr.ocr(img_fp)
pred = [''.join(line_p) for line_p in pred]
print("Predicted Chars:", pred)
assert len(pred) == 1 and pred[0] == 'o12345678'
ocr = CnOcr(name='instance2', cand_alphabet=NUMBERS)
pred = ocr.ocr(img_fp)
pred = [''.join(line_p) for line_p in pred]
print("Predicted Chars:", pred)
assert len(pred) == 1 and pred[0] == '012345678'
def test_cand_alphabet2():
img_fp = os.path.join(example_dir, 'hybrid.png')
ocr = CnOcr()
ocr = CnOcr(name='instance1')
pred = ocr.ocr(img_fp)
pred = [''.join(line_p) for line_p in pred]
print("Predicted Chars:", pred)
assert len(pred) == 1 and pred[0] == 'o12345678'
ocr = CnOcr(cand_alphabet=NUMBERS)
ocr.set_cand_alphabet(NUMBERS)
pred = ocr.ocr(img_fp)
pred = [''.join(line_p) for line_p in pred]
print("Predicted Chars:", pred)
assert len(pred) == 1 and pred[0] == '012345678'
INSTANCE_ID = 0
@pytest.mark.parametrize('model_name', AVAILABLE_MODELS.keys())
def test_multiple_instances(model_name):
global INSTANCE_ID
print('test multiple instances for model_name: %s' % model_name)
img_fp = os.path.join(example_dir, 'hybrid.png')
INSTANCE_ID += 1
print('instance id: %d' % INSTANCE_ID)
cnocr1 = CnOcr(model_name, name='instance-%d' % INSTANCE_ID)
print_preds(cnocr1.ocr(img_fp))
INSTANCE_ID += 1
print('instance id: %d' % INSTANCE_ID)
cnocr2 = CnOcr(model_name, name='instance-%d' % INSTANCE_ID, cand_alphabet=NUMBERS)
print_preds(cnocr2.ocr(img_fp))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册