提交 a8ba8963 编写于 作者: xuyang2233's avatar xuyang2233

merge cpnflict 0815

Global:
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/sr/sr_tsrn_transformer_strock/
save_epoch_step: 3
# evaluation is run every 2000 iterations
eval_batch_step: [0, 1000]
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir: sr_output
use_visualdl: False
infer_img: doc/imgs_words_en/word_52.png
# for data or label process
character_dict_path: ./train_data/srdata/english_decomposition.txt
max_text_length: 100
infer_mode: False
use_space_char: False
save_res_path: ./output/sr/predicts_gestalt.txt
Optimizer:
name: Adam
beta1: 0.5
beta2: 0.999
clip_norm: 0.25
lr:
learning_rate: 0.0001
Architecture:
model_type: sr
algorithm: Gestalt
Transform:
name: TSRN
STN: True
infer_mode: False
Loss:
name: StrokeFocusLoss
character_dict_path: ./train_data/srdata/english_decomposition.txt
PostProcess:
name: None
Metric:
name: SRMetric
main_indicator: all
Train:
dataset:
name: LMDBDataSetSR
data_dir: ./train_data/srdata/train
transforms:
- SRResize:
imgH: 32
imgW: 128
down_sample_scale: 2
- SRLabelEncode: # Class handling label
- KeepKeys:
keep_keys: ['img_lr', 'img_hr', 'length', 'input_tensor', 'label'] # dataloader will return list in this order
loader:
shuffle: False
batch_size_per_card: 16
drop_last: True
num_workers: 4
Eval:
dataset:
name: LMDBDataSetSR
data_dir: ./train_data/srdata/test
transforms:
- SRResize:
imgH: 32
imgW: 128
down_sample_scale: 2
- SRLabelEncode: # Class handling label
- KeepKeys:
keep_keys: ['img_lr', 'img_hr','length', 'input_tensor', 'label'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 16
num_workers: 4
# 关键信息抽取算法-LayoutXLM
- [1. 算法简介](#1-算法简介)
- [2. 环境配置](#2-环境配置)
- [3. 模型训练、评估、预测](#3-模型训练评估预测)
- [4. 推理部署](#4-推理部署)
- [4.1 Python推理](#41-python推理)
- [4.2 C++推理部署](#42-推理部署)
- [4.3 Serving服务化部署](#43-serving服务化部署)
- [4.4 更多推理部署](#44-更多推理部署)
- [5. FAQ](#5-faq)
- [引用](#引用)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836)
>
> Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei
>
> 2021
在XFUND_zh数据集上,算法复现效果如下:
|模型|骨干网络|任务|配置文件|hmean|下载链接|
| --- | --- |--|--- | --- | --- |
|LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)|
|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[推理模型(coming soon)]()|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
请参考[关键信息抽取教程](./kie.md)。PaddleOCR对代码进行了模块化,训练不同的关键信息抽取模型只需要**更换配置文件**即可。
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于LayoutXLM模型的关键信息抽取过程。
首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。
``` bash
wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar
tar -xf ser_LayoutXLM_xfun_zh.tar
python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh/best_accuracy Global.save_inference_dir=./inference/ser_layoutxlm
```
LayoutXLM模型基于SER任务进行推理,可以执行如下命令:
```bash
cd ppstructure
python3 vqa/predict_vqa_token_ser.py \
--vqa_algorithm=LayoutXLM \
--ser_model_dir=../inference/ser_layoutxlm_infer \
--image_dir=./docs/vqa/input/zh_val_42.jpg \
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
--vis_font_path=../doc/fonts/simfang.ttf
```
SER可视化结果默认保存到`./output`文件夹里面,结果示例如下:
<div align="center">
<img src="../../ppstructure/docs/vqa/result_ser/zh_val_42_ser.jpg" width="800">
</div>
<a name="4-2"></a>
### 4.2 C++推理部署
暂不支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@article{DBLP:journals/corr/abs-2104-08836,
author = {Yiheng Xu and
Tengchao Lv and
Lei Cui and
Guoxin Wang and
Yijuan Lu and
Dinei Flor{\^{e}}ncio and
Cha Zhang and
Furu Wei},
title = {LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich
Document Understanding},
journal = {CoRR},
volume = {abs/2104.08836},
year = {2021},
url = {https://arxiv.org/abs/2104.08836},
eprinttype = {arXiv},
eprint = {2104.08836},
timestamp = {Thu, 14 Oct 2021 09:17:23 +0200},
biburl = {https://dblp.org/rec/journals/corr/abs-2104-08836.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
@article{DBLP:journals/corr/abs-1912-13318,
author = {Yiheng Xu and
Minghao Li and
Lei Cui and
Shaohan Huang and
Furu Wei and
Ming Zhou},
title = {LayoutLM: Pre-training of Text and Layout for Document Image Understanding},
journal = {CoRR},
volume = {abs/1912.13318},
year = {2019},
url = {http://arxiv.org/abs/1912.13318},
eprinttype = {arXiv},
eprint = {1912.13318},
timestamp = {Mon, 01 Jun 2020 16:20:46 +0200},
biburl = {https://dblp.org/rec/journals/corr/abs-1912-13318.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
@article{DBLP:journals/corr/abs-2012-14740,
author = {Yang Xu and
Yiheng Xu and
Tengchao Lv and
Lei Cui and
Furu Wei and
Guoxin Wang and
Yijuan Lu and
Dinei A. F. Flor{\^{e}}ncio and
Cha Zhang and
Wanxiang Che and
Min Zhang and
Lidong Zhou},
title = {LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding},
journal = {CoRR},
volume = {abs/2012.14740},
year = {2020},
url = {https://arxiv.org/abs/2012.14740},
eprinttype = {arXiv},
eprint = {2012.14740},
timestamp = {Tue, 27 Jul 2021 09:53:52 +0200},
biburl = {https://dblp.org/rec/journals/corr/abs-2012-14740.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
# 关键信息抽取算法-SDMGR
- [1. 算法简介](#1-算法简介)
- [2. 环境配置](#2-环境配置)
- [3. 模型训练、评估、预测](#3-模型训练评估预测)
- [3.1 模型训练](#31-模型训练)
- [3.2 模型评估](#32-模型评估)
- [3.3 模型预测](#33-模型预测)
- [4. 推理部署](#4-推理部署)
- [4.1 Python推理](#41-python推理)
- [4.2 C++推理部署](#42-c推理部署)
- [4.3 Serving服务化部署](#43-serving服务化部署)
- [4.4 更多推理部署](#44-更多推理部署)
- [5. FAQ](#5-faq)
- [引用](#引用)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [Spatial Dual-Modality Graph Reasoning for Key Information Extraction](https://arxiv.org/abs/2103.14470)
>
> Hongbin Sun and Zhanghui Kuang and Xiaoyu Yue and Chenhao Lin and Wayne Zhang
>
> 2021
在wildreceipt发票公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|hmean|下载链接|
| --- | --- | --- | --- | --- |
|SDMGR|VGG6|[configs/kie/sdmgr/kie_unet_sdmgr.yml](../../configs/kie/sdmgr/kie_unet_sdmgr.yml)|86.7%|[训练模型]( https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/kie_vgg16.tar)/[推理模型(coming soon)]()|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
SDMGR是一个关键信息提取算法,将每个检测到的文本区域分类为预定义的类别,如订单ID、发票号码,金额等。
训练和测试的数据采用wildreceipt数据集,通过如下指令下载数据集:
```bash
wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar
```
创建数据集软链到PaddleOCR/train_data目录下:
```
cd PaddleOCR/ && mkdir train_data && cd train_data
ln -s ../../wildreceipt ./
```
### 3.1 模型训练
训练采用的配置文件是`configs/kie/sdmgr/kie_unet_sdmgr.yml`,配置文件中默认训练数据路径是`train_data/wildreceipt`,准备好数据后,可以通过如下指令执行训练:
```
python3 tools/train.py -c configs/kie/sdmgr/kie_unet_sdmgr.yml -o Global.save_model_dir=./output/kie/
```
### 3.2 模型评估
执行下面的命令进行模型评估
```bash
python3 tools/eval.py -c configs/kie/sdmgr/kie_unet_sdmgr.yml -o Global.checkpoints=./output/kie/best_accuracy
```
输出信息示例如下所示。
```py
[2022/08/10 05:22:23] ppocr INFO: metric eval ***************
[2022/08/10 05:22:23] ppocr INFO: hmean:0.8670120239257812
[2022/08/10 05:22:23] ppocr INFO: fps:10.18816520530961
```
### 3.3 模型预测
执行下面的命令进行模型预测,预测的时候需要预先加载存储图片路径以及OCR信息的文本文件,使用`Global.infer_img`进行指定。
```bash
python3 tools/infer_kie.py -c configs/kie/kie_unet_sdmgr.yml -o Global.checkpoints=kie_vgg16/best_accuracy Global.infer_img=./train_data/wildreceipt/1.txt
```
执行预测后的结果保存在`./output/sdmgr_kie/predicts_kie.txt`文件中,可视化结果保存在`/output/sdmgr_kie/kie_results/`目录下。
可视化结果如下图所示:
<div align="center">
<img src="../../ppstructure/docs/imgs/sdmgr_result.png" width="800">
</div>
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
暂不支持
<a name="4-2"></a>
### 4.2 C++推理部署
暂不支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@misc{sun2021spatial,
title={Spatial Dual-Modality Graph Reasoning for Key Information Extraction},
author={Hongbin Sun and Zhanghui Kuang and Xiaoyu Yue and Chenhao Lin and Wayne Zhang},
year={2021},
eprint={2103.14470},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
# 关键信息抽取算法-VI-LayoutXLM
- [1. 算法简介](#1-算法简介)
- [2. 环境配置](#2-环境配置)
- [3. 模型训练、评估、预测](#3-模型训练评估预测)
- [4. 推理部署](#4-推理部署)
- [4.1 Python推理](#41-python推理)
- [4.2 C++推理部署](#42-c推理部署)
- [4.3 Serving服务化部署](#43-serving服务化部署)
- [4.4 更多推理部署](#44-更多推理部署)
- [5. FAQ](#5-faq)
- [引用](#引用)
<a name="1"></a>
## 1. 算法简介
VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去除视觉骨干网络模块,最终精度基本无损的情况下,模型推理速度进一步提升。
在XFUND_zh数据集上,算法复现效果如下:
|模型|骨干网络|任务|配置文件|hmean|下载链接|
| --- | --- |---| --- | --- | --- |
|VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)|
|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[推理模型(coming soon)]()|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
请参考[关键信息抽取教程](./kie.md)。PaddleOCR对代码进行了模块化,训练不同的关键信息抽取模型只需要**更换配置文件**即可。
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于VI-LayoutXLM模型的关键信息抽取过程。
首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。
``` bash
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar
tar -xf ser_vi_layoutxlm_xfund_pretrained.tar
python3 tools/export_model.py -c configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_vi_layoutxlm_xfund_pretrained/best_accuracy Global.save_inference_dir=./inference/ser_vi_layoutxlm_infer
```
VI-LayoutXLM模型基于SER任务进行推理,可以执行如下命令:
```bash
cd ppstructure
python3 vqa/predict_vqa_token_ser.py \
--vqa_algorithm=LayoutXLM \
--ser_model_dir=../inference/ser_vi_layoutxlm_infer \
--image_dir=./docs/vqa/input/zh_val_42.jpg \
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx"
```
SER可视化结果默认保存到`./output`文件夹里面,结果示例如下:
<div align="center">
<img src="../../ppstructure/docs/vqa/result_ser/zh_val_42_ser.jpg" width="800">
</div>
<a name="4-2"></a>
### 4.2 C++推理部署
暂不支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@article{DBLP:journals/corr/abs-2104-08836,
author = {Yiheng Xu and
Tengchao Lv and
Lei Cui and
Guoxin Wang and
Yijuan Lu and
Dinei Flor{\^{e}}ncio and
Cha Zhang and
Furu Wei},
title = {LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich
Document Understanding},
journal = {CoRR},
volume = {abs/2104.08836},
year = {2021},
url = {https://arxiv.org/abs/2104.08836},
eprinttype = {arXiv},
eprint = {2104.08836},
timestamp = {Thu, 14 Oct 2021 09:17:23 +0200},
biburl = {https://dblp.org/rec/journals/corr/abs-2104-08836.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
@article{DBLP:journals/corr/abs-1912-13318,
author = {Yiheng Xu and
Minghao Li and
Lei Cui and
Shaohan Huang and
Furu Wei and
Ming Zhou},
title = {LayoutLM: Pre-training of Text and Layout for Document Image Understanding},
journal = {CoRR},
volume = {abs/1912.13318},
year = {2019},
url = {http://arxiv.org/abs/1912.13318},
eprinttype = {arXiv},
eprint = {1912.13318},
timestamp = {Mon, 01 Jun 2020 16:20:46 +0200},
biburl = {https://dblp.org/rec/journals/corr/abs-1912-13318.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
@article{DBLP:journals/corr/abs-2012-14740,
author = {Yang Xu and
Yiheng Xu and
Tengchao Lv and
Lei Cui and
Furu Wei and
Guoxin Wang and
Yijuan Lu and
Dinei A. F. Flor{\^{e}}ncio and
Cha Zhang and
Wanxiang Che and
Min Zhang and
Lidong Zhou},
title = {LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding},
journal = {CoRR},
volume = {abs/2012.14740},
year = {2020},
url = {https://arxiv.org/abs/2012.14740},
eprinttype = {arXiv},
eprint = {2012.14740},
timestamp = {Tue, 27 Jul 2021 09:53:52 +0200},
biburl = {https://dblp.org/rec/journals/corr/abs-2012-14740.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
# OCR算法 # 算法汇总
- [1. 两阶段算法](#1) - [1. 两阶段OCR算法](#1)
- [1.1 文本检测算法](#11) - [1.1 文本检测算法](#11)
- [1.2 文本识别算法](#12) - [1.2 文本识别算法](#12)
- [2. 端到端算法](#2) - [2. 端到端OCR算法](#2)
- [3. 表格识别算法](#3) - [3. 表格识别算法](#3)
- [4. 关键信息抽取算法](#4)
本文给出了PaddleOCR已支持的OCR算法列表,以及每个算法在**英文公开数据集**上的模型和指标,主要用于算法简介和算法性能对比,更多包括中文在内的其他数据集上的模型请参考[PP-OCR v2.0 系列模型下载](./models_list.md) 本文给出了PaddleOCR已支持的OCR算法列表,以及每个算法在**英文公开数据集**上的模型和指标,主要用于算法简介和算法性能对比,更多包括中文在内的其他数据集上的模型请参考[PP-OCR v2.0 系列模型下载](./models_list.md)
...@@ -116,3 +117,34 @@ ...@@ -116,3 +117,34 @@
|模型|骨干网络|配置文件|acc|下载链接| |模型|骨干网络|配置文件|acc|下载链接|
|---|---|---|---|---| |---|---|---|---|---|
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar) / [推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)| |TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar) / [推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
## 4. 关键信息抽取算法
已支持的关键信息抽取算法列表(戳链接获取使用教程):
- [x] [VI-LayoutXLM](./algorithm_kie_vi_laoutxlm.md)
- [x] [LayoutLM](./algorithm_kie_laoutxlm.md)
- [x] [LayoutLMv2](./algorithm_kie_laoutxlm.md)
- [x] [LayoutXLM](./algorithm_kie_laoutxlm.md)
- [x] [SDMGR](././algorithm_kie_sdmgr.md)
在wildreceipt发票公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|hmean|下载链接|
| --- | --- | --- | --- | --- |
|SDMGR|VGG6|[configs/kie/sdmgr/kie_unet_sdmgr.yml](../../configs/kie/sdmgr/kie_unet_sdmgr.yml)|86.7%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/kie_vgg16.tar)|
在XFUND_zh公开数据集上,算法效果如下:
|模型|骨干网络|任务|配置文件|hmean|下载链接|
| --- | --- | --- | --- | --- | --- |
|VI-LayoutXLM| VI-LayoutXLM-base | SER | [ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|**93.19%**|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)|
|LayoutXLM| LayoutXLM-base | SER | [ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)|
|LayoutLM| LayoutLM-base | SER | [ser_layoutlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml)|77.31%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar)|
|LayoutLMv2| LayoutLMv2-base | SER | [ser_layoutlmv2_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutlmv2_xfund_zh.yml)|85.44%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar)|
|VI-LayoutXLM| VI-LayoutXLM-base | RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|**83.92%**|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)|
|LayoutXLM| LayoutXLM-base | RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)|
|LayoutLMv2| LayoutLMv2-base | RE | [re_layoutlmv2_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutlmv2_xfund_zh.yml)|67.77%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar)|
...@@ -101,7 +101,7 @@ python3 tools/export_model.py -c configs/rec/rec_r45_visionlan.yml -o Global.pre ...@@ -101,7 +101,7 @@ python3 tools/export_model.py -c configs/rec/rec_r45_visionlan.yml -o Global.pre
执行如下命令进行模型推理: 执行如下命令进行模型推理:
```shell ```shell
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/dict36.txt' python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/ic15_dict.txt' --use_space_char=False
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。 # 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
``` ```
...@@ -110,7 +110,7 @@ python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' ...@@ -110,7 +110,7 @@ python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png'
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下: 执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
结果如下: 结果如下:
```shell ```shell
Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.97076982) Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.9999493)
``` ```
**注意** **注意**
......
# Text Gestalt
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [Text Gestalt: Stroke-Aware Scene Text Image Super-Resolution](https://arxiv.org/pdf/2112.08171.pdf)
> Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang
> AAAI, 2022
参考[FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/text-gestalt) 数据下载说明,在TextZoom测试集合上超分算法效果如下:
|模型|骨干网络|PSNR_Avg|SSIM_Avg|配置文件|下载链接|
|---|---|---|---|---|---|
|Text Gestalt|tsrn|19.28|0.6560| [configs/sr/sr_tsrn_transformer_strock.yml](../../configs/sr/sr_tsrn_transformer_strock.yml)|[训练模型](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar)|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
- 训练
在完成数据准备后,便可以启动训练,训练命令如下:
```
#单卡训练(训练周期长,不建议)
python3 tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
```
- 评估
```
# GPU 评估, Global.pretrained_model 为待测权重
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
- 预测:
```
# 预测使用的配置文件必须与训练一致
python3 tools/infer_sr.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png
```
![](../imgs_words_en/word_52.png)
执行命令后,上面图像的超分结果如下:
![](../imgs_results/sr_word_52.png)
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将文本超分训练过程中保存的模型,转换成inference model。以 Text-Gestalt 训练的[模型](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar) 为例,可以使用如下命令进行转换:
```shell
python3 tools/export_model.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out
```
Text-Gestalt 文本超分模型推理,可以执行如下命令:
```
python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128
```
执行命令后,图像的超分结果如下:
![](../imgs_results/sr_word_52.png)
<a name="4-2"></a>
### 4.2 C++推理
暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂未支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂未支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@inproceedings{chen2022text,
title={Text gestalt: Stroke-aware scene text image super-resolution},
author={Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={36},
number={1},
pages={285--293},
year={2022}
}
```
## DocVQA数据集 # 信息抽取数据集
这里整理了常见的DocVQA数据集,持续更新中,欢迎各位小伙伴贡献数据集~ 这里整理了常见的DocVQA数据集,持续更新中,欢迎各位小伙伴贡献数据集~
- [FUNSD数据集](#funsd) - [FUNSD数据集](#funsd)
- [XFUND数据集](#xfund) - [XFUND数据集](#xfund)
- [wildreceipt数据集](#wildreceipt)
<a name="funsd"></a> <a name="funsd"></a>
#### 1、FUNSD数据集
## 1. FUNSD数据集
- **数据来源**:https://guillaumejaume.github.io/FUNSD/ - **数据来源**:https://guillaumejaume.github.io/FUNSD/
- **数据简介**:FUNSD数据集是一个用于表单理解的数据集,它包含199张真实的、完全标注的扫描版图片,类型包括市场报告、广告以及学术报告等,并分为149张训练集以及50张测试集。FUNSD数据集适用于多种类型的DocVQA任务,如字段级实体分类、字段级实体连接等。部分图像以及标注框可视化如下所示: - **数据简介**:FUNSD数据集是一个用于表单理解的数据集,它包含199张真实的、完全标注的扫描版图片,类型包括市场报告、广告以及学术报告等,并分为149张训练集以及50张测试集。FUNSD数据集适用于多种类型的DocVQA任务,如字段级实体分类、字段级实体连接等。部分图像以及标注框可视化如下所示:
<div align="center"> <div align="center">
...@@ -16,12 +20,33 @@ ...@@ -16,12 +20,33 @@
- **下载地址**:https://guillaumejaume.github.io/FUNSD/download/ - **下载地址**:https://guillaumejaume.github.io/FUNSD/download/
<a name="xfund"></a> <a name="xfund"></a>
#### 2、XFUND数据集
## 2. XFUND数据集
- **数据来源**:https://github.com/doc-analysis/XFUND - **数据来源**:https://github.com/doc-analysis/XFUND
- **数据简介**:XFUND是一个多语种表单理解数据集,它包含7种不同语种的表单数据,并且全部用人工进行了键-值对形式的标注。其中每个语种的数据都包含了199张表单数据,并分为149张训练集以及50张测试集。部分图像以及标注框可视化如下所示: - **数据简介**:XFUND是一个多语种表单理解数据集,它包含7种不同语种的表单数据,并且全部用人工进行了键-值对形式的标注。其中每个语种的数据都包含了199张表单数据,并分为149张训练集以及50张测试集。部分图像以及标注框可视化如下所示:
<div align="center"> <div align="center">
<img src="../../datasets/xfund_demo/gt_zh_train_0.jpg" width="500"> <img src="../../datasets/xfund_demo/gt_zh_train_0.jpg" width="500">
<img src="../../datasets/xfund_demo/gt_zh_train_1.jpg" width="500"> <img src="../../datasets/xfund_demo/gt_zh_train_1.jpg" width="500">
</div> </div>
- **下载地址**:https://github.com/doc-analysis/XFUND/releases/tag/v1.0 - **下载地址**:https://github.com/doc-analysis/XFUND/releases/tag/v1.0
<a name="wildreceipt"></a>
## 3. wildreceipt数据集
- **数据来源**:https://arxiv.org/abs/2103.14470
- **数据简介**:wildreceipt数据集是英文发票数据集,包含26个类别(此处类别体系包含`Ignore`类别),共标注了50000个文本框。其中训练集包含1267张图片,测试集包含472张图片。部分图像以及标注框可视化如下所示:
<div align="center">
<img src="../../datasets/wildreceipt_demo/2769.jpeg" width="500">
<img src="../../datasets/wildreceipt_demo/1bbe854b8817dedb8585e0732089fd1f752d2cec.jpeg" width="500">
</div>
**注:** 这里对于类别为`Ignore`或者`Others`的文本,没有进行可视化。
- **下载地址**
- 原始数据下载地址:[链接](https://download.openmmlab.com/mmocr/data/wildreceipt.tar)
- 数据格式转换后适配于PaddleOCR训练的数据下载地址:[链接](https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar)
此差异已折叠。
...@@ -90,7 +90,7 @@ After the conversion is successful, there are three files in the directory: ...@@ -90,7 +90,7 @@ After the conversion is successful, there are three files in the directory:
For VisionLAN text recognition model inference, the following commands can be executed: For VisionLAN text recognition model inference, the following commands can be executed:
``` ```
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/dict36.txt' python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/ic15_dict.txt' --use_space_char=False
``` ```
![](../imgs_words/en/word_2.png) ![](../imgs_words/en/word_2.png)
...@@ -98,7 +98,7 @@ python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' ...@@ -98,7 +98,7 @@ python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png'
After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows: After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows:
The result is as follows: The result is as follows:
```shell ```shell
Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.97076982) Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.9999493)
``` ```
<a name="4-2"></a> <a name="4-2"></a>
......
# Text Gestalt
- [1. Introduction](#1)
- [2. Environment](#2)
- [3. Model Training / Evaluation / Prediction](#3)
- [3.1 Training](#3-1)
- [3.2 Evaluation](#3-2)
- [3.3 Prediction](#3-3)
- [4. Inference and Deployment](#4)
- [4.1 Python Inference](#4-1)
- [4.2 C++ Inference](#4-2)
- [4.3 Serving](#4-3)
- [4.4 More](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. Introduction
Paper:
> [Text Gestalt: Stroke-Aware Scene Text Image Super-Resolution](https://arxiv.org/pdf/2112.08171.pdf)
> Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang
> AAAI, 2022
Referring to the [FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/text-gestalt) data download instructions, the effect of the super-score algorithm on the TextZoom test set is as follows:
|Model|Backbone|config|Acc|Download link|
|---|---|---|---|---|---|
|Text Gestalt|tsrn|19.28|0.6560| [configs/sr/sr_tsrn_transformer_strock.yml](../../configs/sr/sr_tsrn_transformer_strock.yml)|[train model](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar)|
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different models only requires **changing the configuration file**.
Training:
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
```
#Single GPU training (long training period, not recommended)
python3 tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
#Multi GPU training, specify the gpu number through the --gpus parameter
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml
```
Evaluation:
```
# GPU evaluation
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
Prediction:
```
# The configuration file used for prediction must match the training
python3 tools/infer_sr.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png
```
![](../imgs_words_en/word_52.png)
After executing the command, the super-resolution result of the above image is as follows:
![](../imgs_results/sr_word_52.png)
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, the model saved during the training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar) ), you can use the following command to convert:
```shell
python3 tools/export_model.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out
```
For Text-Gestalt super-resolution model inference, the following commands can be executed:
```
python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128
```
After executing the command, the super-resolution result of the above image is as follows:
![](../imgs_results/sr_word_52.png)
<a name="4-2"></a>
### 4.2 C++ Inference
Not supported
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
## Citation
```bibtex
@inproceedings{chen2022text,
title={Text gestalt: Stroke-aware scene text image super-resolution},
author={Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={36},
number={1},
pages={285--293},
year={2022}
}
```
...@@ -34,7 +34,7 @@ import paddle.distributed as dist ...@@ -34,7 +34,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR
from ppocr.data.pgnet_dataset import PGDataSet from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet from ppocr.data.pubtab_dataset import PubTabDataSet
...@@ -54,7 +54,8 @@ def build_dataloader(config, mode, device, logger, seed=None): ...@@ -54,7 +54,8 @@ def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = [ support_dict = [
'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet' 'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet',
'LMDBDataSetSR'
] ]
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
......
...@@ -1236,6 +1236,54 @@ class ABINetLabelEncode(BaseRecLabelEncode): ...@@ -1236,6 +1236,54 @@ class ABINetLabelEncode(BaseRecLabelEncode):
return dict_character return dict_character
class SRLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(SRLabelEncode, self).__init__(max_text_length,
character_dict_path, use_space_char)
self.dic = {}
with open(character_dict_path, 'r') as fin:
for line in fin.readlines():
line = line.strip()
character, sequence = line.split()
self.dic[character] = sequence
english_stroke_alphabet = '0123456789'
self.english_stroke_dict = {}
for index in range(len(english_stroke_alphabet)):
self.english_stroke_dict[english_stroke_alphabet[index]] = index
def encode(self, label):
stroke_sequence = ''
for character in label:
if character not in self.dic:
continue
else:
stroke_sequence += self.dic[character]
stroke_sequence += '0'
label = stroke_sequence
length = len(label)
input_tensor = np.zeros(self.max_text_len).astype("int64")
for j in range(length - 1):
input_tensor[j + 1] = self.english_stroke_dict[label[j]]
return length, input_tensor
def __call__(self, data):
text = data['label']
length, input_tensor = self.encode(text)
data["length"] = length
data["input_tensor"] = input_tensor
if text is None:
return None
return data
class SPINLabelEncode(AttnLabelEncode): class SPINLabelEncode(AttnLabelEncode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
......
...@@ -24,6 +24,7 @@ import six ...@@ -24,6 +24,7 @@ import six
import cv2 import cv2
import numpy as np import numpy as np
import math import math
from PIL import Image
class DecodeImage(object): class DecodeImage(object):
...@@ -440,3 +441,52 @@ class KieResize(object): ...@@ -440,3 +441,52 @@ class KieResize(object):
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1]) points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0]) points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
return points return points
class SRResize(object):
def __init__(self,
imgH=32,
imgW=128,
down_sample_scale=4,
keep_ratio=False,
min_ratio=1,
mask=False,
infer_mode=False,
**kwargs):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio = keep_ratio
self.min_ratio = min_ratio
self.down_sample_scale = down_sample_scale
self.mask = mask
self.infer_mode = infer_mode
def __call__(self, data):
imgH = self.imgH
imgW = self.imgW
images_lr = data["image_lr"]
transform2 = ResizeNormalize(
(imgW // self.down_sample_scale, imgH // self.down_sample_scale))
images_lr = transform2(images_lr)
data["img_lr"] = images_lr
if self.infer_mode:
return data
images_HR = data["image_hr"]
label_strs = data["label"]
transform = ResizeNormalize((imgW, imgH))
images_HR = transform(images_HR)
data["img_hr"] = images_HR
return data
class ResizeNormalize(object):
def __init__(self, size, interpolation=Image.BICUBIC):
self.size = size
self.interpolation = interpolation
def __call__(self, img):
img = img.resize(self.size, self.interpolation)
img_numpy = np.array(img).astype("float32")
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
return img_numpy
...@@ -16,6 +16,9 @@ import os ...@@ -16,6 +16,9 @@ import os
from paddle.io import Dataset from paddle.io import Dataset
import lmdb import lmdb
import cv2 import cv2
import string
import six
from PIL import Image
from .imaug import transform, create_operators from .imaug import transform, create_operators
...@@ -116,3 +119,58 @@ class LMDBDataSet(Dataset): ...@@ -116,3 +119,58 @@ class LMDBDataSet(Dataset):
def __len__(self): def __len__(self):
return self.data_idx_order_list.shape[0] return self.data_idx_order_list.shape[0]
class LMDBDataSetSR(LMDBDataSet):
def buf2PIL(self, txn, key, type='RGB'):
imgbuf = txn.get(key)
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
im = Image.open(buf).convert(type)
return im
def str_filt(self, str_, voc_type):
alpha_dict = {
'digit': string.digits,
'lower': string.digits + string.ascii_lowercase,
'upper': string.digits + string.ascii_letters,
'all': string.digits + string.ascii_letters + string.punctuation
}
if voc_type == 'lower':
str_ = str_.lower()
for char in str_:
if char not in alpha_dict[voc_type]:
str_ = str_.replace(char, '')
return str_
def get_lmdb_sample_info(self, txn, index):
self.voc_type = 'upper'
self.max_len = 100
self.test = False
label_key = b'label-%09d' % index
word = str(txn.get(label_key).decode())
img_HR_key = b'image_hr-%09d' % index # 128*32
img_lr_key = b'image_lr-%09d' % index # 64*16
try:
img_HR = self.buf2PIL(txn, img_HR_key, 'RGB')
img_lr = self.buf2PIL(txn, img_lr_key, 'RGB')
except IOError or len(word) > self.max_len:
return self[index + 1]
label_str = self.str_filt(word, self.voc_type)
return img_HR, img_lr, label_str
def __getitem__(self, idx):
lmdb_idx, file_idx = self.data_idx_order_list[idx]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
file_idx)
if sample_info is None:
return self.__getitem__(np.random.randint(self.__len__()))
img_HR, img_lr, label_str = sample_info
data = {'image_hr': img_HR, 'image_lr': img_lr, 'label': label_str}
outs = transform(data, self.ops)
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
...@@ -57,6 +57,9 @@ from .table_master_loss import TableMasterLoss ...@@ -57,6 +57,9 @@ from .table_master_loss import TableMasterLoss
# vqa token loss # vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
# sr loss
from .stroke_focus_loss import StrokeFocusLoss
def build_loss(config): def build_loss(config):
support_dict = [ support_dict = [
...@@ -64,7 +67,7 @@ def build_loss(config): ...@@ -64,7 +67,7 @@ def build_loss(config):
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss' 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/stroke_focus_loss.py
"""
import cv2
import sys
import time
import string
import random
import numpy as np
import paddle.nn as nn
import paddle
class StrokeFocusLoss(nn.Layer):
def __init__(self, character_dict_path=None, **kwargs):
super(StrokeFocusLoss, self).__init__(character_dict_path)
self.mse_loss = nn.MSELoss()
self.ce_loss = nn.CrossEntropyLoss()
self.l1_loss = nn.L1Loss()
self.english_stroke_alphabet = '0123456789'
self.english_stroke_dict = {}
for index in range(len(self.english_stroke_alphabet)):
self.english_stroke_dict[self.english_stroke_alphabet[
index]] = index
stroke_decompose_lines = open(character_dict_path, 'r').readlines()
self.dic = {}
for line in stroke_decompose_lines:
line = line.strip()
character, sequence = line.split()
self.dic[character] = sequence
def forward(self, pred, data):
sr_img = pred["sr_img"]
hr_img = pred["hr_img"]
mse_loss = self.mse_loss(sr_img, hr_img)
word_attention_map_gt = pred["word_attention_map_gt"]
word_attention_map_pred = pred["word_attention_map_pred"]
hr_pred = pred["hr_pred"]
sr_pred = pred["sr_pred"]
attention_loss = paddle.nn.functional.l1_loss(word_attention_map_gt,
word_attention_map_pred)
loss = (mse_loss + attention_loss * 50) * 100
return {
"mse_loss": mse_loss,
"attention_loss": attention_loss,
"loss": loss
}
...@@ -30,13 +30,13 @@ from .table_metric import TableMetric ...@@ -30,13 +30,13 @@ from .table_metric import TableMetric
from .kie_metric import KIEMetric from .kie_metric import KIEMetric
from .vqa_token_ser_metric import VQASerTokenMetric from .vqa_token_ser_metric import VQASerTokenMetric
from .vqa_token_re_metric import VQAReTokenMetric from .vqa_token_re_metric import VQAReTokenMetric
from .sr_metric import SRMetric
def build_metric(config): def build_metric(config):
support_dict = [ support_dict = [
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric", "DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric', "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
'VQAReTokenMetric' 'VQAReTokenMetric', 'SRMetric'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -16,6 +16,7 @@ import Levenshtein ...@@ -16,6 +16,7 @@ import Levenshtein
import string import string
class RecMetric(object): class RecMetric(object):
def __init__(self, def __init__(self,
main_indicator='acc', main_indicator='acc',
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/utils/ssim_psnr.py
"""
from math import exp
import paddle
import paddle.nn.functional as F
import paddle.nn as nn
import string
class SSIM(nn.Layer):
def __init__(self, window_size=11, size_average=True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = self.create_window(window_size, self.channel)
def gaussian(self, window_size, sigma):
gauss = paddle.to_tensor([
exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
for x in range(window_size)
])
return gauss / gauss.sum()
def create_window(self, window_size, channel):
_1D_window = self.gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0)
window = _2D_window.expand([channel, 1, window_size, window_size])
return window
def _ssim(self, img1, img2, window, window_size, channel,
size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(
img1 * img1, window, padding=window_size // 2,
groups=channel) - mu1_sq
sigma2_sq = F.conv2d(
img2 * img2, window, padding=window_size // 2,
groups=channel) - mu2_sq
sigma12 = F.conv2d(
img1 * img2, window, padding=window_size // 2,
groups=channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean([1, 2, 3])
def ssim(self, img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.shape
window = self.create_window(window_size, channel)
return self._ssim(img1, img2, window, window_size, channel,
size_average)
def forward(self, img1, img2):
(_, channel, _, _) = img1.shape
if channel == self.channel and self.window.dtype == img1.dtype:
window = self.window
else:
window = self.create_window(self.window_size, channel)
self.window = window
self.channel = channel
return self._ssim(img1, img2, window, self.window_size, channel,
self.size_average)
class SRMetric(object):
def __init__(self, main_indicator='all', **kwargs):
self.main_indicator = main_indicator
self.eps = 1e-5
self.psnr_result = []
self.ssim_result = []
self.calculate_ssim = SSIM()
self.reset()
def reset(self):
self.correct_num = 0
self.all_num = 0
self.norm_edit_dis = 0
self.psnr_result = []
self.ssim_result = []
def calculate_psnr(self, img1, img2):
# img1 and img2 have range [0, 1]
mse = ((img1 * 255 - img2 * 255)**2).mean()
if mse == 0:
return float('inf')
return 20 * paddle.log10(255.0 / paddle.sqrt(mse))
def _normalize_text(self, text):
text = ''.join(
filter(lambda x: x in (string.digits + string.ascii_letters), text))
return text.lower()
def __call__(self, pred_label, *args, **kwargs):
metric = {}
images_sr = pred_label["sr_img"]
images_hr = pred_label["hr_img"]
psnr = self.calculate_psnr(images_sr, images_hr)
ssim = self.calculate_ssim(images_sr, images_hr)
self.psnr_result.append(psnr)
self.ssim_result.append(ssim)
def get_metric(self):
"""
return metrics {
'acc': 0,
'norm_edit_dis': 0,
}
"""
self.psnr_avg = sum(self.psnr_result) / len(self.psnr_result)
self.psnr_avg = round(self.psnr_avg.item(), 6)
self.ssim_avg = sum(self.ssim_result) / len(self.ssim_result)
self.ssim_avg = round(self.ssim_avg.item(), 6)
self.all_avg = self.psnr_avg + self.ssim_avg
self.reset()
return {
'psnr_avg': self.psnr_avg,
"ssim_avg": self.ssim_avg,
"all": self.all_avg
}
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle import nn from paddle import nn
from ppocr.modeling.transforms import build_transform from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone from ppocr.modeling.backbones import build_backbone
...@@ -46,6 +47,10 @@ class BaseModel(nn.Layer): ...@@ -46,6 +47,10 @@ class BaseModel(nn.Layer):
in_channels = self.transform.out_channels in_channels = self.transform.out_channels
# build backbone, backbone is need for del, rec and cls # build backbone, backbone is need for del, rec and cls
if 'Backbone' not in config or config['Backbone'] is None:
self.use_backbone = False
else:
self.use_backbone = True
config["Backbone"]['in_channels'] = in_channels config["Backbone"]['in_channels'] = in_channels
self.backbone = build_backbone(config["Backbone"], model_type) self.backbone = build_backbone(config["Backbone"], model_type)
in_channels = self.backbone.out_channels in_channels = self.backbone.out_channels
...@@ -77,6 +82,7 @@ class BaseModel(nn.Layer): ...@@ -77,6 +82,7 @@ class BaseModel(nn.Layer):
y = dict() y = dict()
if self.use_transform: if self.use_transform:
x = self.transform(x) x = self.transform(x)
if self.use_backbone:
x = self.backbone(x) x = self.backbone(x)
if isinstance(x, dict): if isinstance(x, dict):
y.update(x) y.update(x)
......
...@@ -113,7 +113,6 @@ class LayoutLMv2ForSer(NLPBaseModel): ...@@ -113,7 +113,6 @@ class LayoutLMv2ForSer(NLPBaseModel):
pretrained, pretrained,
checkpoints, checkpoints,
num_classes=num_classes) num_classes=num_classes)
self.use_visual_backbone = True
if hasattr(self.model.layoutlmv2, "use_visual_backbone" if hasattr(self.model.layoutlmv2, "use_visual_backbone"
) and self.model.layoutlmv2.use_visual_backbone is False: ) and self.model.layoutlmv2.use_visual_backbone is False:
self.use_visual_backbone = False self.use_visual_backbone = False
...@@ -155,7 +154,9 @@ class LayoutXLMForSer(NLPBaseModel): ...@@ -155,7 +154,9 @@ class LayoutXLMForSer(NLPBaseModel):
pretrained, pretrained,
checkpoints, checkpoints,
num_classes=num_classes) num_classes=num_classes)
self.use_visual_backbone = True if hasattr(self.model.layoutxlm, "use_visual_backbone"
) and self.model.layoutxlm.use_visual_backbone is False:
self.use_visual_backbone = False
def forward(self, x): def forward(self, x):
if self.use_visual_backbone is True: if self.use_visual_backbone is True:
...@@ -185,6 +186,9 @@ class LayoutLMv2ForRe(NLPBaseModel): ...@@ -185,6 +186,9 @@ class LayoutLMv2ForRe(NLPBaseModel):
super(LayoutLMv2ForRe, self).__init__( super(LayoutLMv2ForRe, self).__init__(
LayoutLMv2Model, LayoutLMv2ForRelationExtraction, mode, "re", LayoutLMv2Model, LayoutLMv2ForRelationExtraction, mode, "re",
pretrained, checkpoints) pretrained, checkpoints)
if hasattr(self.model.layoutlmv2, "use_visual_backbone"
) and self.model.layoutlmv2.use_visual_backbone is False:
self.use_visual_backbone = False
def forward(self, x): def forward(self, x):
x = self.model( x = self.model(
...@@ -207,7 +211,6 @@ class LayoutXLMForRe(NLPBaseModel): ...@@ -207,7 +211,6 @@ class LayoutXLMForRe(NLPBaseModel):
super(LayoutXLMForRe, self).__init__( super(LayoutXLMForRe, self).__init__(
LayoutXLMModel, LayoutXLMForRelationExtraction, mode, "re", LayoutXLMModel, LayoutXLMForRelationExtraction, mode, "re",
pretrained, checkpoints) pretrained, checkpoints)
self.use_visual_backbone = True
if hasattr(self.model.layoutxlm, "use_visual_backbone" if hasattr(self.model.layoutxlm, "use_visual_backbone"
) and self.model.layoutxlm.use_visual_backbone is False: ) and self.model.layoutxlm.use_visual_backbone is False:
self.use_visual_backbone = False self.use_visual_backbone = False
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/transformer_english_decomposition.py
"""
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import math, copy
import numpy as np
# stroke-level alphabet
alphabet = '0123456789'
def get_alphabet_len():
return len(alphabet)
def subsequent_mask(size):
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
"""
mask = paddle.ones([1, size, size], dtype='float32')
mask_inf = paddle.triu(
paddle.full(
shape=[1, size, size], dtype='float32', fill_value='-inf'),
diagonal=1)
mask = mask + mask_inf
padding_mask = paddle.equal(mask, paddle.to_tensor(1, dtype=mask.dtype))
return padding_mask
def clones(module, N):
return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
def attention(query, key, value, mask=None, dropout=None, attention_map=None):
d_k = query.shape[-1]
scores = paddle.matmul(query,
paddle.transpose(key, [0, 1, 3, 2])) / math.sqrt(d_k)
if mask is not None:
scores = masked_fill(scores, mask == 0, float('-inf'))
else:
pass
p_attn = F.softmax(scores, axis=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return paddle.matmul(p_attn, value), p_attn
class MultiHeadedAttention(nn.Layer):
def __init__(self, h, d_model, dropout=0.1, compress_attention=False):
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0
self.d_k = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")
self.compress_attention = compress_attention
self.compress_attention_linear = nn.Linear(h, 1)
def forward(self, query, key, value, mask=None, attention_map=None):
if mask is not None:
mask = mask.unsqueeze(1)
nbatches = query.shape[0]
query, key, value = \
[paddle.transpose(l(x).reshape([nbatches, -1, self.h, self.d_k]), [0,2,1,3])
for l, x in zip(self.linears, (query, key, value))]
x, attention_map = attention(
query,
key,
value,
mask=mask,
dropout=self.dropout,
attention_map=attention_map)
x = paddle.reshape(
paddle.transpose(x, [0, 2, 1, 3]),
[nbatches, -1, self.h * self.d_k])
return self.linears[-1](x), attention_map
class ResNet(nn.Layer):
def __init__(self, num_in, block, layers):
super(ResNet, self).__init__()
self.conv1 = nn.Conv2D(num_in, 64, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2D(64, use_global_stats=True)
self.relu1 = nn.ReLU()
self.pool = nn.MaxPool2D((2, 2), (2, 2))
self.conv2 = nn.Conv2D(64, 128, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2D(128, use_global_stats=True)
self.relu2 = nn.ReLU()
self.layer1_pool = nn.MaxPool2D((2, 2), (2, 2))
self.layer1 = self._make_layer(block, 128, 256, layers[0])
self.layer1_conv = nn.Conv2D(256, 256, 3, 1, 1)
self.layer1_bn = nn.BatchNorm2D(256, use_global_stats=True)
self.layer1_relu = nn.ReLU()
self.layer2_pool = nn.MaxPool2D((2, 2), (2, 2))
self.layer2 = self._make_layer(block, 256, 256, layers[1])
self.layer2_conv = nn.Conv2D(256, 256, 3, 1, 1)
self.layer2_bn = nn.BatchNorm2D(256, use_global_stats=True)
self.layer2_relu = nn.ReLU()
self.layer3_pool = nn.MaxPool2D((2, 2), (2, 2))
self.layer3 = self._make_layer(block, 256, 512, layers[2])
self.layer3_conv = nn.Conv2D(512, 512, 3, 1, 1)
self.layer3_bn = nn.BatchNorm2D(512, use_global_stats=True)
self.layer3_relu = nn.ReLU()
self.layer4_pool = nn.MaxPool2D((2, 2), (2, 2))
self.layer4 = self._make_layer(block, 512, 512, layers[3])
self.layer4_conv2 = nn.Conv2D(512, 1024, 3, 1, 1)
self.layer4_conv2_bn = nn.BatchNorm2D(1024, use_global_stats=True)
self.layer4_conv2_relu = nn.ReLU()
def _make_layer(self, block, inplanes, planes, blocks):
if inplanes != planes:
downsample = nn.Sequential(
nn.Conv2D(inplanes, planes, 3, 1, 1),
nn.BatchNorm2D(
planes, use_global_stats=True), )
else:
downsample = None
layers = []
layers.append(block(inplanes, planes, downsample))
for i in range(1, blocks):
layers.append(block(planes, planes, downsample=None))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.pool(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.layer1_pool(x)
x = self.layer1(x)
x = self.layer1_conv(x)
x = self.layer1_bn(x)
x = self.layer1_relu(x)
x = self.layer2(x)
x = self.layer2_conv(x)
x = self.layer2_bn(x)
x = self.layer2_relu(x)
x = self.layer3(x)
x = self.layer3_conv(x)
x = self.layer3_bn(x)
x = self.layer3_relu(x)
x = self.layer4(x)
x = self.layer4_conv2(x)
x = self.layer4_conv2_bn(x)
x = self.layer4_conv2_relu(x)
return x
class Bottleneck(nn.Layer):
def __init__(self, input_dim):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2D(input_dim, input_dim, 1)
self.bn1 = nn.BatchNorm2D(input_dim, use_global_stats=True)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(input_dim, input_dim, 3, 1, 1)
self.bn2 = nn.BatchNorm2D(input_dim, use_global_stats=True)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu(out)
return out
class PositionalEncoding(nn.Layer):
"Implement the PE function."
def __init__(self, dropout, dim, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")
pe = paddle.zeros([max_len, dim])
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
div_term = paddle.exp(
paddle.arange(0, dim, 2).astype('float32') *
(-math.log(10000.0) / dim))
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = paddle.unsqueeze(pe, 0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :paddle.shape(x)[1]]
return self.dropout(x)
class PositionwiseFeedForward(nn.Layer):
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout, mode="downscale_in_infer")
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class Generator(nn.Layer):
"Define standard linear + softmax generation step."
def __init__(self, d_model, vocab):
super(Generator, self).__init__()
self.proj = nn.Linear(d_model, vocab)
self.relu = nn.ReLU()
def forward(self, x):
out = self.proj(x)
return out
class Embeddings(nn.Layer):
def __init__(self, d_model, vocab):
super(Embeddings, self).__init__()
self.lut = nn.Embedding(vocab, d_model)
self.d_model = d_model
def forward(self, x):
embed = self.lut(x) * math.sqrt(self.d_model)
return embed
class LayerNorm(nn.Layer):
"Construct a layernorm module (See citation for details)."
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = self.create_parameter(
shape=[features],
default_initializer=paddle.nn.initializer.Constant(1.0))
self.b_2 = self.create_parameter(
shape=[features],
default_initializer=paddle.nn.initializer.Constant(0.0))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
class Decoder(nn.Layer):
def __init__(self):
super(Decoder, self).__init__()
self.mask_multihead = MultiHeadedAttention(
h=16, d_model=1024, dropout=0.1)
self.mul_layernorm1 = LayerNorm(1024)
self.multihead = MultiHeadedAttention(h=16, d_model=1024, dropout=0.1)
self.mul_layernorm2 = LayerNorm(1024)
self.pff = PositionwiseFeedForward(1024, 2048)
self.mul_layernorm3 = LayerNorm(1024)
def forward(self, text, conv_feature, attention_map=None):
text_max_length = text.shape[1]
mask = subsequent_mask(text_max_length)
result = text
result = self.mul_layernorm1(result + self.mask_multihead(
text, text, text, mask=mask)[0])
b, c, h, w = conv_feature.shape
conv_feature = paddle.transpose(
conv_feature.reshape([b, c, h * w]), [0, 2, 1])
word_image_align, attention_map = self.multihead(
result,
conv_feature,
conv_feature,
mask=None,
attention_map=attention_map)
result = self.mul_layernorm2(result + word_image_align)
result = self.mul_layernorm3(result + self.pff(result))
return result, attention_map
class BasicBlock(nn.Layer):
def __init__(self, inplanes, planes, downsample):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2D(
inplanes, planes, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2D(planes, use_global_stats=True)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(
planes, planes, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2D(planes, use_global_stats=True)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample != None:
residual = self.downsample(residual)
out += residual
out = self.relu(out)
return out
class Encoder(nn.Layer):
def __init__(self):
super(Encoder, self).__init__()
self.cnn = ResNet(num_in=1, block=BasicBlock, layers=[1, 2, 5, 3])
def forward(self, input):
conv_result = self.cnn(input)
return conv_result
class Transformer(nn.Layer):
def __init__(self, in_channels=1):
super(Transformer, self).__init__()
word_n_class = get_alphabet_len()
self.embedding_word_with_upperword = Embeddings(512, word_n_class)
self.pe = PositionalEncoding(dim=512, dropout=0.1, max_len=5000)
self.encoder = Encoder()
self.decoder = Decoder()
self.generator_word_with_upperword = Generator(1024, word_n_class)
for p in self.parameters():
if p.dim() > 1:
nn.initializer.XavierNormal(p)
def forward(self, image, text_length, text_input, attention_map=None):
if image.shape[1] == 3:
R = image[:, 0:1, :, :]
G = image[:, 1:2, :, :]
B = image[:, 2:3, :, :]
image = 0.299 * R + 0.587 * G + 0.114 * B
conv_feature = self.encoder(image) # batch, 1024, 8, 32
max_length = max(text_length)
text_input = text_input[:, :max_length]
text_embedding = self.embedding_word_with_upperword(
text_input) # batch, text_max_length, 512
postion_embedding = self.pe(
paddle.zeros(text_embedding.shape)) # batch, text_max_length, 512
text_input_with_pe = paddle.concat([text_embedding, postion_embedding],
2) # batch, text_max_length, 1024
batch, seq_len, _ = text_input_with_pe.shape
text_input_with_pe, word_attention_map = self.decoder(
text_input_with_pe, conv_feature)
word_decoder_result = self.generator_word_with_upperword(
text_input_with_pe)
if self.training:
total_length = paddle.sum(text_length)
probs_res = paddle.zeros([total_length, get_alphabet_len()])
start = 0
for index, length in enumerate(text_length):
length = int(length.numpy())
probs_res[start:start + length, :] = word_decoder_result[
index, 0:0 + length, :]
start = start + length
return probs_res, word_attention_map, None
else:
return word_decoder_result
...@@ -18,10 +18,10 @@ __all__ = ['build_transform'] ...@@ -18,10 +18,10 @@ __all__ = ['build_transform']
def build_transform(config): def build_transform(config):
from .tps import TPS from .tps import TPS
from .stn import STN_ON from .stn import STN_ON
from .tsrn import TSRN
from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN
support_dict = ['TPS', 'STN_ON', 'GA_SPIN', 'TSRN']
support_dict = ['TPS', 'STN_ON', 'GA_SPIN']
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/model/tsrn.py
"""
import math
import paddle
import paddle.nn.functional as F
from paddle import nn
from collections import OrderedDict
import sys
import numpy as np
import warnings
import math, copy
import cv2
warnings.filterwarnings("ignore")
from .tps_spatial_transformer import TPSSpatialTransformer
from .stn import STN as STN_model
from ppocr.modeling.heads.sr_rensnet_transformer import Transformer
class TSRN(nn.Layer):
def __init__(self,
in_channels,
scale_factor=2,
width=128,
height=32,
STN=False,
srb_nums=5,
mask=False,
hidden_units=32,
infer_mode=False,
**kwargs):
super(TSRN, self).__init__()
in_planes = 3
if mask:
in_planes = 4
assert math.log(scale_factor, 2) % 1 == 0
upsample_block_num = int(math.log(scale_factor, 2))
self.block1 = nn.Sequential(
nn.Conv2D(
in_planes, 2 * hidden_units, kernel_size=9, padding=4),
nn.PReLU())
self.srb_nums = srb_nums
for i in range(srb_nums):
setattr(self, 'block%d' % (i + 2),
RecurrentResidualBlock(2 * hidden_units))
setattr(
self,
'block%d' % (srb_nums + 2),
nn.Sequential(
nn.Conv2D(
2 * hidden_units,
2 * hidden_units,
kernel_size=3,
padding=1),
nn.BatchNorm2D(2 * hidden_units)))
block_ = [
UpsampleBLock(2 * hidden_units, 2)
for _ in range(upsample_block_num)
]
block_.append(
nn.Conv2D(
2 * hidden_units, in_planes, kernel_size=9, padding=4))
setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_))
self.tps_inputsize = [height // scale_factor, width // scale_factor]
tps_outputsize = [height // scale_factor, width // scale_factor]
num_control_points = 20
tps_margins = [0.05, 0.05]
self.stn = STN
if self.stn:
self.tps = TPSSpatialTransformer(
output_image_size=tuple(tps_outputsize),
num_control_points=num_control_points,
margins=tuple(tps_margins))
self.stn_head = STN_model(
in_channels=in_planes,
num_ctrlpoints=num_control_points,
activation='none')
self.out_channels = in_channels
self.r34_transformer = Transformer()
for param in self.r34_transformer.parameters():
param.trainable = False
self.infer_mode = infer_mode
def forward(self, x):
output = {}
if self.infer_mode:
output["lr_img"] = x
y = x
else:
output["lr_img"] = x[0]
output["hr_img"] = x[1]
y = x[0]
if self.stn and self.training:
_, ctrl_points_x = self.stn_head(y)
y, _ = self.tps(y, ctrl_points_x)
block = {'1': self.block1(y)}
for i in range(self.srb_nums + 1):
block[str(i + 2)] = getattr(self,
'block%d' % (i + 2))(block[str(i + 1)])
block[str(self.srb_nums + 3)] = getattr(self, 'block%d' % (self.srb_nums + 3)) \
((block['1'] + block[str(self.srb_nums + 2)]))
sr_img = paddle.tanh(block[str(self.srb_nums + 3)])
output["sr_img"] = sr_img
if self.training:
hr_img = x[1]
length = x[2]
input_tensor = x[3]
# add transformer
sr_pred, word_attention_map_pred, _ = self.r34_transformer(
sr_img, length, input_tensor)
hr_pred, word_attention_map_gt, _ = self.r34_transformer(
hr_img, length, input_tensor)
output["hr_img"] = hr_img
output["hr_pred"] = hr_pred
output["word_attention_map_gt"] = word_attention_map_gt
output["sr_pred"] = sr_pred
output["word_attention_map_pred"] = word_attention_map_pred
return output
class RecurrentResidualBlock(nn.Layer):
def __init__(self, channels):
super(RecurrentResidualBlock, self).__init__()
self.conv1 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2D(channels)
self.gru1 = GruBlock(channels, channels)
self.prelu = mish()
self.conv2 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2D(channels)
self.gru2 = GruBlock(channels, channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
residual = self.gru1(residual.transpose([0, 1, 3, 2])).transpose(
[0, 1, 3, 2])
return self.gru2(x + residual)
class UpsampleBLock(nn.Layer):
def __init__(self, in_channels, up_scale):
super(UpsampleBLock, self).__init__()
self.conv = nn.Conv2D(
in_channels, in_channels * up_scale**2, kernel_size=3, padding=1)
self.pixel_shuffle = nn.PixelShuffle(up_scale)
self.prelu = mish()
def forward(self, x):
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.prelu(x)
return x
class mish(nn.Layer):
def __init__(self, ):
super(mish, self).__init__()
self.activated = True
def forward(self, x):
if self.activated:
x = x * (paddle.tanh(F.softplus(x)))
return x
class GruBlock(nn.Layer):
def __init__(self, in_channels, out_channels):
super(GruBlock, self).__init__()
assert out_channels % 2 == 0
self.conv1 = nn.Conv2D(
in_channels, out_channels, kernel_size=1, padding=0)
self.gru = nn.GRU(out_channels,
out_channels // 2,
direction='bidirectional')
def forward(self, x):
# x: b, c, w, h
x = self.conv1(x)
x = x.transpose([0, 2, 3, 1]) # b, w, h, c
batch_size, w, h, c = x.shape
x = x.reshape([-1, h, c]) # b*w, h, c
x, _ = self.gru(x)
x = x.reshape([-1, w, h, c])
x = x.transpose([0, 3, 1, 2])
return x
...@@ -780,7 +780,7 @@ class VLLabelDecode(BaseRecLabelDecode): ...@@ -780,7 +780,7 @@ class VLLabelDecode(BaseRecLabelDecode):
) + length[i])].topk(1)[0][:, 0] ) + length[i])].topk(1)[0][:, 0]
preds_prob = paddle.exp( preds_prob = paddle.exp(
paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6)) paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
text.append((preds_text, preds_prob)) text.append((preds_text, preds_prob.numpy()[0]))
if label is None: if label is None:
return text return text
label = self.decode(label) label = self.decode(label)
......
...@@ -56,7 +56,7 @@ def load_model(config, model, optimizer=None, model_type='det'): ...@@ -56,7 +56,7 @@ def load_model(config, model, optimizer=None, model_type='det'):
is_float16 = False is_float16 = False
if model_type == 'vqa': if model_type == 'vqa':
# NOTE: for vqa model, resume training is not supported now # NOTE: for vqa model dsitillation, resume training is not supported now
if config["Architecture"]["algorithm"] in ["Distillation"]: if config["Architecture"]["algorithm"] in ["Distillation"]:
return best_model_dict return best_model_dict
checkpoints = config['Architecture']['Backbone']['checkpoints'] checkpoints = config['Architecture']['Backbone']['checkpoints']
...@@ -148,10 +148,14 @@ def load_pretrained_params(model, path): ...@@ -148,10 +148,14 @@ def load_pretrained_params(model, path):
"The {}.pdparams does not exists!".format(path) "The {}.pdparams does not exists!".format(path)
params = paddle.load(path + '.pdparams') params = paddle.load(path + '.pdparams')
state_dict = model.state_dict() state_dict = model.state_dict()
new_state_dict = {} new_state_dict = {}
is_float16 = False is_float16 = False
for k1 in params.keys(): for k1 in params.keys():
if k1 not in state_dict.keys(): if k1 not in state_dict.keys():
logger.warning("The pretrained params {} not in model".format(k1)) logger.warning("The pretrained params {} not in model".format(k1))
else: else:
...@@ -187,7 +191,6 @@ def save_model(model, ...@@ -187,7 +191,6 @@ def save_model(model,
""" """
_mkdir_if_not_exist(model_path, logger) _mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix) model_prefix = os.path.join(model_path, prefix)
if config['Architecture']["model_type"] != 'vqa':
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
if config['Architecture']["model_type"] != 'vqa': if config['Architecture']["model_type"] != 'vqa':
paddle.save(model.state_dict(), model_prefix + '.pdparams') paddle.save(model.state_dict(), model_prefix + '.pdparams')
......
- [关键信息提取(Key Information Extraction)](#关键信息提取key-information-extraction)
- [1. 快速使用](#1-快速使用)
- [2. 执行训练](#2-执行训练)
- [3. 执行评估](#3-执行评估)
- [4. 参考文献](#4-参考文献)
# 关键信息提取(Key Information Extraction)
本节介绍PaddleOCR中关键信息提取SDMGR方法的快速使用和训练方法。
SDMGR是一个关键信息提取算法,将每个检测到的文本区域分类为预定义的类别,如订单ID、发票号码,金额等。
## 1. 快速使用
训练和测试的数据采用wildreceipt数据集,通过如下指令下载数据集:
```
wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar
```
执行预测:
```
cd PaddleOCR/
wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/kie_vgg16.tar && tar xf kie_vgg16.tar
python3.7 tools/infer_kie.py -c configs/kie/kie_unet_sdmgr.yml -o Global.checkpoints=kie_vgg16/best_accuracy Global.infer_img=../wildreceipt/1.txt
```
执行预测后的结果保存在`./output/sdmgr_kie/predicts_kie.txt`文件中,可视化结果保存在`/output/sdmgr_kie/kie_results/`目录下。
可视化结果如下图所示:
<div align="center">
<img src="./imgs/0.png" width="800">
</div>
## 2. 执行训练
创建数据集软链到PaddleOCR/train_data目录下:
```
cd PaddleOCR/ && mkdir train_data && cd train_data
ln -s ../../wildreceipt ./
```
训练采用的配置文件是configs/kie/kie_unet_sdmgr.yml,配置文件中默认训练数据路径是`train_data/wildreceipt`,准备好数据后,可以通过如下指令执行训练:
```
python3.7 tools/train.py -c configs/kie/kie_unet_sdmgr.yml -o Global.save_model_dir=./output/kie/
```
## 3. 执行评估
```
python3.7 tools/eval.py -c configs/kie/kie_unet_sdmgr.yml -o Global.checkpoints=./output/kie/best_accuracy
```
## 4. 参考文献
<!-- [ALGORITHM] -->
```bibtex
@misc{sun2021spatial,
title={Spatial Dual-Modality Graph Reasoning for Key Information Extraction},
author={Hongbin Sun and Zhanghui Kuang and Xiaoyu Yue and Chenhao Lin and Wayne Zhang},
year={2021},
eprint={2103.14470},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
...@@ -4,8 +4,7 @@ ...@@ -4,8 +4,7 @@
- [2. OCR和表格识别模型](#2-ocr和表格识别模型) - [2. OCR和表格识别模型](#2-ocr和表格识别模型)
- [2.1 OCR](#21-ocr) - [2.1 OCR](#21-ocr)
- [2.2 表格识别模型](#22-表格识别模型) - [2.2 表格识别模型](#22-表格识别模型)
- [3. VQA模型](#3-vqa模型) - [3. KIE模型](#3-kie模型)
- [4. KIE模型](#4-kie模型)
<a name="1"></a> <a name="1"></a>
...@@ -38,19 +37,26 @@ ...@@ -38,19 +37,26 @@
|en_ppocr_mobile_v2.0_table_structure|PubTabNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) | |en_ppocr_mobile_v2.0_table_structure|PubTabNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
<a name="3"></a> <a name="3"></a>
## 3. VQA模型
|模型名称|模型简介|推理模型大小|下载地址| ## 3. KIE模型
| --- | --- | --- | --- |
|ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
<a name="4"></a> 在XFUND_zh数据集上,不同模型的精度与V100 GPU上速度信息如下所示。
## 4. KIE模型
|模型名称|模型简介|模型大小|下载地址| |模型名称|模型简介 | 推理模型大小| 精度(hmean) | 预测耗时(ms) | 下载地址|
| --- | --- | --- | --- | | --- | --- | --- |--- |--- | --- |
|SDMGR|关键信息提取模型|78M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/kie_vgg16.tar)| |ser_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的SER模型|1.1G| 93.19% | 15.49 | [推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar) |
|re_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的RE模型|1.1G| 83.92% | 15.49 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar) |
|ser_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的SER模型|1.4G| 90.38% | 19.49 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|re_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的RE模型|1.4G| 74.83% | 19.49 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfund_zh|基于LayoutLMv2在xfund中文数据集上训练的SER模型|778M| 85.44% | 31.46 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfund_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M| 67.77% | 31.46 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutLM_xfund_zh|基于LayoutLM在xfund中文数据集上训练的SER模型|430M| 77.31% | - |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
* 注:上述预测耗时信息仅包含了inference模型的推理耗时,没有统计预处理与后处理耗时,测试环境为`V100 GPU + CUDA 10.2 + CUDNN 8.1.1 + TRT 7.2.3.4`
在wildreceipt数据集上,SDMGR模型精度与下载地址如下所示。
|模型名称|模型简介|模型大小|精度|下载地址|
| --- | --- | --- |--- | --- |
|SDMGR|关键信息提取模型|78M| 86.70% | [推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/kie_vgg16.tar)|
...@@ -51,6 +51,8 @@ def init_args(): ...@@ -51,6 +51,8 @@ def init_args():
"--ser_dict_path", "--ser_dict_path",
type=str, type=str,
default="../train_data/XFUND/class_list_xfun.txt") default="../train_data/XFUND/class_list_xfun.txt")
# need to be None or tb-yx
parser.add_argument("--ocr_order_method", type=str, default=None)
# params for inference # params for inference
parser.add_argument( parser.add_argument(
"--mode", "--mode",
......
# 怎样完成基于图像数据的信息抽取任务
- [1. 简介](#1-简介)
- [1.1 背景](#11-背景)
- [1.2 主流方法](#12-主流方法)
- [2. 关键信息抽取任务流程](#2-关键信息抽取任务流程)
- [2.1 训练OCR模型](#21-训练OCR模型)
- [2.2 训练KIE模型](#22-训练KIE模型)
- [3. 参考文献](#3-参考文献)
## 1. 简介
### 1.1 背景
关键信息抽取 (Key Information Extraction, KIE)指的是是从文本或者图像中,抽取出关键的信息。针对文档图像的关键信息抽取任务作为OCR的下游任务,存在非常多的实际应用场景,如表单识别、车票信息抽取、身份证信息抽取等。然而,使用人力从这些文档图像中提取或者收集关键信息耗时费力,怎样自动化融合图像中的视觉、布局、文字等特征并完成关键信息抽取是一个价值与挑战并存的问题。
对于特定场景的文档图像,其中的关键信息位置、版式等较为固定,因此在研究早期有很多基于模板匹配的方法进行关键信息的抽取,考虑到其流程较为简单,该方法仍然被广泛应用在目前的很多场景中。但是这种基于模板匹配的方法在应用到不同的场景中时,需要耗费大量精力去调整与适配模板,迁移成本较高。
文档图像中的KIE一般包含2个子任务,示意图如下图所示。
* (1)SER: 语义实体识别 (Semantic Entity Recognition),对每一个检测到的文本进行分类,如将其分为姓名,身份证。如下图中的黑色框和红色框。
* (2)RE: 关系抽取 (Relation Extraction),对每一个检测到的文本进行分类,如将其分为问题 (key) 和答案 (value) 。然后对每一个问题找到对应的答案,相当于完成key-value的匹配过程。如下图中的红色框和黑色框分别代表问题和答案,黄色线代表问题和答案之间的对应关系。
<div align="center">
<img src="https://user-images.githubusercontent.com/14270174/184588654-d87f54f3-13ab-42c4-afc0-da79bead3f14.png" width="800">
</div>
### 1.2 基于深度学习的主流方法
一般的KIE方法基于命名实体识别(Named Entity Recognition,NER)来展开研究,但是此类方法仅使用了文本信息而忽略了位置与视觉特征信息,因此精度受限。近几年大多学者开始融合多个模态的输入信息,进行特征融合,并对多模态信息进行处理,从而提升KIE的精度。主要方法有以下几种
* (1)基于Grid的方法:此类方法主要关注图像层面多模态信息的融合,文本大多大多为字符粒度,对文本与结构结构信息的嵌入方式较为简单,如Chargrid[1]等算法。
* (2)基于Token的方法:此类方法参考NLP中的BERT等方法,将位置、视觉等特征信息共同编码到多模态模型中,并且在大规模数据集上进行预训练,从而在下游任务中,仅需要少量的标注数据便可以获得很好的效果。如LayoutLM[2], LayoutLMv2[3], LayoutXLM[4], StrucText[5]等算法。
* (3)基于GCN的方法:此类方法尝试学习图像、文字之间的结构信息,从而可以解决开集信息抽取的问题(训练集中没有见过的模板),如GCN[6]、SDMGR[7]等算法。
* (4)基于End-to-end的方法:此类方法将现有的OCR文字识别以及KIE信息抽取2个任务放在一个统一的网络中进行共同学习,并在学习过程中相互加强。如Trie[8]等算法。
更多关于该系列算法的详细介绍,请参考“动手学OCR·十讲”课程的课节六部分:[文档分析理论与实践](https://aistudio.baidu.com/aistudio/education/group/info/25207)
## 2. 关键信息抽取任务流程
PaddleOCR中实现了LayoutXLM等算法(基于Token),同时,在PP-Structurev2中,对LayoutXLM多模态预训练模型的网络结构进行简化,去除了其中的Visual backbone部分,设计了视觉无关的VI-LayoutXLM模型,同时引入符合人类阅读顺序的排序逻辑以及UDML知识蒸馏策略,最终同时提升了关键信息抽取模型的精度与推理速度。
下面介绍怎样基于PaddleOCR完成关键信息抽取任务。
在非End-to-end的KIE方法中,完成关键信息抽取,至少需要**2个步骤**:首先使用OCR模型,完成文字位置与内容的提取,然后使用KIE模型,根据图像、文字位置以及文字内容,提取出其中的关键信息。
### 2.1 训练OCR模型
#### 2.1.1 文本检测
**(1)数据**
PaddleOCR中提供的模型大多数为通用模型,在进行文本检测的过程中,相邻文本行的检测一般是根据位置的远近进行区分,如上图,使用PP-OCRv3通用中英文检测模型进行文本检测时,容易将”民族“与“汉”这2个代表不同的字段检测到一起,从而增加后续KIE任务的难度。因此建议在做KIE任务的过程中,首先训练一个针对该文档数据集的检测模型。
在数据标注时,关键信息的标注需要隔开,比上图中的 “民族汉” 3个字相隔较近,此时需要将”民族“与”汉“标注为2个文本检测框,否则会增加后续KIE任务的难度。
对于下游任务,一般来说,`200~300`张的文本训练数据即可保证基本的训练效果,如果没有太多的先验知识,可以先标注 **`200~300`** 张图片,进行后续文本检测模型的训练。
**(2)模型**
在模型选择方面,推荐使用PP-OCRv3_det,关于更多关于检测模型的训练方法介绍,请参考:[OCR文本检测模型训练教程](../../doc/doc_ch/detection.md)[PP-OCRv3 文本检测模型训练教程](../../doc/doc_ch/PPOCRv3_det_train.md)
#### 2.1.2 文本识别
相对自然场景,文档图像中的文本内容识别难度一般相对较低(背景相对不太复杂),因此**优先建议**尝试PaddleOCR中提供的PP-OCRv3通用文本识别模型([PP-OCRv3模型库链接](../../doc/doc_ch/models_list.md))。
**(1)数据**
然而,在部分文档场景中也会存在一些挑战,如身份证场景中存在着罕见字,在发票等场景中的字体比较特殊,这些问题都会增加文本识别的难度,此时如果希望保证或者进一步提升模型的精度,建议基于特定文档场景的文本识别数据集,加载PP-OCRv3模型进行微调。
在模型微调的过程中,建议准备至少`5000`张垂类场景的文本识别图像,可以保证基本的模型微调效果。如果希望提升模型的精度与泛化能力,可以合成更多与该场景类似的文本识别数据,从公开数据集中收集通用真实文本识别数据,一并添加到该场景的文本识别训练任务过程中。在训练过程中,建议每个epoch的真实垂类数据、合成数据、通用数据比例在`1:1:1`左右,这可以通过设置不同数据源的采样比例进行控制。如有3个训练文本文件,分别包含1W、2W、5W条数据,那么可以在配置文件中设置数据如下:
```yml
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/train_list_1W.txt
- ./train_data/train_list_2W.txt
- ./train_data/train_list_5W.txt
ratio_list: [1.0, 0.5, 0.2]
...
```
**(2)模型**
在模型选择方面,推荐使用通用中英文文本识别模型PP-OCRv3_rec,关于更多关于文本识别模型的训练方法介绍,请参考:[OCR文本识别模型训练教程](../../doc/doc_ch/recognition.md)[PP-OCRv3文本识别模型库与配置文件](../../doc/doc_ch/models_list.md)
### 2.2 训练KIE模型
对于识别得到的文字进行关键信息抽取,有2种主要的方法。
(1)直接使用SER,获取关键信息的类别:如身份证场景中,将“姓名“与”张三“分别标记为`name_key``name_value`。最终识别得到的类别为`name_value`对应的**文本字段**即为我们所需要的关键信息。
(2)联合SER与RE进行使用:这种方法中,首先使用SER,获取图像文字内容中所有的key与value,然后使用RE方法,对所有的key与value进行配对,找到映射关系,从而完成关键信息的抽取。
#### 2.2.1 SER
以身份证场景为例, 关键信息一般包含`姓名``性别``民族`等,我们直接将对应的字段标注为特定的类别即可,如下图所示。
<div align="center">
<img src="https://user-images.githubusercontent.com/14270174/184526682-8b810397-5a93-4395-93da-37b8b8494c41.png" width="500">
</div>
**注意:**
- 标注过程中,对于无关于KIE关键信息的文本内容,均需要将其标注为`other`类别,相当于背景信息。如在身份证场景中,如果我们不关注性别信息,那么可以将“性别”与“男”这2个字段的类别均标注为`other`
- 标注过程中,需要以**文本行**为单位进行标注,无需标注单个字符的位置信息。
数据量方面,一般来说,对于比较固定的场景,**50张**左右的训练图片即可达到可以接受的效果,可以使用[PPOCRLabel](../../PPOCRLabel/README_ch.md)完成KIE的标注过程。
模型方面,推荐使用PP-Structurev2中提出的VI-LayoutXLM模型,它基于LayoutXLM模型进行改进,去除其中的视觉特征提取模块,在精度基本无损的情况下,进一步提升了模型推理速度。更多教程请参考:[VI-LayoutXLM算法介绍](../../doc/doc_ch/algorithm_kie_vi_layoutxlm.md)[KIE关键信息抽取使用教程](../../doc/doc_ch/kie.md)
#### 2.2.2 SER + RE
该过程主要包含SER与RE 2个过程。SER阶段主要用于识别出文档图像中的所有key与value,RE阶段主要用于对所有的key与value进行匹配。
以身份证场景为例, 关键信息一般包含`姓名``性别``民族`等关键信息,在SER阶段,我们需要识别所有的question (key) 与answer (value) 。标注如下所示。每个字段的类别信息(`label`字段)可以是question、answer或者other(与待抽取的关键信息无关的字段)
<div align="center">
<img src="https://user-images.githubusercontent.com/14270174/184526785-c3d2d310-cd57-4d31-b933-912716b29856.jpg" width="500">
</div>
在RE阶段,需要标注每个字段的的id与连接信息,如下图所示。
<div align="center">
<img src="https://user-images.githubusercontent.com/14270174/184528728-626f77eb-fd9f-4709-a7dc-5411cc417dab.jpg" width="500">
</div>
每个文本行字段中,需要添加`id``linking`字段信息,`id`记录该文本行的唯一标识,同一张图片中的不同文本内容不能重复,`linking`是一个列表,记录了不同文本之间的连接信息。如字段“出生”的id为0,字段“1996年1月11日”的id为1,那么它们均有[[0, 1]]的`linking`标注,表示该id=0与id=1的字段构成key-value的关系(姓名、性别等字段类似,此处不再一一赘述)。
**注意:**
- 标注过程中,如果value是多个字符,那么linking中可以新增一个key-value对,如`[[0, 1], [0, 2]]`
数据量方面,一般来说,对于比较固定的场景,**50张**左右的训练图片即可达到可以接受的效果,可以使用PPOCRLabel完成KIE的标注过程。
模型方面,推荐使用PP-Structurev2中提出的VI-LayoutXLM模型,它基于LayoutXLM模型进行改进,去除其中的视觉特征提取模块,在精度基本无损的情况下,进一步提升了模型推理速度。更多教程请参考:[VI-LayoutXLM算法介绍](../../doc/doc_ch/algorithm_kie_vi_layoutxlm.md)[KIE关键信息抽取使用教程](../../doc/doc_ch/kie.md)
## 3. 参考文献
[1] Katti A R, Reisswig C, Guder C, et al. Chargrid: Towards understanding 2d documents[J]. arXiv preprint arXiv:1809.08799, 2018.
[2] Xu Y, Li M, Cui L, et al. Layoutlm: Pre-training of text and layout for document image understanding[C]//Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. 2020: 1192-1200.
[3] Xu Y, Xu Y, Lv T, et al. LayoutLMv2: Multi-modal pre-training for visually-rich document understanding[J]. arXiv preprint arXiv:2012.14740, 2020.
[4]: Xu Y, Lv T, Cui L, et al. Layoutxlm: Multimodal pre-training for multilingual visually-rich document understanding[J]. arXiv preprint arXiv:2104.08836, 2021.
[5] Li Y, Qian Y, Yu Y, et al. StrucTexT: Structured Text Understanding with Multi-Modal Transformers[C]//Proceedings of the 29th ACM International Conference on Multimedia. 2021: 1912-1920.
[6] Liu X, Gao F, Zhang Q, et al. Graph convolution for multimodal information extraction from visually rich documents[J]. arXiv preprint arXiv:1903.11279, 2019.
[7] Sun H, Kuang Z, Yue X, et al. Spatial Dual-Modality Graph Reasoning for Key Information Extraction[J]. arXiv preprint arXiv:2103.14470, 2021.
[8] Zhang P, Xu Y, Cheng Z, et al. Trie: End-to-end text reading and information extraction for document understanding[C]//Proceedings of the 28th ACM International Conference on Multimedia. 2020: 1413-1422.
...@@ -40,14 +40,16 @@ logger = get_logger() ...@@ -40,14 +40,16 @@ logger = get_logger()
class SerPredictor(object): class SerPredictor(object):
def __init__(self, args): def __init__(self, args):
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False) self.ocr_engine = PaddleOCR(
use_angle_cls=False, show_log=False, use_gpu=args.use_gpu)
pre_process_list = [{ pre_process_list = [{
'VQATokenLabelEncode': { 'VQATokenLabelEncode': {
'algorithm': args.vqa_algorithm, 'algorithm': args.vqa_algorithm,
'class_path': args.ser_dict_path, 'class_path': args.ser_dict_path,
'contains_re': False, 'contains_re': False,
'ocr_engine': self.ocr_engine 'ocr_engine': self.ocr_engine,
'order_method': args.ocr_order_method,
} }
}, { }, {
'VQATokenPad': { 'VQATokenPad': {
......
...@@ -78,20 +78,10 @@ def export_single_model(model, ...@@ -78,20 +78,10 @@ def export_single_model(model,
shape=[None, 3, 64, 512], dtype="float32"), shape=[None, 3, 64, 512], dtype="float32"),
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "RobustScanner": elif arch_config["model_type"] == "sr":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [ other_shape = [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 3, 48, 160], dtype="float32"), shape=[None, 3, 16, 64], dtype="float32")
[
paddle.static.InputSpec(
shape=[None, ],
dtype="float32"),
paddle.static.InputSpec(
shape=[None, max_text_length],
dtype="int64")
]
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "ViTSTR": elif arch_config["algorithm"] == "ViTSTR":
...@@ -119,6 +109,22 @@ def export_single_model(model, ...@@ -119,6 +109,22 @@ def export_single_model(model,
shape=[None, 3, 64, 256], dtype="float32"), shape=[None, 3, 64, 256], dtype="float32"),
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "RobustScanner":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 48, 160], dtype="float32"),
[
paddle.static.InputSpec(
shape=[None, ],
dtype="float32"),
paddle.static.InputSpec(
shape=[None, max_text_length],
dtype="int64")
]
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]: elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [ input_spec = [
paddle.static.InputSpec( paddle.static.InputSpec(
...@@ -132,7 +138,7 @@ def export_single_model(model, ...@@ -132,7 +138,7 @@ def export_single_model(model,
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype="int64"), # image shape=[None, 3, 224, 224], dtype="int64"), # image
] ]
if arch_config["algorithm"] == "LayoutLM": if model.backbone.use_visual_backbone is False:
input_spec.pop(4) input_spec.pop(4)
model = to_static(model, input_spec=[input_spec]) model = to_static(model, input_spec=[input_spec])
else: else:
...@@ -211,6 +217,9 @@ def main(): ...@@ -211,6 +217,9 @@ def main():
else: # base rec model else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num config["Architecture"]["Head"]["out_channels"] = char_num
# for sr algorithm
if config["Architecture"]["model_type"] == "sr":
config['Architecture']["Transform"]['infer_mode'] = True
model = build_model(config["Architecture"]) model = build_model(config["Architecture"])
load_model(config, model, model_type=config['Architecture']["model_type"]) load_model(config, model, model_type=config['Architecture']["model_type"])
model.eval() model.eval()
......
...@@ -349,8 +349,7 @@ class TextRecognizer(object): ...@@ -349,8 +349,7 @@ class TextRecognizer(object):
for beg_img_no in range(0, img_num, batch_num): for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num) end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = [] norm_img_batch = []
# imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape[:3]
imgH, imgW = self.rec_image_shape[-2:]
max_wh_ratio = imgW / imgH max_wh_ratio = imgW / imgH
# max_wh_ratio = 0 # max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no): for ino in range(beg_img_no, end_img_no):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from PIL import Image
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, __dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import numpy as np
import math
import time
import traceback
import paddle
import tools.infer.utility as utility
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
logger = get_logger()
class TextSR(object):
def __init__(self, args):
self.sr_image_shape = [int(v) for v in args.sr_image_shape.split(",")]
self.sr_batch_num = args.sr_batch_num
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'sr', logger)
self.benchmark = args.benchmark
if args.benchmark:
import auto_log
pid = os.getpid()
gpu_id = utility.get_infer_gpuid()
self.autolog = auto_log.AutoLogger(
model_name="sr",
model_precision=args.precision,
batch_size=args.sr_batch_num,
data_shape="dynamic",
save_path=None, #args.save_log_path,
inference_config=self.config,
pids=pid,
process_name=None,
gpu_ids=gpu_id if args.use_gpu else None,
time_keys=[
'preprocess_time', 'inference_time', 'postprocess_time'
],
warmup=0,
logger=logger)
def resize_norm_img(self, img):
imgC, imgH, imgW = self.sr_image_shape
img = img.resize((imgW // 2, imgH // 2), Image.BICUBIC)
img_numpy = np.array(img).astype("float32")
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
return img_numpy
def __call__(self, img_list):
img_num = len(img_list)
batch_num = self.sr_batch_num
st = time.time()
st = time.time()
all_result = [] * img_num
if self.benchmark:
self.autolog.times.start()
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
imgC, imgH, imgW = self.sr_image_shape
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[ino])
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
if self.benchmark:
self.autolog.times.stamp()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
if len(outputs) != 1:
preds = outputs
else:
preds = outputs[0]
all_result.append(outputs)
if self.benchmark:
self.autolog.times.end(stamp=True)
return all_result, time.time() - st
def main(args):
image_file_list = get_image_file_list(args.image_dir)
text_recognizer = TextSR(args)
valid_image_file_list = []
img_list = []
# warmup 2 times
if args.warmup:
img = np.random.uniform(0, 255, [16, 64, 3]).astype(np.uint8)
for i in range(2):
res = text_recognizer([img] * int(args.sr_batch_num))
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = Image.open(image_file).convert("RGB")
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
valid_image_file_list.append(image_file)
img_list.append(img)
try:
preds, _ = text_recognizer(img_list)
for beg_no in range(len(preds)):
sr_img = preds[beg_no][1]
lr_img = preds[beg_no][0]
for i in (range(sr_img.shape[0])):
fm_sr = (sr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8)
fm_lr = (lr_img[i] * 255).transpose(1, 2, 0).astype(np.uint8)
img_name_pure = os.path.split(valid_image_file_list[
beg_no * args.sr_batch_num + i])[-1]
cv2.imwrite("infer_result/sr_{}".format(img_name_pure),
fm_sr[:, :, ::-1])
logger.info("The visualized image saved in infer_result/sr_{}".
format(img_name_pure))
except Exception as E:
logger.info(traceback.format_exc())
logger.info(E)
exit()
if args.benchmark:
text_recognizer.autolog.report()
if __name__ == "__main__":
main(utility.parse_args())
...@@ -121,6 +121,11 @@ def init_args(): ...@@ -121,6 +121,11 @@ def init_args():
parser.add_argument("--use_pdserving", type=str2bool, default=False) parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--warmup", type=str2bool, default=False) parser.add_argument("--warmup", type=str2bool, default=False)
# SR parmas
parser.add_argument("--sr_model_dir", type=str)
parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
parser.add_argument("--sr_batch_num", type=int, default=1)
# #
parser.add_argument( parser.add_argument(
"--draw_img_save_dir", type=str, default="./inference_results") "--draw_img_save_dir", type=str, default="./inference_results")
...@@ -156,6 +161,8 @@ def create_predictor(args, mode, logger): ...@@ -156,6 +161,8 @@ def create_predictor(args, mode, logger):
model_dir = args.table_model_dir model_dir = args.table_model_dir
elif mode == 'ser': elif mode == 'ser':
model_dir = args.ser_model_dir model_dir = args.ser_model_dir
elif mode == "sr":
model_dir = args.sr_model_dir
else: else:
model_dir = args.e2e_model_dir model_dir = args.e2e_model_dir
...@@ -205,17 +212,23 @@ def create_predictor(args, mode, logger): ...@@ -205,17 +212,23 @@ def create_predictor(args, mode, logger):
workspace_size=1 << 30, workspace_size=1 << 30,
precision_mode=precision, precision_mode=precision,
max_batch_size=args.max_batch_size, max_batch_size=args.max_batch_size,
min_subgraph_size=args.min_subgraph_size, # skip the minmum trt subgraph min_subgraph_size=args.
min_subgraph_size, # skip the minmum trt subgraph
use_calib_mode=False) use_calib_mode=False)
# collect shape # collect shape
if args.shape_info_filename is not None: if args.shape_info_filename is not None:
if not os.path.exists(args.shape_info_filename): if not os.path.exists(args.shape_info_filename):
config.collect_shape_range_info(args.shape_info_filename) config.collect_shape_range_info(args.shape_info_filename)
logger.info(f"collect dynamic shape info into : {args.shape_info_filename}") logger.info(
f"collect dynamic shape info into : {args.shape_info_filename}"
)
else: else:
logger.info(f"dynamic shape info file( {args.shape_info_filename} ) already exists, not need to generate again.") logger.info(
config.enable_tuned_tensorrt_dynamic_shape(args.shape_info_filename, True) f"dynamic shape info file( {args.shape_info_filename} ) already exists, not need to generate again."
)
config.enable_tuned_tensorrt_dynamic_shape(
args.shape_info_filename, True)
use_dynamic_shape = True use_dynamic_shape = True
if mode == "det": if mode == "det":
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import sys
import json
from PIL import Image
import cv2
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, __dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
def main():
global_config = config['Global']
# build post process
post_process_class = build_post_process(config['PostProcess'],
global_config)
# sr transform
config['Architecture']["Transform"]['infer_mode'] = True
model = build_model(config['Architecture'])
load_model(config, model)
# create data ops
transforms = []
for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Label' in op_name:
continue
elif op_name in ['SRResize']:
op[op_name]['infer_mode'] = True
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['imge_lr']
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
save_res_path = config['Global'].get('save_res_path', "./infer_result")
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
model.eval()
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
img = Image.open(file).convert("RGB")
data = {'image_lr': img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
preds = model(images)
sr_img = preds["sr_img"][0]
lr_img = preds["lr_img"][0]
fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
img_name_pure = os.path.split(file)[-1]
cv2.imwrite("infer_result/sr_{}".format(img_name_pure),
fm_sr[:, :, ::-1])
logger.info("The visualized image saved in infer_result/sr_{}".format(
img_name_pure))
logger.info("success!")
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
main()
...@@ -104,8 +104,6 @@ class SerPredictor(object): ...@@ -104,8 +104,6 @@ class SerPredictor(object):
batch = transform(data, self.ops) batch = transform(data, self.ops)
batch = to_tensor(batch) batch = to_tensor(batch)
preds = self.model(batch) preds = self.model(batch)
if self.algorithm in ['LayoutLMv2', 'LayoutXLM']:
preds = preds[0]
post_result = self.post_process_class( post_result = self.post_process_class(
preds, segment_offset_ids=batch[6], ocr_infos=batch[7]) preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
......
...@@ -25,6 +25,8 @@ import datetime ...@@ -25,6 +25,8 @@ import datetime
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from tqdm import tqdm from tqdm import tqdm
import cv2
import numpy as np
from argparse import ArgumentParser, RawDescriptionHelpFormatter from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.utils.stats import TrainingStats from ppocr.utils.stats import TrainingStats
...@@ -262,6 +264,7 @@ def train(config, ...@@ -262,6 +264,7 @@ def train(config,
config, 'Train', device, logger, seed=epoch) config, 'Train', device, logger, seed=epoch)
max_iter = len(train_dataloader) - 1 if platform.system( max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader) ) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
profiler.add_profiler_step(profiler_options) profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - reader_start train_reader_cost += time.time() - reader_start
...@@ -289,7 +292,7 @@ def train(config, ...@@ -289,7 +292,7 @@ def train(config,
else: else:
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']: elif model_type in ["kie", 'vqa', 'sr']:
preds = model(batch) preds = model(batch)
else: else:
preds = model(images) preds = model(images)
...@@ -297,11 +300,12 @@ def train(config, ...@@ -297,11 +300,12 @@ def train(config,
avg_loss = loss['loss'] avg_loss = loss['loss']
avg_loss.backward() avg_loss.backward()
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
if model_type in ['kie']: if model_type in ['kie', 'sr']:
eval_class(preds, batch) eval_class(preds, batch)
elif model_type in ['table']: elif model_type in ['table']:
post_result = post_process_class(preds, batch) post_result = post_process_class(preds, batch)
...@@ -480,6 +484,7 @@ def eval(model, ...@@ -480,6 +484,7 @@ def eval(model,
leave=True) leave=True)
max_iter = len(valid_dataloader) - 1 if platform.system( max_iter = len(valid_dataloader) - 1 if platform.system(
) == "Windows" else len(valid_dataloader) ) == "Windows" else len(valid_dataloader)
sum_images = 0
for idx, batch in enumerate(valid_dataloader): for idx, batch in enumerate(valid_dataloader):
if idx >= max_iter: if idx >= max_iter:
break break
...@@ -493,6 +498,20 @@ def eval(model, ...@@ -493,6 +498,20 @@ def eval(model,
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']: elif model_type in ["kie", 'vqa']:
preds = model(batch) preds = model(batch)
elif model_type in ['sr']:
preds = model(batch)
sr_img = preds["sr_img"]
lr_img = preds["lr_img"]
for i in (range(sr_img.shape[0])):
fm_sr = (sr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
fm_lr = (lr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
cv2.imwrite("output/images/{}_{}_sr.jpg".format(
sum_images, i), fm_sr)
cv2.imwrite("output/images/{}_{}_lr.jpg".format(
sum_images, i), fm_lr)
else: else:
preds = model(images) preds = model(images)
else: else:
...@@ -500,6 +519,20 @@ def eval(model, ...@@ -500,6 +519,20 @@ def eval(model,
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']: elif model_type in ["kie", 'vqa']:
preds = model(batch) preds = model(batch)
elif model_type in ['sr']:
preds = model(batch)
sr_img = preds["sr_img"]
lr_img = preds["lr_img"]
for i in (range(sr_img.shape[0])):
fm_sr = (sr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
fm_lr = (lr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
cv2.imwrite("output/images/{}_{}_sr.jpg".format(
sum_images, i), fm_sr)
cv2.imwrite("output/images/{}_{}_lr.jpg".format(
sum_images, i), fm_lr)
else: else:
preds = model(images) preds = model(images)
...@@ -517,12 +550,15 @@ def eval(model, ...@@ -517,12 +550,15 @@ def eval(model,
elif model_type in ['table', 'vqa']: elif model_type in ['table', 'vqa']:
post_result = post_process_class(preds, batch_numpy) post_result = post_process_class(preds, batch_numpy)
eval_class(post_result, batch_numpy) eval_class(post_result, batch_numpy)
elif model_type in ['sr']:
eval_class(preds, batch_numpy)
else: else:
post_result = post_process_class(preds, batch_numpy[1]) post_result = post_process_class(preds, batch_numpy[1])
eval_class(post_result, batch_numpy) eval_class(post_result, batch_numpy)
pbar.update(1) pbar.update(1)
total_frame += len(images) total_frame += len(images)
sum_images += 1
# Get final metric,eg. acc or hmean # Get final metric,eg. acc or hmean
metric = eval_class.get_metric() metric = eval_class.get_metric()
...@@ -616,7 +652,8 @@ def preprocess(is_train=False): ...@@ -616,7 +652,8 @@ def preprocess(is_train=False):
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', 'RobustScanner' 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
'Gestalt', 'RobustScanner'
] ]
if use_xpu: if use_xpu:
......
...@@ -119,6 +119,7 @@ def main(config, device, logger, vdl_writer): ...@@ -119,6 +119,7 @@ def main(config, device, logger, vdl_writer):
config['Loss']['ignore_index'] = char_num - 1 config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
model = apply_to_static(model, config, logger) model = apply_to_static(model, config, logger)
# build loss # build loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册