提交 e515dd4b 编写于 作者: B breezedeus

update to v2.1

上级 55e56b33
# 可取值:['densenet-s']
ENCODER_NAME = mobilenetv3_tiny
# 可取值:['densenet_lite_136']
ENCODER_NAME = densenet_lite_136
# 可取值:['fc', 'gru', 'lstm']
DECODER_NAME = fc
MODEL_NAME = $(ENCODER_NAME)-$(DECODER_NAME)
......
English [README](./README_en.md) (`out-dated`).
# cnocr
**cnocr****Python 3** 下的**文字识别****Optical Character Recognition**,简称**OCR**)工具包,支持**中文****英文**的常见字符识别,自带了多个训练好的识别模型,安装后即可直接使用。欢迎扫码加入QQ交流群:
......@@ -7,23 +8,9 @@ English [README](./README_en.md) (`out-dated`).
![QQ群二维码](./docs/cnocr-qq.jpg)
## 详细文档
# 最近更新 【2021.08.26】:V2.0.0
主要变更:
* MXNet 越来越小众化,故从基于 MXNet 的实现转为基于 **PyTorch** 的实现;
* 重新实现了识别模型,优化了训练数据,重新训练模型;
* 优化了能识别的字符集合;
* 优化了对英文的识别效果;
* 优化了对场景文字的识别效果;
* 使用接口略有调整。
更多更新说明见 [RELEASE Notes](docs/RELEASE.md)
[CnOcr在线文档](https://cnocr.readthedocs.io/)
## 使用场景说明
......@@ -73,375 +60,6 @@ pip install cnocr -i https://pypi.doubanio.com/simple
> 注意:请使用 **Python3**(3.6以及之后版本应该都行),没测过Python2下是否ok。
## 可直接使用的模型
cnocr的ocr模型可以分为两阶段:第一阶段是获得ocr图片的局部编码向量,第二部分是对局部编码向量进行序列学习,获得序列编码向量。目前的PyTorch版本的两个阶段分别包含以下模型:
1. 局部编码模型(emb model)
* **`densenet-s`**:一个小型的`densenet`网络;
2. 序列编码模型(seq model)
* **`lstm`**:一层的LSTM网络;
* **`gru`**:一层的GRU网络;
* **`fc`**:两层的全连接网络。
cnocr **V2.0** 目前包含以下可直接使用的模型,训练好的模型都放在 **[cnstd-cnocr-models](https://github.com/breezedeus/cnstd-cnocr-models)** 项目中,可免费下载使用:
| 模型名称 | 局部编码模型 | 序列编码模型 | 模型大小 | 迭代次数 | 测试集准确率 |
| :------- | ------------ | ------------ | -------- | ------ | -------- |
| densenet-s-gru | densenet-s | gru | 11 M | 11 | 95.5% |
| densenet-s-fc | densenet-s | fc | 8.7 M | 39 | 91.9% |
> 模型名称是由局部编码模型和序列编码模型名称拼接而成。
## 使用方法
首次使用cnocr时,系统会**自动下载** zip格式的模型压缩文件,并存于 `~/.cnocr`目录(Windows下默认路径为 `C:\Users\<username>\AppData\Roaming\cnocr`)。
下载后的zip文件代码会自动对其解压,然后把解压后的模型相关目录放于`~/.cnocr/2.0`目录中。
如果系统无法自动成功下载zip文件,则需要手动从 **[cnstd-cnocr-models](https://github.com/breezedeus/cnstd-cnocr-models)** 下载此zip文件并把它放于 `~/.cnocr/2.0`目录。如果Github下载太慢,也可以从 [百度云盘](https://pan.baidu.com/s/1c68zjHfTVeqiSMXBEPYMrg) 下载, 提取码为 ` 9768`
放置好zip文件后,后面的事代码就会自动执行了。
### 图片预测
`CnOcr`是OCR的主类,包含了三个函数针对不同场景进行文字识别。类`CnOcr`的初始化函数如下:
```python
class CnOcr(object):
def __init__(
self,
model_name: str = 'densenet-s-fc'
*,
cand_alphabet: Optional[Union[Collection, str]] = None,
context: str = 'cpu', # ['cpu', 'gpu', 'cuda']
model_fp: Optional[str] = None,
root: Union[str, Path] = data_dir(),
**kwargs,
):
```
其中的几个参数含义如下:
* `model_name`: 模型名称,即上面表格第一列中的值。默认为 `densenet-s-fc`
* `cand_alphabet`: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围。取值可以是字符串,如 `"0123456789"`,或者字符列表,如 `["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]`
* `cand_alphabet`也可以初始化后通过类函数 `CnOcr.set_cand_alphabet(cand_alphabet)` 进行设置。这样同一个实例也可以指定不同的`cand_alphabet`进行识别。
* `context`:预测使用的机器资源,可取值为字符串`cpu``gpu``cuda:0`等。
* `model_fp`: 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件(`.ckpt` 文件)。
* `root`: 模型文件所在的根目录。
* Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.0/densenet-s-fc`
* Windows下默认值为 `C:\Users\<username>\AppData\Roaming\cnocr`
每个参数都有默认取值,所以可以不传入任何参数值进行初始化:`ocr = CnOcr()`
`CnOcr`主要包含三个函数,下面分别说明。
#### 1. 函数`CnOcr.ocr(img_fp)`
函数`CnOcr.ocr(img_fp)`可以对包含多行文字(或单行)的图片进行文字识别。
**函数说明**
- 输入参数 `img_fp`: 可以是需要识别的图片文件路径(如下例);或者是已经从图片文件中读入的数组,类型可以为 `torch.Tensor``np.ndarray`,取值应该是`[0,255]`的整数,维数应该是 `[height, width]` (灰度图片)或者 `[height, width, channel]``channel` 可以等于`1`(灰度图片)或者`3``RGB`格式的彩色图片)。
- 返回值:为一个嵌套的`list`,其中的每个元素存储了对一行文字的识别结果,其中也包含了识别概率值。类似这样`[(['第', '一', '行'], 0.80), (['第', '二', '行'], 0.75), (['第', '三', '行'], 0.9)]`,其中的数字为对应的识别概率值。
**调用示例**
```python
from cnocr import CnOcr
ocr = CnOcr()
res = ocr.ocr('docs/examples/multi-line_cn1.png')
print("Predicted Chars:", res)
```
或:
```python
from cnocr.utils import read_img
from cnocr import CnOcr
ocr = CnOcr()
img_fp = 'docs/examples/multi-line_cn1.png'
img = read_img(img_fp)
res = ocr.ocr(img)
print("Predicted Chars:", res)
```
上面使用的图片文件 [docs/examples/multi-line_cn1.png](./docs/examples/multi-line_cn1.png)内容如下:
![docs/examples/multi-line_cn1.png](./docs/examples/multi-line_cn1.png)
上面预测代码段的返回结果如下:
```bash
Predicted Chars: [
(['网', '络', '支', '付', '并', '无', '本', '质', '的', '区', '别', ',', '因', '为'], 0.8677546381950378),
(['每', '一', '个', '手', '机', '号', '码', '和', '邮', '件', '地', '址', '背', '后'], 0.6706454157829285),
(['都', '会', '对', '应', '着', '一', '个', '账', '户', '一', '一', '这', '个', '账'], 0.5052655935287476),
(['户', '可', '以', '是', '信', '用', '卡', '账', '户', '、', '借', '记', '卡', '账'], 0.7785991430282593),
(['户', ',', '也', '包', '括', '邮', '局', '汇', '款', '、', '手', '机', '代'], 0.37458470463752747),
(['收', '、', '电', '话', '代', '收', '、', '预', '付', '费', '卡', '和', '点', '卡'], 0.7326119542121887),
(['等', '多', '种', '形', '式', '。'], 0.14462216198444366)]
```
#### 2. 函数`CnOcr.ocr_for_single_line(img_fp)`
如果明确知道要预测的图片中只包含了单行文字,可以使用函数`CnOcr.ocr_for_single_line(img_fp)`进行识别。和 `CnOcr.ocr()`相比,`CnOcr.ocr_for_single_line()`结果可靠性更强,因为它不需要做额外的分行处理。
**函数说明**
- 输入参数 `img_fp`: 可以是需要识别的图片文件路径(如下例);或者是已经从图片文件中读入的数组,类型可以为 `torch.Tensor``np.ndarray`,取值应该是`[0,255]`的整数,维数应该是 `[height, width]` (灰度图片)或者 `[height, width, channel]``channel` 可以等于`1`(灰度图片)或者`3``RGB`格式的彩色图片)。
- 返回值:为一个`tuple`,其中存储了对一行文字的识别结果,也包含了识别概率值。类似这样`(['第', '一', '行'], 0.80)`,其中的数字为对应的识别概率值。
**调用示例**
```python
from cnocr import CnOcr
ocr = CnOcr()
res = ocr.ocr_for_single_line('docs/examples/rand_cn1.png')
print("Predicted Chars:", res)
```
或:
```python
from cnocr.utils import read_img
from cnocr import CnOcr
ocr = CnOcr()
img_fp = 'docs/examples/rand_cn1.png'
img = read_img(img_fp)
res = ocr.ocr_for_single_line(img)
print("Predicted Chars:", res)
```
对图片文件 [docs/examples/rand_cn1.png](./docs/examples/rand_cn1.png)
![docs/examples/rand_cn1.png](./docs/examples/rand_cn1.png)
的预测结果如下:
```bash
Predicted Chars: (['笠', '淡', '嘿', '骅', '谧', '鼎', '皋', '姚', '歼', '蠢', '驼', '耳', '胬', '挝', '涯', '狗', '蒽', '了', '狞'], 0.7832438349723816)
```
#### 3. 函数`CnOcr.ocr_for_single_lines(img_list, batch_size=1)`
函数`CnOcr.ocr_for_single_lines(img_list)`可以**对多个单行文字图片进行批量预测**。函数`CnOcr.ocr(img_fp)``CnOcr.ocr_for_single_line(img_fp)`内部其实都是调用的函数`CnOcr.ocr_for_single_lines(img_list)`
**函数说明**
- 输入参数` img_list`: 为一个`list`;其中每个元素可以是需要识别的图片文件路径(如下例);或者是已经从图片文件中读入的数组,类型可以为 `torch.Tensor``np.ndarray`,取值应该是`[0,255]`的整数,维数应该是 `[height, width]` (灰度图片)或者 `[height, width, channel]``channel` 可以等于`1`(灰度图片)或者`3``RGB`格式的彩色图片)。
- 输入参数 `batch_size`: 待处理图片很多时,需要分批处理,每批图片的数量由此参数指定。默认为 `1`
- 返回值:为一个嵌套的`list`,其中的每个元素存储了对一行文字的识别结果,其中也包含了识别概率值。类似这样`[(['第', '一', '行'], 0.80), (['第', '二', '行'], 0.75), (['第', '三', '行'], 0.9)]`,其中的数字为对应的识别概率值。
**调用示例**
```python
import numpy as np
from cnocr.utils import read_img
from cnocr import CnOcr, line_split
ocr = CnOcr()
img_fp = 'docs/examples/multi-line_cn1.png'
img = read_img(img_fp)
line_imgs = line_split(np.squeeze(img, -1), blank=True)
line_img_list = [line_img for line_img, _ in line_imgs]
res = ocr.ocr_for_single_lines(line_img_list)
print("Predicted Chars:", res)
```
更详细的使用方法,可参考 [tests/test_cnocr.py](./tests/test_cnocr.py) 中提供的测试用例。
### 结合文字检测引擎 **[cnstd](https://github.com/breezedeus/cnstd)** 使用
对于一般的场景图片(如照片、票据等),需要先利用场景文字检测引擎 **[cnstd](https://github.com/breezedeus/cnstd)** 定位到文字所在位置,然后再利用 **cnocr** 进行文本识别。
```python
from cnstd import CnStd
from cnocr import CnOcr
std = CnStd()
cn_ocr = CnOcr()
box_infos = std.detect('docs/examples/taobao.jpg')
for box_info in box_infos['detected_texts']:
cropped_img = box_info['cropped_img']
ocr_res = cn_ocr.ocr_for_single_line(cropped_img)
print('ocr result: %s' % str(ocr_out))
```
注:运行上面示例需要先安装 **[cnstd](https://github.com/breezedeus/cnstd)**
```bash
pip install cnstd
```
**[cnstd](https://github.com/breezedeus/cnstd)** 相关的更多使用说明请参考其项目地址。
### 脚本使用
**cnocr** 包含了几个命令行工具,安装 **cnocr** 后即可使用。
#### 预测单个文件或文件夹中所有图片
使用命令 **`cnocr predict`** 预测单个文件或文件夹中所有图片,以下是使用说明:
```bash
(venv) ➜ cnocr git:(pytorch) ✗ cnocr predict -h
Usage: cnocr predict [OPTIONS]
Options:
-m, --model-name [densenet-s-lstm|densenet-s-gru|densenet-s-fc]
模型名称。默认值为 densenet-s-fc
--model_epoch INTEGER model epoch。默认为 `None`,表示使用系统自带的预训练模型
-p, --pretrained-model-fp TEXT 使用训练好的模型。默认为 `None`,表示使用系统自带的预训练模型
--context TEXT 使用cpu还是 `gpu` 运行代码,也可指定为特定gpu,如`cuda:0`。默认为
`cpu`
-i, --img-file-or-dir TEXT 输入图片的文件路径或者指定的文件夹 [required]
-s, --single-line 是否输入图片只包含单行文字。对包含单行文字的图片,不做按行切分;否则会先对图片按行分割后
再进行识别
-h, --help Show this message and exit.
```
例如可以使用以下命令对图片 `docs/examples/rand_cn1.png` 进行文字识别:
```bash
cnstd predict -i docs/examples/rand_cn1.png -s
```
具体使用也可参考文件 [Makefile](./Makefile)
#### 模型训练
使用命令 **`cnocr train`** 训练文本检测模型,以下是使用说明:
```bash
(venv) ➜ cnocr git:(pytorch) ✗ cnocr train -h
Usage: cnocr train [OPTIONS]
Options:
-m, --model-name [densenet-s-fc|densenet-s-lstm|densenet-s-gru]
模型名称。默认值为 densenet-s-fc
-i, --index-dir TEXT 索引文件所在的文件夹,会读取文件夹中的 train.tsv 和 dev.tsv 文件
[required]
--train-config-fp TEXT 训练使用的json配置文件,参考 `example/train_config.json`
[required]
-r, --resume-from-checkpoint TEXT
恢复此前中断的训练状态,继续训练。默认为 `None`
-p, --pretrained-model-fp TEXT 导入的训练好的模型,作为初始模型。优先级低于"--restore-training-
fp",当传入"--restore-training-fp"时,此传入失效。默认为
`None`
-h, --help Show this message and exit.
```
例如可以使用以下命令进行训练:
```bash
cnocr train -m densenet-s-fc --index-dir data/test --train-config-fp docs/examples/train_config.json
```
训练数据的格式见文件夹 [data/test](./data/test) 中的 [train.tsv](./data/test/train.tsv)[dev.tsv](./data/test/dev.tsv) 文件。
具体使用也可参考文件 [Makefile](./Makefile)
#### 模型转存
训练好的模型会存储训练状态,使用命令 **`cnocr resave`** 去掉与预测无关的数据,降低模型大小。
```bash
(venv) ➜ cnocr git:(pytorch) ✗ cnocr resave -h
Usage: cnocr resave [OPTIONS]
训练好的模型会存储训练状态,使用此命令去掉预测时无关的数据,降低模型大小
Options:
-i, --input-model-fp TEXT 输入的模型文件路径 [required]
-o, --output-model-fp TEXT 输出的模型文件路径 [required]
-h, --help Show this message and exit.
```
## 未来工作
* [x] 支持图片包含多行文字 (`Done`)
......@@ -452,6 +70,6 @@ Options:
* [x] 尝试新模型,如 DenseNet,进一步提升识别准确率(since `V1.1.0`
* [x] 优化训练集,去掉不合理的样本;在此基础上,重新训练各个模型
* [x] 由 MXNet 改为 PyTorch 架构(since `V2.0.0`
* [ ] 基于 PyTorch 训练更高效的模型
* [x] 基于 PyTorch 训练更高效的模型
* [ ] 支持列格式的文字识别
......@@ -41,7 +41,7 @@ from cnocr import CnOcr, gen_model
_CONTEXT_SETTINGS = {"help_option_names": ['-h', '--help']}
logger = set_logger(log_level=logging.INFO)
DEFAULT_MODEL_NAME = 'densenet-s-fc'
DEFAULT_MODEL_NAME = 'densenet_lite_136-fc'
LEGAL_MODEL_NAMES = {
enc_name + '-' + dec_name
for enc_name in ENCODER_CONFIGS.keys()
......@@ -58,7 +58,7 @@ def cli():
@click.option(
'-m',
'--model-name',
type=click.Choice(LEGAL_MODEL_NAMES),
type=str,
default=DEFAULT_MODEL_NAME,
help='模型名称。默认值为 %s' % DEFAULT_MODEL_NAME,
)
......@@ -97,7 +97,7 @@ def train(
train_transform = T.Compose(
[
RandomStretchAug(min_ratio=0.5, max_ratio=1.5),
RandomCrop((8, 10)),
# RandomCrop((8, 10)),
T.RandomInvert(p=0.2),
T.RandomApply([T.RandomRotation(degrees=1)], p=0.4),
# T.RandomAutocontrast(p=0.05),
......@@ -161,7 +161,7 @@ def visualize_example(example, fp_prefix):
@click.option(
'-m',
'--model-name',
type=click.Choice(LEGAL_MODEL_NAMES),
type=str,
default=DEFAULT_MODEL_NAME,
help='模型名称。默认值为 %s' % DEFAULT_MODEL_NAME,
)
......@@ -213,7 +213,7 @@ def predict(model_name, pretrained_model_fp, context, img_file_or_dir, single_li
@click.option(
'-m',
'--model-name',
type=click.Choice(LEGAL_MODEL_NAMES),
type=str,
default=DEFAULT_MODEL_NAME,
help='模型名称。默认值为 %s' % DEFAULT_MODEL_NAME,
)
......
......@@ -57,7 +57,7 @@ class CnOcr(object):
def __init__(
self,
model_name: str = 'densenet_lite_124-fc',
model_name: str = 'densenet_lite_136-fc',
*,
cand_alphabet: Optional[Union[Collection, str]] = None,
context: str = 'cpu', # ['cpu', 'gpu', 'cuda']
......@@ -69,12 +69,12 @@ class CnOcr(object):
识别模型初始化函数。
Args:
model_name (str): 模型名称。默认为 `densenet_lite_124-fc`
model_name (str): 模型名称。默认为 `densenet_lite_136-fc`
cand_alphabet (Optional[Union[Collection, str]]): 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
context (str): 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为 `cpu`
model_fp (Optional[str]): 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件)
root (Union[str, Path]): 模型文件所在的根目录。
Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.0/densenet-s-fc`。
Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.1/densenet_lite_136-fc`。
Windows下默认值为 `C:/Users/<username>/AppData/Roaming/cnocr`。
**kwargs: 目前未被使用。
......@@ -83,10 +83,10 @@ class CnOcr(object):
>>> ocr = CnOcr()
使用指定模型:
>>> ocr = CnOcr(model_name='densenet_lite_124-fc')
>>> ocr = CnOcr(model_name='densenet_lite_136-fc')
识别时只考虑数字:
>>> ocr = CnOcr(model_name='densenet_lite_124-fc', cand_alphabet='0123456789')
>>> ocr = CnOcr(model_name='densenet_lite_136-fc', cand_alphabet='0123456789')
"""
if 'name' in kwargs:
......
......@@ -60,7 +60,6 @@ ENCODER_CONFIGS = {
'num_init_features': 64,
'out_length': 720,
},
'densenet_lite_113': { # 长度压缩至 1/8(seq_len == 35),输出的向量长度为 2*136 = 272
'growth_rate': 32,
'block_config': [1, 1, 3],
......@@ -91,52 +90,29 @@ ENCODER_CONFIGS = {
'num_init_features': 64,
'out_length': 528,
},
'mobilenetv3_tiny': {
'arch': 'tiny',
'out_length': 384,
},
'mobilenetv3_small': {
'arch': 'small',
'out_length': 384,
}
'mobilenetv3_tiny': {'arch': 'tiny', 'out_length': 384,},
'mobilenetv3_small': {'arch': 'small', 'out_length': 384,},
}
DECODER_CONFIGS = {
'lstm': {
# 'input_size': 512, # 对应 encoder 的输出向量长度
'rnn_units': 128,
},
'gru': {
# 'input_size': 512, # 对应 encoder 的输出向量长度
'rnn_units': 128,
},
'fcfull': {
# 'input_size': 512, # 对应 encoder 的输出向量长度
'hidden_size': 256,
'dropout': 0.3,
},
'fc': {
# 'input_size': 512, # 对应 encoder 的输出向量长度
'hidden_size': 128,
'dropout': 0.1,
},
'lstm': {'rnn_units': 128,},
'gru': {'rnn_units': 128,},
'fc': {'hidden_size': 128, 'dropout': 0.1,},
'fcfull': {'hidden_size': 256, 'dropout': 0.3,},
}
root_url = (
'https://beiye-model.oss-cn-beijing.aliyuncs.com/models/cnocr/%s/'
'https://huggingface.co/breezedeus/cnstd-cnocr-models/resolve/main/models/cnocr/%s/'
% MODEL_VERSION
)
# name: (epochs, url)
# name: (epoch, url)
AVAILABLE_MODELS = {
# 'densenet-s-fc': (8, root_url + 'densenet-s-fc-v2.0.1.zip'),
# 'densenet-s-gru': (14, root_url + 'densenet-s-gru-v2.0.1.zip'),
# 'densenet-lite-113-fclite': (33, root_url + '.zip'),
'densenet_lite_114-fc': (31, root_url + '.zip'),
'densenet_lite_124-fc': (36, root_url + '.zip'),
'densenet_lite_134-fc': (38, root_url + '.zip'),
'densenet_lite_136-fc': (17, root_url + '.zip'),
'densenet_lite_136-fc-scene': (17, root_url + '.zip'),
'densenet_lite_114-fc': (37, root_url + 'densenet_lite_114-fc.zip'),
'densenet_lite_124-fc': (39, root_url + 'densenet_lite_124-fc.zip'),
'densenet_lite_134-fc': (34, root_url + 'densenet_lite_134-fc.zip'),
'densenet_lite_136-fc': (39, root_url + 'densenet_lite_136-fc.zip'),
'densenet_lite_134-gru': (2, root_url + 'densenet_lite_134-gru.zip'),
'densenet_lite_136-gru': (2, root_url + 'densenet_lite_136-gru.zip'),
}
# 候选字符集合
......
# 常见问题(FAQ)
## CnOcr 是免费的吗?
CnOcr是免费的,而且是开源的。可以按需自行调整发布或商业使用。
## CnOcr 能识别英文以及空格吗?
可以。
## CnOcr 能识别繁体中文吗?
不能。
## CnOcr 能识别竖排文字的图片吗?
不能。
......@@ -3,24 +3,29 @@
cnocr的ocr模型可以分为两阶段:第一阶段是获得ocr图片的局部编码向量,第二部分是对局部编码向量进行序列学习,获得序列编码向量。目前的PyTorch版本的两个阶段分别包含以下模型:
1. 局部编码模型(emb model)
* **`densenet-s`**:一个小型的`densenet`网络;
* **`densenet_lite_<numbers>`**:一个微型的`densenet`网络;其中的`<number>`表示模型中每个block包含的层数。
* **`densenet`**:一个小型的`densenet`网络;
2. 序列编码模型(seq model)
* **`lstm`**:一层的LSTM网络;
* **`fc`**:两层的全连接网络;
* **`gru`**:一层的GRU网络;
* **`fc`**:两层的全连接网络。
* **`lstm`**:一层的LSTM网络。
cnocr **V2.0** 目前包含以下可直接使用的模型,训练好的模型都放在 **[cnstd-cnocr-models](https://github.com/breezedeus/cnstd-cnocr-models)** 项目中,可免费下载使用:
cnocr **V2.1** 目前包含以下可直接使用的模型,训练好的模型都放在 **[cnstd-cnocr-models](https://github.com/breezedeus/cnstd-cnocr-models)** 项目中,可免费下载使用:
| 模型名称 | 局部编码模型 | 序列编码模型 | 模型大小 | 迭代次数 | 测试集准确率 |
| :------- | ------------ | ------------ | -------- | ------ | -------- |
| densenet-s-gru | densenet-s | gru | 11 M | 11 | 95.5% |
| densenet-s-fc | densenet-s | fc | 8.7 M | 39 | 91.9% |
| Name | 参数规模 | 模型文件大小 | 准确度 | 平均推断耗时(毫秒/图) |
| --- | --- | --- | --- | --- |
| densenet\_lite\_114-fc | 1.3 M | 4.9 M | 0.9274 | 9.229 |
| densenet\_lite\_124-fc | 1.3 M | 5.1 M | 0.9429 | 10.112 |
| densenet\_lite\_134-fc | 1.4 M | 5.4 M | 0.954 | 10.843 |
| densenet\_lite\_136-fc | 1.5M | 5.9 M | 0.9631 | 11.499 |
| densenet\_lite\_134-gru | 2.9 M | 11 M | 0.9738 | 17.042 |
| densenet\_lite\_136-gru | 3.1 M | 12 M | 0.9756 | 17.725 |
> 模型名称是由局部编码模型和序列编码模型名称拼接而成。
> 模型名称是由局部编码模型和序列编码模型名称拼接而成,以符合"-"分割。
......@@ -107,6 +107,7 @@ nav:
- 文本检测CnStd + 文字识别CnOcr: cnstd_cnocr.md
- 模型训练: train.md
- QQ交流群: contact.md
- 常见问题(FAQ): faq.md
- RELEASE 文档: RELEASE.md
- Python API:
- CnOcr 类: cnocr/cn_ocr.md
\ No newline at end of file
......@@ -5,6 +5,17 @@ from copy import deepcopy
import pytest
import torch
from torch import nn
from torchvision.models import (
resnet50,
resnet34,
resnet18,
mobilenet_v3_large,
mobilenet_v3_small,
shufflenet_v2_x1_0,
shufflenet_v2_x1_5,
shufflenet_v2_x2_0,
)
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))
......@@ -12,6 +23,7 @@ sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))
from cnocr.utils import set_logger, get_model_size
from cnocr.consts import IMG_STANDARD_HEIGHT, ENCODER_CONFIGS, DECODER_CONFIGS
from cnocr.models.densenet import DenseNet, DenseNetLite
from cnocr.models.mobilenet import gen_mobilenet_v3
logger = set_logger('info')
......@@ -36,127 +48,110 @@ def test_densenet():
logger.info(res.shape)
assert tuple(res.shape) == (4, 128, 4, 35)
net = DenseNet(32, [2, 2, 1, 1], 64)
net = DenseNet(32, [1, 1, 1, 4], 64)
net.eval()
logger.info(net)
logger.info(f'model size: {get_model_size(net)}') # 301440
logger.info(img.shape)
res = net(img)
logger.info(res.shape)
assert tuple(res.shape) == (4, 80, 4, 35)
# assert tuple(res.shape) == (4, 100, 4, 35)
#
# net = DenseNet(32, [1, 1, 2, 2], 64)
# net.eval()
# logger.info(net)
# logger.info(f'model size: {get_model_size(net)}') # 243616
# logger.info(img.shape)
# res = net(img)
# logger.info(res.shape)
# assert tuple(res.shape) == (4, 116, 4, 35)
#
# net = DenseNet(32, [1, 2, 2, 2], 64)
# net.eval()
# logger.info(net)
# logger.info(f'model size: {get_model_size(net)}') # 230680
# logger.info(img.shape)
# res = net(img)
# logger.info(res.shape)
# assert tuple(res.shape) == (4, 124, 4, 35)
#
# net = DenseNet(32, [1, 1, 2, 4], 64)
# net.eval()
# logger.info(net)
# logger.info(f'model size: {get_model_size(net)}') # 230680
# logger.info(img.shape)
# res = net(img)
# logger.info(res.shape)
# assert tuple(res.shape) == (4, 180, 4, 35)
net = DenseNet(32, [2, 1, 1, 1], 64)
net.eval()
logger.info(net)
logger.info(f'model size: {get_model_size(net)}') # 243616
logger.info(img.shape)
res = net(img)
logger.info(res.shape)
assert tuple(res.shape) == (4, 72, 4, 35)
net = DenseNet(32, [1, 1, 1, 2], 64)
def test_densenet_lite():
width = 280
img = torch.rand(4, 1, IMG_STANDARD_HEIGHT, width)
# net = DenseNetLite(32, [2, 2, 2], 64)
# net.eval()
# logger.info(net)
# logger.info(f'model size: {get_model_size(net)}') # 302976
# logger.info(img.shape)
# res = net(img)
# logger.info(res.shape)
# assert tuple(res.shape) == (4, 128, 2, 35)
# net = DenseNetLite(32, [2, 1, 1], 64)
# net.eval()
# logger.info(net)
# logger.info(f'model size: {get_model_size(net)}') # 197952
# logger.info(img.shape)
# res = net(img)
# logger.info(res.shape)
# assert tuple(res.shape) == (4, 80, 2, 35)
net = DenseNetLite(32, [1, 3, 4], 64)
net.eval()
logger.info(net)
logger.info(f'model size: {get_model_size(net)}') # 230680
logger.info(f'model size: {get_model_size(net)}') # 186672
logger.info(img.shape)
res = net(img)
logger.info(res.shape)
assert tuple(res.shape) == (4, 100, 4, 35)
assert tuple(res.shape) == (4, 200, 2, 35)
def test_densenet_lite():
width = 280
img = torch.rand(4, 1, IMG_STANDARD_HEIGHT, width)
net = DenseNetLite(32, [2, 2, 2], 64)
net = DenseNetLite(32, [1, 3, 6], 64)
net.eval()
logger.info(net)
logger.info(f'model size: {get_model_size(net)}') # 302976
logger.info(f'model size: {get_model_size(net)}') # 186672
logger.info(img.shape)
res = net(img)
logger.info(res.shape)
assert tuple(res.shape) == (4, 128, 2, 35)
assert tuple(res.shape) == (4, 264, 2, 35)
# net = DenseNetLite(32, [2, 1, 1], 64)
# net = DenseNetLite(32, [1, 2, 2], 64)
# net.eval()
# logger.info(net)
# logger.info(f'model size: {get_model_size(net)}') # 197952
# logger.info(f'model size: {get_model_size(net)}') #
# logger.info(img.shape)
# res = net(img)
# logger.info(res.shape)
# assert tuple(res.shape) == (4, 80, 2, 35)
# assert tuple(res.shape) == (4, 120, 2, 35)
net = DenseNetLite(32, [1, 1, 2], 64)
def test_mobilenet():
width = 280
img = torch.rand(4, 1, IMG_STANDARD_HEIGHT, width)
net = gen_mobilenet_v3('tiny')
net.eval()
logger.info(net)
logger.info(f'model size: {get_model_size(net)}') # 186672
logger.info(img.shape)
res = net(img)
logger.info(f'model size: {get_model_size(net)}') # 186672
logger.info(res.shape)
assert tuple(res.shape) == (4, 104, 2, 35)
assert tuple(res.shape) == (4, 192, 2, 35)
net = DenseNetLite(32, [1, 2, 2], 64)
net = gen_mobilenet_v3('small')
net.eval()
logger.info(net)
logger.info(f'model size: {get_model_size(net)}') #
logger.info(img.shape)
res = net(img)
logger.info(f'model size: {get_model_size(net)}') # 186672
logger.info(res.shape)
assert tuple(res.shape) == (4, 120, 2, 35)
def test_crnn():
_hp = deepcopy(HP)
_hp.set_seq_length(_hp.img_width // 4)
x = nd.random.randn(128, 64, 32, 280)
layer_channels_list = [(64, 128, 256, 512), (32, 64, 128, 256)]
for layer_channels in layer_channels_list:
densenet = DenseNet(layer_channels)
crnn = CRnn(_hp, densenet)
crnn.initialize()
y = crnn(x)
logger.info(
'output shape: %s', y.shape
) # res: `(sequence_length, batch_size, 2*num_hidden)`
assert y.shape == (_hp.seq_length, _hp.batch_size, 2 * _hp.num_hidden)
logger.info('number of parameters: %d', cal_num_params(crnn))
def test_crnn_lstm():
hp = deepcopy(HP)
hp.set_seq_length(hp.img_width // 8)
data = mx.sym.Variable('data', shape=(128, 1, 32, 280))
pred = crnn_lstm(HP, data)
pred_shape = pred.infer_shape()[1][0]
logger.info('shape of pred: %s', pred_shape)
assert pred_shape == (hp.seq_length, hp.batch_size, 2 * hp.num_hidden)
def test_crnn_lstm_lite():
hp = deepcopy(HP)
width = hp.img_width # 280
data = mx.sym.Variable('data', shape=(128, 1, 32, width))
for shorter in (False, True):
pred = crnn_lstm_lite(HP, data, shorter=shorter)
pred_shape = pred.infer_shape()[1][0]
logger.info('shape of pred: %s', pred_shape)
seq_len = hp.img_width // 8 if shorter else hp.img_width // 4 - 1
assert pred_shape == (seq_len, hp.batch_size, 2 * hp.num_hidden)
def test_pipline():
hp = deepcopy(HP)
hp.set_seq_length(hp.img_width // 4)
hp._loss_type = None # infer mode
layer_channels_list = [(64, 128, 256, 512), (32, 64, 128, 256)]
for layer_channels in layer_channels_list:
densenet = DenseNet(layer_channels)
crnn = CRnn(hp, densenet)
data = mx.sym.Variable('data', shape=(128, 1, 32, 280))
pred = pipline(crnn, hp, data)
pred_shape = pred.infer_shape()[1][0]
logger.info('shape of pred: %s', pred_shape)
assert pred_shape == (hp.batch_size * hp.seq_length, hp.num_classes)
assert tuple(res.shape) == (4, 192, 2, 35)
MODEL_NAMES = []
for emb_model in ENCODER_CONFIGS:
......@@ -164,15 +159,15 @@ for emb_model in ENCODER_CONFIGS:
MODEL_NAMES.append('%s-%s' % (emb_model, seq_model))
@pytest.mark.parametrize(
'model_name', MODEL_NAMES
)
def test_gen_networks(model_name):
logger.info('model_name: %s', model_name)
network, hp = gen_network(model_name, HP)
shape_dict = get_infer_shape(network, HP)
logger.info('shape_dict: %s', shape_dict)
assert shape_dict['pred_fc_output'] == (
hp.batch_size * hp.seq_length,
hp.num_classes,
)
# @pytest.mark.parametrize(
# 'model_name', MODEL_NAMES
# )
# def test_gen_networks(model_name):
# logger.info('model_name: %s', model_name)
# network, hp = gen_network(model_name, HP)
# shape_dict = get_infer_shape(network, HP)
# logger.info('shape_dict: %s', shape_dict)
# assert shape_dict['pred_fc_output'] == (
# hp.batch_size * hp.seq_length,
# hp.num_classes,
# )
......@@ -36,65 +36,3 @@ def test_crnn():
crnn = OcrModel(net, vocab=ENG_LETTERS, lstm_features=512, rnn_units=128)
res2 = crnn(img)
print(res2)
def test_crnn_for_variable_length():
vocab, letter2id = read_charset(VOCAB_FP)
net = DenseNet(32, [2, 2, 2, 2], 64)
crnn = OcrModel(net, vocab=vocab, lstm_features=512, rnn_units=128)
crnn.eval()
model_fp = VOCAB_FP.parent / 'models/last.ckpt'
if model_fp.exists():
print(f'load model params from {model_fp}')
load_model_params(crnn, model_fp)
width = 280
img1 = torch.rand(1, IMG_STANDARD_HEIGHT, width)
img2 = torch.rand(1, IMG_STANDARD_HEIGHT, width // 2)
img3 = torch.rand(1, IMG_STANDARD_HEIGHT, width * 2)
imgs = pad_img_seq([img1, img2, img3])
input_lengths = torch.Tensor([width, width // 2, width * 2])
out = crnn(
imgs, input_lengths=input_lengths, return_model_output=True, return_preds=True,
)
print(out['preds'])
padded = torch.zeros((3, 1, IMG_STANDARD_HEIGHT, 50))
imgs2 = torch.cat((imgs, padded), dim=-1)
out2 = crnn(
imgs2, input_lengths=input_lengths, return_model_output=True, return_preds=True,
)
print(out2['preds'])
# breakpoint()
def test_crnn_for_variable_length2():
vocab, letter2id = read_charset(VOCAB_FP)
net = DenseNet(32, [2, 2, 2, 2], 64)
crnn = OcrModel(net, vocab=vocab, lstm_features=512, rnn_units=128)
crnn.eval()
model_fp = VOCAB_FP.parent / 'models/last.ckpt'
if model_fp.exists():
print(f'load model params from {model_fp}')
load_model_params(crnn, model_fp)
img_fps = ('helloworld.jpg', 'helloworld-ch.jpg')
imgs = []
input_lengths = []
for fp in img_fps:
img = read_img(VOCAB_FP.parent / 'examples' / fp)
img = rescale_img(img)
input_lengths.append(img.shape[2])
imgs.append(normalize_img_array(img))
imgs = pad_img_seq(imgs)
input_lengths = torch.Tensor(input_lengths)
out = crnn(
imgs, input_lengths=input_lengths, return_model_output=True, return_preds=True,
)
print(out['preds'])
padded = torch.zeros((2, 1, IMG_STANDARD_HEIGHT, 80))
imgs2 = torch.cat((imgs, padded), dim=-1)
out2 = crnn(
imgs2, input_lengths=input_lengths, return_model_output=True, return_preds=True,
)
print(out2['preds'])
# breakpoint()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册