未验证 提交 453d21d0 编写于 作者: X xiaoting 提交者: GitHub

Merge pull request #288 from tink2123/add_distort_space

Add distort space
......@@ -4,11 +4,11 @@
PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力使用者训练出更好的模型,并应用落地。
**近期更新**
- 2020.7.9 添加支持空格的识别模型,[识别效果](#支持空格的中文OCR效果展示)
- 2020.7.9 添加数据增强、学习率衰减策略,具体参考[配置文件](./doc/doc_ch/config.md)
- 2020.6.8 添加[数据集](./doc/doc_ch/datasets.md),并保持持续更新
- 2020.6.5 支持 `attetnion` 模型导出 `inference_model`
- 2020.6.5 支持单独预测识别时,输出结果得分
- 2020.5.30 提供超轻量级中文OCR在线体验
- 2020.5.30 模型预测、训练支持Windows系统
- [more](./doc/doc_ch/update.md)
## 特性
......@@ -18,22 +18,24 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
- 多种文本检测训练算法,EAST、DB
- 多种文本识别训练算法,Rosetta、CRNN、STAR-Net、RARE
<a name="支持的中文模型列表"></a>
### 支持的中文模型列表:
|模型名称|模型简介|检测模型地址|识别模型地址|
|-|-|-|-|
|chinese_db_crnn_mobile|超轻量级中文OCR模型|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar) & [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar)|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_infer.tar) & [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)|
|chinese_db_crnn_server|通用中文OCR模型|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db_infer.tar) & [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db.tar)|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_infer.tar) & [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn.tar)|
|模型名称|模型简介|检测模型地址|识别模型地址|支持空格的识别模型地址|
|-|-|-|-|-|
|chinese_db_crnn_mobile|超轻量级中文OCR模型|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar)|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance.tar)
|chinese_db_crnn_server|通用中文OCR模型|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db.tar)|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn.tar)|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_enhance_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_enhance.tar)
超轻量级中文OCR在线体验地址:https://www.paddlepaddle.org.cn/hub/scene/ocr
**也可以按如下教程快速体验超轻量级中文OCR和通用中文OCR模型。**
**也可以按如下教程快速体验中文OCR模型。**
## **超轻量级中文OCR以及通用中文OCR体验**
![](doc/imgs_results/11.jpg)
上图是超轻量级中文OCR模型效果展示,更多效果图请见文末[超轻量级中文OCR效果展示](#超轻量级中文OCR效果展示)[通用中文OCR效果展示](#通用中文OCR效果展示)
上图是超轻量级中文OCR模型效果展示,更多效果图请见文末[超轻量级中文OCR效果展示](#超轻量级中文OCR效果展示)
[通用中文OCR效果展示](#通用中文OCR效果展示)[支持空格的中文OCR效果展示](#支持空格的中文OCR效果展示)
#### 1.环境配置
......@@ -44,7 +46,21 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
*windows 环境下如果没有安装wget,下载模型时可将链接复制到浏览器中下载,并解压放置在相应目录下*
#### (1)超轻量级中文OCR模型下载
#### 下载检测/识别模型并解压
复制[中文模型列表](#支持的中文模型列表) 中的检测和识别 `inference模型` 地址,下载并解压:
```
mkdir inference && cd inference
# 下载检测模型并解压
wget {url/of/detection/inference_model} && tar xf {name/of/detection/inference_model/package}
# 下载识别模型并解压
wget {url/of/recognition/inference_model} && tar xf {name/of/recognition/inference_model/package}
cd ..
```
以超轻量级模型为例:
```
mkdir inference && cd inference
# 下载超轻量级中文OCR模型的检测模型并解压
......@@ -53,14 +69,18 @@ wget https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar && tar xf
wget https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_infer.tar && tar xf ch_rec_mv3_crnn_infer.tar
cd ..
```
#### (2)通用中文OCR模型下载
解压完毕后应有如下文件结构:
```
mkdir inference && cd inference
# 下载通用中文OCR模型的检测模型并解压
wget https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db_infer.tar && tar xf ch_det_r50_vd_db_infer.tar
# 下载通用中文OCR模型的识别模型并解压
wget https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_infer.tar && tar xf ch_rec_r34_vd_crnn_infer.tar
cd ..
|-inference
|-ch_rec_mv3_crnn
|- model
|- params
|-ch_det_mv3_db
|- model
|- params
...
```
#### 3.单张图像或者图像集合预测
......@@ -85,6 +105,13 @@ python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_mode
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_model_dir="./inference/ch_det_r50_vd_db/" --rec_model_dir="./inference/ch_rec_r34_vd_crnn/"
```
带空格的通用中文OCR模型的体验可以按照上述步骤下载相应的模型,并且更新相关的参数,示例如下:
```
# 预测image_dir指定的单张图像
python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_12.jpg" --det_model_dir="./inference/ch_det_r50_vd_db/" --rec_model_dir="./inference/ch_rec_r34_vd_crnn_enhance/"
```
更多的文本检测、识别串联推理使用方式请参考文档教程中[基于预测引擎推理](./doc/doc_ch/inference.md)
## 文档教程
......@@ -159,6 +186,7 @@ PaddleOCR文本识别算法的训练和使用请参考文档教程中[文本识
<a name="超轻量级中文OCR效果展示"></a>
## 超轻量级中文OCR效果展示
![](doc/imgs_results/1.jpg)
![](doc/imgs_results/7.jpg)
![](doc/imgs_results/12.jpg)
......@@ -174,6 +202,15 @@ PaddleOCR文本识别算法的训练和使用请参考文档教程中[文本识
![](doc/imgs_results/chinese_db_crnn_server/2.jpg)
![](doc/imgs_results/chinese_db_crnn_server/8.jpg)
<a name="支持空格的中文OCR效果展示"></a>
## 支持空格的中文OCR效果展示
### 轻量级模型
![](doc/imgs_results/img_11.jpg)
### 通用模型
![](doc/imgs_results/chinese_db_crnn_server/en_paper.jpg)
<a name="FAQ"></a>
## FAQ
1. **转换attention识别模型时报错:KeyError: 'predict'**
......
......@@ -3,12 +3,12 @@ English | [简体中文](README.md)
## INTRODUCTION
PaddleOCR aims to create a rich, leading, and practical OCR tools that help users train better models and apply them into practice.
**Recent updates**
**Recent updates**
- 2020.7.9 Add recognition model to support space, [recognition result](#space Chinese OCR results)
- 2020.7.9 Add data auguments and learning rate decay strategies,please read [config](./doc/doc_en/config_en.md)
- 2020.6.8 Add [dataset](./doc/doc_en/datasets_en.md) and keep updating
- 2020.6.5 Support exporting `attention` model to `inference_model`
- 2020.6.5 Support separate prediction and recognition, output result score
- 2020.5.30 Provide lightweight Chinese OCR online experience
- 2020.5.30 Model prediction and training supported on Windows system
- [more](./doc/doc_en/update_en.md)
## FEATURES
......@@ -18,12 +18,13 @@ PaddleOCR aims to create a rich, leading, and practical OCR tools that help user
- Various text detection algorithms: EAST, DB
- Various text recognition algorithms: Rosetta, CRNN, STAR-Net, RARE
<a name="Supported-Chinese-model-list"></a>
### Supported Chinese models list:
|Model Name|Description |Detection Model link|Recognition Model link|
|-|-|-|-|
|chinese_db_crnn_mobile|lightweight Chinese OCR model|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar) & [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_infer.tar) & [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)|
|chinese_db_crnn_server|General Chinese OCR model|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db_infer.tar) & [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_infer.tar) & [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn.tar)|
|Model Name|Description |Detection Model link|Recognition Model link| Support for space Recognition Model link|
|-|-|-|-|-|
|chinese_db_crnn_mobile|lightweight Chinese OCR model|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_infer.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar) / [pre-train model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance.tar)
|chinese_db_crnn_server|General Chinese OCR model|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db_infer.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_infer.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_enhance_infer.tar) / [pre-train model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_enhance.tar)
For testing our Chinese OCR online:https://www.paddlepaddle.org.cn/hub/scene/ocr
......@@ -34,7 +35,7 @@ For testing our Chinese OCR online:https://www.paddlepaddle.org.cn/hub/scene/o
![](doc/imgs_results/11.jpg)
The picture above is the result of our lightweight Chinese OCR model. For more testing results, please see the end of the article [lightweight Chinese OCR results](#lightweight-Chinese-OCR-results) and [General Chinese OCR results](#General-Chinese-OCR-results).
The picture above is the result of our lightweight Chinese OCR model. For more testing results, please see the end of the article [lightweight Chinese OCR results](#lightweight-Chinese-OCR-results) , [General Chinese OCR results](#General-Chinese-OCR-results) and [Support for space Recognition Model](#Space-Chinese-OCR-results).
#### 1. ENVIRONMENT CONFIGURATION
......@@ -45,22 +46,42 @@ Please see [Quick installation](./doc/doc_en/installation_en.md)
#### (1) Download lightweight Chinese OCR models
*If wget is not installed in the windows system, you can copy the link to the browser to download the model. After model downloaded, unzip it and place it in the corresponding directory*
Copy the detection and recognition 'inference model' address in [Chinese model List](#Supported-Chinese-model-list), download and unpack:
```
mkdir inference && cd inference
# Download the detection part of the Chinese OCR and decompress it
wget {url/of/detection/inference_model} && tar xf {name/of/detection/inference_model/package}
# Download the recognition part of the Chinese OCR and decompress it
wget {url/of/recognition/inference_model} && tar xf {name/of/recognition/inference_model/package}
cd ..
```
Take lightweight Chinese OCR model as an example:
```
mkdir inference && cd inference
# Download the detection part of the lightweight Chinese OCR and decompress it
wget https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar && tar xf ch_det_mv3_db_infer.tar
# Download the recognition part of the lightweight Chinese OCR and decompress it
wget https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_infer.tar && tar xf ch_rec_mv3_crnn_infer.tar
# Download the space-recognized part of the lightweight Chinese OCR and decompress it
wget https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar && tar xf ch_rec_mv3_crnn_enhance_infer.tar
cd ..
```
#### (2) Download General Chinese OCR models
After the decompression is completed, the file structure should be as follows:
```
mkdir inference && cd inference
# Download the detection part of the general Chinese OCR model and decompress it
wget https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db_infer.tar && tar xf ch_det_r50_vd_db_infer.tar
# Download the recognition part of the generic Chinese OCR model and decompress it
wget https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_infer.tar && tar xf ch_rec_r34_vd_crnn_infer.tar
cd ..
|-inference
|-ch_rec_mv3_crnn
|- model
|- params
|-ch_det_mv3_db
|- model
|- params
...
```
#### 3. SINGLE IMAGE AND BATCH PREDICTION
......@@ -85,6 +106,13 @@ To run inference of the Generic Chinese OCR model, follow these steps above to d
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_model_dir="./inference/ch_det_r50_vd_db/" --rec_model_dir="./inference/ch_rec_r34_vd_crnn/"
```
To run inference of the space-Generic Chinese OCR model, follow these steps above to download the corresponding models and update the relevant parameters. Examples are as follows:
```
# Prediction on a single image by specifying image path to image_dir
python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_12.jpg" --det_model_dir="./inference/ch_det_r50_vd_db/" --rec_model_dir="./inference/ch_rec_r34_vd_crnn_enhance/"
```
For more text detection and recognition models, please refer to the document [Inference](./doc/doc_en/inference_en.md)
## DOCUMENTATION
......@@ -147,15 +175,15 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
We use [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/datasets_en.md#1-icdar2019-lsvt) dataset and cropout 30w traning data from original photos by using position groundtruth and make some calibration needed. In addition, based on the LSVT corpus, 500w synthetic data is generated to train the Chinese model. The related configuration and pre-trained models are as follows:
|Model|Backbone|Configuration file|Pre-trained model|
|-|-|-|-|
|lightweight Chinese model|MobileNetV3|rec_chinese_lite_train.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)|
|General Chinese OCR model|Resnet34_vd|rec_chinese_common_train.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn.tar)|
|lightweight Chinese model|MobileNetV3|rec_chinese_lite_train.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar) & [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance.tar)|
|General Chinese OCR model|Resnet34_vd|rec_chinese_common_train.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_enhance_infer.tar) & [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_enhance.tar)|
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./doc/doc_en/recognition_en.md)
## END-TO-END OCR ALGORITHM
- [ ] [End2End-PSL](https://arxiv.org/abs/1909.07808)(Baidu Self-Research, comming soon)
<a name="lightweight Chinese OCR results"></a>
<a name="lightweight-Chinese-OCR-results"></a>
## LIGHTWEIGHT CHINESE OCR RESULTS
![](doc/imgs_results/1.jpg)
![](doc/imgs_results/7.jpg)
......@@ -166,12 +194,23 @@ Please refer to the document for training guide and use of PaddleOCR text recogn
![](doc/imgs_results/16.png)
![](doc/imgs_results/22.jpg)
<a name="General Chinese OCR results"></a>
<a name="General-Chinese-OCR-results"></a>
## General Chinese OCR results
![](doc/imgs_results/chinese_db_crnn_server/11.jpg)
![](doc/imgs_results/chinese_db_crnn_server/2.jpg)
![](doc/imgs_results/chinese_db_crnn_server/8.jpg)
<a name="Space-Chinese-OCR-results"></a>
## space Chinese OCR results
### LIGHTWEIGHT CHINESE OCR RESULTS
![](doc/imgs_results/img_11.jpg)
### General Chinese OCR results
![](doc/imgs_results/chinese_db_crnn_server/en_paper.jpg)
<a name="FAQ"></a>
## FAQ
1. Error when using attention-based recognition model: KeyError: 'predict'
......
......@@ -14,6 +14,8 @@ Global:
character_type: ch
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
loss_type: ctc
distort: false
use_space_char: false
reader_yml: ./configs/rec/rec_chinese_reader.yml
pretrain_weights:
checkpoints:
......
......@@ -14,6 +14,8 @@ Global:
character_type: ch
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
loss_type: ctc
distort: false
use_space_char: false
reader_yml: ./configs/rec/rec_chinese_reader.yml
pretrain_weights:
checkpoints:
......
......@@ -13,6 +13,7 @@ Global:
max_text_length: 25
character_type: en
loss_type: ctc
distort: true
reader_yml: ./configs/rec/rec_icdar15_reader.yml
pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy
checkpoints:
......
......@@ -30,6 +30,8 @@
| character_type | 设置字符类型 | ch | en/ch, en时将使用默认dict,ch时使用自定义dict|
| character_dict_path | 设置字典路径 | ./ppocr/utils/ic15_dict.txt | \ |
| loss_type | 设置 loss 类型 | ctc | 支持两种loss: ctc / attention |
| distort | 设置是否使用数据增强 | false | 设置为true时,将在训练时随机进行扰动,支持的扰动操作可阅读[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) |
| use_space_char | 设置是否识别空格 | false | 仅在 character_type=ch 时支持空格 |
| reader_yml | 设置reader配置文件 | ./configs/rec/rec_icdar15_reader.yml | \ |
| pretrain_weights | 加载预训练模型路径 | ./pretrain_models/CRNN/best_accuracy | \ |
| checkpoints | 加载模型参数路径 | None | 用于中断后加载参数继续训练 |
......
......@@ -94,7 +94,10 @@ word_dict.txt 每行有一个单字,将字符与数字索引映射在一起,
`ppocr/utils/ic15_dict.txt` 是一个包含36个字符的英文字典,
您可以按需使用。
如需自定义dic文件,请修改 `configs/rec/rec_icdar15_train.yml` 中的 `character_dict_path` 字段, 并将 `character_type` 设置为 `ch`
如需自定义dic文件,请在 `configs/rec/rec_icdar15_train.yml` 中添加 `character_dict_path` 字段, 并将 `character_type` 设置为 `ch`
*如果希望支持识别"空格"类别, 请将yml文件中的 `use_space_char` 字段设置为 `true`。`use_space_char` 仅在 `character_type=ch` 时生效*
### 启动训练
......@@ -124,6 +127,18 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 tools/train.py -c configs/rec/rec_icdar15_train.yml
```
- 数据增强
PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入扰动,请在配置文件中设置 `distort: true`
默认的扰动方式有:颜色空间转换(cvtColor)、模糊(blur)、抖动(jitter)、噪声(Gasuss noise)、随机切割(random crop)、透视(perspective)、颜色反转(reverse)。
训练过程中每种扰动方式以50%的概率被选择,具体代码实现请参考:[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py)
*由于OpenCV的兼容性问题,扰动操作暂时只支持GPU*
- 训练
PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_train.yml` 中修改 `eval_batch_step` 设置评估频率,默认每500个iter评估一次。评估过程中默认将最佳acc模型,保存为 `output/rec_CRNN/best_accuracy`
如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。
......@@ -157,12 +172,26 @@ Global:
character_type: ch
# 添加自定义字典,如修改字典请将路径指向新字典
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
# 训练时添加数据增强
distort: true
# 识别空格
use_space_char: true
...
# 修改reader类型
reader_yml: ./configs/rec/rec_chinese_reader.yml
...
...
Optimizer:
...
# 添加学习率衰减策略
decay:
function: cosine_decay
# 每个 epoch 包含 iter 数
step_each_epoch: 20
# 总共训练epoch数
total_epoch: 1000
```
**注意,预测/评估时的配置文件请务必与训练一致。**
......
......@@ -30,6 +30,8 @@ Take `rec_chinese_lite_train.yml` as an example
| character_type | Set character type | ch | en/ch, the default dict will be used for en, and the custom dict will be used for ch|
| character_dict_path | Set dictionary path | ./ppocr/utils/ic15_dict.txt | \ |
| loss_type | Set loss type | ctc | Supports two types of loss: ctc / attention |
| distort | Set use distort | false | Support distort type ,read [img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) |
| use_space_char | Wether to recognize space | false | Only support in character_type=ch mode |
| reader_yml | Set the reader configuration file | ./configs/rec/rec_icdar15_reader.yml | \ |
| pretrain_weights | Load pre-trained model path | ./pretrain_models/CRNN/best_accuracy | \ |
| checkpoints | Load saved model path | None | Used to load saved parameters to continue training after interruption |
......
......@@ -158,9 +158,23 @@ Global:
...
# Modify reader type
reader_yml: ./configs/rec/rec_chinese_reader.yml
# Whether to use data augmentation
distort: true
# Whether to recognize spaces
use_space_char: true
...
...
Optimizer:
...
# Add learning rate decay strategy
decay:
function: cosine_decay
# Each epoch contains iter number
step_each_epoch: 20
# Total epoch number
total_epoch: 1000
```
**Note that the configuration file for prediction/evaluation must be consistent with the training.**
......
......@@ -45,12 +45,20 @@ class LMDBReader(object):
self.use_tps = False
if "tps" in params:
self.ues_tps = True
self.use_distort = False
if "distort" in params:
self.use_distort = params['distort'] and params['use_gpu']
if not params['use_gpu']:
logger.info(
"Distort operation can only support in GPU. Distort will be set to False."
)
if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card']
self.drop_last = True
else:
self.batch_size = params['test_batch_size_per_card']
self.drop_last = False
self.use_distort = False
self.infer_img = params['infer_img']
def load_hierarchical_lmdb_dataset(self):
......@@ -142,7 +150,8 @@ class LMDBReader(object):
label=label,
char_ops=self.char_ops,
loss_type=self.loss_type,
max_text_length=self.max_text_length)
max_text_length=self.max_text_length,
distort=self.use_distort)
if outs is None:
continue
yield outs
......@@ -185,12 +194,20 @@ class SimpleReader(object):
self.use_tps = False
if "tps" in params:
self.use_tps = True
self.use_distort = False
if "distort" in params:
self.use_distort = params['distort'] and params['use_gpu']
if not params['use_gpu']:
logger.info(
"Distort operation can only support in GPU.Distort will be set to False."
)
if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card']
self.drop_last = True
else:
self.batch_size = params['test_batch_size_per_card']
self.drop_last = False
self.use_distort = False
def __call__(self, process_id):
if self.mode != 'train':
......@@ -232,9 +249,14 @@ class SimpleReader(object):
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
label = substr[1]
outs = process_image(img, self.image_shape, label,
self.char_ops, self.loss_type,
self.max_text_length)
outs = process_image(
img=img,
image_shape=self.image_shape,
label=label,
char_ops=self.char_ops,
loss_type=self.loss_type,
max_text_length=self.max_text_length,
distort=self.use_distort)
if outs is None:
continue
yield outs
......
......@@ -15,6 +15,7 @@
import math
import cv2
import numpy as np
import random
from ppocr.utils.utility import initial_logger
logger = initial_logger()
......@@ -89,6 +90,254 @@ def get_img_data(value):
return imgori
def flag():
"""
flag
"""
return 1 if random.random() > 0.5000001 else -1
def cvtColor(img):
"""
cvtColor
"""
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
delta = 0.001 * random.random() * flag()
hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta)
new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
return new_img
def blur(img):
"""
blur
"""
h, w, _ = img.shape
if h > 10 and w > 10:
return cv2.GaussianBlur(img, (5, 5), 1)
else:
return img
def jitter(img):
"""
jitter
"""
w, h, _ = img.shape
if h > 10 and w > 10:
thres = min(w, h)
s = int(random.random() * thres * 0.01)
src_img = img.copy()
for i in range(s):
img[i:, i:, :] = src_img[:w - i, :h - i, :]
return img
else:
return img
def add_gasuss_noise(image, mean=0, var=0.1):
"""
Gasuss noise
"""
noise = np.random.normal(mean, var**0.5, image.shape)
out = image + 0.5 * noise
out = np.clip(out, 0, 255)
out = np.uint8(out)
return out
def get_crop(image):
"""
random crop
"""
h, w, _ = image.shape
top_min = 1
top_max = 8
top_crop = int(random.randint(top_min, top_max))
top_crop = min(top_crop, h - 1)
crop_img = image.copy()
ratio = random.randint(0, 1)
if ratio:
crop_img = crop_img[top_crop:h, :, :]
else:
crop_img = crop_img[0:h - top_crop, :, :]
return crop_img
class Config:
"""
Config
"""
def __init__(self, ):
self.anglex = random.random() * 30
self.angley = random.random() * 15
self.anglez = random.random() * 10
self.fov = 42
self.r = 0
self.shearx = random.random() * 0.3
self.sheary = random.random() * 0.05
self.borderMode = cv2.BORDER_REPLICATE
def make(self, w, h, ang):
"""
make
"""
self.anglex = random.random() * 5 * flag()
self.angley = random.random() * 5 * flag()
self.anglez = -1 * random.random() * int(ang) * flag()
self.fov = 42
self.r = 0
self.shearx = 0
self.sheary = 0
self.borderMode = cv2.BORDER_REPLICATE
self.w = w
self.h = h
self.perspective = True
self.crop = True
self.affine = False
self.reverse = True
self.noise = True
self.jitter = True
self.blur = True
self.color = True
def rad(x):
"""
rad
"""
return x * np.pi / 180
def get_warpR(config):
"""
get_warpR
"""
anglex, angley, anglez, fov, w, h, r = \
config.anglex, config.angley, config.anglez, config.fov, config.w, config.h, config.r
if w > 69 and w < 112:
anglex = anglex * 1.5
z = np.sqrt(w**2 + h**2) / 2 / np.tan(rad(fov / 2))
# Homogeneous coordinate transformation matrix
rx = np.array([[1, 0, 0, 0],
[0, np.cos(rad(anglex)), -np.sin(rad(anglex)), 0], [
0,
-np.sin(rad(anglex)),
np.cos(rad(anglex)),
0,
], [0, 0, 0, 1]], np.float32)
ry = np.array([[np.cos(rad(angley)), 0, np.sin(rad(angley)), 0],
[0, 1, 0, 0], [
-np.sin(rad(angley)),
0,
np.cos(rad(angley)),
0,
], [0, 0, 0, 1]], np.float32)
rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0, 0],
[-np.sin(rad(anglez)), np.cos(rad(anglez)), 0, 0],
[0, 0, 1, 0], [0, 0, 0, 1]], np.float32)
r = rx.dot(ry).dot(rz)
# generate 4 points
pcenter = np.array([h / 2, w / 2, 0, 0], np.float32)
p1 = np.array([0, 0, 0, 0], np.float32) - pcenter
p2 = np.array([w, 0, 0, 0], np.float32) - pcenter
p3 = np.array([0, h, 0, 0], np.float32) - pcenter
p4 = np.array([w, h, 0, 0], np.float32) - pcenter
dst1 = r.dot(p1)
dst2 = r.dot(p2)
dst3 = r.dot(p3)
dst4 = r.dot(p4)
list_dst = np.array([dst1, dst2, dst3, dst4])
org = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32)
dst = np.zeros((4, 2), np.float32)
# Project onto the image plane
dst[:, 0] = list_dst[:, 0] * z / (z - list_dst[:, 2]) + pcenter[0]
dst[:, 1] = list_dst[:, 1] * z / (z - list_dst[:, 2]) + pcenter[1]
warpR = cv2.getPerspectiveTransform(org, dst)
dst1, dst2, dst3, dst4 = dst
r1 = int(min(dst1[1], dst2[1]))
r2 = int(max(dst3[1], dst4[1]))
c1 = int(min(dst1[0], dst3[0]))
c2 = int(max(dst2[0], dst4[0]))
try:
ratio = min(1.0 * h / (r2 - r1), 1.0 * w / (c2 - c1))
dx = -c1
dy = -r1
T1 = np.float32([[1., 0, dx], [0, 1., dy], [0, 0, 1.0 / ratio]])
ret = T1.dot(warpR)
except:
ratio = 1.0
T1 = np.float32([[1., 0, 0], [0, 1., 0], [0, 0, 1.]])
ret = T1
return ret, (-r1, -c1), ratio, dst
def get_warpAffine(config):
"""
get_warpAffine
"""
anglez = config.anglez
rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0],
[-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32)
return rz
def warp(img, ang):
"""
warp
"""
h, w, _ = img.shape
config = Config()
config.make(w, h, ang)
new_img = img
if config.perspective:
tp = random.randint(1, 100)
if tp >= 50:
warpR, (r1, c1), ratio, dst = get_warpR(config)
new_w = int(np.max(dst[:, 0])) - int(np.min(dst[:, 0]))
new_img = cv2.warpPerspective(
new_img,
warpR, (int(new_w * ratio), h),
borderMode=config.borderMode)
if config.crop:
img_height, img_width = img.shape[0:2]
tp = random.randint(1, 100)
if tp >= 50 and img_height >= 20 and img_width >= 20:
new_img = get_crop(new_img)
if config.affine:
warpT = get_warpAffine(config)
new_img = cv2.warpAffine(
new_img, warpT, (w, h), borderMode=config.borderMode)
if config.blur:
tp = random.randint(1, 100)
if tp >= 50:
new_img = blur(new_img)
if config.color:
tp = random.randint(1, 100)
if tp >= 50:
new_img = cvtColor(new_img)
if config.jitter:
new_img = jitter(new_img)
if config.noise:
tp = random.randint(1, 100)
if tp >= 50:
new_img = add_gasuss_noise(new_img)
if config.reverse:
tp = random.randint(1, 100)
if tp >= 50:
new_img = 255 - new_img
return new_img
def process_image(img,
image_shape,
label=None,
......@@ -96,7 +345,10 @@ def process_image(img,
loss_type=None,
max_text_length=None,
tps=None,
infer_mode=False):
infer_mode=False,
distort=False):
if distort:
img = warp(img, 10)
if infer_mode and char_ops.character_type == "ch" and not tps:
norm_img = resize_norm_img_chinese(img, image_shape)
else:
......
......@@ -30,12 +30,17 @@ class CharacterOps(object):
dict_character = list(self.character_str)
elif self.character_type == "ch":
character_dict_path = config['character_dict_path']
add_space = False
if 'use_space_char' in config:
add_space = config['use_space_char']
self.character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str += line
if add_space:
self.character_str += " "
dict_character = list(self.character_str)
elif self.character_type == "en_sensitive":
# same with ASTER setting (use 94 char).
......@@ -93,7 +98,7 @@ class CharacterOps(object):
if is_remove_duplicate:
if idx > 0 and text_index[idx - 1] == text_index[idx]:
continue
char_list.append(self.character[text_index[idx]])
char_list.append(self.character[int(text_index[idx])])
text = ''.join(char_list)
return text
......
......@@ -39,7 +39,8 @@ class TextRecognizer(object):
self.rec_algorithm = args.rec_algorithm
char_ops_params = {
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
if self.rec_algorithm != "RARE":
char_ops_params['loss_type'] = 'ctc'
......
......@@ -63,6 +63,7 @@ def parse_args():
"--rec_char_dict_path",
type=str,
default="./ppocr/utils/ppocr_keys_v1.txt")
parser.add_argument("--use_space_char", type=bool, default=True)
return parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册