diff --git "a/applications/\344\270\255\346\226\207\350\241\250\346\240\274\350\257\206\345\210\253.md" "b/applications/\344\270\255\346\226\207\350\241\250\346\240\274\350\257\206\345\210\253.md" new file mode 100644 index 0000000000000000000000000000000000000000..af7cc96b70410c614ef39e91c229d705c8bd400a --- /dev/null +++ "b/applications/\344\270\255\346\226\207\350\241\250\346\240\274\350\257\206\345\210\253.md" @@ -0,0 +1,472 @@ +# 智能运营:通用中文表格识别 + +- [1. 背景介绍](#1-背景介绍) +- [2. 中文表格识别](#2-中文表格识别) +- [2.1 环境准备](#21-环境准备) +- [2.2 准备数据集](#22-准备数据集) + - [2.2.1 划分训练测试集](#221-划分训练测试集) + - [2.2.2 查看数据集](#222-查看数据集) +- [2.3 训练](#23-训练) +- [2.4 验证](#24-验证) +- [2.5 训练引擎推理](#25-训练引擎推理) +- [2.6 模型导出](#26-模型导出) +- [2.7 预测引擎推理](#27-预测引擎推理) +- [2.8 表格识别](#28-表格识别) +- [3. 表格属性识别](#3-表格属性识别) +- [3.1 代码、环境、数据准备](#31-代码环境数据准备) + - [3.1.1 代码准备](#311-代码准备) + - [3.1.2 环境准备](#312-环境准备) + - [3.1.3 数据准备](#313-数据准备) +- [3.2 表格属性识别训练](#32-表格属性识别训练) +- [3.3 表格属性识别推理和部署](#33-表格属性识别推理和部署) + - [3.3.1 模型转换](#331-模型转换) + - [3.3.2 模型推理](#332-模型推理) + +## 1. 背景介绍 + +中文表格识别在金融行业有着广泛的应用,如保险理赔、财报分析和信息录入等领域。当前,金融行业的表格识别主要以手动录入为主,开发一种自动表格识别成为丞待解决的问题。 +![](https://ai-studio-static-online.cdn.bcebos.com/d1e7780f0c7745ada4be540decefd6288e4d59257d8141f6842682a4c05d28b6) + + +在金融行业中,表格图像主要有清单类的单元格密集型表格,申请表类的大单元格表格,拍照表格和倾斜表格四种主要形式。 + +![](https://ai-studio-static-online.cdn.bcebos.com/da82ae8ef8fd479aaa38e1049eb3a681cf020dc108fa458eb3ec79da53b45fd1) +![](https://ai-studio-static-online.cdn.bcebos.com/5ffff2093a144a6993a75eef71634a52276015ee43a04566b9c89d353198c746) + + +当前的表格识别算法不能很好的处理这些场景下的表格图像。在本例中,我们使用PP-Structurev2最新发布的表格识别模型SLANet来演示如何进行中文表格是识别。同时,为了方便作业流程,我们使用表格属性识别模型对表格图像的属性进行识别,对表格的难易程度进行判断,加快人工进行校对速度。 + +本项目AI Studio链接:https://aistudio.baidu.com/aistudio/projectdetail/4588067 + +## 2. 中文表格识别 +### 2.1 环境准备 + + +```python +# 下载PaddleOCR代码 +! git clone -b dygraph https://gitee.com/paddlepaddle/PaddleOCR +``` + + +```python +# 安装PaddleOCR环境 +! pip install -r PaddleOCR/requirements.txt --force-reinstall +! pip install protobuf==3.19 +``` + +### 2.2 准备数据集 + +本例中使用的数据集采用表格[生成工具](https://github.com/WenmuZhou/TableGeneration)制作。 + +使用如下命令对数据集进行解压,并查看数据集大小 + + +```python +! cd data/data165849 && tar -xf table_gen_dataset.tar && cd - +! wc -l data/data165849/table_gen_dataset/gt.txt +``` + +#### 2.2.1 划分训练测试集 + +使用下述命令将数据集划分为训练集和测试集, 这里将90%划分为训练集,10%划分为测试集 + + +```python +import random +with open('/home/aistudio/data/data165849/table_gen_dataset/gt.txt') as f: + lines = f.readlines() +random.shuffle(lines) +train_len = int(len(lines)*0.9) +train_list = lines[:train_len] +val_list = lines[train_len:] + +# 保存结果 +with open('/home/aistudio/train.txt','w',encoding='utf-8') as f: + f.writelines(train_list) +with open('/home/aistudio/val.txt','w',encoding='utf-8') as f: + f.writelines(val_list) +``` + +划分完成后,数据集信息如下 + +|类型|数量|图片地址|标注文件路径| +|---|---|---|---| +|训练集|18000|/home/aistudio/data/data165849/table_gen_dataset|/home/aistudio/train.txt| +|测试集|2000|/home/aistudio/data/data165849/table_gen_dataset|/home/aistudio/val.txt| + +#### 2.2.2 查看数据集 + + +```python +import cv2 +import os, json +import numpy as np +from matplotlib import pyplot as plt +%matplotlib inline + +def parse_line(data_dir, line): + data_line = line.strip("\n") + info = json.loads(data_line) + file_name = info['filename'] + cells = info['html']['cells'].copy() + structure = info['html']['structure']['tokens'].copy() + + img_path = os.path.join(data_dir, file_name) + if not os.path.exists(img_path): + print(img_path) + return None + data = { + 'img_path': img_path, + 'cells': cells, + 'structure': structure, + 'file_name': file_name + } + return data + +def draw_bbox(img_path, points, color=(255, 0, 0), thickness=2): + if isinstance(img_path, str): + img_path = cv2.imread(img_path) + img_path = img_path.copy() + for point in points: + cv2.polylines(img_path, [point.astype(int)], True, color, thickness) + return img_path + + +def rebuild_html(data): + html_code = data['structure'] + cells = data['cells'] + to_insert = [i for i, tag in enumerate(html_code) if tag in ('', '>')] + + for i, cell in zip(to_insert[::-1], cells[::-1]): + if cell['tokens']: + text = ''.join(cell['tokens']) + # skip empty text + sp_char_list = ['', '', '\u2028', ' ', '', ''] + text_remove_style = skip_char(text, sp_char_list) + if len(text_remove_style) == 0: + continue + html_code.insert(i + 1, text) + + html_code = ''.join(html_code) + return html_code + + +def skip_char(text, sp_char_list): + """ + skip empty cell + @param text: text in cell + @param sp_char_list: style char and special code + @return: + """ + for sp_char in sp_char_list: + text = text.replace(sp_char, '') + return text + +save_dir = '/home/aistudio/vis' +os.makedirs(save_dir, exist_ok=True) +image_dir = '/home/aistudio/data/data165849/' +html_str = '' + +# 解析标注信息并还原html表格 +data = parse_line(image_dir, val_list[0]) + +img = cv2.imread(data['img_path']) +img_name = ''.join(os.path.basename(data['file_name']).split('.')[:-1]) +img_save_name = os.path.join(save_dir, img_name) +boxes = [np.array(x['bbox']) for x in data['cells']] +show_img = draw_bbox(data['img_path'], boxes) +cv2.imwrite(img_save_name + '_show.jpg', show_img) + +html = rebuild_html(data) +html_str += html +html_str += '
' + +# 显示标注的html字符串 +from IPython.core.display import display, HTML +display(HTML(html_str)) +# 显示单元格坐标 +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() +``` + +### 2.3 训练 + +这里选用PP-Structurev2中的表格识别模型[SLANet](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/configs/table/SLANet.yml) + +SLANet是PP-Structurev2全新推出的表格识别模型,相比PP-Structurev1中TableRec-RARE,在速度不变的情况下精度提升4.7%。TEDS提升2% + + +|算法|Acc|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|Speed| +| --- | --- | --- | ---| +| EDD[2] |x| 88.3% |x| +| TableRec-RARE(ours) | 71.73%| 93.88% |779ms| +| SLANet(ours) | 76.31%| 95.89%|766ms| + +进行训练之前先使用如下命令下载预训练模型 + + +```python +# 进入PaddleOCR工作目录 +os.chdir('/home/aistudio/PaddleOCR') +# 下载英文预训练模型 +! wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_train.tar --no-check-certificate +! cd ./pretrain_models/ && tar xf en_ppstructure_mobile_v2.0_SLANet_train.tar && cd ../ +``` + +使用如下命令即可启动训练,需要修改的配置有 + +|字段|修改值|含义| +|---|---|---| +|Global.pretrained_model|./pretrain_models/en_ppstructure_mobile_v2.0_SLANet_train/best_accuracy.pdparams|指向英文表格预训练模型地址| +|Global.eval_batch_step|562|模型多少step评估一次,一般设置为一个epoch总的step数| +|Optimizer.lr.name|Const|学习率衰减器 | +|Optimizer.lr.learning_rate|0.0005|学习率设为之前的0.05倍 | +|Train.dataset.data_dir|/home/aistudio/data/data165849|指向训练集图片存放目录 | +|Train.dataset.label_file_list|/home/aistudio/data/data165849/table_gen_dataset/train.txt|指向训练集标注文件 | +|Train.loader.batch_size_per_card|32|训练时每张卡的batch_size | +|Train.loader.num_workers|1|训练集多进程数据读取的进程数,在aistudio中需要设为1 | +|Eval.dataset.data_dir|/home/aistudio/data/data165849|指向测试集图片存放目录 | +|Eval.dataset.label_file_list|/home/aistudio/data/data165849/table_gen_dataset/val.txt|指向测试集标注文件 | +|Eval.loader.batch_size_per_card|32|测试时每张卡的batch_size | +|Eval.loader.num_workers|1|测试集多进程数据读取的进程数,在aistudio中需要设为1 | + + +已经修改好的配置存储在 `/home/aistudio/SLANet_ch.yml` + + +```python +import os +os.chdir('/home/aistudio/PaddleOCR') +! python3 tools/train.py -c /home/aistudio/SLANet_ch.yml +``` + +大约在7个epoch后达到最高精度 97.49% + +### 2.4 验证 + +训练完成后,可使用如下命令在测试集上评估最优模型的精度 + + +```python +! python3 tools/eval.py -c /home/aistudio/SLANet_ch.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/SLANet_ch/best_accuracy.pdparams +``` + +### 2.5 训练引擎推理 +使用如下命令可使用训练引擎对单张图片进行推理 + + +```python +import os;os.chdir('/home/aistudio/PaddleOCR') +! python3 tools/infer_table.py -c /home/aistudio/SLANet_ch.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/SLANet_ch/best_accuracy.pdparams Global.infer_img=/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg +``` + + +```python +import cv2 +from matplotlib import pyplot as plt +%matplotlib inline + +# 显示原图 +show_img = cv2.imread('/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg') +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() + +# 显示预测的单元格 +show_img = cv2.imread('/home/aistudio/PaddleOCR/output/infer/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg') +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() +``` + +### 2.6 模型导出 + +使用如下命令可将模型导出为inference模型 + + +```python +! python3 tools/export_model.py -c /home/aistudio/SLANet_ch.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/SLANet_ch/best_accuracy.pdparams Global.save_inference_dir=/home/aistudio/SLANet_ch/infer +``` + +### 2.7 预测引擎推理 +使用如下命令可使用预测引擎对单张图片进行推理 + + + +```python +os.chdir('/home/aistudio/PaddleOCR/ppstructure') +! python3 table/predict_structure.py \ + --table_model_dir=/home/aistudio/SLANet_ch/infer \ + --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ + --image_dir=/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg \ + --output=../output/inference +``` + + +```python +# 显示原图 +show_img = cv2.imread('/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg') +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() + +# 显示预测的单元格 +show_img = cv2.imread('/home/aistudio/PaddleOCR/output/inference/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg') +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() +``` + +### 2.8 表格识别 + +在表格结构模型训练完成后,可结合OCR检测识别模型,对表格内容进行识别。 + +首先下载PP-OCRv3文字检测识别模型 + + +```python +# 下载PP-OCRv3文本检测识别模型并解压 +! wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar --no-check-certificate +! wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar --no-check-certificate +! cd ./inference/ && tar xf ch_PP-OCRv3_det_slim_infer.tar && tar xf ch_PP-OCRv3_rec_slim_infer.tar && cd ../ +``` + +模型下载完成后,使用如下命令进行表格识别 + + +```python +import os;os.chdir('/home/aistudio/PaddleOCR/ppstructure') +! python3 table/predict_table.py \ + --det_model_dir=inference/ch_PP-OCRv3_det_slim_infer \ + --rec_model_dir=inference/ch_PP-OCRv3_rec_slim_infer \ + --table_model_dir=/home/aistudio/SLANet_ch/infer \ + --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \ + --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ + --image_dir=/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg \ + --output=../output/table +``` + + +```python +# 显示原图 +show_img = cv2.imread('/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg') +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() + +# 显示预测结果 +from IPython.core.display import display, HTML +display(HTML('
alleadersh不贰过,推从自己参与浙江数。另一方
AnSha自己越共商共建工作协商w.east 抓好改革试点任务
EdimeImisesElec怀天下”。22.26 31.614.30 794.94
ip Profundi:2019年12月1Horspro444.482.41 87679.98
iehaiTrain组长蒋蕊Toafterdec203.4323.54 44266.62
Tyint roudlyRol谢您的好意,我知道ErChows48.9010316
NaFlint一辈的aterreclam7823.869829.237.96 3068
家上下游企业,5Tr景象。当地球上的我们Urelaw799.62354.9612.9833
赛事( uestCh复制的业务模式并Listicjust9.239253.22
Ca Iskole扶贫"之名引导 Papua 7191.901.653.6248
避讳ir但由于Fficeof0.226.377.173397.75
ndaTurk百处遗址gMa1288.342053.662.29885.45
')) +``` + +## 3. 表格属性识别 +### 3.1 代码、环境、数据准备 +#### 3.1.1 代码准备 +首先,我们需要准备训练表格属性的代码,PaddleClas集成了PULC方案,该方案可以快速获得一个在CPU上用时2ms的属性识别模型。PaddleClas代码可以clone下载得到。获取方式如下: + + + +```python +! git clone -b develop https://gitee.com/paddlepaddle/PaddleClas +``` + +#### 3.1.2 环境准备 +其次,我们需要安装训练PaddleClas相关的依赖包 + + +```python +! pip install -r PaddleClas/requirements.txt --force-reinstall +! pip install protobuf==3.20.0 +``` + + +#### 3.1.3 数据准备 + +最后,准备训练数据。在这里,我们一共定义了表格的6个属性,分别是表格来源、表格数量、表格颜色、表格清晰度、表格有无干扰、表格角度。其可视化如下: + +![](https://user-images.githubusercontent.com/45199522/190587903-ccdfa6fb-51e8-42de-b08b-a127cb04e304.png) + +这里,我们提供了一个表格属性的demo子集,可以快速迭代体验。下载方式如下: + + +```python +%cd PaddleClas/dataset +!wget https://paddleclas.bj.bcebos.com/data/PULC/table_attribute.tar +!tar -xf table_attribute.tar +%cd ../PaddleClas/dataset +%cd ../ +``` + +### 3.2 表格属性识别训练 +表格属性训练整体pipelinie如下: + +![](https://user-images.githubusercontent.com/45199522/190599426-3415b38e-e16e-4e68-9253-2ff531b1b5ca.png) + +1.训练过程中,图片经过预处理之后,送入到骨干网络之中,骨干网络将抽取表格图片的特征,最终该特征连接输出的FC层,FC层经过Sigmoid激活函数后和真实标签做交叉熵损失函数,优化器通过对该损失函数做梯度下降来更新骨干网络的参数,经过多轮训练后,骨干网络的参数可以对为止图片做很好的预测; + +2.推理过程中,图片经过预处理之后,送入到骨干网络之中,骨干网络加载学习好的权重后对该表格图片做出预测,预测的结果为一个6维向量,该向量中的每个元素反映了每个属性对应的概率值,通过对该值进一步卡阈值之后,得到最终的输出,最终的输出描述了该表格的6个属性。 + +当准备好相关的数据之后,可以一键启动表格属性的训练,训练代码如下: + + +```python + +!python tools/train.py -c ./ppcls/configs/PULC/table_attribute/PPLCNet_x1_0.yaml -o Global.device=cpu -o Global.epochs=10 +``` + +### 3.3 表格属性识别推理和部署 +#### 3.3.1 模型转换 +当训练好模型之后,需要将模型转换为推理模型进行部署。转换脚本如下: + + +```python +!python tools/export_model.py -c ppcls/configs/PULC/table_attribute/PPLCNet_x1_0.yaml -o Global.pretrained_model=output/PPLCNet_x1_0/best_model +``` + +执行以上命令之后,会在当前目录上生成`inference`文件夹,该文件夹中保存了当前精度最高的推理模型。 + +#### 3.3.2 模型推理 +安装推理需要的paddleclas包, 此时需要通过下载安装paddleclas的develop的whl包 + + + +```python +!pip install https://paddleclas.bj.bcebos.com/whl/paddleclas-0.0.0-py3-none-any.whl +``` + +进入`deploy`目录下即可对模型进行推理 + + +```python +%cd deploy/ +``` + +推理命令如下: + + +```python +!python python/predict_cls.py -c configs/PULC/table_attribute/inference_table_attribute.yaml -o Global.inference_model_dir="../inference" -o Global.infer_imgs="../dataset/table_attribute/Table_val/val_9.jpg" +!python python/predict_cls.py -c configs/PULC/table_attribute/inference_table_attribute.yaml -o Global.inference_model_dir="../inference" -o Global.infer_imgs="../dataset/table_attribute/Table_val/val_3253.jpg" +``` + +推理的表格图片: + +![](https://user-images.githubusercontent.com/45199522/190596141-74f4feda-b082-46d7-908d-b0bd5839b430.png) + +预测结果如下: +``` +val_9.jpg: {'attributes': ['Scanned', 'Little', 'Black-and-White', 'Clear', 'Without-Obstacles', 'Horizontal'], 'output': [1, 1, 1, 1, 1, 1]} +``` + + +推理的表格图片: + +![](https://user-images.githubusercontent.com/45199522/190597086-2e685200-22d0-4042-9e46-f61f24e02e4e.png) + +预测结果如下: +``` +val_3253.jpg: {'attributes': ['Photo', 'Little', 'Black-and-White', 'Blurry', 'Without-Obstacles', 'Tilted'], 'output': [0, 1, 1, 0, 1, 0]} +``` + +对比两张图片可以发现,第一张图片比较清晰,表格属性的结果也偏向于比较容易识别,我们可以更相信表格识别的结果,第二张图片比较模糊,且存在倾斜现象,表格识别可能存在错误,需要我们人工进一步校验。通过表格的属性识别能力,可以进一步将“人工”和“智能”很好的结合起来,为表格识别能力的落地的精度提供保障。 diff --git "a/applications/\345\215\260\347\253\240\345\274\257\346\233\262\346\226\207\345\255\227\350\257\206\345\210\253.md" "b/applications/\345\215\260\347\253\240\345\274\257\346\233\262\346\226\207\345\255\227\350\257\206\345\210\253.md" new file mode 100644 index 0000000000000000000000000000000000000000..fce9ea772eed6575de10f50c0ff447aa1aee928b --- /dev/null +++ "b/applications/\345\215\260\347\253\240\345\274\257\346\233\262\346\226\207\345\255\227\350\257\206\345\210\253.md" @@ -0,0 +1,1033 @@ +# 印章弯曲文字识别 + +- [1. 项目介绍](#1-----) +- [2. 环境搭建](#2-----) + * [2.1 准备PaddleDetection环境](#21---paddledetection--) + * [2.2 准备PaddleOCR环境](#22---paddleocr--) +- [3. 数据集准备](#3------) + * [3.1 数据标注](#31-----) + * [3.2 数据处理](#32-----) +- [4. 印章检测实践](#4-------) +- [5. 印章文字识别实践](#5---------) + * [5.1 端对端印章文字识别实践](#51------------) + * [5.2 两阶段印章文字识别实践](#52------------) + + [5.2.1 印章文字检测](#521-------) + + [5.2.2 印章文字识别](#522-------) + + +# 1. 项目介绍 + +弯曲文字识别在OCR任务中有着广泛的应用,比如:自然场景下的招牌,艺术文字,以及常见的印章文字识别。 + +在本项目中,将以印章识别任务为例,介绍如何使用PaddleDetection和PaddleOCR完成印章检测和印章文字识别任务。 + +项目难点: +1. 缺乏训练数据 +2. 图像质量参差不齐,图像模糊,文字不清晰 + +针对以上问题,本项目选用PaddleOCR里的PPOCRLabel工具完成数据标注。基于PaddleDetection完成印章区域检测,然后通过PaddleOCR里的端对端OCR算法和两阶段OCR算法分别完成印章文字识别任务。不同任务的精度效果如下: + + +| 任务 | 训练数据数量 | 精度 | +| -------- | - | -------- | +| 印章检测 | 1000 | 95% | +| 印章文字识别-端对端OCR方法 | 700 | 47% | +| 印章文字识别-两阶段OCR方法 | 700 | 55% | + +点击进入 [AI Studio 项目](https://aistudio.baidu.com/aistudio/projectdetail/4586113) + +# 2. 环境搭建 + +本项目需要准备PaddleDetection和PaddleOCR的项目运行环境,其中PaddleDetection用于实现印章检测任务,PaddleOCR用于实现文字识别任务 + + +## 2.1 准备PaddleDetection环境 + +下载PaddleDetection代码: +``` +!git clone https://github.com/PaddlePaddle/PaddleDetection.git +# 如果克隆github代码较慢,请从gitee上克隆代码 +#git clone https://gitee.com/PaddlePaddle/PaddleDetection.git +``` + +安装PaddleDetection依赖 +``` +!cd PaddleDetection && pip install -r requirements.txt +``` + +## 2.2 准备PaddleOCR环境 + +下载PaddleOCR代码: +``` +!git clone https://github.com/PaddlePaddle/PaddleOCR.git +# 如果克隆github代码较慢,请从gitee上克隆代码 +#git clone https://gitee.com/PaddlePaddle/PaddleOCR.git +``` + +安装PaddleOCR依赖 +``` +!cd PaddleOCR && git checkout dygraph && pip install -r requirements.txt +``` + +# 3. 数据集准备 + +## 3.1 数据标注 + +本项目中使用[PPOCRLabel](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.6/PPOCRLabel)工具标注印章检测数据,标注内容包括印章的位置以及印章中文字的位置和文字内容。 + + +注:PPOCRLabel的使用方法参考[文档](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.6/PPOCRLabel)。 + +PPOCRlabel标注印章数据步骤: +- 打开数据集所在文件夹 +- 按下快捷键Q进行4点(多点)标注——针对印章文本识别, + - 印章弯曲文字包围框采用偶数点标注(比如4点,8点,16点),按照阅读顺序,以16点标注为例,从文字左上方开始标注->到文字右上方标注8个点->到文字右下方->文字左下方8个点,一共8个点,形成包围曲线,参考下图。如果文字弯曲程度不高,为了减小标注工作量,可以采用4点、8点标注,需要注意的是,文字上下点数相同。(总点数尽量不要超过18个) + - 对于需要识别的印章中非弯曲文字,采用4点框标注即可 + - 对应包围框的文字部分默认是”待识别”,需要修改为包围框内的具体文字内容 +- 快捷键W进行矩形标注——针对印章区域检测,印章检测区域保证标注框包围整个印章,包围框对应文字可以设置为'印章区域',方便后续处理。 +- 针对印章中的水平文字可以视情况考虑矩形或四点标注:保证按行标注即可。如果背景文字与印章文字比较接近,标注时尽量避开背景文字。 +- 标注完成后修改右侧文本结果,确认无误后点击下方check(或CTRL+V),确认本张图片的标注。 +- 所有图片标注完成后,在顶部菜单栏点击File -> Export Label导出label.txt。 + +标注完成后,可视化效果如下: +![](https://ai-studio-static-online.cdn.bcebos.com/f5acbc4f50dd401a8f535ed6a263f94b0edff82c1aed4285836a9ead989b9c13) + +数据标注完成后,标签中包含印章检测的标注和印章文字识别的标注,如下所示: +``` +img/1.png [{"transcription": "印章区域", "points": [[87, 245], [214, 245], [214, 369], [87, 369]], "difficult": false}, {"transcription": "国家税务总局泸水市税务局第二税务分局", "points": [[110, 314], [116, 290], [131, 275], [152, 273], [170, 277], [181, 289], [186, 303], [186, 312], [201, 311], [198, 289], [189, 272], [175, 259], [152, 252], [124, 257], [100, 280], [94, 312]], "difficult": false}, {"transcription": "征税专用章", "points": [[117, 334], [183, 334], [183, 352], [117, 352]], "difficult": false}] +``` +标注中包含表示'印章区域'的坐标和'印章文字'坐标以及文字内容。 + + + +## 3.2 数据处理 + +标注时为了方便标注,没有区分印章区域的标注框和文字区域的标注框,可以通过python代码完成标签的划分。 + +在本项目的'/home/aistudio/work/seal_labeled_datas'目录下,存放了标注的数据示例,如下: + + +![](https://ai-studio-static-online.cdn.bcebos.com/3d762970e2184177a2c633695a31029332a4cd805631430ea797309492e45402) + +标签文件'/home/aistudio/work/seal_labeled_datas/Label.txt'中的标注内容如下: + +``` +img/test1.png [{"transcription": "待识别", "points": [[408, 232], [537, 232], [537, 352], [408, 352]], "difficult": false}, {"transcription": "电子回单", "points": [[437, 305], [504, 305], [504, 322], [437, 322]], "difficult": false}, {"transcription": "云南省农村信用社", "points": [[417, 290], [434, 295], [438, 281], [446, 267], [455, 261], [472, 258], [489, 264], [498, 277], [502, 295], [526, 289], [518, 267], [503, 249], [475, 232], [446, 239], [429, 255], [418, 275]], "difficult": false}, {"transcription": "专用章", "points": [[437, 319], [503, 319], [503, 338], [437, 338]], "difficult": false}] +``` + + +为了方便训练,我们需要通过python代码将用于训练印章检测和训练印章文字识别的标注区分开。 + + +``` +import numpy as np +import json +import cv2 +import os +from shapely.geometry import Polygon + + +def poly2box(poly): + xmin = np.min(np.array(poly)[:, 0]) + ymin = np.min(np.array(poly)[:, 1]) + xmax = np.max(np.array(poly)[:, 0]) + ymax = np.max(np.array(poly)[:, 1]) + return np.array([[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]) + + +def draw_text_det_res(dt_boxes, src_im, color=(255, 255, 0)): + for box in dt_boxes: + box = np.array(box).astype(np.int32).reshape(-1, 2) + cv2.polylines(src_im, [box], True, color=color, thickness=2) + return src_im + +class LabelDecode(object): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + label = json.loads(data['label']) + + nBox = len(label) + seal_boxes = self.get_seal_boxes(label) + + gt_label = [] + + for seal_box in seal_boxes: + seal_anno = {'seal_box': seal_box} + boxes, txts, txt_tags = [], [], [] + + for bno in range(0, nBox): + box = label[bno]['points'] + txt = label[bno]['transcription'] + try: + ints = self.get_intersection(box, seal_box) + except Exception as E: + print(E) + continue + + if abs(Polygon(box).area - self.get_intersection(box, seal_box)) < 1e-3 and \ + abs(Polygon(box).area - self.get_union(box, seal_box)) > 1e-3: + + boxes.append(box) + txts.append(txt) + if txt in ['*', '###', '待识别']: + txt_tags.append(True) + else: + txt_tags.append(False) + + seal_anno['polys'] = boxes + seal_anno['texts'] = txts + seal_anno['ignore_tags'] = txt_tags + + gt_label.append(seal_anno) + + return gt_label + + def get_seal_boxes(self, label): + + nBox = len(label) + seal_box = [] + for bno in range(0, nBox): + box = label[bno]['points'] + if len(box) == 4: + seal_box.append(box) + + if len(seal_box) == 0: + return None + + seal_box = self.valid_seal_box(seal_box) + return seal_box + + + def is_seal_box(self, box, boxes): + is_seal = True + for poly in boxes: + if list(box.shape()) != list(box.shape.shape()): + if abs(Polygon(box).area - self.get_intersection(box, poly)) < 1e-3: + return False + else: + if np.sum(np.array(box) - np.array(poly)) < 1e-3: + # continue when the box is same with poly + continue + if abs(Polygon(box).area - self.get_intersection(box, poly)) < 1e-3: + return False + return is_seal + + + def valid_seal_box(self, boxes): + if len(boxes) == 1: + return boxes + + new_boxes = [] + flag = True + for k in range(0, len(boxes)): + flag = True + tmp_box = boxes[k] + for i in range(0, len(boxes)): + if k == i: continue + if abs(Polygon(tmp_box).area - self.get_intersection(tmp_box, boxes[i])) < 1e-3: + flag = False + continue + if flag: + new_boxes.append(tmp_box) + + return new_boxes + + + def get_union(self, pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + def get_intersection_over_union(self, pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + def get_intersection(self, pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + def expand_points_num(self, boxes): + max_points_num = 0 + for box in boxes: + if len(box) > max_points_num: + max_points_num = len(box) + ex_boxes = [] + for box in boxes: + ex_box = box + [box[-1]] * (max_points_num - len(box)) + ex_boxes.append(ex_box) + return ex_boxes + + +def gen_extract_label(data_dir, label_file, seal_gt, seal_ppocr_gt): + label_decode_func = LabelDecode() + gts = open(label_file, "r").readlines() + + seal_gt_list = [] + seal_ppocr_list = [] + + for idx, line in enumerate(gts): + img_path, label = line.strip().split("\t") + data = {'label': label, 'img_path':img_path} + res = label_decode_func(data) + src_img = cv2.imread(os.path.join(data_dir, img_path)) + if res is None: + print("ERROR! res is None!") + continue + + anno = [] + for i, gt in enumerate(res): + # print(i, box, type(box), ) + anno.append({'polys': gt['seal_box'], 'cls':1}) + + seal_gt_list.append(f"{img_path}\t{json.dumps(anno)}\n") + seal_ppocr_list.append(f"{img_path}\t{json.dumps(res)}\n") + + if not os.path.exists(os.path.dirname(seal_gt)): + os.makedirs(os.path.dirname(seal_gt)) + if not os.path.exists(os.path.dirname(seal_ppocr_gt)): + os.makedirs(os.path.dirname(seal_ppocr_gt)) + + with open(seal_gt, "w") as f: + f.writelines(seal_gt_list) + f.close() + + with open(seal_ppocr_gt, 'w') as f: + f.writelines(seal_ppocr_list) + f.close() + +def vis_seal_ppocr(data_dir, label_file, save_dir): + + datas = open(label_file, 'r').readlines() + for idx, line in enumerate(datas): + img_path, label = line.strip().split('\t') + img_path = os.path.join(data_dir, img_path) + + label = json.loads(label) + src_im = cv2.imread(img_path) + if src_im is None: + continue + + for anno in label: + seal_box = anno['seal_box'] + txt_boxes = anno['polys'] + + # vis seal box + src_im = draw_text_det_res([seal_box], src_im, color=(255, 255, 0)) + src_im = draw_text_det_res(txt_boxes, src_im, color=(255, 0, 0)) + + save_path = os.path.join(save_dir, os.path.basename(img_path)) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + # print(src_im.shape) + cv2.imwrite(save_path, src_im) + + +def draw_html(img_dir, save_name): + import glob + + images_dir = glob.glob(img_dir + "/*") + print(len(images_dir)) + + html_path = save_name + with open(html_path, 'w') as html: + html.write('\n\n') + html.write('\n') + html.write("") + + html.write("\n") + html.write(f'\n") + html.write(f'' % (base)) + html.write("\n") + + html.write('\n') + html.write('
\n GT') + + for i, filename in enumerate(sorted(images_dir)): + if filename.endswith("txt"): continue + print(filename) + + base = "{}".format(filename) + if True: + html.write("
{filename}\n GT') + html.write('GT 310\n
\n') + html.write('\n\n') + print("ok") + + +def crop_seal_from_img(label_file, data_dir, save_dir, save_gt_path): + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + datas = open(label_file, 'r').readlines() + all_gts = [] + count = 0 + for idx, line in enumerate(datas): + img_path, label = line.strip().split('\t') + img_path = os.path.join(data_dir, img_path) + + label = json.loads(label) + src_im = cv2.imread(img_path) + if src_im is None: + continue + + for c, anno in enumerate(label): + seal_poly = anno['seal_box'] + txt_boxes = anno['polys'] + txts = anno['texts'] + ignore_tags = anno['ignore_tags'] + + box = poly2box(seal_poly) + img_crop = src_im[box[0][1]:box[2][1], box[0][0]:box[2][0], :] + + save_path = os.path.join(save_dir, f"{idx}_{c}.jpg") + cv2.imwrite(save_path, np.array(img_crop)) + + img_gt = [] + for i in range(len(txts)): + txt_boxes_crop = np.array(txt_boxes[i]) + txt_boxes_crop[:, 1] -= box[0, 1] + txt_boxes_crop[:, 0] -= box[0, 0] + img_gt.append({'transcription': txts[i], "points": txt_boxes_crop.tolist(), "ignore_tag": ignore_tags[i]}) + + if len(img_gt) >= 1: + count += 1 + save_gt = f"{os.path.basename(save_path)}\t{json.dumps(img_gt)}\n" + + all_gts.append(save_gt) + + print(f"The num of all image: {len(all_gts)}, and the number of useful image: {count}") + if not os.path.exists(os.path.dirname(save_gt_path)): + os.makedirs(os.path.dirname(save_gt_path)) + + with open(save_gt_path, "w") as f: + f.writelines(all_gts) + f.close() + print("Done") + + + +if __name__ == "__main__": + + # 数据处理 + gen_extract_label("./seal_labeled_datas", "./seal_labeled_datas/Label.txt", "./seal_ppocr_gt/seal_det_img.txt", "./seal_ppocr_gt/seal_ppocr_img.txt") + vis_seal_ppocr("./seal_labeled_datas", "./seal_ppocr_gt/seal_ppocr_img.txt", "./seal_ppocr_gt/seal_ppocr_vis/") + draw_html("./seal_ppocr_gt/seal_ppocr_vis/", "./vis_seal_ppocr.html") + seal_ppocr_img_label = "./seal_ppocr_gt/seal_ppocr_img.txt" + crop_seal_from_img(seal_ppocr_img_label, "./seal_labeled_datas/", "./seal_img_crop", "./seal_img_crop/label.txt") + +``` + +处理完成后,生成的文件如下: +``` +├── seal_img_crop/ +│ ├── 0_0.jpg +│ ├── ... +│ └── label.txt +├── seal_ppocr_gt/ +│ ├── seal_det_img.txt +│ ├── seal_ppocr_img.txt +│ └── seal_ppocr_vis/ +│ ├── test1.png +│ ├── ... +└── vis_seal_ppocr.html + +``` +其中`seal_img_crop/label.txt`文件为印章识别标签文件,其内容格式为: +``` +0_0.jpg [{"transcription": "\u7535\u5b50\u56de\u5355", "points": [[29, 73], [96, 73], [96, 90], [29, 90]], "ignore_tag": false}, {"transcription": "\u4e91\u5357\u7701\u519c\u6751\u4fe1\u7528\u793e", "points": [[9, 58], [26, 63], [30, 49], [38, 35], [47, 29], [64, 26], [81, 32], [90, 45], [94, 63], [118, 57], [110, 35], [95, 17], [67, 0], [38, 7], [21, 23], [10, 43]], "ignore_tag": false}, {"transcription": "\u4e13\u7528\u7ae0", "points": [[29, 87], [95, 87], [95, 106], [29, 106]], "ignore_tag": false}] +``` +可以直接用于PaddleOCR的PGNet算法的训练。 + +`seal_ppocr_gt/seal_det_img.txt`为印章检测标签文件,其内容格式为: +``` +img/test1.png [{"polys": [[408, 232], [537, 232], [537, 352], [408, 352]], "cls": 1}] +``` +为了使用PaddleDetection工具完成印章检测模型的训练,需要将`seal_det_img.txt`转换为COCO或者VOC的数据标注格式。 + +可以直接使用下述代码将印章检测标注转换成VOC格式。 + + +``` +import numpy as np +import json +import cv2 +import os +from shapely.geometry import Polygon + +seal_train_gt = "./seal_ppocr_gt/seal_det_img.txt" +# 注:仅用于示例,实际使用中需要分别转换训练集和测试集的标签 +seal_valid_gt = "./seal_ppocr_gt/seal_det_img.txt" + +def gen_main_train_txt(mode='train'): + if mode == "train": + file_path = seal_train_gt + if mode in ['valid', 'test']: + file_path = seal_valid_gt + + save_path = f"./seal_VOC/ImageSets/Main/{mode}.txt" + save_train_path = f"./seal_VOC/{mode}.txt" + if not os.path.exists(os.path.dirname(save_path)): + os.makedirs(os.path.dirname(save_path)) + + datas = open(file_path, 'r').readlines() + img_names = [] + train_names = [] + for line in datas: + img_name = line.strip().split('\t')[0] + img_name = os.path.basename(img_name) + (i_name, extension) = os.path.splitext(img_name) + t_name = 'JPEGImages/'+str(img_name)+' '+'Annotations/'+str(i_name)+'.xml\n' + train_names.append(t_name) + img_names.append(i_name + "\n") + + with open(save_train_path, "w") as f: + f.writelines(train_names) + f.close() + + with open(save_path, "w") as f: + f.writelines(img_names) + f.close() + + print(f"{mode} save done") + + +def gen_xml_label(mode='train'): + if mode == "train": + file_path = seal_train_gt + if mode in ['valid', 'test']: + file_path = seal_valid_gt + + datas = open(file_path, 'r').readlines() + img_names = [] + train_names = [] + anno_path = "./seal_VOC/Annotations" + img_path = "./seal_VOC/JPEGImages" + + if not os.path.exists(anno_path): + os.makedirs(anno_path) + if not os.path.exists(img_path): + os.makedirs(img_path) + + for idx, line in enumerate(datas): + img_name, label = line.strip().split('\t') + img = cv2.imread(os.path.join("./seal_labeled_datas", img_name)) + cv2.imwrite(os.path.join(img_path, os.path.basename(img_name)), img) + height, width, c = img.shape + img_name = os.path.basename(img_name) + (i_name, extension) = os.path.splitext(img_name) + label = json.loads(label) + + xml_file = open(("./seal_VOC/Annotations" + '/' + i_name + '.xml'), 'w') + xml_file.write('\n') + xml_file.write(' seal_VOC\n') + xml_file.write(' ' + str(img_name) + '\n') + xml_file.write(' ' + 'Annotations/' + str(img_name) + '\n') + xml_file.write(' \n') + xml_file.write(' ' + str(width) + '\n') + xml_file.write(' ' + str(height) + '\n') + xml_file.write(' 3\n') + xml_file.write(' \n') + xml_file.write(' 0\n') + + for anno in label: + poly = anno['polys'] + if anno['cls'] == 1: + gt_cls = 'redseal' + xmin = np.min(np.array(poly)[:, 0]) + ymin = np.min(np.array(poly)[:, 1]) + xmax = np.max(np.array(poly)[:, 0]) + ymax = np.max(np.array(poly)[:, 1]) + xmin,ymin,xmax,ymax= int(xmin),int(ymin),int(xmax),int(ymax) + xml_file.write(' \n') + xml_file.write(' '+str(gt_cls)+'\n') + xml_file.write(' Unspecified\n') + xml_file.write(' 0\n') + xml_file.write(' 0\n') + xml_file.write(' \n') + xml_file.write(' '+str(xmin)+'\n') + xml_file.write(' '+str(ymin)+'\n') + xml_file.write(' '+str(xmax)+'\n') + xml_file.write(' '+str(ymax)+'\n') + xml_file.write(' \n') + xml_file.write(' \n') + xml_file.write('') + xml_file.close() + print(f'{mode} xml save done!') + + +gen_main_train_txt() +gen_main_train_txt('valid') +gen_xml_label('train') +gen_xml_label('valid') + +``` + +数据处理完成后,转换为VOC格式的印章检测数据存储在~/data/seal_VOC目录下,目录组织结构为: + +``` +├── Annotations/ +├── ImageSets/ +│   └── Main/ +│   ├── train.txt +│   └── valid.txt +├── JPEGImages/ +├── train.txt +└── valid.txt +└── label_list.txt +``` + +Annotations下为数据的标签,JPEGImages目录下为图像文件,label_list.txt为标注检测框类别标签文件。 + +在接下来一节中,将介绍如何使用PaddleDetection工具库完成印章检测模型的训练。 + +# 4. 印章检测实践 + +在实际应用中,印章多是出现在合同,发票,公告等场景中,印章文字识别的任务需要排除图像中背景文字的影响,因此需要先检测出图像中的印章区域。 + + +借助PaddleDetection目标检测库可以很容易的实现印章检测任务,使用PaddleDetection训练印章检测任务流程如下: + +- 选择算法 +- 修改数据集配置路径 +- 启动训练 + + +**算法选择** + +PaddleDetection中有许多检测算法可以选择,考虑到每条数据中印章区域较为清晰,且考虑到性能需求。在本项目中,我们采用mobilenetv3为backbone的ppyolo算法完成印章检测任务,对应的配置文件是:configs/ppyolo/ppyolo_mbv3_large.yml + + + +**修改配置文件** + +配置文件中的默认数据路径是COCO, +需要修改为印章检测的数据路径,主要修改如下: +在配置文件'configs/ppyolo/ppyolo_mbv3_large.yml'末尾增加如下内容: +``` +metric: VOC +map_type: 11point +num_classes: 2 + +TrainDataset: + !VOCDataSet + dataset_dir: dataset/seal_VOC + anno_path: train.txt + label_list: label_list.txt + data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] + +EvalDataset: + !VOCDataSet + dataset_dir: dataset/seal_VOC + anno_path: test.txt + label_list: label_list.txt + data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] + +TestDataset: + !ImageFolder + anno_path: dataset/seal_VOC/label_list.txt +``` + +配置文件中设置的数据路径在PaddleDetection/dataset目录下,我们可以将处理后的印章检测训练数据移动到PaddleDetection/dataset目录下或者创建一个软连接。 + +``` +!ln -s seal_VOC ./PaddleDetection/dataset/ +``` + +另外图象中印章数量比较少,可以调整NMS后处理的检测框数量,即keep_top_k,nms_top_k 从100,1000,调整为10,100。在配置文件'configs/ppyolo/ppyolo_mbv3_large.yml'末尾增加如下内容完成后处理参数的调整 +``` +BBoxPostProcess: + decode: + name: YOLOBox + conf_thresh: 0.005 + downsample_ratio: 32 + clip_bbox: true + scale_x_y: 1.05 + nms: + name: MultiClassNMS + keep_top_k: 10 # 修改前100 + nms_threshold: 0.45 + nms_top_k: 100 # 修改前1000 + score_threshold: 0.005 +``` + + +修改完成后,需要在PaddleDetection中增加印章数据的处理代码,即在PaddleDetection/ppdet/data/source/目录下创建seal.py文件,文件中填充如下代码: +``` +import os +import numpy as np +from ppdet.core.workspace import register, serializable +from .dataset import DetDataset +import cv2 +import json + +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + + +@register +@serializable +class SealDataSet(DetDataset): + """ + Load dataset with COCO format. + + Args: + dataset_dir (str): root directory for dataset. + image_dir (str): directory for images. + anno_path (str): coco annotation file path. + data_fields (list): key name of data dictionary, at least have 'image'. + sample_num (int): number of samples to load, -1 means all. + load_crowd (bool): whether to load crowded ground-truth. + False as default + allow_empty (bool): whether to load empty entry. False as default + empty_ratio (float): the ratio of empty record number to total + record's, if empty_ratio is out of [0. ,1.), do not sample the + records and use all the empty entries. 1. as default + """ + + def __init__(self, + dataset_dir=None, + image_dir=None, + anno_path=None, + data_fields=['image'], + sample_num=-1, + load_crowd=False, + allow_empty=False, + empty_ratio=1.): + super(SealDataSet, self).__init__(dataset_dir, image_dir, anno_path, + data_fields, sample_num) + self.load_image_only = False + self.load_semantic = False + self.load_crowd = load_crowd + self.allow_empty = allow_empty + self.empty_ratio = empty_ratio + + def _sample_empty(self, records, num): + # if empty_ratio is out of [0. ,1.), do not sample the records + if self.empty_ratio < 0. or self.empty_ratio >= 1.: + return records + import random + sample_num = min( + int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records)) + records = random.sample(records, sample_num) + return records + + def parse_dataset(self): + anno_path = os.path.join(self.dataset_dir, self.anno_path) + image_dir = os.path.join(self.dataset_dir, self.image_dir) + + records = [] + empty_records = [] + ct = 0 + + assert anno_path.endswith('.txt'), \ + 'invalid seal_gt file: ' + anno_path + + all_datas = open(anno_path, 'r').readlines() + + for idx, line in enumerate(all_datas): + im_path, label = line.strip().split('\t') + img_path = os.path.join(image_dir, im_path) + label = json.loads(label) + im_h, im_w, im_c = cv2.imread(img_path).shape + + coco_rec = { + 'im_file': img_path, + 'im_id': np.array([idx]), + 'h': im_h, + 'w': im_w, + } if 'image' in self.data_fields else {} + + if not self.load_image_only: + bboxes = [] + for anno in label: + poly = anno['polys'] + # poly to box + x1 = np.min(np.array(poly)[:, 0]) + y1 = np.min(np.array(poly)[:, 1]) + x2 = np.max(np.array(poly)[:, 0]) + y2 = np.max(np.array(poly)[:, 1]) + eps = 1e-5 + if x2 - x1 > eps and y2 - y1 > eps: + clean_box = [ + round(float(x), 3) for x in [x1, y1, x2, y2] + ] + anno = {'clean_box': clean_box, 'gt_cls':int(anno['cls'])} + bboxes.append(anno) + else: + logger.info("invalid box") + + num_bbox = len(bboxes) + if num_bbox <= 0: + continue + + gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32) + gt_class = np.zeros((num_bbox, 1), dtype=np.int32) + is_crowd = np.zeros((num_bbox, 1), dtype=np.int32) + # gt_poly = [None] * num_bbox + + for i, box in enumerate(bboxes): + gt_class[i][0] = box['gt_cls'] + gt_bbox[i, :] = box['clean_box'] + is_crowd[i][0] = 0 + + gt_rec = { + 'is_crowd': is_crowd, + 'gt_class': gt_class, + 'gt_bbox': gt_bbox, + # 'gt_poly': gt_poly, + } + + for k, v in gt_rec.items(): + if k in self.data_fields: + coco_rec[k] = v + + records.append(coco_rec) + ct += 1 + if self.sample_num > 0 and ct >= self.sample_num: + break + self.roidbs = records +``` + +**启动训练** + +启动单卡训练的命令为: +``` +!python3 tools/train.py -c configs/ppyolo/ppyolo_mbv3_large.yml --eval + +# 分布式训练命令为: +!python3 -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyolo/ppyolo_mbv3_large.yml --eval +``` + +训练完成后,日志中会打印模型的精度: + +``` +[07/05 11:42:09] ppdet.engine INFO: Eval iter: 0 +[07/05 11:42:14] ppdet.metrics.metrics INFO: Accumulating evaluatation results... +[07/05 11:42:14] ppdet.metrics.metrics INFO: mAP(0.50, 11point) = 99.31% +[07/05 11:42:14] ppdet.engine INFO: Total sample number: 112, averge FPS: 26.45840794253432 +[07/05 11:42:14] ppdet.engine INFO: Best test bbox ap is 0.996. +``` + + +我们可以使用训练好的模型观察预测结果: +``` +!python3 tools/infer.py -c configs/ppyolo/ppyolo_mbv3_large.yml -o weights=./output/ppyolo_mbv3_large/model_final.pdparams --img_dir=./test.jpg +``` +预测结果如下: + +![](https://ai-studio-static-online.cdn.bcebos.com/0f650c032b0f4d56bd639713924768cc820635e9977845008d233f465291a29e) + +# 5. 印章文字识别实践 + +在使用ppyolo检测到印章区域后,接下来借助PaddleOCR里的文字识别能力,完成印章中文字的识别。 + +PaddleOCR中的OCR算法包含文字检测算法,文字识别算法以及OCR端对端算法。 + +文字检测算法负责检测到图像中的文字,再由文字识别模型识别出检测到的文字,进而实现OCR的任务。文字检测+文字识别串联完成OCR任务的架构称为两阶段的OCR算法。相对应的端对端的OCR方法可以用一个算法同时完成文字检测和识别的任务。 + + +| 文字检测 | 文字识别 | 端对端算法 | +| -------- | -------- | -------- | +| DB\DB++\EAST\SAST\PSENet | SVTR\CRNN\NRTN\Abinet\SAR\... | PGNet | + + +本节中将分别介绍端对端的文字检测识别算法以及两阶段的文字检测识别算法在印章检测识别任务上的实践。 + + +## 5.1 端对端印章文字识别实践 + +本节介绍使用PaddleOCR里的PGNet算法完成印章文字识别。 + +PGNet属于端对端的文字检测识别算法,在PaddleOCR中的配置文件为: +[PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/e2e/e2e_r50_vd_pg.yml) + +使用PGNet完成文字检测识别任务的步骤为: +- 修改配置文件 +- 启动训练 + +PGNet默认配置文件的数据路径为totaltext数据集路径,本次训练中,需要修改为上一节数据处理后得到的标签文件和数据目录: + +训练数据配置修改后如下: +``` +Train: + dataset: + name: PGDataSet + data_dir: ./train_data/seal_ppocr + label_file_list: [./train_data/seal_ppocr/seal_ppocr_img.txt] + ratio_list: [1.0] +``` +测试数据集配置修改后如下: +``` +Eval: + dataset: + name: PGDataSet + data_dir: ./train_data/seal_ppocr_test + label_file_list: [./train_data/seal_ppocr_test/seal_ppocr_img.txt] +``` + +启动训练的命令为: +``` +!python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml +``` +模型训练完成后,可以得到最终的精度为47.4%。数据量较少,以及数据质量较差会影响模型的训练精度,如果有更多的数据参与训练,精度将进一步提升。 + +如需获取已训练模型,请扫文末的二维码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 + +## 5.2 两阶段印章文字识别实践 + +上一节介绍了使用PGNet实现印章识别任务的训练流程。本小节将介绍使用PaddleOCR里的文字检测和文字识别算法分别完成印章文字的检测和识别。 + +### 5.2.1 印章文字检测 + +PaddleOCR中包含丰富的文字检测算法,包含DB,DB++,EAST,SAST,PSENet等等。其中DB,DB++,PSENet均支持弯曲文字检测,本项目中,使用DB++作为印章弯曲文字检测算法。 + +PaddleOCR中发布的db++文字检测算法模型是英文文本检测模型,因此需要重新训练模型。 + + +修改[DB++配置文件](DB++的默认配置文件位于[configs/det/det_r50_db++_icdar15.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/det/det_r50_db%2B%2B_icdar15.yml) +中的数据路径: + + +``` +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/seal_ppocr + label_file_list: [./train_data/seal_ppocr/seal_ppocr_img.txt] + ratio_list: [1.0] +``` +测试数据集配置修改后如下: +``` +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/seal_ppocr_test + label_file_list: [./train_data/seal_ppocr_test/seal_ppocr_img.txt] +``` + + +启动训练: +``` +!python3 tools/train.py -c configs/det/det_r50_db++_icdar15.yml -o Global.epoch_num=100 +``` + +考虑到数据较少,通过Global.epoch_num设置仅训练100个epoch。 +模型训练完成后,在测试集上预测的可视化效果如下: + +![](https://ai-studio-static-online.cdn.bcebos.com/498119182f0a414ab86ae2de752fa31c9ddc3a74a76847049cc57884602cb269) + + +如需获取已训练模型,请扫文末的二维码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 + + +### 5.2.2 印章文字识别 + +上一节中完成了印章文字的检测模型训练,本节介绍印章文字识别模型的训练。识别模型采用SVTR算法,SVTR算法是IJCAI收录的文字识别算法,SVTR模型具备超轻量高精度的特点。 + +在启动训练之前,需要准备印章文字识别需要的数据集,需要使用如下代码,将印章中的文字区域剪切出来构建训练集。 + +``` +import cv2 +import numpy as np + +def get_rotate_crop_image(img, points): + ''' + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + ''' + assert len(points) == 4, "shape of points must be 4*2" + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]))) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + +def run(data_dir, label_file, save_dir): + datas = open(label_file, 'r').readlines() + for idx, line in enumerate(datas): + img_path, label = line.strip().split('\t') + img_path = os.path.join(data_dir, img_path) + + label = json.loads(label) + src_im = cv2.imread(img_path) + if src_im is None: + continue + + for anno in label: + seal_box = anno['seal_box'] + txt_boxes = anno['polys'] + crop_im = get_rotate_crop_image(src_im, text_boxes) + + save_path = os.path.join(save_dir, f'{idx}.png') + if not os.path.exists(save_dir): + os.makedirs(save_dir) + # print(src_im.shape) + cv2.imwrite(save_path, crop_im) + +``` + + +数据处理完成后,即可配置训练的配置文件。SVTR配置文件选择[configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml) +修改SVTR配置文件中的训练数据部分如下: + +``` +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/seal_ppocr_crop/ + label_file_list: + - ./train_data/seal_ppocr_crop/train_list.txt +``` + +修改预测部分配置文件: +``` +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/seal_ppocr_crop/ + label_file_list: + - ./train_data/seal_ppocr_crop_test/train_list.txt +``` + +启动训练: + +``` +!python3 tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml + +``` + +训练完成后可以发现测试集指标达到了61%。 +由于数据较少,训练时会发现在训练集上的acc指标远大于测试集上的acc指标,即出现过拟合现象。通过补充数据和一些数据增强可以缓解这个问题。 + + + +如需获取已训练模型,请扫下图二维码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 + + +![](https://ai-studio-static-online.cdn.bcebos.com/ea32877b717643289dc2121a2e573526d99d0f9eecc64ad4bd8dcf121cb5abde) diff --git a/configs/rec/rec_r31_robustscanner.yml b/configs/rec/rec_r31_robustscanner.yml index 40d39aee3c42c18085ace035944dba057b923245..54b69d456ef67c88289504ed5cf8719588ea0803 100644 --- a/configs/rec/rec_r31_robustscanner.yml +++ b/configs/rec/rec_r31_robustscanner.yml @@ -12,7 +12,7 @@ Global: checkpoints: save_inference_dir: use_visualdl: False - infer_img: ./inference/rec_inference + infer_img: doc/imgs_words_en/word_10.png # for data or label process character_dict_path: ppocr/utils/dict90.txt max_text_length: &max_text_length 40 diff --git a/configs/rec/rec_r32_gaspin_bilstm_att.yml b/configs/rec/rec_r32_gaspin_bilstm_att.yml index aea71388f703376120af4d0caf2fa8ccd4d92cce..91d3e104188c04f4372a6e574b7fc07608234a2e 100644 --- a/configs/rec/rec_r32_gaspin_bilstm_att.yml +++ b/configs/rec/rec_r32_gaspin_bilstm_att.yml @@ -12,7 +12,7 @@ Global: checkpoints: save_inference_dir: use_visualdl: False - infer_img: doc/imgs_words/ch/word_1.jpg + infer_img: doc/imgs_words_en/word_10.png # for data or label process character_dict_path: ./ppocr/utils/dict/spin_dict.txt max_text_length: 25 diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml index 9355a236e15b60db18e8715c2702701fd5d36c71..9d286f4153eaab44bf0d259bbad4a0b3b8ada568 100755 --- a/configs/table/table_mv3.yml +++ b/configs/table/table_mv3.yml @@ -43,7 +43,6 @@ Architecture: Head: name: TableAttentionHead hidden_size: 256 - loc_type: 2 max_text_length: *max_text_length loc_reg_num: &loc_reg_num 4 diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 4351fdbfcb501945d6061bc1fcc3585bd76eab7d..e9bcc275d1a7157628188c337a312a49408d207b 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -101,11 +101,10 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广 |ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) | |ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) | |VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) | -|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | -|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon | +|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) | +|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)| |RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) | - ## 2. 端到端算法 diff --git a/doc/doc_ch/algorithm_rec_robustscanner.md b/doc/doc_ch/algorithm_rec_robustscanner.md index 869f9a7c00b617de87ab3c96326e18e536bc18a8..a1ab3baf038f573c50fcc0b67fd8297459777cf8 100644 --- a/doc/doc_ch/algorithm_rec_robustscanner.md +++ b/doc/doc_ch/algorithm_rec_robustscanner.md @@ -26,7 +26,7 @@ Zhang |模型|骨干网络|配置文件|Acc|下载链接| | --- | --- | --- | --- | --- | -|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon| +|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)| 注:除了使用MJSynth和SynthText两个文字识别数据集外,还加入了[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg)数据(提取码:627x),和部分真实数据,具体数据细节可以参考论文。 diff --git a/doc/doc_ch/algorithm_rec_spin.md b/doc/doc_ch/algorithm_rec_spin.md index c996992d2fa6297e6086ffae4bc36ad3e880873d..908a85a417c4070b95630b37b0830e08aae3ff4f 100644 --- a/doc/doc_ch/algorithm_rec_spin.md +++ b/doc/doc_ch/algorithm_rec_spin.md @@ -26,7 +26,7 @@ SPIN收录于AAAI2020。主要用于OCR识别任务。在任意形状文本识 |模型|骨干网络|配置文件|Acc|下载链接| | --- | --- | --- | --- | --- | -|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon| +|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar)| diff --git a/doc/doc_ch/inference_args.md b/doc/doc_ch/inference_args.md index 36efc6fbf7a6ec62bc700964dc13261fecdb9bd5..24e7223e397c94fe65b0f26d993fc507b323ed16 100644 --- a/doc/doc_ch/inference_args.md +++ b/doc/doc_ch/inference_args.md @@ -7,6 +7,7 @@ | 参数名称 | 类型 | 默认值 | 含义 | | :--: | :--: | :--: | :--: | | image_dir | str | 无,必须显式指定 | 图像或者文件夹路径 | +| page_num | int | 0 | 当输入类型为pdf文件时有效,指定预测前面page_num页,默认预测所有页 | | vis_font_path | str | "./doc/fonts/simfang.ttf" | 用于可视化的字体路径 | | drop_score | float | 0.5 | 识别得分小于该值的结果会被丢弃,不会作为返回结果 | | use_pdserving | bool | False | 是否使用Paddle Serving进行预测 | diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index f7ef7ad4b8fc162a6ed7b275f3beea6955b86452..90449e1729fcff898f27641d3f777c8f002f6a97 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -98,8 +98,8 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) | |ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) | |VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) | -|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | -|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon | +|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) | +|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)| |RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) | diff --git a/doc/doc_en/algorithm_rec_robustscanner_en.md b/doc/doc_en/algorithm_rec_robustscanner_en.md index a324a6d547a9e448566276234c750ad4497abf9c..99372d51300e6571827db7181a4f8537db4f1493 100644 --- a/doc/doc_en/algorithm_rec_robustscanner_en.md +++ b/doc/doc_en/algorithm_rec_robustscanner_en.md @@ -26,7 +26,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval |Model|Backbone|config|Acc|Download link| | --- | --- | --- | --- | --- | -|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon| +|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)| Note:In addition to using the two text recognition datasets MJSynth and SynthText, [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg) data (extraction code: 627x), and some real data are used in training, the specific data details can refer to the paper. diff --git a/doc/doc_en/algorithm_rec_spin_en.md b/doc/doc_en/algorithm_rec_spin_en.md index 43ab30ce7d96cbb64ddf87156fee3012d666b2bf..03f8d8f69986fc5eb14cfdf294fc25fafb06e269 100644 --- a/doc/doc_en/algorithm_rec_spin_en.md +++ b/doc/doc_en/algorithm_rec_spin_en.md @@ -25,7 +25,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval |Model|Backbone|config|Acc|Download link| | --- | --- | --- | --- | --- | -|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon| +|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) | diff --git a/doc/doc_en/inference_args_en.md b/doc/doc_en/inference_args_en.md index f2c99fc8297d47f27a219bf7d8e7f2ea518257f0..b28cd8436da62dcd10f96f17751db9384ebcaa8d 100644 --- a/doc/doc_en/inference_args_en.md +++ b/doc/doc_en/inference_args_en.md @@ -7,6 +7,7 @@ When using PaddleOCR for model inference, you can customize the modification par | parameters | type | default | implication | | :--: | :--: | :--: | :--: | | image_dir | str | None, must be specified explicitly | Image or folder path | +| page_num | int | 0 | Valid when the input type is pdf file, specify to predict the previous page_num pages, all pages are predicted by default | | vis_font_path | str | "./doc/fonts/simfang.ttf" | font path for visualization | | drop_score | float | 0.5 | Results with a recognition score less than this value will be discarded and will not be returned as results | | use_pdserving | bool | False | Whether to use Paddle Serving for prediction | diff --git a/paddleocr.py b/paddleocr.py index fa732fc110dc7873f8d89b2ca2a21817a1e6d20d..d34b8f78a56a8d8d5455c18e7e1cf1e75df8f3f9 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -480,10 +480,11 @@ class PaddleOCR(predict_system.TextSystem): params.rec_image_shape = "3, 48, 320" else: params.rec_image_shape = "3, 32, 320" - # download model - maybe_download(params.det_model_dir, det_url) - maybe_download(params.rec_model_dir, rec_url) - maybe_download(params.cls_model_dir, cls_url) + # download model if using paddle infer + if not params.use_onnx: + maybe_download(params.det_model_dir, det_url) + maybe_download(params.rec_model_dir, rec_url) + maybe_download(params.cls_model_dir, cls_url) if params.det_algorithm not in SUPPORT_DET_MODEL: logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py index d3c86e22b02e08c18d8d5cb193f2ffb8b07ad785..50910c5b73aa2a41f329d7222fc8c632509b4c91 100644 --- a/ppocr/modeling/heads/table_att_head.py +++ b/ppocr/modeling/heads/table_att_head.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math import paddle import paddle.nn as nn from paddle import ParamAttr @@ -42,7 +43,6 @@ class TableAttentionHead(nn.Layer): def __init__(self, in_channels, hidden_size, - loc_type, in_max_len=488, max_text_length=800, out_channels=30, @@ -57,20 +57,16 @@ class TableAttentionHead(nn.Layer): self.structure_attention_cell = AttentionGRUCell( self.input_size, hidden_size, self.out_channels, use_gru=False) self.structure_generator = nn.Linear(hidden_size, self.out_channels) - self.loc_type = loc_type self.in_max_len = in_max_len - if self.loc_type == 1: - self.loc_generator = nn.Linear(hidden_size, 4) + if self.in_max_len == 640: + self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1) + elif self.in_max_len == 800: + self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1) else: - if self.in_max_len == 640: - self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1) - elif self.in_max_len == 800: - self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1) - else: - self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1) - self.loc_generator = nn.Linear(self.input_size + hidden_size, - loc_reg_num) + self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1) + self.loc_generator = nn.Linear(self.input_size + hidden_size, + loc_reg_num) def _char_to_onehot(self, input_char, onehot_dim): input_ont_hot = F.one_hot(input_char, onehot_dim) @@ -80,16 +76,13 @@ class TableAttentionHead(nn.Layer): # if and else branch are both needed when you want to assign a variable # if you modify the var in just one branch, then the modification will not work. fea = inputs[-1] - if len(fea.shape) == 3: - pass - else: - last_shape = int(np.prod(fea.shape[2:])) # gry added - fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape]) - fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) + last_shape = int(np.prod(fea.shape[2:])) # gry added + fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape]) + fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) batch_size = fea.shape[0] hidden = paddle.zeros((batch_size, self.hidden_size)) - output_hiddens = [] + output_hiddens = paddle.zeros((batch_size, self.max_text_length + 1, self.hidden_size)) if self.training and targets is not None: structure = targets[0] for i in range(self.max_text_length + 1): @@ -97,7 +90,8 @@ class TableAttentionHead(nn.Layer): structure[:, i], onehot_dim=self.out_channels) (outputs, hidden), alpha = self.structure_attention_cell( hidden, fea, elem_onehots) - output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + output_hiddens[:, i, :] = outputs + # output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output = paddle.concat(output_hiddens, axis=1) structure_probs = self.structure_generator(output) if self.loc_type == 1: @@ -118,30 +112,25 @@ class TableAttentionHead(nn.Layer): outputs = None alpha = None max_text_length = paddle.to_tensor(self.max_text_length) - i = 0 - while i < max_text_length + 1: + for i in range(max_text_length + 1): elem_onehots = self._char_to_onehot( temp_elem, onehot_dim=self.out_channels) (outputs, hidden), alpha = self.structure_attention_cell( hidden, fea, elem_onehots) - output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + output_hiddens[:, i, :] = outputs + # output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) structure_probs_step = self.structure_generator(outputs) temp_elem = structure_probs_step.argmax(axis=1, dtype="int32") - i += 1 - output = paddle.concat(output_hiddens, axis=1) + output = output_hiddens structure_probs = self.structure_generator(output) structure_probs = F.softmax(structure_probs) - if self.loc_type == 1: - loc_preds = self.loc_generator(output) - loc_preds = F.sigmoid(loc_preds) - else: - loc_fea = fea.transpose([0, 2, 1]) - loc_fea = self.loc_fea_trans(loc_fea) - loc_fea = loc_fea.transpose([0, 2, 1]) - loc_concat = paddle.concat([output, loc_fea], axis=2) - loc_preds = self.loc_generator(loc_concat) - loc_preds = F.sigmoid(loc_preds) + loc_fea = fea.transpose([0, 2, 1]) + loc_fea = self.loc_fea_trans(loc_fea) + loc_fea = loc_fea.transpose([0, 2, 1]) + loc_concat = paddle.concat([output, loc_fea], axis=2) + loc_preds = self.loc_generator(loc_concat) + loc_preds = F.sigmoid(loc_preds) return {'structure_probs': structure_probs, 'loc_preds': loc_preds} diff --git a/ppstructure/table/README.md b/ppstructure/table/README.md index 08635516ba8301e6f98f175e5eba8c0a97b1708e..1d082f3878c56e42d175d13c75e1fe17916e7781 100644 --- a/ppstructure/table/README.md +++ b/ppstructure/table/README.md @@ -114,7 +114,7 @@ python3 table/eval_table.py \ --det_model_dir=path/to/det_model_dir \ --rec_model_dir=path/to/rec_model_dir \ --table_model_dir=path/to/table_model_dir \ - --image_dir=../doc/table/1.png \ + --image_dir=docs/table/table.jpg \ --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ --det_limit_side_len=736 \ @@ -145,6 +145,7 @@ python3 table/eval_table.py \ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ --det_limit_side_len=736 \ --det_limit_type=min \ + --rec_image_shape=3,32,320 \ --gt_path=path/to/gt.txt ``` diff --git a/ppstructure/table/README_ch.md b/ppstructure/table/README_ch.md index 1ef126261d9ce832cd1919a1b3991f341add998c..feccb70adfe20fa8c1cd06f33a10ee6fa043e69e 100644 --- a/ppstructure/table/README_ch.md +++ b/ppstructure/table/README_ch.md @@ -118,7 +118,7 @@ python3 table/eval_table.py \ --det_model_dir=path/to/det_model_dir \ --rec_model_dir=path/to/rec_model_dir \ --table_model_dir=path/to/table_model_dir \ - --image_dir=../doc/table/1.png \ + --image_dir=docs/table/table.jpg \ --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ --det_limit_side_len=736 \ @@ -149,6 +149,7 @@ python3 table/eval_table.py \ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ --det_limit_side_len=736 \ --det_limit_type=min \ + --rec_image_shape=3,32,320 \ --gt_path=path/to/gt.txt ``` diff --git a/requirements.txt b/requirements.txt index 43cd8c1b082768ebad44a5cf58fc31980ebfe891..7a018b50952a876b4839eabbd72fac09d2bbd73b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ premailer openpyxl attrdict Polygon3 +PyMuPDF==1.18.7 diff --git a/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..068c4c6b1d2655b9dcda1120425de7d52d0d543d --- /dev/null +++ b/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt @@ -0,0 +1,17 @@ +===========================paddle2onnx_params=========================== +model_name:en_table_structure +python:python3.7 +2onnx: paddle2onnx +--det_model_dir:./inference/en_ppocr_mobile_v2.0_table_structure_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--det_save_file:./inference/en_ppocr_mobile_v2.0_table_structure_infer/model.onnx +--rec_model_dir: +--rec_save_file: +--opset_version:10 +--enable_onnx_checker:True +inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt +--use_gpu:True|False +--det_model_dir: +--rec_model_dir: +--image_dir:./ppstructure/docs/table/table.jpg \ No newline at end of file diff --git a/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..96b43ceda84376f4d27e134d245229922d667e7e --- /dev/null +++ b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:layoutxlm_ser +python:python3.7 +gpu_list:192.168.0.1,192.168.0.2;0,1 +Global.use_gpu:True +Global.auto_cast:fp32 +Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8 +Architecture.Backbone.checkpoints:null +train_model_name:latest +train_infer_img_dir:ppstructure/docs/kie/input/zh_val_42.jpg +null:null +## +trainer:norm_train +norm_train:tools/train.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Architecture.Backbone.checkpoints: +norm_export:tools/export_model.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o +quant_export: +fpgm_export: +distill_export:null +export1:null +export2:null +## +infer_model:null +infer_export:null +infer_quant:False +inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output +--use_gpu:False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 +--ser_model_dir: +--image_dir:./ppstructure/docs/kie/input/zh_val_42.jpg +null:null +--benchmark:False +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] diff --git a/test_tipc/configs/layoutxlm_ser/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..835395784022f4fcefbfff084dcef8a7bc2a146d --- /dev/null +++ b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:layoutxlm_ser +python:python3.7 +gpu_list:0|0,1 +Global.use_gpu:True|True +Global.auto_cast:amp +Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8 +Architecture.Backbone.checkpoints:null +train_model_name:latest +train_infer_img_dir:ppstructure/docs/kie/input/zh_val_42.jpg +null:null +## +trainer:norm_train +norm_train:tools/train.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Architecture.Backbone.checkpoints: +norm_export:tools/export_model.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o +quant_export: +fpgm_export: +distill_export:null +export1:null +export2:null +## +infer_model:null +infer_export:null +infer_quant:False +inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 +--ser_model_dir: +--image_dir:./ppstructure/docs/kie/input/zh_val_42.jpg +null:null +--benchmark:False +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] diff --git a/test_tipc/configs/slanet/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/slanet/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..45e4e9e858914dd8596cef10625df8160afe45fb --- /dev/null +++ b/test_tipc/configs/slanet/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt @@ -0,0 +1,17 @@ +===========================paddle2onnx_params=========================== +model_name:slanet +python:python3.7 +2onnx: paddle2onnx +--det_model_dir:./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--det_save_file:./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/model.onnx +--rec_model_dir: +--rec_save_file: +--opset_version:10 +--enable_onnx_checker:True +inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_structure_dict_ch.txt +--use_gpu:True|False +--det_model_dir: +--rec_model_dir: +--image_dir:./ppstructure/docs/table/table.jpg \ No newline at end of file diff --git a/test_tipc/configs/slanet/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/slanet/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..4c9d8d654ad286264e8511c788e278cdbcd52ec9 --- /dev/null +++ b/test_tipc/configs/slanet/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:slanet +python:python3.7 +gpu_list:192.168.0.1,192.168.0.2;0,1 +Global.use_gpu:True +Global.auto_cast:fp32 +Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=50 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128 +Global.pretrained_model:./pretrain_models/en_ppstructure_mobile_v2.0_SLANet_train/best_accuracy +train_model_name:latest +train_infer_img_dir:./ppstructure/docs/table/table.jpg +null:null +## +trainer:norm_train +norm_train:tools/train.py -c test_tipc/configs/slanet/SLANet.yml -o +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Global.checkpoints: +norm_export:tools/export_model.py -c test_tipc/configs/slanet/SLANet.yml -o +quant_export: +fpgm_export: +distill_export:null +export1:null +export2:null +## +infer_model:./inference/en_ppstructure_mobile_v2.0_SLANet_train +infer_export:null +infer_quant:False +inference:ppstructure/table/predict_table.py --det_model_dir=./inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=./inference/en_ppocr_mobile_v2.0_table_rec_infer --rec_char_dict_path=./ppocr/utils/dict/table_dict.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --det_limit_side_len=736 --det_limit_type=min --output ./output/table +--use_gpu:False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 +--table_model_dir: +--image_dir:./ppstructure/docs/table/table.jpg +null:null +--benchmark:False +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,488,488]}] diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index ecb1e36bb1bb83c6ee2dcf1cb243e6ee60de5dd8..688deac0f379b50865fe6739529f9301ebcd919b 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -700,10 +700,18 @@ if [ ${MODE} = "cpp_infer" ];then wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../ elif [[ ${model_name} =~ "en_table_structure" ]];then - wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate - cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../ + + cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar + if [ ${model_name} == "en_table_structure" ]; then + wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate + tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar + elif [ ${model_name} == "en_table_structure_PACT" ]; then + wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_slim_infer.tar --no-check-certificate + tar xf en_ppocr_mobile_v2.0_table_structure_slim_infer.tar + fi + cd ../ elif [[ ${model_name} =~ "slanet" ]];then wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate @@ -791,6 +799,12 @@ if [ ${MODE} = "paddle2onnx_infer" ];then wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && cd ../ + elif [[ ${model_name} =~ "slanet" ]];then + wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar --no-check-certificate + cd ./inference/ && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar && cd ../ + elif [[ ${model_name} =~ "en_table_structure" ]];then + wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate + cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && cd ../ fi # wget data diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh index 04bfb590f7c6e64cf136d3feef8594994cb86877..f035e6bb645a1e7927844232c2bff72f0480e38e 100644 --- a/test_tipc/test_paddle2onnx.sh +++ b/test_tipc/test_paddle2onnx.sh @@ -105,6 +105,19 @@ function func_paddle2onnx(){ eval $trans_model_cmd last_status=${PIPESTATUS[0]} status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}" + elif [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then + # trans det + set_dirname=$(func_set_params "--model_dir" "${det_infer_model_dir_value}") + set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}") + set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}") + set_save_model=$(func_set_params "--save_file" "${det_save_file_value}") + set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}") + set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}") + trans_det_log="${LOG_PATH}/trans_model_det.log" + trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} --enable_dev_version=True > ${trans_det_log} 2>&1 " + eval $trans_model_cmd + last_status=${PIPESTATUS[0]} + status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}" fi # python inference @@ -117,7 +130,7 @@ function func_paddle2onnx(){ set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}") infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " - elif [[ ${model_name} =~ "det" ]]; then + elif [[ ${model_name} =~ "det" ]] || [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " elif [[ ${model_name} =~ "rec" ]]; then @@ -136,7 +149,7 @@ function func_paddle2onnx(){ set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}") infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " - elif [[ ${model_name} =~ "det" ]]; then + elif [[ ${model_name} =~ "det" ]]|| [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " elif [[ ${model_name} =~ "rec" ]]; then diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 00fa2e9b7fafd949c59a0eebd43f2f88ae717320..52c225d2b3913cf8c0dc88abcc07f7ccfd3cc914 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -282,44 +282,67 @@ if __name__ == "__main__": args = utility.parse_args() image_file_list = get_image_file_list(args.image_dir) text_detector = TextDetector(args) - count = 0 total_time = 0 - draw_img_save = "./inference_results" + draw_img_save_dir = args.draw_img_save_dir + os.makedirs(draw_img_save_dir, exist_ok=True) if args.warmup: img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) for i in range(2): res = text_detector(img) - if not os.path.exists(draw_img_save): - os.makedirs(draw_img_save) save_results = [] - for image_file in image_file_list: - img, flag, _ = check_and_read(image_file) - if not flag: + for idx, image_file in enumerate(image_file_list): + img, flag_gif, flag_pdf = check_and_read(image_file) + if not flag_gif and not flag_pdf: img = cv2.imread(image_file) - if img is None: - logger.info("error in loading image:{}".format(image_file)) - continue - st = time.time() - dt_boxes, _ = text_detector(img) - elapse = time.time() - st - if count > 0: + if not flag_pdf: + if img is None: + logger.debug("error in loading image:{}".format(image_file)) + continue + imgs = [img] + else: + page_num = args.page_num + if page_num > len(img) or page_num == 0: + page_num = len(img) + imgs = img[:page_num] + for index, img in enumerate(imgs): + st = time.time() + dt_boxes, _ = text_detector(img) + elapse = time.time() - st total_time += elapse - count += 1 - save_pred = os.path.basename(image_file) + "\t" + str( - json.dumps([x.tolist() for x in dt_boxes])) + "\n" - save_results.append(save_pred) - logger.info(save_pred) - logger.info("The predict time of {}: {}".format(image_file, elapse)) - src_im = utility.draw_text_det_res(dt_boxes, image_file) - img_name_pure = os.path.split(image_file)[-1] - img_path = os.path.join(draw_img_save, - "det_res_{}".format(img_name_pure)) - cv2.imwrite(img_path, src_im) - logger.info("The visualized image saved in {}".format(img_path)) + if len(imgs) > 1: + save_pred = os.path.basename(image_file) + '_' + str( + index) + "\t" + str( + json.dumps([x.tolist() for x in dt_boxes])) + "\n" + else: + save_pred = os.path.basename(image_file) + "\t" + str( + json.dumps([x.tolist() for x in dt_boxes])) + "\n" + save_results.append(save_pred) + logger.info(save_pred) + if len(imgs) > 1: + logger.info("{}_{} The predict time of {}: {}".format( + idx, index, image_file, elapse)) + else: + logger.info("{} The predict time of {}: {}".format( + idx, image_file, elapse)) + + src_im = utility.draw_text_det_res(dt_boxes, img) + + if flag_gif: + save_file = image_file[:-3] + "png" + elif flag_pdf: + save_file = image_file.replace('.pdf', + '_' + str(index) + '.png') + else: + save_file = image_file + img_path = os.path.join( + draw_img_save_dir, + "det_res_{}".format(os.path.basename(save_file))) + cv2.imwrite(img_path, src_im) + logger.info("The visualized image saved in {}".format(img_path)) - with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f: + with open(os.path.join(draw_img_save_dir, "det_results.txt"), 'w') as f: f.writelines(save_results) f.close() if args.benchmark: diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index e0f2c41fa2aba23491efee920afbd76db1ec84e0..affd0d1bcd1283be02ead3cd61c01c375b49bdf9 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -159,50 +159,75 @@ def main(args): count = 0 for idx, image_file in enumerate(image_file_list): - img, flag, _ = check_and_read(image_file) - if not flag: + img, flag_gif, flag_pdf = check_and_read(image_file) + if not flag_gif and not flag_pdf: img = cv2.imread(image_file) - if img is None: - logger.debug("error in loading image:{}".format(image_file)) - continue - starttime = time.time() - dt_boxes, rec_res, time_dict = text_sys(img) - elapse = time.time() - starttime - total_time += elapse - - logger.debug( - str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) - for text, score in rec_res: - logger.debug("{}, {:.3f}".format(text, score)) - - res = [{ - "transcription": rec_res[idx][0], - "points": np.array(dt_boxes[idx]).astype(np.int32).tolist(), - } for idx in range(len(dt_boxes))] - save_pred = os.path.basename(image_file) + "\t" + json.dumps( - res, ensure_ascii=False) + "\n" - save_results.append(save_pred) - - if is_visualize: - image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - boxes = dt_boxes - txts = [rec_res[i][0] for i in range(len(rec_res))] - scores = [rec_res[i][1] for i in range(len(rec_res))] - - draw_img = draw_ocr_box_txt( - image, - boxes, - txts, - scores, - drop_score=drop_score, - font_path=font_path) - if flag: - image_file = image_file[:-3] + "png" - cv2.imwrite( - os.path.join(draw_img_save_dir, os.path.basename(image_file)), - draw_img[:, :, ::-1]) - logger.debug("The visualized image saved in {}".format( - os.path.join(draw_img_save_dir, os.path.basename(image_file)))) + if not flag_pdf: + if img is None: + logger.debug("error in loading image:{}".format(image_file)) + continue + imgs = [img] + else: + page_num = args.page_num + if page_num > len(img) or page_num == 0: + page_num = len(img) + imgs = img[:page_num] + for index, img in enumerate(imgs): + starttime = time.time() + dt_boxes, rec_res, time_dict = text_sys(img) + elapse = time.time() - starttime + total_time += elapse + if len(imgs) > 1: + logger.debug( + str(idx) + '_' + str(index) + " Predict time of %s: %.3fs" + % (image_file, elapse)) + else: + logger.debug( + str(idx) + " Predict time of %s: %.3fs" % (image_file, + elapse)) + for text, score in rec_res: + logger.debug("{}, {:.3f}".format(text, score)) + + res = [{ + "transcription": rec_res[i][0], + "points": np.array(dt_boxes[i]).astype(np.int32).tolist(), + } for i in range(len(dt_boxes))] + if len(imgs) > 1: + save_pred = os.path.basename(image_file) + '_' + str( + index) + "\t" + json.dumps( + res, ensure_ascii=False) + "\n" + else: + save_pred = os.path.basename(image_file) + "\t" + json.dumps( + res, ensure_ascii=False) + "\n" + save_results.append(save_pred) + + if is_visualize: + image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + boxes = dt_boxes + txts = [rec_res[i][0] for i in range(len(rec_res))] + scores = [rec_res[i][1] for i in range(len(rec_res))] + + draw_img = draw_ocr_box_txt( + image, + boxes, + txts, + scores, + drop_score=drop_score, + font_path=font_path) + if flag_gif: + save_file = image_file[:-3] + "png" + elif flag_pdf: + save_file = image_file.replace('.pdf', + '_' + str(index) + '.png') + else: + save_file = image_file + cv2.imwrite( + os.path.join(draw_img_save_dir, + os.path.basename(save_file)), + draw_img[:, :, ::-1]) + logger.debug("The visualized image saved in {}".format( + os.path.join(draw_img_save_dir, os.path.basename( + save_file)))) logger.info("The predict total time is {}".format(time.time() - _st)) if args.benchmark: diff --git a/tools/infer/utility.py b/tools/infer/utility.py index b9c9490bdb99f3bee67cb9460a9975b93b0d6366..e555dbec1b314510aaaf6b31f1b35bf60fefa98e 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -45,6 +45,7 @@ def init_args(): # params for text detector parser.add_argument("--image_dir", type=str) + parser.add_argument("--page_num", type=int, default=0) parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_limit_side_len", type=float, default=960) @@ -337,12 +338,11 @@ def draw_e2e_res(dt_boxes, strs, img_path): return src_im -def draw_text_det_res(dt_boxes, img_path): - src_im = cv2.imread(img_path) +def draw_text_det_res(dt_boxes, img): for box in dt_boxes: box = np.array(box).astype(np.int32).reshape(-1, 2) - cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) - return src_im + cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2) + return img def resize_img(img, input_size=600):