提交 d09ded2c 编写于 作者: B breezedeus

adapt to new design

上级 3a3bcbaf
DATA_ROOT_DIR = /data2/ocr/outer
REC_DATA_ROOT_DIR = /dev/data/jinlong/data
# ['conv', 'conv-lite-rnn', 'densenet', 'densenet-lite']
# ['conv', 'conv-lite', 'densenet', 'densenet-lite']
EMB_MODEL_TYPE = densenet-lite
SEQ_MODEL_TYPE = lstm
MODEL_NAME = $(EMB_MODEL_TYPE)-$(SEQ_MODEL_TYPE)
......@@ -19,7 +19,7 @@ gen-rec:
train:
nohup python scripts/cnocr_train.py --gpu 2 --emb_model_type $(EMB_MODEL_TYPE) --seq_model_type $(SEQ_MODEL_TYPE) \
--optimizer Adam --epoch $(EPOCH) --lr 3e-4 --model_name $(MODEL_NAME) \
--optimizer Adam --epoch $(EPOCH) --lr 3e-4 \
--train_file $(REC_DATA_ROOT_DIR)/lst/cnocr_train --test_file $(REC_DATA_ROOT_DIR)/lst/cnocr_test \
>> nohup-$(MODEL_NAME).out 2>&1 &
......
# Update 2019.07.25: 发布 cnocr V1.0.0
# Release Notes
### Update 2020.04.20: 发布 cnocr V1.1.0
### Update 2019.07.25: 发布 cnocr V1.0.0
`cnocr`发布了预测效率更高的新版本v1.0.0。**新版本的模型跟以前版本的模型不兼容**。所以如果大家是升级的话,需要重新下载最新的模型文件。具体说明见下面(流程和原来相同)。
......@@ -15,13 +21,39 @@
**cnocr**是用来做中文OCR的**Python 3**包。cnocr自带了训练好的识别模型,所以安装后即可直接使用。
目前使用的识别模型是**crnn**,识别准确度约为 `98.8%`
本项目起源于我们自己 ([爱因互动 Ein+](https://einplus.cn)) 内部的项目需求,所以非常感谢公司的支持。
## 可直接使用的模型
cnocr的ocr模型可以分为两阶段:第一阶段是获得ocr图片的局部编码向量,第二部分是对局部编码向量进行序列学习,获得序列编码向量。目前两个阶段分别包含以下的模型:
1. 局部编码模型(emb model)
* `conv`:多层的卷积网络;
* `conv-lite`:更小的多层卷积网络;
* `densenet`:一个小型的`densenet`网络;
* `densenet-lite`:一个更小的`densenet`网络。
2. 序列编码模型(seq model)
* `lstm`:两层的LSTM网络;
* `gru`:两层的GRU网络;
* `fc`:两层的全连接网络。
cnocr目前包含以下可直接使用的模型:
| 模型名称 | 局部编码模型 | 序列编码模型 | 模型大小 | 迭代次数 | 测试集准确率 |
| :------- | ------------ | ------------ | -------- | ------ | -------- |
| conv-lstm | conv | lstm | 36M | 50 | 98.5% |
| conv-lite-lstm | conv-lite | lstm | 23M | 45 | 98.6% |
| conv-lite-fc | conv-lite | fc | 20M | 27 | 98.6% |
| densenet-lite-lstm | densenet-lite | lstm | 8.6M | 42 | 98.6% |
| densenet-lite-fc | densenet-lite | fc | 6.8M | 32 | 97% |
> 模型名称是由局部编码模型和序列编码模型名称拼接而成。
## 特色
本项目的大部分代码都fork自 [crnn-mxnet-chinese-text-recognition](https://github.com/diaomin/crnn-mxnet-chinese-text-recognition),感谢作者。
本项目的初期代码都fork自 [crnn-mxnet-chinese-text-recognition](https://github.com/diaomin/crnn-mxnet-chinese-text-recognition),感谢作者。
但源项目使用起来不够方便,所以我在此基础上做了一些封装和重构。主要变化如下:
......
......@@ -21,13 +21,13 @@ import numpy as np
from PIL import Image
from cnocr.__version__ import __version__
from cnocr.consts import MODEL_EPOCE, EMB_MODEL_TYPES, SEQ_MODEL_TYPES
from cnocr.consts import AVAILABLE_MODELS
from cnocr.hyperparams.cn_hyperparams import CnHyperparams as Hyperparams
from cnocr.fit.lstm import init_states
from cnocr.fit.ctc_metrics import CtcMetrics
from cnocr.data_utils.data_iter import SimpleBatch
from cnocr.symbols.crnn import gen_network
from cnocr.utils import data_dir, get_model_file, read_charset, normalize_img_array
from cnocr.utils import data_dir, get_model_file, read_charset, normalize_img_array, check_model_name
from cnocr.line_split import line_split
......@@ -95,14 +95,16 @@ def load_module(prefix, epoch, data_names, data_shapes, network=None):
class CnOcr(object):
MODEL_FILE_PREFIX = 'cnocr-v{}'.format(__version__)
def __init__(self, model_name='conv-lite-lstm', root=data_dir(), model_epoch=MODEL_EPOCE,
cand_alphabet=None):
self._check_model_name(model_name)
def __init__(self, model_name='conv-lite-lstm', model_epoch=None,
cand_alphabet=None, root=data_dir()):
check_model_name(model_name)
self._model_name = model_name
self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX, model_name)
self._model_dir = os.path.join(root, 'models')
self._model_epoch = model_epoch
self._assert_and_prepare_model_files(root)
self._model_epoch = model_epoch or AVAILABLE_MODELS[model_name][0]
root = os.path.join(root, __version__)
self._model_dir = os.path.join(root, '%s-%04d' % (self._model_name, self._model_epoch))
self._assert_and_prepare_model_files()
self._alphabet, inv_alph_dict = read_charset(os.path.join(self._model_dir, 'label_cn.txt'))
self._cand_alph_idx = None
......@@ -115,12 +117,7 @@ class CnOcr(object):
self._mod = self._get_module()
def _check_model_name(self, model_name):
emb_model_type, seq_model_type = model_name.rsplit('-', maxsplit=1)
assert emb_model_type in EMB_MODEL_TYPES
assert seq_model_type in SEQ_MODEL_TYPES
def _assert_and_prepare_model_files(self, root):
def _assert_and_prepare_model_files(self):
model_dir = self._model_dir
model_files = ['label_cn.txt',
'%s-%04d.params' % (self._model_file_prefix, self._model_epoch),
......@@ -137,15 +134,15 @@ class CnOcr(object):
if os.path.exists(model_dir):
os.removedirs(model_dir)
get_model_file(root)
get_model_file(model_dir)
def _get_module(self):
network, self._hp = gen_network(self._model_name, self._hp)
hp = self._hp
prefix = os.path.join(self._model_dir, self._model_file_prefix)
# import pdb; pdb.set_trace()
data_names = ['data']
data_shapes = [(data_names[0], (hp.batch_size, 1, hp.img_height, hp.img_width))]
print('loading model parameters from dir %s' % self._model_dir)
mod = load_module(prefix, self._model_epoch, data_names, data_shapes, network=network)
return mod
......
# coding: utf-8
import string
from .__version__ import __version__
MODEL_BASE_URL = 'https://www.dropbox.com/s/7w8l3mk4pvkt34w/cnocr-models-v1.0.0.zip?dl=1'
MODEL_EPOCE = 20
EMB_MODEL_TYPES = ['conv', 'conv-lite', 'densenet', 'densenet-lite']
SEQ_MODEL_TYPES = ['lstm', 'gru', 'fc']
ZIP_FILE_NAME = 'cnocr-models-v{}.zip'.format(__version__)
# name: (epochs, url)
AVAILABLE_MODELS = {
'conv-lstm': (50, ),
'conv-lite-lstm': (45, ),
'conv-lite-fc': (27, ),
'densenet-lite-lstm': (45, ),
'densenet-lite-fc': (32, ),
}
# 候选字符集合
NUMBERS = string.digits + string.punctuation
......
......@@ -47,7 +47,7 @@ def gen_network(model_name, hp):
hp.seq_len_cmpr_ratio = 4
hp.set_seq_length(hp.img_width // 4 - 1)
model = lambda data: crnn_lstm_lite(hp, data)
elif model_name == 'conv-lstm':
elif model_name.startswith('conv'):
hp.seq_len_cmpr_ratio = 8
hp.set_seq_length(hp.img_width // 8)
model = lambda data: crnn_lstm(hp, data)
......
......@@ -21,6 +21,7 @@ from __future__ import print_function
import sys
import os
import time
import argparse
from operator import itemgetter
from pathlib import Path
......@@ -31,19 +32,14 @@ import Levenshtein
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from cnocr import CnOcr
from cnocr.consts import MODEL_NAMES
def evaluate():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-name",
help="model name",
choices=MODEL_NAMES,
type=str,
default='conv-rnn',
"--model-name", help="model name", type=str, default='densenet-lite-lstm'
)
parser.add_argument("--model-epoch", type=int, help="model epoch")
parser.add_argument("--model-epoch", type=int, default=None, help="model epoch")
parser.add_argument(
"-i",
"--input-fp",
......@@ -75,6 +71,7 @@ def evaluate():
fn_labels_list = read_input_file(args.input_fp)
miss_cnt, redundant_cnt = Counter(), Counter()
model_time_cost = 0.0
start_idx = 0
bad_cnt = 0
badcases = []
......@@ -91,7 +88,9 @@ def evaluate():
img = mx.image.imread(img_fp, 1).asnumpy()
batch_imgs.append(img)
start_time = time.time()
batch_preds = ocr.ocr_for_single_lines(batch_imgs)
model_time_cost += time.time() - start_time
for bad_info in compare_preds_to_reals(
batch_preds, batch_labels, batch_img_fns, alphabet
):
......@@ -110,7 +109,7 @@ def evaluate():
output_dir = Path(args.output_dir)
if not output_dir.exists():
output_dir.mkdir()
os.makedirs(output_dir)
with open(output_dir / 'badcases.txt', 'w') as f:
f.write(
'\t'.join(
......@@ -135,8 +134,8 @@ def evaluate():
f.write('\t'.join([word, str(num)]) + '\n')
print(
"number of total cases: %d, number of bad cases: %d"
% (len(fn_labels_list), bad_cnt)
"number of total cases: %d, time cost per image: %f, number of bad cases: %d"
% (len(fn_labels_list), model_time_cost / len(fn_labels_list), bad_cnt)
)
......
......@@ -26,24 +26,24 @@ import argparse
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from cnocr import CnOcr
from cnocr.consts import MODEL_NAMES
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
help="model name",
choices=MODEL_NAMES,
type=str,
default='conv-rnn',
"--model_name", help="model name", type=str, default='densenet-lite-lstm'
)
parser.add_argument("--model_epoch", type=int, default=None, help="model epoch")
parser.add_argument("-f", "--file", help="Path to the image file")
parser.add_argument("-s", "--single-line", default=False,
help="Whether the image only includes one-line characters")
parser.add_argument(
"-s",
"--single-line",
default=False,
help="Whether the image only includes one-line characters",
)
args = parser.parse_args()
ocr = CnOcr(model_name=MODEL_NAMES)
ocr = CnOcr(model_name=args.model_name, model_epoch=args.model_epoch)
if args.single_line:
res = ocr.ocr_for_single_line(args.file)
else:
......
......@@ -10,7 +10,7 @@ from mxnet import nd
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))
from cnocr.consts import MODEL_NAMES
from cnocr.consts import EMB_MODEL_TYPES, SEQ_MODEL_TYPES
from cnocr.hyperparams.cn_hyperparams import CnHyperparams
from cnocr.symbols.densenet import _make_dense_layer, DenseNet, cal_num_params
from cnocr.symbols.crnn import (
......@@ -102,6 +102,12 @@ def test_pipline():
assert pred_shape == (hp.batch_size * hp.seq_length, hp.num_classes)
MODEL_NAMES = []
for emb_model in EMB_MODEL_TYPES:
for seq_model in SEQ_MODEL_TYPES:
MODEL_NAMES.append('%s-%s' % (emb_model, seq_model))
@pytest.mark.parametrize(
'model_name', MODEL_NAMES
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册