提交 e627147a 编写于 作者: A andyjpaddle

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into dygraph

# 智能运营:通用中文表格识别
- [1. 背景介绍](#1-背景介绍)
- [2. 中文表格识别](#2-中文表格识别)
- [2.1 环境准备](#21-环境准备)
- [2.2 准备数据集](#22-准备数据集)
- [2.2.1 划分训练测试集](#221-划分训练测试集)
- [2.2.2 查看数据集](#222-查看数据集)
- [2.3 训练](#23-训练)
- [2.4 验证](#24-验证)
- [2.5 训练引擎推理](#25-训练引擎推理)
- [2.6 模型导出](#26-模型导出)
- [2.7 预测引擎推理](#27-预测引擎推理)
- [2.8 表格识别](#28-表格识别)
- [3. 表格属性识别](#3-表格属性识别)
- [3.1 代码、环境、数据准备](#31-代码环境数据准备)
- [3.1.1 代码准备](#311-代码准备)
- [3.1.2 环境准备](#312-环境准备)
- [3.1.3 数据准备](#313-数据准备)
- [3.2 表格属性识别训练](#32-表格属性识别训练)
- [3.3 表格属性识别推理和部署](#33-表格属性识别推理和部署)
- [3.3.1 模型转换](#331-模型转换)
- [3.3.2 模型推理](#332-模型推理)
## 1. 背景介绍
中文表格识别在金融行业有着广泛的应用,如保险理赔、财报分析和信息录入等领域。当前,金融行业的表格识别主要以手动录入为主,开发一种自动表格识别成为丞待解决的问题。
![](https://ai-studio-static-online.cdn.bcebos.com/d1e7780f0c7745ada4be540decefd6288e4d59257d8141f6842682a4c05d28b6)
在金融行业中,表格图像主要有清单类的单元格密集型表格,申请表类的大单元格表格,拍照表格和倾斜表格四种主要形式。
![](https://ai-studio-static-online.cdn.bcebos.com/da82ae8ef8fd479aaa38e1049eb3a681cf020dc108fa458eb3ec79da53b45fd1)
![](https://ai-studio-static-online.cdn.bcebos.com/5ffff2093a144a6993a75eef71634a52276015ee43a04566b9c89d353198c746)
当前的表格识别算法不能很好的处理这些场景下的表格图像。在本例中,我们使用PP-Structurev2最新发布的表格识别模型SLANet来演示如何进行中文表格是识别。同时,为了方便作业流程,我们使用表格属性识别模型对表格图像的属性进行识别,对表格的难易程度进行判断,加快人工进行校对速度。
本项目AI Studio链接:https://aistudio.baidu.com/aistudio/projectdetail/4588067
## 2. 中文表格识别
### 2.1 环境准备
```python
# 下载PaddleOCR代码
! git clone -b dygraph https://gitee.com/paddlepaddle/PaddleOCR
```
```python
# 安装PaddleOCR环境
! pip install -r PaddleOCR/requirements.txt --force-reinstall
! pip install protobuf==3.19
```
### 2.2 准备数据集
本例中使用的数据集采用表格[生成工具](https://github.com/WenmuZhou/TableGeneration)制作。
使用如下命令对数据集进行解压,并查看数据集大小
```python
! cd data/data165849 && tar -xf table_gen_dataset.tar && cd -
! wc -l data/data165849/table_gen_dataset/gt.txt
```
#### 2.2.1 划分训练测试集
使用下述命令将数据集划分为训练集和测试集, 这里将90%划分为训练集,10%划分为测试集
```python
import random
with open('/home/aistudio/data/data165849/table_gen_dataset/gt.txt') as f:
lines = f.readlines()
random.shuffle(lines)
train_len = int(len(lines)*0.9)
train_list = lines[:train_len]
val_list = lines[train_len:]
# 保存结果
with open('/home/aistudio/train.txt','w',encoding='utf-8') as f:
f.writelines(train_list)
with open('/home/aistudio/val.txt','w',encoding='utf-8') as f:
f.writelines(val_list)
```
划分完成后,数据集信息如下
|类型|数量|图片地址|标注文件路径|
|---|---|---|---|
|训练集|18000|/home/aistudio/data/data165849/table_gen_dataset|/home/aistudio/train.txt|
|测试集|2000|/home/aistudio/data/data165849/table_gen_dataset|/home/aistudio/val.txt|
#### 2.2.2 查看数据集
```python
import cv2
import os, json
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline
def parse_line(data_dir, line):
data_line = line.strip("\n")
info = json.loads(data_line)
file_name = info['filename']
cells = info['html']['cells'].copy()
structure = info['html']['structure']['tokens'].copy()
img_path = os.path.join(data_dir, file_name)
if not os.path.exists(img_path):
print(img_path)
return None
data = {
'img_path': img_path,
'cells': cells,
'structure': structure,
'file_name': file_name
}
return data
def draw_bbox(img_path, points, color=(255, 0, 0), thickness=2):
if isinstance(img_path, str):
img_path = cv2.imread(img_path)
img_path = img_path.copy()
for point in points:
cv2.polylines(img_path, [point.astype(int)], True, color, thickness)
return img_path
def rebuild_html(data):
html_code = data['structure']
cells = data['cells']
to_insert = [i for i, tag in enumerate(html_code) if tag in ('<td>', '>')]
for i, cell in zip(to_insert[::-1], cells[::-1]):
if cell['tokens']:
text = ''.join(cell['tokens'])
# skip empty text
sp_char_list = ['<b>', '</b>', '\u2028', ' ', '<i>', '</i>']
text_remove_style = skip_char(text, sp_char_list)
if len(text_remove_style) == 0:
continue
html_code.insert(i + 1, text)
html_code = ''.join(html_code)
return html_code
def skip_char(text, sp_char_list):
"""
skip empty cell
@param text: text in cell
@param sp_char_list: style char and special code
@return:
"""
for sp_char in sp_char_list:
text = text.replace(sp_char, '')
return text
save_dir = '/home/aistudio/vis'
os.makedirs(save_dir, exist_ok=True)
image_dir = '/home/aistudio/data/data165849/'
html_str = '<table border="1">'
# 解析标注信息并还原html表格
data = parse_line(image_dir, val_list[0])
img = cv2.imread(data['img_path'])
img_name = ''.join(os.path.basename(data['file_name']).split('.')[:-1])
img_save_name = os.path.join(save_dir, img_name)
boxes = [np.array(x['bbox']) for x in data['cells']]
show_img = draw_bbox(data['img_path'], boxes)
cv2.imwrite(img_save_name + '_show.jpg', show_img)
html = rebuild_html(data)
html_str += html
html_str += '</table>'
# 显示标注的html字符串
from IPython.core.display import display, HTML
display(HTML(html_str))
# 显示单元格坐标
plt.figure(figsize=(15,15))
plt.imshow(show_img)
plt.show()
```
### 2.3 训练
这里选用PP-Structurev2中的表格识别模型[SLANet](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/configs/table/SLANet.yml)
SLANet是PP-Structurev2全新推出的表格识别模型,相比PP-Structurev1中TableRec-RARE,在速度不变的情况下精度提升4.7%。TEDS提升2%
|算法|Acc|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|Speed|
| --- | --- | --- | ---|
| EDD<sup>[2]</sup> |x| 88.3% |x|
| TableRec-RARE(ours) | 71.73%| 93.88% |779ms|
| SLANet(ours) | 76.31%| 95.89%|766ms|
进行训练之前先使用如下命令下载预训练模型
```python
# 进入PaddleOCR工作目录
os.chdir('/home/aistudio/PaddleOCR')
# 下载英文预训练模型
! wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_train.tar --no-check-certificate
! cd ./pretrain_models/ && tar xf en_ppstructure_mobile_v2.0_SLANet_train.tar && cd ../
```
使用如下命令即可启动训练,需要修改的配置有
|字段|修改值|含义|
|---|---|---|
|Global.pretrained_model|./pretrain_models/en_ppstructure_mobile_v2.0_SLANet_train/best_accuracy.pdparams|指向英文表格预训练模型地址|
|Global.eval_batch_step|562|模型多少step评估一次,一般设置为一个epoch总的step数|
|Optimizer.lr.name|Const|学习率衰减器 |
|Optimizer.lr.learning_rate|0.0005|学习率设为之前的0.05倍 |
|Train.dataset.data_dir|/home/aistudio/data/data165849|指向训练集图片存放目录 |
|Train.dataset.label_file_list|/home/aistudio/data/data165849/table_gen_dataset/train.txt|指向训练集标注文件 |
|Train.loader.batch_size_per_card|32|训练时每张卡的batch_size |
|Train.loader.num_workers|1|训练集多进程数据读取的进程数,在aistudio中需要设为1 |
|Eval.dataset.data_dir|/home/aistudio/data/data165849|指向测试集图片存放目录 |
|Eval.dataset.label_file_list|/home/aistudio/data/data165849/table_gen_dataset/val.txt|指向测试集标注文件 |
|Eval.loader.batch_size_per_card|32|测试时每张卡的batch_size |
|Eval.loader.num_workers|1|测试集多进程数据读取的进程数,在aistudio中需要设为1 |
已经修改好的配置存储在 `/home/aistudio/SLANet_ch.yml`
```python
import os
os.chdir('/home/aistudio/PaddleOCR')
! python3 tools/train.py -c /home/aistudio/SLANet_ch.yml
```
大约在7个epoch后达到最高精度 97.49%
### 2.4 验证
训练完成后,可使用如下命令在测试集上评估最优模型的精度
```python
! python3 tools/eval.py -c /home/aistudio/SLANet_ch.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/SLANet_ch/best_accuracy.pdparams
```
### 2.5 训练引擎推理
使用如下命令可使用训练引擎对单张图片进行推理
```python
import os;os.chdir('/home/aistudio/PaddleOCR')
! python3 tools/infer_table.py -c /home/aistudio/SLANet_ch.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/SLANet_ch/best_accuracy.pdparams Global.infer_img=/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg
```
```python
import cv2
from matplotlib import pyplot as plt
%matplotlib inline
# 显示原图
show_img = cv2.imread('/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg')
plt.figure(figsize=(15,15))
plt.imshow(show_img)
plt.show()
# 显示预测的单元格
show_img = cv2.imread('/home/aistudio/PaddleOCR/output/infer/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg')
plt.figure(figsize=(15,15))
plt.imshow(show_img)
plt.show()
```
### 2.6 模型导出
使用如下命令可将模型导出为inference模型
```python
! python3 tools/export_model.py -c /home/aistudio/SLANet_ch.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/SLANet_ch/best_accuracy.pdparams Global.save_inference_dir=/home/aistudio/SLANet_ch/infer
```
### 2.7 预测引擎推理
使用如下命令可使用预测引擎对单张图片进行推理
```python
os.chdir('/home/aistudio/PaddleOCR/ppstructure')
! python3 table/predict_structure.py \
--table_model_dir=/home/aistudio/SLANet_ch/infer \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--image_dir=/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg \
--output=../output/inference
```
```python
# 显示原图
show_img = cv2.imread('/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg')
plt.figure(figsize=(15,15))
plt.imshow(show_img)
plt.show()
# 显示预测的单元格
show_img = cv2.imread('/home/aistudio/PaddleOCR/output/inference/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg')
plt.figure(figsize=(15,15))
plt.imshow(show_img)
plt.show()
```
### 2.8 表格识别
在表格结构模型训练完成后,可结合OCR检测识别模型,对表格内容进行识别。
首先下载PP-OCRv3文字检测识别模型
```python
# 下载PP-OCRv3文本检测识别模型并解压
! wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar --no-check-certificate
! wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar --no-check-certificate
! cd ./inference/ && tar xf ch_PP-OCRv3_det_slim_infer.tar && tar xf ch_PP-OCRv3_rec_slim_infer.tar && cd ../
```
模型下载完成后,使用如下命令进行表格识别
```python
import os;os.chdir('/home/aistudio/PaddleOCR/ppstructure')
! python3 table/predict_table.py \
--det_model_dir=inference/ch_PP-OCRv3_det_slim_infer \
--rec_model_dir=inference/ch_PP-OCRv3_rec_slim_infer \
--table_model_dir=/home/aistudio/SLANet_ch/infer \
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--image_dir=/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg \
--output=../output/table
```
```python
# 显示原图
show_img = cv2.imread('/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg')
plt.figure(figsize=(15,15))
plt.imshow(show_img)
plt.show()
# 显示预测结果
from IPython.core.display import display, HTML
display(HTML('<html><body><table><tr><td colspan="5">alleadersh</td><td rowspan="2">不贰过,推</td><td rowspan="2">从自己参与浙江数</td><td rowspan="2">。另一方</td></tr><tr><td>AnSha</td><td>自己越</td><td>共商共建工作协商</td><td>w.east </td><td>抓好改革试点任务</td></tr><tr><td>Edime</td><td>ImisesElec</td><td>怀天下”。</td><td></td><td>22.26 </td><td>31.61</td><td>4.30 </td><td>794.94</td></tr><tr><td rowspan="2">ip</td><td> Profundi</td><td>:2019年12月1</td><td>Horspro</td><td>444.48</td><td>2.41 </td><td>87</td><td>679.98</td></tr><tr><td> iehaiTrain</td><td>组长蒋蕊</td><td>Toafterdec</td><td>203.43</td><td>23.54 </td><td>4</td><td>4266.62</td></tr><tr><td>Tyint </td><td> roudlyRol</td><td>谢您的好意,我知道</td><td>ErChows</td><td></td><td>48.90</td><td>1031</td><td>6</td></tr><tr><td>NaFlint</td><td></td><td>一辈的</td><td>aterreclam</td><td>7823.86</td><td>9829.23</td><td>7.96 </td><td> 3068</td></tr><tr><td>家上下游企业,5</td><td>Tr</td><td>景象。当地球上的我们</td><td>Urelaw</td><td>799.62</td><td>354.96</td><td>12.98</td><td>33 </td></tr><tr><td>赛事(</td><td> uestCh</td><td>复制的业务模式并</td><td>Listicjust</td><td>9.23</td><td></td><td>92</td><td>53.22</td></tr><tr><td> Ca</td><td> Iskole</td><td>扶贫"之名引导</td><td> Papua </td><td>7191.90</td><td>1.65</td><td>3.62</td><td>48</td></tr><tr><td rowspan="2">避讳</td><td>ir</td><td>但由于</td><td>Fficeof</td><td>0.22</td><td>6.37</td><td>7.17</td><td>3397.75</td></tr><tr><td>ndaTurk</td><td>百处遗址</td><td>gMa</td><td>1288.34</td><td>2053.66</td><td>2.29</td><td>885.45</td></tr></table></body></html>'))
```
## 3. 表格属性识别
### 3.1 代码、环境、数据准备
#### 3.1.1 代码准备
首先,我们需要准备训练表格属性的代码,PaddleClas集成了PULC方案,该方案可以快速获得一个在CPU上用时2ms的属性识别模型。PaddleClas代码可以clone下载得到。获取方式如下:
```python
! git clone -b develop https://gitee.com/paddlepaddle/PaddleClas
```
#### 3.1.2 环境准备
其次,我们需要安装训练PaddleClas相关的依赖包
```python
! pip install -r PaddleClas/requirements.txt --force-reinstall
! pip install protobuf==3.20.0
```
#### 3.1.3 数据准备
最后,准备训练数据。在这里,我们一共定义了表格的6个属性,分别是表格来源、表格数量、表格颜色、表格清晰度、表格有无干扰、表格角度。其可视化如下:
![](https://user-images.githubusercontent.com/45199522/190587903-ccdfa6fb-51e8-42de-b08b-a127cb04e304.png)
这里,我们提供了一个表格属性的demo子集,可以快速迭代体验。下载方式如下:
```python
%cd PaddleClas/dataset
!wget https://paddleclas.bj.bcebos.com/data/PULC/table_attribute.tar
!tar -xf table_attribute.tar
%cd ../PaddleClas/dataset
%cd ../
```
### 3.2 表格属性识别训练
表格属性训练整体pipelinie如下:
![](https://user-images.githubusercontent.com/45199522/190599426-3415b38e-e16e-4e68-9253-2ff531b1b5ca.png)
1.训练过程中,图片经过预处理之后,送入到骨干网络之中,骨干网络将抽取表格图片的特征,最终该特征连接输出的FC层,FC层经过Sigmoid激活函数后和真实标签做交叉熵损失函数,优化器通过对该损失函数做梯度下降来更新骨干网络的参数,经过多轮训练后,骨干网络的参数可以对为止图片做很好的预测;
2.推理过程中,图片经过预处理之后,送入到骨干网络之中,骨干网络加载学习好的权重后对该表格图片做出预测,预测的结果为一个6维向量,该向量中的每个元素反映了每个属性对应的概率值,通过对该值进一步卡阈值之后,得到最终的输出,最终的输出描述了该表格的6个属性。
当准备好相关的数据之后,可以一键启动表格属性的训练,训练代码如下:
```python
!python tools/train.py -c ./ppcls/configs/PULC/table_attribute/PPLCNet_x1_0.yaml -o Global.device=cpu -o Global.epochs=10
```
### 3.3 表格属性识别推理和部署
#### 3.3.1 模型转换
当训练好模型之后,需要将模型转换为推理模型进行部署。转换脚本如下:
```python
!python tools/export_model.py -c ppcls/configs/PULC/table_attribute/PPLCNet_x1_0.yaml -o Global.pretrained_model=output/PPLCNet_x1_0/best_model
```
执行以上命令之后,会在当前目录上生成`inference`文件夹,该文件夹中保存了当前精度最高的推理模型。
#### 3.3.2 模型推理
安装推理需要的paddleclas包, 此时需要通过下载安装paddleclas的develop的whl包
```python
!pip install https://paddleclas.bj.bcebos.com/whl/paddleclas-0.0.0-py3-none-any.whl
```
进入`deploy`目录下即可对模型进行推理
```python
%cd deploy/
```
推理命令如下:
```python
!python python/predict_cls.py -c configs/PULC/table_attribute/inference_table_attribute.yaml -o Global.inference_model_dir="../inference" -o Global.infer_imgs="../dataset/table_attribute/Table_val/val_9.jpg"
!python python/predict_cls.py -c configs/PULC/table_attribute/inference_table_attribute.yaml -o Global.inference_model_dir="../inference" -o Global.infer_imgs="../dataset/table_attribute/Table_val/val_3253.jpg"
```
推理的表格图片:
![](https://user-images.githubusercontent.com/45199522/190596141-74f4feda-b082-46d7-908d-b0bd5839b430.png)
预测结果如下:
```
val_9.jpg: {'attributes': ['Scanned', 'Little', 'Black-and-White', 'Clear', 'Without-Obstacles', 'Horizontal'], 'output': [1, 1, 1, 1, 1, 1]}
```
推理的表格图片:
![](https://user-images.githubusercontent.com/45199522/190597086-2e685200-22d0-4042-9e46-f61f24e02e4e.png)
预测结果如下:
```
val_3253.jpg: {'attributes': ['Photo', 'Little', 'Black-and-White', 'Blurry', 'Without-Obstacles', 'Tilted'], 'output': [0, 1, 1, 0, 1, 0]}
```
对比两张图片可以发现,第一张图片比较清晰,表格属性的结果也偏向于比较容易识别,我们可以更相信表格识别的结果,第二张图片比较模糊,且存在倾斜现象,表格识别可能存在错误,需要我们人工进一步校验。通过表格的属性识别能力,可以进一步将“人工”和“智能”很好的结合起来,为表格识别能力的落地的精度提供保障。
此差异已折叠。
此差异已折叠。
# 金融智能核验:扫描合同关键信息抽取
本案例将使用OCR技术和通用信息抽取技术,实现合同关键信息审核和比对。通过本章的学习,你可以快速掌握:
1. 使用PaddleOCR提取扫描文本内容
2. 使用PaddleNLP抽取自定义信息
点击进入 [AI Studio 项目](https://aistudio.baidu.com/aistudio/projectdetail/4545772)
## 1. 项目背景
合同审核广泛应用于大中型企业、上市公司、证券、基金公司中,是规避风险的重要任务。
- 合同内容对比:合同审核场景中,快速找出不同版本合同修改区域、版本差异;如合同盖章归档场景中有效识别实际签署的纸质合同、电子版合同差异。
- 合规性检查:法务人员进行合同审核,如合同完备性检查、大小写金额检查、签约主体一致性检查、双方权利和义务对等性分析等。
- 风险点识别:通过合同审核可识别事实倾向型风险点和数值计算型风险点等,例如交付地点约定不明、合同总价款不一致、重要条款缺失等风险点。
![](https://ai-studio-static-online.cdn.bcebos.com/d5143df967fa4364a38868793fe7c57b0c0b1213930243babd6ae01423dcbc4d)
传统业务中大多使用人工进行纸质版合同审核,存在成本高,工作量大,效率低的问题,且一旦出错将造成巨额损失。
本项目针对以上场景,使用PaddleOCR+PaddleNLP快速提取文本内容,经过少量数据微调即可准确抽取关键信息,**高效完成合同内容对比、合规性检查、风险点识别等任务,提高效率,降低风险**
![](https://ai-studio-static-online.cdn.bcebos.com/54f3053e6e1b47a39b26e757006fe2c44910d60a3809422ab76c25396b92e69b)
## 2. 解决方案
### 2.1 扫描合同文本内容提取
使用PaddleOCR开源的模型可以快速完成扫描文档的文本内容提取,在清晰文档上识别准确率可达到95%+。下面来快速体验一下:
#### 2.1.1 环境准备
[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)提供了适用于通用场景的高精轻量模型,提供数据预处理-模型推理-后处理全流程,支持pip安装:
```
python -m pip install paddleocr
```
#### 2.1.2 效果测试
使用一张合同图片作为测试样本,感受ppocrv3模型效果:
<img src=https://ai-studio-static-online.cdn.bcebos.com/46258d0dc9dc40bab3ea0e70434e4a905646df8a647f4c49921e217de5142def width=300>
使用中文检测+识别模型提取文本,实例化PaddleOCR类:
```
from paddleocr import PaddleOCR, draw_ocr
# paddleocr目前支持中英文、英文、法语、德语、韩语、日语等80个语种,可以通过修改lang参数进行切换
ocr = PaddleOCR(use_angle_cls=False, lang="ch") # need to run only once to download and load model into memory
```
一行命令启动预测,预测结果包括`检测框``文本识别内容`:
```
img_path = "./test_img/hetong2.jpg"
result = ocr.ocr(img_path, cls=False)
for line in result:
print(line)
# 可视化结果
from PIL import Image
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='./simfang.ttf')
im_show = Image.fromarray(im_show)
im_show.show()
```
#### 2.1.3 图片预处理
通过上图可视化结果可以看到,印章部分造成的文本遮盖,影响了文本识别结果,因此可以考虑通道提取,去除图片中的红色印章:
```
import cv2
import numpy as np
import matplotlib.pyplot as plt
#读入图像,三通道
image=cv2.imread("./test_img/hetong2.jpg",cv2.IMREAD_COLOR) #timg.jpeg
#获得三个通道
Bch,Gch,Rch=cv2.split(image)
#保存三通道图片
cv2.imwrite('blue_channel.jpg',Bch)
cv2.imwrite('green_channel.jpg',Gch)
cv2.imwrite('red_channel.jpg',Rch)
```
#### 2.1.4 合同文本信息提取
经过2.1.3的预处理后,合同照片的红色通道被分离,获得了一张相对更干净的图片,此时可以再次使用ppocr模型提取文本内容:
```
import numpy as np
import cv2
img_path = './red_channel.jpg'
result = ocr.ocr(img_path, cls=False)
# 可视化结果
from PIL import Image
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='./simfang.ttf')
im_show = Image.fromarray(im_show)
vis = np.array(im_show)
im_show.show()
```
忽略检测框内容,提取完整的合同文本:
```
txts = [line[1][0] for line in result]
all_context = "\n".join(txts)
print(all_context)
```
通过以上环节就完成了扫描合同关键信息抽取的第一步:文本内容提取,接下来可以基于识别出的文本内容抽取关键信息
### 2.2 合同关键信息抽取
#### 2.2.1 环境准备
安装PaddleNLP
```
pip install --upgrade pip
pip install --upgrade paddlenlp
```
#### 2.2.2 合同关键信息抽取
PaddleNLP 使用 Taskflow 统一管理多场景任务的预测功能,其中`information_extraction` 通过大量的有标签样本进行训练,在通用的场景中一般可以直接使用,只需更换关键字即可。例如在合同信息抽取中,我们重新定义抽取关键字:
甲方、乙方、币种、金额、付款方式
将使用OCR提取好的文本作为输入,使用三行命令可以对上文中提取到的合同文本进行关键信息抽取:
```
from paddlenlp import Taskflow
schema = ["甲方","乙方","总价"]
ie = Taskflow('information_extraction', schema=schema)
ie.set_schema(schema)
ie(all_context)
```
可以看到UIE模型可以准确的提取出关键信息,用于后续的信息比对或审核。
## 3.效果优化
### 3.1 文本识别后处理调优
实际图片采集过程中,可能出现部分图片弯曲等问题,导致使用默认参数识别文本时存在漏检,影响关键信息获取。
例如下图:
<img src="https://ai-studio-static-online.cdn.bcebos.com/fe350481be0241c58736d487d1bf06c2e65911bf01254a79944be629c4c10091" height="300" width="300">
直接进行预测:
```
img_path = "./test_img/hetong3.jpg"
# 预测结果
result = ocr.ocr(img_path, cls=False)
# 可视化结果
from PIL import Image
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='./simfang.ttf')
im_show = Image.fromarray(im_show)
im_show.show()
```
可视化结果可以看到,弯曲图片存在漏检,一般来说可以通过调整后处理参数解决,无需重新训练模型。漏检问题往往是因为检测模型获得的分割图太小,生成框的得分过低被过滤掉了,通常有两种方式调整参数:
- 开启`use_dilatiion=True` 膨胀分割区域
- 调小`det_db_box_thresh`阈值
```
# 重新实例化 PaddleOCR
ocr = PaddleOCR(use_angle_cls=False, lang="ch", det_db_box_thresh=0.3, use_dilation=True)
# 预测并可视化
img_path = "./test_img/hetong3.jpg"
# 预测结果
result = ocr.ocr(img_path, cls=False)
# 可视化结果
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='./simfang.ttf')
im_show = Image.fromarray(im_show)
im_show.show()
```
可以看到漏检问题被很好的解决,提取完整的文本内容:
```
txts = [line[1][0] for line in result]
context = "\n".join(txts)
print(context)
```
### 3.2 关键信息提取调优
UIE通过大量有标签样本进行训练,得到了一个开箱即用的高精模型。 然而针对不同场景,可能会出现部分实体无法被抽取的情况。通常来说有以下几个方法进行效果调优:
- 修改 schema
- 添加正则方法
- 标注小样本微调模型
**修改schema**
Prompt和原文描述越像,抽取效果越好,例如
```
三:合同价格:总价为人民币大写:参拾玖万捌仟伍佰
元,小写:398500.00元。总价中包括站房工程建设、安装
及相关避雷、消防、接地、电力、材料费、检验费、安全、
验收等所需费用及其他相关费用和税金。
```
schema = ["总金额"] 时无法准确抽取,与原文描述差异较大。 修改 schema = ["总价"] 再次尝试:
```
from paddlenlp import Taskflow
# schema = ["总金额"]
schema = ["总价"]
ie = Taskflow('information_extraction', schema=schema)
ie.set_schema(schema)
ie(all_context)
```
**模型微调**
UIE的建模方式主要是通过 `Prompt` 方式来建模, `Prompt` 在小样本上进行微调效果非常有效。详细的数据标注+模型微调步骤可以参考项目:
[PaddleNLP信息抽取技术重磅升级!](https://aistudio.baidu.com/aistudio/projectdetail/3914778?channelType=0&channel=0)
[工单信息抽取](https://aistudio.baidu.com/aistudio/projectdetail/3914778?contributionType=1)
[快递单信息抽取](https://aistudio.baidu.com/aistudio/projectdetail/4038499?contributionType=1)
## 总结
扫描合同的关键信息提取可以使用 PaddleOCR + PaddleNLP 组合实现,两个工具均有以下优势:
* 使用简单:whl包一键安装,3行命令调用
* 效果领先:优秀的模型效果可覆盖几乎全部的应用场景
* 调优成本低:OCR模型可通过后处理参数的调整适配略有偏差的扫描文本, UIE模型可以通过极少的标注样本微调,成本很低。
## 作业
尝试自己解析出 `test_img/homework.png` 扫描合同中的 [甲方、乙方] 关键词:
<img src=https://ai-studio-static-online.cdn.bcebos.com/50a49a3c9f8348bfa04e8c8b97d3cce0d0dd6b14040f43939268d120688ef7ca width=300 hight=400>
更多场景下的垂类模型获取,请扫下图二维码填写问卷,加入PaddleOCR官方交流群获取模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁
<img src=https://ai-studio-static-online.cdn.bcebos.com/606538b59ea845cb99943b1dec6efe724e78f75c1e9c49228c7bf7da9f8837f5 width=300 hight=300>
...@@ -68,6 +68,7 @@ Train: ...@@ -68,6 +68,7 @@ Train:
- VQAReTokenRelation: - VQAReTokenRelation:
- VQAReTokenChunk: - VQAReTokenChunk:
max_seq_len: *max_seq_len max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize: - Resize:
size: [224,224] size: [224,224]
- NormalizeImage: - NormalizeImage:
...@@ -83,7 +84,6 @@ Train: ...@@ -83,7 +84,6 @@ Train:
drop_last: False drop_last: False
batch_size_per_card: 2 batch_size_per_card: 2
num_workers: 8 num_workers: 8
collate_fn: ListCollator
Eval: Eval:
dataset: dataset:
...@@ -105,6 +105,7 @@ Eval: ...@@ -105,6 +105,7 @@ Eval:
- VQAReTokenRelation: - VQAReTokenRelation:
- VQAReTokenChunk: - VQAReTokenChunk:
max_seq_len: *max_seq_len max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize: - Resize:
size: [224,224] size: [224,224]
- NormalizeImage: - NormalizeImage:
...@@ -120,4 +121,3 @@ Eval: ...@@ -120,4 +121,3 @@ Eval:
drop_last: False drop_last: False
batch_size_per_card: 8 batch_size_per_card: 8
num_workers: 8 num_workers: 8
collate_fn: ListCollator
...@@ -73,6 +73,7 @@ Train: ...@@ -73,6 +73,7 @@ Train:
- VQAReTokenRelation: - VQAReTokenRelation:
- VQAReTokenChunk: - VQAReTokenChunk:
max_seq_len: *max_seq_len max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize: - Resize:
size: [224,224] size: [224,224]
- NormalizeImage: - NormalizeImage:
...@@ -82,13 +83,12 @@ Train: ...@@ -82,13 +83,12 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
batch_size_per_card: 2 batch_size_per_card: 2
num_workers: 4 num_workers: 4
collate_fn: ListCollator
Eval: Eval:
dataset: dataset:
...@@ -112,6 +112,7 @@ Eval: ...@@ -112,6 +112,7 @@ Eval:
- VQAReTokenRelation: - VQAReTokenRelation:
- VQAReTokenChunk: - VQAReTokenChunk:
max_seq_len: *max_seq_len max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize: - Resize:
size: [224,224] size: [224,224]
- NormalizeImage: - NormalizeImage:
...@@ -121,11 +122,9 @@ Eval: ...@@ -121,11 +122,9 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
batch_size_per_card: 8 batch_size_per_card: 8
num_workers: 8 num_workers: 8
collate_fn: ListCollator
...@@ -57,14 +57,16 @@ Loss: ...@@ -57,14 +57,16 @@ Loss:
mode: "l2" mode: "l2"
model_name_pairs: model_name_pairs:
- ["Student", "Teacher"] - ["Student", "Teacher"]
key: hidden_states_5 key: hidden_states
index: 5
name: "loss_5" name: "loss_5"
- DistillationVQADistanceLoss: - DistillationVQADistanceLoss:
weight: 0.5 weight: 0.5
mode: "l2" mode: "l2"
model_name_pairs: model_name_pairs:
- ["Student", "Teacher"] - ["Student", "Teacher"]
key: hidden_states_8 key: hidden_states
index: 8
name: "loss_8" name: "loss_8"
...@@ -116,6 +118,7 @@ Train: ...@@ -116,6 +118,7 @@ Train:
- VQAReTokenRelation: - VQAReTokenRelation:
- VQAReTokenChunk: - VQAReTokenChunk:
max_seq_len: *max_seq_len max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize: - Resize:
size: [224,224] size: [224,224]
- NormalizeImage: - NormalizeImage:
...@@ -125,13 +128,12 @@ Train: ...@@ -125,13 +128,12 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
batch_size_per_card: 2 batch_size_per_card: 2
num_workers: 4 num_workers: 4
collate_fn: ListCollator
Eval: Eval:
dataset: dataset:
...@@ -155,6 +157,7 @@ Eval: ...@@ -155,6 +157,7 @@ Eval:
- VQAReTokenRelation: - VQAReTokenRelation:
- VQAReTokenChunk: - VQAReTokenChunk:
max_seq_len: *max_seq_len max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize: - Resize:
size: [224,224] size: [224,224]
- NormalizeImage: - NormalizeImage:
...@@ -164,12 +167,11 @@ Eval: ...@@ -164,12 +167,11 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
batch_size_per_card: 8 batch_size_per_card: 8
num_workers: 8 num_workers: 8
collate_fn: ListCollator
...@@ -70,14 +70,16 @@ Loss: ...@@ -70,14 +70,16 @@ Loss:
mode: "l2" mode: "l2"
model_name_pairs: model_name_pairs:
- ["Student", "Teacher"] - ["Student", "Teacher"]
key: hidden_states_5 key: hidden_states
index: 5
name: "loss_5" name: "loss_5"
- DistillationVQADistanceLoss: - DistillationVQADistanceLoss:
weight: 0.5 weight: 0.5
mode: "l2" mode: "l2"
model_name_pairs: model_name_pairs:
- ["Student", "Teacher"] - ["Student", "Teacher"]
key: hidden_states_8 key: hidden_states
index: 8
name: "loss_8" name: "loss_8"
......
...@@ -49,6 +49,11 @@ DECLARE_int32(rec_batch_num); ...@@ -49,6 +49,11 @@ DECLARE_int32(rec_batch_num);
DECLARE_string(rec_char_dict_path); DECLARE_string(rec_char_dict_path);
DECLARE_int32(rec_img_h); DECLARE_int32(rec_img_h);
DECLARE_int32(rec_img_w); DECLARE_int32(rec_img_w);
// layout model related
DECLARE_string(layout_model_dir);
DECLARE_string(layout_dict_path);
DECLARE_double(layout_score_threshold);
DECLARE_double(layout_nms_threshold);
// structure model related // structure model related
DECLARE_string(table_model_dir); DECLARE_string(table_model_dir);
DECLARE_int32(table_max_len); DECLARE_int32(table_max_len);
...@@ -60,3 +65,4 @@ DECLARE_bool(det); ...@@ -60,3 +65,4 @@ DECLARE_bool(det);
DECLARE_bool(rec); DECLARE_bool(rec);
DECLARE_bool(cls); DECLARE_bool(cls);
DECLARE_bool(table); DECLARE_bool(table);
DECLARE_bool(layout);
\ No newline at end of file
...@@ -14,26 +14,12 @@ ...@@ -14,26 +14,12 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h" #include "paddle_api.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/preprocess_op.h> #include <include/preprocess_op.h>
#include <include/utility.h> #include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR { namespace PaddleOCR {
class Classifier { class Classifier {
...@@ -66,7 +52,7 @@ public: ...@@ -66,7 +52,7 @@ public:
std::vector<float> &cls_scores, std::vector<double> &times); std::vector<float> &cls_scores, std::vector<double> &times);
private: private:
std::shared_ptr<Predictor> predictor_; std::shared_ptr<paddle_infer::Predictor> predictor_;
bool use_gpu_ = false; bool use_gpu_ = false;
int gpu_id_ = 0; int gpu_id_ = 0;
......
...@@ -14,26 +14,12 @@ ...@@ -14,26 +14,12 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h" #include "paddle_api.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/postprocess_op.h> #include <include/postprocess_op.h>
#include <include/preprocess_op.h> #include <include/preprocess_op.h>
using namespace paddle_infer;
namespace PaddleOCR { namespace PaddleOCR {
class DBDetector { class DBDetector {
...@@ -41,7 +27,7 @@ public: ...@@ -41,7 +27,7 @@ public:
explicit DBDetector(const std::string &model_dir, const bool &use_gpu, explicit DBDetector(const std::string &model_dir, const bool &use_gpu,
const int &gpu_id, const int &gpu_mem, const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const string &limit_type, const bool &use_mkldnn, const std::string &limit_type,
const int &limit_side_len, const double &det_db_thresh, const int &limit_side_len, const double &det_db_thresh,
const double &det_db_box_thresh, const double &det_db_box_thresh,
const double &det_db_unclip_ratio, const double &det_db_unclip_ratio,
...@@ -77,7 +63,7 @@ public: ...@@ -77,7 +63,7 @@ public:
std::vector<double> &times); std::vector<double> &times);
private: private:
std::shared_ptr<Predictor> predictor_; std::shared_ptr<paddle_infer::Predictor> predictor_;
bool use_gpu_ = false; bool use_gpu_ = false;
int gpu_id_ = 0; int gpu_id_ = 0;
...@@ -85,7 +71,7 @@ private: ...@@ -85,7 +71,7 @@ private:
int cpu_math_library_num_threads_ = 4; int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false; bool use_mkldnn_ = false;
string limit_type_ = "max"; std::string limit_type_ = "max";
int limit_side_len_ = 960; int limit_side_len_ = 960;
double det_db_thresh_ = 0.3; double det_db_thresh_ = 0.3;
......
...@@ -14,27 +14,12 @@ ...@@ -14,27 +14,12 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h" #include "paddle_api.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/ocr_cls.h> #include <include/ocr_cls.h>
#include <include/preprocess_op.h>
#include <include/utility.h> #include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR { namespace PaddleOCR {
class CRNNRecognizer { class CRNNRecognizer {
...@@ -42,7 +27,7 @@ public: ...@@ -42,7 +27,7 @@ public:
explicit CRNNRecognizer(const std::string &model_dir, const bool &use_gpu, explicit CRNNRecognizer(const std::string &model_dir, const bool &use_gpu,
const int &gpu_id, const int &gpu_mem, const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const string &label_path, const bool &use_mkldnn, const std::string &label_path,
const bool &use_tensorrt, const bool &use_tensorrt,
const std::string &precision, const std::string &precision,
const int &rec_batch_num, const int &rec_img_h, const int &rec_batch_num, const int &rec_img_h,
...@@ -75,7 +60,7 @@ public: ...@@ -75,7 +60,7 @@ public:
std::vector<float> &rec_text_scores, std::vector<double> &times); std::vector<float> &rec_text_scores, std::vector<double> &times);
private: private:
std::shared_ptr<Predictor> predictor_; std::shared_ptr<paddle_infer::Predictor> predictor_;
bool use_gpu_ = false; bool use_gpu_ = false;
int gpu_id_ = 0; int gpu_id_ = 0;
......
...@@ -14,28 +14,9 @@ ...@@ -14,28 +14,9 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/ocr_cls.h> #include <include/ocr_cls.h>
#include <include/ocr_det.h> #include <include/ocr_det.h>
#include <include/ocr_rec.h> #include <include/ocr_rec.h>
#include <include/preprocess_op.h>
#include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR { namespace PaddleOCR {
...@@ -43,21 +24,27 @@ class PPOCR { ...@@ -43,21 +24,27 @@ class PPOCR {
public: public:
explicit PPOCR(); explicit PPOCR();
~PPOCR(); ~PPOCR();
std::vector<std::vector<OCRPredictResult>>
ocr(std::vector<cv::String> cv_all_img_names, bool det = true, std::vector<std::vector<OCRPredictResult>> ocr(std::vector<cv::Mat> img_list,
bool det = true,
bool rec = true,
bool cls = true);
std::vector<OCRPredictResult> ocr(cv::Mat img, bool det = true,
bool rec = true, bool cls = true); bool rec = true, bool cls = true);
void reset_timer();
void benchmark_log(int img_num);
protected: protected:
void det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results, std::vector<double> time_info_det = {0, 0, 0};
std::vector<double> &times); std::vector<double> time_info_rec = {0, 0, 0};
std::vector<double> time_info_cls = {0, 0, 0};
void det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results);
void rec(std::vector<cv::Mat> img_list, void rec(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results, std::vector<OCRPredictResult> &ocr_results);
std::vector<double> &times);
void cls(std::vector<cv::Mat> img_list, void cls(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results, std::vector<OCRPredictResult> &ocr_results);
std::vector<double> &times);
void log(std::vector<double> &det_times, std::vector<double> &rec_times,
std::vector<double> &cls_times, int img_num);
private: private:
DBDetector *detector_ = nullptr; DBDetector *detector_ = nullptr;
......
...@@ -14,27 +14,9 @@ ...@@ -14,27 +14,9 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/paddleocr.h> #include <include/paddleocr.h>
#include <include/preprocess_op.h> #include <include/structure_layout.h>
#include <include/structure_table.h> #include <include/structure_table.h>
#include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR { namespace PaddleOCR {
...@@ -42,23 +24,31 @@ class PaddleStructure : public PPOCR { ...@@ -42,23 +24,31 @@ class PaddleStructure : public PPOCR {
public: public:
explicit PaddleStructure(); explicit PaddleStructure();
~PaddleStructure(); ~PaddleStructure();
std::vector<std::vector<StructurePredictResult>>
structure(std::vector<cv::String> cv_all_img_names, bool layout = false, std::vector<StructurePredictResult> structure(cv::Mat img,
bool table = true); bool layout = false,
bool table = true,
bool ocr = false);
void reset_timer();
void benchmark_log(int img_num);
private: private:
StructureTableRecognizer *recognizer_ = nullptr; std::vector<double> time_info_table = {0, 0, 0};
std::vector<double> time_info_layout = {0, 0, 0};
StructureTableRecognizer *table_model_ = nullptr;
StructureLayoutRecognizer *layout_model_ = nullptr;
void layout(cv::Mat img,
std::vector<StructurePredictResult> &structure_result);
void table(cv::Mat img, StructurePredictResult &structure_result);
void table(cv::Mat img, StructurePredictResult &structure_result,
std::vector<double> &time_info_table,
std::vector<double> &time_info_det,
std::vector<double> &time_info_rec,
std::vector<double> &time_info_cls);
std::string rebuild_table(std::vector<std::string> rec_html_tags, std::string rebuild_table(std::vector<std::string> rec_html_tags,
std::vector<std::vector<int>> rec_boxes, std::vector<std::vector<int>> rec_boxes,
std::vector<OCRPredictResult> &ocr_result); std::vector<OCRPredictResult> &ocr_result);
float iou(std::vector<int> &box1, std::vector<int> &box2);
float dis(std::vector<int> &box1, std::vector<int> &box2); float dis(std::vector<int> &box1, std::vector<int> &box2);
static bool comparison_dis(const std::vector<float> &dis1, static bool comparison_dis(const std::vector<float> &dis1,
......
...@@ -14,24 +14,9 @@ ...@@ -14,24 +14,9 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include "include/clipper.h" #include "include/clipper.h"
#include "include/utility.h" #include "include/utility.h"
using namespace std;
namespace PaddleOCR { namespace PaddleOCR {
class DBPostProcessor { class DBPostProcessor {
...@@ -106,4 +91,27 @@ private: ...@@ -106,4 +91,27 @@ private:
std::string beg = "sos"; std::string beg = "sos";
}; };
class PicodetPostProcessor {
public:
void init(std::string label_path, const double score_threshold = 0.4,
const double nms_threshold = 0.5,
const std::vector<int> &fpn_stride = {8, 16, 32, 64});
void Run(std::vector<StructurePredictResult> &results,
std::vector<std::vector<float>> outs, std::vector<int> ori_shape,
std::vector<int> resize_shape, int eg_max);
std::vector<int> fpn_stride_ = {8, 16, 32, 64};
private:
StructurePredictResult disPred2Bbox(std::vector<float> bbox_pred, int label,
float score, int x, int y, int stride,
std::vector<int> im_shape, int reg_max);
void nms(std::vector<StructurePredictResult> &input_boxes,
float nms_threshold);
std::vector<std::string> label_list_;
double score_threshold_ = 0.4;
double nms_threshold_ = 0.5;
int num_class_ = 5;
};
} // namespace PaddleOCR } // namespace PaddleOCR
...@@ -14,21 +14,12 @@ ...@@ -14,21 +14,12 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream> #include <iostream>
#include <ostream>
#include <vector> #include <vector>
#include <cstring> #include "opencv2/core.hpp"
#include <fstream> #include "opencv2/imgcodecs.hpp"
#include <numeric> #include "opencv2/imgproc.hpp"
using namespace std;
using namespace paddle;
namespace PaddleOCR { namespace PaddleOCR {
...@@ -51,9 +42,9 @@ public: ...@@ -51,9 +42,9 @@ public:
class ResizeImgType0 { class ResizeImgType0 {
public: public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, string limit_type, virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
int limit_side_len, float &ratio_h, float &ratio_w, std::string limit_type, int limit_side_len, float &ratio_h,
bool use_tensorrt); float &ratio_w, bool use_tensorrt);
}; };
class CrnnResizeImg { class CrnnResizeImg {
...@@ -82,4 +73,10 @@ public: ...@@ -82,4 +73,10 @@ public:
const int max_len = 488); const int max_len = 488);
}; };
class Resize {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, const int h,
const int w);
};
} // namespace PaddleOCR } // namespace PaddleOCR
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <include/postprocess_op.h>
#include <include/preprocess_op.h>
namespace PaddleOCR {
class StructureLayoutRecognizer {
public:
explicit StructureLayoutRecognizer(
const std::string &model_dir, const bool &use_gpu, const int &gpu_id,
const int &gpu_mem, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const std::string &label_path,
const bool &use_tensorrt, const std::string &precision,
const double &layout_score_threshold,
const double &layout_nms_threshold) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem;
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
this->use_mkldnn_ = use_mkldnn;
this->use_tensorrt_ = use_tensorrt;
this->precision_ = precision;
this->post_processor_.init(label_path, layout_score_threshold,
layout_nms_threshold);
LoadModel(model_dir);
}
// Load Paddle inference model
void LoadModel(const std::string &model_dir);
void Run(cv::Mat img, std::vector<StructurePredictResult> &result,
std::vector<double> &times);
private:
std::shared_ptr<paddle_infer::Predictor> predictor_;
bool use_gpu_ = false;
int gpu_id_ = 0;
int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
bool is_scale_ = true;
bool use_tensorrt_ = false;
std::string precision_ = "fp32";
// pre-process
Resize resize_op_;
Normalize normalize_op_;
Permute permute_op_;
// post-process
PicodetPostProcessor post_processor_;
};
} // namespace PaddleOCR
\ No newline at end of file
...@@ -14,26 +14,11 @@ ...@@ -14,26 +14,11 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h" #include "paddle_api.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/postprocess_op.h> #include <include/postprocess_op.h>
#include <include/preprocess_op.h> #include <include/preprocess_op.h>
#include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR { namespace PaddleOCR {
...@@ -42,7 +27,7 @@ public: ...@@ -42,7 +27,7 @@ public:
explicit StructureTableRecognizer( explicit StructureTableRecognizer(
const std::string &model_dir, const bool &use_gpu, const int &gpu_id, const std::string &model_dir, const bool &use_gpu, const int &gpu_id,
const int &gpu_mem, const int &cpu_math_library_num_threads, const int &gpu_mem, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const string &label_path, const bool &use_mkldnn, const std::string &label_path,
const bool &use_tensorrt, const std::string &precision, const bool &use_tensorrt, const std::string &precision,
const int &table_batch_num, const int &table_max_len, const int &table_batch_num, const int &table_max_len,
const bool &merge_no_span_structure) { const bool &merge_no_span_structure) {
...@@ -70,7 +55,7 @@ public: ...@@ -70,7 +55,7 @@ public:
std::vector<double> &times); std::vector<double> &times);
private: private:
std::shared_ptr<Predictor> predictor_; std::shared_ptr<paddle_infer::Predictor> predictor_;
bool use_gpu_ = false; bool use_gpu_ = false;
int gpu_id_ = 0; int gpu_id_ = 0;
......
...@@ -41,12 +41,13 @@ struct OCRPredictResult { ...@@ -41,12 +41,13 @@ struct OCRPredictResult {
}; };
struct StructurePredictResult { struct StructurePredictResult {
std::vector<int> box; std::vector<float> box;
std::vector<std::vector<int>> cell_box; std::vector<std::vector<int>> cell_box;
std::string type; std::string type;
std::vector<OCRPredictResult> text_res; std::vector<OCRPredictResult> text_res;
std::string html; std::string html;
float html_score = -1; float html_score = -1;
float confidence;
}; };
class Utility { class Utility {
...@@ -82,13 +83,20 @@ public: ...@@ -82,13 +83,20 @@ public:
static void print_result(const std::vector<OCRPredictResult> &ocr_result); static void print_result(const std::vector<OCRPredictResult> &ocr_result);
static cv::Mat crop_image(cv::Mat &img, std::vector<int> &area); static cv::Mat crop_image(cv::Mat &img, const std::vector<int> &area);
static cv::Mat crop_image(cv::Mat &img, const std::vector<float> &area);
static void sorted_boxes(std::vector<OCRPredictResult> &ocr_result); static void sorted_boxes(std::vector<OCRPredictResult> &ocr_result);
static std::vector<int> xyxyxyxy2xyxy(std::vector<std::vector<int>> &box); static std::vector<int> xyxyxyxy2xyxy(std::vector<std::vector<int>> &box);
static std::vector<int> xyxyxyxy2xyxy(std::vector<int> &box); static std::vector<int> xyxyxyxy2xyxy(std::vector<int> &box);
static float fast_exp(float x);
static std::vector<float>
activation_function_softmax(std::vector<float> &src);
static float iou(std::vector<int> &box1, std::vector<int> &box2);
static float iou(std::vector<float> &box1, std::vector<float> &box2);
private: private:
static bool comparison_box(const OCRPredictResult &result1, static bool comparison_box(const OCRPredictResult &result1,
const OCRPredictResult &result2) { const OCRPredictResult &result2) {
......
...@@ -174,6 +174,9 @@ inference/ ...@@ -174,6 +174,9 @@ inference/
|-- table |-- table
| |--inference.pdiparams | |--inference.pdiparams
| |--inference.pdmodel | |--inference.pdmodel
|-- layout
| |--inference.pdiparams
| |--inference.pdmodel
``` ```
...@@ -278,8 +281,30 @@ Specifically, ...@@ -278,8 +281,30 @@ Specifically,
--cls=true \ --cls=true \
``` ```
##### 7. layout+table
```shell
./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \
--table_model_dir=inference/table \
--image_dir=../../ppstructure/docs/table/table.jpg \
--layout_model_dir=inference/layout \
--type=structure \
--table=true \
--layout=true
```
##### 8. layout
```shell
./build/ppocr --layout_model_dir=inference/layout \
--image_dir=../../ppstructure/docs/table/1.png \
--type=structure \
--table=false \
--layout=true \
--det=false \
--rec=false
```
##### 7. table ##### 9. table
```shell ```shell
./build/ppocr --det_model_dir=inference/det_db \ ./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \ --rec_model_dir=inference/rec_rcnn \
...@@ -343,6 +368,16 @@ More parameters are as follows, ...@@ -343,6 +368,16 @@ More parameters are as follows,
|rec_img_h|int|48|image height of recognition| |rec_img_h|int|48|image height of recognition|
|rec_img_w|int|320|image width of recognition| |rec_img_w|int|320|image width of recognition|
- Layout related parameters
|parameter|data type|default|meaning|
| :---: | :---: | :---: | :---: |
|layout_model_dir|string|-| Address of layout inference model|
|layout_dict_path|string|../../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt|dictionary file|
|layout_score_threshold|float|0.5|Threshold of score.|
|layout_nms_threshold|float|0.5|Threshold of nms.|
- Table recognition related parameters - Table recognition related parameters
|parameter|data type|default|meaning| |parameter|data type|default|meaning|
...@@ -368,11 +403,51 @@ predict img: ../../doc/imgs/12.jpg ...@@ -368,11 +403,51 @@ predict img: ../../doc/imgs/12.jpg
The detection visualized image saved in ./output//12.jpg The detection visualized image saved in ./output//12.jpg
``` ```
- table - layout+table
```bash ```bash
predict img: ../../ppstructure/docs/table/table.jpg predict img: ../../ppstructure/docs/table/1.png
0 type: table, region: [0,0,371,293], res: <html><body><table><thead><tr><td>Methods</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>SegLink [26]</td><td>70.0</td><td>86.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink [4]</td><td>73.2</td><td>83.0</td><td>77.8</td><td>-</td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2 </td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.4</td><td>81.7</td><td>-</td></tr><tr><td>FTSN [3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td>-</td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><td>-</td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN [16]</td><td>79</td><td>88</td><td>83</td><td>-</td></tr><tr><td>ATRR[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td>-</td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</td><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG [41]</td><td>82.30</td><td>88.05</td><td>85.08</td><td>-</td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85.40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td><td>85.57</td><td>12.31</td></tr></tbody></table></body></html> 0 type: text, region: [12,729,410,848], score: 0.781044, res: count of ocr result is : 7
********** print ocr result **********
0 det boxes: [[4,1],[79,1],[79,12],[4,12]] rec text: CTW1500. rec score: 0.769472
...
6 det boxes: [[4,99],[391,99],[391,112],[4,112]] rec text: sate-of-the-artmethods[12.34.36l.ourapproachachieves rec score: 0.90414
********** end print ocr result **********
1 type: text, region: [69,342,342,359], score: 0.703666, res: count of ocr result is : 1
********** print ocr result **********
0 det boxes: [[8,2],[269,2],[269,13],[8,13]] rec text: Table6.Experimentalresults on CTW-1500 rec score: 0.890454
********** end print ocr result **********
2 type: text, region: [70,316,706,332], score: 0.659738, res: count of ocr result is : 2
********** print ocr result **********
0 det boxes: [[373,2],[630,2],[630,11],[373,11]] rec text: oroposals.andthegreencontoursarefinal rec score: 0.919729
1 det boxes: [[8,3],[357,3],[357,11],[8,11]] rec text: Visualexperimentalresultshebluecontoursareboundar rec score: 0.915963
********** end print ocr result **********
3 type: text, region: [489,342,789,359], score: 0.630538, res: count of ocr result is : 1
********** print ocr result **********
0 det boxes: [[8,2],[294,2],[294,14],[8,14]] rec text: Table7.Experimentalresults onMSRA-TD500 rec score: 0.942251
********** end print ocr result **********
4 type: text, region: [444,751,841,848], score: 0.607345, res: count of ocr result is : 5
********** print ocr result **********
0 det boxes: [[19,3],[389,3],[389,17],[19,17]] rec text: Inthispaper,weproposeanovel adaptivebound rec score: 0.941031
1 det boxes: [[4,22],[390,22],[390,36],[4,36]] rec text: aryproposalnetworkforarbitraryshapetextdetection rec score: 0.960172
2 det boxes: [[4,42],[392,42],[392,56],[4,56]] rec text: whichadoptanboundaryproposalmodeltogeneratecoarse rec score: 0.934647
3 det boxes: [[4,61],[389,61],[389,75],[4,75]] rec text: ooundaryproposals,andthenadoptanadaptiveboundary rec score: 0.946296
4 det boxes: [[5,80],[387,80],[387,93],[5,93]] rec text: leformationmodelcombinedwithGCNandRNNtoper rec score: 0.952401
********** end print ocr result **********
5 type: title, region: [444,705,564,724], score: 0.785429, res: count of ocr result is : 1
********** print ocr result **********
0 det boxes: [[6,2],[113,2],[113,14],[6,14]] rec text: 5.Conclusion rec score: 0.856903
********** end print ocr result **********
6 type: table, region: [14,360,402,711], score: 0.963643, res: <html><body><table><thead><tr><td>Methods</td><td>Ext</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>TextSnake [18]</td><td>Syn</td><td>85.3</td><td>67.9</td><td>75.6</td><td></td></tr><tr><td>CSE [17]</td><td>MiLT</td><td>76.1</td><td>78.7</td><td>77.4</td><td>0.38</td></tr><tr><td>LOMO[40]</td><td>Syn</td><td>76.5</td><td>85.7</td><td>80.8</td><td>4.4</td></tr><tr><td>ATRR[35]</td><td>Sy-</td><td>80.2</td><td>80.1</td><td>80.1</td><td>-</td></tr><tr><td>SegLink++ [28]</td><td>Syn</td><td>79.8</td><td>82.8</td><td>81.3</td><td>-</td></tr><tr><td>TextField [37]</td><td>Syn</td><td>79.8</td><td>83.0</td><td>81.4</td><td>6.0</td></tr><tr><td>MSR[38]</td><td>Syn</td><td>79.0</td><td>84.1</td><td>81.5</td><td>4.3</td></tr><tr><td>PSENet-1s [33]</td><td>MLT</td><td>79.7</td><td>84.8</td><td>82.2</td><td>3.9</td></tr><tr><td>DB [12]</td><td>Syn</td><td>80.2</td><td>86.9</td><td>83.4</td><td>22.0</td></tr><tr><td>CRAFT [2]</td><td>Syn</td><td>81.1</td><td>86.0</td><td>83.5</td><td>-</td></tr><tr><td>TextDragon [5]</td><td>MLT+</td><td>82.8</td><td>84.5</td><td>83.6</td><td></td></tr><tr><td>PAN [34]</td><td>Syn</td><td>81.2</td><td>86.4</td><td>83.7</td><td>39.8</td></tr><tr><td>ContourNet [36]</td><td></td><td>84.1</td><td>83.7</td><td>83.9</td><td>4.5</td></tr><tr><td>DRRG [41]</td><td>MLT</td><td>83.02</td><td>85.93</td><td>84.45</td><td>-</td></tr><tr><td>TextPerception[23]</td><td>Syn</td><td>81.9</td><td>87.5</td><td>84.6</td><td></td></tr><tr><td>Ours</td><td> Syn</td><td>80.57</td><td>87.66</td><td>83.97</td><td>12.08</td></tr><tr><td>Ours</td><td></td><td>81.45</td><td>87.81</td><td>84.51</td><td>12.15</td></tr><tr><td>Ours</td><td>MLT</td><td>83.60</td><td>86.45</td><td>85.00</td><td>12.21</td></tr></tbody></table></body></html>
The table visualized image saved in ./output//6_1.png
7 type: table, region: [462,359,820,657], score: 0.953917, res: <html><body><table><thead><tr><td>Methods</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>SegLink [26]</td><td>70.0</td><td>86.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink [4]</td><td>73.2</td><td>83.0</td><td>77.8</td><td>-</td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2 </td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.4</td><td>81.7</td><td>-</td></tr><tr><td>FTSN[3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td>:</td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><td></td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN [16]</td><td>79</td><td>88</td><td>83</td><td>-</td></tr><tr><td>ATRR[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td>-</td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</td><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG [41]</td><td>82.30</td><td>88.05</td><td>85.08</td><td>-</td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85.40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td><td>85.57</td><td>12.31</td></tr></tbody></table></body></html>
The table visualized image saved in ./output//7_1.png
8 type: figure, region: [14,3,836,310], score: 0.969443, res: count of ocr result is : 26
********** print ocr result **********
0 det boxes: [[506,14],[539,15],[539,22],[506,21]] rec text: E rec score: 0.318073
...
25 det boxes: [[680,290],[759,288],[759,303],[680,305]] rec text: (d) CTW1500 rec score: 0.95911
********** end print ocr result **********
``` ```
<a name="3"></a> <a name="3"></a>
......
...@@ -184,6 +184,9 @@ inference/ ...@@ -184,6 +184,9 @@ inference/
|-- table |-- table
| |--inference.pdiparams | |--inference.pdiparams
| |--inference.pdmodel | |--inference.pdmodel
|-- layout
| |--inference.pdiparams
| |--inference.pdmodel
``` ```
<a name="22"></a> <a name="22"></a>
...@@ -288,7 +291,30 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir ...@@ -288,7 +291,30 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
--cls=true \ --cls=true \
``` ```
##### 7. 表格识别 ##### 7. 版面分析+表格识别
```shell
./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \
--table_model_dir=inference/table \
--image_dir=../../ppstructure/docs/table/table.jpg \
--layout_model_dir=inference/layout \
--type=structure \
--table=true \
--layout=true
```
##### 8. 版面分析
```shell
./build/ppocr --layout_model_dir=inference/layout \
--image_dir=../../ppstructure/docs/table/1.png \
--type=structure \
--table=false \
--layout=true \
--det=false \
--rec=false
```
##### 9. 表格识别
```shell ```shell
./build/ppocr --det_model_dir=inference/det_db \ ./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \ --rec_model_dir=inference/rec_rcnn \
...@@ -352,12 +378,22 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir ...@@ -352,12 +378,22 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|rec_img_w|int|320|文字识别模型输入图像宽度| |rec_img_w|int|320|文字识别模型输入图像宽度|
- 版面分析模型相关
|参数名称|类型|默认参数|意义|
| :---: | :---: | :---: | :---: |
|layout_model_dir|string|-|版面分析模型inference model地址|
|layout_dict_path|string|../../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt|字典文件|
|layout_score_threshold|float|0.5|检测框的分数阈值|
|layout_nms_threshold|float|0.5|nms的阈值|
- 表格识别模型相关 - 表格识别模型相关
|参数名称|类型|默认参数|意义| |参数名称|类型|默认参数|意义|
| :---: | :---: | :---: | :---: | | :---: | :---: | :---: | :---: |
|table_model_dir|string|-|表格识别模型inference model地址| |table_model_dir|string|-|表格识别模型inference model地址|
|table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict.txt|字典文件| |table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict_ch.txt|字典文件|
|table_max_len|int|488|表格识别模型输入图像长边大小,最终网络输入图像大小为(table_max_len,table_max_len)| |table_max_len|int|488|表格识别模型输入图像长边大小,最终网络输入图像大小为(table_max_len,table_max_len)|
|merge_no_span_structure|bool|true|是否合并<td></td><td></td>| |merge_no_span_structure|bool|true|是否合并<td></td><td></td>|
...@@ -378,11 +414,51 @@ predict img: ../../doc/imgs/12.jpg ...@@ -378,11 +414,51 @@ predict img: ../../doc/imgs/12.jpg
The detection visualized image saved in ./output//12.jpg The detection visualized image saved in ./output//12.jpg
``` ```
- table - layout+table
```bash ```bash
predict img: ../../ppstructure/docs/table/table.jpg predict img: ../../ppstructure/docs/table/1.png
0 type: table, region: [0,0,371,293], res: <html><body><table><thead><tr><td>Methods</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>SegLink [26]</td><td>70.0</td><td>86.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink [4]</td><td>73.2</td><td>83.0</td><td>77.8</td><td>-</td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2 </td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.4</td><td>81.7</td><td>-</td></tr><tr><td>FTSN [3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td>-</td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><td>-</td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN [16]</td><td>79</td><td>88</td><td>83</td><td>-</td></tr><tr><td>ATRR[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td>-</td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</td><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG [41]</td><td>82.30</td><td>88.05</td><td>85.08</td><td>-</td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85.40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td><td>85.57</td><td>12.31</td></tr></tbody></table></body></html> 0 type: text, region: [12,729,410,848], score: 0.781044, res: count of ocr result is : 7
********** print ocr result **********
0 det boxes: [[4,1],[79,1],[79,12],[4,12]] rec text: CTW1500. rec score: 0.769472
...
6 det boxes: [[4,99],[391,99],[391,112],[4,112]] rec text: sate-of-the-artmethods[12.34.36l.ourapproachachieves rec score: 0.90414
********** end print ocr result **********
1 type: text, region: [69,342,342,359], score: 0.703666, res: count of ocr result is : 1
********** print ocr result **********
0 det boxes: [[8,2],[269,2],[269,13],[8,13]] rec text: Table6.Experimentalresults on CTW-1500 rec score: 0.890454
********** end print ocr result **********
2 type: text, region: [70,316,706,332], score: 0.659738, res: count of ocr result is : 2
********** print ocr result **********
0 det boxes: [[373,2],[630,2],[630,11],[373,11]] rec text: oroposals.andthegreencontoursarefinal rec score: 0.919729
1 det boxes: [[8,3],[357,3],[357,11],[8,11]] rec text: Visualexperimentalresultshebluecontoursareboundar rec score: 0.915963
********** end print ocr result **********
3 type: text, region: [489,342,789,359], score: 0.630538, res: count of ocr result is : 1
********** print ocr result **********
0 det boxes: [[8,2],[294,2],[294,14],[8,14]] rec text: Table7.Experimentalresults onMSRA-TD500 rec score: 0.942251
********** end print ocr result **********
4 type: text, region: [444,751,841,848], score: 0.607345, res: count of ocr result is : 5
********** print ocr result **********
0 det boxes: [[19,3],[389,3],[389,17],[19,17]] rec text: Inthispaper,weproposeanovel adaptivebound rec score: 0.941031
1 det boxes: [[4,22],[390,22],[390,36],[4,36]] rec text: aryproposalnetworkforarbitraryshapetextdetection rec score: 0.960172
2 det boxes: [[4,42],[392,42],[392,56],[4,56]] rec text: whichadoptanboundaryproposalmodeltogeneratecoarse rec score: 0.934647
3 det boxes: [[4,61],[389,61],[389,75],[4,75]] rec text: ooundaryproposals,andthenadoptanadaptiveboundary rec score: 0.946296
4 det boxes: [[5,80],[387,80],[387,93],[5,93]] rec text: leformationmodelcombinedwithGCNandRNNtoper rec score: 0.952401
********** end print ocr result **********
5 type: title, region: [444,705,564,724], score: 0.785429, res: count of ocr result is : 1
********** print ocr result **********
0 det boxes: [[6,2],[113,2],[113,14],[6,14]] rec text: 5.Conclusion rec score: 0.856903
********** end print ocr result **********
6 type: table, region: [14,360,402,711], score: 0.963643, res: <html><body><table><thead><tr><td>Methods</td><td>Ext</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>TextSnake [18]</td><td>Syn</td><td>85.3</td><td>67.9</td><td>75.6</td><td></td></tr><tr><td>CSE [17]</td><td>MiLT</td><td>76.1</td><td>78.7</td><td>77.4</td><td>0.38</td></tr><tr><td>LOMO[40]</td><td>Syn</td><td>76.5</td><td>85.7</td><td>80.8</td><td>4.4</td></tr><tr><td>ATRR[35]</td><td>Sy-</td><td>80.2</td><td>80.1</td><td>80.1</td><td>-</td></tr><tr><td>SegLink++ [28]</td><td>Syn</td><td>79.8</td><td>82.8</td><td>81.3</td><td>-</td></tr><tr><td>TextField [37]</td><td>Syn</td><td>79.8</td><td>83.0</td><td>81.4</td><td>6.0</td></tr><tr><td>MSR[38]</td><td>Syn</td><td>79.0</td><td>84.1</td><td>81.5</td><td>4.3</td></tr><tr><td>PSENet-1s [33]</td><td>MLT</td><td>79.7</td><td>84.8</td><td>82.2</td><td>3.9</td></tr><tr><td>DB [12]</td><td>Syn</td><td>80.2</td><td>86.9</td><td>83.4</td><td>22.0</td></tr><tr><td>CRAFT [2]</td><td>Syn</td><td>81.1</td><td>86.0</td><td>83.5</td><td>-</td></tr><tr><td>TextDragon [5]</td><td>MLT+</td><td>82.8</td><td>84.5</td><td>83.6</td><td></td></tr><tr><td>PAN [34]</td><td>Syn</td><td>81.2</td><td>86.4</td><td>83.7</td><td>39.8</td></tr><tr><td>ContourNet [36]</td><td></td><td>84.1</td><td>83.7</td><td>83.9</td><td>4.5</td></tr><tr><td>DRRG [41]</td><td>MLT</td><td>83.02</td><td>85.93</td><td>84.45</td><td>-</td></tr><tr><td>TextPerception[23]</td><td>Syn</td><td>81.9</td><td>87.5</td><td>84.6</td><td></td></tr><tr><td>Ours</td><td> Syn</td><td>80.57</td><td>87.66</td><td>83.97</td><td>12.08</td></tr><tr><td>Ours</td><td></td><td>81.45</td><td>87.81</td><td>84.51</td><td>12.15</td></tr><tr><td>Ours</td><td>MLT</td><td>83.60</td><td>86.45</td><td>85.00</td><td>12.21</td></tr></tbody></table></body></html>
The table visualized image saved in ./output//6_1.png
7 type: table, region: [462,359,820,657], score: 0.953917, res: <html><body><table><thead><tr><td>Methods</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>SegLink [26]</td><td>70.0</td><td>86.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink [4]</td><td>73.2</td><td>83.0</td><td>77.8</td><td>-</td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2 </td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.4</td><td>81.7</td><td>-</td></tr><tr><td>FTSN[3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td>:</td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><td></td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN [16]</td><td>79</td><td>88</td><td>83</td><td>-</td></tr><tr><td>ATRR[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td>-</td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</td><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG [41]</td><td>82.30</td><td>88.05</td><td>85.08</td><td>-</td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85.40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td><td>85.57</td><td>12.31</td></tr></tbody></table></body></html>
The table visualized image saved in ./output//7_1.png
8 type: figure, region: [14,3,836,310], score: 0.969443, res: count of ocr result is : 26
********** print ocr result **********
0 det boxes: [[506,14],[539,15],[539,22],[506,21]] rec text: E rec score: 0.318073
...
25 det boxes: [[680,290],[759,288],[759,303],[680,305]] rec text: (d) CTW1500 rec score: 0.95911
********** end print ocr result **********
``` ```
<a name="3"></a> <a name="3"></a>
......
...@@ -51,6 +51,13 @@ DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt", ...@@ -51,6 +51,13 @@ DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
DEFINE_int32(rec_img_h, 48, "rec image height"); DEFINE_int32(rec_img_h, 48, "rec image height");
DEFINE_int32(rec_img_w, 320, "rec image width"); DEFINE_int32(rec_img_w, 320, "rec image width");
// layout model related
DEFINE_string(layout_model_dir, "", "Path of table layout inference model.");
DEFINE_string(layout_dict_path,
"../../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt",
"Path of dictionary.");
DEFINE_double(layout_score_threshold, 0.5, "Threshold of score.");
DEFINE_double(layout_nms_threshold, 0.5, "Threshold of nms.");
// structure model related // structure model related
DEFINE_string(table_model_dir, "", "Path of table struture inference model."); DEFINE_string(table_model_dir, "", "Path of table struture inference model.");
DEFINE_int32(table_max_len, 488, "max len size of input image."); DEFINE_int32(table_max_len, 488, "max len size of input image.");
...@@ -66,3 +73,4 @@ DEFINE_bool(det, true, "Whether use det in forward."); ...@@ -66,3 +73,4 @@ DEFINE_bool(det, true, "Whether use det in forward.");
DEFINE_bool(rec, true, "Whether use rec in forward."); DEFINE_bool(rec, true, "Whether use rec in forward.");
DEFINE_bool(cls, false, "Whether use cls in forward."); DEFINE_bool(cls, false, "Whether use cls in forward.");
DEFINE_bool(table, false, "Whether use table structure in forward."); DEFINE_bool(table, false, "Whether use table structure in forward.");
DEFINE_bool(layout, false, "Whether use layout analysis in forward.");
\ No newline at end of file
...@@ -65,9 +65,18 @@ void check_params() { ...@@ -65,9 +65,18 @@ void check_params() {
exit(1); exit(1);
} }
} }
if (FLAGS_layout) {
if (FLAGS_layout_model_dir.empty() || FLAGS_image_dir.empty()) {
std::cout << "Usage[layout]: ./ppocr "
<< "--layout_model_dir=/PATH/TO/LAYOUT_INFERENCE_MODEL/ "
<< "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
exit(1);
}
}
if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" && if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" &&
FLAGS_precision != "int8") { FLAGS_precision != "int8") {
cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " << endl; std::cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. "
<< std::endl;
exit(1); exit(1);
} }
} }
...@@ -75,71 +84,94 @@ void check_params() { ...@@ -75,71 +84,94 @@ void check_params() {
void ocr(std::vector<cv::String> &cv_all_img_names) { void ocr(std::vector<cv::String> &cv_all_img_names) {
PPOCR ocr = PPOCR(); PPOCR ocr = PPOCR();
std::vector<std::vector<OCRPredictResult>> ocr_results =
ocr.ocr(cv_all_img_names, FLAGS_det, FLAGS_rec, FLAGS_cls);
for (int i = 0; i < cv_all_img_names.size(); ++i) {
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
cout << cv_all_img_names[i] << '\t'; ocr.reset_timer();
if (FLAGS_rec && FLAGS_det) {
Utility::print_result(ocr_results[i]);
} else if (FLAGS_det) {
for (int n = 0; n < ocr_results[i].size(); n++) {
for (int m = 0; m < ocr_results[i][n].box.size(); m++) {
cout << ocr_results[i][n].box[m][0] << ' '
<< ocr_results[i][n].box[m][1] << ' ';
} }
std::vector<cv::Mat> img_list;
std::vector<cv::String> img_names;
for (int i = 0; i < cv_all_img_names.size(); ++i) {
cv::Mat img = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!img.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << std::endl;
continue;
} }
cout << endl; img_list.push_back(img);
} else { img_names.push_back(cv_all_img_names[i]);
Utility::print_result(ocr_results[i]);
} }
} else {
cout << cv_all_img_names[i] << "\n"; std::vector<std::vector<OCRPredictResult>> ocr_results =
ocr.ocr(img_list, FLAGS_det, FLAGS_rec, FLAGS_cls);
for (int i = 0; i < img_names.size(); ++i) {
std::cout << "predict img: " << cv_all_img_names[i] << std::endl;
Utility::print_result(ocr_results[i]); Utility::print_result(ocr_results[i]);
if (FLAGS_visualize && FLAGS_det) { if (FLAGS_visualize && FLAGS_det) {
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); std::string file_name = Utility::basename(img_names[i]);
if (!srcimg.data) { cv::Mat srcimg = img_list[i];
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
std::string file_name = Utility::basename(cv_all_img_names[i]);
Utility::VisualizeBboxes(srcimg, ocr_results[i], Utility::VisualizeBboxes(srcimg, ocr_results[i],
FLAGS_output + "/" + file_name); FLAGS_output + "/" + file_name);
} }
cout << "***************************" << endl;
} }
if (FLAGS_benchmark) {
ocr.benchmark_log(cv_all_img_names.size());
} }
} }
void structure(std::vector<cv::String> &cv_all_img_names) { void structure(std::vector<cv::String> &cv_all_img_names) {
PaddleOCR::PaddleStructure engine = PaddleOCR::PaddleStructure(); PaddleOCR::PaddleStructure engine = PaddleOCR::PaddleStructure();
std::vector<std::vector<StructurePredictResult>> structure_results =
engine.structure(cv_all_img_names, false, FLAGS_table); if (FLAGS_benchmark) {
engine.reset_timer();
}
for (int i = 0; i < cv_all_img_names.size(); i++) { for (int i = 0; i < cv_all_img_names.size(); i++) {
cout << "predict img: " << cv_all_img_names[i] << endl; std::cout << "predict img: " << cv_all_img_names[i] << std::endl;
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); cv::Mat img = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
for (int j = 0; j < structure_results[i].size(); j++) { if (!img.data) {
std::cout << j << "\ttype: " << structure_results[i][j].type std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << std::endl;
continue;
}
std::vector<StructurePredictResult> structure_results = engine.structure(
img, FLAGS_layout, FLAGS_table, FLAGS_det && FLAGS_rec);
for (int j = 0; j < structure_results.size(); j++) {
std::cout << j << "\ttype: " << structure_results[j].type
<< ", region: ["; << ", region: [";
std::cout << structure_results[i][j].box[0] << "," std::cout << structure_results[j].box[0] << ","
<< structure_results[i][j].box[1] << "," << structure_results[j].box[1] << ","
<< structure_results[i][j].box[2] << "," << structure_results[j].box[2] << ","
<< structure_results[i][j].box[3] << "], res: "; << structure_results[j].box[3] << "], score: ";
if (structure_results[i][j].type == "table") { std::cout << structure_results[j].confidence << ", res: ";
std::cout << structure_results[i][j].html << std::endl;
if (structure_results[j].type == "table") {
std::cout << structure_results[j].html << std::endl;
if (structure_results[j].cell_box.size() > 0 && FLAGS_visualize) {
std::string file_name = Utility::basename(cv_all_img_names[i]); std::string file_name = Utility::basename(cv_all_img_names[i]);
Utility::VisualizeBboxes(srcimg, structure_results[i][j], Utility::VisualizeBboxes(img, structure_results[j],
FLAGS_output + "/" + std::to_string(j) + "_" + FLAGS_output + "/" + std::to_string(j) +
file_name); "_" + file_name);
}
} else { } else {
Utility::print_result(structure_results[i][j].text_res); std::cout << "count of ocr result is : "
<< structure_results[j].text_res.size() << std::endl;
if (structure_results[j].text_res.size() > 0) {
std::cout << "********** print ocr result "
<< "**********" << std::endl;
Utility::print_result(structure_results[j].text_res);
std::cout << "********** end print ocr result "
<< "**********" << std::endl;
}
} }
} }
} }
if (FLAGS_benchmark) {
engine.benchmark_log(cv_all_img_names.size());
}
} }
int main(int argc, char **argv) { int main(int argc, char **argv) {
...@@ -149,19 +181,22 @@ int main(int argc, char **argv) { ...@@ -149,19 +181,22 @@ int main(int argc, char **argv) {
if (!Utility::PathExists(FLAGS_image_dir)) { if (!Utility::PathExists(FLAGS_image_dir)) {
std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir
<< endl; << std::endl;
exit(1); exit(1);
} }
std::vector<cv::String> cv_all_img_names; std::vector<cv::String> cv_all_img_names;
cv::glob(FLAGS_image_dir, cv_all_img_names); cv::glob(FLAGS_image_dir, cv_all_img_names);
std::cout << "total images num: " << cv_all_img_names.size() << endl; std::cout << "total images num: " << cv_all_img_names.size() << std::endl;
if (!Utility::PathExists(FLAGS_output)) {
Utility::CreateDir(FLAGS_output);
}
if (FLAGS_type == "ocr") { if (FLAGS_type == "ocr") {
ocr(cv_all_img_names); ocr(cv_all_img_names);
} else if (FLAGS_type == "structure") { } else if (FLAGS_type == "structure") {
structure(cv_all_img_names); structure(cv_all_img_names);
} else { } else {
std::cout << "only value in ['ocr','structure'] is supported" << endl; std::cout << "only value in ['ocr','structure'] is supported" << std::endl;
} }
} }
...@@ -32,7 +32,7 @@ void Classifier::Run(std::vector<cv::Mat> img_list, ...@@ -32,7 +32,7 @@ void Classifier::Run(std::vector<cv::Mat> img_list,
for (int beg_img_no = 0; beg_img_no < img_num; for (int beg_img_no = 0; beg_img_no < img_num;
beg_img_no += this->cls_batch_num_) { beg_img_no += this->cls_batch_num_) {
auto preprocess_start = std::chrono::steady_clock::now(); auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = min(img_num, beg_img_no + this->cls_batch_num_); int end_img_no = std::min(img_num, beg_img_no + this->cls_batch_num_);
int batch_num = end_img_no - beg_img_no; int batch_num = end_img_no - beg_img_no;
// preprocess // preprocess
std::vector<cv::Mat> norm_img_batch; std::vector<cv::Mat> norm_img_batch;
...@@ -97,7 +97,7 @@ void Classifier::Run(std::vector<cv::Mat> img_list, ...@@ -97,7 +97,7 @@ void Classifier::Run(std::vector<cv::Mat> img_list,
} }
void Classifier::LoadModel(const std::string &model_dir) { void Classifier::LoadModel(const std::string &model_dir) {
AnalysisConfig config; paddle_infer::Config config;
config.SetModel(model_dir + "/inference.pdmodel", config.SetModel(model_dir + "/inference.pdmodel",
model_dir + "/inference.pdiparams"); model_dir + "/inference.pdiparams");
...@@ -112,7 +112,7 @@ void Classifier::LoadModel(const std::string &model_dir) { ...@@ -112,7 +112,7 @@ void Classifier::LoadModel(const std::string &model_dir) {
precision = paddle_infer::Config::Precision::kInt8; precision = paddle_infer::Config::Precision::kInt8;
} }
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false); config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
if (!Utility::PathExists("./trt_cls_shape.txt")){ if (!Utility::PathExists("./trt_cls_shape.txt")) {
config.CollectShapeRangeInfo("./trt_cls_shape.txt"); config.CollectShapeRangeInfo("./trt_cls_shape.txt");
} else { } else {
config.EnableTunedTensorRtDynamicShape("./trt_cls_shape.txt", true); config.EnableTunedTensorRtDynamicShape("./trt_cls_shape.txt", true);
...@@ -136,6 +136,6 @@ void Classifier::LoadModel(const std::string &model_dir) { ...@@ -136,6 +136,6 @@ void Classifier::LoadModel(const std::string &model_dir) {
config.EnableMemoryOptim(); config.EnableMemoryOptim();
config.DisableGlogInfo(); config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config); this->predictor_ = paddle_infer::CreatePredictor(config);
} }
} // namespace PaddleOCR } // namespace PaddleOCR
...@@ -33,12 +33,11 @@ void DBDetector::LoadModel(const std::string &model_dir) { ...@@ -33,12 +33,11 @@ void DBDetector::LoadModel(const std::string &model_dir) {
precision = paddle_infer::Config::Precision::kInt8; precision = paddle_infer::Config::Precision::kInt8;
} }
config.EnableTensorRtEngine(1 << 30, 1, 20, precision, false, false); config.EnableTensorRtEngine(1 << 30, 1, 20, precision, false, false);
if (!Utility::PathExists("./trt_det_shape.txt")){ if (!Utility::PathExists("./trt_det_shape.txt")) {
config.CollectShapeRangeInfo("./trt_det_shape.txt"); config.CollectShapeRangeInfo("./trt_det_shape.txt");
} else { } else {
config.EnableTunedTensorRtDynamicShape("./trt_det_shape.txt", true); config.EnableTunedTensorRtDynamicShape("./trt_det_shape.txt", true);
} }
} }
} else { } else {
config.DisableGpu(); config.DisableGpu();
...@@ -59,7 +58,7 @@ void DBDetector::LoadModel(const std::string &model_dir) { ...@@ -59,7 +58,7 @@ void DBDetector::LoadModel(const std::string &model_dir) {
config.EnableMemoryOptim(); config.EnableMemoryOptim();
// config.DisableGlogInfo(); // config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config); this->predictor_ = paddle_infer::CreatePredictor(config);
} }
void DBDetector::Run(cv::Mat &img, void DBDetector::Run(cv::Mat &img,
......
...@@ -37,7 +37,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list, ...@@ -37,7 +37,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
for (int beg_img_no = 0; beg_img_no < img_num; for (int beg_img_no = 0; beg_img_no < img_num;
beg_img_no += this->rec_batch_num_) { beg_img_no += this->rec_batch_num_) {
auto preprocess_start = std::chrono::steady_clock::now(); auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = min(img_num, beg_img_no + this->rec_batch_num_); int end_img_no = std::min(img_num, beg_img_no + this->rec_batch_num_);
int batch_num = end_img_no - beg_img_no; int batch_num = end_img_no - beg_img_no;
int imgH = this->rec_image_shape_[1]; int imgH = this->rec_image_shape_[1];
int imgW = this->rec_image_shape_[2]; int imgW = this->rec_image_shape_[2];
...@@ -46,7 +46,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list, ...@@ -46,7 +46,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
int h = img_list[indices[ino]].rows; int h = img_list[indices[ino]].rows;
int w = img_list[indices[ino]].cols; int w = img_list[indices[ino]].cols;
float wh_ratio = w * 1.0 / h; float wh_ratio = w * 1.0 / h;
max_wh_ratio = max(max_wh_ratio, wh_ratio); max_wh_ratio = std::max(max_wh_ratio, wh_ratio);
} }
int batch_width = imgW; int batch_width = imgW;
...@@ -60,7 +60,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list, ...@@ -60,7 +60,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_); this->is_scale_);
norm_img_batch.push_back(resize_img); norm_img_batch.push_back(resize_img);
batch_width = max(resize_img.cols, batch_width); batch_width = std::max(resize_img.cols, batch_width);
} }
std::vector<float> input(batch_num * 3 * imgH * batch_width, 0.0f); std::vector<float> input(batch_num * 3 * imgH * batch_width, 0.0f);
...@@ -115,7 +115,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list, ...@@ -115,7 +115,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
last_index = argmax_idx; last_index = argmax_idx;
} }
score /= count; score /= count;
if (isnan(score)) { if (std::isnan(score)) {
continue; continue;
} }
rec_texts[indices[beg_img_no + m]] = str_res; rec_texts[indices[beg_img_no + m]] = str_res;
...@@ -130,7 +130,6 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list, ...@@ -130,7 +130,6 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
} }
void CRNNRecognizer::LoadModel(const std::string &model_dir) { void CRNNRecognizer::LoadModel(const std::string &model_dir) {
// AnalysisConfig config;
paddle_infer::Config config; paddle_infer::Config config;
config.SetModel(model_dir + "/inference.pdmodel", config.SetModel(model_dir + "/inference.pdmodel",
model_dir + "/inference.pdiparams"); model_dir + "/inference.pdiparams");
...@@ -147,12 +146,11 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { ...@@ -147,12 +146,11 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
if (this->precision_ == "int8") { if (this->precision_ == "int8") {
precision = paddle_infer::Config::Precision::kInt8; precision = paddle_infer::Config::Precision::kInt8;
} }
if (!Utility::PathExists("./trt_rec_shape.txt")){ if (!Utility::PathExists("./trt_rec_shape.txt")) {
config.CollectShapeRangeInfo("./trt_rec_shape.txt"); config.CollectShapeRangeInfo("./trt_rec_shape.txt");
} else { } else {
config.EnableTunedTensorRtDynamicShape("./trt_rec_shape.txt", true); config.EnableTunedTensorRtDynamicShape("./trt_rec_shape.txt", true);
} }
} }
} else { } else {
config.DisableGpu(); config.DisableGpu();
...@@ -177,7 +175,7 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { ...@@ -177,7 +175,7 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
config.EnableMemoryOptim(); config.EnableMemoryOptim();
// config.DisableGlogInfo(); // config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config); this->predictor_ = paddle_infer::CreatePredictor(config);
} }
} // namespace PaddleOCR } // namespace PaddleOCR
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <include/paddleocr.h> #include <include/paddleocr.h>
#include "auto_log/autolog.h" #include "auto_log/autolog.h"
#include <numeric>
namespace PaddleOCR { namespace PaddleOCR {
PPOCR::PPOCR() { PPOCR::PPOCR() {
...@@ -44,84 +44,15 @@ PPOCR::PPOCR() { ...@@ -44,84 +44,15 @@ PPOCR::PPOCR() {
} }
}; };
void PPOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
std::vector<std::vector<std::vector<int>>> boxes;
std::vector<double> det_times;
this->detector_->Run(img, boxes, det_times);
for (int i = 0; i < boxes.size(); i++) {
OCRPredictResult res;
res.box = boxes[i];
ocr_results.push_back(res);
}
// sort boex from top to bottom, from left to right
Utility::sorted_boxes(ocr_results);
times[0] += det_times[0];
times[1] += det_times[1];
times[2] += det_times[2];
}
void PPOCR::rec(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
std::vector<std::string> rec_texts(img_list.size(), "");
std::vector<float> rec_text_scores(img_list.size(), 0);
std::vector<double> rec_times;
this->recognizer_->Run(img_list, rec_texts, rec_text_scores, rec_times);
// output rec results
for (int i = 0; i < rec_texts.size(); i++) {
ocr_results[i].text = rec_texts[i];
ocr_results[i].score = rec_text_scores[i];
}
times[0] += rec_times[0];
times[1] += rec_times[1];
times[2] += rec_times[2];
}
void PPOCR::cls(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
std::vector<int> cls_labels(img_list.size(), 0);
std::vector<float> cls_scores(img_list.size(), 0);
std::vector<double> cls_times;
this->classifier_->Run(img_list, cls_labels, cls_scores, cls_times);
// output cls results
for (int i = 0; i < cls_labels.size(); i++) {
ocr_results[i].cls_label = cls_labels[i];
ocr_results[i].cls_score = cls_scores[i];
}
times[0] += cls_times[0];
times[1] += cls_times[1];
times[2] += cls_times[2];
}
std::vector<std::vector<OCRPredictResult>> std::vector<std::vector<OCRPredictResult>>
PPOCR::ocr(std::vector<cv::String> cv_all_img_names, bool det, bool rec, PPOCR::ocr(std::vector<cv::Mat> img_list, bool det, bool rec, bool cls) {
bool cls) {
std::vector<double> time_info_det = {0, 0, 0};
std::vector<double> time_info_rec = {0, 0, 0};
std::vector<double> time_info_cls = {0, 0, 0};
std::vector<std::vector<OCRPredictResult>> ocr_results; std::vector<std::vector<OCRPredictResult>> ocr_results;
if (!det) { if (!det) {
std::vector<OCRPredictResult> ocr_result; std::vector<OCRPredictResult> ocr_result;
// read image ocr_result.resize(img_list.size());
std::vector<cv::Mat> img_list;
for (int i = 0; i < cv_all_img_names.size(); ++i) {
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
img_list.push_back(srcimg);
OCRPredictResult res;
ocr_result.push_back(res);
}
if (cls && this->classifier_ != nullptr) { if (cls && this->classifier_ != nullptr) {
this->cls(img_list, ocr_result, time_info_cls); this->cls(img_list, ocr_result);
for (int i = 0; i < img_list.size(); i++) { for (int i = 0; i < img_list.size(); i++) {
if (ocr_result[i].cls_label % 2 == 1 && if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->classifier_->cls_thresh) { ocr_result[i].cls_score > this->classifier_->cls_thresh) {
...@@ -130,43 +61,39 @@ PPOCR::ocr(std::vector<cv::String> cv_all_img_names, bool det, bool rec, ...@@ -130,43 +61,39 @@ PPOCR::ocr(std::vector<cv::String> cv_all_img_names, bool det, bool rec,
} }
} }
if (rec) { if (rec) {
this->rec(img_list, ocr_result, time_info_rec); this->rec(img_list, ocr_result);
} }
for (int i = 0; i < cv_all_img_names.size(); ++i) { for (int i = 0; i < ocr_result.size(); ++i) {
std::vector<OCRPredictResult> ocr_result_tmp; std::vector<OCRPredictResult> ocr_result_tmp;
ocr_result_tmp.push_back(ocr_result[i]); ocr_result_tmp.push_back(ocr_result[i]);
ocr_results.push_back(ocr_result_tmp); ocr_results.push_back(ocr_result_tmp);
} }
} else { } else {
if (!Utility::PathExists(FLAGS_output) && FLAGS_det) { for (int i = 0; i < img_list.size(); ++i) {
Utility::CreateDir(FLAGS_output); std::vector<OCRPredictResult> ocr_result =
this->ocr(img_list[i], true, rec, cls);
ocr_results.push_back(ocr_result);
} }
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::vector<OCRPredictResult> ocr_result;
if (!FLAGS_benchmark) {
cout << "predict img: " << cv_all_img_names[i] << endl;
} }
return ocr_results;
}
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); std::vector<OCRPredictResult> PPOCR::ocr(cv::Mat img, bool det, bool rec,
if (!srcimg.data) { bool cls) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl; std::vector<OCRPredictResult> ocr_result;
exit(1);
}
// det // det
this->det(srcimg, ocr_result, time_info_det); this->det(img, ocr_result);
// crop image // crop image
std::vector<cv::Mat> img_list; std::vector<cv::Mat> img_list;
for (int j = 0; j < ocr_result.size(); j++) { for (int j = 0; j < ocr_result.size(); j++) {
cv::Mat crop_img; cv::Mat crop_img;
crop_img = Utility::GetRotateCropImage(srcimg, ocr_result[j].box); crop_img = Utility::GetRotateCropImage(img, ocr_result[j].box);
img_list.push_back(crop_img); img_list.push_back(crop_img);
} }
// cls // cls
if (cls && this->classifier_ != nullptr) { if (cls && this->classifier_ != nullptr) {
this->cls(img_list, ocr_result, time_info_cls); this->cls(img_list, ocr_result);
for (int i = 0; i < img_list.size(); i++) { for (int i = 0; i < img_list.size(); i++) {
if (ocr_result[i].cls_label % 2 == 1 && if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->classifier_->cls_thresh) { ocr_result[i].cls_score > this->classifier_->cls_thresh) {
...@@ -176,41 +103,93 @@ PPOCR::ocr(std::vector<cv::String> cv_all_img_names, bool det, bool rec, ...@@ -176,41 +103,93 @@ PPOCR::ocr(std::vector<cv::String> cv_all_img_names, bool det, bool rec,
} }
// rec // rec
if (rec) { if (rec) {
this->rec(img_list, ocr_result, time_info_rec); this->rec(img_list, ocr_result);
} }
ocr_results.push_back(ocr_result); return ocr_result;
}
void PPOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results) {
std::vector<std::vector<std::vector<int>>> boxes;
std::vector<double> det_times;
this->detector_->Run(img, boxes, det_times);
for (int i = 0; i < boxes.size(); i++) {
OCRPredictResult res;
res.box = boxes[i];
ocr_results.push_back(res);
} }
// sort boex from top to bottom, from left to right
Utility::sorted_boxes(ocr_results);
this->time_info_det[0] += det_times[0];
this->time_info_det[1] += det_times[1];
this->time_info_det[2] += det_times[2];
}
void PPOCR::rec(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results) {
std::vector<std::string> rec_texts(img_list.size(), "");
std::vector<float> rec_text_scores(img_list.size(), 0);
std::vector<double> rec_times;
this->recognizer_->Run(img_list, rec_texts, rec_text_scores, rec_times);
// output rec results
for (int i = 0; i < rec_texts.size(); i++) {
ocr_results[i].text = rec_texts[i];
ocr_results[i].score = rec_text_scores[i];
} }
if (FLAGS_benchmark) { this->time_info_rec[0] += rec_times[0];
this->log(time_info_det, time_info_rec, time_info_cls, this->time_info_rec[1] += rec_times[1];
cv_all_img_names.size()); this->time_info_rec[2] += rec_times[2];
}
void PPOCR::cls(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results) {
std::vector<int> cls_labels(img_list.size(), 0);
std::vector<float> cls_scores(img_list.size(), 0);
std::vector<double> cls_times;
this->classifier_->Run(img_list, cls_labels, cls_scores, cls_times);
// output cls results
for (int i = 0; i < cls_labels.size(); i++) {
ocr_results[i].cls_label = cls_labels[i];
ocr_results[i].cls_score = cls_scores[i];
} }
return ocr_results; this->time_info_cls[0] += cls_times[0];
} // namespace PaddleOCR this->time_info_cls[1] += cls_times[1];
this->time_info_cls[2] += cls_times[2];
}
void PPOCR::log(std::vector<double> &det_times, std::vector<double> &rec_times, void PPOCR::reset_timer() {
std::vector<double> &cls_times, int img_num) { this->time_info_det = {0, 0, 0};
if (det_times[0] + det_times[1] + det_times[2] > 0) { this->time_info_rec = {0, 0, 0};
this->time_info_cls = {0, 0, 0};
}
void PPOCR::benchmark_log(int img_num) {
if (this->time_info_det[0] + this->time_info_det[1] + this->time_info_det[2] >
0) {
AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt, AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic", FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
FLAGS_precision, det_times, img_num); FLAGS_precision, this->time_info_det, img_num);
autolog_det.report(); autolog_det.report();
} }
if (rec_times[0] + rec_times[1] + rec_times[2] > 0) { if (this->time_info_rec[0] + this->time_info_rec[1] + this->time_info_rec[2] >
0) {
AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt, AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_rec_batch_num, "dynamic", FLAGS_precision, FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
rec_times, img_num); this->time_info_rec, img_num);
autolog_rec.report(); autolog_rec.report();
} }
if (cls_times[0] + cls_times[1] + cls_times[2] > 0) { if (this->time_info_cls[0] + this->time_info_cls[1] + this->time_info_cls[2] >
0) {
AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt, AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision, FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
cls_times, img_num); this->time_info_cls, img_num);
autolog_cls.report(); autolog_cls.report();
} }
} }
PPOCR::~PPOCR() { PPOCR::~PPOCR() {
if (this->detector_ != nullptr) { if (this->detector_ != nullptr) {
delete this->detector_; delete this->detector_;
......
...@@ -16,14 +16,19 @@ ...@@ -16,14 +16,19 @@
#include <include/paddlestructure.h> #include <include/paddlestructure.h>
#include "auto_log/autolog.h" #include "auto_log/autolog.h"
#include <numeric>
#include <sys/stat.h>
namespace PaddleOCR { namespace PaddleOCR {
PaddleStructure::PaddleStructure() { PaddleStructure::PaddleStructure() {
if (FLAGS_layout) {
this->layout_model_ = new StructureLayoutRecognizer(
FLAGS_layout_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_layout_dict_path,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_layout_score_threshold,
FLAGS_layout_nms_threshold);
}
if (FLAGS_table) { if (FLAGS_table) {
this->recognizer_ = new StructureTableRecognizer( this->table_model_ = new StructureTableRecognizer(
FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path, FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num, FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num,
...@@ -31,68 +36,63 @@ PaddleStructure::PaddleStructure() { ...@@ -31,68 +36,63 @@ PaddleStructure::PaddleStructure() {
} }
}; };
std::vector<std::vector<StructurePredictResult>> std::vector<StructurePredictResult>
PaddleStructure::structure(std::vector<cv::String> cv_all_img_names, PaddleStructure::structure(cv::Mat srcimg, bool layout, bool table, bool ocr) {
bool layout, bool table) { cv::Mat img;
std::vector<double> time_info_det = {0, 0, 0}; srcimg.copyTo(img);
std::vector<double> time_info_rec = {0, 0, 0};
std::vector<double> time_info_cls = {0, 0, 0};
std::vector<double> time_info_table = {0, 0, 0};
std::vector<std::vector<StructurePredictResult>> structure_results; std::vector<StructurePredictResult> structure_results;
if (!Utility::PathExists(FLAGS_output) && FLAGS_det) {
Utility::CreateDir(FLAGS_output);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::vector<StructurePredictResult> structure_result;
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
if (layout) { if (layout) {
this->layout(img, structure_results);
} else { } else {
StructurePredictResult res; StructurePredictResult res;
res.type = "table"; res.type = "table";
res.box = std::vector<int>(4, 0); res.box = std::vector<float>(4, 0.0);
res.box[2] = srcimg.cols; res.box[2] = img.cols;
res.box[3] = srcimg.rows; res.box[3] = img.rows;
structure_result.push_back(res); structure_results.push_back(res);
} }
cv::Mat roi_img; cv::Mat roi_img;
for (int i = 0; i < structure_result.size(); i++) { for (int i = 0; i < structure_results.size(); i++) {
// crop image // crop image
roi_img = Utility::crop_image(srcimg, structure_result[i].box); roi_img = Utility::crop_image(img, structure_results[i].box);
if (structure_result[i].type == "table") { if (structure_results[i].type == "table" && table) {
this->table(roi_img, structure_result[i], time_info_table, this->table(roi_img, structure_results[i]);
time_info_det, time_info_rec, time_info_cls); } else if (ocr) {
} structure_results[i].text_res = this->ocr(roi_img, true, true, false);
} }
structure_results.push_back(structure_result);
} }
return structure_results; return structure_results;
}; };
void PaddleStructure::layout(
cv::Mat img, std::vector<StructurePredictResult> &structure_result) {
std::vector<double> layout_times;
this->layout_model_->Run(img, structure_result, layout_times);
this->time_info_layout[0] += layout_times[0];
this->time_info_layout[1] += layout_times[1];
this->time_info_layout[2] += layout_times[2];
}
void PaddleStructure::table(cv::Mat img, void PaddleStructure::table(cv::Mat img,
StructurePredictResult &structure_result, StructurePredictResult &structure_result) {
std::vector<double> &time_info_table,
std::vector<double> &time_info_det,
std::vector<double> &time_info_rec,
std::vector<double> &time_info_cls) {
// predict structure // predict structure
std::vector<std::vector<std::string>> structure_html_tags; std::vector<std::vector<std::string>> structure_html_tags;
std::vector<float> structure_scores(1, 0); std::vector<float> structure_scores(1, 0);
std::vector<std::vector<std::vector<int>>> structure_boxes; std::vector<std::vector<std::vector<int>>> structure_boxes;
std::vector<double> structure_imes; std::vector<double> structure_times;
std::vector<cv::Mat> img_list; std::vector<cv::Mat> img_list;
img_list.push_back(img); img_list.push_back(img);
this->recognizer_->Run(img_list, structure_html_tags, structure_scores,
structure_boxes, structure_imes); this->table_model_->Run(img_list, structure_html_tags, structure_scores,
time_info_table[0] += structure_imes[0]; structure_boxes, structure_times);
time_info_table[1] += structure_imes[1];
time_info_table[2] += structure_imes[2]; this->time_info_table[0] += structure_times[0];
this->time_info_table[1] += structure_times[1];
this->time_info_table[2] += structure_times[2];
std::vector<OCRPredictResult> ocr_result; std::vector<OCRPredictResult> ocr_result;
std::string html; std::string html;
...@@ -100,22 +100,22 @@ void PaddleStructure::table(cv::Mat img, ...@@ -100,22 +100,22 @@ void PaddleStructure::table(cv::Mat img,
for (int i = 0; i < img_list.size(); i++) { for (int i = 0; i < img_list.size(); i++) {
// det // det
this->det(img_list[i], ocr_result, time_info_det); this->det(img_list[i], ocr_result);
// crop image // crop image
std::vector<cv::Mat> rec_img_list; std::vector<cv::Mat> rec_img_list;
std::vector<int> ocr_box; std::vector<int> ocr_box;
for (int j = 0; j < ocr_result.size(); j++) { for (int j = 0; j < ocr_result.size(); j++) {
ocr_box = Utility::xyxyxyxy2xyxy(ocr_result[j].box); ocr_box = Utility::xyxyxyxy2xyxy(ocr_result[j].box);
ocr_box[0] = max(0, ocr_box[0] - expand_pixel); ocr_box[0] = std::max(0, ocr_box[0] - expand_pixel);
ocr_box[1] = max(0, ocr_box[1] - expand_pixel), ocr_box[1] = std::max(0, ocr_box[1] - expand_pixel),
ocr_box[2] = min(img_list[i].cols, ocr_box[2] + expand_pixel); ocr_box[2] = std::min(img_list[i].cols, ocr_box[2] + expand_pixel);
ocr_box[3] = min(img_list[i].rows, ocr_box[3] + expand_pixel); ocr_box[3] = std::min(img_list[i].rows, ocr_box[3] + expand_pixel);
cv::Mat crop_img = Utility::crop_image(img_list[i], ocr_box); cv::Mat crop_img = Utility::crop_image(img_list[i], ocr_box);
rec_img_list.push_back(crop_img); rec_img_list.push_back(crop_img);
} }
// rec // rec
this->rec(rec_img_list, ocr_result, time_info_rec); this->rec(rec_img_list, ocr_result);
// rebuild table // rebuild table
html = this->rebuild_table(structure_html_tags[i], structure_boxes[i], html = this->rebuild_table(structure_html_tags[i], structure_boxes[i],
ocr_result); ocr_result);
...@@ -130,7 +130,7 @@ PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags, ...@@ -130,7 +130,7 @@ PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags,
std::vector<std::vector<int>> structure_boxes, std::vector<std::vector<int>> structure_boxes,
std::vector<OCRPredictResult> &ocr_result) { std::vector<OCRPredictResult> &ocr_result) {
// match text in same cell // match text in same cell
std::vector<std::vector<string>> matched(structure_boxes.size(), std::vector<std::vector<std::string>> matched(structure_boxes.size(),
std::vector<std::string>()); std::vector<std::string>());
std::vector<int> ocr_box; std::vector<int> ocr_box;
...@@ -150,7 +150,7 @@ PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags, ...@@ -150,7 +150,7 @@ PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags,
structure_box = structure_boxes[j]; structure_box = structure_boxes[j];
} }
dis_list[j][0] = this->dis(ocr_box, structure_box); dis_list[j][0] = this->dis(ocr_box, structure_box);
dis_list[j][1] = 1 - this->iou(ocr_box, structure_box); dis_list[j][1] = 1 - Utility::iou(ocr_box, structure_box);
dis_list[j][2] = j; dis_list[j][2] = j;
} }
// find min dis idx // find min dis idx
...@@ -216,28 +216,6 @@ PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags, ...@@ -216,28 +216,6 @@ PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags,
return html_str; return html_str;
} }
float PaddleStructure::iou(std::vector<int> &box1, std::vector<int> &box2) {
int area1 = max(0, box1[2] - box1[0]) * max(0, box1[3] - box1[1]);
int area2 = max(0, box2[2] - box2[0]) * max(0, box2[3] - box2[1]);
// computing the sum_area
int sum_area = area1 + area2;
// find the each point of intersect rectangle
int x1 = max(box1[0], box2[0]);
int y1 = max(box1[1], box2[1]);
int x2 = min(box1[2], box2[2]);
int y2 = min(box1[3], box2[3]);
// judge if there is an intersect
if (y1 >= y2 || x1 >= x2) {
return 0.0;
} else {
int intersect = (x2 - x1) * (y2 - y1);
return intersect / (sum_area - intersect + 0.00000001);
}
}
float PaddleStructure::dis(std::vector<int> &box1, std::vector<int> &box2) { float PaddleStructure::dis(std::vector<int> &box1, std::vector<int> &box2) {
int x1_1 = box1[0]; int x1_1 = box1[0];
int y1_1 = box1[1]; int y1_1 = box1[1];
...@@ -253,12 +231,64 @@ float PaddleStructure::dis(std::vector<int> &box1, std::vector<int> &box2) { ...@@ -253,12 +231,64 @@ float PaddleStructure::dis(std::vector<int> &box1, std::vector<int> &box2) {
abs(x1_2 - x1_1) + abs(y1_2 - y1_1) + abs(x2_2 - x2_1) + abs(y2_2 - y2_1); abs(x1_2 - x1_1) + abs(y1_2 - y1_1) + abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
float dis_2 = abs(x1_2 - x1_1) + abs(y1_2 - y1_1); float dis_2 = abs(x1_2 - x1_1) + abs(y1_2 - y1_1);
float dis_3 = abs(x2_2 - x2_1) + abs(y2_2 - y2_1); float dis_3 = abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
return dis + min(dis_2, dis_3); return dis + std::min(dis_2, dis_3);
}
void PaddleStructure::reset_timer() {
this->time_info_det = {0, 0, 0};
this->time_info_rec = {0, 0, 0};
this->time_info_cls = {0, 0, 0};
this->time_info_table = {0, 0, 0};
this->time_info_layout = {0, 0, 0};
}
void PaddleStructure::benchmark_log(int img_num) {
if (this->time_info_det[0] + this->time_info_det[1] + this->time_info_det[2] >
0) {
AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
FLAGS_precision, this->time_info_det, img_num);
autolog_det.report();
}
if (this->time_info_rec[0] + this->time_info_rec[1] + this->time_info_rec[2] >
0) {
AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
this->time_info_rec, img_num);
autolog_rec.report();
}
if (this->time_info_cls[0] + this->time_info_cls[1] + this->time_info_cls[2] >
0) {
AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
this->time_info_cls, img_num);
autolog_cls.report();
}
if (this->time_info_table[0] + this->time_info_table[1] +
this->time_info_table[2] >
0) {
AutoLogger autolog_table("table", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
this->time_info_table, img_num);
autolog_table.report();
}
if (this->time_info_layout[0] + this->time_info_layout[1] +
this->time_info_layout[2] >
0) {
AutoLogger autolog_layout("layout", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
this->time_info_layout, img_num);
autolog_layout.report();
}
} }
PaddleStructure::~PaddleStructure() { PaddleStructure::~PaddleStructure() {
if (this->recognizer_ != nullptr) { if (this->table_model_ != nullptr) {
delete this->recognizer_; delete this->table_model_;
} }
}; };
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <include/clipper.h>
#include <include/postprocess_op.h> #include <include/postprocess_op.h>
namespace PaddleOCR { namespace PaddleOCR {
...@@ -431,7 +430,7 @@ void TablePostProcessor::Run( ...@@ -431,7 +430,7 @@ void TablePostProcessor::Run(
} }
} }
score /= count; score /= count;
if (isnan(score) || rec_boxes.size() == 0) { if (std::isnan(score) || rec_boxes.size() == 0) {
score = -1; score = -1;
} }
rec_scores.push_back(score); rec_scores.push_back(score);
...@@ -440,4 +439,137 @@ void TablePostProcessor::Run( ...@@ -440,4 +439,137 @@ void TablePostProcessor::Run(
} }
} }
void PicodetPostProcessor::init(std::string label_path,
const double score_threshold,
const double nms_threshold,
const std::vector<int> &fpn_stride) {
this->label_list_ = Utility::ReadDict(label_path);
this->score_threshold_ = score_threshold;
this->nms_threshold_ = nms_threshold;
this->num_class_ = label_list_.size();
this->fpn_stride_ = fpn_stride;
}
void PicodetPostProcessor::Run(std::vector<StructurePredictResult> &results,
std::vector<std::vector<float>> outs,
std::vector<int> ori_shape,
std::vector<int> resize_shape, int reg_max) {
int in_h = resize_shape[0];
int in_w = resize_shape[1];
float scale_factor_h = resize_shape[0] / float(ori_shape[0]);
float scale_factor_w = resize_shape[1] / float(ori_shape[1]);
std::vector<std::vector<StructurePredictResult>> bbox_results;
bbox_results.resize(this->num_class_);
for (int i = 0; i < this->fpn_stride_.size(); ++i) {
int feature_h = std::ceil((float)in_h / this->fpn_stride_[i]);
int feature_w = std::ceil((float)in_w / this->fpn_stride_[i]);
for (int idx = 0; idx < feature_h * feature_w; idx++) {
// score and label
float score = 0;
int cur_label = 0;
for (int label = 0; label < this->num_class_; label++) {
if (outs[i][idx * this->num_class_ + label] > score) {
score = outs[i][idx * this->num_class_ + label];
cur_label = label;
}
}
// bbox
if (score > this->score_threshold_) {
int row = idx / feature_w;
int col = idx % feature_w;
std::vector<float> bbox_pred(
outs[i + this->fpn_stride_.size()].begin() + idx * 4 * reg_max,
outs[i + this->fpn_stride_.size()].begin() +
(idx + 1) * 4 * reg_max);
bbox_results[cur_label].push_back(
this->disPred2Bbox(bbox_pred, cur_label, score, col, row,
this->fpn_stride_[i], resize_shape, reg_max));
}
}
}
for (int i = 0; i < bbox_results.size(); i++) {
bool flag = bbox_results[i].size() <= 0;
}
for (int i = 0; i < bbox_results.size(); i++) {
bool flag = bbox_results[i].size() <= 0;
if (bbox_results[i].size() <= 0) {
continue;
}
this->nms(bbox_results[i], this->nms_threshold_);
for (auto box : bbox_results[i]) {
box.box[0] = box.box[0] / scale_factor_w;
box.box[2] = box.box[2] / scale_factor_w;
box.box[1] = box.box[1] / scale_factor_h;
box.box[3] = box.box[3] / scale_factor_h;
results.push_back(box);
}
}
}
StructurePredictResult
PicodetPostProcessor::disPred2Bbox(std::vector<float> bbox_pred, int label,
float score, int x, int y, int stride,
std::vector<int> im_shape, int reg_max) {
float ct_x = (x + 0.5) * stride;
float ct_y = (y + 0.5) * stride;
std::vector<float> dis_pred;
dis_pred.resize(4);
for (int i = 0; i < 4; i++) {
float dis = 0;
std::vector<float> bbox_pred_i(bbox_pred.begin() + i * reg_max,
bbox_pred.begin() + (i + 1) * reg_max);
std::vector<float> dis_after_sm =
Utility::activation_function_softmax(bbox_pred_i);
for (int j = 0; j < reg_max; j++) {
dis += j * dis_after_sm[j];
}
dis *= stride;
dis_pred[i] = dis;
}
float xmin = (std::max)(ct_x - dis_pred[0], .0f);
float ymin = (std::max)(ct_y - dis_pred[1], .0f);
float xmax = (std::min)(ct_x + dis_pred[2], (float)im_shape[1]);
float ymax = (std::min)(ct_y + dis_pred[3], (float)im_shape[0]);
StructurePredictResult result_item;
result_item.box = {xmin, ymin, xmax, ymax};
result_item.type = this->label_list_[label];
result_item.confidence = score;
return result_item;
}
void PicodetPostProcessor::nms(std::vector<StructurePredictResult> &input_boxes,
float nms_threshold) {
std::sort(input_boxes.begin(), input_boxes.end(),
[](StructurePredictResult a, StructurePredictResult b) {
return a.confidence > b.confidence;
});
std::vector<int> picked(input_boxes.size(), 1);
for (int i = 0; i < input_boxes.size(); ++i) {
if (picked[i] == 0) {
continue;
}
for (int j = i + 1; j < input_boxes.size(); ++j) {
if (picked[j] == 0) {
continue;
}
float iou = Utility::iou(input_boxes[i].box, input_boxes[j].box);
if (iou > nms_threshold) {
picked[j] = 0;
}
}
}
std::vector<StructurePredictResult> input_boxes_nms;
for (int i = 0; i < input_boxes.size(); ++i) {
if (picked[i] == 1) {
input_boxes_nms.push_back(input_boxes[i]);
}
}
input_boxes = input_boxes_nms;
}
} // namespace PaddleOCR } // namespace PaddleOCR
...@@ -12,21 +12,6 @@ ...@@ -12,21 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/preprocess_op.h> #include <include/preprocess_op.h>
namespace PaddleOCR { namespace PaddleOCR {
...@@ -69,13 +54,13 @@ void Normalize::Run(cv::Mat *im, const std::vector<float> &mean, ...@@ -69,13 +54,13 @@ void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
} }
void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
string limit_type, int limit_side_len, float &ratio_h, std::string limit_type, int limit_side_len,
float &ratio_w, bool use_tensorrt) { float &ratio_h, float &ratio_w, bool use_tensorrt) {
int w = img.cols; int w = img.cols;
int h = img.rows; int h = img.rows;
float ratio = 1.f; float ratio = 1.f;
if (limit_type == "min") { if (limit_type == "min") {
int min_wh = min(h, w); int min_wh = std::min(h, w);
if (min_wh < limit_side_len) { if (min_wh < limit_side_len) {
if (h < w) { if (h < w) {
ratio = float(limit_side_len) / float(h); ratio = float(limit_side_len) / float(h);
...@@ -84,7 +69,7 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, ...@@ -84,7 +69,7 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
} }
} }
} else { } else {
int max_wh = max(h, w); int max_wh = std::max(h, w);
if (max_wh > limit_side_len) { if (max_wh > limit_side_len) {
if (h > w) { if (h > w) {
ratio = float(limit_side_len) / float(h); ratio = float(limit_side_len) / float(h);
...@@ -97,8 +82,8 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, ...@@ -97,8 +82,8 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_h = int(float(h) * ratio); int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio); int resize_w = int(float(w) * ratio);
resize_h = max(int(round(float(resize_h) / 32) * 32), 32); resize_h = std::max(int(round(float(resize_h) / 32) * 32), 32);
resize_w = max(int(round(float(resize_w) / 32) * 32), 32); resize_w = std::max(int(round(float(resize_w) / 32) * 32), 32);
cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
ratio_h = float(resize_h) / float(h); ratio_h = float(resize_h) / float(h);
...@@ -175,4 +160,9 @@ void TablePadImg::Run(const cv::Mat &img, cv::Mat &resize_img, ...@@ -175,4 +160,9 @@ void TablePadImg::Run(const cv::Mat &img, cv::Mat &resize_img,
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
} }
void Resize::Run(const cv::Mat &img, cv::Mat &resize_img, const int h,
const int w) {
cv::resize(img, resize_img, cv::Size(w, h));
}
} // namespace PaddleOCR } // namespace PaddleOCR
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <include/structure_layout.h>
namespace PaddleOCR {
void StructureLayoutRecognizer::Run(cv::Mat img,
std::vector<StructurePredictResult> &result,
std::vector<double> &times) {
std::chrono::duration<float> preprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> inference_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> postprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
// preprocess
auto preprocess_start = std::chrono::steady_clock::now();
cv::Mat srcimg;
img.copyTo(srcimg);
cv::Mat resize_img;
this->resize_op_.Run(srcimg, resize_img, 800, 608);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
auto preprocess_end = std::chrono::steady_clock::now();
preprocess_diff += preprocess_end - preprocess_start;
// inference.
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
auto inference_start = std::chrono::steady_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
// Get output tensor
std::vector<std::vector<float>> out_tensor_list;
std::vector<std::vector<int>> output_shape_list;
auto output_names = this->predictor_->GetOutputNames();
for (int j = 0; j < output_names.size(); j++) {
auto output_tensor = this->predictor_->GetOutputHandle(output_names[j]);
std::vector<int> output_shape = output_tensor->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
output_shape_list.push_back(output_shape);
std::vector<float> out_data;
out_data.resize(out_num);
output_tensor->CopyToCpu(out_data.data());
out_tensor_list.push_back(out_data);
}
auto inference_end = std::chrono::steady_clock::now();
inference_diff += inference_end - inference_start;
// postprocess
auto postprocess_start = std::chrono::steady_clock::now();
std::vector<int> bbox_num;
int reg_max = 0;
for (int i = 0; i < out_tensor_list.size(); i++) {
if (i == this->post_processor_.fpn_stride_.size()) {
reg_max = output_shape_list[i][2] / 4;
break;
}
}
std::vector<int> ori_shape = {srcimg.rows, srcimg.cols};
std::vector<int> resize_shape = {resize_img.rows, resize_img.cols};
this->post_processor_.Run(result, out_tensor_list, ori_shape, resize_shape,
reg_max);
bbox_num.push_back(result.size());
auto postprocess_end = std::chrono::steady_clock::now();
postprocess_diff += postprocess_end - postprocess_start;
times.push_back(double(preprocess_diff.count() * 1000));
times.push_back(double(inference_diff.count() * 1000));
times.push_back(double(postprocess_diff.count() * 1000));
}
void StructureLayoutRecognizer::LoadModel(const std::string &model_dir) {
paddle_infer::Config config;
if (Utility::PathExists(model_dir + "/inference.pdmodel") &&
Utility::PathExists(model_dir + "/inference.pdiparams")) {
config.SetModel(model_dir + "/inference.pdmodel",
model_dir + "/inference.pdiparams");
} else if (Utility::PathExists(model_dir + "/model.pdmodel") &&
Utility::PathExists(model_dir + "/model.pdiparams")) {
config.SetModel(model_dir + "/model.pdmodel",
model_dir + "/model.pdiparams");
} else {
std::cerr << "[ERROR] not find model.pdiparams or inference.pdiparams in "
<< model_dir << std::endl;
exit(1);
}
if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
if (this->use_tensorrt_) {
auto precision = paddle_infer::Config::Precision::kFloat32;
if (this->precision_ == "fp16") {
precision = paddle_infer::Config::Precision::kHalf;
}
if (this->precision_ == "int8") {
precision = paddle_infer::Config::Precision::kInt8;
}
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
if (!Utility::PathExists("./trt_layout_shape.txt")) {
config.CollectShapeRangeInfo("./trt_layout_shape.txt");
} else {
config.EnableTunedTensorRtDynamicShape("./trt_layout_shape.txt", true);
}
}
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
// false for zero copy tensor
config.SwitchUseFeedFetchOps(false);
// true for multiple input
config.SwitchSpecifyInputNames(true);
config.SwitchIrOptim(true);
config.EnableMemoryOptim();
config.DisableGlogInfo();
this->predictor_ = paddle_infer::CreatePredictor(config);
}
} // namespace PaddleOCR
...@@ -34,7 +34,7 @@ void StructureTableRecognizer::Run( ...@@ -34,7 +34,7 @@ void StructureTableRecognizer::Run(
beg_img_no += this->table_batch_num_) { beg_img_no += this->table_batch_num_) {
// preprocess // preprocess
auto preprocess_start = std::chrono::steady_clock::now(); auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = min(img_num, beg_img_no + this->table_batch_num_); int end_img_no = std::min(img_num, beg_img_no + this->table_batch_num_);
int batch_num = end_img_no - beg_img_no; int batch_num = end_img_no - beg_img_no;
std::vector<cv::Mat> norm_img_batch; std::vector<cv::Mat> norm_img_batch;
std::vector<int> width_list; std::vector<int> width_list;
...@@ -118,7 +118,7 @@ void StructureTableRecognizer::Run( ...@@ -118,7 +118,7 @@ void StructureTableRecognizer::Run(
} }
void StructureTableRecognizer::LoadModel(const std::string &model_dir) { void StructureTableRecognizer::LoadModel(const std::string &model_dir) {
AnalysisConfig config; paddle_infer::Config config;
config.SetModel(model_dir + "/inference.pdmodel", config.SetModel(model_dir + "/inference.pdmodel",
model_dir + "/inference.pdiparams"); model_dir + "/inference.pdiparams");
...@@ -133,6 +133,11 @@ void StructureTableRecognizer::LoadModel(const std::string &model_dir) { ...@@ -133,6 +133,11 @@ void StructureTableRecognizer::LoadModel(const std::string &model_dir) {
precision = paddle_infer::Config::Precision::kInt8; precision = paddle_infer::Config::Precision::kInt8;
} }
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false); config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
if (!Utility::PathExists("./trt_table_shape.txt")) {
config.CollectShapeRangeInfo("./trt_table_shape.txt");
} else {
config.EnableTunedTensorRtDynamicShape("./trt_table_shape.txt", true);
}
} }
} else { } else {
config.DisableGpu(); config.DisableGpu();
...@@ -152,6 +157,6 @@ void StructureTableRecognizer::LoadModel(const std::string &model_dir) { ...@@ -152,6 +157,6 @@ void StructureTableRecognizer::LoadModel(const std::string &model_dir) {
config.EnableMemoryOptim(); config.EnableMemoryOptim();
config.DisableGlogInfo(); config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config); this->predictor_ = paddle_infer::CreatePredictor(config);
} }
} // namespace PaddleOCR } // namespace PaddleOCR
...@@ -70,6 +70,7 @@ void Utility::VisualizeBboxes(const cv::Mat &srcimg, ...@@ -70,6 +70,7 @@ void Utility::VisualizeBboxes(const cv::Mat &srcimg,
const std::string &save_path) { const std::string &save_path) {
cv::Mat img_vis; cv::Mat img_vis;
srcimg.copyTo(img_vis); srcimg.copyTo(img_vis);
img_vis = crop_image(img_vis, structure_result.box);
for (int n = 0; n < structure_result.cell_box.size(); n++) { for (int n = 0; n < structure_result.cell_box.size(); n++) {
if (structure_result.cell_box[n].size() == 8) { if (structure_result.cell_box[n].size() == 8) {
cv::Point rook_points[4]; cv::Point rook_points[4];
...@@ -280,23 +281,29 @@ void Utility::print_result(const std::vector<OCRPredictResult> &ocr_result) { ...@@ -280,23 +281,29 @@ void Utility::print_result(const std::vector<OCRPredictResult> &ocr_result) {
} }
} }
cv::Mat Utility::crop_image(cv::Mat &img, std::vector<int> &area) { cv::Mat Utility::crop_image(cv::Mat &img, const std::vector<int> &box) {
cv::Mat crop_im; cv::Mat crop_im;
int crop_x1 = std::max(0, area[0]); int crop_x1 = std::max(0, box[0]);
int crop_y1 = std::max(0, area[1]); int crop_y1 = std::max(0, box[1]);
int crop_x2 = std::min(img.cols - 1, area[2] - 1); int crop_x2 = std::min(img.cols - 1, box[2] - 1);
int crop_y2 = std::min(img.rows - 1, area[3] - 1); int crop_y2 = std::min(img.rows - 1, box[3] - 1);
crop_im = cv::Mat::zeros(area[3] - area[1], area[2] - area[0], 16); crop_im = cv::Mat::zeros(box[3] - box[1], box[2] - box[0], 16);
cv::Mat crop_im_window = cv::Mat crop_im_window =
crop_im(cv::Range(crop_y1 - area[1], crop_y2 + 1 - area[1]), crop_im(cv::Range(crop_y1 - box[1], crop_y2 + 1 - box[1]),
cv::Range(crop_x1 - area[0], crop_x2 + 1 - area[0])); cv::Range(crop_x1 - box[0], crop_x2 + 1 - box[0]));
cv::Mat roi_img = cv::Mat roi_img =
img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1)); img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1));
crop_im_window += roi_img; crop_im_window += roi_img;
return crop_im; return crop_im;
} }
cv::Mat Utility::crop_image(cv::Mat &img, const std::vector<float> &box) {
std::vector<int> box_int = {(int)box[0], (int)box[1], (int)box[2],
(int)box[3]};
return crop_image(img, box_int);
}
void Utility::sorted_boxes(std::vector<OCRPredictResult> &ocr_result) { void Utility::sorted_boxes(std::vector<OCRPredictResult> &ocr_result) {
std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box); std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box);
if (ocr_result.size() > 0) { if (ocr_result.size() > 0) {
...@@ -341,4 +348,78 @@ std::vector<int> Utility::xyxyxyxy2xyxy(std::vector<int> &box) { ...@@ -341,4 +348,78 @@ std::vector<int> Utility::xyxyxyxy2xyxy(std::vector<int> &box) {
return box1; return box1;
} }
float Utility::fast_exp(float x) {
union {
uint32_t i;
float f;
} v{};
v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f);
return v.f;
}
std::vector<float>
Utility::activation_function_softmax(std::vector<float> &src) {
int length = src.size();
std::vector<float> dst;
dst.resize(length);
const float alpha = float(*std::max_element(&src[0], &src[0 + length]));
float denominator{0};
for (int i = 0; i < length; ++i) {
dst[i] = fast_exp(src[i] - alpha);
denominator += dst[i];
}
for (int i = 0; i < length; ++i) {
dst[i] /= denominator;
}
return dst;
}
float Utility::iou(std::vector<int> &box1, std::vector<int> &box2) {
int area1 = std::max(0, box1[2] - box1[0]) * std::max(0, box1[3] - box1[1]);
int area2 = std::max(0, box2[2] - box2[0]) * std::max(0, box2[3] - box2[1]);
// computing the sum_area
int sum_area = area1 + area2;
// find the each point of intersect rectangle
int x1 = std::max(box1[0], box2[0]);
int y1 = std::max(box1[1], box2[1]);
int x2 = std::min(box1[2], box2[2]);
int y2 = std::min(box1[3], box2[3]);
// judge if there is an intersect
if (y1 >= y2 || x1 >= x2) {
return 0.0;
} else {
int intersect = (x2 - x1) * (y2 - y1);
return intersect / (sum_area - intersect + 0.00000001);
}
}
float Utility::iou(std::vector<float> &box1, std::vector<float> &box2) {
float area1 = std::max((float)0.0, box1[2] - box1[0]) *
std::max((float)0.0, box1[3] - box1[1]);
float area2 = std::max((float)0.0, box2[2] - box2[0]) *
std::max((float)0.0, box2[3] - box2[1]);
// computing the sum_area
float sum_area = area1 + area2;
// find the each point of intersect rectangle
float x1 = std::max(box1[0], box2[0]);
float y1 = std::max(box1[1], box2[1]);
float x2 = std::min(box1[2], box2[2]);
float y2 = std::min(box1[3], box2[3]);
// judge if there is an intersect
if (y1 >= y2 || x1 >= x2) {
return 0.0;
} else {
float intersect = (x2 - x1) * (y2 - y1);
return intersect / (sum_area - intersect + 0.00000001);
}
}
} // namespace PaddleOCR } // namespace PaddleOCR
\ No newline at end of file
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
|模型|骨干网络|任务|配置文件|hmean|下载链接| |模型|骨干网络|任务|配置文件|hmean|下载链接|
| --- | --- |--|--- | --- | --- | | --- | --- |--|--- | --- | --- |
|LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)| |LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)|
|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[推理模型(coming soon)]()| |LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar)|
<a name="2"></a> <a name="2"></a>
...@@ -52,14 +52,14 @@ ...@@ -52,14 +52,14 @@
### 4.1 Python推理 ### 4.1 Python推理
**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于LayoutXLM模型的关键信息抽取过程。 - SER
首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。 首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。
``` bash ``` bash
wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar
tar -xf ser_LayoutXLM_xfun_zh.tar tar -xf ser_LayoutXLM_xfun_zh.tar
python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh/best_accuracy Global.save_inference_dir=./inference/ser_layoutxlm python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer
``` ```
LayoutXLM模型基于SER任务进行推理,可以执行如下命令: LayoutXLM模型基于SER任务进行推理,可以执行如下命令:
...@@ -80,6 +80,34 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下 ...@@ -80,6 +80,34 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下
<img src="../../ppstructure/docs/kie/result_ser/zh_val_42_ser.jpg" width="800"> <img src="../../ppstructure/docs/kie/result_ser/zh_val_42_ser.jpg" width="800">
</div> </div>
- RE
首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。
``` bash
wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar
tar -xf re_LayoutXLM_xfun_zh.tar
python3 tools/export_model.py -c configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer
```
LayoutXLM模型基于RE任务进行推理,可以执行如下命令:
```bash
cd ppstructure
python3 kie/predict_kie_token_ser_re.py \
--kie_algorithm=LayoutXLM \
--re_model_dir=../inference/re_layoutxlm_infer \
--ser_model_dir=../inference/ser_layoutxlm_infer \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
--vis_font_path=../doc/fonts/simfang.ttf
```
RE可视化结果默认保存到`./output`文件夹里面,结果示例如下:
<div align="center">
<img src="../../ppstructure/docs/kie/result_re/zh_val_42_re.jpg" width="800">
</div>
<a name="4-2"></a> <a name="4-2"></a>
### 4.2 C++推理部署 ### 4.2 C++推理部署
......
...@@ -23,7 +23,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去 ...@@ -23,7 +23,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去
|模型|骨干网络|任务|配置文件|hmean|下载链接| |模型|骨干网络|任务|配置文件|hmean|下载链接|
| --- | --- |---| --- | --- | --- | | --- | --- |---| --- | --- | --- |
|VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)| |VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)|
|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[推理模型(coming soon)]()| |VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar)|
<a name="2"></a> <a name="2"></a>
...@@ -45,7 +45,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去 ...@@ -45,7 +45,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去
### 4.1 Python推理 ### 4.1 Python推理
**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于VI-LayoutXLM模型的关键信息抽取过程。 - SER
首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。 首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。
...@@ -74,6 +74,36 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下 ...@@ -74,6 +74,36 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下
<img src="../../ppstructure/docs/kie/result_ser/zh_val_42_ser.jpg" width="800"> <img src="../../ppstructure/docs/kie/result_ser/zh_val_42_ser.jpg" width="800">
</div> </div>
- RE
首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。
``` bash
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar
tar -xf re_vi_layoutxlm_xfund_pretrained.tar
python3 tools/export_model.py -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_vi_layoutxlm_xfund_pretrained/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm_infer
```
VI-LayoutXLM模型基于RE任务进行推理,可以执行如下命令:
```bash
cd ppstructure
python3 kie/predict_kie_token_ser_re.py \
--kie_algorithm=LayoutXLM \
--re_model_dir=../inference/re_vi_layoutxlm_infer \
--ser_model_dir=../inference/ser_vi_layoutxlm_infer \
--use_visual_backbone=False \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx"
```
RE可视化结果默认保存到`./output`文件夹里面,结果示例如下:
<div align="center">
<img src="../../ppstructure/docs/kie/result_re/zh_val_42_re.jpg" width="800">
</div>
<a name="4-2"></a> <a name="4-2"></a>
### 4.2 C++推理部署 ### 4.2 C++推理部署
......
...@@ -22,7 +22,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows. ...@@ -22,7 +22,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows.
|Model|Backbone|Task |Cnnfig|Hmean|Download link| |Model|Backbone|Task |Cnnfig|Hmean|Download link|
| --- | --- |---| --- | --- | --- | | --- | --- |---| --- | --- | --- |
|VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)| |VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)|
|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[inference model(coming soon)]()| |VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar)|
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code. Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
...@@ -37,7 +37,7 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code ...@@ -37,7 +37,7 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code
### 4.1 Python Inference ### 4.1 Python Inference
**Note:** Currently, the RE model inference process is still in the process of adaptation. We take SER model as an example to introduce the KIE process based on VI-LayoutXLM model. - SER
First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export. First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export.
...@@ -70,6 +70,41 @@ The SER visualization results are saved in the `./output` folder by default. The ...@@ -70,6 +70,41 @@ The SER visualization results are saved in the `./output` folder by default. The
</div> </div>
- RE
First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export.
``` bash
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar
tar -xf re_vi_layoutxlm_xfund_pretrained.tar
python3 tools/export_model.py -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_vi_layoutxlm_xfund_pretrained/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm_infer
```
Use the following command to infer using VI-LayoutXLM RE model.
```bash
cd ppstructure
python3 kie/predict_kie_token_ser_re.py \
--kie_algorithm=LayoutXLM \
--re_model_dir=../inference/re_vi_layoutxlm_infer \
--ser_model_dir=../inference/ser_vi_layoutxlm_infer \
--use_visual_backbone=False \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx"
```
The RE visualization results are saved in the `./output` folder by default. The results are as follows.
<div align="center">
<img src="../../ppstructure/docs/kie/result_re/zh_val_42_re.jpg" width="800">
</div>
### 4.2 C++ Inference ### 4.2 C++ Inference
Not supported Not supported
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
ppstructure/docs/layout/layout.png

178.6 KB | W: | H:

ppstructure/docs/layout/layout.png

1.2 MB | W: | H:

ppstructure/docs/layout/layout.png
ppstructure/docs/layout/layout.png
ppstructure/docs/layout/layout.png
ppstructure/docs/layout/layout.png
  • 2-up
  • Swipe
  • Onion skin
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册