diff --git a/Makefile b/Makefile index e9b061b3e64aadcee1436e24068c8afe561d8cea..47221308f45a018f1d5a8f3384a2d4e212671020 100644 --- a/Makefile +++ b/Makefile @@ -1,12 +1,12 @@ -# 可取值:['densenet-s'] -ENCODER_NAME = densenet-s +# 可取值:['densenet_lite_136'] +ENCODER_NAME = densenet_lite_136 # 可取值:['fc', 'gru', 'lstm'] -DECODER_NAME = gru +DECODER_NAME = fc MODEL_NAME = $(ENCODER_NAME)-$(DECODER_NAME) EPOCH = 41 INDEX_DIR = data/test -TRAIN_CONFIG_FP = examples/train_config.json +TRAIN_CONFIG_FP = docs/examples/train_config.json # 训练模型 train: @@ -14,19 +14,28 @@ train: # 在测试集上评估模型,所有badcases的具体信息会存放到文件夹 `evaluate/$(MODEL_NAME)` 中 evaluate: - python scripts/cnocr_evaluate.py --model-name $(MODEL_NAME) --model-epoch 1 -v -i $(DATA_ROOT_DIR)/test.txt \ - --image-prefix-dir examples --batch-size 128 -o evaluate/$(MODEL_NAME) + cnocr evaluate --model-name $(MODEL_NAME) -i data/test/dev.tsv \ + --image-folder data/images --batch-size 128 -o eval_results/$(MODEL_NAME) predict: - cnocr predict -m $(MODEL_NAME) -i examples/rand_cn1.png + cnocr predict -m $(MODEL_NAME) -i docs/examples/rand_cn1.png + + +doc: +# pip install mkdocs +# pip install mkdocs-macros-plugin +# pip install mkdocs-material +# pip install mkdocstrings + python -m mkdocs serve +# python -m mkdocs build package: python setup.py sdist bdist_wheel -VERSION = 2.0.1 +VERSION = 2.1.0 upload: python -m twine upload dist/cnocr-$(VERSION)* --verbose -.PHONY: train evaluate predict package upload +.PHONY: train evaluate predict doc package upload diff --git a/README.md b/README.md index 5d395b80bb6737260a0299a121fdbd6c5c77c6bd..cd4a13e77be6b1c16bfa997f2560418afb23b1f0 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ English [README](./README_en.md) (`out-dated`). + # cnocr **cnocr** 是 **Python 3** 下的**文字识别**(**Optical Character Recognition**,简称**OCR**)工具包,支持**中文**、**英文**的常见字符识别,自带了多个训练好的识别模型,安装后即可直接使用。欢迎扫码加入QQ交流群: @@ -7,23 +8,9 @@ English [README](./README_en.md) (`out-dated`). ![QQ群二维码](./docs/cnocr-qq.jpg) +## 详细文档 -# 最近更新 【2021.08.26】:V2.0.0 - -主要变更: - -* MXNet 越来越小众化,故从基于 MXNet 的实现转为基于 **PyTorch** 的实现; -* 重新实现了识别模型,优化了训练数据,重新训练模型; -* 优化了能识别的字符集合; -* 优化了对英文的识别效果; -* 优化了对场景文字的识别效果; -* 使用接口略有调整。 - - - -更多更新说明见 [RELEASE Notes](./RELEASE.md)。 - - +见 [CnOcr在线文档](https://cnocr.readthedocs.io/) 。 ## 使用场景说明 @@ -35,20 +22,20 @@ English [README](./README_en.md) (`out-dated`). | 图片 | OCR结果 | | ------------------------------------------------------------ | ------------------------------------------------------------ | -| ![examples/helloworld.jpg](./examples/helloworld.jpg) | Hello world!你好世界 | -| ![examples/chn-00199989.jpg](./examples/chn-00199989.jpg) | 铑泡胭释邑疫反隽寥缔 | -| ![examples/chn-00199980.jpg](./examples/chn-00199980.jpg) | 拇箬遭才柄腾戮胖惬炫 | -| ![examples/chn-00199984.jpg](./examples/chn-00199984.jpg) | 寿猿嗅髓孢刀谎弓供捣 | -| ![examples/chn-00199985.jpg](./examples/chn-00199985.jpg) | 马靼蘑熨距额猬要藕萼 | -| ![examples/chn-00199981.jpg](./examples/chn-00199981.jpg) | 掉江悟厉励.谌查门蠕坑 | -| ![examples/00199975.jpg](./examples/00199975.jpg) | nd-chips fructed ast | -| ![examples/00199978.jpg](./examples/00199978.jpg) | zouna unpayably Raqu | -| ![examples/00199979.jpg](./examples/00199979.jpg) | ape fissioning Senat | -| ![examples/00199971.jpg](./examples/00199971.jpg) | ling oughtlins near | -| ![examples/multi-line_cn1.png](./examples/multi-line_cn1.png) | 网络支付并无本质的区别,因为
每一个手机号码和邮件地址背后
都会对应着一个账户--这个账
户可以是信用卡账户、借记卡账
户,也包括邮局汇款、手机代
收、电话代收、预付费卡和点卡
等多种形式。 | -| ![examples/multi-line_cn2.png](./examples/multi-line_cn2.png) | 当然,在媒介越来越多的情形下,
意味着传播方式的变化。过去主流
的是大众传播,现在互动性和定制
性带来了新的挑战——如何让品牌
与消费者更加互动。 | -| ![examples/multi-line_en_white.png](./examples/multi-line_en_white.png) | This chapter is currently only available in this web version. ebook and print will follow.
Convolutional neural networks learn abstract features and concepts from raw image pixels. Feature
Visualization visualizes the learned features by activation maximization. Network Dissection labels
neural network units (e.g. channels) with human concepts. | -| ![examples/multi-line_en_black.png](./examples/multi-line_en_black.png) | transforms the image many times. First, the image goes through many convolutional layers. In those
convolutional layers, the network learns new and increasingly complex features in its layers. Then the
transformed image information goes through the fully connected layers and turns into a classification
or prediction. | +| ![docs/examples/helloworld.jpg](./docs/examples/helloworld.jpg) | Hello world!你好世界 | +| ![docs/examples/chn-00199989.jpg](./docs/examples/chn-00199989.jpg) | 铑泡胭释邑疫反隽寥缔 | +| ![docs/examples/chn-00199980.jpg](./docs/examples/chn-00199980.jpg) | 拇箬遭才柄腾戮胖惬炫 | +| ![docs/examples/chn-00199984.jpg](./docs/examples/chn-00199984.jpg) | 寿猿嗅髓孢刀谎弓供捣 | +| ![docs/examples/chn-00199985.jpg](./docs/examples/chn-00199985.jpg) | 马靼蘑熨距额猬要藕萼 | +| ![docs/examples/chn-00199981.jpg](./docs/examples/chn-00199981.jpg) | 掉江悟厉励.谌查门蠕坑 | +| ![docs/examples/00199975.jpg](./docs/examples/00199975.jpg) | nd-chips fructed ast | +| ![docs/examples/00199978.jpg](./docs/examples/00199978.jpg) | zouna unpayably Raqu | +| ![docs/examples/00199979.jpg](./docs/examples/00199979.jpg) | ape fissioning Senat | +| ![docs/examples/00199971.jpg](./docs/examples/00199971.jpg) | ling oughtlins near | +| ![docs/examples/multi-line_cn1.png](./docs/examples/multi-line_cn1.png) | 网络支付并无本质的区别,因为
每一个手机号码和邮件地址背后
都会对应着一个账户--这个账
户可以是信用卡账户、借记卡账
户,也包括邮局汇款、手机代
收、电话代收、预付费卡和点卡
等多种形式。 | +| ![docs/examples/multi-line_cn2.png](./docs/examples/multi-line_cn2.png) | 当然,在媒介越来越多的情形下,
意味着传播方式的变化。过去主流
的是大众传播,现在互动性和定制
性带来了新的挑战——如何让品牌
与消费者更加互动。 | +| ![docs/examples/multi-line_en_white.png](./docs/examples/multi-line_en_white.png) | This chapter is currently only available in this web version. ebook and print will follow.
Convolutional neural networks learn abstract features and concepts from raw image pixels. Feature
Visualization visualizes the learned features by activation maximization. Network Dissection labels
neural network units (e.g. channels) with human concepts. | +| ![docs/examples/multi-line_en_black.png](./docs/examples/multi-line_en_black.png) | transforms the image many times. First, the image goes through many convolutional layers. In those
convolutional layers, the network learns new and increasingly complex features in its layers. Then the
transformed image information goes through the fully connected layers and turns into a classification
or prediction. | @@ -73,375 +60,6 @@ pip install cnocr -i https://pypi.doubanio.com/simple > 注意:请使用 **Python3**(3.6以及之后版本应该都行),没测过Python2下是否ok。 - -## 可直接使用的模型 - -cnocr的ocr模型可以分为两阶段:第一阶段是获得ocr图片的局部编码向量,第二部分是对局部编码向量进行序列学习,获得序列编码向量。目前的PyTorch版本的两个阶段分别包含以下模型: - -1. 局部编码模型(emb model) - * **`densenet-s`**:一个小型的`densenet`网络; -2. 序列编码模型(seq model) - * **`lstm`**:一层的LSTM网络; - * **`gru`**:一层的GRU网络; - * **`fc`**:两层的全连接网络。 - - - -cnocr **V2.0** 目前包含以下可直接使用的模型,训练好的模型都放在 **[cnstd-cnocr-models](https://github.com/breezedeus/cnstd-cnocr-models)** 项目中,可免费下载使用: - -| 模型名称 | 局部编码模型 | 序列编码模型 | 模型大小 | 迭代次数 | 测试集准确率 | -| :------- | ------------ | ------------ | -------- | ------ | -------- | -| densenet-s-gru | densenet-s | gru | 11 M | 11 | 95.5% | -| densenet-s-fc | densenet-s | fc | 8.7 M | 39 | 91.9% | - -> 模型名称是由局部编码模型和序列编码模型名称拼接而成。 - - - - - -## 使用方法 - -首次使用cnocr时,系统会**自动下载** zip格式的模型压缩文件,并存于 `~/.cnocr`目录(Windows下默认路径为 `C:\Users\\AppData\Roaming\cnocr`)。 -下载后的zip文件代码会自动对其解压,然后把解压后的模型相关目录放于`~/.cnocr/2.0`目录中。 - -如果系统无法自动成功下载zip文件,则需要手动从 **[cnstd-cnocr-models](https://github.com/breezedeus/cnstd-cnocr-models)** 下载此zip文件并把它放于 `~/.cnocr/2.0`目录。如果Github下载太慢,也可以从 [百度云盘](https://pan.baidu.com/s/1c68zjHfTVeqiSMXBEPYMrg) 下载, 提取码为 ` 9768`。 - -放置好zip文件后,后面的事代码就会自动执行了。 - - - -### 图片预测 - -类`CnOcr`是OCR的主类,包含了三个函数针对不同场景进行文字识别。类`CnOcr`的初始化函数如下: - -```python -class CnOcr(object): - def __init__( - self, - model_name: str = 'densenet-s-fc' - *, - cand_alphabet: Optional[Union[Collection, str]] = None, - context: str = 'cpu', # ['cpu', 'gpu', 'cuda'] - model_fp: Optional[str] = None, - root: Union[str, Path] = data_dir(), - **kwargs, - ): -``` - -其中的几个参数含义如下: - -* `model_name`: 模型名称,即上面表格第一列中的值。默认为 `densenet-s-fc`。 - -* `cand_alphabet`: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围。取值可以是字符串,如 `"0123456789"`,或者字符列表,如 `["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]`。 - - * `cand_alphabet`也可以初始化后通过类函数 `CnOcr.set_cand_alphabet(cand_alphabet)` 进行设置。这样同一个实例也可以指定不同的`cand_alphabet`进行识别。 - -* `context`:预测使用的机器资源,可取值为字符串`cpu`、`gpu`、`cuda:0`等。 - -* `model_fp`: 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件(`.ckpt` 文件)。 - -* `root`: 模型文件所在的根目录。 - - * Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.0/densenet-s-fc`。 - * Windows下默认值为 `C:\Users\\AppData\Roaming\cnocr`。 - - - -每个参数都有默认取值,所以可以不传入任何参数值进行初始化:`ocr = CnOcr()`。 - - - - -类`CnOcr`主要包含三个函数,下面分别说明。 - - - -#### 1. 函数`CnOcr.ocr(img_fp)` - -函数`CnOcr.ocr(img_fp)`可以对包含多行文字(或单行)的图片进行文字识别。 - - - -**函数说明**: - -- 输入参数 `img_fp`: 可以是需要识别的图片文件路径(如下例);或者是已经从图片文件中读入的数组,类型可以为 `torch.Tensor` 或 `np.ndarray`,取值应该是`[0,255]`的整数,维数应该是 `[height, width]` (灰度图片)或者 `[height, width, channel]`,`channel` 可以等于`1`(灰度图片)或者`3`(`RGB`格式的彩色图片)。 -- 返回值:为一个嵌套的`list`,其中的每个元素存储了对一行文字的识别结果,其中也包含了识别概率值。类似这样`[(['第', '一', '行'], 0.80), (['第', '二', '行'], 0.75), (['第', '三', '行'], 0.9)]`,其中的数字为对应的识别概率值。 - - - -**调用示例**: - - -```python -from cnocr import CnOcr - -ocr = CnOcr() -res = ocr.ocr('examples/multi-line_cn1.png') -print("Predicted Chars:", res) -``` - -或: -```python -from cnocr.utils import read_img -from cnocr import CnOcr - -ocr = CnOcr() -img_fp = 'examples/multi-line_cn1.png' -img = read_img(img_fp) -res = ocr.ocr(img) -print("Predicted Chars:", res) -``` - - - -上面使用的图片文件 [examples/multi-line_cn1.png](./examples/multi-line_cn1.png)内容如下: - -![examples/multi-line_cn1.png](./examples/multi-line_cn1.png) - - - -上面预测代码段的返回结果如下: - -```bash -Predicted Chars: [ - (['网', '络', '支', '付', '并', '无', '本', '质', '的', '区', '别', ',', '因', '为'], 0.8677546381950378), - (['每', '一', '个', '手', '机', '号', '码', '和', '邮', '件', '地', '址', '背', '后'], 0.6706454157829285), - (['都', '会', '对', '应', '着', '一', '个', '账', '户', '一', '一', '这', '个', '账'], 0.5052655935287476), - (['户', '可', '以', '是', '信', '用', '卡', '账', '户', '、', '借', '记', '卡', '账'], 0.7785991430282593), - (['户', ',', '也', '包', '括', '邮', '局', '汇', '款', '、', '手', '机', '代'], 0.37458470463752747), - (['收', '、', '电', '话', '代', '收', '、', '预', '付', '费', '卡', '和', '点', '卡'], 0.7326119542121887), - (['等', '多', '种', '形', '式', '。'], 0.14462216198444366)] -``` - - - -#### 2. 函数`CnOcr.ocr_for_single_line(img_fp)` - -如果明确知道要预测的图片中只包含了单行文字,可以使用函数`CnOcr.ocr_for_single_line(img_fp)`进行识别。和 `CnOcr.ocr()`相比,`CnOcr.ocr_for_single_line()`结果可靠性更强,因为它不需要做额外的分行处理。 - -**函数说明**: - -- 输入参数 `img_fp`: 可以是需要识别的图片文件路径(如下例);或者是已经从图片文件中读入的数组,类型可以为 `torch.Tensor` 或 `np.ndarray`,取值应该是`[0,255]`的整数,维数应该是 `[height, width]` (灰度图片)或者 `[height, width, channel]`,`channel` 可以等于`1`(灰度图片)或者`3`(`RGB`格式的彩色图片)。 -- 返回值:为一个`tuple`,其中存储了对一行文字的识别结果,也包含了识别概率值。类似这样`(['第', '一', '行'], 0.80)`,其中的数字为对应的识别概率值。 - - - -**调用示例**: - -```python -from cnocr import CnOcr - -ocr = CnOcr() -res = ocr.ocr_for_single_line('examples/rand_cn1.png') -print("Predicted Chars:", res) -``` - -或: - -```python -from cnocr.utils import read_img -from cnocr import CnOcr - -ocr = CnOcr() -img_fp = 'examples/rand_cn1.png' -img = read_img(img_fp) -res = ocr.ocr_for_single_line(img) -print("Predicted Chars:", res) -``` - - -对图片文件 [examples/rand_cn1.png](./examples/rand_cn1.png): - -![examples/rand_cn1.png](./examples/rand_cn1.png) - -的预测结果如下: - -```bash -Predicted Chars: (['笠', '淡', '嘿', '骅', '谧', '鼎', '皋', '姚', '歼', '蠢', '驼', '耳', '胬', '挝', '涯', '狗', '蒽', '了', '狞'], 0.7832438349723816) -``` - - - -#### 3. 函数`CnOcr.ocr_for_single_lines(img_list, batch_size=1)` - -函数`CnOcr.ocr_for_single_lines(img_list)`可以**对多个单行文字图片进行批量预测**。函数`CnOcr.ocr(img_fp)`和`CnOcr.ocr_for_single_line(img_fp)`内部其实都是调用的函数`CnOcr.ocr_for_single_lines(img_list)`。 - - - -**函数说明**: - -- 输入参数` img_list`: 为一个`list`;其中每个元素可以是需要识别的图片文件路径(如下例);或者是已经从图片文件中读入的数组,类型可以为 `torch.Tensor` 或 `np.ndarray`,取值应该是`[0,255]`的整数,维数应该是 `[height, width]` (灰度图片)或者 `[height, width, channel]`,`channel` 可以等于`1`(灰度图片)或者`3`(`RGB`格式的彩色图片)。 -- 输入参数 `batch_size`: 待处理图片很多时,需要分批处理,每批图片的数量由此参数指定。默认为 `1`。 -- 返回值:为一个嵌套的`list`,其中的每个元素存储了对一行文字的识别结果,其中也包含了识别概率值。类似这样`[(['第', '一', '行'], 0.80), (['第', '二', '行'], 0.75), (['第', '三', '行'], 0.9)]`,其中的数字为对应的识别概率值。 - - - -**调用示例**: - -```python -import numpy as np - -from cnocr.utils import read_img -from cnocr import CnOcr, line_split - -ocr = CnOcr() -img_fp = 'examples/multi-line_cn1.png' -img = read_img(img_fp) -line_imgs = line_split(np.squeeze(img, -1), blank=True) -line_img_list = [line_img for line_img, _ in line_imgs] -res = ocr.ocr_for_single_lines(line_img_list) -print("Predicted Chars:", res) -``` - - - -更详细的使用方法,可参考 [tests/test_cnocr.py](./tests/test_cnocr.py) 中提供的测试用例。 - - - -### 结合文字检测引擎 **[cnstd](https://github.com/breezedeus/cnstd)** 使用 - -对于一般的场景图片(如照片、票据等),需要先利用场景文字检测引擎 **[cnstd](https://github.com/breezedeus/cnstd)** 定位到文字所在位置,然后再利用 **cnocr** 进行文本识别。 - -```python -from cnstd import CnStd -from cnocr import CnOcr - -std = CnStd() -cn_ocr = CnOcr() - -box_infos = std.detect('examples/taobao.jpg') - -for box_info in box_infos['detected_texts']: - cropped_img = box_info['cropped_img'] - ocr_res = cn_ocr.ocr_for_single_line(cropped_img) - print('ocr result: %s' % str(ocr_out)) -``` - -注:运行上面示例需要先安装 **[cnstd](https://github.com/breezedeus/cnstd)** : - -```bash -pip install cnstd -``` - -**[cnstd](https://github.com/breezedeus/cnstd)** 相关的更多使用说明请参考其项目地址。 - - - - - -### 脚本使用 - -**cnocr** 包含了几个命令行工具,安装 **cnocr** 后即可使用。 - - - -#### 预测单个文件或文件夹中所有图片 - -使用命令 **`cnocr predict`** 预测单个文件或文件夹中所有图片,以下是使用说明: - - - -```bash -(venv) ➜ cnocr git:(pytorch) ✗ cnocr predict -h -Usage: cnocr predict [OPTIONS] - -Options: - -m, --model-name [densenet-s-lstm|densenet-s-gru|densenet-s-fc] - 模型名称。默认值为 densenet-s-fc - --model_epoch INTEGER model epoch。默认为 `None`,表示使用系统自带的预训练模型 - -p, --pretrained-model-fp TEXT 使用训练好的模型。默认为 `None`,表示使用系统自带的预训练模型 - --context TEXT 使用cpu还是 `gpu` 运行代码,也可指定为特定gpu,如`cuda:0`。默认为 - `cpu` - - -i, --img-file-or-dir TEXT 输入图片的文件路径或者指定的文件夹 [required] - -s, --single-line 是否输入图片只包含单行文字。对包含单行文字的图片,不做按行切分;否则会先对图片按行分割后 - 再进行识别 - - -h, --help Show this message and exit. -``` - - - -例如可以使用以下命令对图片 `examples/rand_cn1.png` 进行文字识别: - -```bash -cnstd predict -i examples/rand_cn1.png -s -``` - - - -具体使用也可参考文件 [Makefile](./Makefile) 。 - - - -#### 模型训练 - -使用命令 **`cnocr train`** 训练文本检测模型,以下是使用说明: - -```bash -(venv) ➜ cnocr git:(pytorch) ✗ cnocr train -h -Usage: cnocr train [OPTIONS] - -Options: - -m, --model-name [densenet-s-fc|densenet-s-lstm|densenet-s-gru] - 模型名称。默认值为 densenet-s-fc - -i, --index-dir TEXT 索引文件所在的文件夹,会读取文件夹中的 train.tsv 和 dev.tsv 文件 - [required] - - --train-config-fp TEXT 训练使用的json配置文件,参考 `example/train_config.json` - [required] - - -r, --resume-from-checkpoint TEXT - 恢复此前中断的训练状态,继续训练。默认为 `None` - -p, --pretrained-model-fp TEXT 导入的训练好的模型,作为初始模型。优先级低于"--restore-training- - fp",当传入"--restore-training-fp"时,此传入失效。默认为 - `None` - - -h, --help Show this message and exit. -``` - - - -例如可以使用以下命令进行训练: - -```bash -cnocr train -m densenet-s-fc --index-dir data/test --train-config-fp examples/train_config.json -``` - - - -训练数据的格式见文件夹 [data/test](./data/test) 中的 [train.tsv](./data/test/train.tsv) 和 [dev.tsv](./data/test/dev.tsv) 文件。 - - - -具体使用也可参考文件 [Makefile](./Makefile) 。 - - - -#### 模型转存 - -训练好的模型会存储训练状态,使用命令 **`cnocr resave`** 去掉与预测无关的数据,降低模型大小。 - -```bash -(venv) ➜ cnocr git:(pytorch) ✗ cnocr resave -h -Usage: cnocr resave [OPTIONS] - - 训练好的模型会存储训练状态,使用此命令去掉预测时无关的数据,降低模型大小 - -Options: - -i, --input-model-fp TEXT 输入的模型文件路径 [required] - -o, --output-model-fp TEXT 输出的模型文件路径 [required] - -h, --help Show this message and exit. -``` - - - - - - - ## 未来工作 * [x] 支持图片包含多行文字 (`Done`) @@ -452,6 +70,6 @@ Options: * [x] 尝试新模型,如 DenseNet,进一步提升识别准确率(since `V1.1.0`) * [x] 优化训练集,去掉不合理的样本;在此基础上,重新训练各个模型 * [x] 由 MXNet 改为 PyTorch 架构(since `V2.0.0`) -* [ ] 基于 PyTorch 训练更高效的模型 +* [x] 基于 PyTorch 训练更高效的模型 * [ ] 支持列格式的文字识别 diff --git a/cnocr/__version__.py b/cnocr/__version__.py index 46d64a79f4a461ecb407585344f6895e031c374f..d1cfb8bd08b0fb9c85406b76881e5db717c94741 100644 --- a/cnocr/__version__.py +++ b/cnocr/__version__.py @@ -17,4 +17,4 @@ # specific language governing permissions and limitations # under the License. -__version__ = '2.0.1' +__version__ = '2.1.0' diff --git a/cnocr/cli.py b/cnocr/cli.py index 808a39f133287d9803467679e48fc366d7f0d14f..2e4878a85abbfda74e702e90a7bfa06d09408ffb 100644 --- a/cnocr/cli.py +++ b/cnocr/cli.py @@ -21,15 +21,19 @@ from __future__ import absolute_import, division, print_function import os import logging import time -import click +from collections import Counter import json import glob +from operator import itemgetter +from pathlib import Path +import click +import Levenshtein from torchvision import transforms as T from cnocr.consts import MODEL_VERSION, ENCODER_CONFIGS, DECODER_CONFIGS -from cnocr.utils import set_logger, load_model_params, check_model_name -from cnocr.data_utils.aug import NormalizeAug, RandomPaddingAug +from cnocr.utils import set_logger, load_model_params, check_model_name, save_img, read_img +from cnocr.data_utils.aug import NormalizeAug, RandomPaddingAug, RandomStretchAug, RandomCrop from cnocr.dataset import OcrDataModule from cnocr.trainer import PlTrainer, resave_model from cnocr import CnOcr, gen_model @@ -37,7 +41,7 @@ from cnocr import CnOcr, gen_model _CONTEXT_SETTINGS = {"help_option_names": ['-h', '--help']} logger = set_logger(log_level=logging.INFO) -DEFAULT_MODEL_NAME = 'densenet-s-fc' +DEFAULT_MODEL_NAME = 'densenet_lite_136-fc' LEGAL_MODEL_NAMES = { enc_name + '-' + dec_name for enc_name in ENCODER_CONFIGS.keys() @@ -54,7 +58,7 @@ def cli(): @click.option( '-m', '--model-name', - type=click.Choice(LEGAL_MODEL_NAMES), + type=str, default=DEFAULT_MODEL_NAME, help='模型名称。默认值为 %s' % DEFAULT_MODEL_NAME, ) @@ -69,7 +73,7 @@ def cli(): '--train-config-fp', type=str, required=True, - help='训练使用的json配置文件,参考 `example/train_config.json`', + help='训练使用的json配置文件,参考 `docs/examples/train_config.json`', ) @click.option( '-r', @@ -92,8 +96,10 @@ def train( check_model_name(model_name) train_transform = T.Compose( [ - T.RandomInvert(p=0.5), - T.RandomRotation(degrees=2), + RandomStretchAug(min_ratio=0.5, max_ratio=1.5), + # RandomCrop((8, 10)), + T.RandomInvert(p=0.2), + T.RandomApply([T.RandomRotation(degrees=1)], p=0.4), # T.RandomAutocontrast(p=0.05), # T.RandomPosterize(bits=4, p=0.3), # T.RandomAdjustSharpness(sharpness_factor=0.5, p=0.3), @@ -101,7 +107,6 @@ def train( # T.RandomApply([T.GaussianBlur(kernel_size=3)], p=0.5), NormalizeAug(), # RandomPaddingAug(p=0.5, max_pad_len=72), - ] ) val_transform = NormalizeAug() @@ -119,6 +124,18 @@ def train( pin_memory=train_config['pin_memory'], ) + # train_ds = data_mod.train + # for i in range(min(100, len(train_ds))): + # visualize_example(train_transform(train_ds[i][0]), 'debugs/train-1-%d' % i) + # visualize_example(train_transform(train_ds[i][0]), 'debugs/train-2-%d' % i) + # visualize_example(train_transform(train_ds[i][0]), 'debugs/train-3-%d' % i) + # val_ds = data_mod.val + # for i in range(min(10, len(val_ds))): + # visualize_example(val_transform(val_ds[i][0]), 'debugs/val-1-%d' % i) + # visualize_example(val_transform(val_ds[i][0]), 'debugs/val-2-%d' % i) + # visualize_example(val_transform(val_ds[i][0]), 'debugs/val-2-%d' % i) + # return + trainer = PlTrainer( train_config, ckpt_fn=['cnocr', 'v%s' % MODEL_VERSION, model_name] ) @@ -133,20 +150,21 @@ def train( ) +def visualize_example(example, fp_prefix): + if not os.path.exists(os.path.dirname(fp_prefix)): + os.makedirs(os.path.dirname(fp_prefix)) + image = example + save_img(image, '%s-image.jpg' % fp_prefix) + + @cli.command('predict') @click.option( '-m', '--model-name', - type=click.Choice(LEGAL_MODEL_NAMES), + type=str, default=DEFAULT_MODEL_NAME, help='模型名称。默认值为 %s' % DEFAULT_MODEL_NAME, ) -@click.option( - "--model_epoch", - type=int, - default=None, - help="model epoch。默认为 `None`,表示使用系统自带的预训练模型", -) @click.option( '-p', '--pretrained-model-fp', @@ -155,6 +173,7 @@ def train( help='使用训练好的模型。默认为 `None`,表示使用系统自带的预训练模型', ) @click.option( + "-c", "--context", help="使用cpu还是 `gpu` 运行代码,也可指定为特定gpu,如`cuda:0`。默认为 `cpu`", type=str, @@ -167,15 +186,8 @@ def train( is_flag=True, help="是否输入图片只包含单行文字。对包含单行文字的图片,不做按行切分;否则会先对图片按行分割后再进行识别", ) -def predict( - model_name, model_epoch, pretrained_model_fp, context, img_file_or_dir, single_line -): - ocr = CnOcr( - model_name=model_name, - model_epoch=model_epoch, - model_fp=pretrained_model_fp, - context=context, - ) +def predict(model_name, pretrained_model_fp, context, img_file_or_dir, single_line): + ocr = CnOcr(model_name=model_name, model_fp=pretrained_model_fp, context=context) ocr_func = ocr.ocr_for_single_line if single_line else ocr.ocr fp_list = [] if os.path.isfile(img_file_or_dir): @@ -197,6 +209,158 @@ def predict( logger.info('\npred: %s, with probability %f' % (''.join(preds), prob)) +@cli.command('evaluate') +@click.option( + '-m', + '--model-name', + type=str, + default=DEFAULT_MODEL_NAME, + help='模型名称。默认值为 %s' % DEFAULT_MODEL_NAME, +) +@click.option( + '-p', + '--pretrained-model-fp', + type=str, + default=None, + help='使用训练好的模型。默认为 `None`,表示使用系统自带的预训练模型', +) +@click.option( + "-c", + "--context", + help="使用cpu还是 `gpu` 运行代码,也可指定为特定gpu,如`cuda:0`。默认为 `cpu`", + type=str, + default='cpu', +) +@click.option( + "-i", + "--eval-index-fp", + type=str, + help='待评估文件所在的索引文件,格式与训练时训练集索引文件相同,每行格式为 `<图片路径>\t<以空格分割的labels>`', + default='test.txt', +) +@click.option("--img-folder", required=True, help="图片所在文件夹,相对于索引文件中记录的图片位置") +@click.option("--batch-size", type=int, help="batch size. 默认值:128", default=128) +@click.option( + '-o', + '--output-dir', + type=str, + default='eval_results', + help='存放评估结果的文件夹。默认值:`eval_results`', +) +@click.option( + "-v", "--verbose", is_flag=True, help="whether to print details to screen", +) +def evaluate( + model_name, + pretrained_model_fp, + context, + eval_index_fp, + img_folder, + batch_size, + output_dir, + verbose, +): + ocr = CnOcr(model_name=model_name, model_fp=pretrained_model_fp, context=context) + + fn_labels_list = read_input_file(eval_index_fp) + + miss_cnt, redundant_cnt = Counter(), Counter() + total_time_cost = 0.0 + bad_cnt = 0 + badcases = [] + + start_idx = 0 + while start_idx < len(fn_labels_list): + logger.info('start_idx: %d', start_idx) + batch = fn_labels_list[start_idx : start_idx + batch_size] + img_fps = [os.path.join(img_folder, fn) for fn, _ in batch] + reals = [labels for _, labels in batch] + + imgs = [read_img(img) for img in img_fps] + start_time = time.time() + outs = ocr.ocr_for_single_lines(imgs, batch_size=1) + total_time_cost += time.time() - start_time + + preds = [out[0] for out in outs] + for bad_info in compare_preds_to_reals(preds, reals, img_fps): + if verbose: + logger.info('\t'.join(bad_info)) + distance = Levenshtein.distance(bad_info[1], bad_info[2]) + bad_info.insert(0, distance) + badcases.append(bad_info) + miss_cnt.update(list(bad_info[-2])) + redundant_cnt.update(list(bad_info[-1])) + bad_cnt += 1 + + start_idx += batch_size + + badcases.sort(key=itemgetter(0), reverse=True) + + output_dir = Path(output_dir) + if not output_dir.exists(): + os.makedirs(output_dir) + with open(output_dir / 'badcases.txt', 'w') as f: + f.write( + '\t'.join( + [ + 'distance', + 'image_fp', + 'real_words', + 'pred_words', + 'miss_words', + 'redundant_words', + ] + ) + + '\n' + ) + for bad_info in badcases: + f.write('\t'.join(map(str, bad_info)) + '\n') + with open(output_dir / 'miss_words_stat.txt', 'w') as f: + for word, num in miss_cnt.most_common(): + f.write('\t'.join([word, str(num)]) + '\n') + with open(output_dir / 'redundant_words_stat.txt', 'w') as f: + for word, num in redundant_cnt.most_common(): + f.write('\t'.join([word, str(num)]) + '\n') + + logger.info( + "number of total cases: %d, number of bad cases: %d, acc: %.4f, time cost per image: %f" + % ( + len(fn_labels_list), + bad_cnt, + 1.0 - bad_cnt / len(fn_labels_list), + total_time_cost / len(fn_labels_list), + ) + ) + + +def read_input_file(in_fp): + fn_labels_list = [] + with open(in_fp) as f: + for line in f: + fields = line.strip().split('\t') + labels = fields[1].split(' ') + labels = [l if l != '' else ' ' for l in labels] + fn_labels_list.append((fields[0], labels)) + return fn_labels_list + + +def compare_preds_to_reals(batch_preds, batch_reals, batch_img_fns): + for preds, reals, img_fn in zip(batch_preds, batch_reals, batch_img_fns): + if preds == reals: + continue + preds_set, reals_set = set(preds), set(reals) + + miss_words = reals_set.difference(preds_set) + redundant_words = preds_set.difference(reals_set) + yield [ + img_fn, + ''.join(reals), + ''.join(preds), + ''.join(miss_words), + ''.join(redundant_words), + ] + + @cli.command('resave') @click.option('-i', '--input-model-fp', type=str, required=True, help='输入的模型文件路径') @click.option('-o', '--output-model-fp', type=str, required=True, help='输出的模型文件路径') diff --git a/cnocr/cn_ocr.py b/cnocr/cn_ocr.py index 1f6bc9697efdcdb8be7e5dc9be56c15727252fd7..0fe876cee59e817de7391077c6a59e220607ddd1 100644 --- a/cnocr/cn_ocr.py +++ b/cnocr/cn_ocr.py @@ -48,11 +48,6 @@ logger = logging.getLogger(__name__) def gen_model(model_name, vocab): check_model_name(model_name) - if not model_name.startswith('densenet-s'): - logger.warning( - 'only "densenet-s" is supported now, use "densenet-s-fc" by default' - ) - model_name = 'densenet-s-fc' model = OcrModel.from_name(model_name, vocab) return model @@ -62,7 +57,7 @@ class CnOcr(object): def __init__( self, - model_name: str = 'densenet-s-fc', + model_name: str = 'densenet_lite_136-fc', *, cand_alphabet: Optional[Union[Collection, str]] = None, context: str = 'cpu', # ['cpu', 'gpu', 'cuda'] @@ -71,14 +66,28 @@ class CnOcr(object): **kwargs, ): """ + 识别模型初始化函数。 + + Args: + model_name (str): 模型名称。默认为 `densenet_lite_136-fc` + cand_alphabet (Optional[Union[Collection, str]]): 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围 + context (str): 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为 `cpu` + model_fp (Optional[str]): 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件) + root (Union[str, Path]): 模型文件所在的根目录。 + Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.1/densenet_lite_136-fc`。 + Windows下默认值为 `C:/Users//AppData/Roaming/cnocr`。 + **kwargs: 目前未被使用。 + + Examples: + 使用默认参数: + >>> ocr = CnOcr() + + 使用指定模型: + >>> ocr = CnOcr(model_name='densenet_lite_136-fc') + + 识别时只考虑数字: + >>> ocr = CnOcr(model_name='densenet_lite_136-fc', cand_alphabet='0123456789') - :param model_name: 模型名称。默认为 `densenet-s-fc` - :param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围 - :param context: 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为 `cpu` - :param model_fp: 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件) - :param root: 模型文件所在的根目录。 - Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.0/densenet-s-fc`。 - Windows下默认值为 `C:/Users//AppData/Roaming/cnocr`。 """ if 'name' in kwargs: logger.warning( @@ -144,8 +153,13 @@ class CnOcr(object): def set_cand_alphabet(self, cand_alphabet: Optional[Union[Collection, str]]): """ 设置待识别字符的候选集合。 - :param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围 - :return: None + + Args: + cand_alphabet (Optional[Union[Collection, str]]): 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围 + + Returns: + None + """ if cand_alphabet is None: self._candidates = None @@ -169,10 +183,15 @@ class CnOcr(object): self, img_fp: Union[str, Path, torch.Tensor, np.ndarray] ) -> List[Tuple[List[str], float]]: """ - :param img_fp: image file path; or color image torch.Tensor or np.ndarray, - with shape [height, width] or [height, width, channel]. - channel should be 1 (gray image) or 3 (RGB formatted color image). scaled in [0, 255]. - :return: list of (list of chars, prob), such as + 识别函数。 + + Args: + img_fp (Union[str, Path, torch.Tensor, np.ndarray]): image file path; or color image torch.Tensor or np.ndarray, + with shape [height, width] or [height, width, channel]. + channel should be 1 (gray image) or 3 (RGB formatted color image). scaled in [0, 255]. + + Returns: + list of (list of chars, prob), such as [(['第', '一', '行'], 0.80), (['第', '二', '行'], 0.75), (['第', '三', '行'], 0.9)] """ img = self._prepare_img(img_fp) @@ -190,11 +209,15 @@ class CnOcr(object): self, img_fp: Union[str, Path, torch.Tensor, np.ndarray] ) -> np.ndarray: """ - :param img: image array with type torch.Tensor or np.ndarray, - with shape [height, width] or [height, width, channel]. - channel should be 1 (gray image) or 3 (color image). - :return: np.ndarray, with shape (height, width, 1), dtype uint8, scale [0, 255] + Args: + img_fp (Union[str, Path, torch.Tensor, np.ndarray]): + image array with type torch.Tensor or np.ndarray, + with shape [height, width] or [height, width, channel]. + channel should be 1 (gray image) or 3 (color image). + + Returns: + np.ndarray: with shape (height, width, 1), dtype uint8, scale [0, 255] """ img = img_fp if isinstance(img_fp, (str, Path)): @@ -226,10 +249,15 @@ class CnOcr(object): ) -> Tuple[List[str], float]: """ Recognize characters from an image with only one-line characters. - :param img_fp: image file path; or image torch.Tensor or np.ndarray, - with shape [height, width] or [height, width, channel]. - The optional channel should be 1 (gray image) or 3 (color image). - :return: (list of chars, prob), such as (['你', '好'], 0.80) + + Args: + img_fp (Union[str, Path, torch.Tensor, np.ndarray]): + image file path; or image torch.Tensor or np.ndarray, + with shape [height, width] or [height, width, channel]. + The optional channel should be 1 (gray image) or 3 (color image). + + Returns: + tuple: (list of chars, prob), such as (['你', '好'], 0.80) """ img = self._prepare_img(img_fp) res = self.ocr_for_single_lines([img]) @@ -242,27 +270,49 @@ class CnOcr(object): ) -> List[Tuple[List[str], float]]: """ Batch recognize characters from a list of one-line-characters images. - :param img_list: list of images, in which each element should be a line image array, - with type torch.Tensor or np.ndarray. - Each element should be a tensor with values ranging from 0 to 255, - and with shape [height, width] or [height, width, channel]. - The optional channel should be 1 (gray image) or 3 (color image). - :param batch_size: 待处理图片很多时,需要分批处理,每批图片的数量由此参数指定。默认为 `1`。 - :return: list of (list of chars, prob), such as + + Args: + img_list (List[Union[str, Path, torch.Tensor, np.ndarray]]): + list of images, in which each element should be a line image array, + with type torch.Tensor or np.ndarray. + Each element should be a tensor with values ranging from 0 to 255, + and with shape [height, width] or [height, width, channel]. + The optional channel should be 1 (gray image) or 3 (color image). + 注:img_list 不宜包含太多图片,否则同时导入这些图片会消耗很多内存。 + batch_size: 待处理图片很多时,需要分批处理,每批图片的数量由此参数指定。默认为 `1`。 + + Returns: + list: list of (list of chars, prob), such as [(['第', '一', '行'], 0.80), (['第', '二', '行'], 0.75), (['第', '三', '行'], 0.9)] """ if len(img_list) == 0: return [] + img_list = [self._prepare_img(img) for img in img_list] img_list = [self._transform_img(img) for img in img_list] + should_sort = batch_size > 1 and len(img_list) // batch_size > 1 + + if should_sort: + # 把图片按宽度从小到大排列,提升效率 + sorted_idx_list = sorted( + range(len(img_list)), key=lambda i: img_list[i].shape[2] + ) + sorted_img_list = [img_list[i] for i in sorted_idx_list] + else: + sorted_idx_list = range(len(img_list)) + sorted_img_list = img_list + idx = 0 - out = [] - while idx * batch_size < len(img_list): - imgs = img_list[idx * batch_size : (idx + 1) * batch_size] + sorted_out = [] + while idx * batch_size < len(sorted_img_list): + imgs = sorted_img_list[idx * batch_size : (idx + 1) * batch_size] batch_out = self._predict(imgs) - out.extend(batch_out['preds']) + sorted_out.extend(batch_out['preds']) idx += 1 + out = [None] * len(sorted_out) + for idx, pred in zip(sorted_idx_list, sorted_out): + out[idx] = pred res = [] for line in out: @@ -274,11 +324,13 @@ class CnOcr(object): def _transform_img(self, img: np.ndarray) -> torch.Tensor: """ - :param img: image array with type torch.Tensor or np.ndarray, - with shape [height, width] or [height, width, channel]. - channel shoule be 1 (gray image) or 3 (color image). + Args: + img: image array with type torch.Tensor or np.ndarray, + with shape [height, width] or [height, width, channel]. + channel shoule be 1 (gray image) or 3 (color image). - :return: torch.Tensor, with shape (1, height, width) + Returns: + torch.Tensor: with shape (1, height, width) """ img = rescale_img(img.transpose((2, 0, 1))) # res: [C, H, W] return NormalizeAug()(img).to(device=torch.device(self.context)) diff --git a/cnocr/consts.py b/cnocr/consts.py index b0c6087638924e543315caea192a38e5ec6f1c87..62597fcd1d6147aff9af21fe3a8c453036902ada 100644 --- a/cnocr/consts.py +++ b/cnocr/consts.py @@ -30,38 +30,89 @@ IMG_STANDARD_HEIGHT = 32 VOCAB_FP = Path(__file__).parent / 'label_cn.txt' ENCODER_CONFIGS = { - 'densenet-s': { # 长度压缩至 1/8(seq_len == 35),输出的向量长度为 4*128 = 512 + 'densenet': { # 长度压缩至 1/8(seq_len == 35),输出的向量长度为 4*128 = 512 'growth_rate': 32, 'block_config': [2, 2, 2, 2], 'num_init_features': 64, 'out_length': 512, # 输出的向量长度为 4*128 = 512 }, + 'densenet_1112': { # 长度压缩至 1/8(seq_len == 35) + 'growth_rate': 32, + 'block_config': [1, 1, 1, 2], + 'num_init_features': 64, + 'out_length': 400, + }, + 'densenet_1114': { # 长度压缩至 1/8(seq_len == 35) + 'growth_rate': 32, + 'block_config': [1, 1, 1, 4], + 'num_init_features': 64, + 'out_length': 656, + }, + 'densenet_1122': { # 长度压缩至 1/8(seq_len == 35) + 'growth_rate': 32, + 'block_config': [1, 1, 2, 2], + 'num_init_features': 64, + 'out_length': 464, + }, + 'densenet_1124': { # 长度压缩至 1/8(seq_len == 35) + 'growth_rate': 32, + 'block_config': [1, 1, 2, 4], + 'num_init_features': 64, + 'out_length': 720, + }, + 'densenet_lite_113': { # 长度压缩至 1/8(seq_len == 35),输出的向量长度为 2*136 = 272 + 'growth_rate': 32, + 'block_config': [1, 1, 3], + 'num_init_features': 64, + 'out_length': 272, # 输出的向量长度为 2*80 = 160 + }, + 'densenet_lite_114': { # 长度压缩至 1/8(seq_len == 35) + 'growth_rate': 32, + 'block_config': [1, 1, 4], + 'num_init_features': 64, + 'out_length': 336, + }, + 'densenet_lite_124': { # 长度压缩至 1/8(seq_len == 35) + 'growth_rate': 32, + 'block_config': [1, 2, 4], + 'num_init_features': 64, + 'out_length': 368, + }, + 'densenet_lite_134': { # 长度压缩至 1/8(seq_len == 35) + 'growth_rate': 32, + 'block_config': [1, 3, 4], + 'num_init_features': 64, + 'out_length': 400, + }, + 'densenet_lite_136': { # 长度压缩至 1/8(seq_len == 35) + 'growth_rate': 32, + 'block_config': [1, 3, 6], + 'num_init_features': 64, + 'out_length': 528, + }, + 'mobilenetv3_tiny': {'arch': 'tiny', 'out_length': 384,}, + 'mobilenetv3_small': {'arch': 'small', 'out_length': 384,}, } DECODER_CONFIGS = { - 'lstm': { - 'input_size': 512, # 对应 encoder 的输出向量长度 - 'rnn_units': 128, - }, - 'gru': { - 'input_size': 512, # 对应 encoder 的输出向量长度 - 'rnn_units': 128, - }, - 'fc': { - 'input_size': 512, # 对应 encoder 的输出向量长度 - 'hidden_size': 256, - 'dropout': 0.3, - } + 'lstm': {'rnn_units': 128,}, + 'gru': {'rnn_units': 128,}, + 'fc': {'hidden_size': 128, 'dropout': 0.1,}, + 'fcfull': {'hidden_size': 256, 'dropout': 0.3,}, } root_url = ( - 'https://beiye-model.oss-cn-beijing.aliyuncs.com/models/cnocr/%s/' + 'https://huggingface.co/breezedeus/cnstd-cnocr-models/resolve/main/models/cnocr/%s/' % MODEL_VERSION ) -# name: (epochs, url) +# name: (epoch, url) AVAILABLE_MODELS = { - 'densenet-s-fc': (8, root_url + 'densenet-s-fc-v2.0.1.zip'), - 'densenet-s-gru': (14, root_url + 'densenet-s-gru-v2.0.1.zip'), + 'densenet_lite_114-fc': (37, root_url + 'densenet_lite_114-fc.zip'), + 'densenet_lite_124-fc': (39, root_url + 'densenet_lite_124-fc.zip'), + 'densenet_lite_134-fc': (34, root_url + 'densenet_lite_134-fc.zip'), + 'densenet_lite_136-fc': (39, root_url + 'densenet_lite_136-fc.zip'), + 'densenet_lite_134-gru': (2, root_url + 'densenet_lite_134-gru.zip'), + 'densenet_lite_136-gru': (2, root_url + 'densenet_lite_136-gru.zip'), } # 候选字符集合 diff --git a/cnocr/data_utils/aug.py b/cnocr/data_utils/aug.py index d53399af95f1baa59d09484fe163caa7e22b164b..f8353eaa84fecea817b17f79f34eed045ec866aa 100644 --- a/cnocr/data_utils/aug.py +++ b/cnocr/data_utils/aug.py @@ -18,8 +18,10 @@ # under the License. import random +from typing import Tuple import torch +import torchvision.transforms.functional as F from ..utils import normalize_img_array @@ -32,6 +34,7 @@ class FgBgFlipAug(object): p : float Probability to flip image horizontally """ + def __init__(self, p): self.p = p @@ -47,6 +50,71 @@ class NormalizeAug(object): return normalize_img_array(img) +class RandomStretchAug(object): + """对图片在宽度上做随机拉伸""" + + def __init__(self, min_ratio=0.9, max_ratio=1.1): + self.min_ratio = min_ratio + self.max_ratio = max_ratio + + def __call__(self, img: torch.Tensor): + """ + + :param img: [C, H, W] + :return: + """ + _, h, w = img.shape + new_w_ratio = self.min_ratio + random.random() * ( + self.max_ratio - self.min_ratio + ) + return F.resize(img, [h, int(w * new_w_ratio)]) + + +class RandomCrop(torch.nn.Module): + def __init__( + self, crop_size: Tuple[int, int], interpolation=F.InterpolationMode.BILINEAR + ): + super().__init__() + self.crop_size = crop_size + self.interpolation = interpolation + + def get_params(self, ori_w, ori_h) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random crop. + + Args: + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. + """ + while True: + h_top, h_bot = ( + random.randint(0, self.crop_size[0]), + random.randint(0, self.crop_size[0]), + ) + w_left, w_right = ( + random.randint(0, self.crop_size[1]), + random.randint(0, self.crop_size[1]), + ) + h = ori_h - h_top - h_bot + w = ori_w - w_left - w_right + if h < ori_h * 0.5 or w < ori_w * 0.9: + continue + + return h_top, w_left, h, w + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped and resized. + + Returns: + PIL Image or Tensor: Randomly cropped and resized image. + """ + ori_w, ori_h = F._get_image_size(img) + i, j, h, w = self.get_params(ori_w, ori_h) + return F.resized_crop(img, i, j, h, w, (ori_h, ori_w), self.interpolation) + + class RandomPaddingAug(object): def __init__(self, p, max_pad_len): self.p = p diff --git a/cnocr/data_utils/block_shuffle.py b/cnocr/data_utils/block_shuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..b632a31cb326a5fcf940f2e0612518be141b59f2 --- /dev/null +++ b/cnocr/data_utils/block_shuffle.py @@ -0,0 +1,54 @@ +# coding: utf-8 +# Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# Credits: adapted from https://mp.weixin.qq.com/s/xGvaW87UQFjetc5xFmKxWg + + +import random +from torch.utils.data import Dataset, DataLoader + + +class BlockShuffleDataLoader(DataLoader): + def __init__( + self, dataset: Dataset, **kwargs + ): + """ + 对 OcrDataset 数据集实现Block Shuffle功能,按文字数量从少到多的顺序排列样本(相同长度样本则随机排列) + Args: + dataset: OcrDataset类的实例,其中中必须包含labels_list变量,并且该变量为一个list + **kwargs: + """ + assert isinstance( + dataset.labels_list, list + ), "dataset为OcrDataset类的实例,其中必须包含labels_list变量,并且该变量为一个list" + kwargs['shuffle'] = False + super().__init__(dataset, **kwargs) + + def __iter__(self): + self.block_shuffle2() + return super().__iter__() + + def block_shuffle2(self): + idx_list = list(range(len(self.dataset))) + random.shuffle(idx_list) + random.shuffle(idx_list) + idx_list.sort(key=lambda idx: len(self.dataset.labels_list[idx])) + for attr in ('img_fp_list', 'labels_list'): + ori_list = getattr(self.dataset, attr) + new_list = [ori_list[idx] for idx in idx_list] + setattr(self.dataset, attr, new_list) diff --git a/cnocr/dataset.py b/cnocr/dataset.py index c57d443d4d2475ec2c82417608487d940e834f62..7d6aa2aefd74a74d49c2b5d68dc4119ada0f9263 100644 --- a/cnocr/dataset.py +++ b/cnocr/dataset.py @@ -59,9 +59,9 @@ def collate_fn(img_labels: List[Tuple[str, str]], transformers: Callable = None) img_list, labels_list = zip(*img_labels) label_lengths = torch.tensor([len(labels) for labels in labels_list]) - img_lengths = torch.tensor([img.size(2) for img in img_list]) if transformers is not None: img_list = [transformers(img) for img in img_list] + img_lengths = torch.tensor([img.size(2) for img in img_list]) imgs = pad_img_seq(img_list) return imgs, img_lengths, labels_list, label_lengths diff --git a/cnocr/lr_scheduler.py b/cnocr/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..10b6286206032aa3e753c5b4d606b9fab6db9f7a --- /dev/null +++ b/cnocr/lr_scheduler.py @@ -0,0 +1,199 @@ +# coding: utf-8 +# Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from copy import deepcopy +import math + +import torch +from torch.optim.lr_scheduler import ( + _LRScheduler, + StepLR, + LambdaLR, + CyclicLR, + CosineAnnealingWarmRestarts, + MultiStepLR, + OneCycleLR, +) + + +def get_lr_scheduler(config, optimizer): + orig_lr = config['learning_rate'] + lr_sch_config = deepcopy(config['lr_scheduler']) + lr_sch_name = lr_sch_config.pop('name') + epochs = config['epochs'] + steps_per_epoch = config['steps_per_epoch'] + + if lr_sch_name == 'multi_step': + milestones = [v * steps_per_epoch for v in lr_sch_config['milestones']] + return MultiStepLR( + optimizer, milestones=milestones, gamma=lr_sch_config['gamma'], + ) + elif lr_sch_name == 'cos_warmup': + min_lr_mult_factor = lr_sch_config.get('min_lr_mult_factor', 0.1) + warmup_epochs = lr_sch_config.get('warmup_epochs', 0.1) + return WarmupCosineAnnealingRestarts( + optimizer, + first_cycle_steps=steps_per_epoch * epochs, + max_lr=orig_lr, + min_lr=orig_lr * min_lr_mult_factor, + warmup_steps=int(steps_per_epoch * warmup_epochs), + ) + elif lr_sch_name == 'cos_anneal': + # 5 个 epochs, 一个循环 + return CosineAnnealingWarmRestarts( + optimizer, T_0=5 * steps_per_epoch, T_mult=1, eta_min=orig_lr * 0.1 + ) + elif lr_sch_name == 'cyclic': + return CyclicLR( + optimizer, + base_lr=orig_lr / 10.0, + max_lr=orig_lr, + step_size_up=5 * steps_per_epoch, # 5 个 epochs, 从最小base_lr上升到最大max_lr + cycle_momentum=False, + ) + elif lr_sch_name == 'one_cycle': + return OneCycleLR( + optimizer, max_lr=orig_lr, epochs=epochs, steps_per_epoch=steps_per_epoch, + ) + + step_size = lr_sch_config['step_size'] + gamma = lr_sch_config['gamma'] + if step_size is None or gamma is None: + return LambdaLR(optimizer, lr_lambda=lambda _: 1) + return StepLR(optimizer, step_size, gamma=gamma) + + +class WarmupCosineAnnealingRestarts(_LRScheduler): + """ + from https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup/blob/master/cosine_annealing_warmup/scheduler.py + + optimizer (Optimizer): Wrapped optimizer. + first_cycle_steps (int): First cycle step size. + cycle_mult(float): Cycle steps magnification. Default: -1. + max_lr(float): First cycle's max learning rate. Default: 0.1. + min_lr(float): Min learning rate. Default: 0.001. + warmup_steps(int): Linear warmup step size. Default: 0. + gamma(float): Decrease rate of max learning rate by cycle. Default: 1. + last_epoch (int): The index of last epoch. Default: -1. + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + first_cycle_steps: int, + cycle_mult: float = 1.0, + max_lr: float = 0.1, + min_lr: float = 0.001, + warmup_steps: int = 0, + gamma: float = 1.0, + last_epoch: int = -1, + ): + assert warmup_steps < first_cycle_steps + + self.first_cycle_steps = first_cycle_steps # first cycle step size + self.cycle_mult = cycle_mult # cycle steps magnification + self.base_max_lr = max_lr # first max learning rate + self.max_lr = max_lr # max learning rate in the current cycle + self.min_lr = min_lr # min learning rate + self.warmup_steps = warmup_steps # warmup step size + self.gamma = gamma # decrease rate of max learning rate by cycle + + self.cur_cycle_steps = first_cycle_steps # first cycle step size + self.cycle = 0 # cycle count + self.step_in_cycle = last_epoch # step size of the current cycle + + super().__init__(optimizer, last_epoch) + + # set learning rate min_lr + self.init_lr() + + def init_lr(self): + self.base_lrs = [] + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.min_lr + self.base_lrs.append(self.min_lr) + + def get_lr(self): + if self.step_in_cycle == -1: + return self.base_lrs + elif self.step_in_cycle < self.warmup_steps: + return [ + (self.max_lr - base_lr) * self.step_in_cycle / self.warmup_steps + + base_lr + for base_lr in self.base_lrs + ] + else: + return [ + base_lr + + (self.max_lr - base_lr) + * ( + 1 + + math.cos( + math.pi + * (self.step_in_cycle - self.warmup_steps) + / (self.cur_cycle_steps - self.warmup_steps) + ) + ) + / 2 + for base_lr in self.base_lrs + ] + + def step(self, epoch=None): + if epoch is None: + epoch = self.last_epoch + 1 + self.step_in_cycle = self.step_in_cycle + 1 + if self.step_in_cycle >= self.cur_cycle_steps: + self.cycle += 1 + self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps + self.cur_cycle_steps = ( + int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + + self.warmup_steps + ) + else: + if epoch >= self.first_cycle_steps: + if self.cycle_mult == 1.0: + self.step_in_cycle = epoch % self.first_cycle_steps + self.cycle = epoch // self.first_cycle_steps + else: + n = int( + math.log( + ( + epoch / self.first_cycle_steps * (self.cycle_mult - 1) + + 1 + ), + self.cycle_mult, + ) + ) + self.cycle = n + self.step_in_cycle = epoch - int( + self.first_cycle_steps + * (self.cycle_mult ** n - 1) + / (self.cycle_mult - 1) + ) + self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** ( + n + ) + else: + self.cur_cycle_steps = self.first_cycle_steps + self.step_in_cycle = epoch + + self.max_lr = self.base_max_lr * (self.gamma ** self.cycle) + self.last_epoch = math.floor(epoch) + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + param_group['lr'] = lr diff --git a/cnocr/models/densenet.py b/cnocr/models/densenet.py index 77826164580d1ff4c620c369bda9fa26a08eb2ab..7082ccc01e446af34e589c6cc6d9ad348698bb85 100644 --- a/cnocr/models/densenet.py +++ b/cnocr/models/densenet.py @@ -29,7 +29,7 @@ class DenseNet(densenet.DenseNet): def __init__( self, growth_rate: int = 32, - block_config: Tuple[int, int, int, int] = (6, 12, 24, 16), + block_config: Tuple[int, int, int, int] = (2, 2, 2, 2), num_init_features: int = 64, bn_size: int = 4, drop_rate: float = 0, @@ -46,13 +46,36 @@ class DenseNet(densenet.DenseNet): ) self.block_config = block_config + + delattr(self, 'classifier') self.features.conv0 = nn.Conv2d( 1, num_init_features, kernel_size=3, stride=1, padding=1, bias=False ) self.features.pool0 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) - delattr(self, 'classifier') + + last_denselayer = self._get_last_denselayer(len(self.block_config)) + conv = last_denselayer.conv2 + in_channels, out_channels = conv.in_channels, conv.out_channels + last_denselayer.conv2 = nn.Conv2d( + in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False + ) + + # for i in range(1, len(self.block_config)): + # transition = getattr(self.features, 'transition%d' % i) + # in_channels, out_channels = transition.conv.in_channels, transition.conv.out_channels + # trans = _MaxPoolTransition(num_input_features=in_channels, + # num_output_features=out_channels) + # setattr(self.features, 'transition%d' % i, trans) + self._post_init_weights() + def _get_last_denselayer(self, block_num): + denseblock = getattr(self.features, 'denseblock%d' % block_num) + i = 1 + while hasattr(denseblock, 'denselayer%d' % i): + i += 1 + return getattr(denseblock, 'denselayer%d' % (i-1)) + @property def compress_ratio(self): return 2 ** (len(self.block_config) - 1) @@ -71,3 +94,51 @@ class DenseNet(densenet.DenseNet): def forward(self, x: Tensor) -> Tensor: features = self.features(x) return features + + +class DenseNetLite(DenseNet): + def __init__( + self, + growth_rate: int = 32, + block_config: Tuple[int, int, int] = (2, 2, 2), + num_init_features: int = 64, + bn_size: int = 4, + drop_rate: float = 0, + memory_efficient: bool = False, + ) -> None: + super().__init__( + growth_rate, + block_config, + num_init_features, + bn_size, + drop_rate, + memory_efficient=memory_efficient, + ) + self.features.pool0 = nn.AvgPool2d(kernel_size=2, stride=2) + + # last max pool, pool 1/8 to 1/16 for height dimension + self.features.add_module( + 'pool5', nn.AvgPool2d(kernel_size=(2, 1), stride=(2, 1)) + ) + + @property + def compress_ratio(self): + return 2 ** len(self.block_config) + + +class _MaxPoolTransition(nn.Sequential): + def __init__(self, num_input_features: int, num_output_features: int) -> None: + super().__init__() + self.add_module('norm', nn.BatchNorm2d(num_input_features)) + self.add_module('relu', nn.ReLU(inplace=True)) + self.add_module( + 'conv', + nn.Conv2d( + num_input_features, + num_output_features, + kernel_size=1, + stride=1, + bias=False, + ), + ) + self.add_module('pool', nn.MaxPool2d(kernel_size=2, stride=2)) diff --git a/cnocr/models/mobilenet.py b/cnocr/models/mobilenet.py new file mode 100644 index 0000000000000000000000000000000000000000..5cfbfc8f255bc1120eee973e7cabf5b6eea973f3 --- /dev/null +++ b/cnocr/models/mobilenet.py @@ -0,0 +1,189 @@ +# coding: utf-8 +# Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# adapted from: torchvision/models/mobilenetv3.py + +from functools import partial +from typing import Any, List, Optional, Callable + +from torch import nn, Tensor +from torchvision.models.mobilenetv2 import ConvBNActivation +from torchvision.models import mobilenetv3 +from torchvision.models.mobilenetv3 import InvertedResidualConfig + + +class MobileNetV3(mobilenetv3.MobileNetV3): + def __init__( + self, + inverted_residual_setting: List[InvertedResidualConfig], + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any + ) -> None: + super().__init__(inverted_residual_setting, 1, 2, block, norm_layer) + delattr(self, 'classifier') + + firstconv_input_channels = self.features[0][0].out_channels + self.features[0] = ConvBNActivation( + 1, + firstconv_input_channels, + kernel_size=3, + stride=2, + norm_layer=norm_layer, + activation_layer=nn.Hardswish, + ) + + lastconv_input_channels = self.features[-1][0].in_channels + lastconv_output_channels = 2 * lastconv_input_channels + self.features[-1] = ConvBNActivation( + lastconv_input_channels, + lastconv_output_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=nn.Hardswish, + ) + self.avgpool = nn.AvgPool2d(kernel_size=(2, 1), stride=(2, 1)) + + self._post_init_weights() + + @property + def compress_ratio(self): + return 8 + + def _post_init_weights(self): + # Official init from torch repo. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + def forward(self, x: Tensor) -> Tensor: + features = self.features(x) + features = self.avgpool(features) + return features + + +def _mobilenet_v3_conf( + arch: str, + width_mult: float = 1.0, + reduced_tail: bool = False, + dilated: bool = False, + **kwargs: Any +): + reduce_divider = 2 if reduced_tail else 1 + dilation = 2 if dilated else 1 + + bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) + adjust_channels = partial( + InvertedResidualConfig.adjust_channels, width_mult=width_mult + ) + + if arch == "mobilenet_v3_tiny": + inverted_residual_setting = [ + bneck_conf(24, 3, 88, 24, False, "RE", 1, 1), + bneck_conf(24, 5, 96, 40, False, "HS", 2, 1), # C3 + bneck_conf(40, 5, 120, 48, False, "HS", 1, 1), + # bneck_conf(48, 5, 144, 48, False, "HS", 1, 1), + bneck_conf( + 48, 5, 288, 96 // reduce_divider, False, "HS", 2, dilation + ), # C4 + bneck_conf( + 96 // reduce_divider, + 5, + 128 // reduce_divider, + 96 // reduce_divider, + True, + "HS", + 1, + dilation, + ), + bneck_conf( + 96 // reduce_divider, + 5, + 128 // reduce_divider, + 96 // reduce_divider, + True, + "HS", + 1, + dilation, + ), + ] + elif arch == "mobilenet_v3_small": + inverted_residual_setting = [ + bneck_conf(16, 3, 16, 16, False, "RE", 1, 1), # C1 + bneck_conf(16, 3, 72, 24, False, "RE", 1, 1), # C2 + bneck_conf(24, 3, 88, 24, False, "RE", 1, 1), + bneck_conf(24, 5, 96, 40, False, "HS", 2, 1), # C3 + bneck_conf(40, 5, 240, 40, False, "HS", 1, 1), + bneck_conf(40, 5, 240, 40, False, "HS", 1, 1), + bneck_conf(40, 5, 120, 48, True, "HS", 1, 1), + bneck_conf(48, 5, 144, 48, True, "HS", 1, 1), + bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4 + bneck_conf( + 96 // reduce_divider, + 5, + 576 // reduce_divider, + 96 // reduce_divider, + True, + "HS", + 1, + dilation, + ), + bneck_conf( + 96 // reduce_divider, + 5, + 576 // reduce_divider, + 96 // reduce_divider, + True, + "HS", + 1, + dilation, + ), + ] + else: + raise ValueError("Unsupported model type {}".format(arch)) + + return inverted_residual_setting + + +def _mobilenet_v3_model( + inverted_residual_setting: List[InvertedResidualConfig], **kwargs: Any +): + model = MobileNetV3(inverted_residual_setting, **kwargs) + return model + + +def gen_mobilenet_v3(arch: str = 'tiny', **kwargs: Any) -> MobileNetV3: + """ + Constructs a small MobileNetV3 architecture from + `"Searching for MobileNetV3" `_. + + Args: + arch (str): arch name; values: 'tiny' or 'small' + + """ + arch = 'mobilenet_v3_%s' % arch + inverted_residual_setting = _mobilenet_v3_conf(arch, **kwargs) + return _mobilenet_v3_model(inverted_residual_setting, **kwargs) diff --git a/cnocr/models/ocr_model.py b/cnocr/models/ocr_model.py index d5f1c1534729b411597d41f7178cca4d6b603874..6b834f6f4c0bd51e2619cc88256aead6e765a87e 100644 --- a/cnocr/models/ocr_model.py +++ b/cnocr/models/ocr_model.py @@ -30,7 +30,8 @@ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence from .ctc import CTCPostProcessor from ..consts import ENCODER_CONFIGS, DECODER_CONFIGS from ..data_utils.utils import encode_sequences -from .densenet import DenseNet +from .densenet import DenseNet, DenseNetLite +from .mobilenet import gen_mobilenet_v3 class EncoderManager(object): @@ -45,9 +46,16 @@ class EncoderManager(object): assert config is not None and 'name' in config name = config.pop('name') - if name.lower() == 'densenet-s': + if name.lower().startswith('densenet_lite'): + out_length = config.pop('out_length') + encoder = DenseNetLite(**config) + elif name.lower().startswith('densenet'): out_length = config.pop('out_length') encoder = DenseNet(**config) + elif name.lower().startswith('mobilenet'): + arch = config['arch'] + out_length = config.pop('out_length') + encoder = gen_mobilenet_v3(arch) else: raise ValueError('not supported encoder name: %s' % name) return encoder, out_length @@ -86,11 +94,11 @@ class DecoderManager(object): bidirectional=True, ) out_length = config['rnn_units'] * 2 - elif name.lower() == 'fc': + elif name.lower() in ('fc', 'fcfull'): decoder = nn.Sequential( nn.Dropout(p=config['dropout']), # nn.Tanh(), - nn.Linear(config['input_size'], config['hidden_size']), + nn.Linear(input_size, config['hidden_size']), nn.Dropout(p=config['dropout']), nn.Tanh(), ) diff --git a/cnocr/trainer.py b/cnocr/trainer.py index 4a884d14a49e139f8f5e5c860bec2852a6c46886..c13359d7edbe70f43a7c27d384915a7e305c8f18 100644 --- a/cnocr/trainer.py +++ b/cnocr/trainer.py @@ -26,16 +26,12 @@ import numpy as np import torch import torch.optim as optim from torch import nn +from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor -from torch.optim.lr_scheduler import ( - StepLR, - LambdaLR, - CyclicLR, - CosineAnnealingWarmRestarts, - MultiStepLR, -) -from torch.utils.data import DataLoader + +from .lr_scheduler import get_lr_scheduler + logger = logging.getLogger(__name__) @@ -67,37 +63,6 @@ def get_optimizer(name: str, model, learning_rate, weight_decay): return optimizer -def get_lr_scheduler(config, optimizer): - orig_lr = config['learning_rate'] - lr_sch_config = deepcopy(config['lr_scheduler']) - lr_sch_name = lr_sch_config.pop('name') - - if lr_sch_name == 'multi_step': - return MultiStepLR( - optimizer, - milestones=lr_sch_config['milestones'], - gamma=lr_sch_config['gamma'], - ) - elif lr_sch_name == 'cos_anneal': - return CosineAnnealingWarmRestarts( - optimizer, T_0=4, T_mult=1, eta_min=orig_lr / 10.0 - ) - elif lr_sch_name == 'cyclic': - return CyclicLR( - optimizer, - base_lr=orig_lr / 10.0, - max_lr=orig_lr, - step_size_up=2, - cycle_momentum=False, - ) - - step_size = lr_sch_config['step_size'] - gamma = lr_sch_config['gamma'] - if step_size is None or gamma is None: - return LambdaLR(optimizer, lr_lambda=lambda _: 1) - return StepLR(optimizer, step_size, gamma=gamma) - - class Accuracy(object): @classmethod def complete_match(cls, labels: List[List[str]], preds: List[List[str]]): @@ -144,6 +109,11 @@ class WrapperLightningModule(pl.LightningModule): else: setattr(self.model, 'current_epoch', self.current_epoch) res = self.model.calculate_loss(batch) + + # update lr scheduler + sch = self.lr_schedulers() + sch.step() + losses = res['loss'] self.log( 'train_loss', @@ -214,7 +184,7 @@ class PlTrainer(object): max_epochs=self.config.get('epochs', 20), precision=self.config.get('precision', 32), callbacks=callbacks, - stochastic_weight_avg=True, + stochastic_weight_avg=False, ) def fit( @@ -241,6 +211,12 @@ class PlTrainer(object): no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch. """ + steps_per_epoch = ( + len(train_dataloader) + if train_dataloader is not None + else len(datamodule.train_dataloader()) + ) + self.config['steps_per_epoch'] = steps_per_epoch if resume_from_checkpoint is not None: pl_module = WrapperLightningModule.load_from_checkpoint( resume_from_checkpoint, config=self.config, model=model diff --git a/cnocr/utils.py b/cnocr/utils.py index bcbb81eeb79f6ae72eb5f79ba9852a965f1db3ec..6998541fec8c797c924575ec9d78f772f636f2ce 100644 --- a/cnocr/utils.py +++ b/cnocr/utils.py @@ -102,7 +102,7 @@ def data_dir(): def check_model_name(model_name): - encoder_type, decoder_type = model_name.rsplit('-', maxsplit=1) + encoder_type, decoder_type = model_name.split('-')[:2] assert encoder_type in ENCODER_CONFIGS assert decoder_type in DECODER_CONFIGS @@ -362,3 +362,9 @@ def load_model_params(model, param_fp, device='cpu'): state_dict[k.split('.', maxsplit=1)[1]] = v model.load_state_dict(state_dict) return model + + +def get_model_size(model, only_trainable=False): + if only_trainable: + return sum(p.numel() for p in model.parameters() if p.requires_grad) + return sum(p.numel() for p in model.parameters()) diff --git a/RELEASE.md b/docs/RELEASE.md similarity index 92% rename from RELEASE.md rename to docs/RELEASE.md index 545f0d43105f089b50922632fcd4d69e7797f00e..02122125afc36f81a607aa2693cbd1d79d1f3c30 100644 --- a/RELEASE.md +++ b/docs/RELEASE.md @@ -1,5 +1,15 @@ # Release Notes +### Update 2021.11.06: 发布 cnocr V2.1.0 + +主要变更: + +* 使用了更精简的模型架构:`densenet_lite_*`; +* 使用了更丰富的数据重新训练了所有模型,精度相较于之前版本更高; +* 提供了更多预训练好的模型; +* 加入了 `cnocr evaluate` 命令以评估效果。 + + ### Update 2021.09.21: 发布 cnocr V2.0.1 主要变更: diff --git a/docs/cnocr/cn_ocr.md b/docs/cnocr/cn_ocr.md new file mode 100644 index 0000000000000000000000000000000000000000..2f0def5d0ebb94747998a0436f3ce1ed228cfc1d --- /dev/null +++ b/docs/cnocr/cn_ocr.md @@ -0,0 +1 @@ +:::cnocr.cn_ocr diff --git a/docs/cnstd_cnocr.md b/docs/cnstd_cnocr.md new file mode 100644 index 0000000000000000000000000000000000000000..6f6f3c0278668f0657bd9a82b0789360966e9d6b --- /dev/null +++ b/docs/cnstd_cnocr.md @@ -0,0 +1,32 @@ + +# 强强联合:[CnStd](https://github.com/breezedeus/cnstd) + CnOcr + +关于为什么要结合 [CnStd](https://github.com/breezedeus/cnstd) 和 CnOcr 一起使用,可参考 [场景文字识别介绍](std_ocr.md)。 + +对于一般的场景图片(如照片、票据等),需要先利用场景文字检测引擎 **[cnstd](https://github.com/breezedeus/cnstd)** 定位到文字所在位置,然后再利用 **cnocr** 进行文本识别。 + +```python +from cnstd import CnStd +from cnocr import CnOcr + +std = CnStd() +cn_ocr = CnOcr() + +box_infos = std.detect('examples/taobao.jpg') + +for box_info in box_infos['detected_texts']: + cropped_img = box_info['cropped_img'] + ocr_res = cn_ocr.ocr_for_single_line(cropped_img) + print('ocr result: %s' % str(ocr_out)) +``` + +注:运行上面示例需要先安装 **[cnstd](https://github.com/breezedeus/cnstd)** : + +```bash +pip install cnstd +``` + +**[cnstd](https://github.com/breezedeus/cnstd)** 相关的更多使用说明请参考其项目地址。 + +可基于 [在线 Demo](demo.md) 查看 CnStd + CnOcr 的联合效果。 + diff --git a/docs/command.md b/docs/command.md new file mode 100644 index 0000000000000000000000000000000000000000..a90b34d584f18947e36d829d2b6e87de40d20893 --- /dev/null +++ b/docs/command.md @@ -0,0 +1,146 @@ +# 脚本使用 + +**cnocr** 包含了几个命令行工具,安装 **cnocr** 后即可使用。 + + + +## 预测单个文件或文件夹中所有图片 + +使用命令 **`cnocr predict`** 预测单个文件或文件夹中所有图片,以下是使用说明: + + + +```bash +(venv) ➜ cnocr git:(dev) ✗ cnocr predict -h +Usage: cnocr predict [OPTIONS] + +Options: + -m, --model-name TEXT 模型名称。默认值为 densenet_lite_136-fc + -p, --pretrained-model-fp TEXT 使用训练好的模型。默认为 `None`,表示使用系统自带的预训练模型 + -c, --context TEXT 使用cpu还是 `gpu` 运行代码,也可指定为特定gpu,如`cuda:0`。默认为 + `cpu` + + -i, --img-file-or-dir TEXT 输入图片的文件路径或者指定的文件夹 [required] + -s, --single-line 是否输入图片只包含单行文字。对包含单行文字的图片,不做按行切分;否则会先对图片按行分割后 + 再进行识别 + + -h, --help Show this message and exit. +``` + + + +例如可以使用以下命令对图片 `docs/examples/rand_cn1.png` 进行文字识别: + +```bash +cnstd predict -i docs/examples/rand_cn1.png -s +``` + + + +具体使用也可参考文件 [Makefile](https://github.com/breezedeus/cnocr/blob/master/Makefile) 。 + + + + +## 模型评估 + +使用命令 **`cnocr evaluate`** 在指定的数据集上评估模型效果,以下是使用说明: + + + +```bash +(venv) ➜ cnocr git:(dev) ✗ cnocr evaluate -h +Usage: cnocr evaluate [OPTIONS] + +Options: + -m, --model-name TEXT 模型名称。默认值为 densenet_lite_136-fc + -p, --pretrained-model-fp TEXT 使用训练好的模型。默认为 `None`,表示使用系统自带的预训练模型 + -c, --context TEXT 使用cpu还是 `gpu` 运行代码,也可指定为特定gpu,如`cuda:0`。默认为 + `cpu` + + -i, --eval-index-fp TEXT 待评估文件所在的索引文件,格式与训练时训练集索引文件相同,每行格式为 `<图片路径> + <以空格分割的labels>` + + --img-folder TEXT 图片所在文件夹,相对于索引文件中记录的图片位置 [required] + --batch-size INTEGER batch size. 默认值:`128` + -o, --output-dir TEXT 存放评估结果的文件夹。默认值:`eval_results` + -v, --verbose whether to print details to screen + -h, --help Show this message and exit. +``` + + + +例如可以使用以下命令评估 `data/test/dev.tsv` 中指定的所有样本: + +```bash +cnocr evaluate -i data/test/dev.tsv --image-folder data/images +``` + + + +具体使用也可参考文件 [Makefile](https://github.com/breezedeus/cnocr/blob/master/Makefile) 。 + + + +## 模型训练 + +使用命令 **`cnocr train`** 训练文本检测模型,以下是使用说明: + +```bash +(venv) ➜ cnocr git:(dev) ✗ cnocr train -h +Usage: cnocr train [OPTIONS] + +Options: + -m, --model-name TEXT 模型名称。默认值为 densenet_lite_136-fc + -i, --index-dir TEXT 索引文件所在的文件夹,会读取文件夹中的 train.tsv 和 dev.tsv 文件 + [required] + + --train-config-fp TEXT 训练使用的json配置文件,参考 + `docs/examples/train_config.json` + [required] + + -r, --resume-from-checkpoint TEXT + 恢复此前中断的训练状态,继续训练。默认为 `None` + -p, --pretrained-model-fp TEXT 导入的训练好的模型,作为初始模型。优先级低于"--restore-training- + fp",当传入"--restore-training-fp"时,此传入失效。默认为 + `None` + + -h, --help Show this message and exit. +``` + + + +例如可以使用以下命令进行训练: + +```bash +cnocr train -m densenet_lite_136-fc --index-dir data/test --train-config-fp docs/examples/train_config.json +``` + + + +训练数据的格式见文件夹 [data/test](https://github.com/breezedeus/cnocr/blob/master/data/test) 中的 [train.tsv](https://github.com/breezedeus/cnocr/blob/master/data/test/train.tsv) 和 [dev.tsv](https://github.com/breezedeus/cnocr/blob/master/data/test/dev.tsv) 文件。 + + + +具体使用也可参考文件 [Makefile](https://github.com/breezedeus/cnocr/blob/master/Makefile) 。 + + + +## 模型转存 + +训练好的模型会存储训练状态,使用命令 **`cnocr resave`** 去掉与预测无关的数据,降低模型大小。 + +```bash +(venv) ➜ cnocr git:(pytorch) ✗ cnocr resave -h +Usage: cnocr resave [OPTIONS] + + 训练好的模型会存储训练状态,使用此命令去掉预测时无关的数据,降低模型大小 + +Options: + -i, --input-model-fp TEXT 输入的模型文件路径 [required] + -o, --output-model-fp TEXT 输出的模型文件路径 [required] + -h, --help Show this message and exit. +``` + + + diff --git a/docs/contact.md b/docs/contact.md new file mode 100644 index 0000000000000000000000000000000000000000..1c729073bafb902447ca1325416e70649bc66efc --- /dev/null +++ b/docs/contact.md @@ -0,0 +1,7 @@ + +# QQ 交流群 + +欢迎扫码加入QQ交流群: + +![QQ群二维码](./cnocr-qq.jpg) + diff --git a/docs/demo.md b/docs/demo.md new file mode 100644 index 0000000000000000000000000000000000000000..d5dcd9f74d355f320c1a79643e7c4bb8325d336a --- /dev/null +++ b/docs/demo.md @@ -0,0 +1,5 @@ +# 在线 Demo + +地址:[https://share.streamlit.io/breezedeus/cnstd/st-deploy/cnstd/app.py](https://share.streamlit.io/breezedeus/cnstd/st-deploy/cnstd/app.py) 。 + +![Demo](figs/demo.jpg) \ No newline at end of file diff --git a/examples/00010965.jpg b/docs/examples/00010965.jpg similarity index 100% rename from examples/00010965.jpg rename to docs/examples/00010965.jpg diff --git a/examples/00010991.jpg b/docs/examples/00010991.jpg similarity index 100% rename from examples/00010991.jpg rename to docs/examples/00010991.jpg diff --git a/examples/00010994.jpg b/docs/examples/00010994.jpg similarity index 100% rename from examples/00010994.jpg rename to docs/examples/00010994.jpg diff --git a/examples/00199971.jpg b/docs/examples/00199971.jpg similarity index 100% rename from examples/00199971.jpg rename to docs/examples/00199971.jpg diff --git a/examples/00199975.jpg b/docs/examples/00199975.jpg similarity index 100% rename from examples/00199975.jpg rename to docs/examples/00199975.jpg diff --git a/examples/00199978.jpg b/docs/examples/00199978.jpg similarity index 100% rename from examples/00199978.jpg rename to docs/examples/00199978.jpg diff --git a/examples/00199979.jpg b/docs/examples/00199979.jpg similarity index 100% rename from examples/00199979.jpg rename to docs/examples/00199979.jpg diff --git a/examples/00199980.jpg b/docs/examples/00199980.jpg similarity index 100% rename from examples/00199980.jpg rename to docs/examples/00199980.jpg diff --git a/examples/00199985.jpg b/docs/examples/00199985.jpg similarity index 100% rename from examples/00199985.jpg rename to docs/examples/00199985.jpg diff --git a/examples/20457890_2399557098.jpg b/docs/examples/20457890_2399557098.jpg similarity index 100% rename from examples/20457890_2399557098.jpg rename to docs/examples/20457890_2399557098.jpg diff --git a/examples/captcha1.png b/docs/examples/captcha1.png similarity index 100% rename from examples/captcha1.png rename to docs/examples/captcha1.png diff --git a/examples/chn-00199980.jpg b/docs/examples/chn-00199980.jpg similarity index 100% rename from examples/chn-00199980.jpg rename to docs/examples/chn-00199980.jpg diff --git a/examples/chn-00199981.jpg b/docs/examples/chn-00199981.jpg similarity index 100% rename from examples/chn-00199981.jpg rename to docs/examples/chn-00199981.jpg diff --git a/examples/chn-00199984.jpg b/docs/examples/chn-00199984.jpg similarity index 100% rename from examples/chn-00199984.jpg rename to docs/examples/chn-00199984.jpg diff --git a/examples/chn-00199985.jpg b/docs/examples/chn-00199985.jpg similarity index 100% rename from examples/chn-00199985.jpg rename to docs/examples/chn-00199985.jpg diff --git a/examples/chn-00199989.jpg b/docs/examples/chn-00199989.jpg similarity index 100% rename from examples/chn-00199989.jpg rename to docs/examples/chn-00199989.jpg diff --git a/examples/helloworld.jpg b/docs/examples/helloworld.jpg similarity index 100% rename from examples/helloworld.jpg rename to docs/examples/helloworld.jpg diff --git a/examples/hybrid.png b/docs/examples/hybrid.png similarity index 100% rename from examples/hybrid.png rename to docs/examples/hybrid.png diff --git a/examples/multi-line_cn1.png b/docs/examples/multi-line_cn1.png similarity index 100% rename from examples/multi-line_cn1.png rename to docs/examples/multi-line_cn1.png diff --git a/examples/multi-line_cn2.png b/docs/examples/multi-line_cn2.png similarity index 100% rename from examples/multi-line_cn2.png rename to docs/examples/multi-line_cn2.png diff --git a/examples/multi-line_en_black.png b/docs/examples/multi-line_en_black.png similarity index 100% rename from examples/multi-line_en_black.png rename to docs/examples/multi-line_en_black.png diff --git a/examples/multi-line_en_white.png b/docs/examples/multi-line_en_white.png similarity index 100% rename from examples/multi-line_en_white.png rename to docs/examples/multi-line_en_white.png diff --git a/examples/rand_cn1.png b/docs/examples/rand_cn1.png similarity index 100% rename from examples/rand_cn1.png rename to docs/examples/rand_cn1.png diff --git a/examples/rand_cn2.png b/docs/examples/rand_cn2.png similarity index 100% rename from examples/rand_cn2.png rename to docs/examples/rand_cn2.png diff --git a/docs/examples/taobao4.jpg b/docs/examples/taobao4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..22dc27bc7a3354a7c7713768d304e75a6b061d01 Binary files /dev/null and b/docs/examples/taobao4.jpg differ diff --git a/examples/train_config.json b/docs/examples/train_config.json similarity index 68% rename from examples/train_config.json rename to docs/examples/train_config.json index ea0473b4d1ad23cd78a71a18f3e8967607fe03f3..9df5970fcfa54e405f30b596d1e74f1811daa077 100644 --- a/examples/train_config.json +++ b/docs/examples/train_config.json @@ -1,19 +1,19 @@ { - "vocab_fp": "label_cn.txt", + "vocab_fp": "cnocr/label_cn.txt", "img_folder": "data/images", "gpus": 0, - "epochs": 2, - "batch_size": 64, + "epochs": 20, + "batch_size": 4, "num_workers": 0, "pin_memory": false, "optimizer": "adam", "learning_rate": 1e-3, "weight_decay": 0, "lr_scheduler": { - "name": "multi_step", - "milestones": [5, 10, 16, 22, 30], - "gamma": 0.5 + "name": "cos_warmup", + "min_lr_mult_factor": 0.01, + "warmup_epochs": 0.2 }, "precision": 32, "limit_train_batches": 1.0, diff --git a/examples/train_config_gpu.json b/docs/examples/train_config_gpu.json similarity index 68% rename from examples/train_config_gpu.json rename to docs/examples/train_config_gpu.json index df5afc1e40f1588532d21d1314dc1d7d23d24d39..211b3d6d29b8b9f7c4213167eeba8e895bbc62b9 100644 --- a/examples/train_config_gpu.json +++ b/docs/examples/train_config_gpu.json @@ -1,17 +1,19 @@ { - "vocab_fp": "label_cn.txt", - "img_folder": "data/images", + "vocab_fp": "cnocr/label_cn.txt", + "img_folder": "data/output_normal", "gpus": [0], "epochs": 40, - "batch_size": 200, + "batch_size": 100, "num_workers": 12, "pin_memory": true, "optimizer": "adam", "learning_rate": 3e-3, "weight_decay": 0, "lr_scheduler": { - "name": "multi_step", + "name": "cos_warmup", + "min_lr_mult_factor": 0.01, + "warmup_epochs": 0.2, "milestones": [5, 10, 16, 22, 30], "gamma": 0.5 }, diff --git a/docs/faq.md b/docs/faq.md new file mode 100644 index 0000000000000000000000000000000000000000..ce35e0bf4a632b608d05c54d04b833905b2c9d13 --- /dev/null +++ b/docs/faq.md @@ -0,0 +1,19 @@ +# 常见问题(FAQ) + +## CnOcr 是免费的吗? + +CnOcr是免费的,而且是开源的。可以按需自行调整发布或商业使用。 + + +## CnOcr 能识别英文以及空格吗? + +可以。 + +## CnOcr 能识别繁体中文吗? + +不能。 + +## CnOcr 能识别竖排文字的图片吗? + +不能。 + diff --git a/docs/figs/demo.jpg b/docs/figs/demo.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d122265285d1f15747c62d1722fd5e87d9c2e046 Binary files /dev/null and b/docs/figs/demo.jpg differ diff --git a/docs/figs/jinlong.ico b/docs/figs/jinlong.ico new file mode 100755 index 0000000000000000000000000000000000000000..53a1890232341cb74ae073cd28cd47b3c4eaab85 Binary files /dev/null and b/docs/figs/jinlong.ico differ diff --git a/docs/figs/jinlong.png b/docs/figs/jinlong.png new file mode 100644 index 0000000000000000000000000000000000000000..f7641e6393cd4632dcfd22b0d2cdf82ef0b0a9ce Binary files /dev/null and b/docs/figs/jinlong.png differ diff --git a/docs/figs/std-ocr.jpg b/docs/figs/std-ocr.jpg new file mode 100644 index 0000000000000000000000000000000000000000..42c332304780af1dc66bfc58062e5752d5a4829b Binary files /dev/null and b/docs/figs/std-ocr.jpg differ diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000000000000000000000000000000000000..a40b03cf382703310a7711debe94204ed53e24aa --- /dev/null +++ b/docs/index.md @@ -0,0 +1,100 @@ +# CnOcr + +**[CnOcr](https://github.com/breezedeus/cnocr)** 是 **Python 3** 下的**文字识别**(**Optical Character Recognition**,简称**OCR**)工具包, +支持**中文**、**英文**的常见字符识别,自带了多个[训练好的识别模型](models.md),安装后即可直接使用。 +欢迎扫码加入[QQ交流群](contact.md)。 + +CnOcr的目标是**使用简单**。 + +可以使用 [在线 Demo](demo.md) 查看效果。 + + +## 安装简单 + +嗯,安装真的很简单。 + +```bash +pip install cnocr +``` + +更多说明可见 [安装文档](install.md)。 + +## 使用简单 + +使用 `CnOcr.ocr()` 识别下图: + +![多行文字图片](examples/multi-line_cn1.png) + +**调用示例**: + +```python +from cnocr import CnOcr + +ocr = CnOcr() +res = ocr.ocr('examples/multi-line_cn1.png') +print("Predicted Chars:", res) +``` + +或: +```python +from cnocr.utils import read_img +from cnocr import CnOcr + +ocr = CnOcr() +img_fp = 'examples/multi-line_cn1.png' +img = read_img(img_fp) +res = ocr.ocr(img) +print("Predicted Chars:", res) +``` + +更多说明可见 [使用方法](usage.md)。 + + +## 命令行工具 + +具体见 [命令行工具](command.md)。 + +### 训练自己的模型 + +具体见 [模型训练](train.md)。 + +## 效果示例 + +| 图片 | OCR结果 | +| ------------------------------------------------------------ | ------------------------------------------------------------ | +| ![examples/helloworld.jpg](./examples/helloworld.jpg) | Hello world!你好世界 | +| ![examples/chn-00199989.jpg](./examples/chn-00199989.jpg) | 铑泡胭释邑疫反隽寥缔 | +| ![examples/chn-00199980.jpg](./examples/chn-00199980.jpg) | 拇箬遭才柄腾戮胖惬炫 | +| ![examples/chn-00199984.jpg](./examples/chn-00199984.jpg) | 寿猿嗅髓孢刀谎弓供捣 | +| ![examples/chn-00199985.jpg](./examples/chn-00199985.jpg) | 马靼蘑熨距额猬要藕萼 | +| ![examples/chn-00199981.jpg](./examples/chn-00199981.jpg) | 掉江悟厉励.谌查门蠕坑 | +| ![examples/00199975.jpg](./examples/00199975.jpg) | nd-chips fructed ast | +| ![examples/00199978.jpg](./examples/00199978.jpg) | zouna unpayably Raqu | +| ![examples/00199979.jpg](./examples/00199979.jpg) | ape fissioning Senat | +| ![examples/00199971.jpg](./examples/00199971.jpg) | ling oughtlins near | +| ![examples/multi-line_cn1.png](./examples/multi-line_cn1.png) | 网络支付并无本质的区别,因为
每一个手机号码和邮件地址背后
都会对应着一个账户--这个账
户可以是信用卡账户、借记卡账
户,也包括邮局汇款、手机代
收、电话代收、预付费卡和点卡
等多种形式。 | +| ![examples/multi-line_cn2.png](./examples/multi-line_cn2.png) | 当然,在媒介越来越多的情形下,
意味着传播方式的变化。过去主流
的是大众传播,现在互动性和定制
性带来了新的挑战——如何让品牌
与消费者更加互动。 | +| ![examples/multi-line_en_white.png](./examples/multi-line_en_white.png) | This chapter is currently only available
in this web version. ebook and print will follow.
Convolutional neural networks learn abstract
features and concepts from raw image pixels. Feature
Visualization visualizes the learned features
by activation maximization. Network Dissection labels
neural network units (e.g. channels) with human concepts. | +| ![examples/multi-line_en_black.png](./examples/multi-line_en_black.png) | transforms the image many times. First, the image
goes through many convolutional layers. In those
convolutional layers, the network learns new
and increasingly complex features in its layers. Then the
transformed image information goes through
the fully connected layers and turns into a classification
or prediction. | + + +## 其他文档 + +* [场景文字识别技术介绍(PPT+视频)](std_ocr.md) +* 对于通用场景的文字识别,使用 [文本检测CnStd + 文字识别CnOcr](cnstd_cnocr.md) +* [RELEASE文档](RELEASE.md) + + +## 未来工作 + +* [x] 支持图片包含多行文字 (`Done`) +* [x] crnn模型支持可变长预测,提升灵活性 (since `V1.0.0`) +* [x] 完善测试用例 (`Doing`) +* [x] 修bugs(目前代码还比较凌乱。。) (`Doing`) +* [x] 支持`空格`识别(since `V1.1.0`) +* [x] 尝试新模型,如 DenseNet,进一步提升识别准确率(since `V1.1.0`) +* [x] 优化训练集,去掉不合理的样本;在此基础上,重新训练各个模型 +* [x] 由 MXNet 改为 PyTorch 架构(since `V2.0.0`) +* [ ] 基于 PyTorch 训练更高效的模型 +* [ ] 支持列格式的文字识别 + diff --git a/docs/install.md b/docs/install.md new file mode 100644 index 0000000000000000000000000000000000000000..bde2f850d4c8d4bb5981947537b314169eaf86b9 --- /dev/null +++ b/docs/install.md @@ -0,0 +1,21 @@ +## 安装 + +嗯,安装真的很简单。 + +```bash +pip install cnocr +``` + + + +安装速度慢的话,可以指定国内的安装源,如使用豆瓣源: + +```bash +pip install cnocr -i https://pypi.doubanio.com/simple +``` + + + +> 注意:请使用 **Python3**(3.6以及之后版本应该都行),没测过Python2下是否ok。 + + diff --git a/docs/intro-cnstd-cnocr.pdf b/docs/intro-cnstd-cnocr.pdf new file mode 100644 index 0000000000000000000000000000000000000000..56931f78b6815cf084045cd6e47abad140affdc0 Binary files /dev/null and b/docs/intro-cnstd-cnocr.pdf differ diff --git a/docs/models.md b/docs/models.md new file mode 100644 index 0000000000000000000000000000000000000000..b9118daba97388237f850aee95006ffd584d4990 --- /dev/null +++ b/docs/models.md @@ -0,0 +1,26 @@ +## 可直接使用的模型 + +cnocr的ocr模型可以分为两阶段:第一阶段是获得ocr图片的局部编码向量,第二部分是对局部编码向量进行序列学习,获得序列编码向量。目前的PyTorch版本的两个阶段分别包含以下模型: + +1. 局部编码模型(emb model) + * **`densenet_lite_`**:一个微型的`densenet`网络;其中的``表示模型中每个block包含的层数。 + * **`densenet`**:一个小型的`densenet`网络; +2. 序列编码模型(seq model) + * **`fc`**:两层的全连接网络; + * **`gru`**:一层的GRU网络; + * **`lstm`**:一层的LSTM网络。 + + + +cnocr **V2.1** 目前包含以下可直接使用的模型,训练好的模型都放在 **[cnstd-cnocr-models](https://github.com/breezedeus/cnstd-cnocr-models)** 项目中,可免费下载使用: + +| Name | 参数规模 | 模型文件大小 | 准确度 | 平均推断耗时(毫秒/图) | +| --- | --- | --- | --- | --- | +| densenet\_lite\_114-fc | 1.3 M | 4.9 M | 0.9274 | 9.229 | +| densenet\_lite\_124-fc | 1.3 M | 5.1 M | 0.9429 | 10.112 | +| densenet\_lite\_134-fc | 1.4 M | 5.4 M | 0.954 | 10.843 | +| densenet\_lite\_136-fc | 1.5M | 5.9 M | 0.9631 | 11.499 | +| densenet\_lite\_134-gru | 2.9 M | 11 M | 0.9738 | 17.042 | +| densenet\_lite\_136-gru | 3.1 M | 12 M | 0.9756 | 17.725 | + +> 模型名称是由局部编码模型和序列编码模型名称拼接而成,以符合"-"分割。 diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d8f1468c6d71364bda7a45167fc8f8901068275e --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,60 @@ +# +# This file is autogenerated by pip-compile +# To update, run: +# +# pip-compile --output-file=requirements.txt requirements.in +# +--index-url https://pypi.doubanio.com/simple + +absl-py==0.13.0 # via tensorboard +aiohttp==3.7.4.post0 # via fsspec +async-timeout==3.0.1 # via aiohttp +attrs==21.2.0 # via aiohttp +cachetools==4.2.2 # via google-auth +certifi==2020.4.5.1 # via requests +chardet==3.0.4 # via aiohttp, requests +click==8.0.1 # via -r requirements.in +fsspec[http]==2021.7.0 # via pytorch-lightning +future==0.18.2 # via pytorch-lightning +google-auth-oauthlib==0.4.5 # via tensorboard +google-auth==1.35.0 # via google-auth-oauthlib, tensorboard +grpcio==1.39.0 # via tensorboard +idna==2.9 # via requests, yarl +markdown==3.3.4 # via tensorboard +multidict==5.1.0 # via aiohttp, yarl +numpy==1.18.3 # via -r requirements.in, pytorch-lightning, tensorboard, torchmetrics, torchvision +oauthlib==3.1.1 # via requests-oauthlib +packaging==21.0 # via pytorch-lightning, torchmetrics +pillow==5.3.0 # via -r requirements.in, torchvision +protobuf==3.17.3 # via tensorboard +pyasn1-modules==0.2.8 # via google-auth +pyasn1==0.4.8 # via pyasn1-modules, rsa +pydeprecate==0.3.1 # via pytorch-lightning +pyparsing==2.4.7 # via packaging +pytorch-lightning==1.4.4 # via -r requirements.in +pyyaml==5.4.1 # via pytorch-lightning +requests-oauthlib==1.3.0 # via google-auth-oauthlib +requests==2.23.0 # via fsspec, requests-oauthlib, tensorboard +rsa==4.7.2 # via google-auth +six==1.14.0 # via absl-py, google-auth, grpcio, protobuf +tensorboard-data-server==0.6.1 # via tensorboard +tensorboard-plugin-wit==1.8.0 # via tensorboard +tensorboard==2.6.0 # via pytorch-lightning +torch==1.9.0 # via -r requirements.in, pytorch-lightning, torchmetrics, torchvision +torchmetrics==0.5.0 # via pytorch-lightning +torchvision==0.10.0 # via -r requirements.in +tqdm==4.45.0 # via -r requirements.in, pytorch-lightning +typing-extensions==3.10.0.0 # via aiohttp, pytorch-lightning, torch +urllib3==1.25.9 # via requests +werkzeug==2.0.1 # via tensorboard +wheel==0.37.0 # via tensorboard +yarl==1.6.3 # via aiohttp + +# The following packages are considered to be unsafe in a requirements file: +# setuptools + +# for mkdocs +mkdocs +mkdocs-macros-plugin +mkdocs-material +mkdocstrings diff --git a/docs/static/css/custom.css b/docs/static/css/custom.css new file mode 100644 index 0000000000000000000000000000000000000000..b5340d6f9ba8b33ba4cb0db7162410146a079cb7 --- /dev/null +++ b/docs/static/css/custom.css @@ -0,0 +1,80 @@ +/* Sidebar */ +.md-sidebar { + width: 10rem; +} + +/* Indenting docstrings */ +div.doc-contents:not(.first) { + padding-left: 25px; + border-left: 4px solid #e6e6e6; + margin-bottom: 80px; +} + +/* Functions inside classes */ +.md-typeset h5 { + font-size: 0.8rem; + text-transform: none !important; +} + +/* Code highlights (custom) */ +.highlight .hll { background-color: #ffffcc } +.highlight .c { color: #408080; font-style: italic } /* Comment */ +.highlight .err { border: 1px solid #FF0000 } /* Error */ +.highlight .k { color: #008000; font-weight: bold } /* Keyword */ +.highlight .o { color: #AE2FFE } /* Operator */ +.highlight .cm { color: #408080; font-style: italic } /* Comment.Multiline */ +.highlight .cp { color: #BC7A00 } /* Comment.Preproc */ +.highlight .c1 { color: #408080; font-style: italic } /* Comment.Single */ +.highlight .cs { color: #408080; font-style: italic } /* Comment.Special */ +.highlight .gd { color: #A00000 } /* Generic.Deleted */ +.highlight .ge { font-style: italic } /* Generic.Emph */ +.highlight .gr { color: #FF0000 } /* Generic.Error */ +.highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */ +.highlight .gi { color: #00A000 } /* Generic.Inserted */ +.highlight .go { color: #808080 } /* Generic.Output */ +.highlight .gp { color: #000080; font-weight: bold } /* Generic.Prompt */ +.highlight .gs { font-weight: bold } /* Generic.Strong */ +.highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ +.highlight .gt { color: #0040D0 } /* Generic.Traceback */ +.highlight .kc { color: #008000; font-weight: bold } /* Keyword.Constant */ +.highlight .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */ +.highlight .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */ +.highlight .kp { color: #008000 } /* Keyword.Pseudo */ +.highlight .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */ +.highlight .kt { color: #B00040 } /* Keyword.Type */ +.highlight .m { color: #008000 } /* Literal.Number */ +.highlight .s { color: #BA2121 } /* Literal.String */ +.highlight .na { color: #7D9029 } /* Name.Attribute */ +.highlight .nb { color: #008000 } /* Name.Builtin */ +.highlight .nc { color: #0000FF; font-weight: bold } /* Name.Class */ +.highlight .no { color: #880000 } /* Name.Constant */ +.highlight .nd { color: #AA22FF } /* Name.Decorator */ +.highlight .ni { color: #999999; font-weight: bold } /* Name.Entity */ +.highlight .ne { color: #D2413A; font-weight: bold } /* Name.Exception */ +.highlight .nf { color: #0000FF } /* Name.Function */ +.highlight .nl { color: #A0A000 } /* Name.Label */ +.highlight .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */ +.highlight .nt { color: #008000; font-weight: bold } /* Name.Tag */ +.highlight .nv { color: #19177C } /* Name.Variable */ +.highlight .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */ +.highlight .w { color: #bbbbbb } /* Text.Whitespace */ +.highlight .mf { color: #008000 } /* Literal.Number.Float */ +.highlight .mh { color: #008000 } /* Literal.Number.Hex */ +.highlight .mi { color: #008000 } /* Literal.Number.Integer */ +.highlight .mo { color: #008000 } /* Literal.Number.Oct */ +.highlight .sb { color: #BA2121 } /* Literal.String.Backtick */ +.highlight .sc { color: #BA2121 } /* Literal.String.Char */ +.highlight .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */ +.highlight .s2 { color: #BA2121 } /* Literal.String.Double */ +.highlight .se { color: #BB6622; font-weight: bold } /* Literal.String.Escape */ +.highlight .sh { color: #BA2121 } /* Literal.String.Heredoc */ +.highlight .si { color: #BB6688; font-weight: bold } /* Literal.String.Interpol */ +.highlight .sx { color: #008000 } /* Literal.String.Other */ +.highlight .sr { color: #BB6688 } /* Literal.String.Regex */ +.highlight .s1 { color: #BA2121 } /* Literal.String.Single */ +.highlight .ss { color: #19177C } /* Literal.String.Symbol */ +.highlight .bp { color: #008000 } /* Name.Builtin.Pseudo */ +.highlight .vc { color: #19177C } /* Name.Variable.Class */ +.highlight .vg { color: #19177C } /* Name.Variable.Global */ +.highlight .vi { color: #19177C } /* Name.Variable.Instance */ +.highlight .il { color: #008000 } /* Literal.Number.Integer.Long */ diff --git a/docs/std_ocr.md b/docs/std_ocr.md new file mode 100644 index 0000000000000000000000000000000000000000..14678471b666669cb80a7e6878f7985aef3e6f19 --- /dev/null +++ b/docs/std_ocr.md @@ -0,0 +1,29 @@ +# 场景文字识别技术介绍 + +为了识别一张图片中的文字,通常包含两个步骤: + +1. **文本检测**:检测出图片中文字所在的位置; +2. **文字识别**:识别包含文字的图片局部,预测具体的文字。 + +如下图: + +![文字识别流程](figs/std-ocr.jpg) + +更多相关介绍可参考作者分享:**文本检测与识别**([PPT](intro-cnstd-cnocr.pdf)、[B站视频](https://www.bilibili.com/video/BV1uU4y1N7Ba))。 + +--- + +cnocr 主要功能是上面的第二步,也即文字识别。有些应用场景(如下图的文字截图图片等),待检测的图片背景很简单,如白色或其他纯色, +cnocr 内置的文字检测和分行模块可以处理这种简单场景。 + +![文字截图图片](examples/multi-line_cn1.png) + + +但如果用于其他复杂的场景文字图片(如下图)的识别, +cnocr 需要结合其他的场景文字检测引擎使用,推荐文字检测引擎 **[CnStd](https://github.com/breezedeus/cnstd)** 。 + +![复杂场景文字图片](examples/taobao4.jpg) + + +具体使用方式,可参考 [文本检测CnStd + 文字识别CnOcr](cnstd_cnocr.md)。 + diff --git a/docs/train.md b/docs/train.md new file mode 100644 index 0000000000000000000000000000000000000000..fa33eca172258b28b4903b77ea6a4dfa64dedbee --- /dev/null +++ b/docs/train.md @@ -0,0 +1,64 @@ +# 模型训练 + +自带模型基于 `500+万` 的文字图片训练而成。 + + +## 训练命令 + +[命令行工具](command.md) 介绍了训练命令。使用命令 **`cnocr train`** 训练文本检测模型,以下是使用说明: + +```bash +(venv) ➜ cnocr git:(dev) ✗ cnocr train -h +Usage: cnocr train [OPTIONS] + +Options: + -m, --model-name TEXT 模型名称。默认值为 densenet_lite_136-fc + -i, --index-dir TEXT 索引文件所在的文件夹,会读取文件夹中的 train.tsv 和 dev.tsv 文件 + [required] + + --train-config-fp TEXT 训练使用的json配置文件,参考 + `docs/examples/train_config.json` + [required] + + -r, --resume-from-checkpoint TEXT + 恢复此前中断的训练状态,继续训练。默认为 `None` + -p, --pretrained-model-fp TEXT 导入的训练好的模型,作为初始模型。优先级低于"--restore-training- + fp",当传入"--restore-training-fp"时,此传入失效。默认为 + `None` + + -h, --help Show this message and exit. +``` + + + +例如可以使用以下命令进行训练: + +```bash +cnocr train -m densenet_lite_136-fc --index-dir data/test --train-config-fp docs/examples/train_config.json +``` + + +训练数据的格式见文件夹 [data/test](https://github.com/breezedeus/cnocr/blob/master/data/test) 中的 [train.tsv](https://github.com/breezedeus/cnocr/blob/master/data/test/train.tsv) 和 [dev.tsv](https://github.com/breezedeus/cnocr/blob/master/data/test/dev.tsv) 文件。 + + + +具体使用也可参考文件 [Makefile](https://github.com/breezedeus/cnocr/blob/master/Makefile) 。 + + + +# 模型精调 + +如果需要在已有模型的基础上精调模型,需要把训练配置中的学习率设置的较小,`lr_scheduler`的设置可参考以下: + +```json + "learning_rate": 3e-5, + "lr_scheduler": { + "name": "cos_warmup", + "min_lr_mult_factor": 0.01, + "warmup_epochs": 2 + }, +``` + + + +> 注:需要尽量避免过度精调! diff --git a/docs/usage.md b/docs/usage.md new file mode 100644 index 0000000000000000000000000000000000000000..6387bc85b0e6df696f0dfdb74e8e9de5026b4c32 --- /dev/null +++ b/docs/usage.md @@ -0,0 +1,280 @@ +# 使用方法 + +## 模型文件自动下载 +首次使用cnocr时,系统会**自动下载** zip格式的模型压缩文件,并存于 `~/.cnocr`目录(Windows下默认路径为 `C:\Users\\AppData\Roaming\cnocr`)。 +下载后的zip文件代码会自动对其解压,然后把解压后的模型相关目录放于`~/.cnocr/2.1`目录中。 + +如果系统无法自动成功下载zip文件,则需要手动从 **[cnstd-cnocr-models](https://huggingface.co/breezedeus/cnstd-cnocr-models/tree/main)** 下载此zip文件并把它放于 `~/.cnocr/2.1`目录。如果下载太慢,也可以从 [百度云盘](https://pan.baidu.com/s/1N6HoYearUzU0U8NTL3K35A) 下载, 提取码为 ` gcig`。 + +放置好zip文件后,后面的事代码就会自动执行了。 + +## 预测代码 + +### 针对多行文字的图片识别 + +如果待识别的图片包含多行文字,或者可能包含多行文字(如下图),可以使用 `CnOcr.ocr()` 进行识别。 + +![多行文字图片](examples/multi-line_cn1.png) + +**调用示例**: + +```python +from cnocr import CnOcr + +ocr = CnOcr() +res = ocr.ocr('docs/examples/multi-line_cn1.png') +print("Predicted Chars:", res) +``` + +或: +```python +from cnocr.utils import read_img +from cnocr import CnOcr + +ocr = CnOcr() +img_fp = 'docs/examples/multi-line_cn1.png' +img = read_img(img_fp) +res = ocr.ocr(img) +print("Predicted Chars:", res) +``` + +### 针对单行文字的图片识别 + +如果明确知道待识别的图片包含单行文字(如下图),可以使用 `CnOcr.ocr_for_single_line()` 进行识别。 + +![单行文字图片](examples/helloworld.jpg) + +**调用示例**: + +```python +from cnocr import CnOcr + +ocr = CnOcr() +res = ocr.ocr_for_single_line('docs/examples/helloworld.jpg') +print("Predicted Chars:", res) +``` + +或: + +```python +from cnocr.utils import read_img +from cnocr import CnOcr + +ocr = CnOcr() +img_fp = 'docs/examples/helloworld.jpg' +img = read_img(img_fp) +res = ocr.ocr_for_single_line(img) +print("Predicted Chars:", res) +``` + + +## 效果示例 + +| 图片 | OCR结果 | +| ------------------------------------------------------------ | ------------------------------------------------------------ | +| ![examples/helloworld.jpg](./examples/helloworld.jpg) | Hello world!你好世界 | +| ![examples/chn-00199989.jpg](./examples/chn-00199989.jpg) | 铑泡胭释邑疫反隽寥缔 | +| ![examples/chn-00199980.jpg](./examples/chn-00199980.jpg) | 拇箬遭才柄腾戮胖惬炫 | +| ![examples/chn-00199984.jpg](./examples/chn-00199984.jpg) | 寿猿嗅髓孢刀谎弓供捣 | +| ![examples/chn-00199985.jpg](./examples/chn-00199985.jpg) | 马靼蘑熨距额猬要藕萼 | +| ![examples/chn-00199981.jpg](./examples/chn-00199981.jpg) | 掉江悟厉励.谌查门蠕坑 | +| ![examples/00199975.jpg](./examples/00199975.jpg) | nd-chips fructed ast | +| ![examples/00199978.jpg](./examples/00199978.jpg) | zouna unpayably Raqu | +| ![examples/00199979.jpg](./examples/00199979.jpg) | ape fissioning Senat | +| ![examples/00199971.jpg](./examples/00199971.jpg) | ling oughtlins near | +| ![examples/multi-line_cn1.png](./examples/multi-line_cn1.png) | 网络支付并无本质的区别,因为
每一个手机号码和邮件地址背后
都会对应着一个账户--这个账
户可以是信用卡账户、借记卡账
户,也包括邮局汇款、手机代
收、电话代收、预付费卡和点卡
等多种形式。 | +| ![examples/multi-line_cn2.png](./examples/multi-line_cn2.png) | 当然,在媒介越来越多的情形下,
意味着传播方式的变化。过去主流
的是大众传播,现在互动性和定制
性带来了新的挑战——如何让品牌
与消费者更加互动。 | +| ![examples/multi-line_en_white.png](./examples/multi-line_en_white.png) | This chapter is currently only available
in this web version. ebook and print will follow.
Convolutional neural networks learn abstract
features and concepts from raw image pixels. Feature
Visualization visualizes the learned features
by activation maximization. Network Dissection labels
neural network units (e.g. channels) with human concepts. | +| ![examples/multi-line_en_black.png](./examples/multi-line_en_black.png) | transforms the image many times. First, the image
goes through many convolutional layers. In those
convolutional layers, the network learns new
and increasingly complex features in its layers. Then the
transformed image information goes through
the fully connected layers and turns into a classification
or prediction. | + + + +## 详细使用说明 + +[类CnOcr](cnocr/cn_ocr.md) 是识别主类,包含了三个函数针对不同场景进行文字识别。类`CnOcr`的初始化函数如下: + +```python +class CnOcr(object): + def __init__( + self, + model_name: str = 'densenet_lite_136-fc' + *, + cand_alphabet: Optional[Union[Collection, str]] = None, + context: str = 'cpu', # ['cpu', 'gpu', 'cuda'] + model_fp: Optional[str] = None, + root: Union[str, Path] = data_dir(), + **kwargs, + ): +``` + +其中的几个参数含义如下: + +* `model_name`: 模型名称,即上面表格第一列中的值。默认为 `densenet_lite_136-fc`。 + +* `cand_alphabet`: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围。取值可以是字符串,如 `"0123456789"`,或者字符列表,如 `["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]`。 + + * `cand_alphabet`也可以初始化后通过类函数 `CnOcr.set_cand_alphabet(cand_alphabet)` 进行设置。这样同一个实例也可以指定不同的`cand_alphabet`进行识别。 + +* `context`:预测使用的机器资源,可取值为字符串`cpu`、`gpu`、`cuda:0`等。 + +* `model_fp`: 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件(`.ckpt` 文件)。 + +* `root`: 模型文件所在的根目录。 + + * Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.1/densenet_lite_136-fc`。 + * Windows下默认值为 `C:\Users\\AppData\Roaming\cnocr`。 + + + +每个参数都有默认取值,所以可以不传入任何参数值进行初始化:`ocr = CnOcr()`。 + + +--- + + + +类`CnOcr`主要包含三个函数,下面分别说明。 + + + +### 1. 函数`CnOcr.ocr(img_fp)` + +函数`CnOcr.ocr(img_fp)`可以对包含多行文字(或单行)的图片进行文字识别。 + + + +**函数说明**: + +- 输入参数 `img_fp`: 可以是需要识别的图片文件路径(如下例);或者是已经从图片文件中读入的数组,类型可以为 `torch.Tensor` 或 `np.ndarray`,取值应该是`[0,255]`的整数,维数应该是 `[height, width]` (灰度图片)或者 `[height, width, channel]`,`channel` 可以等于`1`(灰度图片)或者`3`(`RGB`格式的彩色图片)。 +- 返回值:为一个嵌套的`list`,其中的每个元素存储了对一行文字的识别结果,其中也包含了识别概率值。类似这样`[(['第', '一', '行'], 0.80), (['第', '二', '行'], 0.75), (['第', '三', '行'], 0.9)]`,其中的数字为对应的识别概率值。 + + + +**调用示例**: + + +```python +from cnocr import CnOcr + +ocr = CnOcr() +res = ocr.ocr('examples/multi-line_cn1.png') +print("Predicted Chars:", res) +``` + +或: +```python +from cnocr.utils import read_img +from cnocr import CnOcr + +ocr = CnOcr() +img_fp = 'examples/multi-line_cn1.png' +img = read_img(img_fp) +res = ocr.ocr(img) +print("Predicted Chars:", res) +``` + + + +上面使用的图片文件 [docs/examples/multi-line_cn1.png](./examples/multi-line_cn1.png)内容如下: + +![examples/multi-line_cn1.png](./examples/multi-line_cn1.png) + + + +上面预测代码段的返回结果如下: + +```bash +Predicted Chars: [ + (['网', '络', '支', '付', '并', '无', '本', '质', '的', '区', '别', ',', '因', '为'], 0.8677546381950378), + (['每', '一', '个', '手', '机', '号', '码', '和', '邮', '件', '地', '址', '背', '后'], 0.6706454157829285), + (['都', '会', '对', '应', '着', '一', '个', '账', '户', '一', '一', '这', '个', '账'], 0.5052655935287476), + (['户', '可', '以', '是', '信', '用', '卡', '账', '户', '、', '借', '记', '卡', '账'], 0.7785991430282593), + (['户', ',', '也', '包', '括', '邮', '局', '汇', '款', '、', '手', '机', '代'], 0.37458470463752747), + (['收', '、', '电', '话', '代', '收', '、', '预', '付', '费', '卡', '和', '点', '卡'], 0.7326119542121887), + (['等', '多', '种', '形', '式', '。'], 0.14462216198444366)] +``` + + + +### 2. 函数`CnOcr.ocr_for_single_line(img_fp)` + +如果明确知道要预测的图片中只包含了单行文字,可以使用函数`CnOcr.ocr_for_single_line(img_fp)`进行识别。和 `CnOcr.ocr()`相比,`CnOcr.ocr_for_single_line()`结果可靠性更强,因为它不需要做额外的分行处理。 + +**函数说明**: + +- 输入参数 `img_fp`: 可以是需要识别的图片文件路径(如下例);或者是已经从图片文件中读入的数组,类型可以为 `torch.Tensor` 或 `np.ndarray`,取值应该是`[0,255]`的整数,维数应该是 `[height, width]` (灰度图片)或者 `[height, width, channel]`,`channel` 可以等于`1`(灰度图片)或者`3`(`RGB`格式的彩色图片)。 +- 返回值:为一个`tuple`,其中存储了对一行文字的识别结果,也包含了识别概率值。类似这样`(['第', '一', '行'], 0.80)`,其中的数字为对应的识别概率值。 + + + +**调用示例**: + +```python +from cnocr import CnOcr + +ocr = CnOcr() +res = ocr.ocr_for_single_line('examples/rand_cn1.png') +print("Predicted Chars:", res) +``` + +或: + +```python +from cnocr.utils import read_img +from cnocr import CnOcr + +ocr = CnOcr() +img_fp = 'examples/rand_cn1.png' +img = read_img(img_fp) +res = ocr.ocr_for_single_line(img) +print("Predicted Chars:", res) +``` + + +对图片文件 [docs/examples/rand_cn1.png](./examples/rand_cn1.png): + +![examples/rand_cn1.png](./examples/rand_cn1.png) + +的预测结果如下: + +```bash +Predicted Chars: (['笠', '淡', '嘿', '骅', '谧', '鼎', '皋', '姚', '歼', '蠢', '驼', '耳', '胬', '挝', '涯', '狗', '蒽', '了', '狞'], 0.7832438349723816) +``` + + + +### 3. 函数`CnOcr.ocr_for_single_lines(img_list, batch_size=1)` + +函数`CnOcr.ocr_for_single_lines(img_list)`可以**对多个单行文字图片进行批量预测**。函数`CnOcr.ocr(img_fp)`和`CnOcr.ocr_for_single_line(img_fp)`内部其实都是调用的函数`CnOcr.ocr_for_single_lines(img_list)`。 + + + +**函数说明**: + +- 输入参数` img_list`: 为一个`list`;其中每个元素可以是需要识别的图片文件路径(如下例);或者是已经从图片文件中读入的数组,类型可以为 `torch.Tensor` 或 `np.ndarray`,取值应该是`[0,255]`的整数,维数应该是 `[height, width]` (灰度图片)或者 `[height, width, channel]`,`channel` 可以等于`1`(灰度图片)或者`3`(`RGB`格式的彩色图片)。 +- 输入参数 `batch_size`: 待处理图片很多时,需要分批处理,每批图片的数量由此参数指定。默认为 `1`。 +- 返回值:为一个嵌套的`list`,其中的每个元素存储了对一行文字的识别结果,其中也包含了识别概率值。类似这样`[(['第', '一', '行'], 0.80), (['第', '二', '行'], 0.75), (['第', '三', '行'], 0.9)]`,其中的数字为对应的识别概率值。 + + + +**调用示例**: + +```python +import numpy as np + +from cnocr.utils import read_img +from cnocr import CnOcr, line_split + +ocr = CnOcr() +img_fp = 'examples/multi-line_cn1.png' +img = read_img(img_fp) +line_imgs = line_split(np.squeeze(img, -1), blank=True) +line_img_list = [line_img for line_img, _ in line_imgs] +res = ocr.ocr_for_single_lines(line_img_list) +print("Predicted Chars:", res) +``` + + + +更详细的使用方法,可参考 [tests/test_cnocr.py](https://github.com/breezedeus/cnocr/blob/master/tests/test_cnocr.py) 中提供的测试用例。 + diff --git a/examples/taobao4.jpg b/examples/taobao4.jpg deleted file mode 100644 index 69e620abc5ea78ea57a6aff2a1816a0e6424bea5..0000000000000000000000000000000000000000 Binary files a/examples/taobao4.jpg and /dev/null differ diff --git a/gpu.Makefile b/gpu.Makefile index 99363a62ef0aef171bc1c5d17f926d6d75d72b91..039f8d2e52dc7392215702f35cc9de9dcaf5011a 100644 --- a/gpu.Makefile +++ b/gpu.Makefile @@ -1,24 +1,23 @@ # 可取值:['densenet-s'] -ENCODER_NAME = densenet-s +ENCODER_NAME = densenet-lite-136 # 可取值:['fc', 'gru', 'lstm'] -DECODER_NAME = gru +DECODER_NAME = fclite MODEL_NAME = $(ENCODER_NAME)-$(DECODER_NAME) -EPOCH = 41 -INDEX_DIR = data -TRAIN_CONFIG_FP = examples/train_config_gpu.json +INDEX_DIR = data/output_normal +TRAIN_CONFIG_FP = docs/examples/train_config_gpu.json train: cnocr train -m $(MODEL_NAME) --index-dir $(INDEX_DIR) --train-config-fp $(TRAIN_CONFIG_FP) evaluate: - python scripts/cnocr_evaluate.py --model-name $(MODEL_NAME) --model-epoch $(EPOCH) -i $(REC_DATA_ROOT_DIR)/test-part.txt --image-prefix-dir $(REC_DATA_ROOT_DIR) --batch-size 128 --gpu 1 -o evaluate/$(MODEL_NAME)-$(EPOCH) + cnocr evaluate -m $(MODEL_NAME) -i $(REC_DATA_ROOT_DIR)/test-part.txt --image-folder $(REC_DATA_ROOT_DIR) --batch-size 128 -c cuda:0 -o eval_results/$(MODEL_NAME)-$(EPOCH) filter: python scripts/filter_samples.py --sample_file $(REC_DATA_ROOT_DIR)/test-part.txt --badcases_file evaluate/$(MODEL_NAME)-$(EPOCH)/badcases.txt --distance_thrsh 2 -o $(REC_DATA_ROOT_DIR)/new.txt predict: - cnocr predict -m $(MODEL_NAME) -f examples/rand_cn1.png + cnocr predict -m $(MODEL_NAME) -f docs/examples/rand_cn1.png diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000000000000000000000000000000000000..ec3af45526e6bc5cdf226429c6fd726964fca96f --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,113 @@ +# Project information +site_name: CnOcr +site_url: https://cnocr.readthedocs.io +site_description: CnOcr 使用说明 +site_author: Breezedeus + +# Repository +repo_url: https://github.com/breezedeus/cnocr +repo_name: Breezedeus/CnOcr +edit_uri: "" #disables edit button + +# Copyright +copyright: Copyright © 2021 + +# Social media +extra: + social: + - icon: fontawesome/brands/github + link: https://github.com/breezedeus + - icon: fontawesome/brands/zhihu + link: https://www.zhihu.com/people/breezedeus-50 + - icon: fontawesome/brands/youtube + link: https://space.bilibili.com/509307267 + - icon: fontawesome/brands/twitter + link: https://twitter.com/breezedeus + +# Configuration +theme: + name: material +# name: readthedocs + logo: figs/jinlong.png + favicon: figs/jinlong.ico + palette: + primary: indigo + accent: indigo + font: + text: Roboto + code: Roboto Mono + features: + - navigation.tabs + - navigation.expand + icon: + repo: fontawesome/brands/github + +# Extensions +markdown_extensions: + - meta + - pymdownx.emoji: + emoji_index: !!python/name:materialx.emoji.twemoji + emoji_generator: !!python/name:materialx.emoji.to_svg + - admonition # alerts + - pymdownx.details # collapsible alerts + - pymdownx.superfences # nest code and content inside alerts + - attr_list # add HTML and CSS to Markdown elements + - pymdownx.inlinehilite # inline code highlights + - pymdownx.keys # show keystroke symbols + - pymdownx.snippets # insert content from other files + - pymdownx.tabbed # content tabs + - footnotes + - def_list + - pymdownx.arithmatex: # mathjax + generic: true + - pymdownx.tasklist: + custom_checkbox: true + clickable_checkbox: false + - codehilite + - pymdownx.highlight: + use_pygments: true + - toc: + toc_depth: 4 + +# Plugins +plugins: + - search + - macros + - mkdocstrings: + default_handler: python + handlers: + python: + rendering: + show_root_heading: false + show_source: true + show_category_heading: true + watch: + - cnocr + +# Extra CSS +extra_css: + - static/css/custom.css + +# Extra JS +extra_javascript: + - https://cdnjs.cloudflare.com/ajax/libs/tablesort/5.2.1/tablesort.min.js + - https://polyfill.io/v3/polyfill.min.js?features=es6 + - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js + +# Page tree +nav: + - Home: index.md + - Docs: + - 场景文字识别技术介绍: std_ocr.md + - 安装: install.md + - 使用方法: usage.md + - 命令行工具: command.md + - 在线 Demo: demo.md + - 自带模型: models.md + - 文本检测CnStd + 文字识别CnOcr: cnstd_cnocr.md + - 模型训练: train.md + - QQ交流群: contact.md + - 常见问题(FAQ): faq.md + - RELEASE 文档: RELEASE.md + - Python API: + - CnOcr 类: cnocr/cn_ocr.md \ No newline at end of file diff --git a/requirements.in b/requirements.in index f8acf7d414fe15c25ca000e5b9e232276e0fc0b2..ba4b4c77f18751454a916421bc6855574c138f8a 100644 --- a/requirements.in +++ b/requirements.in @@ -1,7 +1,8 @@ click tqdm -torch>=1.7.0 -torchvision +torch>=1.8.0 +torchvision>=0.9.0 numpy pytorch-lightning pillow>=5.3.0 +python-Levenshtein diff --git a/requirements.txt b/requirements.txt index 0424a127501f1a0c3960259695af93f2ca297aec..dd051cdf02148bb643dbead5d11ea951d6cd72df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,6 +31,7 @@ pyasn1-modules==0.2.8 # via google-auth pyasn1==0.4.8 # via pyasn1-modules, rsa pydeprecate==0.3.1 # via pytorch-lightning pyparsing==2.4.7 # via packaging +python-levenshtein==0.12.0 # via -r requirements.in pytorch-lightning==1.4.4 # via -r requirements.in pyyaml==5.4.1 # via pytorch-lightning requests-oauthlib==1.3.0 # via google-auth-oauthlib diff --git a/setup.py b/setup.py index 97280b177882e0c39e0fd85ce1fec2975502e95b..d44a6fb8b8ac78d2ff8107e6bd081f6c7cb4c13e 100644 --- a/setup.py +++ b/setup.py @@ -39,11 +39,12 @@ exec( required = [ "click", "tqdm", - "torch>=1.7.0", - "torchvision", + "torch>=1.8.0", + "torchvision>=0.9.0", 'numpy', "pytorch-lightning", "pillow>=5.3.0", + "python-Levenshtein", ] extras_require = { "dev": ["pip-tools", "pytest", "python-Levenshtein"], diff --git a/tests/test_cnocr.py b/tests/test_cnocr.py index 98b1b71e92f72897bbde25a30382b604f5f00b06..68b2776100394422842303886c7a0172e6ee7cb5 100644 --- a/tests/test_cnocr.py +++ b/tests/test_cnocr.py @@ -33,7 +33,7 @@ from cnocr.consts import NUMBERS, AVAILABLE_MODELS from cnocr.line_split import line_split root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -example_dir = os.path.join(root_dir, 'examples') +example_dir = os.path.join(root_dir, 'docs/examples') CNOCR = CnOcr(model_name='densenet-s-fc', model_epoch=None) SINGLE_LINE_CASES = [ diff --git a/tests/test_dataset.py b/tests/test_dataset.py index d26aeecdd698c90ed79374e10397f37534b15a42..88c98d0f1612171c55abf5bc428b20a867f74ac2 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -9,7 +9,7 @@ from torchvision import transforms sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) -EXAMPLE_DIR = Path(__file__).parent.parent / 'examples' +EXAMPLE_DIR = Path(__file__).parent.parent / 'docs/examples' INDEX_DIR = Path(__file__).parent.parent / 'data/test' from cnocr.utils import save_img diff --git a/tests/test_models.py b/tests/test_models.py index aef6c6642e1dea3de9e70c4789298a256d131a8e..2f65f03d64e4e6b7691b3da99da422be9a47a03d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -3,124 +3,171 @@ import os import sys from copy import deepcopy import pytest -import mxnet as mx -from mxnet import nd +import torch +from torch import nn +from torchvision.models import ( + resnet50, + resnet34, + resnet18, + mobilenet_v3_large, + mobilenet_v3_small, + shufflenet_v2_x1_0, + shufflenet_v2_x1_5, + shufflenet_v2_x2_0, +) + sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) -from cnocr.consts import EMB_MODEL_TYPES, SEQ_MODEL_TYPES -from cnocr.utils import set_logger -from cnocr.hyperparams.cn_hyperparams import CnHyperparams -from cnocr.symbols.densenet import _make_dense_layer, DenseNet, cal_num_params -from cnocr.symbols.crnn import ( - CRnn, - pipline, - gen_network, - get_infer_shape, - crnn_lstm, - crnn_lstm_lite, -) +from cnocr.utils import set_logger, get_model_size +from cnocr.consts import IMG_STANDARD_HEIGHT, ENCODER_CONFIGS, DECODER_CONFIGS +from cnocr.models.densenet import DenseNet, DenseNetLite +from cnocr.models.mobilenet import gen_mobilenet_v3 -logger = set_logger('info') -HP = CnHyperparams() +logger = set_logger('info') -def test_dense_layer(): - x = nd.random.randn(128, 64, 32, 280) - net = _make_dense_layer(64, 2, 0.1) - net.initialize() - y = net(x) - logger.info(net) - logger.info(y.shape) +def test_conv(): + conv = nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2, bias=False) + input = torch.rand(1, 32, 10, 4) + res = conv(input) + logger.info(res.shape) def test_densenet(): width = 280 - x = nd.random.randn(128, 64, 32, width) - layer_channels = (64, 128, 256, 512) - for shorter in (False, True): - net = DenseNet(layer_channels, shorter=shorter) - net.initialize() - y = net(x) - logger.info(net) - logger.info(y.shape) # (128, 512, 1, 70) or (128, 512, 1, 35) - assert y.shape[2] == 1 - expected_seq_len = width // 8 if shorter else width // 4 - assert y.shape[3] == expected_seq_len - logger.info('number of parameters: %d', cal_num_params(net)) # 1748224 - - -def test_crnn(): - _hp = deepcopy(HP) - _hp.set_seq_length(_hp.img_width // 4) - x = nd.random.randn(128, 64, 32, 280) - layer_channels_list = [(64, 128, 256, 512), (32, 64, 128, 256)] - for layer_channels in layer_channels_list: - densenet = DenseNet(layer_channels) - crnn = CRnn(_hp, densenet) - crnn.initialize() - y = crnn(x) - logger.info( - 'output shape: %s', y.shape - ) # res: `(sequence_length, batch_size, 2*num_hidden)` - assert y.shape == (_hp.seq_length, _hp.batch_size, 2 * _hp.num_hidden) - logger.info('number of parameters: %d', cal_num_params(crnn)) - - -def test_crnn_lstm(): - hp = deepcopy(HP) - hp.set_seq_length(hp.img_width // 8) - data = mx.sym.Variable('data', shape=(128, 1, 32, 280)) - pred = crnn_lstm(HP, data) - pred_shape = pred.infer_shape()[1][0] - logger.info('shape of pred: %s', pred_shape) - assert pred_shape == (hp.seq_length, hp.batch_size, 2 * hp.num_hidden) - - -def test_crnn_lstm_lite(): - hp = deepcopy(HP) - width = hp.img_width # 280 - data = mx.sym.Variable('data', shape=(128, 1, 32, width)) - for shorter in (False, True): - pred = crnn_lstm_lite(HP, data, shorter=shorter) - pred_shape = pred.infer_shape()[1][0] - logger.info('shape of pred: %s', pred_shape) - seq_len = hp.img_width // 8 if shorter else hp.img_width // 4 - 1 - assert pred_shape == (seq_len, hp.batch_size, 2 * hp.num_hidden) - - -def test_pipline(): - hp = deepcopy(HP) - hp.set_seq_length(hp.img_width // 4) - hp._loss_type = None # infer mode - layer_channels_list = [(64, 128, 256, 512), (32, 64, 128, 256)] - for layer_channels in layer_channels_list: - densenet = DenseNet(layer_channels) - crnn = CRnn(hp, densenet) - data = mx.sym.Variable('data', shape=(128, 1, 32, 280)) - pred = pipline(crnn, hp, data) - pred_shape = pred.infer_shape()[1][0] - logger.info('shape of pred: %s', pred_shape) - assert pred_shape == (hp.batch_size * hp.seq_length, hp.num_classes) + img = torch.rand(4, 1, IMG_STANDARD_HEIGHT, width) + net = DenseNet(32, [2, 2, 2, 2], 64) + net.eval() + logger.info(net) + logger.info(f'model size: {get_model_size(net)}') # 406464 + logger.info(img.shape) + res = net(img) + logger.info(res.shape) + assert tuple(res.shape) == (4, 128, 4, 35) + + net = DenseNet(32, [1, 1, 1, 4], 64) + net.eval() + logger.info(net) + logger.info(f'model size: {get_model_size(net)}') # 301440 + logger.info(img.shape) + res = net(img) + logger.info(res.shape) + # assert tuple(res.shape) == (4, 100, 4, 35) + # + # net = DenseNet(32, [1, 1, 2, 2], 64) + # net.eval() + # logger.info(net) + # logger.info(f'model size: {get_model_size(net)}') # 243616 + # logger.info(img.shape) + # res = net(img) + # logger.info(res.shape) + # assert tuple(res.shape) == (4, 116, 4, 35) + # + # net = DenseNet(32, [1, 2, 2, 2], 64) + # net.eval() + # logger.info(net) + # logger.info(f'model size: {get_model_size(net)}') # 230680 + # logger.info(img.shape) + # res = net(img) + # logger.info(res.shape) + # assert tuple(res.shape) == (4, 124, 4, 35) + # + # net = DenseNet(32, [1, 1, 2, 4], 64) + # net.eval() + # logger.info(net) + # logger.info(f'model size: {get_model_size(net)}') # 230680 + # logger.info(img.shape) + # res = net(img) + # logger.info(res.shape) + # assert tuple(res.shape) == (4, 180, 4, 35) + + +def test_densenet_lite(): + width = 280 + img = torch.rand(4, 1, IMG_STANDARD_HEIGHT, width) + # net = DenseNetLite(32, [2, 2, 2], 64) + # net.eval() + # logger.info(net) + # logger.info(f'model size: {get_model_size(net)}') # 302976 + # logger.info(img.shape) + # res = net(img) + # logger.info(res.shape) + # assert tuple(res.shape) == (4, 128, 2, 35) + + # net = DenseNetLite(32, [2, 1, 1], 64) + # net.eval() + # logger.info(net) + # logger.info(f'model size: {get_model_size(net)}') # 197952 + # logger.info(img.shape) + # res = net(img) + # logger.info(res.shape) + # assert tuple(res.shape) == (4, 80, 2, 35) + + net = DenseNetLite(32, [1, 3, 4], 64) + net.eval() + logger.info(net) + logger.info(f'model size: {get_model_size(net)}') # 186672 + logger.info(img.shape) + res = net(img) + logger.info(res.shape) + assert tuple(res.shape) == (4, 200, 2, 35) + + net = DenseNetLite(32, [1, 3, 6], 64) + net.eval() + logger.info(net) + logger.info(f'model size: {get_model_size(net)}') # 186672 + logger.info(img.shape) + res = net(img) + logger.info(res.shape) + assert tuple(res.shape) == (4, 264, 2, 35) + + # net = DenseNetLite(32, [1, 2, 2], 64) + # net.eval() + # logger.info(net) + # logger.info(f'model size: {get_model_size(net)}') # + # logger.info(img.shape) + # res = net(img) + # logger.info(res.shape) + # assert tuple(res.shape) == (4, 120, 2, 35) + + +def test_mobilenet(): + width = 280 + img = torch.rand(4, 1, IMG_STANDARD_HEIGHT, width) + net = gen_mobilenet_v3('tiny') + net.eval() + logger.info(net) + res = net(img) + logger.info(f'model size: {get_model_size(net)}') # 186672 + logger.info(res.shape) + assert tuple(res.shape) == (4, 192, 2, 35) + net = gen_mobilenet_v3('small') + net.eval() + logger.info(net) + res = net(img) + logger.info(f'model size: {get_model_size(net)}') # 186672 + logger.info(res.shape) + assert tuple(res.shape) == (4, 192, 2, 35) MODEL_NAMES = [] -for emb_model in EMB_MODEL_TYPES: - for seq_model in SEQ_MODEL_TYPES: +for emb_model in ENCODER_CONFIGS: + for seq_model in DECODER_CONFIGS: MODEL_NAMES.append('%s-%s' % (emb_model, seq_model)) -@pytest.mark.parametrize( - 'model_name', MODEL_NAMES -) -def test_gen_networks(model_name): - logger.info('model_name: %s', model_name) - network, hp = gen_network(model_name, HP) - shape_dict = get_infer_shape(network, HP) - logger.info('shape_dict: %s', shape_dict) - assert shape_dict['pred_fc_output'] == ( - hp.batch_size * hp.seq_length, - hp.num_classes, - ) +# @pytest.mark.parametrize( +# 'model_name', MODEL_NAMES +# ) +# def test_gen_networks(model_name): +# logger.info('model_name: %s', model_name) +# network, hp = gen_network(model_name, HP) +# shape_dict = get_infer_shape(network, HP) +# logger.info('shape_dict: %s', shape_dict) +# assert shape_dict['pred_fc_output'] == ( +# hp.batch_size * hp.seq_length, +# hp.num_classes, +# ) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 5b61a6844be6868951ae82b3552dc60791ceb632..25f363eaf99923bbb30864d39b68b51faaa755b5 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -36,65 +36,3 @@ def test_crnn(): crnn = OcrModel(net, vocab=ENG_LETTERS, lstm_features=512, rnn_units=128) res2 = crnn(img) print(res2) - - -def test_crnn_for_variable_length(): - vocab, letter2id = read_charset(VOCAB_FP) - net = DenseNet(32, [2, 2, 2, 2], 64) - crnn = OcrModel(net, vocab=vocab, lstm_features=512, rnn_units=128) - crnn.eval() - model_fp = VOCAB_FP.parent / 'models/last.ckpt' - if model_fp.exists(): - print(f'load model params from {model_fp}') - load_model_params(crnn, model_fp) - width = 280 - img1 = torch.rand(1, IMG_STANDARD_HEIGHT, width) - img2 = torch.rand(1, IMG_STANDARD_HEIGHT, width // 2) - img3 = torch.rand(1, IMG_STANDARD_HEIGHT, width * 2) - imgs = pad_img_seq([img1, img2, img3]) - input_lengths = torch.Tensor([width, width // 2, width * 2]) - out = crnn( - imgs, input_lengths=input_lengths, return_model_output=True, return_preds=True, - ) - print(out['preds']) - - padded = torch.zeros((3, 1, IMG_STANDARD_HEIGHT, 50)) - imgs2 = torch.cat((imgs, padded), dim=-1) - out2 = crnn( - imgs2, input_lengths=input_lengths, return_model_output=True, return_preds=True, - ) - print(out2['preds']) - # breakpoint() - - -def test_crnn_for_variable_length2(): - vocab, letter2id = read_charset(VOCAB_FP) - net = DenseNet(32, [2, 2, 2, 2], 64) - crnn = OcrModel(net, vocab=vocab, lstm_features=512, rnn_units=128) - crnn.eval() - model_fp = VOCAB_FP.parent / 'models/last.ckpt' - if model_fp.exists(): - print(f'load model params from {model_fp}') - load_model_params(crnn, model_fp) - img_fps = ('helloworld.jpg', 'helloworld-ch.jpg') - imgs = [] - input_lengths = [] - for fp in img_fps: - img = read_img(VOCAB_FP.parent / 'examples' / fp) - img = rescale_img(img) - input_lengths.append(img.shape[2]) - imgs.append(normalize_img_array(img)) - imgs = pad_img_seq(imgs) - input_lengths = torch.Tensor(input_lengths) - out = crnn( - imgs, input_lengths=input_lengths, return_model_output=True, return_preds=True, - ) - print(out['preds']) - - padded = torch.zeros((2, 1, IMG_STANDARD_HEIGHT, 80)) - imgs2 = torch.cat((imgs, padded), dim=-1) - out2 = crnn( - imgs2, input_lengths=input_lengths, return_model_output=True, return_preds=True, - ) - print(out2['preds']) - # breakpoint() diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 8c83008c5718f7f5eb79941c2a2a2a27563a4043..53a434c5c2aa0f43522dbd70828e327291211ee0 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -9,7 +9,7 @@ from torchvision import transforms sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) -EXAMPLE_DIR = Path(__file__).parent.parent / 'examples' +EXAMPLE_DIR = Path(__file__).parent.parent / 'docs/examples' INDEX_DIR = Path(__file__).parent.parent / 'data/test' IMAGE_DIR = Path(__file__).parent.parent / 'data/images' diff --git a/tests/test_utils.py b/tests/test_utils.py index 2028a9a34c90f22c7f4155bd577f5db46f2dbdb2..0ece844f7b9b8e00e4155eac24f285927da22dd1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,7 +7,7 @@ from mxnet.gluon.utils import download sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) -EXAMPLE_DIR = Path(__file__).parent.parent / 'examples' +EXAMPLE_DIR = Path(__file__).parent.parent / 'docs/examples' from cnocr.utils import check_context, read_img