提交 3f8602c1 编写于 作者: z37757's avatar z37757

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into rfl_branch

# 智能运营:通用中文表格识别
- [1. 背景介绍](#1-背景介绍)
- [2. 中文表格识别](#2-中文表格识别)
- [2.1 环境准备](#21-环境准备)
- [2.2 准备数据集](#22-准备数据集)
- [2.2.1 划分训练测试集](#221-划分训练测试集)
- [2.2.2 查看数据集](#222-查看数据集)
- [2.3 训练](#23-训练)
- [2.4 验证](#24-验证)
- [2.5 训练引擎推理](#25-训练引擎推理)
- [2.6 模型导出](#26-模型导出)
- [2.7 预测引擎推理](#27-预测引擎推理)
- [2.8 表格识别](#28-表格识别)
- [3. 表格属性识别](#3-表格属性识别)
- [3.1 代码、环境、数据准备](#31-代码环境数据准备)
- [3.1.1 代码准备](#311-代码准备)
- [3.1.2 环境准备](#312-环境准备)
- [3.1.3 数据准备](#313-数据准备)
- [3.2 表格属性识别训练](#32-表格属性识别训练)
- [3.3 表格属性识别推理和部署](#33-表格属性识别推理和部署)
- [3.3.1 模型转换](#331-模型转换)
- [3.3.2 模型推理](#332-模型推理)
## 1. 背景介绍
中文表格识别在金融行业有着广泛的应用,如保险理赔、财报分析和信息录入等领域。当前,金融行业的表格识别主要以手动录入为主,开发一种自动表格识别成为丞待解决的问题。
![](https://ai-studio-static-online.cdn.bcebos.com/d1e7780f0c7745ada4be540decefd6288e4d59257d8141f6842682a4c05d28b6)
在金融行业中,表格图像主要有清单类的单元格密集型表格,申请表类的大单元格表格,拍照表格和倾斜表格四种主要形式。
![](https://ai-studio-static-online.cdn.bcebos.com/da82ae8ef8fd479aaa38e1049eb3a681cf020dc108fa458eb3ec79da53b45fd1)
![](https://ai-studio-static-online.cdn.bcebos.com/5ffff2093a144a6993a75eef71634a52276015ee43a04566b9c89d353198c746)
当前的表格识别算法不能很好的处理这些场景下的表格图像。在本例中,我们使用PP-Structurev2最新发布的表格识别模型SLANet来演示如何进行中文表格是识别。同时,为了方便作业流程,我们使用表格属性识别模型对表格图像的属性进行识别,对表格的难易程度进行判断,加快人工进行校对速度。
本项目AI Studio链接:https://aistudio.baidu.com/aistudio/projectdetail/4588067
## 2. 中文表格识别
### 2.1 环境准备
```python
# 下载PaddleOCR代码
! git clone -b dygraph https://gitee.com/paddlepaddle/PaddleOCR
```
```python
# 安装PaddleOCR环境
! pip install -r PaddleOCR/requirements.txt --force-reinstall
! pip install protobuf==3.19
```
### 2.2 准备数据集
本例中使用的数据集采用表格[生成工具](https://github.com/WenmuZhou/TableGeneration)制作。
使用如下命令对数据集进行解压,并查看数据集大小
```python
! cd data/data165849 && tar -xf table_gen_dataset.tar && cd -
! wc -l data/data165849/table_gen_dataset/gt.txt
```
#### 2.2.1 划分训练测试集
使用下述命令将数据集划分为训练集和测试集, 这里将90%划分为训练集,10%划分为测试集
```python
import random
with open('/home/aistudio/data/data165849/table_gen_dataset/gt.txt') as f:
lines = f.readlines()
random.shuffle(lines)
train_len = int(len(lines)*0.9)
train_list = lines[:train_len]
val_list = lines[train_len:]
# 保存结果
with open('/home/aistudio/train.txt','w',encoding='utf-8') as f:
f.writelines(train_list)
with open('/home/aistudio/val.txt','w',encoding='utf-8') as f:
f.writelines(val_list)
```
划分完成后,数据集信息如下
|类型|数量|图片地址|标注文件路径|
|---|---|---|---|
|训练集|18000|/home/aistudio/data/data165849/table_gen_dataset|/home/aistudio/train.txt|
|测试集|2000|/home/aistudio/data/data165849/table_gen_dataset|/home/aistudio/val.txt|
#### 2.2.2 查看数据集
```python
import cv2
import os, json
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline
def parse_line(data_dir, line):
data_line = line.strip("\n")
info = json.loads(data_line)
file_name = info['filename']
cells = info['html']['cells'].copy()
structure = info['html']['structure']['tokens'].copy()
img_path = os.path.join(data_dir, file_name)
if not os.path.exists(img_path):
print(img_path)
return None
data = {
'img_path': img_path,
'cells': cells,
'structure': structure,
'file_name': file_name
}
return data
def draw_bbox(img_path, points, color=(255, 0, 0), thickness=2):
if isinstance(img_path, str):
img_path = cv2.imread(img_path)
img_path = img_path.copy()
for point in points:
cv2.polylines(img_path, [point.astype(int)], True, color, thickness)
return img_path
def rebuild_html(data):
html_code = data['structure']
cells = data['cells']
to_insert = [i for i, tag in enumerate(html_code) if tag in ('<td>', '>')]
for i, cell in zip(to_insert[::-1], cells[::-1]):
if cell['tokens']:
text = ''.join(cell['tokens'])
# skip empty text
sp_char_list = ['<b>', '</b>', '\u2028', ' ', '<i>', '</i>']
text_remove_style = skip_char(text, sp_char_list)
if len(text_remove_style) == 0:
continue
html_code.insert(i + 1, text)
html_code = ''.join(html_code)
return html_code
def skip_char(text, sp_char_list):
"""
skip empty cell
@param text: text in cell
@param sp_char_list: style char and special code
@return:
"""
for sp_char in sp_char_list:
text = text.replace(sp_char, '')
return text
save_dir = '/home/aistudio/vis'
os.makedirs(save_dir, exist_ok=True)
image_dir = '/home/aistudio/data/data165849/'
html_str = '<table border="1">'
# 解析标注信息并还原html表格
data = parse_line(image_dir, val_list[0])
img = cv2.imread(data['img_path'])
img_name = ''.join(os.path.basename(data['file_name']).split('.')[:-1])
img_save_name = os.path.join(save_dir, img_name)
boxes = [np.array(x['bbox']) for x in data['cells']]
show_img = draw_bbox(data['img_path'], boxes)
cv2.imwrite(img_save_name + '_show.jpg', show_img)
html = rebuild_html(data)
html_str += html
html_str += '</table>'
# 显示标注的html字符串
from IPython.core.display import display, HTML
display(HTML(html_str))
# 显示单元格坐标
plt.figure(figsize=(15,15))
plt.imshow(show_img)
plt.show()
```
### 2.3 训练
这里选用PP-Structurev2中的表格识别模型[SLANet](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/configs/table/SLANet.yml)
SLANet是PP-Structurev2全新推出的表格识别模型,相比PP-Structurev1中TableRec-RARE,在速度不变的情况下精度提升4.7%。TEDS提升2%
|算法|Acc|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|Speed|
| --- | --- | --- | ---|
| EDD<sup>[2]</sup> |x| 88.3% |x|
| TableRec-RARE(ours) | 71.73%| 93.88% |779ms|
| SLANet(ours) | 76.31%| 95.89%|766ms|
进行训练之前先使用如下命令下载预训练模型
```python
# 进入PaddleOCR工作目录
os.chdir('/home/aistudio/PaddleOCR')
# 下载英文预训练模型
! wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_train.tar --no-check-certificate
! cd ./pretrain_models/ && tar xf en_ppstructure_mobile_v2.0_SLANet_train.tar && cd ../
```
使用如下命令即可启动训练,需要修改的配置有
|字段|修改值|含义|
|---|---|---|
|Global.pretrained_model|./pretrain_models/en_ppstructure_mobile_v2.0_SLANet_train/best_accuracy.pdparams|指向英文表格预训练模型地址|
|Global.eval_batch_step|562|模型多少step评估一次,一般设置为一个epoch总的step数|
|Optimizer.lr.name|Const|学习率衰减器 |
|Optimizer.lr.learning_rate|0.0005|学习率设为之前的0.05倍 |
|Train.dataset.data_dir|/home/aistudio/data/data165849|指向训练集图片存放目录 |
|Train.dataset.label_file_list|/home/aistudio/data/data165849/table_gen_dataset/train.txt|指向训练集标注文件 |
|Train.loader.batch_size_per_card|32|训练时每张卡的batch_size |
|Train.loader.num_workers|1|训练集多进程数据读取的进程数,在aistudio中需要设为1 |
|Eval.dataset.data_dir|/home/aistudio/data/data165849|指向测试集图片存放目录 |
|Eval.dataset.label_file_list|/home/aistudio/data/data165849/table_gen_dataset/val.txt|指向测试集标注文件 |
|Eval.loader.batch_size_per_card|32|测试时每张卡的batch_size |
|Eval.loader.num_workers|1|测试集多进程数据读取的进程数,在aistudio中需要设为1 |
已经修改好的配置存储在 `/home/aistudio/SLANet_ch.yml`
```python
import os
os.chdir('/home/aistudio/PaddleOCR')
! python3 tools/train.py -c /home/aistudio/SLANet_ch.yml
```
大约在7个epoch后达到最高精度 97.49%
### 2.4 验证
训练完成后,可使用如下命令在测试集上评估最优模型的精度
```python
! python3 tools/eval.py -c /home/aistudio/SLANet_ch.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/SLANet_ch/best_accuracy.pdparams
```
### 2.5 训练引擎推理
使用如下命令可使用训练引擎对单张图片进行推理
```python
import os;os.chdir('/home/aistudio/PaddleOCR')
! python3 tools/infer_table.py -c /home/aistudio/SLANet_ch.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/SLANet_ch/best_accuracy.pdparams Global.infer_img=/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg
```
```python
import cv2
from matplotlib import pyplot as plt
%matplotlib inline
# 显示原图
show_img = cv2.imread('/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg')
plt.figure(figsize=(15,15))
plt.imshow(show_img)
plt.show()
# 显示预测的单元格
show_img = cv2.imread('/home/aistudio/PaddleOCR/output/infer/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg')
plt.figure(figsize=(15,15))
plt.imshow(show_img)
plt.show()
```
### 2.6 模型导出
使用如下命令可将模型导出为inference模型
```python
! python3 tools/export_model.py -c /home/aistudio/SLANet_ch.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/SLANet_ch/best_accuracy.pdparams Global.save_inference_dir=/home/aistudio/SLANet_ch/infer
```
### 2.7 预测引擎推理
使用如下命令可使用预测引擎对单张图片进行推理
```python
os.chdir('/home/aistudio/PaddleOCR/ppstructure')
! python3 table/predict_structure.py \
--table_model_dir=/home/aistudio/SLANet_ch/infer \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--image_dir=/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg \
--output=../output/inference
```
```python
# 显示原图
show_img = cv2.imread('/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg')
plt.figure(figsize=(15,15))
plt.imshow(show_img)
plt.show()
# 显示预测的单元格
show_img = cv2.imread('/home/aistudio/PaddleOCR/output/inference/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg')
plt.figure(figsize=(15,15))
plt.imshow(show_img)
plt.show()
```
### 2.8 表格识别
在表格结构模型训练完成后,可结合OCR检测识别模型,对表格内容进行识别。
首先下载PP-OCRv3文字检测识别模型
```python
# 下载PP-OCRv3文本检测识别模型并解压
! wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar --no-check-certificate
! wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar --no-check-certificate
! cd ./inference/ && tar xf ch_PP-OCRv3_det_slim_infer.tar && tar xf ch_PP-OCRv3_rec_slim_infer.tar && cd ../
```
模型下载完成后,使用如下命令进行表格识别
```python
import os;os.chdir('/home/aistudio/PaddleOCR/ppstructure')
! python3 table/predict_table.py \
--det_model_dir=inference/ch_PP-OCRv3_det_slim_infer \
--rec_model_dir=inference/ch_PP-OCRv3_rec_slim_infer \
--table_model_dir=/home/aistudio/SLANet_ch/infer \
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--image_dir=/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg \
--output=../output/table
```
```python
# 显示原图
show_img = cv2.imread('/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg')
plt.figure(figsize=(15,15))
plt.imshow(show_img)
plt.show()
# 显示预测结果
from IPython.core.display import display, HTML
display(HTML('<html><body><table><tr><td colspan="5">alleadersh</td><td rowspan="2">不贰过,推</td><td rowspan="2">从自己参与浙江数</td><td rowspan="2">。另一方</td></tr><tr><td>AnSha</td><td>自己越</td><td>共商共建工作协商</td><td>w.east </td><td>抓好改革试点任务</td></tr><tr><td>Edime</td><td>ImisesElec</td><td>怀天下”。</td><td></td><td>22.26 </td><td>31.61</td><td>4.30 </td><td>794.94</td></tr><tr><td rowspan="2">ip</td><td> Profundi</td><td>:2019年12月1</td><td>Horspro</td><td>444.48</td><td>2.41 </td><td>87</td><td>679.98</td></tr><tr><td> iehaiTrain</td><td>组长蒋蕊</td><td>Toafterdec</td><td>203.43</td><td>23.54 </td><td>4</td><td>4266.62</td></tr><tr><td>Tyint </td><td> roudlyRol</td><td>谢您的好意,我知道</td><td>ErChows</td><td></td><td>48.90</td><td>1031</td><td>6</td></tr><tr><td>NaFlint</td><td></td><td>一辈的</td><td>aterreclam</td><td>7823.86</td><td>9829.23</td><td>7.96 </td><td> 3068</td></tr><tr><td>家上下游企业,5</td><td>Tr</td><td>景象。当地球上的我们</td><td>Urelaw</td><td>799.62</td><td>354.96</td><td>12.98</td><td>33 </td></tr><tr><td>赛事(</td><td> uestCh</td><td>复制的业务模式并</td><td>Listicjust</td><td>9.23</td><td></td><td>92</td><td>53.22</td></tr><tr><td> Ca</td><td> Iskole</td><td>扶贫"之名引导</td><td> Papua </td><td>7191.90</td><td>1.65</td><td>3.62</td><td>48</td></tr><tr><td rowspan="2">避讳</td><td>ir</td><td>但由于</td><td>Fficeof</td><td>0.22</td><td>6.37</td><td>7.17</td><td>3397.75</td></tr><tr><td>ndaTurk</td><td>百处遗址</td><td>gMa</td><td>1288.34</td><td>2053.66</td><td>2.29</td><td>885.45</td></tr></table></body></html>'))
```
## 3. 表格属性识别
### 3.1 代码、环境、数据准备
#### 3.1.1 代码准备
首先,我们需要准备训练表格属性的代码,PaddleClas集成了PULC方案,该方案可以快速获得一个在CPU上用时2ms的属性识别模型。PaddleClas代码可以clone下载得到。获取方式如下:
```python
! git clone -b develop https://gitee.com/paddlepaddle/PaddleClas
```
#### 3.1.2 环境准备
其次,我们需要安装训练PaddleClas相关的依赖包
```python
! pip install -r PaddleClas/requirements.txt --force-reinstall
! pip install protobuf==3.20.0
```
#### 3.1.3 数据准备
最后,准备训练数据。在这里,我们一共定义了表格的6个属性,分别是表格来源、表格数量、表格颜色、表格清晰度、表格有无干扰、表格角度。其可视化如下:
![](https://user-images.githubusercontent.com/45199522/190587903-ccdfa6fb-51e8-42de-b08b-a127cb04e304.png)
这里,我们提供了一个表格属性的demo子集,可以快速迭代体验。下载方式如下:
```python
%cd PaddleClas/dataset
!wget https://paddleclas.bj.bcebos.com/data/PULC/table_attribute.tar
!tar -xf table_attribute.tar
%cd ../PaddleClas/dataset
%cd ../
```
### 3.2 表格属性识别训练
表格属性训练整体pipelinie如下:
![](https://user-images.githubusercontent.com/45199522/190599426-3415b38e-e16e-4e68-9253-2ff531b1b5ca.png)
1.训练过程中,图片经过预处理之后,送入到骨干网络之中,骨干网络将抽取表格图片的特征,最终该特征连接输出的FC层,FC层经过Sigmoid激活函数后和真实标签做交叉熵损失函数,优化器通过对该损失函数做梯度下降来更新骨干网络的参数,经过多轮训练后,骨干网络的参数可以对为止图片做很好的预测;
2.推理过程中,图片经过预处理之后,送入到骨干网络之中,骨干网络加载学习好的权重后对该表格图片做出预测,预测的结果为一个6维向量,该向量中的每个元素反映了每个属性对应的概率值,通过对该值进一步卡阈值之后,得到最终的输出,最终的输出描述了该表格的6个属性。
当准备好相关的数据之后,可以一键启动表格属性的训练,训练代码如下:
```python
!python tools/train.py -c ./ppcls/configs/PULC/table_attribute/PPLCNet_x1_0.yaml -o Global.device=cpu -o Global.epochs=10
```
### 3.3 表格属性识别推理和部署
#### 3.3.1 模型转换
当训练好模型之后,需要将模型转换为推理模型进行部署。转换脚本如下:
```python
!python tools/export_model.py -c ppcls/configs/PULC/table_attribute/PPLCNet_x1_0.yaml -o Global.pretrained_model=output/PPLCNet_x1_0/best_model
```
执行以上命令之后,会在当前目录上生成`inference`文件夹,该文件夹中保存了当前精度最高的推理模型。
#### 3.3.2 模型推理
安装推理需要的paddleclas包, 此时需要通过下载安装paddleclas的develop的whl包
```python
!pip install https://paddleclas.bj.bcebos.com/whl/paddleclas-0.0.0-py3-none-any.whl
```
进入`deploy`目录下即可对模型进行推理
```python
%cd deploy/
```
推理命令如下:
```python
!python python/predict_cls.py -c configs/PULC/table_attribute/inference_table_attribute.yaml -o Global.inference_model_dir="../inference" -o Global.infer_imgs="../dataset/table_attribute/Table_val/val_9.jpg"
!python python/predict_cls.py -c configs/PULC/table_attribute/inference_table_attribute.yaml -o Global.inference_model_dir="../inference" -o Global.infer_imgs="../dataset/table_attribute/Table_val/val_3253.jpg"
```
推理的表格图片:
![](https://user-images.githubusercontent.com/45199522/190596141-74f4feda-b082-46d7-908d-b0bd5839b430.png)
预测结果如下:
```
val_9.jpg: {'attributes': ['Scanned', 'Little', 'Black-and-White', 'Clear', 'Without-Obstacles', 'Horizontal'], 'output': [1, 1, 1, 1, 1, 1]}
```
推理的表格图片:
![](https://user-images.githubusercontent.com/45199522/190597086-2e685200-22d0-4042-9e46-f61f24e02e4e.png)
预测结果如下:
```
val_3253.jpg: {'attributes': ['Photo', 'Little', 'Black-and-White', 'Blurry', 'Without-Obstacles', 'Tilted'], 'output': [0, 1, 1, 0, 1, 0]}
```
对比两张图片可以发现,第一张图片比较清晰,表格属性的结果也偏向于比较容易识别,我们可以更相信表格识别的结果,第二张图片比较模糊,且存在倾斜现象,表格识别可能存在错误,需要我们人工进一步校验。通过表格的属性识别能力,可以进一步将“人工”和“智能”很好的结合起来,为表格识别能力的落地的精度提供保障。
# 印章弯曲文字识别
- [1. 项目介绍](#1-----)
- [2. 环境搭建](#2-----)
* [2.1 准备PaddleDetection环境](#21---paddledetection--)
* [2.2 准备PaddleOCR环境](#22---paddleocr--)
- [3. 数据集准备](#3------)
* [3.1 数据标注](#31-----)
* [3.2 数据处理](#32-----)
- [4. 印章检测实践](#4-------)
- [5. 印章文字识别实践](#5---------)
* [5.1 端对端印章文字识别实践](#51------------)
* [5.2 两阶段印章文字识别实践](#52------------)
+ [5.2.1 印章文字检测](#521-------)
+ [5.2.2 印章文字识别](#522-------)
# 1. 项目介绍
弯曲文字识别在OCR任务中有着广泛的应用,比如:自然场景下的招牌,艺术文字,以及常见的印章文字识别。
在本项目中,将以印章识别任务为例,介绍如何使用PaddleDetection和PaddleOCR完成印章检测和印章文字识别任务。
项目难点:
1. 缺乏训练数据
2. 图像质量参差不齐,图像模糊,文字不清晰
针对以上问题,本项目选用PaddleOCR里的PPOCRLabel工具完成数据标注。基于PaddleDetection完成印章区域检测,然后通过PaddleOCR里的端对端OCR算法和两阶段OCR算法分别完成印章文字识别任务。不同任务的精度效果如下:
| 任务 | 训练数据数量 | 精度 |
| -------- | - | -------- |
| 印章检测 | 1000 | 95% |
| 印章文字识别-端对端OCR方法 | 700 | 47% |
| 印章文字识别-两阶段OCR方法 | 700 | 55% |
点击进入 [AI Studio 项目](https://aistudio.baidu.com/aistudio/projectdetail/4586113)
# 2. 环境搭建
本项目需要准备PaddleDetection和PaddleOCR的项目运行环境,其中PaddleDetection用于实现印章检测任务,PaddleOCR用于实现文字识别任务
## 2.1 准备PaddleDetection环境
下载PaddleDetection代码:
```
!git clone https://github.com/PaddlePaddle/PaddleDetection.git
# 如果克隆github代码较慢,请从gitee上克隆代码
#git clone https://gitee.com/PaddlePaddle/PaddleDetection.git
```
安装PaddleDetection依赖
```
!cd PaddleDetection && pip install -r requirements.txt
```
## 2.2 准备PaddleOCR环境
下载PaddleOCR代码:
```
!git clone https://github.com/PaddlePaddle/PaddleOCR.git
# 如果克隆github代码较慢,请从gitee上克隆代码
#git clone https://gitee.com/PaddlePaddle/PaddleOCR.git
```
安装PaddleOCR依赖
```
!cd PaddleOCR && git checkout dygraph && pip install -r requirements.txt
```
# 3. 数据集准备
## 3.1 数据标注
本项目中使用[PPOCRLabel](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.6/PPOCRLabel)工具标注印章检测数据,标注内容包括印章的位置以及印章中文字的位置和文字内容。
注:PPOCRLabel的使用方法参考[文档](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.6/PPOCRLabel)
PPOCRlabel标注印章数据步骤:
- 打开数据集所在文件夹
- 按下快捷键Q进行4点(多点)标注——针对印章文本识别,
- 印章弯曲文字包围框采用偶数点标注(比如4点,8点,16点),按照阅读顺序,以16点标注为例,从文字左上方开始标注->到文字右上方标注8个点->到文字右下方->文字左下方8个点,一共8个点,形成包围曲线,参考下图。如果文字弯曲程度不高,为了减小标注工作量,可以采用4点、8点标注,需要注意的是,文字上下点数相同。(总点数尽量不要超过18个)
- 对于需要识别的印章中非弯曲文字,采用4点框标注即可
- 对应包围框的文字部分默认是”待识别”,需要修改为包围框内的具体文字内容
- 快捷键W进行矩形标注——针对印章区域检测,印章检测区域保证标注框包围整个印章,包围框对应文字可以设置为'印章区域',方便后续处理。
- 针对印章中的水平文字可以视情况考虑矩形或四点标注:保证按行标注即可。如果背景文字与印章文字比较接近,标注时尽量避开背景文字。
- 标注完成后修改右侧文本结果,确认无误后点击下方check(或CTRL+V),确认本张图片的标注。
- 所有图片标注完成后,在顶部菜单栏点击File -> Export Label导出label.txt。
标注完成后,可视化效果如下:
![](https://ai-studio-static-online.cdn.bcebos.com/f5acbc4f50dd401a8f535ed6a263f94b0edff82c1aed4285836a9ead989b9c13)
数据标注完成后,标签中包含印章检测的标注和印章文字识别的标注,如下所示:
```
img/1.png [{"transcription": "印章区域", "points": [[87, 245], [214, 245], [214, 369], [87, 369]], "difficult": false}, {"transcription": "国家税务总局泸水市税务局第二税务分局", "points": [[110, 314], [116, 290], [131, 275], [152, 273], [170, 277], [181, 289], [186, 303], [186, 312], [201, 311], [198, 289], [189, 272], [175, 259], [152, 252], [124, 257], [100, 280], [94, 312]], "difficult": false}, {"transcription": "征税专用章", "points": [[117, 334], [183, 334], [183, 352], [117, 352]], "difficult": false}]
```
标注中包含表示'印章区域'的坐标和'印章文字'坐标以及文字内容。
## 3.2 数据处理
标注时为了方便标注,没有区分印章区域的标注框和文字区域的标注框,可以通过python代码完成标签的划分。
在本项目的'/home/aistudio/work/seal_labeled_datas'目录下,存放了标注的数据示例,如下:
![](https://ai-studio-static-online.cdn.bcebos.com/3d762970e2184177a2c633695a31029332a4cd805631430ea797309492e45402)
标签文件'/home/aistudio/work/seal_labeled_datas/Label.txt'中的标注内容如下:
```
img/test1.png [{"transcription": "待识别", "points": [[408, 232], [537, 232], [537, 352], [408, 352]], "difficult": false}, {"transcription": "电子回单", "points": [[437, 305], [504, 305], [504, 322], [437, 322]], "difficult": false}, {"transcription": "云南省农村信用社", "points": [[417, 290], [434, 295], [438, 281], [446, 267], [455, 261], [472, 258], [489, 264], [498, 277], [502, 295], [526, 289], [518, 267], [503, 249], [475, 232], [446, 239], [429, 255], [418, 275]], "difficult": false}, {"transcription": "专用章", "points": [[437, 319], [503, 319], [503, 338], [437, 338]], "difficult": false}]
```
为了方便训练,我们需要通过python代码将用于训练印章检测和训练印章文字识别的标注区分开。
```
import numpy as np
import json
import cv2
import os
from shapely.geometry import Polygon
def poly2box(poly):
xmin = np.min(np.array(poly)[:, 0])
ymin = np.min(np.array(poly)[:, 1])
xmax = np.max(np.array(poly)[:, 0])
ymax = np.max(np.array(poly)[:, 1])
return np.array([[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]])
def draw_text_det_res(dt_boxes, src_im, color=(255, 255, 0)):
for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape(-1, 2)
cv2.polylines(src_im, [box], True, color=color, thickness=2)
return src_im
class LabelDecode(object):
def __init__(self, **kwargs):
pass
def __call__(self, data):
label = json.loads(data['label'])
nBox = len(label)
seal_boxes = self.get_seal_boxes(label)
gt_label = []
for seal_box in seal_boxes:
seal_anno = {'seal_box': seal_box}
boxes, txts, txt_tags = [], [], []
for bno in range(0, nBox):
box = label[bno]['points']
txt = label[bno]['transcription']
try:
ints = self.get_intersection(box, seal_box)
except Exception as E:
print(E)
continue
if abs(Polygon(box).area - self.get_intersection(box, seal_box)) < 1e-3 and \
abs(Polygon(box).area - self.get_union(box, seal_box)) > 1e-3:
boxes.append(box)
txts.append(txt)
if txt in ['*', '###', '待识别']:
txt_tags.append(True)
else:
txt_tags.append(False)
seal_anno['polys'] = boxes
seal_anno['texts'] = txts
seal_anno['ignore_tags'] = txt_tags
gt_label.append(seal_anno)
return gt_label
def get_seal_boxes(self, label):
nBox = len(label)
seal_box = []
for bno in range(0, nBox):
box = label[bno]['points']
if len(box) == 4:
seal_box.append(box)
if len(seal_box) == 0:
return None
seal_box = self.valid_seal_box(seal_box)
return seal_box
def is_seal_box(self, box, boxes):
is_seal = True
for poly in boxes:
if list(box.shape()) != list(box.shape.shape()):
if abs(Polygon(box).area - self.get_intersection(box, poly)) < 1e-3:
return False
else:
if np.sum(np.array(box) - np.array(poly)) < 1e-3:
# continue when the box is same with poly
continue
if abs(Polygon(box).area - self.get_intersection(box, poly)) < 1e-3:
return False
return is_seal
def valid_seal_box(self, boxes):
if len(boxes) == 1:
return boxes
new_boxes = []
flag = True
for k in range(0, len(boxes)):
flag = True
tmp_box = boxes[k]
for i in range(0, len(boxes)):
if k == i: continue
if abs(Polygon(tmp_box).area - self.get_intersection(tmp_box, boxes[i])) < 1e-3:
flag = False
continue
if flag:
new_boxes.append(tmp_box)
return new_boxes
def get_union(self, pD, pG):
return Polygon(pD).union(Polygon(pG)).area
def get_intersection_over_union(self, pD, pG):
return get_intersection(pD, pG) / get_union(pD, pG)
def get_intersection(self, pD, pG):
return Polygon(pD).intersection(Polygon(pG)).area
def expand_points_num(self, boxes):
max_points_num = 0
for box in boxes:
if len(box) > max_points_num:
max_points_num = len(box)
ex_boxes = []
for box in boxes:
ex_box = box + [box[-1]] * (max_points_num - len(box))
ex_boxes.append(ex_box)
return ex_boxes
def gen_extract_label(data_dir, label_file, seal_gt, seal_ppocr_gt):
label_decode_func = LabelDecode()
gts = open(label_file, "r").readlines()
seal_gt_list = []
seal_ppocr_list = []
for idx, line in enumerate(gts):
img_path, label = line.strip().split("\t")
data = {'label': label, 'img_path':img_path}
res = label_decode_func(data)
src_img = cv2.imread(os.path.join(data_dir, img_path))
if res is None:
print("ERROR! res is None!")
continue
anno = []
for i, gt in enumerate(res):
# print(i, box, type(box), )
anno.append({'polys': gt['seal_box'], 'cls':1})
seal_gt_list.append(f"{img_path}\t{json.dumps(anno)}\n")
seal_ppocr_list.append(f"{img_path}\t{json.dumps(res)}\n")
if not os.path.exists(os.path.dirname(seal_gt)):
os.makedirs(os.path.dirname(seal_gt))
if not os.path.exists(os.path.dirname(seal_ppocr_gt)):
os.makedirs(os.path.dirname(seal_ppocr_gt))
with open(seal_gt, "w") as f:
f.writelines(seal_gt_list)
f.close()
with open(seal_ppocr_gt, 'w') as f:
f.writelines(seal_ppocr_list)
f.close()
def vis_seal_ppocr(data_dir, label_file, save_dir):
datas = open(label_file, 'r').readlines()
for idx, line in enumerate(datas):
img_path, label = line.strip().split('\t')
img_path = os.path.join(data_dir, img_path)
label = json.loads(label)
src_im = cv2.imread(img_path)
if src_im is None:
continue
for anno in label:
seal_box = anno['seal_box']
txt_boxes = anno['polys']
# vis seal box
src_im = draw_text_det_res([seal_box], src_im, color=(255, 255, 0))
src_im = draw_text_det_res(txt_boxes, src_im, color=(255, 0, 0))
save_path = os.path.join(save_dir, os.path.basename(img_path))
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# print(src_im.shape)
cv2.imwrite(save_path, src_im)
def draw_html(img_dir, save_name):
import glob
images_dir = glob.glob(img_dir + "/*")
print(len(images_dir))
html_path = save_name
with open(html_path, 'w') as html:
html.write('<html>\n<body>\n')
html.write('<table border="1">\n')
html.write("<meta http-equiv=\"Content-Type\" content=\"text/html; charset=utf-8\" />")
html.write("<tr>\n")
html.write(f'<td> \n GT')
for i, filename in enumerate(sorted(images_dir)):
if filename.endswith("txt"): continue
print(filename)
base = "{}".format(filename)
if True:
html.write("<tr>\n")
html.write(f'<td> {filename}\n GT')
html.write('<td>GT 310\n<img src="%s" width=640></td>' % (base))
html.write("</tr>\n")
html.write('<style>\n')
html.write('span {\n')
html.write(' color: red;\n')
html.write('}\n')
html.write('</style>\n')
html.write('</table>\n')
html.write('</html>\n</body>\n')
print("ok")
def crop_seal_from_img(label_file, data_dir, save_dir, save_gt_path):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
datas = open(label_file, 'r').readlines()
all_gts = []
count = 0
for idx, line in enumerate(datas):
img_path, label = line.strip().split('\t')
img_path = os.path.join(data_dir, img_path)
label = json.loads(label)
src_im = cv2.imread(img_path)
if src_im is None:
continue
for c, anno in enumerate(label):
seal_poly = anno['seal_box']
txt_boxes = anno['polys']
txts = anno['texts']
ignore_tags = anno['ignore_tags']
box = poly2box(seal_poly)
img_crop = src_im[box[0][1]:box[2][1], box[0][0]:box[2][0], :]
save_path = os.path.join(save_dir, f"{idx}_{c}.jpg")
cv2.imwrite(save_path, np.array(img_crop))
img_gt = []
for i in range(len(txts)):
txt_boxes_crop = np.array(txt_boxes[i])
txt_boxes_crop[:, 1] -= box[0, 1]
txt_boxes_crop[:, 0] -= box[0, 0]
img_gt.append({'transcription': txts[i], "points": txt_boxes_crop.tolist(), "ignore_tag": ignore_tags[i]})
if len(img_gt) >= 1:
count += 1
save_gt = f"{os.path.basename(save_path)}\t{json.dumps(img_gt)}\n"
all_gts.append(save_gt)
print(f"The num of all image: {len(all_gts)}, and the number of useful image: {count}")
if not os.path.exists(os.path.dirname(save_gt_path)):
os.makedirs(os.path.dirname(save_gt_path))
with open(save_gt_path, "w") as f:
f.writelines(all_gts)
f.close()
print("Done")
if __name__ == "__main__":
# 数据处理
gen_extract_label("./seal_labeled_datas", "./seal_labeled_datas/Label.txt", "./seal_ppocr_gt/seal_det_img.txt", "./seal_ppocr_gt/seal_ppocr_img.txt")
vis_seal_ppocr("./seal_labeled_datas", "./seal_ppocr_gt/seal_ppocr_img.txt", "./seal_ppocr_gt/seal_ppocr_vis/")
draw_html("./seal_ppocr_gt/seal_ppocr_vis/", "./vis_seal_ppocr.html")
seal_ppocr_img_label = "./seal_ppocr_gt/seal_ppocr_img.txt"
crop_seal_from_img(seal_ppocr_img_label, "./seal_labeled_datas/", "./seal_img_crop", "./seal_img_crop/label.txt")
```
处理完成后,生成的文件如下:
```
├── seal_img_crop/
│ ├── 0_0.jpg
│ ├── ...
│ └── label.txt
├── seal_ppocr_gt/
│ ├── seal_det_img.txt
│ ├── seal_ppocr_img.txt
│ └── seal_ppocr_vis/
│ ├── test1.png
│ ├── ...
└── vis_seal_ppocr.html
```
其中`seal_img_crop/label.txt`文件为印章识别标签文件,其内容格式为:
```
0_0.jpg [{"transcription": "\u7535\u5b50\u56de\u5355", "points": [[29, 73], [96, 73], [96, 90], [29, 90]], "ignore_tag": false}, {"transcription": "\u4e91\u5357\u7701\u519c\u6751\u4fe1\u7528\u793e", "points": [[9, 58], [26, 63], [30, 49], [38, 35], [47, 29], [64, 26], [81, 32], [90, 45], [94, 63], [118, 57], [110, 35], [95, 17], [67, 0], [38, 7], [21, 23], [10, 43]], "ignore_tag": false}, {"transcription": "\u4e13\u7528\u7ae0", "points": [[29, 87], [95, 87], [95, 106], [29, 106]], "ignore_tag": false}]
```
可以直接用于PaddleOCR的PGNet算法的训练。
`seal_ppocr_gt/seal_det_img.txt`为印章检测标签文件,其内容格式为:
```
img/test1.png [{"polys": [[408, 232], [537, 232], [537, 352], [408, 352]], "cls": 1}]
```
为了使用PaddleDetection工具完成印章检测模型的训练,需要将`seal_det_img.txt`转换为COCO或者VOC的数据标注格式。
可以直接使用下述代码将印章检测标注转换成VOC格式。
```
import numpy as np
import json
import cv2
import os
from shapely.geometry import Polygon
seal_train_gt = "./seal_ppocr_gt/seal_det_img.txt"
# 注:仅用于示例,实际使用中需要分别转换训练集和测试集的标签
seal_valid_gt = "./seal_ppocr_gt/seal_det_img.txt"
def gen_main_train_txt(mode='train'):
if mode == "train":
file_path = seal_train_gt
if mode in ['valid', 'test']:
file_path = seal_valid_gt
save_path = f"./seal_VOC/ImageSets/Main/{mode}.txt"
save_train_path = f"./seal_VOC/{mode}.txt"
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
datas = open(file_path, 'r').readlines()
img_names = []
train_names = []
for line in datas:
img_name = line.strip().split('\t')[0]
img_name = os.path.basename(img_name)
(i_name, extension) = os.path.splitext(img_name)
t_name = 'JPEGImages/'+str(img_name)+' '+'Annotations/'+str(i_name)+'.xml\n'
train_names.append(t_name)
img_names.append(i_name + "\n")
with open(save_train_path, "w") as f:
f.writelines(train_names)
f.close()
with open(save_path, "w") as f:
f.writelines(img_names)
f.close()
print(f"{mode} save done")
def gen_xml_label(mode='train'):
if mode == "train":
file_path = seal_train_gt
if mode in ['valid', 'test']:
file_path = seal_valid_gt
datas = open(file_path, 'r').readlines()
img_names = []
train_names = []
anno_path = "./seal_VOC/Annotations"
img_path = "./seal_VOC/JPEGImages"
if not os.path.exists(anno_path):
os.makedirs(anno_path)
if not os.path.exists(img_path):
os.makedirs(img_path)
for idx, line in enumerate(datas):
img_name, label = line.strip().split('\t')
img = cv2.imread(os.path.join("./seal_labeled_datas", img_name))
cv2.imwrite(os.path.join(img_path, os.path.basename(img_name)), img)
height, width, c = img.shape
img_name = os.path.basename(img_name)
(i_name, extension) = os.path.splitext(img_name)
label = json.loads(label)
xml_file = open(("./seal_VOC/Annotations" + '/' + i_name + '.xml'), 'w')
xml_file.write('<annotation>\n')
xml_file.write(' <folder>seal_VOC</folder>\n')
xml_file.write(' <filename>' + str(img_name) + '</filename>\n')
xml_file.write(' <path>' + 'Annotations/' + str(img_name) + '</path>\n')
xml_file.write(' <size>\n')
xml_file.write(' <width>' + str(width) + '</width>\n')
xml_file.write(' <height>' + str(height) + '</height>\n')
xml_file.write(' <depth>3</depth>\n')
xml_file.write(' </size>\n')
xml_file.write(' <segmented>0</segmented>\n')
for anno in label:
poly = anno['polys']
if anno['cls'] == 1:
gt_cls = 'redseal'
xmin = np.min(np.array(poly)[:, 0])
ymin = np.min(np.array(poly)[:, 1])
xmax = np.max(np.array(poly)[:, 0])
ymax = np.max(np.array(poly)[:, 1])
xmin,ymin,xmax,ymax= int(xmin),int(ymin),int(xmax),int(ymax)
xml_file.write(' <object>\n')
xml_file.write(' <name>'+str(gt_cls)+'</name>\n')
xml_file.write(' <pose>Unspecified</pose>\n')
xml_file.write(' <truncated>0</truncated>\n')
xml_file.write(' <difficult>0</difficult>\n')
xml_file.write(' <bndbox>\n')
xml_file.write(' <xmin>'+str(xmin)+'</xmin>\n')
xml_file.write(' <ymin>'+str(ymin)+'</ymin>\n')
xml_file.write(' <xmax>'+str(xmax)+'</xmax>\n')
xml_file.write(' <ymax>'+str(ymax)+'</ymax>\n')
xml_file.write(' </bndbox>\n')
xml_file.write(' </object>\n')
xml_file.write('</annotation>')
xml_file.close()
print(f'{mode} xml save done!')
gen_main_train_txt()
gen_main_train_txt('valid')
gen_xml_label('train')
gen_xml_label('valid')
```
数据处理完成后,转换为VOC格式的印章检测数据存储在~/data/seal_VOC目录下,目录组织结构为:
```
├── Annotations/
├── ImageSets/
│   └── Main/
│   ├── train.txt
│   └── valid.txt
├── JPEGImages/
├── train.txt
└── valid.txt
└── label_list.txt
```
Annotations下为数据的标签,JPEGImages目录下为图像文件,label_list.txt为标注检测框类别标签文件。
在接下来一节中,将介绍如何使用PaddleDetection工具库完成印章检测模型的训练。
# 4. 印章检测实践
在实际应用中,印章多是出现在合同,发票,公告等场景中,印章文字识别的任务需要排除图像中背景文字的影响,因此需要先检测出图像中的印章区域。
借助PaddleDetection目标检测库可以很容易的实现印章检测任务,使用PaddleDetection训练印章检测任务流程如下:
- 选择算法
- 修改数据集配置路径
- 启动训练
**算法选择**
PaddleDetection中有许多检测算法可以选择,考虑到每条数据中印章区域较为清晰,且考虑到性能需求。在本项目中,我们采用mobilenetv3为backbone的ppyolo算法完成印章检测任务,对应的配置文件是:configs/ppyolo/ppyolo_mbv3_large.yml
**修改配置文件**
配置文件中的默认数据路径是COCO,
需要修改为印章检测的数据路径,主要修改如下:
在配置文件'configs/ppyolo/ppyolo_mbv3_large.yml'末尾增加如下内容:
```
metric: VOC
map_type: 11point
num_classes: 2
TrainDataset:
!VOCDataSet
dataset_dir: dataset/seal_VOC
anno_path: train.txt
label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
EvalDataset:
!VOCDataSet
dataset_dir: dataset/seal_VOC
anno_path: test.txt
label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
TestDataset:
!ImageFolder
anno_path: dataset/seal_VOC/label_list.txt
```
配置文件中设置的数据路径在PaddleDetection/dataset目录下,我们可以将处理后的印章检测训练数据移动到PaddleDetection/dataset目录下或者创建一个软连接。
```
!ln -s seal_VOC ./PaddleDetection/dataset/
```
另外图象中印章数量比较少,可以调整NMS后处理的检测框数量,即keep_top_k,nms_top_k 从100,1000,调整为10,100。在配置文件'configs/ppyolo/ppyolo_mbv3_large.yml'末尾增加如下内容完成后处理参数的调整
```
BBoxPostProcess:
decode:
name: YOLOBox
conf_thresh: 0.005
downsample_ratio: 32
clip_bbox: true
scale_x_y: 1.05
nms:
name: MultiClassNMS
keep_top_k: 10 # 修改前100
nms_threshold: 0.45
nms_top_k: 100 # 修改前1000
score_threshold: 0.005
```
修改完成后,需要在PaddleDetection中增加印章数据的处理代码,即在PaddleDetection/ppdet/data/source/目录下创建seal.py文件,文件中填充如下代码:
```
import os
import numpy as np
from ppdet.core.workspace import register, serializable
from .dataset import DetDataset
import cv2
import json
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
@register
@serializable
class SealDataSet(DetDataset):
"""
Load dataset with COCO format.
Args:
dataset_dir (str): root directory for dataset.
image_dir (str): directory for images.
anno_path (str): coco annotation file path.
data_fields (list): key name of data dictionary, at least have 'image'.
sample_num (int): number of samples to load, -1 means all.
load_crowd (bool): whether to load crowded ground-truth.
False as default
allow_empty (bool): whether to load empty entry. False as default
empty_ratio (float): the ratio of empty record number to total
record's, if empty_ratio is out of [0. ,1.), do not sample the
records and use all the empty entries. 1. as default
"""
def __init__(self,
dataset_dir=None,
image_dir=None,
anno_path=None,
data_fields=['image'],
sample_num=-1,
load_crowd=False,
allow_empty=False,
empty_ratio=1.):
super(SealDataSet, self).__init__(dataset_dir, image_dir, anno_path,
data_fields, sample_num)
self.load_image_only = False
self.load_semantic = False
self.load_crowd = load_crowd
self.allow_empty = allow_empty
self.empty_ratio = empty_ratio
def _sample_empty(self, records, num):
# if empty_ratio is out of [0. ,1.), do not sample the records
if self.empty_ratio < 0. or self.empty_ratio >= 1.:
return records
import random
sample_num = min(
int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
records = random.sample(records, sample_num)
return records
def parse_dataset(self):
anno_path = os.path.join(self.dataset_dir, self.anno_path)
image_dir = os.path.join(self.dataset_dir, self.image_dir)
records = []
empty_records = []
ct = 0
assert anno_path.endswith('.txt'), \
'invalid seal_gt file: ' + anno_path
all_datas = open(anno_path, 'r').readlines()
for idx, line in enumerate(all_datas):
im_path, label = line.strip().split('\t')
img_path = os.path.join(image_dir, im_path)
label = json.loads(label)
im_h, im_w, im_c = cv2.imread(img_path).shape
coco_rec = {
'im_file': img_path,
'im_id': np.array([idx]),
'h': im_h,
'w': im_w,
} if 'image' in self.data_fields else {}
if not self.load_image_only:
bboxes = []
for anno in label:
poly = anno['polys']
# poly to box
x1 = np.min(np.array(poly)[:, 0])
y1 = np.min(np.array(poly)[:, 1])
x2 = np.max(np.array(poly)[:, 0])
y2 = np.max(np.array(poly)[:, 1])
eps = 1e-5
if x2 - x1 > eps and y2 - y1 > eps:
clean_box = [
round(float(x), 3) for x in [x1, y1, x2, y2]
]
anno = {'clean_box': clean_box, 'gt_cls':int(anno['cls'])}
bboxes.append(anno)
else:
logger.info("invalid box")
num_bbox = len(bboxes)
if num_bbox <= 0:
continue
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
# gt_poly = [None] * num_bbox
for i, box in enumerate(bboxes):
gt_class[i][0] = box['gt_cls']
gt_bbox[i, :] = box['clean_box']
is_crowd[i][0] = 0
gt_rec = {
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
# 'gt_poly': gt_poly,
}
for k, v in gt_rec.items():
if k in self.data_fields:
coco_rec[k] = v
records.append(coco_rec)
ct += 1
if self.sample_num > 0 and ct >= self.sample_num:
break
self.roidbs = records
```
**启动训练**
启动单卡训练的命令为:
```
!python3 tools/train.py -c configs/ppyolo/ppyolo_mbv3_large.yml --eval
# 分布式训练命令为:
!python3 -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyolo/ppyolo_mbv3_large.yml --eval
```
训练完成后,日志中会打印模型的精度:
```
[07/05 11:42:09] ppdet.engine INFO: Eval iter: 0
[07/05 11:42:14] ppdet.metrics.metrics INFO: Accumulating evaluatation results...
[07/05 11:42:14] ppdet.metrics.metrics INFO: mAP(0.50, 11point) = 99.31%
[07/05 11:42:14] ppdet.engine INFO: Total sample number: 112, averge FPS: 26.45840794253432
[07/05 11:42:14] ppdet.engine INFO: Best test bbox ap is 0.996.
```
我们可以使用训练好的模型观察预测结果:
```
!python3 tools/infer.py -c configs/ppyolo/ppyolo_mbv3_large.yml -o weights=./output/ppyolo_mbv3_large/model_final.pdparams --img_dir=./test.jpg
```
预测结果如下:
![](https://ai-studio-static-online.cdn.bcebos.com/0f650c032b0f4d56bd639713924768cc820635e9977845008d233f465291a29e)
# 5. 印章文字识别实践
在使用ppyolo检测到印章区域后,接下来借助PaddleOCR里的文字识别能力,完成印章中文字的识别。
PaddleOCR中的OCR算法包含文字检测算法,文字识别算法以及OCR端对端算法。
文字检测算法负责检测到图像中的文字,再由文字识别模型识别出检测到的文字,进而实现OCR的任务。文字检测+文字识别串联完成OCR任务的架构称为两阶段的OCR算法。相对应的端对端的OCR方法可以用一个算法同时完成文字检测和识别的任务。
| 文字检测 | 文字识别 | 端对端算法 |
| -------- | -------- | -------- |
| DB\DB++\EAST\SAST\PSENet | SVTR\CRNN\NRTN\Abinet\SAR\... | PGNet |
本节中将分别介绍端对端的文字检测识别算法以及两阶段的文字检测识别算法在印章检测识别任务上的实践。
## 5.1 端对端印章文字识别实践
本节介绍使用PaddleOCR里的PGNet算法完成印章文字识别。
PGNet属于端对端的文字检测识别算法,在PaddleOCR中的配置文件为:
[PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/e2e/e2e_r50_vd_pg.yml)
使用PGNet完成文字检测识别任务的步骤为:
- 修改配置文件
- 启动训练
PGNet默认配置文件的数据路径为totaltext数据集路径,本次训练中,需要修改为上一节数据处理后得到的标签文件和数据目录:
训练数据配置修改后如下:
```
Train:
dataset:
name: PGDataSet
data_dir: ./train_data/seal_ppocr
label_file_list: [./train_data/seal_ppocr/seal_ppocr_img.txt]
ratio_list: [1.0]
```
测试数据集配置修改后如下:
```
Eval:
dataset:
name: PGDataSet
data_dir: ./train_data/seal_ppocr_test
label_file_list: [./train_data/seal_ppocr_test/seal_ppocr_img.txt]
```
启动训练的命令为:
```
!python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml
```
模型训练完成后,可以得到最终的精度为47.4%。数据量较少,以及数据质量较差会影响模型的训练精度,如果有更多的数据参与训练,精度将进一步提升。
如需获取已训练模型,请扫文末的二维码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁
## 5.2 两阶段印章文字识别实践
上一节介绍了使用PGNet实现印章识别任务的训练流程。本小节将介绍使用PaddleOCR里的文字检测和文字识别算法分别完成印章文字的检测和识别。
### 5.2.1 印章文字检测
PaddleOCR中包含丰富的文字检测算法,包含DB,DB++,EAST,SAST,PSENet等等。其中DB,DB++,PSENet均支持弯曲文字检测,本项目中,使用DB++作为印章弯曲文字检测算法。
PaddleOCR中发布的db++文字检测算法模型是英文文本检测模型,因此需要重新训练模型。
修改[DB++配置文件](DB++的默认配置文件位于[configs/det/det_r50_db++_icdar15.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/det/det_r50_db%2B%2B_icdar15.yml)
中的数据路径:
```
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/seal_ppocr
label_file_list: [./train_data/seal_ppocr/seal_ppocr_img.txt]
ratio_list: [1.0]
```
测试数据集配置修改后如下:
```
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/seal_ppocr_test
label_file_list: [./train_data/seal_ppocr_test/seal_ppocr_img.txt]
```
启动训练:
```
!python3 tools/train.py -c configs/det/det_r50_db++_icdar15.yml -o Global.epoch_num=100
```
考虑到数据较少,通过Global.epoch_num设置仅训练100个epoch。
模型训练完成后,在测试集上预测的可视化效果如下:
![](https://ai-studio-static-online.cdn.bcebos.com/498119182f0a414ab86ae2de752fa31c9ddc3a74a76847049cc57884602cb269)
如需获取已训练模型,请扫文末的二维码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁
### 5.2.2 印章文字识别
上一节中完成了印章文字的检测模型训练,本节介绍印章文字识别模型的训练。识别模型采用SVTR算法,SVTR算法是IJCAI收录的文字识别算法,SVTR模型具备超轻量高精度的特点。
在启动训练之前,需要准备印章文字识别需要的数据集,需要使用如下代码,将印章中的文字区域剪切出来构建训练集。
```
import cv2
import numpy as np
def get_rotate_crop_image(img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def run(data_dir, label_file, save_dir):
datas = open(label_file, 'r').readlines()
for idx, line in enumerate(datas):
img_path, label = line.strip().split('\t')
img_path = os.path.join(data_dir, img_path)
label = json.loads(label)
src_im = cv2.imread(img_path)
if src_im is None:
continue
for anno in label:
seal_box = anno['seal_box']
txt_boxes = anno['polys']
crop_im = get_rotate_crop_image(src_im, text_boxes)
save_path = os.path.join(save_dir, f'{idx}.png')
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# print(src_im.shape)
cv2.imwrite(save_path, crop_im)
```
数据处理完成后,即可配置训练的配置文件。SVTR配置文件选择[configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml)
修改SVTR配置文件中的训练数据部分如下:
```
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/seal_ppocr_crop/
label_file_list:
- ./train_data/seal_ppocr_crop/train_list.txt
```
修改预测部分配置文件:
```
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/seal_ppocr_crop/
label_file_list:
- ./train_data/seal_ppocr_crop_test/train_list.txt
```
启动训练:
```
!python3 tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml
```
训练完成后可以发现测试集指标达到了61%。
由于数据较少,训练时会发现在训练集上的acc指标远大于测试集上的acc指标,即出现过拟合现象。通过补充数据和一些数据增强可以缓解这个问题。
如需获取已训练模型,请扫下图二维码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁
![](https://ai-studio-static-online.cdn.bcebos.com/ea32877b717643289dc2121a2e573526d99d0f9eecc64ad4bd8dcf121cb5abde)
...@@ -12,7 +12,7 @@ Global: ...@@ -12,7 +12,7 @@ Global:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: ./inference/rec_inference infer_img: doc/imgs_words_en/word_10.png
# for data or label process # for data or label process
character_dict_path: ppocr/utils/dict90.txt character_dict_path: ppocr/utils/dict90.txt
max_text_length: &max_text_length 40 max_text_length: &max_text_length 40
......
...@@ -12,7 +12,7 @@ Global: ...@@ -12,7 +12,7 @@ Global:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg infer_img: doc/imgs_words_en/word_10.png
# for data or label process # for data or label process
character_dict_path: ./ppocr/utils/dict/spin_dict.txt character_dict_path: ./ppocr/utils/dict/spin_dict.txt
max_text_length: 25 max_text_length: 25
......
...@@ -43,7 +43,6 @@ Architecture: ...@@ -43,7 +43,6 @@ Architecture:
Head: Head:
name: TableAttentionHead name: TableAttentionHead
hidden_size: 256 hidden_size: 256
loc_type: 2
max_text_length: *max_text_length max_text_length: *max_text_length
loc_reg_num: &loc_reg_num 4 loc_reg_num: &loc_reg_num 4
......
...@@ -101,11 +101,10 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广 ...@@ -101,11 +101,10 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) | |ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) | |ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) | |VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | |SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon | |RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) | |RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) |
<a name="2"></a> <a name="2"></a>
## 2. 端到端算法 ## 2. 端到端算法
......
...@@ -26,7 +26,7 @@ Zhang ...@@ -26,7 +26,7 @@ Zhang
|模型|骨干网络|配置文件|Acc|下载链接| |模型|骨干网络|配置文件|Acc|下载链接|
| --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- |
|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon| |RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
注:除了使用MJSynth和SynthText两个文字识别数据集外,还加入了[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg)数据(提取码:627x),和部分真实数据,具体数据细节可以参考论文。 注:除了使用MJSynth和SynthText两个文字识别数据集外,还加入了[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg)数据(提取码:627x),和部分真实数据,具体数据细节可以参考论文。
......
...@@ -26,7 +26,7 @@ SPIN收录于AAAI2020。主要用于OCR识别任务。在任意形状文本识 ...@@ -26,7 +26,7 @@ SPIN收录于AAAI2020。主要用于OCR识别任务。在任意形状文本识
|模型|骨干网络|配置文件|Acc|下载链接| |模型|骨干网络|配置文件|Acc|下载链接|
| --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- |
|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon| |SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar)|
<a name="2"></a> <a name="2"></a>
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
| 参数名称 | 类型 | 默认值 | 含义 | | 参数名称 | 类型 | 默认值 | 含义 |
| :--: | :--: | :--: | :--: | | :--: | :--: | :--: | :--: |
| image_dir | str | 无,必须显式指定 | 图像或者文件夹路径 | | image_dir | str | 无,必须显式指定 | 图像或者文件夹路径 |
| page_num | int | 0 | 当输入类型为pdf文件时有效,指定预测前面page_num页,默认预测所有页 |
| vis_font_path | str | "./doc/fonts/simfang.ttf" | 用于可视化的字体路径 | | vis_font_path | str | "./doc/fonts/simfang.ttf" | 用于可视化的字体路径 |
| drop_score | float | 0.5 | 识别得分小于该值的结果会被丢弃,不会作为返回结果 | | drop_score | float | 0.5 | 识别得分小于该值的结果会被丢弃,不会作为返回结果 |
| use_pdserving | bool | False | 是否使用Paddle Serving进行预测 | | use_pdserving | bool | False | 是否使用Paddle Serving进行预测 |
......
...@@ -98,8 +98,8 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r ...@@ -98,8 +98,8 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) | |ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) | |ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) | |VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | |SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon | |RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) | |RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) |
<a name="2"></a> <a name="2"></a>
......
...@@ -26,7 +26,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval ...@@ -26,7 +26,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval
|Model|Backbone|config|Acc|Download link| |Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- |
|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon| |RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
Note:In addition to using the two text recognition datasets MJSynth and SynthText, [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg) data (extraction code: 627x), and some real data are used in training, the specific data details can refer to the paper. Note:In addition to using the two text recognition datasets MJSynth and SynthText, [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg) data (extraction code: 627x), and some real data are used in training, the specific data details can refer to the paper.
......
...@@ -25,7 +25,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval ...@@ -25,7 +25,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval
|Model|Backbone|config|Acc|Download link| |Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- |
|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon| |SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
<a name="2"></a> <a name="2"></a>
......
...@@ -7,6 +7,7 @@ When using PaddleOCR for model inference, you can customize the modification par ...@@ -7,6 +7,7 @@ When using PaddleOCR for model inference, you can customize the modification par
| parameters | type | default | implication | | parameters | type | default | implication |
| :--: | :--: | :--: | :--: | | :--: | :--: | :--: | :--: |
| image_dir | str | None, must be specified explicitly | Image or folder path | | image_dir | str | None, must be specified explicitly | Image or folder path |
| page_num | int | 0 | Valid when the input type is pdf file, specify to predict the previous page_num pages, all pages are predicted by default |
| vis_font_path | str | "./doc/fonts/simfang.ttf" | font path for visualization | | vis_font_path | str | "./doc/fonts/simfang.ttf" | font path for visualization |
| drop_score | float | 0.5 | Results with a recognition score less than this value will be discarded and will not be returned as results | | drop_score | float | 0.5 | Results with a recognition score less than this value will be discarded and will not be returned as results |
| use_pdserving | bool | False | Whether to use Paddle Serving for prediction | | use_pdserving | bool | False | Whether to use Paddle Serving for prediction |
......
...@@ -480,10 +480,11 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -480,10 +480,11 @@ class PaddleOCR(predict_system.TextSystem):
params.rec_image_shape = "3, 48, 320" params.rec_image_shape = "3, 48, 320"
else: else:
params.rec_image_shape = "3, 32, 320" params.rec_image_shape = "3, 32, 320"
# download model # download model if using paddle infer
maybe_download(params.det_model_dir, det_url) if not params.use_onnx:
maybe_download(params.rec_model_dir, rec_url) maybe_download(params.det_model_dir, det_url)
maybe_download(params.cls_model_dir, cls_url) maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.cls_model_dir, cls_url)
if params.det_algorithm not in SUPPORT_DET_MODEL: if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
......
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import math
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from paddle import ParamAttr from paddle import ParamAttr
...@@ -42,7 +43,6 @@ class TableAttentionHead(nn.Layer): ...@@ -42,7 +43,6 @@ class TableAttentionHead(nn.Layer):
def __init__(self, def __init__(self,
in_channels, in_channels,
hidden_size, hidden_size,
loc_type,
in_max_len=488, in_max_len=488,
max_text_length=800, max_text_length=800,
out_channels=30, out_channels=30,
...@@ -57,20 +57,16 @@ class TableAttentionHead(nn.Layer): ...@@ -57,20 +57,16 @@ class TableAttentionHead(nn.Layer):
self.structure_attention_cell = AttentionGRUCell( self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.out_channels, use_gru=False) self.input_size, hidden_size, self.out_channels, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.out_channels) self.structure_generator = nn.Linear(hidden_size, self.out_channels)
self.loc_type = loc_type
self.in_max_len = in_max_len self.in_max_len = in_max_len
if self.loc_type == 1: if self.in_max_len == 640:
self.loc_generator = nn.Linear(hidden_size, 4) self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1)
elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
else: else:
if self.in_max_len == 640: self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1) self.loc_generator = nn.Linear(self.input_size + hidden_size,
elif self.in_max_len == 800: loc_reg_num)
self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
else:
self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size,
loc_reg_num)
def _char_to_onehot(self, input_char, onehot_dim): def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim) input_ont_hot = F.one_hot(input_char, onehot_dim)
...@@ -80,16 +76,13 @@ class TableAttentionHead(nn.Layer): ...@@ -80,16 +76,13 @@ class TableAttentionHead(nn.Layer):
# if and else branch are both needed when you want to assign a variable # if and else branch are both needed when you want to assign a variable
# if you modify the var in just one branch, then the modification will not work. # if you modify the var in just one branch, then the modification will not work.
fea = inputs[-1] fea = inputs[-1]
if len(fea.shape) == 3: last_shape = int(np.prod(fea.shape[2:])) # gry added
pass fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
else: fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
last_shape = int(np.prod(fea.shape[2:])) # gry added
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
batch_size = fea.shape[0] batch_size = fea.shape[0]
hidden = paddle.zeros((batch_size, self.hidden_size)) hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = [] output_hiddens = paddle.zeros((batch_size, self.max_text_length + 1, self.hidden_size))
if self.training and targets is not None: if self.training and targets is not None:
structure = targets[0] structure = targets[0]
for i in range(self.max_text_length + 1): for i in range(self.max_text_length + 1):
...@@ -97,7 +90,8 @@ class TableAttentionHead(nn.Layer): ...@@ -97,7 +90,8 @@ class TableAttentionHead(nn.Layer):
structure[:, i], onehot_dim=self.out_channels) structure[:, i], onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots) hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output_hiddens[:, i, :] = outputs
# output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1) output = paddle.concat(output_hiddens, axis=1)
structure_probs = self.structure_generator(output) structure_probs = self.structure_generator(output)
if self.loc_type == 1: if self.loc_type == 1:
...@@ -118,30 +112,25 @@ class TableAttentionHead(nn.Layer): ...@@ -118,30 +112,25 @@ class TableAttentionHead(nn.Layer):
outputs = None outputs = None
alpha = None alpha = None
max_text_length = paddle.to_tensor(self.max_text_length) max_text_length = paddle.to_tensor(self.max_text_length)
i = 0 for i in range(max_text_length + 1):
while i < max_text_length + 1:
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.out_channels) temp_elem, onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots) hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output_hiddens[:, i, :] = outputs
# output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
structure_probs_step = self.structure_generator(outputs) structure_probs_step = self.structure_generator(outputs)
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32") temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
i += 1
output = paddle.concat(output_hiddens, axis=1) output = output_hiddens
structure_probs = self.structure_generator(output) structure_probs = self.structure_generator(output)
structure_probs = F.softmax(structure_probs) structure_probs = F.softmax(structure_probs)
if self.loc_type == 1: loc_fea = fea.transpose([0, 2, 1])
loc_preds = self.loc_generator(output) loc_fea = self.loc_fea_trans(loc_fea)
loc_preds = F.sigmoid(loc_preds) loc_fea = loc_fea.transpose([0, 2, 1])
else: loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_fea = fea.transpose([0, 2, 1]) loc_preds = self.loc_generator(loc_concat)
loc_fea = self.loc_fea_trans(loc_fea) loc_preds = F.sigmoid(loc_preds)
loc_fea = loc_fea.transpose([0, 2, 1])
loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs': structure_probs, 'loc_preds': loc_preds} return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
......
...@@ -114,7 +114,7 @@ python3 table/eval_table.py \ ...@@ -114,7 +114,7 @@ python3 table/eval_table.py \
--det_model_dir=path/to/det_model_dir \ --det_model_dir=path/to/det_model_dir \
--rec_model_dir=path/to/rec_model_dir \ --rec_model_dir=path/to/rec_model_dir \
--table_model_dir=path/to/table_model_dir \ --table_model_dir=path/to/table_model_dir \
--image_dir=../doc/table/1.png \ --image_dir=docs/table/table.jpg \
--rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \ --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--det_limit_side_len=736 \ --det_limit_side_len=736 \
...@@ -145,6 +145,7 @@ python3 table/eval_table.py \ ...@@ -145,6 +145,7 @@ python3 table/eval_table.py \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--det_limit_side_len=736 \ --det_limit_side_len=736 \
--det_limit_type=min \ --det_limit_type=min \
--rec_image_shape=3,32,320 \
--gt_path=path/to/gt.txt --gt_path=path/to/gt.txt
``` ```
......
...@@ -118,7 +118,7 @@ python3 table/eval_table.py \ ...@@ -118,7 +118,7 @@ python3 table/eval_table.py \
--det_model_dir=path/to/det_model_dir \ --det_model_dir=path/to/det_model_dir \
--rec_model_dir=path/to/rec_model_dir \ --rec_model_dir=path/to/rec_model_dir \
--table_model_dir=path/to/table_model_dir \ --table_model_dir=path/to/table_model_dir \
--image_dir=../doc/table/1.png \ --image_dir=docs/table/table.jpg \
--rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \ --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--det_limit_side_len=736 \ --det_limit_side_len=736 \
...@@ -149,6 +149,7 @@ python3 table/eval_table.py \ ...@@ -149,6 +149,7 @@ python3 table/eval_table.py \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--det_limit_side_len=736 \ --det_limit_side_len=736 \
--det_limit_type=min \ --det_limit_type=min \
--rec_image_shape=3,32,320 \
--gt_path=path/to/gt.txt --gt_path=path/to/gt.txt
``` ```
......
...@@ -15,3 +15,4 @@ premailer ...@@ -15,3 +15,4 @@ premailer
openpyxl openpyxl
attrdict attrdict
Polygon3 Polygon3
PyMuPDF==1.18.7
===========================paddle2onnx_params===========================
model_name:en_table_structure
python:python3.7
2onnx: paddle2onnx
--det_model_dir:./inference/en_ppocr_mobile_v2.0_table_structure_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--det_save_file:./inference/en_ppocr_mobile_v2.0_table_structure_infer/model.onnx
--rec_model_dir:
--rec_save_file:
--opset_version:10
--enable_onnx_checker:True
inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt
--use_gpu:True|False
--det_model_dir:
--rec_model_dir:
--image_dir:./ppstructure/docs/table/table.jpg
\ No newline at end of file
===========================train_params===========================
model_name:layoutxlm_ser
python:python3.7
gpu_list:192.168.0.1,192.168.0.2;0,1
Global.use_gpu:True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8
Architecture.Backbone.checkpoints:null
train_model_name:latest
train_infer_img_dir:ppstructure/docs/kie/input/zh_val_42.jpg
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Architecture.Backbone.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o
quant_export:
fpgm_export:
distill_export:null
export1:null
export2:null
##
infer_model:null
infer_export:null
infer_quant:False
inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output
--use_gpu:False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--ser_model_dir:
--image_dir:./ppstructure/docs/kie/input/zh_val_42.jpg
null:null
--benchmark:False
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
===========================train_params===========================
model_name:layoutxlm_ser
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:amp
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8
Architecture.Backbone.checkpoints:null
train_model_name:latest
train_infer_img_dir:ppstructure/docs/kie/input/zh_val_42.jpg
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Architecture.Backbone.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o
quant_export:
fpgm_export:
distill_export:null
export1:null
export2:null
##
infer_model:null
infer_export:null
infer_quant:False
inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--ser_model_dir:
--image_dir:./ppstructure/docs/kie/input/zh_val_42.jpg
null:null
--benchmark:False
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
===========================paddle2onnx_params===========================
model_name:slanet
python:python3.7
2onnx: paddle2onnx
--det_model_dir:./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
--det_save_file:./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/model.onnx
--rec_model_dir:
--rec_save_file:
--opset_version:10
--enable_onnx_checker:True
inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_structure_dict_ch.txt
--use_gpu:True|False
--det_model_dir:
--rec_model_dir:
--image_dir:./ppstructure/docs/table/table.jpg
\ No newline at end of file
===========================train_params===========================
model_name:slanet
python:python3.7
gpu_list:192.168.0.1,192.168.0.2;0,1
Global.use_gpu:True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=50
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
Global.pretrained_model:./pretrain_models/en_ppstructure_mobile_v2.0_SLANet_train/best_accuracy
train_model_name:latest
train_infer_img_dir:./ppstructure/docs/table/table.jpg
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/slanet/SLANet.yml -o
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/slanet/SLANet.yml -o
quant_export:
fpgm_export:
distill_export:null
export1:null
export2:null
##
infer_model:./inference/en_ppstructure_mobile_v2.0_SLANet_train
infer_export:null
infer_quant:False
inference:ppstructure/table/predict_table.py --det_model_dir=./inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=./inference/en_ppocr_mobile_v2.0_table_rec_infer --rec_char_dict_path=./ppocr/utils/dict/table_dict.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --det_limit_side_len=736 --det_limit_type=min --output ./output/table
--use_gpu:False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--table_model_dir:
--image_dir:./ppstructure/docs/table/table.jpg
null:null
--benchmark:False
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,488,488]}]
...@@ -700,10 +700,18 @@ if [ ${MODE} = "cpp_infer" ];then ...@@ -700,10 +700,18 @@ if [ ${MODE} = "cpp_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [[ ${model_name} =~ "en_table_structure" ]];then elif [[ ${model_name} =~ "en_table_structure" ]];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate
cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../
cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar
if [ ${model_name} == "en_table_structure" ]; then
wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate
tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
elif [ ${model_name} == "en_table_structure_PACT" ]; then
wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_slim_infer.tar --no-check-certificate
tar xf en_ppocr_mobile_v2.0_table_structure_slim_infer.tar
fi
cd ../
elif [[ ${model_name} =~ "slanet" ]];then elif [[ ${model_name} =~ "slanet" ]];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate
...@@ -791,6 +799,12 @@ if [ ${MODE} = "paddle2onnx_infer" ];then ...@@ -791,6 +799,12 @@ if [ ${MODE} = "paddle2onnx_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && cd ../ cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && cd ../
elif [[ ${model_name} =~ "slanet" ]];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar --no-check-certificate
cd ./inference/ && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar && cd ../
elif [[ ${model_name} =~ "en_table_structure" ]];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate
cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && cd ../
fi fi
# wget data # wget data
......
...@@ -105,6 +105,19 @@ function func_paddle2onnx(){ ...@@ -105,6 +105,19 @@ function func_paddle2onnx(){
eval $trans_model_cmd eval $trans_model_cmd
last_status=${PIPESTATUS[0]} last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}" status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}"
elif [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then
# trans det
set_dirname=$(func_set_params "--model_dir" "${det_infer_model_dir_value}")
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
set_save_model=$(func_set_params "--save_file" "${det_save_file_value}")
set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}")
set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}")
trans_det_log="${LOG_PATH}/trans_model_det.log"
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} --enable_dev_version=True > ${trans_det_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}"
fi fi
# python inference # python inference
...@@ -117,7 +130,7 @@ function func_paddle2onnx(){ ...@@ -117,7 +130,7 @@ function func_paddle2onnx(){
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}") set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
elif [[ ${model_name} =~ "det" ]]; then elif [[ ${model_name} =~ "det" ]] || [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
elif [[ ${model_name} =~ "rec" ]]; then elif [[ ${model_name} =~ "rec" ]]; then
...@@ -136,7 +149,7 @@ function func_paddle2onnx(){ ...@@ -136,7 +149,7 @@ function func_paddle2onnx(){
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}") set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
elif [[ ${model_name} =~ "det" ]]; then elif [[ ${model_name} =~ "det" ]]|| [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
elif [[ ${model_name} =~ "rec" ]]; then elif [[ ${model_name} =~ "rec" ]]; then
......
...@@ -282,44 +282,67 @@ if __name__ == "__main__": ...@@ -282,44 +282,67 @@ if __name__ == "__main__":
args = utility.parse_args() args = utility.parse_args()
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
text_detector = TextDetector(args) text_detector = TextDetector(args)
count = 0
total_time = 0 total_time = 0
draw_img_save = "./inference_results" draw_img_save_dir = args.draw_img_save_dir
os.makedirs(draw_img_save_dir, exist_ok=True)
if args.warmup: if args.warmup:
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
for i in range(2): for i in range(2):
res = text_detector(img) res = text_detector(img)
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
save_results = [] save_results = []
for image_file in image_file_list: for idx, image_file in enumerate(image_file_list):
img, flag, _ = check_and_read(image_file) img, flag_gif, flag_pdf = check_and_read(image_file)
if not flag: if not flag_gif and not flag_pdf:
img = cv2.imread(image_file) img = cv2.imread(image_file)
if img is None: if not flag_pdf:
logger.info("error in loading image:{}".format(image_file)) if img is None:
continue logger.debug("error in loading image:{}".format(image_file))
st = time.time() continue
dt_boxes, _ = text_detector(img) imgs = [img]
elapse = time.time() - st else:
if count > 0: page_num = args.page_num
if page_num > len(img) or page_num == 0:
page_num = len(img)
imgs = img[:page_num]
for index, img in enumerate(imgs):
st = time.time()
dt_boxes, _ = text_detector(img)
elapse = time.time() - st
total_time += elapse total_time += elapse
count += 1 if len(imgs) > 1:
save_pred = os.path.basename(image_file) + "\t" + str( save_pred = os.path.basename(image_file) + '_' + str(
json.dumps([x.tolist() for x in dt_boxes])) + "\n" index) + "\t" + str(
save_results.append(save_pred) json.dumps([x.tolist() for x in dt_boxes])) + "\n"
logger.info(save_pred) else:
logger.info("The predict time of {}: {}".format(image_file, elapse)) save_pred = os.path.basename(image_file) + "\t" + str(
src_im = utility.draw_text_det_res(dt_boxes, image_file) json.dumps([x.tolist() for x in dt_boxes])) + "\n"
img_name_pure = os.path.split(image_file)[-1] save_results.append(save_pred)
img_path = os.path.join(draw_img_save, logger.info(save_pred)
"det_res_{}".format(img_name_pure)) if len(imgs) > 1:
cv2.imwrite(img_path, src_im) logger.info("{}_{} The predict time of {}: {}".format(
logger.info("The visualized image saved in {}".format(img_path)) idx, index, image_file, elapse))
else:
logger.info("{} The predict time of {}: {}".format(
idx, image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, img)
if flag_gif:
save_file = image_file[:-3] + "png"
elif flag_pdf:
save_file = image_file.replace('.pdf',
'_' + str(index) + '.png')
else:
save_file = image_file
img_path = os.path.join(
draw_img_save_dir,
"det_res_{}".format(os.path.basename(save_file)))
cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f: with open(os.path.join(draw_img_save_dir, "det_results.txt"), 'w') as f:
f.writelines(save_results) f.writelines(save_results)
f.close() f.close()
if args.benchmark: if args.benchmark:
......
...@@ -159,50 +159,75 @@ def main(args): ...@@ -159,50 +159,75 @@ def main(args):
count = 0 count = 0
for idx, image_file in enumerate(image_file_list): for idx, image_file in enumerate(image_file_list):
img, flag, _ = check_and_read(image_file) img, flag_gif, flag_pdf = check_and_read(image_file)
if not flag: if not flag_gif and not flag_pdf:
img = cv2.imread(image_file) img = cv2.imread(image_file)
if img is None: if not flag_pdf:
logger.debug("error in loading image:{}".format(image_file)) if img is None:
continue logger.debug("error in loading image:{}".format(image_file))
starttime = time.time() continue
dt_boxes, rec_res, time_dict = text_sys(img) imgs = [img]
elapse = time.time() - starttime else:
total_time += elapse page_num = args.page_num
if page_num > len(img) or page_num == 0:
logger.debug( page_num = len(img)
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) imgs = img[:page_num]
for text, score in rec_res: for index, img in enumerate(imgs):
logger.debug("{}, {:.3f}".format(text, score)) starttime = time.time()
dt_boxes, rec_res, time_dict = text_sys(img)
res = [{ elapse = time.time() - starttime
"transcription": rec_res[idx][0], total_time += elapse
"points": np.array(dt_boxes[idx]).astype(np.int32).tolist(), if len(imgs) > 1:
} for idx in range(len(dt_boxes))] logger.debug(
save_pred = os.path.basename(image_file) + "\t" + json.dumps( str(idx) + '_' + str(index) + " Predict time of %s: %.3fs"
res, ensure_ascii=False) + "\n" % (image_file, elapse))
save_results.append(save_pred) else:
logger.debug(
if is_visualize: str(idx) + " Predict time of %s: %.3fs" % (image_file,
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) elapse))
boxes = dt_boxes for text, score in rec_res:
txts = [rec_res[i][0] for i in range(len(rec_res))] logger.debug("{}, {:.3f}".format(text, score))
scores = [rec_res[i][1] for i in range(len(rec_res))]
res = [{
draw_img = draw_ocr_box_txt( "transcription": rec_res[i][0],
image, "points": np.array(dt_boxes[i]).astype(np.int32).tolist(),
boxes, } for i in range(len(dt_boxes))]
txts, if len(imgs) > 1:
scores, save_pred = os.path.basename(image_file) + '_' + str(
drop_score=drop_score, index) + "\t" + json.dumps(
font_path=font_path) res, ensure_ascii=False) + "\n"
if flag: else:
image_file = image_file[:-3] + "png" save_pred = os.path.basename(image_file) + "\t" + json.dumps(
cv2.imwrite( res, ensure_ascii=False) + "\n"
os.path.join(draw_img_save_dir, os.path.basename(image_file)), save_results.append(save_pred)
draw_img[:, :, ::-1])
logger.debug("The visualized image saved in {}".format( if is_visualize:
os.path.join(draw_img_save_dir, os.path.basename(image_file)))) image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
boxes = dt_boxes
txts = [rec_res[i][0] for i in range(len(rec_res))]
scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr_box_txt(
image,
boxes,
txts,
scores,
drop_score=drop_score,
font_path=font_path)
if flag_gif:
save_file = image_file[:-3] + "png"
elif flag_pdf:
save_file = image_file.replace('.pdf',
'_' + str(index) + '.png')
else:
save_file = image_file
cv2.imwrite(
os.path.join(draw_img_save_dir,
os.path.basename(save_file)),
draw_img[:, :, ::-1])
logger.debug("The visualized image saved in {}".format(
os.path.join(draw_img_save_dir, os.path.basename(
save_file))))
logger.info("The predict total time is {}".format(time.time() - _st)) logger.info("The predict total time is {}".format(time.time() - _st))
if args.benchmark: if args.benchmark:
......
...@@ -45,6 +45,7 @@ def init_args(): ...@@ -45,6 +45,7 @@ def init_args():
# params for text detector # params for text detector
parser.add_argument("--image_dir", type=str) parser.add_argument("--image_dir", type=str)
parser.add_argument("--page_num", type=int, default=0)
parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_limit_side_len", type=float, default=960) parser.add_argument("--det_limit_side_len", type=float, default=960)
...@@ -337,12 +338,11 @@ def draw_e2e_res(dt_boxes, strs, img_path): ...@@ -337,12 +338,11 @@ def draw_e2e_res(dt_boxes, strs, img_path):
return src_im return src_im
def draw_text_det_res(dt_boxes, img_path): def draw_text_det_res(dt_boxes, img):
src_im = cv2.imread(img_path)
for box in dt_boxes: for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape(-1, 2) box = np.array(box).astype(np.int32).reshape(-1, 2)
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
return src_im return img
def resize_img(img, input_size=600): def resize_img(img, input_size=600):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册