未验证 提交 e0bf9c9c 编写于 作者: S shaohua.zhang 提交者: GitHub

Merge pull request #13 from PaddlePaddle/develop

update-2020-8-26
此差异已折叠。
[English](README.md) | 简体中文
## 简介
PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力使用者训练出更好的模型,并应用落地。
**近期更新**
- 2020.8.24 支持通过whl包安装使用PaddleOCR,具体参考[Paddleocr Package使用说明](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/whl.md)
- 2020.8.21 更新8月18日B站直播课回放和PPT,课节2,易学易用的OCR工具大礼包,[获取地址](https://aistudio.baidu.com/aistudio/education/group/info/1519)
- 2020.8.16 开源文本检测算法[SAST](https://arxiv.org/abs/1908.05498)和文本识别算法[SRN](https://arxiv.org/abs/2003.12294)
- 2020.7.23 发布7月21日B站直播课回放和PPT,课节1,PaddleOCR开源大礼包全面解读,[获取地址](https://aistudio.baidu.com/aistudio/course/introduce/1519)
- 2020.7.15 添加基于EasyEdge和Paddle-Lite的移动端DEMO,支持iOS和Android系统
- [more](./doc/doc_ch/update.md)
## 特性
- 超轻量级中文OCR模型,总模型仅8.6M
- 单模型支持中英文数字组合识别、竖排文本识别、长文本识别
- 检测模型DB(4.1M)+识别模型CRNN(4.5M)
- 实用通用中文OCR模型
- 多种预测推理部署方案,包括服务部署和端侧部署
- 多种文本检测训练算法,EAST、DB
- 多种文本识别训练算法,Rosetta、CRNN、STAR-Net、RARE
- 可运行于Linux、Windows、MacOS等多种系统
## 快速体验
<div align="center">
<img src="doc/imgs_results/11.jpg" width="800">
</div>
上图是超轻量级中文OCR模型效果展示,更多效果图请见[效果展示页面](./doc/doc_ch/visualization.md)
- 超轻量级中文OCR在线体验地址:https://www.paddlepaddle.org.cn/hub/scene/ocr
- 移动端DEMO体验(基于EasyEdge和Paddle-Lite, 支持iOS和Android系统):[安装包二维码获取地址](https://ai.baidu.com/easyedge/app/openSource?from=paddlelite)
Android手机也可以扫描下面二维码安装体验。
<div align="center">
<img src="./doc/ocr-android-easyedge.png" width = "200" height = "200" />
</div>
- [**中文OCR模型快速使用**](./doc/doc_ch/quickstart.md)
## 中文OCR模型列表
|模型名称|模型简介|检测模型地址|识别模型地址|支持空格的识别模型地址|
|-|-|-|-|-|
|chinese_db_crnn_mobile|超轻量级中文OCR模型|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar)|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance.tar)
|chinese_db_crnn_server|通用中文OCR模型|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db.tar)|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn.tar)|[inference模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_enhance_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_enhance.tar)
## 文档教程
- [快速安装](./doc/doc_ch/installation.md)
- [中文OCR模型快速使用](./doc/doc_ch/quickstart.md)
- 算法介绍
- [文本检测](#文本检测算法)
- [文本识别](#文本识别算法)
- [端到端OCR](#端到端OCR算法)
- 模型训练/评估
- [文本检测](./doc/doc_ch/detection.md)
- [文本识别](./doc/doc_ch/recognition.md)
- [yml参数配置文件介绍](./doc/doc_ch/config.md)
- [中文OCR训练预测技巧](./doc/doc_ch/tricks.md)
- 预测部署
- [基于Python预测引擎推理](./doc/doc_ch/inference.md)
- [基于C++预测引擎推理](./deploy/cpp_infer/readme.md)
- [服务化部署](./doc/doc_ch/serving.md)
- [端侧部署](./deploy/lite/readme.md)
- 模型量化压缩(coming soon)
- [Benchmark](./doc/doc_ch/benchmark.md)
- 数据集
- [通用中英文OCR数据集](./doc/doc_ch/datasets.md)
- [手写中文OCR数据集](./doc/doc_ch/handwritten_datasets.md)
- [垂类多语言OCR数据集](./doc/doc_ch/vertical_and_multilingual_datasets.md)
- [常用数据标注工具](./doc/doc_ch/data_annotation.md)
- [常用数据合成工具](./doc/doc_ch/data_synthesis.md)
- [FAQ](#FAQ)
- 效果展示
- [超轻量级中文OCR效果展示](#超轻量级中文OCR效果展示)
- [通用中文OCR效果展示](#通用中文OCR效果展示)
- [支持空格的中文OCR效果展示](#支持空格的中文OCR效果展示)
- [技术交流群](#欢迎加入PaddleOCR技术交流群)
- [参考文献](./doc/doc_ch/reference.md)
- [许可证书](#许可证书)
- [贡献代码](#贡献代码)
<a name="算法介绍"></a>
## 算法介绍
<a name="文本检测算法"></a>
### 1.文本检测算法
PaddleOCR开源的文本检测算法列表:
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))
- [x] DB([paper](https://arxiv.org/abs/1911.08947))
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))(百度自研)
在ICDAR2015文本检测公开数据集上,算法效果如下:
|模型|骨干网络|precision|recall|Hmean|下载链接|
|-|-|-|-|-|-|
|EAST|ResNet50_vd|88.18%|85.51%|86.82%|[下载链接](https://paddleocr.bj.bcebos.com/det_r50_vd_east.tar)|
|EAST|MobileNetV3|81.67%|79.83%|80.74%|[下载链接](https://paddleocr.bj.bcebos.com/det_mv3_east.tar)|
|DB|ResNet50_vd|83.79%|80.65%|82.19%|[下载链接](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)|
|DB|MobileNetV3|75.92%|73.18%|74.53%|[下载链接](https://paddleocr.bj.bcebos.com/det_mv3_db.tar)|
|SAST|ResNet50_vd|92.18%|82.96%|87.33%|[下载链接](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_icdar2015.tar)|
在Total-text文本检测公开数据集上,算法效果如下:
|模型|骨干网络|precision|recall|Hmean|下载链接|
|-|-|-|-|-|-|
|SAST|ResNet50_vd|88.74%|79.80%|84.03%|[下载链接](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_total_text.tar)|
**说明:** SAST模型训练额外加入了icdar2013、icdar2017、COCO-Text、ArT等公开数据集进行调优。PaddleOCR用到的经过整理格式的英文公开数据集下载:[百度云地址](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (提取码: 2bpi)
使用[LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt)街景数据集共3w张数据,训练中文检测模型的相关配置和预训练文件如下:
|模型|骨干网络|配置文件|预训练模型|
|-|-|-|-|
|超轻量中文模型|MobileNetV3|det_mv3_db.yml|[下载链接](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar)|
|通用中文OCR模型|ResNet50_vd|det_r50_vd_db.yml|[下载链接](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db.tar)|
* 注: 上述DB模型的训练和评估,需设置后处理参数box_thresh=0.6,unclip_ratio=1.5,使用不同数据集、不同模型训练,可调整这两个参数进行优化
PaddleOCR文本检测算法的训练和使用请参考文档教程中[模型训练/评估中的文本检测部分](./doc/doc_ch/detection.md)
<a name="文本识别算法"></a>
### 2.文本识别算法
PaddleOCR开源的文本识别算法列表:
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))(百度自研)
参考[DTRB](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|模型|骨干网络|Avg Accuracy|模型存储命名|下载链接|
|-|-|-|-|-|
|Rosetta|Resnet34_vd|80.24%|rec_r34_vd_none_none_ctc|[下载链接](https://paddleocr.bj.bcebos.com/rec_r34_vd_none_none_ctc.tar)|
|Rosetta|MobileNetV3|78.16%|rec_mv3_none_none_ctc|[下载链接](https://paddleocr.bj.bcebos.com/rec_mv3_none_none_ctc.tar)|
|CRNN|Resnet34_vd|82.20%|rec_r34_vd_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/rec_r34_vd_none_bilstm_ctc.tar)|
|CRNN|MobileNetV3|79.37%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/rec_mv3_none_bilstm_ctc.tar)|
|STAR-Net|Resnet34_vd|83.93%|rec_r34_vd_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_ctc.tar)|
|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_ctc.tar)|
|RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|[下载链接](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar)|
|RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[下载链接](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)|
|SRN|Resnet50_vd_fpn|88.33%|rec_r50fpn_vd_none_srn|[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)|
**说明:** SRN模型使用了数据扰动方法对上述提到对两个训练集进行增广,增广后的数据可以在[百度网盘](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA)上下载,提取码: y3ry。
原始论文使用两阶段训练平均精度为89.74%,PaddleOCR中使用one-stage训练,平均精度为88.33%。两种预训练权重均在[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)中。
使用[LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt)街景数据集根据真值将图crop出来30w数据,进行位置校准。此外基于LSVT语料生成500w合成数据训练中文模型,相关配置和预训练文件如下:
|模型|骨干网络|配置文件|预训练模型|
|-|-|-|-|
|超轻量中文模型|MobileNetV3|rec_chinese_lite_train.yml|[下载链接](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)|
|通用中文OCR模型|Resnet34_vd|rec_chinese_common_train.yml|[下载链接](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn.tar)|
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./doc/doc_ch/recognition.md)
<a name="端到端OCR算法"></a>
### 3.端到端OCR算法
- [ ] [End2End-PSL](https://arxiv.org/abs/1909.07808)(百度自研, coming soon)
## 效果展示
<a name="超轻量级中文OCR效果展示"></a>
### 1.超轻量级中文OCR效果展示 [more](./doc/doc_ch/visualization.md)
<div align="center">
<img src="doc/imgs_results/1.jpg" width="800">
</div>
<a name="通用中文OCR效果展示"></a>
### 2.通用中文OCR效果展示 [more](./doc/doc_ch/visualization.md)
<div align="center">
<img src="doc/imgs_results/chinese_db_crnn_server/11.jpg" width="800">
</div>
<a name="支持空格的中文OCR效果展示"></a>
### 3.支持空格的中文OCR效果展示 [more](./doc/doc_ch/visualization.md)
<div align="center">
<img src="doc/imgs_results/chinese_db_crnn_server/en_paper.jpg" width="800">
</div>
<a name="FAQ"></a>
## FAQ
1. **转换attention识别模型时报错:KeyError: 'predict'**
问题已解,请更新到最新代码。
2. **关于推理速度**
图片中的文字较多时,预测时间会增,可以使用--rec_batch_num设置更小预测batch num,默认值为30,可以改为10或其他数值。
3. **服务部署与移动端部署**
预计6月中下旬会先后发布基于Serving的服务部署方案和基于Paddle Lite的移动端部署方案,欢迎持续关注。
4. **自研算法发布时间**
自研算法SAST、SRN、End2End-PSL都将在7-8月陆续发布,敬请期待。
[more](./doc/doc_ch/FAQ.md)
<a name="欢迎加入PaddleOCR技术交流群"></a>
## 欢迎加入PaddleOCR技术交流群
请扫描下面二维码,完成问卷填写,获取加群二维码和OCR方向的炼丹秘籍
<div align="center">
<img src="./doc/joinus.jpg" width = "200" height = "200" />
</div>
<a name="许可证书"></a>
## 许可证书
本项目的发布受<a href="https://github.com/PaddlePaddle/PaddleOCR/blob/master/LICENSE">Apache 2.0 license</a>许可认证。
<a name="贡献代码"></a>
## 贡献代码
我们非常欢迎你为PaddleOCR贡献代码,也十分感谢你的反馈。
- 非常感谢 [Khanh Tran](https://github.com/xxxpsyduck)[Karl Horky](https://github.com/karlhorky) 贡献修改英文文档
- 非常感谢 [zhangxin](https://github.com/ZhangXinNan)([Blog](https://blog.csdn.net/sdlypyzq)) 贡献新的可视化方式、添加.gitgnore、处理手动设置PYTHONPATH环境变量的问题
- 非常感谢 [lyl120117](https://github.com/lyl120117) 贡献打印网络结构的代码
- 非常感谢 [xiangyubo](https://github.com/xiangyubo) 贡献手写中文OCR数据集
- 非常感谢 [authorfu](https://github.com/authorfu) 贡献Android和[xiadeye](https://github.com/xiadeye) 贡献IOS的demo代码
- 非常感谢 [BeyondYourself](https://github.com/BeyondYourself) 给PaddleOCR提了很多非常棒的建议,并简化了PaddleOCR的部分代码风格。
- 非常感谢 [tangmq](https://gitee.com/tangmq) 给PaddleOCR增加Docker化部署服务,支持快速发布可调用的Restful API服务。
English | [简体中文](README.md)
## Introduction
PaddleOCR aims to create rich, leading, and practical OCR tools that help users train better models and apply them into practice.
**Recent updates**
- 2020.8.24 Support the use of PaddleOCR through whl package installation,pelease refer [PaddleOCR Package](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/whl_en.md)
- 2020.8.16, Release text detection algorithm [SAST](https://arxiv.org/abs/1908.05498) and text recognition algorithm [SRN](https://arxiv.org/abs/2003.12294)
- 2020.7.23, Release the playback and PPT of live class on BiliBili station, PaddleOCR Introduction, [address](https://aistudio.baidu.com/aistudio/course/introduce/1519)
- 2020.7.15, Add mobile App demo , support both iOS and Android ( based on easyedge and Paddle Lite)
- 2020.7.15, Improve the deployment ability, add the C + + inference , serving deployment. In addition, the benchmarks of the ultra-lightweight OCR model are provided.
- 2020.7.15, Add several related datasets, data annotation and synthesis tools.
- [more](./doc/doc_en/update_en.md)
## Features
- Ultra-lightweight OCR model, total model size is only 8.6M
- Single model supports Chinese/English numbers combination recognition, vertical text recognition, long text recognition
- Detection model DB (4.1M) + recognition model CRNN (4.5M)
- Various text detection algorithms: EAST, DB
- Various text recognition algorithms: Rosetta, CRNN, STAR-Net, RARE
- Support Linux, Windows, macOS and other systems.
## Visualization
![](doc/imgs_results/11.jpg)
![](doc/imgs_results/img_10.jpg)
[More visualization](./doc/doc_en/visualization_en.md)
You can also quickly experience the ultra-lightweight OCR : [Online Experience](https://www.paddlepaddle.org.cn/hub/scene/ocr)
Mobile DEMO experience (based on EasyEdge and Paddle-Lite, supports iOS and Android systems): [Sign in to the website to obtain the QR code for installing the App](https://ai.baidu.com/easyedge/app/openSource?from=paddlelite)
Also, you can scan the QR code below to install the App (**Android support only**)
<div align="center">
<img src="./doc/ocr-android-easyedge.png" width = "200" height = "200" />
</div>
- [**OCR Quick Start**](./doc/doc_en/quickstart_en.md)
<a name="Supported-Chinese-model-list"></a>
### Supported Models:
|Model Name|Description |Detection Model link|Recognition Model link| Support for space Recognition Model link|
|-|-|-|-|-|
|db_crnn_mobile|ultra-lightweight OCR model|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_infer.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar) / [pre-train model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance.tar)
|db_crnn_server|General OCR model|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db_infer.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_infer.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_enhance_infer.tar) / [pre-train model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_enhance.tar)
## Tutorials
- [Installation](./doc/doc_en/installation_en.md)
- [Quick Start](./doc/doc_en/quickstart_en.md)
- Algorithm introduction
- [Text Detection Algorithm](#TEXTDETECTIONALGORITHM)
- [Text Recognition Algorithm](#TEXTRECOGNITIONALGORITHM)
- [END-TO-END OCR Algorithm](#ENDENDOCRALGORITHM)
- Model training/evaluation
- [Text Detection](./doc/doc_en/detection_en.md)
- [Text Recognition](./doc/doc_en/recognition_en.md)
- [Yml Configuration](./doc/doc_en/config_en.md)
- [Tricks](./doc/doc_en/tricks_en.md)
- Deployment
- [Python Inference](./doc/doc_en/inference_en.md)
- [C++ Inference](./deploy/cpp_infer/readme_en.md)
- [Serving](./doc/doc_en/serving_en.md)
- [Mobile](./deploy/lite/readme_en.md)
- Model Quantization and Compression (coming soon)
- [Benchmark](./doc/doc_en/benchmark_en.md)
- Datasets
- [General OCR Datasets(Chinese/English)](./doc/doc_en/datasets_en.md)
- [HandWritten_OCR_Datasets(Chinese)](./doc/doc_en/handwritten_datasets_en.md)
- [Various OCR Datasets(multilingual)](./doc/doc_en/vertical_and_multilingual_datasets_en.md)
- [Data Annotation Tools](./doc/doc_en/data_annotation_en.md)
- [Data Synthesis Tools](./doc/doc_en/data_synthesis_en.md)
- [FAQ](#FAQ)
- Visualization
- [Ultra-lightweight Chinese/English OCR Visualization](#UCOCRVIS)
- [General Chinese/English OCR Visualization](#GeOCRVIS)
- [Chinese/English OCR Visualization (Support Space Recognition )](#SpaceOCRVIS)
- [Community](#Community)
- [References](./doc/doc_en/reference_en.md)
- [License](#LICENSE)
- [Contribution](#CONTRIBUTION)
<a name="TEXTDETECTIONALGORITHM"></a>
## Text Detection Algorithm
PaddleOCR open source text detection algorithms list:
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))
- [x] DB([paper](https://arxiv.org/abs/1911.08947))
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))(Baidu Self-Research)
On the ICDAR2015 dataset, the text detection result is as follows:
|Model|Backbone|precision|recall|Hmean|Download link|
|-|-|-|-|-|-|
|EAST|ResNet50_vd|88.18%|85.51%|86.82%|[Download link](https://paddleocr.bj.bcebos.com/det_r50_vd_east.tar)|
|EAST|MobileNetV3|81.67%|79.83%|80.74%|[Download link](https://paddleocr.bj.bcebos.com/det_mv3_east.tar)|
|DB|ResNet50_vd|83.79%|80.65%|82.19%|[Download link](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)|
|DB|MobileNetV3|75.92%|73.18%|74.53%|[Download link](https://paddleocr.bj.bcebos.com/det_mv3_db.tar)|
|SAST|ResNet50_vd|92.18%|82.96%|87.33%|[Download link](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_icdar2015.tar)|
On Total-Text dataset, the text detection result is as follows:
|Model|Backbone|precision|recall|Hmean|Download link|
|-|-|-|-|-|-|
|SAST|ResNet50_vd|88.74%|79.80%|84.03%|[Download link](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_total_text.tar)|
**Note:** Additional data, like icdar2013, icdar2017, COCO-Text, ArT, was added to the model training of SAST. Download English public dataset in organized format used by PaddleOCR from [Baidu Drive](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (download code: 2bpi).
For use of [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/datasets_en.md#1-icdar2019-lsvt) street view dataset with a total of 3w training data,the related configuration and pre-trained models for text detection task are as follows:
|Model|Backbone|Configuration file|Pre-trained model|
|-|-|-|-|
|ultra-lightweight OCR model|MobileNetV3|det_mv3_db.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar)|
|General OCR model|ResNet50_vd|det_r50_vd_db.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db.tar)|
* Note: For the training and evaluation of the above DB model, post-processing parameters box_thresh=0.6 and unclip_ratio=1.5 need to be set. If using different datasets and different models for training, these two parameters can be adjusted for better result.
For the training guide and use of PaddleOCR text detection algorithms, please refer to the document [Text detection model training/evaluation/prediction](./doc/doc_en/detection_en.md)
<a name="TEXTRECOGNITIONALGORITHM"></a>
## Text Recognition Algorithm
PaddleOCR open-source text recognition algorithms list:
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))(Baidu Self-Research)
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|Model|Backbone|Avg Accuracy|Module combination|Download link|
|-|-|-|-|-|
|Rosetta|Resnet34_vd|80.24%|rec_r34_vd_none_none_ctc|[Download link](https://paddleocr.bj.bcebos.com/rec_r34_vd_none_none_ctc.tar)|
|Rosetta|MobileNetV3|78.16%|rec_mv3_none_none_ctc|[Download link](https://paddleocr.bj.bcebos.com/rec_mv3_none_none_ctc.tar)|
|CRNN|Resnet34_vd|82.20%|rec_r34_vd_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/rec_r34_vd_none_bilstm_ctc.tar)|
|CRNN|MobileNetV3|79.37%|rec_mv3_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/rec_mv3_none_bilstm_ctc.tar)|
|STAR-Net|Resnet34_vd|83.93%|rec_r34_vd_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_ctc.tar)|
|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_ctc.tar)|
|RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar)|
|RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)|
|SRN|Resnet50_vd_fpn|88.33%|rec_r50fpn_vd_none_srn|[Download link](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)|
**Note:** SRN model uses data expansion method to expand the two training sets mentioned above, and the expanded data can be downloaded from [Baidu Drive](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA) (download code: y3ry).
The average accuracy of the two-stage training in the original paper is 89.74%, and that of one stage training in paddleocr is 88.33%. Both pre-trained weights can be downloaded [here](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar).
We use [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/datasets_en.md#1-icdar2019-lsvt) dataset and cropout 30w training data from original photos by using position groundtruth and make some calibration needed. In addition, based on the LSVT corpus, 500w synthetic data is generated to train the model. The related configuration and pre-trained models are as follows:
|Model|Backbone|Configuration file|Pre-trained model|
|-|-|-|-|
|ultra-lightweight OCR model|MobileNetV3|rec_chinese_lite_train.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar) & [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance.tar)|
|General OCR model|Resnet34_vd|rec_chinese_common_train.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_enhance_infer.tar) & [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_enhance.tar)|
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./doc/doc_en/recognition_en.md)
<a name="ENDENDOCRALGORITHM"></a>
## END-TO-END OCR Algorithm
- [ ] [End2End-PSL](https://arxiv.org/abs/1909.07808)(Baidu Self-Research, coming soon)
## Visualization
<a name="UCOCRVIS"></a>
### 1.Ultra-lightweight Chinese/English OCR Visualization [more](./doc/doc_en/visualization_en.md)
<div align="center">
<img src="doc/imgs_results/1.jpg" width="800">
</div>
<a name="GeOCRVIS"></a>
### 2. General Chinese/English OCR Visualization [more](./doc/doc_en/visualization_en.md)
<div align="center">
<img src="doc/imgs_results/chinese_db_crnn_server/11.jpg" width="800">
</div>
<a name="SpaceOCRVIS"></a>
### 3.Chinese/English OCR Visualization (Space_support) [more](./doc/doc_en/visualization_en.md)
<div align="center">
<img src="doc/imgs_results/chinese_db_crnn_server/en_paper.jpg" width="800">
</div>
<a name="FAQ"></a>
## FAQ
1. Error when using attention-based recognition model: KeyError: 'predict'
The inference of recognition model based on attention loss is still being debugged. For Chinese text recognition, it is recommended to choose the recognition model based on CTC loss first. In practice, it is also found that the recognition model based on attention loss is not as effective as the one based on CTC loss.
2. About inference speed
When there are a lot of texts in the picture, the prediction time will increase. You can use `--rec_batch_num` to set a smaller prediction batch size. The default value is 30, which can be changed to 10 or other values.
3. Service deployment and mobile deployment
It is expected that the service deployment based on Serving and the mobile deployment based on Paddle Lite will be released successively in mid-to-late June. Stay tuned for more updates.
4. Release time of self-developed algorithm
Baidu Self-developed algorithms such as SAST, SRN and end2end PSL will be released in June or July. Please be patient.
[more](./doc/doc_en/FAQ_en.md)
<a name="Community"></a>
## Community
Scan the QR code below with your wechat and completing the questionnaire, you can access to offical technical exchange group.
<div align="center">
<img src="./doc/joinus.jpg" width = "200" height = "200" />
</div>
<a name="LICENSE"></a>
## License
This project is released under <a href="https://github.com/PaddlePaddle/PaddleOCR/blob/master/LICENSE">Apache 2.0 license</a>
<a name="CONTRIBUTION"></a>
## Contribution
We welcome all the contributions to PaddleOCR and appreciate for your feedback very much.
- Many thanks to [Khanh Tran](https://github.com/xxxpsyduck) and [Karl Horky](https://github.com/karlhorky) for contributing and revising the English documentation.
- Many thanks to [zhangxin](https://github.com/ZhangXinNan) for contributing the new visualize function、add .gitgnore and discard set PYTHONPATH manually.
- Many thanks to [lyl120117](https://github.com/lyl120117) for contributing the code for printing the network structure.
- Thanks [xiangyubo](https://github.com/xiangyubo) for contributing the handwritten Chinese OCR datasets.
- Thanks [authorfu](https://github.com/authorfu) for contributing Android demo and [xiadeye](https://github.com/xiadeye) contributing iOS demo, respectively.
- Thanks [BeyondYourself](https://github.com/BeyondYourself) for contributing many great suggestions and simplifying part of the code style.
- Thanks [tangmq](https://gitee.com/tangmq) for contributing Dockerized deployment services to PaddleOCR and supporting the rapid release of callable Restful API services.
...@@ -27,7 +27,7 @@ Architecture: ...@@ -27,7 +27,7 @@ Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
Backbone: Backbone:
function: ppocr.modeling.backbones.rec_resnet50_fpn,ResNet function: ppocr.modeling.backbones.rec_resnet_fpn,ResNet
layers: 50 layers: 50
Head: Head:
......
...@@ -36,4 +36,6 @@ def read_params(): ...@@ -36,4 +36,6 @@ def read_params():
# cfg.rec_char_dict_path = "./ppocr/utils/ppocr_keys_v1.txt" # cfg.rec_char_dict_path = "./ppocr/utils/ppocr_keys_v1.txt"
# cfg.use_space_char = True # cfg.use_space_char = True
return cfg cfg.use_zero_copy_run = False
\ No newline at end of file
return cfg
...@@ -38,4 +38,6 @@ def read_params(): ...@@ -38,4 +38,6 @@ def read_params():
cfg.rec_char_dict_path = "./ppocr/utils/ppocr_keys_v1.txt" cfg.rec_char_dict_path = "./ppocr/utils/ppocr_keys_v1.txt"
cfg.use_space_char = True cfg.use_space_char = True
return cfg cfg.use_zero_copy_run = False
\ No newline at end of file
return cfg
...@@ -38,4 +38,6 @@ def read_params(): ...@@ -38,4 +38,6 @@ def read_params():
cfg.rec_char_dict_path = "./ppocr/utils/ppocr_keys_v1.txt" cfg.rec_char_dict_path = "./ppocr/utils/ppocr_keys_v1.txt"
cfg.use_space_char = True cfg.use_space_char = True
return cfg cfg.use_zero_copy_run = False
\ No newline at end of file
return cfg
此差异已折叠。
...@@ -45,7 +45,7 @@ At present, the open source model, dataset and magnitude are as follows: ...@@ -45,7 +45,7 @@ At present, the open source model, dataset and magnitude are as follows:
Among them, the public datasets are opensourced, users can search and download by themselves, or refer to [Chinese data set](./datasets_en.md), synthetic data is not opensourced, users can use open-source synthesis tools to synthesize data themselves. Current available synthesis tools include [text_renderer](https://github.com/Sanster/text_renderer), [SynthText](https://github.com/ankush-me/SynthText), [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator), etc. Among them, the public datasets are opensourced, users can search and download by themselves, or refer to [Chinese data set](./datasets_en.md), synthetic data is not opensourced, users can use open-source synthesis tools to synthesize data themselves. Current available synthesis tools include [text_renderer](https://github.com/Sanster/text_renderer), [SynthText](https://github.com/ankush-me/SynthText), [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator), etc.
10. **Error in using the model with TPS module for prediction** 10. **Error in using the model with TPS module for prediction**
Error message: Input(X) dims[3] and Input(Grid) dims[2] should be equal, but received X dimension[3](108) != Grid dimension[2](100) Error message: Input(X) dims[3] and Input(Grid) dims[2] should be equal, but received X dimension[3]\(108) != Grid dimension[2]\(100)
Solution:TPS does not support variable shape. Please set --rec_image_shape='3,32,100' and --rec_char_type='en' Solution:TPS does not support variable shape. Please set --rec_image_shape='3,32,100' and --rec_char_type='en'
11. **Custom dictionary used during training, the recognition results show that words do not appear in the dictionary** 11. **Custom dictionary used during training, the recognition results show that words do not appear in the dictionary**
......
...@@ -214,6 +214,8 @@ class SimpleReader(object): ...@@ -214,6 +214,8 @@ class SimpleReader(object):
self.mode = params['mode'] self.mode = params['mode']
self.infer_img = params['infer_img'] self.infer_img = params['infer_img']
self.use_tps = False self.use_tps = False
if "num_heads" in params:
self.num_heads = params['num_heads']
if "tps" in params: if "tps" in params:
self.use_tps = True self.use_tps = True
self.use_distort = False self.use_distort = False
...@@ -251,12 +253,19 @@ class SimpleReader(object): ...@@ -251,12 +253,19 @@ class SimpleReader(object):
img = cv2.imread(single_img) img = cv2.imread(single_img)
if img.shape[-1] == 1 or len(list(img.shape)) == 2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image( if self.loss_type == 'srn':
img=img, norm_img = process_image_srn(
image_shape=self.image_shape, img=img,
char_ops=self.char_ops, image_shape=self.image_shape,
tps=self.use_tps, num_heads=self.num_heads,
infer_mode=True) max_text_length=self.max_text_length)
else:
norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.use_tps,
infer_mode=True)
yield norm_img yield norm_img
else: else:
with open(self.label_file_path, "rb") as fin: with open(self.label_file_path, "rb") as fin:
...@@ -286,14 +295,25 @@ class SimpleReader(object): ...@@ -286,14 +295,25 @@ class SimpleReader(object):
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
label = substr[1] label = substr[1]
outs = process_image( if self.loss_type == "srn":
img=img, outs = process_image_srn(
image_shape=self.image_shape, img=img,
label=label, image_shape=self.image_shape,
char_ops=self.char_ops, num_heads=self.num_heads,
loss_type=self.loss_type, max_text_length=self.max_text_length,
max_text_length=self.max_text_length, label=label,
distort=self.use_distort) char_ops=self.char_ops,
loss_type=self.loss_type)
else:
outs = process_image(
img=img,
image_shape=self.image_shape,
label=label,
char_ops=self.char_ops,
loss_type=self.loss_type,
max_text_length=self.max_text_length,
distort=self.use_distort)
if outs is None: if outs is None:
continue continue
yield outs yield outs
......
...@@ -410,7 +410,8 @@ def resize_norm_img_srn(img, image_shape): ...@@ -410,7 +410,8 @@ def resize_norm_img_srn(img, image_shape):
def srn_other_inputs(image_shape, def srn_other_inputs(image_shape,
num_heads, num_heads,
max_text_length): max_text_length,
char_num):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8)) feature_dim = int((imgH / 8) * (imgW / 8))
...@@ -418,7 +419,7 @@ def srn_other_inputs(image_shape, ...@@ -418,7 +419,7 @@ def srn_other_inputs(image_shape,
encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64') encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64') gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64')
lbl_weight = np.array([37] * max_text_length).reshape((-1,1)).astype('int64') lbl_weight = np.array([int(char_num-1)] * max_text_length).reshape((-1,1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length]) gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length])
...@@ -441,17 +442,18 @@ def process_image_srn(img, ...@@ -441,17 +442,18 @@ def process_image_srn(img,
loss_type=None): loss_type=None):
norm_img = resize_norm_img_srn(img, image_shape) norm_img = resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
char_num = char_ops.get_char_num()
[lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ [lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
srn_other_inputs(image_shape, num_heads, max_text_length) srn_other_inputs(image_shape, num_heads, max_text_length,char_num)
if label is not None: if label is not None:
char_num = char_ops.get_char_num()
text = char_ops.encode(label) text = char_ops.encode(label)
if len(text) == 0 or len(text) > max_text_length: if len(text) == 0 or len(text) > max_text_length:
return None return None
else: else:
if loss_type == "srn": if loss_type == "srn":
text_padded = [37] * max_text_length text_padded = [int(char_num-1)] * max_text_length
for i in range(len(text)): for i in range(len(text)):
text_padded[i] = text[i] text_padded[i] = text[i]
lbl_weight[i] = [1.0] lbl_weight[i] = [1.0]
......
...@@ -22,12 +22,12 @@ import paddle ...@@ -22,12 +22,12 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
__all__ = [
__all__ = ["ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"] "ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"
]
Trainable = True Trainable = True
w_nolr = fluid.ParamAttr( w_nolr = fluid.ParamAttr(trainable=Trainable)
trainable = Trainable)
train_parameters = { train_parameters = {
"input_size": [3, 224, 224], "input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406], "input_mean": [0.485, 0.456, 0.406],
...@@ -40,12 +40,12 @@ train_parameters = { ...@@ -40,12 +40,12 @@ train_parameters = {
} }
} }
class ResNet(): class ResNet():
def __init__(self, params): def __init__(self, params):
self.layers = params['layers'] self.layers = params['layers']
self.params = train_parameters self.params = train_parameters
def __call__(self, input): def __call__(self, input):
layers = self.layers layers = self.layers
supported_layers = [18, 34, 50, 101, 152] supported_layers = [18, 34, 50, 101, 152]
...@@ -60,12 +60,17 @@ class ResNet(): ...@@ -60,12 +60,17 @@ class ResNet():
depth = [3, 4, 23, 3] depth = [3, 4, 23, 3]
elif layers == 152: elif layers == 152:
depth = [3, 8, 36, 3] depth = [3, 8, 36, 3]
stride_list = [(2,2),(2,2),(1,1),(1,1)] stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)]
num_filters = [64, 128, 256, 512] num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer( conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu', name="conv1") input=input,
F = [] num_filters=64,
filter_size=7,
stride=2,
act='relu',
name="conv1")
F = []
if layers >= 50: if layers >= 50:
for block in range(len(depth)): for block in range(len(depth)):
for i in range(depth[block]): for i in range(depth[block]):
...@@ -79,26 +84,67 @@ class ResNet(): ...@@ -79,26 +84,67 @@ class ResNet():
conv = self.bottleneck_block( conv = self.bottleneck_block(
input=conv, input=conv,
num_filters=num_filters[block], num_filters=num_filters[block],
stride=stride_list[block] if i == 0 else 1, name=conv_name) stride=stride_list[block] if i == 0 else 1,
name=conv_name)
F.append(conv)
else:
for block in range(len(depth)):
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
if i == 0 and block != 0:
stride = (2, 1)
else:
stride = (1, 1)
conv = self.basic_block(
input=conv,
num_filters=num_filters[block],
stride=stride,
if_first=block == i == 0,
name=conv_name)
F.append(conv) F.append(conv)
base = F[-1] base = F[-1]
for i in [-2, -3]: for i in [-2, -3]:
b, c, w, h = F[i].shape b, c, w, h = F[i].shape
if (w,h) == base.shape[2:]: if (w, h) == base.shape[2:]:
base = base base = base
else: else:
base = fluid.layers.conv2d_transpose( input=base, num_filters=c,filter_size=4, stride=2, base = fluid.layers.conv2d_transpose(
padding=1,act=None, input=base,
num_filters=c,
filter_size=4,
stride=2,
padding=1,
act=None,
param_attr=w_nolr, param_attr=w_nolr,
bias_attr=w_nolr) bias_attr=w_nolr)
base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr) base = fluid.layers.batch_norm(
base, act="relu", param_attr=w_nolr, bias_attr=w_nolr)
base = fluid.layers.concat([base, F[i]], axis=1) base = fluid.layers.concat([base, F[i]], axis=1)
base = fluid.layers.conv2d(base, num_filters=c, filter_size=1, param_attr=w_nolr, bias_attr=w_nolr) base = fluid.layers.conv2d(
base = fluid.layers.conv2d(base, num_filters=c, filter_size=3,padding = 1, param_attr=w_nolr, bias_attr=w_nolr) base,
base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr) num_filters=c,
filter_size=1,
base = fluid.layers.conv2d(base, num_filters=512, filter_size=1,bias_attr=w_nolr,param_attr=w_nolr) param_attr=w_nolr,
bias_attr=w_nolr)
base = fluid.layers.conv2d(
base,
num_filters=c,
filter_size=3,
padding=1,
param_attr=w_nolr,
bias_attr=w_nolr)
base = fluid.layers.batch_norm(
base, act="relu", param_attr=w_nolr, bias_attr=w_nolr)
base = fluid.layers.conv2d(
base,
num_filters=512,
filter_size=1,
bias_attr=w_nolr,
param_attr=w_nolr)
return base return base
...@@ -113,13 +159,14 @@ class ResNet(): ...@@ -113,13 +159,14 @@ class ResNet():
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(
input=input, input=input,
num_filters=num_filters, num_filters=num_filters,
filter_size= 2 if stride==(1,1) else filter_size, filter_size=2 if stride == (1, 1) else filter_size,
dilation = 2 if stride==(1,1) else 1, dilation=2 if stride == (1, 1) else 1,
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr(name=name + "_weights",trainable = Trainable), param_attr=ParamAttr(
name=name + "_weights", trainable=Trainable),
bias_attr=False, bias_attr=False,
name=name + '.conv2d.output.1') name=name + '.conv2d.output.1')
...@@ -127,28 +174,35 @@ class ResNet(): ...@@ -127,28 +174,35 @@ class ResNet():
bn_name = "bn_" + name bn_name = "bn_" + name
else: else:
bn_name = "bn" + name[3:] bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(input=conv, return fluid.layers.batch_norm(
act=act, input=conv,
name=bn_name + '.output.1', act=act,
param_attr=ParamAttr(name=bn_name + '_scale',trainable = Trainable), name=bn_name + '.output.1',
bias_attr=ParamAttr(bn_name + '_offset',trainable = Trainable), param_attr=ParamAttr(
moving_mean_name=bn_name + '_mean', name=bn_name + '_scale', trainable=Trainable),
moving_variance_name=bn_name + '_variance', ) bias_attr=ParamAttr(
bn_name + '_offset', trainable=Trainable),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
def shortcut(self, input, ch_out, stride, is_first, name): def shortcut(self, input, ch_out, stride, is_first, name):
ch_in = input.shape[1] ch_in = input.shape[1]
if ch_in != ch_out or stride != 1 or is_first == True: if ch_in != ch_out or stride != 1 or is_first == True:
if stride == (1,1): if stride == (1, 1):
return self.conv_bn_layer(input, ch_out, 1, 1, name=name) return self.conv_bn_layer(input, ch_out, 1, 1, name=name)
else: #stride == (2,2) else: #stride == (2,2)
return self.conv_bn_layer(input, ch_out, 1, stride, name=name) return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else: else:
return input return input
def bottleneck_block(self, input, num_filters, stride, name): def bottleneck_block(self, input, num_filters, stride, name):
conv0 = self.conv_bn_layer( conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu', name=name + "_branch2a") input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
conv1 = self.conv_bn_layer( conv1 = self.conv_bn_layer(
input=conv0, input=conv0,
num_filters=num_filters, num_filters=num_filters,
...@@ -157,16 +211,36 @@ class ResNet(): ...@@ -157,16 +211,36 @@ class ResNet():
act='relu', act='relu',
name=name + "_branch2b") name=name + "_branch2b")
conv2 = self.conv_bn_layer( conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name + "_branch2c") input=conv1,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
short = self.shortcut(input, num_filters * 4, stride, is_first=False, name=name + "_branch1") short = self.shortcut(
input,
num_filters * 4,
stride,
is_first=False,
name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu', name=name + ".add.output.5") return fluid.layers.elementwise_add(
x=short, y=conv2, act='relu', name=name + ".add.output.5")
def basic_block(self, input, num_filters, stride, is_first, name): def basic_block(self, input, num_filters, stride, is_first, name):
conv0 = self.conv_bn_layer(input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride, conv0 = self.conv_bn_layer(
name=name + "_branch2a") input=input,
conv1 = self.conv_bn_layer(input=conv0, num_filters=num_filters, filter_size=3, act=None, num_filters=num_filters,
name=name + "_branch2b") filter_size=3,
short = self.shortcut(input, num_filters, stride, is_first, name=name + "_branch1") act='relu',
stride=stride,
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
short = self.shortcut(
input, num_filters, stride, is_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu') return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
...@@ -26,8 +26,6 @@ class CharacterOps(object): ...@@ -26,8 +26,6 @@ class CharacterOps(object):
self.character_type = config['character_type'] self.character_type = config['character_type']
self.loss_type = config['loss_type'] self.loss_type = config['loss_type']
self.max_text_len = config['max_text_length'] self.max_text_len = config['max_text_length']
if self.loss_type == "srn" and self.character_type != "en":
raise Exception("SRN can only support in character_type == en")
if self.character_type == "en": if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
...@@ -160,13 +158,15 @@ def cal_predicts_accuracy_srn(char_ops, ...@@ -160,13 +158,15 @@ def cal_predicts_accuracy_srn(char_ops,
acc_num = 0 acc_num = 0
img_num = 0 img_num = 0
char_num = char_ops.get_char_num()
total_len = preds.shape[0] total_len = preds.shape[0]
img_num = int(total_len / max_text_len) img_num = int(total_len / max_text_len)
for i in range(img_num): for i in range(img_num):
cur_label = [] cur_label = []
cur_pred = [] cur_pred = []
for j in range(max_text_len): for j in range(max_text_len):
if labels[j + i * max_text_len] != 37: #0 if labels[j + i * max_text_len] != int(char_num-1): #0
cur_label.append(labels[j + i * max_text_len][0]) cur_label.append(labels[j + i * max_text_len][0])
else: else:
break break
...@@ -178,7 +178,7 @@ def cal_predicts_accuracy_srn(char_ops, ...@@ -178,7 +178,7 @@ def cal_predicts_accuracy_srn(char_ops,
elif j == len(cur_label) and j == max_text_len: elif j == len(cur_label) and j == max_text_len:
acc_num += 1 acc_num += 1
break break
elif j == len(cur_label) and preds[j + i * max_text_len][0] == 37: elif j == len(cur_label) and preds[j + i * max_text_len][0] == int(char_num-1):
acc_num += 1 acc_num += 1
break break
acc = acc_num * 1.0 / img_num acc = acc_num * 1.0 / img_num
......
...@@ -140,12 +140,12 @@ def main(): ...@@ -140,12 +140,12 @@ def main():
preds = preds.reshape(-1) preds = preds.reshape(-1)
preds_text = char_ops.decode(preds) preds_text = char_ops.decode(preds)
elif loss_type == "srn": elif loss_type == "srn":
cur_pred = [] char_num = char_ops.get_char_num()
preds = np.array(predict[0]) preds = np.array(predict[0])
preds = preds.reshape(-1) preds = preds.reshape(-1)
probs = np.array(predict[1]) probs = np.array(predict[1])
ind = np.argmax(probs, axis=1) ind = np.argmax(probs, axis=1)
valid_ind = np.where(preds != 37)[0] valid_ind = np.where(preds != int(char_num-1))[0]
if len(valid_ind) == 0: if len(valid_ind) == 0:
continue continue
score = np.mean(probs[valid_ind, ind[valid_ind]]) score = np.mean(probs[valid_ind, ind[valid_ind]])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册