未验证 提交 e83e4ebc 编写于 作者: B BreezeDeus 提交者: GitHub

Merge pull request #66 from breezedeus/dev

V1.1.0
DATA_ROOT_DIR = data/sample-data
REC_DATA_ROOT_DIR = data/sample-data-lst
# `EMB_MODEL_TYPE` 可取值:['conv', 'conv-lite-rnn', 'densenet', 'densenet-lite']
EMB_MODEL_TYPE = densenet-lite
# `SEQ_MODEL_TYPE` 可取值:['lstm', 'gru', 'fc']
SEQ_MODEL_TYPE = fc
MODEL_NAME = $(EMB_MODEL_TYPE)-$(SEQ_MODEL_TYPE)
# 产生 *.lst 文件
gen-lst:
python scripts/im2rec.py --list --num-label 20 --chunks 1 \
--train-idx-fp $(DATA_ROOT_DIR)/train.txt --test-idx-fp $(DATA_ROOT_DIR)/test.txt --prefix $(REC_DATA_ROOT_DIR)/sample-data
# 利用 *.lst 文件产生 *.idx 和 *.rec 文件。
# 真正的图片文件存储在 `examples` 目录,可通过 `--root` 指定。
gen-rec:
python scripts/im2rec.py --pack-label --color 1 --num-thread 1 --prefix $(REC_DATA_ROOT_DIR) --root examples
# 训练模型
train:
python scripts/cnocr_train.py --gpu 0 --emb_model_type $(EMB_MODEL_TYPE) --seq_model_type $(SEQ_MODEL_TYPE) \
--optimizer adam --epoch 20 --lr 1e-4 \
--train_file $(REC_DATA_ROOT_DIR)/sample-data_train --test_file $(REC_DATA_ROOT_DIR)/sample-data_test
# 在测试集上评估模型,所有badcases的具体信息会存放到文件夹 `evaluate/$(MODEL_NAME)` 中
evaluate:
python scripts/cnocr_evaluate.py --model-name $(MODEL_NAME) --model-epoch 1 -v -i $(DATA_ROOT_DIR)/test.txt \
--image-prefix-dir examples --batch-size 128 -o evaluate/$(MODEL_NAME)
predict:
python scripts/cnocr_predict.py --model_name $(MODEL_NAME) --file examples/rand_cn1.png
.PHONY: gen-lst gen-rec train evaluate predict
DATA_ROOT_DIR = /data2/ocr/outer
REC_DATA_ROOT_DIR = /dev/data/jinlong/data
# ['conv', 'conv-lite', 'densenet', 'densenet-lite']
EMB_MODEL_TYPE = densenet-lite
SEQ_MODEL_TYPE = lstm
MODEL_NAME = $(EMB_MODEL_TYPE)-$(SEQ_MODEL_TYPE)
EPOCH = 20
gen-lst:
python scripts/im2rec.py --list --num-label 20 --chunks 1 --train-idx-fp $(DATA_ROOT_DIR)/train.txt \
--test-idx-fp $(DATA_ROOT_DIR)/test.txt --prefix $(DATA_ROOT_DIR)/lst/cnocr
gen-rec:
python scripts/im2rec.py --pack-label --color 1 --num-thread 1 --prefix $(DATA_ROOT_DIR)/lst --root $(DATA_ROOT_DIR)
##
## copy rec dir from $(DATA_ROOT_DIR) to $(REC_DATA_ROOT_DIR)
##
train:
nohup python scripts/cnocr_train.py --gpu 2 --emb_model_type $(EMB_MODEL_TYPE) --seq_model_type $(SEQ_MODEL_TYPE) \
--optimizer Adam --epoch $(EPOCH) --lr 3e-4 \
--train_file $(REC_DATA_ROOT_DIR)/lst/cnocr_train --test_file $(REC_DATA_ROOT_DIR)/lst/cnocr_test \
>> nohup-$(MODEL_NAME).out 2>&1 &
predict:
python scripts/cnocr_predict.py --model_name $(MODEL_NAME) --file examples/rand_cn1.png
.PHONY: gen-lst gen-rec train predict
此差异已折叠。
# Update 2019.07.25: 发布 cnocr V1.0.0
`cnocr`发布了预测效率更高的新版本v1.0.0。**新版本的模型跟以前版本的模型不兼容**。所以如果大家是升级的话,需要重新下载最新的模型文件。具体说明见下面(流程和原来相同)。
主要改动如下:
- **crnn模型支持可变长预测,提升预测效率**
- 支持利用特定数据对现有模型进行精调(继续训练)
- 修复bugs,如训练时`accuracy`一直为`0`
- 依赖的 `mxnet` 版本从`1.3.1`更新至 `1.4.1`
# cnocr
**cnocr**是用来做中文OCR的**Python 3**包。cnocr自带了训练好的识别模型,所以安装后即可直接使用。
目前使用的识别模型是**crnn**,识别准确度约为 `98.8%`
本项目起源于我们自己 ([爱因互动 Ein+](https://einplus.cn)) 内部的项目需求,所以非常感谢公司的支持。
## 特色
本项目的大部分代码都fork自 [crnn-mxnet-chinese-text-recognition](https://github.com/diaomin/crnn-mxnet-chinese-text-recognition),感谢作者。
但源项目使用起来不够方便,所以我在此基础上做了一些封装和重构。主要变化如下:
* 不再使用需要额外安装的MXNet WarpCTC Loss,改用原生的 MXNet CTC Loss。所以安装极简!
* 自带训练好的中文OCR识别模型。不再需要额外训练!
* 增加了预测(或推断)接口。所以使用方便!
## 安装
```bash
pip install cnocr
```
> 注意:请使用Python3 (3.4, 3.5, 3.6以及之后版本应该都行),没测过Python2下是否ok。
## 使用方法
首次使用cnocr时,系统会自动从[Dropbox](https://www.dropbox.com/s/7w8l3mk4pvkt34w/cnocr-models-v1.0.0.zip?dl=0)下载zip格式的模型压缩文件,并存于 `~/.cnocr`目录。
下载后的zip文件代码会自动对其解压,然后把解压后的模型相关文件放于`~/.cnocr/models`目录。
如果系统不能自动从[Dropbox](https://www.dropbox.com/s/7w8l3mk4pvkt34w/cnocr-models-v1.0.0.zip?dl=0)成功下载zip文件,则需要手动下载此zip文件并把它放于 `~/.cnocr`目录。
另一个下载地址是[百度云盘](https://pan.baidu.com/s/1DWV3H2UWmzOU6d48UbTYVw)(提取码为`ss81`)
放置好zip文件后,后面的事代码就会自动执行了。
### 代码预测
主要包含三个函数,下面分别说明。
#### 1. 函数`CnOcr.ocr(img_fp)`
函数`CnOcr.ocr(img_fp)`可以对包含多行文字(或单行)的图片进行文字识别。
**函数说明**
- 输入参数 `img_fp`: 可以是需要识别的图片文件路径(如上例);或者是已经从图片文件中读入的数组,类型可以为`mx.nd.NDArray``np.ndarray`,取值应该是`[0,255]`的整数,维数应该是`(height, width, 3)`,第三个维度是channel,它应该是`RGB`格式的。
- 返回值:为一个嵌套的`list`,类似这样`[['第', '一', '行'], ['第', '二', '行'], ['第', '三', '行']]`
**调用示例**
```python
from cnocr import CnOcr
ocr = CnOcr()
res = ocr.ocr('examples/multi-line_cn1.png')
print("Predicted Chars:", res)
```
或:
```python
import mxnet as mx
from cnocr import CnOcr
ocr = CnOcr()
img_fp = 'examples/multi-line_cn1.png'
img = mx.image.imread(img_fp, 1)
res = ocr.ocr(img)
print("Predicted Chars:", res)
```
上面使用的图片文件 [examples/multi-line_cn1.png](./examples/multi-line_cn1.png)内容如下:
![examples/multi-line_cn1.png](./examples/multi-line_cn1.png)
上面预测代码段的返回结果如下:
```bash
Predicted Chars: [['网', '络', '支', '付', '并', '无', '本', '质', '的', '区', '别', ',', '因', '为'],
['每', '一', '个', '手', '机', '号', '码', '和', '邮', '件', '地', '址', '背', '后'],
['都', '会', '对', '应', '着', '一', '个', '账', '户', '一', '―', '这', '个', '账'],
['户', '可', '以', '是', '信', '用', '卡', '账', '户', '、', '借', '记', '卡', '账'],
['户', ',', '也', '包', '括', '邮', '局', '汇', '款', '、', '手', '机', '代'],
['收', '、', '电', '话', '代', '收', '、', '预', '付', '费', '卡', '和', '点', '卡'],
['等', '多', '种', '形', '式', '。']]
```
#### 2. 函数`CnOcr.ocr_for_single_line(img_fp)`
如果明确知道要预测的图片中只包含了单行文字,可以使用函数`CnOcr.ocr_for_single_line(img_fp)`进行识别。和 `CnOcr.ocr()`相比,`CnOcr.ocr_for_single_line()`结果可靠性更强,因为它不需要做额外的分行处理。
**函数说明**
- 输入参数 `img_fp`: 可以是需要识别的单行文字图片文件路径(如上例);或者是已经从图片文件中读入的数组,类型可以为`mx.nd.NDArray``np.ndarray`,取值应该是`[0,255]`的整数,维数应该是`(height, width)``(height, width, channel)`。如果没有channel,表示传入的就是灰度图片。第三个维度channel可以是`1`(灰度图片)或者`3`(彩色图片)。如果是彩色图片,它应该是`RGB`格式的。
- 返回值:为一个`list`,类似这样`['你', '好']`
**调用示例**
```python
from cnocr import CnOcr
ocr = CnOcr()
res = ocr.ocr_for_single_line('examples/rand_cn1.png')
print("Predicted Chars:", res)
```
或:
```python
import mxnet as mx
from cnocr import CnOcr
ocr = CnOcr()
img_fp = 'examples/rand_cn1.png'
img = mx.image.imread(img_fp, 1)
res = ocr.ocr_for_single_line(img)
print("Predicted Chars:", res)
```
对图片文件 [examples/rand_cn1.png](./examples/rand_cn1.png)
![examples/rand_cn1.png](./examples/rand_cn1.png)
的预测结果如下:
```bash
Predicted Chars: ['笠', '淡', '嘿', '骅', '谧', '鼎', '臭', '姚', '歼', '蠢', '驼', '耳', '裔', '挝', '涯', '狗', '蒽', '子', '犷']
```
#### 3. 函数`CnOcr.ocr_for_single_lines(img_list)`
函数`CnOcr.ocr_for_single_lines(img_list)`可以**对多个单行文字图片进行批量预测**。函数`CnOcr.ocr(img_fp)``CnOcr.ocr_for_single_line(img_fp)`内部其实都是调用的函数`CnOcr.ocr_for_single_lines(img_list)`
**函数说明**
- 输入参数` img_list`: 为一个`list`;其中每个元素是已经从图片文件中读入的数组,类型可以为`mx.nd.NDArray``np.ndarray`,取值应该是`[0,255]`的整数,维数应该是`(height, width)``(height, width, channel)`。如果没有channel,表示传入的就是灰度图片。第三个维度channel可以是`1`(灰度图片)或者`3`(彩色图片)。如果是彩色图片,它应该是`RGB`格式的。
- 返回值:为一个嵌套的`list`,类似这样`[['第', '一', '行'], ['第', '二', '行'], ['第', '三', '行']]`
**调用示例**
```python
import mxnet as mx
from cnocr import CnOcr
ocr = CnOcr()
img_fp = 'examples/multi-line_cn1.png'
img = mx.image.imread(img_fp, 1).asnumpy()
line_imgs = line_split(img, blank=True)
line_img_list = [line_img for line_img, _ in line_imgs]
res = ocr.ocr_for_single_lines(line_img_list)
print("Predicted Chars:", res)
```
更详细的使用方法,可参考[tests/test_cnocr.py](./tests/test_cnocr.py)中提供的测试用例。
### 脚本引用
也可以使用脚本模式预测:
```bash
python scripts/cnocr_predict.py --file examples/multi-line_cn1.png
```
返回结果同上面。
### 训练自己的模型
cnocr安装后即可直接使用,但如果你**非要**训练自己的模型,请参考下面命令:
```bash
python scripts/cnocr_train.py --cpu 2 --num_proc 4 --loss ctc --dataset cn_ocr
```
现在也支持从已有模型利用特定数据精调模型,请参考下面命令:
```bash
python scripts/cnocr_train.py --cpu 2 --num_proc 4 --loss ctc --dataset cn_ocr --load_epoch 20
```
更多可参考脚本[scripts/run_cnocr_train.sh](./scripts/run_cnocr_train.sh)中的命令。
## 未来工作
* [x] 支持图片包含多行文字 (`Done`)
* [x] crnn模型支持可变长预测,提升灵活性 (`Done`)
* [x] 完善测试用例 (`Doing`)
* [x] 修bugs(目前代码还比较凌乱。。) (`Doing`)
* [ ] 支持`空格`识别(`V1.0.0`在训练集中加入了空格,但从预测结果看,空格依旧是识别不出来)
* 尝试新模型,如 DenseNet、ResNet,进一步提升识别准确率
中文版说明请见[中文README](./README.md)
# Update 2019.07.25: release cnocr V1.0.0
`cnocr` `v1.0.0` is released, which is more efficient for prediction. **The new version of the model is not compatible with the previous version.** So if upgrading, please download the latest model file again. See below for the details (same as before).
Main changes are:
- **The new crnn model supports prediction for variable-width image files, so is more efficient for prediction.**
- Support fine-tuning the existing model with specific data.
- Fix bugs,such as `train accuracy` always `0`.
- Depended package `mxnet` is upgraded from `1.3.1` to `1.4.1`.
# cnocr
A python package for Chinese OCR with available trained models.
So it can be used directly after installed.
The accuracy of the current crnn model is about `98.8%`.
The project originates from our own ([爱因互动 Ein+](https://einplus.cn)) internal needs.
Thanks for the internal supports.
## Changes
Most of the codes are adapted from [crnn-mxnet-chinese-text-recognition](https://github.com/diaomin/crnn-mxnet-chinese-text-recognition).
Much thanks to the author.
Some changes are:
* use raw MXNet CTC Loss instead of WarpCTC Loss. No more complicated installation.
* public pre-trained model for anyone. No more a-few-days training.
* add online `predict` function and script. Easy to use.
## Installation
```bash
pip install cnocr
```
> Please use Python3 (3.4, 3.5, 3.6 should work). Python2 is not tested.
## Usage
The first time cnocr is used, the model files will be downloaded automatically from
[Dropbox](https://www.dropbox.com/s/7w8l3mk4pvkt34w/cnocr-models-v1.0.0.zip?dl=0) to `~/.cnocr`.
The zip file will be extracted and you can find the resulting model files in `~/.cnocr/models` by default.
In case the automatic download can't perform well, you can download the zip file manually
from [Baidu NetDisk](https://pan.baidu.com/s/1DWV3H2UWmzOU6d48UbTYVw) with extraction code `ss81`, and put the zip file to `~/.cnocr`. The code will do else.
### Predict
Three functions are provided for prediction.
#### 1. `CnOcr.ocr(img_fp)`
The function `cnOcr.ocr (img_fp)` can recognize texts in an image containing multiple lines of text (or single lines).
**Function Description**
- input parameter `img_fp`: image file path; or color image `mx.nd.NDArray` or `np.ndarray`, with shape `(height, width, 3)`, and the channels should be RGB formatted.
- return: `List(List(Char))`, such as: `[['第', '一', '行'], ['第', '二', '行'], ['第', '三', '行']]`.
**Use Case**
```python
from cnocr import CnOcr
ocr = CnOcr()
res = ocr.ocr('examples/multi-line_cn1.png')
print("Predicted Chars:", res)
```
or:
```python
import mxnet as mx
from cnocr import CnOcr
ocr = CnOcr()
img_fp = 'examples/multi-line_cn1.png'
img = mx.image.imread(img_fp, 1)
res = ocr.ocr(img)
print("Predicted Chars:", res)
```
The previous codes can recognize texts in the image file [examples/multi-line_cn1.png](./examples/multi-line_cn1.png):
![examples/multi-line_cn1.png](./examples/multi-line_cn1.png)
The OCR results shoule be:
```bash
Predicted Chars: [['网', '络', '支', '付', '并', '无', '本', '质', '的', '区', '别', ',', '因', '为'],
['每', '一', '个', '手', '机', '号', '码', '和', '邮', '件', '地', '址', '背', '后'],
['都', '会', '对', '应', '着', '一', '个', '账', '户', '一', '―', '这', '个', '账'],
['户', '可', '以', '是', '信', '用', '卡', '账', '户', '、', '借', '记', '卡', '账'],
['户', ',', '也', '包', '括', '邮', '局', '汇', '款', '、', '手', '机', '代'],
['收', '、', '电', '话', '代', '收', '、', '预', '付', '费', '卡', '和', '点', '卡'],
['等', '多', '种', '形', '式', '。']]
```
#### 2. `CnOcr.ocr_for_single_line(img_fp)`
If you know that the image you're predicting contains only one line of text, function `CnOcr.ocr_for_single_line(img_fp)` can be used instead。Compared with `CnOcr.ocr()`, the result of `CnOcr.ocr_for_single_line()` is more reliable because the process of splitting lines is not required.
**Function Description**
- input parameter `img_fp`: image file path; or color image `mx.nd.NDArray` or `np.ndarray`, with shape `[height, width]` or `[height, width, channel]`. The optional channel should be `1` (gray image) or `3` (color image).
- return: `List(Char)`, such as: `['你', '好']`.
**Use Case**
```python
from cnocr import CnOcr
ocr = CnOcr()
res = ocr.ocr_for_single_line('examples/rand_cn1.png')
print("Predicted Chars:", res)
```
or:
```python
import mxnet as mx
from cnocr import CnOcr
ocr = CnOcr()
img_fp = 'examples/rand_cn1.png'
img = mx.image.imread(img_fp, 1)
res = ocr.ocr_for_single_line(img)
print("Predicted Chars:", res)
```
The previous codes can recognize texts in the image file [examples/rand_cn1.png](./examples/rand_cn1.png)
![examples/rand_cn1.png](./examples/rand_cn1.png)
The OCR results shoule be:
```bash
Predicted Chars: ['笠', '淡', '嘿', '骅', '谧', '鼎', '臭', '姚', '歼', '蠢', '驼', '耳', '裔', '挝', '涯', '狗', '蒽', '子', '犷']
```
#### 3. `CnOcr.ocr_for_single_lines(img_list)`
Function `CnOcr.ocr_for_single_lines(img_list)` can predict a number of single-line-text image arrays batchly. Actually `CnOcr.ocr(img_fp)` and `CnOcr.ocr_for_single_line(img_fp)` both invoke `CnOcr.ocr_for_single_lines(img_list)` internally.
**Function Description**
- input parameter `img_list`: list of images, in which each element should be a line image array, with type `mx.nd.NDArray` or `np.ndarray`. Each element should be a tensor with values ranging from `0` to` 255`, and with shape `[height, width]` or `[height, width, channel]`. The optional channel should be `1` (gray image) or `3` (color image).
- return: `List(List(Char))`, such as: `[['第', '一', '行'], ['第', '二', '行'], ['第', '三', '行']]`.
**Use Case**
```python
import mxnet as mx
from cnocr import CnOcr
ocr = CnOcr()
img_fp = 'examples/multi-line_cn1.png'
img = mx.image.imread(img_fp, 1).asnumpy()
line_imgs = line_split(img, blank=True)
line_img_list = [line_img for line_img, _ in line_imgs]
res = ocr.ocr_for_single_lines(line_img_list)
print("Predicted Chars:", res)
```
More use cases can be found at [tests/test_cnocr.py](./tests/test_cnocr.py).
### Using the Script
```bash
python scripts/cnocr_predict.py --file examples/multi-line_cn1.png
```
### (No NECESSARY) Train
You can use the package without any train. But if you really really want to train your own models, follow this:
```bash
python scripts/cnocr_train.py --cpu 2 --num_proc 4 --loss ctc --dataset cn_ocr
```
Fine-tuning the model with specific data from existing models is also supported. Please refer to the following command:
```bash
python scripts/cnocr_train.py --cpu 2 --num_proc 4 --loss ctc --dataset cn_ocr --load_epoch 20
```
More references can be found at [scripts/run_cnocr_train.sh](./scripts/run_cnocr_train.sh).
## Future Work
* [x] support multi-line-characters recognition (`Done`)
* [x] crnn model supports prediction for variable-width image files (`Done`)
* [x] Add Unit Tests (`Doing`)
* [x] Bugfixes (`Doing`)
* [ ] Support space recognition (Tried, but not successful for now )
* [ ] Try other models such as DenseNet, ResNet
__version__ = '1.0.0'
__version__ = '1.1.0'
......@@ -21,13 +21,19 @@ import numpy as np
from PIL import Image
from cnocr.__version__ import __version__
from cnocr.consts import MODEL_EPOCE
from cnocr.consts import AVAILABLE_MODELS
from cnocr.hyperparams.cn_hyperparams import CnHyperparams as Hyperparams
from cnocr.fit.lstm import init_states
from cnocr.fit.ctc_metrics import CtcMetrics
from cnocr.data_utils.data_iter import SimpleBatch
from cnocr.symbols.crnn import crnn_lstm
from cnocr.utils import data_dir, get_model_file, read_charset, normalize_img_array
from cnocr.symbols.crnn import gen_network
from cnocr.utils import (
data_dir,
get_model_file,
read_charset,
normalize_img_array,
check_model_name,
)
from cnocr.line_split import line_split
......@@ -54,7 +60,6 @@ def rescale_img(img, hp):
img = mx.nd.array(img)
scale = hp.img_height / img.shape[0]
new_width = int(scale * img.shape[1])
hp._seq_length = new_width // 8
if len(img.shape) == 2: # mx.image.imresize needs the third dim
img = mx.nd.expand_dims(img, 2)
img = mx.image.imresize(img, w=new_width, h=hp.img_height).asnumpy()
......@@ -64,11 +69,13 @@ def rescale_img(img, hp):
def lstm_init_states(batch_size, hp):
""" Returns a tuple of names and zero arrays for LSTM init states"""
init_shapes = init_states(batch_size=batch_size, num_lstm_layer=hp.num_lstm_layer, num_hidden=hp.num_hidden)
init_shapes = init_states(
batch_size=batch_size,
num_lstm_layer=hp.num_lstm_layer,
num_hidden=hp.num_hidden,
)
init_names = [s[0] for s in init_shapes]
init_arrays = [mx.nd.zeros(x[1]) for x in init_shapes]
# init_names.append('seq_length')
# init_arrays.append(hp.seq_length)
return init_names, init_arrays
......@@ -86,31 +93,62 @@ def load_module(prefix, epoch, data_names, data_shapes, network=None):
pred_fc = sym.get_internals()['pred_fc_output']
sym = mx.sym.softmax(data=pred_fc)
mod = mx.mod.Module(symbol=sym, context=mx.cpu(), data_names=data_names, label_names=None)
mod = mx.mod.Module(
symbol=sym, context=mx.cpu(), data_names=data_names, label_names=None
)
mod.bind(for_training=False, data_shapes=data_shapes)
mod.set_params(arg_params, aux_params, allow_missing=False)
return mod
class CnOcr(object):
MODEL_FILE_PREFIX = 'model-v{}'.format(__version__)
MODEL_FILE_PREFIX = 'cnocr-v{}'.format(__version__)
def __init__(
self,
model_name='conv-lite-fc',
model_epoch=None,
cand_alphabet=None,
root=data_dir(),
):
"""
def __init__(self, root=data_dir(), model_epoch=MODEL_EPOCE):
self._model_dir = os.path.join(root, 'models')
self._model_epoch = model_epoch
self._assert_and_prepare_model_files(root)
self._alphabet, _ = read_charset(os.path.join(self._model_dir, 'label_cn.txt'))
:param model_name: 模型名称
:param model_epoch: 模型迭代次数
:param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
:param root: 模型文件所在的根目录。
Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.1.0/conv-lite-fc-0027`。
Windows下默认值为 ``。
"""
check_model_name(model_name)
self._model_name = model_name
self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX, model_name)
self._model_epoch = model_epoch or AVAILABLE_MODELS[model_name][0]
root = os.path.join(root, __version__)
self._model_dir = os.path.join(root, self._model_name)
self._assert_and_prepare_model_files()
self._alphabet, inv_alph_dict = read_charset(
os.path.join(self._model_dir, 'label_cn.txt')
)
self._cand_alph_idx = None
if cand_alphabet is not None:
self._cand_alph_idx = [0] + [inv_alph_dict[word] for word in cand_alphabet]
self._cand_alph_idx.sort()
self._hp = Hyperparams()
self._hp._loss_type = None # infer mode
self._mod = self._get_module(self._hp)
self._mod = self._get_module()
def _assert_and_prepare_model_files(self, root):
def _assert_and_prepare_model_files(self):
model_dir = self._model_dir
model_files = ['label_cn.txt',
'%s-%04d.params' % (self.MODEL_FILE_PREFIX, self._model_epoch),
'%s-symbol.json' % self.MODEL_FILE_PREFIX]
model_files = [
'label_cn.txt',
'%s-%04d.params' % (self._model_file_prefix, self._model_epoch),
'%s-symbol.json' % self._model_file_prefix,
]
file_prepared = True
for f in model_files:
f = os.path.join(model_dir, f)
......@@ -121,17 +159,18 @@ class CnOcr(object):
if file_prepared:
return
if os.path.exists(model_dir):
os.removedirs(model_dir)
get_model_file(root)
get_model_file(model_dir)
def _get_module(self, hp):
network = crnn_lstm(hp)
prefix = os.path.join(self._model_dir, self.MODEL_FILE_PREFIX)
# import pdb; pdb.set_trace()
def _get_module(self):
network, self._hp = gen_network(self._model_name, self._hp)
hp = self._hp
prefix = os.path.join(self._model_dir, self._model_file_prefix)
data_names = ['data']
data_shapes = [(data_names[0], (hp.batch_size, 1, hp.img_height, hp.img_width))]
mod = load_module(prefix, self._model_epoch, data_names, data_shapes, network=network)
print('loading model parameters from dir %s' % self._model_dir)
mod = load_module(
prefix, self._model_epoch, data_names, data_shapes, network=network
)
return mod
def ocr(self, img_fp):
......@@ -141,7 +180,9 @@ class CnOcr(object):
:return: List(List(Char)), such as:
[['第', '一', '行'], ['第', '二', '行'], ['第', '三', '行']]
"""
if isinstance(img_fp, str) and os.path.isfile(img_fp):
if isinstance(img_fp, str):
if not os.path.isfile(img_fp):
raise FileNotFoundError(img_fp)
img = mx.image.imread(img_fp, 1).asnumpy()
elif isinstance(img_fp, mx.nd.NDArray):
img = img_fp.asnumpy()
......@@ -151,6 +192,8 @@ class CnOcr(object):
raise TypeError('Inappropriate argument type.')
if min(img.shape[0], img.shape[1]) < 2:
return ''
if img.mean() < 145: # 把黑底白字的图片对调为白底黑字
img = 255 - img
line_imgs = line_split(img, blank=True)
line_img_list = [line_img for line_img, _ in line_imgs]
line_chars_list = self.ocr_for_single_lines(line_img_list)
......@@ -164,7 +207,9 @@ class CnOcr(object):
The optional channel should be 1 (gray image) or 3 (color image).
:return: character list, such as ['你', '好']
"""
if isinstance(img_fp, str) and os.path.isfile(img_fp):
if isinstance(img_fp, str):
if not os.path.isfile(img_fp):
raise FileNotFoundError(img_fp)
img = read_ocr_img(img_fp)
elif isinstance(img_fp, mx.nd.NDArray) or isinstance(img_fp, np.ndarray):
img = img_fp
......@@ -191,20 +236,30 @@ class CnOcr(object):
batch_size = len(img_list)
img_list, img_widths = self._pad_arrays(img_list)
# import pdb; pdb.set_trace()
sample = SimpleBatch(
data_names=['data'],
data=[mx.nd.array(img_list)])
sample = SimpleBatch(data_names=['data'], data=[mx.nd.array(img_list)])
prob = self._predict(sample)
prob = np.reshape(prob, (-1, batch_size, prob.shape[1])) # [seq_len, batch_size, num_classes]
# [seq_len, batch_size, num_classes]
prob = np.reshape(prob, (-1, batch_size, prob.shape[1]))
if self._cand_alph_idx is not None:
prob = prob * self._gen_mask(prob.shape)
max_width = max(img_widths)
res = []
for i in range(batch_size):
res.append(self._gen_line_pred_chars(prob[:, i, :], img_widths[i], max_width))
res.append(
self._gen_line_pred_chars(prob[:, i, :], img_widths[i], max_width)
)
return res
def _gen_mask(self, prob_shape):
mask_shape = list(prob_shape)
mask_shape[1] = 1
mask = np.zeros(mask_shape, dtype='int8')
mask[:, :, self._cand_alph_idx] = 1
return mask
def _preprocess_img_array(self, img):
"""
:param img: image array with type mx.nd.NDArray or np.ndarray,
......@@ -253,8 +308,6 @@ class CnOcr(object):
:return:
"""
class_ids = np.argmax(line_prob, axis=-1)
# idxs = list(zip(range(len(class_ids)), class_ids))
# probs = [line_prob[e[0], e[1]] for e in idxs]
if img_width < max_img_width:
comp_ratio = self._hp.seq_len_cmpr_ratio
......@@ -262,32 +315,7 @@ class CnOcr(object):
if end_idx < len(class_ids):
class_ids[end_idx:] = 0
prediction, start_end_idx = CtcMetrics.ctc_label(class_ids.tolist())
# print(start_end_idx)
alphabet = self._alphabet
res = [alphabet[p] for p in prediction]
res = [alphabet[p] if alphabet[p] != '<space>' else ' ' for p in prediction]
# res = self._insert_space_char(res, start_end_idx)
return res
def _insert_space_char(self, pred_chars, start_end_idx, min_interval=None):
if len(pred_chars) < 2:
return pred_chars
assert len(pred_chars) == len(start_end_idx)
if min_interval is None:
# 自动计算最小区间值
intervals = {start_end_idx[idx][0] - start_end_idx[idx-1][1] for idx in range(1, len(start_end_idx))}
if len(intervals) >= 3:
intervals = sorted(list(intervals))
if intervals[0] < 1: # 排除间距为0的情况
intervals = intervals[1:]
min_interval = intervals[2]
else:
min_interval = start_end_idx[-1][1] # no space will be inserted
res_chars = [pred_chars[0]]
for idx in range(1, len(pred_chars)):
if start_end_idx[idx][0] - start_end_idx[idx-1][1] >= min_interval:
res_chars.append(' ')
res_chars.append(pred_chars[idx])
return res_chars
# coding: utf-8
import os
import string
from .__version__ import __version__
MODEL_BASE_URL = 'https://www.dropbox.com/s/7w8l3mk4pvkt34w/cnocr-models-v1.0.0.zip?dl=1'
MODEL_EPOCE = 20
EMB_MODEL_TYPES = ['conv', 'conv-lite', 'densenet', 'densenet-lite']
SEQ_MODEL_TYPES = ['lstm', 'gru', 'fc']
ZIP_FILE_NAME = 'cnocr-models-v{}.zip'.format(__version__)
root_url = (
'https://raw.githubusercontent.com/breezedeus/cnocr-models/master/models/%s'
% __version__
)
# name: (epochs, url)
AVAILABLE_MODELS = {
'conv-lstm': (50, root_url + '/conv-lstm.zip'),
'conv-lite-lstm': (45, root_url + '/conv-lite-lstm.zip'),
'conv-lite-fc': (27, root_url + '/conv-lite-fc.zip'),
'densenet-lite-lstm': (42, root_url + '/densenet-lite-lstm.zip'),
'densenet-lite-fc': (32, root_url + '/densenet-lite-fc.zip'),
}
# 候选字符集合
NUMBERS = string.digits + string.punctuation
ENG_LETTERS = string.digits + string.ascii_letters + string.punctuation
import random
import numpy as np
from PIL import Image
from mxnet import nd
from mxnet.image import Augmenter
class GrayAug(Augmenter):
"""FIXME don't use this one"""
def __call__(self, img):
"""
:param img: nd.NDArray with shape [height, width, channel] and dtype 'float32'. channel should be 3.
:return: nd.NDArray with shape [height, width, 1] and dtype 'uint8'.
"""
if img.dtype != np.uint8:
img = img.astype('uint8')
# color to gray
img = np.array(Image.fromarray(img.asnumpy()).convert('L'))
return nd.expand_dims(nd.array(img, dtype='uint8'), 2)
class FgBgFlipAug(Augmenter):
"""前景色背景色对调。
Parameters
----------
p : float
Probability to flip image horizontally
"""
def __init__(self, p):
super(FgBgFlipAug, self).__init__(p=p)
self.p = p
def __call__(self, src):
"""Augmenter body"""
if random.random() < self.p:
src = 255 - src
return src
......@@ -3,9 +3,13 @@ from __future__ import print_function
import os
from PIL import Image
import numpy as np
from mxnet import nd
import mxnet as mx
import random
from mxnet.image import ImageIter
from mxnet.io import io
from ..utils import normalize_img_array
from .multiproc_data import MPData
......@@ -209,7 +213,7 @@ class MPOcrImages(object):
img = Image.open(img_path).resize(self.data_shape, Image.BILINEAR).convert('L')
img = np.array(img)
# print(img.shape)
img = np.transpose(img, (1, 0)) # res: [1, width, height]
img = np.transpose(img, (1, 0)) # res: [width, height]
img = normalize_img_array(img)
# print(np.mean(img), np.std(img))
# if len(img.shape) == 2:
......@@ -310,3 +314,48 @@ class OCRIter(mx.io.DataIter):
data_batch = SimpleBatch(data_names, data_all, label_names, label_all)
yield data_batch
class GrayImageIter(ImageIter):
def __init__(self, batch_size, data_shape, **kwargs):
assert 'data_name' not in kwargs and 'label_name' not in kwargs
super(GrayImageIter, self).__init__(batch_size, data_shape, data_name='data', label_name='label', **kwargs)
self.provide_data = [('data', (batch_size, 1) + data_shape[1:])]
def next(self):
"""
:return: io.DataBatch, which attribute `data` is nd.NDArray,
with shape [batch_size, 1, height, width] and dtype 'uint8'.
"""
data_batch = super().next()
new_data = [self._post_process(sub_data) for sub_data in data_batch.data]
# data_names = ['data']
# label_names = ['label']
# return SimpleBatch(data_names, [new_data], label_names, data_batch.label)
return io.DataBatch(new_data, data_batch.label, pad=data_batch.pad)
@classmethod
def _post_process(cls, data):
"""
:param data: nd.NDArray with shape [batch_size, channel, height, width]. channel should be 3.
:return: nd.NDArray with shape [batch_size, 1, height, width] and dtype 'uint8'.
:param data:
:return:
"""
data_shape = list(data.shape)
data_shape[1] = 1 # [batch_size, 1, height, width]
new_data = nd.zeros(tuple(data_shape), dtype='float32')
batch_size = data.shape[0]
for i in range(batch_size):
img = data[i] # shape: [channel, height, width]
if img.dtype != np.uint8:
img = img.astype('uint8')
# color to gray
img = np.array(Image.fromarray(img.transpose((1, 2, 0)).asnumpy()).convert('L'))
img = normalize_img_array(img, dtype='float32')
new_data[i] = nd.expand_dims(nd.array(img), 0) # res shape: [1, height, width]
return new_data
......@@ -9,16 +9,18 @@ def _load_model(args):
assert args.prefix is not None
model_prefix = args.prefix
sym, arg_params, aux_params = mx.model.load_checkpoint(
model_prefix, args.load_epoch)
model_prefix, args.load_epoch
)
logging.info('Loaded model %s-%04d.params', model_prefix, args.load_epoch)
return sym, arg_params, aux_params
def fit(network, data_train, data_val, metrics, args, hp, data_names=None):
if args.gpu:
if args.gpu > 0:
contexts = [mx.context.gpu(i) for i in range(args.gpu)]
else:
contexts = [mx.context.cpu(i) for i in range(args.cpu)]
contexts = [mx.context.cpu()]
logging.info('hp: %s', hp)
sym, arg_params, aux_params = _load_model(args)
if sym is not None:
......@@ -27,48 +29,35 @@ def fit(network, data_train, data_val, metrics, args, hp, data_names=None):
os.makedirs(os.path.dirname(args.prefix))
module = mx.mod.Module(
symbol=network,
data_names=["data"] if data_names is None else data_names,
label_names=['label'],
context=contexts)
# from mxnet import nd
# import numpy as np
# data = nd.random.uniform(shape=(128, 1, 32, 100))
# label = np.random.randint(1, 11, size=(128, 4))
# module.bind(data_shapes=[('data', (128, 1, 32, 100))], label_shapes=[('label', (128, 4))])
# # e = module.bind()
# # f = e.forward(is_train=False)
# module.init_params(mx.init.Xavier(factor_type="in", magnitude=2.34))
# from ..data_utils.data_iter import SimpleBatch
# data_all = [data]
# label_all = [mx.nd.array(label)]
# # print(label_all[0])
# # data_names = ['data'] + init_state_names
# data_names = ['data']
# label_names = ['label']
#
# data_batch = SimpleBatch(data_names, data_all, label_names, label_all)
# module.forward(data_batch)
# f = module.get_outputs()
# import pdb; pdb.set_trace()
symbol=network,
data_names=["data"] if data_names is None else data_names,
label_names=['label'],
context=contexts,
)
begin_epoch = args.load_epoch if args.load_epoch else 0
num_epoch = hp.num_epoch + begin_epoch
module.fit(train_data=data_train,
eval_data=data_val,
begin_epoch=begin_epoch,
num_epoch=num_epoch,
# use metrics.accuracy or metrics.accuracy_lcs
eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True),
optimizer='AdaDelta',
optimizer_params={'learning_rate': hp.learning_rate,
# 'momentum': hp.momentum,
'wd': 0.00001,
},
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
arg_params=arg_params,
aux_params=aux_params,
batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50),
epoch_end_callback=mx.callback.do_checkpoint(args.prefix),
)
\ No newline at end of file
optimizer_params = {
'learning_rate': hp.learning_rate,
# 'momentum': hp.momentum,
'wd': hp.wd,
}
if hp.clip_gradient is not None:
optimizer_params['clip_gradient'] = hp.clip_gradient
module.fit(
train_data=data_train,
eval_data=data_val,
begin_epoch=begin_epoch,
num_epoch=num_epoch,
# use metrics.accuracy or metrics.accuracy_lcs
eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True),
optimizer=hp.optimizer,
optimizer_params=optimizer_params,
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
arg_params=arg_params,
aux_params=aux_params,
batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50),
epoch_end_callback=mx.callback.do_checkpoint(args.prefix),
)
......@@ -65,7 +65,8 @@ def lstm2(net, num_lstm_layer, num_hidden):
# import pdb; pdb.set_trace()
output = lstm(net) # res: `(sequence_length, batch_size, 2*num_hidden)`
# print('7', output.infer_shape()[1])
return mx.symbol.reshape(output, shape=(-3, -2)) # res: (bz * 35, c)
return output
# return mx.symbol.reshape(output, shape=(-3, -2)) # res: (bz * 35, c)
# - **out**: output tensor with shape `(sequence_length, batch_size, num_hidden)`
# when `layout` is "TNC". If `bidirectional` is True, output shape will instead
# be `(sequence_length, batch_size, 2*num_hidden)`
......
......@@ -7,41 +7,47 @@ class CnHyperparams(object):
"""
def __init__(self):
# Training hyper parameters
self._train_epoch_size = 2560000
self._eval_epoch_size = 3000
# self._train_epoch_size = 2560000
# self._eval_epoch_size = 3000
self._num_epoch = 20
self.optimizer = "Adam"
self._learning_rate = 0.001
self._momentum = 0.9
self._bn_mom = 0.9
self._workspace = 512
self._loss_type = "ctc" # ["warpctc" "ctc"]
self.wd = 0.00001
self.clip_gradient = None # `None`: don't use clip gradient
# self._momentum = 0.9
# self._bn_mom = 0.9
# self._workspace = 512
self._batch_size = 128
self._num_classes = 6426 # 应该是6426的。。 5990
self._img_width = 280
self._img_height = 32
# DenseNet hyper parameters
self._depth = 161
self._growrate = 32
self._reduction = 0.5
# LSTM hyper parameters
self.seq_model_type = 'lstm'
self._num_hidden = 100
self._num_lstm_layer = 2
# self._seq_length = 35
self.seq_len_cmpr_ratio = 8 # 模型对于图片宽度压缩的比例(模型中的卷积层造成的)
self._seq_length = self._img_width // self.seq_len_cmpr_ratio
self._num_label = 10
# 模型对于图片宽度压缩的比例(模型中的卷积层造成的);由模型决定,不同模型不一样
self.seq_len_cmpr_ratio = None
# 序列长度;由模型决定,不同模型不一样
self._seq_length = None
self._num_label = 20
self._drop_out = 0.5
@property
def train_epoch_size(self):
return self._train_epoch_size
def __repr__(self):
return str(self.__dict__)
@property
def eval_epoch_size(self):
return self._eval_epoch_size
def set_seq_length(self, seq_len):
self._seq_length = seq_len
# @property
# def train_epoch_size(self):
# return self._train_epoch_size
#
# @property
# def eval_epoch_size(self):
# return self._eval_epoch_size
@property
def num_epoch(self):
......@@ -55,17 +61,17 @@ class CnHyperparams(object):
def momentum(self):
return self._momentum
@property
def bn_mom(self):
return self._bn_mom
@property
def workspace(self):
return self._workspace
# @property
# def bn_mom(self):
# return self._bn_mom
#
# @property
# def workspace(self):
# return self._workspace
@property
def loss_type(self):
return self._loss_type
return "ctc"
@property
def batch_size(self):
......
from __future__ import print_function
class Hyperparams(object):
"""
Hyperparameters for LSTM network
"""
def __init__(self):
# Training hyper parameters
self._train_epoch_size = 30000
self._eval_epoch_size = 3000
self._num_epoch = 20
self._learning_rate = 0.001
self._momentum = 0.9
self._bn_mom = 0.9
self._workspace = 512
self._loss_type = "ctc" # ["warpctc" "ctc"]
self._batch_size = 128
self._num_classes = 11
self._img_width = 100
self._img_height = 32
# DenseNet hyper parameters
self._depth = 161
self._growrate = 32
self._reduction = 0.5
# LSTM hyper parameters
self._num_hidden = 100
self._num_lstm_layer = 2
self.seq_len_cmpr_ratio = 8 # 模型对于图片宽度压缩的比例(模型中的卷积层造成的)
self._seq_length = self._img_width // self.seq_len_cmpr_ratio
self._num_label = 4
self._drop_out = 0.5
@property
def train_epoch_size(self):
return self._train_epoch_size
@property
def eval_epoch_size(self):
return self._eval_epoch_size
@property
def num_epoch(self):
return self._num_epoch
@property
def learning_rate(self):
return self._learning_rate
@property
def momentum(self):
return self._momentum
@property
def bn_mom(self):
return self._bn_mom
@property
def workspace(self):
return self._workspace
@property
def loss_type(self):
return self._loss_type
@property
def batch_size(self):
return self._batch_size
@property
def num_classes(self):
return self._num_classes
@property
def img_width(self):
return self._img_width
@property
def img_height(self):
return self._img_height
@property
def depth(self):
return self._depth
@property
def growrate(self):
return self._growrate
@property
def reduction(self):
return self._reduction
@property
def num_hidden(self):
return self._num_hidden
@property
def num_lstm_layer(self):
return self._num_lstm_layer
@property
def seq_length(self):
return self._seq_length
@property
def num_label(self):
return self._num_label
@property
def dropout(self):
return self._drop_out
......@@ -20,51 +20,104 @@ LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick Haffner.
Gradient-based learning applied to document recognition.
Proceedings of the IEEE (1998)
"""
from copy import deepcopy
import mxnet as mx
from mxnet.gluon import nn
from mxnet.gluon.rnn.rnn_layer import LSTM, GRU
from .densenet import DenseNet
from ..fit.ctc_loss import add_ctc_loss
from ..fit.lstm import lstm2
def crnn_no_lstm(hp):
def gen_network(model_name, hp):
hp = deepcopy(hp)
hp.seq_model_type = model_name.rsplit('-', maxsplit=1)[-1]
model_name = model_name.lower()
if model_name.startswith('densenet'):
hp.seq_len_cmpr_ratio = 4
hp.set_seq_length(hp.img_width // 4 - 1)
layer_channels = (
(32, 64, 128, 256)
if model_name.startswith('densenet-lite')
else (64, 128, 256, 512)
)
densenet = DenseNet(layer_channels)
model = CRnn(hp, densenet)
elif model_name.startswith('conv-lite'):
hp.seq_len_cmpr_ratio = 4
hp.set_seq_length(hp.img_width // 4 - 1)
model = lambda data: crnn_lstm_lite(hp, data)
elif model_name.startswith('conv'):
hp.seq_len_cmpr_ratio = 8
hp.set_seq_length(hp.img_width // 8)
model = lambda data: crnn_lstm(hp, data)
else:
raise NotImplementedError('bad model_name: %s', model_name)
# input
data = mx.sym.Variable('data')
label = mx.sym.Variable('label')
return pipline(model, hp), hp
kernel_size = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
padding_size = [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]
layer_size = [min(32*2**(i+1), 512) for i in range(len(kernel_size))]
def convRelu(i, input_data, bn=True):
layer = mx.symbol.Convolution(name='conv-%d' % i, data=input_data, kernel=kernel_size[i], pad=padding_size[i],
num_filter=layer_size[i])
if bn:
layer = mx.sym.BatchNorm(data=layer, name='batchnorm-%d' % i)
layer = mx.sym.LeakyReLU(data=layer,name='leakyrelu-%d' % i)
return layer
def get_infer_shape(sym_model, hp):
init_states = {
'data': (hp.batch_size, 1, hp.img_height, hp.img_width),
'label': (hp.batch_size, hp.num_label),
}
internals = sym_model.get_internals()
_, out_shapes, _ = internals.infer_shape(**init_states)
shape_dict = dict(zip(internals.list_outputs(), out_shapes))
return shape_dict
net = convRelu(0, data) # bz x f x 32 x 200
max = mx.sym.Pooling(data=net, name='pool-0_m', pool_type='max', kernel=(2, 2), stride=(2, 2))
avg = mx.sym.Pooling(data=net, name='pool-0_a', pool_type='avg', kernel=(2, 2), stride=(2, 2))
net = max - avg # 16 x 100
net = convRelu(1, net)
net = mx.sym.Pooling(data=net, name='pool-1', pool_type='max', kernel=(2, 2), stride=(2, 2)) # bz x f x 8 x 50
net = convRelu(2, net, True)
net = convRelu(3, net)
net = mx.sym.Pooling(data=net, name='pool-2', pool_type='max', kernel=(2, 2), stride=(2, 2)) # bz x f x 4 x 25
net = convRelu(4, net, True)
net = convRelu(5, net)
net = mx.symbol.Pooling(data=net, kernel=(4, 1), pool_type='avg', name='pool1') # bz x f x 1 x 25
if hp.dropout > 0:
net = mx.symbol.Dropout(data=net, p=hp.dropout)
def gen_seq_model(hp):
if hp.seq_model_type.lower() == 'lstm':
seq_model = LSTM(hp.num_hidden, hp.num_lstm_layer, bidirectional=True)
elif hp.seq_model_type.lower() == 'gru':
seq_model = GRU(hp.num_hidden, hp.num_lstm_layer, bidirectional=True)
else:
def fc_seq_model(data):
fc = mx.sym.FullyConnected(
data, num_hidden=2 * hp.num_hidden, flatten=False, name='seq-fc'
)
net = mx.sym.Activation(data=fc, act_type='relu', name='seq-relu')
return net
seq_model = fc_seq_model
return seq_model
class CRnn(nn.HybridBlock):
def __init__(self, hp, emb_model, **kw):
super().__init__(**kw)
self.hp = hp
self.emb_model = emb_model
self.dropout = nn.Dropout(hp.dropout)
self.seq_model = gen_seq_model(hp)
def hybrid_forward(self, F, X):
embs = self.emb_model(X) # res: bz x emb_size x 1 x seq_len
hp = self.hp
if hp.dropout > 0:
embs = self.dropout(embs)
embs = F.squeeze(embs, axis=2) # res: bz x emb_size x seq_len
embs = F.transpose(embs, axes=(2, 0, 1)) # res: seq_len x bz x emb_size
# res: `(sequence_length, batch_size, 2*num_hidden)`
return self.seq_model(embs)
net = mx.sym.transpose(data=net, axes=[1,0,2,3]) # f x bz x 1 x 25
net = mx.sym.flatten(data=net) # f x (bz x 25)
hidden_concat = mx.sym.transpose(data=net, axes=[1,0]) # (bz x 25) x f
# mx.sym.transpose(net, [])
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=hp.num_classes) # (bz x 25) x num_classes
def pipline(model, hp, data=None):
# 构建用于训练的整个计算图
data = data if data is not None else mx.sym.Variable('data')
output = model(data)
output = mx.symbol.reshape(output, shape=(-3, -2)) # res: (seq_len * bz, c)
pred = mx.sym.FullyConnected(
data=output, num_hidden=hp.num_classes, name='pred_fc'
) # (bz x 35) x num_classes
# print('pred', pred.infer_shape()[1])
# import pdb; pdb.set_trace()
if hp.loss_type:
# Training mode, add loss
......@@ -74,26 +127,87 @@ def crnn_no_lstm(hp):
return mx.sym.softmax(data=pred, name='softmax')
def crnn_lstm(hp):
def convRelu(i, input_data, kernel_size, layer_size, padding_size, bn=True):
layer = mx.symbol.Convolution(
name='conv-%d' % i,
data=input_data,
kernel=kernel_size,
pad=padding_size,
num_filter=layer_size,
)
# in_channel = input_data.infer_shape()[1][0][1]
# num_params = in_channel * kernel_size[0] * kernel_size[1] * layer_size
# print('number of conv-%d layer parameters: %d' % (i, num_params))
if bn:
layer = mx.sym.BatchNorm(data=layer, name='batchnorm-%d' % i)
layer = mx.sym.LeakyReLU(data=layer, name='leakyrelu-%d' % i)
# layer = mx.symbol.Convolution(name='conv-%d-1x1' % i, data=layer, kernel=(1, 1), pad=(0, 0),
# num_filter=layer_size[i])
# if bn:
# layer = mx.sym.BatchNorm(data=layer, name='batchnorm-%d-1x1' % i)
# layer = mx.sym.LeakyReLU(data=layer, name='leakyrelu-%d-1x1' % i)
return layer
# input
data = mx.sym.Variable('data')
label = mx.sym.Variable('label')
# data = mx.sym.Variable('data', shape=(128, 1, 32, 100))
# label = mx.sym.Variable('label', shape=(128, 4))
def bottle_conv(i, input_data, kernel_size, layer_size, padding_size, bn=True):
bottle_channel = layer_size // 2
layer = mx.symbol.Convolution(
name='conv-%d-1-1x1' % i,
data=input_data,
kernel=(1, 1),
pad=(0, 0),
num_filter=bottle_channel,
)
layer = mx.sym.LeakyReLU(data=layer, name='leakyrelu-%d-1' % i)
layer = mx.symbol.Convolution(
name='conv-%d' % i,
data=layer,
kernel=kernel_size,
pad=padding_size,
num_filter=bottle_channel,
)
layer = mx.sym.LeakyReLU(data=layer, name='leakyrelu-%d-2' % i)
layer = mx.symbol.Convolution(
name='conv-%d-2-1x1' % i,
data=layer,
kernel=(1, 1),
pad=(0, 0),
num_filter=layer_size,
)
# in_channel = input_data.infer_shape()[1][0][1]
# num_params = in_channel * bottle_channel
# num_params += bottle_channel * kernel_size[0] * kernel_size[1] * bottle_channel
# num_params += bottle_channel * layer_size
# print('number of bottle-conv-%d layer parameters: %d' % (i, num_params))
if bn:
layer = mx.sym.BatchNorm(data=layer, name='batchnorm-%d' % i)
layer = mx.sym.LeakyReLU(data=layer, name='leakyrelu-%d' % i)
return layer
def crnn_lstm(hp, data):
kernel_size = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
padding_size = [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]
layer_size = [min(32*2**(i+1), 512) for i in range(len(kernel_size))]
layer_size = [min(32 * 2 ** (i + 1), 512) for i in range(len(kernel_size))]
def convRelu(i, input_data, bn=True):
layer = mx.symbol.Convolution(name='conv-%d' % i, data=input_data, kernel=kernel_size[i], pad=padding_size[i],
num_filter=layer_size[i])
layer = mx.symbol.Convolution(
name='conv-%d' % i,
data=input_data,
kernel=kernel_size[i],
pad=padding_size[i],
num_filter=layer_size[i],
)
if bn:
layer = mx.sym.BatchNorm(data=layer, name='batchnorm-%d' % i)
layer = mx.sym.LeakyReLU(data=layer,name='leakyrelu-%d' % i)
layer = mx.symbol.Convolution(name='conv-%d-1x1' % i, data=layer, kernel=(1, 1), pad=(0, 0),
num_filter=layer_size[i])
layer = mx.sym.LeakyReLU(data=layer, name='leakyrelu-%d' % i)
layer = mx.symbol.Convolution(
name='conv-%d-1x1' % i,
data=layer,
kernel=(1, 1),
pad=(0, 0),
num_filter=layer_size[i],
)
if bn:
layer = mx.sym.BatchNorm(data=layer, name='batchnorm-%d-1x1' % i)
layer = mx.sym.LeakyReLU(data=layer, name='leakyrelu-%d-1x1' % i)
......@@ -101,63 +215,98 @@ def crnn_lstm(hp):
net = convRelu(0, data) # bz x f x 32 x 280
# print('0', net.infer_shape()[1])
max = mx.sym.Pooling(data=net, name='pool-0_m', pool_type='max', kernel=(2, 2), stride=(2, 2))
avg = mx.sym.Pooling(data=net, name='pool-0_a', pool_type='avg', kernel=(2, 2), stride=(2, 2))
net = convRelu(1, net)
max = mx.sym.Pooling(
data=net, name='pool-0_m', pool_type='max', kernel=(2, 2), stride=(2, 2)
)
avg = mx.sym.Pooling(
data=net, name='pool-0_a', pool_type='avg', kernel=(2, 2), stride=(2, 2)
)
net = max - avg # 8 x 70
net = convRelu(1, net)
# print('2', net.infer_shape()[1])
net = mx.sym.Pooling(data=net, name='pool-1', pool_type='max', kernel=(2, 2), stride=(2, 2)) # res: bz x f x 8 x 70
net = mx.sym.Pooling(
data=net, name='pool-1', pool_type='max', kernel=(2, 2), stride=(2, 2)
) # res: bz x f x 8 x 70
# print('3', net.infer_shape()[1])
net = convRelu(2, net, True)
net = convRelu(3, net)
net = mx.sym.Pooling(data=net, name='pool-2', pool_type='max', kernel=(2, 2), stride=(2, 2)) # res: bz x f x 4 x 35
net = mx.sym.Pooling(
data=net, name='pool-2', pool_type='max', kernel=(2, 2), stride=(2, 2)
) # res: bz x f x 4 x 35
# print('4', net.infer_shape()[1])
net = convRelu(4, net, True)
net = convRelu(5, net)
net = mx.symbol.Pooling(data=net, kernel=(4, 1), pool_type='avg', name='pool1') # res: bz x f x 1 x 35
net = mx.symbol.Pooling(
data=net, kernel=(4, 1), pool_type='avg', name='pool1'
) # res: bz x f x 1 x 35
# print('5', net.infer_shape()[1])
if hp.dropout > 0:
net = mx.symbol.Dropout(data=net, p=hp.dropout)
hidden_concat = lstm2(net, num_lstm_layer=hp.num_lstm_layer, num_hidden=hp.num_hidden)
# import pdb; pdb.set_trace()
net = mx.symbol.squeeze(net, axis=2) # res: bz x emb_size x seq_len
net = mx.symbol.transpose(net, axes=(2, 0, 1))
# mx.sym.transpose(net, [])
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=hp.num_classes, name='pred_fc') # (bz x 35) x num_classes
# print('pred', pred.infer_shape()[1])
seq_model = gen_seq_model(hp)
hidden_concat = seq_model(net)
if hp.loss_type:
# Training mode, add loss
return add_ctc_loss(pred, hp.seq_length, hp.num_label, hp.loss_type)
# else:
# # Inference mode, add softmax
# return mx.sym.softmax(data=pred, name='softmax')
return hidden_concat
from ..hyperparams.cn_hyperparams import CnHyperparams as Hyperparams
if __name__ == '__main__':
hp = Hyperparams()
init_states = {}
init_states['data'] = (hp.batch_size, 1, hp.img_height, hp.img_width)
init_states['label'] = (hp.batch_size, hp.num_label)
def crnn_lstm_lite(hp, data):
kernel_size = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
padding_size = [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]
layer_size = [min(32 * 2 ** (i + 1), 512) for i in range(len(kernel_size))]
# init_c = {('l%d_init_c' % l): (hp.batch_size, hp.num_hidden) for l in range(hp.num_lstm_layer*2)}
# init_h = {('l%d_init_h' % l): (hp.batch_size, hp.num_hidden) for l in range(hp.num_lstm_layer*2)}
#
# for item in init_c:
# init_states[item] = init_c[item]
# for item in init_h:
# init_states[item] = init_h[item]
net = convRelu(
0, data, kernel_size[0], layer_size[0], padding_size[0]
) # bz x 64 x 32 x 280
# print('0', net.infer_shape()[1])
net = convRelu(
1, net, kernel_size[1], layer_size[1], padding_size[1], True
) # bz x 128 x 16 x 140
# print('1', net.infer_shape()[1])
net = mx.sym.Pooling(
data=net, name='pool-0', pool_type='max', kernel=(2, 2), stride=(2, 2)
)
# avg = mx.sym.Pooling(data=net, name='pool-0_a', pool_type='avg', kernel=(2, 2), stride=(2, 2))
# net = max - avg # bz x 64 x 16 x 140
# print('2', net.infer_shape()[1])
# res: bz x 128 x 8 x 70
# net = mx.sym.Pooling(data=net, name='pool-1', pool_type='max', kernel=(2, 2), stride=(2, 2))
net = convRelu(
2, net, kernel_size[2], layer_size[2], padding_size[2]
) # res: bz x 256 x 8 x 70
# print('3', net.infer_shape()[1])
net = convRelu(
3, net, kernel_size[3], layer_size[3], padding_size[3], True
) # res: bz x 512 x 8 x 70
# res: bz x 512 x 4 x 35
x = net = mx.sym.Pooling(
data=net, name='pool-1', pool_type='max', kernel=(2, 2), stride=(2, 2)
)
# print('4', net.infer_shape()[1])
net = bottle_conv(4, net, kernel_size[4], layer_size[4], padding_size[4])
net = bottle_conv(5, net, kernel_size[5], layer_size[5], padding_size[5], True) + x
# res: bz x 512 x 1 x 35,高度变成1的原因是pooling后没用padding
net = mx.symbol.Pooling(
data=net, name='pool-2', pool_type='max', kernel=(2, 2), stride=(2, 1)
)
# print('5', net.infer_shape()[1])
# net = mx.symbol.Convolution(name='conv-%d' % 6, data=net, kernel=(4, 1), num_filter=layer_size[5])
net = bottle_conv(6, net, (4, 1), layer_size[5], (0, 0))
# print('6', net.infer_shape()[1])
# num_params = layer_size[5] * 4 * 1 * layer_size[5]
# print('number of conv-%d layer parameters: %d' % (6, num_params))
symbol = crnn_no_lstm(hp)
interals = symbol.get_internals()
_, out_shapes, _ = interals.infer_shape(**init_states)
shape_dict = dict(zip(interals.list_outputs(), out_shapes))
if hp.dropout > 0:
net = mx.symbol.Dropout(data=net, p=hp.dropout)
for item in shape_dict:
print(item,shape_dict[item])
net = mx.symbol.squeeze(net, axis=2) # res: bz x emb_size x seq_len
net = mx.symbol.transpose(net, axes=(2, 0, 1))
seq_model = gen_seq_model(hp)
hidden_concat = seq_model(net)
# print('sequence length:', hp.seq_length)
return hidden_concat
"""From DenseNet in Gluon.
from gluoncv.model_zoo.densenet import DenseNet
"""
import logging
from mxnet.gluon.block import HybridBlock
from mxnet.gluon import nn
from mxnet.gluon.contrib.nn import HybridConcurrent, Identity
logger = logging.getLogger(__name__)
def cal_num_params(net):
import numpy as np
params = [p for p in net.collect_params().values()]
for p in params:
logger.info(p)
total = sum([np.prod(p.shape) for p in params])
logger.info(f'total params: {total}')
return total
# Helpers
def _make_dense_block(num_layers, bn_size, growth_rate, dropout, stage_index):
out = nn.HybridSequential(prefix='stage%d_' % stage_index)
with out.name_scope():
for _ in range(num_layers):
out.add(_make_dense_layer(growth_rate, bn_size, dropout))
return out
def _make_dense_layer(growth_rate, bn_size, dropout):
"""Convolutional(1x1) --> Convolutional(3x3)"""
new_features = nn.HybridSequential(prefix='')
new_features.add(nn.BatchNorm())
new_features.add(nn.Activation('relu'))
new_features.add(nn.Conv2D(bn_size * growth_rate, kernel_size=1, use_bias=False))
new_features.add(nn.BatchNorm())
new_features.add(nn.Activation('relu'))
new_features.add(nn.Conv2D(growth_rate, kernel_size=3, padding=1, use_bias=False))
if dropout:
new_features.add(nn.Dropout(dropout))
out = _make_residual(new_features)
return out
def _make_residual(cell_net):
out = HybridConcurrent(axis=1, prefix='')
# 把原始的channels加进来,所以最终 out_channel = in_channel + net.out_channel
out.add(Identity())
out.add(cell_net)
return out
def _make_transition(num_output_features, strides=2):
out = nn.HybridSequential(prefix='')
out.add(nn.BatchNorm())
out.add(nn.Activation('relu'))
out.add(nn.Conv2D(num_output_features, kernel_size=1, use_bias=False))
out.add(nn.MaxPool2D(pool_size=2, strides=strides))
return out
class DenseNet(HybridBlock):
r"""Densenet-BC model from the
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ paper.
Parameters
----------
num_init_features : int
Number of filters to learn in the first convolution layer.
growth_rate : int
Number of filters to add each layer (`k` in the paper).
block_config : list of int
List of integers for numbers of layers in each pooling block.
bn_size : int, default 4
Multiplicative factor for number of bottle neck layers.
(i.e. bn_size * k features in the bottleneck layer)
dropout : float, default 0
Rate of dropout after each dense layer.
classes : int, default 1000
Number of classification classes.
"""
def __init__(self, layer_channels, **kwargs):
assert len(layer_channels) == 4
super(DenseNet, self).__init__(**kwargs)
with self.name_scope():
# Stage 0
self.features = nn.HybridSequential(prefix='')
self.features.add(_make_first_stage_net((layer_channels[0], layer_channels[1])))
self.features.add(_make_transition(layer_channels[1]))
# self.features.add(nn.Conv2D(num_init_features, kernel_size=3,
# strides=1, padding=1, use_bias=False))
# self.features.add(nn.BatchNorm())
# self.features.add(nn.Activation('relu'))
# self.features.add(nn.MaxPool2D(pool_size=2, strides=2))
# Add dense blocks
# Stage 1
self.features.add(_make_inter_stage_net(1, num_layers=2, growth_rate=layer_channels[0]))
self.features.add(_make_transition(layer_channels[2]))
# Stage 2
self.features.add(_make_inter_stage_net(2, num_layers=2, growth_rate=layer_channels[1]))
self.features.add(_make_transition(layer_channels[3], strides=(2, 1)))
# self.features.add(nn.MaxPool2D(pool_size=2, strides=(2, 1)))
# self.features.add(_make_transition(512))
# Stage 3
self.features.add(_make_final_stage_net(3, out_channels=layer_channels[3]))
# num_features = num_init_features
# for i, num_layers in enumerate(block_config):
# self.features.add(_make_dense_block(num_layers, bn_size, growth_rate, dropout, i+1))
# num_features = num_features + num_layers * growth_rate
# if i != len(block_config) - 1:
# self.features.add(_make_transition(num_features // 2))
# num_features = num_features // 2
# self.features.add(nn.BatchNorm())
# self.features.add(nn.Activation('relu'))
# self.features.add(nn.AvgPool2D(pool_size=7))
# self.features.add(nn.Flatten())
# self.output = nn.Dense(classes)
def hybrid_forward(self, F, x):
x = self.features(x)
# x = self.output(x)
return x
def _make_first_stage_net(out_channels):
features = nn.HybridSequential(prefix='stage%d_' % 0)
with features.name_scope():
features.add(nn.Conv2D(out_channels[0], kernel_size=3,
strides=1, padding=1, use_bias=False))
features.add(nn.BatchNorm())
features.add(nn.Activation('relu'))
features.add(nn.Conv2D(out_channels[1], kernel_size=3,
strides=1, padding=1, use_bias=False))
# features.add(nn.BatchNorm())
# features.add(nn.Activation('relu'))
return _make_residual(features)
def _make_inter_stage_net(stage_index, num_layers=2, growth_rate=128):
return _make_dense_block(num_layers, bn_size=2, growth_rate=growth_rate, dropout=0.0, stage_index=stage_index)
def _make_final_stage_net(stage_index, out_channels):
features = nn.HybridSequential(prefix='stage%d_' % stage_index)
with features.name_scope():
features.add(nn.BatchNorm())
features.add(nn.Activation('relu'))
features.add(nn.Conv2D(out_channels // 4, kernel_size=1, use_bias=False))
features.add(nn.BatchNorm())
features.add(nn.Activation('relu'))
features.add(nn.Conv2D(out_channels, kernel_size=(4, 1), use_bias=False))
features.add(nn.BatchNorm())
features.add(nn.Activation('relu'))
return features
......@@ -18,10 +18,9 @@
import os
import platform
import zipfile
import numpy as np
from mxnet.gluon.utils import download
from .consts import MODEL_BASE_URL, ZIP_FILE_NAME
from .consts import AVAILABLE_MODELS, EMB_MODEL_TYPES, SEQ_MODEL_TYPES
def data_dir_default():
......@@ -44,7 +43,13 @@ def data_dir():
return os.getenv('CNOCR_HOME', data_dir_default())
def get_model_file(root=data_dir()):
def check_model_name(model_name):
emb_model_type, seq_model_type = model_name.rsplit('-', maxsplit=1)
assert emb_model_type in EMB_MODEL_TYPES
assert seq_model_type in SEQ_MODEL_TYPES
def get_model_file(model_dir):
r"""Return location for the downloaded models on local file system.
This function will download from online model zoo when model cannot be found or has mismatch.
......@@ -52,7 +57,7 @@ def get_model_file(root=data_dir()):
Parameters
----------
root : str, default $CNOCR_HOME
model_dir : str, default $CNOCR_HOME
Location for keeping the model parameters.
Returns
......@@ -60,18 +65,22 @@ def get_model_file(root=data_dir()):
file_path
Path to the requested pretrained model file.
"""
root = os.path.expanduser(root)
os.makedirs(root, exist_ok=True)
model_dir = os.path.expanduser(model_dir)
par_dir = os.path.dirname(model_dir)
os.makedirs(par_dir, exist_ok=True)
zip_file_path = os.path.join(root, ZIP_FILE_NAME)
zip_file_path = model_dir + '.zip'
if not os.path.exists(zip_file_path):
download(MODEL_BASE_URL, path=zip_file_path, overwrite=True)
model_name = os.path.basename(model_dir)
if model_name not in AVAILABLE_MODELS:
raise NotImplementedError('%s is not an available downloaded model' % model_name)
url = AVAILABLE_MODELS[model_name][1]
download(url, path=zip_file_path, overwrite=True)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(root)
zf.extractall(par_dir)
os.remove(zip_file_path)
return os.path.join(root, 'models')
return model_dir
def read_charset(charset_fp):
......@@ -81,12 +90,18 @@ def read_charset(charset_fp):
for line in fp:
alphabet.append(line.rstrip('\n'))
# print('Alphabet size: %d' % len(alphabet))
try:
space_idx = alphabet.index('<space>')
alphabet[space_idx] = ' '
except ValueError:
pass
inv_alph_dict = {_char: idx for idx, _char in enumerate(alphabet)}
# inv_alph_dict[' '] = inv_alph_dict['<space>'] # 对应空格
return alphabet, inv_alph_dict
def normalize_img_array(img):
def normalize_img_array(img, dtype='float32'):
""" rescale to [-1.0, 1.0] """
# return (img / 255.0 - 0.5) * 2
return (img - np.mean(img)) / (np.std(img) + 1e-6)
img = img.astype(dtype)
# return (img - np.mean(img, dtype=dtype)) / 255.0
return img / 255.0
# return (img - np.median(img)) / (np.std(img, dtype=dtype) + 1e-6) # 转完以后有些情况会变得不可识别
0 0
1 5276
2 10092
3 14816
4 22952
5 28336
6 34228
7 39028
0 3814 6028 1206 304 3074 1706 3426 2705 3780 2750 0 0 0 0 0 0 0 0 0 0 chn-00199980.jpg
1 1249 601 1588 1394 71 6198 482 238 3605 2648 0 0 0 0 0 0 0 0 0 0 chn-00199981.jpg
2 1575 4639 3279 2798 4389 1407 2756 2496 526 3174 0 0 0 0 0 0 0 0 0 0 chn-00199984.jpg
3 401 6415 3804 4895 1133 2139 4923 35 4850 6228 0 0 0 0 0 0 0 0 0 0 chn-00199985.jpg
4 242 242 106 111 456 6425 81 377 456 270 86 242 106 111 139 6425 111 47 107 146 00199971.jpg
5 111 344 205 248 270 106 450 139 6425 518 146 377 248 86 47 344 6425 107 139 86 00199975.jpg
6 1656 81 377 111 107 6425 377 111 450 107 558 107 695 242 558 6425 894 107 2143 377 00199978.jpg
7 107 450 47 6425 518 106 139 139 106 81 111 106 111 456 6425 469 47 111 107 86 00199979.jpg
8 6018 1768 4847 1135 1527 1741 292 5226 2837 3127 0 0 0 0 0 0 0 0 0 0 chn-00199989.jpg
9 107 423 47 146 6425 450 146 81 372 377 242 456 107 86 106 81 111 139 6425 139 00199980.jpg
10 205 106 111 139 86 106 111 248 86 106 780 47 111 47 139 139 6425 47 111 344 00199985.jpg
chn-00199980.jpg 3814 6028 1206 304 3074 1706 3426 2705 3780 2750
chn-00199981.jpg 1249 601 1588 1394 71 6198 482 238 3605 2648
chn-00199984.jpg 1575 4639 3279 2798 4389 1407 2756 2496 526 3174
chn-00199985.jpg 401 6415 3804 4895 1133 2139 4923 35 4850 6228
00199971.jpg 242 242 106 111 456 6425 81 377 456 270 86 242 106 111 139 6425 111 47 107 146
00199975.jpg 111 344 205 248 270 106 450 139 6425 518 146 377 248 86 47 344 6425 107 139 86
00199978.jpg 1656 81 377 111 107 6425 377 111 450 107 558 107 695 242 558 6425 894 107 2143 377
00199979.jpg 107 450 47 6425 518 106 139 139 106 81 111 106 111 456 6425 469 47 111 107 86
chn-00199989.jpg 6018 1768 4847 1135 1527 1741 292 5226 2837 3127
00199980.jpg 107 423 47 146 6425 450 146 81 372 377 242 456 107 86 106 81 111 139 6425 139
00199985.jpg 205 106 111 139 86 106 111 248 86 106 780 47 111 47 139 139 6425 47 111 344
from __future__ import print_function
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from cnocr.utils import data_dir, read_charset
BAD_CHARS = [5751, 5539, 5536, 5535, 5464, 4105]
def main():
charset_fp = os.path.join(data_dir(), 'models', 'label_cn.txt')
alphabet, inv_alph_dict = read_charset(charset_fp)
for idx in BAD_CHARS:
print('idx: {}, char: {}'.format(idx, alphabet[idx]))
if __name__ == '__main__':
main()
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" An example of predicting CAPTCHA image data with a LSTM network pre-trained with a CTC loss"""
from __future__ import print_function
import sys
import os
import time
import argparse
from operator import itemgetter
from pathlib import Path
from collections import Counter
import mxnet as mx
import Levenshtein
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from cnocr import CnOcr
def evaluate():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-name", help="model name", type=str, default='densenet-lite-lstm'
)
parser.add_argument("--model-epoch", type=int, default=None, help="model epoch")
parser.add_argument(
"-i",
"--input-fp",
default='test.txt',
help="the file path with image names and labels",
)
parser.add_argument(
"--image-prefix-dir", default='.', help="图片所在文件夹,相对于索引文件中记录的图片位置"
)
parser.add_argument("--batch-size", type=int, default=128, help="batch size")
parser.add_argument(
"-v",
"--verbose",
action='store_true',
help="whether to print details to screen",
)
parser.add_argument(
"-o",
"--output-dir",
default=False,
help="the output directory which records the analysis results",
)
args = parser.parse_args()
ocr = CnOcr(model_name=args.model_name, model_epoch=args.model_epoch)
alphabet = ocr._alphabet
fn_labels_list = read_input_file(args.input_fp)
miss_cnt, redundant_cnt = Counter(), Counter()
model_time_cost = 0.0
start_idx = 0
bad_cnt = 0
badcases = []
while start_idx < len(fn_labels_list):
print('start_idx: ', start_idx)
batch = fn_labels_list[start_idx : start_idx + args.batch_size]
batch_img_fns = []
batch_labels = []
batch_imgs = []
for fn, labels in batch:
batch_labels.append(labels)
img_fp = os.path.join(args.image_prefix_dir, fn)
batch_img_fns.append(img_fp)
img = mx.image.imread(img_fp, 1).asnumpy()
batch_imgs.append(img)
start_time = time.time()
batch_preds = ocr.ocr_for_single_lines(batch_imgs)
model_time_cost += time.time() - start_time
for bad_info in compare_preds_to_reals(
batch_preds, batch_labels, batch_img_fns, alphabet
):
if args.verbose:
print('\t'.join(bad_info))
distance = Levenshtein.distance(bad_info[1], bad_info[2])
bad_info.insert(0, distance)
badcases.append(bad_info)
miss_cnt.update(list(bad_info[-2]))
redundant_cnt.update(list(bad_info[-1]))
bad_cnt += 1
start_idx += args.batch_size
badcases.sort(key=itemgetter(0), reverse=True)
output_dir = Path(args.output_dir)
if not output_dir.exists():
os.makedirs(output_dir)
with open(output_dir / 'badcases.txt', 'w') as f:
f.write(
'\t'.join(
[
'distance',
'image_fp',
'real_words',
'pred_words',
'miss_words',
'redundant_words',
]
)
+ '\n'
)
for bad_info in badcases:
f.write('\t'.join(map(str, bad_info)) + '\n')
with open(output_dir / 'miss_words_stat.txt', 'w') as f:
for word, num in miss_cnt.most_common():
f.write('\t'.join([word, str(num)]) + '\n')
with open(output_dir / 'redundant_words_stat.txt', 'w') as f:
for word, num in redundant_cnt.most_common():
f.write('\t'.join([word, str(num)]) + '\n')
print(
"number of total cases: %d, time cost per image: %f, number of bad cases: %d"
% (len(fn_labels_list), model_time_cost / len(fn_labels_list), bad_cnt)
)
def read_input_file(in_fp):
fn_labels_list = []
with open(in_fp) as f:
for line in f:
fields = line.strip().split()
fn_labels_list.append((fields[0], fields[1:]))
return fn_labels_list
def compare_preds_to_reals(batch_preds, batch_reals, batch_img_fns, alphabet):
for preds, reals, img_fn in zip(batch_preds, batch_reals, batch_img_fns):
reals = [alphabet[int(_id)] for _id in reals if _id != '0'] # '0' is padding id
if preds == reals:
continue
preds_set, reals_set = set(preds), set(reals)
miss_words = reals_set.difference(preds_set)
redundant_words = preds_set.difference(reals_set)
yield [
img_fn,
''.join(reals),
''.join(preds),
''.join(miss_words),
''.join(redundant_words),
]
if __name__ == '__main__':
evaluate()
......@@ -30,12 +30,20 @@ from cnocr import CnOcr
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name", help="model name", type=str, default='conv-lite-fc'
)
parser.add_argument("--model_epoch", type=int, default=None, help="model epoch")
parser.add_argument("-f", "--file", help="Path to the image file")
parser.add_argument("-s", "--single-line", default=False,
help="Whether the image only includes one-line characters")
parser.add_argument(
"-s",
"--single-line",
default=False,
help="Whether the image only includes one-line characters",
)
args = parser.parse_args()
ocr = CnOcr()
ocr = CnOcr(model_name=args.model_name, model_epoch=args.model_epoch)
if args.single_line:
res = ocr.ocr_for_single_line(args.file)
else:
......
......@@ -21,14 +21,17 @@ import argparse
import logging
import os
import sys
import mxnet as mx
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from cnocr.__version__ import __version__
from cnocr.consts import EMB_MODEL_TYPES, SEQ_MODEL_TYPES
from cnocr.utils import data_dir
from cnocr.hyperparams.cn_hyperparams import CnHyperparams as Hyperparams
from cnocr.hyperparams.hyperparams2 import Hyperparams as Hyperparams2
from cnocr.data_utils.data_iter import ImageIterLstm, MPOcrImages, OCRIter
from cnocr.symbols.crnn import crnn_no_lstm, crnn_lstm
from cnocr.hyperparams.cn_hyperparams import CnHyperparams
from cnocr.data_utils.data_iter import GrayImageIter
from cnocr.data_utils.aug import FgBgFlipAug
from cnocr.symbols.crnn import gen_network
from cnocr.fit.ctc_metrics import CtcMetrics
from cnocr.fit.fit import fit
......@@ -36,144 +39,162 @@ from cnocr.fit.fit import fit
def parse_args():
# Parse command line arguments
parser = argparse.ArgumentParser()
default_model_prefix = os.path.join(data_dir(), 'models', 'model-v{}'.format(__version__))
parser.add_argument("--dataset",
help="use which kind of dataset, captcha or cn_ocr",
choices=['captcha', 'cn_ocr'],
type=str, default='captcha')
parser.add_argument("--data_root", help="Path to image files", type=str,
default='/Users/king/Documents/WhatIHaveDone/Test/text_renderer/output/wechat_simulator')
parser.add_argument("--train_file", help="Path to train txt file", type=str,
default='/Users/king/Documents/WhatIHaveDone/Test/text_renderer/output/wechat_simulator/train.txt')
parser.add_argument("--test_file", help="Path to test txt file", type=str,
default='/Users/king/Documents/WhatIHaveDone/Test/text_renderer/output/wechat_simulator/test.txt')
parser.add_argument("--cpu",
help="Number of CPUs for training [Default 8]. Ignored if --gpu is specified.",
type=int, default=2)
parser.add_argument("--gpu", help="Number of GPUs for training [Default 0]", type=int)
parser.add_argument('--load_epoch', type=int,
help='load the model on an epoch using the model-load-prefix [Default: no trained model will be loaded]')
parser.add_argument("--prefix", help="Checkpoint prefix [Default '{}']".format(default_model_prefix),
default=default_model_prefix)
parser.add_argument("--loss", help="'ctc' or 'warpctc' loss [Default 'ctc']", default='ctc')
parser.add_argument("--num_proc", help="Number CAPTCHA generating processes [Default 4]", type=int, default=4)
parser.add_argument("--font_path", help="Path to ttf font file or directory containing ttf files")
return parser.parse_args()
def get_fonts(path):
fonts = list()
if os.path.isdir(path):
for filename in os.listdir(path):
if filename.endswith('.ttf') or filename.endswith('.ttc'):
fonts.append(os.path.join(path, filename))
else:
fonts.append(path)
return fonts
def run_captcha(args):
from cnocr.data_utils.captcha_generator import MPDigitCaptcha
hp = Hyperparams2()
network = crnn_lstm(hp)
# arg_shape, out_shape, aux_shape = network.infer_shape(data=(128, 1, 32, 100), label=(128, 10),
# l0_init_h=(128, 100), l1_init_h=(128, 100), l2_init_h=(128, 100), l3_init_h=(128, 100))
# print(dict(zip(network.list_arguments(), arg_shape)))
# import pdb; pdb.set_trace()
# Start a multiprocessor captcha image generator
mp_captcha = MPDigitCaptcha(
font_paths=get_fonts(args.font_path), h=hp.img_width, w=hp.img_height,
num_digit_min=3, num_digit_max=4, num_processes=args.num_proc, max_queue_size=hp.batch_size * 2)
mp_captcha.start()
# img, num = mp_captcha.get()
# print(img.shape, num)
# import numpy as np
# import cv2
# img = np.transpose(img, (1, 0))
# cv2.imwrite('captcha1.png', img * 255)
# import sys
# sys.exit(0)
# import pdb; pdb.set_trace()
# init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
# init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
# init_states = init_c + init_h
# data_names = ['data'] + [x[0] for x in init_states]
data_names = ['data']
parser.add_argument(
"--emb_model_type",
help="which embedding model to use",
choices=EMB_MODEL_TYPES,
type=str,
default='conv-lite',
)
parser.add_argument(
"--seq_model_type",
help='which sequence model to use',
default='fc',
type=str,
choices=SEQ_MODEL_TYPES,
)
parser.add_argument(
"--train_file",
help="Path to train txt file",
type=str,
default='data/sample-data-lst/train.txt',
)
parser.add_argument(
"--test_file",
help="Path to test txt file",
type=str,
default='data/sample-data-lst/test.txt',
)
parser.add_argument(
"--use_train_image_aug",
action='store_true',
help="Whether to use image augmentation for training",
)
parser.add_argument(
"--gpu",
help="Number of GPUs for training [Default 0, means using cpu]",
type=int,
default=0,
)
parser.add_argument(
"--optimizer",
help="optimizer for training [Default: Adam]",
type=str,
default='Adam',
)
parser.add_argument(
'--epoch', type=int, default=20, help='train epochs [Default: 20]'
)
parser.add_argument(
'--load_epoch',
type=int,
help='load the model on an epoch using the model-load-prefix [Default: no trained model will be loaded]',
)
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument(
'--wd', type=float, default=0.0, help='weight decay factor [Default: 0.0]'
)
parser.add_argument(
'--clip_gradient',
type=float,
default=None,
help='value for clip gradient [Default: None, means no gradient will be clip]',
)
parser.add_argument(
"--out_model_dir",
help='output model directory',
default=os.path.join(data_dir(), __version__),
)
return parser.parse_args()
data_train = OCRIter(
hp.train_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_captcha, num_label=hp.num_label,
name='train')
data_val = OCRIter(
hp.eval_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_captcha, num_label=hp.num_label,
name='val')
def train_cnocr(args):
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
args.model_name = args.emb_model_type + '-' + args.seq_model_type
out_dir = os.path.join(args.out_model_dir, args.model_name)
print('save models to dir: %s' % out_dir, flush=True)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
args.prefix = os.path.join(
out_dir, 'cnocr-v{}-{}'.format(__version__, args.model_name)
)
hp = CnHyperparams()
hp = _update_hp(hp, args)
network, hp = gen_network(args.model_name, hp)
metrics = CtcMetrics(hp.seq_length)
fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names)
mp_captcha.reset()
def run_cn_ocr(args):
hp = Hyperparams()
network = crnn_lstm(hp)
mp_data_train = MPOcrImages(args.data_root, args.train_file, (hp.img_width, hp.img_height), hp.num_label,
num_processes=args.num_proc, max_queue_size=hp.batch_size * 100)
# img, num = mp_data_train.get()
# print(img.shape)
# print(mp_data_train.shape)
# import pdb; pdb.set_trace()
# import numpy as np
# import cv2
# img = np.transpose(img, (1, 0))
# cv2.imwrite('captcha1.png', img * 255)
# import pdb; pdb.set_trace()
mp_data_test = MPOcrImages(args.data_root, args.test_file, (hp.img_width, hp.img_height), hp.num_label,
num_processes=max(args.num_proc // 2, 1), max_queue_size=hp.batch_size * 10)
mp_data_train.start()
mp_data_test.start()
# init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
# init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
# init_states = init_c + init_h
# data_names = ['data'] + [x[0] for x in init_states]
data_train, data_val = _gen_iters(
hp, args.train_file, args.test_file, args.use_train_image_aug
)
data_names = ['data']
data_train = OCRIter(
hp.train_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_data_train, num_label=hp.num_label,
name='train')
data_val = OCRIter(
hp.eval_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_data_test, num_label=hp.num_label,
name='val')
# data_train = ImageIterLstm(
# args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train")
# data_val = ImageIterLstm(
# args.data_root, args.test_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val")
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
metrics = CtcMetrics(hp.seq_length)
fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names)
mp_data_train.reset()
mp_data_test.reset()
fit(
network=network,
data_train=data_train,
data_val=data_val,
metrics=metrics,
args=args,
hp=hp,
data_names=data_names,
)
def _update_hp(hp, args):
hp.seq_model_type = args.seq_model_type
hp._num_epoch = args.epoch
hp.optimizer = args.optimizer
hp._learning_rate = args.lr
hp.wd = args.wd
hp.clip_gradient = args.clip_gradient
return hp
def _gen_iters(hp, train_fp_prefix, val_fp_prefix, use_train_image_aug):
height, width = hp.img_height, hp.img_width
augs = None
if use_train_image_aug:
augs = mx.image.CreateAugmenter(
data_shape=(3, height, width),
resize=0,
rand_crop=False,
rand_resize=False,
rand_mirror=False,
mean=None,
std=None,
brightness=0.001,
contrast=0.001,
saturation=0.001,
hue=0.05,
pca_noise=0.1,
inter_method=2,
)
augs.append(FgBgFlipAug(p=0.2))
train_iter = GrayImageIter(
batch_size=hp.batch_size,
data_shape=(3, height, width),
label_width=hp.num_label,
dtype='int32',
shuffle=True,
path_imgrec=str(train_fp_prefix) + ".rec",
path_imgidx=str(train_fp_prefix) + ".idx",
aug_list=augs,
)
val_iter = GrayImageIter(
batch_size=hp.batch_size,
data_shape=(3, height, width),
label_width=hp.num_label,
dtype='int32',
path_imgrec=str(val_fp_prefix) + ".rec",
path_imgidx=str(val_fp_prefix) + ".idx",
)
return train_iter, val_iter
if __name__ == '__main__':
args = parse_args()
if args.dataset == 'captcha':
run_captcha(args)
else:
run_cn_ocr(args)
train_cnocr(args)
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
从badcases.txt中可以发现一些不好的训练样本。这个脚本就是为了过滤掉这些样本。
"""
from __future__ import print_function
import argparse
import logging
def parse_args():
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"--sample_file", help="Path to train or test txt file", type=str, required=True
)
parser.add_argument(
"--badcases_file", help="Path to badcases file from evaluate mode", type=str, required=True
)
parser.add_argument(
"--distance_thrsh", help="samples with distance >= thrsh will be deleted", type=int, default=2
)
parser.add_argument(
"-o",
"--output_file",
help="the new sample file",
type=str,
required=True,
)
return parser.parse_args()
def read_badcases_file(fp, dist_thrsh):
badcases = set()
with open(fp) as f:
for line in f:
line = line.strip()
fields = line.split('\t')
if fields[0] == 'distance':
continue
dist, fp = int(fields[0]), fields[1]
if dist >= dist_thrsh:
fp = '/'.join(fp.split('/')[-2:])
badcases.add(fp)
print('get %d badcase samples' % len(badcases))
return badcases
def process_sample_file(in_fp, out_fp, badcases):
num_deleted = 0
with open(in_fp) as in_f, open(out_fp, 'w') as out_f:
for line in in_f:
line = line.strip()
fields = line.split()
sample_fp = fields[0]
if sample_fp in badcases:
num_deleted += 1
else:
out_f.write(line + '\n')
print('%d samples are deleted' % num_deleted)
def filter(args):
"""选择包含给定id的样本,并按包含数量从高到低排序。"""
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
badcases = read_badcases_file(args.badcases_file, args.distance_thrsh)
process_sample_file(args.sample_file, args.output_file, badcases)
if __name__ == '__main__':
args = parse_args()
filter(args)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import print_function
import os
import sys
from pathlib import Path
curr_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(curr_path, "../python"))
import mxnet as mx
import random
import argparse
import cv2
import time
import traceback
try:
import multiprocessing
except ImportError:
multiprocessing = None
def list_image(root, recursive, exts):
"""Traverses the root of directory that contains images and
generates image list iterator.
Parameters
----------
root: string
recursive: bool
exts: string
Returns
-------
image iterator that contains all the image under the specified path
"""
i = 0
if recursive:
cat = {}
for path, dirs, files in os.walk(root, followlinks=True):
dirs.sort()
files.sort()
for fname in files:
fpath = os.path.join(path, fname)
suffix = os.path.splitext(fname)[1].lower()
if os.path.isfile(fpath) and (suffix in exts):
if path not in cat:
cat[path] = len(cat)
yield (i, os.path.relpath(fpath, root), cat[path])
i += 1
for k, v in sorted(cat.items(), key=lambda x: x[1]):
print(os.path.relpath(k, root), v)
else:
for fname in sorted(os.listdir(root)):
fpath = os.path.join(root, fname)
suffix = os.path.splitext(fname)[1].lower()
if os.path.isfile(fpath) and (suffix in exts):
yield (i, os.path.relpath(fpath, root), 0)
i += 1
def write_list(path_out, image_list):
"""Hepler function to write image list into the file.
The format is as below,
integer_image_index \t float_label_index \t path_to_image
Note that the blank between number and tab is only used for readability.
Parameters
----------
path_out: string
image_list: list
"""
with open(path_out, 'w') as fout:
for i, item in enumerate(image_list):
line = '%d\t' % item[0]
for j in item[2:]:
line += '%f\t' % j
line += '%s\n' % item[1]
fout.write(line)
def make_list(args):
"""Generates .lst file.
Parameters
----------
args: object that contains all the arguments
"""
image_list = list_image(args.root, args.recursive, args.exts)
image_list = list(image_list)
if args.shuffle is True:
random.seed(100)
random.shuffle(image_list)
N = len(image_list)
chunk_size = (N + args.chunks - 1) // args.chunks
for i in range(args.chunks):
chunk = image_list[i * chunk_size : (i + 1) * chunk_size]
if args.chunks > 1:
str_chunk = '_%d' % i
else:
str_chunk = ''
sep = int(chunk_size * args.train_ratio)
sep_test = int(chunk_size * args.test_ratio)
if args.train_ratio == 1.0:
write_list(args.prefix + str_chunk + '.lst', chunk)
else:
if args.test_ratio:
write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test])
if args.train_ratio + args.test_ratio < 1.0:
write_list(
args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep :]
)
write_list(
args.prefix + str_chunk + '_train.lst', chunk[sep_test : sep_test + sep]
)
def make_list_new(args):
def read_file(fp, begin_idx=0):
res_list = []
with open(fp) as f:
for idx, line in enumerate(f):
eles = line.strip().split(' ')
fname = eles[0]
labels = eles[1:]
if len(labels) < args.num_label:
labels.extend(['0'] * (args.num_label - len(labels)))
res_list.append((str(begin_idx + idx), '\t'.join(labels), fname))
return res_list
def write_to_file(fp, the_list):
with open(fp, 'w') as f:
for ele in the_list:
f.write('\t'.join(ele) + '\n')
if os.path.isdir(args.prefix):
working_dir = args.prefix
prefix = ''
else:
working_dir = os.path.dirname(args.prefix)
if not os.path.exists(working_dir):
os.makedirs(working_dir)
prefix = os.path.basename(args.prefix)
test_list = read_file(args.test_idx_fp)
write_to_file(os.path.join(working_dir, prefix + '_test.lst'), test_list)
train_list = read_file(args.train_idx_fp, len(test_list))
num_chunks = args.chunks
if num_chunks <= 1:
write_to_file(os.path.join(working_dir, prefix + '_train.lst'), train_list)
else:
chunk_size = (len(train_list) + num_chunks - 1) // num_chunks
for i in range(num_chunks):
chunk = train_list[i * chunk_size: (i + 1) * chunk_size]
write_to_file(os.path.join(working_dir, prefix + '_%d_train.lst' % i), chunk)
def read_list(path_in):
"""Reads the .lst file and generates corresponding iterator.
Parameters
----------
path_in: string
Returns
-------
item iterator that contains information in .lst file
"""
with open(path_in) as fin:
while True:
line = fin.readline()
if not line:
break
line = [i.strip() for i in line.strip().split('\t')]
line_len = len(line)
# check the data format of .lst file
if line_len < 3:
print(
'lst should have at least has three parts, but only has %s parts for %s'
% (line_len, line)
)
continue
try:
item = [int(line[0])] + [line[-1]] + [int(i) for i in line[1:-1]]
except Exception as e:
print('Parsing lst met error for %s, detail: %s' % (line, e))
continue
yield item
def image_encode(args, i, item, q_out):
"""Reads, preprocesses, packs the image and put it back in output queue.
Parameters
----------
args: object
i: int
item: list
q_out: queue
"""
fullpath = os.path.join(args.root, item[1])
if len(item) > 3 and args.pack_label:
header = mx.recordio.IRHeader(0, item[2:], item[0], 0)
else:
header = mx.recordio.IRHeader(0, item[2], item[0], 0)
if args.pass_through:
try:
with open(fullpath, 'rb') as fin:
img = fin.read()
s = mx.recordio.pack(header, img)
q_out.put((i, s, item))
except Exception as e:
traceback.print_exc()
print('pack_img error:', item[1], e)
q_out.put((i, None, item))
return
try:
img = cv2.imread(fullpath, args.color)
except:
traceback.print_exc()
print('imread error trying to load file: %s ' % fullpath)
q_out.put((i, None, item))
return
if img is None:
print('imread read blank (None) image for file: %s' % fullpath)
q_out.put((i, None, item))
return
if args.center_crop:
if img.shape[0] > img.shape[1]:
margin = (img.shape[0] - img.shape[1]) // 2
img = img[margin : margin + img.shape[1], :]
else:
margin = (img.shape[1] - img.shape[0]) // 2
img = img[:, margin : margin + img.shape[0]]
if args.resize:
if img.shape[0] > img.shape[1]:
newsize = (args.resize, img.shape[0] * args.resize // img.shape[1])
else:
newsize = (img.shape[1] * args.resize // img.shape[0], args.resize)
img = cv2.resize(img, newsize)
try:
s = mx.recordio.pack_img(
header, img, quality=args.quality, img_fmt=args.encoding
)
q_out.put((i, s, item))
except Exception as e:
traceback.print_exc()
print('pack_img error on file: %s' % fullpath, e)
q_out.put((i, None, item))
return
def read_worker(args, q_in, q_out):
"""Function that will be spawned to fetch the image
from the input queue and put it back to output queue.
Parameters
----------
args: object
q_in: queue
q_out: queue
"""
while True:
deq = q_in.get()
if deq is None:
break
i, item = deq
image_encode(args, i, item, q_out)
def write_worker(q_out, fname, working_dir):
"""Function that will be spawned to fetch processed image
from the output queue and write to the .rec file.
Parameters
----------
q_out: queue
fname: string
working_dir: string
"""
pre_time = time.time()
count = 0
fname = os.path.basename(fname)
fname_rec = os.path.splitext(fname)[0] + '.rec'
fname_idx = os.path.splitext(fname)[0] + '.idx'
record = mx.recordio.MXIndexedRecordIO(
os.path.join(working_dir, fname_idx), os.path.join(working_dir, fname_rec), 'w'
)
buf = {}
more = True
while more:
deq = q_out.get()
if deq is not None:
i, s, item = deq
buf[i] = (s, item)
else:
more = False
while count in buf:
s, item = buf[count]
del buf[count]
if s is not None:
record.write_idx(item[0], s)
if count % 1000 == 0:
cur_time = time.time()
print('time:', cur_time - pre_time, ' count:', count)
pre_time = cur_time
count += 1
def parse_args():
"""Defines all arguments.
Returns
-------
args object that contains all the params
"""
def_data_dir = Path('data/sample-data')
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description='Create an image list or \
make a record database by reading from an image list',
)
parser.add_argument('--prefix', help='prefix of input/output lst and rec files.')
cgroup = parser.add_argument_group('Options for creating image lists')
cgroup.add_argument(
'--list',
action='store_true',
help='If this is set im2rec will create image list(s) by traversing root folder\
and output to <prefix>.lst.\
Otherwise im2rec will read <prefix>.lst and create a database at <prefix>.rec',
)
# cgroup.add_argument('--exts', nargs='+', default=['.jpeg', '.jpg', '.png'],
# help='list of acceptable image extensions.')
cgroup.add_argument('--chunks', type=int, default=1, help='number of chunks.')
cgroup.add_argument(
'--train-idx-fp',
type=str,
help='original index file for train data.',
default=def_data_dir / 'train.txt',
)
cgroup.add_argument(
'--test-idx-fp',
type=str,
help='original index file for test data.',
default=def_data_dir / 'test.txt',
)
cgroup.add_argument('--num-label', type=int, default=20, help='每个样本的字符数量;不足会使用0做padding')
# cgroup.add_argument('--train-ratio', type=float, default=1.0,
# help='Ratio of images to use for training.')
# cgroup.add_argument('--test-ratio', type=float, default=0,
# help='Ratio of images to use for testing.')
# cgroup.add_argument('--recursive', action='store_true',
# help='If true recursively walk through subdirs and assign an unique label\
# to images in each folder. Otherwise only include images in the root folder\
# and give them label 0.')
cgroup.add_argument(
'--no-shuffle',
dest='shuffle',
action='store_false',
help='If this is passed, \
im2rec will not randomize the image order in <prefix>.lst',
)
rgroup = parser.add_argument_group('Options for creating database')
rgroup.add_argument('--root', help='path to folder containing images.')
rgroup.add_argument(
'--pass-through',
action='store_true',
help='whether to skip transformation and save image as is',
)
rgroup.add_argument(
'--resize',
type=int,
default=0,
help='resize the shorter edge of image to the newsize, original images will\
be packed by default.',
)
rgroup.add_argument(
'--center-crop',
action='store_true',
help='specify whether to crop the center image to make it rectangular.',
)
rgroup.add_argument(
'--quality',
type=int,
default=95,
help='JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9',
)
rgroup.add_argument(
'--num-thread',
type=int,
default=1,
help='number of thread to use for encoding. order of images will be different\
from the input list if >1. the input list will be modified to match the\
resulting order.',
)
rgroup.add_argument(
'--color',
type=int,
default=1,
choices=[-1, 0, 1],
help='specify the color mode of the loaded image.\
1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\
0: Loads image in grayscale mode.\
-1:Loads image as such including alpha channel.',
)
rgroup.add_argument(
'--encoding',
type=str,
default='.jpg',
choices=['.jpg', '.png'],
help='specify the encoding of the images.',
)
rgroup.add_argument(
'--pack-label',
action='store_true',
help='Whether to also pack multi dimensional label in the record file',
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
# if the '--list' is used, it generates .lst file
if args.list:
# make_list(args)
make_list_new(args)
# otherwise read .lst file to generates .rec file
else:
args.prefix = os.path.abspath(args.prefix)
args.root = os.path.abspath(args.root)
if os.path.isdir(args.prefix):
working_dir = args.prefix
else:
working_dir = os.path.dirname(args.prefix)
files = [
os.path.join(working_dir, fname)
for fname in os.listdir(working_dir)
if os.path.isfile(os.path.join(working_dir, fname))
]
count = 0
for fname in files:
if fname.startswith(args.prefix) and fname.endswith('.lst'):
print('Creating .rec file from', fname, 'in', working_dir)
count += 1
image_list = read_list(fname)
# -- write_record -- #
if args.num_thread > 1 and multiprocessing is not None:
q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)]
q_out = multiprocessing.Queue(1024)
# define the process
read_process = [
multiprocessing.Process(
target=read_worker, args=(args, q_in[i], q_out)
)
for i in range(args.num_thread)
]
# process images with num_thread process
for p in read_process:
p.start()
# only use one process to write .rec to avoid race-condtion
write_process = multiprocessing.Process(
target=write_worker, args=(q_out, fname, working_dir)
)
write_process.start()
# put the image list into input queue
for i, item in enumerate(image_list):
q_in[i % len(q_in)].put((i, item))
for q in q_in:
q.put(None)
for p in read_process:
p.join()
q_out.put(None)
write_process.join()
else:
print(
'multiprocessing not available, fall back to single threaded encoding'
)
try:
import Queue as queue
except ImportError:
import queue
q_out = queue.Queue()
fname = os.path.basename(fname)
fname_rec = os.path.splitext(fname)[0] + '.rec'
fname_idx = os.path.splitext(fname)[0] + '.idx'
record = mx.recordio.MXIndexedRecordIO(
os.path.join(working_dir, fname_idx),
os.path.join(working_dir, fname_rec),
'w',
)
cnt = 0
pre_time = time.time()
for i, item in enumerate(image_list):
image_encode(args, i, item, q_out)
if q_out.empty():
continue
_, s, _ = q_out.get()
record.write_idx(item[0], s)
if cnt > 0 and cnt % 1000 == 0:
cur_time = time.time()
print('time:', cur_time - pre_time, ' count:', cnt)
pre_time = cur_time
cnt += 1
print('time:', time.time() - pre_time, ' count:', cnt)
if not count:
print('Did not find and list file with prefix %s' % args.prefix)
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
选择包含给定id的样本,并按包含数量从高到低排序。
主要是期望找出l相关的样本,然后看看模型能否预测出ll这种在一块的占位很小的序列。
"""
from __future__ import print_function
import argparse
import logging
import os
import shutil
def parse_args():
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"--sample_file", help="Path to train or test txt file", type=str, required=True
)
parser.add_argument(
"--prefix", help="prefix directory for image files", required=True
)
parser.add_argument(
"--target_id", help="target id to select", type=int, default=242
)
parser.add_argument(
"--num_samples", help="how many samples to be selected", type=int, default=256
)
parser.add_argument(
"-o",
"--output_dir",
help="the directory for storing selected sample images",
type=str,
required=True,
)
return parser.parse_args()
def read_file(fp, target_id):
target_id = str(target_id)
res_list = []
with open(fp) as f:
for line in f:
line = line.strip()
fields = line.split()
sample_fp, ids = fields[0], fields[1:]
target_cnt = ids.count(target_id)
if target_cnt > 0:
res_list.append((target_cnt, sample_fp, line))
return res_list
def copy_files(cand_list, prefix, out_dir):
if not os.path.exists(out_dir):
os.makedirs(out_dir)
with open(os.path.join(out_dir, 'labels.txt'), 'w') as label_f:
for _, sample_fp, line in cand_list:
fp = os.path.join(prefix, sample_fp)
shutil.copy(fp, out_dir)
label_f.write(line + '\n')
def select(args):
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
cand_list = read_file(args.sample_file, args.target_id)
cand_list.sort(key=lambda x: x[0], reverse=True)
cand_list = cand_list[: args.num_samples]
copy_files(cand_list, args.prefix, args.output_dir)
if __name__ == '__main__':
args = parse_args()
select(args)
......@@ -18,16 +18,19 @@ exec(
)
required = [
'numpy>=1.14.0,<1.15.0',
'numpy>=1.14.0,<1.20.0',
'pillow>=5.3.0',
'mxnet>=1.4.1,<1.5.0',
'gluoncv>=0.3.0,<0.4.0',
'mxnet>=1.5.0,<1.7.0',
'gluoncv>=0.3.0,<0.7.0',
]
extras_require = {
"dev": ["pip-tools", "pytest", "python-Levenshtein"],
}
setup(
name=PACKAGE_NAME,
version=about['__version__'],
description="Package for Chinese OCR, which can be used after installed without training yourself OCR model",
description="Simple package for Chinese OCR, with small pretrained models",
long_description=long_description,
long_description_content_type="text/markdown",
author='breezedeus',
......@@ -36,10 +39,9 @@ setup(
url='https://github.com/breezedeus/cnocr',
platforms=["Mac", "Linux", "Windows"],
packages=find_packages(),
# entry_points={'cnocr_predict': ['chitchatbot=chitchatbot.cli:main'],
# 'cnocr_train': ['chitchatbot=chitchatbot:Spec']},
include_package_data=True,
install_requires=required,
extras_require=extras_require,
zip_safe=False,
classifiers=[
'Development Status :: 4 - Beta',
......
......@@ -6,53 +6,111 @@ import numpy as np
import mxnet as mx
from mxnet import nd
from PIL import Image
import Levenshtein
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))
from cnocr import CnOcr
from cnocr.line_split import line_split
from cnocr.data_utils.aug import GrayAug
CNOCR = CnOcr()
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
example_dir = os.path.join(root_dir, 'examples')
CNOCR = CnOcr(model_name='conv-lite-fc', model_epoch=None)
SINGLE_LINE_CASES = [
('20457890_2399557098.jpg', [['就', '会', '哈', '哈', '大', '笑', '。', '3', '.', '0']]),
('rand_cn1.png', [['笠', '淡', '嘿', '骅', '谧', '鼎', '臭', '姚', '歼', '蠢', '驼', '耳', '裔', '挝', '涯', '狗', '蒽', '子', '犷']])
('20457890_2399557098.jpg', ['就会哈哈大笑。3.0']),
('rand_cn1.png', ['笠淡嘿骅谧鼎皋姚歼蠢驼耳胬挝涯狗蒽孓犷']),
('rand_cn2.png', ['凉芦']),
('helloworld.jpg', ['Hello World!你好世界']),
]
MULTIPLE_LINE_CASES = [
('multi-line_cn1.png', [['网', '络', '支', '付', '并', '无', '本', '质', '的', '区', '别', ',', '因', '为'],
['每', '一', '个', '手', '机', '号', '码', '和', '邮', '件', '地', '址', '背', '后'],
['都', '会', '对', '应', '着', '一', '个', '账', '户', '一', '―', '这', '个', '账'],
['户', '可', '以', '是', '信', '用', '卡', '账', '户', '、', '借', '记', '卡', '账'],
['户', ',', '也', '包', '括', '邮', '局', '汇', '款', '、', '手', '机', '代'],
['收', '、', '电', '话', '代', '收', '、', '预', '付', '费', '卡', '和', '点', '卡'],
['等', '多', '种', '形', '式', '。']]),
('multi-line_cn2.png', [['。', '当', '然', ',', '在', '媒', '介', '越', '来', '越', '多', '的', '情', '形', '下', ','],
['意', '味', '着', '传', '播', '方', '式', '的', '变', '化', '。', '过', '去', '主', '流'],
['的', '是', '大', '众', '传', '播', ',', '现', '在', '互', '动', '性', '和', '定', '制'],
['性', '带', '来', '了', '新', '的', '挑', '战', '—', '—', '如', '何', '让', '品', '牌'],
['与', '消', '费', '者', '更', '加', '互', '动', '。']]),
('hybrid.png', ['o12345678']),
(
'multi-line_en_black.png',
[
'transforms the image many times. First, the image goes through many convolutional layers. In those',
'convolutional layers, the network learns new and increasingly complex features in its layers. Then the ',
'transformed image information goes through the fully connected layers and turns into a classification ',
'or prediction.',
],
),
(
'multi-line_en_white.png',
[
'This chapter is currently only available in this web version. ebook and print will follow.',
'Convolutional neural networks learn abstract features and concepts from raw image pixels. Feature',
'Visualization visualizes the learned features by activation maximization. Network Dissection labels',
'neural network units (e.g. channels) with human concepts.',
],
),
(
'multi-line_cn1.png',
[
'网络支付并无本质的区别,因为',
'每一个手机号码和邮件地址背后',
'都会对应着一个账户--这个账',
'户可以是信用卡账户、借记卡账',
'户,也包括邮局汇款、手机代',
'收、电话代收、预付费卡和点卡',
'等多种形式。',
],
),
(
'multi-line_cn2.png',
[
'当然,在媒介越来越多的情形下,',
'意味着传播方式的变化。过去主流',
'的是大众传播,现在互动性和定制',
'性带来了新的挑战——如何让品牌',
'与消费者更加互动。',
],
),
]
CASES = SINGLE_LINE_CASES + MULTIPLE_LINE_CASES
def print_preds(pred):
pred = [''.join(line_p) for line_p in pred]
print("Predicted Chars:", pred)
def cal_score(preds, expected):
if len(preds) != len(expected):
return 0
total_cnt = 0
total_dist = 0
for real, pred in zip(expected, preds):
pred = ''.join(pred)
distance = Levenshtein.distance(real, pred)
total_dist += distance
total_cnt += len(real)
return 1.0 - float(total_dist) / total_cnt
@pytest.mark.parametrize('img_fp, expected', CASES)
def test_ocr(img_fp, expected):
ocr = CNOCR
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
img_fp = os.path.join(root_dir, 'examples', img_fp)
# img_fp = 'multi-line-game.jpeg'
pred = ocr.ocr(img_fp)
print('\n')
print("Predicted Chars:", pred)
assert expected == pred
print_preds(pred)
assert cal_score(pred, expected) >= 0.9
img = mx.image.imread(img_fp, 1)
pred = ocr.ocr(img)
print("Predicted Chars:", pred)
assert expected == pred
print_preds(pred)
assert cal_score(pred, expected) >= 0.9
img = mx.image.imread(img_fp, 1).asnumpy()
pred = ocr.ocr(img)
print("Predicted Chars:", pred)
assert expected == pred
print_preds(pred)
assert cal_score(pred, expected) >= 0.9
@pytest.mark.parametrize('img_fp, expected', SINGLE_LINE_CASES)
......@@ -62,26 +120,30 @@ def test_ocr_for_single_line(img_fp, expected):
img_fp = os.path.join(root_dir, 'examples', img_fp)
pred = ocr.ocr_for_single_line(img_fp)
print('\n')
print("Predicted Chars:", pred)
assert expected[0] == pred
print_preds(pred)
assert cal_score([pred], expected) >= 0.9
img = mx.image.imread(img_fp, 1)
pred = ocr.ocr_for_single_line(img)
print("Predicted Chars:", pred)
assert expected[0] == pred
print_preds(pred)
assert cal_score([pred], expected) >= 0.9
img = mx.image.imread(img_fp, 1).asnumpy()
pred = ocr.ocr_for_single_line(img)
print("Predicted Chars:", pred)
assert expected[0] == pred
print_preds(pred)
assert cal_score([pred], expected) >= 0.9
img = np.array(Image.fromarray(img).convert('L'))
assert len(img.shape) == 2
pred = ocr.ocr_for_single_line(img)
print("Predicted Chars:", pred)
assert expected[0] == pred
print_preds(pred)
assert cal_score([pred], expected) >= 0.9
img = np.expand_dims(img, axis=2)
assert len(img.shape) == 3 and img.shape[2] == 1
pred = ocr.ocr_for_single_line(img)
print("Predicted Chars:", pred)
assert expected[0] == pred
print_preds(pred)
assert cal_score([pred], expected) >= 0.9
@pytest.mark.parametrize('img_fp, expected', MULTIPLE_LINE_CASES)
......@@ -90,13 +152,43 @@ def test_ocr_for_single_lines(img_fp, expected):
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
img_fp = os.path.join(root_dir, 'examples', img_fp)
img = mx.image.imread(img_fp, 1).asnumpy()
if img.mean() < 145: # 把黑底白字的图片对调为白底黑字
img = 255 - img
line_imgs = line_split(img, blank=True)
line_img_list = [line_img for line_img, _ in line_imgs]
pred = ocr.ocr_for_single_lines(line_img_list)
print('\n')
print("Predicted Chars:", pred)
assert expected == pred
print_preds(pred)
assert cal_score(pred, expected) >= 0.9
line_img_list = [nd.array(line_img) for line_img in line_img_list]
pred = ocr.ocr_for_single_lines(line_img_list)
print_preds(pred)
assert cal_score(pred, expected) >= 0.9
@pytest.mark.parametrize('img_fp, expected', SINGLE_LINE_CASES)
def test_gray_aug(img_fp, expected):
img_fp = os.path.join(example_dir, img_fp)
img = mx.image.imread(img_fp, 1)
aug = GrayAug()
res_img = aug(img)
print(res_img.shape, res_img.dtype)
def test_cand_alphabet():
from cnocr.consts import NUMBERS
img_fp = os.path.join(example_dir, 'hybrid.png')
ocr = CnOcr()
pred = ocr.ocr(img_fp)
pred = [''.join(line_p) for line_p in pred]
print("Predicted Chars:", pred)
assert len(pred) == 1 and pred[0] == 'o12345678'
ocr = CnOcr(cand_alphabet=NUMBERS)
pred = ocr.ocr(img_fp)
pred = [''.join(line_p) for line_p in pred]
print("Predicted Chars:", pred)
assert expected == pred
assert len(pred) == 1 and pred[0] == '012345678'
# coding: utf-8
import os
import sys
import logging
from copy import deepcopy
import pytest
import mxnet as mx
from mxnet import nd
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))
from cnocr.consts import EMB_MODEL_TYPES, SEQ_MODEL_TYPES
from cnocr.hyperparams.cn_hyperparams import CnHyperparams
from cnocr.symbols.densenet import _make_dense_layer, DenseNet, cal_num_params
from cnocr.symbols.crnn import (
CRnn,
pipline,
gen_network,
get_infer_shape,
crnn_lstm,
crnn_lstm_lite,
)
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
logger = logging.getLogger(__name__)
HP = CnHyperparams()
def test_dense_layer():
x = nd.random.randn(128, 64, 32, 280)
net = _make_dense_layer(64, 2, 0.1)
net.initialize()
y = net(x)
logger.info(net)
logger.info(y.shape)
def test_densenet():
x = nd.random.randn(128, 64, 32, 280)
layer_channels = (64, 128, 256, 512)
net = DenseNet(layer_channels)
net.initialize()
y = net(x)
logger.info(net)
logger.info(y.shape) # (128, 512, 1, 69)
assert y.shape[2] == 1
logger.info('number of parameters: %d', cal_num_params(net)) # 1748224
def test_crnn():
_hp = deepcopy(HP)
_hp.set_seq_length(_hp.img_width // 4 - 1)
x = nd.random.randn(128, 64, 32, 280)
layer_channels_list = [(64, 128, 256, 512), (32, 64, 128, 256)]
for layer_channels in layer_channels_list:
densenet = DenseNet(layer_channels)
crnn = CRnn(_hp, densenet)
crnn.initialize()
y = crnn(x)
logger.info(
'output shape: %s', y.shape
) # res: `(sequence_length, batch_size, 2*num_hidden)`
assert y.shape == (_hp.seq_length, _hp.batch_size, 2 * _hp.num_hidden)
logger.info('number of parameters: %d', cal_num_params(crnn))
def test_crnn_lstm():
hp = deepcopy(HP)
hp.set_seq_length(hp.img_width // 8)
data = mx.sym.Variable('data', shape=(128, 1, 32, 280))
pred = crnn_lstm(HP, data)
pred_shape = pred.infer_shape()[1][0]
logger.info('shape of pred: %s', pred_shape)
assert pred_shape == (hp.seq_length, hp.batch_size, 2 * hp.num_hidden)
def test_crnn_lstm_lite():
hp = deepcopy(HP)
hp.set_seq_length(hp.img_width // 4 - 1)
data = mx.sym.Variable('data', shape=(128, 1, 32, 280))
pred = crnn_lstm_lite(HP, data)
pred_shape = pred.infer_shape()[1][0]
logger.info('shape of pred: %s', pred_shape)
assert pred_shape == (hp.seq_length, hp.batch_size, 2 * hp.num_hidden)
def test_pipline():
hp = deepcopy(HP)
hp.set_seq_length(hp.img_width // 4 - 1)
hp._loss_type = None # infer mode
layer_channels_list = [(64, 128, 256, 512), (32, 64, 128, 256)]
for layer_channels in layer_channels_list:
densenet = DenseNet(layer_channels)
crnn = CRnn(hp, densenet)
data = mx.sym.Variable('data', shape=(128, 1, 32, 280))
pred = pipline(crnn, hp, data)
pred_shape = pred.infer_shape()[1][0]
logger.info('shape of pred: %s', pred_shape)
assert pred_shape == (hp.batch_size * hp.seq_length, hp.num_classes)
MODEL_NAMES = []
for emb_model in EMB_MODEL_TYPES:
for seq_model in SEQ_MODEL_TYPES:
MODEL_NAMES.append('%s-%s' % (emb_model, seq_model))
@pytest.mark.parametrize(
'model_name', MODEL_NAMES
)
def test_gen_networks(model_name):
logger.info('model_name: %s', model_name)
network, hp = gen_network(model_name, HP)
shape_dict = get_infer_shape(network, HP)
logger.info('shape_dict: %s', shape_dict)
assert shape_dict['pred_fc_output'] == (
hp.batch_size * hp.seq_length,
hp.num_classes,
)
# coding: utf-8
import os
import sys
from pathlib import Path
import mxnet as mx
import numpy as np
from mxnet import nd
......@@ -9,12 +10,104 @@ import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))
from cnocr.data_utils.aug import FgBgFlipAug
from cnocr.data_utils.data_iter import GrayImageIter
LST_DIR = Path('data/sample-data-lst')
DATA_DIR = Path('data/sample-data')
def test_nd():
ele = np.reshape(np.array(range(2*3)), (2, 3))
ele = np.reshape(np.array(range(2 * 3)), (2, 3))
data = [ele, ele + 10]
new = nd.array([ele])
assert new.shape == (1, 2, 3)
new = nd.array(data)
assert new.shape == (2, 2, 3)
print(new)
def _read_lst_file(fp):
with open(fp) as f:
for line in f:
_, fname = line.strip().rsplit('\t', maxsplit=1)
yield str(DATA_DIR / fname)
@pytest.mark.parametrize(
'fp_prefix',
[
LST_DIR / 'sample-data_test',
],
)
def test_iter(fp_prefix):
augs = mx.image.CreateAugmenter(
data_shape=(3, 32, 280),
resize=0,
rand_crop=False,
rand_resize=False,
rand_mirror=False,
mean=None,
std=None,
brightness=0.001,
contrast=0.001,
saturation=0.001,
hue=0.05,
pca_noise=0.1,
inter_method=2,
)
augs.append(FgBgFlipAug(p=0.2))
data_iter = GrayImageIter(
batch_size=2,
data_shape=(3, 32, 280),
label_width=20,
path_imgrec=str(fp_prefix) + ".rec",
path_imgidx=str(fp_prefix) + ".idx",
aug_list=augs,
)
expected_img_fps = _read_lst_file(str(fp_prefix) + ".lst")
expected_imgs = [
mx.image.imread(fp, 1) for fp in expected_img_fps
] # shape of each one: (32, 280, 3)
# data_iter的类型是mxnet.image.ImageIter
# reset()函数的作用是:resents the iterator to the beginning of the data
data_iter.reset()
# batch的类型是mxnet.io.DataBatch,因为next()方法的返回值就是DataBatch
batch = data_iter.next()
# data是一个NDArray,表示第一个batch中的数据,因为这里的batch_size大小是4,所以data的size是2*3*32*280
data = batch.data[0] # shape of each one: (3, 32, 280)
# import pdb; pdb.set_trace()
from matplotlib import pyplot as plt
# 这个for循环就是读取这个batch中的每张图像并显示
for i in range(2):
plt.subplot(4, 1, i * 2 + 1)
print(data[i].shape)
# print(
# nd.sum(nd.abs(data[i].astype(np.uint8))),
# nd.sum(expected_imgs[i].transpose((2, 0, 1))),
# )
# print(
# nd.sum(
# nd.abs(data[i].astype(np.uint8) - expected_imgs[i].transpose((2, 0, 1)))
# )
# )
# print(float(data[i].min()), float(data[i].max()))
new_img = data[i].asnumpy() * 255
plt.imshow(new_img.astype(np.uint8).squeeze(axis=0), cmap='gray')
import cv2
cv2.imwrite(f'new-{i}.png', new_img.astype(np.uint8).squeeze(axis=0))
plt.subplot(4, 1, i * 2 + 2)
plt.imshow(expected_imgs[i].asnumpy())
plt.show()
def test_lr_scheduler():
from mxnet import lr_scheduler, optimizer
scheduler = lr_scheduler.FactorScheduler(base_lr=1, step=250, factor=0.5)
optim = optimizer.SGD(learning_rate=0.1, lr_scheduler=scheduler)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册